Abstract
•
큰 transformer model은 성능이 좋지만 특히 긴 문장에서 학습하기에는 비용이 많이든다
•
2가지 방법으로 이 문제를 개선했다
◦
LSH를 이용하여 에서 로 complexity를 개선했다
◦
reversible residual layer 방법을 사용하여 N번 activation을 저장하는게 아니라 한번만 저장하도록 개선했다
Introduction
•
transformer 기반의 모델들은 성능을 높이기위해 layer당 parameter를 0.5B, layer의 수는 64개까지 늘려왔다
•
11K 정도의 long sequence 에서도 사용이되며 music, image같은 다른 영역에서도 long sequence는 필요하다
•
하지만 큰 규모의 transformer를 학습하기위해서는 대규모의 GPU 인프라가 필요하며 1개의 GPU로는 finetunning조차 힘들다
•
이런 transformer에서 연산량은 아래와 같다
◦
layer당 0.5B가 필요한 모델의 경우 layer당 2GB의 GPU memory 필요
◦
64K token, embedding size 1024, batch size 8일때는 64K*1K*8 = 0.5B floats가 저장되어야 하며 2GB GPU memory가 필요하다
◦
학습할때 layer당으로만 memory를 사용하면 이런 큰 모델도 1개의 GPU로 충분히 fine-tune 가능하다(기본 transformer의 경우 layer별 activation값들을 다 저장해서 back prop해야한다)
•
기존 Transformer의 문제
◦
back propagation을 하기 위해 N layer에서는 single layer보다 N배 memory가 필요하다
◦
feed-forward layer의 depth는 attention activation보다 커서 memory 사용에 큰 부분을 차지한다
◦
dot product attention은 의 연산, 메모리 사용량을 가져서 비효율적이다
•
아래 방법을 사용하여 위 문제를 해결
◦
reversible layer —> N개 block의 activation을 저장하는게 아니고 1개의 activation copy만으로 skip connection 방식을 학습할 수 있는 방법
◦
feed-forward layer의 activation을 chunk들로 쪼개서 메모리를 절약하는 방법
◦
LSH를 이용해서 에서 로 효율화
•
이러한 기법들은 base transformer에 비해 구현적으로도 크게 변화가 없이 할 수 있고, 성능 감소도 없이 진행하였다
•
LSH는 2가지 기법보다는 변화량이 컸는데 concurrent hash의 수의 따라 결과가 다르게 나왔다 후에 실험에서 자세히 설명
LOCALITY-SENSITIVE HASHING ATTENTION
Dot-product attention
Memory-efficient attention
•
64K length의 경우 batch size가 1이라 가정해도 64K * 64K의 matrix가 필요하며 32-bit floats이라 할때 16GB memory가 필요하다 이건 long sequence transformer를 학습하기에 적은 GPU로는 쉽지 않다
•
근본적으로 가 하는 동작을 보면 전체가 필요하지 않다 결국 Query와 Key의 내적은 query와 유사한 key값을 찾기 위함이며 softmax를 통과하고 나오면 대부분의 값들은 사용하지 않는다
Where do Q,K,V come from?
•
Q,K,V는 이전 input의 activation값 기반으로 나온다 예를 들어 dim 512 multi head 8의 경우 각각 64 dim을 갖음
•
transformer에서는 Q,K,V를 각각 다른 parameter로 사용하였는데 LSH attention에서는 Q,K는 같은 parameter로 QK를 공유하는 transformer 방식을 사용하였다 이렇게 사용해도 성능에는 큰 영향이 없었고 실험에서 자세히 설명 예정
Hashing attention
•
위 attention 식을 보면 QK의 경우 64K의 length를 가정하면 결국 softmax에서 나온 값은 대부분은 낮은 값을 갖을 것이고 비슷한 상위 32~64개 정도의 값만 갖고 계산을 해도 충분할 것이라는 가정 그렇다면 이런 QK의 유사한 값을 계산하는 걸 nearest neighbor search 방식으로 찾을수는 없을까?
Locality sensitive hashing
•
각각의 vector를 hashing할 수 있는 hashing function을 정의하면 같은 hash값이 나오는 vector는 근접한 vector라고 가정할 수 있다
•
LSH 방식으로 이러한 hash를 하고 자세한 알고리즘은 아래와 같다
•
정한 bucket 수의 절반의 random matrix를 생성하여 input과 곱했을때 높은 값(가까운 위치)에 있는 포인트를 같은 bucket이라고 가정하는 방식
•
positive와 negative의 값을 concat하여 bucket에 들어가는 오차를 줄이기위해 아래와같은 hashing function이 나옴
LSH attention
•
아래 는 i번째 query와 같은 bucket에 들어가는 key set이라고 생각하면된다
LSH attention의 식이다 LSH attention은 값이 같은 자기자신의 값은 제거했다 LSH attention에서는 현재 위치보다 이전의 key값만 확인하여 attention을 동작한다 그리고 scaling factor도 제거한걸 확인할 수 있다
batch operation을 하기 위한 LSH attention 식이다 에 없는 값들은 infinity값을 빼줘서 반영이 안되도록 하는걸 확인할 수 있다 자기자신은 이미 뺐으므로 제거된다 후에 나오지만 Shared QK를 사용했기때문에 자기자신은 항상 높은 유사성을 갖고 이건 저자가 의도한 방향이 아니므로 이런 방법론을 취했다
(b)를 보면 Q에 가까운 K bucketing이 완료된 걸 볼 수 있다
여기서 생기는 문제점은 2가지가 있는데 bucket별 들어가는 key set의 사이즈가 다르고 sequence의 순서가 섞여서 있으므로 동시에 연산하기가 힘들다 —> order하여 해결
Q와 유사한 key set이 1개도 없는 bucket이 생길 수 있다 —> Q=K로 하여 문제 해결 후에 실험결과에서 Shared QK가 성능에는 큰 영향이 없다는 내용이 나온다
bucket size의 경우 length의 2배로 정하면 보통 잘 처리된다고 한다
위와같이 chunking attention을 적용하므로 문장 길이에 memory 효율적인 attention 연산이 가능해집니다
Multi-round LSH attention
위와 같이 했을때 hash function 자체가 random하게 bucket을 정하는 값을 구하는 것이므로 정확도를 올리기위에 위 동작을 여러번 돌린다 아래에 학습결과를 보면 8번정도 multi-round를 했을때 성능이 좋다
Reversible Transformer
transformer는 residual connection 방식을 사용하고있어서 back propagation을 하기 위해 중간 activation 값들을 memory에 올려야 한다
이 방법을 개선한 2017년에 나온 RevNet 방식이 있는데 이 방법을 이용해 memory를 절약하도록 적용
위 그림은 resnet과 revnet의 내용이다 resnet의 경우 y2로 y1을 y1으로 x1을 역으로 계산할 수 없지만
revnet의 경우 y2, y1을 갖고 있으면 x2, x1을 산할 수 있도록 구성한 네트워크이다
그래서 중간 activation값을 저장하지 않아도 되지만 중간 activation값을 계산하기 위해 학습 시간이 1.5~2배 정도 오래걸린다
Reversible Transformer
revnet의 F(x)를 attention G(x)를 FeedForward 로 대응하여 revnet을 적용
Chunking
전체의 tensor를 다 봐야하는 attention과 다르게 FeedForward의 경우 각각의 연산이 독립적으로 발생한다
이점을 이용해서 FeedForward 연산을 부분으로 chunk한 후 메모리에는 chunk에 해당하는 부분만 올려서 학습 물론 이런 방식은 학습 속도의 저하를 일으킨다
Experiments
저자는 imagenet64 dataset으로 image generation task와 enwiki8 dataset으로 text task들을 실험하였다
Effect of sharing QK
Figure 3을 보면 Shared QK attention과 reversibility가 성능에 영향을 주지 않는 걸 확인할 수 있고 오히려 Shared QK가 좀 더 빨리 수렴되는 걸 확인할 수 있다
Effect of reversible layers
아래 표를 보면 reversible layer가 BLEU score 하락이 발생하지 않는 걸 확인할 수 있다
weight sharing의 여부에서도 보면 성능하락이 발생하지 않는 걸 볼 수 있다
WMT14 데이터에서는 LSH attention은 사용하지 않았는데 LSH attention의 128개 token 별로 LSH attention chunk가 할당되는데 WMT14 data는 데이터가 128 token보다 짧아서 사용하지 않았다
LSH attention in Transformer
아래 그래프를 보면 8 round LSH attention의 성능이 full attention과 비슷한 걸 볼 수 있다
Large Reformer models
그래프를 보면 LSH Attention은 깊은 layer와 긴 sequence에서 robust한 것 도 확인할 수 있다
결론
•
reversible layer는 학습 속도가 상당히 오래 걸린다