////
Search
Duplicate

Copy of Ramalho and Garnelo [2019] Adaptive Posterior Learning: few-shot learning with a surprise-based memory module

원래 Episodic Memory in Lifelong Language Learning 이라는 딥마인드에서 NIPS2019에 제출한 논문을 보려다 참고논문으로 된 본 논문을 보고는 이 개념을 먼저 정리하고 넘어가야 되겠다 싶어 보게 되었다.

문제의식

본 논문은 아래 그림과 같은 문제를 푸는 모델을 구현할 수 있는 방법에 대한 고민으로부터 시작한다.
숫자 이미지(MNIST같은)와 세모/별표 같은 이미지를 주고 두개를 더한 값을 맞추라는 문제다. 이런 Observation 데이터를 수만 개 주고 나서 맞추라는 것이 아니다. Few-shot learning이 가능한가를 묻는 것이다.
이런 문제를 풀기 위해 다음 3가지 인지적 능력이 필요하다고 정리한다.
1.
Meta-Learning(Learning to Learn)
이것은 평범한 지도학습으로는 달성할 수 없다. train-set과 test-set의 분포가 같고, train-set이 전체 prior를 충분히 대표할 수 있으며, 모델이 train-set을 충분히 generalize한다는 가정이 있어야 성립하는 지도학습의 논리와 달리, 이런 문제에서는 train단계에서 거의(혹은 아예) 보지 못한 데이터에 대해서도 일반화를 통한 추론이 가능해야 한다.
2.
과거 경험의 기억
메모리를 어떻게 구현할 것인가? LSTM 으로 충분할까?
얼마나 많은 기억들을 저장해야 충분할까? 일정 이상 비슷한 기억들을 중복 저장하는 일이 없도록 하려면 어떻게 해야 할까?
3. 연역적 논리추론
저장된 기억(memory)들끼리, 그리고 기억들과 현재 관측(input) 사이의 유사도를 어떻게 측정할 수 있을까? 과연 어떤 기억과의 연관성이 현재의 새로운 input을 해석하는데 더 도움이 될 수 있을지 어떻게 알 수 있을까?
본 논문의 솔루션은 다음과 같다.
1.
어떤 기억을 저장할지 결정하기 위해 surprise based signal를 바탕으로 결정하는 간단한 memory controller
2.
확장 메모리와 현재 input 사이의 attention 메커니즘
3.
메모리 관리에 backprop이 사용되지 않는(그래서 computationally efficient한 - 이점이 기존 memory network들과 차별화되는) 방식으로 adaptive posterior learning

APL(Adaptive Posterior Learning) 모델 구조

Training 과정
새로운(surprise가 큰) 인풋이 들어올 때마다 M에 쌓여가는 과정 위 그림에서 점선 화살표는 backprop이 적용되지 않는다.
위 그림에서 Encoder(e) : 입력 데이터의 representation 생성
입력 xt를 받아서 low dimension representation et 를 생성 (Pretrained Resnet 등을 사용)
Memory Store(M)
Encoder가 생성한 embedding em와 true label ym pair를 저장
input 과 가장 가까운 memory 를 K-nearest neighbor로 구한다.
Memory Controller(c)
어떤 메모리를 저장할지 선택하는 로직
surprise는 다음과 같이 정의된다. 그래서 classifier의 cross entropy loss를 그대로 사용해도 된다.
surprise 의 threashhold (N은 클래스의 개수)
Decoder(d) : Classifier에 입력될 최종 logit을 생성
3가지 decoder를 생각해 볼 수 있다.
1) Relational self-attention feed-forward decoder
2) Relational working memory decoder(Santoro et al., 2018)
3) LSTM decoder

Experiments

1.
FEW-SHOT OMNIGLOT CLASSIFICATION
a) Saturation : 일정 iteration 이상 훈련되면 accuracy 와 memory size 가 saturate한다.
b) evolution of posterior distribution : decoder 의 logit 이 초반에는 uniform 하다가 훈련이 진행되면 특정 class에 confidence 가 올라가서 수렴한다.
c) 클래스별로 필요한 example(memory)의 개수가 그리 많지 않다. (거의 1:1에 가깝다)
d) 2000개 이상 샘플이 저장되면 5-shot 고정 케이스보다 accuracy 가 좋다.
2. IMAGENET
위 OMNIGLOT과 달리 full scale이다. (테스트시 못본 클래스는 없다)
N=1000개 클래스까지 가면 accuracy 가 좀 떨어진다. (역시 쉽지않은 문제)
3. NUMBER ANALOGY
example은 8 정도면 충분
SAFF가 decoder 로서는 더 좋다. WM은 LSTM과 유사