@@ -394,6 +394,18 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
394
394
}
395
395
}
396
396
397
+ // Save the solver history
398
+ vector<shared_ptr<Blob<Dtype> > > history_copies;
399
+ const vector<shared_ptr<Blob<Dtype> > >& orig_history = solver_->history ();
400
+ history_copies.resize (orig_history.size ());
401
+ for (int i = 0 ; i < orig_history.size (); ++i) {
402
+ history_copies[i].reset (new Blob<Dtype>());
403
+ const bool kReshape = true ;
404
+ for (int copy_diff = false ; copy_diff <= true ; ++copy_diff) {
405
+ history_copies[i]->CopyFrom (*orig_history[i], copy_diff, kReshape );
406
+ }
407
+ }
408
+
397
409
// Run the solver for num_iters iterations and snapshot.
398
410
snapshot = true ;
399
411
string snapshot_name = RunLeastSquaresSolver (learning_rate, weight_decay,
@@ -414,6 +426,17 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
414
426
<< " param " << i << " diff differed at dim " << j;
415
427
}
416
428
}
429
+
430
+ // Check that history now matches.
431
+ const vector<shared_ptr<Blob<Dtype> > >& history = solver_->history ();
432
+ for (int i = 0 ; i < history.size (); ++i) {
433
+ for (int j = 0 ; j < history[i]->count (); ++j) {
434
+ EXPECT_EQ (history_copies[i]->cpu_data ()[j], history[i]->cpu_data ()[j])
435
+ << " history blob " << i << " data differed at dim " << j;
436
+ EXPECT_EQ (history_copies[i]->cpu_diff ()[j], history[i]->cpu_diff ()[j])
437
+ << " history blob " << i << " diff differed at dim " << j;
438
+ }
439
+ }
417
440
}
418
441
};
419
442
0 commit comments