基于pytorch的SRGAN实现(全网最细!!!)

article/2025/11/8 10:10:50

基于pytorch的SRGAN实现

    • 前言
    • SRGAN论文概要(贡献)
    • 网络结构和损失函数
    • pytorch代码实现
      • 1. 准备工作
        • 1.1 数据下载并放到合适位置
      • 2. 开始训练和测试
    • 源码详解
      • 1. 数据集的加载: data_utils.py
      • 2. 网络模型: model.py
        • 2.1 生成器: Generator
        • 2.2 判别器:Discriminator
        • 2.3 残差块: ResidualBlock
        • 2.4 上采样块: UpsampleBLock
      • 3. 损失函数: loss.py
      • 4.训练:train.py
      • 5. 测试基准数据集: test_benchmark.py
      • 6. 测试单张图片:test_image.py(原理和基准测试集相同,只需只能需要测试的图片名字即可)

前言

SRGAN是发表在顶会CVPR2017的文章, 利用GAN进行超分实现了不错的效果!
注意到的是,SRGAN是目前SR领域中引用量最高的论文.
链接地址: https://arxiv.org/abs/1609.04802v5
作者代码链接: https://github.com/leftthomas/SRGAN

SRGAN论文概要(贡献)

  1. 深度RESNet(SRRESNet)针对MSE进行了优化,通过PSNR和结构相似度(SSIM)来测量图像SR的高放大因子
  2. SRGAN,是一种基于GAN的网络,针对一种新的感知损失进行了优化。用在VGG网络的特征映射上计算的损失来代替基于MSE的内容损失,该特征映射对像素空间的变化更加不变,这样相较于原来像素损失超分的图像更具有纹理等高频细节.
  3. 对来自三个公共基准数据集的图像进行广泛的平均意见得分(MOS)测试,证实SRGAN在很大程度上是高放大因子(4×)的照片真实感SR图像估计的最新技术, 即超分后的图像更加接近自然图像.

网络结构和损失函数

对应的详解在相应的代码实现处
网络模型:
在这里插入图片描述

Perceptual loss function(感知损失函数或总损失)
在这里插入图片描述

Content loss(内容损失)
在这里插入图片描述

Adversarial loss(对抗损失)
在这里插入图片描述

pytorch代码实现

1. 准备工作

1.1 数据下载并放到合适位置

train 和 val 数据集是从VOC2012中采样得到的
VOC2012:链接地址 提取码: 5tzp

测试图像数据集来自Set5 Set14 BSD100 Urban100 SunHays80 链接地址

下载图像数据集,然后将其解压到data目录中
如图所示:
在这里插入图片描述

注意: 如需训练自己的数据集,请准备好原图和对应插值缩放后的图片

2. 开始训练和测试

训练: (1) 打开终端,进入当前文件目录
在这里插入图片描述

(2) 选择指定的参数, 未指定的情况下按代码中的默认值处理
在这里插入图片描述

也可以直接运行README文件中相应代码:
在这里插入图片描述

训练完成后,训练结果会保存到benchmark_results 文件夹中
![在这里插入图片描述](https://img-blog.csdnimg.cn/e58ebf2affbc49709b4d838aef2dc1ca.png

测试过程同训练过程一样, 对应实现即可!其它参数细节论文里面均有说明

源码详解

1. 数据集的加载: data_utils.py

from os import listdir
from os.path import joinfrom PIL import Image
from torch.utils.data.dataset import Dataset
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resizedef is_image_file(filename):# 判断文件名是否是图像文件return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])def calculate_valid_crop_size(crop_size, upscale_factor):# 计算可用的裁剪尺寸return crop_size - (crop_size % upscale_factor)def train_hr_transform(crop_size):# 训练集的高分辨率图像转换return Compose([RandomCrop(crop_size),  # 随机裁剪图像到指定尺寸ToTensor(),  # 将图像转换为张量])def train_lr_transform(crop_size, upscale_factor):# 训练集的低分辨率图像转换return Compose([ToPILImage(),  # 将张量转换为PIL图像Resize(crop_size // upscale_factor, interpolation=Image.BICUBIC), # 将图像缩放到指定尺寸,使用双三次插值方法ToTensor()   # 将图像转换为张量])def display_transform():# 显示图像的转换return Compose([ToPILImage(),   # 将张量转换为PIL图像Resize(400),  # 将图像调整大小为400x400CenterCrop(400), # 对图像进行中心裁剪为400x400ToTensor()   # 将图像转换为张量])#  加载训练集中的图像数据
class TrainDatasetFromFolder(Dataset):def __init__(self, dataset_dir, crop_size, upscale_factor):super(TrainDatasetFromFolder, self).__init__()# 获取目录中的所有图像文件名,并使用is_image_file函数来筛选出图像文件self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]# 计算可用的裁剪尺寸crop_size = calculate_valid_crop_size(crop_size, upscale_factor)# 分别创建高分辨率和低分辨率图像的转换操作self.hr_transform = train_hr_transform(crop_size)self.lr_transform = train_lr_transform(crop_size, upscale_factor)def __getitem__(self, index):# 获取给定索引的图像数据hr_image = self.hr_transform(Image.open(self.image_filenames[index]))lr_image = self.lr_transform(hr_image)return lr_image, hr_imagedef __len__(self):# 返回数据集的大小(图像数量)return len(self.image_filenames)#  加载验证集中的图像数据 同训练集
class ValDatasetFromFolder(Dataset):def __init__(self, dataset_dir, upscale_factor):super(ValDatasetFromFolder, self).__init__()self.upscale_factor = upscale_factorself.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]def __getitem__(self, index):# 打开高分辨率图像文件hr_image = Image.open(self.image_filenames[index])w, h = hr_image.size  # 获取图像的宽度和高度crop_size = calculate_valid_crop_size(min(w, h), self.upscale_factor)   # 计算可用的裁剪尺寸lr_scale = Resize(crop_size // self.upscale_factor, interpolation=Image.BICUBIC)  # 缩放图像为低分辨率图像hr_scale = Resize(crop_size, interpolation=Image.BICUBIC)  # 缩放图像为高分辨率图像hr_image = CenterCrop(crop_size)(hr_image)    # 对高分辨率图像进行中心裁剪lr_image = lr_scale(hr_image)   # 缩放得到低分辨率图像hr_restore_img = hr_scale(lr_image) # 缩放得到还原后的高分辨率图像return ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)def __len__(self):return len(self.image_filenames)#  加载测试集中的图像数据 同训练集
class TestDatasetFromFolder(Dataset):def __init__(self, dataset_dir, upscale_factor):super(TestDatasetFromFolder, self).__init__()self.lr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/data/'  # 构建低分辨率图像文件路径self.hr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/target/' # 构建高分辨率图像文件路径self.upscale_factor = upscale_factor# 获取两个路径下的图像文件名,并保存在lr_filenames和hr_filenames列表中。self.lr_filenames = [join(self.lr_path, x) for x in listdir(self.lr_path) if is_image_file(x)]self.hr_filenames = [join(self.hr_path, x) for x in listdir(self.hr_path) if is_image_file(x)]def __getitem__(self, index):# 获取给定索引的图像数据image_name = self.lr_filenames[index].split('/')[-1]lr_image = Image.open(self.lr_filenames[index])  # 打开低分辨率图像文件w, h = lr_image.size  # 获取低分辨率图像的宽度和高度hr_image = Image.open(self.hr_filenames[index])  # 打开高分辨率图像文件hr_scale = Resize((self.upscale_factor * h, self.upscale_factor * w), interpolation=Image.BICUBIC) # 缩放高分辨率图像hr_restore_img = hr_scale(lr_image)  # 缩放得到还原后的高分辨率图像# 将图像文件名、低分辨率图像、还原后的高分辨率图像和原始高分辨率图像转换为张量并返回return image_name, ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)def __len__(self):return len(self.lr_filenames)

2. 网络模型: model.py

2.1 生成器: Generator

import math
import torch
from torch import nn# 生成器模型
class Generator(nn.Module):def __init__(self, scale_factor):# 计算需要进行上采样的块的数量upsample_block_num = int(math.log(scale_factor, 2))super(Generator, self).__init__()# 二维卷积层,输入通道数为3,输出通道数为64,卷积核大小为9,填充为4self.block1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=9, padding=4),nn.PReLU()  # Parametric ReLU激活函数)self.block2 = ResidualBlock(64) # 定义(残差)ResidualBlock模块self.block3 = ResidualBlock(64)self.block4 = ResidualBlock(64)self.block5 = ResidualBlock(64)self.block6 = ResidualBlock(64)self.block7 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1),nn.BatchNorm2d(64))# 由多个UpsampleBlock模块组成的列表block8 = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)]block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))self.block8 = nn.Sequential(*block8)  # 由block8列表中的模块组成的序列模块def forward(self, x):block1 = self.block1(x)block2 = self.block2(block1)block3 = self.block3(block2)block4 = self.block4(block3)block5 = self.block5(block4)block6 = self.block6(block5)block7 = self.block7(block6)block8 = self.block8(block1 + block7)# 将输出限制在0到1之间,通过tanh激活函数和缩放操作得到最终生成的图像return (torch.tanh(block8) + 1) / 2

2.2 判别器:Discriminator

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.net = nn.Sequential(# 二维卷积层,输入通道数为3,输出通道数为64,卷积核大小为3,填充为1nn.Conv2d(3, 64, kernel_size=3, padding=1),nn.LeakyReLU(0.2),  # LeakyReLU激活函数nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),nn.BatchNorm2d(64),nn.LeakyReLU(0.2),nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.BatchNorm2d(128),nn.LeakyReLU(0.2),nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),nn.BatchNorm2d(128),nn.LeakyReLU(0.2),nn.Conv2d(128, 256, kernel_size=3, padding=1),nn.BatchNorm2d(256),nn.LeakyReLU(0.2),nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),nn.BatchNorm2d(256),nn.LeakyReLU(0.2),nn.Conv2d(256, 512, kernel_size=3, padding=1),nn.BatchNorm2d(512),nn.LeakyReLU(0.2),nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),nn.BatchNorm2d(512),nn.LeakyReLU(0.2),# 自适应平均池化层,将输入特征图转换为大小为1x1的特征图nn.AdaptiveAvgPool2d(1),nn.Conv2d(512, 1024, kernel_size=1),nn.LeakyReLU(0.2),nn.Conv2d(1024, 1, kernel_size=1))def forward(self, x):# 输入批次的大小batch_size = x.size(0)# 使用torch.sigmoid函数将特征图映射到0到1之间,表示输入图像为真实图像的概率。return torch.sigmoid(self.net(x).view(batch_size))

2.3 残差块: ResidualBlock

class ResidualBlock(nn.Module):def __init__(self, channels):super(ResidualBlock, self).__init__()# 二维卷积层,输入通道数为channels,输出通道数为channels,卷积核大小为3,填充为1self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)self.bn1 = nn.BatchNorm2d(channels)  # 二维批归一化层self.prelu = nn.PReLU()  # Parametric ReLU激活函数self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)self.bn2 = nn.BatchNorm2d(channels)  # 二维批归一化层def forward(self, x):# 应用对应的layer得到前向传播的输出(残差项)residual = self.conv1(x)residual = self.bn1(residual)residual = self.prelu(residual)residual = self.conv2(residual)residual = self.bn2(residual)return x + residual  # 将输入x与残差项相加,得到最终输出

2.4 上采样块: UpsampleBLock

# 上采样块
class UpsampleBLock(nn.Module):def __init__(self, in_channels, up_scale):super(UpsampleBLock, self).__init__()# 卷积层,输入通道数为in_channels,输出通道数为in_channels * 2 ** 2,卷积核大小为3,填充为1self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)# 像素重排操作,上采样因子为up_scaleself.pixel_shuffle = nn.PixelShuffle(up_scale)self.prelu = nn.PReLU()def forward(self, x):x = self.conv(x)x = self.pixel_shuffle(x)x = self.prelu(x)return x

3. 损失函数: loss.py

import torch
from torch import nn
from torchvision.models.vgg import vgg16class GeneratorLoss(nn.Module):def __init__(self):super(GeneratorLoss, self).__init__()# 使用预训练的 VGG16 模型来构建特征提取网络vgg = vgg16(pretrained=True)# 选择 VGG16 模型的前 31 层作为损失网络,并将其设置为评估模式(不进行梯度更新)loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()# 冻结其参数,不进行梯度更新for param in loss_network.parameters():param.requires_grad = Falseself.loss_network = loss_network# 定义均方误差损失函数: 计算生成器生成图像与目标图像之间的均方误差损失self.mse_loss = nn.MSELoss()# 定义总变差损失函数: 计算生成器生成图像的总变差损失,用于平滑生成的图像self.tv_loss = TVLoss()def forward(self, out_labels, out_images, target_images):# Adversarial Loss(对抗损失):使生成的图像更接近真实图像,目标是最小化生成器对图像的判别结果的平均值与 1 的差距adversarial_loss = torch.mean(1 - out_labels)# Perception Loss(感知损失):计算生成图像和目标图像在特征提取网络中提取的特征之间的均方误差损失perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))# Image Loss(图像损失):计算生成图像和目标图像之间的均方误差损失image_loss = self.mse_loss(out_images, target_images)# TV Loss(总变差损失):计算生成图像的总变差损失,用于平滑生成的图像tv_loss = self.tv_loss(out_images)# 返回生成器的总损失,四个损失项加权求和return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_lossclass TVLoss(nn.Module):def __init__(self, tv_loss_weight=1):super(TVLoss, self).__init__()self.tv_loss_weight = tv_loss_weightdef forward(self, x):batch_size = x.size()[0]h_x = x.size()[2]w_x = x.size()[3]count_h = self.tensor_size(x[:, :, 1:, :])count_w = self.tensor_size(x[:, :, :, 1:])# 计算水平方向上的总变差损失h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()# 计算垂直方向上的总变差损失w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()# 返回总变差损失return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size@staticmethoddef tensor_size(t):# 返回张量的尺寸大小,即通道数乘以高度乘以宽度return t.size()[1] * t.size()[2] * t.size()[3]if __name__ == "__main__":g_loss = GeneratorLoss()print(g_loss)

4.训练:train.py

import argparse
import os
from math import log10import pandas as pd
import torch.optim as optim
import torch.utils.data
import torchvision.utils as utils
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdmimport pytorch_ssim
from data_utils import TrainDatasetFromFolder, ValDatasetFromFolder, display_transform
from loss import GeneratorLoss
from model import Generator, Discriminator# 创建一个命令行参数解析器对象
parser = argparse.ArgumentParser(description='Train Super Resolution Models')
# 用于指定训练图像的裁剪尺寸,默认为88
parser.add_argument('--crop_size', default=88, type=int, help='training images crop size')
# 用于指定超分辨率的放大因子,默认为4
parser.add_argument('--upscale_factor', default=4, type=int, choices=[2, 4, 8],help='super resolution upscale factor')
# 用于指定训练的轮数,默认为100
parser.add_argument('--num_epochs', default=100, type=int, help='train epoch number')if __name__ == '__main__':# 解析命令行参数并将结果存储在变量opt中opt = parser.parse_args()# 从opt中获取crop_size、upscale_factor和num_epochs的值,并分别赋给对应的变量CROP_SIZE = opt.crop_sizeUPSCALE_FACTOR = opt.upscale_factorNUM_EPOCHS = opt.num_epochs# 创建训练数据集对象TrainDatasetFromFolder,指定数据集路径、裁剪尺寸和放大因子train_set = TrainDatasetFromFolder('data/VOC2012/train', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)# 创建验证数据集对象ValDatasetFromFolder,指定数据集路径和放大因子val_set = ValDatasetFromFolder('data/VOC2012/val', upscale_factor=UPSCALE_FACTOR)# 创建训练数据加载器,指定数据集对象、工作线程数、批量大小和是否打乱数据顺序train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=64, shuffle=True)# 创建验证数据加载器,指定数据集对象、工作线程数、批量大小和是否打乱数据顺序val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False)# 创建生成器模型对象Generator,指定放大因子netG = Generator(UPSCALE_FACTOR)# 输出生成器模型参数的数量print('# generator parameters:', sum(param.numel() for param in netG.parameters()))# 创建生成器损失函数对象GeneratorLossnetD = Discriminator()# 输出判别器模型参数的数量print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))# 创建生成器损失函数对象GeneratorLossgenerator_criterion = GeneratorLoss()# GPU如果可用的话,将生成器模型、判别器模型和生成器损失函数移动到GPU上进行计算if torch.cuda.is_available():netG.cuda()netD.cuda()generator_criterion.cuda()# 创建生成器和判别器的优化器对象,用于更新模型参数optimizerG = optim.Adam(netG.parameters())optimizerD = optim.Adam(netD.parameters())# 创建一个字典用于存储训练过程中的判别器和生成器的损失、分数和评估指标结果(信噪比和相似性)results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}for epoch in range(1, NUM_EPOCHS + 1):# 创建训练数据的进度条train_bar = tqdm(train_loader)running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}netG.train()  # 将生成器设置为训练模式netD.train()  # 将判别器设置为训练模式for data, target in train_bar:g_update_first = Truebatch_size = data.size(0)running_results['batch_sizes'] += batch_size# (1) Update D network: maximize D(x)-1-D(G(z))real_img = Variable(target)if torch.cuda.is_available():real_img = real_img.cuda()z = Variable(data)if torch.cuda.is_available():z = z.cuda()fake_img = netG(z)  # 通过生成器生成伪图像# 清除判别器的梯度netD.zero_grad()# 通过判别器对真实图像进行前向传播,并计算其输出的平均值real_out = netD(real_img).mean()# 通过判别器对伪图像进行前向传播,并计算其输出的平均值fake_out = netD(fake_img).mean()# 计算判别器的损失d_loss = 1 - real_out + fake_out# 在判别器网络中进行反向传播,并保留计算图以进行后续优化步骤d_loss.backward(retain_graph=True)# 利用优化器对判别器网络的参数进行更新optimizerD.step()# (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV LossnetG.zero_grad()# The two lines below are added to prevent runtime error in Google Colab# 通过生成器对输入图像(z)进行生成,生成伪图像(fake_img)fake_img = netG(z)# 通过判别器对伪图像进行前向传播,并计算其输出的平均值fake_out = netD(fake_img).mean()### 计算生成器的损失,包括对抗损失、感知损失、图像损失和TV损失g_loss = generator_criterion(fake_out, fake_img, real_img)# 在生成器网络中进行反向传播,计算生成器的梯度g_loss.backward()# 再次通过生成器对输入图像(z)进行生成,得到新的伪图像(fake_img)fake_img = netG(z)# 通过判别器对新的伪图像进行前向传播,并计算其输出的平均值fake_out = netD(fake_img).mean()# 利用优化器对生成器网络的参数进行更新optimizerG.step()# loss for current batch before optimization# 累加当前批次生成器的损失值乘以批次大小,用于计算平均损失running_results['g_loss'] += g_loss.item() * batch_size# 累加当前批次判别器的损失值乘以批次大小,用于计算平均损失running_results['d_loss'] += d_loss.item() * batch_size# 累加当前批次真实图像在判别器的输出得分乘以批次大小,用于计算平均得分running_results['d_score'] += real_out.item() * batch_size# 累加当前批次伪图像在判别器的输出得分乘以批次大小,用于计算平均得分running_results['g_score'] += fake_out.item() * batch_size# 更新训练进度条的描述信息train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (epoch, NUM_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'],running_results['g_loss'] / running_results['batch_sizes'],running_results['d_score'] / running_results['batch_sizes'],running_results['g_score'] / running_results['batch_sizes']))netG.eval()# 创建用于保存训练结果的目录out_path = 'training_results/SRF_' + str(UPSCALE_FACTOR) + '/'if not os.path.exists(out_path):os.makedirs(out_path)with torch.no_grad():val_bar = tqdm(val_loader)valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0}val_images = []# 遍历验证数据集(低分辨率图 恢复的高分辨率图 高分辨率图)for val_lr, val_hr_restore, val_hr in val_bar:batch_size = val_lr.size(0)valing_results['batch_sizes'] += batch_sizelr = val_lrhr = val_hrif torch.cuda.is_available():lr = lr.cuda()hr = hr.cuda()# 生成超分辨率图像sr = netG(lr)# 计算批量图像的均方误差batch_mse = ((sr - hr) ** 2).data.mean()# 累加均方误差valing_results['mse'] += batch_mse * batch_size# 计算批量图像的结构相似度指数batch_ssim = pytorch_ssim.ssim(sr, hr).item()# 累加结构相似度指数valing_results['ssims'] += batch_ssim * batch_size# 计算平均峰值信噪比valing_results['psnr'] = 10 * log10((hr.max()**2) / (valing_results['mse'] / valing_results['batch_sizes']))# 计算平均结构相似度指数valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes']# 更新训练进度条的描述信息val_bar.set_description(desc='[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % (valing_results['psnr'], valing_results['ssim']))val_images.extend(# 将图像应用转换函数,并添加到验证图像列表[display_transform()(val_hr_restore.squeeze(0)), display_transform()(hr.data.cpu().squeeze(0)),display_transform()(sr.data.cpu().squeeze(0))])# 将验证图像列表堆叠为张量val_images = torch.stack(val_images)# 将堆叠后的张量分割为多个小块,每个小块包含15张图像val_images = torch.chunk(val_images, val_images.size(0) // 15)# 创建进度条,并设置描述为“[saving training results]”val_save_bar = tqdm(val_images, desc='[saving training results]')index = 1for image in val_save_bar:# 将小块中的图像创建为一个网格,每行显示3张图像,图像之间有5个像素的间隔image = utils.make_grid(image, nrow=3, padding=5)# 将网格图像保存为文件,文件名包含epoch和index信息utils.save_image(image, out_path + 'epoch_%d_index_%d.png' % (epoch, index), padding=5)index += 1# save model parameters# 将判别器和生成器的参数保存到指定文件torch.save(netG.state_dict(), 'epochs/netG_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch))torch.save(netD.state_dict(), 'epochs/netD_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch))# save loss\scores\psnr\ssimresults['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes'])results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])results['d_score'].append(running_results['d_score'] / running_results['batch_sizes'])results['g_score'].append(running_results['g_score'] / running_results['batch_sizes'])results['psnr'].append(valing_results['psnr'])results['ssim'].append(valing_results['ssim'])if epoch % 10 == 0 and epoch != 0:out_path = 'statistics/'# 创建一个DataFrame对象,用于存储训练结果数据data_frame = pd.DataFrame(data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'],'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},index=range(1, epoch + 1))# 将DataFrame对象保存为CSV文件data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) + '_train_results.csv', index_label='Epoch')

5. 测试基准数据集: test_benchmark.py

import argparse
import os
from math import log10import numpy as np
import pandas as pd
import torch
import torchvision.utils as utils
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdmimport pytorch_ssim
from data_utils import TestDatasetFromFolder, display_transform
from model import Generatorparser = argparse.ArgumentParser(description='Test Benchmark Datasets')
parser.add_argument('--upscale_factor', default=4, type=int, help='super resolution upscale factor')
parser.add_argument('--model_name', default='netG_epoch_4_150.pth', type=str, help='generator model epoch name')
opt = parser.parse_args()UPSCALE_FACTOR = opt.upscale_factor
MODEL_NAME = opt.model_name# 保存每个测试数据集的结果
results = {'Set5': {'psnr': [], 'ssim': []}, 'Set14': {'psnr': [], 'ssim': []}, 'BSD100': {'psnr': [], 'ssim': []},'Urban100': {'psnr': [], 'ssim': []}, 'SunHays80': {'psnr': [], 'ssim': []}}# 创建一个 Generator 对象
model = Generator(UPSCALE_FACTOR).eval()
if torch.cuda.is_available():model = model.cuda()
# 加载训练好的模型参数
model.load_state_dict(torch.load('epochs/' + MODEL_NAME))# 加载测试数据集
test_set = TestDatasetFromFolder('data/test', upscale_factor=UPSCALE_FACTOR)
test_loader = DataLoader(dataset=test_set, num_workers=4, batch_size=1, shuffle=False)
# 创建一个用于 test_loader 的 tqdm 进度条
test_bar = tqdm(test_loader, desc='[testing benchmark datasets]')# 测试结果输出路径
out_path = 'benchmark_results/SRF_' + str(UPSCALE_FACTOR) + '/'
if not os.path.exists(out_path):os.makedirs(out_path)for image_name, lr_image, hr_restore_img, hr_image in test_bar:# 由于 image_name 是一个包含单个元素的列表,所以将其取出image_name = image_name[0]# 将 lr_image 转换为 Variable 对象,并设置 volatile=True# volatile=True 表示不会计算梯度,这在推理阶段通常是需要的lr_image = Variable(lr_image, volatile=True)hr_image = Variable(hr_image, volatile=True)if torch.cuda.is_available():lr_image = lr_image.cuda()hr_image = hr_image.cuda()# 生成超分变率图像sr_image = model(lr_image)mse = ((hr_image - sr_image) ** 2).data.mean()# 计算峰值信噪比(Peak Signal-to-Noise Ratio)psnr = 10 * log10(255 ** 2 / mse)# 计算结构相似性指数(Structural Similarity Index)# 使用 pytorch_ssim 库中的 ssim 函数计算 SSIMssim = pytorch_ssim.ssim(sr_image, hr_image).data[0]# 创建一个包含三张图像的张量,分别是原始恢复的高分辨率图像、原始高分辨率图像和生成的超分辨率图像# 将每张图像应用 display_transform() 转换,并通过 squeeze(0) 去除批次维度test_images = torch.stack([display_transform()(hr_restore_img.squeeze(0)), display_transform()(hr_image.data.cpu().squeeze(0)),display_transform()(sr_image.data.cpu().squeeze(0))])# 使用 make_grid 函数将三张图像拼接成一张大图像# nrow=3 表示每行显示 3 张图像,padding=5 表示图像之间的间距为 5image = utils.make_grid(test_images, nrow=3, padding=5)# 使用 save_image 函数将合成的图像保存到指定路径utils.save_image(image, out_path + image_name.split('.')[0] + '_psnr_%.4f_ssim_%.4f.' % (psnr, ssim) +image_name.split('.')[-1], padding=5)# 将对应数据集的PSNR和SSIM保存到对应的字典当中results[image_name.split('_')[0]]['psnr'].append(psnr)results[image_name.split('_')[0]]['ssim'].append(ssim)# 最终结果保存路径
out_path = 'statistics/'
saved_results = {'psnr': [], 'ssim': []}# 遍历 results 字典中的每个值
for item in results.values():# 获取 PSNR 和 SSIM 的列表psnr = np.array(item['psnr'])ssim = np.array(item['ssim'])# 如果列表为空,将 PSNR 和 SSIM 设置为 'No data'if (len(psnr) == 0) or (len(ssim) == 0):psnr = 'No data'ssim = 'No data'else:# 如果列表不为空,计算 PSNR 和 SSIM 的均值psnr = psnr.mean()ssim = ssim.mean()# 将计算得到的 PSNR 和 SSIM 添加到 saved_results 字典的相应列表中saved_results['psnr'].append(psnr)saved_results['ssim'].append(ssim)# 创建一个 DataFrame 对象,使用 saved_results 字典作为数据,以 results.keys() 作为列标签
data_frame = pd.DataFrame(saved_results, results.keys())
# 将 DataFrame 对象保存为 CSV 文件
# 文件路径由 out_path、'srf_'、UPSCALE_FACTOR 值和 '_test_results.csv' 组成
# index_label='DataSet' 表示使用 'DataSet' 作为索引标签
data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) + '_test_results.csv', index_label='DataSet')

6. 测试单张图片:test_image.py(原理和基准测试集相同,只需只能需要测试的图片名字即可)

import argparse
import timeimport torch
from PIL import Image
from torch.autograd import Variable
from torchvision.transforms import ToTensor, ToPILImagefrom model import Generatorparser = argparse.ArgumentParser(description='Test Single Image')
parser.add_argument('--upscale_factor', default=4, type=int, help='super resolution upscale factor')
parser.add_argument('--test_mode', default='CPU', type=str, choices=['GPU', 'CPU'], help='using GPU or CPU')
parser.add_argument('--image_name', default='data/test/SRF_4/data/Set5_003.png',type=str, help='test low resolution image name')
parser.add_argument('--model_name', default='netG_epoch_4_150.pth', type=str, help='generator model epoch name')
opt = parser.parse_args()UPSCALE_FACTOR = opt.upscale_factor
TEST_MODE = True if opt.test_mode == 'GPU' else False
IMAGE_NAME = opt.image_name
MODEL_NAME = opt.model_namemodel = Generator(UPSCALE_FACTOR).eval()
if TEST_MODE:model.cuda()model.load_state_dict(torch.load('epochs/' + MODEL_NAME))
else:model.load_state_dict(torch.load('epochs/' + MODEL_NAME, map_location=lambda storage, loc: storage))image = Image.open(IMAGE_NAME)
image = Variable(ToTensor()(image), volatile=True).unsqueeze(0)
if TEST_MODE:image = image.cuda()start = time.time()
out = model(image)
elapsed = (time.time() - start)
print('cost: ' + str(elapsed) + 's')
out_img = ToPILImage()(out[0].data.cpu())
# out_img.save('out_srf_' + str(UPSCALE_FACTOR) + '_' + IMAGE_NAME)out_img.show()
save_path = 'single_test/image_name.jpg'
out_img.save(save_path)
print("图像已保存到文件夹中。")

本人水平有限,文中发现错误敬请指正,如有相同研究方向的同学可以互相学习,共同进步。


http://chatgpt.dhexx.cn/article/EdkN8CNK.shtml

相关文章

超分辨率——基于SRGAN的图像超分辨率重建(Pytorch实现)

基于SRGAN的图像超分辨率重建 本文偏新手项,因此只是作为定性学习使用,因此不涉及最后的定量评估环节 目录 基于SRGAN的图像超分辨率重建1 简要介绍2 代码实现2.1 开发环境2.2 主要流程2.3 构建数据集2.4 构建生成模型(Generator&#xff09…

SRCNN神经网络

0 前言 超分辨率技术(Super Resolution,SR)是指从观测到的低分辨率图像重建出相应的高分辨率图像,在监控设备、卫星图像和医学影像等领域都拥有着重要的应用价值。 1 SRCNN SRCNN是深度学习用在超分辨率重建上的开山之作。 其结构十分简单&#xff0c…

SRGAN的理解

全文翻译见:https://blog.csdn.net/weixin_42113955/article/details/89001989 和https://blog.csdn.net/c2a2o2/article/details/78930865 1. ptrain是真正的HR图像,也就是data要预测的。 pG是生成的超分辨图像 好处在于:固定 G&#xff0c…

GANs综述

生成式对抗网络GANs及其变体 基础GAN 生成式对抗网络,是lan Goodfellow 等人在2014年开发的,GANs 属于生成式模型,GANs是基于最小值和最大值的零和博弈理论。 为此,GANs是由两个神经网络组成一个Generator。另一个是Discriminat…

图像的超分辨率重建SRGAN与ESRGAN

SRGAN 传统的图像超分辨率重建方法一般都是放大较小的倍数,当放大倍数在4倍以上时就会出现过度平滑的现象,使得图像出现一些非真实感。SRGAN借助于GAN的网络架构生成图像中的细节。 训练网络使用均方误差(MSE)能够获得较高的峰值…

SRGAN With WGAN

SRGAN With WGAN RGAN 是基于 GAN 方法进行训练的,有一个生成器和一个判别器,判别器的主体使用 VGG19,生成器是一连串的 Residual block 连接,同时在模型后部也加入了 subpixel 模块,借鉴了 Shi et al 的 Subpixel Ne…

SRGAN论文与ESRGAN论文总结

博客结构 SRGANContribution:Network Architecture:Generator NetworkDiscriminator Network Perceptual loss function:Experiments:Mean opinion score (MOS) testing: ESRGANContribution:Network Architecture:ESR…

SR-GNN

Session-based Recommendation with Graph Neural Networks 一、论文 1、理论 ​ SR-GNN是一种基于会话序列建模的推荐系统。会话序列专门表示一个用户过往一段时间的交互序列。 ​ 常用的会话推荐包括循环神经网络和马尔科夫链,但有两个缺点: 当一…

SRGAN(SRResNet)介绍

生成对抗网络GAN是由蒙特利尔大学Ian Goodfellow在2014年提出的机器学习架构。 生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。模型通过框架中(至…

SRGAN

摘要: 尽管使用更快更深的卷积神经网络在单图像超分辨率的准确性和速度方面取得了突破,但一个核心问题仍然很大程度上未解决:当我们在大的升级因子上超分辨时,我们如何恢复更精细的纹理细节?基于优化的超分辨率方法的行…

深度学习计划(4)SRGan简析

SRGAN 一种用于图像超分辨率(SR)的生成对抗网络(GAN) 超分辨率:从低分辨率(LR)图像来估计其对应高分辨率(HR)图像的高挑战性任务被称作超分辨率(SR) 问题: 重建的SR图像中通常缺少纹理细节。有监督SR算法的优化目标通常是最小化恢复的HR图像和真实图像…

图像超分经典网络 SRGAN精确解析

SRGAN 核心思想 早期超分辨率方法的优化目标都是降低低清图像和高清图像之间的均方误差。降低均方误差,确实让增强图像和原高清图像的相似度更高。但是,图像的相似度指标高并不能代表图像的增强质量就很高。 为什么 SRGAN 的增强结果那么清楚呢&#x…

SRGAN简单了解

超分辨率问题的病态性质尤其表现在取较高的放大因子时,重构的超分辨率图像通常会缺失纹理细节。监督SR算法的优化目标函数通常取重建高分辨率图像和地面真值之间的均方误差,在减小均方误差的同时又可以增大峰值信噪比(PSNR),PSNR是评价和比较…

【超分辨】SRGAN详解及其pytorch代码解释

SRGAN详解 介绍网络结构损失函数数据处理网络训练 介绍 「2023年更新」本代码是学习参考代码,一般不能直接运行,想找现成能运行的建议看看其他的。 SRGAN是一个超分辨网络,利用生成对抗网络的方法实现图片的超分辨。 关于生成对抗网络&#…

超分之一文读懂SRGAN

这篇文章介绍SRResNet网络,以及将SRResNet作为生成网络的GAN模型用于超分,即SRGAN模型。这是首篇在人类感知视觉上进行超分的文章,而以往的文章以PSNR为导向,但那些方式并不能让人眼觉得感知到了高分辨率——Photo-Realistic。 参…

图像超分经典网络 SRGAN 解析 ~ 如何把 GAN 运用在其他视觉任务上

生成对抗网络(GAN)是一类非常有趣的神经网络。借助GAN,计算机能够生成逼真的图片。近年来有许多“AI绘画”的新闻,这些应用大多是通过GAN实现的。实际上,GAN不仅能做图像生成,还能辅助其他输入信息不足的视觉任务。比如SRGAN&…

Oracle常用函数汇总记录

Oracle常用函数汇总记录 一、SUBSTR 截取函数 用法:substr(字符串,截取开始位置,截取长度) //返回截取的字, 字符串的起始位置为1,截取时包含起始位置字符 1.SUBSTR( “Hello World”, 2 ) //返回结果为:ello World,从第二个字符开始截取至末位 2.SUBSTR( “Hello World”, -2…

oracle一些常用函数用法,Oracle常用函数及其用法

01、入门Oracle 本章目标: 掌握oracle安装、启动和关闭 基本管理以及常用工具 简单备份和恢复 熟练使用sql,掌握oracle常用对象 掌握数据库设计和优化基本方法 http://jingyan.baidu.com/article/5d6edee228308899eadeec3f.html oracle数据库&#xff1a…

oracle常用函数详解(详细)

Oracle SQL 提供了用于执行特定操作的专用函数。这些函数大大增强了 SQL 语言的功能。函数可以接受零个或者多个输入参数,并返回一个输出结果。 Oracle 数据库中主要使用两种类型的函数: 1. 单行函数:对每一个函数应用在表的记录中时&#…

event对象的offsetX、clientX、pageX、screenX及 window.innerWidth、outerWidth使用详解

目录 offset client screen page window.innerWidht offset offsetX、offsetY为当前鼠标点击位置距离当前元素参考原点(左上角)的距离,而不同浏览器参考原点的位置不尽相同,FF及Chrome中参考原点为内容区域左上角,不…