Skip to content

Commit

Permalink
AMD (MI250X) support (#1775)
Browse files Browse the repository at this point in the history
  • Loading branch information
TensorTemplar authored Oct 10, 2024
1 parent ad57435 commit 46c4337
Show file tree
Hide file tree
Showing 3 changed files with 226 additions and 59 deletions.
34 changes: 12 additions & 22 deletions litgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,11 +445,8 @@ def build_rope_cache(
condense_ratio: int = 1,
extra_config: Optional[dict] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Enhanced Transformer with Rotary Position Embedding.
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
transformers/rope/__init__.py. MIT License:
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
"""
Enhanced Transformer with Rotary Position Embedding.
Args:
seq_len (int): Sequence length.
Expand All @@ -463,7 +460,7 @@ def build_rope_cache(
Tuple[torch.Tensor, torch.Tensor]: Cosine and sine caches for RoPE.
"""

# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ assert n_elem % 2 == 0, "n_elem (head dimension) must be even"
# Compute the inverse frequencies theta
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))

if extra_config is not None:
Expand All @@ -480,26 +477,19 @@ def build_rope_cache(
# Compute wavelengths corresponding to the inverse frequencies
wavelen = 2 * torch.pi / theta

# Initialize adjusted inverse frequencies
adjusted_theta = theta.clone()

# Low Frequency Region: wavelen > low_freq_wavelen
mask_low_freq = wavelen > low_freq_wavelen
adjusted_theta[mask_low_freq] = theta[mask_low_freq] / factor
# Compute ratio across all elements
ratio = orig_context_len / wavelen

# Medium Frequency Region: high_freq_wavelen ≤ wavelen ≤ low_freq_wavelen
mask_medium_freq = (wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen)
# Compute smooth factor for medium frequencies
ratio = orig_context_len / wavelen[mask_medium_freq]
# Compute smooth_factor and clamp between 0 and 1
smooth_factor = (ratio - low_freq_factor) / (high_freq_factor - low_freq_factor)
# Interpolate inverse frequencies
adjusted_theta[mask_medium_freq] = (
(1 - smooth_factor) * (theta[mask_medium_freq] / factor)
+ smooth_factor * theta[mask_medium_freq]
)
smooth_factor = torch.clamp(smooth_factor, min=0.0, max=1.0)

# Compute adjusted_theta without masked indexing
adjusted_theta = (1 - smooth_factor) * (theta / factor) + smooth_factor * theta

theta = adjusted_theta

# Create position indexes `[0, 1, ..., seq_len - 1]`
# Create position indices `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, device=device) / condense_ratio

# Calculate the product of position index and $\theta_i$
Expand Down
131 changes: 98 additions & 33 deletions litgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,47 +637,112 @@ def auto_download_checkpoint(model_name, access_token=None, ignore_tokenizer_fil


def check_nvlink_connectivity(fabric=None):
"""Checks GPU connectivity for both NVIDIA and AMD GPUs.
This function delegates to vendor-specific implementations based on
the detected GPU vendor.
"""
if fabric is not None:
custom_print = fabric.print
else:
custom_print = print

if os.getenv("RANK", "0") == "0":
try:
result = subprocess.run(["nvidia-smi", "topo", "-m"], stdout=subprocess.PIPE, text=True)

if result.returncode != 0:
custom_print("Failed to run nvidia-smi")
return

lines = result.stdout.split('\n')
gpu_matrix = []

start_index = next((i for i, line in enumerate(lines) if "GPU0" in line), None) + 1
headers_line = lines[start_index - 1]
headers = headers_line.split()
# The regex is to avoid counting the "GPU NUMA ID" header as a GPU
# in headers like ['\x1b[4mGPU0', 'GPU1', 'GPU2', 'GPU3', 'GPU4', 'GPU5', 'GPU6', 'GPU7', 'NIC0', 'NIC1', 'NIC2', 'NIC3', 'NIC4', 'NIC5', 'NIC6', 'NIC7', 'NIC8', 'NIC9', 'CPU', 'Affinity', 'NUMA', 'Affinity', 'GPU', 'NUMA', 'ID\x1b[0m']
gpu_regex = re.compile(r'^GPU\d+$')
gpu_count = len([header for header in headers if gpu_regex.match(header)])

all_nvlink = True
for line in lines[start_index:start_index + gpu_count]:
gpu_matrix.append(line.strip())
connections = line.split()[1:1 + gpu_count]
if not all("NV" in conn for conn in connections if conn != "X"):
all_nvlink = False
break

if all_nvlink:
custom_print("All GPUs are fully connected via NVLink.")
if torch.cuda.is_available():
device_properties = torch.cuda.get_device_properties(0)
gpu_name = device_properties.name.lower()
if "nvidia" in gpu_name:
_check_nvidia_connectivity(custom_print)
elif "advanced micro devices" in gpu_name or "amd" in gpu_name:
_check_amd_connectivity(custom_print)
else:
custom_print(f"Unrecognized GPU vendor: {device_properties.name}")
else:
custom_print(
"Warning: Not all GPUs are fully connected via NVLink. Some GPUs are connected via slower interfaces. "
"It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance."
)

custom_print("No GPUs available")
except Exception as e:
custom_print(f"An error occurred: {e}")
custom_print(f"An error occurred while checking GPU connectivity: {e}")


def _check_nvidia_connectivity(custom_print):
"""Checks NVLink connectivity on NVIDIA GPUs."""
result = subprocess.run(["nvidia-smi", "topo", "-m"], stdout=subprocess.PIPE, text=True)
if result.returncode != 0:
custom_print("Failed to run nvidia-smi")
return

lines = result.stdout.strip().split("\n")
start_index = next((i for i, line in enumerate(lines) if "GPU0" in line), None)
if start_index is None:
custom_print("Failed to parse nvidia-smi output")
return

headers_line = lines[start_index]
headers = headers_line.split()
gpu_regex = re.compile(r"^GPU\d+$")
gpu_count = len([header for header in headers if gpu_regex.match(header)])

all_nvlink = True
for line in lines[start_index + 1 : start_index + 1 + gpu_count]:
columns = line.split()
connections = columns[1 : 1 + gpu_count]
if not all("NV" in conn for conn in connections if conn != "X"):
all_nvlink = False
break

if all_nvlink:
custom_print("All GPUs are fully connected via NVLink.")
else:
custom_print(
"Warning: Not all GPUs are fully connected via NVLink. Some GPUs are connected via slower interfaces. "
"It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance."
)


def _check_amd_connectivity(custom_print):
"""Checks XGMI connectivity on AMD GPUs."""
result = subprocess.run(["rocm-smi", "--showtopotype"], stdout=subprocess.PIPE, text=True)
if result.returncode != 0:
custom_print("Failed to run rocm-smi")
return

lines = result.stdout.strip().split("\n")
gpu_header_index = next((i for i, line in enumerate(lines) if re.match(r"^\s*GPU0", line)), None)
if gpu_header_index is None or gpu_header_index == 0:
custom_print("Failed to parse rocm-smi output (no GPU headers found)")
return

header_line = lines[gpu_header_index - 1]
headers = header_line.strip().split()
gpu_regex = re.compile(r"^GPU\d+$")
gpu_count = len([header for header in headers if gpu_regex.match(header)])

gpu_lines = []
for line in lines[gpu_header_index : gpu_header_index + gpu_count]:
if re.match(r"^\s*GPU\d+", line):
gpu_lines.append(line.strip())
if len(gpu_lines) != gpu_count:
custom_print("Mismatch in GPU count when parsing rocm-smi output")
return

all_xgmi = True
for line in gpu_lines:
columns = line.split()
connections = columns[1 : 1 + gpu_count]
for conn in connections:
if conn not in ("XGMI", "0"):
all_xgmi = False
break
if not all_xgmi:
break

if all_xgmi:
custom_print("All GPUs are fully connected via XGMI.")
else:
custom_print(
"Warning: Not all GPUs are fully connected via XGMI. Some GPUs are connected via slower interfaces. "
"It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance."
)


def fix_and_load_json(s):
Expand Down
120 changes: 116 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,29 @@ def test_file_size_above_limit_on_gpu():
assert size == 4_600_000_000


@pytest.fixture
def mock_cuda_is_available_true(monkeypatch):
"""Fixture to mock torch.cuda.is_available() to return True."""
monkeypatch.setattr(torch.cuda, "is_available", lambda: True)


@pytest.fixture
def mock_nvidia_device_properties(monkeypatch):
"""Fixture to mock torch.cuda.get_device_properties() for NVIDIA GPUs."""
mock_device_properties = mock.MagicMock(name="GPU Device", spec=["name"])
mock_device_properties.name = "NVIDIA RTX A6000"
monkeypatch.setattr(torch.cuda, "get_device_properties", lambda idx: mock_device_properties)


@pytest.fixture
def mock_amd_device_properties(monkeypatch):
"""Fixture to mock torch.cuda.get_device_properties() for AMD GPUs."""
mock_device_properties = mock.MagicMock(name="GPU Device", spec=["name"])
mock_device_properties.name = "AMD Instinct MI250X"
monkeypatch.setattr(torch.cuda, "get_device_properties", lambda idx: mock_device_properties)



@pytest.fixture
def all_nvlink_connected_output():
return mock.MagicMock(stdout=""" GPU0 GPU1 GPU2 GPU3
Expand All @@ -475,7 +498,7 @@ def all_nvlink_connected_output():


@mock.patch("subprocess.run")
def test_all_nvlink_connected(mock_run, all_nvlink_connected_output):
def test_all_nvlink_connected(mock_run, all_nvlink_connected_output, mock_cuda_is_available_true, mock_nvidia_device_properties):
mock_run.return_value = all_nvlink_connected_output
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
Expand All @@ -497,7 +520,7 @@ def nvlink_partially_connected_output():


@mock.patch("subprocess.run")
def test_nvlink_partially_connected_output(mock_run, nvlink_partially_connected_output):
def test_nvlink_partially_connected_output(mock_run, nvlink_partially_connected_output, mock_cuda_is_available_true, mock_nvidia_device_properties):
mock_run.return_value = nvlink_partially_connected_output
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
Expand Down Expand Up @@ -527,7 +550,7 @@ def nvlink_not_connected_output():


@mock.patch("subprocess.run")
def test_nvlink_not_connected_output(mock_run, nvlink_not_connected_output):
def test_nvlink_not_connected_output(mock_run, nvlink_not_connected_output, mock_cuda_is_available_true, mock_nvidia_device_properties):
mock_run.return_value = nvlink_not_connected_output
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
Expand Down Expand Up @@ -586,13 +609,102 @@ def nvlink_all_gpu_connected_but_other_connected_output():


@mock.patch("subprocess.run")
def test_nvlink_all_gpu_connected_but_other_connected_output(mock_run, nvlink_all_gpu_connected_but_other_connected_output):
def test_nvlink_all_gpu_connected_but_other_connected_output(
mock_run,
nvlink_all_gpu_connected_but_other_connected_output,
mock_cuda_is_available_true,
mock_nvidia_device_properties,
):
mock_run.return_value = nvlink_all_gpu_connected_but_other_connected_output
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("All GPUs are fully connected via NVLink.")


@pytest.fixture
def nvidia_smi_nvlink_output_dual_gpu_no_numa():
return mock.MagicMock(
stdout="""
GPU0 GPU1 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X NV1 0-15 0 N/A
GPU1 NV1 X 0-15 0 N/A
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
""",
returncode=0,
)


@mock.patch("subprocess.run")
def test_check_nvlink_connectivity__returns_fully_connected_when_nvidia_all_nvlink_two_gpus(
mock_run, nvidia_smi_nvlink_output_dual_gpu_no_numa, mock_cuda_is_available_true, mock_nvidia_device_properties
):
mock_run.return_value = nvidia_smi_nvlink_output_dual_gpu_no_numa
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("All GPUs are fully connected via NVLink.")


@pytest.fixture
def rocm_smi_xgmi_output_multi_gpu():
"""
rocm-smi --showtopotype on ROCm 6.0.3+
"""
return mock.MagicMock(
stdout="""
=============================== ROCm System Management Interface ============================
=============================== Link Type between two GPUs ===============================
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
GPU0 0 XGMI XGMI XGMI XGMI XGMI XGMI XGMI
GPU1 XGMI 0 XGMI XGMI XGMI XGMI XGMI XGMI
GPU2 XGMI XGMI 0 XGMI XGMI XGMI XGMI XGMI
GPU3 XGMI XGMI XGMI 0 XGMI XGMI XGMI XGMI
GPU4 XGMI XGMI XGMI XGMI 0 XGMI XGMI XGMI
GPU5 XGMI XGMI XGMI XGMI XGMI 0 XGMI XGMI
GPU6 XGMI XGMI XGMI XGMI XGMI XGMI 0 XGMI
GPU7 XGMI XGMI XGMI XGMI XGMI XGMI XGMI 0
================================== End of ROCm SMI Log ===================================
""",
returncode=0,
)


@mock.patch("subprocess.run")
def test_check_nvlink_connectivity__returns_fully_connected_when_amd_all_xgmi_8_gpus(
mock_run, rocm_smi_xgmi_output_multi_gpu, mock_cuda_is_available_true, mock_amd_device_properties
):
mock_run.return_value = rocm_smi_xgmi_output_multi_gpu
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("All GPUs are fully connected via XGMI.")


@mock.patch("subprocess.run")
def test_check_nvlink_connectivity__returns_no_gpus_when_no_gpus(mock_run, monkeypatch):
monkeypatch.setattr(torch.cuda, "is_available", lambda: False)
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("No GPUs available")


@mock.patch("subprocess.run")
def test_check_nvlink_connectivity__returns_unrecognized_vendor_when_unrecognized_vendor(mock_run, monkeypatch, mock_cuda_is_available_true):
mock_device_properties = mock.MagicMock(name="GPU Device", spec=["name"])
mock_device_properties.name = "GARAGE DIY HYPERSCALER GPU"
monkeypatch.setattr(torch.cuda, "get_device_properties", lambda idx: mock_device_properties)
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("Unrecognized GPU vendor: GARAGE DIY HYPERSCALER GPU")


def test_fix_and_load_json():
# Test 1: Invalid JSON string with a trailing comma
invalid_json_trailing_comma = '''
Expand Down

0 comments on commit 46c4337

Please sign in to comment.