|
13 | 13 | import sys
|
14 | 14 |
|
15 | 15 | from collections import OrderedDict
|
| 16 | +from dataclasses import asdict |
16 | 17 | from typing import Any, Dict, List, Optional, Set, Type, TypeVar, Union
|
17 | 18 |
|
18 | 19 | import torch
|
@@ -377,45 +378,46 @@ def add_params_from_parameter_sharding(
|
377 | 378 | # update fused_params using params from parameter_sharding
|
378 | 379 | # this will take precidence over the fused_params provided from sharders
|
379 | 380 | if parameter_sharding.cache_params is not None:
|
380 |
| - cache_params = parameter_sharding.cache_params |
381 |
| - if cache_params.algorithm is not None: |
382 |
| - fused_params["cache_algorithm"] = cache_params.algorithm |
383 |
| - if cache_params.load_factor is not None: |
384 |
| - fused_params["cache_load_factor"] = cache_params.load_factor |
385 |
| - if cache_params.reserved_memory is not None: |
386 |
| - fused_params["cache_reserved_memory"] = cache_params.reserved_memory |
387 |
| - if cache_params.precision is not None: |
388 |
| - fused_params["cache_precision"] = cache_params.precision |
389 |
| - if cache_params.prefetch_pipeline is not None: |
390 |
| - fused_params["prefetch_pipeline"] = cache_params.prefetch_pipeline |
391 |
| - if cache_params.multipass_prefetch_config is not None: |
392 |
| - fused_params["multipass_prefetch_config"] = ( |
393 |
| - cache_params.multipass_prefetch_config |
394 |
| - ) |
395 |
| - |
396 |
| - if parameter_sharding.enforce_hbm is not None: |
397 |
| - fused_params["enforce_hbm"] = parameter_sharding.enforce_hbm |
398 |
| - |
399 |
| - if parameter_sharding.stochastic_rounding is not None: |
400 |
| - fused_params["stochastic_rounding"] = parameter_sharding.stochastic_rounding |
401 |
| - |
402 |
| - if parameter_sharding.bounds_check_mode is not None: |
403 |
| - fused_params["bounds_check_mode"] = parameter_sharding.bounds_check_mode |
404 |
| - |
405 |
| - if parameter_sharding.output_dtype is not None: |
406 |
| - fused_params["output_dtype"] = parameter_sharding.output_dtype |
| 381 | + cache_params_dict = asdict(parameter_sharding.cache_params) |
| 382 | + |
| 383 | + def _add_cache_prefix(key: str) -> str: |
| 384 | + if key in {"algorithm", "load_factor", "reserved_memory", "precision"}: |
| 385 | + return f"cache_{key}" |
| 386 | + return key |
| 387 | + |
| 388 | + cache_params_dict = { |
| 389 | + _add_cache_prefix(k): v |
| 390 | + for k, v in cache_params_dict.items() |
| 391 | + if v is not None and k not in {"stats"} |
| 392 | + } |
| 393 | + fused_params.update(cache_params_dict) |
| 394 | + |
| 395 | + parameter_sharding_dict = asdict(parameter_sharding) |
| 396 | + params_to_fused_tbe: Set[str] = { |
| 397 | + "enforce_hbm", |
| 398 | + "stochastic_rounding", |
| 399 | + "bounds_check_mode", |
| 400 | + "output_dtype", |
| 401 | + } |
| 402 | + parameter_sharding_dict = { |
| 403 | + k: v |
| 404 | + for k, v in parameter_sharding_dict.items() |
| 405 | + if v is not None and k in params_to_fused_tbe |
| 406 | + } |
| 407 | + fused_params.update(parameter_sharding_dict) |
407 | 408 |
|
408 | 409 | # print warning if sharding_type is data_parallel or kernel is dense
|
409 |
| - if parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value: |
410 |
| - logger.warning( |
411 |
| - f"Sharding Type is {parameter_sharding.sharding_type}, " |
412 |
| - "caching params will be ignored" |
413 |
| - ) |
414 |
| - elif parameter_sharding.compute_kernel == EmbeddingComputeKernel.DENSE.value: |
415 |
| - logger.warning( |
416 |
| - f"Compute Kernel is {parameter_sharding.compute_kernel}, " |
417 |
| - "caching params will be ignored" |
418 |
| - ) |
| 410 | + if parameter_sharding.cache_params is not None: |
| 411 | + if parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value: |
| 412 | + logger.warning( |
| 413 | + f"Sharding Type is {parameter_sharding.sharding_type}, " |
| 414 | + "caching params will be ignored" |
| 415 | + ) |
| 416 | + elif parameter_sharding.compute_kernel == EmbeddingComputeKernel.DENSE.value: |
| 417 | + logger.warning( |
| 418 | + f"Compute Kernel is {parameter_sharding.compute_kernel}, " |
| 419 | + "caching params will be ignored" |
| 420 | + ) |
419 | 421 |
|
420 | 422 | return fused_params
|
421 | 423 |
|
|
0 commit comments