Skip to content

Commit 6d5970e

Browse files
henrylhtsangfacebook-github-bot
authored andcommitted
Refactor passing over cache params (pytorch#2155)
Summary: Pull Request resolved: pytorch#2155 Refactor the passing over cache params from dataclass to fused_params dict a bit. Motivation: I am trying to add KeyValueParams. Differential Revision: D58886177
1 parent a117ae2 commit 6d5970e

File tree

1 file changed

+39
-37
lines changed

1 file changed

+39
-37
lines changed

torchrec/distributed/utils.py

+39-37
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import sys
1414

1515
from collections import OrderedDict
16+
from dataclasses import asdict
1617
from typing import Any, Dict, List, Optional, Set, Type, TypeVar, Union
1718

1819
import torch
@@ -377,45 +378,46 @@ def add_params_from_parameter_sharding(
377378
# update fused_params using params from parameter_sharding
378379
# this will take precidence over the fused_params provided from sharders
379380
if parameter_sharding.cache_params is not None:
380-
cache_params = parameter_sharding.cache_params
381-
if cache_params.algorithm is not None:
382-
fused_params["cache_algorithm"] = cache_params.algorithm
383-
if cache_params.load_factor is not None:
384-
fused_params["cache_load_factor"] = cache_params.load_factor
385-
if cache_params.reserved_memory is not None:
386-
fused_params["cache_reserved_memory"] = cache_params.reserved_memory
387-
if cache_params.precision is not None:
388-
fused_params["cache_precision"] = cache_params.precision
389-
if cache_params.prefetch_pipeline is not None:
390-
fused_params["prefetch_pipeline"] = cache_params.prefetch_pipeline
391-
if cache_params.multipass_prefetch_config is not None:
392-
fused_params["multipass_prefetch_config"] = (
393-
cache_params.multipass_prefetch_config
394-
)
395-
396-
if parameter_sharding.enforce_hbm is not None:
397-
fused_params["enforce_hbm"] = parameter_sharding.enforce_hbm
398-
399-
if parameter_sharding.stochastic_rounding is not None:
400-
fused_params["stochastic_rounding"] = parameter_sharding.stochastic_rounding
401-
402-
if parameter_sharding.bounds_check_mode is not None:
403-
fused_params["bounds_check_mode"] = parameter_sharding.bounds_check_mode
404-
405-
if parameter_sharding.output_dtype is not None:
406-
fused_params["output_dtype"] = parameter_sharding.output_dtype
381+
cache_params_dict = asdict(parameter_sharding.cache_params)
382+
383+
def _add_cache_prefix(key: str) -> str:
384+
if key in {"algorithm", "load_factor", "reserved_memory", "precision"}:
385+
return f"cache_{key}"
386+
return key
387+
388+
cache_params_dict = {
389+
_add_cache_prefix(k): v
390+
for k, v in cache_params_dict.items()
391+
if v is not None and k not in {"stats"}
392+
}
393+
fused_params.update(cache_params_dict)
394+
395+
parameter_sharding_dict = asdict(parameter_sharding)
396+
params_to_fused_tbe: Set[str] = {
397+
"enforce_hbm",
398+
"stochastic_rounding",
399+
"bounds_check_mode",
400+
"output_dtype",
401+
}
402+
parameter_sharding_dict = {
403+
k: v
404+
for k, v in parameter_sharding_dict.items()
405+
if v is not None and k in params_to_fused_tbe
406+
}
407+
fused_params.update(parameter_sharding_dict)
407408

408409
# print warning if sharding_type is data_parallel or kernel is dense
409-
if parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value:
410-
logger.warning(
411-
f"Sharding Type is {parameter_sharding.sharding_type}, "
412-
"caching params will be ignored"
413-
)
414-
elif parameter_sharding.compute_kernel == EmbeddingComputeKernel.DENSE.value:
415-
logger.warning(
416-
f"Compute Kernel is {parameter_sharding.compute_kernel}, "
417-
"caching params will be ignored"
418-
)
410+
if parameter_sharding.cache_params is not None:
411+
if parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value:
412+
logger.warning(
413+
f"Sharding Type is {parameter_sharding.sharding_type}, "
414+
"caching params will be ignored"
415+
)
416+
elif parameter_sharding.compute_kernel == EmbeddingComputeKernel.DENSE.value:
417+
logger.warning(
418+
f"Compute Kernel is {parameter_sharding.compute_kernel}, "
419+
"caching params will be ignored"
420+
)
419421

420422
return fused_params
421423

0 commit comments

Comments
 (0)