Skip to content

Commit c31fb4f

Browse files
committed
Add explicit history check to SnapshotTest
1 parent dcb72d7 commit c31fb4f

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

src/caffe/test/test_gradient_based_solver.cpp

+23
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,18 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
394394
}
395395
}
396396

397+
// Save the solver history
398+
vector<shared_ptr<Blob<Dtype> > > history_copies;
399+
const vector<shared_ptr<Blob<Dtype> > >& orig_history = solver_->history();
400+
history_copies.resize(orig_history.size());
401+
for (int i = 0; i < orig_history.size(); ++i) {
402+
history_copies[i].reset(new Blob<Dtype>());
403+
const bool kReshape = true;
404+
for (int copy_diff = false; copy_diff <= true; ++copy_diff) {
405+
history_copies[i]->CopyFrom(*orig_history[i], copy_diff, kReshape);
406+
}
407+
}
408+
397409
// Run the solver for num_iters iterations and snapshot.
398410
snapshot = true;
399411
string snapshot_name = RunLeastSquaresSolver(learning_rate, weight_decay,
@@ -414,6 +426,17 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
414426
<< "param " << i << " diff differed at dim " << j;
415427
}
416428
}
429+
430+
// Check that history now matches.
431+
const vector<shared_ptr<Blob<Dtype> > >& history = solver_->history();
432+
for (int i = 0; i < history.size(); ++i) {
433+
for (int j = 0; j < history[i]->count(); ++j) {
434+
EXPECT_EQ(history_copies[i]->cpu_data()[j], history[i]->cpu_data()[j])
435+
<< "history blob " << i << " data differed at dim " << j;
436+
EXPECT_EQ(history_copies[i]->cpu_diff()[j], history[i]->cpu_diff()[j])
437+
<< "history blob " << i << " diff differed at dim " << j;
438+
}
439+
}
417440
}
418441
};
419442

0 commit comments

Comments
 (0)