Skip to content

Commit

Permalink
Change default precision on macOS (#1720)
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt authored Sep 11, 2024
1 parent df5b273 commit 456a32c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
1 change: 1 addition & 0 deletions litgpt/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def distribute(
accelerator = "cuda"
elif torch.backends.mps.is_available():
# accelerator = "mps"
accelerator = "cpu"
warnings.warn("MPS is currently not supported. Using CPU instead.", UserWarning)
else:
accelerator = "cpu"
Expand Down
20 changes: 15 additions & 5 deletions litgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,19 +348,29 @@ def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str)
return state_dict


def get_default_supported_precision(training: bool) -> str:
"""Return default precision that is supported by the hardware: either `bf16` or `16`.
def get_default_supported_precision(training: bool, use_mps: bool = False) -> str:
"""
Return the default precision that is supported by the hardware: either `bf16` or `16`.
Args:
training: `-mixed` or `-true` version of the precision to use
training: If True, returns '-mixed' version of the precision; if False, returns '-true' version.
use_mps: Flag to determine if MPS should be used when available.
Returns:
default precision that is suitable for the task and is supported by the hardware
The default precision that is suitable for the task and is supported by the hardware.
"""
from lightning.fabric.accelerators import MPSAccelerator
import torch

if MPSAccelerator.is_available() or (torch.cuda.is_available() and not torch.cuda.is_bf16_supported()):
if use_mps and MPSAccelerator.is_available():
return "16-mixed" if training else "16-true"

if torch.cuda.is_available():
if torch.cuda.is_bf16_supported():
return "bf16-mixed" if training else "bf16-true"
else:
return "16-mixed" if training else "16-true"

return "bf16-mixed" if training else "bf16-true"


Expand Down

0 comments on commit 456a32c

Please sign in to comment.