Flash Attention(高速アテンション)FA
Flash Attention は、GPU の SRAM とタイリングを活用してアテンション行列をメモリに具体化せず計算する I/O 最適化アルゴリズムです。メモリ消費を線形化し2-4倍の高速化を実現します。
詳細解説
Flash Attention(フラッシュアテンション)は、Tri Dao らが2022年に提案した Self-Attention の I/O-aware な厳密計算アルゴリズムで、GPU の HBM(High Bandwidth Memory)と SRAM(オンチップ・シェアードメモリ)の帯域差を最大限に活用するタイリング戦略により、N×N のアテンション行列を一度に具体化することなくブロック単位で逐次計算します。これにより、メモリ消費が系列長 N に対して O(N^2) から O(N) に改善され、A100 GPU で2-4倍、長文ではさらに大きな高速化が得られます。重要なのは数値的に厳密(exact)であり近似ではない点で、ドロップアウトや任意のマスクとも互換性があります。Online Softmax の手法を再帰的に適用することで、すべてのスコアを保持せずに正規化を実現しています。PyTorch 2.0 の scaled_dot_product_attention で透過的に呼ばれるようになり、Hugging Face Transformers の attn_implementation="flash_attention_2" 指定で簡単に使えます。GPT-4 を含む主要 LLM の事前学習・ファインチューニング・推論で標準的に採用され、Flash Attention 2(2023)、Flash Attention 3(2024、H100 最適化)へと進化しています。
実装例 / 使い方
- 01PyTorch 2.0 以降は torch.nn.functional.scaled_dot_product_attention で自動利用されます
- 02GPT-4・Llama・Mistral などの大規模学習で標準採用されています
- 03A100 80GB で系列長32Kでも OOM せず学習できます
関連する用語
Self-Attention(自己注意機構)
SASelf-Attention は、入力系列の各トークンが同じ系列内の他のトークンとの関連度を計算する仕組みです。Query・Key・Value の3行列に射影し...
Flash Attention 2
FA2Flash Attention 2 は2023年に発表された Flash Attention の改良版です。並列化粒度を見直し非行列演算のオーバーヘッドを削減し...
Flash Attention 3
FA3Flash Attention 3 は2024年7月に発表された H100 GPU に最適化された改良版です。ワープ特殊化と非同期 TMA、FP8 サポートによ...
KV Cache(KVキャッシュ)
KV CacheKV Cache は、自己回帰生成中に過去トークンの Key・Value を保存して再利用する仕組みです。1トークン生成あたりの計算量が O(N^2) から O...
Transformer
Transformer (トランスフォーマー) は 2017 年 Google 論文で発表された深層学習アーキテクチャで、Self-Attention 機構を核...
Flash Attention(高速アテンション)を、実際に活用する
用語の意味は分かった。次は実装。EXBANK の無料診断で、貴社で具体的にどう活用できるかをご提案します。
営業時間 平日10-18時 / 通常24時間以内に返信
