Skip to content

Commit 9e4a6a7

Browse files
ys97529facebook-github-bot
authored andcommitted
Support weighted_bwd_compute_multiplier in sharding estimators
Differential Revision: D53550851
1 parent df78731 commit 9e4a6a7

File tree

3 files changed

+16
-0
lines changed

3 files changed

+16
-0
lines changed

torchrec/distributed/planner/constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
HALF_BLOCK_PENALTY: float = 1.15 # empirical studies
3535
QUARTER_BLOCK_PENALTY: float = 1.75 # empirical studies
3636
BWD_COMPUTE_MULTIPLIER: float = 2 # empirical studies
37+
WEIGHTED_FEATURE_BWD_COMPUTE_MULTIPLIER: float = 1 # empirical studies
3738
WEIGHTED_KERNEL_MULTIPLIER: float = 1.1 # empirical studies
3839
DP_ELEMENTWISE_KERNELS_PERF_FACTOR: float = 9.22 # empirical studies
3940

torchrec/distributed/planner/shard_estimators.py

+6
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def estimate(
216216
intra_host_bw=self._topology.intra_host_bw,
217217
inter_host_bw=self._topology.inter_host_bw,
218218
bwd_compute_multiplier=self._topology.bwd_compute_multiplier,
219+
weighted_feature_bwd_compute_multiplier=self._topology.weighted_feature_bwd_compute_multiplier,
219220
is_pooled=sharding_option.is_pooled,
220221
is_weighted=is_weighted,
221222
is_inference=self._is_inference,
@@ -251,6 +252,7 @@ def perf_func_emb_wall_time(
251252
intra_host_bw: float,
252253
inter_host_bw: float,
253254
bwd_compute_multiplier: float,
255+
weighted_feature_bwd_compute_multiplier: float,
254256
is_pooled: bool,
255257
is_weighted: bool = False,
256258
caching_ratio: Optional[float] = None,
@@ -336,6 +338,7 @@ def perf_func_emb_wall_time(
336338
inter_host_bw=inter_host_bw,
337339
intra_host_bw=intra_host_bw,
338340
bwd_compute_multiplier=bwd_compute_multiplier,
341+
weighted_feature_bwd_compute_multiplier=weighted_feature_bwd_compute_multiplier,
339342
is_pooled=is_pooled,
340343
is_weighted=is_weighted,
341344
is_inference=is_inference,
@@ -447,6 +450,7 @@ def _get_tw_sharding_perf(
447450
inter_host_bw: float,
448451
intra_host_bw: float,
449452
bwd_compute_multiplier: float,
453+
weighted_feature_bwd_compute_multiplier: float,
450454
is_pooled: bool,
451455
is_weighted: bool = False,
452456
is_inference: bool = False,
@@ -507,6 +511,8 @@ def _get_tw_sharding_perf(
507511

508512
# includes fused optimizers
509513
bwd_compute = fwd_compute * bwd_compute_multiplier
514+
if is_weighted:
515+
bwd_compute = bwd_compute * weighted_feature_bwd_compute_multiplier
510516

511517
prefetch_compute = cls._get_expected_cache_prefetch_time(
512518
ddr_mem_bw, expected_cache_fetches, emb_dim, table_data_type_size

torchrec/distributed/planner/types.py

+9
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
HBM_MEM_BW,
2626
INTRA_NODE_BANDWIDTH,
2727
POOLING_FACTOR,
28+
WEIGHTED_FEATURE_BWD_COMPUTE_MULTIPLIER,
2829
)
2930
from torchrec.distributed.types import (
3031
BoundsCheckMode,
@@ -186,6 +187,7 @@ def __init__(
186187
inter_host_bw: float = CROSS_NODE_BANDWIDTH,
187188
bwd_compute_multiplier: float = BWD_COMPUTE_MULTIPLIER,
188189
custom_topology_data: Optional[CustomTopologyData] = None,
190+
weighted_feature_bwd_compute_multiplier: float = WEIGHTED_FEATURE_BWD_COMPUTE_MULTIPLIER,
189191
) -> None:
190192
"""
191193
Representation of a network of devices in a cluster.
@@ -238,6 +240,9 @@ def __init__(
238240
self._inter_host_bw = inter_host_bw
239241
self._bwd_compute_multiplier = bwd_compute_multiplier
240242
self._custom_topology_data = custom_topology_data
243+
self._weighted_feature_bwd_compute_multiplier = (
244+
weighted_feature_bwd_compute_multiplier
245+
)
241246

242247
@property
243248
def compute_device(self) -> str:
@@ -275,6 +280,10 @@ def inter_host_bw(self) -> float:
275280
def bwd_compute_multiplier(self) -> float:
276281
return self._bwd_compute_multiplier
277282

283+
@property
284+
def weighted_feature_bwd_compute_multiplier(self) -> float:
285+
return self._weighted_feature_bwd_compute_multiplier
286+
278287
def __repr__(self) -> str:
279288
topology_repr: str = f"world_size={self._world_size} \n"
280289
topology_repr += f"compute_device={self._compute_device}\n"

0 commit comments

Comments
 (0)