Skip to content

Commit 8ac8c0e

Browse files
henrylhtsangfacebook-github-bot
authored andcommitted
Integrate SSD TBE stage 1
Summary: # Plan Stage 1 aims to ensure that it can run, and won't break from normal operations (e.g. checkpointing). Checkpointing (i.e. state_dict and load_state_dict) are still work in progress. We also need to guarantee checkpointing for optimizer states. Stage 2: save state_dict (mostly on fbgemm side) * current hope is we can rely on flush to save state dict Stage 3: load_state_dict (need more thoughts) * solution should be similar to that of PS Stage 4: optimizer states checkpointing (torchrec side, should be pretty standard) * should be straightforward * need fbgemm to support split_embedding_weights api # Outstanding issues: * init is not the same as before * SSD TBE doesn't support mixed dim # design doc TODO: # tests should cover * state dict and load state dict (done) * should copy dense parts and not break * deterministics output (done) * numerical equivalence to normal TBE (done) * changing learning rate and warm up policy (done) * work for different sharding types (done) * work with mixed kernel (done) * work with mixed sharding types * multi-gpu training (todo) # OSS NOTE: SSD TBE won't work in an OSS environment, due to some rocksdb problem. # ad hoc * SSD kernel is guarded, user must specify it in constraints to use it Differential Revision: D57452256
1 parent da49f44 commit 8ac8c0e

11 files changed

+1334
-28
lines changed

torchrec/distributed/batched_embedding_kernel.py

+372
Large diffs are not rendered by default.

torchrec/distributed/embedding.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import logging
1313
import warnings
1414
from collections import defaultdict, deque, OrderedDict
15-
from dataclasses import dataclass, field
1615
from itertools import accumulate
1716
from typing import Any, cast, Dict, List, MutableMapping, Optional, Tuple, Type, Union
1817

@@ -654,6 +653,10 @@ def _initialize_torch_state(self) -> None: # noqa
654653
table_name,
655654
local_shards,
656655
) in self._model_parallel_name_to_local_shards.items():
656+
if model_parallel_name_to_compute_kernel[table_name] in {
657+
EmbeddingComputeKernel.SSD.value
658+
}:
659+
continue
657660
# for shards that don't exist on this rank, register with empty tensor
658661
if not hasattr(self.embeddings[table_name], "weight"):
659662
self.embeddings[table_name].register_parameter(
@@ -702,6 +705,10 @@ def reset_parameters(self) -> None:
702705
return
703706
# Initialize embedding weights with init_fn
704707
for table_config in self._embedding_configs:
708+
if self.module_sharding_plan[table_config.name].compute_kernel in {
709+
EmbeddingComputeKernel.SSD.value,
710+
}:
711+
continue
705712
assert table_config.init_fn is not None
706713
param = self.embeddings[f"{table_config.name}"].weight
707714
# pyre-ignore

torchrec/distributed/embedding_lookup.py

+19
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
BatchedDenseEmbeddingBag,
3232
BatchedFusedEmbedding,
3333
BatchedFusedEmbeddingBag,
34+
KeyValueEmbedding,
35+
KeyValueEmbeddingBag,
3436
)
3537
from torchrec.distributed.comm_ops import get_gradient_division
3638
from torchrec.distributed.composable.table_batched_embedding_slice import (
@@ -168,6 +170,14 @@ def _create_lookup(
168170
pg=pg,
169171
device=device,
170172
)
173+
elif config.compute_kernel in {
174+
EmbeddingComputeKernel.SSD,
175+
}:
176+
return KeyValueEmbedding(
177+
config=config,
178+
pg=pg,
179+
device=device,
180+
)
171181
else:
172182
raise ValueError(
173183
f"Compute kernel not supported {config.compute_kernel}"
@@ -368,6 +378,15 @@ def _create_lookup(
368378
device=device,
369379
sharding_type=sharding_type,
370380
)
381+
elif config.compute_kernel in {
382+
EmbeddingComputeKernel.SSD,
383+
}:
384+
return KeyValueEmbeddingBag(
385+
config=config,
386+
pg=pg,
387+
device=device,
388+
sharding_type=sharding_type,
389+
)
371390
else:
372391
raise ValueError(
373392
f"Compute kernel not supported {config.compute_kernel}"

torchrec/distributed/embedding_sharding.py

-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import abc
1111
import copy
12-
import uuid
1312
from collections import defaultdict
1413
from dataclasses import dataclass
1514
from itertools import filterfalse

torchrec/distributed/embedding_types.py

+3
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class EmbeddingComputeKernel(Enum):
6060
QUANT = "quant"
6161
QUANT_UVM = "quant_uvm"
6262
QUANT_UVM_CACHING = "quant_uvm_caching"
63+
SSD = "SSD"
6364

6465

6566
def compute_kernel_to_embedding_location(
@@ -69,6 +70,7 @@ def compute_kernel_to_embedding_location(
6970
EmbeddingComputeKernel.DENSE,
7071
EmbeddingComputeKernel.FUSED,
7172
EmbeddingComputeKernel.QUANT,
73+
EmbeddingComputeKernel.SSD, # use hbm for cache
7274
]:
7375
return EmbeddingLocation.DEVICE
7476
elif compute_kernel in [
@@ -410,6 +412,7 @@ def compute_kernels(
410412
ret += [
411413
EmbeddingComputeKernel.FUSED_UVM.value,
412414
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
415+
EmbeddingComputeKernel.SSD.value,
413416
]
414417
else:
415418
# TODO re-enable model parallel and dense

torchrec/distributed/embeddingbag.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from functools import partial
1414
from typing import (
1515
Any,
16-
Callable,
1716
cast,
1817
Dict,
1918
Iterator,
@@ -793,6 +792,10 @@ def _initialize_torch_state(self) -> None: # noqa
793792
table_name,
794793
local_shards,
795794
) in self._model_parallel_name_to_local_shards.items():
795+
if model_parallel_name_to_compute_kernel[table_name] in {
796+
EmbeddingComputeKernel.SSD.value
797+
}:
798+
continue
796799
# for shards that don't exist on this rank, register with empty tensor
797800
if not hasattr(self.embedding_bags[table_name], "weight"):
798801
self.embedding_bags[table_name].register_parameter(
@@ -841,6 +844,10 @@ def reset_parameters(self) -> None:
841844

842845
# Initialize embedding bags weights with init_fn
843846
for table_config in self._embedding_bag_configs:
847+
if self.module_sharding_plan[table_config.name].compute_kernel in {
848+
EmbeddingComputeKernel.SSD.value,
849+
}:
850+
continue
844851
assert table_config.init_fn is not None
845852
param = self.embedding_bags[f"{table_config.name}"].weight
846853
# pyre-ignore

torchrec/distributed/planner/constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def kernel_bw_lookup(
8787
caching_ratio * hbm_mem_bw + (1 - caching_ratio) * ddr_mem_bw
8888
)
8989
/ 10,
90+
("cuda", EmbeddingComputeKernel.SSD.value): ddr_mem_bw,
9091
}
9192

9293
if (

torchrec/distributed/planner/enumerators.py

+25-15
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# pyre-strict
99

1010
import logging
11-
from typing import Dict, List, Optional, Tuple, Union
11+
from typing import Dict, List, Optional, Set, Tuple, Union
1212

1313
from torch import nn
1414
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
@@ -40,6 +40,9 @@
4040

4141
logger: logging.Logger = logging.getLogger(__name__)
4242

43+
# compute kernels that should only be used if users specified them
44+
GUARDED_COMPUTE_KERNELS: Set[EmbeddingComputeKernel] = {EmbeddingComputeKernel.SSD}
45+
4346

4447
class EmbeddingEnumerator(Enumerator):
4548
"""
@@ -256,22 +259,29 @@ def _filter_compute_kernels(
256259
allowed_compute_kernels: List[str],
257260
sharding_type: str,
258261
) -> List[str]:
259-
# for the log message only
260-
constrained_compute_kernels: List[str] = [
261-
compute_kernel.value for compute_kernel in EmbeddingComputeKernel
262-
]
263-
if not self._constraints or not self._constraints.get(name):
264-
filtered_compute_kernels = allowed_compute_kernels
262+
# setup constrained_compute_kernels
263+
if (
264+
self._constraints
265+
and self._constraints.get(name)
266+
and self._constraints[name].compute_kernels
267+
):
268+
# pyre-ignore
269+
constrained_compute_kernels: List[str] = self._constraints[
270+
name
271+
].compute_kernels
265272
else:
266-
constraints: ParameterConstraints = self._constraints[name]
267-
if not constraints.compute_kernels:
268-
filtered_compute_kernels = allowed_compute_kernels
269-
else:
270-
constrained_compute_kernels = constraints.compute_kernels
271-
filtered_compute_kernels = list(
272-
set(constrained_compute_kernels) & set(allowed_compute_kernels)
273-
)
273+
constrained_compute_kernels: List[str] = [
274+
compute_kernel.value
275+
for compute_kernel in EmbeddingComputeKernel
276+
if compute_kernel not in GUARDED_COMPUTE_KERNELS
277+
]
278+
279+
# setup filtered_compute_kernels
280+
filtered_compute_kernels = list(
281+
set(constrained_compute_kernels) & set(allowed_compute_kernels)
282+
)
274283

284+
# special rules
275285
if EmbeddingComputeKernel.DENSE.value in filtered_compute_kernels:
276286
if (
277287
EmbeddingComputeKernel.FUSED.value in filtered_compute_kernels

torchrec/distributed/planner/shard_estimators.py

+3
Original file line numberDiff line numberDiff line change
@@ -1051,9 +1051,12 @@ def calculate_shard_storages(
10511051
if compute_kernel in {
10521052
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
10531053
EmbeddingComputeKernel.QUANT_UVM_CACHING.value,
1054+
EmbeddingComputeKernel.SSD.value,
10541055
}:
10551056
hbm_storage = round(ddr_storage * caching_ratio)
10561057
table_cached = True
1058+
if compute_kernel in {EmbeddingComputeKernel.SSD.value}:
1059+
ddr_storage = 0
10571060

10581061
optimizer_class = getattr(tensor, "_optimizer_class", None)
10591062

torchrec/distributed/test_utils/test_model_parallel.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,17 @@ def setUp(self, backend: str = "nccl") -> None:
4040

4141
self.tables = [
4242
EmbeddingBagConfig(
43-
num_embeddings=(i + 1) * 10,
44-
embedding_dim=(i + 2) * 8,
43+
num_embeddings=(i + 1) * 1000,
44+
embedding_dim=16,
4545
name="table_" + str(i),
4646
feature_names=["feature_" + str(i)],
4747
)
4848
for i in range(num_features)
4949
]
5050
shared_features_tables = [
5151
EmbeddingBagConfig(
52-
num_embeddings=(i + 1) * 10,
53-
embedding_dim=(i + 2) * 8,
52+
num_embeddings=(i + 1) * 1000,
53+
embedding_dim=16,
5454
name="table_" + str(i + num_features),
5555
feature_names=["feature_" + str(i)],
5656
)
@@ -60,8 +60,8 @@ def setUp(self, backend: str = "nccl") -> None:
6060

6161
self.mean_tables = [
6262
EmbeddingBagConfig(
63-
num_embeddings=(i + 1) * 10,
64-
embedding_dim=(i + 2) * 8,
63+
num_embeddings=(i + 1) * 1000,
64+
embedding_dim=16,
6565
name="table_" + str(i),
6666
feature_names=["feature_" + str(i)],
6767
pooling=PoolingType.MEAN,
@@ -71,8 +71,8 @@ def setUp(self, backend: str = "nccl") -> None:
7171

7272
shared_features_tables_mean = [
7373
EmbeddingBagConfig(
74-
num_embeddings=(i + 1) * 10,
75-
embedding_dim=(i + 2) * 8,
74+
num_embeddings=(i + 1) * 1000,
75+
embedding_dim=16,
7676
name="table_" + str(i + num_features),
7777
feature_names=["feature_" + str(i)],
7878
pooling=PoolingType.MEAN,
@@ -83,8 +83,8 @@ def setUp(self, backend: str = "nccl") -> None:
8383

8484
self.weighted_tables = [
8585
EmbeddingBagConfig(
86-
num_embeddings=(i + 1) * 10,
87-
embedding_dim=(i + 2) * 4,
86+
num_embeddings=(i + 1) * 1000,
87+
embedding_dim=16,
8888
name="weighted_table_" + str(i),
8989
feature_names=["weighted_feature_" + str(i)],
9090
)

0 commit comments

Comments
 (0)