https://smartest-suri.tistory.com/49

 

딥러닝 | U-Net(2015) 논문 리뷰

[주의] 본 포스팅은 수리링이 직접 U-Net 논문 원문을 읽고 리뷰한 내용을 담았으며, 참고 문헌이 있는 경우 출처를 명시하였습니다. 본문 내용에 틀린 부분이 있다면 댓글로 말씀해 주시고, 포스

smartest-suri.tistory.com

지난 번 포스팅에서 리뷰한 U-Net 논문을 파이토치를 이용한 코드로 구현한 과정을 정리해 보겠습니다.


1. [연습] Class 없이 한줄씩 구현

직관적인 이해를 위해서 파이토치 코드로 클래스 없이 한줄씩 유넷 구조를 구현해 보도록 하겠습니다. 

# 먼저 필요한 모듈을 임포트 해줍니다.

import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

1-1. Contracting path - 인코딩 파트

  • 논문에서는 valid padding(패딩 없음)을 사용하지만 코드상의 편의를 위해서 모든 콘볼루션 레이어를 same padding(패딩 1)로 구현하겠습니다.
  • 이렇게 할 경우 나중에 skip-connection 파트에서 concatenate할 때 크기가 딱 맞아서 crop할 필요가 없습니다.
input_channels = 3 # 일단 RGB라고 생각하고 3으로 설정하겠습니다.

# 첫번째 블락
conv1 = nn.Conv2d(input_channels, 64, kernel_size = 3, padding = 1)
conv2 = nn.Conv2d(64, 64, kernel_size = 3, padding = 1)
pool1 = nn.MaxPool2d(kernel_size = 2, stride = 2)

# 두번째 블락
conv3 = nn.Conv2d(64, 128, kernel_size = 3, padding = 1)
conv4 = nn.Conv2d(128, 128, kernel_size = 3, padding = 1)
pool2 = nn.MaxPool2d(kernel_size = 2, stride = 2)

# 세번째 블락
conv5 = nn.Conv2d(128, 256, kernel_size = 3, padding = 1)
conv6 = nn.Conv2d(256, 256, kernel_size = 3, padding = 1)
pool3 = nn.MaxPool2d(kernel_size = 2, stride = 2)

# 네번째 블락
conv7 = nn.Conv2d(256, 512, kernel_size = 3, padding = 1)
conv8 = nn.Conv2d(512, 512, kernel_size = 3, padding = 1)
pool4 = nn.MaxPool2d(kernel_size = 2, stride = 2)

1-2. Bottleneck - 연결 파트

bottleneck

  • 연결 구간에 해당하는 Bottleneck 파트를 작성하겠습니다.
  • 여기까지 오면 최종 채널의 수는 1024가 됩니다.
conv9 = nn.Conv2d(512, 1024, kernel_size = 3, padding = 1)
conv10 = nn.Conv2d(1024, 1024, kernel_size = 3, padding = 1)

1-3. Expanding path - 디코딩 파트

  • 디코딩 파트에는 Skip-connection을 통한 사이즈와 필터의 변화에 주목해서 보시면 좋습니다.
# 첫 번째 블락
up1 = nn.ConvTranspose2d(1024, 512, kernel_size = 2, stride = 2)
# 위에 아웃풋은 512지만 나중에 코드에서 skip-connection을 통해 다시 인풋이 1024가 됨
conv11 = nn.Conv2d(1024, 512, kernel_size = 3, padding = 1)
conv12 = nn.Conv2d(512, 512, kernel_size = 3, padding = 1)

# 두 번째 블락
up2 = nn.ConvTranspose2d(512, 256, kernel_size = 2, stride = 2)
conv13 = nn.Conv2d(512, 256, kernel_size = 3, padding = 1) # Skip-connection 포함
conv14 = nn.Conv2d(256, 256, kernel_size = 3, padding = 1)

# 세 번째 블락
up3 = nn.ConvTranspose2d(256, 128, kernel_size = 2, stride = 2)
conv15 = nn.Conv2d(256, 128, kernel_size = 3, padding = 1) # Skip-connection 포함
conv16 = nn.Conv2d(128, 128, kernel_size = 3, padding = 1)

# 네 번째 블락
up4 = nn.ConvTranspose2d(128, 64, kernel_size = 2, stride = 2)
conv17 = nn.Conv2d(128, 64, kernel_size = 3, padding = 1) # Skip-connection 포함
conv18 = nn.Conv2d(64, 64, kernel_size = 3, padding = 1)

# 마지막 아웃풋에서는 1x1 사이즈 콘볼루션을 사용한다고 이야기함.
output = nn.Conv2d(64, 2, kernel_size = 1, padding = 1)

1-4. Forward-pass

포워드 학습 과정을 한줄씩 구현해 보겠습니다.

def unet_forward(x):
    # 인코더 파트
    x = TF.relu(conv1(x))
    x1 = TF.relu(conv2d(x))
    x = pool1(x1)

    x = TF.relu(conv3(x))
    x2 = TF.relu(conv4(x))
    x = pool2(x2)

    x = TF.relu(conv5(x))
    x3 = TF.relu(conv6(x))
    x = pool3(x3)

    x = TF.relu(conv7(x))
    x4 = TF.relu(conv8(x))
    x = pool4(x4)

    # 연결 bottleneck 파트
    x = TF.relu(conv9(x))
    x = TF.relu(conv10(x))

    # 디코더 파트
    x5 = up1(x)
    x = torch.cat([x5, x4], dim = 1) # skip-connection
    x = TF.relu(conv11(x))
    x = TF.relu(conv12(x))

    x6 = up2(x)
    x = torch.cat([x6, x3], dim = 1) # skip-connection
    x = TF.relu(conv13(x))
    x = TF.relu(conv14(x))

    x7 = up3(x)
    x = torch.cat([x7, x2], dim = 1) # skip-connection
    x = TF.relu(conv15(x))
    x = TF.relu(conv16(x))

    x8 = up4(x)
    x = torch.cat([x8, x1], dim = 1) # skip-connection
    x = TF.relu(conv17(x))
    x = TF.relu(conv18(x))

    # 아웃풋 파트
    output = output(x)

    return output

지금까지 클래스 없이 파이토치로 코드를 짜면서 유넷 구조를 직관적으로 이해해 보았습니다.



2. [실전] Class 이용해서 구현

이번엔 파이토치를 사용하는 만큼 클래스를 이용해서 실전 유넷 코드 구현을 해보겠습니다. 먼저 코드를 짜는 과정은 유튜브 [PyTorch Image Segmentation Tutorial with U-NET: everything from scratch baby (Aladdin Persson)]를 참고하여 작성했음을 미리 밝히겠습니다. 

https://youtu.be/IHq1t7NxS8k?si=776huHRjVsIlf_rS


2-1.  모든 블락의 2개 콘볼루션(Conv2d) 레이어를 정의하는 클래스

 

먼저 unet.py 파일을 만들고 클래스들을 작성하여 파일을 모듈화 하겠습니다.

class DoubleConv(mm.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, bias = False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace = True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias = False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace = True)
        )

    def forward(self, x):
        return self.conv(x)
  • 1장과 마찬가지로 편의를 위해서 valid padding 대신 same padding을 사용합니다. Carvana 대회 1등한 팀도 same padding을 사용했다고 하네요. (저는 다른걸로 할거지만, 영상에서 Carvana 데이터셋을 이용해서 학습을 합니다)
  • Batch Normalization이 추가되었습니다. 논문에서 따로 Batch Normalization을 해준다는 언급은 없었는데요. 영상에 따르면 UNet이 2015년도 발표되었고, BatchNorm은 2016년에 고안된 아이디어라서 그렇다고 합니다. 찾아보니 유넷을 구현하는 많은 코드가 Batch Normalization을 추가하여 Gradient Vanishing/Exploding 문제를 보완하는 것 같습니다.
  • bias를 False로 설정해 주는 이유는 중간에 BatchNorm2d를 추가해주기 때문입니다. bias가 있어봤자 BatchNorm에 의해서 상쇄(cancel)되기 때문에 굳이 bias가 필요 없다고 영상에서 말하고 있습니다. 이부분은 나중에 BatchNorm 논문을 통해 따로 확인해 보겠습니다.

2-2. 유넷 전체 구조를 정의하는 클래스

해당 파트는 코드가 길어지기 때문에 주석을 이용해서 각 부분을 설명했습니다 :) 비교적 간단하지만 그래도 코드를 이해하시려면 유넷 전체 구조에 대한 이해가 필수적입니다.

  • 논문에서는 마지막 아웃풋의 채널이 2였습니다만, 저는 레이블이 0, 1로 binary인 데이터를 다룰 예정이기 때문에 output_channels의 default 값을 1로 두었습니다.
  • features 리스트는 각 블록에서 피처맵의 사이즈를 순서대로 정의하는 역할을 합니다. (contracting path의 경우 순서대로 64, 128, 256, 512 - expanding path의 경우 그 반대)
class UNET(nn.Module):
    def __init__(
            self, in_channels = 3, out_channels = 1, features = [64, 128, 256, 512]
    ):
        super(UNET, self).__init__()
        self.downs = nn.ModuleList() # Contracting path - 인코딩 파트의 모듈을 담을 리스트 선언
        self.ups = nn.ModuleList()   # Expanding path - 디코딩 파트의 모듈을 담을 리스트 선언
        self.pool = nn.MaxPool2d(kernel_size = 2, stride = 2) # 풀링은 모든 블럭에서 공통 사용됨

        # Contracting path (Down - 인코딩 파트) ------------------------
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature)) # 블록마다 더블콘볼루션 해주고 아웃풋은 feature맵 리스트 순서대로 할당(64, 128...)
            in_channels = feature # 다음 모듈의 인풋 사이즈를 feature로 업데이트

        # Bottleneck (인코딩, 디코딩 연결 파트) ---------------------------
        size = features[-1] # 512
        self.bottleneck = DoubleConv(size, size * 2) # 인풋 512 아웃풋 1024

        # Expanding path (Up - 디코딩 파트) ----------------------------
        for feature in features[::-1]: # 피처맵 사이즈 반대로!
            # 먼저 초록색 화살표에 해당하는 up-conv 레이어를 먼저 추가해 줍니다.
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size = 2, stride = 2
                    # 인풋에 *2 해주는 이유 : 나중에 skip-connection을 통해서 들어오는 인풋 사이즈가 더블이 되기 때문!
                    # kernel과 stride size가 2인 이유는.. 논문에서 그렇게 하겠다고 했음 '_^ 
                )
            )
            # 이제 더블 콘볼루션 레이어 추가
            self.ups.append(DoubleConv(feature * 2, feature))

        # Output (아웃풋 파트) -----------------------------------------
        last_input = features[0] # 64
        self.final_conv = nn.Conv2d(last_input, out_channels, kernel_size = 1)

    ####################### **-- forward pass --** #######################
    def forward(self, x):
        skip_connections = []

        # Contracting path (Down - 인코딩 파트) ------------------------
        for down in self.downs: # 인코딩 파트를 지나면서 각 블록에서 저장된 마지막 모듈 하나씩 이터레이션
            x = down(x)
            skip_connections.append(x) # skip_connection 리스트에 추가
            x = self.pool(x)

        # Bottleneck (인코딩, 디코딩 연결 파트) ---------------------------
        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1] # 디코딩 파트에서 순서대로 하나씩 뽑기 편하게 리스트 순서 반대로 뒤집어주기

        # Expanding path (Up - 디코딩 파트) ----------------------------
        for i in range(len(skip_connections)):
            x = self.ups[i * 2](x) # self_ups에는 순서대로 ConvTranspose2d와 DoubleConv가 들어가 있음. 
            # 0, 2, 4... 짝수번째에 해당하는 인덱스만 지정하면 순서대로 ConvTranspose2d와(up-conv) 모듈만 지정하게 됨
            skip_connection = skip_connections[i] # skip_connection 순서대로 하나씩 뽑아서
            # concatenate 해서 붙여줄(connection) 차례!
            # 그런데 만약 붙일때 shape이 맞지 않는다면... (특히 이미지의 input_shape이 홀수인 경우 이런 뻑이 나게됨)
            if x.shape != skip_connection.shape:
                x = TF.resize(x, size = skip_connection.shape[2:]) # 간단히resize로 맞춰주겠음
            concat_skip = torch.cat((skip_connection, x), dim = 1) # 이제 붙임!
            x = self.ups[i * 2 + 1](concat_skip)
            # 1, 3, 5... 홀수번째에 해당하는 인덱스만 지정하면 순서대로 DoubleConv 모듈만 지정하게 됨
            
        return self.final_conv(x)

2-3. 테스트

  • 간단한 랜덤 텐서를 생성해서 우리가 구현한 유넷 모델이 제대로 작동하는지 확인하겠습니다.
  • 모델 인풋과 아웃풋의 shape이 정확히 같은지 확인하는 작업이 들어가는데, image segmentation 작업의 특성상 output mask의 크기가 인풋과 동일해야 하기 때문에 그렇습니다.
  • assert 문법을 사용해서 인풋 아웃풋 shape이 다른 경우가 감지되면 AsserstionError를 발생시켜서 Debug에 활용하도록 합니다.
# 유넷 모델 테스트하는 함수 작성

def test():
    # 3장의 1채널(grayscale), width height가 각각 160인 랜덤 텐서 생성 (테스트용으로!)
    x = torch.randn((3, 1, 160, 160))
    # in_channels 1로 설정 (그레이스케일 이미지), out_channels 1로 설정 (binary output)
    model = UNET(in_channels = 1, out_channels = 1)
    # forward pass
    preds = model(x)
    # preds와 x의 shape 확인하기 - 두개가 같아야 함
    print(preds.shape)
    print(x.shape)
    assert preds.shape == x.shape 
        # True : 오케이
        # False : AssertionError

if __name__ == "__main__": # 메인 파일에서만 작동하고 모듈로 import된 경우에는 작동하지 않도록 함
    test()
  
# 실행 결과 ---------------------------
# torch.Size([3, 1, 160, 160])
# torch.Size([3, 1, 160, 160])

구동 결과 shape이 정확히 같아서 test 함수가 제대로 작동한 것을 확인했습니다.

 

여기까지 작성한 코드를 unet.py 파일로 저장해서 모듈화 해주었습니다.

이후 데이터를 로드에 필요한 dataset.py를 작성한 다음 학습에 필요한 utils.py train.py를 차례대로 작성하는 순서로 학습을 진행합니다. 해당 파트는 본 포스팅에서 제외하도록 하겠습니다.

지금까지 유넷 논문에서 살펴본 구조를 파이토치 코드로 구현하는 과정을 포스팅해 보았습니다 :-) 다음 포스팅에서는 트랜스포머 코드화 작업을 수행해 보도록 하겠습니다. 감사합니다!

+ Recent posts