文章目录
- 引入
- 1 生成器
- 2 鉴别器
- 3 模型训练:生成器与鉴别器的交互
- 4 参数设置
- 5 数据载入
- 6 完整代码
- 7 部分输出图像示意
- 7.1 真实图像
- 7.2 训练200个批次
- 7.2 训练400个批次
- 7.2 训练600个批次
引入
论文详解:Unsupervised representation learning with deep convolutional generative adversarial networks
对抗生成网络的核心在于生成器、鉴别器以及两者之间的交互,本文将详细对这几个部分进行介绍。
1 生成器
DCGAN生成器的本质是多个卷积层、批量归一化、激活函数的堆叠,具体结构如下表:
结构 | 输入通道 | 输出通道 | 卷积核大小 | 步幅 | 填充 | 后续 |
---|---|---|---|---|---|---|
ConvTranspose2d | nz | ngf × \times × 8 | 4 | 1 | 0 | BatchNorm2d+ReLU |
ConvTranspose2d | ngf × \times × 8 | ngf × \times × 4 | 4 | 2 | 1 | BatchNorm2d+ReLU |
ConvTranspose2d | ngf × \times × 4 | ngf × \times × 2 | 4 | 2 | 1 | BatchNorm2d+ReLU |
ConvTranspose2d | ngf × \times × 2 | ngf | 4 | 2 | 1 | BatchNorm2d+ReLU |
ConvTranspose2d | ngf | nc | 4 | 2 | 1 | Tanh |
其中nz为输入通道数,ngf为给定结点数,nc是输出类别数。例如对于MNIST数据集,可设置nz=100,ngf=64,nc=1。
对应代码如下:
class Generator(nn.Module):"""生成器"""def __init__(self):super(Generator, self).__init__()# 使用的GPU数量self.ngpu = ngpu# 生成器结构,与表格中一致self.main = nn.Sequential(# 输入大小:nznn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),nn.BatchNorm2d(ngf * 8),nn.ReLU(True),# 大小:(ngf*8) x 4 x 4nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 4),nn.ReLU(True),# 大小:(ngf*4) x 8 x 8nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 2),nn.ReLU(True),# 大小:(ngf*2) x 16 x 16nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf),nn.ReLU(True),# 大小:(ngf) x 32 x 32nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),nn.Tanh()# 大小:(nc) x 64 x 64)def forward(self, input):if input.is_cuda and self.ngpu > 1:output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))else:output = self.main(input)return output
生成器的结构输出如下 (接下来都以mnist为例):
Generator((main): Sequential((0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True)(6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(8): ReLU(inplace=True)(9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(11): ReLU(inplace=True)(12): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(13): Tanh())
)
2 鉴别器
鉴别器与生成器的不同之处在于,其卷积层、激活函数的设置不同,输入通道数也是逐渐增加的:
结构 | 输入通道 | 输出通道 | 卷积核大小 | 步幅 | 填充 | 后续 |
---|---|---|---|---|---|---|
Conv2d | nc | ndf | 4 | 2 | 1 | LeakyReLU(0.2) |
Conv2d | ndf | ndf × \times × 2 | 4 | 2 | 1 | BatchNorm2d+LeakyReLU(0.2) |
Conv2d | ndf × \times × 2 | ndf × \times × 4 | 4 | 2 | 1 | BatchNorm2d+LeakyReLU(0.2) |
Conv2d | ndf × \times × 4 | ndf × \times × 8 | 4 | 2 | 1 | BatchNorm2d+LeakyReLU(0.2) |
Conv2d | ndf × \times × 8 | 1 | 4 | 1 | 0 | Sigmid() |
其中nc为输入通道数,ndf为给定结点数。注意输出通道变为1了哟。
对应代码如下:
class Discriminator(nn.Module):"""鉴别器"""def __init__(self):super(Discriminator, self).__init__()# 使用的GPU数量self.ngpu = ngpu# 鉴别器的结构self.main = nn.Sequential(# 输入大小: (nc) x 64 x 64nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf) x 32 x 32nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 2),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf*2) x 16 x 16nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 4),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf*4) x 8 x 8nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 8),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf*8) x 4 x 4nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),nn.Sigmoid())def forward(self, input):# 与生成器类似哟if input.is_cuda and self.ngpu > 1:output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))else:output = self.main(input)# 注意输出已经延展成一列的张量了return output.view(-1, 1).squeeze(1)
鉴别器的结构输出如下:
Discriminator((main): Sequential((0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(1): LeakyReLU(negative_slope=0.2, inplace=True)(2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(4): LeakyReLU(negative_slope=0.2, inplace=True)(5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(7): LeakyReLU(negative_slope=0.2, inplace=True)(8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(10): LeakyReLU(negative_slope=0.2, inplace=True)(11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)(12): Sigmoid())
)
3 模型训练:生成器与鉴别器的交互
下图绘制的是DCGAN生成器与鉴别器的交互过程,数字代表该步骤在程序中的运行过程:
训练过程的代码如下:
def DCGAN():"""DCGAN主函数"""# 每一轮训练for epoch in range(opt.nepoch):# 每一个批次的训练for i, data in enumerate(dataset, 0):"""步骤1:训练鉴别器,即最大化log(D(x)) + log(1 - D(G(z)))"""# 首先基于真实图像进行训练# 鉴别器的梯度清零netD.zero_grad()# 格式化当前批次,这里data = data[0]的原因是因为所有批次的图像是放在一个列表里面的data = data[0].to(device)# 获取当前批次的图像的数量batch_size = data.size(0)# 将当前批次所有图像的标签设置为指定的真实标签,如1label = torch.full((batch_size, ), real_label, dtype=data.dtype, device=device)# 先鉴别器输出一下output = netD(data)# 计算鉴别器上基于真实图像计算的损失errorD_real = loss(output, label)errorD_real.backward()D_x = output.mean().item()# 训练虚假图像# 随机生成一个虚假图像noise = torch.randn(batch_size, nz, 1, 1, device=device)# 生成器开始造假fake = netG(noise)# 标签设置为假的标签label.fill_(fake_label)# 鉴别器来判断output = netD(fake.detach())# 假图片的损失errorD_fake = loss(output, label)errorD_fake.backward()# 假图片的梯度D_G_z1 = output.mean().item()# 鉴别器的总损失errorD = errorD_real + errorD_fake# 鉴别器优化一下optimD.step()"""训练生成器"""# 生成器梯度清零netG.zero_grad()# 生成器填真实标签,毕竟想造假label.fill_(real_label)# 得到假图片的输出output = netD(fake)# 计算生成器的损失errorG = loss(output, label)# 生成器梯度清零errorG.backward()# 再一次假图片的梯度D_G_z2 = output.mean().item()optimG.step()# 输出一些关键信息print("[%d/%d][%d/%d] lossD: %.4f lossG: %.4f ""D(x): %.4f D(G(z)): %.4f/%.4f" % (epoch, opt.nepoch, i, len(dataset),errorD.item(), errorG.item(),D_x, D_G_z1, D_G_z2))# 存储图像,可以设置想要的存储时间结点哈if i % 100 == 0:vutils.save_image(data, "real_image.png", normalize=True)fake = netG(fixed_noise)vutils.save_image(fake.detach(),'fake_image_%03d_%03d.png' % (epoch, i), normalize=True)
4 参数设置
用的库如下:
from __future__ import print_function
import argparse
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
相关参数设置,变量的信息于help中给出:
def get_parser():"""获取参数设置器"""parser = argparse.ArgumentParser()# 设置实验用数据集的类型,help中为所支持的数据集类型parser.add_argument("--dataset", required=False, default="mnist",help="数据集类型:cifar10 | lsun | mnist |imagenet | folder | lfw | fake")parser.add_argument("--data_root", required=False, help="数据集的存储路径",default=r"D:\Data\OneDrive\Code\MIL1\Data")parser.add_argument("--workers", type=int, default=2, help="数据集下载并行数")parser.add_argument("--batch_size", type=int, default=64, help="数据集的输入批次大小")parser.add_argument("--image_size", type=int, default=64, help="输入图像的高/宽")parser.add_argument("--nz", type=int, default=100, help="隐含向量z的大小")parser.add_argument("--ngf", type=int, default=64, help="生成器隐藏层结点数")parser.add_argument("--ndf", type=int, default=64, help="鉴别器隐藏层结点数")parser.add_argument("--nepoch", type=int, default=5, help="训练轮次数")parser.add_argument("--lr", type=float, default=0.0002, help="学习率")parser.add_argument("--beta1", type=float, default=0.5, help="Adam的beta1")parser.add_argument("--cuda", action="store_true", help="CUDA是否可用")parser.add_argument("--ngpu", type=int, default=1, help="GPU使用数量")parser.add_argument("--manual_seed", type=int, help="随机种子")parser.add_argument("--classes", default="bedroom", help="LSUN卧室数据集的列表划分分隔符")return parser.parse_args()
管理随机种子:
def get_seed():"""管理随机种子"""if opt.manual_seed is None:opt.manual_seed = random.randint(1, 10000)random.seed(opt.manual_seed)torch.manual_seed(opt.manual_seed)
网络的权重等设置:
def init_weight(m):"""初始化权重"""classname = m.__class__.__name__if classname.find("Conv") != -1:torch.nn.init.normal_(m.weight, 0.0, 0.02)elif classname.find("BatchNorm") != -1:torch.nn.init.normal_(m.weight, 1.0, 0.02)torch.nn.init.zeros_(m.bias)
主函数:
if __name__ == '__main__':# 参数管理器opt = get_parser()# 设备device = torch.device("cuda:0" if opt.cuda else "cpu")# GPU数量、生成器输入通道数、生成器结点数设置、鉴别器结点数设置ngpu, nz, ngf, ndf = int(opt.ngpu), int(opt.nz), int(opt.ngf), int(opt.ndf)# 数据集、输出通道数dataset, nc = get_data()# 启动生成器netG = Generator().to(device)netG.apply(init_weight)# 启动鉴别器netD = Discriminator().to(device)netD.apply(init_weight)# 损失函数loss = nn.BCELoss()# 设置噪声及标签fixed_noise, real_label, fake_label = torch.randn(opt.batch_size, nz, 1, 1, device=device), 1, 0# 启动优化器optimG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))optimD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))DCGAN()
5 数据载入
所集成的数据集如下:
def get_data():"""获取数据集"""# 输出通道数nc = 3if opt.dataset in ["imagenet", "folder", "lfw"]:dataset = dset.ImageFolder(root=opt.data_root,transform=transforms.Compose([transforms.Resize(opt.image_size),transforms.CenterCrop(opt.image_size),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]))elif opt.dataset == "lsun":classes = [c + "_train" for c in opt.classes.split(',')]dataset = dset.LSUN(root=opt.data_root, classes=classes,transform=transforms.Compose([transforms.Resize(opt.image_size),transforms.CenterCrop(opt.image_size),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]))elif opt.dataset == "cifar10":dataset = dset.CIFAR10(root=opt.data_root, download=True,transform=transforms.Compose([transforms.Resize(opt.image_size),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]))elif opt.dataset == "mnist":dataset = dset.MNIST(root=opt.data_root, download=True,transform=transforms.Compose([transforms.Resize(opt.image_size),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)),]))nc = 1else:dataset = dset.FakeData(image_size=(3, opt.image_size, opt.image_size),transform=transforms.ToTensor())return (torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers)),nc)
6 完整代码
from __future__ import print_function
import argparse
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutilsdef get_parser():"""获取参数设置器"""parser = argparse.ArgumentParser()# 设置实验用数据集的类型,help中为所支持的数据集类型parser.add_argument("--dataset", required=False, default="mnist",help="数据集类型:cifar10 | lsun | mnist |imagenet | folder | lfw | fake")parser.add_argument("--data_root", required=False, help="数据集的存储路径",default=r"D:\Data\OneDrive\Code\MIL1\Data")parser.add_argument("--workers", type=int, default=2, help="数据集下载并行数")parser.add_argument("--batch_size", type=int, default=64, help="数据集的输入批次大小")parser.add_argument("--image_size", type=int, default=64, help="输入图像的高/宽")parser.add_argument("--nz", type=int, default=100, help="隐含向量z的大小")parser.add_argument("--ngf", type=int, default=64, help="生成器隐藏层结点数")parser.add_argument("--ndf", type=int, default=64, help="鉴别器隐藏层结点数")parser.add_argument("--nepoch", type=int, default=5, help="训练轮次数")parser.add_argument("--lr", type=float, default=0.0002, help="学习率")parser.add_argument("--beta1", type=float, default=0.5, help="Adam的beta1")parser.add_argument("--cuda", action="store_true", help="CUDA是否可用")parser.add_argument("--ngpu", type=int, default=1, help="GPU使用数量")parser.add_argument("--manual_seed", type=int, help="随机种子")parser.add_argument("--classes", default="bedroom", help="LSUN卧室数据集的列表划分分隔符")return parser.parse_args()def get_seed():"""管理随机种子"""if opt.manual_seed is None:opt.manual_seed = random.randint(1, 10000)random.seed(opt.manual_seed)torch.manual_seed(opt.manual_seed)def get_data():"""获取数据集"""# 输出通道数nc = 3if opt.dataset in ["imagenet", "folder", "lfw"]:dataset = dset.ImageFolder(root=opt.data_root,transform=transforms.Compose([transforms.Resize(opt.image_size),transforms.CenterCrop(opt.image_size),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]))elif opt.dataset == "lsun":classes = [c + "_train" for c in opt.classes.split(',')]dataset = dset.LSUN(root=opt.data_root, classes=classes,transform=transforms.Compose([transforms.Resize(opt.image_size),transforms.CenterCrop(opt.image_size),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]))elif opt.dataset == "cifar10":dataset = dset.CIFAR10(root=opt.data_root, download=True,transform=transforms.Compose([transforms.Resize(opt.image_size),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]))elif opt.dataset == "mnist":dataset = dset.MNIST(root=opt.data_root, download=True,transform=transforms.Compose([transforms.Resize(opt.image_size),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)),]))nc = 1else:dataset = dset.FakeData(image_size=(3, opt.image_size, opt.image_size),transform=transforms.ToTensor())return (torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers)),nc)def init_weight(m):"""初始化权重"""classname = m.__class__.__name__if classname.find("Conv") != -1:torch.nn.init.normal_(m.weight, 0.0, 0.02)elif classname.find("BatchNorm") != -1:torch.nn.init.normal_(m.weight, 1.0, 0.02)torch.nn.init.zeros_(m.bias)class Generator(nn.Module):"""生成器"""def __init__(self):super(Generator, self).__init__()# 使用的GPU数量self.ngpu = ngpu# 生成器结构,与表格中一致self.main = nn.Sequential(# 输入大小:nznn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),nn.BatchNorm2d(ngf * 8),nn.ReLU(True),# 大小:(ngf*8) x 4 x 4nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 4),nn.ReLU(True),# 大小:(ngf*4) x 8 x 8nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 2),nn.ReLU(True),# 大小:(ngf*2) x 16 x 16nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf),nn.ReLU(True),# 大小:(ngf) x 32 x 32nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),nn.Tanh()# 大小:(nc) x 64 x 64)def forward(self, input):if input.is_cuda and self.ngpu > 1:output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))else:output = self.main(input)return outputclass Discriminator(nn.Module):"""鉴别器"""def __init__(self):super(Discriminator, self).__init__()# 使用的GPU数量self.ngpu = ngpu# 鉴别器的结构self.main = nn.Sequential(# 输入大小: (nc) x 64 x 64nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf) x 32 x 32nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 2),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf*2) x 16 x 16nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 4),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf*4) x 8 x 8nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 8),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf*8) x 4 x 4nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),nn.Sigmoid())def forward(self, input):# 与生成器类似哟if input.is_cuda and self.ngpu > 1:output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))else:output = self.main(input)# 注意输出已经延展成一列的张量了return output.view(-1, 1).squeeze(1)def DCGAN():"""DCGAN主函数"""# 每一轮训练for epoch in range(opt.nepoch):# 每一个批次的训练for i, data in enumerate(dataset, 0):"""步骤1:训练鉴别器,即最大化log(D(x)) + log(1 - D(G(z)))"""# 首先基于真实图像进行训练# 鉴别器的梯度清零netD.zero_grad()# 格式化当前批次,这里data = data[0]的原因是因为所有批次的图像是放在一个列表里面的data = data[0].to(device)# 获取当前批次的图像的数量batch_size = data.size(0)# 将当前批次所有图像的标签设置为指定的真实标签,如1label = torch.full((batch_size, ), real_label, dtype=data.dtype, device=device)# 先鉴别器输出一下output = netD(data)# 计算鉴别器上基于真实图像计算的损失errorD_real = loss(output, label)errorD_real.backward()D_x = output.mean().item()# 训练虚假图像# 随机生成一个虚假图像noise = torch.randn(batch_size, nz, 1, 1, device=device)# 生成器开始造假fake = netG(noise)# 标签设置为假的标签label.fill_(fake_label)# 鉴别器来判断output = netD(fake.detach())# 假图片的损失errorD_fake = loss(output, label)errorD_fake.backward()# 假图片的梯度D_G_z1 = output.mean().item()# 鉴别器的总损失errorD = errorD_real + errorD_fake# 鉴别器优化一下optimD.step()"""训练生成器"""# 生成器梯度清零netG.zero_grad()# 生成器填真实标签,毕竟想造假label.fill_(real_label)# 得到假图片的输出output = netD(fake)# 计算生成器的损失errorG = loss(output, label)# 生成器梯度清零errorG.backward()# 再一次假图片的梯度D_G_z2 = output.mean().item()optimG.step()# 输出一些关键信息print("[%d/%d][%d/%d] lossD: %.4f lossG: %.4f ""D(x): %.4f D(G(z)): %.4f/%.4f" % (epoch, opt.nepoch, i, len(dataset),errorD.item(), errorG.item(),D_x, D_G_z1, D_G_z2))# 存储图像,可以设置想要的存储时间结点哈if i % 100 == 0:vutils.save_image(data, "real_image.png", normalize=True)fake = netG(fixed_noise)vutils.save_image(fake.detach(),'fake_image_%03d_%03d.png' % (epoch, i), normalize=True)if __name__ == '__main__':# 参数管理器opt = get_parser()# 设备device = torch.device("cuda:0" if opt.cuda else "cpu")# GPU数量、生成器输入通道数、生成器结点数设置、鉴别器结点数设置ngpu, nz, ngf, ndf = int(opt.ngpu), int(opt.nz), int(opt.ngf), int(opt.ndf)# 数据集、输出通道数dataset, nc = get_data()# 启动生成器netG = Generator().to(device)netG.apply(init_weight)# 启动鉴别器netD = Discriminator().to(device)netD.apply(init_weight)# 损失函数loss = nn.BCELoss()# 设置噪声及标签fixed_noise, real_label, fake_label = torch.randn(opt.batch_size, nz, 1, 1, device=device), 1, 0# 启动优化器optimG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))optimD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))DCGAN()
7 部分输出图像示意
7.1 真实图像
7.2 训练200个批次
7.2 训练400个批次
7.2 训练600个批次
设备有限,就这么多了: