diff --git a/tests/test_utils.py b/tests/test_utils.py index 26e7de02a3..2fc0217013 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -497,11 +497,13 @@ def all_nvlink_connected_output(): @mock.patch("subprocess.run") -def test_all_nvlink_connected(mock_run, all_nvlink_connected_output): +def test_all_nvlink_connected( + mock_run, all_nvlink_connected_output, mock_cuda_is_available_true, mock_nvidia_device_properties +): mock_run.return_value = all_nvlink_connected_output with mock.patch("builtins.print") as mock_print: check_nvlink_connectivity() - mock_print.assert_any_call("All GPUs are fully connected via NVLink.") + mock_print.assert_any_call("All GPUs are fully connected via NVLink.") @pytest.fixture @@ -519,14 +521,16 @@ def nvlink_partially_connected_output(): @mock.patch("subprocess.run") -def test_nvlink_partially_connected_output(mock_run, nvlink_partially_connected_output): +def test_nvlink_partially_connected_output( + mock_run, nvlink_partially_connected_output, mock_cuda_is_available_true, mock_nvidia_device_properties +): mock_run.return_value = nvlink_partially_connected_output with mock.patch("builtins.print") as mock_print: check_nvlink_connectivity() - mock_print.assert_any_call( - "Warning: Not all GPUs are fully connected via NVLink. Some GPUs are connected via slower interfaces. " - "It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance." - ) + mock_print.assert_any_call( + "Warning: Not all GPUs are fully connected via NVLink. Some GPUs are connected via slower interfaces. " + "It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance." + ) @pytest.fixture @@ -549,14 +553,16 @@ def nvlink_not_connected_output(): @mock.patch("subprocess.run") -def test_nvlink_not_connected_output(mock_run, nvlink_not_connected_output): +def test_nvlink_not_connected_output( + mock_run, nvlink_not_connected_output, mock_cuda_is_available_true, mock_nvidia_device_properties +): mock_run.return_value = nvlink_not_connected_output with mock.patch("builtins.print") as mock_print: check_nvlink_connectivity() - mock_print.assert_any_call( - "Warning: Not all GPUs are fully connected via NVLink. Some GPUs are connected via slower interfaces. " - "It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance." - ) + mock_print.assert_any_call( + "Warning: Not all GPUs are fully connected via NVLink. Some GPUs are connected via slower interfaces. " + "It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance." + ) @pytest.fixture @@ -608,8 +614,43 @@ def nvlink_all_gpu_connected_but_other_connected_output(): @mock.patch("subprocess.run") -def test_nvlink_all_gpu_connected_but_other_connected_output(mock_run, nvlink_all_gpu_connected_but_other_connected_output): +def test_nvlink_all_gpu_connected_but_other_connected_output( + mock_run, + nvlink_all_gpu_connected_but_other_connected_output, + mock_cuda_is_available_true, + mock_nvidia_device_properties, +): mock_run.return_value = nvlink_all_gpu_connected_but_other_connected_output + with mock.patch("builtins.print") as mock_print: + check_nvlink_connectivity() + mock_print.assert_any_call("All GPUs are fully connected via NVLink.") + + +@pytest.fixture +def nvidia_smi_nvlink_output_dual_gpu_no_numa(): + return mock.MagicMock( + stdout=""" + GPU0 GPU1 CPU Affinity NUMA Affinity GPU NUMA ID +GPU0 X NV1 0-15 0 N/A +GPU1 NV1 X 0-15 0 N/A +Legend: + X = Self + SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI) + NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node + PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU) + PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge) + PIX = Connection traversing at most a single PCIe bridge + NV# = Connection traversing a bonded set of # NVLinks + """, + returncode=0, + ) + + +@mock.patch("subprocess.run") +def test_check_nvlink_connectivity__returns_fully_connected_when_nvidia_all_nvlink_two_gpus( + mock_run, all_nvlink_connected_output, mock_cuda_is_available_true, mock_nvidia_device_properties +): + mock_run.return_value = all_nvlink_connected_output with mock.patch("builtins.print") as mock_print: check_nvlink_connectivity() mock_print.assert_any_call("All GPUs are fully connected via NVLink.")