Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support auto dtype selection and cast from bf16 weights on cuda #33

Merged
merged 6 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions .github/workflows/analysis.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
name: Analysis
on:
pull_request_target

jobs:
comment:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Install Rust and Cargo
run: |
curl -sSf https://sh.rustup.rs | sh -s -- -y
source $HOME/.cargo/env

- name: Install Tokei
run: cargo install tokei

- name: Run Tokei and get the lines of code
run: tokei . > tokei_output.txt

- name: Comment or Update PR
uses: actions/github-script@v7
with:
script: |
const fs = require('fs');
const tokeiOutput = fs.readFileSync('tokei_output.txt', 'utf8');
const uniqueIdentifier = 'Code Metrics Report';
const codeReport = `
<details>
<summary>${uniqueIdentifier}</summary>
<pre>
${tokeiOutput}
</pre>
</details>
`;

const issue_number = context.issue.number;
const { owner, repo } = context.repo;

const comments = await github.rest.issues.listComments({
issue_number,
owner,
repo
});

const existingComment = comments.data.find(comment => comment.body.includes(uniqueIdentifier));

if (existingComment) {
await github.rest.issues.updateComment({
owner,
repo,
comment_id: existingComment.id,
body: codeReport
});
} else {
await github.rest.issues.createComment({
issue_number,
owner,
repo,
body: codeReport
});
}
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Blazingly fast inference of diffusion models.
</h3>

<p align="center">
| <a href="https://ericlbuehler.github.io/diffusion-rs/diffusion_rs_core/"><b>Rust Documentation</b></a> | <a href="https://ericlbuehler.github.io/diffusion-rs/pyo3/diffusion_rs.html"><b>Python Documentation</b></a> | <a href="https://discord.gg/DRcvs6z5vu"><b>Discord</b></a> |
| <a href="https://ericlbuehler.github.io/diffusion-rs/diffusion_rs_core/"><b>Latest Rust Documentation</b></a> | <a href="https://ericlbuehler.github.io/diffusion-rs/pyo3/diffusion_rs.html"><b>Latest Python Documentation</b></a> | <a href="https://discord.gg/DRcvs6z5vu"><b>Discord</b></a> |
</p>


Expand Down Expand Up @@ -72,7 +72,7 @@ Examples with the Rust crate: [here](diffusion_rs_examples/examples).
```rust
use std::time::Instant;

use diffusion_rs_core::{DiffusionGenerationParams, ModelSource, Offloading, Pipeline, TokenSource};
use diffusion_rs_core::{DiffusionGenerationParams, ModelSource, ModelDType, Offloading, Pipeline, TokenSource};
use tracing::level_filters::LevelFilter;
use tracing_subscriber::EnvFilter;

Expand All @@ -87,6 +87,7 @@ let pipeline = Pipeline::load(
TokenSource::CacheToken,
None,
None,
&ModelDType::Auto,
)?;

let start = Instant::now();
Expand Down
8 changes: 6 additions & 2 deletions diffusion_rs_cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{path::PathBuf, time::Instant};

use clap::{Parser, Subcommand};
use diffusion_rs_core::{
DiffusionGenerationParams, ModelSource, Offloading, Pipeline, TokenSource,
DiffusionGenerationParams, ModelDType, ModelSource, Offloading, Pipeline, TokenSource,
};
use tracing::level_filters::LevelFilter;
use tracing_subscriber::EnvFilter;
Expand Down Expand Up @@ -48,6 +48,10 @@ struct Args {
/// Offloading setting to use for this model
#[arg(short, long)]
offloading: Option<Offloading>,

/// DType for the model. The default is to use an automatic strategy with a fallback pattern: BF16 -> F16 -> F32
#[arg(short, long, default_value = "auto")]
dtype: ModelDType,
}

fn main() -> anyhow::Result<()> {
Expand All @@ -67,7 +71,7 @@ fn main() -> anyhow::Result<()> {
.map(TokenSource::Literal)
.unwrap_or(TokenSource::CacheToken);

let pipeline = Pipeline::load(source, false, token, None, args.offloading)?;
let pipeline = Pipeline::load(source, false, token, None, args.offloading, &args.dtype)?;

let height: usize = input("Height:")
.default_input("720")
Expand Down
43 changes: 43 additions & 0 deletions diffusion_rs_common/src/cuda_kernels/cast.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "cuda_utils.cuh"
#include<stdint.h>
#include "dummy_bf16.cuh"

template <typename S, typename T>
__device__ void cast_(
Expand Down Expand Up @@ -96,6 +97,29 @@ __device__ void cast_through(
}
}

template <typename T>
__device__ void cast_bf16_dummy(
const size_t numel,
const size_t num_dims,
const size_t *info,
const uint16_t *inp,
T *out
) {
const size_t *dims = info;
const size_t *strides = info + num_dims;
if (info == nullptr || is_contiguous(num_dims, dims, strides)) {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
out[i] = static_cast<T>(bf16_to_f32(inp[i]));
}
}
else {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
unsigned strided_i = get_strided_index(i, num_dims, dims, strides);
out[i] = static_cast<T>(bf16_to_f32(inp[strided_i]));
}
}
}


#define CAST_OP(SRC_TYPENAME, DST_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
Expand Down Expand Up @@ -143,6 +167,17 @@ extern "C" __global__ void FN_NAME( \
cast_through<SRC_TYPENAME, DST_TYPENAME, INT_TYPENAME>(numel, num_dims, info, inp, out); \
} \

#define CAST_BF16_FALLBACK_OP(DST_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t numel, \
const size_t num_dims, \
const size_t *info, \
const uint16_t *inp, \
DST_TYPENAME *out \
) { \
cast_bf16_dummy<DST_TYPENAME>(numel, num_dims, info, inp, out); \
} \

#if __CUDA_ARCH__ >= 800
#include "cuda_fp8.h"
#include "cuda_bf16.h"
Expand Down Expand Up @@ -195,6 +230,14 @@ CAST_OP(float, __half, cast_f32_f16)
CAST_OP(double, __half, cast_f64_f16)
CAST_OP(int32_t, __half, cast_i32_f16 )
CAST_THROUGH_OP(__half, int32_t, float, cast_f16_i32)

#if __CUDA_ARCH__ < 800
CAST_BF16_FALLBACK_OP(uint32_t, cast_bf16_u32)
CAST_BF16_FALLBACK_OP(float, cast_bf16_f32)
CAST_BF16_FALLBACK_OP(double, cast_bf16_f64)
CAST_BF16_FALLBACK_OP(__half, cast_bf16_f16)
CAST_BF16_FALLBACK_OP(int32_t, cast_bf16_i32)
#endif
#endif

CAST_OP(uint32_t, uint32_t, cast_u32_u32)
Expand Down
56 changes: 56 additions & 0 deletions diffusion_rs_common/src/cuda_kernels/dummy_bf16.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#include<stdint.h>

__device__ __forceinline__ float bf16_to_f32(const uint16_t i)
{
// If NaN, keep current mantissa but also set most significant mantissa bit
if ((i & 0x7FFFu) > 0x7F80u) {
// NaN path
uint32_t tmp = ((static_cast<uint32_t>(i) | 0x0040u) << 16);
union {
uint32_t as_int;
float as_float;
} u;
u.as_int = tmp;
return u.as_float;
// Alternatively:
// return __int_as_float(((static_cast<uint32_t>(i) | 0x0040u) << 16));
} else {
// Normal path
uint32_t tmp = (static_cast<uint32_t>(i) << 16);
union {
uint32_t as_int;
float as_float;
} u;
u.as_int = tmp;
return u.as_float;
// Alternatively:
// return __int_as_float(static_cast<uint32_t>(i) << 16);
}
}

// Convert FP32 (float) to BF16 (unsigned short)
__device__ __forceinline__ uint16_t f32_to_bf16(const float value)
{
// Reinterpret float bits as uint32_t
union {
float as_float;
uint32_t as_int;
} u;
u.as_float = value;
uint32_t x = u.as_int;

// Check for NaN
if ((x & 0x7FFF'FFFFu) > 0x7F80'0000u) {
// Keep high part of current mantissa but also set most significant mantissa bit
return static_cast<uint16_t>((x >> 16) | 0x0040u);
}

// Round and shift
constexpr uint32_t round_bit = 0x0000'8000u; // bit 15
if (((x & round_bit) != 0) && ((x & (3 * round_bit - 1)) != 0)) {
// Round half to even (or to odd) depends on your preference
return static_cast<uint16_t>((x >> 16) + 1);
} else {
return static_cast<uint16_t>(x >> 16);
}
}
18 changes: 4 additions & 14 deletions diffusion_rs_common/src/varbuilder_loading.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,8 @@ impl TensorLoaderBackend for SafetensorBackend {
.map(|(name, _)| name)
.collect::<Vec<_>>()
}
fn load_name(&self, name: &str, device: &Device, dtype: Option<DType>) -> Result<Tensor> {
let t = self.0.load(name, device)?;
if let Some(dtype) = dtype {
t.to_dtype(dtype)
} else {
Ok(t)
}
fn load_name(&self, name: &str, device: &Device, _dtype: Option<DType>) -> Result<Tensor> {
self.0.load(name, device)
}
}

Expand All @@ -53,13 +48,8 @@ impl TensorLoaderBackend for BytesSafetensorBackend<'_> {
.map(|(name, _)| name)
.collect::<Vec<_>>()
}
fn load_name(&self, name: &str, device: &Device, dtype: Option<DType>) -> Result<Tensor> {
let t = self.0.load(name, device)?;
if let Some(dtype) = dtype {
t.to_dtype(dtype)
} else {
Ok(t)
}
fn load_name(&self, name: &str, device: &Device, _dtype: Option<DType>) -> Result<Tensor> {
self.0.load(name, device)
}
}

Expand Down
5 changes: 4 additions & 1 deletion diffusion_rs_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
//! ```rust,no_run
//! use std::time::Instant;
//!
//! use diffusion_rs_core::{DiffusionGenerationParams, ModelSource, Offloading, Pipeline, TokenSource};
//! use diffusion_rs_core::{DiffusionGenerationParams, ModelSource, ModelDType, Offloading, Pipeline, TokenSource};
//!
//! let pipeline = Pipeline::load(
//! ModelSource::dduf("FLUX.1-dev-Q4-bnb.dduf")?,
//! true,
//! TokenSource::CacheToken,
//! None,
//! None,
//! &ModelDType::Auto,
//! )?;
//!
//! let start = Instant::now();
Expand All @@ -37,6 +38,8 @@

mod models;
mod pipelines;
mod util;

pub use diffusion_rs_common::{ModelSource, TokenSource};
pub use pipelines::{DiffusionGenerationParams, Offloading, Pipeline};
pub use util::{ModelDType, TryIntoDType};
11 changes: 9 additions & 2 deletions diffusion_rs_core/src/models/vaes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::sync::Arc;

use autoencoder_kl::{AutencoderKlConfig, AutoEncoderKl};
use diffusion_rs_common::{
core::{Device, Result, Tensor},
core::{DType, Device, Result, Tensor},
ModelSource,
};
use serde::Deserialize;
Expand Down Expand Up @@ -46,10 +46,17 @@ pub(crate) fn dispatch_load_vae_model(
cfg_json: &FileData,
safetensor_files: Vec<FileData>,
device: &Device,
dtype: DType,
silent: bool,
source: Arc<ModelSource>,
) -> anyhow::Result<Arc<dyn VAEModel>> {
let vb = from_mmaped_safetensors(safetensor_files, None, device, silent, source.clone())?;
let vb = from_mmaped_safetensors(
safetensor_files,
Some(dtype),
device,
silent,
source.clone(),
)?;

let VaeConfigShim { name } = serde_json::from_str(&cfg_json.read_to_string(&source)?)?;
match name.as_str() {
Expand Down
8 changes: 5 additions & 3 deletions diffusion_rs_core/src/pipelines/flux/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ impl Loader for FluxLoader {
&self,
mut components: HashMap<ComponentName, ComponentElem>,
device: &Device,
dtype: DType,
silent: bool,
offloading_type: Option<Offloading>,
source: Arc<ModelSource>,
Expand Down Expand Up @@ -96,7 +97,7 @@ impl Loader for FluxLoader {

let vb = from_mmaped_safetensors(
safetensors.into_values().collect(),
None,
Some(dtype),
device,
silent,
source.clone(),
Expand All @@ -116,7 +117,7 @@ impl Loader for FluxLoader {
let cfg: T5Config = serde_json::from_str(&config.read_to_string(&source)?)?;
let vb = from_mmaped_safetensors(
safetensors.into_values().collect(),
None,
Some(dtype),
&t5_flux_device,
silent,
source.clone(),
Expand All @@ -137,6 +138,7 @@ impl Loader for FluxLoader {
&config,
safetensors.into_values().collect(),
device,
dtype,
silent,
source.clone(),
)?
Expand All @@ -154,7 +156,7 @@ impl Loader for FluxLoader {
let cfg: FluxConfig = serde_json::from_str(&config.read_to_string(&source)?)?;
let vb = from_mmaped_safetensors(
safetensors.into_values().collect(),
None,
Some(dtype),
&t5_flux_device,
silent,
source,
Expand Down
Loading
Loading