Skip to content

Commit 619fc34

Browse files
author
pax authors
committed
Merge pull request #69 from wenscarl:fp8_direct_quant
PiperOrigin-RevId: 682969173
2 parents d7ad153 + ce9243b commit 619fc34

File tree

3 files changed

+170
-26
lines changed

3 files changed

+170
-26
lines changed

praxis/layers/injection/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ pytype_strict_test(
4242
deps = [
4343
":fp8_nvidia_gpu",
4444
# Implicit absl.testing.absltest dependency.
45+
# Implicit absl.testing.parameterized dependency.
4546
# Implicit flax.core dependency.
4647
# Implicit upb python proto dependency.
4748
# Implicit jax dependency.

praxis/layers/injection/fp8_nvidia_gpu.py

+162-24
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class Fp8EinsumOp(base_layer.BaseLayer):
7070
"""Wrapper around jnp.einsum used in standard Pax layers."""
7171

7272
amax_history_length: int = 1024
73+
use_direct_quant: bool = True
7374

7475
def setup(self) -> None:
7576
scale_args, amax_history_args = _get_fp8_args(
@@ -130,7 +131,7 @@ def quantized_einsum(
130131

131132
def __call__(
132133
self, equation: str, *args: JTensor
133-
) -> Union[JTensor, tuple[JTensor, JTensor]]:
134+
) -> JTensor | tuple[JTensor, JTensor]:
134135
assert len(args) == 2
135136
x = args[0]
136137
k = args[1]
@@ -141,11 +142,75 @@ def __call__(
141142
), f'k dtype has to be {comp_dtype}, but got {k.dtype}'
142143
x = jnp.asarray(x, comp_dtype)
143144

144-
y = self.quantized_einsum(equation, x, k, return_quantized_x=False)
145+
if self.use_direct_quant:
146+
147+
def _quantized_dot_general(
148+
lhs,
149+
rhs,
150+
dimension_numbers,
151+
precision=None,
152+
preferred_element_type=None,
153+
):
154+
theta = self.theta
155+
return fp8_ops.q_dot_dq(
156+
lhs,
157+
rhs,
158+
lhs_scale=theta.input_scale,
159+
rhs_scale=theta.kernel_scale,
160+
out_grad_scale=theta.output_grad_scale,
161+
lhs_amax_history=theta.input_amax_history,
162+
rhs_amax_history=theta.kernel_amax_history,
163+
out_grad_amax_history=theta.output_grad_amax_history,
164+
compute_dtype=comp_dtype,
165+
dimension_numbers=dimension_numbers,
166+
precision=precision,
167+
preferred_element_type=preferred_element_type,
168+
)
169+
170+
y = jnp.einsum(equation, x, k, _dot_general=_quantized_dot_general)
171+
else:
172+
y = self.quantized_einsum(equation, x, k, return_quantized_x=False)
145173

146174
return y
147175

148176

177+
# This decorator wraps a function to perform quantized dot product.
178+
# It prepares the arguments for quantized_dot, including the pre-quantized input,
179+
# scales, and amax histories. This allows for efficient FP8 matrix multiplication while
180+
# managing quantization parameters.
181+
def quantized_dot_config(
182+
compute_dtype,
183+
q_lhs,
184+
lhs_scale,
185+
q_rhs,
186+
rhs_scale,
187+
out_grad_scale,
188+
out_grad_amax_history,
189+
):
190+
def decorator(func):
191+
def wrapper(
192+
lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None
193+
):
194+
return fp8_ops.quantized_dot(
195+
lhs=lhs,
196+
q_lhs=q_lhs,
197+
lhs_scale=lhs_scale,
198+
rhs=rhs,
199+
q_rhs=q_rhs,
200+
rhs_scale=rhs_scale,
201+
out_grad_scale=out_grad_scale,
202+
out_grad_amax_history=out_grad_amax_history,
203+
compute_dtype=compute_dtype,
204+
dimension_numbers=dimension_numbers,
205+
precision=precision,
206+
preferred_element_type=preferred_element_type,
207+
)
208+
209+
return wrapper
210+
211+
return decorator
212+
213+
149214
class Fp8EinsumGatedOp(Fp8EinsumOp):
150215
"""Wrapper around two jnp.einsum for gated FFN."""
151216

@@ -181,29 +246,102 @@ def __call__(self, equation: str, *args: JTensor) -> tuple[JTensor, JTensor]:
181246
), f'k dtype has to be {comp_dtype}, but got {k.dtype} and {k_gated.dtype}'
182247
x = jnp.asarray(x, comp_dtype)
183248

184-
y, x_qdq = self.quantized_einsum(equation, x, k, return_quantized_x=True)
185-
186249
theta = self.theta
187250

188-
k_gated_qdq = fp8_ops.in_qdq(
189-
comp_dtype,
190-
jnp.float8_e4m3fn,
191-
k_gated,
192-
theta.kernel_scale_gated,
193-
theta.kernel_amax_history_gated,
194-
)
195-
y_gated_qdq = jnp.einsum(
196-
equation,
197-
x_qdq,
198-
k_gated_qdq,
199-
_dot_general=fp8_ops.dot_general_with_precision,
200-
)
201-
y_gated = fp8_ops.out_qdq(
202-
comp_dtype,
203-
jnp.float8_e5m2,
204-
y_gated_qdq,
205-
theta.output_grad_scale_gated,
206-
theta.output_grad_amax_history_gated,
207-
)
251+
if self.use_direct_quant:
252+
q_x, new_input_scale = fp8_ops.in_q(
253+
comp_dtype,
254+
jnp.float8_e4m3fn,
255+
x,
256+
theta.input_scale,
257+
theta.input_amax_history,
258+
)
259+
q_k, new_kernel_scale = fp8_ops.in_q(
260+
comp_dtype,
261+
jnp.float8_e4m3fn,
262+
k,
263+
theta.kernel_scale,
264+
theta.kernel_amax_history,
265+
)
266+
q_k_gated, new_kernel_scale_gated = fp8_ops.in_q(
267+
comp_dtype,
268+
jnp.float8_e4m3fn,
269+
k_gated,
270+
theta.kernel_scale_gated,
271+
theta.kernel_amax_history_gated,
272+
)
273+
common_args = (comp_dtype, q_x, new_input_scale)
274+
main_fp8_metas = (
275+
q_k,
276+
new_kernel_scale,
277+
theta.output_grad_scale,
278+
theta.output_grad_amax_history,
279+
)
280+
gated_fp8_metas = (
281+
q_k_gated,
282+
new_kernel_scale_gated,
283+
theta.output_grad_scale_gated,
284+
theta.output_grad_amax_history_gated,
285+
)
286+
287+
@quantized_dot_config(*common_args, *main_fp8_metas)
288+
def _quantized_dot_general(
289+
lhs,
290+
rhs,
291+
dimension_numbers,
292+
precision=None,
293+
preferred_element_type=None,
294+
):
295+
pass
296+
297+
@quantized_dot_config(*common_args, *gated_fp8_metas)
298+
def _quantized_dot_general_gated(
299+
lhs,
300+
rhs,
301+
dimension_numbers,
302+
precision=None,
303+
preferred_element_type=None,
304+
):
305+
pass
306+
307+
y = jnp.einsum(equation, x, k, _dot_general=_quantized_dot_general)
308+
y_gated = jnp.einsum(
309+
equation, x, k_gated, _dot_general=_quantized_dot_general_gated
310+
)
311+
312+
y = fp8_ops.out_dq(
313+
dq_type=x.dtype,
314+
lhs_scale=new_input_scale,
315+
rhs_scale=new_kernel_scale,
316+
out=y,
317+
)
318+
y_gated = fp8_ops.out_dq(
319+
dq_type=x.dtype,
320+
lhs_scale=new_input_scale,
321+
rhs_scale=new_kernel_scale_gated,
322+
out=y,
323+
)
324+
else:
325+
y, x_qdq = self.quantized_einsum(equation, x, k, return_quantized_x=True)
326+
k_gated_qdq = fp8_ops.in_qdq(
327+
comp_dtype,
328+
jnp.float8_e4m3fn,
329+
k_gated,
330+
theta.kernel_scale_gated,
331+
theta.kernel_amax_history_gated,
332+
)
333+
y_gated_qdq = jnp.einsum(
334+
equation,
335+
x_qdq,
336+
k_gated_qdq,
337+
_dot_general=fp8_ops.dot_general_with_precision,
338+
)
339+
y_gated = fp8_ops.out_qdq(
340+
comp_dtype,
341+
jnp.float8_e5m2,
342+
y_gated_qdq,
343+
theta.output_grad_scale_gated,
344+
theta.output_grad_amax_history_gated,
345+
)
208346

209347
return y, y_gated

praxis/layers/injection/fp8_nvidia_gpu_test.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from functools import partial
1919

2020
from absl.testing import absltest
21+
from absl.testing import parameterized
2122
from flax.linen.fp8_ops import qdq
2223
import jax
2324
from jax import numpy as jnp
@@ -30,9 +31,11 @@
3031

3132
PARAMS = base_layer.PARAMS
3233

34+
3335
class Fp8LinearsTest(test_utils.TestCase):
3436

35-
def test_fp8_einsum_injection(self):
37+
@parameterized.parameters([True, False])
38+
def test_fp8_einsum_injection(self, use_direct_quant):
3639
# Used to cast the inputs to be representable in FP8, so that the difference
3740
# of the results from the original gemm and fp8 gemm is small.
3841
cast_to_representable = partial(
@@ -100,7 +103,9 @@ def _train(variables, x):
100103
}
101104

102105
output1a, output1b = run(None, expected_shapes_original)
103-
einsum_tpl = pax_fiddle.Config(fp8_ops.Fp8EinsumOp)
106+
einsum_tpl = pax_fiddle.Config(
107+
fp8_ops.Fp8EinsumOp, use_direct_quant=use_direct_quant
108+
)
104109
output2a, output2b = run(einsum_tpl, expected_shapes_new)
105110
dw1, dw2 = output1b[0][PARAMS]['w'], output2b[0][PARAMS]['w']
106111
dx1, dx2 = output1b[1], output2b[1]

0 commit comments

Comments
 (0)