Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup bincount #10308

Merged
merged 6 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 68 additions & 23 deletions oneflow/user/kernels/bincount_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,67 @@ namespace oneflow {
namespace user_op {
namespace {

// clang-format off
template<typename IDX, typename T>
__global__ static void BinCountCompute(const IDX* in_ptr, const T* weight, T* out_ptr, int64_t size) {
CUDA_1D_KERNEL_LOOP(i, size) {
IDX idx = *(in_ptr + i);
cuda::atomic::Add(out_ptr + idx, weight[i]);
template<typename IDX, typename T, bool UseGlobalMem>
__global__ static void BinCountCompute(const IDX* in_ptr, const T* weight, T* out_ptr,
int64_t in_size, int64_t out_size) {
if constexpr (UseGlobalMem) {
CUDA_1D_KERNEL_LOOP(i, in_size) {
IDX idx = *(in_ptr + i);
cuda::atomic::Add(out_ptr + idx, weight[i]);
}
} else {
__shared__ T shm[kCudaThreadsNumPerBlock];
T zero = GetZeroVal<T>();
shm[threadIdx.x] = zero;
__syncthreads();
CUDA_1D_KERNEL_LOOP(i, in_size) {
IDX idx = *(in_ptr + i);
cuda::atomic::Add(shm + idx, weight[i]);
}
__syncthreads();
if (threadIdx.x < out_size) { cuda::atomic::Add(out_ptr + threadIdx.x, shm[threadIdx.x]); }
}
};
// clang-format on

template<typename IDX, typename T>
__global__ static void BinCountCompute(const IDX* in_ptr, T* out_ptr, int64_t size) {
template<typename IDX, typename T, bool UseGlobalMem>
__global__ static void BinCountCompute(const IDX* in_ptr, T* out_ptr, int64_t in_size,
int64_t out_size) {
T one = GetOneVal<T>();
CUDA_1D_KERNEL_LOOP(i, size) {
IDX idx = *(in_ptr + i);
cuda::atomic::Add(out_ptr + idx, one);
if constexpr (UseGlobalMem) {
CUDA_1D_KERNEL_LOOP(i, in_size) {
IDX idx = *(in_ptr + i);
cuda::atomic::Add(out_ptr + idx, one);
}
} else {
__shared__ T shm[kCudaThreadsNumPerBlock];
T zero = GetZeroVal<T>();
shm[threadIdx.x] = zero;
__syncthreads();
CUDA_1D_KERNEL_LOOP(i, in_size) {
IDX idx = *(in_ptr + i);
cuda::atomic::Add(shm + idx, one);
}
__syncthreads();
if (threadIdx.x < out_size) { cuda::atomic::Add(out_ptr + threadIdx.x, shm[threadIdx.x]); }
}
};

template<typename IDX, typename T, bool UseGlobalMem>
static void BinCountDispatch(user_op::KernelComputeContext* ctx, const IDX* in_ptr,
const T* weight_ptr, T* out_ptr, int64_t in_size, int64_t out_size) {
if (weight_ptr) {
BinCountCompute<IDX, T, UseGlobalMem>
<<<BlocksNum4ThreadsNum(in_size), kCudaThreadsNumPerBlock, 0,
ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(in_ptr, weight_ptr, out_ptr,
in_size, out_size);
} else {
BinCountCompute<IDX, T, UseGlobalMem>
<<<BlocksNum4ThreadsNum(in_size), kCudaThreadsNumPerBlock, 0,
ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(in_ptr, out_ptr, in_size,
out_size);
}
}

template<typename IDX, typename T>
class CUDABinCountKernel final : public user_op::OpKernel {
public:
Expand All @@ -52,31 +94,34 @@ class CUDABinCountKernel final : public user_op::OpKernel {
using user_op::OpKernel::Compute;
void Compute(user_op::KernelComputeContext* ctx) const override {
const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0);
size_t out_size = ctx->Attr<int64_t>("size") * sizeof(T);
size_t out_size = ctx->Attr<int64_t>("size");
user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0);
const IDX* in_ptr = in->dptr<IDX>();
T* out_ptr = out->mut_dptr<T>();

std::unique_ptr<ep::primitive::Memset> memset_primitive =
ep::primitive::NewPrimitive<ep::primitive::MemsetFactory>(ctx->device_type());
CHECK(memset_primitive);
memset_primitive->Launch(ctx->stream(), out_ptr, 0, out_size);
int64_t in_size = in->shape_view().elem_cnt();
memset_primitive->Launch(ctx->stream(), out_ptr, 0, out_size * sizeof(T));

const int64_t in_size = in->shape_view().elem_cnt();
if (in_size == 0) { return; }

const T* weight_ptr = nullptr;
if (ctx->has_input("weight", 0)) {
const T* weight_ptr = ctx->Tensor4ArgNameAndIndex("weight", 0)->dptr<T>();
BinCountCompute<IDX, T><<<BlocksNum4ThreadsNum(in_size), kCudaThreadsNumPerBlock, 0,
ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(
in_ptr, weight_ptr, out_ptr, in_size);
weight_ptr = ctx->Tensor4ArgNameAndIndex("weight", 0)->dptr<T>();
};

if (out_size > kCudaThreadsNumPerBlock) {
BinCountDispatch<IDX, T, true>(ctx, in_ptr, weight_ptr, out_ptr, in_size, out_size);
} else {
BinCountCompute<IDX, T>
<<<BlocksNum4ThreadsNum(in_size), kCudaThreadsNumPerBlock, 0,
ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(in_ptr, out_ptr, in_size);
BinCountDispatch<IDX, T, false>(ctx, in_ptr, weight_ptr, out_ptr, in_size, out_size);
}
};
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};

} // namespace oneflow
} // namespace

#define REGISTER_CUDA_BINCOUNT_KERNEL(idx_type, dtype) \
REGISTER_USER_KERNEL("bincount") \
Expand Down
14 changes: 10 additions & 4 deletions python/oneflow/test/modules/test_bincount.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,33 +26,39 @@ class TestBinCount(flow.unittest.TestCase):
@autotest(n=5, auto_backward=False, check_graph=False)
def test_bincount(test_case):
device = random_device()
x = random_tensor(1, 100, low=0, dtype=int).to(device)
x = random_tensor(1, 100, low=0, high=65536, dtype=int).to(device)
result = torch.bincount(x)
return result

@autotest(n=5, auto_backward=False, check_graph=False)
def test_bincount_weight(test_case):
device = random_device()
x = random_tensor(1, 100, low=0, dtype=int).to(device)
x = random_tensor(1, 100, low=0, high=65536, dtype=int).to(device)
weight = random_tensor(1, 100).to(device)
return torch.bincount(x, weights=weight)

@autotest(n=5, auto_backward=False, check_graph=False)
def test_bincount_minlength(test_case):
device = random_device()
x = random_tensor(1, 100, low=0, dtype=int).to(device)
x = random_tensor(1, 100, low=0, high=65536, dtype=int).to(device)
weight = random_tensor(1, 100).to(device)
minlength = random(1, 200).to(int)
return torch.bincount(x, weights=weight, minlength=minlength)

@autotest(n=5, auto_backward=False, check_graph=False)
def test_bincount_0element(test_case):
device = random_device()
x = random_tensor(1, 0, low=0, dtype=int).to(device)
x = random_tensor(1, 0, low=0, high=65536, dtype=int).to(device)
weight = random_tensor(1, 0).to(device)
minlength = random(1, 200).to(int)
return torch.bincount(x, weights=weight, minlength=minlength)

@profile(torch.bincount)
def profile_bincount(test_case):
torch.bincount(torch.ones(4096).int())
torch.bincount(torch.ones(65536).int())
torch.bincount(torch.arange(4096).int())


if __name__ == "__main__":
unittest.main()