Skip to content

Commit

Permalink
feat: better support devices which do not have bf16 (#34)
Browse files Browse the repository at this point in the history
* Fix .any

* Add cast_f32_bf16

* Fix cast_f32_bf16

* Remove analysis

* Add some more fns
  • Loading branch information
EricLBuehler authored Jan 8, 2025
1 parent ac4423f commit 1d1f0bf
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 77 deletions.
64 changes: 0 additions & 64 deletions .github/workflows/analysis.yaml

This file was deleted.

58 changes: 48 additions & 10 deletions diffusion_rs_common/src/cuda_kernels/cast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,30 @@ __device__ void cast_through(
}

template <typename T>
__device__ void cast_bf16_dummy(
__device__ void cast_to_bf16_dummy(
const size_t numel,
const size_t num_dims,
const size_t *info,
const T *inp,
uint16_t *out
) {
const size_t *dims = info;
const size_t *strides = info + num_dims;
if (info == nullptr || is_contiguous(num_dims, dims, strides)) {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
out[i] = fallback_f32_to_bf16(static_cast<float>(inp[i]));
}
}
else {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
unsigned strided_i = get_strided_index(i, num_dims, dims, strides);
out[i] = fallback_f32_to_bf16(static_cast<float>(inp[strided_i]));
}
}
}

template <typename T>
__device__ void cast_from_bf16_dummy(
const size_t numel,
const size_t num_dims,
const size_t *info,
Expand All @@ -109,13 +132,13 @@ __device__ void cast_bf16_dummy(
const size_t *strides = info + num_dims;
if (info == nullptr || is_contiguous(num_dims, dims, strides)) {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
out[i] = static_cast<T>(bf16_to_f32(inp[i]));
out[i] = static_cast<T>(fallback_bf16_to_f32(inp[i]));
}
}
else {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
unsigned strided_i = get_strided_index(i, num_dims, dims, strides);
out[i] = static_cast<T>(bf16_to_f32(inp[strided_i]));
out[i] = static_cast<T>(fallback_bf16_to_f32(inp[strided_i]));
}
}
}
Expand Down Expand Up @@ -167,15 +190,26 @@ extern "C" __global__ void FN_NAME( \
cast_through<SRC_TYPENAME, DST_TYPENAME, INT_TYPENAME>(numel, num_dims, info, inp, out); \
} \

#define CAST_BF16_FALLBACK_OP(DST_TYPENAME, FN_NAME) \
#define CAST_TO_BF16_FALLBACK_OP(SRC_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t numel, \
const size_t num_dims, \
const size_t *info, \
const SRC_TYPENAME *inp, \
uint16_t *out \
) { \
cast_to_bf16_dummy<SRC_TYPENAME>(numel, num_dims, info, inp, out); \
} \

#define CAST_FROM_BF16_FALLBACK_OP(DST_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t numel, \
const size_t num_dims, \
const size_t *info, \
const uint16_t *inp, \
DST_TYPENAME *out \
) { \
cast_bf16_dummy<DST_TYPENAME>(numel, num_dims, info, inp, out); \
cast_from_bf16_dummy<DST_TYPENAME>(numel, num_dims, info, inp, out); \
} \

#if __CUDA_ARCH__ >= 800
Expand Down Expand Up @@ -232,11 +266,15 @@ CAST_OP(int32_t, __half, cast_i32_f16 )
CAST_THROUGH_OP(__half, int32_t, float, cast_f16_i32)

#if __CUDA_ARCH__ < 800
CAST_BF16_FALLBACK_OP(uint32_t, cast_bf16_u32)
CAST_BF16_FALLBACK_OP(float, cast_bf16_f32)
CAST_BF16_FALLBACK_OP(double, cast_bf16_f64)
CAST_BF16_FALLBACK_OP(__half, cast_bf16_f16)
CAST_BF16_FALLBACK_OP(int32_t, cast_bf16_i32)
CAST_FROM_BF16_FALLBACK_OP(uint32_t, cast_bf16_u32)
CAST_FROM_BF16_FALLBACK_OP(float, cast_bf16_f32)
CAST_FROM_BF16_FALLBACK_OP(double, cast_bf16_f64)
CAST_FROM_BF16_FALLBACK_OP(__half, cast_bf16_f16)
CAST_FROM_BF16_FALLBACK_OP(int32_t, cast_bf16_i32)

CAST_TO_BF16_FALLBACK_OP(float, cast_f32_bf16)
CAST_TO_BF16_FALLBACK_OP(__half, cast_f16_bf16)
CAST_TO_BF16_FALLBACK_OP(double, cast_f64_bf16)
#endif
#endif

Expand Down
4 changes: 2 additions & 2 deletions diffusion_rs_common/src/cuda_kernels/dummy_bf16.cuh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include<stdint.h>

__device__ __forceinline__ float bf16_to_f32(const uint16_t i)
__device__ __forceinline__ float fallback_bf16_to_f32(const uint16_t i)
{
// If NaN, keep current mantissa but also set most significant mantissa bit
if ((i & 0x7FFFu) > 0x7F80u) {
Expand Down Expand Up @@ -29,7 +29,7 @@ __device__ __forceinline__ float bf16_to_f32(const uint16_t i)
}

// Convert FP32 (float) to BF16 (unsigned short)
__device__ __forceinline__ uint16_t f32_to_bf16(const float value)
__device__ __forceinline__ uint16_t fallback_f32_to_bf16(const float value)
{
// Reinterpret float bits as uint32_t
union {
Expand Down
9 changes: 8 additions & 1 deletion diffusion_rs_core/src/models/t5/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use diffusion_rs_backend::{linear_no_bias, QuantMethod, QuantizedConfig};
use diffusion_rs_common::core::{DType, Device, Module, Result, Tensor, D};
use diffusion_rs_common::nn::{Activation, Embedding};
use diffusion_rs_common::{embedding, VarBuilder};
use float8::F8E4M3;
use serde::Deserialize;
use std::sync::Arc;

Expand Down Expand Up @@ -479,11 +480,17 @@ impl TensorInfExtend for Tensor {
fn any(&self) -> Result<bool> {
let sum = self.sum_all()?;
match self.dtype() {
DType::I8 => Ok(sum.to_scalar::<u8>()? == 0),
DType::U8 => Ok(sum.to_scalar::<u8>()? == 0),
DType::U32 => Ok(sum.to_scalar::<u32>()? == 0),
DType::I16 => Ok(sum.to_scalar::<i16>()? == 0),
DType::I32 => Ok(sum.to_scalar::<i32>()? == 0),
DType::I64 => Ok(sum.to_scalar::<i64>()? == 0),
DType::F16 => Ok(sum.to_scalar::<half::f16>()? == half::f16::from_f32_const(0.)),
DType::BF16 => Ok(sum.to_scalar::<half::bf16>()? == half::bf16::from_f32_const(0.)),
DType::F32 => Ok(sum.to_scalar::<f32>()? == 0.),
DType::F64 => Ok(sum.to_scalar::<f64>()? == 0.),
_ => unreachable!(),
DType::F8E4M3 => Ok(sum.to_scalar::<F8E4M3>()? == F8E4M3::ZERO),
}
}
}
Expand Down

0 comments on commit 1d1f0bf

Please sign in to comment.