28 February 2023
Paper: YOLOv4: Optimal Speed and Accuracy of Object Detection
class CBM(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride = 1):
super(CBM, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, \
stride, kernel_size // 2, bias = False)
self.bn = nn.BatchNorm2d(out_channels)
self.act = Mish()
def forward(self, x):
x = self.act(self.bn(self.conv(x)))
return x
lass ResUnit(nn.Module):
def __init__(self, channels, hidden_channels = None):
super(ResUnit, self).__init__()
if hidden_channels is None:
hidden_channels = channels
self.block = nn.Sequential(
CBM(channels, hidden_channels, 1),
CBM(hidden_channels, channels, 3),
)
def forward(self, x):
return x + self.block(x)
class CSPX(nn.Module):
def __init__(self, in_channels, out_channels, num_blocks, first=False):
super(CSPX, self).__init__()
self.downsample_conv = CBM(in_channels, out_channels, 3, stride = 2)
if first:
self.split_conv0 = CBM(out_channels, out_channels, 1)
self.split_conv1 = CBM(out_channels, out_channels, 1)
self.blocks_conv = nn.Sequential(
ResUnit(channels=out_channels, hidden_channels=out_channels // 2),
CBM(out_channels, out_channels ,1)
)
self.concat_conv = CBM(out_channels * 2, out_channels, 1) # ❗️
else:
self.split_conv0 = CBM(out_channels, out_channels // 2, 1)
self.split_conv1 = CBM(out_channels, out_channels // 2, 1)
self.blocks_conv = nn.Sequential(
*[ResUnit(channels=out_channels // 2) for _ in range(num_blocks)],
CBM(out_channels // 2, out_channels // 2 ,1)
)
self.concat_conv = CBM(out_channels, out_channels, 1) # ❗️
def forward(self, x):
x = self.downsample_conv(x)
x0 = self.split_conv0(x)
x1 = self.blocks_conv(self.split_conv1(x))
x = torch.cat([x1, x0], dim=1)
x = self.concat_conv(x)
return x
class CSPDarkNet(nn.Module):
def __init__(self, layers):
super(CSPDarkNet, self).__init__()
self.inplanes = 32
self.conv1 = CBM(3, self.inplanes, kernel_size=3, stride=1)
self.feature_channels = [64, 128, 256, 512, 1024]
self.stages = nn.ModuleList([
CSPX(self.inplanes, self.feature_channels[0], layers[0], first=True),
CSPX(self.feature_channels[0], self.feature_channels[1], layers[1]),
CSPX(self.feature_channels[1], self.feature_channels[2], layers[2]),
CSPX(self.feature_channels[2], self.feature_channels[3], layers[3]),
CSPX(self.feature_channels[3], self.feature_channels[4], layers[4])
])
# 初始化权重
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def forward(self, x):
x = self.conv1(x)
x = self.stages[0](x)
x = self.stages[1](x)
out3 = self.stages[2](x)
out4 = self.stages[3](out3)
out5 = self.stages[4](out4)
return out3, out4, out5