Gated-SCNN: Gated Shape CNNs for Semantic Segmentation

article/2025/10/2 6:38:04

目录

作者

一、Model of Gated-SCNN

二、 Gated Shape CNN

1.Regular Stream

2.Shape Stream

3. Gate Conv Layer

4.ASPP

5 总代码

三 损失函数

1.BoundaryBCELoss

2.DualTaskLoss



作者

一、Model of Gated-SCNN

文章使用了双流CNN来处理语义分割中的边界问题,分为Regular streamShape stream.
作者认为,在编码器提取的很多特征中,例如纹理,色彩,梯度等很多细节信息,随着卷积的不断加深,会干扰定位信息,因为包含了很多与识别无关的信息,比如整体图中的背景信息,树的细节信息对于网络分割而言,属于噪声,但是轮廓这种信息又不能完全舍去(汽车轮廓).
综上,作者将形状信息作为一个单独的分支,目的是只提取对应的轮廓信息,对应图中的shape stream

二、 Gated Shape CNN

1.Regular Stream

作者使用主流的 Resnet-101 and WideResNet 作为本文的Regular stream,对输入的2D图像提取了分辨率不同的5种特征图。我们把这五种特征图记作 c1,c2,c3,c4,c5. 与此同时,在输入图像时,提取了 图像的梯度 image gradients (channel==1). 记为 grad

2.Shape Stream

1.  在shape stream 中,使用了c1,c3,c4,c5以及grad 作为形状流的输入。
2.  c3,c4,c5 先经过1x1卷积降为残差结构(channel==1)对应代码中self.res()_conv.  channel of c1 保持不变(64),之后将所有特征图上采样到原图大小(size of grad).
3.  在2的基础上,shape stream 会通过3个residual block,将通道数再次降低一半。
4.  在通过residual block之后,会分别与c3,c4,c5进入Gate Conv Layer 得到最终通道数为8的特征  图gate3,再次将gate3降维至1并转化为权重分数表示得到gate。
5.  gate 与 grad 以第一维度(channel)拼接融合,形成新的权重。作者也对gate进行了边界损失,防止边界预测错误。最终 shape stream 输出的feat 作为针对形状的预测与regular stream 特征进行融合(加强边界信息)。

因为edge bce loss的原因,会限制其他细节的得分,例如,色彩,斑点,纹理,以及一些小的梯度,都会被bce loss在反向传播的过程中逐层减弱。 这也是形状流只关注形状的主要原因。

class ShapeStream(nn.Module):def __init__(self):super().__init__()self.res2_conv = nn.Conv2d(512, 1, 1)self.res3_conv = nn.Conv2d(1024, 1, 1)self.res4_conv = nn.Conv2d(2048, 1, 1)self.res1 = BasicBlock(64, 64, 1)self.res2 = BasicBlock(32, 32, 1)self.res3 = BasicBlock(16, 16, 1)self.res1_pre = nn.Conv2d(64, 32, 1)self.res2_pre = nn.Conv2d(32, 16, 1)self.res3_pre = nn.Conv2d(16, 8, 1)self.gate1 = GatedConv(32, 32)self.gate2 = GatedConv(16, 16)self.gate3 = GatedConv(8, 8)self.gate = nn.Conv2d(8, 1, 1, bias=False)self.fuse = nn.Conv2d(2, 1, 1, bias=False)def forward(self, c1, c2, c3, c4, grad):size = grad.size()[-2:]c1 = F.interpolate(c1, size, mode='bilinear', align_corners=True)c2 = F.interpolate(self.res2_conv(c2), size, mode='bilinear', align_corners=True)c3 = F.interpolate(self.res3_conv(c3), size, mode='bilinear', align_corners=True)c4 = F.interpolate(self.res4_conv(c4), size, mode='bilinear', align_corners=True)gate1 = self.gate1(self.res1_pre(self.res1(c1)), c2)gate2 = self.gate2(self.res2_pre(self.res2(gate1)), c3)gate3 = self.gate3(self.res3_pre(self.res3(gate2)), c4)gate = torch.sigmoid(self.gate(gate3))feat = torch.sigmoid(self.fuse(torch.cat((gate, grad), dim=1)))return gate, feat

3. Gate Conv Layer

需要注意的是 每次进入GCL层的 feat的channel为 32,16,8.  gate channel ==1

class GatedConv(nn.Conv2d):def __init__(self, in_channels, out_channels):super().__init__(in_channels, out_channels, 1, bias=False)self.attention = nn.Sequential(nn.BatchNorm2d(in_channels + 1),nn.Conv2d(in_channels + 1, in_channels + 1, 1),nn.ReLU(),nn.Conv2d(in_channels + 1, 1, 1),nn.BatchNorm2d(1),nn.Sigmoid())def forward(self, feat, gate):attention = self.attention(torch.cat((feat, gate), dim=1))out = F.conv2d(feat * (attention + 1), 1, out_channels)return out

4.ASPP

最终ASPP的输入是c1 c4 以及 feat,rate of dilation are 6,12,and 18.

class FeatureFusion(ASPP):def __init__(self, in_channels, atrous_rates=(6, 12, 18), out_channels=256):# atrous_rates (6, 12, 18) is for stride 16super().__init__(in_channels, atrous_rates, out_channels)self.shape_conv = nn.Sequential(nn.Conv2d(1, out_channels, 1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU())self.project = nn.Conv2d((len(atrous_rates) + 3) * out_channels, out_channels, 1, bias=False)self.fine = nn.Conv2d(256, 48, kernel_size=1, bias=False)def forward(self, c1, c4, feat):res = []for conv in self.convs:res.append(conv(c4))res = torch.cat(res, dim=1)feat = F.interpolate(feat, res.size()[-2:], mode='bilinear', align_corners=True)res = torch.cat((res, self.shape_conv(feat)), dim=1)coarse = F.interpolate(self.project(res), c1.size()[-2:], mode='bilinear', align_corners=True)fine = self.fine(c1)out = torch.cat((coarse, fine), dim=1)return out

5 总代码

class GatedSCNN(nn.Module):def __init__(self, backbone_type='resnet50', num_classes=19):super().__init__()self.regular_stream = RegularStream(backbone_type)self.shape_stream = ShapeStream()self.feature_fusion = FeatureFusion(2048, (12, 24, 36), 256)self.seg = nn.Sequential(nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(256),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(256),nn.ReLU(inplace=True),nn.Conv2d(256, num_classes, kernel_size=1, bias=False))def forward(self, x, grad):x, res1, res2, res3, res4 = self.regular_stream(x)gate, feat = self.shape_stream(x, res2, res3, res4, grad)out = self.feature_fusion(res1, res4, feat)seg = F.interpolate(self.seg(out), grad.size()[-2:], mode='bilinear', align_corners=False)# [B, N, H, W], [B, 1, H, W]return seg, gate

三 损失函数

1.BoundaryBCELoss

torch.clamp 的作用是将输入的tensor 缩放到最小值和最大值之间
edge是网络预测的输出,而boundary 是针对 GT 的边界

class BoundaryBCELoss(nn.Module):def __init__(self, ignore_index=255):super().__init__()self.ignore_index = ignore_indexdef forward(self, edge, target, boundary):edge = edge.squeeze(dim=1)mask = target != self.ignore_indexpos_mask = (boundary == 1.0) & maskneg_mask = (boundary == 0.0) & masknum = torch.clamp(mask.sum(), min=1)pos_weight = neg_mask.sum() / numneg_weight = pos_mask.sum() / numweight = torch.zeros_like(boundary)weight[pos_mask] = pos_weightweight[neg_mask] = neg_weightloss = F.binary_cross_entropy(edge, boundary, weight, reduction='sum') / numreturn loss

2.DualTaskLoss

threshold的作用是将太细的梯度过滤

class DualTaskLoss(nn.Module):def __init__(self, threshold=0.8, ignore_index=255):super().__init__()self.threshold = thresholdself.ignore_index = ignore_indexdef forward(self, seg, edge, target):edge = edge.squeeze(dim=1)logit = F.cross_entropy(seg, target, ignore_index=self.ignore_index, reduction='none')mask = target != self.ignore_indexnum = torch.clamp(((edge > self.threshold) & mask).sum(), min=1)loss = (logit[edge > self.threshold].sum()) / numreturn loss


http://chatgpt.dhexx.cn/article/hCDFXGOc.shtml

相关文章

(整理)吊炸天的CNNs,这是我见过最详尽的图解!(下)

之前在CSDN上看到这篇文章,觉得通俗易懂,写的非常好。不过近来再次查看,发现文章的照片莫名其妙的没有了,没有图就根本看不懂了。找到了之前关注的微信公众号:AI传送门 。 在里面找到了这篇文章,决定再把这…

CNNs详尽图解

已经成为每一个初入人工智能——特别是图像识别领域的朋友,都渴望探究的秘密。 本文通过“算法可视化”的方法,将卷积神经网络的原理,呈献给大家。教程分为上、下两个部分,通篇长度不超过7000字,没有复杂的数学公式&a…

学习笔记:利用CNNs进行图像分类

1.神经网络图像(CNNs)分类简介 本文将重点关注卷积神经网络,也被称为CNNs或Convnets。CNNs是一种特殊类型的神经网络,特别适合于图像数据。自2012年以来,ImageNet竞赛(ImageNet)一直由CNN架构赢得。 在本文中&#x…

(整理)吊炸天的CNNs,这是我见过最详尽的图解!(上)

之前在CSDN上看到这篇文章,觉得通俗易懂,写的非常好。不过近来再次查看,发现文章的照片莫名其妙的没有了,没有图就根本看不懂了。找到了之前关注的微信公众号:AI传送门 。 在里面找到了这篇文章,决定再把这…

交叉验证

概述Holdout 交叉验证K-Fold 交叉验证Leave-P-Out 交叉验证总结 概述 交叉验证是在机器学习建立模型和验证模型参数时常用的办法。 顾名思义,就是重复的使用数据,把得到的样本数据进行切分,组合为不同的训练集和测试集。 用训练集来训练模型&…

交叉验证评估模型性能

在构建一个机器学习模型之后,我们需要对模型的性能进行评估。如果一个模型过于简单,就会导致欠拟合(高偏差)问题,如果模型过于复杂,就会导致过拟合(高方差)问题。为了使模型能够在欠拟合和过拟合之间找到一个折中方案,我们需要对模型进行评估,后面将会介绍holdout交叉…

Python实现:Hold-Out、k折交叉验证、分层k折交叉验证、留一交叉验证

模型在统计中是极其重要的,可以通过模型来描述数据集的内在关系,了解数据的内在关系有助于对未来进行预测。一个模型可以通过设置不同的参数来描述不同的数据集,有的参数需要根据数据集估计,有的参数需要人为设定(超参…

深度理解hold-out Method(留出法)和K-fold Cross-Validation(k折交叉验证法)

模型评估(Model Evaluation) 1.测试集(testing set) 测试集(testing set): 通常,我们可通过实验测验来对学习器的泛化误差进行评估并进而做出选择,为此,需要一个“测试集”来测试学习器对新样本的判别能力。然后以测试集上的“测…

cross-validation:从 holdout validation 到 k-fold validation

构建机器学习模型的一个重要环节是评价模型在新的数据集上的性能。模型过于简单时,容易发生欠拟合(high bias);模型过于复杂时,又容易发生过拟合(high variance)。为了达到一个合理的 bias-vari…

《The reusable holdout: Preserving validity in adaptive data analysis》中文翻译

写在前面:这是我看到的第一篇发在《science》上的文章,将近年来比较火的差分隐私用在解决过机器学习中的过拟合上,效果很棒。这是15年的文章,现在已经17年了,网上居然没有中文翻译,我就粗略的翻译一下给后来…

机器学习模型评测:holdout cross-validation k-fold cross-validation

cross-validation:从 holdout validation 到 k-fold validation 2016年01月15日 11:06:00 Inside_Zhang 阅读数:4445 版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/lanchunhui/article/details/5…

三种模型验证方法:holdout, K-fold, leave one out cross validation(LOOCV)

Cross Validation: A Beginner’s Guide An introduction to LOO, K-Fold, and Holdout model validation By: Caleb Neale, Demetri Workman, Abhinay Dommalapati 源自:https://towardsdatascience.com/cross-validation-a-beginners-guide-5b8ca04962cd 文章目录…

模型检验方法:holdout、k-fold、bootstrap

参考:https://www.cnblogs.com/chay/articles/10745417.html https://www.cnblogs.com/xiaosongshine/p/10557891.html 1.Holdout检验 Holdout 检验是最简单也是最直接的验证方法, 它将原始的样本集合随机划分成训练集和验证集两部分。 比方说&#x…

多种方式Map集合遍历

1.如何遍历Map中的key-value对,代码实现(至少2种) Map集合的遍历(方式1)键找值: package com.B.Container_13.Map;import java.util.HashMap; import java.util.Map; import java.util.Set;//Map集合的遍历(方式1)键找值 public class Map04_01 {publi…

Map集合中的四种遍历方式

1.Map接口的概述 &#xff08;1&#xff09;它是双列集合&#xff1b; &#xff08;2&#xff09;格式&#xff1a;Interface Map<k,v> K:键的类型 V&#xff1a;值得类型 &#xff08;3&#xff09;它的每个元素都包含一个键对象Key和值对象Value&#xff0c;并且他们…

Java中的Map集合以及Map集合遍历实例

文章目录 一、Map集合二、Map集合遍历实例 一、Map集合 Map<K,V>k是键&#xff0c;v是值 1、 将键映射到值的对象&#xff0c;一个映射不能包含重复的键&#xff0c;每个键最多只能映射的一个值 2、 实现类  a) HashMap  b) TreeMap 3、 Map集合和Collection集合的区别…

Map集合的四种遍历

Map集合的四种遍历 这里记录一下map集合的4种遍历&#xff1a; 第一种 得到所有的key–map.keySet() ,根据key拿到value–map.get(key) public static void main(String[] args) {Map<String, String> map new HashMap();map.put("1", "刘备");…

Map集合遍历的三种方式

Map集合遍历的三种方式 遍历Map集合的三种方式 键找值键值对Lambda表达式 方式一 : 键找值 先获取Map集合的全部键的Set集合遍历键的Set集合,然后通过键提取对应值 原理图 键找值涉及到的API 方法名称说明Set keySet()获取所有键的集合V get(Object key)根据键获取值 Map…

java中Map集合的四种遍历方式

java中Map集合的四种遍历方式 Map接口和Collection接口的集合不同,Map集合是双列的,Collection是单列的.Map集合将键映射到值的对象. 双列的集合遍历起来也是比较麻烦些的,特别是嵌套的map集合,这里说下MAP集合的四种遍历方式&#xff0c;并且以嵌套的hashMap集合为例, 遍历一…