Skip to content

Commit efd5911

Browse files
VHellendoornStellaAthenaQuentin-Anthony
authored
Add support for Flash attention (#725)
* Add support for Flash attention * Fix attention type can be both sparse and flash * Updates from running pre-commit on modified files * Update README.md Co-authored-by: Stella Biderman <[email protected]> Co-authored-by: Quentin Anthony <[email protected]>
1 parent 589c70a commit efd5911

File tree

7 files changed

+273
-17
lines changed

7 files changed

+273
-17
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ repos:
3232
hooks:
3333
- id: codespell
3434
args: [
35-
'--ignore-words-list=reord', # Word used in error messages that need rewording
35+
'--ignore-words-list=reord,dout', # Word used in error messages that need rewording
3636
--check-filenames,
3737
--check-hidden,
3838
]

README.md

+5
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ from the repository root.
9999
</aside>
100100

101101

102+
### Flash Attention
103+
104+
To use [Flash-Attention](https://github.com/HazyResearch/flash-attention), install the additional dependencies in `./requirements/requirements-flashattention.txt` and set the attention type in your configuration accordingly (see [configs](./configs/)). This can provide significant speed-ups over regular attention on certain GPU architectures, including Ampere GPUs (such as A100s); see the repository for more details.
105+
106+
102107
### Containerized Setup
103108

104109
We also provide a Dockerfile if you prefer to run NeoX in a container. To use this option, first build an image named `gpt-neox` from the repository root directory with `docker build -t gpt-neox -f Dockerfile .`. We also host pre-built images on Docker Hub at `leogao2/gpt-neox`.

configs/neox_arguments.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ Model Arguments
334334
The first item in the list specifies the attention type(s), and should be a list of strings. The second item
335335
specifies the number of times to repeat those attention types in the full list.
336336

337-
attention type choices: [global, local, sparse_fixed, sparse_variable, bslongformer, bigbird]
337+
attention type choices: [global, local, sparse_fixed, sparse_variable, bigbird, bslongformer, gmlp, amlp, flash]
338338

339339
So a 12 layer network with only global attention could be specified like:
340340
[[[`global`], 12]]
@@ -345,6 +345,8 @@ Model Arguments
345345
If none is specified, this defaults to
346346
[[[`global`], n_layers]]
347347

348+
"flash" attention refers to optimized global attention for Ampere (and some other) generation GPUs described here [Flash-Attention](https://github.com/HazyResearch/flash-attention).
349+
348350

349351

350352
- **sparsity_config**: dict
@@ -950,7 +952,7 @@ Text Generation arguments
950952

951953
- **eval_results_prefix**: str
952954

953-
Default =
955+
Default =
954956

955957
prefix to which to save evaluation results - final fp will be {eval_results_prefix}_eval_results_yy-mm-dd-HH-MM.json
956958

@@ -1538,7 +1540,7 @@ Args for deepspeed config
15381540

15391541
Default = None
15401542

1541-
1543+
15421544

15431545

15441546

@@ -1670,6 +1672,4 @@ Args for deepspeed runner (deepspeed.launcher.runner).
16701672
- **comment**: str
16711673

16721674
Default = None
1673-
16741675
Adds a `--comment` to the DeepSpeed launch command. In DeeperSpeed this is passed on to the SlurmLauncher as well. Sometime necessary for cluster rules, or so I've heard.
1675-

megatron/model/flash_attention.py

+185
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# Based on: https://github.com/HazyResearch/flash-attention/blob/4a6eaa9f27df6fff7ffb2c24e894938a687dd870/flash_attn/flash_attn_interface.py
2+
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
7+
import flash_attn_cuda
8+
9+
10+
def _flash_attn_forward(
11+
q,
12+
k,
13+
v,
14+
out,
15+
cu_seqlens_q,
16+
cu_seqlens_k,
17+
max_seqlen_q,
18+
max_seqlen_k,
19+
dropout_p,
20+
softmax_scale,
21+
causal,
22+
return_softmax,
23+
num_splits=0,
24+
generator=None,
25+
):
26+
"""
27+
num_splits: how much to parallelize over the seqlen_q dimension. num_splits=0 means
28+
it will be set by an internal heuristic. We're exposing num_splits mostly for benchmarking.
29+
Don't change it unless you know what you're doing.
30+
"""
31+
softmax_lse, *rest = flash_attn_cuda.fwd(
32+
q,
33+
k,
34+
v,
35+
out,
36+
cu_seqlens_q,
37+
cu_seqlens_k,
38+
max_seqlen_q,
39+
max_seqlen_k,
40+
dropout_p,
41+
softmax_scale,
42+
False,
43+
causal,
44+
return_softmax,
45+
num_splits,
46+
generator,
47+
)
48+
# if out.isnan().any() or softmax_lse.isnan().any():
49+
# breakpoint()
50+
S_dmask = rest[0] if return_softmax else None
51+
return out, softmax_lse, S_dmask
52+
53+
54+
def _flash_attn_backward(
55+
dout,
56+
q,
57+
k,
58+
v,
59+
out,
60+
softmax_lse,
61+
dq,
62+
dk,
63+
dv,
64+
cu_seqlens_q,
65+
cu_seqlens_k,
66+
max_seqlen_q,
67+
max_seqlen_k,
68+
dropout_p,
69+
softmax_scale,
70+
causal,
71+
num_splits=0,
72+
generator=None,
73+
):
74+
"""
75+
num_splits: whether to parallelize over the seqlen_k dimension (num_splits > 1) or
76+
not (num_splits = 1). num_splits=0 means it will be set by an internal heuristic.
77+
Any value above 1 will call the same kernel (i.e. num_splits=2 would call the same kernel
78+
as num_splits=3), so effectively the choices are 0, 1, and 2.
79+
This hyperparameter can be tuned for performance, but default value (heuristic) should work fine.
80+
"""
81+
_, _, _, softmax_d = flash_attn_cuda.bwd(
82+
dout,
83+
q,
84+
k,
85+
v,
86+
out,
87+
softmax_lse,
88+
dq,
89+
dk,
90+
dv,
91+
cu_seqlens_q,
92+
cu_seqlens_k,
93+
max_seqlen_q,
94+
max_seqlen_k,
95+
dropout_p,
96+
softmax_scale,
97+
False,
98+
causal,
99+
num_splits,
100+
generator,
101+
)
102+
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
103+
# breakpoint()
104+
return dq, dk, dv, softmax_d
105+
106+
107+
class FlashAttnQKVPackedFunc(torch.autograd.Function):
108+
@staticmethod
109+
def forward(
110+
ctx,
111+
qkv,
112+
cu_seqlens,
113+
max_seqlen,
114+
dropout_p,
115+
softmax_scale,
116+
causal,
117+
return_softmax,
118+
):
119+
# Save rng_state because the backward pass will regenerate the dropout mask
120+
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
121+
if softmax_scale is None:
122+
softmax_scale = qkv.shape[-1] ** (-0.5)
123+
out, softmax_lse, S_dmask = _flash_attn_forward(
124+
qkv[:, 0],
125+
qkv[:, 1],
126+
qkv[:, 2],
127+
torch.empty_like(qkv[:, 0]),
128+
cu_seqlens,
129+
cu_seqlens,
130+
max_seqlen,
131+
max_seqlen,
132+
dropout_p,
133+
softmax_scale,
134+
causal=causal,
135+
return_softmax=return_softmax,
136+
)
137+
ctx.save_for_backward(qkv, out, softmax_lse, cu_seqlens, rng_state)
138+
ctx.dropout_p = dropout_p
139+
ctx.max_seqlen = max_seqlen
140+
ctx.softmax_scale = softmax_scale
141+
ctx.causal = causal
142+
return out if not return_softmax else (out, softmax_lse, S_dmask)
143+
144+
@staticmethod
145+
def backward(ctx, dout, *args):
146+
qkv, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
147+
if rng_state is not None:
148+
cur_rng_state = torch.cuda.get_rng_state()
149+
torch.cuda.set_rng_state(rng_state)
150+
dqkv = torch.empty_like(qkv)
151+
_flash_attn_backward(
152+
dout,
153+
qkv[:, 0],
154+
qkv[:, 1],
155+
qkv[:, 2],
156+
out,
157+
softmax_lse,
158+
dqkv[:, 0],
159+
dqkv[:, 1],
160+
dqkv[:, 2],
161+
cu_seqlens,
162+
cu_seqlens,
163+
ctx.max_seqlen,
164+
ctx.max_seqlen,
165+
ctx.dropout_p,
166+
ctx.softmax_scale,
167+
ctx.causal,
168+
)
169+
if rng_state is not None:
170+
torch.cuda.set_rng_state(cur_rng_state)
171+
return dqkv, None, None, None, None, None, None
172+
173+
174+
def flash_attn_unpadded_qkvpacked_func(
175+
qkv,
176+
cu_seqlens,
177+
max_seqlen,
178+
dropout_p,
179+
softmax_scale=None,
180+
causal=False,
181+
return_attn_probs=False,
182+
):
183+
return FlashAttnQKVPackedFunc.apply(
184+
qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_attn_probs
185+
)

megatron/model/transformer.py

+75-11
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,8 @@ def __init__(
259259
self.rotary_emb = None
260260

261261
self.attention_type = neox_args.attention_config[layer_number]
262-
self.sparse = self.attention_type != "global"
262+
self.use_flash_attention = self.attention_type == "flash"
263+
self.sparse = self.attention_type != "global" and not self.use_flash_attention
263264
if self.sparse:
264265
self.sparse_attn = configure_sparse_attention(
265266
neox_args,
@@ -268,19 +269,31 @@ def __init__(
268269
mpu=mpu,
269270
)
270271
else:
271-
self.scale_mask_softmax = FusedScaleMaskSoftmax(
272-
input_in_fp16=self.fp16,
273-
input_in_bf16=self.bf16,
274-
fusion_type=get_fusion_type(neox_args),
275-
mask_func=self.attention_mask_func,
276-
softmax_in_fp32=self.attention_softmax_in_fp32,
277-
scale=coeff,
278-
)
272+
if self.use_flash_attention:
273+
from megatron.model.flash_attention import (
274+
flash_attn_unpadded_qkvpacked_func,
275+
)
276+
277+
self.flash_attention_function = flash_attn_unpadded_qkvpacked_func
278+
if self.pos_emb == "alibi":
279+
raise ValueError(
280+
"Flash attention is currently not compatible with AliBi positional embeddings. Use sinuisoidal, learned, or rotary embeddings instead."
281+
)
282+
else:
283+
self.scale_mask_softmax = FusedScaleMaskSoftmax(
284+
input_in_fp16=self.fp16,
285+
input_in_bf16=self.bf16,
286+
fusion_type=get_fusion_type(neox_args),
287+
mask_func=self.attention_mask_func,
288+
softmax_in_fp32=self.attention_softmax_in_fp32,
289+
scale=coeff,
290+
)
279291

280292
# Dropout. Note that for a single iteration, this layer will generate
281293
# different outputs on different number of parallel partitions but
282294
# on average it should not be partition dependent.
283-
self.attention_dropout = nn.Dropout(neox_args.attention_dropout)
295+
self.dropout_p = neox_args.attention_dropout
296+
self.attention_dropout = nn.Dropout(self.dropout_p)
284297

285298
# Output.
286299
self.dense = mpu.RowParallelLinear(
@@ -396,6 +409,55 @@ def attention(
396409
context_layer = context_layer.view(*output_size)
397410
return context_layer
398411

412+
def flash_attention(self, query_layer, key_layer, value_layer):
413+
# [b, np, sq, sk]
414+
output_size = (
415+
query_layer.size(1),
416+
query_layer.size(2),
417+
query_layer.size(0),
418+
key_layer.size(0),
419+
)
420+
# [s, b, np, hn] -> [b, s, np, hn] -> [b * s, 1, np, hn]
421+
query_layer = query_layer.transpose(0, 1).reshape(
422+
output_size[0] * output_size[2], 1, output_size[1], -1
423+
)
424+
key_layer = key_layer.transpose(0, 1).reshape(
425+
output_size[0] * output_size[3], 1, output_size[1], -1
426+
)
427+
value_layer = value_layer.transpose(0, 1).reshape(
428+
output_size[0] * output_size[3], 1, output_size[1], -1
429+
)
430+
431+
# Combined q/k/v into [b * s, 3, np, hn].
432+
qkv = torch.concat([query_layer, key_layer, value_layer], dim=1)
433+
434+
batch_size = output_size[0]
435+
seqlen = output_size[2]
436+
max_s = seqlen
437+
cu_seqlens = torch.arange(
438+
0,
439+
(batch_size + 1) * seqlen,
440+
step=seqlen,
441+
dtype=torch.int32,
442+
device=qkv.device,
443+
)
444+
output = self.flash_attention_function(
445+
qkv,
446+
cu_seqlens,
447+
max_s,
448+
self.dropout_p if self.training else 0.0,
449+
softmax_scale=None,
450+
causal=True,
451+
)
452+
# [b * sq, np, hn] -> [b, sq, np, hn]
453+
matmul_result = output.view(
454+
output_size[0], output_size[2], output.shape[1], output.shape[2]
455+
)
456+
# [b, sq, np, hn] -> [b, np, sq, hn]
457+
matmul_result = matmul_result.transpose(1, 2)
458+
459+
return matmul_result
460+
399461
def sparse_attention(self, query_layer, key_layer, value_layer, attention_mask):
400462
# TODO: sparse attn dropout?
401463
# TODO: pad to block size
@@ -483,7 +545,9 @@ def forward(self, hidden_states, attention_mask, layer_past=None):
483545
if self.use_cache:
484546
present = torch.stack((key_layer, value_layer))
485547

486-
if not self.sparse:
548+
if self.use_flash_attention:
549+
context_layer = self.flash_attention(query_layer, key_layer, value_layer)
550+
elif not self.sparse:
487551
context_layer = self.attention(
488552
query_layer, key_layer, value_layer, layer_past, attention_mask
489553
)

megatron/neox_arguments/neox_args.py

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"bslongformer",
3535
"gmlp",
3636
"amlp",
37+
"flash",
3738
]
3839

3940

Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
flash-attn==0.2.2

0 commit comments

Comments
 (0)