|
716 | 716 | <a href="#715" id="715">715</a>
|
717 | 717 | <a href="#716" id="716">716</a>
|
718 | 718 | <a href="#717" id="717">717</a>
|
719 |
| -<a href="#718" id="718">718</a></pre></div><pre class="rust"><code><span class="attr">#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] |
| 719 | +<a href="#718" id="718">718</a> |
| 720 | +<a href="#719" id="719">719</a> |
| 721 | +<a href="#720" id="720">720</a> |
| 722 | +<a href="#721" id="721">721</a> |
| 723 | +<a href="#722" id="722">722</a> |
| 724 | +<a href="#723" id="723">723</a> |
| 725 | +<a href="#724" id="724">724</a> |
| 726 | +<a href="#725" id="725">725</a></pre></div><pre class="rust"><code><span class="attr">#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] |
720 | 727 |
|
721 | 728 | </span><span class="comment">// T5 Text Model
|
722 | 729 | // https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
|
725 | 732 | <span class="kw">use </span>diffusion_rs_common::core::{DType, Device, Module, <span class="prelude-ty">Result</span>, Tensor, D};
|
726 | 733 | <span class="kw">use </span>diffusion_rs_common::nn::{Activation, Embedding};
|
727 | 734 | <span class="kw">use </span>diffusion_rs_common::{embedding, VarBuilder};
|
| 735 | +<span class="kw">use </span>float8::F8E4M3; |
728 | 736 | <span class="kw">use </span>serde::Deserialize;
|
729 | 737 | <span class="kw">use </span>std::sync::Arc;
|
730 | 738 |
|
|
1197 | 1205 | <span class="kw">fn </span>any(<span class="kw-2">&</span><span class="self">self</span>) -> <span class="prelude-ty">Result</span><bool> {
|
1198 | 1206 | <span class="kw">let </span>sum = <span class="self">self</span>.sum_all()<span class="question-mark">?</span>;
|
1199 | 1207 | <span class="kw">match </span><span class="self">self</span>.dtype() {
|
| 1208 | + DType::I8 => <span class="prelude-val">Ok</span>(sum.to_scalar::<u8>()<span class="question-mark">? </span>== <span class="number">0</span>), |
| 1209 | + DType::U8 => <span class="prelude-val">Ok</span>(sum.to_scalar::<u8>()<span class="question-mark">? </span>== <span class="number">0</span>), |
| 1210 | + DType::U32 => <span class="prelude-val">Ok</span>(sum.to_scalar::<u32>()<span class="question-mark">? </span>== <span class="number">0</span>), |
| 1211 | + DType::I16 => <span class="prelude-val">Ok</span>(sum.to_scalar::<i16>()<span class="question-mark">? </span>== <span class="number">0</span>), |
| 1212 | + DType::I32 => <span class="prelude-val">Ok</span>(sum.to_scalar::<i32>()<span class="question-mark">? </span>== <span class="number">0</span>), |
| 1213 | + DType::I64 => <span class="prelude-val">Ok</span>(sum.to_scalar::<i64>()<span class="question-mark">? </span>== <span class="number">0</span>), |
1200 | 1214 | DType::F16 => <span class="prelude-val">Ok</span>(sum.to_scalar::<half::f16>()<span class="question-mark">? </span>== half::f16::from_f32_const(<span class="number">0.</span>)),
|
1201 | 1215 | DType::BF16 => <span class="prelude-val">Ok</span>(sum.to_scalar::<half::bf16>()<span class="question-mark">? </span>== half::bf16::from_f32_const(<span class="number">0.</span>)),
|
1202 | 1216 | DType::F32 => <span class="prelude-val">Ok</span>(sum.to_scalar::<f32>()<span class="question-mark">? </span>== <span class="number">0.</span>),
|
1203 | 1217 | DType::F64 => <span class="prelude-val">Ok</span>(sum.to_scalar::<f64>()<span class="question-mark">? </span>== <span class="number">0.</span>),
|
1204 |
| - <span class="kw">_ </span>=> <span class="macro">unreachable!</span>(), |
| 1218 | + DType::F8E4M3 => <span class="prelude-val">Ok</span>(sum.to_scalar::<F8E4M3>()<span class="question-mark">? </span>== F8E4M3::ZERO), |
1205 | 1219 | }
|
1206 | 1220 | }
|
1207 | 1221 | }
|
|
0 commit comments