Skip to content

Commit e68c860

Browse files
henrylhtsangfacebook-github-bot
authored andcommitted
Refactor passing over cache params
Summary: Refactor the passing over cache params from dataclass to fused_params dict a bit. Differential Revision: D58886177
1 parent a117ae2 commit e68c860

File tree

1 file changed

+25
-25
lines changed

1 file changed

+25
-25
lines changed

torchrec/distributed/utils.py

+25-25
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import sys
1414

1515
from collections import OrderedDict
16+
from dataclasses import asdict
1617
from typing import Any, Dict, List, Optional, Set, Type, TypeVar, Union
1718

1819
import torch
@@ -377,21 +378,19 @@ def add_params_from_parameter_sharding(
377378
# update fused_params using params from parameter_sharding
378379
# this will take precidence over the fused_params provided from sharders
379380
if parameter_sharding.cache_params is not None:
380-
cache_params = parameter_sharding.cache_params
381-
if cache_params.algorithm is not None:
382-
fused_params["cache_algorithm"] = cache_params.algorithm
383-
if cache_params.load_factor is not None:
384-
fused_params["cache_load_factor"] = cache_params.load_factor
385-
if cache_params.reserved_memory is not None:
386-
fused_params["cache_reserved_memory"] = cache_params.reserved_memory
387-
if cache_params.precision is not None:
388-
fused_params["cache_precision"] = cache_params.precision
389-
if cache_params.prefetch_pipeline is not None:
390-
fused_params["prefetch_pipeline"] = cache_params.prefetch_pipeline
391-
if cache_params.multipass_prefetch_config is not None:
392-
fused_params["multipass_prefetch_config"] = (
393-
cache_params.multipass_prefetch_config
394-
)
381+
cache_params_dict = asdict(parameter_sharding.cache_params)
382+
383+
def _add_cache_prefix(key: str) -> str:
384+
if key in {"algorithm", "load_factor", "reserved_memory", "precision"}:
385+
return f"cache_{key}"
386+
return key
387+
388+
cache_params_dict = {
389+
_add_cache_prefix(k): v
390+
for k, v in cache_params_dict.items()
391+
if v is not None
392+
}
393+
fused_params.update(cache_params_dict)
395394

396395
if parameter_sharding.enforce_hbm is not None:
397396
fused_params["enforce_hbm"] = parameter_sharding.enforce_hbm
@@ -406,16 +405,17 @@ def add_params_from_parameter_sharding(
406405
fused_params["output_dtype"] = parameter_sharding.output_dtype
407406

408407
# print warning if sharding_type is data_parallel or kernel is dense
409-
if parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value:
410-
logger.warning(
411-
f"Sharding Type is {parameter_sharding.sharding_type}, "
412-
"caching params will be ignored"
413-
)
414-
elif parameter_sharding.compute_kernel == EmbeddingComputeKernel.DENSE.value:
415-
logger.warning(
416-
f"Compute Kernel is {parameter_sharding.compute_kernel}, "
417-
"caching params will be ignored"
418-
)
408+
if parameter_sharding.cache_params is not None:
409+
if parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value:
410+
logger.warning(
411+
f"Sharding Type is {parameter_sharding.sharding_type}, "
412+
"caching params will be ignored"
413+
)
414+
elif parameter_sharding.compute_kernel == EmbeddingComputeKernel.DENSE.value:
415+
logger.warning(
416+
f"Compute Kernel is {parameter_sharding.compute_kernel}, "
417+
"caching params will be ignored"
418+
)
419419

420420
return fused_params
421421

0 commit comments

Comments
 (0)