@@ -535,24 +535,32 @@ Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const
535
535
shader.AddInput (" input_a" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
536
536
shader.AddOutput (" output" , ShaderUsage::UseUniform);
537
537
shader.AddOutput (" scales" , ShaderUsage::UseUniform);
538
-
538
+ shader.AdditionalImplementation () << R"ADDNL_FN(
539
+ fn readInput(offset: u32) -> input_a_value_t
540
+ {
541
+ if (offset > uniforms.input_size) {
542
+ return input_a_value_t(0);
543
+ }
544
+ return input_a[offset];
545
+ }
546
+ )ADDNL_FN" ;
539
547
shader.MainFunctionBody () << R"MAIN_FN(
540
548
var local_a : array<vec4<input_a_element_t>, 32>;
541
549
var max_value:vec4<input_a_element_t> = vec4<input_a_element_t>(0);
542
550
for (var idx:u32=0;idx<32;idx+=1)
543
551
{
544
- local_a[idx] = input_a[workgroup_id.x *32 + idx] ;
552
+ local_a[idx] = readInput(workgroup_idx *32 + idx) ;
545
553
max_value = max(max_value, abs(local_a[idx]));
546
554
}
547
555
var scale = max(max_value.x, max_value.y);
548
556
scale = max(scale, max_value.z);
549
557
scale = max(scale, max_value.w);
550
558
for (var idx:u32=0;idx<32;idx+=1)
551
559
{
552
- output[workgroup_id.x *32+idx] = pack4x8snorm(vec4<f32>(local_a[idx]/scale));
560
+ output[workgroup_idx *32+idx] = pack4x8snorm(vec4<f32>(local_a[idx]/scale));
553
561
}
554
562
// 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f.
555
- scales[workgroup_id.x ] = scale/127;
563
+ scales[workgroup_idx ] = scale/127;
556
564
)MAIN_FN" ;
557
565
return Status::OK ();
558
566
}
@@ -828,7 +836,8 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
828
836
Tensor a_scale = context.CreateGPUTensor (a->DataType (), a_scales_dims);
829
837
quantize_program.AddInputs ({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int >(kVec4Components )}})
830
838
.AddOutputs ({{&a_quant, ProgramTensorMetadataDependency::Rank, a_quant.Shape (), gsl::narrow<int >(1 )},
831
- {&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape (), gsl::narrow<int >(1 )}});
839
+ {&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape (), gsl::narrow<int >(1 )}})
840
+ .AddUniformVariable ({static_cast <uint32_t >(M * K / kVec4Components )});
832
841
ORT_RETURN_IF_ERROR (context.RunProgram (quantize_program));
833
842
834
843
constexpr uint32_t kTileSize = 64 ;
0 commit comments