Skip to content

Commit

Permalink
Add type_alias to import framework into ops
Browse files Browse the repository at this point in the history
Make implement an operator less noisy.
  • Loading branch information
reyoung committed Jul 25, 2017
1 parent 0c2790f commit efc119b
Show file tree
Hide file tree
Showing 23 changed files with 205 additions and 232 deletions.
29 changes: 12 additions & 17 deletions paddle/operators/add_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/add_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"

namespace paddle {
namespace operators {

class AddOp : public framework::OperatorWithKernel {
class AddOp : public OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {
void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2, "Input size of AddOp must be two");
PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one");
PADDLE_ENFORCE(
Expand All @@ -35,10 +32,10 @@ class AddOp : public framework::OperatorWithKernel {
}
};

class AddOpMaker : public framework::OpProtoAndCheckerMaker {
class AddOpMaker : public OpProtoAndCheckerMaker {
public:
AddOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of add op");
AddInput("Y", "The second input of add op");
AddOutput("Out", "The output of add op");
Expand All @@ -50,11 +47,10 @@ The equation is: Out = X + Y
}
};

class AddOpGrad : public framework::OperatorWithKernel {
class AddOpGrad : public OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {}
void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs) const override {}
std::string DebugString() const override {
LOG(INFO) << "AddOpGrad";
return "";
Expand All @@ -64,7 +60,6 @@ class AddOpGrad : public framework::OperatorWithKernel {
} // namespace operators
} // namespace paddle

REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker);
REGISTER_GRADIENT_OP(add_two, add_two_grad, paddle::operators::AddOpGrad);
REGISTER_OP_CPU_KERNEL(
add_two, paddle::operators::AddKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker);
REGISTER_GRADIENT_OP(add_two, add_two_grad, ops::AddOpGrad);
REGISTER_OP_CPU_KERNEL(add_two, ops::AddKernel<ops::CPUPlace, float>);
5 changes: 2 additions & 3 deletions paddle/operators/add_op.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include "paddle/operators/add_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/add_op.h"

REGISTER_OP_GPU_KERNEL(add_two,
paddle::operators::AddKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(add_two, ops::AddKernel<ops::GPUPlace, float>);
19 changes: 8 additions & 11 deletions paddle/operators/add_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,24 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "glog/logging.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
#include "paddle/operators/type_alias.h"

namespace paddle {
namespace operators {

template <typename Place, typename T>
class AddKernel : public framework::OpKernel {
class AddKernel : public OpKernel {
public:
void Compute(const framework::KernelContext& context) const override {
auto input0 = context.Input(0)->Get<framework::Tensor>();
auto input1 = context.Input(1)->Get<framework::Tensor>();
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
void Compute(const KernelContext& context) const override {
auto input0 = context.Input(0)->Get<Tensor>();
auto input1 = context.Input(1)->Get<Tensor>();
auto output = context.Output(0)->GetMutable<Tensor>();

output->mutable_data<T>(context.GetPlace());

framework::EigenVector<T>::Flatten(*output).device(
EigenVector<T>::Flatten(*output).device(
*(context.GetEigenDevice<Place>())) =
framework::EigenVector<T>::Flatten(input0) +
framework::EigenVector<T>::Flatten(input1);
EigenVector<T>::Flatten(input0) + EigenVector<T>::Flatten(input1);
}
};

Expand Down
28 changes: 11 additions & 17 deletions paddle/operators/cross_entropy_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/cross_entropy_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"

namespace paddle {
namespace operators {

class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
class OnehotCrossEntropyOp : public OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {
void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2,
"Input size of OnehotCrossEntropyOp must be two");
PADDLE_ENFORCE(outputs.size() == 1,
Expand All @@ -35,15 +32,14 @@ class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(inputs[0]->dims().size() == 2, "X's dimension must be 2.");
PADDLE_ENFORCE(outputs[0]->dims().size() == 1,
"label's dimension must be 1.");
outputs[0]->Resize(framework::make_ddim({inputs[0]->dims()[0]}));
outputs[0]->Resize({inputs[0]->dims()[0]});
}
};

class OnehotCrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
class OnehotCrossEntropyOpMaker : public OpProtoAndCheckerMaker {
public:
OnehotCrossEntropyOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
OnehotCrossEntropyOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of OnehotCrossEntropyOp");
AddInput("label", "The second input of OnehotCrossEntropyOp");
AddOutput("Y", "The output of OnehotCrossEntropyOp");
Expand All @@ -59,9 +55,7 @@ OnehotCrossEntropy Operator.
} // namespace paddle

REGISTER_OP(onehot_cross_entropy,
paddle::operators::OnehotCrossEntropyOp,
paddle::operators::OnehotCrossEntropyOpMaker);
REGISTER_OP_CPU_KERNEL(
onehot_cross_entropy,
paddle::operators::OnehotCrossEntropyOpKernel<::paddle::platform::CPUPlace,
float>);
ops::OnehotCrossEntropyOp,
ops::OnehotCrossEntropyOpMaker);
REGISTER_OP_CPU_KERNEL(onehot_cross_entropy,
ops::OnehotCrossEntropyOpKernel<ops::CPUPlace, float>);
4 changes: 1 addition & 3 deletions paddle/operators/cross_entropy_op.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
#include "paddle/operators/cross_entropy_op.h"
#include "paddle/framework/op_registry.h"

REGISTER_OP_GPU_KERNEL(onehot_cross_entropy,
paddle::operators::OnehotCrossEntropyOpKernel<
::paddle::platform::GPUPlace, float>);
ops::OnehotCrossEntropyOpKernel<ops::GPUPlace, float>);
14 changes: 6 additions & 8 deletions paddle/operators/cross_entropy_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "glog/logging.h"
#include "paddle/framework/operator.h"
#include "paddle/operators/type_alias.h"

namespace paddle {
namespace operators {

template <typename Place, typename T>
class OnehotCrossEntropyOpKernel : public framework::OpKernel {
class OnehotCrossEntropyOpKernel : public OpKernel {
public:
constexpr T LOG_THRESHOLD() const { return static_cast<T>(1e-20); }

void Compute(const framework::KernelContext& context) const override {
auto X = context.Input(0)->Get<framework::Tensor>();
void Compute(const KernelContext& context) const override {
auto X = context.Input(0)->Get<Tensor>();
const T* X_data = X.data<T>();
const int* label_data =
context.Input(1)->Get<framework::Tensor>().data<int>();
auto* Y = context.Output(0)->GetMutable<framework::Tensor>();
const int* label_data = context.Input(1)->Get<Tensor>().data<int>();
auto* Y = context.Output(0)->GetMutable<Tensor>();

Y->mutable_data<T>(context.GetPlace());

Expand Down
39 changes: 17 additions & 22 deletions paddle/operators/fc_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,41 +12,38 @@
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "type_alias.h"

namespace paddle {
namespace operators {

class FullyConnectedOp : public framework::PlainNet {
class FullyConnectedOp : public PlainNet {
public:
void Init() override {
AddOp(framework::OpRegistry::CreateOp("mul",
{
Input("X"), Input("W"),
},
{Output("before_act")},
{}));
AddOp(OpRegistry::CreateOp("mul",
{
Input("X"), Input("W"),
},
{Output("before_act")},
{}));
auto b = Input("b");
if (b != framework::OperatorBase::EMPTY_VAR_NAME()) {
AddOp(framework::OpRegistry::CreateOp("rowwise_add",
{Output("before_act"), Input("b")},
{Output("before_act")},
{}));
if (b != EMPTY_VAR_NAME()) {
AddOp(OpRegistry::CreateOp("rowwise_add",
{Output("before_act"), Input("b")},
{Output("before_act")},
{}));
}

auto activation = GetAttr<std::string>("activation");
AddOp(framework::OpRegistry::CreateOp(
AddOp(OpRegistry::CreateOp(
activation, {Output("before_act")}, {Output("Y")}, {}));
CompleteAddOp(false);
}
};

class FullyConnectedOpMaker : public framework::OpProtoAndCheckerMaker {
class FullyConnectedOpMaker : public OpProtoAndCheckerMaker {
public:
FullyConnectedOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
FullyConnectedOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "the input of fc operator");
AddInput("W", "the weight of fc operator");
Expand All @@ -71,6 +68,4 @@ USE_OP(rowwise_add);
USE_OP(sigmoid);
USE_OP(softmax);

REGISTER_OP(fc,
paddle::operators::FullyConnectedOp,
paddle::operators::FullyConnectedOpMaker);
REGISTER_OP(fc, ops::FullyConnectedOp, ops::FullyConnectedOpMaker);
29 changes: 12 additions & 17 deletions paddle/operators/mul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,14 @@
limitations under the License. */

#include "paddle/operators/mul_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"

namespace paddle {
namespace operators {

class MulOp : public framework::OperatorWithKernel {
class MulOp : public OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {
void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2, "The mul op must take two inputs");
auto dim0 = inputs[0]->dims();
auto dim1 = inputs[1]->dims();
Expand All @@ -37,10 +34,10 @@ class MulOp : public framework::OperatorWithKernel {
}
};

class MulOpMaker : public framework::OpProtoAndCheckerMaker {
class MulOpMaker : public OpProtoAndCheckerMaker {
public:
MulOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
MulOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of mul op");
AddInput("Y", "The second input of mul op");
AddOutput("Out", "The output of mul op");
Expand All @@ -52,11 +49,10 @@ The equation is: Out = X * Y
}
};

class MulOpGrad : public framework::OperatorWithKernel {
class MulOpGrad : public OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {}
void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs) const override {}
std::string DebugString() const override {
LOG(INFO) << "MulGrad";
return "";
Expand All @@ -66,8 +62,7 @@ class MulOpGrad : public framework::OperatorWithKernel {
} // namespace operators
} // namespace paddle

REGISTER_OP(mul, paddle::operators::MulOp, paddle::operators::MulOpMaker);
REGISTER_GRADIENT_OP(mul, mul_grad, paddle::operators::MulOpGrad);
REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker);
REGISTER_GRADIENT_OP(mul, mul_grad, ops::MulOpGrad);

REGISTER_OP_CPU_KERNEL(
mul, paddle::operators::MulKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel<ops::CPUPlace, float>);
5 changes: 1 addition & 4 deletions paddle/operators/mul_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,5 @@
limitations under the License. */

#include "paddle/operators/mul_op.h"
#include "paddle/framework/op_registry.h"

REGISTER_OP_GPU_KERNEL(mul,
paddle::operators::MulKernel<paddle::platform
::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<ops::GPUPlace, float>);
21 changes: 9 additions & 12 deletions paddle/operators/mul_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,27 @@

#pragma once

#include "glog/logging.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
#include "paddle/operators/type_alias.h"

namespace paddle {
namespace operators {

template <typename Place, typename T>
class MulKernel : public framework::OpKernel {
class MulKernel : public OpKernel {
public:
void Compute(const framework::KernelContext& context) const override {
void Compute(const KernelContext& context) const override {
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = {
{Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}};

auto input0 = context.Input(0)->Get<framework::Tensor>();
auto input1 = context.Input(1)->Get<framework::Tensor>();
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
auto input0 = context.Input(0)->Get<Tensor>();
auto input1 = context.Input(1)->Get<Tensor>();
auto* output = context.Output(0)->GetMutable<Tensor>();

output->mutable_data<T>(context.GetPlace());

framework::EigenMatrix<T>::From(*output).device(
*(context.GetEigenDevice<Place>())) =
framework::EigenMatrix<T>::From(input0).contract(
framework::EigenMatrix<T>::From(input1), dim_pair);
EigenMatrix<T>::From(*output).device(*(context.GetEigenDevice<Place>())) =
EigenMatrix<T>::From(input0).contract(EigenMatrix<T>::From(input1),
dim_pair);
}
};
} // namespace operators
Expand Down
Loading

0 comments on commit efc119b

Please sign in to comment.