Skip to content

Commit

Permalink
revert style changes
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Oct 9, 2024
1 parent 467fb87 commit e7da8e6
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 52 deletions.
27 changes: 8 additions & 19 deletions litgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,30 +473,19 @@ def build_rope_cache(
low_freq_factor = extra_config["low_freq_factor"]
high_freq_factor = extra_config["high_freq_factor"]

# Compute wavelength thresholds
low_freq_wavelen = orig_context_len / low_freq_factor
high_freq_wavelen = orig_context_len / high_freq_factor

# Compute wavelengths corresponding to the inverse frequencies
wavelen = 2 * torch.pi / theta

# Initialize adjusted inverse frequencies
adjusted_theta = theta.clone()

# Low Frequency Region: wavelen > low_freq_wavelen
mask_low_freq = wavelen > low_freq_wavelen
adjusted_theta[mask_low_freq] = theta[mask_low_freq] / factor
# Compute ratio across all elements
ratio = orig_context_len / wavelen

# Medium Frequency Region: high_freq_wavelen ≤ wavelen ≤ low_freq_wavelen
mask_medium_freq = (wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen)
# Compute smooth factor for medium frequencies
ratio = orig_context_len / wavelen[mask_medium_freq]
# Compute smooth_factor and clamp between 0 and 1
smooth_factor = (ratio - low_freq_factor) / (high_freq_factor - low_freq_factor)
# Interpolate inverse frequencies
adjusted_theta[mask_medium_freq] = (
(1 - smooth_factor) * (theta[mask_medium_freq] / factor)
+ smooth_factor * theta[mask_medium_freq]
)
smooth_factor = torch.clamp(smooth_factor, min=0.0, max=1.0)

# Compute adjusted_theta without masked indexing
adjusted_theta = (1 - smooth_factor) * (theta / factor) + smooth_factor * theta

theta = adjusted_theta

# Create position indexes `[0, 1, ..., seq_len - 1]`
Expand Down
131 changes: 98 additions & 33 deletions litgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,47 +637,112 @@ def auto_download_checkpoint(model_name, access_token=None, ignore_tokenizer_fil


def check_nvlink_connectivity(fabric=None):
"""Checks GPU connectivity for both NVIDIA and AMD GPUs.
This function delegates to vendor-specific implementations based on
the detected GPU vendor.
"""
if fabric is not None:
custom_print = fabric.print
else:
custom_print = print

if os.getenv("RANK", "0") == "0":
try:
result = subprocess.run(["nvidia-smi", "topo", "-m"], stdout=subprocess.PIPE, text=True)

if result.returncode != 0:
custom_print("Failed to run nvidia-smi")
return

lines = result.stdout.split('\n')
gpu_matrix = []

start_index = next((i for i, line in enumerate(lines) if "GPU0" in line), None) + 1
headers_line = lines[start_index - 1]
headers = headers_line.split()
# The regex is to avoid counting the "GPU NUMA ID" header as a GPU
# in headers like ['\x1b[4mGPU0', 'GPU1', 'GPU2', 'GPU3', 'GPU4', 'GPU5', 'GPU6', 'GPU7', 'NIC0', 'NIC1', 'NIC2', 'NIC3', 'NIC4', 'NIC5', 'NIC6', 'NIC7', 'NIC8', 'NIC9', 'CPU', 'Affinity', 'NUMA', 'Affinity', 'GPU', 'NUMA', 'ID\x1b[0m']
gpu_regex = re.compile(r'^GPU\d+$')
gpu_count = len([header for header in headers if gpu_regex.match(header)])

all_nvlink = True
for line in lines[start_index:start_index + gpu_count]:
gpu_matrix.append(line.strip())
connections = line.split()[1:1 + gpu_count]
if not all("NV" in conn for conn in connections if conn != "X"):
all_nvlink = False
break

if all_nvlink:
custom_print("All GPUs are fully connected via NVLink.")
if torch.cuda.is_available():
device_properties = torch.cuda.get_device_properties(0)
gpu_name = device_properties.name.lower()
if "nvidia" in gpu_name:
_check_nvidia_connectivity(custom_print)
elif "advanced micro devices" in gpu_name or "amd" in gpu_name:
_check_amd_connectivity(custom_print)
else:
custom_print(f"Unrecognized GPU vendor: {device_properties.name}")
else:
custom_print(
"Warning: Not all GPUs are fully connected via NVLink. Some GPUs are connected via slower interfaces. "
"It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance."
)

custom_print("No GPUs available")
except Exception as e:
custom_print(f"An error occurred: {e}")
custom_print(f"An error occurred while checking GPU connectivity: {e}")


def _check_nvidia_connectivity(custom_print):
"""Checks NVLink connectivity on NVIDIA GPUs."""
result = subprocess.run(["nvidia-smi", "topo", "-m"], stdout=subprocess.PIPE, text=True)
if result.returncode != 0:
custom_print("Failed to run nvidia-smi")
return

lines = result.stdout.strip().split("\n")
start_index = next((i for i, line in enumerate(lines) if "GPU0" in line), None)
if start_index is None:
custom_print("Failed to parse nvidia-smi output")
return

headers_line = lines[start_index]
headers = headers_line.split()
gpu_regex = re.compile(r"^GPU\d+$")
gpu_count = len([header for header in headers if gpu_regex.match(header)])

all_nvlink = True
for line in lines[start_index + 1 : start_index + 1 + gpu_count]:
columns = line.split()
connections = columns[1 : 1 + gpu_count]
if not all("NV" in conn for conn in connections if conn != "X"):
all_nvlink = False
break

if all_nvlink:
custom_print("All GPUs are fully connected via NVLink.")
else:
custom_print(
"Warning: Not all GPUs are fully connected via NVLink. Some GPUs are connected via slower interfaces. "
"It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance."
)


def _check_amd_connectivity(custom_print):
"""Checks XGMI connectivity on AMD GPUs."""
result = subprocess.run(["rocm-smi", "--showtopotype"], stdout=subprocess.PIPE, text=True)
if result.returncode != 0:
custom_print("Failed to run rocm-smi")
return

lines = result.stdout.strip().split("\n")
gpu_header_index = next((i for i, line in enumerate(lines) if re.match(r"^\s*GPU0", line)), None)
if gpu_header_index is None or gpu_header_index == 0:
custom_print("Failed to parse rocm-smi output (no GPU headers found)")
return

header_line = lines[gpu_header_index - 1]
headers = header_line.strip().split()
gpu_regex = re.compile(r"^GPU\d+$")
gpu_count = len([header for header in headers if gpu_regex.match(header)])

gpu_lines = []
for line in lines[gpu_header_index : gpu_header_index + gpu_count]:
if re.match(r"^\s*GPU\d+", line):
gpu_lines.append(line.strip())
if len(gpu_lines) != gpu_count:
custom_print("Mismatch in GPU count when parsing rocm-smi output")
return

all_xgmi = True
for line in gpu_lines:
columns = line.split()
connections = columns[1 : 1 + gpu_count]
for conn in connections:
if conn not in ("XGMI", "0"):
all_xgmi = False
break
if not all_xgmi:
break

if all_xgmi:
custom_print("All GPUs are fully connected via XGMI.")
else:
custom_print(
"Warning: Not all GPUs are fully connected via XGMI. Some GPUs are connected via slower interfaces. "
"It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance."
)


def fix_and_load_json(s):
Expand Down
74 changes: 74 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,28 @@ def test_file_size_above_limit_on_gpu():
assert size == 4_600_000_000


@pytest.fixture
def mock_cuda_is_available_true(monkeypatch):
"""Fixture to mock torch.cuda.is_available() to return True."""
monkeypatch.setattr(torch.cuda, "is_available", lambda: True)


@pytest.fixture
def mock_nvidia_device_properties(monkeypatch):
"""Fixture to mock torch.cuda.get_device_properties() for NVIDIA GPUs."""
mock_device_properties = mock.MagicMock(name="GPU Device", spec=["name"])
mock_device_properties.name = "NVIDIA RTX A6000"
monkeypatch.setattr(torch.cuda, "get_device_properties", lambda idx: mock_device_properties)


@pytest.fixture
def mock_amd_device_properties(monkeypatch):
"""Fixture to mock torch.cuda.get_device_properties() for AMD GPUs."""
mock_device_properties = mock.MagicMock(name="GPU Device", spec=["name"])
mock_device_properties.name = "AMD Instinct MI250X"
monkeypatch.setattr(torch.cuda, "get_device_properties", lambda idx: mock_device_properties)


@pytest.fixture
def all_nvlink_connected_output():
return mock.MagicMock(stdout=""" GPU0 GPU1 GPU2 GPU3
Expand Down Expand Up @@ -593,6 +615,58 @@ def test_nvlink_all_gpu_connected_but_other_connected_output(mock_run, nvlink_al
mock_print.assert_any_call("All GPUs are fully connected via NVLink.")


@pytest.fixture
def rocm_smi_xgmi_output_multi_gpu():
"""
rocm-smi --showtopotype on ROCm 6.0.3+
"""
return mock.MagicMock(
stdout="""
=============================== ROCm System Management Interface ============================
=============================== Link Type between two GPUs ===============================
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
GPU0 0 XGMI XGMI XGMI XGMI XGMI XGMI XGMI
GPU1 XGMI 0 XGMI XGMI XGMI XGMI XGMI XGMI
GPU2 XGMI XGMI 0 XGMI XGMI XGMI XGMI XGMI
GPU3 XGMI XGMI XGMI 0 XGMI XGMI XGMI XGMI
GPU4 XGMI XGMI XGMI XGMI 0 XGMI XGMI XGMI
GPU5 XGMI XGMI XGMI XGMI XGMI 0 XGMI XGMI
GPU6 XGMI XGMI XGMI XGMI XGMI XGMI 0 XGMI
GPU7 XGMI XGMI XGMI XGMI XGMI XGMI XGMI 0
================================== End of ROCm SMI Log ===================================
""",
returncode=0,
)


@mock.patch("subprocess.run")
def test_check_nvlink_connectivity__returns_fully_connected_when_amd_all_xgmi_8_gpus(
mock_run, rocm_smi_xgmi_output_multi_gpu, mock_cuda_is_available_true, mock_amd_device_properties
):
mock_run.return_value = rocm_smi_xgmi_output_multi_gpu
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("All GPUs are fully connected via XGMI.")


@mock.patch("subprocess.run")
def test_check_nvlink_connectivity__returns_no_gpus_when_no_gpus(mock_run, monkeypatch):
monkeypatch.setattr(torch.cuda, "is_available", lambda: False)
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("No GPUs available")


@mock.patch("subprocess.run")
def test_check_nvlink_connectivity__returns_unrecognized_vendor_when_unrecognized_vendor(mock_run, monkeypatch, mock_cuda_is_available_true):
mock_device_properties = mock.MagicMock(name="GPU Device", spec=["name"])
mock_device_properties.name = "GARAGE DIY HYPERSCALER GPU"
monkeypatch.setattr(torch.cuda, "get_device_properties", lambda idx: mock_device_properties)
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("Unrecognized GPU vendor: GARAGE DIY HYPERSCALER GPU")


def test_fix_and_load_json():
# Test 1: Invalid JSON string with a trailing comma
invalid_json_trailing_comma = '''
Expand Down

0 comments on commit e7da8e6

Please sign in to comment.