Eryck Zhou

A super simple BLOG for Artifical Intelligence.

从零复现YOLOv4系列一:Backbone部分

28 February 2023

Photo by unsplash-logoPatryk Wojcieszak

Paper: YOLOv4: Optimal Speed and Accuracy of Object Detection

CBM 模块

class CBM(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride = 1):
        super(CBM, self).__init__()

        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, \
                              stride, kernel_size // 2, bias = False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = Mish()

    def forward(self, x):
        x = self.act(self.bn(self.conv(x)))
        return x

Res Unit

lass ResUnit(nn.Module):
    def __init__(self, channels, hidden_channels = None):
        super(ResUnit, self).__init__()

        if hidden_channels is None:
            hidden_channels = channels

        self.block = nn.Sequential(
            CBM(channels, hidden_channels, 1),
            CBM(hidden_channels, channels, 3),
        )

    def forward(self, x):
        return x + self.block(x)

CSPX 模块

class CSPX(nn.Module):
    def __init__(self, in_channels, out_channels, num_blocks, first=False):
        super(CSPX, self).__init__()

        self.downsample_conv = CBM(in_channels, out_channels, 3, stride = 2)

        if first:
            self.split_conv0 = CBM(out_channels, out_channels, 1)
            self.split_conv1 = CBM(out_channels, out_channels, 1)
            self.blocks_conv = nn.Sequential(
                ResUnit(channels=out_channels, hidden_channels=out_channels // 2),
                CBM(out_channels, out_channels ,1)
            )
            
            self.concat_conv = CBM(out_channels * 2, out_channels, 1)   # ❗️
        else:
            self.split_conv0 = CBM(out_channels, out_channels // 2, 1)
            self.split_conv1 = CBM(out_channels, out_channels // 2, 1)

            self.blocks_conv = nn.Sequential(
                *[ResUnit(channels=out_channels // 2) for _ in range(num_blocks)],
                CBM(out_channels // 2, out_channels // 2 ,1)
            )

            self.concat_conv = CBM(out_channels, out_channels, 1)       # ❗️

    def forward(self, x):
        x = self.downsample_conv(x)

        x0 = self.split_conv0(x)
        x1 = self.blocks_conv(self.split_conv1(x))

        x = torch.cat([x1, x0], dim=1)
        x = self.concat_conv(x)

        return x

CSPDarkNet(Backbone)

class CSPDarkNet(nn.Module):
    def __init__(self, layers):
        super(CSPDarkNet, self).__init__()
        self.inplanes = 32
        self.conv1 = CBM(3, self.inplanes, kernel_size=3, stride=1)
        self.feature_channels = [64, 128, 256, 512, 1024]

        self.stages = nn.ModuleList([
            CSPX(self.inplanes, self.feature_channels[0], layers[0], first=True),
            CSPX(self.feature_channels[0], self.feature_channels[1], layers[1]),
            CSPX(self.feature_channels[1], self.feature_channels[2], layers[2]),
            CSPX(self.feature_channels[2], self.feature_channels[3], layers[3]),
            CSPX(self.feature_channels[3], self.feature_channels[4], layers[4])
        ])

        # 初始化权重
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        x = self.conv1(x)

        x = self.stages[0](x)
        x = self.stages[1](x)
        out3 = self.stages[2](x)
        out4 = self.stages[3](out3)
        out5 = self.stages[4](out4)

        return out3, out4, out5