6
6
import torch
7
7
from torch ._six import inf
8
8
import torch .optim as optim
9
- import torch .optim ._multi_tensor as optim_mt
10
9
import torch .nn .functional as F
11
10
from torch .optim import SGD
12
11
from torch .autograd import Variable
@@ -250,55 +249,49 @@ def _build_params_dict_single(self, weight, bias, **kwargs):
250
249
return [dict (params = bias , ** kwargs )]
251
250
252
251
def test_sgd (self ):
253
- for optimizer in [optim .SGD , optim_mt .SGD ]:
254
- self ._test_basic_cases (
255
- lambda weight , bias : optimizer ([weight , bias ], lr = 1e-3 )
256
- )
257
- self ._test_basic_cases (
258
- lambda weight , bias : optimizer (
259
- self ._build_params_dict (weight , bias , lr = 1e-2 ),
260
- lr = 1e-3 )
261
- )
262
- self ._test_basic_cases (
263
- lambda weight , bias : optimizer (
264
- self ._build_params_dict_single (weight , bias , lr = 1e-2 ),
265
- lr = 1e-3 )
266
- )
267
- self ._test_basic_cases (
268
- lambda weight , bias : optimizer (
269
- self ._build_params_dict_single (weight , bias , lr = 1e-2 ))
270
- )
271
- self ._test_basic_cases (
272
- lambda weight , bias : optimizer ([weight , bias ], lr = 1e-3 ),
273
- [lambda opt : StepLR (opt , gamma = 0.9 , step_size = 10 )]
274
- )
275
- self ._test_basic_cases (
276
- lambda weight , bias : optimizer ([weight , bias ], lr = 1e-3 ),
277
- [lambda opt : StepLR (opt , gamma = 0.9 , step_size = 10 ),
278
- lambda opt : ReduceLROnPlateau (opt )]
279
- )
280
- self ._test_basic_cases (
281
- lambda weight , bias : optimizer ([weight , bias ], lr = 1e-3 ),
282
- [lambda opt : StepLR (opt , gamma = 0.99 , step_size = 10 ),
283
- lambda opt : ExponentialLR (opt , gamma = 0.99 ),
284
- lambda opt : ReduceLROnPlateau (opt )]
285
- )
286
- with self .assertRaisesRegex (ValueError , "Invalid momentum value: -0.5" ):
287
- optimizer (None , lr = 1e-2 , momentum = - 0.5 )
252
+ self ._test_basic_cases (
253
+ lambda weight , bias : optim .SGD ([weight , bias ], lr = 1e-3 )
254
+ )
255
+ self ._test_basic_cases (
256
+ lambda weight , bias : optim .SGD (
257
+ self ._build_params_dict (weight , bias , lr = 1e-2 ),
258
+ lr = 1e-3 )
259
+ )
260
+ self ._test_basic_cases (
261
+ lambda weight , bias : optim .SGD (
262
+ self ._build_params_dict_single (weight , bias , lr = 1e-2 ),
263
+ lr = 1e-3 )
264
+ )
265
+ self ._test_basic_cases (
266
+ lambda weight , bias : optim .SGD (
267
+ self ._build_params_dict_single (weight , bias , lr = 1e-2 ))
268
+ )
269
+ self ._test_basic_cases (
270
+ lambda weight , bias : optim .SGD ([weight , bias ], lr = 1e-3 ),
271
+ [lambda opt : StepLR (opt , gamma = 0.9 , step_size = 10 )]
272
+ )
273
+ self ._test_basic_cases (
274
+ lambda weight , bias : optim .SGD ([weight , bias ], lr = 1e-3 ),
275
+ [lambda opt : StepLR (opt , gamma = 0.9 , step_size = 10 ),
276
+ lambda opt : ReduceLROnPlateau (opt )]
277
+ )
278
+ self ._test_basic_cases (
279
+ lambda weight , bias : optim .SGD ([weight , bias ], lr = 1e-3 ),
280
+ [lambda opt : StepLR (opt , gamma = 0.99 , step_size = 10 ),
281
+ lambda opt : ExponentialLR (opt , gamma = 0.99 ),
282
+ lambda opt : ReduceLROnPlateau (opt )]
283
+ )
284
+ with self .assertRaisesRegex (ValueError , "Invalid momentum value: -0.5" ):
285
+ optim .SGD (None , lr = 1e-2 , momentum = - 0.5 )
288
286
289
287
def test_sgd_sparse (self ):
290
- for optimizer in [optim .SGD , optim_mt .SGD ]:
291
- self ._test_rosenbrock_sparse (
292
- lambda params : optimizer (params , lr = 5e-3 )
293
- )
294
- self ._test_rosenbrock_sparse (
295
- lambda params : optimizer (params , lr = 0.005 ),
296
- [lambda opt : StepLR (opt , gamma = 0.99999 , step_size = 300 )]
297
- )
298
-
299
- def test_multi_tensor_optimizers (self ):
300
- if not torch .cuda .is_available ():
301
- return
288
+ self ._test_rosenbrock_sparse (
289
+ lambda params : optim .SGD (params , lr = 5e-3 )
290
+ )
291
+ self ._test_rosenbrock_sparse (
292
+ lambda params : optim .SGD (params , lr = 0.005 ),
293
+ [lambda opt : StepLR (opt , gamma = 0.99999 , step_size = 300 )]
294
+ )
302
295
303
296
def test_adam (self ):
304
297
self ._test_basic_cases (
@@ -344,22 +337,21 @@ def test_adam(self):
344
337
with self .assertRaisesRegex (ValueError , "Invalid beta parameter at index 0: 1.0" ):
345
338
optim .Adam (None , lr = 1e-2 , betas = (1.0 , 0.0 ))
346
339
347
- with self .assertRaisesRegex (ValueError , "Invalid weight_decay value: -1" ):
340
+ with self .assertRaisesRegex (ValueError , "Invalid weight_decay value: -1" ):
348
341
optim .Adam (None , lr = 1e-2 , weight_decay = - 1 )
349
342
350
343
def test_adamw (self ):
351
- for optimizer in [optim .AdamW , optim_mt .AdamW ]:
352
- self ._test_basic_cases (
353
- lambda weight , bias : optimizer ([weight , bias ], lr = 1e-3 )
354
- )
355
- self ._test_basic_cases (
356
- lambda weight , bias : optimizer (
357
- self ._build_params_dict (weight , bias , lr = 1e-2 ),
358
- lr = 1e-3 )
359
- )
360
-
361
- with self .assertRaisesRegex (ValueError , "Invalid weight_decay value: -1" ):
362
- optimizer (None , lr = 1e-2 , weight_decay = - 1 )
344
+ self ._test_basic_cases (
345
+ lambda weight , bias : optim .AdamW ([weight , bias ], lr = 1e-3 )
346
+ )
347
+ self ._test_basic_cases (
348
+ lambda weight , bias : optim .AdamW (
349
+ self ._build_params_dict (weight , bias , lr = 1e-2 ),
350
+ lr = 1e-3 )
351
+ )
352
+
353
+ with self .assertRaisesRegex (ValueError , "Invalid weight_decay value: -1" ):
354
+ optim .AdamW (None , lr = 1e-2 , weight_decay = - 1 )
363
355
364
356
def test_sparse_adam (self ):
365
357
self ._test_rosenbrock_sparse (
@@ -377,22 +369,21 @@ def test_sparse_adam(self):
377
369
# ROCm precision is too low to pass this test
378
370
@skipIfRocm
379
371
def test_adadelta (self ):
380
- for optimizer in [optim .Adadelta , optim_mt .Adadelta ]:
381
- self ._test_basic_cases (
382
- lambda weight , bias : optimizer ([weight , bias ])
383
- )
384
- self ._test_basic_cases (
385
- lambda weight , bias : optimizer (
386
- self ._build_params_dict (weight , bias , rho = 0.95 ))
387
- )
388
- self ._test_basic_cases (
389
- lambda weight , bias : optimizer (
390
- self ._build_params_dict (weight , bias , rho = 0.95 )),
391
- [lambda opt : StepLR (opt , gamma = 0.9 , step_size = 10 ),
392
- lambda opt : ReduceLROnPlateau (opt )]
393
- )
394
- with self .assertRaisesRegex (ValueError , "Invalid rho value: 1.1" ):
395
- optimizer (None , lr = 1e-2 , rho = 1.1 )
372
+ self ._test_basic_cases (
373
+ lambda weight , bias : optim .Adadelta ([weight , bias ])
374
+ )
375
+ self ._test_basic_cases (
376
+ lambda weight , bias : optim .Adadelta (
377
+ self ._build_params_dict (weight , bias , rho = 0.95 ))
378
+ )
379
+ self ._test_basic_cases (
380
+ lambda weight , bias : optim .Adadelta (
381
+ self ._build_params_dict (weight , bias , rho = 0.95 )),
382
+ [lambda opt : StepLR (opt , gamma = 0.9 , step_size = 10 ),
383
+ lambda opt : ReduceLROnPlateau (opt )]
384
+ )
385
+ with self .assertRaisesRegex (ValueError , "Invalid rho value: 1.1" ):
386
+ optim .Adadelta (None , lr = 1e-2 , rho = 1.1 )
396
387
397
388
def test_adagrad (self ):
398
389
self ._test_basic_cases (
@@ -434,71 +425,52 @@ def test_adagrad_sparse(self):
434
425
)
435
426
436
427
def test_adamax (self ):
437
- for optimizer in [optim .Adamax , optim_mt .Adamax ]:
438
- self ._test_basic_cases (
439
- lambda weight , bias : optimizer ([weight , bias ], lr = 1e-1 )
440
- )
441
- self ._test_basic_cases (
442
- lambda weight , bias : optimizer (
443
- self ._build_params_dict (weight , bias , lr = 1e-2 ),
444
- lr = 1e-1 )
445
- )
446
- with self .assertRaisesRegex (ValueError , "Invalid beta parameter at index 1: 1.0" ):
447
- optimizer (None , lr = 1e-2 , betas = (0.0 , 1.0 ))
428
+ self ._test_basic_cases (
429
+ lambda weight , bias : optim .Adamax ([weight , bias ], lr = 1e-1 )
430
+ )
431
+ self ._test_basic_cases (
432
+ lambda weight , bias : optim .Adamax (
433
+ self ._build_params_dict (weight , bias , lr = 1e-2 ),
434
+ lr = 1e-1 )
435
+ )
436
+ with self .assertRaisesRegex (ValueError , "Invalid beta parameter at index 1: 1.0" ):
437
+ optim .Adamax (None , lr = 1e-2 , betas = (0.0 , 1.0 ))
448
438
449
439
def test_rmsprop (self ):
450
- for optimizer in [optim .RMSprop , optim_mt .RMSprop ]:
451
- self ._test_basic_cases (
452
- lambda weight , bias : optimizer ([weight , bias ], lr = 1e-2 )
453
- )
454
- self ._test_basic_cases (
455
- lambda weight , bias : optimizer (
456
- self ._build_params_dict (weight , bias , lr = 1e-3 ),
457
- lr = 1e-2 )
458
- )
459
- self ._test_basic_cases (
460
- lambda weight , bias : optimizer (
461
- self ._build_params_dict (weight , bias , lr = 1e-3 ),
462
- lr = 1e-2 , centered = True )
463
- )
464
- self ._test_basic_cases (
465
- lambda weight , bias : optimizer (
466
- self ._build_params_dict (weight , bias , lr = 1e-3 ),
467
- lr = 1e-2 , centered = True , momentum = 0.1 )
468
- )
469
- self ._test_basic_cases (
470
- lambda weight , bias : optimizer (
471
- self ._build_params_dict (weight , bias , lr = 1e-3 ),
472
- lr = 1e-2 , momentum = 0.1 )
473
- )
474
- with self .assertRaisesRegex (ValueError , "Invalid momentum value: -1.0" ):
475
- optimizer (None , lr = 1e-2 , momentum = - 1.0 )
440
+ self ._test_basic_cases (
441
+ lambda weight , bias : optim .RMSprop ([weight , bias ], lr = 1e-2 )
442
+ )
443
+ self ._test_basic_cases (
444
+ lambda weight , bias : optim .RMSprop (
445
+ self ._build_params_dict (weight , bias , lr = 1e-3 ),
446
+ lr = 1e-2 )
447
+ )
448
+ with self .assertRaisesRegex (ValueError , "Invalid momentum value: -1.0" ):
449
+ optim .RMSprop (None , lr = 1e-2 , momentum = - 1.0 )
476
450
477
451
def test_asgd (self ):
478
- for optimizer in [optim .ASGD , optim_mt .ASGD ]:
479
- self ._test_basic_cases (
480
- lambda weight , bias : optimizer ([weight , bias ], lr = 1e-3 , t0 = 100 )
481
- )
482
- self ._test_basic_cases (
483
- lambda weight , bias : optimizer (
484
- self ._build_params_dict (weight , bias , lr = 1e-2 ),
485
- lr = 1e-3 , t0 = 100 )
486
- )
487
- with self .assertRaisesRegex (ValueError , "Invalid weight_decay value: -0.5" ):
488
- optimizer (None , lr = 1e-2 , weight_decay = - 0.5 )
452
+ self ._test_basic_cases (
453
+ lambda weight , bias : optim .ASGD ([weight , bias ], lr = 1e-3 , t0 = 100 )
454
+ )
455
+ self ._test_basic_cases (
456
+ lambda weight , bias : optim .ASGD (
457
+ self ._build_params_dict (weight , bias , lr = 1e-2 ),
458
+ lr = 1e-3 , t0 = 100 )
459
+ )
460
+ with self .assertRaisesRegex (ValueError , "Invalid weight_decay value: -0.5" ):
461
+ optim .ASGD (None , lr = 1e-2 , weight_decay = - 0.5 )
489
462
490
463
def test_rprop (self ):
491
- for optimizer in [optim .Rprop , optim_mt .Rprop ]:
492
- self ._test_basic_cases (
493
- lambda weight , bias : optimizer ([weight , bias ], lr = 1e-3 )
494
- )
495
- self ._test_basic_cases (
496
- lambda weight , bias : optimizer (
497
- self ._build_params_dict (weight , bias , lr = 1e-2 ),
498
- lr = 1e-3 )
499
- )
500
- with self .assertRaisesRegex (ValueError , "Invalid eta values: 1.0, 0.5" ):
501
- optimizer (None , lr = 1e-2 , etas = (1.0 , 0.5 ))
464
+ self ._test_basic_cases (
465
+ lambda weight , bias : optim .Rprop ([weight , bias ], lr = 1e-3 )
466
+ )
467
+ self ._test_basic_cases (
468
+ lambda weight , bias : optim .Rprop (
469
+ self ._build_params_dict (weight , bias , lr = 1e-2 ),
470
+ lr = 1e-3 )
471
+ )
472
+ with self .assertRaisesRegex (ValueError , "Invalid eta values: 1.0, 0.5" ):
473
+ optim .Rprop (None , lr = 1e-2 , etas = (1.0 , 0.5 ))
502
474
503
475
def test_lbfgs (self ):
504
476
self ._test_basic_cases (
0 commit comments