|
6 | 6 | from typing import List, Union
|
7 | 7 | from lightllm.utils.log_utils import init_logger
|
8 | 8 | from lightllm.common.kv_trans_kernel.kv_trans import kv_trans
|
| 9 | +from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_d_node, kv_trans_v2_for_p_node |
9 | 10 |
|
10 | 11 | logger = init_logger(__name__)
|
11 | 12 |
|
@@ -103,50 +104,99 @@ def send_to_decode_node_p2p(
|
103 | 104 | """
|
104 | 105 | 使用 p2p triton kernel 进行数据复制和传输的实现方式。
|
105 | 106 | """
|
106 |
| - assert dp_size_in_node == 1 |
| 107 | + if not hasattr(self, "mem_ptrs_dict"): |
| 108 | + self.mem_ptrs_dict = {} |
| 109 | + for layer_index in range(self.layer_num): |
| 110 | + mems_ptr = [] |
| 111 | + for i in range(0, len(mem_managers), len(mem_managers) // dp_size_in_node): |
| 112 | + mems_ptr.append(mem_managers[i].kv_buffer[layer_index, :, :, :].data_ptr()) |
| 113 | + mems_ptr = torch.tensor(mems_ptr, dtype=torch.uint64, device="cuda") |
| 114 | + self.mem_ptrs_dict[layer_index] = mems_ptr |
107 | 115 |
|
108 | 116 | move_token_indexes = []
|
| 117 | + token_dp_indexes = [] |
109 | 118 | for task in move_tasks:
|
110 | 119 | if task.move_kv_len != 0:
|
111 | 120 | move_token_indexes.extend(task.prefill_token_indexes[-task.move_kv_len :])
|
| 121 | + token_dp_indexes.extend([task.prefill_dp_index for _ in range(task.move_kv_len)]) |
112 | 122 |
|
113 | 123 | move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda")
|
| 124 | + token_dp_indexes = torch.tensor(token_dp_indexes, dtype=torch.int32, device="cuda") |
114 | 125 | for layer_index in range(self.layer_num):
|
115 |
| - move_buffer = self._get_kv_move_data_p2p(move_token_indexes, layer_index, self.kv_move_buffer) |
| 126 | + move_buffer = self._get_kv_move_data_p2p( |
| 127 | + move_token_indexes, token_dp_indexes, layer_index, self.kv_move_buffer, dp_size_in_node |
| 128 | + ) |
116 | 129 | dist.send(move_buffer, dst=1)
|
117 | 130 | return
|
118 | 131 |
|
119 |
| - def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, kv_move_buffer: torch.Tensor): |
| 132 | + def _get_kv_move_data_p2p( |
| 133 | + self, |
| 134 | + token_indexes: torch.Tensor, |
| 135 | + token_dp_indexes: torch.Tensor, |
| 136 | + layer_index: int, |
| 137 | + kv_move_buffer: torch.Tensor, |
| 138 | + dp_size_in_node: int, |
| 139 | + ): |
120 | 140 | move_token_num = len(token_indexes)
|
121 | 141 | move_size = self.kv_buffer.numel() // self.layer_num // self.size * move_token_num
|
122 | 142 | move_buffer = kv_move_buffer.view(-1)[0:move_size].view(move_token_num, self.head_num, self.head_dim)
|
123 |
| - kv_trans( |
124 |
| - self.kv_buffer[layer_index, :, :, :], token_indexes, move_buffer, self.kv_move_buf_indexes[0:move_token_num] |
| 143 | + kv_trans_v2_for_p_node( |
| 144 | + input_mems=self.mem_ptrs_dict[layer_index], |
| 145 | + input_idx=token_indexes, |
| 146 | + input_dp_idx=token_dp_indexes, |
| 147 | + output=move_buffer, |
| 148 | + output_idx=self.kv_move_buf_indexes[0:move_token_num], |
| 149 | + dp_size_in_node=dp_size_in_node, |
125 | 150 | )
|
126 | 151 | return move_buffer
|
127 | 152 |
|
128 | 153 | def receive_from_prefill_node_p2p(
|
129 | 154 | self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
|
130 | 155 | ):
|
131 |
| - assert dp_size_in_node == 1 |
| 156 | + if not hasattr(self, "mem_ptrs_dict"): |
| 157 | + self.mem_ptrs_dict = {} |
| 158 | + for layer_index in range(self.layer_num): |
| 159 | + mems_ptr = [] |
| 160 | + for i in range(0, len(mem_managers)): |
| 161 | + mems_ptr.append(mem_managers[i].kv_buffer[layer_index, :, :, :].data_ptr()) |
| 162 | + mems_ptr = torch.tensor(mems_ptr, dtype=torch.uint64, device="cuda") |
| 163 | + self.mem_ptrs_dict[layer_index] = mems_ptr |
132 | 164 |
|
133 | 165 | move_token_indexes = []
|
| 166 | + token_dp_indexes = [] |
134 | 167 | for task in move_tasks:
|
135 | 168 | if task.move_kv_len != 0:
|
136 | 169 | move_token_indexes.extend(task.decode_token_indexes[-task.move_kv_len :])
|
| 170 | + token_dp_indexes.extend([task.decode_dp_index for _ in range(task.move_kv_len)]) |
137 | 171 |
|
138 | 172 | move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda")
|
| 173 | + token_dp_indexes = torch.tensor(token_dp_indexes, dtype=torch.int32, device="cuda") |
139 | 174 |
|
140 | 175 | token_num = len(move_token_indexes)
|
141 | 176 | move_size = self.kv_buffer.numel() // self.layer_num // self.size * token_num
|
142 | 177 | recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(token_num, self.head_num, self.head_dim)
|
143 | 178 | for layer_index in range(self.layer_num):
|
144 | 179 | dist.recv(recive_buffer, src=0)
|
145 |
| - for i, mem in enumerate(mem_managers): |
146 |
| - mem._write_kv_move_data_p2p(move_token_indexes, recive_buffer, layer_index) |
| 180 | + self._write_kv_move_data_p2p( |
| 181 | + move_token_indexes, token_dp_indexes, recive_buffer, layer_index, dp_size_in_node |
| 182 | + ) |
147 | 183 | return
|
148 | 184 |
|
149 |
| - def _write_kv_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: torch.Tensor, layer_index): |
| 185 | + def _write_kv_move_data_p2p( |
| 186 | + self, |
| 187 | + token_indexes: torch.Tensor, |
| 188 | + token_dp_indexes: torch.Tensor, |
| 189 | + buffer_tensor: torch.Tensor, |
| 190 | + layer_index, |
| 191 | + dp_size_in_node: int, |
| 192 | + ): |
150 | 193 | move_token_num = len(token_indexes)
|
151 |
| - kv_trans(buffer_tensor, self.kv_move_buf_indexes[0:move_token_num], self.kv_buffer[layer_index], token_indexes) |
| 194 | + kv_trans_v2_for_d_node( |
| 195 | + output_mems=self.mem_ptrs_dict[layer_index], |
| 196 | + output_idx=token_indexes, |
| 197 | + output_dp_idx=token_dp_indexes, |
| 198 | + input=buffer_tensor, |
| 199 | + input_idx=self.kv_move_buf_indexes[0:move_token_num], |
| 200 | + dp_size_in_node=dp_size_in_node, |
| 201 | + ) |
152 | 202 | return
|
0 commit comments