본문 바로가기
Deep Learning

[DL] 딥러닝에서 미세 조정(Fine Tuning) vs 특징 추출(Feature Extraction)

by JJuOn 2021. 9. 6.

Convolutional Neural Network를 살펴보면 대개 두가지의 부분으로 구성되어 있다.

 

첫번째는 Convolutional Layers로, 입력(대개 이미지)에 대한 feature를 추출하는 역할을 맡는다.

두번째는 Classifier로 Conv. Layers로부터 획득한 feature를 토대로 해당 입력이 어떤 class에 속하는지 판단한다.

 

이전 포스팅에서 언급한 바와 같이, 시간과 비용 문제로 인해 대개 well-trained network를 transfer learning하여 사용하곤 한다.

그런데 transfer learning에 관한 자료들을 찾다 보면 fine tuning이라는 용어도 보이고, feature extraction이라는 용어도 보이곤 한다.

이번 포스팅에서는 두 방법의 차이점과 두 방법이 어떠한 상황에서 적용 되어야 하는지 서술해 보고자 한다.

Feature Extraction

Feature Extraction


Feature Extraction은 특징 추출을 의미한다.
network의 classifier는 해결하고자 하는 문제에 맞게 초기화 한 후 trainable하게 둔다.
convolutional layers는 freeze하여 새로운 데이터들의 feature들이 학습되지 않도록 한다.

Fine Tuning

Fine Tuning


Fine Tuning은 미세 조정을 의미한다.
feature extraction의 경우와 마찬가지고 network의 classifier는 trainable하게 두어
새로운 데이터를 올바르게 분류하도록 학습시킨다.
convolutional layers들은 기존 pretrained network의 weight를 최대한 망가뜨리지 않고
새로운 데이터의 feature를 아주 조금씩만 학습하도록 하기 위해 learning rate를 아주 작게 설정하고
weight들을 trainable하게 둔다.
또는, conv layers의 초기 layer들이 edge나 blob등의 저차원 feature를 추출하는 역할을 하여
데이터가 달라져도 그 역할이 크게 달라지지 않는다.
그리고 learning rate를 작게 조절하더라도 overfitting이 발생 할 수 있기 때문에

conv layers의 초기 layer는 freeze하고 이후 layer는 trainable하게 두어

새로운 데이터의 feature를 학습하게 할 수도 있다.

 

Implementation

그렇다면 어떻게 특정 layer들을 freeze하거나 trainable하게 할 수 있을까?

 

pytorch의 parameter들은 requires_grad라는 attribute를 모두 가지고 있다.

이러한 requires_grad=True로 하면 trainable한 parameter가 되는 것이고,  

requires_grad=False로 하면 freeze된다.

 

아래 예시는 모두 feature extraction의 예시로 classfier는 trainable하게,  

그 외의 layer는 freeze하였다.

 

# torchvision version
import torchvision

import torch.nn as nn

num_classes=10

model=torchvision.models.resnet18(pretrained=True)

for p in model.parameters():
	p.requires_grad=False
    
model.fc=nn.Linear(model.fc.in_features,num_classes,bias=True)

torchvision의 경우 ImageNet으로 pretrain된 network를 기존 classifier까지 함께 download 된다.

그렇기 때문에 기존 layer들의 parameter를 모두 freeze하고

새로운 classifier를 생성한 것이다.

 

이후 새로 만든 layer에 필요하다면 weight initialization을 적용하여 학습을 시작하면 된다.

 

# timm version
import timm

num_classes=10

model=timm.create_model('resnet18',pretrained=True,num_classes=num_classes)

for n,p in model.named_parameters():
	if n in ['fc.weight','fc.bias']:
		p.requires_grad=True
	else:
		p.requires_grad=False

timm의 경우 torchvision과 마찬가지로 직접 model.fc에 새로운 nn.Linear()를 붙여줘도 되지만

timm의 경우 model에 따라 create_model()중에 weight initialization을 진행하기 때문에

model.named_parameters()를 사용하여 classfier의 weight, bias는 trainable하게 두고

다른 layer의 weight와 bias는 freeze하였다.

 

References

https://cs231n.github.io/transfer-learning/

댓글