diff --git a/.github/workflows/analysis.yaml b/.github/workflows/analysis.yaml new file mode 100644 index 0000000..52a5fe0 --- /dev/null +++ b/.github/workflows/analysis.yaml @@ -0,0 +1,64 @@ +name: Analysis +on: + pull_request_target + +jobs: + comment: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install Rust and Cargo + run: | + curl -sSf https://sh.rustup.rs | sh -s -- -y + source $HOME/.cargo/env + + - name: Install Tokei + run: cargo install tokei + + - name: Run Tokei and get the lines of code + run: tokei . > tokei_output.txt + + - name: Comment or Update PR + uses: actions/github-script@v7 + with: + script: | + const fs = require('fs'); + const tokeiOutput = fs.readFileSync('tokei_output.txt', 'utf8'); + const uniqueIdentifier = 'Code Metrics Report'; + const codeReport = ` +
+ ${uniqueIdentifier} +
+              ${tokeiOutput}
+              
+
+ `; + + const issue_number = context.issue.number; + const { owner, repo } = context.repo; + + const comments = await github.rest.issues.listComments({ + issue_number, + owner, + repo + }); + + const existingComment = comments.data.find(comment => comment.body.includes(uniqueIdentifier)); + + if (existingComment) { + await github.rest.issues.updateComment({ + owner, + repo, + comment_id: existingComment.id, + body: codeReport + }); + } else { + await github.rest.issues.createComment({ + issue_number, + owner, + repo, + body: codeReport + }); + } diff --git a/README.md b/README.md index 835a259..7fd0217 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Blazingly fast inference of diffusion models.

-| Rust Documentation | Python Documentation | Discord | +| Latest Rust Documentation | Latest Python Documentation | Discord |

@@ -72,7 +72,7 @@ Examples with the Rust crate: [here](diffusion_rs_examples/examples). ```rust use std::time::Instant; -use diffusion_rs_core::{DiffusionGenerationParams, ModelSource, Offloading, Pipeline, TokenSource}; +use diffusion_rs_core::{DiffusionGenerationParams, ModelSource, ModelDType, Offloading, Pipeline, TokenSource}; use tracing::level_filters::LevelFilter; use tracing_subscriber::EnvFilter; @@ -87,6 +87,7 @@ let pipeline = Pipeline::load( TokenSource::CacheToken, None, None, + &ModelDType::Auto, )?; let start = Instant::now(); diff --git a/diffusion_rs_cli/src/main.rs b/diffusion_rs_cli/src/main.rs index 9b0cb7d..ffe9516 100644 --- a/diffusion_rs_cli/src/main.rs +++ b/diffusion_rs_cli/src/main.rs @@ -3,7 +3,7 @@ use std::{path::PathBuf, time::Instant}; use clap::{Parser, Subcommand}; use diffusion_rs_core::{ - DiffusionGenerationParams, ModelSource, Offloading, Pipeline, TokenSource, + DiffusionGenerationParams, ModelDType, ModelSource, Offloading, Pipeline, TokenSource, }; use tracing::level_filters::LevelFilter; use tracing_subscriber::EnvFilter; @@ -48,6 +48,10 @@ struct Args { /// Offloading setting to use for this model #[arg(short, long)] offloading: Option, + + /// DType for the model. The default is to use an automatic strategy with a fallback pattern: BF16 -> F16 -> F32 + #[arg(short, long, default_value = "auto")] + dtype: ModelDType, } fn main() -> anyhow::Result<()> { @@ -67,7 +71,7 @@ fn main() -> anyhow::Result<()> { .map(TokenSource::Literal) .unwrap_or(TokenSource::CacheToken); - let pipeline = Pipeline::load(source, false, token, None, args.offloading)?; + let pipeline = Pipeline::load(source, false, token, None, args.offloading, &args.dtype)?; let height: usize = input("Height:") .default_input("720") diff --git a/diffusion_rs_common/src/cuda_kernels/cast.cu b/diffusion_rs_common/src/cuda_kernels/cast.cu index 07b079a..7262525 100644 --- a/diffusion_rs_common/src/cuda_kernels/cast.cu +++ b/diffusion_rs_common/src/cuda_kernels/cast.cu @@ -1,5 +1,6 @@ #include "cuda_utils.cuh" #include +#include "dummy_bf16.cuh" template __device__ void cast_( @@ -96,6 +97,29 @@ __device__ void cast_through( } } +template +__device__ void cast_bf16_dummy( + const size_t numel, + const size_t num_dims, + const size_t *info, + const uint16_t *inp, + T *out +) { + const size_t *dims = info; + const size_t *strides = info + num_dims; + if (info == nullptr || is_contiguous(num_dims, dims, strides)) { + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + out[i] = static_cast(bf16_to_f32(inp[i])); + } + } + else { + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + unsigned strided_i = get_strided_index(i, num_dims, dims, strides); + out[i] = static_cast(bf16_to_f32(inp[strided_i])); + } + } +} + #define CAST_OP(SRC_TYPENAME, DST_TYPENAME, FN_NAME) \ extern "C" __global__ void FN_NAME( \ @@ -143,6 +167,17 @@ extern "C" __global__ void FN_NAME( \ cast_through(numel, num_dims, info, inp, out); \ } \ +#define CAST_BF16_FALLBACK_OP(DST_TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *info, \ + const uint16_t *inp, \ + DST_TYPENAME *out \ +) { \ + cast_bf16_dummy(numel, num_dims, info, inp, out); \ +} \ + #if __CUDA_ARCH__ >= 800 #include "cuda_fp8.h" #include "cuda_bf16.h" @@ -195,6 +230,14 @@ CAST_OP(float, __half, cast_f32_f16) CAST_OP(double, __half, cast_f64_f16) CAST_OP(int32_t, __half, cast_i32_f16 ) CAST_THROUGH_OP(__half, int32_t, float, cast_f16_i32) + +#if __CUDA_ARCH__ < 800 +CAST_BF16_FALLBACK_OP(uint32_t, cast_bf16_u32) +CAST_BF16_FALLBACK_OP(float, cast_bf16_f32) +CAST_BF16_FALLBACK_OP(double, cast_bf16_f64) +CAST_BF16_FALLBACK_OP(__half, cast_bf16_f16) +CAST_BF16_FALLBACK_OP(int32_t, cast_bf16_i32) +#endif #endif CAST_OP(uint32_t, uint32_t, cast_u32_u32) diff --git a/diffusion_rs_common/src/cuda_kernels/dummy_bf16.cuh b/diffusion_rs_common/src/cuda_kernels/dummy_bf16.cuh new file mode 100644 index 0000000..ea0f8b3 --- /dev/null +++ b/diffusion_rs_common/src/cuda_kernels/dummy_bf16.cuh @@ -0,0 +1,56 @@ +#include + +__device__ __forceinline__ float bf16_to_f32(const uint16_t i) +{ + // If NaN, keep current mantissa but also set most significant mantissa bit + if ((i & 0x7FFFu) > 0x7F80u) { + // NaN path + uint32_t tmp = ((static_cast(i) | 0x0040u) << 16); + union { + uint32_t as_int; + float as_float; + } u; + u.as_int = tmp; + return u.as_float; + // Alternatively: + // return __int_as_float(((static_cast(i) | 0x0040u) << 16)); + } else { + // Normal path + uint32_t tmp = (static_cast(i) << 16); + union { + uint32_t as_int; + float as_float; + } u; + u.as_int = tmp; + return u.as_float; + // Alternatively: + // return __int_as_float(static_cast(i) << 16); + } +} + +// Convert FP32 (float) to BF16 (unsigned short) +__device__ __forceinline__ uint16_t f32_to_bf16(const float value) +{ + // Reinterpret float bits as uint32_t + union { + float as_float; + uint32_t as_int; + } u; + u.as_float = value; + uint32_t x = u.as_int; + + // Check for NaN + if ((x & 0x7FFF'FFFFu) > 0x7F80'0000u) { + // Keep high part of current mantissa but also set most significant mantissa bit + return static_cast((x >> 16) | 0x0040u); + } + + // Round and shift + constexpr uint32_t round_bit = 0x0000'8000u; // bit 15 + if (((x & round_bit) != 0) && ((x & (3 * round_bit - 1)) != 0)) { + // Round half to even (or to odd) depends on your preference + return static_cast((x >> 16) + 1); + } else { + return static_cast(x >> 16); + } +} diff --git a/diffusion_rs_common/src/varbuilder_loading.rs b/diffusion_rs_common/src/varbuilder_loading.rs index 8ee7fa8..8f0bf1a 100644 --- a/diffusion_rs_common/src/varbuilder_loading.rs +++ b/diffusion_rs_common/src/varbuilder_loading.rs @@ -33,13 +33,8 @@ impl TensorLoaderBackend for SafetensorBackend { .map(|(name, _)| name) .collect::>() } - fn load_name(&self, name: &str, device: &Device, dtype: Option) -> Result { - let t = self.0.load(name, device)?; - if let Some(dtype) = dtype { - t.to_dtype(dtype) - } else { - Ok(t) - } + fn load_name(&self, name: &str, device: &Device, _dtype: Option) -> Result { + self.0.load(name, device) } } @@ -53,13 +48,8 @@ impl TensorLoaderBackend for BytesSafetensorBackend<'_> { .map(|(name, _)| name) .collect::>() } - fn load_name(&self, name: &str, device: &Device, dtype: Option) -> Result { - let t = self.0.load(name, device)?; - if let Some(dtype) = dtype { - t.to_dtype(dtype) - } else { - Ok(t) - } + fn load_name(&self, name: &str, device: &Device, _dtype: Option) -> Result { + self.0.load(name, device) } } diff --git a/diffusion_rs_core/src/lib.rs b/diffusion_rs_core/src/lib.rs index 0575831..a3968bb 100644 --- a/diffusion_rs_core/src/lib.rs +++ b/diffusion_rs_core/src/lib.rs @@ -5,7 +5,7 @@ //! ```rust,no_run //! use std::time::Instant; //! -//! use diffusion_rs_core::{DiffusionGenerationParams, ModelSource, Offloading, Pipeline, TokenSource}; +//! use diffusion_rs_core::{DiffusionGenerationParams, ModelSource, ModelDType, Offloading, Pipeline, TokenSource}; //! //! let pipeline = Pipeline::load( //! ModelSource::dduf("FLUX.1-dev-Q4-bnb.dduf")?, @@ -13,6 +13,7 @@ //! TokenSource::CacheToken, //! None, //! None, +//! &ModelDType::Auto, //! )?; //! //! let start = Instant::now(); @@ -37,6 +38,8 @@ mod models; mod pipelines; +mod util; pub use diffusion_rs_common::{ModelSource, TokenSource}; pub use pipelines::{DiffusionGenerationParams, Offloading, Pipeline}; +pub use util::{ModelDType, TryIntoDType}; diff --git a/diffusion_rs_core/src/models/vaes/mod.rs b/diffusion_rs_core/src/models/vaes/mod.rs index 1a8c3d0..823b4a2 100644 --- a/diffusion_rs_core/src/models/vaes/mod.rs +++ b/diffusion_rs_core/src/models/vaes/mod.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use autoencoder_kl::{AutencoderKlConfig, AutoEncoderKl}; use diffusion_rs_common::{ - core::{Device, Result, Tensor}, + core::{DType, Device, Result, Tensor}, ModelSource, }; use serde::Deserialize; @@ -46,10 +46,17 @@ pub(crate) fn dispatch_load_vae_model( cfg_json: &FileData, safetensor_files: Vec, device: &Device, + dtype: DType, silent: bool, source: Arc, ) -> anyhow::Result> { - let vb = from_mmaped_safetensors(safetensor_files, None, device, silent, source.clone())?; + let vb = from_mmaped_safetensors( + safetensor_files, + Some(dtype), + device, + silent, + source.clone(), + )?; let VaeConfigShim { name } = serde_json::from_str(&cfg_json.read_to_string(&source)?)?; match name.as_str() { diff --git a/diffusion_rs_core/src/pipelines/flux/mod.rs b/diffusion_rs_core/src/pipelines/flux/mod.rs index 1fecdaf..8db8de2 100644 --- a/diffusion_rs_core/src/pipelines/flux/mod.rs +++ b/diffusion_rs_core/src/pipelines/flux/mod.rs @@ -46,6 +46,7 @@ impl Loader for FluxLoader { &self, mut components: HashMap, device: &Device, + dtype: DType, silent: bool, offloading_type: Option, source: Arc, @@ -96,7 +97,7 @@ impl Loader for FluxLoader { let vb = from_mmaped_safetensors( safetensors.into_values().collect(), - None, + Some(dtype), device, silent, source.clone(), @@ -116,7 +117,7 @@ impl Loader for FluxLoader { let cfg: T5Config = serde_json::from_str(&config.read_to_string(&source)?)?; let vb = from_mmaped_safetensors( safetensors.into_values().collect(), - None, + Some(dtype), &t5_flux_device, silent, source.clone(), @@ -137,6 +138,7 @@ impl Loader for FluxLoader { &config, safetensors.into_values().collect(), device, + dtype, silent, source.clone(), )? @@ -154,7 +156,7 @@ impl Loader for FluxLoader { let cfg: FluxConfig = serde_json::from_str(&config.read_to_string(&source)?)?; let vb = from_mmaped_safetensors( safetensors.into_values().collect(), - None, + Some(dtype), &t5_flux_device, silent, source, diff --git a/diffusion_rs_core/src/pipelines/mod.rs b/diffusion_rs_core/src/pipelines/mod.rs index ed5073b..9dcec12 100644 --- a/diffusion_rs_core/src/pipelines/mod.rs +++ b/diffusion_rs_core/src/pipelines/mod.rs @@ -9,7 +9,7 @@ use std::{ }; use anyhow::Result; -use diffusion_rs_common::core::{Device, Tensor}; +use diffusion_rs_common::core::{DType, Device, Tensor}; use flux::FluxLoader; use image::{DynamicImage, RgbImage}; use serde::Deserialize; @@ -17,6 +17,8 @@ use serde::Deserialize; use diffusion_rs_common::{FileData, FileLoader, ModelSource, NiceProgressBar, TokenSource}; use tracing::info; +use crate::TryIntoDType; + /// Generation parameters. #[derive(Debug, Clone)] pub struct DiffusionGenerationParams { @@ -82,6 +84,7 @@ pub(crate) trait Loader { &self, components: HashMap, device: &Device, + dtype: DType, silent: bool, offloading_type: Option, source: Arc, @@ -120,6 +123,7 @@ impl Pipeline { token: TokenSource, revision: Option, offloading_type: Option, + dtype: &dyn TryIntoDType, ) -> Result { info!("loading from source: {source}."); @@ -212,9 +216,14 @@ impl Pipeline { #[cfg(feature = "metal")] let device = Device::new_metal(0)?; + // NOTE: we can set the device to be just the primary even in the offloading case. + // This will need to be updated! + let dtype = dtype.try_into_dtype(&[&device], silent)?; + let model = model_loader.load_from_components( components, &device, + dtype, silent, offloading_type, Arc::new(source), diff --git a/diffusion_rs_core/src/util/auto_dtype.rs b/diffusion_rs_core/src/util/auto_dtype.rs new file mode 100644 index 0000000..1ac933e --- /dev/null +++ b/diffusion_rs_core/src/util/auto_dtype.rs @@ -0,0 +1,161 @@ +use std::fmt::Display; + +use anyhow::Result; +use diffusion_rs_common::core::{DType, Device, Tensor}; +use serde::Deserialize; +use tracing::info; + +#[derive(Clone, Copy, Default, Debug, Deserialize, PartialEq, clap::ValueEnum)] +/// DType for the model. +/// +/// Note: When using `Auto`, fallback pattern is: BF16 -> F16 -> F32 +pub enum ModelDType { + #[default] + #[serde(rename = "auto")] + Auto, + #[serde(rename = "bf16")] + BF16, + #[serde(rename = "f16")] + F16, + #[serde(rename = "f32")] + F32, +} + +impl Display for ModelDType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Auto => write!(f, "auto"), + Self::BF16 => write!(f, "bf16"), + Self::F16 => write!(f, "f16"), + Self::F32 => write!(f, "f32"), + } + } +} + +/// Type which can be converted to a DType +pub trait TryIntoDType { + fn try_into_dtype(&self, devices: &[&Device], silent: bool) -> Result; +} + +impl TryIntoDType for DType { + fn try_into_dtype(&self, _: &[&Device], silent: bool) -> Result { + if !silent { + info!("dtype selected is {self:?}."); + } + if !matches!(self, DType::BF16 | DType::F32 | DType::F64 | DType::F16) { + anyhow::bail!("DType must be one of BF16, F16, F32, F64"); + } + Ok(*self) + } +} + +#[cfg(feature = "cuda")] +fn get_dtypes(silent: bool) -> Vec { + use std::process::Command; + + // >= is supported + const MIN_BF16_CC: usize = 800; + // >= is supported + const MIN_F16_CC: usize = 530; + + let raw_out = Command::new("nvidia-smi") + .arg("--query-gpu=compute_cap") + .arg("--format=csv") + .output() + .expect("Failed to run `nvidia-smi` but CUDA is selected.") + .stdout; + let out = String::from_utf8(raw_out).expect("`nvidia-smi` did not return valid utf8"); + // This reduce-min will always return at least one value so unwrap is OK. + let min_cc = out + .split('\n') + .skip(1) + .filter(|cc| !cc.trim().is_empty()) + .map(|cc| cc.trim().parse::().unwrap()) + .reduce(|a, b| if a < b { a } else { b }) + .unwrap(); + if !silent { + info!("detected minimum CUDA compute capability {min_cc}"); + } + // 7.5 -> 750 + #[allow(clippy::cast_possible_truncation)] + let min_cc = (min_cc * 100.) as usize; + + let mut dtypes = Vec::new(); + if min_cc >= MIN_BF16_CC { + dtypes.push(DType::BF16); + } else if !silent { + info!("skipping BF16 because CC < 8.0"); + } + if min_cc >= MIN_F16_CC { + dtypes.push(DType::F16); + } else if !silent { + info!("skipping F16 because CC < 5.3"); + } + dtypes +} + +fn get_dtypes_non_cuda() -> Vec { + vec![DType::BF16, DType::F16] +} + +#[cfg(not(feature = "cuda"))] +fn get_dtypes(_silent: bool) -> Vec { + get_dtypes_non_cuda() +} + +fn determine_auto_dtype_all( + devices: &[&Device], + silent: bool, +) -> diffusion_rs_common::core::Result { + let dev_dtypes = get_dtypes(silent); + for dtype in get_dtypes_non_cuda() + .iter() + .filter(|x| dev_dtypes.contains(x)) + { + let mut results = Vec::new(); + for device in devices { + // Try a matmul + let x = Tensor::zeros((2, 2), *dtype, device)?; + results.push(x.matmul(&x)); + } + if results.iter().all(|x| x.is_ok()) { + return Ok(*dtype); + } else { + for result in results { + match result { + Ok(_) => (), + Err(e) => match e { + // For CUDA + diffusion_rs_common::core::Error::UnsupportedDTypeForOp(_, _) => continue, + // Accelerate backend doesn't support f16/bf16 + // Metal backend doesn't support f16 + diffusion_rs_common::core::Error::Msg(_) => continue, + // This is when the metal backend doesn't support bf16 + diffusion_rs_common::core::Error::Metal(_) => continue, + // If running with RUST_BACKTRACE=1 + diffusion_rs_common::core::Error::WithBacktrace { .. } => continue, + other => return Err(other), + }, + } + } + } + } + Ok(DType::F32) +} + +impl TryIntoDType for ModelDType { + fn try_into_dtype(&self, devices: &[&Device], silent: bool) -> Result { + let dtype = match self { + Self::Auto => { + Ok(determine_auto_dtype_all(devices, silent).map_err(anyhow::Error::msg)?) + } + Self::BF16 => Ok(DType::BF16), + Self::F16 => Ok(DType::F16), + Self::F32 => Ok(DType::F32), + }; + if !silent { + info!("dtype selected is {:?}.", dtype.as_ref().unwrap()); + } + dtype + } +} diff --git a/diffusion_rs_core/src/util/mod.rs b/diffusion_rs_core/src/util/mod.rs new file mode 100644 index 0000000..b794ef6 --- /dev/null +++ b/diffusion_rs_core/src/util/mod.rs @@ -0,0 +1,3 @@ +mod auto_dtype; + +pub use auto_dtype::{ModelDType, TryIntoDType}; diff --git a/diffusion_rs_examples/Cargo.toml b/diffusion_rs_examples/Cargo.toml index 2455d60..5278e15 100644 --- a/diffusion_rs_examples/Cargo.toml +++ b/diffusion_rs_examples/Cargo.toml @@ -17,3 +17,10 @@ anyhow.workspace = true clap.workspace = true tracing.workspace = true tracing-subscriber.workspace = true + +[features] +cuda = ["diffusion_rs_core/cuda"] +cudnn = ["diffusion_rs_core/cudnn"] +metal = ["diffusion_rs_core/metal"] +accelerate = ["diffusion_rs_core/accelerate"] +mkl = ["diffusion_rs_core/mkl"] diff --git a/diffusion_rs_examples/examples/dduf/main.rs b/diffusion_rs_examples/examples/dduf/main.rs index 431c753..dafd62a 100644 --- a/diffusion_rs_examples/examples/dduf/main.rs +++ b/diffusion_rs_examples/examples/dduf/main.rs @@ -2,7 +2,7 @@ use std::time::Instant; use clap::Parser; use diffusion_rs_core::{ - DiffusionGenerationParams, ModelSource, Offloading, Pipeline, TokenSource, + DiffusionGenerationParams, ModelDType, ModelSource, Offloading, Pipeline, TokenSource, }; use tracing::level_filters::LevelFilter; use tracing_subscriber::EnvFilter; @@ -44,6 +44,7 @@ fn main() -> anyhow::Result<()> { TokenSource::CacheToken, None, args.offloading, + &ModelDType::Auto, )?; let start = Instant::now(); diff --git a/diffusion_rs_examples/examples/flux/main.rs b/diffusion_rs_examples/examples/flux/main.rs index 81f1665..e1850a5 100644 --- a/diffusion_rs_examples/examples/flux/main.rs +++ b/diffusion_rs_examples/examples/flux/main.rs @@ -1,7 +1,7 @@ use std::time::Instant; use diffusion_rs_core::{ - DiffusionGenerationParams, ModelSource, Offloading, Pipeline, TokenSource, + DiffusionGenerationParams, ModelDType, ModelSource, Offloading, Pipeline, TokenSource, }; use clap::{Parser, ValueEnum}; @@ -50,6 +50,7 @@ fn main() -> anyhow::Result<()> { TokenSource::CacheToken, None, args.offloading, + &ModelDType::Auto, )?; let num_steps = match args.which { Which::Dev => 50, diff --git a/diffusion_rs_py/diffuse_rs.pyi b/diffusion_rs_py/diffuse_rs.pyi index ffeeb3c..2592e71 100644 --- a/diffusion_rs_py/diffuse_rs.pyi +++ b/diffusion_rs_py/diffuse_rs.pyi @@ -1,6 +1,18 @@ from dataclasses import dataclass from enum import Enum +@dataclass +class ModelDType(Enum): + """ + DType for the model. + Note: When using `Auto`, fallback pattern is: BF16 -> F16 -> F32 + """ + + Auto = 0 + BF16 = 1 + F16 = 2 + F32 = 2 + @dataclass class Offloading(Enum): """ @@ -40,6 +52,8 @@ class Pipeline: silent: bool = False, token: str | None = None, revision: str | None = None, + offloading: Offloading | None = None, + ModelDType: ModelDType = ModelDType.Auto, ) -> None: """ Load a model. @@ -47,7 +61,10 @@ class Pipeline: - `source`: the source of the model - `silent`: silent loading, defaults to `False`. - `token`: specifies a literal Hugging Face token for accessing gated models. - - `revision`: specifies a specific Hugging Face model revision, otherwise the default is used. - `token_source` specifies where to load the HF token from. + - `revision`: specifies a specific Hugging Face model revision, otherwise the default is used. + - `token_source` specifies where to load the HF token from. + - `offloading`: offloading setting for the model. + - `dtype`: dtype selection for the model. The default is to use an automatic strategy with a fallback pattern: BF16 -> F16 -> F32 """ ... diff --git a/diffusion_rs_py/src/lib.rs b/diffusion_rs_py/src/lib.rs index 2e9f039..f0c6a82 100644 --- a/diffusion_rs_py/src/lib.rs +++ b/diffusion_rs_py/src/lib.rs @@ -33,6 +33,16 @@ pub struct DiffusionGenerationParams { pub guidance_scale: f64, } +#[pyclass(eq, eq_int)] +#[pyo3(get_all)] +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum ModelDType { + Auto, + BF16, + F16, + F32, +} + #[pymethods] impl DiffusionGenerationParams { #[new] @@ -77,6 +87,7 @@ impl Pipeline { token = None, revision = None, offloading = None, + dtype = ModelDType::Auto, ))] pub fn new( source: ModelSource, @@ -84,6 +95,7 @@ impl Pipeline { token: Option, revision: Option, offloading: Option, + dtype: ModelDType, ) -> PyResult { let token = token .map(diffusion_rs_core::TokenSource::Literal) @@ -99,8 +111,14 @@ impl Pipeline { let offloading = offloading.map(|offloading| match offloading { Offloading::Full => diffusion_rs_core::Offloading::Full, }); + let dtype = match dtype { + ModelDType::Auto => diffusion_rs_core::ModelDType::Auto, + ModelDType::F16 => diffusion_rs_core::ModelDType::F16, + ModelDType::BF16 => diffusion_rs_core::ModelDType::BF16, + ModelDType::F32 => diffusion_rs_core::ModelDType::F32, + }; Ok(Self( - diffusion_rs_core::Pipeline::load(source, silent, token, revision, offloading) + diffusion_rs_core::Pipeline::load(source, silent, token, revision, offloading, &dtype) .map_err(wrap_anyhow_error)?, )) }