Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fused attention: Switch to Flash Decoding #656

Merged
merged 5 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 82 additions & 126 deletions awq/modules/fused/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@


try:
import awq_ft_ext
from flash_attn import flash_attn_func, flash_attn_with_kvcache

FT_INSTALLED = True
FA_INSTALLED = True
except:
FT_INSTALLED = False
FA_INSTALLED = False

HF_NEW_CACHE_FORMAT = False

Expand All @@ -28,6 +28,7 @@ class RoPE(nn.Module):
def __init__(self, head_dim, max_seq_len, device, rope_theta):
super(RoPE, self).__init__()

self.head_dim = head_dim
self.freqs_cis = nn.Parameter(
self.precompute_freqs_cis(head_dim, max_seq_len, rope_theta).to(device),
requires_grad=False,
Expand All @@ -49,7 +50,23 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)

def forward(self, xq: torch.Tensor, xk: torch.Tensor, start_pos: int, seqlen: int):
def forward(
self,
xq: torch.Tensor,
xk: torch.Tensor,
start_pos: int,
seqlen: int,
partial: bool = False,
):
if partial:
xq, xq_pass = (
xq[..., : self.head_dim],
xq[..., self.head_dim :],
)
xk, xk_pass = (
xk[..., : self.head_dim],
xk[..., self.head_dim :],
)
xq_ = torch.view_as_complex(
xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
)
Expand All @@ -62,6 +79,10 @@ def forward(self, xq: torch.Tensor, xk: torch.Tensor, start_pos: int, seqlen: in
xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)

if partial:
xq = torch.cat((xq, xq_pass), dim=-1)
xk = torch.cat((xk, xk_pass), dim=-1)

return xq_out.type_as(xq), xk_out.type_as(xk)


Expand Down Expand Up @@ -118,7 +139,7 @@ def __init__(
rope_theta=10000,
partial_rotary_factor=1.0,
head_dim=None,
attn_logit_softcapping=None,
attn_logit_softcapping=0.0,
**kwargs
):
super().__init__()
Expand Down Expand Up @@ -147,18 +168,18 @@ def __init__(
# attention shapes for self attention
self.attention_shapes = get_attention_shapes(
attention_shapes,
max_seq_len,
self.cache_batch_size,
n_heads,
n_kv_heads,
self.head_dim,
)
# cache store that rolls cache
self.cache = WindowedCache(
self.attention_shapes["cache_v"],
self.attention_shapes["cache_k"],
self.max_seq_len,
dev,
cache_batch_size=self.cache_batch_size,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
head_dim=self.head_dim,
max_seq_len=self.max_seq_len,
device=dev,
)

if use_alibi:
Expand All @@ -174,13 +195,10 @@ def __init__(

if kwargs.get("is_neox") is not None:
self.is_neox = kwargs["is_neox"]

self.attn_logit_softcapping = attn_logit_softcapping
self.use_sdpa = kwargs.get("use_sdpa", False)

def forward(
self, hidden_states: torch.Tensor, attention_mask=None, *args, **kwargs
):
def forward(self, hidden_states: torch.Tensor, *args, **kwargs):
bsz, seqlen, _ = hidden_states.shape

# Reallocate cache if batch size changes
Expand All @@ -196,21 +214,27 @@ def forward(
self.start_pos = 0

hf_is_generating = False
hf_is_first_forward = "past_key_value" in kwargs and kwargs["past_key_value"] is None
hf_is_new_cache_first_forward = "past_key_value" in kwargs and isinstance(kwargs["past_key_value"], DynamicCache) and kwargs["past_key_value"].get_seq_length() == 0
hf_is_first_forward = (
"past_key_value" in kwargs and kwargs["past_key_value"] is None
)
hf_is_new_cache_first_forward = (
"past_key_value" in kwargs
and isinstance(kwargs["past_key_value"], DynamicCache)
and kwargs["past_key_value"].get_seq_length() == 0
)

if self.is_hf_transformers and "use_cache" in kwargs:
hf_is_generating = kwargs["use_cache"]

# print(kwargs["past_key_value"].get_seq_length())

# In case we re-generate, we need to refresh the starting position
# to 0. We detect it by checking if `past_key_values` is set to None,
# which indicates that we are on the first step of `generate()`.
# This is only applicable for `transformers` integration
if (self.is_hf_transformers and (hf_is_first_forward or hf_is_new_cache_first_forward)) or (self.is_hf_transformers and not hf_is_generating):
if (
self.is_hf_transformers
and (hf_is_first_forward or hf_is_new_cache_first_forward)
) or (self.is_hf_transformers and not hf_is_generating):
self.start_pos = 0


xqkv = self.qkv_proj(hidden_states)
xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
Expand All @@ -219,114 +243,47 @@ def forward(
xk = self.attention_shapes["xk_slice"](xqkv)
xv = self.attention_shapes["xv_slice"](xqkv)

if seqlen > 1 or self.partial_rotary_factor < 1 or not FT_INSTALLED:
xq = xq.view((bsz, seqlen) + self.attention_shapes["xq_view"])
xk = xk.view((bsz, seqlen) + self.attention_shapes["xk_view"])
xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"])

if not self.use_alibi:
# Partial rotary embedding
if self.partial_rotary_factor < 1:
xq_rot, xq_pass = (
xq[..., : self.rotary_dim],
xq[..., self.rotary_dim :],
)
xk_rot, xk_pass = (
xk[..., : self.rotary_dim],
xk[..., self.rotary_dim :],
)
xq_rot, xk_rot = self.rope.forward(xq_rot, xk_rot, self.start_pos, seqlen)
xq = torch.cat((xq_rot, xq_pass), dim=-1)
xk = torch.cat((xk_rot, xk_pass), dim=-1)
else:
xq, xk = self.rope.forward(xq, xk, self.start_pos, seqlen)

values_store = xv.transpose(2, 1)
keys_store = (
xk.reshape((bsz, seqlen) + self.attention_shapes["xk_reshape"])
.permute(0, 2, 3, 1, 4)
.contiguous()
if not self.use_alibi:
xq, xk = self.rope.forward(
xq, xk, self.start_pos, seqlen, partial=self.partial_rotary_factor < 1
)

self.cache.to(xq)
self.cache.update_kv(values_store, keys_store, bsz, self.start_pos, seqlen)

# Only necessary to retrieve from cache when we are not processing context
if seqlen == 1:
xv, xk = self.cache.get_kv(bsz, self.start_pos, seqlen, self.head_dim)

keys = xk
values = xv

if self.n_kv_groups != 0:
keys = torch.repeat_interleave(keys, dim=2, repeats=self.n_kv_groups)
values = torch.repeat_interleave(
values, dim=2, repeats=self.n_kv_groups
)

xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)

# Used in Gemma2
if self.attn_logit_softcapping is not None:
scores = scores / self.attn_logit_softcapping
scores = torch.tanh(scores)
scores = scores * self.attn_logit_softcapping

if self.use_sdpa:
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : keys.shape[-2]]
is_causal = True if causal_mask is None and seqlen > 1 else False
output = torch.nn.functional.scaled_dot_product_attention(
xq,
keys,
values,
attn_mask=causal_mask,
dropout_p=0.0,
is_causal=is_causal,
)
else:
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if self.use_alibi:
scores = self.alibi.forward(scores, seqlen)

# When seqlen is 1, there is nothing else to attend to
if attention_mask is not None and seqlen > 1:
# For llama-arch, the causal mask is preallocated with bsz x 1 x max_seq_len x max_seq_len, thus we
# need to slice it
if attention_mask.shape[-1] != seqlen:
attention_mask = attention_mask[:, :, :seqlen, :seqlen]

scores = (
scores + attention_mask
) # (bs, n_local_heads, slen, cache_len + slen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)

attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
self.cache.to(xq)
self.cache.update_kv(
values_store=xv,
keys_store=xk,
batch_size=bsz,
start_pos=self.start_pos,
seqlen=seqlen,
)

if seqlen > 1:
output = flash_attn_func(
q=xq,
k=xk,
v=xv,
causal=True,
alibi_slopes=self.alibi.slopes if self.alibi is not None else None,
softcap=self.attn_logit_softcapping,
)
else:
xq = xq.view((bsz,) + self.attention_shapes["single_xq_view"])
xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"])
xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"])

alibi_slopes = self.alibi.slopes if self.alibi is not None else None
attention_weight = awq_ft_ext.single_query_attention(
xq, # query
xk, # key
xv, # value
self.cache.k, # key cache
self.cache.v, # value cache
None, # length per sample
alibi_slopes, # alibi slopes
self.start_pos, # timestep
self.rotary_dim, # rotary embedding dimension
self.rope_theta, # rotary embedding base
self.is_neox, # is neox
cache_seqlens = torch.full(
(bsz,), self.start_pos + seqlen, dtype=torch.int32, device=xq.device
)

output = flash_attn_with_kvcache(
q=xq,
k=xk,
k_cache=self.cache.k,
v=xv,
v_cache=self.cache.v,
cache_seqlens=cache_seqlens,
causal=True,
alibi_slopes=self.alibi.slopes if self.alibi is not None else None,
softcap=self.attn_logit_softcapping,
)
attention_weight = attention_weight.reshape(bsz, 1, -1)

attention_weight = output.view(bsz, seqlen, -1)
attn_output = self.o_proj(attention_weight)
self.start_pos += seqlen

Expand All @@ -338,7 +295,6 @@ def forward(
# about past key length
past_key_value = [torch.zeros(1, 1, self.start_pos, 1)]


if HF_NEW_CACHE_FORMAT and self.is_hf_transformers:
new_cache = DynamicCache()
new_cache.update(past_key_value[0], past_key_value[0], layer_idx=0)
Expand Down
Loading