13
13
import sys
14
14
15
15
from collections import OrderedDict
16
+ from dataclasses import asdict
16
17
from typing import Any , Dict , List , Optional , Set , Type , TypeVar , Union
17
18
18
19
import torch
@@ -377,21 +378,19 @@ def add_params_from_parameter_sharding(
377
378
# update fused_params using params from parameter_sharding
378
379
# this will take precidence over the fused_params provided from sharders
379
380
if parameter_sharding .cache_params is not None :
380
- cache_params = parameter_sharding .cache_params
381
- if cache_params .algorithm is not None :
382
- fused_params ["cache_algorithm" ] = cache_params .algorithm
383
- if cache_params .load_factor is not None :
384
- fused_params ["cache_load_factor" ] = cache_params .load_factor
385
- if cache_params .reserved_memory is not None :
386
- fused_params ["cache_reserved_memory" ] = cache_params .reserved_memory
387
- if cache_params .precision is not None :
388
- fused_params ["cache_precision" ] = cache_params .precision
389
- if cache_params .prefetch_pipeline is not None :
390
- fused_params ["prefetch_pipeline" ] = cache_params .prefetch_pipeline
391
- if cache_params .multipass_prefetch_config is not None :
392
- fused_params ["multipass_prefetch_config" ] = (
393
- cache_params .multipass_prefetch_config
394
- )
381
+ cache_params_dict = asdict (parameter_sharding .cache_params )
382
+
383
+ def _add_cache_prefix (key : str ) -> str :
384
+ if key in {"algorithm" , "load_factor" , "reserved_memory" , "precision" }:
385
+ return f"cache_{ key } "
386
+ return key
387
+
388
+ cache_params_dict = {
389
+ _add_cache_prefix (k ): v
390
+ for k , v in cache_params_dict .items ()
391
+ if v is not None
392
+ }
393
+ fused_params .update (cache_params_dict )
395
394
396
395
if parameter_sharding .enforce_hbm is not None :
397
396
fused_params ["enforce_hbm" ] = parameter_sharding .enforce_hbm
@@ -406,16 +405,17 @@ def add_params_from_parameter_sharding(
406
405
fused_params ["output_dtype" ] = parameter_sharding .output_dtype
407
406
408
407
# print warning if sharding_type is data_parallel or kernel is dense
409
- if parameter_sharding .sharding_type == ShardingType .DATA_PARALLEL .value :
410
- logger .warning (
411
- f"Sharding Type is { parameter_sharding .sharding_type } , "
412
- "caching params will be ignored"
413
- )
414
- elif parameter_sharding .compute_kernel == EmbeddingComputeKernel .DENSE .value :
415
- logger .warning (
416
- f"Compute Kernel is { parameter_sharding .compute_kernel } , "
417
- "caching params will be ignored"
418
- )
408
+ if parameter_sharding .cache_params is not None :
409
+ if parameter_sharding .sharding_type == ShardingType .DATA_PARALLEL .value :
410
+ logger .warning (
411
+ f"Sharding Type is { parameter_sharding .sharding_type } , "
412
+ "caching params will be ignored"
413
+ )
414
+ elif parameter_sharding .compute_kernel == EmbeddingComputeKernel .DENSE .value :
415
+ logger .warning (
416
+ f"Compute Kernel is { parameter_sharding .compute_kernel } , "
417
+ "caching params will be ignored"
418
+ )
419
419
420
420
return fused_params
421
421
0 commit comments