AI VIDEO BRIEFING

플래시 어텐션(FlashAttention) 원리: 타일링과 온라인 소프트맥스로 어텐션 가속

트랜스포머의 셀프 어텐션이 왜 느리고 메모리를 많이 쓰는지, 그리고 타일링과 온라인 소프트맥스를 결합한 플래시 어텐션이 정확도를 그대로 지키면서 GPU 메모리 접근을 줄여 어떻게 속도를 끌어올리는지 쉽게 정리했다.

플래시 어텐션(FlashAttention)은 어떻게 트랜스포머를 빠르게 만드나 영상 대표 이미지

핵심 메시지

  • 셀프 어텐션은 토큰이 수만 개로 늘어나면 N×N 어텐션 행렬을 만들어 메모리와 연산이 급증하는 병목이 된다.
  • 플래시 어텐션은 근사(approximation)와 달리 정확한 결과를 그대로 내면서도 빠르고 메모리 효율이 높은 알고리즘이다.
  • 핵심은 느린 HBM 대신 훨씬 빠른 온칩 SRAM을 활용하는 IO 인지(IO-aware) 설계와 타일링(tiling)이다.
  • 온라인 소프트맥스로 세 번 훑던 계산을 한 번의 루프로 합쳐, 거대한 어텐션 행렬을 메모리에 만들지 않고도 동일한 결과를 얻는다.

쉽게 이해하기

트랜스포머는 오늘날 AI 붐을 이끄는 핵심 구조지만, 그 심장인 어텐션 메커니즘은 느리고 메모리를 많이 먹는다는 약점이 있다. 그동안 어텐션 행렬을 근사해 속도를 높이려는 시도가 많았지만, 대개 정확도를 희생하면서도 실제 체감 속도(월 클록) 향상은 얻지 못했다. 플래시 어텐션은 이 한계를 넘어, 근사가 아니라 정확한 계산을 유지하면서 빠르고 메모리 효율적인 결과를 낸다.

어텐션이 느린 이유는 GPU 메모리 구조에 있다. 쿼리·키·값 행렬은 GPU 코어 바깥의 고대역폭 메모리(HBM)에 저장되는데, 점곱→소프트맥스→가중 평균을 계산하는 과정에서 중간 결과인 N×N 어텐션 행렬을 HBM에 썼다가 다시 읽기를 반복한다. 이 잦은 글로벌 메모리 접근이 큰 지연을 만든다. 메모리 계층은 느리지만 큰 DRAM, 그보다 빠른 HBM, 그리고 훨씬 빠르지만 아주 작은 온칩 SRAM으로 나뉜다.

플래시 어텐션의 첫 번째 아이디어는 타일링이다. 행렬 곱셈을 작은 블록(타일)으로 쪼개 SRAM에서 처리하면, 같은 결과를 얻으면서 글로벌 메모리 접근 횟수를 블록 크기에 비례해 줄일 수 있다. 영상은 4×4 행렬 곱을 2×2 블록으로 나누면 32번이던 메모리 접근이 16번으로 절반이 되는 예를 든다.

문제는 중간에 끼어 있는 소프트맥스다. 소프트맥스는 수치 안정성을 위해 먼저 최댓값을 빼는 안전 소프트맥스(safe softmax)를 쓰는데, 이는 최댓값 찾기·지수 합 구하기·정규화의 세 번 순회가 필요하다. 여기서 부분 수열의 최댓값을 이용한 점화식을 세우면 순회를 두 번으로 줄이는 온라인 소프트맥스가 되고, 같은 트릭을 출력 계산에까지 확장하면 모든 계산을 단 하나의 루프로 융합할 수 있다. 이렇게 거대한 어텐션 행렬을 한 번도 통째로 만들지 않고도 정확한 어텐션을 얻는 것이 플래시 어텐션의 핵심이다.

동작은 쿼리·키·값 행렬을 타일로 나눠 SRAM에 차례로 올려 부분 결과를 갱신하는 방식으로 진행된다. 모든 타일을 돌고 나면 완전한 어텐션 결과가 완성되며, 그동안 전체 행렬을 메모리에 만들지 않아 글로벌 메모리 접근이 크게 줄어든다. 이후 플래시 어텐션 2, 3으로 효율을 더 끌어올리는 후속 연구도 이어지고 있다.

주요 인사이트

  • "빠르게"의 비결이 더 영리한 수학이 아니라 하드웨어의 메모리 계층을 이해한 데이터 이동 최소화라는 점이 인상적이다. 알고리즘 설계에서 연산량만큼 메모리 접근 비용이 중요함을 보여준다.
  • 근사 기법은 정확도를 깎고도 실제 속도 이득이 작을 수 있는 반면, 플래시 어텐션은 정확성을 그대로 두고 속도를 얻는다는 점에서 결이 다르다.
  • 온라인 소프트맥스의 점화식은 "전체 최댓값에 대한 의존성"을 "부분 최댓값"으로 대체해 순차 의존을 끊는 트릭으로, 여러 패스를 한 루프로 합치는 일반적 최적화 발상의 좋은 사례다.

자주 묻는 질문

플래시 어텐션은 어텐션을 근사해서 빠른 것인가요?

아닙니다. 근사 기법과 달리 플래시 어텐션은 정확히 같은 결과를 내면서 빠르고 메모리 효율이 높습니다. 속도는 계산을 줄여서가 아니라 느린 메모리 접근을 줄여서 얻습니다.

어텐션이 느린 근본 원인은 무엇인가요?

중간 결과인 N×N 어텐션 행렬을 느린 고대역폭 메모리(HBM)에 썼다가 다시 읽는 과정이 반복되기 때문입니다. 토큰 수가 수만 개로 커지면 이 메모리 왕복 비용이 지연의 큰 원인이 됩니다.

타일링과 온라인 소프트맥스는 어떤 역할을 하나요?

타일링은 행렬을 작은 블록으로 나눠 빠른 온칩 SRAM에서 계산해 글로벌 메모리 접근을 줄입니다. 온라인 소프트맥스는 세 번 순회하던 소프트맥스를 점화식으로 한 번에 융합해, 큰 어텐션 행렬을 만들지 않고도 계산을 합칠 수 있게 합니다.

원문과 출처

이 글은 원본 영상의 자막을 바탕으로 한국어 독자를 위해 요약했습니다. 전체 맥락과 최신 정보는 원문에서 확인하세요.

YouTube 원본 영상 보기 ↗

관련 AI 소식

#플래시어텐션#트랜스포머#어텐션#GPU최적화#온라인소프트맥스