Hits

Xception: Deep Learning with Depthwise Seperable Convolutions, F Chollet, CVPR2017

Paper


Xception 방법 요약

Inception Module에서 regular convolution 연산을 depthwise convolution 연산으로 변환환 Xception을 소개한다.


이전 연구인 Inception의 가설

Inception에서는 cross-channel correlations와 spatial correlations이 decoupled 되어있어 jointly 맵핑하지 않는 것이 가능하다는 것이다. Inception에서는 1x1 convolution들로 cross channels relation 기능을 하는 것 처럼 보이고, 나머지 3x3이나 5x5 convolution으로 3D space에 대한 mapping을 하는 것 처럼 보인다.

하지만 이러한 cross-channel correlations와 spatial correlations를 완전하게 분리하여 맵핑할 수 있을까? 이 질문의 답을 Xception 논문에서 제시한다.

1


Depthwise Separable Convolution

Depthwise separable convolution 또는 separable convolution이라고 부른다. separable convolution은 depthwise convolution을 수행하고, pointwise convolution을 수행하는 구조로 되어있다. depthwise convolution은 input의 각각의 channel에 대해서 독립적으로 수행하는 convolution이고, 그 다음에 pointwise convolution은 depthwise convolution에 의해 연산된 output을 새로운 channel space로 projecting 하는 방법이다.

이것을 Inception의 “extreme” 버전이라고 하여 Xception이라고 한다.

  • cross-channel convolution: depthwise convolution (3x3 convolution)
  • spatial convolution: pointwise convolution (1x1 convolution)

2

논문에 있는 convolution은 기존의 Inception과 비교한 것이다. 실제로 사용하게 되는 Depthwise Separable Convolution은 논문에 있는 그림의 반대로 연결했다고 보면 된다.

5 6


class SeparableConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super(SeparableConv2d, self).__init__()

        self.depthwise = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, padding=1),
            nn.BatchNorm2d(in_channels),
        )
        self.pointwise = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1),
            nn.BatchNorm2d(out_channels),
        )

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x


Xception의 architecture

Xception은 전체적으로 depthwise separable convolution을 residual connection과 함께 linear stack으로 쌓은 것으로 볼 수 있다. Residual을 더하기 전의 연산 모듈들을 nn.Sequential로 하나의 block으로 생성하였다. Middle flow 부분은 같은 모듈을 8개 쌓아야 하므로 층이 많아서 class로 따로 구현을 해준 후, main 모델에 추가해주었다.

3

import torch
import torch.nn as nn

class SeparableConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super(SeparableConv2d, self).__init__()

        self.depthwise = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, padding=1),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
        )
        self.pointwise = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x

class Block(nn.Module):
    def __init__(self, in_channels):
        super(Block, self).__init__()
        self.feature = nn.Sequential(
            nn.ReLU(),
            SeparableConv2d(in_channels, in_channels, kernel_size=3),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(),
            SeparableConv2d(in_channels, in_channels, 3),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(),
            SeparableConv2d(in_channels, in_channels, 3),
            nn.BatchNorm2d(in_channels),  
        )

    def forward(self, x):
        x = self.feature(x) + x
        x = self.feature(x) + x
        x = self.feature(x) + x
        x = self.feature(x) + x
        x = self.feature(x) + x
        x = self.feature(x) + x
        x = self.feature(x) + x
        x = self.feature(x) + x
        return x

class Xception(nn.Module):
    def __init__(self):
        super(Xception, self).__init__()
        self.feature = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )

        self.conv3 = nn.Conv2d(64, 128, kernel_size=1, stride=2)

        self.block1 = nn.Sequential(
            SeparableConv2d(64, 128, kernel_size=3),
            nn.ReLU(),
            SeparableConv2d(128, 128, 3),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        self.conv4 = nn.Conv2d(128, 256, kernel_size=1, stride=2)

        self.block2 = nn.Sequential(
            nn.ReLU(),
            SeparableConv2d(128, 256, kernel_size=3),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            SeparableConv2d(256, 256, 3),
            nn.BatchNorm2d(256),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        self.conv5 = nn.Conv2d(256, 728, kernel_size=1, stride=2)

        self.block3 = nn.Sequential(
            nn.ReLU(),
            SeparableConv2d(256, 728, kernel_size=3),
            nn.BatchNorm2d(728),
            nn.ReLU(),
            SeparableConv2d(728, 728, 3),
            nn.BatchNorm2d(728),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        #self.conv6 = nn.Conv2d(728, 728, kernel_size=1, stride=2)

        self.block4 = Block(728)

        self.conv7 = nn.Conv2d(728, 1024, kernel_size=1, stride=2)

        self.block5 = nn.Sequential(
            nn.ReLU(),
            SeparableConv2d(728, 728, kernel_size=3),
            nn.BatchNorm2d(728),
            nn.ReLU(),
            SeparableConv2d(728, 1024, 3),
            nn.BatchNorm2d(1024),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)   
        )

        self.block6 = nn.Sequential(
            nn.ReLU(),
            SeparableConv2d(1024, 1536, kernel_size=3),
            nn.BatchNorm2d(1536),
            nn.ReLU(),
            SeparableConv2d(1536, 2048, 3),
            nn.BatchNorm2d(2048),
            nn.AvgPool2d(kernel_size=2, stride=1)
        )

        self.fc = nn.Linear(2048, 100)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def make_layer(self, iter):
        layers = []
        feature = nn.Sequential(
            nn.ReLU(),
            SeparableConv2d(728, 728, kernel_size=3),
            nn.BatchNorm2d(728),
            nn.ReLU(),
            SeparableConv2d(728, 728, 3),
            nn.BatchNorm2d(728),
            nn.ReLU(),
            SeparableConv2d(728, 728, 3),
            nn.BatchNorm2d(728),  
        )
        for i in range(iter):
            layers.append(feature)
        return nn.Sequential(*layers)

    def forward(self, x):
        # Entry flow
        x = self.feature(x)
        x = self.block1(x) + self.conv3(x)
        x = self.block2(x) + self.conv4(x)
        x = self.block3(x) + self.conv5(x)

        # Middle flow
        x = self.block4(x)
        
        # Exit flow
        x = self.block5(x) + self.conv7(x)
        x = self.block6(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

def xception():
    return Xception()