
Pytorch ResNet18 사용시 CIFAR10 데이터셋에서 낮은 성능을 보이는 경우
PyTorch에서 제공하는 ResNet18을 사용할 때, CIFAR-100 데이터셋에 대한 성능이 기대보다 훨씬 낮게 나타납니다. 정확도가 약 30~40%에 불과하여 보고된 벤치마크보다 크게 떨어집니다.
왜 이런 현상이 발생할까요?
PyTorch의 ResNet18은 원래 ImageNet 데이터셋을 위해 설계되었습니다. ImageNet에서는 이미지 크기가 일반적으로 224×224 또는 256×256으로 리사이즈됩니다. 따라서 모델 아키텍처는 특히 첫 번째 합성곱 계층에서 큰 커널 크기를 사용하도록 최적화되어 있습니다:
Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
반면, CIFAR-100의 이미지는 훨씬 작아서 (32×32), 이 큰 커널과 스트라이드를 그대로 사용하면 과도한 다운샘플링이 발생하여 성능에 부정적인 영향을 미칩니다.
어떻게 수정할 수 있을까요?
CIFAR-100에 ResNet18을 더 잘 맞추기 위해서는 다음과 같이 해야 합니다:
- 첫 번째 합성곱 계층의 커널 크기와 스트라이드를 더 작게 수정합니다.
- 최대 풀링(max pooling) 계층을 제거합니다 (CIFAR 이미지가 이미 작기 때문에 추가 풀링이 필요하지 않습니다).
다음은 이러한 수정을 위한 코드 스니펫입니다:
import torch.nn as nn
from torchvision.models import resnet18
# 사전 학습된 가중치 없이 ResNet18 로드
model = resnet18(pretrained=False)
# 32x32 이미지를 위해 첫 번째 합성곱 계층 수정
model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
# 최대 풀링 계층 제거
model.maxpool = nn.Identity()
이와 같이 수정하면 ResNet18은 CIFAR-100에 더 적합해져 성능이 개선될 것으로 기대됩니다. 🚀
References: