1. 加载数据集 MNIST
from torchvision import datasets, transforms
transform = transforms.Compose([
transforms.ToTensor(), # value of pixel: [0, 255] -> [0, 1]
transforms.Normalize(mean = (0.5,), std = (0.5,)) # value of tensor: [0, 1] -> [-1, 1]
])
mnist = datasets.MNIST(root='data', train=True, download=True, transform=transform)
transforms.Normalize()
用于将图像进行标准化:\(\rm{\frac{(x – mean)}{std}}\),使得处理的数据呈正态分布。
由于 MNIST 数据集图像为灰度图只有一个通道,因此只需要设置单个通道的 mean 与 std 即可。
这里的取值,可以是将图像像素值[0,255] 缩放至 [0, 1]后求得均值和方差,也可以是根据经验设置,即 mean=0.5, std=0.5。
2. 查看数据
img, label = mnist[len(mnist)-500]
print(f"Label: {label}")
print(f"Some pixel values: {img[0, 10:15, 10:15]}")
print(f"Min value: {img.min()}, Max value: {img.max()}")
Label: 3
Some pixel values: tensor([[-0.9451, -0.6392, -0.9843, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -0.9529, -0.7725],
[-1.0000, -0.8745, -0.0196, 0.5765, 0.7725],
[-1.0000, 0.0902, 0.9922, 0.9922, 0.9922],
[-1.0000, -0.3569, 0.1216, 0.1216, -0.5686]])
Min value: -1.0, Max value: 1.0
import matplotlib.pyplot as plt
import torch
def dnorm(x:torch.Tensor):
min_value = -1
max_value = 1
out = (x - min_value) / (max_value - min_value)
return out.clamp(0,1) # plt expects values in [0,1]
img_norm = dnorm(img) # shape: (1, 28, 28)
plt.imshow(img_norm.squeeze(0), cmap='gray')
<matplotlib.image.AxesImage at 0x187d76c7990>
3. 制作数据加载器Dataloader
from torch.utils.data import DataLoader
batch_size = 100
data_loader = DataLoader(mnist, batch_size, shuffle=True)
4. 创建GAN的生成器与判别器并测试
查看
model.py
import torch.nn as nn
import torch.nn.functional as F
# 判别器网络
class Discriminator(nn.Module):
def __init__(self, image_size: int, hidden_size: int):
super(Discriminator, self).__init__()
self.linear1 = nn.Linear(image_size, hidden_size)
self.linear2 = nn.Linear(hidden_size, hidden_size)
self.linear3 = nn.Linear(hidden_size, 1)
def forward(self, x):
out = self.linear1(x)
out = F.leaky_relu(out, negative_slope=0.2, inplace=True)
out = self.linear2(out)
out = F.leaky_relu(out, negative_slope=0.2, inplace=True)
out = self.linear3(out)
return F.sigmoid(out)
# 生成器网络
class Generator(nn.Module):
def __init__(self, image_size: int, latent_size: int, hidden_size: int):
super(Generator, self).__init__()
self.linear1 = nn.Linear(latent_size, hidden_size)
self.linear2 = nn.Linear(hidden_size, hidden_size)
self.linear3 = nn.Linear(hidden_size, image_size)
def forward(self, x):
out = self.linear1(x)
out = F.relu(out)
out = self.linear2(out)
out = F.relu(out)
out = self.linear3(out)
return F.tanh(out)
from model import Generator, Discriminator
image_size = 28 * 28
hidden_size = 256
latent_size = 64
G = Generator(image_size=image_size, hidden_size=hidden_size, latent_size=latent_size)
D = Discriminator(image_size=image_size, hidden_size=hidden_size)
untrained_G_out = G(torch.randn(latent_size)) # Shape: [latent_size]
untrained_D_out = D(untrained_G_out.view(1, -1))
print(f"Result from Discriminator: {untrained_D_out.item():.4f}")
plt.imshow(untrained_G_out.view(28, 28).detach(), cmap='gray')
Result from Discriminator: 0.5166
5. 对抗训练模型
from torch import optim
from torch import nn
num_epochs = 300
device = "cuda:0" if torch.cuda.is_available() else "cpu"
D.to(device=device)
G.to(device=device)
d_optim = optim.Adam(D.parameters(), lr=0.002)
g_optim = optim.Adam(G.parameters(), lr=0.002)
criterion = nn.BCELoss()
d_loss_list, g_loss_list, real_score_list, fake_score_list = ([] for _ in range(4))
查看
training.py
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os
def run_discriminator_one_batch(d_net: nn.Module,
g_net: nn.Module,
batch_size: int,
latent_size: int,
images: torch.Tensor,
criterion: nn.Module,
optimizer: optim.Optimizer,
device: str):
# 定义真实样本与假样本的标签
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# 使用真实样本训练鉴别器
outputs = d_net(images)
d_loss_real = criterion(outputs, real_labels)
real_score = outputs
# 使用生成样本训练鉴别器
z = torch.randn(batch_size, latent_size).to(device)
fake_images = g_net(z)
outputs = d_net(fake_images.detach())
d_loss_fake = criterion(outputs, fake_labels)
fake_score = outputs
d_loss = d_loss_real + d_loss_fake # 计算总损失
d_loss.backward() # 反向传播
optimizer.step() # 更新参数
optimizer.zero_grad() # 清空梯度
return d_loss, real_score, fake_score
def run_generator_one_batch(d_net: nn.Module,
g_net: nn.Module,
batch_size: int,
latent_size: int,
criterion: nn.Module,
optimizer: optim.Optimizer,
device: str):
# 定义生成样本的标签和噪声
real_labels = torch.ones(batch_size, 1).to(device)
z = torch.randn(batch_size, latent_size).to(device)
# 训练生成器
fake_images = g_net(z)
outputs = d_net(fake_images)
g_loss = criterion(outputs, real_labels) # 计算判别器结果和真实标签的损失
g_loss.backward() # 反向传播
optimizer.step() # 更新参数
optimizer.zero_grad() # 清空梯度
return g_loss, fake_images
def generate_and_save_images(g_net: nn.Module,
batch_size: int,
latent_size: int,
device: str,
image_prefix: str,
index: int) -> bool:
def dnorm(x: torch.Tensor):
min_value = -1
max_value = 1
out = (x - min_value) / (max_value - min_value)
return out.clamp(0, 1) # plt expects values in [0,1]
sample_vectors = torch.randn(batch_size, latent_size).to(device)
fake_images = g_net(sample_vectors)
fake_images = fake_images.view(batch_size, 1, 28, 28)
if os.path.exists(image_prefix) is False:
os.makedirs(image_prefix)
save_image(dnorm(fake_images), os.path.join(image_prefix, f'fake_images-{index:03d}.png'), nrow=10)
return True
def run_epoch(d_net: nn.Module,
g_net: nn.Module,
train_loader: DataLoader,
criterion: nn.Module,
d_optim: optim.Optimizer,
g_optim: optim.Optimizer,
batch_size: int,
latent_size: int,
device: str,
d_loss_list: list,
g_loss_list: list,
real_score_list: list,
fake_score_list: list,
epoch: int, num_epochs: int):
d_net.train()
g_net.train()
for idx, (images, _) in enumerate(train_loader):
images = images.view(batch_size, -1).to(device)
# 训练鉴别器
d_loss, real_score, fake_score = run_discriminator_one_batch(d_net, g_net, batch_size, latent_size, images,
criterion, d_optim, device)
# 训练生成器
g_loss, _ = run_generator_one_batch(d_net, g_net, batch_size, latent_size, criterion, g_optim, device)
if (idx + 1) % 300 == 0:
num = f"Epoch: [{epoch + 1}/{num_epochs}], Batch: [{idx + 1}/{len(train_loader)}]"
loss_info = f"Discriminator Loss: {d_loss.item():.4f}, Generator Loss: {g_loss.item():.4f}"
real_sample_score = f"Real sample score for Discriminator D(x): {real_score.mean().item():.4f}"
fake_sample_score = f"Fake sample score for Discriminator D(G(x)): {fake_score.mean().item():.4f}"
print(num + loss_info)
print(num + real_sample_score)
print(num + fake_sample_score)
d_loss_list.append(d_loss.item())
g_loss_list.append(g_loss.item())
real_score_list.append(real_score.mean().item())
fake_score_list.append(fake_score.mean().item())
from training import run_epoch, generate_and_save_images
image_prefix = "./sample"
for epoch in range(num_epochs):
run_epoch(d_net=D, g_net=G,
train_loader=data_loader, criterion=criterion,
d_optim=d_optim, g_optim=g_optim,
batch_size=batch_size, latent_size=latent_size, device=device,
d_loss_list=d_loss_list, g_loss_list=g_loss_list,
real_score_list=real_score_list, fake_score_list=fake_score_list,
epoch=epoch, num_epochs=num_epochs)
if (epoch+1) % 10 == 0:
if generate_and_save_images(g_net=G, batch_size=batch_size,
latent_size=latent_size, device=device,
image_prefix=image_prefix, index=epoch+1):
print(f"Generated images at epoch {epoch+1}")
Epoch: [1/300], Batch: [300/600]Discriminator Loss: 1.1440, Generator Loss: 0.5215
Epoch: [1/300], Batch: [300/600]Real sample score for Discriminator D(x): 0.8644
Epoch: [1/300], Batch: [300/600]Fake sample score for Discriminator D(G(x)): 0.6283
Epoch: [1/300], Batch: [600/600]Discriminator Loss: 1.3556, Generator Loss: 0.8904
Epoch: [1/300], Batch: [600/600]Real sample score for Discriminator D(x): 0.9466
Epoch: [1/300], Batch: [600/600]Fake sample score for Discriminator D(G(x)): 0.6932
...
Epoch: [300/300], Batch: [600/600]Discriminator Loss: 1.1809, Generator Loss: 0.5166
Epoch: [300/300], Batch: [600/600]Real sample score for Discriminator D(x): 0.8612
Epoch: [300/300], Batch: [600/600]Fake sample score for Discriminator D(G(x)): 0.6094
Generated images at epoch 300
6. 保存checkpoint
import os
checkpoint_path = "./checkpoints"
if not os.path.exists(checkpoint_path):
os.makedirs(checkpoint_path)
torch.save(G.state_dict(), os.path.join(checkpoint_path, "G.pt"))
torch.save(D.state_dict(), os.path.join(checkpoint_path, "D.pt"))
7. 检查训练结果
损失变化与判别器评判分数
plt.plot(d_loss_list[::200], label="Discriminator Loss")
plt.plot(g_loss_list[::200], label="Generator Loss")
plt.xlabel("Step")
plt.ylabel("Loss")
plt.legend(loc='upper right', bbox_to_anchor=(1, 1))
plt.show()
plt.plot(real_score_list[::200], label="Real Score")
plt.plot(fake_score_list[::200], label="Fake Score")
plt.xlabel("Step")
plt.ylabel("Score")
plt.legend(loc='upper right', bbox_to_anchor=(1, 1))
plt.show()
生成的图像
from IPython.display import Image
Image(os.path.join(image_prefix, "fake_images-010.png"))
Image(os.path.join(image_prefix, "fake_images-300.png"))
运行环境
torch==2.1.1
torchvision==0.16.1
1.本站内容仅供参考,不作为任何法律依据。用户在使用本站内容时,应自行判断其真实性、准确性和完整性,并承担相应风险。
2.本站部分内容来源于互联网,仅用于交流学习研究知识,若侵犯了您的合法权益,请及时邮件或站内私信与本站联系,我们将尽快予以处理。
3.本文采用知识共享 署名4.0国际许可协议 [BY-NC-SA] 进行授权
4.根据《计算机软件保护条例》第十七条规定“为了学习和研究软件内含的设计思想和原理,通过安装、显示、传输或者存储软件等方式使用软件的,可以不经软件著作权人许可,不向其支付报酬。”您需知晓本站所有内容资源均来源于网络,仅供用户交流学习与研究使用,版权归属原版权方所有,版权争议与本站无关,用户本人下载后不能用作商业或非法用途,需在24个小时之内从您的电脑中彻底删除上述内容,否则后果均由用户承担责任;如果您访问和下载此文件,表示您同意只将此文件用于参考、学习而非其他用途,否则一切后果请您自行承担,如果您喜欢该程序,请支持正版软件,购买注册,得到更好的正版服务。
5.本站是非经营性个人站点,所有软件信息均来自网络,所有资源仅供学习参考研究目的,并不贩卖软件,不存在任何商业目的及用途
暂无评论内容