Skip to content

Commit

Permalink
feat: support model dtype selection and cuda fallback cast for bf16 (#33
Browse files Browse the repository at this point in the history
)

* Support cast from bf16 weights on cuda

* Add automatic dtype selection

* Avoid double encoding of metal command buffer

* Add default value for cli

* Fix test and python api

* Update readme
  • Loading branch information
EricLBuehler authored Jan 8, 2025
1 parent 9da4ea6 commit ac4423f
Show file tree
Hide file tree
Showing 17 changed files with 416 additions and 29 deletions.
64 changes: 64 additions & 0 deletions .github/workflows/analysis.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
name: Analysis
on:
pull_request_target

jobs:
comment:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Install Rust and Cargo
run: |
curl -sSf https://sh.rustup.rs | sh -s -- -y
source $HOME/.cargo/env
- name: Install Tokei
run: cargo install tokei

- name: Run Tokei and get the lines of code
run: tokei . > tokei_output.txt

- name: Comment or Update PR
uses: actions/github-script@v7
with:
script: |
const fs = require('fs');
const tokeiOutput = fs.readFileSync('tokei_output.txt', 'utf8');
const uniqueIdentifier = 'Code Metrics Report';
const codeReport = `
<details>
<summary>${uniqueIdentifier}</summary>
<pre>
${tokeiOutput}
</pre>
</details>
`;
const issue_number = context.issue.number;
const { owner, repo } = context.repo;
const comments = await github.rest.issues.listComments({
issue_number,
owner,
repo
});
const existingComment = comments.data.find(comment => comment.body.includes(uniqueIdentifier));
if (existingComment) {
await github.rest.issues.updateComment({
owner,
repo,
comment_id: existingComment.id,
body: codeReport
});
} else {
await github.rest.issues.createComment({
issue_number,
owner,
repo,
body: codeReport
});
}
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Blazingly fast inference of diffusion models.
</h3>

<p align="center">
| <a href="https://ericlbuehler.github.io/diffusion-rs/diffusion_rs_core/"><b>Rust Documentation</b></a> | <a href="https://ericlbuehler.github.io/diffusion-rs/pyo3/diffusion_rs.html"><b>Python Documentation</b></a> | <a href="https://discord.gg/DRcvs6z5vu"><b>Discord</b></a> |
| <a href="https://ericlbuehler.github.io/diffusion-rs/diffusion_rs_core/"><b>Latest Rust Documentation</b></a> | <a href="https://ericlbuehler.github.io/diffusion-rs/pyo3/diffusion_rs.html"><b>Latest Python Documentation</b></a> | <a href="https://discord.gg/DRcvs6z5vu"><b>Discord</b></a> |
</p>


Expand Down Expand Up @@ -72,7 +72,7 @@ Examples with the Rust crate: [here](diffusion_rs_examples/examples).
```rust
use std::time::Instant;

use diffusion_rs_core::{DiffusionGenerationParams, ModelSource, Offloading, Pipeline, TokenSource};
use diffusion_rs_core::{DiffusionGenerationParams, ModelSource, ModelDType, Offloading, Pipeline, TokenSource};
use tracing::level_filters::LevelFilter;
use tracing_subscriber::EnvFilter;

Expand All @@ -87,6 +87,7 @@ let pipeline = Pipeline::load(
TokenSource::CacheToken,
None,
None,
&ModelDType::Auto,
)?;

let start = Instant::now();
Expand Down
8 changes: 6 additions & 2 deletions diffusion_rs_cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{path::PathBuf, time::Instant};

use clap::{Parser, Subcommand};
use diffusion_rs_core::{
DiffusionGenerationParams, ModelSource, Offloading, Pipeline, TokenSource,
DiffusionGenerationParams, ModelDType, ModelSource, Offloading, Pipeline, TokenSource,
};
use tracing::level_filters::LevelFilter;
use tracing_subscriber::EnvFilter;
Expand Down Expand Up @@ -48,6 +48,10 @@ struct Args {
/// Offloading setting to use for this model
#[arg(short, long)]
offloading: Option<Offloading>,

/// DType for the model. The default is to use an automatic strategy with a fallback pattern: BF16 -> F16 -> F32
#[arg(short, long, default_value = "auto")]
dtype: ModelDType,
}

fn main() -> anyhow::Result<()> {
Expand All @@ -67,7 +71,7 @@ fn main() -> anyhow::Result<()> {
.map(TokenSource::Literal)
.unwrap_or(TokenSource::CacheToken);

let pipeline = Pipeline::load(source, false, token, None, args.offloading)?;
let pipeline = Pipeline::load(source, false, token, None, args.offloading, &args.dtype)?;

let height: usize = input("Height:")
.default_input("720")
Expand Down
43 changes: 43 additions & 0 deletions diffusion_rs_common/src/cuda_kernels/cast.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "cuda_utils.cuh"
#include<stdint.h>
#include "dummy_bf16.cuh"

template <typename S, typename T>
__device__ void cast_(
Expand Down Expand Up @@ -96,6 +97,29 @@ __device__ void cast_through(
}
}

template <typename T>
__device__ void cast_bf16_dummy(
const size_t numel,
const size_t num_dims,
const size_t *info,
const uint16_t *inp,
T *out
) {
const size_t *dims = info;
const size_t *strides = info + num_dims;
if (info == nullptr || is_contiguous(num_dims, dims, strides)) {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
out[i] = static_cast<T>(bf16_to_f32(inp[i]));
}
}
else {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
unsigned strided_i = get_strided_index(i, num_dims, dims, strides);
out[i] = static_cast<T>(bf16_to_f32(inp[strided_i]));
}
}
}


#define CAST_OP(SRC_TYPENAME, DST_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
Expand Down Expand Up @@ -143,6 +167,17 @@ extern "C" __global__ void FN_NAME( \
cast_through<SRC_TYPENAME, DST_TYPENAME, INT_TYPENAME>(numel, num_dims, info, inp, out); \
} \

#define CAST_BF16_FALLBACK_OP(DST_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t numel, \
const size_t num_dims, \
const size_t *info, \
const uint16_t *inp, \
DST_TYPENAME *out \
) { \
cast_bf16_dummy<DST_TYPENAME>(numel, num_dims, info, inp, out); \
} \

#if __CUDA_ARCH__ >= 800
#include "cuda_fp8.h"
#include "cuda_bf16.h"
Expand Down Expand Up @@ -195,6 +230,14 @@ CAST_OP(float, __half, cast_f32_f16)
CAST_OP(double, __half, cast_f64_f16)
CAST_OP(int32_t, __half, cast_i32_f16 )
CAST_THROUGH_OP(__half, int32_t, float, cast_f16_i32)

#if __CUDA_ARCH__ < 800
CAST_BF16_FALLBACK_OP(uint32_t, cast_bf16_u32)
CAST_BF16_FALLBACK_OP(float, cast_bf16_f32)
CAST_BF16_FALLBACK_OP(double, cast_bf16_f64)
CAST_BF16_FALLBACK_OP(__half, cast_bf16_f16)
CAST_BF16_FALLBACK_OP(int32_t, cast_bf16_i32)
#endif
#endif

CAST_OP(uint32_t, uint32_t, cast_u32_u32)
Expand Down
56 changes: 56 additions & 0 deletions diffusion_rs_common/src/cuda_kernels/dummy_bf16.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#include<stdint.h>

__device__ __forceinline__ float bf16_to_f32(const uint16_t i)
{
// If NaN, keep current mantissa but also set most significant mantissa bit
if ((i & 0x7FFFu) > 0x7F80u) {
// NaN path
uint32_t tmp = ((static_cast<uint32_t>(i) | 0x0040u) << 16);
union {
uint32_t as_int;
float as_float;
} u;
u.as_int = tmp;
return u.as_float;
// Alternatively:
// return __int_as_float(((static_cast<uint32_t>(i) | 0x0040u) << 16));
} else {
// Normal path
uint32_t tmp = (static_cast<uint32_t>(i) << 16);
union {
uint32_t as_int;
float as_float;
} u;
u.as_int = tmp;
return u.as_float;
// Alternatively:
// return __int_as_float(static_cast<uint32_t>(i) << 16);
}
}

// Convert FP32 (float) to BF16 (unsigned short)
__device__ __forceinline__ uint16_t f32_to_bf16(const float value)
{
// Reinterpret float bits as uint32_t
union {
float as_float;
uint32_t as_int;
} u;
u.as_float = value;
uint32_t x = u.as_int;

// Check for NaN
if ((x & 0x7FFF'FFFFu) > 0x7F80'0000u) {
// Keep high part of current mantissa but also set most significant mantissa bit
return static_cast<uint16_t>((x >> 16) | 0x0040u);
}

// Round and shift
constexpr uint32_t round_bit = 0x0000'8000u; // bit 15
if (((x & round_bit) != 0) && ((x & (3 * round_bit - 1)) != 0)) {
// Round half to even (or to odd) depends on your preference
return static_cast<uint16_t>((x >> 16) + 1);
} else {
return static_cast<uint16_t>(x >> 16);
}
}
18 changes: 4 additions & 14 deletions diffusion_rs_common/src/varbuilder_loading.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,8 @@ impl TensorLoaderBackend for SafetensorBackend {
.map(|(name, _)| name)
.collect::<Vec<_>>()
}
fn load_name(&self, name: &str, device: &Device, dtype: Option<DType>) -> Result<Tensor> {
let t = self.0.load(name, device)?;
if let Some(dtype) = dtype {
t.to_dtype(dtype)
} else {
Ok(t)
}
fn load_name(&self, name: &str, device: &Device, _dtype: Option<DType>) -> Result<Tensor> {
self.0.load(name, device)
}
}

Expand All @@ -53,13 +48,8 @@ impl TensorLoaderBackend for BytesSafetensorBackend<'_> {
.map(|(name, _)| name)
.collect::<Vec<_>>()
}
fn load_name(&self, name: &str, device: &Device, dtype: Option<DType>) -> Result<Tensor> {
let t = self.0.load(name, device)?;
if let Some(dtype) = dtype {
t.to_dtype(dtype)
} else {
Ok(t)
}
fn load_name(&self, name: &str, device: &Device, _dtype: Option<DType>) -> Result<Tensor> {
self.0.load(name, device)
}
}

Expand Down
5 changes: 4 additions & 1 deletion diffusion_rs_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
//! ```rust,no_run
//! use std::time::Instant;
//!
//! use diffusion_rs_core::{DiffusionGenerationParams, ModelSource, Offloading, Pipeline, TokenSource};
//! use diffusion_rs_core::{DiffusionGenerationParams, ModelSource, ModelDType, Offloading, Pipeline, TokenSource};
//!
//! let pipeline = Pipeline::load(
//! ModelSource::dduf("FLUX.1-dev-Q4-bnb.dduf")?,
//! true,
//! TokenSource::CacheToken,
//! None,
//! None,
//! &ModelDType::Auto,
//! )?;
//!
//! let start = Instant::now();
Expand All @@ -37,6 +38,8 @@
mod models;
mod pipelines;
mod util;

pub use diffusion_rs_common::{ModelSource, TokenSource};
pub use pipelines::{DiffusionGenerationParams, Offloading, Pipeline};
pub use util::{ModelDType, TryIntoDType};
11 changes: 9 additions & 2 deletions diffusion_rs_core/src/models/vaes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::sync::Arc;

use autoencoder_kl::{AutencoderKlConfig, AutoEncoderKl};
use diffusion_rs_common::{
core::{Device, Result, Tensor},
core::{DType, Device, Result, Tensor},
ModelSource,
};
use serde::Deserialize;
Expand Down Expand Up @@ -46,10 +46,17 @@ pub(crate) fn dispatch_load_vae_model(
cfg_json: &FileData,
safetensor_files: Vec<FileData>,
device: &Device,
dtype: DType,
silent: bool,
source: Arc<ModelSource>,
) -> anyhow::Result<Arc<dyn VAEModel>> {
let vb = from_mmaped_safetensors(safetensor_files, None, device, silent, source.clone())?;
let vb = from_mmaped_safetensors(
safetensor_files,
Some(dtype),
device,
silent,
source.clone(),
)?;

let VaeConfigShim { name } = serde_json::from_str(&cfg_json.read_to_string(&source)?)?;
match name.as_str() {
Expand Down
8 changes: 5 additions & 3 deletions diffusion_rs_core/src/pipelines/flux/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ impl Loader for FluxLoader {
&self,
mut components: HashMap<ComponentName, ComponentElem>,
device: &Device,
dtype: DType,
silent: bool,
offloading_type: Option<Offloading>,
source: Arc<ModelSource>,
Expand Down Expand Up @@ -96,7 +97,7 @@ impl Loader for FluxLoader {

let vb = from_mmaped_safetensors(
safetensors.into_values().collect(),
None,
Some(dtype),
device,
silent,
source.clone(),
Expand All @@ -116,7 +117,7 @@ impl Loader for FluxLoader {
let cfg: T5Config = serde_json::from_str(&config.read_to_string(&source)?)?;
let vb = from_mmaped_safetensors(
safetensors.into_values().collect(),
None,
Some(dtype),
&t5_flux_device,
silent,
source.clone(),
Expand All @@ -137,6 +138,7 @@ impl Loader for FluxLoader {
&config,
safetensors.into_values().collect(),
device,
dtype,
silent,
source.clone(),
)?
Expand All @@ -154,7 +156,7 @@ impl Loader for FluxLoader {
let cfg: FluxConfig = serde_json::from_str(&config.read_to_string(&source)?)?;
let vb = from_mmaped_safetensors(
safetensors.into_values().collect(),
None,
Some(dtype),
&t5_flux_device,
silent,
source,
Expand Down
Loading

0 comments on commit ac4423f

Please sign in to comment.