본문 바로가기
Deep Learning

[DL - 논문 리뷰] Few-Shot Image Generation with Elastic Weight Consolidation

by JJuOn 2021. 12. 26.

 이번 포스팅에서는 Yijun Li가 NeurIPS 2020에서 발표한 "Few-Shot Image Generation with Elastic Weight Solidation"을 읽고 정리해보도록 하겠습니다.


1. Introduction

 GAN의 성공은 특정 도메인에 해당하는 70,000장의 사람 얼굴 데이터셋인 FFHQ나 여러 클래스에 걸쳐 130만장에 해당하는 ImageNet 데이터셋과 함께 보여졌습니다. 그러나 예술과 같은 영역은 데이터의 개수가 한정적입니다.

 

 이러한 영역에도 적은 데이터(a few)만으로도 일반화 성능을 향상시키기 위해서는 선험 지식에 의존해야합니다. 본 논문에서는 아래 사진과 같은 continuous learning framework에서 few-shot image generation을 연구했습니다. 이는 target domain의 적은 데이터로도 해당 domain의 데이터들을 생성하는 것 입니다. 연구 과정에서는 source domain과 target domain 사이에는 latent factor가 공유되어 있어야 한다는 가정이 있었습니다. 예를 들자면 사람의 얼굴과 이모지(emoji) 사이에는 외향적인 차이는 있지만, 자세나 표정 등의 요소는 공유합니다.

 이를 달성하기 위해 본 논문에서는 간단하고 효율적인 adaptation technique를 도입합니다. pretrained model의 weight를 가져오면서, 다른 parameter를 추가하지 않습니다. 그렇게 되면 어떻게 한정된 target domain에 대해 학습을 시키면서 동시의 전이(transfer)된 지식과 다양성(diversity)를 유지할 지가 관건입니다. 이에 핵심은, 모든 parameter를 adaptation 과정에서 똑같이 취급하지 않는 것입니다. 본 논문에서는 각 매개 변수마다 중요성을 계산하여 tuning 과정에서 해당 parameter가 보존 되도록 하였습니다. Krikpatric et alElastic Weight Consolidation이라 하는 방법을 제시했는데, 이 방법은 objective likelihood와 연관있는 Fisher Information을 계산하여 각 parameter의 중요성을 평가했습니다. 본 논문은 Fisher Information이 proxy objective(a frozen discriminator)로 부터 측정될 수 있음을 보여주고, 10개 이하의 극단적으로 작은 예시를 가지고도 다른 target domain에 대해 높은 수준의 결과를 생성할 수 있음을 보여줍니다.

 

 게다가, 필연적으로 source domain의 정보를 유지하는 것과 target domain에 adapt하는 것 사이에는 trade-off가 존재합니다. 본 논문에서는 target example의 수나 source domain과 target domain 간의 차이점과 같은 중요한 요소에 관해서도 깊은 분석을 했습니다.

 

 본 논문이 기여한 바는 다음과 같습니다.

  • Pretrained generative model을 별도의 parameter를 추가하지 않고 새로운 target domain에 adapt 시킨 것
  • 이러한 방법이 artistic domain과 같은 실제론 데이터의 양의 한정적인 경우에 효과적인 것
  • 이전 방법들이 photo domain에 한정된 것에 반해, 여러 cross-domain source/target pair에 대해서도 실험한 것

2. Related Work

Few-shot learning

 Few-shot learning은 처음엔 few-shot image classification 분야에서 연구되었습니다. 몇몇 대표적인 방법은 다음과 같습니다.

  • Metric learning methods
  • Meta-learning methods
  • Dynamically weight prediction methods
  • Some semi-supervised learning methods

 그러나 few-shot image generation 문제에서는 적은 양의 이미지만 주어지고 target domain에 대한 다른 정보는 주어지지 않는다고 가정합니다. Few-shot image classification과의 차이점은 generation은 다양한 결과를 생성하는 것에 초점을 두는 반면, classification은 한정돼있는 label을 예측하는 것을 목표로 두고 있다는 점입니다.

 

 Few-shot image generation에 관한 몇몇 연구는

  • matching networks
  • squential generative models
  • autoregressive model

에 기반한 few-shot density estimation에 집중했습니다. 그러나 이러한 방법은 간단한 패턴이나 저해상도의 이미지를 생성하는 데 그쳤습니다. 그러나 최근 BSA, MineGAN에서는 고해상도의 이미지를 생성하는것을 보여주었습니다. 둘 다 pretrained GAN을 fine-tuning하는 것 부터 시작하여 학습을 위해 별도의 parameter를 원본 network에 추가하였습니다. 이러한 방법은 4장에서 서술되겠지만 극단적으로 데이터의 수가 적은 경우에는 비효율적입니다. BSA와 MineGAN은 source domain의 데이터를 계속 생성할 수는 있습니다. 그러나 본 논문에서는 source domain의 diversity는 보존하면서 target domain의 데이터를 생성하는 데 더 집중했습니다.

Style transfer

 Target domain의 이미지를 생성하는 다른 방안은 style transfer를 적용하는 것입니다. Style transfer methods에는 주로 두 가지종류가 있습니다.

 

 첫 번째는 example-based 방법입니다. Example-based 방법은 오직 하나의 style example에서만 작동하지만 정렬을 요구하거나 단지 색상이나 질감만 transfer(NST)하는 등 한계가 있습니다. 하나의 example은 domain의 style을 완전히 표현하지 못하고, domain의 style이 색상이나 질감에만 한정되어있는 것은 아닙니다.

 

 반면에, 두 번째 방법인 domain-based 방법은 많은 양의 source domain 데이터와 target domain 데이터를 요구합니다. 그렇기 때문에 few-shot task에는 바로 적용할 수는 없지만 최근 연구는 style label이 있는 여러 source domain을 구축하여 meta-learning을 실행하거나 content와 style representation들을 구분하도록 학습합니다. 본 논문에서는 한 가지의 source domain이 주어지기 때문에 위와는 다릅니다.

Continuous learning

 Source domain에서 pretrain된 GAN을 target domain에서 adapt하는 것은 두 tasks에서 순차적으로 학습 시키는 것과 같으므로 자연스레 continous learning에 해당합니다. Continous learning은 주로 castrophic forgetting 현상을 다루게 되는 데 이는 이전에 학습된 task를 잘 수행하면서 연속적인 task를 학습하는 것입니다. 몇몇 연구들은

  • distillation-based methods
  • memory-based methods
  • attention-based methods
  • regularization-based methods

를 통해 classification task에서 이뤄졌습니다. 이를 연장해 최근 연구들은 generative domain에서 진행됐습니다. 그러나 이런 과정들은 충분한 데이터를 전제로하기 때문에 본 논문의 연구과는 다릅니다. 앞서 언급한 것 처럼 source domain의 데이터는 더이상 생성할 수 없습니다. 그러나 본 논문에서 유지하고자 하는 것은 source domain의 diversity이고 그로 인해 적은 양의 target domain data로도 다양한 결과를 내고자 했습니다.


3. Proposed Method

 본 논문에서 다룰 접근 방식의 목표는 source domain에서 pretrain된 model의 weights를 단지 몇개의 target domain에 adapt하는 것 입니다. Regularization 없이 바로 adapting 하는 것은 over-fitting을 초래하게 됩니다. 그렇기 때문에

  1. 어떤 weights가 중요해서 보존해야 하는 지 또는 어떤 weights가 변경시키기에 자유로운 지
  2. 어떻게 이러한 중요도를 계산하여 loss function을 통해 regularize할 것 인지

를 파악해야 합니다.

3.1 Rate of changes on weights

 첫번째로, target domain에 많은 데이터가 있어 제대로 된 generative model을 학습시킬 수 있고 좋은 weights란 어떻게 생겼는 지 영감을 받을 수 있다고 가정했습니다. 데이터의 수가 많다면 처음부터 모델을 학습 시키거나 pretrained model을 사용하는 것 둘 다 좋은 성과를 낼 것 입니다. 그러나 few-shot 상황에서는 처음부터 모델을 학습시키는 것은 불가능 하므로 pretrained model을 fine tuning하는 과정에서 weight가 어떻게 변하는 지 확인했습니다. Source domain은 20만 장 가량의 CelebA dataset을 사용했고, Target domain으로는 Bitmoji API로 부터 8만장 가량의 emoji data를 사용했습니다. Five-Layer의 DCGAN을 처음에는 source domain(real faces)에 대해 학습시키고, 다음으로 target domain(emojis)에 대해 fine tuning을 진행했습니다. 두 domain에서 모두 아래의 adversarial loss를 사용했습니다.

Adversarial loss

여기서 p_data(x)와 p_z(z)는 각각 real data x와 noise variable z의 분포를 나타냅니다. 생성된 real faces와 emoji는 다음과 같습니다.

Generated real faces(from G) and emojis(from G')

그 다음 아래 수식을 통해서 i번째 layer에서 weights의 평균 변화율을 계산했습니다. 

Computing the average change rate of weights at i-th layer

그 결과 각 layer마다 평균 변화율은 다음과 같이 나타났습니다.

Rate of changes on weights

이를 통해 가장 마지막 layer(Conv5)가 평균적으로 다른 layer에 비해 가장 적게 변화한 것을 알 수 있었습니다. 다른 source-target domain pair를 다른 GAN architecture(LapGAN, StyleGAN)에 대해 진행했을때도 비슷한 결과가 나타났습니다. 이를 통해 알 수 있는 것은 마지막 layer에 있는 weight들은 더욱 중요하고, 다른 가중치들 보다 더 보존시켜야 한다는 것입니다.

3.2 Importance Measure

 이번 절에서 다룰 내용은 이전 절의 내용을 바탕으로 각 weights의 중요성을 측정하여 regularize하는 것입니다. 각 

weights의 중요성을 Fisher Information을 사용하여 측정하고자 했습니다. 특정 파라미터 θ_s에 대한 Fisher Information은 다음과 같이 구해집니다.

Fisher Information

여기서 L(X|θ_s)는 log-likelihood function입니다. Reconstruction 혹은 perceptual loss를 사용했을 때에도 비슷한 결과 나왔다고 합니다. 각 layer의 평균 Fisher Information은 다음과 같이 나왔습니다.

Fisher information

본 논문에서는 F를 바로 weight의 중요도로 사용하고, target domain에서 fine-tuning 할 때 loss function에 규제항으로 사용했습니다. 

Loss function at adaptation

두번째 항인 규제항은 처음에 Krikpatric et al에서 classification 분야에서 전체 데이터는 충분히 있는 상황에서 예전 클래스에 대한 performance도 유지하기 위해 사용되었고 EWC라 명명되었습니다. 본 논문에서는 few-shot generative 환경에서도 효과적임을 보여줍니다. EWC를 사용하지 않으면 adapataion과정에서 over-fitting을 유발하여 단지 주어진 target domain에서의 입력에 대해 재생산하게 됩니다.

The effectiveness of EWC loss

위 그래프는 EWC를 사용하지 않았을 때 기존 weight에서 매우 가파르게 벗어난 것을 보여주고, 그 오른쪽의 사진은 EWC를 적용하지 않았을 때 단순히 입력된 target domain의 데이터를 재생산하는 것을 보여줍니다.


4. Experimental Results

 본 논문에서는 qualitative comparison과 quantitative comparison의 두 가지로 나누어 실험합니다. 그리고 마지막에서는 target domain의 데이터의 수나 EWC loss의 λ에 따른 성능의 차이를 분석합니다.

 

 데이터셋은 face와 landscape 두가지 부류의 데이터셋을 사용했습니다. Face 분야에서는 실제 얼굴들을 다룬 FFHQ 데이터셋을 source domain으로 사용했으며 targe domain으로 emoji 데이터셋인 Bitmoji API, 동물 얼굴들을 다룬 AFHQ 데이터셋, 그리고 초상화를 다룬 Artistic-Faces 데이터셋을 사용했습니다. AFHQ 데이터셋에서는 10장의 강아지와 고양이 얼굴 이미지를 사용했고, Artistic-Faces 데이터셋은 16명의 화가가 있고 각 화가마다 10장의 초상화 이미지가 있습니다. Landscape에서는 CLP 데이터셋을 source domain으로 사용했고, 10장의 연필로 그린 풍경화 사진을 target domain으로 사용했습니다.

 

 평가는 본 논문의 방법(EWC)과 NST, BSA, MineGAN사이에서 진행됐습니다. Neural Style(NST)는 style transfer 방법의 일종입니다. 앞서 2장에서 언급했던 것 처럼, NST는 example-based 방법이기 때문에 target domain에서 랜덤으로 한장의 이미지를 선택해서 사용했습니다. BSA와 MineGAN은 추가 parameter를 도입하여 target domain에 대한 adaptation을 진행했습니다. BSA는 BigGAN에 batch norm layer를 추가하였고 adaptation 과정에서만 새로운 parameter를 학습시켰습니다. MineGAN은 Progressive GAN generator에 작은 mining network를 추가하였고 mining network 만 처음에 fine-tuning 한 후 generator와 함께 다시 fine-tuning하는 두 단계의 방법을 채택했습니다. 본 논문의 방법은 StyleGAN을 기반으로 하였습니다.

4.1 Qualitative results

Visual comparisons for different methods for few-shot generation

 위 사진은 다른 방법들간의 차이를 시각화 한 것 입니다. NST는 target example에 대한 전반적인 색깔과 질감을 전이했습니다. 그러나 상대적으로 지저분하고 고차원 특징은 포착하지 못했습니다. BSA는 처음 봤을 때 뿌옇게 보이고 비슷한 결과를 내는 걸로 보아 model collapse 현상이 보여집니다. Model collapse 현상은 generator가 다른 input에 대해 같은 output을 내는 것 입니다. MineGAN은 입력을 재생산하는 걸로 보아 diversity가 떨어집니다. 그에 반에 본 논문의 방법은 target domain의 style에 충실한 채 diversity까지 확보했습니다.

4.2 Quantitative comparisons

 Quantitative study는 target domain이 충분한 데이터를 갖고 있는 지 여부에 따라 다르게 진행됐습니다. Targe domain이 충분한 데이터를 갖고 있다면 Frechet Inception Distance(FID)를 사용하여 generated image의 quality를 측정했습니다. Target domain의 데이터의 수가 많지 않을 경우 Amazon Mechanical Turk에서 user study를 진행했습니다. User study는 user에게 artistic domain의 몇몇 예시를 보여주고 어떤 그림이 생성 된 그림인지 맞추도록 하였습니다. 총 10라운드로 구성되어 있고 각각의 방법마다 300표씩 수집했습니다. Diversity를 측정하기 위해 LPIPS metric을 사용하여 결과들 간의 유사성을 측정했습니다.

Quantitative comparisons between different few-shot generation methods

 위 표는 10-shot generation의 결과입니다. Pretrained model을 첫번째로 평가하여 reference로 사용했습니다. 첫번째 행에서 본 논문의 방법이 quality를 의미하는 FID가 가장 낮은 것을 나타내고 있습니다. 이는 위 Figure 4에도 나타나있습니다. 그리고 diversity를 의미하는 LPIPS가 NST가 가장 높습니다. 그러나 이는 NST의 결과가 다소 어지러운 경향을 나타내기 때문입니다. 그렇게 때문에 BSA와 MineGAN을 본 논문의 방법과 비교하는 것이 적절하고 그 결과 본 논문의 방법이 다양한 결과를 생성한다고 합니다. 마지막으로 User는 user study에서 진짜와 가짜를 구분 못하는 것이 좋은 성능을 낸다고 해석될 수 있습니다. 본 논문의 방법이 fooling rate 47.92%로 진짜와 가짜를 구분하기 어려운 결과을 생성했음을 나타냈습니다.

4.3 Discussion

Quantitative comparisons between different few-shot generation methods with respect to the number of shots

 Few-shot generation task에서 입력 데이터의 수는 성능에 영향을 미치는 중요한 역할을 합니다. 위 표는 입력 데이터의 수에 따른 FID를 비교한 것입니다. NST는 성능이 크게 향상되지 못했습니다. BSA는 입력 데이터의 수가 증가함에 따라 오히려 성능이 떨어졌음을 보입니다. 이는 BSA의 학습 이미지를 재구성하는 방식 때문입니다. 많은 입력 데이터들이 오히려 낮은 수준의 재구성을 초래합니다. MineGAN의 경우 본 논문의 방법처럼 비슷한 성능 향상을 보이지만 10개 이하의 극도록 적은 학습 데이터에서는 본 논문의 방법이 훨씬 더 우수합니다. 본 논문의 방법은 또한 1-shot인 경우에서도 좋은 성능을 냅니다. 아래 사진은 1-shot generation의 예시입니다.

1-shot generation example

한장의 학습 데이터만 주어졌음에도 갈색 피부, 붉은 머리는 유지한 채 안경이나 미소, 포즈, 성별등이 다르게 생성되었습니다.

위 표는 10-shot 환경에서 EWC loss의 λ값에 따른 성능을 비교한 것입니다. λ값이 너무 커지면 target domain을 제대로 학습하지 못하게 되지만 source domain의 diversity는 유지하게 됩니다. 반대로 λ값이 너무 작아지면 target domain을 학습하는 데 있어 over-fitting이 발생하고 diversity는 떨어지게 됩니다. 모든 실험에서는 λ=5x10^8로 설정했습니다. 본 논문에서는 source domain과 target domain이 더욱 비슷하다면 λ을 더 크게하여 diversity를 더욱 확보하고, 만약 target domain의 데이터가 좀 더 있다면 더 작은 λ를 선택해 target domain을 더욱 잘 학습하도록 하는 것이 낫다고 합니다.

 

 본 논문에서는 또한 source domain과 target domain의 유사도가 성능에 영향을 미치는 지도 실험해보았습니다. 상대적으로 source domain인 FFHQ(real faces) 데이터셋은 target domain 중 CelebA-Female face, emoji face, cat face, color pencil landscape 순서대로 비슷합니다. Dissimilarity는 diversity를 측정할 때 사용되었던 LPIPS로 측정되었고, 각각의 결과는 FID로 측정되었습니다. 결과는 아래 표와 그림과 같습니다.

특히 real faces->color pencil landscape 과정에서 실루엣이 보이는 등 domain이 극도로 차이나는 경우에는 성능이 하락하는 것을 보이는데, EWC regularization을 적용해도 큰 효과는 없다고 합니다. 이로 인해 source domain과 target domain간의 유사도 또한 중요했음을 알 수 있습니다.

 

 

 마지막으로 본 논문에서는 source domain의 diversity가 보존됐는지 실험을 진행했습니다.

위 사진의 상단은 source domain에서 생성된 이미지의 모습이고, 하단은 Moise Kisling face 도메인으로 adapt한 후 생성된 이미지의 모습입니다. 외향은 달라졌지만 태도나 안경, 헤어스타일등의 속성은 보존된 것을 알 수 있습니다.


Refereneces

[1] https://en.wikipedia.org/wiki/Fisher_information

[2] https://wandb.ai/wandb_fc/korean/reports/-Frechet-Inception-distance-FID-GANs---Vmlldzo0MzQ3Mzc

[3] https://raon1123.blogspot.com/2019/10/gan-model-collapse.html

[4] https://deepai.org/machine-learning-glossary-and-terms/perceptual-loss-function

 

댓글