Skip to content

Commit 002af6e

Browse files
qjia7guschmue
authored andcommitted
[webgpu] Use workgroup_idx instead of workgroup_id.x (#23696)
We should always use workgroup_idx instead of workgroup_id.x in cause the dispatched workgroups are normalized. When the input is large enough, the 1d workgroups will be normalized to 2d/3d and results incorrect result.
1 parent 612e42a commit 002af6e

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc

+14-5
Original file line numberDiff line numberDiff line change
@@ -535,24 +535,32 @@ Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const
535535
shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
536536
shader.AddOutput("output", ShaderUsage::UseUniform);
537537
shader.AddOutput("scales", ShaderUsage::UseUniform);
538-
538+
shader.AdditionalImplementation() << R"ADDNL_FN(
539+
fn readInput(offset: u32) -> input_a_value_t
540+
{
541+
if (offset > uniforms.input_size) {
542+
return input_a_value_t(0);
543+
}
544+
return input_a[offset];
545+
}
546+
)ADDNL_FN";
539547
shader.MainFunctionBody() << R"MAIN_FN(
540548
var local_a : array<vec4<input_a_element_t>, 32>;
541549
var max_value:vec4<input_a_element_t> = vec4<input_a_element_t>(0);
542550
for (var idx:u32=0;idx<32;idx+=1)
543551
{
544-
local_a[idx] = input_a[workgroup_id.x*32 + idx];
552+
local_a[idx] = readInput(workgroup_idx*32 + idx);
545553
max_value = max(max_value, abs(local_a[idx]));
546554
}
547555
var scale = max(max_value.x, max_value.y);
548556
scale = max(scale, max_value.z);
549557
scale = max(scale, max_value.w);
550558
for (var idx:u32=0;idx<32;idx+=1)
551559
{
552-
output[workgroup_id.x*32+idx] = pack4x8snorm(vec4<f32>(local_a[idx]/scale));
560+
output[workgroup_idx*32+idx] = pack4x8snorm(vec4<f32>(local_a[idx]/scale));
553561
}
554562
// 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f.
555-
scales[workgroup_id.x] = scale/127;
563+
scales[workgroup_idx] = scale/127;
556564
)MAIN_FN";
557565
return Status::OK();
558566
}
@@ -828,7 +836,8 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
828836
Tensor a_scale = context.CreateGPUTensor(a->DataType(), a_scales_dims);
829837
quantize_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kVec4Components)}})
830838
.AddOutputs({{&a_quant, ProgramTensorMetadataDependency::Rank, a_quant.Shape(), gsl::narrow<int>(1)},
831-
{&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape(), gsl::narrow<int>(1)}});
839+
{&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape(), gsl::narrow<int>(1)}})
840+
.AddUniformVariable({static_cast<uint32_t>(M * K / kVec4Components)});
832841
ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program));
833842

834843
constexpr uint32_t kTileSize = 64;

onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class DP4AMatMulQuantizeProgram final : public Program<DP4AMatMulQuantizeProgram
3939
public:
4040
DP4AMatMulQuantizeProgram() : Program{"DP4AMatMulQuantize"} {}
4141
Status GenerateShaderCode(ShaderHelper& sh) const override;
42+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"input_size", ProgramUniformVariableDataType::Uint32});
4243
};
4344

4445
class DP4AMatMulNBitsProgram final : public Program<DP4AMatMulNBitsProgram> {

0 commit comments

Comments
 (0)