Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Oct 9, 2024
1 parent e7da8e6 commit b93ef60
Showing 1 changed file with 54 additions and 13 deletions.
67 changes: 54 additions & 13 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,11 +497,13 @@ def all_nvlink_connected_output():


@mock.patch("subprocess.run")
def test_all_nvlink_connected(mock_run, all_nvlink_connected_output):
def test_all_nvlink_connected(
mock_run, all_nvlink_connected_output, mock_cuda_is_available_true, mock_nvidia_device_properties
):
mock_run.return_value = all_nvlink_connected_output
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("All GPUs are fully connected via NVLink.")
mock_print.assert_any_call("All GPUs are fully connected via NVLink.")


@pytest.fixture
Expand All @@ -519,14 +521,16 @@ def nvlink_partially_connected_output():


@mock.patch("subprocess.run")
def test_nvlink_partially_connected_output(mock_run, nvlink_partially_connected_output):
def test_nvlink_partially_connected_output(
mock_run, nvlink_partially_connected_output, mock_cuda_is_available_true, mock_nvidia_device_properties
):
mock_run.return_value = nvlink_partially_connected_output
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call(
"Warning: Not all GPUs are fully connected via NVLink. Some GPUs are connected via slower interfaces. "
"It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance."
)
mock_print.assert_any_call(
"Warning: Not all GPUs are fully connected via NVLink. Some GPUs are connected via slower interfaces. "
"It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance."
)


@pytest.fixture
Expand All @@ -549,14 +553,16 @@ def nvlink_not_connected_output():


@mock.patch("subprocess.run")
def test_nvlink_not_connected_output(mock_run, nvlink_not_connected_output):
def test_nvlink_not_connected_output(
mock_run, nvlink_not_connected_output, mock_cuda_is_available_true, mock_nvidia_device_properties
):
mock_run.return_value = nvlink_not_connected_output
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call(
"Warning: Not all GPUs are fully connected via NVLink. Some GPUs are connected via slower interfaces. "
"It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance."
)
mock_print.assert_any_call(
"Warning: Not all GPUs are fully connected via NVLink. Some GPUs are connected via slower interfaces. "
"It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance."
)


@pytest.fixture
Expand Down Expand Up @@ -608,8 +614,43 @@ def nvlink_all_gpu_connected_but_other_connected_output():


@mock.patch("subprocess.run")
def test_nvlink_all_gpu_connected_but_other_connected_output(mock_run, nvlink_all_gpu_connected_but_other_connected_output):
def test_nvlink_all_gpu_connected_but_other_connected_output(
mock_run,
nvlink_all_gpu_connected_but_other_connected_output,
mock_cuda_is_available_true,
mock_nvidia_device_properties,
):
mock_run.return_value = nvlink_all_gpu_connected_but_other_connected_output
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("All GPUs are fully connected via NVLink.")


@pytest.fixture
def nvidia_smi_nvlink_output_dual_gpu_no_numa():
return mock.MagicMock(
stdout="""
GPU0 GPU1 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X NV1 0-15 0 N/A
GPU1 NV1 X 0-15 0 N/A
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
""",
returncode=0,
)


@mock.patch("subprocess.run")
def test_check_nvlink_connectivity__returns_fully_connected_when_nvidia_all_nvlink_two_gpus(
mock_run, all_nvlink_connected_output, mock_cuda_is_available_true, mock_nvidia_device_properties
):
mock_run.return_value = all_nvlink_connected_output
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("All GPUs are fully connected via NVLink.")
Expand Down

0 comments on commit b93ef60

Please sign in to comment.