Skip to content

Commit fc77ef3

Browse files
committed
Merge pull request #2836 from erictzeng/hdf5_snapshot
Snapshot model weights/solver state to HDF5 files
2 parents 4a667d1 + c9b333e commit fc77ef3

20 files changed

+662
-151
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ Makefile.config
6161
data/*
6262
models/*
6363
*.caffemodel
64+
*.caffemodel.h5
6465
*.solverstate
66+
*.solverstate.h5
6567
*.binaryproto
6668
*leveldb
6769
*lmdb

examples/cifar10/train_full.sh

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ $TOOLS/caffe train \
88
# reduce learning rate by factor of 10
99
$TOOLS/caffe train \
1010
--solver=examples/cifar10/cifar10_full_solver_lr1.prototxt \
11-
--snapshot=examples/cifar10/cifar10_full_iter_60000.solverstate
11+
--snapshot=examples/cifar10/cifar10_full_iter_60000.solverstate.h5
1212

1313
# reduce learning rate by factor of 10
1414
$TOOLS/caffe train \
1515
--solver=examples/cifar10/cifar10_full_solver_lr2.prototxt \
16-
--snapshot=examples/cifar10/cifar10_full_iter_65000.solverstate
16+
--snapshot=examples/cifar10/cifar10_full_iter_65000.solverstate.h5

examples/cifar10/train_quick.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ $TOOLS/caffe train \
88
# reduce learning rate by factor of 10 after 8 epochs
99
$TOOLS/caffe train \
1010
--solver=examples/cifar10/cifar10_quick_solver_lr1.prototxt \
11-
--snapshot=examples/cifar10/cifar10_quick_iter_4000.solverstate
11+
--snapshot=examples/cifar10/cifar10_quick_iter_4000.solverstate.h5

examples/imagenet/resume_training.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33
./build/tools/caffe train \
44
--solver=models/bvlc_reference_caffenet/solver.prototxt \
5-
--snapshot=models/bvlc_reference_caffenet/caffenet_train_10000.solverstate
5+
--snapshot=models/bvlc_reference_caffenet/caffenet_train_10000.solverstate.h5

include/caffe/blob.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#include "caffe/syncedmem.hpp"
1111
#include "caffe/util/math_functions.hpp"
1212

13-
const int kMaxBlobAxes = INT_MAX;
13+
const int kMaxBlobAxes = 32;
1414

1515
namespace caffe {
1616

include/caffe/net.hpp

+4
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,12 @@ class Net {
9898
*/
9999
void CopyTrainedLayersFrom(const NetParameter& param);
100100
void CopyTrainedLayersFrom(const string trained_filename);
101+
void CopyTrainedLayersFromBinaryProto(const string trained_filename);
102+
void CopyTrainedLayersFromHDF5(const string trained_filename);
101103
/// @brief Writes the net to a proto.
102104
void ToProto(NetParameter* param, bool write_diff = false) const;
105+
/// @brief Writes the net to an HDF5 file.
106+
void ToHDF5(const string& filename, bool write_diff = false) const;
103107

104108
/// @brief returns the network name.
105109
inline const string& name() const { return name_; }

include/caffe/solver.hpp

+14-7
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ class Solver {
2727
virtual void Solve(const char* resume_file = NULL);
2828
inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
2929
void Step(int iters);
30-
// The Restore function implements how one should restore the solver to a
31-
// previously snapshotted state. You should implement the RestoreSolverState()
32-
// function that restores the state from a SolverState protocol buffer.
30+
// The Restore method simply dispatches to one of the
31+
// RestoreSolverStateFrom___ protected methods. You should implement these
32+
// methods to restore the state from the appropriate snapshot type.
3333
void Restore(const char* resume_file);
3434
virtual ~Solver() {}
3535
inline shared_ptr<Net<Dtype> > net() { return net_; }
@@ -46,11 +46,15 @@ class Solver {
4646
// function that produces a SolverState protocol buffer that needs to be
4747
// written to disk together with the learned net.
4848
void Snapshot();
49+
string SnapshotFilename(const string extension);
50+
string SnapshotToBinaryProto();
51+
string SnapshotToHDF5();
4952
// The test routine
5053
void TestAll();
5154
void Test(const int test_net_id = 0);
52-
virtual void SnapshotSolverState(SolverState* state) = 0;
53-
virtual void RestoreSolverState(const SolverState& state) = 0;
55+
virtual void SnapshotSolverState(const string& model_filename) = 0;
56+
virtual void RestoreSolverStateFromHDF5(const string& state_file) = 0;
57+
virtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0;
5458
void DisplayOutputBlobs(const int net_id);
5559

5660
SolverParameter param_;
@@ -85,8 +89,11 @@ class SGDSolver : public Solver<Dtype> {
8589
virtual void Regularize(int param_id);
8690
virtual void ComputeUpdateValue(int param_id, Dtype rate);
8791
virtual void ClipGradients();
88-
virtual void SnapshotSolverState(SolverState * state);
89-
virtual void RestoreSolverState(const SolverState& state);
92+
virtual void SnapshotSolverState(const string& model_filename);
93+
virtual void SnapshotSolverStateToBinaryProto(const string& model_filename);
94+
virtual void SnapshotSolverStateToHDF5(const string& model_filename);
95+
virtual void RestoreSolverStateFromHDF5(const string& state_file);
96+
virtual void RestoreSolverStateFromBinaryProto(const string& state_file);
9097
// history maintains the historical momentum data.
9198
// update maintains update related data and is not needed in snapshots.
9299
// temp maintains other information that might be needed in computation

include/caffe/util/hdf5.hpp

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#ifndef CAFFE_UTIL_HDF5_H_
2+
#define CAFFE_UTIL_HDF5_H_
3+
4+
#include <string>
5+
6+
#include "hdf5.h"
7+
#include "hdf5_hl.h"
8+
9+
#include "caffe/blob.hpp"
10+
11+
namespace caffe {
12+
13+
template <typename Dtype>
14+
void hdf5_load_nd_dataset_helper(
15+
hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
16+
Blob<Dtype>* blob);
17+
18+
template <typename Dtype>
19+
void hdf5_load_nd_dataset(
20+
hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
21+
Blob<Dtype>* blob);
22+
23+
template <typename Dtype>
24+
void hdf5_save_nd_dataset(
25+
const hid_t file_id, const string& dataset_name, const Blob<Dtype>& blob,
26+
bool write_diff = false);
27+
28+
int hdf5_load_int(hid_t loc_id, const string& dataset_name);
29+
void hdf5_save_int(hid_t loc_id, const string& dataset_name, int i);
30+
string hdf5_load_string(hid_t loc_id, const string& dataset_name);
31+
void hdf5_save_string(hid_t loc_id, const string& dataset_name,
32+
const string& s);
33+
34+
int hdf5_get_num_links(hid_t loc_id);
35+
string hdf5_get_name_by_idx(hid_t loc_id, int idx);
36+
37+
} // namespace caffe
38+
39+
#endif // CAFFE_UTIL_HDF5_H_

include/caffe/util/io.hpp

-18
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,11 @@
55
#include <string>
66

77
#include "google/protobuf/message.h"
8-
#include "hdf5.h"
9-
#include "hdf5_hl.h"
108

119
#include "caffe/blob.hpp"
1210
#include "caffe/common.hpp"
1311
#include "caffe/proto/caffe.pb.h"
1412

15-
#define HDF5_NUM_DIMS 4
16-
1713
namespace caffe {
1814

1915
using ::google::protobuf::Message;
@@ -140,20 +136,6 @@ cv::Mat DecodeDatumToCVMat(const Datum& datum, bool is_color);
140136

141137
void CVMatToDatum(const cv::Mat& cv_img, Datum* datum);
142138

143-
template <typename Dtype>
144-
void hdf5_load_nd_dataset_helper(
145-
hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
146-
Blob<Dtype>* blob);
147-
148-
template <typename Dtype>
149-
void hdf5_load_nd_dataset(
150-
hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
151-
Blob<Dtype>* blob);
152-
153-
template <typename Dtype>
154-
void hdf5_save_nd_dataset(
155-
const hid_t file_id, const string& dataset_name, const Blob<Dtype>& blob);
156-
157139
} // namespace caffe
158140

159141
#endif // CAFFE_UTIL_IO_H_

src/caffe/blob.cpp

+42-7
Original file line numberDiff line numberDiff line change
@@ -456,31 +456,66 @@ void Blob<Dtype>::FromProto(const BlobProto& proto, bool reshape) {
456456
}
457457
// copy data
458458
Dtype* data_vec = mutable_cpu_data();
459-
for (int i = 0; i < count_; ++i) {
460-
data_vec[i] = proto.data(i);
459+
if (proto.double_data_size() > 0) {
460+
CHECK_EQ(count_, proto.double_data_size());
461+
for (int i = 0; i < count_; ++i) {
462+
data_vec[i] = proto.double_data(i);
463+
}
464+
} else {
465+
CHECK_EQ(count_, proto.data_size());
466+
for (int i = 0; i < count_; ++i) {
467+
data_vec[i] = proto.data(i);
468+
}
461469
}
462-
if (proto.diff_size() > 0) {
470+
if (proto.double_diff_size() > 0) {
471+
CHECK_EQ(count_, proto.double_diff_size());
472+
Dtype* diff_vec = mutable_cpu_diff();
473+
for (int i = 0; i < count_; ++i) {
474+
diff_vec[i] = proto.double_diff(i);
475+
}
476+
} else if (proto.diff_size() > 0) {
477+
CHECK_EQ(count_, proto.diff_size());
463478
Dtype* diff_vec = mutable_cpu_diff();
464479
for (int i = 0; i < count_; ++i) {
465480
diff_vec[i] = proto.diff(i);
466481
}
467482
}
468483
}
469484

470-
template <typename Dtype>
471-
void Blob<Dtype>::ToProto(BlobProto* proto, bool write_diff) const {
485+
template <>
486+
void Blob<double>::ToProto(BlobProto* proto, bool write_diff) const {
487+
proto->clear_shape();
488+
for (int i = 0; i < shape_.size(); ++i) {
489+
proto->mutable_shape()->add_dim(shape_[i]);
490+
}
491+
proto->clear_double_data();
492+
proto->clear_double_diff();
493+
const double* data_vec = cpu_data();
494+
for (int i = 0; i < count_; ++i) {
495+
proto->add_double_data(data_vec[i]);
496+
}
497+
if (write_diff) {
498+
const double* diff_vec = cpu_diff();
499+
for (int i = 0; i < count_; ++i) {
500+
proto->add_double_diff(diff_vec[i]);
501+
}
502+
}
503+
}
504+
505+
template <>
506+
void Blob<float>::ToProto(BlobProto* proto, bool write_diff) const {
472507
proto->clear_shape();
473508
for (int i = 0; i < shape_.size(); ++i) {
474509
proto->mutable_shape()->add_dim(shape_[i]);
475510
}
476511
proto->clear_data();
477512
proto->clear_diff();
478-
const Dtype* data_vec = cpu_data();
513+
const float* data_vec = cpu_data();
479514
for (int i = 0; i < count_; ++i) {
480515
proto->add_data(data_vec[i]);
481516
}
482517
if (write_diff) {
483-
const Dtype* diff_vec = cpu_diff();
518+
const float* diff_vec = cpu_diff();
484519
for (int i = 0; i < count_; ++i) {
485520
proto->add_diff(diff_vec[i]);
486521
}

src/caffe/layers/hdf5_data_layer.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
#include "caffe/data_layers.hpp"
1818
#include "caffe/layer.hpp"
19-
#include "caffe/util/io.hpp"
19+
#include "caffe/util/hdf5.hpp"
2020

2121
namespace caffe {
2222

src/caffe/layers/hdf5_output_layer.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include "caffe/blob.hpp"
77
#include "caffe/common.hpp"
88
#include "caffe/layer.hpp"
9-
#include "caffe/util/io.hpp"
9+
#include "caffe/util/hdf5.hpp"
1010
#include "caffe/vision_layers.hpp"
1111

1212
namespace caffe {

src/caffe/layers/hdf5_output_layer.cu

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
#include "caffe/blob.hpp"
77
#include "caffe/common.hpp"
88
#include "caffe/layer.hpp"
9-
#include "caffe/util/io.hpp"
109
#include "caffe/vision_layers.hpp"
1110

1211
namespace caffe {

0 commit comments

Comments
 (0)