Search
🌨️

MobileStyleGAN: A Lightweight Convolutional Neural Network for High-Fidelity Image Synthesis

Created
6/28/2021, 9:54:15 AM
Tags
MobileStyleGAN
CNN
일시
2021/06/26
발표자
김연태
✅main
포스팅 종류
논문리뷰
Github : https://github.com/bes-dev/MobileStyleGAN.pytorch

Abstract

최근 이미지 생성에서 GAN이 많이 사용되고 있고 특히 StyleGAN은 매우 좋은 결과물을 생성하나 매우 큰 연산량을 가진다는 문제가 있다
그래서 이 모델의 최적화에 중점을 두어 계산량이 많은 부분을 분석하고 생성기를 엣지 디바이스에서도 사용가능하도록 하였다
제안된 MobileGAN은 StyleGAN 보다 3.5배 적은 파라미터와 9.5배 적은 연산량으로 비슷한 품질을 제공한다

1. Introduction

현재 GAN으로 생성하는 이미지의 퀄리티가 매우 높아지고 있는데 초기 DCGAN 의 경우 64x64의 작은 이미지를 생성하였으나 BigGAN 혹은 StyleGAN의 경우는 512, 심지어 1024의 이미지도 생성이 가능하다
물론 품질이 좋아진 만큼 연산량도 많이 필요하여 SOTA급 모델들을 엣지 디바이스에서 사용하는것은 어렵다
StyleGAN v2의 경우 28.27M 의 파라미터와 143.15GMAC 의 연산복잡도를 가짐
그래서 MobileStyleGAN을 제안
StyleGAN v2 를 기반으로 모델에서 연산량이 많은 부분을 분석하여 품질의 손실 없이 모델을 경량화
8.01M의 파라미터와 15.09 GMAC의 매우 작은 연산 복잡도를 가지지만 FFHQ 데이터에서 FID 12.38이라는 좋은 값을 가짐
Main Contribution 은 총 4가지
고품질 이미지 생성을 위한 wavelet-based CNN 기반의 end-to-end 모델 생성
Depthwise Separable Modulated Convolution 을 경량화한 Modulated Convolution 제안
demodulation 부분을 개선 (그래프 최적화에 적용가능)
모델 학습을 위한 KD 기반의 파이프라인

2. Related Work

StyleGAN

StyleGAN
큰 이미지 생성을 위해 모델을 점진적으로 증가시키며 훈련하는 Progressive Growing 방식
고정된 값으로부터 이미지 생성을 시작하고 확률적으로 생성된 latent를 기반으로 이미지를 생성
8개의 FCL 을 거쳐 latent를 비선형으로 변환후 AdaIN을 사용해 각 해상도에 스타일벡터로 적용
StyleGAN2
실제 통계가 아닌 추정된 통계로 정규화를 하여 droplet modes를 제거
Progressive Growing방식 대신 Skip Connection을 사용하여 눈과 이빨 등의 위치가 강제되는 현상 제거
latent 를 smoothing 하여 PPL을 줄여 이미지 퀄리티를 높임
StyleGAN2-ADA
Adaptive Discriminator augmentaion 을 사용하여 적은 데이터로 GAN을 학습하는 방법 제안

Model acceleration

CNN 모델의 경량화에 초점을 두고 연구도 많이 진행되고 있음
MobileNets
모바일에서 사용 가능한 경량화된 모델
Nas와 KD 를 기반으로 Contional GAN 을 자동으로 최적화 하는 프레임워크 제안
BigGAN을 경량화 할수 있는 distillation pipeline 제안
Attention 기반의 경량 GAN ( https://arxiv.org/abs/2101.04775 )

Knowledge distillation

큰 모델을 사용하여 작은 모델을 학습시켜 작은 모델
최근 많은 모델들이 KD 방식을 pipeline의 일부로 사용 중

Wavelet transform

wavelet transform 참고 블로그 ( https://bskyvision.com/404 )
Wabelet 기반 방식은 새로운 방식은 아니고 Texture classification, Image restoration, Super resolution 등 많은 CV Task 에서 사용되오던 방식이다
Not-So-Big-GAN 이라는 낮은 해상도의 생성모델과 업샘플링을 위한 wabelet 기반의 서브 네트워크를 사용한모델도 있는데 Pixel기반보다 좋은 성능을 보여줌
그래서 우리도 Wavelet 기반의 CNN 을 사용하여 네트워크를 경량화시키고 더 smooth 한 latent space를 생성함

3. MobileStyleGAN Architecture

기존 StyleGAN2를 기반으로 했기에 기존 Mapping network와 Synthesis network를 사용

Image representational revisited (link)

StyleGAN과 같은 기존 대부분의 GAN들은 이미지의 픽셀값을 직접 예측하는것을 목표로하지만 우리의 경우 이미지의 주파수 기반의 representation을 학습
출력 영상의 discrete wavelet transform (DWT) 을 예측
2d 이미지에 DWT를 적용 시 낮은 공간해상도와 다른 주파수 영역을 갖는 4개의 채널로 변환
IDWT를 사용하여 원래 이미지로 복원 가능
wavelet 방식은 몇가지 장점을 가짐
Wavelet 방식은 픽셀방식보다 더 많은 이미지의 구조적 정보를 가지고 있기에 저해상도 feature에서 손실없이 고해상도 이미지를 생성가능
Harr wavelet 방식을 사용하였는데 이방식에서의 IDWT는 곱셈없이 효율적으로 구현이 가능
Haar Wavlet 설명 블로그 : (https://zockr.tistory.com/1065)
이미지의 디테일을 생성하는것은 매우 복잡하다
StyleGAN의 latent space는 낮은 주파수에선 smooth 하나 높은 주파수에서는 rough 하다
하지만 wavelet 기반에선 고주파 성분에 직접 정규화를 추가할수 있기에 저주파와 고주파 전체에서 smooth한 latent space를 가질 수 있음

Progressive growing revisited (link)

StyleGAN2 에서는 여러 해상도에서 명시적인 RGB값을 유지하기위해 Skip Connection을 사용하나 wavelet기반에서는 이것이 이미지 품질에 큰 기여를 하지 않는것을 발견
따라서 연산량 감소를 위해 마지막 블록에서만 단일 헤드로 Skip Connection을 대체
하지만 중간 해상도에서의 예측도 이미지 안정을 위해 중요하므로 Auxiliary loss를 추가함

Depthwise Separable Modulated Convolution (link)

Depthwise Separable convolution은 MobileNet에서 사용된 방식인데 일반적인 Convolution layer는 채널과 필터의 곱의 연산량을 갖는데 이것을 채널별연산 (Depthwise convolution) 과 채널간 연산 (Pointwise Convolution)으로 분리하여 연산량을 줄이는 방식이다
이것과 StyleGAN에서 제안한 Modulate Convolution 을 합성한 Depthwise Separable Modulated Convolution( 아 이름 너무길...)은 Modulate Convolution과 마찬가지로 Modulation, convolution, Normalization의 세파트로 구성된다
하지만 기존 방식과는 다르게 가중치에 modulate를 적용하는 것이 아닌 실제 입출력 값에 적용을 하여 Depthwise Convolution에 대한 적용을 쉽게 한다
Depthwise Separable Modulated Convolution 적용 과정
modulation은 Style ss를 기반으로 각 feature map 을 스케일링함
xx는 입력값, xx'는 스타일로 스케일된 입력값이고 이것에 Depthwise, Pointwise Convolution이 순차적으로 적용
출력 feature map 통계에서 ss를 제거하기 위해서 demodulation을 적용하는데 Convolution 연산이 선형성이므로 위 두가지 연산을 순차적으로 적용해도 일반 Convolution을 적용한 것과 같은 결과를 가짐
여기에 demodulation 계수를 계산하면
여기서 i,j,ki,j,k 는 각각 입력, 출력, feature에 해당하고 연산된 계수를 Convolution 결과에 적용

Demodulation Fusion (Link)

Batch normalization과 Convolution의 Fusion은 inference 단계에서 연산량을 감소시키는 매우 좋은 방법이다
이것은 두가지 연산이 모두 선형성을 가지기에 하나로 병합하기에 가능한 것인데 Demodulation도 Batch normaliztion과 학습시엔 유사하게 동작하지만 inference 단계에서는 선형적으로 동작하진 못한다
스타일이 적용되기 때문인데 이 연산이 선형성을 갖기 위해 스타일 대신 학습 가능한 파라미터로 대체하여 PointWise Convolution과 Fusion을 가능하게 하였고 이방식은 결과의 품질을 손실시키지 않음을 발견함

Upscale revisited (link)

기존엔 Upscale을 위해 Transpose Convolution을 사용하였으나 이것을 IDWT로 대체함
파라미터가 없기에 추가적으로 DWModulated Convolution블럭을 추가함
최종 블럭의 구조는 아래와 같음

Training Framework

KD 기반의 훈련 프레임워크를 가짐
SytleGAN2를 Teacher 모델로 사용

Data preparation (Link)

StyleGAN2 생성기가 있다면 paired 학습으로 변경이 가능해짐
style,noise,Iteacherstyle, noise, I_{teacher} 3가지의 paired 데이터
stylestyle : StyleGAN2 Mapping network의 출력
noisenoise : StyleGAN2에 입력되는 noise
IteacherI_{teacher} : StyleGAN2로 생성된 해상도별 이미지
중간중간 auxiliary prediction이 사용되는데 이부분을 위해 IteacherI_{teacher}로부터 생성된 저해상도 이미지 피라미드인 IteacherpyramidI^{pyramid}_{teacher}를 사용
과적합 방지를 위해 사전 생성된 데이터가 아닌 실제 학습중 데이터를 생성하고 메모리 소모를 줄이기 위해 실제 데이터는 사용하지않고 생성된 데이터만 사용

Training Objective (Link)

Pixel-Level Distillation Loss
Wavelet 영역에서 Pixel 기반 이미지를 예측하는것이 목표
이미지를 wavelet으로 변환 후 두값 사이의 L1 거리를 계산하고 정규화를 위해 추가로 Pixel 영역에서의 비교도 추가
Perceptual Loss
위에 pixel 기반 loss 만으로는 이미지의 perceptual의 차이는 알 수가 없기에 이것을 해결하는 loss 추가
VGG-16을 기반으로 4개 feature map에서 나온 값에 대한 L2 distance를 계산 하여 합산
GAN loss
위 두가지 로스만 사용해서는 blurred 한 이미지가 생성되기에 추가로 GAN에서 사용되는 기본 loss를 채용
각각 Generator 와 Discriminator에 사용되는 loss로 DTD_T는 differentiable augmentation이 적용된 Discriminator network를 의미
Full objective

5. Experiments

세상에 Experimnets가 이렇게 짧은 논문 처음봄
실험결과가 별로인 건지 실험을 별로 안하고 급하게 낸건지...

Training (Link)

FFHQ 데이터셋으로 사전학습된 StyleGAN2를 사용하여 학습
Optimizer : Adam (0.9, 0.999) , 5e-4
매 step마다 Generator와 Discriminator 전부 update
Augmentation : Affine, Cutout
λ1=λ2=1.0,λ3=0.1\lambda_1 = \lambda_2 = 1.0, \lambda_3 = 0.1
4 x NVIDIA 2080Ti, batch_size = 8 , 3일 소요

Result

위 내용이 전부인데 품질의 저하가 없었다는것 치곤 FID가 많이 높아진 느낌
아래는 i5-8279U 를 장착한 노트북으로 테스트 했는데 MobileStyleGAN에서 사용된 구조는 그래프 최적화가 가능한 구조여서 그래프 최적화를 자동으로 하게 해주는 OpenVINO를 사용시 매우 빠른 속도를 가질 수 있다고 한다

Conclusion

엣지 디바이스에서 접근하기 좋은 모델과 그 모델을 학습하는 pipeline 제안
Quantization 이나 prouning등으로 더 향상 시킬수 있을것이다

추가 코멘트

조금 하나씩 적용하면서 변화하는 비교를 만들어 줬으면 싶었는데 없어서 좀 아쉽
안경 사라짐