diff --git a/.github/workflows/analysis.yaml b/.github/workflows/analysis.yaml
new file mode 100644
index 0000000..52a5fe0
--- /dev/null
+++ b/.github/workflows/analysis.yaml
@@ -0,0 +1,64 @@
+name: Analysis
+on:
+ pull_request_target
+
+jobs:
+ comment:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v4
+
+ - name: Install Rust and Cargo
+ run: |
+ curl -sSf https://sh.rustup.rs | sh -s -- -y
+ source $HOME/.cargo/env
+
+ - name: Install Tokei
+ run: cargo install tokei
+
+ - name: Run Tokei and get the lines of code
+ run: tokei . > tokei_output.txt
+
+ - name: Comment or Update PR
+ uses: actions/github-script@v7
+ with:
+ script: |
+ const fs = require('fs');
+ const tokeiOutput = fs.readFileSync('tokei_output.txt', 'utf8');
+ const uniqueIdentifier = 'Code Metrics Report';
+ const codeReport = `
+
+ ${uniqueIdentifier}
+
+ ${tokeiOutput}
+
+
+ `;
+
+ const issue_number = context.issue.number;
+ const { owner, repo } = context.repo;
+
+ const comments = await github.rest.issues.listComments({
+ issue_number,
+ owner,
+ repo
+ });
+
+ const existingComment = comments.data.find(comment => comment.body.includes(uniqueIdentifier));
+
+ if (existingComment) {
+ await github.rest.issues.updateComment({
+ owner,
+ repo,
+ comment_id: existingComment.id,
+ body: codeReport
+ });
+ } else {
+ await github.rest.issues.createComment({
+ issue_number,
+ owner,
+ repo,
+ body: codeReport
+ });
+ }
diff --git a/README.md b/README.md
index 835a259..7fd0217 100644
--- a/README.md
+++ b/README.md
@@ -8,7 +8,7 @@ Blazingly fast inference of diffusion models.
-| Rust Documentation | Python Documentation | Discord |
+| Latest Rust Documentation | Latest Python Documentation | Discord |
@@ -72,7 +72,7 @@ Examples with the Rust crate: [here](diffusion_rs_examples/examples).
```rust
use std::time::Instant;
-use diffusion_rs_core::{DiffusionGenerationParams, ModelSource, Offloading, Pipeline, TokenSource};
+use diffusion_rs_core::{DiffusionGenerationParams, ModelSource, ModelDType, Offloading, Pipeline, TokenSource};
use tracing::level_filters::LevelFilter;
use tracing_subscriber::EnvFilter;
@@ -87,6 +87,7 @@ let pipeline = Pipeline::load(
TokenSource::CacheToken,
None,
None,
+ &ModelDType::Auto,
)?;
let start = Instant::now();
diff --git a/diffusion_rs_cli/src/main.rs b/diffusion_rs_cli/src/main.rs
index 9b0cb7d..ffe9516 100644
--- a/diffusion_rs_cli/src/main.rs
+++ b/diffusion_rs_cli/src/main.rs
@@ -3,7 +3,7 @@ use std::{path::PathBuf, time::Instant};
use clap::{Parser, Subcommand};
use diffusion_rs_core::{
- DiffusionGenerationParams, ModelSource, Offloading, Pipeline, TokenSource,
+ DiffusionGenerationParams, ModelDType, ModelSource, Offloading, Pipeline, TokenSource,
};
use tracing::level_filters::LevelFilter;
use tracing_subscriber::EnvFilter;
@@ -48,6 +48,10 @@ struct Args {
/// Offloading setting to use for this model
#[arg(short, long)]
offloading: Option,
+
+ /// DType for the model. The default is to use an automatic strategy with a fallback pattern: BF16 -> F16 -> F32
+ #[arg(short, long, default_value = "auto")]
+ dtype: ModelDType,
}
fn main() -> anyhow::Result<()> {
@@ -67,7 +71,7 @@ fn main() -> anyhow::Result<()> {
.map(TokenSource::Literal)
.unwrap_or(TokenSource::CacheToken);
- let pipeline = Pipeline::load(source, false, token, None, args.offloading)?;
+ let pipeline = Pipeline::load(source, false, token, None, args.offloading, &args.dtype)?;
let height: usize = input("Height:")
.default_input("720")
diff --git a/diffusion_rs_common/src/cuda_kernels/cast.cu b/diffusion_rs_common/src/cuda_kernels/cast.cu
index 07b079a..7262525 100644
--- a/diffusion_rs_common/src/cuda_kernels/cast.cu
+++ b/diffusion_rs_common/src/cuda_kernels/cast.cu
@@ -1,5 +1,6 @@
#include "cuda_utils.cuh"
#include
+#include "dummy_bf16.cuh"
template
__device__ void cast_(
@@ -96,6 +97,29 @@ __device__ void cast_through(
}
}
+template
+__device__ void cast_bf16_dummy(
+ const size_t numel,
+ const size_t num_dims,
+ const size_t *info,
+ const uint16_t *inp,
+ T *out
+) {
+ const size_t *dims = info;
+ const size_t *strides = info + num_dims;
+ if (info == nullptr || is_contiguous(num_dims, dims, strides)) {
+ for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
+ out[i] = static_cast(bf16_to_f32(inp[i]));
+ }
+ }
+ else {
+ for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
+ unsigned strided_i = get_strided_index(i, num_dims, dims, strides);
+ out[i] = static_cast(bf16_to_f32(inp[strided_i]));
+ }
+ }
+}
+
#define CAST_OP(SRC_TYPENAME, DST_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
@@ -143,6 +167,17 @@ extern "C" __global__ void FN_NAME( \
cast_through(numel, num_dims, info, inp, out); \
} \
+#define CAST_BF16_FALLBACK_OP(DST_TYPENAME, FN_NAME) \
+extern "C" __global__ void FN_NAME( \
+ const size_t numel, \
+ const size_t num_dims, \
+ const size_t *info, \
+ const uint16_t *inp, \
+ DST_TYPENAME *out \
+) { \
+ cast_bf16_dummy(numel, num_dims, info, inp, out); \
+} \
+
#if __CUDA_ARCH__ >= 800
#include "cuda_fp8.h"
#include "cuda_bf16.h"
@@ -195,6 +230,14 @@ CAST_OP(float, __half, cast_f32_f16)
CAST_OP(double, __half, cast_f64_f16)
CAST_OP(int32_t, __half, cast_i32_f16 )
CAST_THROUGH_OP(__half, int32_t, float, cast_f16_i32)
+
+#if __CUDA_ARCH__ < 800
+CAST_BF16_FALLBACK_OP(uint32_t, cast_bf16_u32)
+CAST_BF16_FALLBACK_OP(float, cast_bf16_f32)
+CAST_BF16_FALLBACK_OP(double, cast_bf16_f64)
+CAST_BF16_FALLBACK_OP(__half, cast_bf16_f16)
+CAST_BF16_FALLBACK_OP(int32_t, cast_bf16_i32)
+#endif
#endif
CAST_OP(uint32_t, uint32_t, cast_u32_u32)
diff --git a/diffusion_rs_common/src/cuda_kernels/dummy_bf16.cuh b/diffusion_rs_common/src/cuda_kernels/dummy_bf16.cuh
new file mode 100644
index 0000000..ea0f8b3
--- /dev/null
+++ b/diffusion_rs_common/src/cuda_kernels/dummy_bf16.cuh
@@ -0,0 +1,56 @@
+#include
+
+__device__ __forceinline__ float bf16_to_f32(const uint16_t i)
+{
+ // If NaN, keep current mantissa but also set most significant mantissa bit
+ if ((i & 0x7FFFu) > 0x7F80u) {
+ // NaN path
+ uint32_t tmp = ((static_cast(i) | 0x0040u) << 16);
+ union {
+ uint32_t as_int;
+ float as_float;
+ } u;
+ u.as_int = tmp;
+ return u.as_float;
+ // Alternatively:
+ // return __int_as_float(((static_cast(i) | 0x0040u) << 16));
+ } else {
+ // Normal path
+ uint32_t tmp = (static_cast(i) << 16);
+ union {
+ uint32_t as_int;
+ float as_float;
+ } u;
+ u.as_int = tmp;
+ return u.as_float;
+ // Alternatively:
+ // return __int_as_float(static_cast(i) << 16);
+ }
+}
+
+// Convert FP32 (float) to BF16 (unsigned short)
+__device__ __forceinline__ uint16_t f32_to_bf16(const float value)
+{
+ // Reinterpret float bits as uint32_t
+ union {
+ float as_float;
+ uint32_t as_int;
+ } u;
+ u.as_float = value;
+ uint32_t x = u.as_int;
+
+ // Check for NaN
+ if ((x & 0x7FFF'FFFFu) > 0x7F80'0000u) {
+ // Keep high part of current mantissa but also set most significant mantissa bit
+ return static_cast((x >> 16) | 0x0040u);
+ }
+
+ // Round and shift
+ constexpr uint32_t round_bit = 0x0000'8000u; // bit 15
+ if (((x & round_bit) != 0) && ((x & (3 * round_bit - 1)) != 0)) {
+ // Round half to even (or to odd) depends on your preference
+ return static_cast((x >> 16) + 1);
+ } else {
+ return static_cast(x >> 16);
+ }
+}
diff --git a/diffusion_rs_common/src/varbuilder_loading.rs b/diffusion_rs_common/src/varbuilder_loading.rs
index 8ee7fa8..8f0bf1a 100644
--- a/diffusion_rs_common/src/varbuilder_loading.rs
+++ b/diffusion_rs_common/src/varbuilder_loading.rs
@@ -33,13 +33,8 @@ impl TensorLoaderBackend for SafetensorBackend {
.map(|(name, _)| name)
.collect::>()
}
- fn load_name(&self, name: &str, device: &Device, dtype: Option) -> Result {
- let t = self.0.load(name, device)?;
- if let Some(dtype) = dtype {
- t.to_dtype(dtype)
- } else {
- Ok(t)
- }
+ fn load_name(&self, name: &str, device: &Device, _dtype: Option) -> Result {
+ self.0.load(name, device)
}
}
@@ -53,13 +48,8 @@ impl TensorLoaderBackend for BytesSafetensorBackend<'_> {
.map(|(name, _)| name)
.collect::>()
}
- fn load_name(&self, name: &str, device: &Device, dtype: Option) -> Result {
- let t = self.0.load(name, device)?;
- if let Some(dtype) = dtype {
- t.to_dtype(dtype)
- } else {
- Ok(t)
- }
+ fn load_name(&self, name: &str, device: &Device, _dtype: Option) -> Result {
+ self.0.load(name, device)
}
}
diff --git a/diffusion_rs_core/src/lib.rs b/diffusion_rs_core/src/lib.rs
index 0575831..a3968bb 100644
--- a/diffusion_rs_core/src/lib.rs
+++ b/diffusion_rs_core/src/lib.rs
@@ -5,7 +5,7 @@
//! ```rust,no_run
//! use std::time::Instant;
//!
-//! use diffusion_rs_core::{DiffusionGenerationParams, ModelSource, Offloading, Pipeline, TokenSource};
+//! use diffusion_rs_core::{DiffusionGenerationParams, ModelSource, ModelDType, Offloading, Pipeline, TokenSource};
//!
//! let pipeline = Pipeline::load(
//! ModelSource::dduf("FLUX.1-dev-Q4-bnb.dduf")?,
@@ -13,6 +13,7 @@
//! TokenSource::CacheToken,
//! None,
//! None,
+//! &ModelDType::Auto,
//! )?;
//!
//! let start = Instant::now();
@@ -37,6 +38,8 @@
mod models;
mod pipelines;
+mod util;
pub use diffusion_rs_common::{ModelSource, TokenSource};
pub use pipelines::{DiffusionGenerationParams, Offloading, Pipeline};
+pub use util::{ModelDType, TryIntoDType};
diff --git a/diffusion_rs_core/src/models/vaes/mod.rs b/diffusion_rs_core/src/models/vaes/mod.rs
index 1a8c3d0..823b4a2 100644
--- a/diffusion_rs_core/src/models/vaes/mod.rs
+++ b/diffusion_rs_core/src/models/vaes/mod.rs
@@ -2,7 +2,7 @@ use std::sync::Arc;
use autoencoder_kl::{AutencoderKlConfig, AutoEncoderKl};
use diffusion_rs_common::{
- core::{Device, Result, Tensor},
+ core::{DType, Device, Result, Tensor},
ModelSource,
};
use serde::Deserialize;
@@ -46,10 +46,17 @@ pub(crate) fn dispatch_load_vae_model(
cfg_json: &FileData,
safetensor_files: Vec,
device: &Device,
+ dtype: DType,
silent: bool,
source: Arc,
) -> anyhow::Result> {
- let vb = from_mmaped_safetensors(safetensor_files, None, device, silent, source.clone())?;
+ let vb = from_mmaped_safetensors(
+ safetensor_files,
+ Some(dtype),
+ device,
+ silent,
+ source.clone(),
+ )?;
let VaeConfigShim { name } = serde_json::from_str(&cfg_json.read_to_string(&source)?)?;
match name.as_str() {
diff --git a/diffusion_rs_core/src/pipelines/flux/mod.rs b/diffusion_rs_core/src/pipelines/flux/mod.rs
index 1fecdaf..8db8de2 100644
--- a/diffusion_rs_core/src/pipelines/flux/mod.rs
+++ b/diffusion_rs_core/src/pipelines/flux/mod.rs
@@ -46,6 +46,7 @@ impl Loader for FluxLoader {
&self,
mut components: HashMap,
device: &Device,
+ dtype: DType,
silent: bool,
offloading_type: Option,
source: Arc,
@@ -96,7 +97,7 @@ impl Loader for FluxLoader {
let vb = from_mmaped_safetensors(
safetensors.into_values().collect(),
- None,
+ Some(dtype),
device,
silent,
source.clone(),
@@ -116,7 +117,7 @@ impl Loader for FluxLoader {
let cfg: T5Config = serde_json::from_str(&config.read_to_string(&source)?)?;
let vb = from_mmaped_safetensors(
safetensors.into_values().collect(),
- None,
+ Some(dtype),
&t5_flux_device,
silent,
source.clone(),
@@ -137,6 +138,7 @@ impl Loader for FluxLoader {
&config,
safetensors.into_values().collect(),
device,
+ dtype,
silent,
source.clone(),
)?
@@ -154,7 +156,7 @@ impl Loader for FluxLoader {
let cfg: FluxConfig = serde_json::from_str(&config.read_to_string(&source)?)?;
let vb = from_mmaped_safetensors(
safetensors.into_values().collect(),
- None,
+ Some(dtype),
&t5_flux_device,
silent,
source,
diff --git a/diffusion_rs_core/src/pipelines/mod.rs b/diffusion_rs_core/src/pipelines/mod.rs
index ed5073b..9dcec12 100644
--- a/diffusion_rs_core/src/pipelines/mod.rs
+++ b/diffusion_rs_core/src/pipelines/mod.rs
@@ -9,7 +9,7 @@ use std::{
};
use anyhow::Result;
-use diffusion_rs_common::core::{Device, Tensor};
+use diffusion_rs_common::core::{DType, Device, Tensor};
use flux::FluxLoader;
use image::{DynamicImage, RgbImage};
use serde::Deserialize;
@@ -17,6 +17,8 @@ use serde::Deserialize;
use diffusion_rs_common::{FileData, FileLoader, ModelSource, NiceProgressBar, TokenSource};
use tracing::info;
+use crate::TryIntoDType;
+
/// Generation parameters.
#[derive(Debug, Clone)]
pub struct DiffusionGenerationParams {
@@ -82,6 +84,7 @@ pub(crate) trait Loader {
&self,
components: HashMap,
device: &Device,
+ dtype: DType,
silent: bool,
offloading_type: Option,
source: Arc,
@@ -120,6 +123,7 @@ impl Pipeline {
token: TokenSource,
revision: Option,
offloading_type: Option,
+ dtype: &dyn TryIntoDType,
) -> Result {
info!("loading from source: {source}.");
@@ -212,9 +216,14 @@ impl Pipeline {
#[cfg(feature = "metal")]
let device = Device::new_metal(0)?;
+ // NOTE: we can set the device to be just the primary even in the offloading case.
+ // This will need to be updated!
+ let dtype = dtype.try_into_dtype(&[&device], silent)?;
+
let model = model_loader.load_from_components(
components,
&device,
+ dtype,
silent,
offloading_type,
Arc::new(source),
diff --git a/diffusion_rs_core/src/util/auto_dtype.rs b/diffusion_rs_core/src/util/auto_dtype.rs
new file mode 100644
index 0000000..1ac933e
--- /dev/null
+++ b/diffusion_rs_core/src/util/auto_dtype.rs
@@ -0,0 +1,161 @@
+use std::fmt::Display;
+
+use anyhow::Result;
+use diffusion_rs_common::core::{DType, Device, Tensor};
+use serde::Deserialize;
+use tracing::info;
+
+#[derive(Clone, Copy, Default, Debug, Deserialize, PartialEq, clap::ValueEnum)]
+/// DType for the model.
+///
+/// Note: When using `Auto`, fallback pattern is: BF16 -> F16 -> F32
+pub enum ModelDType {
+ #[default]
+ #[serde(rename = "auto")]
+ Auto,
+ #[serde(rename = "bf16")]
+ BF16,
+ #[serde(rename = "f16")]
+ F16,
+ #[serde(rename = "f32")]
+ F32,
+}
+
+impl Display for ModelDType {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Self::Auto => write!(f, "auto"),
+ Self::BF16 => write!(f, "bf16"),
+ Self::F16 => write!(f, "f16"),
+ Self::F32 => write!(f, "f32"),
+ }
+ }
+}
+
+/// Type which can be converted to a DType
+pub trait TryIntoDType {
+ fn try_into_dtype(&self, devices: &[&Device], silent: bool) -> Result;
+}
+
+impl TryIntoDType for DType {
+ fn try_into_dtype(&self, _: &[&Device], silent: bool) -> Result {
+ if !silent {
+ info!("dtype selected is {self:?}.");
+ }
+ if !matches!(self, DType::BF16 | DType::F32 | DType::F64 | DType::F16) {
+ anyhow::bail!("DType must be one of BF16, F16, F32, F64");
+ }
+ Ok(*self)
+ }
+}
+
+#[cfg(feature = "cuda")]
+fn get_dtypes(silent: bool) -> Vec {
+ use std::process::Command;
+
+ // >= is supported
+ const MIN_BF16_CC: usize = 800;
+ // >= is supported
+ const MIN_F16_CC: usize = 530;
+
+ let raw_out = Command::new("nvidia-smi")
+ .arg("--query-gpu=compute_cap")
+ .arg("--format=csv")
+ .output()
+ .expect("Failed to run `nvidia-smi` but CUDA is selected.")
+ .stdout;
+ let out = String::from_utf8(raw_out).expect("`nvidia-smi` did not return valid utf8");
+ // This reduce-min will always return at least one value so unwrap is OK.
+ let min_cc = out
+ .split('\n')
+ .skip(1)
+ .filter(|cc| !cc.trim().is_empty())
+ .map(|cc| cc.trim().parse::().unwrap())
+ .reduce(|a, b| if a < b { a } else { b })
+ .unwrap();
+ if !silent {
+ info!("detected minimum CUDA compute capability {min_cc}");
+ }
+ // 7.5 -> 750
+ #[allow(clippy::cast_possible_truncation)]
+ let min_cc = (min_cc * 100.) as usize;
+
+ let mut dtypes = Vec::new();
+ if min_cc >= MIN_BF16_CC {
+ dtypes.push(DType::BF16);
+ } else if !silent {
+ info!("skipping BF16 because CC < 8.0");
+ }
+ if min_cc >= MIN_F16_CC {
+ dtypes.push(DType::F16);
+ } else if !silent {
+ info!("skipping F16 because CC < 5.3");
+ }
+ dtypes
+}
+
+fn get_dtypes_non_cuda() -> Vec {
+ vec![DType::BF16, DType::F16]
+}
+
+#[cfg(not(feature = "cuda"))]
+fn get_dtypes(_silent: bool) -> Vec {
+ get_dtypes_non_cuda()
+}
+
+fn determine_auto_dtype_all(
+ devices: &[&Device],
+ silent: bool,
+) -> diffusion_rs_common::core::Result {
+ let dev_dtypes = get_dtypes(silent);
+ for dtype in get_dtypes_non_cuda()
+ .iter()
+ .filter(|x| dev_dtypes.contains(x))
+ {
+ let mut results = Vec::new();
+ for device in devices {
+ // Try a matmul
+ let x = Tensor::zeros((2, 2), *dtype, device)?;
+ results.push(x.matmul(&x));
+ }
+ if results.iter().all(|x| x.is_ok()) {
+ return Ok(*dtype);
+ } else {
+ for result in results {
+ match result {
+ Ok(_) => (),
+ Err(e) => match e {
+ // For CUDA
+ diffusion_rs_common::core::Error::UnsupportedDTypeForOp(_, _) => continue,
+ // Accelerate backend doesn't support f16/bf16
+ // Metal backend doesn't support f16
+ diffusion_rs_common::core::Error::Msg(_) => continue,
+ // This is when the metal backend doesn't support bf16
+ diffusion_rs_common::core::Error::Metal(_) => continue,
+ // If running with RUST_BACKTRACE=1
+ diffusion_rs_common::core::Error::WithBacktrace { .. } => continue,
+ other => return Err(other),
+ },
+ }
+ }
+ }
+ }
+ Ok(DType::F32)
+}
+
+impl TryIntoDType for ModelDType {
+ fn try_into_dtype(&self, devices: &[&Device], silent: bool) -> Result {
+ let dtype = match self {
+ Self::Auto => {
+ Ok(determine_auto_dtype_all(devices, silent).map_err(anyhow::Error::msg)?)
+ }
+ Self::BF16 => Ok(DType::BF16),
+ Self::F16 => Ok(DType::F16),
+ Self::F32 => Ok(DType::F32),
+ };
+ if !silent {
+ info!("dtype selected is {:?}.", dtype.as_ref().unwrap());
+ }
+ dtype
+ }
+}
diff --git a/diffusion_rs_core/src/util/mod.rs b/diffusion_rs_core/src/util/mod.rs
new file mode 100644
index 0000000..b794ef6
--- /dev/null
+++ b/diffusion_rs_core/src/util/mod.rs
@@ -0,0 +1,3 @@
+mod auto_dtype;
+
+pub use auto_dtype::{ModelDType, TryIntoDType};
diff --git a/diffusion_rs_examples/Cargo.toml b/diffusion_rs_examples/Cargo.toml
index 2455d60..5278e15 100644
--- a/diffusion_rs_examples/Cargo.toml
+++ b/diffusion_rs_examples/Cargo.toml
@@ -17,3 +17,10 @@ anyhow.workspace = true
clap.workspace = true
tracing.workspace = true
tracing-subscriber.workspace = true
+
+[features]
+cuda = ["diffusion_rs_core/cuda"]
+cudnn = ["diffusion_rs_core/cudnn"]
+metal = ["diffusion_rs_core/metal"]
+accelerate = ["diffusion_rs_core/accelerate"]
+mkl = ["diffusion_rs_core/mkl"]
diff --git a/diffusion_rs_examples/examples/dduf/main.rs b/diffusion_rs_examples/examples/dduf/main.rs
index 431c753..dafd62a 100644
--- a/diffusion_rs_examples/examples/dduf/main.rs
+++ b/diffusion_rs_examples/examples/dduf/main.rs
@@ -2,7 +2,7 @@ use std::time::Instant;
use clap::Parser;
use diffusion_rs_core::{
- DiffusionGenerationParams, ModelSource, Offloading, Pipeline, TokenSource,
+ DiffusionGenerationParams, ModelDType, ModelSource, Offloading, Pipeline, TokenSource,
};
use tracing::level_filters::LevelFilter;
use tracing_subscriber::EnvFilter;
@@ -44,6 +44,7 @@ fn main() -> anyhow::Result<()> {
TokenSource::CacheToken,
None,
args.offloading,
+ &ModelDType::Auto,
)?;
let start = Instant::now();
diff --git a/diffusion_rs_examples/examples/flux/main.rs b/diffusion_rs_examples/examples/flux/main.rs
index 81f1665..e1850a5 100644
--- a/diffusion_rs_examples/examples/flux/main.rs
+++ b/diffusion_rs_examples/examples/flux/main.rs
@@ -1,7 +1,7 @@
use std::time::Instant;
use diffusion_rs_core::{
- DiffusionGenerationParams, ModelSource, Offloading, Pipeline, TokenSource,
+ DiffusionGenerationParams, ModelDType, ModelSource, Offloading, Pipeline, TokenSource,
};
use clap::{Parser, ValueEnum};
@@ -50,6 +50,7 @@ fn main() -> anyhow::Result<()> {
TokenSource::CacheToken,
None,
args.offloading,
+ &ModelDType::Auto,
)?;
let num_steps = match args.which {
Which::Dev => 50,
diff --git a/diffusion_rs_py/diffuse_rs.pyi b/diffusion_rs_py/diffuse_rs.pyi
index ffeeb3c..2592e71 100644
--- a/diffusion_rs_py/diffuse_rs.pyi
+++ b/diffusion_rs_py/diffuse_rs.pyi
@@ -1,6 +1,18 @@
from dataclasses import dataclass
from enum import Enum
+@dataclass
+class ModelDType(Enum):
+ """
+ DType for the model.
+ Note: When using `Auto`, fallback pattern is: BF16 -> F16 -> F32
+ """
+
+ Auto = 0
+ BF16 = 1
+ F16 = 2
+ F32 = 2
+
@dataclass
class Offloading(Enum):
"""
@@ -40,6 +52,8 @@ class Pipeline:
silent: bool = False,
token: str | None = None,
revision: str | None = None,
+ offloading: Offloading | None = None,
+ ModelDType: ModelDType = ModelDType.Auto,
) -> None:
"""
Load a model.
@@ -47,7 +61,10 @@ class Pipeline:
- `source`: the source of the model
- `silent`: silent loading, defaults to `False`.
- `token`: specifies a literal Hugging Face token for accessing gated models.
- - `revision`: specifies a specific Hugging Face model revision, otherwise the default is used. - `token_source` specifies where to load the HF token from.
+ - `revision`: specifies a specific Hugging Face model revision, otherwise the default is used.
+ - `token_source` specifies where to load the HF token from.
+ - `offloading`: offloading setting for the model.
+ - `dtype`: dtype selection for the model. The default is to use an automatic strategy with a fallback pattern: BF16 -> F16 -> F32
"""
...
diff --git a/diffusion_rs_py/src/lib.rs b/diffusion_rs_py/src/lib.rs
index 2e9f039..f0c6a82 100644
--- a/diffusion_rs_py/src/lib.rs
+++ b/diffusion_rs_py/src/lib.rs
@@ -33,6 +33,16 @@ pub struct DiffusionGenerationParams {
pub guidance_scale: f64,
}
+#[pyclass(eq, eq_int)]
+#[pyo3(get_all)]
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub enum ModelDType {
+ Auto,
+ BF16,
+ F16,
+ F32,
+}
+
#[pymethods]
impl DiffusionGenerationParams {
#[new]
@@ -77,6 +87,7 @@ impl Pipeline {
token = None,
revision = None,
offloading = None,
+ dtype = ModelDType::Auto,
))]
pub fn new(
source: ModelSource,
@@ -84,6 +95,7 @@ impl Pipeline {
token: Option,
revision: Option,
offloading: Option,
+ dtype: ModelDType,
) -> PyResult {
let token = token
.map(diffusion_rs_core::TokenSource::Literal)
@@ -99,8 +111,14 @@ impl Pipeline {
let offloading = offloading.map(|offloading| match offloading {
Offloading::Full => diffusion_rs_core::Offloading::Full,
});
+ let dtype = match dtype {
+ ModelDType::Auto => diffusion_rs_core::ModelDType::Auto,
+ ModelDType::F16 => diffusion_rs_core::ModelDType::F16,
+ ModelDType::BF16 => diffusion_rs_core::ModelDType::BF16,
+ ModelDType::F32 => diffusion_rs_core::ModelDType::F32,
+ };
Ok(Self(
- diffusion_rs_core::Pipeline::load(source, silent, token, revision, offloading)
+ diffusion_rs_core::Pipeline::load(source, silent, token, revision, offloading, &dtype)
.map_err(wrap_anyhow_error)?,
))
}