Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepseekMoE support with Fused MoE kernel #2453

Merged
merged 20 commits into from
Jan 30, 2024
11 changes: 11 additions & 0 deletions csrc/dispatch_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,14 @@
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))

#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)

#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
108 changes: 108 additions & 0 deletions csrc/moe_align_block_size_kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

#include <ATen/ATen.h>
#include <THC/THCAtomics.cuh>

#include "cuda_compat.h"
#include "dispatch_utils.h"

const static size_t NUM_MAX_EXPERTS = 64;
#define CEILDIV(x,y) (((x) + (y) - 1) / (y))

namespace vllm {
template <typename scalar_t>
__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
int32_t *sorted_token_ids,
int32_t *expert_ids,
int32_t *total_tokens_post_pad,
int32_t num_experts,
int32_t block_size,
size_t numel) {
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread;
__shared__ int32_t tokens_cnts[NUM_MAX_EXPERTS + 1][NUM_MAX_EXPERTS];
Copy link
Collaborator

@WoosukKwon WoosukKwon Jan 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Actually, this part doesn't need to be static. Shared memory size can be configured dynamically at the kernel launch time. However, I think we can fix this in a later PR.

__shared__ int32_t cumsum[NUM_MAX_EXPERTS + 1];
for (int i = 0; i < num_experts; ++i) {
tokens_cnts[threadIdx.x + 1][i] = 0;
}

/**
* In the first step we compute token_cnts[thread_index + 1][expert_index],
* which counts how many tokens in the token shard of thread_index are assigned
* to expert expert_index.
*/
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
++tokens_cnts[threadIdx.x + 1][topk_ids[i]];
}

__syncthreads();

// For each expert we accumulate the token counts from the different threads.
tokens_cnts[0][threadIdx.x] = 0;
for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[i][threadIdx.x] += tokens_cnts[i-1][threadIdx.x];
}

__syncthreads();

// We accumulate the token counts of all experts in thread 0.
if (threadIdx.x == 0) {
cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) {
cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[blockDim.x][i - 1], block_size) * block_size;
}
*total_tokens_post_pad = cumsum[num_experts];
}

__syncthreads();

/**
* For each expert, each thread processes the tokens of the corresponding blocks
* and stores the corresponding expert_id for each block.
*/
for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) {
expert_ids[i / block_size] = threadIdx.x;
}

/**
* Each thread processes a token shard, calculating the index of each token after
* sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and
* block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *],
* where * represents a padding value(preset in python).
*/
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int32_t expert_id = topk_ids[i];
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id]
* stores the indices of the tokens processed by the expert with expert_id within
* the current thread's token shard.
*/
int32_t rank_post_pad = tokens_cnts[threadIdx.x][expert_id] + cumsum[expert_id];
sorted_token_ids[rank_post_pad] = i;
++tokens_cnts[threadIdx.x][expert_id];
}
}
}

void moe_align_block_size(
torch::Tensor topk_ids,
int num_experts,
int block_size,
torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
assert(num_experts <= NUM_MAX_EXPERTS);
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_alig_block_size_kernel", [&] {
vllm::moe_align_block_size_kernel<scalar_t><<<1, num_experts, 0, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(),
num_experts,
block_size,
topk_ids.numel());
});
}
9 changes: 9 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,12 @@ std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
const std::vector<std::vector<int64_t>> &offsets);
#endif

void moe_align_block_size(
torch::Tensor topk_ids,
int num_experts,
int block_size,
torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad
);
4 changes: 4 additions & 0 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
ops.def(
"moe_align_block_size",
&moe_align_block_size,
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size.");

// Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def get_torch_arch_list() -> Set[str]:
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
"csrc/quantization/gptq/q_gemm.cu",
"csrc/cuda_utils_kernels.cu",
"csrc/moe_align_block_size_kernels.cu",
"csrc/pybind.cpp",
]

Expand Down
49 changes: 49 additions & 0 deletions tests/kernels/test_fused_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.activation import SiluAndMul
import pytest


def torch_moe(a, w1, w2, topk_weight, topk_ids):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D)
out = torch.zeros(B * topk_ids.shape[1],
w2.shape[1],
dtype=a.dtype,
device=a.device)
topk_ids = topk_ids.view(-1)
topk_weight = topk_weight.view(-1)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1)).sum(dim=1)


@pytest.mark.parametrize("m", [512, 222, 33, 1])
@pytest.mark.parametrize("n", [2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("e", [8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_fused_moe(
m: int,
n: int,
k: int,
e: int,
topk: int,
dtype: torch.dtype,
):
a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10

score = torch.randn((m, e), device='cuda', dtype=dtype)
score = torch.softmax(score, dim=-1)
topk_weight, topk_ids = torch.topk(score, topk)

triton_output = fused_moe(a, w1, w2, topk_weight, topk_ids, False)
torch_output = torch_moe(a, w1, w2, topk_weight, topk_ids)
assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0)
Loading
Loading