Transformer 아키텍쳐(Vaswani et al., 2017)가 나온 이래 BERT 등으로 발전하면서 NLP 분야에 획기적인 발전이 계속되었지만, 너무 큰 GPU 메모리를 요구하는 구조로 인해 스타트업 직원으로서는 감히 따라가 볼 엄두가 안나는 분야가 되어버렸다.
2020년 새해가 밝아오면서, 16GB짜리 GPU 1장으로도 NLP 연구를 가능하게 한다는 엄청나게 Memory-efficient한 Transformer 구조가 제안되었다는 소식이 전해졌다. 희망찬 소식에 큰 기대를 가지고 바로 그 논문을 들여다본다.
이 논문의 핵심 포인트 3가지
이 논문을 그림 한장으로 요약하면 다음그림과 같다(고 한다.) 지금은 이해하기 어렵다.
LSH Transformer
•
기존 transformer의 dot-product attention을 locality-sensitive hashing attention (이하 LSH attention)으로 바꾸어서 복잡도를 L^2에서 L*logL로 감소시켰다. (L은 sequence 길이)
Reversible Transformer
•
기존 transformer의 standard residual layer를 reversible residual layer 로 변경하면서 N개 layer마다 별도로 사용하던 메모리를 1번만 사용하도록 개선하였다.
Chunked Reversible Transformer
Transformer 내부의 feed forward network의 activation 구간을 chunk 단위로 쪼개서 메모리사용량을 감소
아래는 Illustrated Transformer(http://jalammar.github.io/illustrated-transformer/) 에서 가져온 Transformer 의 기본 구조 그림이다. LSH는 Self-Attention 부분을, Reversible은 Feed Forward 부분을 개선했다고 볼 수 있겠다.
각각에 대해 좀더 디테일하게 살펴보자.
LSH Transformer
문제의식
•
위 dot-product attention 구조에서 QKT 부분이 batch_size * seq_len * seq_len 의 메모리를 요구한다.
◦
만약 개별 qi 별로 K와의 product을 한꺼번에 다 구해야 한다는 전제를 깨뜨린다면?
•
왜 굳이 Q와 K를 다르게 두어야 할까? Q=K라고 둘 수 없나?
◦
shared-QK Transformer
◦
놀랍게도 shared-QK Transformer는 separated QK Transformer 대비 성능저하가 없음을 실험으로 확인했다.
•
다시 돌아가, 그렇다면 qi와 가까운 kj에 대해서만 product을 취한다면 어떨까?
•
그런데 어떤 qi의 nearest neighbor를 어떻게 찾지?
Locality sensitive Hashing
만약 x의 해시함수 h(x)가, 가까운 x끼리는 h(x)가 같고, 먼 x끼리는 h(x)가 다르게 리턴한다면 이 h(x)를 locality-sensitive하다고 한다.
이런 해시함수를 구하는 원리를 다음 그림으로 설명할 수 있다.
TODO : 설명을 간단히 추가할 것. (Andoni et al., 2015 참조할 것)
그렇다면 qi와 가까운 k의 인덱스 j를 모은 hash bucket을 다음과 같이 정의할 수 있다.
그래서 LSH attention을 다음과 같이 고쳐쓸 수 있다. 계산량과 메모리 요구량이 엄청나게 줄어든다.
만약 K가 64k 길이라면, 하나의 q에 대해 이 방식으로 32~64개의 가장 가까운 k에 대해서만 연산한다고 생각해 보라.
배치처리를 고려하여, 위 식은 다음과 같이 좀더 제너럴하게 표현될 수 있다.
특히 decoder 구현 부분의 future masking 처리에도 이 m이 유용하게 이용된다.
이제 LSH attention의 원리를 설명하는 이 그림을 다시 살펴보자.
(a)는 memory effcient하지 않은 sparse한 attention의 구조를 보여준다.
(b)는 hash bucket을 구성하여, hash bucket index 순서대로 소팅하여 재구성(i→Si)한 attention map이다.
(c) 아래와 같은 트릭이 중요하다.
이제 attention matrix는 diagonal 해진다.
(d) 적정길이 m으로 hash bucket을 뚝뚝 자른다.
이제 이 그림이 이해가 될 것이다.
Reversible Transformer
이 구조는 RevNet (The Reversible Residual Network: Backpropagation Without Storing Activations, Gomez et. al., 2017) 의 아이디어를 그대로 차용하였다.
backpropagation을 위해서 intermediate activation layer를 메모리에 저장해야 한다. 이렇게 하지 않고 backpropagation을 하는 방법을 제안한 논문이다.
Reversible Transformer는 위 F와 G에 Attention와 FFN 을 적용한 솔루션이다.
Chunking
Reversible Transformer 구조를 위와 같이 쪼개주자. dff=4k 에서 chuck size c에 비례하게끔 메모리 요구량이 줄어든다.