@@ -456,31 +456,66 @@ void Blob<Dtype>::FromProto(const BlobProto& proto, bool reshape) {
456
456
}
457
457
// copy data
458
458
Dtype* data_vec = mutable_cpu_data ();
459
- for (int i = 0 ; i < count_; ++i) {
460
- data_vec[i] = proto.data (i);
459
+ if (proto.double_data_size () > 0 ) {
460
+ CHECK_EQ (count_, proto.double_data_size ());
461
+ for (int i = 0 ; i < count_; ++i) {
462
+ data_vec[i] = proto.double_data (i);
463
+ }
464
+ } else {
465
+ CHECK_EQ (count_, proto.data_size ());
466
+ for (int i = 0 ; i < count_; ++i) {
467
+ data_vec[i] = proto.data (i);
468
+ }
461
469
}
462
- if (proto.diff_size () > 0 ) {
470
+ if (proto.double_diff_size () > 0 ) {
471
+ CHECK_EQ (count_, proto.double_diff_size ());
472
+ Dtype* diff_vec = mutable_cpu_diff ();
473
+ for (int i = 0 ; i < count_; ++i) {
474
+ diff_vec[i] = proto.double_diff (i);
475
+ }
476
+ } else if (proto.diff_size () > 0 ) {
477
+ CHECK_EQ (count_, proto.diff_size ());
463
478
Dtype* diff_vec = mutable_cpu_diff ();
464
479
for (int i = 0 ; i < count_; ++i) {
465
480
diff_vec[i] = proto.diff (i);
466
481
}
467
482
}
468
483
}
469
484
470
- template <typename Dtype>
471
- void Blob<Dtype>::ToProto(BlobProto* proto, bool write_diff) const {
485
+ template <>
486
+ void Blob<double >::ToProto(BlobProto* proto, bool write_diff) const {
487
+ proto->clear_shape ();
488
+ for (int i = 0 ; i < shape_.size (); ++i) {
489
+ proto->mutable_shape ()->add_dim (shape_[i]);
490
+ }
491
+ proto->clear_double_data ();
492
+ proto->clear_double_diff ();
493
+ const double * data_vec = cpu_data ();
494
+ for (int i = 0 ; i < count_; ++i) {
495
+ proto->add_double_data (data_vec[i]);
496
+ }
497
+ if (write_diff) {
498
+ const double * diff_vec = cpu_diff ();
499
+ for (int i = 0 ; i < count_; ++i) {
500
+ proto->add_double_diff (diff_vec[i]);
501
+ }
502
+ }
503
+ }
504
+
505
+ template <>
506
+ void Blob<float >::ToProto(BlobProto* proto, bool write_diff) const {
472
507
proto->clear_shape ();
473
508
for (int i = 0 ; i < shape_.size (); ++i) {
474
509
proto->mutable_shape ()->add_dim (shape_[i]);
475
510
}
476
511
proto->clear_data ();
477
512
proto->clear_diff ();
478
- const Dtype * data_vec = cpu_data ();
513
+ const float * data_vec = cpu_data ();
479
514
for (int i = 0 ; i < count_; ++i) {
480
515
proto->add_data (data_vec[i]);
481
516
}
482
517
if (write_diff) {
483
- const Dtype * diff_vec = cpu_diff ();
518
+ const float * diff_vec = cpu_diff ();
484
519
for (int i = 0 ; i < count_; ++i) {
485
520
proto->add_diff (diff_vec[i]);
486
521
}
0 commit comments