1. 背景
用PaddlePaddle复现论文LIIF,LIIF中使用的Encoder是RDN,本文介绍一下RDN。
RDN论文:https://arxiv.org/abs/1802.08797
Torch代码: https://github.com/yinboc/liif/blob/main/models/rdn.py
2. RDN的组成
RDN网络结构

2.1 浅层特征提取网络(SFENet)
就是网络最开始的那两个卷积层。
2.2 残差密集块RDB
(residual density block, RDB )

2.3 残差密集块RDBs

通过将卷积层密集连接的方式提取丰富的局部特征。
RDB还允许从先前所有RDB的状态直接连接到当前RDB中的所有层,从而形成了连续记忆(contiguous memory,CM)机制。
【残差密集块RDB = 密集连接层 + 局部特征融合(LFF)+ 局部残差】,形成了连续记忆机制(Contiguous Memory)
连续记忆机制(CM)
就是可以将第 d-1 个RDB块的输出直接输入到第 d 个RDB块中的每一层去(见上图dense部分的红线所示),经过dense的作用,可以将 F d − 1 , F d , 1 , F d , c , F d , C F_{d-1},F_{d,1},F_{d,c},F_{d,C} Fd−1,Fd,1,Fd,c,Fd,C 的特征都利用起来。
局部特征融合(Local feature fusion,LFF)
即RDB中的那个concat,能够将前一个RDB的输出 F d − 1 F_{d-1} Fd−1 、当前RDB F d F_{d} Fd 中每一层得到的状态融合通过concat在一起。然后,再利用 1 x 1 卷积对concat降低通道数,简化数据。
局部残差学习(Local residual learning)
由于RDB中存在多个卷积层,因此引入局部残差学习以进一步改善信息流。
2.4 密集特征融合(DFF)
通过一系列RDBs提取了局部密集特征后,进一步提出密集特征融合(DFF),从全局的角度挖掘多层次特征(hierarchical features)。 DFF由全局特征融合(GFF)和全局残差学习(GRL)两部分组成。
全局特征融合(Global Feature Fusion)

如上面Figure 2. 所示,全局特征融合即:
- 把多个RDBs的输出( F 1 , F d , … , F D F_{1}, F_{d}, \ldots, F_{D} F1,Fd,…,FD)concat在一起;
- 再经过一个 1 x 1 Conv层,将这一系列不同level的特征自适应地融合在一起;
- 再通过 3 x 3 Conv层,进一步提取特征得到 F G F F_{GF} FGF ,用接下来的全局残差学习(GRL)。
全局残差学习(Global Residual Learning)
全局残差学习就是上面的Figure 2. 中,将通过第一个Conv层得到的浅层特征图 F − 1 F_{-1} F−1 ,与上面全局特征融合GFF得到的 F G F F_{GF} FGF 作element-wise的相加,得到 F D F F_{DF} FDF 。
2.5 上采样网络(UPNet)
就是一个上采样+卷积操作,最终输出HR结果 I H R I_{HR} IHR 。
3. 数据集
3.1 数据集介绍
DIV2K是一个流行的单图像超分辨率数据集,它包含 1000 张不同场景的图像,分为 800 张用于训练,100 张用于验证,100 张用于测试。它是为 NTIRE2017 和 NTIRE2018 超分辨率挑战收集的,以鼓励对具有更逼真退化的图像超分辨率的研究。该数据集包含具有不同类型退化的低分辨率图像。
div2k数据集官方地址:https://data.vision.ee.ethz.ch/cvl/DIV2K/
本项目使用的是 AiStudio公开数据集里的已存在的div2k https://aistudio.baidu.com/aistudio/datasetdetail/104667
# 解压数据集
!unzip -qo /home/aistudio/data/data104667/DIV2K_train_HR.zip -d /home/aistudio/DIV2K
!unzip -qo /home/aistudio/data/data104667/DIV2K_valid_HR.zip -d /home/aistudio/DIV2K
#导入包
import os
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np
from PIL import Image
import matplotlib.pyplot as pltimport time
import warnings
warnings.filterwarnings('ignore')
os.environ["CUDA_VISIBLE_DEVICES"]="0"# 定义常量
BATCHSIZE=16
SCALE = 4
PATCHSIZE = [48,48]
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingfrom collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingfrom collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingfrom collections import Sized
3.2 dataset类编写
读取图片目录并处理为成对的数据集
# 定义批量读取数据集目录
def reader_patch(batchsize,scale=SCALE,patchsize=PATCHSIZE):dirsname = '/home/aistudio/DIV2K/DIV2K_train_HR/'dirs = os.listdir(dirsname)np.random.shuffle(dirs)LRs = np.zeros((batchsize,3,patchsize[0],patchsize[1])).astype("float32")HRs = np.zeros((batchsize,3,patchsize[0]*scale,patchsize[1]*scale)).astype("float32")for filename in dirs:image = Image.open(dirsname+filename)sz = image.sizesz_row = sz[1]//(patchsize[0]*scale)*patchsize[0]*scalediff_row = sz[1] - sz_rowsz_col = sz[0]//(patchsize[1]*scale)*patchsize[1]*scalediff_col = sz[0] - sz_colrow_min = np.random.randint(diff_row+1)col_min = np.random.randint(diff_col+1)HR = image.crop((col_min,row_min,col_min+sz_col,row_min+sz_row))LR = HR.resize((sz[0]//(patchsize[1]*scale)*patchsize[1],sz[1]//(patchsize[0]*scale)*patchsize[0]), Image.BICUBIC)LR = np.array(LR).astype("float32").transpose([2,0,1]) / 255 * 2 - 1HR = np.array(HR).astype("float32").transpose([2,0,1]) / 255 * 2 - 1for batch in range(batchsize):rowMin, colMin = np.random.randint(0,LR.shape[1]-patchsize[0]+1), np.random.randint(0,LR.shape[2]-patchsize[1]+1)LRs[batch,:,:,:] = LR[:,rowMin:rowMin+patchsize[0], colMin:colMin+patchsize[1]]HRs[batch,:,:,:] = HR[:,scale*rowMin:scale*(rowMin+patchsize[0]), scale*colMin:scale*(colMin+patchsize[1])]yield LRs, HRs
#随机查看数据集的LR和HR效果
data = reader_patch(1)
for i in range(2):LR, HR = next(data)LR = LR.transpose([2,3,1,0]).reshape(PATCHSIZE[0],PATCHSIZE[1],3)LR = Image.fromarray(np.uint8((LR+1)/2*255))HR = HR.transpose([2,3,1,0]).reshape(PATCHSIZE[0]*SCALE,PATCHSIZE[1]*SCALE,3)HR = Image.fromarray(np.uint8((HR+1)/2*255))plt.subplot(1,2,1), plt.imshow(LR),plt.title('LRx'+str(SCALE)) #是把HR处理为LR后再放大多少倍plt.subplot(1,2,2), plt.imshow(HR),plt.title('HR')plt.show()


# 定义读取数据集方法
def load_data(mode='train',batchsize=BATCHSIZE,scale=SCALE,patchsize=PATCHSIZE):if mode=='train':dirsname = '/home/aistudio/DIV2K/DIV2K_train_HR/'elif mode=='valid':dirsname = '/home/aistudio/DIV2K/DIV2K_valid_HR/'dirs = os.listdir(dirsname)# 定义数据生成器def data_generator():# 训练模式下,打乱训练数据if mode == 'train':np.random.shuffle(dirs)LRs = np.zeros((batchsize,3,patchsize[0],patchsize[1])).astype("float32")HRs = np.zeros((batchsize,3,patchsize[0]*scale,patchsize[1]*scale)).astype("float32")for filename in dirs:# print(filename)image = Image.open(dirsname+filename)sz = image.sizesz_row = sz[1]//(patchsize[0]*scale)*patchsize[0]*scalediff_row = sz[1] - sz_rowsz_col = sz[0]//(patchsize[1]*scale)*patchsize[1]*scalediff_col = sz[0] - sz_colrow_min = np.random.randint(diff_row+1)col_min = np.random.randint(diff_col+1)HR = image.crop((col_min,row_min,col_min+sz_col,row_min+sz_row))LR = HR.resize((sz[0]//(patchsize[1]*scale)*patchsize[1],sz[1]//(patchsize[0]*scale)*patchsize[0]), Image.BICUBIC)LR = np.array(LR).astype("float32").transpose([2,0,1]) / 255 * 2 - 1HR = np.array(HR).astype("float32").transpose([2,0,1]) / 255 * 2 - 1for batch in range(batchsize):rowMin, colMin = np.random.randint(0,LR.shape[1]-patchsize[0]+1), np.random.randint(0,LR.shape[2]-patchsize[1]+1)LRs[batch,:,:,:] = LR[:,rowMin:rowMin+patchsize[0], colMin:colMin+patchsize[1]]HRs[batch,:,:,:] = HR[:,scale*rowMin:scale*(rowMin+patchsize[0]), scale*colMin:scale*(colMin+patchsize[1])]yield LRs, HRsreturn data_generator
#随机读取对比下低精和高精效果
train_loader = load_data('train',1)
for i in range(2):LR, HR = next(train_loader())LR = LR.transpose([2,3,1,0]).reshape(PATCHSIZE[0],PATCHSIZE[1],3)LR = Image.fromarray(np.uint8((LR+1)/2*255))HR = HR.transpose([2,3,1,0]).reshape(PATCHSIZE[0]*SCALE,PATCHSIZE[1]*SCALE,3)HR = Image.fromarray(np.uint8((HR+1)/2*255))plt.subplot(1,2,1), plt.imshow(LR),plt.title('LRx'+str(SCALE))plt.subplot(1,2,2), plt.imshow(HR),plt.title('HR')plt.show()


4. 组建网络
4.1 RDN网络模型
# RDB卷积层
class RDB_Conv(nn.Layer):def __init__(self, inChannels, growRate, kSize=3):super().__init__()Cin = inChannelsG = growRateself.conv = nn.Sequential(*[nn.Conv2D(Cin, G, kSize, padding=(kSize-1)//2, stride=1),nn.ReLU()])def forward(self, x):out = self.conv(x)return paddle.concat([x, out], 1)# 残差密集块RDB
class RDB(nn.Layer):def __init__(self, growRate0, growRate, nConvLayers, kSize=3):super().__init__()G0 = growRate0G = growRateC = nConvLayersconvs = []for c in range(C):convs.append(RDB_Conv(G0 + c*G, G))self.convs = nn.Sequential(*convs)# Local Feature Fusionself.LFF = nn.Conv2D(G0 + C*G, G0, 1, padding=0, stride=1)def forward(self, x):return self.LFF(self.convs(x)) + x
# 定义RND网络
class RDN(nn.Layer):def __init__(self):super(RDN, self).__init__()# self.args = args# r = args.scale[0]# G0 = args.G0# kSize = args.RDNkSizer = SCALEG0 = 64kSize = 3n_colors = 3self.no_upsampling = False# number of RDB blocks, conv layers, out channelsself.D, C, G = {'A': (20, 6, 32),'B': (16, 8, 64),}['B']# Shallow feature extraction netself.SFENet1 = nn.Conv2D(n_colors, G0, kSize, padding=(kSize-1)//2, stride=1)self.SFENet2 = nn.Conv2D(G0, G0, kSize, padding=(kSize-1)//2, stride=1)# Redidual dense blocks and dense feature fusionself.RDBs = nn.LayerList()for i in range(self.D):self.RDBs.append(RDB(growRate0 = G0, growRate = G, nConvLayers = C))# Global Feature Fusionself.GFF = nn.Sequential(*[nn.Conv2D(self.D * G0, G0, 1, padding=0, stride=1),nn.Conv2D(G0, G0, kSize, padding=(kSize-1)//2, stride=1)])#if self.no_upsampling:self.out_dim = G0else:self.out_dim = n_colors# Up-sampling netif r == 2 or r == 3:self.UPNet = nn.Sequential(*[nn.Conv2D(G0, G * r * r, kSize, padding=(kSize-1)//2, stride=1),nn.PixelShuffle(r),nn.Conv2D(G, n_colors, kSize, padding=(kSize-1)//2, stride=1)])elif r == 4:self.UPNet = nn.Sequential(*[nn.Conv2D(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1),nn.PixelShuffle(2),nn.Conv2D(G, G * 4, kSize, padding=(kSize-1)//2, stride=1),nn.PixelShuffle(2),nn.Conv2D(G, n_colors, kSize, padding=(kSize-1)//2, stride=1)])else:raise ValueError("scale must be 2 or 3 or 4.")def forward(self, x):f__1 = self.SFENet1(x)x = self.SFENet2(f__1)#RDBs_out = []for i in range(self.D):x = self.RDBs[i](x)RDBs_out.append(x)#x = self.GFF(paddle.concat(RDBs_out,1))x += f__1#if self.no_upsampling:return xelse:return self.UPNet(x)
定义显示图片,包括LR(低清图),HR(高清图) 和SR(超分辨率图)
# SR图片显示函数
def show_image(G=None):if G==None:G = RDN()G.eval()dirsname = '/home/aistudio/DIV2K/DIV2K_train_HR/'dirs = os.listdir(dirsname)np.random.shuffle(dirs)fig = plt.figure(figsize=(25, 25))gs = plt.GridSpec(1, 3)gs.update(wspace=0.1, hspace=0.1)image = Image.open(dirsname+dirs[0])# image = image.crop([0,0,image.size[0]//SCALE*SCALE,image.size[1]//SCALE*SCALE])# image = image.crop([0,0,100,100])image = image.crop([0,0,200,200])LR0 = image.resize((image.size[0]//SCALE,image.size[1]//SCALE),Image.BICUBIC)LR = np.array(LR0).astype('float32').reshape([image.size[1]//SCALE,image.size[0]//SCALE,3,1]).transpose([3,2,0,1]) / 255 * 2 - 1LSR = G(paddle.to_tensor(LR)).numpy()print(np.max(LSR), np.min(LSR))LSR = LSR.reshape([3,image.size[1]//SCALE*SCALE,image.size[0]//SCALE*SCALE]).transpose([1,2,0])# LSR = Image.fromarray(np.uint8((LSR+1)/2*255)) ### 亮斑罪魁祸首LSR = (LSR+1)/2ax = plt.subplot(gs[0])plt.imshow(LR0)plt.title('LR')ax = plt.subplot(gs[1])plt.imshow(LSR)plt.title('SR')ax = plt.subplot(gs[2])plt.imshow(image)plt.title('HR')plt.show()
4.2 模型训练
定义数据增强函数
# 数据增强
def data_augmentation(LR, HR): if np.random.randint(2) == 1:LR = LR[:,:,:,::-1]HR = HR[:,:,:,::-1]n = np.random.randint(4)if n == 1:LR = LR.transpose([0,1,3,2])LR = LR[:,:,::-1,:]HR = HR.transpose([0,1,3,2])HR = HR[:,:,::-1,:]if n == 2:LR = LR[:,:,:,::-1]LR = LR[:,:,::-1,:]HR = HR[:,:,:,::-1]HR = HR[:,:,::-1,:]if n == 3:LR = LR.transpose([0,1,3,2])LR = LR[:,:,:,::-1]HR = HR.transpose([0,1,3,2])HR = HR[:,:,:,::-1]return LR, HR
定义数据加载和训练函数
from visualdl import LogWriter
log_writer = LogWriter(logdir="./output/RDN/log")#调用加载数据的函数
train_loader = load_data('train')
# LR, HR = next(train_loader())
# print(LR, HR )def train(model,epoch_num=200,batchsize=1,load_model=False):model.train()optimizer = paddle.optimizer.Adam(learning_rate=1e-4, beta1=0.9, parameters=model.parameters())model_path = './output/RDN/'if load_model == True:model.set_state_dict(paddle.load(model_path+'Model.pdparams'))iteration_num = 0iters=[]losses=[]for epoch_id in range(epoch_num):for batch_id, data in enumerate(train_loader()):iteration_num += 1 LR, HR = dataLR, HR = data_augmentation(LR, HR) # 数据增强LR = paddle.to_tensor(LR)HR = paddle.to_tensor(HR)y = model(LR) loss = paddle.mean(paddle.abs(y - HR))#每训练了100批次的数据,打印下当前Loss的情况if(iteration_num % 20 == 0):datetime = time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))print("{} epoch: {},batch_id: {}, iter: {}, loss is: {}".format(datetime,epoch_id,batch_id, iteration_num, loss.numpy())) # 累计迭代次数和对应的losslog_writer.add_scalar(tag = 'loss', step = iteration_num, value = loss.numpy())#后向传播,更新参数的过程loss.backward()optimizer.step()optimizer.clear_grad()#保存模型参数 paddle.save(model.state_dict(), model_path+'Model.pdparams')print('save model in {}'.format(model_path+'Model.pdparams'))
训练开始
# train(epoch_num=1, load_model=False, batchsize=12) #第一个epoch,保存模型, 记得注释
# train(epoch_num=1000, load_model=True, batchsize=16) #后面每次训练就可以读取之前的模型,继续训练了#启动训练过程
model = RDN()
train(model=model,epoch_num=10, load_model=False, batchsize=16)
4.3 模型预测
# 加载训练好的RDN模型
GG = RDN()
GG.eval()GG.set_state_dict(paddle.load('./Best.pdparams'))
# 显示图像,SR与LR和HR的对比
show_image(GG)
0.6906135 -0.94815046

5. 结束
了解RDN后,就可以进一步了解LIIF了。戳《超分辨率模型-LIIF,可放大30多倍》
如果对你有帮助,请关注、点赞、fork。
此文章为搬运
原项目链接
















