티스토리 뷰
[논문 100개 구현 1탄] U-Net: Convolutional Networks for Biomedical Image Segmentation
sikaro 2024. 11. 27. 11:47논문 100개 구현의 1탄을 장식한다.
최근 어떤 부트캠프 홈페이지에서 논문을 50~100개 구현할 수 있으면 전문가 수준이 된다고 해서 직접 해보려고 한다.
4~6개월에 1천만원이라면 비싼 돈 들일 필요 없이 직접 하면 돈이 아껴지지 않겠는가.
어떻게 보면 의지로 돈을 사는 셈이다.
고로 오늘 해볼 논문은 공모전에서도 활용해볼 겸 U-Net의 논문이다.
https://arxiv.org/abs/1505.04597
논문 구현에 앞서 확인해야 할 포인트
Read
1. 논문 제목(title)과 초록(abstract), 도표(figures) 읽기
2. 도입(introduction), 결론(conclusion), 도표(figures)를 읽고 필요없는 부분 생략
3. 수식은 처음 읽을 때는 과감하게 생략
4. 이해가 안되는 부분은 빼고 전체적으로 읽는다.
QnA
1. 저자가 뭘 해내고 싶어했는가?
2. 이 연구의 접근에서 중요한 요소는 무엇인가?
3. 당신(논문독자)는 스스로 이 논문을 이용할 수 있는가?
4. 당신이 참고하고 싶은 다른 레퍼런스에는 어떤 것이 있는가?
구현하기
1. 수식 이해하고 직접 연산하기
2. 코드 연습하기(오픈소스를 받아 직접 구현)
Read
1. 논문 제목(title)과 초록(abstract), 도표(figures) 읽기
논문 제목부터 눈에 띈다. Convolutional Networks for Biomedical Image Segmentation이다.
즉, 이 저자는 처음부터 CNN을 기반으로 Biomedical 분야에서 Segmentation을 하기 위해 이 논문을 만들었다는 걸 알 수 있다.
그리고 초록을 보면 어떤 의도를 가지고 이 논문을 만들었는지 알 수 있다.
In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated samples more efficiently
의학 분야는 annotated sample이 부족하다. 또한 Segmentation 자체가 원래 annotation을 하는데 굉장히 많은 시간과 비용이 들어간다. 즉, data augmentation의 효율을 높이기 위한 네트워크와 학습 방법론을 위해 이 논문을 만들었다는 셈이다.
이게 어떻게 가능한가?
The architecture consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. We show that such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks.
간략하게 번역하자면,
이 구조는 맥락을 포착하기 위한 축소 경로와 정확한 위치를 파악할 수 있도록 대칭적인 확장 경로로 구성되어 있습니다. 우리는 이러한 네트워크가 매우 적은 수의 이미지로부터 종단 간(end-to-end) 학습이 가능하며, 전자현미경 스택에서 신경 구조를 분할하는 ISBI 챌린지에서 이전의 최고 방법(슬라이딩 윈도우 컨볼루션 네트워크)을 능가한다는 것을 보여줍니다.
이 부분은 U-Net의 구조와 일맥상통한다. 말 그대로 Decoder 과정에서 Encoder에서 썼던 Feature Map을 가져다 쓰는데, 대칭적인 구조로서 되어 있다는 거고, 그렇기에 네트워크가 매우 적어도 end-to-end로 학습이 가능하다는 것이다.
2. 도입(introduction), 결론(conclusion), 도표(figures)를 읽고 필요없는 부분 생략
Introduction
Introduction 부분에서 필요 없는 부분을 생략하고 요약하자면 이렇다.
- 딥 컨볼루션 네트워크는 ImageNet 데이터셋을 사용한 대규모 학습으로 성공을 이루었다.
- 기존 방법인 슬라이딩 윈도우는 각 픽셀 주변 패치를 입력으로 제공해 클래스 라벨을 예측한다.
- 그러나 너무 느린 속도와 중복 처리, 맥락 활용과 위치 정확성 간의 트레이드오프가 있다.
그래서 U-Net을 제시한다.
특징 : 축소 경로의 고해상도 특징(Feature Map)을 확장 경로의 업샘플링된 결과와 결합(concatenate)해 더 정밀한 출력 생성(위 사진에서 회색 화살표들)
CNN과 달리 완전 연결층(FCN) 없음. GPU 메모리 제한 없이 큰 이미지를 분할 가능
효율적 학습 전략
데이터 증강 - Elastic Deformation(탄성 변형)을 통해 비선형적인 이미지로 증강한 걸 사용하여 효율을 높였다.
손실함수 개선 : 인접 객체 구분을 위해 경계 배경 픽셀에 높은 가중치 부여
Conclusion
U-Net 아키텍처는 다양한 생의학적 분할 작업에서 매우 뛰어난 성능을 보임.
탄성 변형을 활용한 데이터 증강 덕분에 소량의 주석 이미지만으로도 효과적인 학습 가능.
NVidia Titan GPU(6GB)에서 약 10시간만에 훈련 완료.
Caffe 기반 구현과 사전 훈련된 네트워크를 함께 제공.
이 아키텍처는 다양한 다른 작업에도 쉽게 적용 가능할 것으로 기대됨.
이렇게 보고, 다른 Figure에 대해서는 QnA 부분에서 다뤄보도록 한다.
QnA
1. 저자가 뭘 해내고 싶어했는가?
앞서 언급했던 대로 소량의 라벨링 이미지만으로도 효과적인 Segmentation을 하기 원했던 것이다.
2. 이 연구의 접근에서 중요한 요소는 무엇인가?
U-Net 아키텍처: 축소(Contracting) 경로와 확장(Expanding) 경로로 구성된 대칭 구조로, 맥락 정보(픽셀의 맥락)와 세부 정보를 효과적으로 결합.
데이터 증강: 탄성 변형(elastic deformation)을 포함한 다양한 증강 기법으로 학습 데이터를 확장.
효율성: GPU 메모리 한계 내에서 큰 이미지를 처리하기 위해 오버랩 타일링(overlap-tile) 전략 적용.
손실 함수 개선: 인접 객체를 구분하기 위해 경계 픽셀에 가중치를 부여.
3. 당신(논문독자)는 스스로 이 논문을 이용할 수 있는가?
Caffe 기반의 U-Net 구현과 사전 훈련된 네트워크가 공개되어 있음으로 이용할 수 있게 되어있다.
4. 당신이 참고하고 싶은 다른 레퍼런스에는 어떤 것이 있는가?
U-Net의 원리와 관련된 연구: Fully Convolutional Networks (Long et al.)
데이터 증강: 탄성 변형과 관련된 unsupervised feature learning (Dosovitskiy et al.)
초기 슬라이딩 윈도우 접근법: Ciresan et al.
생의학 이미지 데이터셋: ISBI 2012 EM Segmentation Challenge 및 ISBI 2015 Cell Tracking Challenge 데이터셋.
구현하기
1. 수식 이해하고 직접 연산하기
이제 어느정도 가닥이 잡혔으니, 본격적으로 수학적인 구현에 들어가보자.
이 논문을 구현할 때 주의해야 하는 부분은 이 부분이다.
이런식으로 경계를 명확하게 하기 위한 손실함수를 사용하는데, 가중치 함수 w를 사용하므로 이 부분에 대한 구현도 요구된다.
이는 Pytorch에서 BCEWithLogitsLoss()를 사용할 것이다. 참고자료는 이진분류긴 하지만, 다음과 같이 고친다.
pos_weight = torch.tensor([2.0]).to(device) # 예: 객체 픽셀에 더 높은 가중치를 부여
fn_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight).to(device)
그리고 Softmax를 사용하려면 마지막 레이어를 소프트맥스로 바꾼다.
# 기존: 출력 채널이 1인 경우
# self.fc = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True)
# 수정: 다중 클래스 (예: 3개의 클래스)
self.fc = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=1, stride=1, padding=0, bias=True)
# CrossEntropyLoss 정의 (클래스별 가중치 설정)
fn_loss = nn.CrossEntropyLoss(weight=class_weights).to(device)
이 부분만 이해한다면 크게 달라지는 건 없다.
2. 코드 연습하기(오픈소스를 받아 직접 구현)
Encoder 부분은 64 -> 128 -> 256 -> 512 -> 1024 순으로 커진다.
반면 Decoder 부분은 그 반대로 작아지고, 처음 각 부분의 처음에는 Feature Map과 Concatenate 해준다.
따라서 Functional로 정의하면 다음과 같다.
## 라이브러리 불러오기
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
from torchvision import transforms, datasets
## 네트워크 구축하기
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
# Convolution + BatchNormalization + Relu 정의하기
def CBR2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True):
layers = []
layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
bias=bias)]
layers += [nn.BatchNorm2d(num_features=out_channels)]
layers += [nn.ReLU()]
cbr = nn.Sequential(*layers)
return cbr
# 수축 경로(Contracting path)
self.enc1_1 = CBR2d(in_channels=1, out_channels=64)
self.enc1_2 = CBR2d(in_channels=64, out_channels=64)
self.pool1 = nn.MaxPool2d(kernel_size=2)
self.enc2_1 = CBR2d(in_channels=64, out_channels=128)
self.enc2_2 = CBR2d(in_channels=128, out_channels=128)
self.pool2 = nn.MaxPool2d(kernel_size=2)
self.enc3_1 = CBR2d(in_channels=128, out_channels=256)
self.enc3_2 = CBR2d(in_channels=256, out_channels=256)
self.pool3 = nn.MaxPool2d(kernel_size=2)
self.enc4_1 = CBR2d(in_channels=256, out_channels=512)
self.enc4_2 = CBR2d(in_channels=512, out_channels=512)
self.pool4 = nn.MaxPool2d(kernel_size=2)
self.enc5_1 = CBR2d(in_channels=512, out_channels=1024)
# 확장 경로(Expansive path)
self.dec5_1 = CBR2d(in_channels=1024, out_channels=512)
self.unpool4 = nn.ConvTranspose2d(in_channels=512, out_channels=512,
kernel_size=2, stride=2, padding=0, bias=True)
self.dec4_2 = CBR2d(in_channels=2 * 512, out_channels=512)
self.dec4_1 = CBR2d(in_channels=512, out_channels=256)
self.unpool3 = nn.ConvTranspose2d(in_channels=256, out_channels=256,
kernel_size=2, stride=2, padding=0, bias=True)
self.dec3_2 = CBR2d(in_channels=2 * 256, out_channels=256)
self.dec3_1 = CBR2d(in_channels=256, out_channels=128)
self.unpool2 = nn.ConvTranspose2d(in_channels=128, out_channels=128,
kernel_size=2, stride=2, padding=0, bias=True)
self.dec2_2 = CBR2d(in_channels=2 * 128, out_channels=128)
self.dec2_1 = CBR2d(in_channels=128, out_channels=64)
self.unpool1 = nn.ConvTranspose2d(in_channels=64, out_channels=64,
kernel_size=2, stride=2, padding=0, bias=True)
self.dec1_2 = CBR2d(in_channels=2 * 64, out_channels=64)
self.dec1_1 = CBR2d(in_channels=64, out_channels=64)
self.fc = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True)
# forward 함수 정의하기
def forward(self, x):
enc1_1 = self.enc1_1(x)
enc1_2 = self.enc1_2(enc1_1)
pool1 = self.pool1(enc1_2)
enc2_1 = self.enc2_1(pool1)
enc2_2 = self.enc2_2(enc2_1)
pool2 = self.pool2(enc2_2)
enc3_1 = self.enc3_1(pool2)
enc3_2 = self.enc3_2(enc3_1)
pool3 = self.pool3(enc3_2)
enc4_1 = self.enc4_1(pool3)
enc4_2 = self.enc4_2(enc4_1)
pool4 = self.pool4(enc4_2)
enc5_1 = self.enc5_1(pool4)
dec5_1 = self.dec5_1(enc5_1)
unpool4 = self.unpool4(dec5_1)
cat4 = torch.cat((unpool4, enc4_2), dim=1)
dec4_2 = self.dec4_2(cat4)
dec4_1 = self.dec4_1(dec4_2)
unpool3 = self.unpool3(dec4_1)
cat3 = torch.cat((unpool3, enc3_2), dim=1)
dec3_2 = self.dec3_2(cat3)
dec3_1 = self.dec3_1(dec3_2)
unpool2 = self.unpool2(dec3_1)
cat2 = torch.cat((unpool2, enc2_2), dim=1)
dec2_2 = self.dec2_2(cat2)
dec2_1 = self.dec2_1(dec2_2)
unpool1 = self.unpool1(dec2_1)
cat1 = torch.cat((unpool1, enc1_2), dim=1)
dec1_2 = self.dec1_2(cat1)
dec1_1 = self.dec1_1(dec1_2)
x = self.fc(dec1_1)
return x
코드가 굉장히 길지만, 하나씩 짚어보자.
def CBR2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True)
# 해당 부분은 말 그대로 CNN 자체에 Relu와 Batch Normalization까지 합해서 함수 블록을 만들어준 것이다.
이렇게 하면 정의만 하면 되기 때문에 손쉽게 Forward를 정의할 수 있게 된다.
pooling도 적용이 되고, 그 이외에 중요한 부분은 cat4 = torch.cat((unpool4, enc4_2), dim=1)이다.
이 부분에서 마지막 엔코더 부분과 Decoder의 시작 부분이 합쳐진다.
그리고 엔코더 부분의 시작점들은 전부 torch.cat으로 시작되는 걸 알 수 있다.
이렇게 나온 x값은 logit값으로, BCEWithLogitsLoss()에 들어가서 확률 값으로 출력된다.
전체 예제 코드는 여기에서 볼 수 있다.
마치며
이상으로 논문 100탄 시리즈의 포문을 여는 U-Net 논문 구현을 작성해보았다.
과연 정말일지, 얼마나 실력이 향상될지 기대가 된다.