29
29
from torch .testing ._internal .common_distributed import MultiProcessTestCase , \
30
30
requires_gloo , requires_nccl , requires_nccl_version , \
31
31
skip_if_not_multigpu , skip_if_lt_x_gpu , get_timeout , skip_if_rocm , \
32
- simple_sparse_reduce_tests
32
+ simple_sparse_reduce_tests , skip_if_win32 , create_device
33
33
34
34
from torch .testing ._internal .common_utils import TestCase , load_tests , run_tests , \
35
35
retry_on_connect_failures , ADDRESS_IN_USE , CONNECT_TIMEOUT , TEST_WITH_TSAN
@@ -255,6 +255,7 @@ def create_tcp_store(addr):
255
255
raise RuntimeError ("Unable to find free port (tried %s)" % ", " .join (ports ))
256
256
257
257
258
+ @skip_if_win32 ()
258
259
class TCPStoreTest (TestCase , StoreTestBase ):
259
260
def _create_store (self ):
260
261
store = create_tcp_store ('localhost' )
@@ -273,6 +274,7 @@ def test_address_already_in_use(self):
273
274
store2 = c10d .TCPStore (addr , port , 1 , True ) # noqa: F841
274
275
275
276
277
+ @skip_if_win32 ()
276
278
class PrefixTCPStoreTest (TestCase , StoreTestBase ):
277
279
def setUp (self ):
278
280
super (PrefixTCPStoreTest , self ).setUp ()
@@ -329,6 +331,7 @@ def test_unknown_handler(self):
329
331
c10d .rendezvous ('invalid://' )
330
332
331
333
334
+ @skip_if_win32 ()
332
335
class RendezvousEnvTest (TestCase ):
333
336
@retry_on_connect_failures
334
337
def test_common_errors (self ):
@@ -455,7 +458,7 @@ def test_common_errors(self):
455
458
456
459
def test_nominal (self ):
457
460
with tempfile .NamedTemporaryFile (delete = False ) as file :
458
- url = 'file://%s?world_size=%d' % ( file .name , 2 )
461
+ url = f 'file:/// { file .name . replace ( os . path . sep , "/" ) } ?world_size=2'
459
462
gen0 = c10d .rendezvous (url + "&rank=0" )
460
463
store0 , rank0 , size0 = next (gen0 )
461
464
self .assertEqual (0 , rank0 )
@@ -474,6 +477,7 @@ def test_nominal(self):
474
477
self .assertEqual (b"value1" , store0 .get ("key1" ))
475
478
476
479
480
+ @skip_if_win32 ()
477
481
class RendezvousTCPTest (TestCase ):
478
482
479
483
def create_tcp_url (self ):
@@ -544,9 +548,13 @@ def _test_store_timeout(self, backend, init_method, c2p):
544
548
545
549
def _init_methods (self ):
546
550
f = tempfile .NamedTemporaryFile (delete = False )
547
- yield "file://%s" % f .name
548
- f .close ()
549
- yield "tcp://127.0.0.1:%d" % common .find_free_port ()
551
+ if sys .platform == 'win32' :
552
+ yield "file:///%s" % f .name .replace ("\\ " , "/" )
553
+ f .close ()
554
+ else :
555
+ yield "file://%s" % f .name
556
+ f .close ()
557
+ yield "tcp://127.0.0.1:%d" % common .find_free_port ()
550
558
551
559
def _test_default_store_timeout (self , backend ):
552
560
for init_method in self ._init_methods ():
@@ -584,11 +592,16 @@ def test_default_store_timeout_gloo(self):
584
592
class ProcessGroupGlooTest (MultiProcessTestCase ):
585
593
def setUp (self ):
586
594
super (ProcessGroupGlooTest , self ).setUp ()
587
- self ._fork_processes ()
595
+
596
+ # For Windows platform, Python does not support fork, change it to spawn here.
597
+ if sys .platform == 'win32' :
598
+ self ._spawn_processes ()
599
+ else :
600
+ self ._fork_processes ()
588
601
589
602
def opts (self , threads = 2 ):
590
603
opts = c10d .ProcessGroupGloo .Options ()
591
- opts .devices = [c10d . ProcessGroupGloo . create_device (interface = LOOPBACK )]
604
+ opts .devices = [create_device (interface = LOOPBACK )]
592
605
opts .timeout = 5.0
593
606
opts .threads = threads
594
607
return opts
@@ -598,8 +611,8 @@ def test_multi_device_constructor(self):
598
611
opts = c10d .ProcessGroupGloo .Options ()
599
612
opts .timeout = 5.0
600
613
opts .devices = [
601
- c10d . ProcessGroupGloo . create_device (interface = LOOPBACK ),
602
- c10d . ProcessGroupGloo . create_device (interface = LOOPBACK ),
614
+ create_device (interface = LOOPBACK ),
615
+ create_device (interface = LOOPBACK ),
603
616
]
604
617
pg = c10d .ProcessGroupGloo (store , self .rank , self .world_size , opts )
605
618
@@ -1514,6 +1527,7 @@ def test_barrier_implies_wait(self):
1514
1527
for i , tensor in enumerate (tensors ):
1515
1528
self .assertEqual (torch .full (size , float (i * self .world_size )), tensor )
1516
1529
1530
+ @skip_if_win32 ()
1517
1531
def test_round_robin (self ):
1518
1532
num_process_groups = 2
1519
1533
store = c10d .FileStore (self .file_name , self .world_size )
@@ -1531,6 +1545,7 @@ def test_round_robin(self):
1531
1545
pg .broadcast (tensor , root = 0 ).wait ()
1532
1546
self .assertEqual (torch .full ([100 , 100 ], 0. ), tensor )
1533
1547
1548
+ @skip_if_win32 ()
1534
1549
def test_round_robin_create_destroy (self ):
1535
1550
store = c10d .FileStore (self .file_name , self .world_size )
1536
1551
@@ -1959,7 +1974,10 @@ def forward(self, x):
1959
1974
class DistributedDataParallelTest (MultiProcessTestCase ):
1960
1975
def setUp (self ):
1961
1976
super (DistributedDataParallelTest , self ).setUp ()
1962
- self ._fork_processes ()
1977
+ if sys .platform == 'win32' :
1978
+ self ._spawn_processes ()
1979
+ else :
1980
+ self ._fork_processes ()
1963
1981
1964
1982
def tearDown (self ):
1965
1983
# DistributedDataParallel test doesn't seem to call FileStore destructor
@@ -2068,7 +2086,7 @@ def update_parameters(model):
2068
2086
def _test_gloo_backend (self , devices , device_ids , multi_device = False , gradient_as_bucket_view = False ):
2069
2087
store = c10d .FileStore (self .file_name , self .world_size )
2070
2088
options = c10d .ProcessGroupGloo .Options ()
2071
- options .devices = [c10d . ProcessGroupGloo . create_device (interface = LOOPBACK )]
2089
+ options .devices = [create_device (interface = LOOPBACK )]
2072
2090
process_group = c10d .ProcessGroupGloo (store , self .rank , self .world_size , options )
2073
2091
self ._test_ddp_with_process_group (process_group , devices , device_ids , multi_device , gradient_as_bucket_view )
2074
2092
@@ -3947,7 +3965,10 @@ def test_nccl_timeout(self):
3947
3965
class CommTest (MultiProcessTestCase ):
3948
3966
def setUp (self ):
3949
3967
super (CommTest , self ).setUp ()
3950
- self ._fork_processes ()
3968
+ if sys .platform == 'win32' :
3969
+ self ._spawn_processes ()
3970
+ else :
3971
+ self ._fork_processes ()
3951
3972
3952
3973
def tearDown (self ):
3953
3974
super (CommTest , self ).tearDown ()
@@ -4013,7 +4034,7 @@ def test_broadcast_coalesced_nccl(self):
4013
4034
def test_broadcast_coalesced_gloo_cuda (self ):
4014
4035
store = c10d .FileStore (self .file_name , self .world_size )
4015
4036
options = c10d .ProcessGroupGloo .Options ()
4016
- options .devices = [c10d . ProcessGroupGloo . create_device (interface = LOOPBACK )]
4037
+ options .devices = [create_device (interface = LOOPBACK )]
4017
4038
process_group = c10d .ProcessGroupGloo (store , self .rank , self .world_size , options )
4018
4039
device = torch .device ("cuda:%d" % self .rank )
4019
4040
ranks = list (range (self .world_size ))
@@ -4024,7 +4045,7 @@ def test_broadcast_coalesced_gloo_cuda(self):
4024
4045
def test_broadcast_coalesced_gloo_cpu (self ):
4025
4046
store = c10d .FileStore (self .file_name , self .world_size )
4026
4047
options = c10d .ProcessGroupGloo .Options ()
4027
- options .devices = [c10d . ProcessGroupGloo . create_device (interface = LOOPBACK )]
4048
+ options .devices = [create_device (interface = LOOPBACK )]
4028
4049
process_group = c10d .ProcessGroupGloo (store , self .rank , self .world_size , options )
4029
4050
device = torch .device ("cpu" )
4030
4051
ranks = list (range (self .world_size ))
0 commit comments