Skip to content

Commit c16e7b8

Browse files
authored
deepseek2 support pd multi dp kv trans. (#765)
1 parent 5f26114 commit c16e7b8

File tree

3 files changed

+201
-16
lines changed

3 files changed

+201
-16
lines changed

lightllm/common/deepseek2_mem_manager.py

+60-10
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import List, Union
77
from lightllm.utils.log_utils import init_logger
88
from lightllm.common.kv_trans_kernel.kv_trans import kv_trans
9+
from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_d_node, kv_trans_v2_for_p_node
910

1011
logger = init_logger(__name__)
1112

@@ -103,50 +104,99 @@ def send_to_decode_node_p2p(
103104
"""
104105
使用 p2p triton kernel 进行数据复制和传输的实现方式。
105106
"""
106-
assert dp_size_in_node == 1
107+
if not hasattr(self, "mem_ptrs_dict"):
108+
self.mem_ptrs_dict = {}
109+
for layer_index in range(self.layer_num):
110+
mems_ptr = []
111+
for i in range(0, len(mem_managers), len(mem_managers) // dp_size_in_node):
112+
mems_ptr.append(mem_managers[i].kv_buffer[layer_index, :, :, :].data_ptr())
113+
mems_ptr = torch.tensor(mems_ptr, dtype=torch.uint64, device="cuda")
114+
self.mem_ptrs_dict[layer_index] = mems_ptr
107115

108116
move_token_indexes = []
117+
token_dp_indexes = []
109118
for task in move_tasks:
110119
if task.move_kv_len != 0:
111120
move_token_indexes.extend(task.prefill_token_indexes[-task.move_kv_len :])
121+
token_dp_indexes.extend([task.prefill_dp_index for _ in range(task.move_kv_len)])
112122

113123
move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda")
124+
token_dp_indexes = torch.tensor(token_dp_indexes, dtype=torch.int32, device="cuda")
114125
for layer_index in range(self.layer_num):
115-
move_buffer = self._get_kv_move_data_p2p(move_token_indexes, layer_index, self.kv_move_buffer)
126+
move_buffer = self._get_kv_move_data_p2p(
127+
move_token_indexes, token_dp_indexes, layer_index, self.kv_move_buffer, dp_size_in_node
128+
)
116129
dist.send(move_buffer, dst=1)
117130
return
118131

119-
def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, kv_move_buffer: torch.Tensor):
132+
def _get_kv_move_data_p2p(
133+
self,
134+
token_indexes: torch.Tensor,
135+
token_dp_indexes: torch.Tensor,
136+
layer_index: int,
137+
kv_move_buffer: torch.Tensor,
138+
dp_size_in_node: int,
139+
):
120140
move_token_num = len(token_indexes)
121141
move_size = self.kv_buffer.numel() // self.layer_num // self.size * move_token_num
122142
move_buffer = kv_move_buffer.view(-1)[0:move_size].view(move_token_num, self.head_num, self.head_dim)
123-
kv_trans(
124-
self.kv_buffer[layer_index, :, :, :], token_indexes, move_buffer, self.kv_move_buf_indexes[0:move_token_num]
143+
kv_trans_v2_for_p_node(
144+
input_mems=self.mem_ptrs_dict[layer_index],
145+
input_idx=token_indexes,
146+
input_dp_idx=token_dp_indexes,
147+
output=move_buffer,
148+
output_idx=self.kv_move_buf_indexes[0:move_token_num],
149+
dp_size_in_node=dp_size_in_node,
125150
)
126151
return move_buffer
127152

128153
def receive_from_prefill_node_p2p(
129154
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
130155
):
131-
assert dp_size_in_node == 1
156+
if not hasattr(self, "mem_ptrs_dict"):
157+
self.mem_ptrs_dict = {}
158+
for layer_index in range(self.layer_num):
159+
mems_ptr = []
160+
for i in range(0, len(mem_managers)):
161+
mems_ptr.append(mem_managers[i].kv_buffer[layer_index, :, :, :].data_ptr())
162+
mems_ptr = torch.tensor(mems_ptr, dtype=torch.uint64, device="cuda")
163+
self.mem_ptrs_dict[layer_index] = mems_ptr
132164

133165
move_token_indexes = []
166+
token_dp_indexes = []
134167
for task in move_tasks:
135168
if task.move_kv_len != 0:
136169
move_token_indexes.extend(task.decode_token_indexes[-task.move_kv_len :])
170+
token_dp_indexes.extend([task.decode_dp_index for _ in range(task.move_kv_len)])
137171

138172
move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda")
173+
token_dp_indexes = torch.tensor(token_dp_indexes, dtype=torch.int32, device="cuda")
139174

140175
token_num = len(move_token_indexes)
141176
move_size = self.kv_buffer.numel() // self.layer_num // self.size * token_num
142177
recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(token_num, self.head_num, self.head_dim)
143178
for layer_index in range(self.layer_num):
144179
dist.recv(recive_buffer, src=0)
145-
for i, mem in enumerate(mem_managers):
146-
mem._write_kv_move_data_p2p(move_token_indexes, recive_buffer, layer_index)
180+
self._write_kv_move_data_p2p(
181+
move_token_indexes, token_dp_indexes, recive_buffer, layer_index, dp_size_in_node
182+
)
147183
return
148184

149-
def _write_kv_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: torch.Tensor, layer_index):
185+
def _write_kv_move_data_p2p(
186+
self,
187+
token_indexes: torch.Tensor,
188+
token_dp_indexes: torch.Tensor,
189+
buffer_tensor: torch.Tensor,
190+
layer_index,
191+
dp_size_in_node: int,
192+
):
150193
move_token_num = len(token_indexes)
151-
kv_trans(buffer_tensor, self.kv_move_buf_indexes[0:move_token_num], self.kv_buffer[layer_index], token_indexes)
194+
kv_trans_v2_for_d_node(
195+
output_mems=self.mem_ptrs_dict[layer_index],
196+
output_idx=token_indexes,
197+
output_dp_idx=token_dp_indexes,
198+
input=buffer_tensor,
199+
input_idx=self.kv_move_buf_indexes[0:move_token_num],
200+
dp_size_in_node=dp_size_in_node,
201+
)
152202
return

lightllm/common/kv_trans_kernel/kv_trans_v2.py

+102-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
@triton.jit
8-
def _kv_trans_kernel(
8+
def _kv_trans_prefill_node_kernel(
99
input_mems_ptr,
1010
input_stride_0,
1111
input_stride_1,
@@ -48,7 +48,7 @@ def _kv_trans_kernel(
4848
return
4949

5050

51-
def kv_trans_v2(
51+
def kv_trans_v2_for_p_node(
5252
input_mems: torch.Tensor,
5353
input_idx: torch.Tensor,
5454
input_dp_idx: torch.Tensor,
@@ -75,7 +75,7 @@ def kv_trans_v2(
7575
NUM_STAGES = 3
7676
grid = (grid_count,)
7777

78-
_kv_trans_kernel[grid](
78+
_kv_trans_prefill_node_kernel[grid](
7979
input_mems,
8080
*output.stride(),
8181
input_idx,
@@ -92,3 +92,102 @@ def kv_trans_v2(
9292
num_warps=1,
9393
)
9494
return
95+
96+
97+
@triton.jit
98+
def _kv_trans_decode_node_kernel(
99+
output_mems_ptr,
100+
output_stride_0,
101+
output_stride_1,
102+
output_stride_2,
103+
output_token_idx_ptr,
104+
output_token_dp_index_ptr,
105+
input_ptr,
106+
input_stride_0,
107+
input_stride_1,
108+
input_stride_2,
109+
input_token_idx_ptr,
110+
token_num: int,
111+
head_num: int,
112+
head_dim: int,
113+
grid_count: int,
114+
BLOCK_SIZE: tl.constexpr,
115+
NUM_STAGES: tl.constexpr,
116+
CARD_NUM_PER_D: tl.constexpr,
117+
):
118+
input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64)
119+
input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64)
120+
output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64)
121+
output_stride_1 = tl.cast(output_stride_1, dtype=tl.int64)
122+
123+
head_num_dim = head_num * head_dim
124+
tid = tl.program_id(0)
125+
126+
offs = tl.arange(0, BLOCK_SIZE)
127+
while tid < token_num:
128+
dp_index = tl.load(output_token_dp_index_ptr + tid)
129+
input_token_idx = tl.load(input_token_idx_ptr + tid)
130+
output_token_idx = tl.load(output_token_idx_ptr + tid)
131+
for block_idx in tl.range(0, tl.cdiv(head_num_dim, BLOCK_SIZE), 1, num_stages=NUM_STAGES):
132+
cur_offs = block_idx * BLOCK_SIZE + offs
133+
in_datas = tl.load(input_ptr + input_stride_0 * input_token_idx + cur_offs, mask=cur_offs < head_num_dim)
134+
for mem_index in tl.range(
135+
dp_index * CARD_NUM_PER_D, (dp_index + 1) * CARD_NUM_PER_D, num_stages=NUM_STAGES
136+
):
137+
output_ptr = tl.load(output_mems_ptr + mem_index).to(tl.pointer_type(input_ptr.dtype.element_ty))
138+
tl.store(
139+
output_ptr + output_stride_0 * output_token_idx + cur_offs, in_datas, mask=cur_offs < head_num_dim
140+
)
141+
142+
tid += grid_count
143+
144+
return
145+
146+
147+
def kv_trans_v2_for_d_node(
148+
output_mems: torch.Tensor,
149+
output_idx: torch.Tensor,
150+
output_dp_idx: torch.Tensor,
151+
input: torch.Tensor,
152+
input_idx: torch.Tensor,
153+
dp_size_in_node: int,
154+
):
155+
"""
156+
output_mems 是一个 torch.uint64 的tensor, 其内部存储了当前使用的对应的mem_manager对象中kv cache的首指针。
157+
"""
158+
assert output_mems.is_contiguous()
159+
assert input.is_contiguous()
160+
assert len(output_mems.shape) == 1
161+
assert len(input.shape) == 3
162+
assert len(input_idx) == len(output_idx)
163+
assert len(input_idx) == len(output_dp_idx)
164+
assert len(output_mems) % dp_size_in_node == 0
165+
166+
card_num_per_d = len(output_mems) // dp_size_in_node
167+
168+
_, head_num, head_dim = input.shape
169+
token_num = len(input_idx)
170+
# 用较少的资源来做数据传输,防止占用过多的 sm 计算单元
171+
grid_count = 20
172+
BLOCK_SIZE = 256
173+
NUM_STAGES = 3
174+
grid = (grid_count,)
175+
176+
_kv_trans_decode_node_kernel[grid](
177+
output_mems,
178+
*input.stride(),
179+
output_idx,
180+
output_dp_idx,
181+
input,
182+
*input.stride(),
183+
input_idx,
184+
token_num=token_num,
185+
head_num=head_num,
186+
head_dim=head_dim,
187+
grid_count=grid_count,
188+
BLOCK_SIZE=BLOCK_SIZE,
189+
NUM_STAGES=NUM_STAGES,
190+
CARD_NUM_PER_D=card_num_per_d,
191+
num_warps=1,
192+
)
193+
return

unit_tests/common/kv_trans_kernel/test_kv_trans_v2.py

+39-3
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import pytest
22
import torch
33
import random
4-
from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2
4+
from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_p_node, kv_trans_v2_for_d_node
55

66

77
@pytest.mark.parametrize(
88
"token_num",
99
[token_num for token_num in range(5, 10)],
1010
)
11-
def test_kv_trans_v2(token_num):
11+
def test_kv_trans_v2_for_p_node(token_num):
1212
dp_size_in_node = 8
1313
head_num = 2
1414
head_dim = 512
@@ -26,7 +26,7 @@ def test_kv_trans_v2(token_num):
2626
test_output = torch.zeros((token_num, head_num, head_dim), dtype=torch.float16, device="cuda")
2727
output_idx = torch.arange(0, token_num, 1, dtype=torch.int32, device="cuda")
2828

29-
kv_trans_v2(input_mems, input_idx, input_dp_idx, test_output, output_idx, dp_size_in_node)
29+
kv_trans_v2_for_p_node(input_mems, input_idx, input_dp_idx, test_output, output_idx, dp_size_in_node)
3030

3131
for dest_token_index, token_index, dp_index in zip(
3232
list(range(token_num)), input_idx.cpu().numpy(), input_dp_idx.cpu().numpy()
@@ -37,5 +37,41 @@ def test_kv_trans_v2(token_num):
3737
return
3838

3939

40+
@pytest.mark.parametrize(
41+
"token_num",
42+
[token_num for token_num in range(5, 10)],
43+
)
44+
def test_kv_trans_v2_for_d_node(token_num):
45+
card_num = 8
46+
dp_size_in_node = 4
47+
head_num = 2
48+
head_dim = 512
49+
kv_buffer_token_num = 512
50+
mems = []
51+
for _ in range(card_num):
52+
mems.append(torch.randn((kv_buffer_token_num, head_num, head_dim), dtype=torch.float16, device="cuda"))
53+
output_mems = torch.tensor([e.data_ptr() for e in mems], dtype=torch.uint64, device="cuda")
54+
output_idx = [random.randint(0, kv_buffer_token_num - 1) for _ in range(token_num)]
55+
output_idx = torch.tensor(output_idx, dtype=torch.int32, device="cuda")
56+
output_dp_idx = [random.randint(0, dp_size_in_node - 1) for _ in range(token_num)]
57+
output_dp_idx = torch.tensor(output_dp_idx, dtype=torch.int32, device="cuda")
58+
59+
test_input = torch.randn((token_num, head_num, head_dim), dtype=torch.float16, device="cuda")
60+
input_idx = torch.arange(0, token_num, 1, dtype=torch.int32, device="cuda")
61+
62+
kv_trans_v2_for_d_node(output_mems, output_idx, output_dp_idx, test_input, input_idx, dp_size_in_node)
63+
64+
for dest_token_index, token_index, dest_token_index, dp_index in zip(
65+
list(range(token_num)),
66+
input_idx.cpu().numpy(),
67+
output_idx.cpu().numpy(),
68+
output_dp_idx.cpu().numpy(),
69+
):
70+
for mem_index in range(dp_index * card_num // dp_size_in_node, (dp_index + 1) * card_num // dp_size_in_node):
71+
torch.equal(mems[mem_index][dest_token_index, :, :], test_input[token_index, :, :])
72+
73+
return
74+
75+
4076
if __name__ == "__main__":
4177
pytest.main()

0 commit comments

Comments
 (0)