Skip to content

Commit 9127e30

Browse files
henrylhtsangfacebook-github-bot
authored andcommitted
Add KeyValueParams (pytorch#2168)
Summary: Pull Request resolved: pytorch#2168 Add KeyValueParams class, that are for params to go to SSD TBE. Expectation: * pass to SSD TBE only when using EmbeddingComputeKernel.KEY_VALUE. This is important to make sure we can use a mixed of FUSED and KEY_VALUE tables. * need to be hashable Reviewed By: francomomo Differential Revision: D58892592
1 parent 115895d commit 9127e30

File tree

6 files changed

+61
-0
lines changed

6 files changed

+61
-0
lines changed

torchrec/distributed/batched_embedding_kernel.py

+6
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
)
4444
from fbgemm_gpu.tbe.ssd import ASSOC, SSDTableBatchedEmbeddingBags
4545
from torch import nn
46+
from torchrec.distributed.comm import get_local_rank
4647
from torchrec.distributed.composable.table_batched_embedding_slice import (
4748
TableBatchedEmbeddingSlice,
4849
)
@@ -133,6 +134,11 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
133134

134135
if "ssd_storage_directory" not in ssd_tbe_params:
135136
ssd_tbe_params["ssd_storage_directory"] = tempfile.mkdtemp()
137+
else:
138+
directory = ssd_tbe_params["ssd_storage_directory"]
139+
if "@local_rank" in directory:
140+
# assume we have initialized a process group already
141+
directory = directory.replace("@local_rank", str(get_local_rank()))
136142

137143
if "weights_precision" not in ssd_tbe_params:
138144
weights_precision = data_type_to_sparse_type(config.data_type)

torchrec/distributed/planner/enumerators.py

+7
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from torchrec.distributed.types import (
3232
BoundsCheckMode,
3333
CacheParams,
34+
KeyValueParams,
3435
ModuleSharder,
3536
ShardingType,
3637
)
@@ -154,6 +155,7 @@ def enumerate(
154155
feature_names,
155156
output_dtype,
156157
device_group,
158+
key_value_params,
157159
) = _extract_constraints_for_param(self._constraints, name)
158160

159161
# skip for other device groups
@@ -209,6 +211,7 @@ def enumerate(
209211
is_pooled=is_pooled,
210212
feature_names=feature_names,
211213
output_dtype=output_dtype,
214+
key_value_params=key_value_params,
212215
)
213216
)
214217
if not sharding_options_per_table:
@@ -315,6 +318,7 @@ def _extract_constraints_for_param(
315318
Optional[List[str]],
316319
Optional[DataType],
317320
Optional[str],
321+
Optional[KeyValueParams],
318322
]:
319323
input_lengths = [POOLING_FACTOR]
320324
col_wise_shard_dim = None
@@ -325,6 +329,7 @@ def _extract_constraints_for_param(
325329
feature_names = None
326330
output_dtype = None
327331
device_group = None
332+
key_value_params = None
328333

329334
if constraints and constraints.get(name):
330335
input_lengths = constraints[name].pooling_factors
@@ -336,6 +341,7 @@ def _extract_constraints_for_param(
336341
feature_names = constraints[name].feature_names
337342
output_dtype = constraints[name].output_dtype
338343
device_group = constraints[name].device_group
344+
key_value_params = constraints[name].key_value_params
339345

340346
return (
341347
input_lengths,
@@ -347,6 +353,7 @@ def _extract_constraints_for_param(
347353
feature_names,
348354
output_dtype,
349355
device_group,
356+
key_value_params,
350357
)
351358

352359

torchrec/distributed/planner/planners.py

+1
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def _to_sharding_plan(
106106
stochastic_rounding=sharding_option.stochastic_rounding,
107107
bounds_check_mode=sharding_option.bounds_check_mode,
108108
output_dtype=sharding_option.output_dtype,
109+
key_value_params=sharding_option.key_value_params,
109110
)
110111
plan[sharding_option.path] = module_plan
111112
return ShardingPlan(plan)

torchrec/distributed/planner/types.py

+8
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from torchrec.distributed.types import (
3131
BoundsCheckMode,
3232
CacheParams,
33+
KeyValueParams,
3334
ModuleSharder,
3435
ShardingPlan,
3536
)
@@ -368,6 +369,8 @@ class ShardingOption:
368369
output_dtype (Optional[DataType]): output dtype to be used by this table.
369370
The default is FP32. If not None, the output dtype will also be used
370371
by the planner to produce a more balanced plan.
372+
key_value_params (Optional[KeyValueParams]): Params for SSD TBE, either
373+
for SSD or PS.
371374
"""
372375

373376
def __init__(
@@ -389,6 +392,7 @@ def __init__(
389392
is_pooled: Optional[bool] = None,
390393
feature_names: Optional[List[str]] = None,
391394
output_dtype: Optional[DataType] = None,
395+
key_value_params: Optional[KeyValueParams] = None,
392396
) -> None:
393397
self.name = name
394398
self._tensor = tensor
@@ -410,6 +414,7 @@ def __init__(
410414
self.is_weighted: Optional[bool] = None
411415
self.feature_names: Optional[List[str]] = feature_names
412416
self.output_dtype: Optional[DataType] = output_dtype
417+
self.key_value_params: Optional[KeyValueParams] = key_value_params
413418

414419
@property
415420
def tensor(self) -> torch.Tensor:
@@ -574,6 +579,8 @@ class ParameterConstraints:
574579
device_group (Optional[str]): device group to be used by this table. It can be cpu
575580
or cuda. This specifies if the table should be placed on a cpu device
576581
or a gpu device.
582+
key_value_params (Optional[KeyValueParams]): key value params for SSD TBE, either for
583+
SSD or PS.
577584
"""
578585

579586
sharding_types: Optional[List[str]] = None
@@ -592,6 +599,7 @@ class ParameterConstraints:
592599
feature_names: Optional[List[str]] = None
593600
output_dtype: Optional[DataType] = None
594601
device_group: Optional[str] = None
602+
key_value_params: Optional[KeyValueParams] = None
595603

596604

597605
class PlannerErrorType(Enum):

torchrec/distributed/types.py

+28
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Iterator,
2121
List,
2222
Optional,
23+
Tuple,
2324
Type,
2425
TypeVar,
2526
Union,
@@ -576,6 +577,31 @@ def __hash__(self) -> int:
576577
)
577578

578579

580+
@dataclass
581+
class KeyValueParams:
582+
"""
583+
Params for SSD TBE aka SSDTableBatchedEmbeddingBags.
584+
585+
Attributes:
586+
ssd_storage_directory (Optional[str]): Directory for SSD. If we want directory
587+
to be f"data00_nvidia{local_rank}", pass in "data00_nvidia@local_rank".
588+
ps_hosts (Optional[Tuple[Tuple[str, int]]]): List of PS host ip addresses
589+
and ports. Example: (("::1", 2000), ("::1", 2001), ("::1", 2002)).
590+
Reason for using tuple is we want it hashable.
591+
"""
592+
593+
ssd_storage_directory: Optional[str] = None
594+
ps_hosts: Optional[Tuple[Tuple[str, int], ...]] = None
595+
596+
def __hash__(self) -> int:
597+
return hash(
598+
(
599+
self.ssd_storage_directory,
600+
self.ps_hosts,
601+
)
602+
)
603+
604+
579605
@dataclass
580606
class ParameterSharding:
581607
"""
@@ -591,6 +617,7 @@ class ParameterSharding:
591617
stochastic_rounding (Optional[bool]): whether to use stochastic rounding.
592618
bounds_check_mode (Optional[BoundsCheckMode]): bounds check mode.
593619
output_dtype (Optional[DataType]): output dtype.
620+
key_value_params (Optional[KeyValueParams]): key value params for SSD TBE or PS.
594621
595622
NOTE:
596623
ShardingType.TABLE_WISE - rank where this embedding is placed
@@ -610,6 +637,7 @@ class ParameterSharding:
610637
stochastic_rounding: Optional[bool] = None
611638
bounds_check_mode: Optional[BoundsCheckMode] = None
612639
output_dtype: Optional[DataType] = None
640+
key_value_params: Optional[KeyValueParams] = None
613641

614642

615643
class EmbeddingModuleShardingPlan(ModuleShardingPlan, Dict[str, ParameterSharding]):

torchrec/distributed/utils.py

+11
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import sys
1414

1515
from collections import OrderedDict
16+
from dataclasses import asdict
1617
from typing import Any, Dict, List, Optional, Set, Type, TypeVar, Union
1718

1819
import torch
@@ -405,6 +406,16 @@ def add_params_from_parameter_sharding(
405406
if parameter_sharding.output_dtype is not None:
406407
fused_params["output_dtype"] = parameter_sharding.output_dtype
407408

409+
if (
410+
parameter_sharding.compute_kernel in {EmbeddingComputeKernel.KEY_VALUE.value}
411+
and parameter_sharding.key_value_params is not None
412+
):
413+
key_value_params_dict = asdict(parameter_sharding.key_value_params)
414+
key_value_params_dict = {
415+
k: v for k, v in key_value_params_dict.items() if v is not None
416+
}
417+
fused_params.update(key_value_params_dict)
418+
408419
# print warning if sharding_type is data_parallel or kernel is dense
409420
if parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value:
410421
logger.warning(

0 commit comments

Comments
 (0)