@@ -70,6 +70,7 @@ class Fp8EinsumOp(base_layer.BaseLayer):
70
70
"""Wrapper around jnp.einsum used in standard Pax layers."""
71
71
72
72
amax_history_length : int = 1024
73
+ use_direct_quant : bool = True
73
74
74
75
def setup (self ) -> None :
75
76
scale_args , amax_history_args = _get_fp8_args (
@@ -130,7 +131,7 @@ def quantized_einsum(
130
131
131
132
def __call__ (
132
133
self , equation : str , * args : JTensor
133
- ) -> Union [ JTensor , tuple [JTensor , JTensor ] ]:
134
+ ) -> JTensor | tuple [JTensor , JTensor ]:
134
135
assert len (args ) == 2
135
136
x = args [0 ]
136
137
k = args [1 ]
@@ -141,11 +142,75 @@ def __call__(
141
142
), f'k dtype has to be { comp_dtype } , but got { k .dtype } '
142
143
x = jnp .asarray (x , comp_dtype )
143
144
144
- y = self .quantized_einsum (equation , x , k , return_quantized_x = False )
145
+ if self .use_direct_quant :
146
+
147
+ def _quantized_dot_general (
148
+ lhs ,
149
+ rhs ,
150
+ dimension_numbers ,
151
+ precision = None ,
152
+ preferred_element_type = None ,
153
+ ):
154
+ theta = self .theta
155
+ return fp8_ops .q_dot_dq (
156
+ lhs ,
157
+ rhs ,
158
+ lhs_scale = theta .input_scale ,
159
+ rhs_scale = theta .kernel_scale ,
160
+ out_grad_scale = theta .output_grad_scale ,
161
+ lhs_amax_history = theta .input_amax_history ,
162
+ rhs_amax_history = theta .kernel_amax_history ,
163
+ out_grad_amax_history = theta .output_grad_amax_history ,
164
+ compute_dtype = comp_dtype ,
165
+ dimension_numbers = dimension_numbers ,
166
+ precision = precision ,
167
+ preferred_element_type = preferred_element_type ,
168
+ )
169
+
170
+ y = jnp .einsum (equation , x , k , _dot_general = _quantized_dot_general )
171
+ else :
172
+ y = self .quantized_einsum (equation , x , k , return_quantized_x = False )
145
173
146
174
return y
147
175
148
176
177
+ # This decorator wraps a function to perform quantized dot product.
178
+ # It prepares the arguments for quantized_dot, including the pre-quantized input,
179
+ # scales, and amax histories. This allows for efficient FP8 matrix multiplication while
180
+ # managing quantization parameters.
181
+ def quantized_dot_config (
182
+ compute_dtype ,
183
+ q_lhs ,
184
+ lhs_scale ,
185
+ q_rhs ,
186
+ rhs_scale ,
187
+ out_grad_scale ,
188
+ out_grad_amax_history ,
189
+ ):
190
+ def decorator (func ):
191
+ def wrapper (
192
+ lhs , rhs , dimension_numbers , precision = None , preferred_element_type = None
193
+ ):
194
+ return fp8_ops .quantized_dot (
195
+ lhs = lhs ,
196
+ q_lhs = q_lhs ,
197
+ lhs_scale = lhs_scale ,
198
+ rhs = rhs ,
199
+ q_rhs = q_rhs ,
200
+ rhs_scale = rhs_scale ,
201
+ out_grad_scale = out_grad_scale ,
202
+ out_grad_amax_history = out_grad_amax_history ,
203
+ compute_dtype = compute_dtype ,
204
+ dimension_numbers = dimension_numbers ,
205
+ precision = precision ,
206
+ preferred_element_type = preferred_element_type ,
207
+ )
208
+
209
+ return wrapper
210
+
211
+ return decorator
212
+
213
+
149
214
class Fp8EinsumGatedOp (Fp8EinsumOp ):
150
215
"""Wrapper around two jnp.einsum for gated FFN."""
151
216
@@ -181,29 +246,102 @@ def __call__(self, equation: str, *args: JTensor) -> tuple[JTensor, JTensor]:
181
246
), f'k dtype has to be { comp_dtype } , but got { k .dtype } and { k_gated .dtype } '
182
247
x = jnp .asarray (x , comp_dtype )
183
248
184
- y , x_qdq = self .quantized_einsum (equation , x , k , return_quantized_x = True )
185
-
186
249
theta = self .theta
187
250
188
- k_gated_qdq = fp8_ops .in_qdq (
189
- comp_dtype ,
190
- jnp .float8_e4m3fn ,
191
- k_gated ,
192
- theta .kernel_scale_gated ,
193
- theta .kernel_amax_history_gated ,
194
- )
195
- y_gated_qdq = jnp .einsum (
196
- equation ,
197
- x_qdq ,
198
- k_gated_qdq ,
199
- _dot_general = fp8_ops .dot_general_with_precision ,
200
- )
201
- y_gated = fp8_ops .out_qdq (
202
- comp_dtype ,
203
- jnp .float8_e5m2 ,
204
- y_gated_qdq ,
205
- theta .output_grad_scale_gated ,
206
- theta .output_grad_amax_history_gated ,
207
- )
251
+ if self .use_direct_quant :
252
+ q_x , new_input_scale = fp8_ops .in_q (
253
+ comp_dtype ,
254
+ jnp .float8_e4m3fn ,
255
+ x ,
256
+ theta .input_scale ,
257
+ theta .input_amax_history ,
258
+ )
259
+ q_k , new_kernel_scale = fp8_ops .in_q (
260
+ comp_dtype ,
261
+ jnp .float8_e4m3fn ,
262
+ k ,
263
+ theta .kernel_scale ,
264
+ theta .kernel_amax_history ,
265
+ )
266
+ q_k_gated , new_kernel_scale_gated = fp8_ops .in_q (
267
+ comp_dtype ,
268
+ jnp .float8_e4m3fn ,
269
+ k_gated ,
270
+ theta .kernel_scale_gated ,
271
+ theta .kernel_amax_history_gated ,
272
+ )
273
+ common_args = (comp_dtype , q_x , new_input_scale )
274
+ main_fp8_metas = (
275
+ q_k ,
276
+ new_kernel_scale ,
277
+ theta .output_grad_scale ,
278
+ theta .output_grad_amax_history ,
279
+ )
280
+ gated_fp8_metas = (
281
+ q_k_gated ,
282
+ new_kernel_scale_gated ,
283
+ theta .output_grad_scale_gated ,
284
+ theta .output_grad_amax_history_gated ,
285
+ )
286
+
287
+ @quantized_dot_config (* common_args , * main_fp8_metas )
288
+ def _quantized_dot_general (
289
+ lhs ,
290
+ rhs ,
291
+ dimension_numbers ,
292
+ precision = None ,
293
+ preferred_element_type = None ,
294
+ ):
295
+ pass
296
+
297
+ @quantized_dot_config (* common_args , * gated_fp8_metas )
298
+ def _quantized_dot_general_gated (
299
+ lhs ,
300
+ rhs ,
301
+ dimension_numbers ,
302
+ precision = None ,
303
+ preferred_element_type = None ,
304
+ ):
305
+ pass
306
+
307
+ y = jnp .einsum (equation , x , k , _dot_general = _quantized_dot_general )
308
+ y_gated = jnp .einsum (
309
+ equation , x , k_gated , _dot_general = _quantized_dot_general_gated
310
+ )
311
+
312
+ y = fp8_ops .out_dq (
313
+ dq_type = x .dtype ,
314
+ lhs_scale = new_input_scale ,
315
+ rhs_scale = new_kernel_scale ,
316
+ out = y ,
317
+ )
318
+ y_gated = fp8_ops .out_dq (
319
+ dq_type = x .dtype ,
320
+ lhs_scale = new_input_scale ,
321
+ rhs_scale = new_kernel_scale_gated ,
322
+ out = y ,
323
+ )
324
+ else :
325
+ y , x_qdq = self .quantized_einsum (equation , x , k , return_quantized_x = True )
326
+ k_gated_qdq = fp8_ops .in_qdq (
327
+ comp_dtype ,
328
+ jnp .float8_e4m3fn ,
329
+ k_gated ,
330
+ theta .kernel_scale_gated ,
331
+ theta .kernel_amax_history_gated ,
332
+ )
333
+ y_gated_qdq = jnp .einsum (
334
+ equation ,
335
+ x_qdq ,
336
+ k_gated_qdq ,
337
+ _dot_general = fp8_ops .dot_general_with_precision ,
338
+ )
339
+ y_gated = fp8_ops .out_qdq (
340
+ comp_dtype ,
341
+ jnp .float8_e5m2 ,
342
+ y_gated_qdq ,
343
+ theta .output_grad_scale_gated ,
344
+ theta .output_grad_amax_history_gated ,
345
+ )
208
346
209
347
return y , y_gated
0 commit comments