Xception: Deep Learning with Depthwise Seperable Convolutions, F Chollet, CVPR2017
Xception 방법 요약
Inception Module에서 regular convolution 연산을 depthwise convolution 연산으로 변환환 Xception을 소개한다.
이전 연구인 Inception의 가설
Inception에서는 cross-channel correlations와 spatial correlations이 decoupled 되어있어 jointly 맵핑하지 않는 것이 가능하다는 것이다. Inception에서는 1x1 convolution들로 cross channels relation 기능을 하는 것 처럼 보이고, 나머지 3x3이나 5x5 convolution으로 3D space에 대한 mapping을 하는 것 처럼 보인다.
하지만 이러한 cross-channel correlations와 spatial correlations를 완전하게 분리하여 맵핑할 수 있을까? 이 질문의 답을 Xception 논문에서 제시한다.
Depthwise Separable Convolution
Depthwise separable convolution 또는 separable convolution이라고 부른다. separable convolution은 depthwise convolution을 수행하고, pointwise convolution을 수행하는 구조로 되어있다. depthwise convolution은 input의 각각의 channel에 대해서 독립적으로 수행하는 convolution이고, 그 다음에 pointwise convolution은 depthwise convolution에 의해 연산된 output을 새로운 channel space로 projecting 하는 방법이다.
이것을 Inception의 “extreme” 버전이라고 하여 Xception이라고 한다.
- cross-channel convolution: depthwise convolution (3x3 convolution)
- spatial convolution: pointwise convolution (1x1 convolution)
논문에 있는 convolution은 기존의 Inception과 비교한 것이다. 실제로 사용하게 되는 Depthwise Separable Convolution은 논문에 있는 그림의 반대로 연결했다고 보면 된다.
class SeparableConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size):
super(SeparableConv2d, self).__init__()
self.depthwise = nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, padding=1),
nn.BatchNorm2d(in_channels),
)
self.pointwise = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1),
nn.BatchNorm2d(out_channels),
)
def forward(self, x):
x = self.depthwise(x)
x = self.pointwise(x)
return x
Xception의 architecture
Xception은 전체적으로 depthwise separable convolution을 residual connection과 함께 linear stack으로 쌓은 것으로 볼 수 있다. Residual을 더하기 전의 연산 모듈들을 nn.Sequential로 하나의 block으로 생성하였다. Middle flow 부분은 같은 모듈을 8개 쌓아야 하므로 층이 많아서 class로 따로 구현을 해준 후, main 모델에 추가해주었다.
import torch
import torch.nn as nn
class SeparableConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size):
super(SeparableConv2d, self).__init__()
self.depthwise = nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, padding=1),
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True),
)
self.pointwise = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.depthwise(x)
x = self.pointwise(x)
return x
class Block(nn.Module):
def __init__(self, in_channels):
super(Block, self).__init__()
self.feature = nn.Sequential(
nn.ReLU(),
SeparableConv2d(in_channels, in_channels, kernel_size=3),
nn.BatchNorm2d(in_channels),
nn.ReLU(),
SeparableConv2d(in_channels, in_channels, 3),
nn.BatchNorm2d(in_channels),
nn.ReLU(),
SeparableConv2d(in_channels, in_channels, 3),
nn.BatchNorm2d(in_channels),
)
def forward(self, x):
x = self.feature(x) + x
x = self.feature(x) + x
x = self.feature(x) + x
x = self.feature(x) + x
x = self.feature(x) + x
x = self.feature(x) + x
x = self.feature(x) + x
x = self.feature(x) + x
return x
class Xception(nn.Module):
def __init__(self):
super(Xception, self).__init__()
self.feature = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
self.conv3 = nn.Conv2d(64, 128, kernel_size=1, stride=2)
self.block1 = nn.Sequential(
SeparableConv2d(64, 128, kernel_size=3),
nn.ReLU(),
SeparableConv2d(128, 128, 3),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
self.conv4 = nn.Conv2d(128, 256, kernel_size=1, stride=2)
self.block2 = nn.Sequential(
nn.ReLU(),
SeparableConv2d(128, 256, kernel_size=3),
nn.BatchNorm2d(256),
nn.ReLU(),
SeparableConv2d(256, 256, 3),
nn.BatchNorm2d(256),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
self.conv5 = nn.Conv2d(256, 728, kernel_size=1, stride=2)
self.block3 = nn.Sequential(
nn.ReLU(),
SeparableConv2d(256, 728, kernel_size=3),
nn.BatchNorm2d(728),
nn.ReLU(),
SeparableConv2d(728, 728, 3),
nn.BatchNorm2d(728),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
#self.conv6 = nn.Conv2d(728, 728, kernel_size=1, stride=2)
self.block4 = Block(728)
self.conv7 = nn.Conv2d(728, 1024, kernel_size=1, stride=2)
self.block5 = nn.Sequential(
nn.ReLU(),
SeparableConv2d(728, 728, kernel_size=3),
nn.BatchNorm2d(728),
nn.ReLU(),
SeparableConv2d(728, 1024, 3),
nn.BatchNorm2d(1024),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
self.block6 = nn.Sequential(
nn.ReLU(),
SeparableConv2d(1024, 1536, kernel_size=3),
nn.BatchNorm2d(1536),
nn.ReLU(),
SeparableConv2d(1536, 2048, 3),
nn.BatchNorm2d(2048),
nn.AvgPool2d(kernel_size=2, stride=1)
)
self.fc = nn.Linear(2048, 100)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def make_layer(self, iter):
layers = []
feature = nn.Sequential(
nn.ReLU(),
SeparableConv2d(728, 728, kernel_size=3),
nn.BatchNorm2d(728),
nn.ReLU(),
SeparableConv2d(728, 728, 3),
nn.BatchNorm2d(728),
nn.ReLU(),
SeparableConv2d(728, 728, 3),
nn.BatchNorm2d(728),
)
for i in range(iter):
layers.append(feature)
return nn.Sequential(*layers)
def forward(self, x):
# Entry flow
x = self.feature(x)
x = self.block1(x) + self.conv3(x)
x = self.block2(x) + self.conv4(x)
x = self.block3(x) + self.conv5(x)
# Middle flow
x = self.block4(x)
# Exit flow
x = self.block5(x) + self.conv7(x)
x = self.block6(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def xception():
return Xception()