UNet详解(附图文和代码实现)

article/2025/8/24 8:19:59

卷积神经网络被大规模的应用在分类任务中,输出的结果是整个图像的类标签。但是UNet是像素级分类,输出的则是每个像素点的类别,且不同类别的像素会显示不同颜色,UNet常常用在生物医学图像上,而该任务中图片数据往往较少。所以,Ciresan等人训练了一个卷积神经网络,用滑动窗口提供像素的周围区域(patch)作为输入来预测每个像素的类标签。这个网络有两个优点:(1)输出结果可以定位出目标类别的位置;(2)由于输入的训练数据是patches,这样就相当于进行了数据增强,从而解决了生物医学图像数量少的问题。

但是,采用该方法的神经网络也有两个很明显的缺点:(1)它很慢,因为这个网络必须训练每个patch,并且因为patch之间的重叠有很多冗余,这样会导致同样特征被多次训练,造成资源的浪费,导致训练时间的加长且效率也会有所降低,也有人会问神经网络经过多次训练这个特征后,会对这个特征的印象加深,从而准确率也会上升,但是举个例子一个图片复制50张,用这50张图片去训练网络,虽说数据集增大了,可是导致的后果是神经网络会出现过拟合,也就是说神经网络对训练图片很熟悉,可是换了一张图片,神经网络就有可能分辨不出来了。(2)定位准确性和获取上下文信息不可兼得,大的patches需要更多的max-pooling,这样会减少定位准确性,因为最大池化会丢失目标像素和周围像素之间的空间关系,而小patches只能看到很小的局部信息,包含的背景信息不够。

UNet主要贡献是在U型结构上,该结构可以使它使用更少的训练图片的同时,且分割的准确度也不会差,UNet的网络结构如下图:

在这里插入图片描述
(1)UNet采用全卷积神经网络。
(2)左边网络为特征提取网络:使用conv和pooling
(3)右边网络为特征融合网络:使用上采样产生的特征图与左侧特征图进行concatenate操作。(pooling层会丢失图像信息和降低图像分辨率且是永久性的,对于图像分割任务有一些影响,对图像分类任务的影响不大,为什么要做上采样呢?上采样可以让包含高级抽象特征低分辨率图片在保留高级抽象特征的同时变为高分辨率,然后再与左边低级表层特征高分辨率图片进行concatenate操作)
(4)最后再经过两次卷积操作,生成特征图,再用两个卷积核大小为1*1的卷积做分类得到最后的两张heatmap,例如第一张表示第一类的得分,第二张表示第二类的得分heatmap,然后作为softmax函数的输入,算出概率比较大的softmax,然后再进行loss,反向传播计算。

Unet模型的代码实现(基于keras):

def get_unet():inputs = Input((img_rows, img_cols, 1))conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)# pool1 = Dropout(0.25)(pool1)# pool1 = BatchNormalization()(pool1)conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)# pool2 = Dropout(0.5)(pool2)# pool2 = BatchNormalization()(pool2)conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3)pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)# pool3 = Dropout(0.5)(pool3)# pool3 = BatchNormalization()(pool3)conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3)conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4)pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)# pool4 = Dropout(0.5)(pool4)# pool4 = BatchNormalization()(pool4)conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool4)conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5)up6 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv5), conv4], axis=3)# up6 = Dropout(0.5)(up6)# up6 = BatchNormalization()(up6)conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up6)conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6)up7 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], axis=3)# up7 = Dropout(0.5)(up7)# up7 = BatchNormalization()(up7)conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(up7)conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7)up8 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], axis=3)# up8 = Dropout(0.5)(up8)# up8 = BatchNormalization()(up8)conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(up8)conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv8)up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3)# up9 = Dropout(0.5)(up9)# up9 = BatchNormalization()(up9)conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(up9)conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv9)# conv9 = Dropout(0.5)(conv9)conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9)model = Model(inputs=[inputs], outputs=[conv10])model.compile(optimizer=Adam(lr=1e-5),loss=dice_coef_loss, metrics=[dice_coef])return model

Unet的代码实现(pytorch版)

"""
这是根据UNet模型搭建出的一个基本网络结构
输入和输出大小是一样的,可以根据需求进行修改
"""
import torch
import torch.nn as nn
from torch.nn import functional as F# 基本卷积块
class Conv(nn.Module):def __init__(self, C_in, C_out):super(Conv, self).__init__()self.layer = nn.Sequential(nn.Conv2d(C_in, C_out, 3, 1, 1),nn.BatchNorm2d(C_out),# 防止过拟合nn.Dropout(0.3),nn.LeakyReLU(),nn.Conv2d(C_out, C_out, 3, 1, 1),nn.BatchNorm2d(C_out),# 防止过拟合nn.Dropout(0.4),nn.LeakyReLU(),)def forward(self, x):return self.layer(x)# 下采样模块
class DownSampling(nn.Module):def __init__(self, C):super(DownSampling, self).__init__()self.Down = nn.Sequential(# 使用卷积进行2倍的下采样,通道数不变nn.Conv2d(C, C, 3, 2, 1),nn.LeakyReLU())def forward(self, x):return self.Down(x)# 上采样模块
class UpSampling(nn.Module):def __init__(self, C):super(UpSampling, self).__init__()# 特征图大小扩大2倍,通道数减半self.Up = nn.Conv2d(C, C // 2, 1, 1)def forward(self, x, r):# 使用邻近插值进行下采样up = F.interpolate(x, scale_factor=2, mode="nearest")x = self.Up(up)# 拼接,当前上采样的,和之前下采样过程中的return torch.cat((x, r), 1)# 主干网络
class UNet(nn.Module):def __init__(self):super(UNet, self).__init__()# 4次下采样self.C1 = Conv(3, 64)self.D1 = DownSampling(64)self.C2 = Conv(64, 128)self.D2 = DownSampling(128)self.C3 = Conv(128, 256)self.D3 = DownSampling(256)self.C4 = Conv(256, 512)self.D4 = DownSampling(512)self.C5 = Conv(512, 1024)# 4次上采样self.U1 = UpSampling(1024)self.C6 = Conv(1024, 512)self.U2 = UpSampling(512)self.C7 = Conv(512, 256)self.U3 = UpSampling(256)self.C8 = Conv(256, 128)self.U4 = UpSampling(128)self.C9 = Conv(128, 64)self.Th = torch.nn.Sigmoid()self.pred = torch.nn.Conv2d(64, 3, 3, 1, 1)def forward(self, x):# 下采样部分R1 = self.C1(x)R2 = self.C2(self.D1(R1))R3 = self.C3(self.D2(R2))R4 = self.C4(self.D3(R3))Y1 = self.C5(self.D4(R4))# 上采样部分# 上采样的时候需要拼接起来O1 = self.C6(self.U1(Y1, R4))O2 = self.C7(self.U2(O1, R3))O3 = self.C8(self.U3(O2, R2))O4 = self.C9(self.U4(O3, R1))# 输出预测,这里大小跟输入是一致的# 可以把下采样时的中间抠出来再进行拼接,这样修改后输出就会更小return self.Th(self.pred(O4))if __name__ == '__main__':a = torch.randn(2, 3, 256, 256)net = UNet()print(net(a).shape)

Unet的代码实现(TensorFlow版)

# -*-coding: utf-8 -*-
import tensorflow as tf
import tensorflow.contrib.slim as slimdef lrelu(x):return tf.maximum(x * 0.2, x)activation_fn=lreludef UNet(inputs, reg):  # Unetconv1 = slim.conv2d(inputs, 32, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv1_1', weights_regularizer=reg)conv1 = slim.conv2d(conv1, 32, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv1_2',weights_regularizer=reg)pool1 = slim.max_pool2d(conv1, [2, 2], padding='SAME')conv2 = slim.conv2d(pool1, 64, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv2_1',weights_regularizer=reg)conv2 = slim.conv2d(conv2, 64, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv2_2',weights_regularizer=reg)pool2 = slim.max_pool2d(conv2, [2, 2], padding='SAME')conv3 = slim.conv2d(pool2, 128, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv3_1',weights_regularizer=reg)conv3 = slim.conv2d(conv3, 128, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv3_2',weights_regularizer=reg)pool3 = slim.max_pool2d(conv3, [2, 2], padding='SAME')conv4 = slim.conv2d(pool3, 256, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv4_1',weights_regularizer=reg)conv4 = slim.conv2d(conv4, 256, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv4_2',weights_regularizer=reg)pool4 = slim.max_pool2d(conv4, [2, 2], padding='SAME')conv5 = slim.conv2d(pool4, 512, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv5_1',weights_regularizer=reg)conv5 = slim.conv2d(conv5, 512, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv5_2',weights_regularizer=reg)up6 = upsample_and_concat(conv5, conv4, 256, 512)conv6 = slim.conv2d(up6, 256, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv6_1',weights_regularizer=reg)conv6 = slim.conv2d(conv6, 256, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv6_2',weights_regularizer=reg)up7 = upsample_and_concat(conv6, conv3, 128, 256)conv7 = slim.conv2d(up7, 128, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv7_1',weights_regularizer=reg)conv7 = slim.conv2d(conv7, 128, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv7_2',weights_regularizer=reg)up8 = upsample_and_concat(conv7, conv2, 64, 128)conv8 = slim.conv2d(up8, 64, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv8_1',weights_regularizer=reg)conv8 = slim.conv2d(conv8, 64, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv8_2',weights_regularizer=reg)up9 = upsample_and_concat(conv8, conv1, 32, 64)conv9 = slim.conv2d(up9, 32, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv9_1', weights_regularizer=reg)conv9 = slim.conv2d(conv9, 32, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv9_2',weights_regularizer=reg)print("conv9.shape:{}".format(conv9.get_shape()))type='UNet_1X'with tf.variable_scope(name_or_scope="output"):if type=='UNet_3X':#UNet放大三倍conv10 = slim.conv2d(conv9, 27, [1, 1], rate=1, activation_fn=None, scope='g_conv10',weights_regularizer=reg)out = tf.depth_to_space(conv10, 3)if type=='UNet_1X':#输入输出维度相同out = slim.conv2d(conv9, 6, [1, 1], rate=1, activation_fn=None, scope='g_conv10',weights_regularizer=reg)return outdef upsample_and_concat(x1, x2, output_channels, in_channels):pool_size = 2deconv_filter = tf.Variable(tf.truncated_normal([pool_size, pool_size, output_channels, in_channels], stddev=0.02))deconv = tf.nn.conv2d_transpose(x1, deconv_filter, tf.shape(x2), strides=[1, pool_size, pool_size, 1])deconv_output = tf.concat([deconv, x2], 3)deconv_output.set_shape([None, None, None, output_channels * 2])return deconv_outputif __name__=="__main__":weight_decay=0.001reg = slim.l2_regularizer(scale=weight_decay)inputs = tf.ones(shape=[4, 256, 256, 3])out=UNet(inputs,reg)print("net1.shape:{}".format(inputs.get_shape()))print("out.shape:{}".format(out.get_shape()))with tf.Session() as sess:sess.run(tf.global_variables_initializer())

http://chatgpt.dhexx.cn/article/xp8qbObR.shtml

相关文章

条纹进度条

最开始学习qml的时候,想实现一个条纹进度条,当时还不熟悉动画,做不出来,只做了个静止的。qml学习和使用了快1年之际,把这个遗憾弥补下。 先上效果图,颜色取自Bootflat。 以下是源码,先上Stripe…

数据批量插入与逐条插入分析

简述 今天抽空做了一下使用Jdbc对数据操作的实际性能。在平时开发过程中我们经常会使用Hibernate来操作数据库,所以我们很少会去使用批量插入数据。一般都是通过hibernate的insert、addSave等方法来一条条地插入数据。所以很少去考虑这个问题。下面是针对Jdbc进行的…

css滚动条

此部分针对webkit内核的浏览器,使用伪类来改变滚动条的默认样式,详情如下: 滚动条组成部分 1. ::-webkit-scrollbar 滚动条整体部分2. ::-webkit-scrollbar-thumb 滚动条里面的小方块,能向上向下移动(或向左向右移动…

计算机组成原理(五)-一条指令是怎么被执行的

什么是指令: 程序代码的本质就是一条一条的指令,我们需要通过编码的方式让CPU知道我们需要它干什么,最后由译码器翻译成一条条的机器指令。机器指令主要有两部分组成:操作码、地址码。地址码直接给出操作数和操作数的地址&#x…

CSS 斜条纹进度条动画

这是第一版进度条 ,用css写的.但是后续因为数据不同,要显示不同的颜色和数据,所以又改了一版,直接用的el-progress.自定义的样式.对于新手小白来说比较友好.先上这一版代码. <div class"state"><span>开机时间</span><!-- 进度条 --><div…

Acrobat DC 更改背景颜色会有一条条白色横纹

解决方法如下&#xff1a; 编辑->首选项->页面显示->取消 使用2D图形加速

turtle模块还能这样玩?(一条条金龙鱼、雨景)

文章目录 一条条金龙鱼雨景 Python的turtle模块不仅可以用来绘制一些基本的图形&#xff0c;还有与图片结合&#xff0c;做出一些特殊的效果&#xff0c;还可以用来做二维小游戏。本篇是介绍用turtle模块做出的一幅动态的鱼儿游过的画面和动态的雨景图 一条条金龙鱼 1、先看一…

2.Python # 代码注释

2. # 代码注释 文章目录 2. # 代码注释1. 什么是代码注释2. 注释语法3. 注释位置1. 注释在代码的上一行2. 注释在代码的末端 4. 课堂练习 1. 什么是代码注释 代码注释即对代码进行批注说明。 相当于给一个英文单词批注中文释义。 【温馨提示】注释是给程序员自己看的&#xf…

python:导入第三方库greenlet,gevent方法

greenlet&#xff0c;gevent greenlet&#xff0c;gevent是python支持的第三方库&#xff0c;它们可以帮助我们完成协程的使用&#xff0c;其中greenlet是手动调换方式&#xff08;switch方法&#xff09;&#xff0c;gevent是自动调换方式&#xff08;遇到IO操作&#xff09;…

python gevent使用

对大部分语言来说&#xff0c;经常用到并发来处理一些情况。比如必须要多次查询数据库&#xff0c;多次请求API&#xff0c;python内置的gevent就很简单好用。传参&#xff0c;获取返回值&#xff0c;捕获协程的错误都很方便。 直接上例子&#xff1a; import gevent as gevent…

指定Geany使用的Python版本

本文介绍&#xff1a; 在win7下配置Geany,使其使用 Python 3 因为电脑上安装了不同版本的Python&#xff0c;需要根据实际情况来进行版本切换。 第一步&#xff1a;首先点击"生成"按钮的三角箭头&#xff0c;再点击"设置生成命令" 第二步&#xff1a;在弹出…

ModuleNotFountError:No module named ‘gensim‘(在python代码中导入gensim模块)

运行窗口&#xff1a; conda install 模块 pip install 模块 第一个命令应该由于网速太慢没有下载完全&#xff0c;第二个命令我手动输入n退出了。 网速慢&#xff0c;可以使用如下命令&#xff1a; pip install -i https://pypi.douban.com/simple gensim pip install -i…

Python学习笔记--图例 legend

Python学习笔记--图例 legend 参靠视频:《Python数据可视化分析 matplotlib教程》链接&#xff1a;https://www.bilibili.com/video/av6989413/?p6 所用的库及环境: IDE:Pycharm Python环境&#xff1a;python3.7 Matplotlib: Matplotlib 1.11 Numpy&#xff1a; Numpy1.1…

【pybind11笔记】eigen与numpy数据交互

系列文章 【pybind11笔记】eigen与numpy数据交互 【pybind11笔记】python调用c函数 【pybind11笔记】python调用c结构体 【pybind11笔记】python调用c类 文件结构 为了方便演示&#xff0c;我们使用cmake构建该样例&#xff0c;文件结构如下&#xff1a; pybind11与eigen…

python学习笔记:问题一,Geany编辑器无法使用中文注释

python学习笔记&#xff1a; 问题一&#xff1a; Geany编辑器无法使用中文注释 Geany编译python时运行弹出SyntaxError: (unicode error) ‘utf-8’ codec can’t提升&#xff0c;文本编辑器Geany无法使用中文注释&#xff0c;可以设置一下文本编码格式就好了设置方法为&…

Python--注释

Python--注释 <font size4, colorblue> 一、Python中注释的形式<font size4, colorblue> 1、单行注释&#xff1a;使用“#”符号注释<font size4, colorblue> 2、多行注释&#xff1a;使用一对三个英文单引号注释<font size4, colorblue> 3、多行注释&…

python中generate什么意思_python generate怎么用

generate语句允许细化时间(Elaboration-time)的选取或者某些语句的重复。这些语句可以包括模块实例引用的语句、连续赋值语句、always语句、initial语句和门级实例引用语句等。细化时间是指仿真开始前的一个阶段&#xff0c;此时所有的设计模块已经被链接到一起&#xff0c;并完…

Python Gevent

参考资料 http://www.gevent.org/contents.htmlhttps://uwsgi-docs-zh.readthedocs.io/zh_CN/latest/Gevent.html Python脚本的执行效率一直来说并不是很高&#xff0c;特别是Python下的多线程机制&#xff0c;长久以来一直被人们诟病。很多人都在思考如何让Python执行的更快…

符号回归工具之 geppy: Python中的基因表达编程框架

符号回归工具之 geppy&#xff1a; Python中的基因表达编程框架 geppy是一个专门用于基因表达编程&#xff08;GEP&#xff09;的计算框架&#xff0c;由 C. Ferreira 在 2001 年提出 [1]。 geppy是在 Python 3 中开发的。这个框架个人认为稍微了解下遗传算法和遗传规划即可入…

如何在Geany中添加python的中文注释

在Geany中编译Python中直接添加中文注释会出现如下错误 只需要在程序的开始位置添加一句&#xff1a;# coding:utf-8