From 1acad21bbf7a7eea1dc5cb9a68057d35210f7cdb Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 30 Jan 2018 20:27:38 +0800 Subject: [PATCH 01/12] init reader.h and reader.cc files --- paddle/framework/reader.cc | 51 ++++++++++++++++++++++++++++++ paddle/framework/reader.h | 65 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 116 insertions(+) create mode 100644 paddle/framework/reader.cc create mode 100644 paddle/framework/reader.h diff --git a/paddle/framework/reader.cc b/paddle/framework/reader.cc new file mode 100644 index 00000000000000..7f80dd7fc10a67 --- /dev/null +++ b/paddle/framework/reader.cc @@ -0,0 +1,51 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/framework/reader.h" + +namespace paddle { +namespace framework { + +DDim Reader::shape(int idx) const { + PADDLE_ENFORCE_LT( + idx, shapes_.size(), + "Cannot get the %d'th shape, 'shapes_' only has %d elements.", idx, + shapes_.size()); +} + +int RandomReader::ReadNext(std::vector* outs) { + PADDLE_ENFORCE_EQ( + shapes_.size(), outs.size(), + "shapes_.size() is %d, while outs.size() is %d. They are not equal.", + shapes_.size(), outs.size()); + std::minstd_rand engine; + unsigned int seed = std::random_device()(); + engine.seed(seed); + std::uniform_real_distribution dist(min_, max_); + for (int idx = 0; idx < shapes_.size(); ++idx) { + DDim shape = shapes_[idx]; + LoDTensor* out = outs[idx]; + int64_t numel = out->numel(); + PADDLE_ENFORCE_EQ(product(shape), numel, + "The product of %d'th shape is %lld, while the " + "corresponding out's numel is %lld. They are not equal.", + idx, product(shape), numel); + for (int64_t i = 0; i < numel, ++i) { + out[i] = dist(engine); + } + } + return 0; +} +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/reader.h b/paddle/framework/reader.h new file mode 100644 index 00000000000000..eed9c18d0877d1 --- /dev/null +++ b/paddle/framework/reader.h @@ -0,0 +1,65 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/framework/ddim.h" +#include "paddle/framework/lod_tensor.h" + +namespace paddle { +namespace framework { + +class Reader { + public: + virtual int ReadNext(std::vector* outs) = 0; + DDim shape(int idx) const; + + private: + std::vector shapes_; +}; + +// file readers + +class RandomReader : public Reader { + public: + RandomReader(const std::vector& shapes, float min, float max) + : shapes_(shapes), min_(min), max_(max) {} + int ReadNext(std::vector* outs) override; + + private: + float min_; + float max_; +}; + +// decorators + +class BatchReader : public Reader { + public: + BatchReader(const Reader* reader) : reader_(reader) {} + int ReadNext(std::vector* outs) override; + + private: + const Reader* reader_; +}; + +class ShuffleReader : public Reader { + public: + ShuffleReader(const Reader* reader) : reader_(reader) {} + int ReadNext(std::vector* outs) override; + + private: + const Reader* reader_; +}; +} // namespace framework +} // namespace paddle From f32ca6369099f5d3776ae87d431b9b39ea8eba3e Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 31 Jan 2018 18:46:45 +0800 Subject: [PATCH 02/12] draft of Reader classes --- paddle/framework/CMakeLists.txt | 2 + paddle/framework/reader.cc | 107 +++++++++++++++++++++++++------- paddle/framework/reader.h | 83 +++++++++++++++++++++---- 3 files changed, 159 insertions(+), 33 deletions(-) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 8c28709a68bec4..7eec91f9070de6 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -24,6 +24,8 @@ cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto) cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor paddle_memory) nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor) +cc_library(reader SRCS reader.cc DEPS lod_tensor ddim) + cc_test(variable_test SRCS variable_test.cc) cc_library(threadpool SRCS threadpool.cc DEPS enforce) diff --git a/paddle/framework/reader.cc b/paddle/framework/reader.cc index 7f80dd7fc10a67..e11662166c6100 100644 --- a/paddle/framework/reader.cc +++ b/paddle/framework/reader.cc @@ -17,35 +17,100 @@ namespace paddle { namespace framework { -DDim Reader::shape(int idx) const { +DDim Reader::shape(size_t idx) const { PADDLE_ENFORCE_LT( idx, shapes_.size(), "Cannot get the %d'th shape, 'shapes_' only has %d elements.", idx, shapes_.size()); + return shapes_[idx]; } -int RandomReader::ReadNext(std::vector* outs) { - PADDLE_ENFORCE_EQ( - shapes_.size(), outs.size(), - "shapes_.size() is %d, while outs.size() is %d. They are not equal.", - shapes_.size(), outs.size()); - std::minstd_rand engine; - unsigned int seed = std::random_device()(); - engine.seed(seed); - std::uniform_real_distribution dist(min_, max_); - for (int idx = 0; idx < shapes_.size(); ++idx) { - DDim shape = shapes_[idx]; - LoDTensor* out = outs[idx]; - int64_t numel = out->numel(); - PADDLE_ENFORCE_EQ(product(shape), numel, - "The product of %d'th shape is %lld, while the " - "corresponding out's numel is %lld. They are not equal.", - idx, product(shape), numel); - for (int64_t i = 0; i < numel, ++i) { - out[i] = dist(engine); +std::vector ShuffleReader::ReadNext() { + if (iteration_pos_ >= buffer_.size()) { + // Reload buffer with new data + buffer_.clear(); + for (int i = 0; i < buffer_size_; ++i) { + if (reader_->HasNext()) { + buffer_.push_back(reader_->ReadNext()); + } else { + break; + } } + std::random_shuffle(buffer_.begin(), buffer_.end()); + iteration_pos_ = 0; } - return 0; + if (buffer_.empty()) { + std::vector empty_res; + return empty_res; + } + return buffer_[iteration_pos_++]; +} + +std::vector BatchReader::ReadNext() { + buffer_.clear(); + for (int i = 0; i < batch_size_; ++i) { + if (reader_->HasNext()) { + buffer_.push_back(reader_->ReadNext()); + } else { + break; + } + } + // Concat instances + std::vector res; + if (buffer_.empty()) { + return res; + } + int out_num = buffer_[0].size(); + res.reserve(out_num); + for (int j = 0; j < out_num; ++j) { + // Merge shape and check date type + std::type_index batch_type = buffer_[0][j].type(); + DDim batch_shape = buffer_[0][j].dims(); + for (size_t i = 1; i < buffer_.size(); ++i) { + std::type_index ins_type = buffer_[i][j].type(); + DDim ins_shape = buffer_[i][j].dims(); + PADDLE_ENFORCE_EQ(batch_type, ins_type); + PADDLE_ENFORCE_EQ(slice_ddim(batch_shape, 1, batch_shape.size()), + slice_ddim(ins_shape, 1, ins_shape.size())); + PADDLE_ENFORCE_GT(ins_shape[0], 0); + batch_shape[0] += ins_shape[0]; + } + + LoDTensor out; + out.Resize(batch_shape); + out.mutable_data(platform::CPUPlace(), batch_type); + int64_t dst_offset = 0; + + // Merge lod and data + LoD batch_lod; + std::vector top_level_lod({0}); + for (size_t i = 0; i < buffer_.size(); ++i) { + DDim ins_shape = buffer_[i][j].dims(); + LoD ins_lod = buffer_[i][j].lod(); + if (i == 0) { + batch_lod = ins_lod; + } else { + PADDLE_ENFORCE_EQ(batch_lod.size(), ins_lod.size()); + for (size_t level_idx = 0; level_idx < batch_lod.size(); ++level_idx) { + auto& lod_level = batch_lod[level_idx]; + for (size_t k = 1; k < ins_lod[level_idx].size(); ++k) { + lod_level.push_back(ins_lod[level_idx][k] + lod_level.back()); + } + } + } + top_level_lod.push_back( + top_level_lod.back() + + (ins_lod.empty() ? ins_shape[0] : (ins_lod[0].size() - 1))); + + Tensor dst = out.Slice(dst_offset, dst_offset + ins_shape[0]); + Copy(buffer_[i][j], platform::CPUPlace(), &dst); + dst_offset += ins_shape[0]; + } + batch_lod.insert(batch_lod.begin(), top_level_lod); + out.set_lod(batch_lod); + res.push_back(out); + } + return res; } } // namespace framework } // namespace paddle diff --git a/paddle/framework/reader.h b/paddle/framework/reader.h index eed9c18d0877d1..58675863e56d94 100644 --- a/paddle/framework/reader.h +++ b/paddle/framework/reader.h @@ -22,20 +22,61 @@ namespace framework { class Reader { public: - virtual int ReadNext(std::vector* outs) = 0; - DDim shape(int idx) const; + Reader() {} + explicit Reader(const std::vector& shapes) : shapes_(shapes) {} + + virtual std::vector ReadNext() = 0; + virtual bool HasNext() const = 0; + + virtual DDim shape(size_t idx) const; + virtual std::vector shapes() const { return shapes_; } + + virtual ~Reader() {} private: + // set private to prevent directly access in decorators + // a decorator should access its underlying reader_'s shape, not its own. std::vector shapes_; }; // file readers +template class RandomReader : public Reader { public: RandomReader(const std::vector& shapes, float min, float max) - : shapes_(shapes), min_(min), max_(max) {} - int ReadNext(std::vector* outs) override; + : Reader(shapes), min_(min), max_(max) { + PADDLE_ENFORCE_LE(min, max, + "'min' should be less than or equal to 'max'.(%f vs %f)", + min, max); + } + + std::vector ReadNext() override { + std::minstd_rand engine; + unsigned int seed = std::random_device()(); + engine.seed(seed); + std::uniform_real_distribution dist(min_, max_); + + std::vector res; + res.reserve(shapes().size()); + for (const DDim& shape : shapes()) { + PADDLE_ENFORCE_GE( + shape.size(), 2, + "The rank of input data should be 2 at least.(Now it's %d)", + shape.size()); + LoDTensor out; + out.Resize(shape); + T* data = out.mutable_data(platform::CPUPlace()); + int64_t numel = product(shape); + for (int64_t i = 0; i < numel; ++i) { + data[i] = dist(engine); + } + res.push_back(out); + } + return res; + } + + bool HasNext() const override { return true; } private: float min_; @@ -44,22 +85,40 @@ class RandomReader : public Reader { // decorators -class BatchReader : public Reader { +class ShuffleReader : public Reader { public: - BatchReader(const Reader* reader) : reader_(reader) {} - int ReadNext(std::vector* outs) override; + ShuffleReader(Reader* reader, int buffer_size) + : reader_(reader), buffer_size_(buffer_size), iteration_pos_(0) { + buffer_.reserve(buffer_size); + } + std::vector ReadNext() override; + bool HasNext() const override { return reader_->HasNext(); } + + DDim shape(size_t idx) const override { return reader_->shape(idx); } + std::vector shapes() const override { return reader_->shapes(); } private: - const Reader* reader_; + Reader* reader_; + int buffer_size_; + std::vector> buffer_; + size_t iteration_pos_; }; -class ShuffleReader : public Reader { +class BatchReader : public Reader { public: - ShuffleReader(const Reader* reader) : reader_(reader) {} - int ReadNext(std::vector* outs) override; + BatchReader(Reader* reader, int batch_size) + : reader_(reader), batch_size_(batch_size) {} + std::vector ReadNext() override; + bool HasNext() const override { return reader_->HasNext(); }; + + DDim shape(size_t idx) const override { return reader_->shape(idx); } + std::vector shapes() const override { return reader_->shapes(); } private: - const Reader* reader_; + Reader* reader_; + int batch_size_; + std::vector> buffer_; }; + } // namespace framework } // namespace paddle From d8cc21da53e1113aaee3b43ea77d136bbbd204bb Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Thu, 1 Feb 2018 12:58:14 +0800 Subject: [PATCH 03/12] refine inheritance relationship --- paddle/framework/reader.cc | 2 +- paddle/framework/reader.h | 66 +++++++++++++++++++++----------------- 2 files changed, 37 insertions(+), 31 deletions(-) diff --git a/paddle/framework/reader.cc b/paddle/framework/reader.cc index e11662166c6100..a05bef42ffa0c1 100644 --- a/paddle/framework/reader.cc +++ b/paddle/framework/reader.cc @@ -17,7 +17,7 @@ namespace paddle { namespace framework { -DDim Reader::shape(size_t idx) const { +DDim FileReader::shape(size_t idx) const { PADDLE_ENFORCE_LT( idx, shapes_.size(), "Cannot get the %d'th shape, 'shapes_' only has %d elements.", idx, diff --git a/paddle/framework/reader.h b/paddle/framework/reader.h index 58675863e56d94..3954a1bea8a00c 100644 --- a/paddle/framework/reader.h +++ b/paddle/framework/reader.h @@ -20,32 +20,48 @@ namespace paddle { namespace framework { -class Reader { +class ReaderBase { public: - Reader() {} - explicit Reader(const std::vector& shapes) : shapes_(shapes) {} - virtual std::vector ReadNext() = 0; virtual bool HasNext() const = 0; - virtual DDim shape(size_t idx) const; - virtual std::vector shapes() const { return shapes_; } + virtual DDim shape(size_t idx) const = 0; + virtual std::vector shapes() const = 0; - virtual ~Reader() {} + virtual ~ReaderBase() {} +}; - private: - // set private to prevent directly access in decorators - // a decorator should access its underlying reader_'s shape, not its own. +class FileReader : public ReaderBase { + public: + explicit FileReader(const std::vector& shapes) : shapes_(shapes) {} + + DDim shape(size_t idx) const override; + std::vector shapes() const override { return shapes_; } + + protected: std::vector shapes_; }; +class ReaderDecorator : public ReaderBase { + public: + explicit ReaderDecorator(ReaderBase* reader) : reader_(reader) {} + + bool HasNext() const override { return reader_->HasNext(); } + + DDim shape(size_t idx) const override { return reader_->shape(idx); } + std::vector shapes() const override { return reader_->shapes(); } + + protected: + ReaderBase* reader_; +}; + // file readers template -class RandomReader : public Reader { +class RandomReader : public FileReader { public: RandomReader(const std::vector& shapes, float min, float max) - : Reader(shapes), min_(min), max_(max) { + : FileReader(shapes), min_(min), max_(max) { PADDLE_ENFORCE_LE(min, max, "'min' should be less than or equal to 'max'.(%f vs %f)", min, max); @@ -58,8 +74,8 @@ class RandomReader : public Reader { std::uniform_real_distribution dist(min_, max_); std::vector res; - res.reserve(shapes().size()); - for (const DDim& shape : shapes()) { + res.reserve(shapes_.size()); + for (const DDim& shape : shapes_) { PADDLE_ENFORCE_GE( shape.size(), 2, "The rank of input data should be 2 at least.(Now it's %d)", @@ -85,37 +101,27 @@ class RandomReader : public Reader { // decorators -class ShuffleReader : public Reader { +class ShuffleReader : public ReaderDecorator { public: - ShuffleReader(Reader* reader, int buffer_size) - : reader_(reader), buffer_size_(buffer_size), iteration_pos_(0) { + ShuffleReader(ReaderBase* reader, int buffer_size) + : ReaderDecorator(reader), buffer_size_(buffer_size), iteration_pos_(0) { buffer_.reserve(buffer_size); } std::vector ReadNext() override; - bool HasNext() const override { return reader_->HasNext(); } - - DDim shape(size_t idx) const override { return reader_->shape(idx); } - std::vector shapes() const override { return reader_->shapes(); } private: - Reader* reader_; int buffer_size_; std::vector> buffer_; size_t iteration_pos_; }; -class BatchReader : public Reader { +class BatchReader : public ReaderDecorator { public: - BatchReader(Reader* reader, int batch_size) - : reader_(reader), batch_size_(batch_size) {} + BatchReader(ReaderBase* reader, int batch_size) + : ReaderDecorator(reader), batch_size_(batch_size) {} std::vector ReadNext() override; - bool HasNext() const override { return reader_->HasNext(); }; - - DDim shape(size_t idx) const override { return reader_->shape(idx); } - std::vector shapes() const override { return reader_->shapes(); } private: - Reader* reader_; int batch_size_; std::vector> buffer_; }; From 93cab64185edf722dc493d1a00db5032014d836e Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Thu, 1 Feb 2018 17:38:57 +0800 Subject: [PATCH 04/12] Complete CreateRandomReaderOp --- paddle/framework/reader.h | 37 +++++++----- paddle/operators/create_reader_op.cc | 90 ++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+), 16 deletions(-) create mode 100644 paddle/operators/create_reader_op.cc diff --git a/paddle/framework/reader.h b/paddle/framework/reader.h index 3954a1bea8a00c..0669a7c7c75245 100644 --- a/paddle/framework/reader.h +++ b/paddle/framework/reader.h @@ -33,8 +33,6 @@ class ReaderBase { class FileReader : public ReaderBase { public: - explicit FileReader(const std::vector& shapes) : shapes_(shapes) {} - DDim shape(size_t idx) const override; std::vector shapes() const override { return shapes_; } @@ -44,8 +42,6 @@ class FileReader : public ReaderBase { class ReaderDecorator : public ReaderBase { public: - explicit ReaderDecorator(ReaderBase* reader) : reader_(reader) {} - bool HasNext() const override { return reader_->HasNext(); } DDim shape(size_t idx) const override { return reader_->shape(idx); } @@ -60,19 +56,19 @@ class ReaderDecorator : public ReaderBase { template class RandomReader : public FileReader { public: - RandomReader(const std::vector& shapes, float min, float max) - : FileReader(shapes), min_(min), max_(max) { + void Initialize(const std::vector& shapes, float min, float max) { PADDLE_ENFORCE_LE(min, max, "'min' should be less than or equal to 'max'.(%f vs %f)", min, max); + shapes_ = shapes; + min_ = min; + max_ = max; + unsigned int seed = std::random_device()(); + engine_.seed(seed); + dist_ = std::uniform_real_distribution(min_, max_); } std::vector ReadNext() override { - std::minstd_rand engine; - unsigned int seed = std::random_device()(); - engine.seed(seed); - std::uniform_real_distribution dist(min_, max_); - std::vector res; res.reserve(shapes_.size()); for (const DDim& shape : shapes_) { @@ -85,7 +81,7 @@ class RandomReader : public FileReader { T* data = out.mutable_data(platform::CPUPlace()); int64_t numel = product(shape); for (int64_t i = 0; i < numel; ++i) { - data[i] = dist(engine); + data[i] = dist_(engine_); } res.push_back(out); } @@ -97,16 +93,21 @@ class RandomReader : public FileReader { private: float min_; float max_; + std::minstd_rand engine_; + std::uniform_real_distribution dist_; }; // decorators class ShuffleReader : public ReaderDecorator { public: - ShuffleReader(ReaderBase* reader, int buffer_size) - : ReaderDecorator(reader), buffer_size_(buffer_size), iteration_pos_(0) { + void Initialize(ReaderBase* reader, int buffer_size) { + reader_ = reader; + buffer_size_ = buffer_size; + iteration_pos_ = 0; buffer_.reserve(buffer_size); } + std::vector ReadNext() override; private: @@ -117,8 +118,12 @@ class ShuffleReader : public ReaderDecorator { class BatchReader : public ReaderDecorator { public: - BatchReader(ReaderBase* reader, int batch_size) - : ReaderDecorator(reader), batch_size_(batch_size) {} + void Initialize(ReaderBase* reader, int batch_size) { + reader_ = reader; + batch_size_ = batch_size; + buffer_.reserve(batch_size_); + } + std::vector ReadNext() override; private: diff --git a/paddle/operators/create_reader_op.cc b/paddle/operators/create_reader_op.cc new file mode 100644 index 00000000000000..abdc12087e04b7 --- /dev/null +++ b/paddle/operators/create_reader_op.cc @@ -0,0 +1,90 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/framework/op_registry.h" +#include "paddle/framework/reader.h" + +namespace paddle { +namespace operators { + +// general infershape +class CreateReaderInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of CreateReaderOp should not be null."); + } +}; + +template +class CreateRandomReaderOp : public framework::OperatorBase { + public: + using framework::OperatorBase::OperatorBase; + void Run(const framework::Scope& scope, + const platform::Place& dev_place) const override { + const auto& shape_concat = Attr>("shape_concat"); + const auto& ranks = Attr>("ranks"); + PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0), + int(shape_concat.size()), + "The accumulate of all ranks should be equal to the " + "shape concat's length."); + std::vector shapes; + int offset = 0; + for (int len : ranks) { + auto start_it = shape_concat.begin() + offset; + auto end_it = start_it + len; + shapes.push_back( + framework::make_ddim(std::vector(start_it, end_it))); + offset += len; + } + auto* out = scope.FindVar(Output("Out")) + ->template GetMutable>(); + out->Initialize(shapes, Attr("min"), Attr("max")); + } +}; + +class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker { + public: + CreateRandomReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(op_proto, op_checker) { + AddOutput("Out", "(RandomReader) The created random reader."); + AddAttr>("shape_concat", + "The concat of all data's shapes."); + AddAttr>( + "ranks", + "The ranks of each data." + "e.g." + "shape_concat = [2,3,4,5,6]" + "ranks = [3,2]" + "It means the reader will generate two data each time," + "whose shapes are [2,3,4] and [5,6] respectively."); + AddAttr("min", "The lower bound of reader's uniform distribution."); + AddAttr("max", "The upper bound of reader's uniform distribution."); + AddComment(R"DOC( + CreateRandomReader Operator + + This Op creates a random reader. + The reader generates random data instead of really reading from files. + Generated data follow an uniform distribution between 'min' and 'max'. + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(create_random_reader, ops::CreateRandomReaderOp, + ops::CreateReaderInferShape, ops::CreateRandomReaderOpMaker, + paddle::framework::EmptyGradOpMaker); \ No newline at end of file From 1696cb0e510a8d52427b6ca96900bab4e03b5af1 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Thu, 1 Feb 2018 21:10:16 +0800 Subject: [PATCH 05/12] Complete CreateShuffleReaderOp --- paddle/framework/reader.h | 41 +++++++++++++------ paddle/operators/CMakeLists.txt | 5 ++- paddle/operators/create_reader_op.cc | 59 +++++++++++++++++++++++++--- 3 files changed, 87 insertions(+), 18 deletions(-) diff --git a/paddle/framework/reader.h b/paddle/framework/reader.h index 0669a7c7c75245..18a34bfd170116 100644 --- a/paddle/framework/reader.h +++ b/paddle/framework/reader.h @@ -33,6 +33,10 @@ class ReaderBase { class FileReader : public ReaderBase { public: + explicit FileReader(const std::vector& shapes) : shapes_(shapes) { + PADDLE_ENFORCE(!shapes_.empty()); + } + DDim shape(size_t idx) const override; std::vector shapes() const override { return shapes_; } @@ -42,6 +46,10 @@ class FileReader : public ReaderBase { class ReaderDecorator : public ReaderBase { public: + explicit ReaderDecorator(ReaderBase* reader) : reader_(reader) { + PADDLE_ENFORCE_NOT_NULL(reader_); + } + bool HasNext() const override { return reader_->HasNext(); } DDim shape(size_t idx) const override { return reader_->shape(idx); } @@ -56,13 +64,11 @@ class ReaderDecorator : public ReaderBase { template class RandomReader : public FileReader { public: - void Initialize(const std::vector& shapes, float min, float max) { + RandomReader(const std::vector& shapes, float min, float max) + : FileReader(shapes), min_(min), max_(max) { PADDLE_ENFORCE_LE(min, max, "'min' should be less than or equal to 'max'.(%f vs %f)", min, max); - shapes_ = shapes; - min_ = min; - max_ = max; unsigned int seed = std::random_device()(); engine_.seed(seed); dist_ = std::uniform_real_distribution(min_, max_); @@ -101,10 +107,8 @@ class RandomReader : public FileReader { class ShuffleReader : public ReaderDecorator { public: - void Initialize(ReaderBase* reader, int buffer_size) { - reader_ = reader; - buffer_size_ = buffer_size; - iteration_pos_ = 0; + ShuffleReader(ReaderBase* reader, int buffer_size) + : ReaderDecorator(reader), buffer_size_(buffer_size), iteration_pos_(0) { buffer_.reserve(buffer_size); } @@ -118,9 +122,8 @@ class ShuffleReader : public ReaderDecorator { class BatchReader : public ReaderDecorator { public: - void Initialize(ReaderBase* reader, int batch_size) { - reader_ = reader; - batch_size_ = batch_size; + BatchReader(ReaderBase* reader, int batch_size) + : ReaderDecorator(reader), batch_size_(batch_size) { buffer_.reserve(batch_size_); } @@ -131,5 +134,21 @@ class BatchReader : public ReaderDecorator { std::vector> buffer_; }; +class ReaderHolder { + public: + void Reset(ReaderBase* reader) { reader_.reset(reader); } + + ReaderBase* Get() const { return reader_.get(); } + + std::vector ReadNext() { return reader_->ReadNext(); } + bool HasNext() const { return reader_->HasNext(); } + + DDim shape(size_t idx) const { return reader_->shape(idx); } + std::vector shapes() const { return reader_->shapes(); } + + private: + std::unique_ptr reader_; +}; + } // namespace framework } // namespace paddle diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 48cf5816cce4bb..3684eb0dcca759 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -62,7 +62,7 @@ function(op_library TARGET) endif() # Define operators that don't need pybind here. - foreach(manual_pybind_op "net_op" "compare_op" "logical_op" "nccl_op" "tensor_array_read_write_op") + foreach(manual_pybind_op "net_op" "compare_op" "logical_op" "nccl_op" "tensor_array_read_write_op" "create_reader_op") if ("${TARGET}" STREQUAL "${manual_pybind_op}") set(pybind_flag 1) endif() @@ -153,6 +153,7 @@ op_library(recurrent_op DEPS executor) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale math_function) op_library(cos_sim_op DEPS cos_sim_functor) op_library(parallel_do_op DEPS executor) +op_library(create_reader_op DEPS reader) # Regist multiple Kernel to pybind if (WITH_GPU) @@ -178,7 +179,7 @@ list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) foreach(src ${GENERAL_OPS}) op_library(${src}) endforeach() -file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n") +file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\nUSE_NO_KERNEL_OP(create_random_reader);\n") set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library") diff --git a/paddle/operators/create_reader_op.cc b/paddle/operators/create_reader_op.cc index abdc12087e04b7..29b487e10b5c67 100644 --- a/paddle/operators/create_reader_op.cc +++ b/paddle/operators/create_reader_op.cc @@ -18,7 +18,7 @@ namespace paddle { namespace operators { -// general infershape +// general infershape for file readers class CreateReaderInferShape : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext* ctx) const override { @@ -35,6 +35,7 @@ class CreateRandomReaderOp : public framework::OperatorBase { const platform::Place& dev_place) const override { const auto& shape_concat = Attr>("shape_concat"); const auto& ranks = Attr>("ranks"); + PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty()); PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0), int(shape_concat.size()), "The accumulate of all ranks should be equal to the " @@ -49,8 +50,9 @@ class CreateRandomReaderOp : public framework::OperatorBase { offset += len; } auto* out = scope.FindVar(Output("Out")) - ->template GetMutable>(); - out->Initialize(shapes, Attr("min"), Attr("max")); + ->template GetMutable(); + out->Reset(new framework::RandomReader(shapes, Attr("min"), + Attr("max"))); } }; @@ -58,7 +60,7 @@ class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker { public: CreateRandomReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(op_proto, op_checker) { - AddOutput("Out", "(RandomReader) The created random reader."); + AddOutput("Out", "(ReaderHolder) The created random reader."); AddAttr>("shape_concat", "The concat of all data's shapes."); AddAttr>( @@ -81,10 +83,57 @@ class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker { } }; +class CreateShuffleReaderInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Underlying_reader"), + "Input(Underlying_reader) of CreateShuffleReaderOp should " + "not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of CreateShuffleReaderOp should not be null."); + } +}; + +class CreateShuffleReaderOp : public framework::OperatorBase { + public: + using framework::OperatorBase::OperatorBase; + void Run(const framework::Scope& scope, + const platform::Place& dev_place) const override { + const auto& underlying_reader = scope.FindVar(Input("Underlying_reader")) + ->Get(); + auto* out = scope.FindVar(Output("Out")) + ->template GetMutable(); + out->Reset(new framework::ShuffleReader(underlying_reader.Get(), + Attr("buffer_size"))); + } +}; + +class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker { + public: + CreateShuffleReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(op_proto, op_checker) { + AddInput( + "Underlying_reader", + "(ReaderHolder) The underlying reader for creating a shuffle reader."); + AddOutput("Out", "(ReaderHolder) The created shuffle reader."); + AddAttr("buffer_size", "The shuffle buffer size.").GreaterThan(0); + AddComment(R"DOC( + CreateShuffleReader Operator + + A shuffle reader takes another reader as its 'underlying reader' + and output the underlying reader's outputs in a shuffled order. + )DOC"); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(create_random_reader, ops::CreateRandomReaderOp, ops::CreateReaderInferShape, ops::CreateRandomReaderOpMaker, - paddle::framework::EmptyGradOpMaker); \ No newline at end of file + paddle::framework::EmptyGradOpMaker); +REGISTER_OPERATOR(create_shuffle_reader, ops::CreateShuffleReaderOp, + ops::CreateShuffleReaderInferShape, + ops::CreateShuffleReaderOpMaker, + paddle::framework::EmptyGradOpMaker); From 3dfd1da138805e0c98be4c57f3ea73d62865cd18 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Thu, 1 Feb 2018 23:43:33 +0800 Subject: [PATCH 06/12] Complete CreateBatchReaderOp --- paddle/framework/reader.h | 12 ++--- paddle/operators/create_reader_op.cc | 71 +++++++++++++++++++++------- 2 files changed, 61 insertions(+), 22 deletions(-) diff --git a/paddle/framework/reader.h b/paddle/framework/reader.h index 18a34bfd170116..8275ea474b41d7 100644 --- a/paddle/framework/reader.h +++ b/paddle/framework/reader.h @@ -44,9 +44,9 @@ class FileReader : public ReaderBase { std::vector shapes_; }; -class ReaderDecorator : public ReaderBase { +class DecoratedReader : public ReaderBase { public: - explicit ReaderDecorator(ReaderBase* reader) : reader_(reader) { + explicit DecoratedReader(ReaderBase* reader) : reader_(reader) { PADDLE_ENFORCE_NOT_NULL(reader_); } @@ -105,10 +105,10 @@ class RandomReader : public FileReader { // decorators -class ShuffleReader : public ReaderDecorator { +class ShuffleReader : public DecoratedReader { public: ShuffleReader(ReaderBase* reader, int buffer_size) - : ReaderDecorator(reader), buffer_size_(buffer_size), iteration_pos_(0) { + : DecoratedReader(reader), buffer_size_(buffer_size), iteration_pos_(0) { buffer_.reserve(buffer_size); } @@ -120,10 +120,10 @@ class ShuffleReader : public ReaderDecorator { size_t iteration_pos_; }; -class BatchReader : public ReaderDecorator { +class BatchReader : public DecoratedReader { public: BatchReader(ReaderBase* reader, int batch_size) - : ReaderDecorator(reader), batch_size_(batch_size) { + : DecoratedReader(reader), batch_size_(batch_size) { buffer_.reserve(batch_size_); } diff --git a/paddle/operators/create_reader_op.cc b/paddle/operators/create_reader_op.cc index 29b487e10b5c67..9cf27bbfc694b7 100644 --- a/paddle/operators/create_reader_op.cc +++ b/paddle/operators/create_reader_op.cc @@ -19,11 +19,22 @@ namespace paddle { namespace operators { // general infershape for file readers -class CreateReaderInferShape : public framework::InferShapeBase { +class CreateFileReaderInferShape : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of CreateReaderOp should not be null."); + "The output file reader should not be null."); + } +}; + +// general infershape for decorated readers +class CreateDecoratedReaderInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Underlying_reader"), + "Input(Underlying_reader) should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "The output decorated reader should not be null."); } }; @@ -83,17 +94,6 @@ class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker { } }; -class CreateShuffleReaderInferShape : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Underlying_reader"), - "Input(Underlying_reader) of CreateShuffleReaderOp should " - "not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of CreateShuffleReaderOp should not be null."); - } -}; - class CreateShuffleReaderOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; @@ -121,7 +121,41 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker { CreateShuffleReader Operator A shuffle reader takes another reader as its 'underlying reader' - and output the underlying reader's outputs in a shuffled order. + and yields the underlying reader's outputs in a shuffled order. + )DOC"); + } +}; + +class CreateBatchReaderOp : public framework::OperatorBase { + public: + using framework::OperatorBase::OperatorBase; + void Run(const framework::Scope& scope, + const platform::Place& dev_place) const override { + const auto& underlying_reader = scope.FindVar(Input("Underlying_reader")) + ->Get(); + auto* out = scope.FindVar(Output("Out")) + ->template GetMutable(); + out->Reset(new framework::BatchReader(underlying_reader.Get(), + Attr("batch_size"))); + } +}; + +class CreateBatchReaderOpMaker : public framework::OpProtoAndCheckerMaker { + public: + CreateBatchReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(op_proto, op_checker) { + AddInput( + "Underlying_reader", + "(ReaderHolder) The underlying reader for creating a batch reader."); + AddOutput("Out", "(ReaderHolder) The created batch reader."); + AddAttr("batch_size", + "How many instances the batch reader yields each time.") + .GreaterThan(0); + AddComment(R"DOC( + CreateBatchReader Operator + + A batch reader takes another reader as its 'underlying reader', + gathers the underlying reader's outputs and then yields them in batches. )DOC"); } }; @@ -131,9 +165,14 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker { namespace ops = paddle::operators; REGISTER_OPERATOR(create_random_reader, ops::CreateRandomReaderOp, - ops::CreateReaderInferShape, ops::CreateRandomReaderOpMaker, + ops::CreateFileReaderInferShape, + ops::CreateRandomReaderOpMaker, paddle::framework::EmptyGradOpMaker); REGISTER_OPERATOR(create_shuffle_reader, ops::CreateShuffleReaderOp, - ops::CreateShuffleReaderInferShape, + ops::CreateDecoratedReaderInferShape, ops::CreateShuffleReaderOpMaker, paddle::framework::EmptyGradOpMaker); +REGISTER_OPERATOR(create_batch_reader, ops::CreateBatchReaderOp, + ops::CreateDecoratedReaderInferShape, + ops::CreateBatchReaderOpMaker, + paddle::framework::EmptyGradOpMaker); From 53e697c11d30a84e59fab7d1c1d54718eed14f66 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Fri, 2 Feb 2018 00:06:46 +0800 Subject: [PATCH 07/12] refine code --- paddle/framework/reader.h | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/paddle/framework/reader.h b/paddle/framework/reader.h index 8275ea474b41d7..f450e67689a774 100644 --- a/paddle/framework/reader.h +++ b/paddle/framework/reader.h @@ -66,9 +66,8 @@ class RandomReader : public FileReader { public: RandomReader(const std::vector& shapes, float min, float max) : FileReader(shapes), min_(min), max_(max) { - PADDLE_ENFORCE_LE(min, max, - "'min' should be less than or equal to 'max'.(%f vs %f)", - min, max); + PADDLE_ENFORCE_LE( + min, max, "'min' shouldn't be greater than 'max'.(%f vs %f)", min, max); unsigned int seed = std::random_device()(); engine_.seed(seed); dist_ = std::uniform_real_distribution(min_, max_); @@ -103,7 +102,7 @@ class RandomReader : public FileReader { std::uniform_real_distribution dist_; }; -// decorators +// decorated readers class ShuffleReader : public DecoratedReader { public: @@ -134,6 +133,8 @@ class BatchReader : public DecoratedReader { std::vector> buffer_; }; +// The ReaderHolder is used as readers' unified wrapper, +// making it easier to access different type readers in Variables. class ReaderHolder { public: void Reset(ReaderBase* reader) { reader_.reset(reader); } From 1010e39bdf738029fcb78b0d388a91dfdebdda2f Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 6 Feb 2018 12:39:51 +0800 Subject: [PATCH 08/12] Add ReadOp --- paddle/framework/framework.proto | 4 +- paddle/framework/op_desc.cc | 29 +++++++-- paddle/framework/operator.cc | 26 ++++++-- paddle/framework/reader.cc | 40 ++++++------ paddle/framework/reader.h | 32 +++++----- paddle/framework/shape_inference.cc | 14 +++++ paddle/framework/shape_inference.h | 3 +- paddle/operators/read_op.cc | 94 +++++++++++++++++++++++++++++ 8 files changed, 193 insertions(+), 49 deletions(-) create mode 100644 paddle/operators/read_op.cc diff --git a/paddle/framework/framework.proto b/paddle/framework/framework.proto index f65ccae6e6a4df..d7be1a7352da56 100644 --- a/paddle/framework/framework.proto +++ b/paddle/framework/framework.proto @@ -116,7 +116,7 @@ message LoDTensorArrayDesc { optional int32 lod_level = 2 [ default = 0 ]; } -message Reader { repeated LoDTensorDesc lod_tensor = 1; } +message ReaderDesc { repeated LoDTensorDesc lod_tensor = 1; } message VarDesc { enum VarType { @@ -136,7 +136,7 @@ message VarDesc { optional LoDTensorDesc lod_tensor = 4; optional TensorDesc selected_rows = 5; optional LoDTensorArrayDesc tensor_array = 6; - optional Reader reader = 7; + optional ReaderDesc reader = 7; } message BlockDesc { diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index ad361852ec9f2b..772ec26895e945 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -72,6 +72,8 @@ class CompileTimeInferShapeContext : public InferShapeContext { void SetDim(const std::string &name, const DDim &dim) override; + std::vector GetRepeatedDim(const std::string &name) const override; + const OpDesc &op_; const BlockDesc &block_; }; @@ -457,22 +459,37 @@ const std::vector &CompileTimeInferShapeContext::Outputs( DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const { auto var = block_.FindVarRecursive(name); PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name); + DDim res; try { auto shape = var->GetShape(); - if (shape.empty()) { - return framework::make_ddim({0UL}); - } else { - return framework::make_ddim(var->GetShape()); - } + res = shape.empty() ? make_ddim({0UL}) : make_ddim(shape); } catch (...) { VLOG(5) << "GetDim of variable " << name << " error"; std::rethrow_exception(std::current_exception()); } + return res; +} + +std::vector CompileTimeInferShapeContext::GetRepeatedDim( + const std::string &name) const { + auto var = block_.FindVarRecursive(name); + PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name); + std::vector res; + try { + auto shapes = var->GetShapes(); + for (const auto &s : shapes) { + res.push_back(s.empty() ? make_ddim({0UL}) : make_ddim(s)); + } + } catch (...) { + VLOG(5) << "GetRepeatedDim of variable " << name << " error."; + std::rethrow_exception(std::current_exception()); + } + return res; } void CompileTimeInferShapeContext::SetDim(const std::string &name, const DDim &dim) { - block_.FindVarRecursive(name)->SetShape(framework::vectorize(dim)); + block_.FindVarRecursive(name)->SetShape(vectorize(dim)); } bool CompileTimeInferShapeContext::IsRuntime() const { return false; } diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 81fa8cf477423f..1aa111dc76de10 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -320,8 +320,8 @@ class RuntimeInferShapeContext : public InferShapeContext { if (length == 0) { return false; } - PADDLE_ENFORCE_EQ(length, 1UL, "Input %s should have more than one inputs", - name); + PADDLE_ENFORCE_EQ(length, 1UL, + "Input %s should not have more than one inputs", name); auto ipt = ins[0]; auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); return var != nullptr; @@ -333,8 +333,8 @@ class RuntimeInferShapeContext : public InferShapeContext { if (length == 0) { return false; } - PADDLE_ENFORCE_EQ(length, 1UL, "Output %s should have more than one inputs", - name); + PADDLE_ENFORCE_EQ(length, 1UL, + "Output %s should not have more than one inputs", name); auto ipt = outs[0]; auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); return var != nullptr; @@ -421,8 +421,22 @@ class RuntimeInferShapeContext : public InferShapeContext { } else if (var->IsType()) { return var->Get().GetCompleteDims(); } else { - PADDLE_THROW("Variable %s type_id %s, expect LoDTensor/SelectedRows.", - name, var->Type().name()); + PADDLE_THROW( + "Only LoDTensor/SelectedRows support 'GetDim', but Variable %s's " + "type_id is %s.", + name, var->Type().name()); + } + } + + std::vector GetRepeatedDim(const std::string& name) const override { + Variable* var = scope_.FindVar(name); + if (var->IsType()) { + return var->Get().shapes(); + } else { + PADDLE_THROW( + "Only ReaderHolder support 'GetRepeatedDim', but Variable %s's " + "type_id is %s.", + name, var->Type().name()); } } diff --git a/paddle/framework/reader.cc b/paddle/framework/reader.cc index a05bef42ffa0c1..76cbc827ba5e84 100644 --- a/paddle/framework/reader.cc +++ b/paddle/framework/reader.cc @@ -25,13 +25,15 @@ DDim FileReader::shape(size_t idx) const { return shapes_[idx]; } -std::vector ShuffleReader::ReadNext() { +void ShuffleReader::ReadNext(std::vector* out) { if (iteration_pos_ >= buffer_.size()) { // Reload buffer with new data buffer_.clear(); + buffer_.reverse(buffer_size_); for (int i = 0; i < buffer_size_; ++i) { if (reader_->HasNext()) { - buffer_.push_back(reader_->ReadNext()); + buffer.push_back(std::vector()); + reader_->ReadNext(&buffer.back()); } else { break; } @@ -39,29 +41,32 @@ std::vector ShuffleReader::ReadNext() { std::random_shuffle(buffer_.begin(), buffer_.end()); iteration_pos_ = 0; } - if (buffer_.empty()) { - std::vector empty_res; - return empty_res; + out->clear(); + if (!buffer_.empty()) { + std::swap(*out, buffer_[iteration_pos_++]); } - return buffer_[iteration_pos_++]; + // if buffer_ is empty, the 'out' will return as an empty vector. } -std::vector BatchReader::ReadNext() { +void BatchReader::ReadNext(std::vector* out) { buffer_.clear(); + buffer_.reserve(batch_size_); for (int i = 0; i < batch_size_; ++i) { if (reader_->HasNext()) { - buffer_.push_back(reader_->ReadNext()); + buffer_.push_back(std::vector()); + reader_->ReadNext(&buffer_.back()); } else { break; } } // Concat instances - std::vector res; + out.clear(); if (buffer_.empty()) { - return res; + // if buffer_ is empty, the 'out' will return as an empty vector. + return; } int out_num = buffer_[0].size(); - res.reserve(out_num); + out->reserve(out_num); for (int j = 0; j < out_num; ++j) { // Merge shape and check date type std::type_index batch_type = buffer_[0][j].type(); @@ -76,9 +81,9 @@ std::vector BatchReader::ReadNext() { batch_shape[0] += ins_shape[0]; } - LoDTensor out; - out.Resize(batch_shape); - out.mutable_data(platform::CPUPlace(), batch_type); + LoDTensor out_tensor; + out_tensor.Resize(batch_shape); + out_tensor.mutable_data(platform::CPUPlace(), batch_type); int64_t dst_offset = 0; // Merge lod and data @@ -102,15 +107,14 @@ std::vector BatchReader::ReadNext() { top_level_lod.back() + (ins_lod.empty() ? ins_shape[0] : (ins_lod[0].size() - 1))); - Tensor dst = out.Slice(dst_offset, dst_offset + ins_shape[0]); + Tensor dst = out_tensor.Slice(dst_offset, dst_offset + ins_shape[0]); Copy(buffer_[i][j], platform::CPUPlace(), &dst); dst_offset += ins_shape[0]; } batch_lod.insert(batch_lod.begin(), top_level_lod); - out.set_lod(batch_lod); - res.push_back(out); + out_tensor.set_lod(batch_lod); + out->push_back(out_tensor); } - return res; } } // namespace framework } // namespace paddle diff --git a/paddle/framework/reader.h b/paddle/framework/reader.h index f450e67689a774..523ff28c990799 100644 --- a/paddle/framework/reader.h +++ b/paddle/framework/reader.h @@ -15,14 +15,14 @@ #pragma once #include "paddle/framework/ddim.h" -#include "paddle/framework/lod_tensor.h" +#include "paddle/framework/lod_tensor_array.h" namespace paddle { namespace framework { class ReaderBase { public: - virtual std::vector ReadNext() = 0; + virtual void ReadNext(std::vector* out) = 0; virtual bool HasNext() const = 0; virtual DDim shape(size_t idx) const = 0; @@ -73,24 +73,24 @@ class RandomReader : public FileReader { dist_ = std::uniform_real_distribution(min_, max_); } - std::vector ReadNext() override { - std::vector res; - res.reserve(shapes_.size()); + void ReadNext(std::vector* out) override { + out.clear(); + out.reserve(shapes_.size()); for (const DDim& shape : shapes_) { PADDLE_ENFORCE_GE( shape.size(), 2, - "The rank of input data should be 2 at least.(Now it's %d)", + "The rank of reader's output data should be 2 at least.(Now it's %d)", shape.size()); - LoDTensor out; - out.Resize(shape); - T* data = out.mutable_data(platform::CPUPlace()); + LoDTensor out_tensor; + out_tensor.Resize(shape); + T* data = out_tensor.mutable_data(platform::CPUPlace()); int64_t numel = product(shape); for (int64_t i = 0; i < numel; ++i) { data[i] = dist_(engine_); } - res.push_back(out); + out.push_back(out_tensor); } - return res; + return out; } bool HasNext() const override { return true; } @@ -111,11 +111,11 @@ class ShuffleReader : public DecoratedReader { buffer_.reserve(buffer_size); } - std::vector ReadNext() override; + void ReadNext(std::vector* out) override; private: int buffer_size_; - std::vector> buffer_; + std::vector> buffer_; size_t iteration_pos_; }; @@ -126,11 +126,11 @@ class BatchReader : public DecoratedReader { buffer_.reserve(batch_size_); } - std::vector ReadNext() override; + void ReadNext(std::vector* out) override; private: int batch_size_; - std::vector> buffer_; + std::vector> buffer_; }; // The ReaderHolder is used as readers' unified wrapper, @@ -141,7 +141,7 @@ class ReaderHolder { ReaderBase* Get() const { return reader_.get(); } - std::vector ReadNext() { return reader_->ReadNext(); } + void ReadNext(std::vector* out) { reader_->ReadNext(out); } bool HasNext() const { return reader_->HasNext(); } DDim shape(size_t idx) const { return reader_->shape(idx); } diff --git a/paddle/framework/shape_inference.cc b/paddle/framework/shape_inference.cc index a0fa467291bb42..4a8acfb87ff122 100644 --- a/paddle/framework/shape_inference.cc +++ b/paddle/framework/shape_inference.cc @@ -32,6 +32,16 @@ std::vector InferShapeContext::GetInputsDim( return GetDims(arg_names); } +std::vector InferShapeContext::GetReaderDims( + const std::string &name) const { + const std::vector &arg_names = Inputs(name); + PADDLE_ENFORCE_EQ( + arg_names.size(), 1UL, + "Reader input '%s' should hold one element, but now it holds %d", name, + arg_names.size()); + return this->GetRepeatedDims(arg_names[0]); +} + DDim InferShapeContext::GetInputsElementDim(const std::string &name, int idx) const { const std::vector &names = Inputs(name); @@ -61,6 +71,7 @@ std::vector InferShapeContext::GetDims( [this](const std::string &name) { return this->GetDim(name); }); return ret; } + void InferShapeContext::SetDims(const std::vector &names, const std::vector &dims) { size_t length = names.size(); @@ -72,14 +83,17 @@ void InferShapeContext::SetDims(const std::vector &names, SetDim(names[i], dims[i]); } } + std::vector InferShapeContext::GetInputsVarType( const std::string &name) const { return GetVarTypes(Inputs(name)); } + std::vector InferShapeContext::GetOutputsVarType( const std::string &name) const { return GetVarTypes(Outputs(name)); } + std::vector InferShapeContext::GetVarTypes( const std::vector &names) const { std::vector retv; diff --git a/paddle/framework/shape_inference.h b/paddle/framework/shape_inference.h index 830f199ed14515..f1a64e9024beb8 100644 --- a/paddle/framework/shape_inference.h +++ b/paddle/framework/shape_inference.h @@ -36,8 +36,8 @@ class InferShapeContext { virtual bool HasOutputs(const std::string &name) const = 0; DDim GetInputDim(const std::string &name) const; - std::vector GetInputsDim(const std::string &name) const; + std::vector GetReaderDims(const std::string &name) const DDim; DDim GetInputsElementDim(const std::string &name, int idx) const; void SetOutputDim(const std::string &name, const DDim &dim); @@ -61,6 +61,7 @@ class InferShapeContext { protected: virtual DDim GetDim(const std::string &name) const = 0; virtual void SetDim(const std::string &name, const DDim &dim) = 0; + std::vector GetRepeatedDim(const std::string &name) const = 0; std::vector GetDims(const std::vector &names) const; std::vector GetVarTypes( diff --git a/paddle/operators/read_op.cc b/paddle/operators/read_op.cc new file mode 100644 index 00000000000000..c6ff4ba8fee502 --- /dev/null +++ b/paddle/operators/read_op.cc @@ -0,0 +1,94 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/framework/op_registry.h" +#include "paddle/framework/reader.h" + +namespace paddle { +namespace operators { + +class ReadInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Reader"), + "The ReadOp must take a reader as input."); + PADDLE_ENFORCE(ctx->HasOutputs("Out"), + "The ReadOp should be assigned with output."); + std::vector reader_dims = ctx->GetReaderDims("Reader"); + std::vector out_names = ctx->Outputs("Out"); + PADDLE_ENFORCE_EQ( + reader_dims.size(), out_names.size(), + "The reader's dim number doesn't match the output number."); + ctx->SetOutputsDim("Out", reader_dims); + } +}; + +class ReadInferVarType : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc& op_desc, + framework::BlockDesc* block) const override { + std::string reader_name = op_desc.Input("Reader")[0]; + std::vector out_names = op_desc.Output("Out"); + framework::VarDesc reader = block.FindVarRecursive(reader_name); + auto dtypes = reader.GetDataTypes(); + PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size()); + for (size_t i = 0; i < dtypes.size(); ++i) { + faremwork::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]); + out.SetType(framework::proto::DataType::LOD_TENSOR); + out.SetDataType(dtypes[i]); + } + } +}; + +class ReadOp : public framework::OperatorBase { + public: + using framework::OperatorBase::OperatorBase; + void Run(const framework::Scope& scope, + const platform::Place& dev_place) const override { + const framework::ReaderHolder& reader = + scope.FindVar(Input("Reader"))->Get(); + if (!reader.HasNext()) { + // what shall we do??? + return; + } + std::vector out_arg_names = Outputs("Out"); + std::vector ins; + reader.ReadNext(&ins); + PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size()); + for (size_t i = 0; i < ins.size(); ++i) { + auto* out = + scope.FindVar(out_arg_names[i])->GetMutable(); + PADDLE_ENFORCE_EQ(ins[i].dims(), out->dims()); + out->ShareDataWith(ins[i]); + out->set_lod(ins[i].lod()); + } + } +}; + +class ReadOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ReadOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(op_proto, op_checker) { + AddInput("Reader", "(ReaderHolder) The executed reader."); + AddOutput("Out", "(LoDTensor) The output data.").AsDuplicable(); + AddComment(R"DOC( + Read Operator + + Execute a given reader once and output data. + )DOC") + } +}; + +} // namespace operators +} // namespace paddle \ No newline at end of file From 0bb9c80ef960d777c5937f8fed8ddf75f2ac6a18 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 6 Feb 2018 23:46:18 +0800 Subject: [PATCH 09/12] refine code and add unit tests --- paddle/framework/executor.cc | 7 +- paddle/framework/op_desc.cc | 17 ++++- paddle/framework/operator.cc | 17 ++++- paddle/framework/reader.cc | 16 ++--- paddle/framework/reader.h | 51 +++++++------ paddle/framework/shape_inference.cc | 10 +++ paddle/framework/shape_inference.h | 7 +- paddle/framework/var_desc.cc | 35 +++++---- paddle/framework/var_type.h | 8 ++- paddle/operators/create_reader_op.cc | 61 +++++++++++----- paddle/operators/read_op.cc | 28 ++++---- paddle/pybind/protobuf.cc | 2 - python/paddle/v2/fluid/executor.py | 3 +- .../paddle/v2/fluid/tests/test_cpp_reader.py | 71 +++++++++++++++++++ 14 files changed, 244 insertions(+), 89 deletions(-) create mode 100644 python/paddle/v2/fluid/tests/test_cpp_reader.py diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc index 9a232b08434d29..2a88e5a92985fa 100644 --- a/paddle/framework/executor.cc +++ b/paddle/framework/executor.cc @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/framework/lod_rank_table.h" #include "paddle/framework/lod_tensor_array.h" #include "paddle/framework/op_registry.h" +#include "paddle/framework/reader.h" #include "paddle/platform/place.h" #include "paddle/platform/profiler.h" @@ -52,11 +53,13 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) { var->GetMutable(); } else if (var_type == proto::VarDesc::PLACE_LIST) { var->GetMutable(); + } else if (var_type == proto::VarDesc::READER) { + var->GetMutable(); } else { PADDLE_THROW( "Variable type %d is not in " - "[LoDTensor, SelectedRows, FEED_MINIBATCH, FETCH_LIST, LOD_RANK_TABLE," - " PLACE_LIST]", + "[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, " + "LOD_RANK_TABLE, PLACE_LIST, READER]", var_type); } } diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index 772ec26895e945..ea4028750248ec 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -72,7 +72,10 @@ class CompileTimeInferShapeContext : public InferShapeContext { void SetDim(const std::string &name, const DDim &dim) override; - std::vector GetRepeatedDim(const std::string &name) const override; + std::vector GetRepeatedDims(const std::string &name) const override; + + void SetRepeatedDims(const std::string &name, + const std::vector &dims) override; const OpDesc &op_; const BlockDesc &block_; @@ -470,7 +473,7 @@ DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const { return res; } -std::vector CompileTimeInferShapeContext::GetRepeatedDim( +std::vector CompileTimeInferShapeContext::GetRepeatedDims( const std::string &name) const { auto var = block_.FindVarRecursive(name); PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name); @@ -491,6 +494,16 @@ void CompileTimeInferShapeContext::SetDim(const std::string &name, const DDim &dim) { block_.FindVarRecursive(name)->SetShape(vectorize(dim)); } + +void CompileTimeInferShapeContext::SetRepeatedDims( + const std::string &name, const std::vector &dims) { + auto var = block_.FindVarRecursive(name); + PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name); + std::vector> dim_vec(dims.size()); + std::transform(dims.begin(), dims.end(), dim_vec.begin(), vectorize); + var->SetShapes(dim_vec); +} + bool CompileTimeInferShapeContext::IsRuntime() const { return false; } proto::VarDesc::VarType CompileTimeInferShapeContext::GetVarType( diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 1aa111dc76de10..52387aabd9d0b4 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -428,13 +428,13 @@ class RuntimeInferShapeContext : public InferShapeContext { } } - std::vector GetRepeatedDim(const std::string& name) const override { + std::vector GetRepeatedDims(const std::string& name) const override { Variable* var = scope_.FindVar(name); if (var->IsType()) { return var->Get().shapes(); } else { PADDLE_THROW( - "Only ReaderHolder support 'GetRepeatedDim', but Variable %s's " + "Only ReaderHolder support 'GetRepeatedDims', but Variable %s's " "type_id is %s.", name, var->Type().name()); } @@ -452,6 +452,19 @@ class RuntimeInferShapeContext : public InferShapeContext { } } + void SetRepeatedDims(const std::string& name, + const std::vector& dims) override { + Variable* var = scope_.FindVar(name); + if (var->IsType()) { + var->GetMutable()->set_shapes(dims); + } else { + PADDLE_THROW( + "Only ReaderHolder support 'SetRepeatedDims', but Variable %s's " + "type_id is %s.", + name, var->Type().name()); + } + } + proto::VarDesc::VarType GetVarType(const std::string& name) const override { auto* var = scope_.FindVar(name); return ToVarType(var->Type()); diff --git a/paddle/framework/reader.cc b/paddle/framework/reader.cc index 76cbc827ba5e84..86220cd0bbaf07 100644 --- a/paddle/framework/reader.cc +++ b/paddle/framework/reader.cc @@ -17,7 +17,7 @@ namespace paddle { namespace framework { -DDim FileReader::shape(size_t idx) const { +DDim ReaderBase::shape(size_t idx) const { PADDLE_ENFORCE_LT( idx, shapes_.size(), "Cannot get the %d'th shape, 'shapes_' only has %d elements.", idx, @@ -25,15 +25,15 @@ DDim FileReader::shape(size_t idx) const { return shapes_[idx]; } -void ShuffleReader::ReadNext(std::vector* out) { +void ShuffleReader::ReadNext(std::vector* out) { if (iteration_pos_ >= buffer_.size()) { // Reload buffer with new data buffer_.clear(); - buffer_.reverse(buffer_size_); + buffer_.reserve(buffer_size_); for (int i = 0; i < buffer_size_; ++i) { if (reader_->HasNext()) { - buffer.push_back(std::vector()); - reader_->ReadNext(&buffer.back()); + buffer_.push_back(std::vector()); + reader_->ReadNext(&buffer_.back()); } else { break; } @@ -48,19 +48,19 @@ void ShuffleReader::ReadNext(std::vector* out) { // if buffer_ is empty, the 'out' will return as an empty vector. } -void BatchReader::ReadNext(std::vector* out) { +void BatchReader::ReadNext(std::vector* out) { buffer_.clear(); buffer_.reserve(batch_size_); for (int i = 0; i < batch_size_; ++i) { if (reader_->HasNext()) { - buffer_.push_back(std::vector()); + buffer_.push_back(std::vector()); reader_->ReadNext(&buffer_.back()); } else { break; } } // Concat instances - out.clear(); + out->clear(); if (buffer_.empty()) { // if buffer_ is empty, the 'out' will return as an empty vector. return; diff --git a/paddle/framework/reader.h b/paddle/framework/reader.h index 523ff28c990799..ff7153bc7bfb65 100644 --- a/paddle/framework/reader.h +++ b/paddle/framework/reader.h @@ -22,39 +22,36 @@ namespace framework { class ReaderBase { public: - virtual void ReadNext(std::vector* out) = 0; + explicit ReaderBase(const std::vector& shapes) : shapes_(shapes) { + PADDLE_ENFORCE(!shapes_.empty()); + } + virtual void ReadNext(std::vector* out) = 0; virtual bool HasNext() const = 0; - virtual DDim shape(size_t idx) const = 0; - virtual std::vector shapes() const = 0; + DDim shape(size_t idx) const; + std::vector shapes() const { return shapes_; } + void set_shapes(const std::vector& shapes) { shapes_ = shapes; } virtual ~ReaderBase() {} + + protected: + std::vector shapes_; }; class FileReader : public ReaderBase { public: - explicit FileReader(const std::vector& shapes) : shapes_(shapes) { - PADDLE_ENFORCE(!shapes_.empty()); - } - - DDim shape(size_t idx) const override; - std::vector shapes() const override { return shapes_; } - - protected: - std::vector shapes_; + explicit FileReader(const std::vector& shapes) : ReaderBase(shapes) {} }; class DecoratedReader : public ReaderBase { public: - explicit DecoratedReader(ReaderBase* reader) : reader_(reader) { + explicit DecoratedReader(ReaderBase* reader) + : ReaderBase(reader->shapes()), reader_(reader) { PADDLE_ENFORCE_NOT_NULL(reader_); } bool HasNext() const override { return reader_->HasNext(); } - DDim shape(size_t idx) const override { return reader_->shape(idx); } - std::vector shapes() const override { return reader_->shapes(); } - protected: ReaderBase* reader_; }; @@ -73,9 +70,9 @@ class RandomReader : public FileReader { dist_ = std::uniform_real_distribution(min_, max_); } - void ReadNext(std::vector* out) override { - out.clear(); - out.reserve(shapes_.size()); + void ReadNext(std::vector* out) override { + out->clear(); + out->reserve(shapes_.size()); for (const DDim& shape : shapes_) { PADDLE_ENFORCE_GE( shape.size(), 2, @@ -88,9 +85,8 @@ class RandomReader : public FileReader { for (int64_t i = 0; i < numel; ++i) { data[i] = dist_(engine_); } - out.push_back(out_tensor); + out->push_back(out_tensor); } - return out; } bool HasNext() const override { return true; } @@ -111,11 +107,11 @@ class ShuffleReader : public DecoratedReader { buffer_.reserve(buffer_size); } - void ReadNext(std::vector* out) override; + void ReadNext(std::vector* out) override; private: int buffer_size_; - std::vector> buffer_; + std::vector> buffer_; size_t iteration_pos_; }; @@ -126,11 +122,11 @@ class BatchReader : public DecoratedReader { buffer_.reserve(batch_size_); } - void ReadNext(std::vector* out) override; + void ReadNext(std::vector* out) override; private: int batch_size_; - std::vector> buffer_; + std::vector> buffer_; }; // The ReaderHolder is used as readers' unified wrapper, @@ -141,11 +137,14 @@ class ReaderHolder { ReaderBase* Get() const { return reader_.get(); } - void ReadNext(std::vector* out) { reader_->ReadNext(out); } + void ReadNext(std::vector* out) { reader_->ReadNext(out); } bool HasNext() const { return reader_->HasNext(); } DDim shape(size_t idx) const { return reader_->shape(idx); } std::vector shapes() const { return reader_->shapes(); } + void set_shapes(const std::vector& shapes) { + reader_->set_shapes(shapes); + } private: std::unique_ptr reader_; diff --git a/paddle/framework/shape_inference.cc b/paddle/framework/shape_inference.cc index 4a8acfb87ff122..2f4d45057715d2 100644 --- a/paddle/framework/shape_inference.cc +++ b/paddle/framework/shape_inference.cc @@ -62,6 +62,16 @@ void InferShapeContext::SetOutputsDim(const std::string &name, SetDims(names, dims); } +void InferShapeContext::SetReaderDims(const std::string &name, + const std::vector &dims) { + const std::vector &arg_names = Outputs(name); + PADDLE_ENFORCE_EQ( + arg_names.size(), 1UL, + "Reader output '%s' should hold one element, but now it holds %d", name, + arg_names.size()); + return this->SetRepeatedDims(arg_names[0], dims); +} + std::vector InferShapeContext::GetDims( const std::vector &names) const { std::vector ret; diff --git a/paddle/framework/shape_inference.h b/paddle/framework/shape_inference.h index f1a64e9024beb8..7bee86985239de 100644 --- a/paddle/framework/shape_inference.h +++ b/paddle/framework/shape_inference.h @@ -37,11 +37,12 @@ class InferShapeContext { DDim GetInputDim(const std::string &name) const; std::vector GetInputsDim(const std::string &name) const; - std::vector GetReaderDims(const std::string &name) const DDim; + std::vector GetReaderDims(const std::string &name) const; DDim GetInputsElementDim(const std::string &name, int idx) const; void SetOutputDim(const std::string &name, const DDim &dim); void SetOutputsDim(const std::string &name, const std::vector &dims); + void SetReaderDims(const std::string &name, const std::vector &dims); virtual AttrReader Attrs() const = 0; virtual const std::vector &Inputs( @@ -61,7 +62,9 @@ class InferShapeContext { protected: virtual DDim GetDim(const std::string &name) const = 0; virtual void SetDim(const std::string &name, const DDim &dim) = 0; - std::vector GetRepeatedDim(const std::string &name) const = 0; + virtual std::vector GetRepeatedDims(const std::string &name) const = 0; + virtual void SetRepeatedDims(const std::string &name, + const std::vector &dims) = 0; std::vector GetDims(const std::vector &names) const; std::vector GetVarTypes( diff --git a/paddle/framework/var_desc.cc b/paddle/framework/var_desc.cc index 6d83e2e41126db..11a4daf2c991fc 100644 --- a/paddle/framework/var_desc.cc +++ b/paddle/framework/var_desc.cc @@ -57,10 +57,13 @@ size_t VarDesc::GetTensorDescNum() const { void VarDesc::SetShapes( const std::vector> &multiple_dims) { - PADDLE_ENFORCE_EQ(multiple_dims.size(), GetTensorDescNum(), - "The number of given shapes(%d) doesn't equal to the " - "number of sub tensor.", - multiple_dims.size(), GetTensorDescNum()); + if (multiple_dims.size() != GetTensorDescNum()) { + VLOG(3) << "WARNING: The number of given shapes(" << multiple_dims.size() + << ") doesn't match the existing tensor number(" + << GetTensorDescNum() + << "). The Reader is going to be reinitialized."; + SetTensorDescNum(multiple_dims.size()); + } std::vector tensors = mutable_tensor_descs(); for (size_t i = 0; i < multiple_dims.size(); ++i) { VectorToRepeated(multiple_dims[i], tensors[i]->mutable_dims()); @@ -87,10 +90,14 @@ void VarDesc::SetDataType(proto::DataType data_type) { void VarDesc::SetDataTypes( const std::vector &multiple_data_type) { - PADDLE_ENFORCE_EQ(multiple_data_type.size(), GetTensorDescNum(), - "The number of given data types(%d) doesn't equal to the " - "number of sub tensor.", - multiple_data_type.size(), GetTensorDescNum()); + if (multiple_data_type.size() != GetTensorDescNum()) { + VLOG(3) << "WARNING: The number of given data types(" + << multiple_data_type.size() + << ") doesn't match the existing tensor number(" + << GetTensorDescNum() + << "). The Reader is going to be reinitialized."; + SetTensorDescNum(multiple_data_type.size()); + } std::vector tensor_descs = mutable_tensor_descs(); for (size_t i = 0; i < multiple_data_type.size(); ++i) { tensor_descs[i]->set_data_type(multiple_data_type[i]); @@ -127,10 +134,14 @@ void VarDesc::SetLoDLevel(int32_t lod_level) { } void VarDesc::SetLoDLevels(const std::vector &multiple_lod_level) { - PADDLE_ENFORCE_EQ(multiple_lod_level.size(), GetTensorDescNum(), - "The number of given data types(%d) doesn't equal to the " - "number of sub tensor.", - multiple_lod_level.size(), GetTensorDescNum()); + if (multiple_lod_level.size() != GetTensorDescNum()) { + VLOG(3) << "WARNING: The number of given lod_levels(" + << multiple_lod_level.size() + << ") doesn't match the existing tensor number(" + << GetTensorDescNum() + << "). The Reader is going to be reinitialized."; + SetTensorDescNum(multiple_lod_level.size()); + } switch (desc_.type()) { case proto::VarDesc::READER: { size_t i = 0; diff --git a/paddle/framework/var_type.h b/paddle/framework/var_type.h index 5b7a08a08732a6..599d45149024ca 100644 --- a/paddle/framework/var_type.h +++ b/paddle/framework/var_type.h @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/framework/lod_rank_table.h" #include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor_array.h" +#include "paddle/framework/reader.h" #include "paddle/framework/selected_rows.h" #include "paddle/framework/variable.h" @@ -31,6 +32,8 @@ inline proto::VarDesc::VarType ToVarType(std::type_index type) { return proto::VarDesc_VarType_LOD_TENSOR_ARRAY; } else if (type.hash_code() == typeid(SelectedRows).hash_code()) { return proto::VarDesc_VarType_SELECTED_ROWS; + } else if (type.hash_code() == typeid(ReaderHolder).hash_code()) { + return proto::VarDesc_VarType_READER; } else { PADDLE_THROW("ToVarType:Unsupported type %s", type.name()); } @@ -40,7 +43,7 @@ template inline void VisitVarType(const framework::Variable& var, Visitor visitor) { switch (ToVarType(var.Type())) { case proto::VarDesc_VarType_LOD_TENSOR: - visitor(var.Get()); + visitor(var.Get()); return; case proto::VarDesc_VarType_LOD_RANK_TABLE: visitor(var.Get()); @@ -51,6 +54,9 @@ inline void VisitVarType(const framework::Variable& var, Visitor visitor) { case proto::VarDesc_VarType_SELECTED_ROWS: visitor(var.Get()); return; + case proto::VarDesc_VarType_READER: + visitor(var.Get()); + return; default: PADDLE_THROW("Not supported visit type, %d", ToVarType(var.Type())); } diff --git a/paddle/operators/create_reader_op.cc b/paddle/operators/create_reader_op.cc index 9cf27bbfc694b7..11c77a06032de2 100644 --- a/paddle/operators/create_reader_op.cc +++ b/paddle/operators/create_reader_op.cc @@ -18,12 +18,30 @@ namespace paddle { namespace operators { +std::vector RestoreShapes(const std::vector& shape_concat, + const std::vector& ranks) { + std::vector res; + int offset = 0; + for (int len : ranks) { + auto start_it = shape_concat.begin() + offset; + auto end_it = start_it + len; + res.push_back(framework::make_ddim(std::vector(start_it, end_it))); + offset += len; + } + return res; +} + // general infershape for file readers class CreateFileReaderInferShape : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasOutput("Out"), "The output file reader should not be null."); + const auto shape_concat = + ctx->Attrs().Get>("shape_concat"); + const auto ranks = ctx->Attrs().Get>("ranks"); + std::vector shapes = RestoreShapes(shape_concat, ranks); + ctx->SetReaderDims("Out", shapes); } }; @@ -31,10 +49,22 @@ class CreateFileReaderInferShape : public framework::InferShapeBase { class CreateDecoratedReaderInferShape : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Underlying_reader"), - "Input(Underlying_reader) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("UnderlyingReader"), + "Input(UnderlyingReader) should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "The output decorated reader should not be null."); + ctx->SetReaderDims("Out", ctx->GetReaderDims("UnderlyingReader")); + } +}; + +// general var type inference for all readers +class CreateReaderInferVarType : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc& op_desc, + framework::BlockDesc* block) const override { + std::string reader_name = op_desc.Output("Out")[0]; + framework::VarDesc* reader = block->FindVarRecursive(reader_name); + reader->SetType(framework::proto::VarDesc::READER); } }; @@ -51,15 +81,7 @@ class CreateRandomReaderOp : public framework::OperatorBase { int(shape_concat.size()), "The accumulate of all ranks should be equal to the " "shape concat's length."); - std::vector shapes; - int offset = 0; - for (int len : ranks) { - auto start_it = shape_concat.begin() + offset; - auto end_it = start_it + len; - shapes.push_back( - framework::make_ddim(std::vector(start_it, end_it))); - offset += len; - } + std::vector shapes = RestoreShapes(shape_concat, ranks); auto* out = scope.FindVar(Output("Out")) ->template GetMutable(); out->Reset(new framework::RandomReader(shapes, Attr("min"), @@ -99,7 +121,7 @@ class CreateShuffleReaderOp : public framework::OperatorBase { using framework::OperatorBase::OperatorBase; void Run(const framework::Scope& scope, const platform::Place& dev_place) const override { - const auto& underlying_reader = scope.FindVar(Input("Underlying_reader")) + const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); auto* out = scope.FindVar(Output("Out")) ->template GetMutable(); @@ -113,7 +135,7 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker { CreateShuffleReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(op_proto, op_checker) { AddInput( - "Underlying_reader", + "UnderlyingReader", "(ReaderHolder) The underlying reader for creating a shuffle reader."); AddOutput("Out", "(ReaderHolder) The created shuffle reader."); AddAttr("buffer_size", "The shuffle buffer size.").GreaterThan(0); @@ -131,7 +153,7 @@ class CreateBatchReaderOp : public framework::OperatorBase { using framework::OperatorBase::OperatorBase; void Run(const framework::Scope& scope, const platform::Place& dev_place) const override { - const auto& underlying_reader = scope.FindVar(Input("Underlying_reader")) + const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); auto* out = scope.FindVar(Output("Out")) ->template GetMutable(); @@ -145,7 +167,7 @@ class CreateBatchReaderOpMaker : public framework::OpProtoAndCheckerMaker { CreateBatchReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(op_proto, op_checker) { AddInput( - "Underlying_reader", + "UnderlyingReader", "(ReaderHolder) The underlying reader for creating a batch reader."); AddOutput("Out", "(ReaderHolder) The created batch reader."); AddAttr("batch_size", @@ -167,12 +189,15 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(create_random_reader, ops::CreateRandomReaderOp, ops::CreateFileReaderInferShape, ops::CreateRandomReaderOpMaker, - paddle::framework::EmptyGradOpMaker); + paddle::framework::EmptyGradOpMaker, + ops::CreateReaderInferVarType); REGISTER_OPERATOR(create_shuffle_reader, ops::CreateShuffleReaderOp, ops::CreateDecoratedReaderInferShape, ops::CreateShuffleReaderOpMaker, - paddle::framework::EmptyGradOpMaker); + paddle::framework::EmptyGradOpMaker, + ops::CreateReaderInferVarType); REGISTER_OPERATOR(create_batch_reader, ops::CreateBatchReaderOp, ops::CreateDecoratedReaderInferShape, ops::CreateBatchReaderOpMaker, - paddle::framework::EmptyGradOpMaker); + paddle::framework::EmptyGradOpMaker, + ops::CreateReaderInferVarType); diff --git a/paddle/operators/read_op.cc b/paddle/operators/read_op.cc index c6ff4ba8fee502..3d17b26c998fc2 100644 --- a/paddle/operators/read_op.cc +++ b/paddle/operators/read_op.cc @@ -25,7 +25,7 @@ class ReadInferShape : public framework::InferShapeBase { "The ReadOp must take a reader as input."); PADDLE_ENFORCE(ctx->HasOutputs("Out"), "The ReadOp should be assigned with output."); - std::vector reader_dims = ctx->GetReaderDims("Reader"); + std::vector reader_dims = ctx->GetReaderDims("Reader"); std::vector out_names = ctx->Outputs("Out"); PADDLE_ENFORCE_EQ( reader_dims.size(), out_names.size(), @@ -40,12 +40,12 @@ class ReadInferVarType : public framework::VarTypeInference { framework::BlockDesc* block) const override { std::string reader_name = op_desc.Input("Reader")[0]; std::vector out_names = op_desc.Output("Out"); - framework::VarDesc reader = block.FindVarRecursive(reader_name); - auto dtypes = reader.GetDataTypes(); + framework::VarDesc* reader = block->FindVarRecursive(reader_name); + auto dtypes = reader->GetDataTypes(); PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size()); for (size_t i = 0; i < dtypes.size(); ++i) { - faremwork::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]); - out.SetType(framework::proto::DataType::LOD_TENSOR); + framework::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]); + out.SetType(framework::proto::VarDesc::LOD_TENSOR); out.SetDataType(dtypes[i]); } } @@ -56,20 +56,18 @@ class ReadOp : public framework::OperatorBase { using framework::OperatorBase::OperatorBase; void Run(const framework::Scope& scope, const platform::Place& dev_place) const override { - const framework::ReaderHolder& reader = - scope.FindVar(Input("Reader"))->Get(); - if (!reader.HasNext()) { - // what shall we do??? + framework::ReaderHolder* reader = + scope.FindVar(Input("Reader"))->GetMutable(); + if (!reader->HasNext()) { return; } std::vector out_arg_names = Outputs("Out"); std::vector ins; - reader.ReadNext(&ins); + reader->ReadNext(&ins); PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size()); for (size_t i = 0; i < ins.size(); ++i) { auto* out = scope.FindVar(out_arg_names[i])->GetMutable(); - PADDLE_ENFORCE_EQ(ins[i].dims(), out->dims()); out->ShareDataWith(ins[i]); out->set_lod(ins[i].lod()); } @@ -86,9 +84,13 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker { Read Operator Execute a given reader once and output data. - )DOC") + )DOC"); } }; } // namespace operators -} // namespace paddle \ No newline at end of file +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(read, ops::ReadOp, ops::ReadInferShape, ops::ReadOpMaker, + paddle::framework::EmptyGradOpMaker, ops::ReadInferVarType); diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 0f1953abe0864c..0a92e10927caf0 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -217,8 +217,6 @@ void BindVarDsec(py::module &m) { .def("set_shapes", &VarDesc::SetShapes) .def("set_dtype", &VarDesc::SetDataType) .def("set_dtypes", &VarDesc::SetDataTypes) - .def("set_tensor_num", &VarDesc::SetTensorDescNum) - .def("tensor_num", &VarDesc::GetTensorDescNum) .def("shape", &VarDesc::GetShape, py::return_value_policy::reference) .def("shapes", &VarDesc::GetShapes, py::return_value_policy::reference) .def("dtype", &VarDesc::GetDataType, py::return_value_policy::reference) diff --git a/python/paddle/v2/fluid/executor.py b/python/paddle/v2/fluid/executor.py index 0eddcc3a5ab6f7..1bc3423f10cf09 100644 --- a/python/paddle/v2/fluid/executor.py +++ b/python/paddle/v2/fluid/executor.py @@ -51,7 +51,8 @@ def as_numpy(tensor): if len(lod) == 0: ans = tensor_data else: - raise RuntimeError("LoD Calculate lacks unit tests and buggy") + #raise RuntimeError("LoD Calculate lacks unit tests and buggy") + ans = tensor_data # elif len(lod) == 1: # ans = [] # idx = 0 diff --git a/python/paddle/v2/fluid/tests/test_cpp_reader.py b/python/paddle/v2/fluid/tests/test_cpp_reader.py new file mode 100644 index 00000000000000..cd5fff9425cb34 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_cpp_reader.py @@ -0,0 +1,71 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.v2 as paddle +import paddle.v2.fluid as fluid +import numpy as np + +prog = fluid.framework.Program() +block = prog.current_block() + +random_reader = block.create_var( + type=fluid.core.VarDesc.VarType.READER, name="RandomReader") +random_reader.desc.set_lod_levels([0, 0]) + +create_random_reader_op = block.append_op( + type="create_random_reader", + outputs={"Out": random_reader}, + attrs={ + "shape_concat": [1, 2, 1, 1], + "ranks": [2, 2], + "min": 0.0, + "max": 1.0 + }) + +batch_reader = block.create_var( + type=fluid.core.VarDesc.VarType.READER, name=("BatchReader")) +batch_reader.desc.set_lod_levels([0, 0]) + +create_batch_reader_op = block.append_op( + type="create_batch_reader", + inputs={"UnderlyingReader": random_reader}, + outputs={"Out": batch_reader}, + attrs={"batch_size": 10}) + +out1 = block.create_var( + type=fluid.core.VarDesc.VarType.LOD_TENSOR, + name="Out1", + shape=[10, 2], + dtype="float32", + lod_level=1) +out2 = block.create_var( + type=fluid.core.VarDesc.VarType.LOD_TENSOR, + name="Out2", + shape=[10, 1], + dtype="float32", + lod_level=1) + +read_op = block.append_op( + type="read", inputs={"Reader": batch_reader}, + outputs={"Out": [out1, out2]}) + +place = fluid.CPUPlace() +exe = fluid.Executor(place) + +[res1, res2] = exe.run(prog, fetch_list=[out1, out2]) + +if len(res1) == 0 or len(res2) == 0: + exit(1) + +exit(0) From 542bdef7a5142bbfebafc327ff393a8c1aa62214 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 7 Feb 2018 10:17:31 +0800 Subject: [PATCH 10/12] fix a unit test --- python/paddle/v2/fluid/tests/test_protobuf_descs.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/paddle/v2/fluid/tests/test_protobuf_descs.py b/python/paddle/v2/fluid/tests/test_protobuf_descs.py index 8f335d13db3ddb..c590bf1c6570a2 100644 --- a/python/paddle/v2/fluid/tests/test_protobuf_descs.py +++ b/python/paddle/v2/fluid/tests/test_protobuf_descs.py @@ -120,7 +120,6 @@ def test_multiple_shape(self): block = program_desc.block(0) var = block.var('my_reader') var.set_type(core.VarDesc.VarType.READER) - var.set_tensor_num(3) src_shapes = [[2, 3, 3], [4, 5], [6, 7, 8, 9]] var.set_shapes(src_shapes) res_shapes = var.shapes() @@ -141,7 +140,6 @@ def test_multiple_dtype(self): block = program_desc.block(0) var = block.var('my_reader') var.set_type(core.VarDesc.VarType.READER) - var.set_tensor_num(3) src_types = [ core.DataType.INT32, core.DataType.FP64, core.DataType.FP32 ] @@ -154,7 +152,6 @@ def test_multiple_lod_level(self): block = program_desc.block(0) var = block.var('my_reader') var.set_type(core.VarDesc.VarType.READER) - var.set_tensor_num(3) src_types = [3, 1, 2] var.set_lod_levels(src_types) self.assertEqual(src_types, var.lod_levels()) From b00cae60abdea7402baf70798885f9634b8eb0b0 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 7 Feb 2018 10:59:21 +0800 Subject: [PATCH 11/12] refine code --- python/paddle/v2/fluid/executor.py | 3 +-- python/paddle/v2/fluid/tests/test_cpp_reader.py | 13 ++----------- 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/python/paddle/v2/fluid/executor.py b/python/paddle/v2/fluid/executor.py index 1bc3423f10cf09..0eddcc3a5ab6f7 100644 --- a/python/paddle/v2/fluid/executor.py +++ b/python/paddle/v2/fluid/executor.py @@ -51,8 +51,7 @@ def as_numpy(tensor): if len(lod) == 0: ans = tensor_data else: - #raise RuntimeError("LoD Calculate lacks unit tests and buggy") - ans = tensor_data + raise RuntimeError("LoD Calculate lacks unit tests and buggy") # elif len(lod) == 1: # ans = [] # idx = 0 diff --git a/python/paddle/v2/fluid/tests/test_cpp_reader.py b/python/paddle/v2/fluid/tests/test_cpp_reader.py index cd5fff9425cb34..7efcb0c46d2a7e 100644 --- a/python/paddle/v2/fluid/tests/test_cpp_reader.py +++ b/python/paddle/v2/fluid/tests/test_cpp_reader.py @@ -33,16 +33,6 @@ "max": 1.0 }) -batch_reader = block.create_var( - type=fluid.core.VarDesc.VarType.READER, name=("BatchReader")) -batch_reader.desc.set_lod_levels([0, 0]) - -create_batch_reader_op = block.append_op( - type="create_batch_reader", - inputs={"UnderlyingReader": random_reader}, - outputs={"Out": batch_reader}, - attrs={"batch_size": 10}) - out1 = block.create_var( type=fluid.core.VarDesc.VarType.LOD_TENSOR, name="Out1", @@ -57,7 +47,8 @@ lod_level=1) read_op = block.append_op( - type="read", inputs={"Reader": batch_reader}, + type="read", + inputs={"Reader": random_reader}, outputs={"Out": [out1, out2]}) place = fluid.CPUPlace() From c1349d98aa48060b449c4eea4dfc95a2989ad203 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 7 Feb 2018 14:43:14 +0800 Subject: [PATCH 12/12] fix compile errors --- paddle/framework/reader.cc | 2 ++ paddle/framework/reader.h | 11 ++++++++-- paddle/operators/CMakeLists.txt | 2 +- paddle/operators/create_reader_op.cc | 22 ++++++++++--------- paddle/operators/read_op.cc | 5 ++++- .../paddle/v2/fluid/tests/test_cpp_reader.py | 6 ++--- 6 files changed, 31 insertions(+), 17 deletions(-) diff --git a/paddle/framework/reader.cc b/paddle/framework/reader.cc index 86220cd0bbaf07..928b661aaadb4a 100644 --- a/paddle/framework/reader.cc +++ b/paddle/framework/reader.cc @@ -38,6 +38,8 @@ void ShuffleReader::ReadNext(std::vector* out) { break; } } + // TODO(fengjiayi): 'std::random_shuffle' can be very slow. It needs to be + // optimize. std::random_shuffle(buffer_.begin(), buffer_.end()); iteration_pos_ = 0; } diff --git a/paddle/framework/reader.h b/paddle/framework/reader.h index ff7153bc7bfb65..534894cfbd6668 100644 --- a/paddle/framework/reader.h +++ b/paddle/framework/reader.h @@ -28,6 +28,8 @@ class ReaderBase { virtual void ReadNext(std::vector* out) = 0; virtual bool HasNext() const = 0; + virtual void ReInit() = 0; + DDim shape(size_t idx) const; std::vector shapes() const { return shapes_; } void set_shapes(const std::vector& shapes) { shapes_ = shapes; } @@ -52,6 +54,8 @@ class DecoratedReader : public ReaderBase { bool HasNext() const override { return reader_->HasNext(); } + void ReInit() override { reader_->ReInit(); } + protected: ReaderBase* reader_; }; @@ -59,9 +63,9 @@ class DecoratedReader : public ReaderBase { // file readers template -class RandomReader : public FileReader { +class RandomDataGenerator : public FileReader { public: - RandomReader(const std::vector& shapes, float min, float max) + RandomDataGenerator(const std::vector& shapes, float min, float max) : FileReader(shapes), min_(min), max_(max) { PADDLE_ENFORCE_LE( min, max, "'min' shouldn't be greater than 'max'.(%f vs %f)", min, max); @@ -91,6 +95,8 @@ class RandomReader : public FileReader { bool HasNext() const override { return true; } + void ReInit() override { return; } + private: float min_; float max_; @@ -139,6 +145,7 @@ class ReaderHolder { void ReadNext(std::vector* out) { reader_->ReadNext(out); } bool HasNext() const { return reader_->HasNext(); } + void ReInit() { reader_->ReInit(); } DDim shape(size_t idx) const { return reader_->shape(idx); } std::vector shapes() const { return reader_->shapes(); } diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index e1dba8bb3f9a72..25bb7187d36c5f 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -186,7 +186,7 @@ list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) foreach(src ${GENERAL_OPS}) op_library(${src}) endforeach() -file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\nUSE_NO_KERNEL_OP(create_random_reader);\n") +file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\nUSE_NO_KERNEL_OP(create_random_data_generator);\n") set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library") diff --git a/paddle/operators/create_reader_op.cc b/paddle/operators/create_reader_op.cc index 11c77a06032de2..5ba2a25ab4c679 100644 --- a/paddle/operators/create_reader_op.cc +++ b/paddle/operators/create_reader_op.cc @@ -18,8 +18,8 @@ namespace paddle { namespace operators { -std::vector RestoreShapes(const std::vector& shape_concat, - const std::vector& ranks) { +static std::vector RestoreShapes( + const std::vector& shape_concat, const std::vector& ranks) { std::vector res; int offset = 0; for (int len : ranks) { @@ -69,7 +69,7 @@ class CreateReaderInferVarType : public framework::VarTypeInference { }; template -class CreateRandomReaderOp : public framework::OperatorBase { +class CreateRandomDataGeneratorOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; void Run(const framework::Scope& scope, @@ -84,14 +84,15 @@ class CreateRandomReaderOp : public framework::OperatorBase { std::vector shapes = RestoreShapes(shape_concat, ranks); auto* out = scope.FindVar(Output("Out")) ->template GetMutable(); - out->Reset(new framework::RandomReader(shapes, Attr("min"), - Attr("max"))); + out->Reset(new framework::RandomDataGenerator(shapes, Attr("min"), + Attr("max"))); } }; -class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker { +class CreateRandomDataGeneratorOpMaker + : public framework::OpProtoAndCheckerMaker { public: - CreateRandomReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) + CreateRandomDataGeneratorOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(op_proto, op_checker) { AddOutput("Out", "(ReaderHolder) The created random reader."); AddAttr>("shape_concat", @@ -107,7 +108,7 @@ class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("min", "The lower bound of reader's uniform distribution."); AddAttr("max", "The upper bound of reader's uniform distribution."); AddComment(R"DOC( - CreateRandomReader Operator + CreateRandomDataGenerator Operator This Op creates a random reader. The reader generates random data instead of really reading from files. @@ -186,9 +187,10 @@ class CreateBatchReaderOpMaker : public framework::OpProtoAndCheckerMaker { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(create_random_reader, ops::CreateRandomReaderOp, +REGISTER_OPERATOR(create_random_data_generator, + ops::CreateRandomDataGeneratorOp, ops::CreateFileReaderInferShape, - ops::CreateRandomReaderOpMaker, + ops::CreateRandomDataGeneratorOpMaker, paddle::framework::EmptyGradOpMaker, ops::CreateReaderInferVarType); REGISTER_OPERATOR(create_shuffle_reader, ops::CreateShuffleReaderOp, diff --git a/paddle/operators/read_op.cc b/paddle/operators/read_op.cc index 3d17b26c998fc2..3ae454101f585c 100644 --- a/paddle/operators/read_op.cc +++ b/paddle/operators/read_op.cc @@ -59,7 +59,10 @@ class ReadOp : public framework::OperatorBase { framework::ReaderHolder* reader = scope.FindVar(Input("Reader"))->GetMutable(); if (!reader->HasNext()) { - return; + reader->ReInit(); + PADDLE_ENFORCE( + reader->HasNext(), + "Reader can not read the next data even it has been re-initialized."); } std::vector out_arg_names = Outputs("Out"); std::vector ins; diff --git a/python/paddle/v2/fluid/tests/test_cpp_reader.py b/python/paddle/v2/fluid/tests/test_cpp_reader.py index 7efcb0c46d2a7e..e71c3a290c9b12 100644 --- a/python/paddle/v2/fluid/tests/test_cpp_reader.py +++ b/python/paddle/v2/fluid/tests/test_cpp_reader.py @@ -20,11 +20,11 @@ block = prog.current_block() random_reader = block.create_var( - type=fluid.core.VarDesc.VarType.READER, name="RandomReader") + type=fluid.core.VarDesc.VarType.READER, name="RandomDataGenerator") random_reader.desc.set_lod_levels([0, 0]) -create_random_reader_op = block.append_op( - type="create_random_reader", +create_random_data_generator_op = block.append_op( + type="create_random_data_generator", outputs={"Out": random_reader}, attrs={ "shape_concat": [1, 2, 1, 1],