Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention
Summary
Motivation
- Transformer는 성능은 매우 좋으나 내부 구조의 복잡함으로 인해서 계산비용이 기하급수적으로 증가하는 문제가 있다. 이 부분에 대한 개선은 Dot Product를 이용해서 지속해서 개선해왔으나, 여전히 개선이 필요한 분야이다.
- 연구분야는 1)크게 메모리사용량을 줄이기 위한 방안을 찾는 부분과 2)Sequence Length를 늘려서 맥락을 최대한 이해할 수 있게 하려는 부분인데 계산량 자체를 줄인다면 이 두 부분에 기여할 수 있을 것이라고 보았다.
Approach
- 계산량이 많이 발생하는 Self Attention 부분을 재해석해서 커널 Function으로 간주하고 내부 수식을 내적으로 바꿈으로써 계산량을 Input Sequence를 $N$이라고 할 때, $O(N^2)$에서 $O(N)$로 단축시켰다.
- 이 때구체적으로 접근한 부분은 Attention Structure로 아래와 같이 수식으로 계된단다고 할 때, $$A_l(x)=V'=softmax( {{QK^T} \over {\sqrt{D}}}V )$$
- 임의의 시퀀스 $i$만 놓고 본다고 하면 이렇게 정리할 수 있다. 이 때 $sim(Q,K)$는 $exp({{Q,K^T} \over {\sqrt{D}}})$으로 하면 우리가 알고 있는 Transformer 모델이 된다. 이 때 중요한 것은 위에서 보고 있는 $softmax$가 필수가 아닐 수도 있다는 사실이다. $$V' = {{\sum\limits_{j=1}^Nsim(Q_i,K_j)V_j} \over {\sum\limits_{j=1}^Nsim(Q_i,K_j)}}$$
- 여기가 핵심인게 그러면 $sim(x,y)$이 하나의 유사도 함수라고 보면 유사도함수는 벡터간의 계산을 스칼라로 뱉어주는게 핵심이 이릍 특정 차원에 사상하고 내적하는 계산이라고 보면 그냥 커널함수라고 봐도 되서 아래와 같이 정리할 수 있게 된다. $$sim(x,y) = \phi(x)^T\phi(y)$$
- 이 걸 이제 위에 $V'$에 대한 식에 대입해서 정리해 보면 $K$,$V$에 대한 부분을 한번에 계산하고 반복해서 쓰는 형태로 다음과 같이 정리할 수 있고 자연스럽게 계산량을 줄일 수 있게 된다. $$V_i' = {{\phi(Q_i)^T\sum\limits_{j=1}^N\phi(K_j)V_j^T} \over {\phi(Q_i)^T\sum\limits_{j=1}^N\phi(K_j)}}$$
- Encoder, Decoder에 쓰이는 Multi Head Attention은 구조상 약간의 차이가 있기 때문에 Encoder, Decoder에 모두 Linear 구조로 바꾸어서 개선할 수 있었다.
Experiment
- Synthetic Task, Image Generation, Speech Recognition Task를 기준으로 실험을 진행하였다.
- 비교를 위한 Baseline Model은 Softmax Attention을 사용하는 Transformer, Reformer(Kitaev et al., 2020)를 선정 → 최근 모델 동향을 제대로 보지 않아 Reformer는 한 번 훑어봐야 할 듯
- 데이터셋은 MNIST, CIFAR-10, 80 hour WSJ dataset을 사용하였음
- fast_transformer 라는 별도의 Pytorch based Framework를 구축해서 사용
Transformers are RNNs
- Linear Transformer의 Decoder 기준 매번 $i$시점의 입력갑을 이전 $i-i$ 시점의 정보를 함께 넣어준다는 점에 RNN의 Hidden State와 비슷하기 때문에 사실상 Linear Transformer는 RNN과 비슷하다는 주장을 짧게 섹션으로 소개하는데 다소 뜬금없기도 했지만, 그것보다 이 구조 때문에 점화식의 형태로 이전시점의 데이터만 참고하면 될 수 있고 성능개선의 중요한 포인트라는 것은 확실하였다. 마치 Markov Property를 보는 느낌적 느낌.
Thoughts
- 더 훑어봐야겠지만 우선 테스트에서 NLP가 안보여서 안한 걸까, 못한 걸까는 좀 궁금하다.
- fast transformer framework의 확장성등이 언급되지 않아서 활용가능성에 대해서 의문이 있지만 확인해보 하지만 UberEats에서 성능 개선을 위해서 사용한 기록(링크)가 있어 우선은 가능한 것으로 보인다.