从ResNet到ResNeSt:手把手教你用PyTorch复现Split-Attention注意力机制

张开发
2026/6/8 12:08:52 15 分钟阅读
从ResNet到ResNeSt:手把手教你用PyTorch复现Split-Attention注意力机制
从ResNet到ResNeSt手把手教你用PyTorch复现Split-Attention注意力机制在计算机视觉领域注意力机制已经成为提升模型性能的关键技术。ResNeSt作为ResNet的改进版本通过引入Split-Attention机制在保持ResNet简洁架构的同时显著提升了特征表达能力。本文将深入解析Split-Attention的实现细节带你从零开始用PyTorch实现这一创新模块。1. Split-Attention核心原理剖析Split-Attention的核心思想是将特征图在通道维度上进行多级分组并在不同组之间建立注意力交互。这种设计既保留了分组卷积的计算效率又通过注意力机制增强了特征表达能力。具体来说Split-Attention包含三个关键步骤基数分组(Cardinal Groups)将输入特征图划分为K个基数组径向划分(Radix Splits)在每个基数组内进一步划分为R个子组注意力融合基于全局上下文信息计算各子组的注意力权重这种双重分组结构可以用以下公式表示总分组数 基数(K) × 径向数(R)在PyTorch中我们可以通过group参数实现基数分组而径向划分则需要更精细的张量操作。下面是一个简单的分组示意图操作步骤输入形状输出形状说明基数分组(B,C,H,W)(B,K,C/K,H,W)沿通道维度分组径向划分(B,K,C/K,H,W)(B,K,R,C/(KR),H,W)每组内再划分注意力计算(B,K,R,C/(KR),H,W)(B,K,C/K,H,W)加权融合2. RadixSoftmax模块实现RadixSoftmax是Split-Attention的核心组件负责计算各子组的注意力权重。与常规Softmax不同它需要在特定维度上进行归一化。class RadixSoftmax(nn.Module): def __init__(self, radix, cardinality): super().__init__() self.radix radix # 每个基数组下的子组数 self.cardinality cardinality # 基数组数量 def forward(self, x): batch x.size(0) if self.radix 1: # 将输入重塑为(B, K, R, C/(KR))形式 x x.view(batch, self.cardinality, self.radix, -1) # 在径向维度(R)上计算Softmax x F.softmax(x, dim2) x x.reshape(batch, -1) else: x torch.sigmoid(x) return x这个实现有几个关键点当radix1时退化为Sigmoid相当于SE模块的注意力机制通过view和reshape操作实现张量的高效重组Softmax仅在径向维度计算保持基数组间的独立性3. SplitAttn模块完整实现基于RadixSoftmax我们可以构建完整的SplitAttn模块。以下是逐步实现过程class SplitAttn(nn.Module): def __init__(self, in_channels, out_channelsNone, kernel_size3, stride1, radix2, groups1, norm_layernn.BatchNorm2d): super().__init__() out_channels out_channels or in_channels self.radix radix # 中间通道数 输出通道 × radix mid_chs out_channels * radix # 注意力计算通道数 attn_chs max(in_channels * radix // 8, 32) # 主卷积路径 self.conv nn.Conv2d( in_channels, mid_chs, kernel_size, stridestride, paddingkernel_size//2, groupsgroups * radix, biasFalse) self.bn0 norm_layer(mid_chs) self.act0 nn.ReLU(inplaceTrue) # 注意力路径 self.fc1 nn.Conv2d(out_channels, attn_chs, 1, groupsgroups) self.bn1 norm_layer(attn_chs) self.act1 nn.ReLU(inplaceTrue) self.fc2 nn.Conv2d(attn_chs, mid_chs, 1, groupsgroups) self.rsoftmax RadixSoftmax(radix, groups)前向传播过程需要仔细处理张量形状变换def forward(self, x): x self.conv(x) x self.bn0(x) x self.act0(x) B, RC, H, W x.shape if self.radix 1: # 将特征图拆分为radix个子组 x x.view(B, self.radix, RC//self.radix, H, W) # 对各子组特征求和 x_gap x.sum(dim1) else: x_gap x # 计算全局平均池化 x_gap x_gap.mean([2,3], keepdimTrue) # 计算注意力权重 x_attn self.fc1(x_gap) x_attn self.bn1(x_attn) x_attn self.act1(x_attn) x_attn self.fc2(x_attn) x_attn self.rsoftmax(x_attn).view(B, -1, 1, 1) # 应用注意力权重 if self.radix 1: out (x * x_attn.view(B, self.radix, RC//self.radix, 1, 1)).sum(dim1) else: out x * x_attn return out.contiguous()4. ResNeSt Bottleneck集成将SplitAttn集成到ResNet的Bottleneck中形成完整的ResNeSt模块class ResNestBottleneck(nn.Module): expansion 4 def __init__(self, inplanes, planes, stride1, downsampleNone, radix2, cardinality1, base_width64): super().__init__() group_width int(planes * (base_width / 64.)) * cardinality self.conv1 nn.Conv2d(inplanes, group_width, kernel_size1, biasFalse) self.bn1 nn.BatchNorm2d(group_width) self.act1 nn.ReLU(inplaceTrue) self.conv2 SplitAttn( group_width, group_width, kernel_size3, stridestride, radixradix, groupscardinality) self.conv3 nn.Conv2d(group_width, planes * self.expansion, kernel_size1, biasFalse) self.bn3 nn.BatchNorm2d(planes * self.expansion) self.act3 nn.ReLU(inplaceTrue) self.downsample downsample前向传播保持了ResNet的经典残差结构def forward(self, x): identity x out self.conv1(x) out self.bn1(out) out self.act1(out) out self.conv2(out) out self.conv3(out) out self.bn3(out) if self.downsample is not None: identity self.downsample(x) out identity out self.act3(out) return out5. 实战技巧与性能优化在实际实现ResNeSt时有几个关键点需要注意基数与径向数的选择基数(Cardinality)通常设置为1或2径向数(Radix)常用值为2或4两者乘积不宜过大否则会显著增加计算量内存优化使用contiguous()确保张量内存布局连续合理设置groups参数利用分组卷积优化训练技巧学习率warmup有助于稳定训练标签平滑(Label Smoothing)可以提升泛化能力大型batch训练时需要调整BN参数以下是不同配置下的计算量对比模型Params(M)FLOPs(G)Top-1 Acc(%)ResNet-5025.54.176.2ResNeSt-50 (radix2)27.54.378.3ResNeSt-50 (radix4)30.14.779.1在实现过程中我发现最易出错的地方是张量形状变换。特别是在SplitAttn模块中需要确保分组卷积的groups参数正确设置为cardinality × radix注意力权重的计算与原始特征图维度匹配残差连接时的通道数对齐一个实用的调试技巧是添加shape检查断言assert x.shape (B,C,H,W), fExpected shape {(B,C,H,W)}, got {x.shape}通过PyTorch的灵活张量操作我们可以高效实现Split-Attention机制。相比原始论文的TensorFlow实现PyTorch版本通常能获得更好的运行时性能特别是在使用torch.jit.script优化后。

更多文章