From 672770e060cec1f78c39a89724857f9150d77c20 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Thu, 12 Dec 2024 23:35:46 -0500 Subject: [PATCH 1/4] Support loading from dduf files --- .gitignore | 3 +- Cargo.lock | 254 +++++++++++++++++- Cargo.toml | 1 + diffusers_common/Cargo.toml | 4 +- diffusers_common/src/lib.rs | 3 + diffusers_common/src/model_source.rs | 218 +++++++++++++++ diffusers_common/src/safetensors.rs | 27 ++ diffusers_common/src/varbuilder_loading.rs | 44 ++- diffusers_core/src/lib.rs | 4 +- diffusers_core/src/models/vaes/mod.rs | 14 +- diffusers_core/src/pipelines/flux/mod.rs | 69 ++--- diffusers_core/src/pipelines/mod.rs | 117 ++------ .../examples/bitsandbytes/main.rs | 6 +- diffusers_examples/examples/dduf/main.rs | 18 ++ diffusers_examples/examples/flux_dev/main.rs | 4 +- .../examples/flux_schnell/main.rs | 4 +- 16 files changed, 618 insertions(+), 172 deletions(-) create mode 100644 diffusers_common/src/model_source.rs create mode 100644 diffusers_common/src/safetensors.rs create mode 100644 diffusers_examples/examples/dduf/main.rs diff --git a/.gitignore b/.gitignore index ccb5166..e0eeee4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /target -.vscode \ No newline at end of file +.vscode +*.dduf \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index e9b367b..033f20a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -14,6 +14,17 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" +[[package]] +name = "aes" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + [[package]] name = "aho-corasick" version = "1.1.3" @@ -217,6 +228,27 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" +[[package]] +name = "bzip2" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8" +dependencies = [ + "bzip2-sys", + "libc", +] + +[[package]] +name = "bzip2-sys" +version = "0.1.11+1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + [[package]] name = "candle-core" version = "0.8.1" @@ -245,7 +277,7 @@ dependencies = [ "ug-cuda", "ug-metal", "yoke", - "zip", + "zip 1.1.4", ] [[package]] @@ -329,6 +361,16 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common", + "inout", +] + [[package]] name = "color_quant" version = "1.1.0" @@ -348,6 +390,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "constant_time_eq" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" + [[package]] name = "core-foundation" version = "0.9.4" @@ -384,6 +432,21 @@ dependencies = [ "libc", ] +[[package]] +name = "crc" +version = "3.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69e6e4d7b33a94f0991c26729976b10ebde1d34c3ee82408fb536164fa10d636" +dependencies = [ + "crc-catalog", +] + +[[package]] +name = "crc-catalog" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" + [[package]] name = "crc32fast" version = "1.4.2" @@ -504,6 +567,21 @@ dependencies = [ "syn", ] +[[package]] +name = "deflate64" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da692b8d1080ea3045efaab14434d40468c3d8657e42abddfffca87b428f4c1b" + +[[package]] +name = "deranged" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" +dependencies = [ + "powerfmt", +] + [[package]] name = "derive_arbitrary" version = "1.4.1" @@ -575,10 +653,12 @@ dependencies = [ "candle-core", "candle-nn", "dirs", + "hf-hub", "indicatif", "safetensors", "thiserror 2.0.4", "tqdm", + "zip 2.2.1", ] [[package]] @@ -619,6 +699,7 @@ checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", "crypto-common", + "subtle", ] [[package]] @@ -1055,6 +1136,15 @@ dependencies = [ "ureq", ] +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "iana-time-zone" version = "0.1.61" @@ -1285,6 +1375,15 @@ dependencies = [ "web-time", ] +[[package]] +name = "inout" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5" +dependencies = [ + "generic-array", +] + [[package]] name = "intel-mkl-src" version = "0.8.1" @@ -1444,6 +1543,12 @@ dependencies = [ "scopeguard", ] +[[package]] +name = "lockfree-object-pool" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9374ef4228402d4b7e403e5838cb880d9ee663314b0a900d5a6aabf0c213552e" + [[package]] name = "log" version = "0.4.22" @@ -1459,6 +1564,16 @@ dependencies = [ "imgref", ] +[[package]] +name = "lzma-rs" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "297e814c836ae64db86b36cf2a557ba54368d03f6afcd7d947c266692f71115e" +dependencies = [ + "byteorder", + "crc", +] + [[package]] name = "macro_rules_attribute" version = "0.2.0" @@ -1662,6 +1777,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + [[package]] name = "num-derive" version = "0.4.2" @@ -1919,6 +2040,16 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +[[package]] +name = "pbkdf2" +version = "0.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ed6a7761f76e3b9f92dfb0a60a6a6477c61024b775147ff0973a02653abaf2" +dependencies = [ + "digest", + "hmac", +] + [[package]] name = "percent-encoding" version = "2.3.1" @@ -1956,6 +2087,12 @@ version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6" +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + [[package]] name = "ppv-lite86" version = "0.2.20" @@ -2436,6 +2573,17 @@ dependencies = [ "serde", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sha2" version = "0.10.8" @@ -2689,6 +2837,25 @@ dependencies = [ "weezl", ] +[[package]] +name = "time" +version = "0.3.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35e7868883861bd0e56d9ac6efcaaca0d6d5d82a2a7ec8209ff492c07cf37b21" +dependencies = [ + "deranged", + "num-conv", + "powerfmt", + "serde", + "time-core", +] + +[[package]] +name = "time-core" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" + [[package]] name = "tinystr" version = "0.7.6" @@ -3367,6 +3534,20 @@ name = "zeroize" version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] [[package]] name = "zerovec" @@ -3405,6 +3586,77 @@ dependencies = [ "thiserror 1.0.69", ] +[[package]] +name = "zip" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99d52293fc86ea7cf13971b3bb81eb21683636e7ae24c729cdaf1b7c4157a352" +dependencies = [ + "aes", + "arbitrary", + "bzip2", + "constant_time_eq", + "crc32fast", + "crossbeam-utils", + "deflate64", + "displaydoc", + "flate2", + "hmac", + "indexmap", + "lzma-rs", + "memchr", + "pbkdf2", + "rand", + "sha1", + "thiserror 2.0.4", + "time", + "zeroize", + "zopfli", + "zstd", +] + +[[package]] +name = "zopfli" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5019f391bac5cf252e93bbcc53d039ffd62c7bfb7c150414d61369afe57e946" +dependencies = [ + "bumpalo", + "crc32fast", + "lockfree-object-pool", + "log", + "once_cell", + "simd-adler32", +] + +[[package]] +name = "zstd" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.13+zstd.1.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38ff0f21cfee8f97d94cef41359e0c89aa6113028ab0291aa8ca0038995a95aa" +dependencies = [ + "cc", + "pkg-config", +] + [[package]] name = "zune-core" version = "0.4.12" diff --git a/Cargo.toml b/Cargo.toml index 9529b4d..94b3ea6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,3 +42,4 @@ lazy_static = "1.4" paste = "1.0.15" byteorder = "1.5.0" safetensors = "0.4.1" +zip = "2.2.1" \ No newline at end of file diff --git a/diffusers_common/Cargo.toml b/diffusers_common/Cargo.toml index 5e03d5b..ebdbb0e 100644 --- a/diffusers_common/Cargo.toml +++ b/diffusers_common/Cargo.toml @@ -19,4 +19,6 @@ anyhow.workspace = true dirs.workspace = true indicatif.workspace = true tqdm.workspace = true -safetensors.workspace = true \ No newline at end of file +safetensors.workspace = true +hf-hub.workspace = true +zip.workspace = true diff --git a/diffusers_common/src/lib.rs b/diffusers_common/src/lib.rs index ef936d3..55aa0e5 100644 --- a/diffusers_common/src/lib.rs +++ b/diffusers_common/src/lib.rs @@ -1,9 +1,12 @@ +mod model_source; mod nn; mod progress; +mod safetensors; mod tokens; mod varbuilder; mod varbuilder_loading; +pub use model_source::*; pub use nn::*; pub use progress::NiceProgressBar; pub use tokens::get_token; diff --git a/diffusers_common/src/model_source.rs b/diffusers_common/src/model_source.rs new file mode 100644 index 0000000..4f40f78 --- /dev/null +++ b/diffusers_common/src/model_source.rs @@ -0,0 +1,218 @@ +use std::{ + ffi::OsStr, + fmt::Debug, + fs::{self, File}, + io::Read, + path::PathBuf, +}; + +use crate::{get_token, TokenSource}; +use hf_hub::{ + api::sync::{ApiBuilder, ApiRepo}, + Repo, RepoType, +}; +use zip::ZipArchive; + +pub enum ModelSource { + ModelId(String), + ModelIdWithTransformer { + model_id: String, + transformer_model_id: String, + }, + Dduf { + file: File, + }, +} + +impl ModelSource { + pub fn from_model_id(model_id: S) -> Self { + Self::ModelId(model_id.to_string()) + } + + pub fn override_transformer_model_id(self, model_id: S) -> anyhow::Result { + let Self::ModelId(base_id) = self else { + anyhow::bail!("Expected model ID for the model source") + }; + Ok(Self::ModelIdWithTransformer { + model_id: base_id, + transformer_model_id: model_id.to_string(), + }) + } + + pub fn dduf(filename: S) -> anyhow::Result { + Ok(Self::Dduf { + file: File::open(filename.to_string())?, + }) + } +} + +pub enum FileLoader { + Api(ApiRepo), + ApiWithTransformer { base: ApiRepo, transformer: ApiRepo }, + Dduf(ZipArchive), +} + +impl FileLoader { + pub fn from_model_source( + source: ModelSource, + silent: bool, + token: TokenSource, + revision: Option, + ) -> anyhow::Result { + match source { + ModelSource::ModelId(model_id) => { + let api_builder = ApiBuilder::new() + .with_progress(!silent) + .with_token(get_token(&token)?) + .build()?; + let revision = revision.unwrap_or("main".to_string()); + let api = api_builder.repo(Repo::with_revision( + model_id, + RepoType::Model, + revision.clone(), + )); + + Ok(Self::Api(api)) + } + ModelSource::Dduf { file } => Ok(Self::Dduf(ZipArchive::new(file)?)), + ModelSource::ModelIdWithTransformer { + model_id, + transformer_model_id, + } => { + let api_builder = ApiBuilder::new() + .with_progress(!silent) + .with_token(get_token(&token)?) + .build()?; + let revision = revision.unwrap_or("main".to_string()); + let api = api_builder.repo(Repo::with_revision( + model_id, + RepoType::Model, + revision.clone(), + )); + let transformer_api = api_builder.repo(Repo::with_revision( + transformer_model_id, + RepoType::Model, + revision.clone(), + )); + + Ok(Self::ApiWithTransformer { + base: api, + transformer: transformer_api, + }) + } + } + } + + pub fn list_files(&mut self) -> anyhow::Result> { + match self { + Self::Api(api) + | Self::ApiWithTransformer { + base: api, + transformer: _, + } => api + .info() + .map(|repo| { + repo.siblings + .iter() + .map(|x| x.rfilename.clone()) + .collect::>() + }) + .map_err(|e| anyhow::Error::msg(e.to_string())), + Self::Dduf(dduf) => (0..dduf.len()) + .map(|i| { + dduf.by_index(i) + .map(|x| x.name().to_string()) + .map_err(|e| anyhow::Error::msg(e.to_string())) + }) + .collect::>>(), + } + } + + pub fn list_transformer_files(&self) -> anyhow::Result>> { + match self { + Self::Api(_) | Self::Dduf(_) => Ok(None), + + Self::ApiWithTransformer { + base: _, + transformer: api, + } => api + .info() + .map(|repo| { + repo.siblings + .iter() + .map(|x| x.rfilename.clone()) + .collect::>() + }) + .map(Some) + .map_err(|e| anyhow::Error::msg(e.to_string())), + } + } + + pub fn read_file(&mut self, name: &str, from_transformer: bool) -> anyhow::Result { + if from_transformer && !matches!(self, Self::ApiWithTransformer { .. }) { + anyhow::bail!("This model source has no transformer files.") + } + + match (self, from_transformer) { + (Self::Api(api), false) + | ( + Self::ApiWithTransformer { + base: api, + transformer: _, + }, + false, + ) => Ok(FileData::Path( + api.get(name) + .map_err(|e| anyhow::Error::msg(e.to_string()))?, + )), + ( + Self::ApiWithTransformer { + base: api, + transformer: _, + }, + true, + ) => Ok(FileData::Path( + api.get(name) + .map_err(|e| anyhow::Error::msg(e.to_string()))?, + )), + (Self::Api(_), true) => anyhow::bail!("This model source has no transformer files."), + (Self::Dduf(dduf), _) => { + let file = dduf.by_name(name)?; + Ok(FileData::Dduf { + name: PathBuf::from(file.name().to_string()), + data: file.bytes().collect::>>()?, + }) + } + } + } +} + +pub enum FileData { + Path(PathBuf), + Dduf { name: PathBuf, data: Vec }, +} + +impl Debug for FileData { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Path(p) => write!(f, "path: {}", p.display()), + Self::Dduf { name, data: _ } => write!(f, "dduf: {}", name.display()), + } + } +} + +impl FileData { + pub fn read_to_string(&self) -> anyhow::Result { + match self { + Self::Path(p) => Ok(fs::read_to_string(p)?), + Self::Dduf { name: _, data } => Ok(String::from_utf8(data.to_vec())?), + } + } + + pub fn extension(&self) -> Option<&OsStr> { + match self { + Self::Path(p) => p.extension(), + Self::Dduf { name, data: _ } => name.extension(), + } + } +} diff --git a/diffusers_common/src/safetensors.rs b/diffusers_common/src/safetensors.rs new file mode 100644 index 0000000..39e3f57 --- /dev/null +++ b/diffusers_common/src/safetensors.rs @@ -0,0 +1,27 @@ +use candle_core::safetensors::Load; +use candle_core::{Device, Error, Result, Tensor}; +use safetensors::tensor as st; +use safetensors::tensor::SafeTensors; + +pub struct BytesSafetensors<'a> { + safetensors: SafeTensors<'a>, +} + +impl<'a> BytesSafetensors<'a> { + pub fn new(bytes: &'a [u8]) -> Result> { + let st = safetensors::SafeTensors::deserialize(bytes).map_err(|e| Error::from(e))?; + Ok(Self { safetensors: st }) + } + + pub fn load(&self, name: &str, dev: &Device) -> Result { + self.get(name)?.load(dev) + } + + pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> { + self.safetensors.tensors() + } + + pub fn get(&self, name: &str) -> Result> { + Ok(self.safetensors.tensor(name)?) + } +} diff --git a/diffusers_common/src/varbuilder_loading.rs b/diffusers_common/src/varbuilder_loading.rs index 2e2117f..8c52967 100644 --- a/diffusers_common/src/varbuilder_loading.rs +++ b/diffusers_common/src/varbuilder_loading.rs @@ -2,17 +2,15 @@ use std::{ collections::HashMap, - path::PathBuf, thread::{self, JoinHandle}, }; use crate::{ + safetensors::BytesSafetensors, varbuilder::{SimpleBackend, VarBuilderArgs}, - VarBuilder, -}; -use candle_core::{ - pickle::PthTensors, safetensors::MmapedSafetensors, DType, Device, Result, Tensor, + FileData, VarBuilder, }; +use candle_core::{safetensors::MmapedSafetensors, DType, Device, Result, Tensor}; use super::progress::IterWithProgress; @@ -41,20 +39,18 @@ impl TensorLoaderBackend for SafetensorBackend { } } -struct PickleBackend(PthTensors); +struct BytesSafetensorBackend<'a>(BytesSafetensors<'a>); -impl TensorLoaderBackend for PickleBackend { +impl<'a> TensorLoaderBackend for BytesSafetensorBackend<'a> { fn get_names(&self) -> Vec { - self.0.tensor_infos().keys().cloned().collect::>() + self.0 + .tensors() + .into_iter() + .map(|(name, _)| name) + .collect::>() } fn load_name(&self, name: &str, device: &Device, dtype: Option) -> Result { - let t = self - .0 - .get(name)? - .ok_or(candle_core::Error::Msg(format!( - "Could not load tensor {name}" - )))? - .to_device(device)?; + let t = self.0.load(name, device)?; if let Some(dtype) = dtype { t.to_dtype(dtype) } else { @@ -70,7 +66,7 @@ impl TensorLoaderBackend for PickleBackend { /// - If `regexes` is specified, this will be used in `make_dummy_predicate` based on `.any` /// - Otherwise, only include keys for which predicate evaluates to true. pub fn from_mmaped_safetensors<'a>( - paths: Vec, + paths: Vec, dtype: Option, device: &Device, silent: bool, @@ -104,7 +100,7 @@ pub fn from_mmaped_safetensors<'a>( trait LoadTensors { fn load_tensors_from_path( &self, - path: &PathBuf, + path: &FileData, device: &Device, dtype: Option, silent: bool, @@ -115,12 +111,14 @@ trait LoadTensors { .to_str() .expect("Expected to convert") { - "safetensors" => Box::new(SafetensorBackend(unsafe { - candle_core::safetensors::MmapedSafetensors::new(path)? - })), - "pth" | "pt" | "bin" => Box::new(PickleBackend( - candle_core::pickle::PthTensors::new(path, None)? - )), + "safetensors" => match path { + FileData::Dduf { name: _, data } => { + Box::new(BytesSafetensorBackend(BytesSafetensors::new(data)?)) + } + FileData::Path(path) => {Box::new(SafetensorBackend(unsafe { + candle_core::safetensors::MmapedSafetensors::new(path)? + }))} + }, other => candle_core::bail!("Unexpected extension `{other}`, this should have been handles by `get_model_paths`."), }; diff --git a/diffusers_core/src/lib.rs b/diffusers_core/src/lib.rs index 4ff081d..a658893 100644 --- a/diffusers_core/src/lib.rs +++ b/diffusers_core/src/lib.rs @@ -1,5 +1,5 @@ mod models; mod pipelines; -pub use diffusers_common::TokenSource; -pub use pipelines::{DiffusionGenerationParams, ModelPaths, Pipeline}; +pub use diffusers_common::{ModelSource, TokenSource}; +pub use pipelines::{DiffusionGenerationParams, Pipeline}; diff --git a/diffusers_core/src/models/vaes/mod.rs b/diffusers_core/src/models/vaes/mod.rs index 0a7443a..b315c0e 100644 --- a/diffusers_core/src/models/vaes/mod.rs +++ b/diffusers_core/src/models/vaes/mod.rs @@ -1,10 +1,10 @@ -use std::{fs, path::PathBuf, sync::Arc}; +use std::sync::Arc; use autoencoder_kl::{AutencoderKlConfig, AutoEncoderKl}; use candle_core::{Device, Result, Tensor}; use serde::Deserialize; -use diffusers_common::{from_mmaped_safetensors, VarBuilder}; +use diffusers_common::{from_mmaped_safetensors, FileData, VarBuilder}; mod autoencoder_kl; mod vae; @@ -30,19 +30,19 @@ struct VaeConfigShim { name: String, } -fn load_autoencoder_kl(cfg_json: &PathBuf, vb: VarBuilder) -> anyhow::Result> { - let cfg: AutencoderKlConfig = serde_json::from_str(&fs::read_to_string(cfg_json)?)?; +fn load_autoencoder_kl(cfg_json: &FileData, vb: VarBuilder) -> anyhow::Result> { + let cfg: AutencoderKlConfig = serde_json::from_str(&cfg_json.read_to_string()?)?; Ok(Arc::new(AutoEncoderKl::new(&cfg, vb)?)) } pub(crate) fn dispatch_load_vae_model( - cfg_json: &PathBuf, - safetensor_files: Vec, + cfg_json: &FileData, + safetensor_files: Vec, device: &Device, ) -> anyhow::Result> { let vb = from_mmaped_safetensors(safetensor_files, None, device, false)?; - let VaeConfigShim { name } = serde_json::from_str(&fs::read_to_string(cfg_json)?)?; + let VaeConfigShim { name } = serde_json::from_str(&cfg_json.read_to_string()?)?; match name.as_str() { "AutoencoderKL" => load_autoencoder_kl(cfg_json, vb), other => anyhow::bail!("Unexpected VAE type `{other:?}`."), diff --git a/diffusers_core/src/pipelines/flux/mod.rs b/diffusers_core/src/pipelines/flux/mod.rs index 36445a4..6af1be0 100644 --- a/diffusers_core/src/pipelines/flux/mod.rs +++ b/diffusers_core/src/pipelines/flux/mod.rs @@ -1,4 +1,4 @@ -use std::{cmp::Ordering, collections::HashMap, fs, sync::Arc}; +use std::{cmp::Ordering, collections::HashMap, sync::Arc}; use anyhow::Result; use candle_core::{DType, Device, Tensor, D}; @@ -46,30 +46,30 @@ impl Loader for FluxLoader { fn load_from_components( &self, - components: HashMap, + mut components: HashMap, device: &Device, ) -> Result> { - let scheduler = components[&ComponentName::Scheduler].clone(); - let clip_component = components[&ComponentName::TextEncoder(1)].clone(); - let t5_component = components[&ComponentName::TextEncoder(2)].clone(); - let clip_tok_component = components[&ComponentName::Tokenizer(1)].clone(); - let t5_tok_component = components[&ComponentName::Tokenizer(2)].clone(); - let flux_component = components[&ComponentName::Transformer].clone(); - let vae_component = components[&ComponentName::Vae].clone(); + let scheduler = components.remove(&ComponentName::Scheduler).unwrap(); + let clip_component = components.remove(&ComponentName::TextEncoder(1)).unwrap(); + let t5_component = components.remove(&ComponentName::TextEncoder(2)).unwrap(); + let clip_tok_component = components.remove(&ComponentName::Tokenizer(1)).unwrap(); + let t5_tok_component = components.remove(&ComponentName::Tokenizer(2)).unwrap(); + let flux_component = components.remove(&ComponentName::Transformer).unwrap(); + let vae_component = components.remove(&ComponentName::Vae).unwrap(); let scheduler_config = if let ComponentElem::Config { files } = scheduler { - serde_json::from_str::(&fs::read_to_string( - files["scheduler/scheduler_config.json"].clone(), - )?)? + serde_json::from_str::( + &files["scheduler/scheduler_config.json"].read_to_string()?, + )? } else { anyhow::bail!("expected scheduler config") }; let clip_tokenizer = if let ComponentElem::Other { files } = clip_tok_component { - let vocab_file = files["tokenizer/vocab.json"].clone(); - let merges_file = files["tokenizer/merges.txt"].clone(); - let vocab: HashMap = - serde_json::from_str(&fs::read_to_string(vocab_file)?)?; - let merges: Vec<(String, String)> = fs::read_to_string(merges_file)? + let vocab_file = &files["tokenizer/vocab.json"]; + let merges_file = &files["tokenizer/merges.txt"]; + let vocab: HashMap = serde_json::from_str(&vocab_file.read_to_string()?)?; + let merges: Vec<(String, String)> = merges_file + .read_to_string()? .split('\n') .skip(1) .map(|x| x.split(' ').collect::>()) @@ -82,7 +82,7 @@ impl Loader for FluxLoader { anyhow::bail!("incorrect storage of clip tokenizer") }; let t5_tokenizer = if let ComponentElem::Other { files } = t5_tok_component { - Tokenizer::from_file(files["tokenizer_2/tokenizer.json"].clone()) + Tokenizer::from_bytes(files["tokenizer_2/tokenizer.json"].read_to_string()?) .map_err(anyhow::Error::msg)? } else { anyhow::bail!("incorrect storage of t5 tokenizer") @@ -92,13 +92,10 @@ impl Loader for FluxLoader { config, } = clip_component { - let cfg: ClipTextConfig = serde_json::from_str(&fs::read_to_string(config)?)?; - let vb = from_mmaped_safetensors( - safetensors.values().cloned().collect(), - None, - device, - false, - )?; + let cfg: ClipTextConfig = serde_json::from_str(&config.read_to_string()?)?; + + let vb = + from_mmaped_safetensors(safetensors.into_values().collect(), None, device, false)?; ClipTextTransformer::new(vb.pp("text_model"), &cfg)? } else { anyhow::bail!("incorrect storage of clip model") @@ -108,13 +105,9 @@ impl Loader for FluxLoader { config, } = t5_component { - let cfg: T5Config = serde_json::from_str(&fs::read_to_string(config)?)?; - let vb = from_mmaped_safetensors( - safetensors.values().cloned().collect(), - None, - device, - false, - )?; + let cfg: T5Config = serde_json::from_str(&config.read_to_string()?)?; + let vb = + from_mmaped_safetensors(safetensors.into_values().collect(), None, device, false)?; T5EncoderModel::new(vb, &cfg, device)? } else { anyhow::bail!("incorrect storage of t5 model") @@ -124,7 +117,7 @@ impl Loader for FluxLoader { config, } = vae_component { - dispatch_load_vae_model(&config, safetensors.values().cloned().collect(), device)? + dispatch_load_vae_model(&config, safetensors.into_values().collect(), device)? } else { anyhow::bail!("incorrect storage of vae model") }; @@ -133,13 +126,9 @@ impl Loader for FluxLoader { config, } = flux_component { - let cfg: FluxConfig = serde_json::from_str(&fs::read_to_string(config)?)?; - let vb = from_mmaped_safetensors( - safetensors.values().cloned().collect(), - None, - device, - false, - )?; + let cfg: FluxConfig = serde_json::from_str(&config.read_to_string()?)?; + let vb = + from_mmaped_safetensors(safetensors.into_values().collect(), None, device, false)?; FluxModel::new(&cfg, vb)? } else { anyhow::bail!("incorrect storage of flux model") diff --git a/diffusers_core/src/pipelines/mod.rs b/diffusers_core/src/pipelines/mod.rs index 0b23bc3..a1f5d9d 100644 --- a/diffusers_core/src/pipelines/mod.rs +++ b/diffusers_core/src/pipelines/mod.rs @@ -1,15 +1,14 @@ mod flux; -use std::{collections::HashMap, fmt::Display, fs, path::PathBuf, sync::Arc}; +use std::{collections::HashMap, fmt::Display, sync::Arc}; use anyhow::Result; use candle_core::{Device, Tensor}; use flux::FluxLoader; -use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; use image::{DynamicImage, RgbImage}; use serde::Deserialize; -use diffusers_common::{get_token, TokenSource}; +use diffusers_common::{FileData, FileLoader, ModelSource, TokenSource}; #[derive(Debug, Clone)] pub struct DiffusionGenerationParams { @@ -35,17 +34,17 @@ impl Default for DiffusionGenerationParams { } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub(crate) enum ComponentElem { Model { - safetensors: HashMap, - config: PathBuf, + safetensors: HashMap, + config: FileData, }, Config { - files: HashMap, + files: HashMap, }, Other { - files: HashMap, + files: HashMap, }, } @@ -95,103 +94,41 @@ struct ModelIndex { name: String, } -pub struct ModelPaths { - pub model_id: String, - pub transformer_model_id: Option, -} - -impl ModelPaths { - pub fn from_model_id(model_id: S) -> Self { - Self { - model_id: model_id.to_string(), - transformer_model_id: None, - } - } - pub fn override_transformer_model_id(mut self, model_id: S) -> Self { - self.transformer_model_id = Some(model_id.to_string()); - self - } -} - pub struct Pipeline(Arc); impl Pipeline { pub fn load( - model_paths: ModelPaths, + source: ModelSource, silent: bool, token: TokenSource, revision: Option, ) -> Result { - let api_builder = ApiBuilder::new() - .with_progress(!silent) - .with_token(get_token(&token)?) - .build()?; - let revision = revision.unwrap_or("main".to_string()); - let api = api_builder.repo(Repo::with_revision( - model_paths.model_id, - RepoType::Model, - revision.clone(), - )); - let transformer_api = - model_paths - .transformer_model_id - .as_ref() - .map(|transformer_model_id| { - api_builder.repo(Repo::with_revision( - transformer_model_id.clone(), - RepoType::Model, - revision.clone(), - )) - }); - - let files = api - .info() - .map(|repo| { - repo.siblings - .iter() - .map(|x| x.rfilename.clone()) - .collect::>() - }) - .map_err(|e| anyhow::Error::msg(e.to_string()))?; - - let transformer_files = if let Some(transformer_api) = &transformer_api { - Some( - transformer_api - .info() - .map(|repo| { - repo.siblings - .iter() - .map(|x| x.rfilename.clone()) - .collect::>() - }) - .map_err(|e| anyhow::Error::msg(e.to_string()))?, - ) - } else { - None - }; + let mut loader = FileLoader::from_model_source(source, silent, token, revision)?; + let files = loader.list_files()?; + let transformer_files = loader.list_transformer_files()?; if !files.contains(&"model_index.json".to_string()) { anyhow::bail!("Expected `model_index.json` file present."); } - let ModelIndex { name } = - serde_json::from_str(&fs::read_to_string(api.get("model_index.json")?)?)?; + let ModelIndex { name } = serde_json::from_str( + &loader + .read_file("model_index.json", false)? + .read_to_string()?, + )?; - let loader = match name.as_str() { + let model_loader = match name.as_str() { "FluxPipeline" => Box::new(FluxLoader), other => anyhow::bail!("Unexpected loader type `{other:?}`."), }; let mut components = HashMap::new(); - for component in loader.required_component_names() { - let (files, api_to_use, dir) = + for component in model_loader.required_component_names() { + dbg!(&component); + let (files, from_transformer, dir) = if component == ComponentName::Transformer && transformer_files.is_some() { - ( - transformer_files.clone().unwrap(), - transformer_api.as_ref().unwrap(), - "".to_string(), - ) + (transformer_files.clone().unwrap(), true, "".to_string()) } else { - (files.clone(), &api, format!("{component}/")) + (files.clone(), false, format!("{component}/")) }; let files_for_component = files .iter() @@ -212,11 +149,11 @@ impl Pipeline { .iter() .filter(|file| file.ends_with(".safetensors")) { - safetensors.insert(file.clone(), api_to_use.get(file)?); + safetensors.insert(file.clone(), loader.read_file(file, from_transformer)?); } ComponentElem::Model { safetensors, - config: api_to_use.get(&format!("{dir}config.json"))?, + config: loader.read_file(&format!("{dir}config.json"), from_transformer)?, } } else if files_for_component .iter() @@ -227,13 +164,13 @@ impl Pipeline { .iter() .filter(|file| file.ends_with(".json")) { - files.insert(file.clone(), api_to_use.get(file)?); + files.insert(file.clone(), loader.read_file(file, from_transformer)?); } ComponentElem::Config { files } } else { let mut files = HashMap::new(); for file in files_for_component { - files.insert(file.clone(), api_to_use.get(&file)?); + files.insert(file.clone(), loader.read_file(&file, from_transformer)?); } ComponentElem::Other { files } }; @@ -245,7 +182,7 @@ impl Pipeline { #[cfg(feature = "metal")] let device = Device::new_metal(0)?; - let model = loader.load_from_components(components, &device)?; + let model = model_loader.load_from_components(components, &device)?; Ok(Self(model)) } diff --git a/diffusers_examples/examples/bitsandbytes/main.rs b/diffusers_examples/examples/bitsandbytes/main.rs index 4a967a8..c43f4b6 100644 --- a/diffusers_examples/examples/bitsandbytes/main.rs +++ b/diffusers_examples/examples/bitsandbytes/main.rs @@ -1,9 +1,9 @@ -use diffusers_core::{DiffusionGenerationParams, ModelPaths, Pipeline, TokenSource}; +use diffusers_core::{DiffusionGenerationParams, ModelSource, Pipeline, TokenSource}; fn main() -> anyhow::Result<()> { let pipeline = Pipeline::load( - ModelPaths::from_model_id("black-forest-labs/FLUX.1-dev") - .override_transformer_model_id("sayakpaul/flux.1-dev-nf4-with-bnb-integration"), + ModelSource::from_model_id("black-forest-labs/FLUX.1-dev") + .override_transformer_model_id("sayakpaul/flux.1-dev-nf4-with-bnb-integration")?, false, TokenSource::CacheToken, None, diff --git a/diffusers_examples/examples/dduf/main.rs b/diffusers_examples/examples/dduf/main.rs new file mode 100644 index 0000000..89e8052 --- /dev/null +++ b/diffusers_examples/examples/dduf/main.rs @@ -0,0 +1,18 @@ +use diffusers_core::{DiffusionGenerationParams, ModelSource, Pipeline, TokenSource}; + +fn main() -> anyhow::Result<()> { + let pipeline = Pipeline::load( + ModelSource::dduf("FLUX.1-dev-Q4-bnb.dduf")?, + false, + TokenSource::CacheToken, + None, + )?; + + let images = pipeline.forward( + vec!["Draw a picture of a beautiful sunset in the winter in the mountains, 4k, high quality.".to_string()], + DiffusionGenerationParams::default(), + )?; + images[0].save("image_schnell.png")?; + + Ok(()) +} diff --git a/diffusers_examples/examples/flux_dev/main.rs b/diffusers_examples/examples/flux_dev/main.rs index 0618411..c3cf37f 100644 --- a/diffusers_examples/examples/flux_dev/main.rs +++ b/diffusers_examples/examples/flux_dev/main.rs @@ -1,8 +1,8 @@ -use diffusers_core::{DiffusionGenerationParams, ModelPaths, Pipeline, TokenSource}; +use diffusers_core::{DiffusionGenerationParams, ModelSource, Pipeline, TokenSource}; fn main() -> anyhow::Result<()> { let pipeline = Pipeline::load( - ModelPaths::from_model_id("black-forest-labs/FLUX.1-dev"), + ModelSource::from_model_id("black-forest-labs/FLUX.1-dev"), false, TokenSource::CacheToken, None, diff --git a/diffusers_examples/examples/flux_schnell/main.rs b/diffusers_examples/examples/flux_schnell/main.rs index 22ff993..94014d4 100644 --- a/diffusers_examples/examples/flux_schnell/main.rs +++ b/diffusers_examples/examples/flux_schnell/main.rs @@ -1,8 +1,8 @@ -use diffusers_core::{DiffusionGenerationParams, ModelPaths, Pipeline, TokenSource}; +use diffusers_core::{DiffusionGenerationParams, ModelSource, Pipeline, TokenSource}; fn main() -> anyhow::Result<()> { let pipeline = Pipeline::load( - ModelPaths::from_model_id("black-forest-labs/FLUX.1-schnell"), + ModelSource::from_model_id("black-forest-labs/FLUX.1-schnell"), false, TokenSource::CacheToken, None, From 9b25b9fdc97cab46385c1971330d2e3d61cd4791 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Fri, 13 Dec 2024 06:13:56 -0500 Subject: [PATCH 2/4] Use mmap for much faster loading --- Cargo.lock | 1 + Cargo.toml | 3 ++- diffusers_common/Cargo.toml | 1 + diffusers_common/src/model_source.rs | 21 +++++++++++++-------- 4 files changed, 17 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 033f20a..0dad322 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -655,6 +655,7 @@ dependencies = [ "dirs", "hf-hub", "indicatif", + "memmap2", "safetensors", "thiserror 2.0.4", "tqdm", diff --git a/Cargo.toml b/Cargo.toml index 94b3ea6..c4429f4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,4 +42,5 @@ lazy_static = "1.4" paste = "1.0.15" byteorder = "1.5.0" safetensors = "0.4.1" -zip = "2.2.1" \ No newline at end of file +zip = "2.2.1" +memmap2 = "0.9.5" \ No newline at end of file diff --git a/diffusers_common/Cargo.toml b/diffusers_common/Cargo.toml index ebdbb0e..c90eaec 100644 --- a/diffusers_common/Cargo.toml +++ b/diffusers_common/Cargo.toml @@ -22,3 +22,4 @@ tqdm.workspace = true safetensors.workspace = true hf-hub.workspace = true zip.workspace = true +memmap2.workspace = true diff --git a/diffusers_common/src/model_source.rs b/diffusers_common/src/model_source.rs index 4f40f78..d8b574a 100644 --- a/diffusers_common/src/model_source.rs +++ b/diffusers_common/src/model_source.rs @@ -2,7 +2,7 @@ use std::{ ffi::OsStr, fmt::Debug, fs::{self, File}, - io::Read, + io::Cursor, path::PathBuf, }; @@ -11,6 +11,7 @@ use hf_hub::{ api::sync::{ApiBuilder, ApiRepo}, Repo, RepoType, }; +use memmap2::Mmap; use zip::ZipArchive; pub enum ModelSource { @@ -49,7 +50,7 @@ impl ModelSource { pub enum FileLoader { Api(ApiRepo), ApiWithTransformer { base: ApiRepo, transformer: ApiRepo }, - Dduf(ZipArchive), + Dduf(ZipArchive>), } impl FileLoader { @@ -74,7 +75,11 @@ impl FileLoader { Ok(Self::Api(api)) } - ModelSource::Dduf { file } => Ok(Self::Dduf(ZipArchive::new(file)?)), + ModelSource::Dduf { file } => { + let mmap = unsafe { Mmap::map(&file)? }; + let cursor = Cursor::new(mmap); + Ok(Self::Dduf(ZipArchive::new(cursor)?)) + } ModelSource::ModelIdWithTransformer { model_id, transformer_model_id, @@ -177,11 +182,11 @@ impl FileLoader { )), (Self::Api(_), true) => anyhow::bail!("This model source has no transformer files."), (Self::Dduf(dduf), _) => { - let file = dduf.by_name(name)?; - Ok(FileData::Dduf { - name: PathBuf::from(file.name().to_string()), - data: file.bytes().collect::>>()?, - }) + let mut file = dduf.by_name(name)?; + let mut data = Vec::new(); + std::io::copy(&mut file, &mut data)?; + let name = PathBuf::from(file.name().to_string()); + Ok(FileData::Dduf { name, data }) } } } From b1dbf9e8812892250eaa370048570af2f057db2d Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Fri, 13 Dec 2024 07:11:31 -0500 Subject: [PATCH 3/4] Fix loading --- diffusers_backend/src/lib.rs | 60 +++++---- diffusers_common/src/varbuilder_loading.rs | 1 + diffusers_core/src/models/t5/mod.rs | 150 +++++++++++---------- 3 files changed, 113 insertions(+), 98 deletions(-) diff --git a/diffusers_backend/src/lib.rs b/diffusers_backend/src/lib.rs index 6a4b69d..ce2ac2d 100644 --- a/diffusers_backend/src/lib.rs +++ b/diffusers_backend/src/lib.rs @@ -189,26 +189,33 @@ impl Module for dyn QuantMethod { } } +fn vb_contains_quant(vb: &VarBuilder) -> bool { + vb.contains_tensor("weight.absmax") +} + pub fn linear_no_bias( in_dim: usize, out_dim: usize, config: &Option, vb: VarBuilder, ) -> Result> { - let layer = if let Some(quant_conf) = &config { - match quant_conf.quant_method { - QuantMethodType::Bitsandbytes => { - Arc::new(BnbLinear::linear_b(in_dim, out_dim, false, vb)?) as Arc<_> - } - QuantMethodType::Unreachable => unreachable!(), + if vb_contains_quant(&vb) { + if let Some(quant_conf) = &config { + let layer = match quant_conf.quant_method { + QuantMethodType::Bitsandbytes => { + Arc::new(BnbLinear::linear_b(in_dim, out_dim, false, vb)?) as Arc<_> + } + QuantMethodType::Unreachable => unreachable!(), + }; + return Ok(layer); } - } else { - let ws = vb.get((out_dim, in_dim), "weight")?; - let layer = Linear::new(ws, None); + } - let layer = ::new(QuantMethodConfig::Unquantized(layer))?; - Arc::new(layer) as Arc - }; + let ws = vb.get((out_dim, in_dim), "weight")?; + let layer = Linear::new(ws, None); + + let layer = ::new(QuantMethodConfig::Unquantized(layer))?; + let layer = Arc::new(layer) as Arc; Ok(layer) } @@ -218,21 +225,24 @@ pub fn linear( config: &Option, vb: VarBuilder, ) -> Result> { - let layer = if let Some(quant_conf) = &config { - match quant_conf.quant_method { - QuantMethodType::Bitsandbytes => { - Arc::new(BnbLinear::linear_b(in_dim, out_dim, true, vb)?) as Arc<_> - } - QuantMethodType::Unreachable => unreachable!(), + if vb_contains_quant(&vb) { + if let Some(quant_conf) = &config { + let layer = match quant_conf.quant_method { + QuantMethodType::Bitsandbytes => { + Arc::new(BnbLinear::linear_b(in_dim, out_dim, true, vb)?) as Arc<_> + } + QuantMethodType::Unreachable => unreachable!(), + }; + return Ok(layer); } - } else { - let ws = vb.get((out_dim, in_dim), "weight")?; - let bs = vb.get(out_dim, "bias")?; - let layer = Linear::new(ws, Some(bs)); + } + + let ws = vb.get((out_dim, in_dim), "weight")?; + let bs = vb.get(out_dim, "bias")?; + let layer = Linear::new(ws, Some(bs)); - let layer = ::new(QuantMethodConfig::Unquantized(layer))?; - Arc::new(layer) as Arc - }; + let layer = ::new(QuantMethodConfig::Unquantized(layer))?; + let layer = Arc::new(layer) as Arc; Ok(layer) } diff --git a/diffusers_common/src/varbuilder_loading.rs b/diffusers_common/src/varbuilder_loading.rs index 8c52967..51b3685 100644 --- a/diffusers_common/src/varbuilder_loading.rs +++ b/diffusers_common/src/varbuilder_loading.rs @@ -90,6 +90,7 @@ pub fn from_mmaped_safetensors<'a>( } let first_dtype = DType::BF16; //ws.values().next().unwrap().dtype(); + dbg!(&ws.keys().collect::>()); Ok(VarBuilder::from_tensors( ws, dtype.unwrap_or(first_dtype), diff --git a/diffusers_core/src/models/t5/mod.rs b/diffusers_core/src/models/t5/mod.rs index 387452b..e6a9f6a 100644 --- a/diffusers_core/src/models/t5/mod.rs +++ b/diffusers_core/src/models/t5/mod.rs @@ -4,8 +4,9 @@ // https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py use candle_core::{DType, Device, Module, Result, Tensor, D}; -use candle_nn::{Activation, Embedding, Linear}; -use diffusers_common::{embedding, linear_no_bias, VarBuilder}; +use candle_nn::{Activation, Embedding}; +use diffusers_backend::{linear_no_bias, QuantMethod, QuantizedConfig}; +use diffusers_common::{embedding, VarBuilder}; use serde::Deserialize; use std::sync::Arc; @@ -21,10 +22,6 @@ fn default_use_cache() -> bool { true } -fn default_tie_word_embeddings() -> bool { - true -} - fn get_mask(size: usize, device: &Device) -> Result { let mask: Vec<_> = (0..size) .flat_map(|i| (0..size).map(move |j| u8::from(j > i))) @@ -70,63 +67,25 @@ where } } -#[derive(Debug, Clone, PartialEq, Deserialize)] +#[derive(Debug, Clone, Deserialize)] pub struct T5Config { pub vocab_size: usize, pub d_model: usize, pub d_kv: usize, pub d_ff: usize, pub num_layers: usize, - pub num_decoder_layers: Option, pub num_heads: usize, pub relative_attention_num_buckets: usize, #[serde(default = "default_relative_attention_max_distance")] pub relative_attention_max_distance: usize, - pub dropout_rate: f64, pub layer_norm_epsilon: f64, - pub initializer_factor: f64, #[serde(default, deserialize_with = "deserialize_feed_forward_proj_activation")] pub feed_forward_proj: ActivationWithOptionalGating, - #[serde(default = "default_tie_word_embeddings")] - pub tie_word_embeddings: bool, #[serde(default = "default_is_decoder")] pub is_decoder: bool, - pub is_encoder_decoder: bool, #[serde(default = "default_use_cache")] pub use_cache: bool, - pub pad_token_id: usize, - pub eos_token_id: usize, - pub decoder_start_token_id: Option, -} - -impl Default for T5Config { - fn default() -> Self { - Self { - vocab_size: 32128, - d_model: 512, - d_kv: 64, - d_ff: 2048, - num_layers: 6, - num_decoder_layers: None, - num_heads: 8, - relative_attention_num_buckets: 32, - relative_attention_max_distance: 128, - dropout_rate: 0.1, - layer_norm_epsilon: 1e-6, - initializer_factor: 1.0, - feed_forward_proj: ActivationWithOptionalGating { - gated: false, - activation: Activation::Relu, - }, - tie_word_embeddings: true, - is_decoder: false, - is_encoder_decoder: true, - use_cache: true, - pad_token_id: 0, - eos_token_id: 1, - decoder_start_token_id: Some(0), - } - } + pub quantization_config: Option, } #[derive(Debug, Clone)] @@ -160,15 +119,25 @@ impl Module for T5LayerNorm { #[derive(Debug, Clone)] struct T5DenseActDense { - wi: Linear, - wo: Linear, + wi: Arc, + wo: Arc, act: Activation, } impl T5DenseActDense { fn load(vb: VarBuilder, cfg: &T5Config) -> Result { - let wi = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi"))?; - let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?; + let wi = linear_no_bias( + cfg.d_model, + cfg.d_ff, + &&cfg.quantization_config, + vb.pp("wi"), + )?; + let wo = linear_no_bias( + cfg.d_ff, + cfg.d_model, + &&cfg.quantization_config, + vb.pp("wo"), + )?; Ok(Self { wi, wo, @@ -179,26 +148,41 @@ impl T5DenseActDense { impl Module for T5DenseActDense { fn forward(&self, xs: &Tensor) -> Result { - let xs = self.wi.forward(xs)?; + let xs = self.wi.forward_autocast(xs)?; let xs = self.act.forward(&xs)?; - let xs = self.wo.forward(&xs)?; + let xs = self.wo.forward_autocast(&xs)?; Ok(xs) } } #[derive(Debug, Clone)] struct T5DenseGatedActDense { - wi_0: Linear, - wi_1: Linear, - wo: Linear, + wi_0: Arc, + wi_1: Arc, + wo: Arc, act: Activation, } impl T5DenseGatedActDense { fn load(vb: VarBuilder, cfg: &T5Config) -> Result { - let wi_0 = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?; - let wi_1 = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?; - let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?; + let wi_0 = linear_no_bias( + cfg.d_model, + cfg.d_ff, + &&cfg.quantization_config, + vb.pp("wi_0"), + )?; + let wi_1 = linear_no_bias( + cfg.d_model, + cfg.d_ff, + &&cfg.quantization_config, + vb.pp("wi_1"), + )?; + let wo = linear_no_bias( + cfg.d_ff, + cfg.d_model, + &&cfg.quantization_config, + vb.pp("wo"), + )?; Ok(Self { wi_0, wi_1, @@ -210,10 +194,10 @@ impl T5DenseGatedActDense { impl Module for T5DenseGatedActDense { fn forward(&self, xs: &Tensor) -> Result { - let hidden_gelu = self.act.forward(&self.wi_0.forward(xs)?)?; - let hidden_linear = self.wi_1.forward(xs)?; + let hidden_gelu = self.act.forward(&self.wi_0.forward_autocast(xs)?)?; + let hidden_linear = self.wi_1.forward_autocast(xs)?; let xs = hidden_gelu.broadcast_mul(&hidden_linear)?; - let xs = self.wo.forward(&xs)?; + let xs = self.wo.forward_autocast(&xs)?; Ok(xs) } } @@ -262,10 +246,10 @@ impl Module for T5LayerFF { #[derive(Debug, Clone)] struct T5Attention { - q: Linear, - k: Linear, - v: Linear, - o: Linear, + q: Arc, + k: Arc, + v: Arc, + o: Arc, n_heads: usize, d_kv: usize, relative_attention_bias: Option, @@ -283,10 +267,30 @@ impl T5Attention { cfg: &T5Config, ) -> Result { let inner_dim = cfg.num_heads * cfg.d_kv; - let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?; - let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?; - let v = linear_no_bias(cfg.d_model, inner_dim, vb.pp("v"))?; - let o = linear_no_bias(inner_dim, cfg.d_model, vb.pp("o"))?; + let q = linear_no_bias( + cfg.d_model, + inner_dim, + &&cfg.quantization_config, + vb.pp("q"), + )?; + let k = linear_no_bias( + cfg.d_model, + inner_dim, + &&cfg.quantization_config, + vb.pp("k"), + )?; + let v = linear_no_bias( + cfg.d_model, + inner_dim, + &&cfg.quantization_config, + vb.pp("v"), + )?; + let o = linear_no_bias( + inner_dim, + cfg.d_model, + &&cfg.quantization_config, + vb.pp("o"), + )?; let relative_attention_bias = if has_relative_attention_bias { let emb = embedding( cfg.relative_attention_num_buckets, @@ -327,9 +331,9 @@ impl T5Attention { }; let (b_sz, q_len) = (xs.dim(0)?, xs.dim(1)?); let kv_len = kv_input.dim(1)?; - let q = self.q.forward(xs)?; - let k = self.k.forward(kv_input)?; - let v = self.v.forward(kv_input)?; + let q = self.q.forward_autocast(xs)?; + let k = self.k.forward_autocast(kv_input)?; + let v = self.v.forward_autocast(kv_input)?; let q = q .reshape((b_sz, q_len, self.n_heads, self.d_kv))? .transpose(1, 2)? @@ -421,7 +425,7 @@ impl T5Attention { let attn_output = attn_output .transpose(1, 2)? .reshape((b_sz, q_len, self.inner_dim))?; - let attn_output = self.o.forward(&attn_output)?; + let attn_output = self.o.forward_autocast(&attn_output)?; Ok((attn_output, position_bias)) } } From 1c0e75ff33b7ac1d0a060ac542365441d0c7613e Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Fri, 13 Dec 2024 07:13:29 -0500 Subject: [PATCH 4/4] Clippy --- diffusers_backend/src/unquantized/mod.rs | 2 +- diffusers_common/src/model_source.rs | 13 ++++-- diffusers_common/src/safetensors.rs | 2 +- diffusers_common/src/varbuilder_loading.rs | 2 +- diffusers_core/src/models/t5/mod.rs | 53 ++++------------------ 5 files changed, 20 insertions(+), 52 deletions(-) diff --git a/diffusers_backend/src/unquantized/mod.rs b/diffusers_backend/src/unquantized/mod.rs index e469c65..8659a32 100644 --- a/diffusers_backend/src/unquantized/mod.rs +++ b/diffusers_backend/src/unquantized/mod.rs @@ -54,7 +54,7 @@ impl QuantMethod for UnquantLinear { { cublaslt .batch_matmul( - &a, + a, &w, Some(&b.t()?.contiguous()?), None, diff --git a/diffusers_common/src/model_source.rs b/diffusers_common/src/model_source.rs index d8b574a..1ab55fa 100644 --- a/diffusers_common/src/model_source.rs +++ b/diffusers_common/src/model_source.rs @@ -48,8 +48,11 @@ impl ModelSource { } pub enum FileLoader { - Api(ApiRepo), - ApiWithTransformer { base: ApiRepo, transformer: ApiRepo }, + Api(Box), + ApiWithTransformer { + base: Box, + transformer: Box, + }, Dduf(ZipArchive>), } @@ -73,7 +76,7 @@ impl FileLoader { revision.clone(), )); - Ok(Self::Api(api)) + Ok(Self::Api(Box::new(api))) } ModelSource::Dduf { file } => { let mmap = unsafe { Mmap::map(&file)? }; @@ -101,8 +104,8 @@ impl FileLoader { )); Ok(Self::ApiWithTransformer { - base: api, - transformer: transformer_api, + base: Box::new(api), + transformer: Box::new(transformer_api), }) } } diff --git a/diffusers_common/src/safetensors.rs b/diffusers_common/src/safetensors.rs index 39e3f57..4889997 100644 --- a/diffusers_common/src/safetensors.rs +++ b/diffusers_common/src/safetensors.rs @@ -9,7 +9,7 @@ pub struct BytesSafetensors<'a> { impl<'a> BytesSafetensors<'a> { pub fn new(bytes: &'a [u8]) -> Result> { - let st = safetensors::SafeTensors::deserialize(bytes).map_err(|e| Error::from(e))?; + let st = safetensors::SafeTensors::deserialize(bytes).map_err(Error::from)?; Ok(Self { safetensors: st }) } diff --git a/diffusers_common/src/varbuilder_loading.rs b/diffusers_common/src/varbuilder_loading.rs index 51b3685..4be081c 100644 --- a/diffusers_common/src/varbuilder_loading.rs +++ b/diffusers_common/src/varbuilder_loading.rs @@ -41,7 +41,7 @@ impl TensorLoaderBackend for SafetensorBackend { struct BytesSafetensorBackend<'a>(BytesSafetensors<'a>); -impl<'a> TensorLoaderBackend for BytesSafetensorBackend<'a> { +impl TensorLoaderBackend for BytesSafetensorBackend<'_> { fn get_names(&self) -> Vec { self.0 .tensors() diff --git a/diffusers_core/src/models/t5/mod.rs b/diffusers_core/src/models/t5/mod.rs index e6a9f6a..cedf6b0 100644 --- a/diffusers_core/src/models/t5/mod.rs +++ b/diffusers_core/src/models/t5/mod.rs @@ -126,18 +126,8 @@ struct T5DenseActDense { impl T5DenseActDense { fn load(vb: VarBuilder, cfg: &T5Config) -> Result { - let wi = linear_no_bias( - cfg.d_model, - cfg.d_ff, - &&cfg.quantization_config, - vb.pp("wi"), - )?; - let wo = linear_no_bias( - cfg.d_ff, - cfg.d_model, - &&cfg.quantization_config, - vb.pp("wo"), - )?; + let wi = linear_no_bias(cfg.d_model, cfg.d_ff, &cfg.quantization_config, vb.pp("wi"))?; + let wo = linear_no_bias(cfg.d_ff, cfg.d_model, &cfg.quantization_config, vb.pp("wo"))?; Ok(Self { wi, wo, @@ -168,21 +158,16 @@ impl T5DenseGatedActDense { let wi_0 = linear_no_bias( cfg.d_model, cfg.d_ff, - &&cfg.quantization_config, + &cfg.quantization_config, vb.pp("wi_0"), )?; let wi_1 = linear_no_bias( cfg.d_model, cfg.d_ff, - &&cfg.quantization_config, + &cfg.quantization_config, vb.pp("wi_1"), )?; - let wo = linear_no_bias( - cfg.d_ff, - cfg.d_model, - &&cfg.quantization_config, - vb.pp("wo"), - )?; + let wo = linear_no_bias(cfg.d_ff, cfg.d_model, &cfg.quantization_config, vb.pp("wo"))?; Ok(Self { wi_0, wi_1, @@ -267,30 +252,10 @@ impl T5Attention { cfg: &T5Config, ) -> Result { let inner_dim = cfg.num_heads * cfg.d_kv; - let q = linear_no_bias( - cfg.d_model, - inner_dim, - &&cfg.quantization_config, - vb.pp("q"), - )?; - let k = linear_no_bias( - cfg.d_model, - inner_dim, - &&cfg.quantization_config, - vb.pp("k"), - )?; - let v = linear_no_bias( - cfg.d_model, - inner_dim, - &&cfg.quantization_config, - vb.pp("v"), - )?; - let o = linear_no_bias( - inner_dim, - cfg.d_model, - &&cfg.quantization_config, - vb.pp("o"), - )?; + let q = linear_no_bias(cfg.d_model, inner_dim, &cfg.quantization_config, vb.pp("q"))?; + let k = linear_no_bias(cfg.d_model, inner_dim, &cfg.quantization_config, vb.pp("k"))?; + let v = linear_no_bias(cfg.d_model, inner_dim, &cfg.quantization_config, vb.pp("v"))?; + let o = linear_no_bias(inner_dim, cfg.d_model, &cfg.quantization_config, vb.pp("o"))?; let relative_attention_bias = if has_relative_attention_bias { let emb = embedding( cfg.relative_attention_num_buckets,