From 62dee178fa4e789e68522158c9a1858550e66cdc Mon Sep 17 00:00:00 2001 From: strint Date: Wed, 12 Apr 2023 14:22:21 +0000 Subject: [PATCH 01/18] add task to/from proto --- .../core/graph/boxing_identity_task_node.cpp | 14 +++ .../core/graph/boxing_identity_task_node.h | 4 + oneflow/core/graph/boxing_task_graph.proto | 110 ++++++++++++++++++ oneflow/core/graph/boxing_zeros_task_node.cpp | 20 ++++ oneflow/core/graph/boxing_zeros_task_node.h | 4 + .../collective_boxing_pack_task_node.cpp | 24 ++++ .../graph/collective_boxing_pack_task_node.h | 4 + .../graph/collective_boxing_task_node.cpp | 16 +++ .../core/graph/collective_boxing_task_node.h | 4 + .../collective_boxing_unpack_task_node.cpp | 24 ++++ .../collective_boxing_unpack_task_node.h | 4 + oneflow/core/graph/compute_task_node.cpp | 4 +- oneflow/core/graph/compute_task_node.h | 2 +- oneflow/core/graph/copy_task_node.cpp | 26 +++++ oneflow/core/graph/copy_task_node.h | 12 +- .../graph/nccl_send_recv_boxing_task_node.cpp | 39 +++++++ .../graph/nccl_send_recv_boxing_task_node.h | 6 + oneflow/core/graph/slice_boxing_task_node.cpp | 36 ++++++ oneflow/core/graph/slice_boxing_task_node.h | 11 +- oneflow/core/graph/task_edge.proto | 12 ++ oneflow/core/graph/task_graph_rebuild_ctx.cpp | 67 +++++++++++ oneflow/core/graph/task_graph_rebuild_ctx.h | 60 ++++++++++ oneflow/core/graph/task_node.cpp | 30 ++++- oneflow/core/graph/task_node.h | 11 +- oneflow/core/graph/transport_task_node.cpp | 46 ++++++++ oneflow/core/graph/transport_task_node.h | 10 ++ oneflow/core/register/register_desc.cpp | 14 ++- oneflow/core/register/register_desc.h | 3 +- 28 files changed, 596 insertions(+), 21 deletions(-) create mode 100644 oneflow/core/graph/boxing_task_graph.proto create mode 100644 oneflow/core/graph/task_edge.proto create mode 100644 oneflow/core/graph/task_graph_rebuild_ctx.cpp create mode 100644 oneflow/core/graph/task_graph_rebuild_ctx.h create mode 100644 oneflow/core/graph/transport_task_node.cpp diff --git a/oneflow/core/graph/boxing_identity_task_node.cpp b/oneflow/core/graph/boxing_identity_task_node.cpp index 0420af0e552..cbe7969e888 100644 --- a/oneflow/core/graph/boxing_identity_task_node.cpp +++ b/oneflow/core/graph/boxing_identity_task_node.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/framework/to_string.h" #include "oneflow/core/graph/boxing_identity_task_node.h" +#include "oneflow/core/graph/boxing_task_graph.pb.h" namespace oneflow { @@ -53,4 +54,17 @@ void BoxingIdentityTaskNode::InferProducedDataRegstTimeShape() { NaiveInferProducedDataRegstTimeShape(); } +Maybe BoxingIdentityTaskNode::InitTransportTaskFromProto( + const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) { + CHECK_OR_RETURN(transport_task_proto.has_boxing_identity_task()) + << "not a serialized BoxingIdentityTaskNode. debug string: " + << transport_task_proto.DebugString(); + return Maybe::Ok(); +} + +void BoxingIdentityTaskNode::ToTransportTaskProto(TransportTaskProto* transport_task_proto) const { + ToProto(transport_task_proto->mutable_task_proto(), /*check=*/false); + transport_task_proto->mutable_boxing_identity_task(); +} + } // namespace oneflow diff --git a/oneflow/core/graph/boxing_identity_task_node.h b/oneflow/core/graph/boxing_identity_task_node.h index 4a1935a71cd..34fd026e4bd 100644 --- a/oneflow/core/graph/boxing_identity_task_node.h +++ b/oneflow/core/graph/boxing_identity_task_node.h @@ -29,6 +29,10 @@ class BoxingIdentityTaskNode : public TransportTaskNode { void Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi); TaskType GetTaskType() const override { return TaskType::kBoxingIdentity; } + Maybe InitTransportTaskFromProto(const TransportTaskProto& transport_task_proto, + const TaskGraphRebuildCtx& ctx) override; + void ToTransportTaskProto(TransportTaskProto*) const override; + private: void BuildExecGphAndRegst() override; void ProduceAllRegstsAndBindEdges() override; diff --git a/oneflow/core/graph/boxing_task_graph.proto b/oneflow/core/graph/boxing_task_graph.proto new file mode 100644 index 00000000000..e732bdf82f0 --- /dev/null +++ b/oneflow/core/graph/boxing_task_graph.proto @@ -0,0 +1,110 @@ +syntax = "proto2"; +package oneflow; + +import "oneflow/core/register/logical_blob_id.proto"; +import "oneflow/core/common/shape.proto"; +import "oneflow/core/common/data_type.proto"; +import "oneflow/core/job/sbp_parallel.proto"; +import "oneflow/core/job/task.proto"; +import "oneflow/core/job/placement.proto"; +import "oneflow/core/graph/task_edge.proto"; +import "oneflow/core/operator/op_conf.proto"; +import "oneflow/core/register/tensor_slice_view.proto"; + +message ComputeTasksProto { + map parallel_id2task = 2; +} + +message CollectiveBoxingGenericTaskProto { + required OperatorConf op_conf = 1; +} + +message NcclSendRecvBoxingTaskProto { + required ShapeProto logical_shape = 1; + required DataType data_type = 2; + required NdSbp src_nd_sbp = 3; + required NdSbp dst_nd_sbp = 4; + required ParallelConf src_parallel_conf = 5; + required ParallelConf dst_parallel_conf = 6; + required ParallelConf parallel_conf = 7; + required ParallelContext parallel_ctx = 8; + required bool has_input = 9; + required bool has_output = 10; + required string stream_name = 11; +} + +enum CopyHdType { + H2D = 0; + D2H = 1; +} + +message CopyHdTaskProto { + required CopyHdType copy_type = 1; +} + +message CopyCommNetTaskProto { +} + +message BoxingZerosTaskProto { + required ShapeProto shape = 1; + required DataType data_type = 2; + required ShapeProto time_shape = 3; +} + +enum SliceBoxingTaskMode { + kSliceBoxingTaskModeInvalid = 0; + kSliceBoxingTaskModeCopy = 1; + kSliceBoxingTaskModeAdd = 2; +} + +message SliceBoxingTaskProto { + map in_data_edge_uid2slice = 1; + repeated int64 ordered_in_data_edge_uid = 2; + required TensorSliceViewProto out_slice = 3; + required ShapeProto out_shape = 4; + required SliceBoxingTaskMode mode = 5; +} + +message CollectiveBoxingPackTaskProto { + required ShapeProto logical_shape = 1; + required SbpParallel src_sbp_parallel = 2; + required SbpParallel dst_sbp_parallel = 3; + required int64 parallel_num = 4; +} + +message CollectiveBoxingUnpackTaskProto { + required ShapeProto logical_shape = 1; + required SbpParallel src_sbp_parallel = 2; + required SbpParallel dst_sbp_parallel = 3; + required int64 parallel_num = 4; +} + +message BoxingIdentityTaskProto { +} + +message TransportTaskProto { + required TaskProto task_proto = 1; + required LogicalBlobId lbi = 11; + oneof transport_task_type { + CollectiveBoxingGenericTaskProto collective_boxing_generic_task = 2; + NcclSendRecvBoxingTaskProto nccl_send_recv_boxing_task = 3; + CopyHdTaskProto copy_hd_task = 4; + CopyCommNetTaskProto copy_comm_net_task = 5; + BoxingZerosTaskProto boxing_zeros_task = 6; + SliceBoxingTaskProto slice_boxing_task = 7; + CollectiveBoxingPackTaskProto collective_boxing_pack_task = 8; + CollectiveBoxingUnpackTaskProto collective_boxing_unpack_task = 9; + BoxingIdentityTaskProto boxing_identity_task = 10; + } +} + +message TaskIdsProto { + repeated int64 task_id = 1; +} + +message BoxingTaskGraphProto { + map boxing_related_op_name2compute_tasks = 1; + repeated TransportTaskProto transport_task = 2; + repeated TaskEdgeProto task_edge = 3; + map boxing_unrelated_op_name2task_ids = 4; +} \ No newline at end of file diff --git a/oneflow/core/graph/boxing_zeros_task_node.cpp b/oneflow/core/graph/boxing_zeros_task_node.cpp index 3984f21d7ee..ce7f2e65e9d 100644 --- a/oneflow/core/graph/boxing_zeros_task_node.cpp +++ b/oneflow/core/graph/boxing_zeros_task_node.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/framework/to_string.h" #include "oneflow/core/graph/boxing_zeros_task_node.h" +#include "oneflow/core/graph/boxing_task_graph.pb.h" namespace oneflow { @@ -56,5 +57,24 @@ void BoxingZerosTaskNode::BuildExecGphAndRegst() { void BoxingZerosTaskNode::InferProducedDataRegstTimeShape() { GetProducedRegst("out")->mut_data_regst_time_shape()->reset(new Shape(time_shape_)); } +Maybe BoxingZerosTaskNode::InitTransportTaskFromProto( + const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) { + CHECK_OR_RETURN(transport_task_proto.has_boxing_zeros_task()) + << "not a serialized BoxingZerosTaskNode. debug string: " + << transport_task_proto.DebugString(); + const auto& proto = transport_task_proto.boxing_zeros_task(); + shape_ = Shape(proto.shape()); + data_type_ = proto.data_type(); + time_shape_ = Shape(proto.time_shape()); + return Maybe::Ok(); +} + +void BoxingZerosTaskNode::ToTransportTaskProto(TransportTaskProto* transport_task_proto) const { + ToProto(transport_task_proto->mutable_task_proto(), /*check=*/false); + auto* proto = transport_task_proto->mutable_boxing_zeros_task(); + shape_.ToProto(proto->mutable_shape()); + proto->set_data_type(data_type_); + time_shape_.ToProto(proto->mutable_time_shape()); +} } // namespace oneflow diff --git a/oneflow/core/graph/boxing_zeros_task_node.h b/oneflow/core/graph/boxing_zeros_task_node.h index 5a99df1d5ea..9594fddd367 100644 --- a/oneflow/core/graph/boxing_zeros_task_node.h +++ b/oneflow/core/graph/boxing_zeros_task_node.h @@ -30,6 +30,10 @@ class BoxingZerosTaskNode : public TransportTaskNode { DataType data_type, const Shape& time_shape); TaskType GetTaskType() const override { return TaskType::kBoxingZeros; } + Maybe InitTransportTaskFromProto(const TransportTaskProto& transport_task_proto, + const TaskGraphRebuildCtx& ctx) override; + void ToTransportTaskProto(TransportTaskProto*) const override; + private: void BuildExecGphAndRegst() override; void ProduceAllRegstsAndBindEdges() override; diff --git a/oneflow/core/graph/collective_boxing_pack_task_node.cpp b/oneflow/core/graph/collective_boxing_pack_task_node.cpp index 6ee44f90d62..5225591b036 100644 --- a/oneflow/core/graph/collective_boxing_pack_task_node.cpp +++ b/oneflow/core/graph/collective_boxing_pack_task_node.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/framework/to_string.h" #include "oneflow/core/graph/collective_boxing_pack_task_node.h" +#include "oneflow/core/graph/boxing_task_graph.pb.h" namespace oneflow { @@ -66,4 +67,27 @@ void CollectiveBoxingPackTaskNode::InferProducedDataRegstTimeShape() { NaiveInferProducedDataRegstTimeShape(); } +Maybe CollectiveBoxingPackTaskNode::InitTransportTaskFromProto( + const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) { + CHECK_OR_RETURN(transport_task_proto.has_collective_boxing_pack_task()) + << "not a serialized CollectiveBoxingPackTaskNode. debug string: " + << transport_task_proto.DebugString(); + const auto& proto = transport_task_proto.collective_boxing_pack_task(); + logical_shape_ = Shape(proto.logical_shape()); + src_sbp_parallel_ = proto.src_sbp_parallel(); + dst_sbp_parallel_ = proto.dst_sbp_parallel(); + parallel_num_ = proto.parallel_num(); + return Maybe::Ok(); +} + +void CollectiveBoxingPackTaskNode::ToTransportTaskProto( + TransportTaskProto* transport_task_proto) const { + ToProto(transport_task_proto->mutable_task_proto(), /*check=*/false); + auto* proto = transport_task_proto->mutable_collective_boxing_pack_task(); + logical_shape_.ToProto(proto->mutable_logical_shape()); + *proto->mutable_src_sbp_parallel() = src_sbp_parallel_; + *proto->mutable_dst_sbp_parallel() = dst_sbp_parallel_; + proto->set_parallel_num(parallel_num_); +} + } // namespace oneflow diff --git a/oneflow/core/graph/collective_boxing_pack_task_node.h b/oneflow/core/graph/collective_boxing_pack_task_node.h index 9230019e4f7..e3318bd4b9e 100644 --- a/oneflow/core/graph/collective_boxing_pack_task_node.h +++ b/oneflow/core/graph/collective_boxing_pack_task_node.h @@ -31,6 +31,10 @@ class CollectiveBoxingPackTaskNode : public TransportTaskNode { const SbpParallel& dst_sbp_parallel, const int64_t parallel_num); TaskType GetTaskType() const override { return TaskType::kCollectiveBoxingPack; } + Maybe InitTransportTaskFromProto(const TransportTaskProto& transport_task_proto, + const TaskGraphRebuildCtx& ctx) override; + void ToTransportTaskProto(TransportTaskProto*) const override; + private: void BuildExecGphAndRegst() override; void ProduceAllRegstsAndBindEdges() override; diff --git a/oneflow/core/graph/collective_boxing_task_node.cpp b/oneflow/core/graph/collective_boxing_task_node.cpp index 08d3ee37d8f..75b2459a582 100644 --- a/oneflow/core/graph/collective_boxing_task_node.cpp +++ b/oneflow/core/graph/collective_boxing_task_node.cpp @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "oneflow/core/graph/boxing_task_graph.pb.h" #include "oneflow/core/graph/collective_boxing_task_node.h" #include "oneflow/core/graph/boxing/collective_boxing_util.h" @@ -62,4 +63,19 @@ void CollectiveBoxingGenericTaskNode::InferProducedDataRegstTimeShape() { if (out_regst != nullptr) { out_regst->mut_data_regst_time_shape()->reset(new Shape({1, 1})); } } +Maybe CollectiveBoxingGenericTaskNode::InitTransportTaskFromProto( + const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) { + CHECK_OR_RETURN(transport_task_proto.has_collective_boxing_generic_task()) + << "not a serialized CollectiveBoxingGenericTaskNode. debug string: " + << transport_task_proto.DebugString(); + op_conf_ = transport_task_proto.collective_boxing_generic_task().op_conf(); + return Maybe::Ok(); +} + +void CollectiveBoxingGenericTaskNode::ToTransportTaskProto( + TransportTaskProto* transport_task_proto) const { + ToProto(transport_task_proto->mutable_task_proto(), /*check=*/false); + *transport_task_proto->mutable_collective_boxing_generic_task()->mutable_op_conf() = op_conf_; +} + } // namespace oneflow diff --git a/oneflow/core/graph/collective_boxing_task_node.h b/oneflow/core/graph/collective_boxing_task_node.h index 12ee316f22c..b3eb91a8a04 100644 --- a/oneflow/core/graph/collective_boxing_task_node.h +++ b/oneflow/core/graph/collective_boxing_task_node.h @@ -29,6 +29,10 @@ class CollectiveBoxingGenericTaskNode : public TransportTaskNode { void Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi, const OperatorConf& op_conf); + Maybe InitTransportTaskFromProto(const TransportTaskProto& transport_task_proto, + const TaskGraphRebuildCtx& ctx) override; + void ToTransportTaskProto(TransportTaskProto*) const override; + private: void BuildExecGphAndRegst() override; void ProduceAllRegstsAndBindEdges() override; diff --git a/oneflow/core/graph/collective_boxing_unpack_task_node.cpp b/oneflow/core/graph/collective_boxing_unpack_task_node.cpp index 1a8dc4127a5..624c959679b 100644 --- a/oneflow/core/graph/collective_boxing_unpack_task_node.cpp +++ b/oneflow/core/graph/collective_boxing_unpack_task_node.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/to_string.h" +#include "oneflow/core/graph/boxing_task_graph.pb.h" #include "oneflow/core/graph/collective_boxing_unpack_task_node.h" namespace oneflow { @@ -66,4 +67,27 @@ void CollectiveBoxingUnpackTaskNode::InferProducedDataRegstTimeShape() { NaiveInferProducedDataRegstTimeShape(); } +Maybe CollectiveBoxingUnpackTaskNode::InitTransportTaskFromProto( + const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) { + CHECK_OR_RETURN(transport_task_proto.has_collective_boxing_unpack_task()) + << "not a serialized CollectiveBoxingUnpackTaskNode. debug string: " + << transport_task_proto.DebugString(); + const auto& proto = transport_task_proto.collective_boxing_unpack_task(); + logical_shape_ = Shape(proto.logical_shape()); + src_sbp_parallel_ = proto.src_sbp_parallel(); + dst_sbp_parallel_ = proto.dst_sbp_parallel(); + parallel_num_ = proto.parallel_num(); + return Maybe::Ok(); +} + +void CollectiveBoxingUnpackTaskNode::ToTransportTaskProto( + TransportTaskProto* transport_task_proto) const { + ToProto(transport_task_proto->mutable_task_proto(), /*check=*/false); + auto* proto = transport_task_proto->mutable_collective_boxing_unpack_task(); + logical_shape_.ToProto(proto->mutable_logical_shape()); + *proto->mutable_src_sbp_parallel() = src_sbp_parallel_; + *proto->mutable_dst_sbp_parallel() = dst_sbp_parallel_; + proto->set_parallel_num(parallel_num_); +} + } // namespace oneflow diff --git a/oneflow/core/graph/collective_boxing_unpack_task_node.h b/oneflow/core/graph/collective_boxing_unpack_task_node.h index 4d74b6952c7..c6738459785 100644 --- a/oneflow/core/graph/collective_boxing_unpack_task_node.h +++ b/oneflow/core/graph/collective_boxing_unpack_task_node.h @@ -32,6 +32,10 @@ class CollectiveBoxingUnpackTaskNode : public TransportTaskNode { TaskType GetTaskType() const override { return TaskType::kCollectiveBoxingUnpack; } + Maybe InitTransportTaskFromProto(const TransportTaskProto& transport_task_proto, + const TaskGraphRebuildCtx& ctx) override; + void ToTransportTaskProto(TransportTaskProto*) const override; + private: void BuildExecGphAndRegst() override; void ProduceAllRegstsAndBindEdges() override; diff --git a/oneflow/core/graph/compute_task_node.cpp b/oneflow/core/graph/compute_task_node.cpp index c6bff0f3f8b..f6547053970 100644 --- a/oneflow/core/graph/compute_task_node.cpp +++ b/oneflow/core/graph/compute_task_node.cpp @@ -67,8 +67,8 @@ std::vector GetCompTaskNodesOnEdge( std::string CompTaskNode::VisualStr() const { return op_node_->op().op_name(); } -void CompTaskNode::ToProto(TaskProto* task_proto) const { - TaskNode::ToProto(task_proto); +void CompTaskNode::ToProto(TaskProto* task_proto, bool check) const { + TaskNode::ToProto(task_proto, check); *(task_proto->mutable_parallel_ctx()) = parallel_ctx_; } diff --git a/oneflow/core/graph/compute_task_node.h b/oneflow/core/graph/compute_task_node.h index 5c78a487f53..05b7cb5c190 100644 --- a/oneflow/core/graph/compute_task_node.h +++ b/oneflow/core/graph/compute_task_node.h @@ -28,7 +28,7 @@ class CompTaskNode : public TaskNode { CompTaskNode() = default; virtual ~CompTaskNode() = default; - virtual void ToProto(TaskProto*) const override; + virtual void ToProto(TaskProto*, bool check) const override; // parallel_ctx_ int64_t parallel_id() const { return parallel_ctx_.parallel_id(); } diff --git a/oneflow/core/graph/copy_task_node.cpp b/oneflow/core/graph/copy_task_node.cpp index 6863845e576..490fb6643f0 100644 --- a/oneflow/core/graph/copy_task_node.cpp +++ b/oneflow/core/graph/copy_task_node.cpp @@ -122,4 +122,30 @@ OperatorConf CopyCommNetTaskNode::NewCopyOpConf() { return conf; } +Maybe CopyHdTaskNode::InitTransportTaskFromProto( + const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) { + CHECK_OR_RETURN(transport_task_proto.has_copy_hd_task()) + << "not a serialized CopyHdTaskNode. debug string: " << transport_task_proto.DebugString(); + copy_type_ = transport_task_proto.copy_hd_task().copy_type(); + return Maybe::Ok(); +} + +void CopyHdTaskNode::ToTransportTaskProto(TransportTaskProto* transport_task_proto) const { + ToProto(transport_task_proto->mutable_task_proto(), /*check=*/false); + transport_task_proto->mutable_copy_hd_task()->set_copy_type(copy_type_); +} + +Maybe CopyCommNetTaskNode::InitTransportTaskFromProto( + const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) { + CHECK_OR_RETURN(transport_task_proto.has_copy_comm_net_task()) + << "not a serialized CopyCommNetTaskNode. debug string: " + << transport_task_proto.DebugString(); + return Maybe::Ok(); +} + +void CopyCommNetTaskNode::ToTransportTaskProto(TransportTaskProto* transport_task_proto) const { + ToProto(transport_task_proto->mutable_task_proto(), /*check=*/false); + transport_task_proto->mutable_copy_comm_net_task(); +} + } // namespace oneflow diff --git a/oneflow/core/graph/copy_task_node.h b/oneflow/core/graph/copy_task_node.h index a768f65dd3f..ffce2dead94 100644 --- a/oneflow/core/graph/copy_task_node.h +++ b/oneflow/core/graph/copy_task_node.h @@ -17,9 +17,9 @@ limitations under the License. #define ONEFLOW_CORE_GRAPH_COPY_TASK_NODE_H_ #include "oneflow/core/graph/transport_task_node.h" +#include "oneflow/core/graph/boxing_task_graph.pb.h" namespace oneflow { - class CopyTaskNode : public TransportTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(CopyTaskNode); @@ -37,8 +37,6 @@ class CopyTaskNode : public TransportTaskNode { void InferProducedDataRegstTimeShape() final; }; -enum CopyHdType { H2D = 0, D2H = 1 }; - class CopyHdTaskNode final : public CopyTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(CopyHdTaskNode); @@ -63,6 +61,10 @@ class CopyHdTaskNode final : public CopyTaskNode { return kInvalidMemZoneId; } + Maybe InitTransportTaskFromProto(const TransportTaskProto& transport_task_proto, + const TaskGraphRebuildCtx& ctx) override; + void ToTransportTaskProto(TransportTaskProto*) const override; + private: void InitProducedRegstMemCase(MemoryCase*) override; OperatorConf NewCopyOpConf() override; @@ -80,6 +82,10 @@ class CopyCommNetTaskNode final : public CopyTaskNode { void Init(int64_t machine_id, const LogicalBlobId& lbi); + Maybe InitTransportTaskFromProto(const TransportTaskProto& transport_task_proto, + const TaskGraphRebuildCtx& ctx) override; + void ToTransportTaskProto(TransportTaskProto*) const override; + private: OperatorConf NewCopyOpConf() override; }; diff --git a/oneflow/core/graph/nccl_send_recv_boxing_task_node.cpp b/oneflow/core/graph/nccl_send_recv_boxing_task_node.cpp index e6ab2530c36..85466658e3f 100644 --- a/oneflow/core/graph/nccl_send_recv_boxing_task_node.cpp +++ b/oneflow/core/graph/nccl_send_recv_boxing_task_node.cpp @@ -15,6 +15,8 @@ limitations under the License. */ #include "oneflow/core/framework/to_string.h" #include "oneflow/core/graph/nccl_send_recv_boxing_task_node.h" +#include "oneflow/core/graph/boxing_task_graph.pb.h" +#include "oneflow/core/job/placement.pb.h" namespace oneflow { @@ -93,4 +95,41 @@ void NcclSendRecvBoxingTaskNode::InferProducedDataRegstTimeShape() { tmp_regst->mut_data_regst_time_shape()->reset(new Shape({1, 1})); } +Maybe NcclSendRecvBoxingTaskNode::InitTransportTaskFromProto( + const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) { + CHECK_OR_RETURN(transport_task_proto.has_nccl_send_recv_boxing_task()) + << "not a serialized NcclSendRecvBoxingTaskNode. debug string: " + << transport_task_proto.DebugString(); + const auto& proto = transport_task_proto.nccl_send_recv_boxing_task(); + logical_shape_ = Shape(proto.logical_shape()); + data_type_ = proto.data_type(); + src_nd_sbp_ = proto.src_nd_sbp(); + dst_nd_sbp_ = proto.dst_nd_sbp(); + src_parallel_conf_ = proto.src_parallel_conf(); + dst_parallel_conf_ = proto.dst_parallel_conf(); + parallel_conf_ = proto.parallel_conf(); + parallel_ctx_ = proto.parallel_ctx(); + has_input_ = proto.has_input(); + has_output_ = proto.has_output(); + stream_name_ = proto.stream_name(); + return Maybe::Ok(); +} + +void NcclSendRecvBoxingTaskNode::ToTransportTaskProto( + TransportTaskProto* transport_task_proto) const { + ToProto(transport_task_proto->mutable_task_proto(), /*check=*/false); + auto* proto = transport_task_proto->mutable_nccl_send_recv_boxing_task(); + logical_shape_.ToProto(proto->mutable_logical_shape()); + proto->set_data_type(data_type_); + *proto->mutable_src_nd_sbp() = src_nd_sbp_; + *proto->mutable_dst_nd_sbp() = dst_nd_sbp_; + *proto->mutable_src_parallel_conf() = src_parallel_conf_; + *proto->mutable_dst_parallel_conf() = dst_parallel_conf_; + *proto->mutable_parallel_conf() = parallel_conf_; + *proto->mutable_parallel_ctx() = parallel_ctx_; + proto->set_has_input(has_input_); + proto->set_has_output(has_output_); + proto->set_stream_name(stream_name_); +} + } // namespace oneflow diff --git a/oneflow/core/graph/nccl_send_recv_boxing_task_node.h b/oneflow/core/graph/nccl_send_recv_boxing_task_node.h index 1fcc4482f0e..fb14dac8761 100644 --- a/oneflow/core/graph/nccl_send_recv_boxing_task_node.h +++ b/oneflow/core/graph/nccl_send_recv_boxing_task_node.h @@ -17,6 +17,8 @@ limitations under the License. #define ONEFLOW_CORE_GRAPH_NCCL_SEND_RECV_BOXING_TASK_NODE_H_ #include "oneflow/core/graph/transport_task_node.h" +#include "oneflow/core/graph/boxing_task_graph.pb.h" +#include "oneflow/core/job/placement.pb.h" namespace oneflow { @@ -35,6 +37,10 @@ class NcclSendRecvBoxingTaskNode : public TransportTaskNode { TaskType GetTaskType() const override { return TaskType::kNcclSendRecvBoxing; } const ParallelContext* parallel_ctx() const override { return ¶llel_ctx_; } + Maybe InitTransportTaskFromProto(const TransportTaskProto& transport_task_proto, + const TaskGraphRebuildCtx& ctx) override; + void ToTransportTaskProto(TransportTaskProto*) const override; + private: void BuildExecGphAndRegst() override; void ProduceAllRegstsAndBindEdges() override; diff --git a/oneflow/core/graph/slice_boxing_task_node.cpp b/oneflow/core/graph/slice_boxing_task_node.cpp index 2bb283c1765..14dd4e29ffb 100644 --- a/oneflow/core/graph/slice_boxing_task_node.cpp +++ b/oneflow/core/graph/slice_boxing_task_node.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/framework/to_string.h" #include "oneflow/core/graph/slice_boxing_task_node.h" +#include "oneflow/core/graph/task_graph_rebuild_ctx.h" namespace oneflow { @@ -108,4 +109,39 @@ OperatorConf SliceBoxingTaskNode::GetBoxingOpConf() { return op_conf; } +Maybe SliceBoxingTaskNode::InitTransportTaskFromProto( + const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) { + CHECK_OR_RETURN(transport_task_proto.has_slice_boxing_task()) + << "not a serialized SliceBoxingTaskNode. debug string: " + << transport_task_proto.DebugString(); + const auto& proto = transport_task_proto.slice_boxing_task(); + for (const auto& pair : proto.in_data_edge_uid2slice()) { + const auto* edge = JUST(ctx.TaskEdge4Uid(pair.first)); + CHECK_OR_RETURN(in_data_edge2slice_.emplace(edge, pair.second).second) + << "redundant edge found. edge_uid: " << pair.first; + } + for (int64_t edge_uid : proto.ordered_in_data_edge_uid()) { + ordered_in_data_edges_.push_back(JUST(ctx.TaskEdge4Uid(edge_uid))); + } + out_slice_ = TensorSliceView(proto.out_slice()); + out_shape_ = Shape(proto.out_shape()); + mode_ = proto.mode(); + return Maybe::Ok(); +} + +void SliceBoxingTaskNode::ToTransportTaskProto(TransportTaskProto* transport_task_proto) const { + ToProto(transport_task_proto->mutable_task_proto(), /*check=*/false); + auto* proto = transport_task_proto->mutable_slice_boxing_task(); + for (const auto& pair : in_data_edge2slice_) { + int64_t edge_uid = reinterpret_cast(pair.first); + pair.second.ToProto(&(*proto->mutable_in_data_edge_uid2slice())[edge_uid]); + } + for (const auto* edge : ordered_in_data_edges_) { + proto->add_ordered_in_data_edge_uid(reinterpret_cast(edge)); + } + out_slice_.ToProto(proto->mutable_out_slice()); + out_shape_.ToProto(proto->mutable_out_shape()); + proto->set_mode(mode_); +} + } // namespace oneflow diff --git a/oneflow/core/graph/slice_boxing_task_node.h b/oneflow/core/graph/slice_boxing_task_node.h index b05fdc7bc1a..a7f8f8f5817 100644 --- a/oneflow/core/graph/slice_boxing_task_node.h +++ b/oneflow/core/graph/slice_boxing_task_node.h @@ -16,18 +16,13 @@ limitations under the License. #ifndef ONEFLOW_CORE_GRAPH_SLICE_BOXING_TASK_NODE_H_ #define ONEFLOW_CORE_GRAPH_SLICE_BOXING_TASK_NODE_H_ +#include "oneflow/core/graph/boxing_task_graph.pb.h" #include "oneflow/core/graph/transport_task_node.h" #include "oneflow/core/register/tensor_slice_view.h" #include "oneflow/core/memory/memory_zone.h" namespace oneflow { -enum SliceBoxingTaskMode { - kSliceBoxingTaskModeInvalid, - kSliceBoxingTaskModeCopy, - kSliceBoxingTaskModeAdd, -}; - class SliceBoxingTaskNode final : public TransportTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(SliceBoxingTaskNode); @@ -43,6 +38,10 @@ class SliceBoxingTaskNode final : public TransportTaskNode { void ConnectToSrcNodeWithSlice(TaskNode* src, TaskEdge* edge, const TensorSliceView& slice); void SetOutShape(const Shape& shape); + Maybe InitTransportTaskFromProto(const TransportTaskProto& transport_task_proto, + const TaskGraphRebuildCtx& ctx) override; + void ToTransportTaskProto(TransportTaskProto*) const override; + private: void BuildExecGphAndRegst() override; void InferProducedDataRegstTimeShape() override; diff --git a/oneflow/core/graph/task_edge.proto b/oneflow/core/graph/task_edge.proto new file mode 100644 index 00000000000..e61740f6a3a --- /dev/null +++ b/oneflow/core/graph/task_edge.proto @@ -0,0 +1,12 @@ +syntax = "proto2"; +package oneflow; + +import "oneflow/core/register/logical_blob_id.proto"; + +message TaskEdgeProto { + required int64 task_edge_uid = 1; + required int64 src_task_id = 2; + required int64 dst_task_id = 3; + repeated LogicalBlobId lbi = 4; + map name_in_producer2regst_desc_id = 5; +}; \ No newline at end of file diff --git a/oneflow/core/graph/task_graph_rebuild_ctx.cpp b/oneflow/core/graph/task_graph_rebuild_ctx.cpp new file mode 100644 index 00000000000..28984657252 --- /dev/null +++ b/oneflow/core/graph/task_graph_rebuild_ctx.cpp @@ -0,0 +1,67 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/common/container_util.h" +#include "oneflow/core/graph/task_node.h" +#include "oneflow/core/graph/task_graph_rebuild_ctx.h" + +namespace oneflow { + +Maybe TaskGraphRebuildCtx::TaskNode4Id(int64_t task_id) const { + auto* task_node = JUST(MapAt(id2task_node_, task_id)); + CHECK_EQ_OR_RETURN(task_node->task_id(), task_id); // NOLINT + return task_node; +} + +Maybe TaskGraphRebuildCtx::TaskEdge4Uid(int64_t task_edge_uid) const { + return JUST(MapAt(uid2task_edge_, task_edge_uid)); +} + +Maybe TaskGraphRebuildCtx::RegstDesc4Id(int64_t regst_desc_id) const { + return JUST(MapAt(id2regst_desc_, regst_desc_id)); +} + +Maybe TaskGraphRebuildCtx::AddTaskNode(TaskNode* task_node) { + CHECK_OR_RETURN(id2task_node_.emplace(task_node->task_id(), task_node).second) + << "redundant task id found. value: " << task_node->task_id(); + for (const auto& pair : task_node->produced_regsts()) { JUST(AddRegstDesc(pair.second)); } + return Maybe::Ok(); +} + +Maybe TaskGraphRebuildCtx::AddTaskEdge(TaskEdge* task_edge, int64_t task_edge_uid) { + CHECK_OR_RETURN(uid2task_edge_.emplace(task_edge_uid, task_edge).second) + << "redundant task edge uid found. value: " << task_edge_uid; + return Maybe::Ok(); +} + +Maybe TaskGraphRebuildCtx::AddRegstDesc(const std::shared_ptr& regst_desc) { + CHECK_OR_RETURN(id2regst_desc_.emplace(regst_desc->regst_desc_id(), regst_desc).second) + << "redundant register descriptor id found. value: " << regst_desc->regst_desc_id(); + return Maybe::Ok(); +} + +} // namespace oneflow \ No newline at end of file diff --git a/oneflow/core/graph/task_graph_rebuild_ctx.h b/oneflow/core/graph/task_graph_rebuild_ctx.h new file mode 100644 index 00000000000..53b3557b877 --- /dev/null +++ b/oneflow/core/graph/task_graph_rebuild_ctx.h @@ -0,0 +1,60 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_GRAPH_TASK_GRAPH_REBUILD_CTX_H_ +#define ONEFLOW_CORE_GRAPH_TASK_GRAPH_REBUILD_CTX_H_ + +#include "oneflow/core/common/maybe.h" +#include "oneflow/core/register/register_desc.h" + +namespace oneflow { + +class TaskNode; +class TaskEdge; + +class TaskGraphRebuildCtx { + public: + TaskGraphRebuildCtx() = default; + ~TaskGraphRebuildCtx() = default; + + Maybe TaskNode4Id(int64_t task_id) const; + Maybe TaskEdge4Uid(int64_t task_edge_uid) const; + Maybe RegstDesc4Id(int64_t regst_desc_id) const; + + Maybe AddTaskNode(TaskNode* task_node); + Maybe AddTaskEdge(TaskEdge* task_edge, int64_t task_edge_uid); + Maybe AddRegstDesc(const std::shared_ptr& regst_desc); + + private: + HashMap id2task_node_; + HashMap uid2task_edge_; + HashMap> id2regst_desc_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_GRAPH_TASK_GRAPH_REBUILD_CTX_H_ \ No newline at end of file diff --git a/oneflow/core/graph/task_node.cpp b/oneflow/core/graph/task_node.cpp index cd29a5e15b8..ed25fa43cc4 100644 --- a/oneflow/core/graph/task_node.cpp +++ b/oneflow/core/graph/task_node.cpp @@ -16,6 +16,7 @@ limitations under the License. #include "oneflow/core/graph/task_node.h" #include "oneflow/core/job/id_manager.h" #include "oneflow/core/memory/memory_case_util.h" +#include "oneflow/core/graph/task_graph_rebuild_ctx.h" namespace oneflow { @@ -202,7 +203,7 @@ std::string TaskNode::VisualStr() const { bool TaskNode::IsMeaningLess() { return produced_regsts_.empty() && consumed_regsts_.empty(); } -void TaskNode::ToProto(TaskProto* task_proto) const { +void TaskNode::ToProto(TaskProto* task_proto, bool check) const { // Step1: process some scalar items. task_proto->set_task_type(GetTaskType()); task_proto->set_machine_id(machine_id_); @@ -219,7 +220,7 @@ void TaskNode::ToProto(TaskProto* task_proto) const { auto* produced_regst_proto = task_proto->mutable_produced_regst_desc(); for (auto& pair : produced_regsts_) { RegstDescProto regst_desc_proto; - pair.second->ToProto(®st_desc_proto); + pair.second->ToProto(®st_desc_proto, check); CHECK(produced_regst_proto->insert({pair.first, regst_desc_proto}).second); } @@ -460,4 +461,29 @@ size_t TaskNode::in_data_edges_size() const { return GetEdgesSize(&TaskNode::For size_t TaskNode::out_data_edges_size() const { return GetEdgesSize(&TaskNode::ForEachOutDataEdge); } +Maybe TaskEdge::InitFromProto(const TaskEdgeProto& proto, + const TaskGraphRebuildCtx& task_graph_rebuild_ctx) { + CHECK_NE_OR_RETURN(proto.src_task_id(), proto.dst_task_id()) << "self-loop are not supported"; + JUST(task_graph_rebuild_ctx.TaskNode4Id(proto.src_task_id())); + JUST(task_graph_rebuild_ctx.TaskNode4Id(proto.dst_task_id())); + // Note that edge id from proto is ignored. + lbis_.insert(proto.lbi().begin(), proto.lbi().end()); + for (const auto& pair : proto.name_in_producer2regst_desc_id()) { + AddRegst(pair.first, JUST(task_graph_rebuild_ctx.RegstDesc4Id(pair.second))); + } + return Maybe::Ok(); +} + +void TaskEdge::ToProto(TaskEdgeProto* proto) const { + // proto->set_task_edge_uid(edge_id()); + proto->set_task_edge_uid(reinterpret_cast(this)); + proto->set_src_task_id(src_node()->task_id()); + proto->set_dst_task_id(dst_node()->task_id()); + *proto->mutable_lbi() = {lbis_.begin(), lbis_.end()}; + auto* map = proto->mutable_name_in_producer2regst_desc_id(); + for (const auto& pair : name_in_producer2regst_) { + CHECK(map->insert({pair.first, pair.second->regst_desc_id()}).second); + } +} + } // namespace oneflow diff --git a/oneflow/core/graph/task_node.h b/oneflow/core/graph/task_node.h index beb21e5d208..81102754bff 100644 --- a/oneflow/core/graph/task_node.h +++ b/oneflow/core/graph/task_node.h @@ -18,6 +18,7 @@ limitations under the License. #include "oneflow/core/graph/exec_graph.h" #include "oneflow/core/job/task.pb.h" +#include "oneflow/core/graph/task_edge.pb.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/common/auto_registration_factory.h" #include "oneflow/core/memory/memory_zone.h" @@ -95,7 +96,8 @@ class TaskNode : public Node { virtual TaskType GetTaskType() const { return TaskType::kInvalid; } std::string VisualStr() const override; virtual bool IsMeaningLess(); - virtual void ToProto(TaskProto*) const; + void ToProto(TaskProto* task_proto) const { ToProto(task_proto, /*check*/ true); } + virtual void ToProto(TaskProto* task_proto, bool check) const; void BindEdgeWithProducedRegst(TaskEdge*, const std::string& name); virtual MemZoneId MemZoneId121() const; bool BuildCtrlRegstDescIfNeed(TaskNode* dst_node, std::string* name); @@ -114,6 +116,7 @@ class TaskNode : public Node { TaskEdge* SoleOutDataEdge() const; size_t in_data_edges_size() const; size_t out_data_edges_size() const; + bool has_new_task_id() const { return static_cast(new_task_id_); } protected: std::shared_ptr ProduceRegst(const std::string& name, bool enable_reuse_mem); @@ -157,6 +160,8 @@ class TaskNode : public Node { HashMap>> consumed_regsts_; }; +class TaskGraphRebuildCtx; + class TaskEdge final : public Edge { public: OF_DISALLOW_COPY_AND_MOVE(TaskEdge); @@ -174,6 +179,10 @@ class TaskEdge final : public Edge { void CheckRegstLbiValid() const; + Maybe InitFromProto(const TaskEdgeProto& proto, + const TaskGraphRebuildCtx& task_graph_rebuild_ctx); + void ToProto(TaskEdgeProto* proto) const; + private: HashSet lbis_; HashMap> name_in_producer2regst_; diff --git a/oneflow/core/graph/transport_task_node.cpp b/oneflow/core/graph/transport_task_node.cpp new file mode 100644 index 00000000000..a976a19fa61 --- /dev/null +++ b/oneflow/core/graph/transport_task_node.cpp @@ -0,0 +1,46 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/graph/transport_task_node.h" +#include "oneflow/core/graph/boxing_task_graph.pb.h" + +namespace oneflow { + +Maybe TransportTaskNode::InitTransportTaskFromProtoIf( + const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) { + CHECK(has_new_task_id()); + JUST(InitTransportTaskFromProto(transport_task_proto, ctx)); + lbi_ = transport_task_proto.lbi(); + return Maybe::Ok(); +} + +void TransportTaskNode::ToTransportTaskProtoIf(TransportTaskProto* transport_task_proto) const { + ToTransportTaskProto(transport_task_proto); + *transport_task_proto->mutable_lbi() = lbi_; +} + +} // namespace oneflow \ No newline at end of file diff --git a/oneflow/core/graph/transport_task_node.h b/oneflow/core/graph/transport_task_node.h index 75dbc6b7221..4a2662bafb5 100644 --- a/oneflow/core/graph/transport_task_node.h +++ b/oneflow/core/graph/transport_task_node.h @@ -21,6 +21,9 @@ limitations under the License. namespace oneflow { +class TransportTaskProto; +class TaskGraphRebuildCtx; + class TransportTaskNode : public TaskNode { public: OF_DISALLOW_COPY_AND_MOVE(TransportTaskNode); @@ -30,7 +33,14 @@ class TransportTaskNode : public TaskNode { void set_lbi(const LogicalBlobId& lbi) { lbi_ = lbi; } LogicalBlobId lbi() const { return lbi_; } + Maybe InitTransportTaskFromProtoIf(const TransportTaskProto& transport_task_proto, + const TaskGraphRebuildCtx& ctx); + void ToTransportTaskProtoIf(TransportTaskProto*) const; + private: + virtual Maybe InitTransportTaskFromProto(const TransportTaskProto&, + const TaskGraphRebuildCtx& ctx) = 0; + virtual void ToTransportTaskProto(TransportTaskProto*) const = 0; LogicalBlobId lbi_; }; diff --git a/oneflow/core/register/register_desc.cpp b/oneflow/core/register/register_desc.cpp index b1defff2ebd..ca5ac440556 100644 --- a/oneflow/core/register/register_desc.cpp +++ b/oneflow/core/register/register_desc.cpp @@ -120,7 +120,7 @@ void RegstDesc::EraseUninitializedShapeBlob() { }); } -void RegstDesc::ToProto(RegstDescProto* ret) const { +void RegstDesc::ToProto(RegstDescProto* ret, bool check) const { ret->set_regst_desc_id(regst_desc_id_); ret->set_producer_task_id(producer_->task_id()); for (const TaskNode* consumer : consumers_) { ret->add_consumer_task_id(consumer->task_id()); } @@ -133,8 +133,10 @@ void RegstDesc::ToProto(RegstDescProto* ret) const { *(pb_pair->mutable_lbi()) = pair.first; pair.second->ToProto(pb_pair->mutable_blob_desc()); } - CHECK(data_regst_time_shape_); - data_regst_time_shape_->ToProto(data_regst_desc_proto->mutable_time_shape()); + if (check) { CHECK(data_regst_time_shape_); } + if (data_regst_time_shape_) { + data_regst_time_shape_->ToProto(data_regst_desc_proto->mutable_time_shape()); + } } else if (regst_desc_type_.has_ctrl_regst_desc()) { // do nothing } else { @@ -147,8 +149,10 @@ void RegstDesc::ToProto(RegstDescProto* ret) const { ret->set_enable_reuse_mem(enable_reuse_mem_); ret->set_mem_block_id(mem_block_id_); ret->set_mem_block_offset(mem_block_offset_); - CHECK(hint_inplace_consumed_regst_desc_id_ == -1 || force_inplace_consumed_regst_desc_id_ == -1) - << "They are oneof fields"; + if (check) { + CHECK(hint_inplace_consumed_regst_desc_id_ == -1 || force_inplace_consumed_regst_desc_id_ == -1) + << "They are oneof fields"; + } if (hint_inplace_consumed_regst_desc_id_ != -1) { ret->set_hint_inplace_consumed_regst_desc_id(hint_inplace_consumed_regst_desc_id_); } else if (force_inplace_consumed_regst_desc_id_ != -1) { diff --git a/oneflow/core/register/register_desc.h b/oneflow/core/register/register_desc.h index 1a57895dd16..faee94ca856 100644 --- a/oneflow/core/register/register_desc.h +++ b/oneflow/core/register/register_desc.h @@ -99,7 +99,8 @@ class RegstDesc final { // util void EraseUninitializedShapeBlob(); - void ToProto(RegstDescProto*) const; + void ToProto(RegstDescProto* proto) const { ToProto(proto, /*check*/ true); } + void ToProto(RegstDescProto*, bool check) const; bool HasSameBlobDescs(const RegstDesc*); private: From db84632cd8c8d734edc49f54267940d07749b922 Mon Sep 17 00:00:00 2001 From: Xiaoyu Xu Date: Wed, 12 Apr 2023 22:56:58 +0800 Subject: [PATCH 02/18] Update boxing_task_graph.proto --- oneflow/core/graph/boxing_task_graph.proto | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/core/graph/boxing_task_graph.proto b/oneflow/core/graph/boxing_task_graph.proto index e732bdf82f0..7ee6a8901fa 100644 --- a/oneflow/core/graph/boxing_task_graph.proto +++ b/oneflow/core/graph/boxing_task_graph.proto @@ -107,4 +107,4 @@ message BoxingTaskGraphProto { repeated TransportTaskProto transport_task = 2; repeated TaskEdgeProto task_edge = 3; map boxing_unrelated_op_name2task_ids = 4; -} \ No newline at end of file +} From 49c8d182cafcda4621b92d51148090b922fe6ecc Mon Sep 17 00:00:00 2001 From: Xiaoyu Xu Date: Wed, 12 Apr 2023 22:57:32 +0800 Subject: [PATCH 03/18] Update task_edge.proto --- oneflow/core/graph/task_edge.proto | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/core/graph/task_edge.proto b/oneflow/core/graph/task_edge.proto index e61740f6a3a..e43331d9365 100644 --- a/oneflow/core/graph/task_edge.proto +++ b/oneflow/core/graph/task_edge.proto @@ -9,4 +9,4 @@ message TaskEdgeProto { required int64 dst_task_id = 3; repeated LogicalBlobId lbi = 4; map name_in_producer2regst_desc_id = 5; -}; \ No newline at end of file +}; From 041a4b3344ff959cf31df6f86297e22949ce7683 Mon Sep 17 00:00:00 2001 From: Xiaoyu Xu Date: Wed, 12 Apr 2023 22:58:00 +0800 Subject: [PATCH 04/18] Update task_graph_rebuild_ctx.cpp --- oneflow/core/graph/task_graph_rebuild_ctx.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/core/graph/task_graph_rebuild_ctx.cpp b/oneflow/core/graph/task_graph_rebuild_ctx.cpp index 28984657252..cfd947f6823 100644 --- a/oneflow/core/graph/task_graph_rebuild_ctx.cpp +++ b/oneflow/core/graph/task_graph_rebuild_ctx.cpp @@ -64,4 +64,4 @@ Maybe TaskGraphRebuildCtx::AddRegstDesc(const std::shared_ptr& return Maybe::Ok(); } -} // namespace oneflow \ No newline at end of file +} // namespace oneflow From 3e089442305c3cd28dbec94b12adc16828a4c21e Mon Sep 17 00:00:00 2001 From: Xiaoyu Xu Date: Wed, 12 Apr 2023 22:58:29 +0800 Subject: [PATCH 05/18] Update task_graph_rebuild_ctx.h --- oneflow/core/graph/task_graph_rebuild_ctx.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/core/graph/task_graph_rebuild_ctx.h b/oneflow/core/graph/task_graph_rebuild_ctx.h index 53b3557b877..c0e6bc12751 100644 --- a/oneflow/core/graph/task_graph_rebuild_ctx.h +++ b/oneflow/core/graph/task_graph_rebuild_ctx.h @@ -57,4 +57,4 @@ class TaskGraphRebuildCtx { } // namespace oneflow -#endif // ONEFLOW_CORE_GRAPH_TASK_GRAPH_REBUILD_CTX_H_ \ No newline at end of file +#endif // ONEFLOW_CORE_GRAPH_TASK_GRAPH_REBUILD_CTX_H_ From e0cf92b33e847a06a46eb2453b50dd66b8803c95 Mon Sep 17 00:00:00 2001 From: Xiaoyu Xu Date: Wed, 12 Apr 2023 22:58:53 +0800 Subject: [PATCH 06/18] Update transport_task_node.cpp --- oneflow/core/graph/transport_task_node.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/core/graph/transport_task_node.cpp b/oneflow/core/graph/transport_task_node.cpp index a976a19fa61..1e3031e0003 100644 --- a/oneflow/core/graph/transport_task_node.cpp +++ b/oneflow/core/graph/transport_task_node.cpp @@ -43,4 +43,4 @@ void TransportTaskNode::ToTransportTaskProtoIf(TransportTaskProto* transport_tas *transport_task_proto->mutable_lbi() = lbi_; } -} // namespace oneflow \ No newline at end of file +} // namespace oneflow From 008239ed5538fc2d77f922da8c6442917c8a8c94 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 13 Apr 2023 04:40:46 +0000 Subject: [PATCH 07/18] support infer desc choose method --- oneflow/core/graph/boxing_identity_task_node.cpp | 2 +- oneflow/core/graph/boxing_zeros_task_node.cpp | 2 +- oneflow/core/graph/collective_boxing_pack_task_node.cpp | 2 +- oneflow/core/graph/collective_boxing_task_node.cpp | 2 +- oneflow/core/graph/collective_boxing_unpack_task_node.cpp | 2 +- oneflow/core/graph/compute_task_node.h | 4 ++++ oneflow/core/graph/exec_graph.cpp | 4 ++-- oneflow/core/graph/exec_graph.h | 3 ++- oneflow/core/graph/nccl_send_recv_boxing_task_node.cpp | 2 +- oneflow/core/graph/slice_boxing_task_node.cpp | 2 +- oneflow/core/graph/task_node.h | 3 +++ oneflow/core/graph/transport_task_node.h | 7 +++++++ oneflow/core/graph_impl/acc_compute_task_node.cpp | 2 +- .../core/graph_impl/acc_ctrl_tick_compute_task_node.cpp | 2 +- oneflow/core/graph_impl/acc_tick_compute_task_node.cpp | 2 +- oneflow/core/graph_impl/case_compute_task_node.cpp | 2 +- .../graph_impl/critical_section_wait_compute_task_node.cpp | 2 +- oneflow/core/graph_impl/decode_h2d_compute_task_node.cpp | 2 +- oneflow/core/graph_impl/device_tick_compute_task_node.cpp | 2 +- .../graph_impl/distribute_concat_compute_task_node.cpp | 3 ++- .../core/graph_impl/distribute_split_compute_task_node.cpp | 3 ++- .../core/graph_impl/dst_subset_tick_compute_task_node.cpp | 2 +- oneflow/core/graph_impl/esac_compute_task_node.cpp | 2 +- .../core/graph_impl/normal_forward_compute_task_node.cpp | 3 ++- oneflow/core/graph_impl/pack_compute_task_node.cpp | 2 +- .../core/graph_impl/reentrant_lock_compute_task_node.cpp | 2 +- oneflow/core/graph_impl/repeat_compute_task_node.cpp | 2 +- oneflow/core/graph_impl/source_tick_compute_task_node.cpp | 2 +- .../core/graph_impl/src_subset_tick_compute_task_node.cpp | 2 +- oneflow/core/graph_impl/ssp_variable_proxy_task_node.cpp | 2 +- oneflow/core/graph_impl/tick_compute_task_node.cpp | 2 +- oneflow/core/graph_impl/unpack_compute_task_node.cpp | 2 +- .../graph_impl/wait_and_send_ids_compute_task_node.cpp | 2 +- 33 files changed, 49 insertions(+), 31 deletions(-) diff --git a/oneflow/core/graph/boxing_identity_task_node.cpp b/oneflow/core/graph/boxing_identity_task_node.cpp index cbe7969e888..5d16f73c172 100644 --- a/oneflow/core/graph/boxing_identity_task_node.cpp +++ b/oneflow/core/graph/boxing_identity_task_node.cpp @@ -47,7 +47,7 @@ void BoxingIdentityTaskNode::BuildExecGphAndRegst() { std::shared_ptr out_regst = GetProducedRegst("out"); out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn())); node->BindBnWithRegst(sole_op->SoleObn(), out_regst); - node->InferBlobDescs(nullptr); + (node->*GetInferBlobDescsMethod())(nullptr); } void BoxingIdentityTaskNode::InferProducedDataRegstTimeShape() { diff --git a/oneflow/core/graph/boxing_zeros_task_node.cpp b/oneflow/core/graph/boxing_zeros_task_node.cpp index ce7f2e65e9d..8861db9528e 100644 --- a/oneflow/core/graph/boxing_zeros_task_node.cpp +++ b/oneflow/core/graph/boxing_zeros_task_node.cpp @@ -51,7 +51,7 @@ void BoxingZerosTaskNode::BuildExecGphAndRegst() { std::shared_ptr out_regst = GetProducedRegst("out"); out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn())); node->BindBnWithRegst(sole_op->SoleObn(), out_regst); - node->InferBlobDescs(nullptr); + (node->*GetInferBlobDescsMethod())(nullptr); } void BoxingZerosTaskNode::InferProducedDataRegstTimeShape() { diff --git a/oneflow/core/graph/collective_boxing_pack_task_node.cpp b/oneflow/core/graph/collective_boxing_pack_task_node.cpp index 5225591b036..ad3a7f3031c 100644 --- a/oneflow/core/graph/collective_boxing_pack_task_node.cpp +++ b/oneflow/core/graph/collective_boxing_pack_task_node.cpp @@ -60,7 +60,7 @@ void CollectiveBoxingPackTaskNode::BuildExecGphAndRegst() { std::shared_ptr out_regst = GetProducedRegst("out"); out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn())); node->BindBnWithRegst(sole_op->SoleObn(), out_regst); - node->InferBlobDescs(nullptr); + (node->*GetInferBlobDescsMethod())(nullptr); } void CollectiveBoxingPackTaskNode::InferProducedDataRegstTimeShape() { diff --git a/oneflow/core/graph/collective_boxing_task_node.cpp b/oneflow/core/graph/collective_boxing_task_node.cpp index 75b2459a582..ea44eeb751c 100644 --- a/oneflow/core/graph/collective_boxing_task_node.cpp +++ b/oneflow/core/graph/collective_boxing_task_node.cpp @@ -55,7 +55,7 @@ void CollectiveBoxingGenericTaskNode::BuildExecGphAndRegst() { node->BindBnWithRegst(obn, out_regst); out_regst->AddLbi(boxing_op->BnInOp2Lbi(obn)); } - node->InferBlobDescs(nullptr); + (node->*GetInferBlobDescsMethod())(nullptr); } void CollectiveBoxingGenericTaskNode::InferProducedDataRegstTimeShape() { diff --git a/oneflow/core/graph/collective_boxing_unpack_task_node.cpp b/oneflow/core/graph/collective_boxing_unpack_task_node.cpp index 624c959679b..f9c9fbf3f1c 100644 --- a/oneflow/core/graph/collective_boxing_unpack_task_node.cpp +++ b/oneflow/core/graph/collective_boxing_unpack_task_node.cpp @@ -60,7 +60,7 @@ void CollectiveBoxingUnpackTaskNode::BuildExecGphAndRegst() { std::shared_ptr out_regst = GetProducedRegst("out"); out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn())); node->BindBnWithRegst(sole_op->SoleObn(), out_regst); - node->InferBlobDescs(nullptr); + (node->*GetInferBlobDescsMethod())(nullptr); } void CollectiveBoxingUnpackTaskNode::InferProducedDataRegstTimeShape() { diff --git a/oneflow/core/graph/compute_task_node.h b/oneflow/core/graph/compute_task_node.h index 05b7cb5c190..bba6d957668 100644 --- a/oneflow/core/graph/compute_task_node.h +++ b/oneflow/core/graph/compute_task_node.h @@ -43,6 +43,10 @@ class CompTaskNode : public TaskNode { // op std::shared_ptr op() const { return op_node_->shared_op(); } + ExecNode::InferBlobDescsMethod GetInferBlobDescsMethod() const override { + return &ExecNode::InferBlobDescsByInputs; + } + protected: const OpNode* GetOneSuccOpNodeOnEdge(TaskEdge* edge); const OpNode* GetOnePredOpNodeOnEdge(TaskEdge* edge); diff --git a/oneflow/core/graph/exec_graph.cpp b/oneflow/core/graph/exec_graph.cpp index 2c47076abc7..a164a06e493 100644 --- a/oneflow/core/graph/exec_graph.cpp +++ b/oneflow/core/graph/exec_graph.cpp @@ -105,7 +105,7 @@ Maybe CheckPhysicalBlobDesc( } // namespace -void ExecNode::InferBlobDescs(const ParallelContext* parallel_ctx) { +void ExecNode::InferBlobDescsByInputs(const ParallelContext* parallel_ctx) { auto GetBlobDesc4BnInOp = GetBlobDesc4BnInOpFunc(); const OpNode* op_node = Singleton::Get()->OpNode4OpName(op()->op_name()); const NdSbpSignature* nd_sbp_signature = nullptr; @@ -128,7 +128,7 @@ void ExecNode::InferBlobDescs(const ParallelContext* parallel_ctx) { CHECK_JUST_MSG(op_->InferInplaceObn2IbnIf(&mut_inplace_obn2ibn_, &con_inplace_obn2ibn_, GetBlobDesc4BnInOp, parallel_ctx), std::stringstream() - << " infer inplace obn to ibn if failed, op name " << op_->op_loc()); + << " infer inplace obn to ibn is failed, op name " << op_->op_loc()); } std::function ExecNode::GetBlobDesc4BnInOpFunc() const { diff --git a/oneflow/core/graph/exec_graph.h b/oneflow/core/graph/exec_graph.h index 55728b5088f..c8a14abae9a 100644 --- a/oneflow/core/graph/exec_graph.h +++ b/oneflow/core/graph/exec_graph.h @@ -72,7 +72,8 @@ class ExecNode final : public Node { std::string VisualStr() const override { return op_->op_name(); } void ToProto(const ParallelContext*, ExecNodeProto*) const; - void InferBlobDescs(const ParallelContext* parallel_ctx); + typedef void (ExecNode::*InferBlobDescsMethod)(const ParallelContext*); + void InferBlobDescsByInputs(const ParallelContext* parallel_ctx); const HashMap& mut_inplace_obn2ibn() const { return mut_inplace_obn2ibn_; diff --git a/oneflow/core/graph/nccl_send_recv_boxing_task_node.cpp b/oneflow/core/graph/nccl_send_recv_boxing_task_node.cpp index 85466658e3f..9afbd3ba711 100644 --- a/oneflow/core/graph/nccl_send_recv_boxing_task_node.cpp +++ b/oneflow/core/graph/nccl_send_recv_boxing_task_node.cpp @@ -85,7 +85,7 @@ void NcclSendRecvBoxingTaskNode::BuildExecGphAndRegst() { node->BindBnWithRegst(sole_op->SoleObn(), out_regst); } node->AddBnToRegstAndBindIt(&Operator::tmp_bns, GetProducedRegst("tmp")); - node->InferBlobDescs(parallel_ctx()); + (node->*GetInferBlobDescsMethod())(parallel_ctx()); } void NcclSendRecvBoxingTaskNode::InferProducedDataRegstTimeShape() { diff --git a/oneflow/core/graph/slice_boxing_task_node.cpp b/oneflow/core/graph/slice_boxing_task_node.cpp index 14dd4e29ffb..038e1d70868 100644 --- a/oneflow/core/graph/slice_boxing_task_node.cpp +++ b/oneflow/core/graph/slice_boxing_task_node.cpp @@ -64,7 +64,7 @@ void SliceBoxingTaskNode::BuildExecGphAndRegst() { out_regst->AddLbi(lbi()); node->BindBnWithRegst(op->SoleObn(), out_regst); node->AddBnToRegstAndBindIt(&Operator::tmp_bns, GetProducedRegst("tmp")); - node->InferBlobDescs(parallel_ctx()); + (node->*GetInferBlobDescsMethod())(parallel_ctx()); } void SliceBoxingTaskNode::InferProducedDataRegstTimeShape() { diff --git a/oneflow/core/graph/task_node.h b/oneflow/core/graph/task_node.h index 81102754bff..25924b42457 100644 --- a/oneflow/core/graph/task_node.h +++ b/oneflow/core/graph/task_node.h @@ -71,6 +71,9 @@ class TaskNode : public Node { DeviceType device_type() const; virtual const ParallelContext* parallel_ctx() const { return nullptr; } + // Different types of ExecNode choose different output BlobDesc inference methods + virtual ExecNode::InferBlobDescsMethod GetInferBlobDescsMethod() const = 0; + // Setters void set_machine_id(int64_t val); void set_thrd_id(int64_t val); diff --git a/oneflow/core/graph/transport_task_node.h b/oneflow/core/graph/transport_task_node.h index 4a2662bafb5..58d2298459b 100644 --- a/oneflow/core/graph/transport_task_node.h +++ b/oneflow/core/graph/transport_task_node.h @@ -37,9 +37,16 @@ class TransportTaskNode : public TaskNode { const TaskGraphRebuildCtx& ctx); void ToTransportTaskProtoIf(TransportTaskProto*) const; + ExecNode::InferBlobDescsMethod GetInferBlobDescsMethod() const override { + // TransportTaskNode infers output BlobDesc based on input BlobDesc, because it can't infers + // output BlobDesc with SBP. + return &ExecNode::InferBlobDescsByInputs; + } + private: virtual Maybe InitTransportTaskFromProto(const TransportTaskProto&, const TaskGraphRebuildCtx& ctx) = 0; + virtual void ToTransportTaskProto(TransportTaskProto*) const = 0; LogicalBlobId lbi_; }; diff --git a/oneflow/core/graph_impl/acc_compute_task_node.cpp b/oneflow/core/graph_impl/acc_compute_task_node.cpp index c14d85367d5..17695dc3f68 100644 --- a/oneflow/core/graph_impl/acc_compute_task_node.cpp +++ b/oneflow/core/graph_impl/acc_compute_task_node.cpp @@ -44,7 +44,7 @@ void AccCompTaskNode::BuildExecGphAndRegst() { exec_node->BindBnWithRegst(op()->SoleIbn(), in_regst); out_regst->AddLbi(op()->BnInOp2Lbi(op()->SoleObn())); exec_node->BindBnWithRegst(op()->SoleObn(), out_regst); - exec_node->InferBlobDescs(parallel_ctx()); + (exec_node->*GetInferBlobDescsMethod())(parallel_ctx()); out_regst->ForEachLbi([out_regst](const LogicalBlobId& lbi) { const BlobDesc* blob_desc = out_regst->GetBlobDesc(lbi); CHECK_EQ(blob_desc->is_dynamic(), false); diff --git a/oneflow/core/graph_impl/acc_ctrl_tick_compute_task_node.cpp b/oneflow/core/graph_impl/acc_ctrl_tick_compute_task_node.cpp index adf7b2b89d1..7de08f51154 100644 --- a/oneflow/core/graph_impl/acc_ctrl_tick_compute_task_node.cpp +++ b/oneflow/core/graph_impl/acc_ctrl_tick_compute_task_node.cpp @@ -47,7 +47,7 @@ void AccCtrlTickCompTaskNode::BuildExecGphAndRegst() { exec_node->BindBnWithRegst(op->SoleIbn(), in_regst); out_regst->AddLbi(op->BnInOp2Lbi(op->SoleObn())); exec_node->BindBnWithRegst(op->SoleObn(), out_regst); - exec_node->InferBlobDescs(parallel_ctx()); + (exec_node->*GetInferBlobDescsMethod())(parallel_ctx()); } REGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kAccCtrlTick); diff --git a/oneflow/core/graph_impl/acc_tick_compute_task_node.cpp b/oneflow/core/graph_impl/acc_tick_compute_task_node.cpp index f43a4ceed6f..47f93a81c2c 100644 --- a/oneflow/core/graph_impl/acc_tick_compute_task_node.cpp +++ b/oneflow/core/graph_impl/acc_tick_compute_task_node.cpp @@ -47,7 +47,7 @@ void AccTickCompTaskNode::BuildExecGphAndRegst() { exec_node->BindBnWithRegst(op->SoleIbn(), in_regst); out_regst->AddLbi(op->BnInOp2Lbi(op->SoleObn())); exec_node->BindBnWithRegst(op->SoleObn(), out_regst); - exec_node->InferBlobDescs(parallel_ctx()); + (exec_node->*GetInferBlobDescsMethod())(parallel_ctx()); } REGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kAccTick); diff --git a/oneflow/core/graph_impl/case_compute_task_node.cpp b/oneflow/core/graph_impl/case_compute_task_node.cpp index a52fa4adb5a..803b5271cc7 100644 --- a/oneflow/core/graph_impl/case_compute_task_node.cpp +++ b/oneflow/core/graph_impl/case_compute_task_node.cpp @@ -69,7 +69,7 @@ void CaseCompTaskNode::BuildExecGphAndRegst() { out_regst->AddLbi(sole_op->BnInOp2Lbi(name)); node->BindBnWithRegst(name, out_regst); } - node->InferBlobDescs(parallel_ctx()); + (node->*GetInferBlobDescsMethod())(parallel_ctx()); } void CaseCompTaskNode::InferProducedDataRegstTimeShape() { NaiveInferProducedDataRegstTimeShape(); } diff --git a/oneflow/core/graph_impl/critical_section_wait_compute_task_node.cpp b/oneflow/core/graph_impl/critical_section_wait_compute_task_node.cpp index 51647e00bdf..d2a4a6775b6 100644 --- a/oneflow/core/graph_impl/critical_section_wait_compute_task_node.cpp +++ b/oneflow/core/graph_impl/critical_section_wait_compute_task_node.cpp @@ -56,7 +56,7 @@ void CriticalSectionWaitTickCompTaskNode::BuildExecGphAndRegst() { out_regst->AddLbi(lbi); node->BindBnWithRegst(obn, out_regst); } - node->InferBlobDescs(parallel_ctx()); + (node->*GetInferBlobDescsMethod())(parallel_ctx()); } REGISTER_INDEPENDENT_TASK_STREAM_INDEX_GETTER(TaskType::kCriticalSectionWaitTick); diff --git a/oneflow/core/graph_impl/decode_h2d_compute_task_node.cpp b/oneflow/core/graph_impl/decode_h2d_compute_task_node.cpp index e24b67bd309..5d7f94da6d2 100644 --- a/oneflow/core/graph_impl/decode_h2d_compute_task_node.cpp +++ b/oneflow/core/graph_impl/decode_h2d_compute_task_node.cpp @@ -54,7 +54,7 @@ void DecodeH2DCompTaskNode::BuildExecGphAndRegst() { out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn())); node->BindBnWithRegst(sole_op->SoleObn(), out_regst); node->AddBnToRegstAndBindIt(&Operator::tmp_bns, GetProducedRegst("tmp")); - node->InferBlobDescs(parallel_ctx()); + (node->*GetInferBlobDescsMethod())(parallel_ctx()); } REGISTER_NAMED_TASK_STREAM_INDEX_GETTER(DeviceType::kCUDA, TaskType::kDecodeH2D, "DECODE_H2D") diff --git a/oneflow/core/graph_impl/device_tick_compute_task_node.cpp b/oneflow/core/graph_impl/device_tick_compute_task_node.cpp index 0e2728e717a..1a2a53c0b50 100644 --- a/oneflow/core/graph_impl/device_tick_compute_task_node.cpp +++ b/oneflow/core/graph_impl/device_tick_compute_task_node.cpp @@ -56,7 +56,7 @@ void DeviceTickCompTaskNode::BuildExecGphAndRegst() { out_regst->AddLbi(lbi); node->BindBnWithRegst(obn, out_regst); } - node->InferBlobDescs(parallel_ctx()); + (node->*GetInferBlobDescsMethod())(parallel_ctx()); } REGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kDeviceTick); diff --git a/oneflow/core/graph_impl/distribute_concat_compute_task_node.cpp b/oneflow/core/graph_impl/distribute_concat_compute_task_node.cpp index 4c9915b7b5f..bc8a9ace7f3 100644 --- a/oneflow/core/graph_impl/distribute_concat_compute_task_node.cpp +++ b/oneflow/core/graph_impl/distribute_concat_compute_task_node.cpp @@ -52,7 +52,8 @@ void DistributeConcatCompTaskNode::ConsumeAllRegsts() { void DistributeConcatCompTaskNode::BuildExecGphAndRegst() { BuildExecGphStructAndBindInRegst(); BuildOutRegst(); - mut_exec_gph().TopoForEachNode([this](ExecNode* node) { node->InferBlobDescs(parallel_ctx()); }); + mut_exec_gph().TopoForEachNode( + [this](ExecNode* node) { (node->*GetInferBlobDescsMethod())(parallel_ctx()); }); } void DistributeConcatCompTaskNode::BuildExecGphStructAndBindInRegst() { diff --git a/oneflow/core/graph_impl/distribute_split_compute_task_node.cpp b/oneflow/core/graph_impl/distribute_split_compute_task_node.cpp index 36952138fb8..fd3e0334888 100644 --- a/oneflow/core/graph_impl/distribute_split_compute_task_node.cpp +++ b/oneflow/core/graph_impl/distribute_split_compute_task_node.cpp @@ -47,7 +47,8 @@ void DistributeSplitCompTaskNode::ConsumeAllRegsts() { void DistributeSplitCompTaskNode::BuildExecGphAndRegst() { BuildExecGphStructAndBindInRegst(); BuildOutRegst(); - mut_exec_gph().TopoForEachNode([this](ExecNode* node) { node->InferBlobDescs(parallel_ctx()); }); + mut_exec_gph().TopoForEachNode( + [this](ExecNode* node) { (node->*GetInferBlobDescsMethod())(parallel_ctx()); }); } void DistributeSplitCompTaskNode::BuildExecGphStructAndBindInRegst() { diff --git a/oneflow/core/graph_impl/dst_subset_tick_compute_task_node.cpp b/oneflow/core/graph_impl/dst_subset_tick_compute_task_node.cpp index e2d95f0b8b0..901a4b10092 100644 --- a/oneflow/core/graph_impl/dst_subset_tick_compute_task_node.cpp +++ b/oneflow/core/graph_impl/dst_subset_tick_compute_task_node.cpp @@ -56,7 +56,7 @@ void DstSubsetTickCompTaskNode::BuildExecGphAndRegst() { out_regst->AddLbi(lbi); node->BindBnWithRegst(obn, out_regst); } - node->InferBlobDescs(parallel_ctx()); + (node->*GetInferBlobDescsMethod())(parallel_ctx()); } REGISTER_TICK_TASK_STREAM_INDEX_GETTER(TaskType::kDstSubsetTick); diff --git a/oneflow/core/graph_impl/esac_compute_task_node.cpp b/oneflow/core/graph_impl/esac_compute_task_node.cpp index 45e60378dee..8595c22ce0f 100644 --- a/oneflow/core/graph_impl/esac_compute_task_node.cpp +++ b/oneflow/core/graph_impl/esac_compute_task_node.cpp @@ -70,7 +70,7 @@ void EsacCompTaskNode::BuildExecGphAndRegst() { std::shared_ptr out_regst = GetProducedRegst("out"); out_regst->AddLbi(sole_op->BnInOp2Lbi("out")); node->BindBnWithRegst("out", out_regst); - node->InferBlobDescs(parallel_ctx()); + (node->*GetInferBlobDescsMethod())(parallel_ctx()); } void EsacCompTaskNode::InferProducedDataRegstTimeShape() { NaiveInferProducedDataRegstTimeShape(); } diff --git a/oneflow/core/graph_impl/normal_forward_compute_task_node.cpp b/oneflow/core/graph_impl/normal_forward_compute_task_node.cpp index 16b3a2b63d7..8f71a3ac335 100644 --- a/oneflow/core/graph_impl/normal_forward_compute_task_node.cpp +++ b/oneflow/core/graph_impl/normal_forward_compute_task_node.cpp @@ -99,7 +99,8 @@ void NormalForwardCompTaskNode::BuildExecGphAndRegst() { BuildExecGphStructAndBindInRegst(); BuildOutRegst(); BuildTmp7BufRegsts(); - mut_exec_gph().TopoForEachNode([this](ExecNode* node) { node->InferBlobDescs(parallel_ctx()); }); + mut_exec_gph().TopoForEachNode( + [this](ExecNode* node) { (node->*GetInferBlobDescsMethod())(parallel_ctx()); }); } void NormalForwardCompTaskNode::BuildExecGphStructAndBindInRegst() { diff --git a/oneflow/core/graph_impl/pack_compute_task_node.cpp b/oneflow/core/graph_impl/pack_compute_task_node.cpp index 9aec8613853..88ec0ce6834 100644 --- a/oneflow/core/graph_impl/pack_compute_task_node.cpp +++ b/oneflow/core/graph_impl/pack_compute_task_node.cpp @@ -50,7 +50,7 @@ void PackCompTaskNode::BuildExecGphAndRegst() { out_regst->AddLbi(op()->BnInOp2Lbi(op()->SoleObn())); exec_node->BindBnWithRegst(op()->SoleObn(), out_regst); - exec_node->InferBlobDescs(parallel_ctx()); + (exec_node->*GetInferBlobDescsMethod())(parallel_ctx()); } REGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kPack); diff --git a/oneflow/core/graph_impl/reentrant_lock_compute_task_node.cpp b/oneflow/core/graph_impl/reentrant_lock_compute_task_node.cpp index bb9f0368fcb..5045b65dc23 100644 --- a/oneflow/core/graph_impl/reentrant_lock_compute_task_node.cpp +++ b/oneflow/core/graph_impl/reentrant_lock_compute_task_node.cpp @@ -56,7 +56,7 @@ void ReentrantLockCompTaskNode::BuildExecGphAndRegst() { out_regst->AddLbi(lbi); node->BindBnWithRegst(obn, out_regst); } - node->InferBlobDescs(parallel_ctx()); + (node->*GetInferBlobDescsMethod())(parallel_ctx()); } void ReentrantLockCompTaskNode::InferProducedDataRegstTimeShape() { diff --git a/oneflow/core/graph_impl/repeat_compute_task_node.cpp b/oneflow/core/graph_impl/repeat_compute_task_node.cpp index 8c887688d75..6385ac8dc37 100644 --- a/oneflow/core/graph_impl/repeat_compute_task_node.cpp +++ b/oneflow/core/graph_impl/repeat_compute_task_node.cpp @@ -51,7 +51,7 @@ void RepeatCompTaskNode::BuildExecGphAndRegst() { std::shared_ptr out_regst = GetProducedRegst("out"); out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn())); node->BindBnWithRegst(sole_op->SoleObn(), out_regst); - node->InferBlobDescs(parallel_ctx()); + (node->*GetInferBlobDescsMethod())(parallel_ctx()); // NOTE(chengcheng): force inplace CHECK_EQ(in_regst->NumOfLbi(), 1); diff --git a/oneflow/core/graph_impl/source_tick_compute_task_node.cpp b/oneflow/core/graph_impl/source_tick_compute_task_node.cpp index f0bc0d850d3..93393a89145 100644 --- a/oneflow/core/graph_impl/source_tick_compute_task_node.cpp +++ b/oneflow/core/graph_impl/source_tick_compute_task_node.cpp @@ -46,7 +46,7 @@ void SourceTickCompTaskNode::BuildExecGphAndRegst() { out_regst->AddLbi(lbi); node->BindBnWithRegst(obn, out_regst); } - node->InferBlobDescs(parallel_ctx()); + (node->*GetInferBlobDescsMethod())(parallel_ctx()); } REGISTER_TICK_TASK_STREAM_INDEX_GETTER(TaskType::kSourceTick); diff --git a/oneflow/core/graph_impl/src_subset_tick_compute_task_node.cpp b/oneflow/core/graph_impl/src_subset_tick_compute_task_node.cpp index dad25139730..e1fa7349553 100644 --- a/oneflow/core/graph_impl/src_subset_tick_compute_task_node.cpp +++ b/oneflow/core/graph_impl/src_subset_tick_compute_task_node.cpp @@ -56,7 +56,7 @@ void SrcSubsetTickCompTaskNode::BuildExecGphAndRegst() { out_regst->AddLbi(lbi); node->BindBnWithRegst(obn, out_regst); } - node->InferBlobDescs(parallel_ctx()); + (node->*GetInferBlobDescsMethod())(parallel_ctx()); } REGISTER_TICK_TASK_STREAM_INDEX_GETTER(TaskType::kSrcSubsetTick); diff --git a/oneflow/core/graph_impl/ssp_variable_proxy_task_node.cpp b/oneflow/core/graph_impl/ssp_variable_proxy_task_node.cpp index eb04236a537..aa6d46729b0 100644 --- a/oneflow/core/graph_impl/ssp_variable_proxy_task_node.cpp +++ b/oneflow/core/graph_impl/ssp_variable_proxy_task_node.cpp @@ -74,7 +74,7 @@ class SspVariableProxyCompTaskNode final : public CompTaskNode { BuildExecGphStructAndBindInRegst(); BuildOutRegst(); mut_exec_gph().TopoForEachNode( - [this](ExecNode* node) { node->InferBlobDescs(parallel_ctx()); }); + [this](ExecNode* node) { (node->*GetInferBlobDescsMethod())(parallel_ctx()); }); } void BuildExecGphStructAndBindInRegst() { diff --git a/oneflow/core/graph_impl/tick_compute_task_node.cpp b/oneflow/core/graph_impl/tick_compute_task_node.cpp index 34b1ca3fd64..052b72841a7 100644 --- a/oneflow/core/graph_impl/tick_compute_task_node.cpp +++ b/oneflow/core/graph_impl/tick_compute_task_node.cpp @@ -56,7 +56,7 @@ void TickCompTaskNode::BuildExecGphAndRegst() { out_regst->AddLbi(lbi); node->BindBnWithRegst(obn, out_regst); } - node->InferBlobDescs(parallel_ctx()); + (node->*GetInferBlobDescsMethod())(parallel_ctx()); } REGISTER_TICK_TASK_STREAM_INDEX_GETTER(TaskType::kTick); diff --git a/oneflow/core/graph_impl/unpack_compute_task_node.cpp b/oneflow/core/graph_impl/unpack_compute_task_node.cpp index 2f6bbee7ff7..2916ee2fbb5 100644 --- a/oneflow/core/graph_impl/unpack_compute_task_node.cpp +++ b/oneflow/core/graph_impl/unpack_compute_task_node.cpp @@ -50,7 +50,7 @@ void UnpackCompTaskNode::BuildExecGphAndRegst() { std::shared_ptr out_regst = GetProducedRegst("out"); out_regst->AddLbi(op()->BnInOp2Lbi(op()->SoleObn())); exec_node->BindBnWithRegst(op()->SoleObn(), out_regst); - exec_node->InferBlobDescs(parallel_ctx()); + (exec_node->*GetInferBlobDescsMethod())(parallel_ctx()); } REGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kUnpack); diff --git a/oneflow/core/graph_impl/wait_and_send_ids_compute_task_node.cpp b/oneflow/core/graph_impl/wait_and_send_ids_compute_task_node.cpp index 68449dfc860..ba79fca0c8d 100644 --- a/oneflow/core/graph_impl/wait_and_send_ids_compute_task_node.cpp +++ b/oneflow/core/graph_impl/wait_and_send_ids_compute_task_node.cpp @@ -49,7 +49,7 @@ void WaitAndSendIdsCompTaskNode::BuildExecGphAndRegst() { out_regst->AddLbi(lbi); node->BindBnWithRegst(obn, out_regst); } - node->InferBlobDescs(parallel_ctx()); + (node->*GetInferBlobDescsMethod())(parallel_ctx()); } void WaitAndSendIdsCompTaskNode::InferProducedDataRegstTimeShape() { From be2987d4972c237c44af00c17670c223cbb3ba4d Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 13 Apr 2023 10:19:02 +0000 Subject: [PATCH 08/18] refine comment --- oneflow/core/graph/task_node.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/core/graph/task_node.h b/oneflow/core/graph/task_node.h index 25924b42457..7856c0d08d1 100644 --- a/oneflow/core/graph/task_node.h +++ b/oneflow/core/graph/task_node.h @@ -71,7 +71,7 @@ class TaskNode : public Node { DeviceType device_type() const; virtual const ParallelContext* parallel_ctx() const { return nullptr; } - // Different types of ExecNode choose different output BlobDesc inference methods + // Different types of TaskNode/Compile Mode choose different output BlobDesc inference methods virtual ExecNode::InferBlobDescsMethod GetInferBlobDescsMethod() const = 0; // Setters From 9acbc799f2b3069374e906572364b4b4f445943b Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 14 Apr 2023 09:46:57 +0000 Subject: [PATCH 09/18] rm useless --- oneflow/core/graph/task_graph_rebuild_ctx.cpp | 12 ------------ oneflow/core/graph/task_graph_rebuild_ctx.h | 12 ------------ 2 files changed, 24 deletions(-) diff --git a/oneflow/core/graph/task_graph_rebuild_ctx.cpp b/oneflow/core/graph/task_graph_rebuild_ctx.cpp index cfd947f6823..224da06be90 100644 --- a/oneflow/core/graph/task_graph_rebuild_ctx.cpp +++ b/oneflow/core/graph/task_graph_rebuild_ctx.cpp @@ -13,18 +13,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ #include "oneflow/core/common/container_util.h" #include "oneflow/core/graph/task_node.h" #include "oneflow/core/graph/task_graph_rebuild_ctx.h" diff --git a/oneflow/core/graph/task_graph_rebuild_ctx.h b/oneflow/core/graph/task_graph_rebuild_ctx.h index c0e6bc12751..8bbd5b7516d 100644 --- a/oneflow/core/graph/task_graph_rebuild_ctx.h +++ b/oneflow/core/graph/task_graph_rebuild_ctx.h @@ -13,18 +13,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ #ifndef ONEFLOW_CORE_GRAPH_TASK_GRAPH_REBUILD_CTX_H_ #define ONEFLOW_CORE_GRAPH_TASK_GRAPH_REBUILD_CTX_H_ From 34b313392313a8248816a9259d18c75aa0f55e93 Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 14 Apr 2023 15:16:31 +0000 Subject: [PATCH 10/18] add comsume fake regst --- oneflow/core/graph/compute_task_node.cpp | 45 +++++++++++++++ oneflow/core/graph/compute_task_node.h | 11 +++- .../core/graph/fake_consumed_regst_provider.h | 35 ++++++++++++ .../graph/normal_forward_compute_task_node.h | 1 + oneflow/core/graph/task_node.cpp | 57 ++++++++++++++++++- oneflow/core/graph/task_node.h | 9 +++ oneflow/core/graph/transport_task_node.cpp | 12 ---- .../core/graph_impl/acc_compute_task_node.cpp | 2 + .../acc_ctrl_tick_compute_task_node.cpp | 3 + .../graph_impl/acc_tick_compute_task_node.cpp | 2 + .../callback_notify_compute_task_node.cpp | 3 + .../graph_impl/case_compute_task_node.cpp | 3 + ...ritical_section_wait_compute_task_node.cpp | 3 + .../decode_h2d_compute_task_node.cpp | 3 + .../device_tick_compute_task_node.cpp | 3 + .../distribute_concat_compute_task_node.cpp | 3 + .../distribute_split_compute_task_node.cpp | 3 + .../dst_subset_tick_compute_task_node.cpp | 3 + .../graph_impl/esac_compute_task_node.cpp | 1 + .../normal_forward_compute_task_node.cpp | 2 + .../graph_impl/pack_compute_task_node.cpp | 3 + .../reentrant_lock_compute_task_node.cpp | 3 + .../graph_impl/repeat_compute_task_node.cpp | 3 + .../source_tick_compute_task_node.cpp | 1 + .../src_subset_tick_compute_task_node.cpp | 3 + .../ssp_variable_proxy_task_node.cpp | 2 + .../graph_impl/tick_compute_task_node.cpp | 3 + .../graph_impl/unpack_compute_task_node.cpp | 3 + .../wait_and_send_ids_compute_task_node.cpp | 1 + oneflow/core/register/register_desc.cpp | 26 +++++++++ oneflow/core/register/register_desc.h | 5 +- 31 files changed, 241 insertions(+), 16 deletions(-) create mode 100644 oneflow/core/graph/fake_consumed_regst_provider.h diff --git a/oneflow/core/graph/compute_task_node.cpp b/oneflow/core/graph/compute_task_node.cpp index f6547053970..1cec2eb4582 100644 --- a/oneflow/core/graph/compute_task_node.cpp +++ b/oneflow/core/graph/compute_task_node.cpp @@ -63,10 +63,55 @@ std::vector GetCompTaskNodesOnEdge( return comp_task_nodes; } +std::shared_ptr NewFakeDataRegstDesc() { + auto regst_desc = std::make_shared(); + regst_desc->mut_regst_desc_type()->mutable_data_regst_desc(); + return regst_desc; +} + } // namespace +void CompTaskNode::ConsumeFakeRegst(const std::string& regst_name) { + ConsumeRegst(regst_name, NewFakeDataRegstDesc()); + fake_consumed_regst_names_.insert(regst_name); +} + +void CompTaskNode::ConsumeFakeRegstsIf() { + ConsumeFakeRegsts(); + RegstDesc* data_regst_desc = nullptr; + for (const auto& pair : consumed_regsts()) { + for (const auto& regst_desc : pair.second) { + if (regst_desc->regst_desc_type().has_data_regst_desc()) { + CHECK(data_regst_desc == nullptr); + data_regst_desc = CHECK_NOTNULL(regst_desc.get()); + } else if (regst_desc->regst_desc_type().has_ctrl_regst_desc()) { + // do nothing. + } else { + UNIMPLEMENTED(); + } + } + } + if (data_regst_desc != nullptr) { + for (const auto& ibn : op_node()->op().input_bns()) { + data_regst_desc->AddLbi(op_node()->op().BnInOp2Lbi(ibn)); + } + } +} + +void CompTaskNode::EraseFakeRegstsIf() { + for (const auto& fake_consumed_regst_name : fake_consumed_regst_names_) { + EraseConsumedRegstsByName(fake_consumed_regst_name); + } + fake_consumed_regst_names_.clear(); +} + std::string CompTaskNode::VisualStr() const { return op_node_->op().op_name(); } +void CompTaskNode::InitFromProtoExceptConsumedRegsts(const TaskProto& proto) { + TaskNode::InitFromProtoExceptConsumedRegsts(proto); + parallel_ctx_ = proto.parallel_ctx(); +} + void CompTaskNode::ToProto(TaskProto* task_proto, bool check) const { TaskNode::ToProto(task_proto, check); *(task_proto->mutable_parallel_ctx()) = parallel_ctx_; diff --git a/oneflow/core/graph/compute_task_node.h b/oneflow/core/graph/compute_task_node.h index bba6d957668..7d6cb03bc47 100644 --- a/oneflow/core/graph/compute_task_node.h +++ b/oneflow/core/graph/compute_task_node.h @@ -18,17 +18,25 @@ limitations under the License. #include "oneflow/core/graph/task_node.h" #include "oneflow/core/graph/op_graph.h" +#include "oneflow/core/graph/fake_consumed_regst_provider.h" #include "oneflow/core/device/cuda_util.h" namespace oneflow { -class CompTaskNode : public TaskNode { +class CompTaskNode : public TaskNode, public FakeConsumedRegstProvider { public: OF_DISALLOW_COPY_AND_MOVE(CompTaskNode); CompTaskNode() = default; virtual ~CompTaskNode() = default; virtual void ToProto(TaskProto*, bool check) const override; + virtual void InitFromProtoExceptConsumedRegsts(const TaskProto&) override; + void ConsumeFakeRegstsIf() override; + void EraseFakeRegstsIf() override; + + // ConsumeFakeRegsts is used for initializing CompTaskNode.consumed_regsts_ on the other ranks. + virtual void ConsumeFakeRegsts() = 0; + void ConsumeFakeRegst(const std::string& regst_name); // parallel_ctx_ int64_t parallel_id() const { return parallel_ctx_.parallel_id(); } @@ -58,6 +66,7 @@ class CompTaskNode : public TaskNode { private: ParallelContext parallel_ctx_; const OpNode* op_node_; + HashSet fake_consumed_regst_names_; }; class OpCompTaskNodeCreator { diff --git a/oneflow/core/graph/fake_consumed_regst_provider.h b/oneflow/core/graph/fake_consumed_regst_provider.h new file mode 100644 index 00000000000..8bd8dc26326 --- /dev/null +++ b/oneflow/core/graph/fake_consumed_regst_provider.h @@ -0,0 +1,35 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_GRAPH_FAKE_CONSUMED_REGST_PROVIDER_H_ +#define ONEFLOW_CORE_GRAPH_FAKE_CONSUMED_REGST_PROVIDER_H_ + +namespace oneflow { + +// Provide a compute task node with a fake input regst, and its output regst can be inferred using +// SBP + Placement. The fake compute task node can help the task graph of one rank to infer blob +// desc, mainly to ensure that the transport task node has the correct input blob desc. +class FakeConsumedRegstProvider { + public: + FakeConsumedRegstProvider() = default; + virtual ~FakeConsumedRegstProvider() = default; + + virtual void ConsumeFakeRegstsIf() = 0; + virtual void EraseFakeRegstsIf() = 0; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_GRAPH_FAKE_CONSUMED_REGST_PROVIDER_H_ \ No newline at end of file diff --git a/oneflow/core/graph/normal_forward_compute_task_node.h b/oneflow/core/graph/normal_forward_compute_task_node.h index 53d1807c5a7..b3da9f36edb 100644 --- a/oneflow/core/graph/normal_forward_compute_task_node.h +++ b/oneflow/core/graph/normal_forward_compute_task_node.h @@ -30,6 +30,7 @@ class NormalForwardCompTaskNode final : public CompTaskNode { void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; TaskType GetTaskType() const override { return TaskType::kNormalForward; } diff --git a/oneflow/core/graph/task_node.cpp b/oneflow/core/graph/task_node.cpp index ed25fa43cc4..684cc483d5f 100644 --- a/oneflow/core/graph/task_node.cpp +++ b/oneflow/core/graph/task_node.cpp @@ -203,6 +203,38 @@ std::string TaskNode::VisualStr() const { bool TaskNode::IsMeaningLess() { return produced_regsts_.empty() && consumed_regsts_.empty(); } +void TaskNode::InitFromProtoExceptConsumedRegsts(const TaskProto& task_proto) { + // Step1: init some scalar items. + CHECK(task_proto.task_type() == GetTaskType()); + machine_id_ = task_proto.machine_id(); + thrd_id_ = task_proto.thrd_id(); + task_id_ = task_proto.task_id(); + new_task_id_.reset(new TaskId(DecodeTaskIdFromInt64(task_id_))); + CHECK(task_proto.job_id() == GlobalJobDesc().job_id()); + chain_id_ = task_proto.chain_id(); + order_in_chain_ = task_proto.order_in_chain(); + // Step2: check exec_gph empty. + CHECK(task_proto.exec_sequence().exec_node().empty()); + // Step3: init produced_regst. + for (const auto& pair : task_proto.produced_regst_desc()) { + const auto& regst_desc = ProduceRegst(pair.first, pair.second.enable_reuse_mem()); + // regst_desc->consumers_ will be initialized by RegstDesc::InitConsumersFromProto. + regst_desc->InitFromProtoExceptConsumers(pair.second); + } +} + +Maybe TaskNode::InitConsumedRegstsFromProto( + const TaskProto& task_proto, + const std::function(int64_t regst_desc_id)>& RegstDesc4Id) { + // Step3: init consumed_regst. + for (const auto& pair : task_proto.consumed_regst_desc_id()) { + for (int64_t regst_desc_id : pair.second.regst_desc_id()) { + ConsumeRegst(pair.first, JUST(RegstDesc4Id(regst_desc_id))); + } + } + return Maybe::Ok(); +} + void TaskNode::ToProto(TaskProto* task_proto, bool check) const { // Step1: process some scalar items. task_proto->set_task_type(GetTaskType()); @@ -265,9 +297,22 @@ RegstDesc* TaskNode::BuildCtrlRegstDesc(TaskNode* dst_node, std::string* name) { } void TaskNode::BindEdgeWithProducedRegst(TaskEdge* edge, const std::string& name) { + if (edge->HasRegst(name)) { return; } edge->AddRegst(name, GetProducedRegst(name)); } +std::shared_ptr TaskNode::GetOrCheckRegst(const std::string& name, bool enable_reuse_mem, + int32_t min_register_num, + int32_t max_register_num) const { + auto iter = produced_regsts_.find(name); + if (iter == produced_regsts_.end()) { return nullptr; } + const auto& regst = (iter->second); + CHECK_EQ(regst->min_register_num(), min_register_num); + CHECK_EQ(regst->max_register_num(), max_register_num); + CHECK_EQ(regst->enable_reuse_mem(), enable_reuse_mem); + return regst; +} + std::shared_ptr TaskNode::ProduceRegst(const std::string& name, bool enable_reuse_mem) { return ProduceRegst(name, enable_reuse_mem, 1, kMaxRegisterNum); } @@ -275,6 +320,10 @@ std::shared_ptr TaskNode::ProduceRegst(const std::string& name, bool std::shared_ptr TaskNode::ProduceRegst(const std::string& name, bool enable_reuse_mem, int32_t min_register_num, int32_t max_register_num) { + // Because the Regst of separate compilation is not created in order, some Regst may have been + // built. This implementation can avoid ProduceRegst being called multiple times. + const auto& regst = GetOrCheckRegst(name, enable_reuse_mem, min_register_num, max_register_num); + if (regst) { return regst; } RegstDescTypeProto regst_desc_type; regst_desc_type.mutable_data_regst_desc(); return ProduceRegst(name, enable_reuse_mem, min_register_num, max_register_num, regst_desc_type); @@ -353,8 +402,13 @@ std::shared_ptr TaskEdge::GetRegst(const std::string& name_in_produce return name_in_producer2regst_.at(name_in_producer); } +bool TaskEdge::HasRegst(const std::string& name_in_producer) const { + return (name_in_producer2regst_.find(name_in_producer) != name_in_producer2regst_.end()); +} + std::shared_ptr TaskEdge::GetSoleRegst() const { - CHECK_EQ(name_in_producer2regst_.size(), 1); + CHECK_EQ(name_in_producer2regst_.size(), 1) + << "edge: " << this << ", src: " << src_node()->task_id(); return name_in_producer2regst_.begin()->second; } @@ -367,6 +421,7 @@ std::vector> TaskEdge::GetRegsts() const { void TaskEdge::AddRegst(const std::string& name_in_producer, const std::shared_ptr& regst) { + if (HasRegst(name_in_producer)) { return; } CHECK(name_in_producer2regst_.emplace(name_in_producer, regst).second); } diff --git a/oneflow/core/graph/task_node.h b/oneflow/core/graph/task_node.h index 7856c0d08d1..65ca8181a2e 100644 --- a/oneflow/core/graph/task_node.h +++ b/oneflow/core/graph/task_node.h @@ -100,6 +100,10 @@ class TaskNode : public Node { std::string VisualStr() const override; virtual bool IsMeaningLess(); void ToProto(TaskProto* task_proto) const { ToProto(task_proto, /*check*/ true); } + virtual void InitFromProtoExceptConsumedRegsts(const TaskProto& task_proto); + Maybe InitConsumedRegstsFromProto( + const TaskProto& task_proto, + const std::function(int64_t regst_desc_id)>& RegstDesc4Id); virtual void ToProto(TaskProto* task_proto, bool check) const; void BindEdgeWithProducedRegst(TaskEdge*, const std::string& name); virtual MemZoneId MemZoneId121() const; @@ -150,6 +154,9 @@ class TaskNode : public Node { private: void UpdateTaskId(); + std::shared_ptr GetOrCheckRegst(const std::string& name, bool enable_reuse_mem, + int32_t min_register_num, + int32_t max_register_num) const; int64_t machine_id_; int64_t thrd_id_; @@ -172,6 +179,7 @@ class TaskEdge final : public Edge { ~TaskEdge() override = default; std::shared_ptr GetRegst(const std::string& name_in_producer) const; + bool HasRegst(const std::string& name_in_producer) const; std::shared_ptr GetSoleRegst() const; std::vector> GetRegsts() const; const HashSet& GetLbis() const { return lbis_; } @@ -181,6 +189,7 @@ class TaskEdge final : public Edge { void AddLbis(const std::vector& lbis) { lbis_.insert(lbis.begin(), lbis.end()); } void CheckRegstLbiValid() const; + bool OutHasBindRegst() const { return !name_in_producer2regst_.empty(); } Maybe InitFromProto(const TaskEdgeProto& proto, const TaskGraphRebuildCtx& task_graph_rebuild_ctx); diff --git a/oneflow/core/graph/transport_task_node.cpp b/oneflow/core/graph/transport_task_node.cpp index 1e3031e0003..53315281453 100644 --- a/oneflow/core/graph/transport_task_node.cpp +++ b/oneflow/core/graph/transport_task_node.cpp @@ -13,18 +13,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ #include "oneflow/core/graph/transport_task_node.h" #include "oneflow/core/graph/boxing_task_graph.pb.h" diff --git a/oneflow/core/graph_impl/acc_compute_task_node.cpp b/oneflow/core/graph_impl/acc_compute_task_node.cpp index 17695dc3f68..016bec8cffb 100644 --- a/oneflow/core/graph_impl/acc_compute_task_node.cpp +++ b/oneflow/core/graph_impl/acc_compute_task_node.cpp @@ -27,6 +27,7 @@ class AccCompTaskNode final : public CompTaskNode { void BuildExecGphAndRegst() override; void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; }; void AccCompTaskNode::ProduceAllRegstsAndBindEdges() { @@ -35,6 +36,7 @@ void AccCompTaskNode::ProduceAllRegstsAndBindEdges() { } void AccCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); } +void AccCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } void AccCompTaskNode::BuildExecGphAndRegst() { std::shared_ptr in_regst = GetSoleConsumedRegst("in"); diff --git a/oneflow/core/graph_impl/acc_ctrl_tick_compute_task_node.cpp b/oneflow/core/graph_impl/acc_ctrl_tick_compute_task_node.cpp index 7de08f51154..b6dee6102bb 100644 --- a/oneflow/core/graph_impl/acc_ctrl_tick_compute_task_node.cpp +++ b/oneflow/core/graph_impl/acc_ctrl_tick_compute_task_node.cpp @@ -27,6 +27,7 @@ class AccCtrlTickCompTaskNode final : public CompTaskNode { void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; void BuildExecGphAndRegst() override; + void ConsumeFakeRegsts() override; }; void AccCtrlTickCompTaskNode::ProduceAllRegstsAndBindEdges() { @@ -38,6 +39,8 @@ void AccCtrlTickCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); } +void AccCtrlTickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } + void AccCtrlTickCompTaskNode::BuildExecGphAndRegst() { std::shared_ptr in_regst = GetSoleConsumedRegst("in"); std::shared_ptr out_regst = GetProducedRegst("out"); diff --git a/oneflow/core/graph_impl/acc_tick_compute_task_node.cpp b/oneflow/core/graph_impl/acc_tick_compute_task_node.cpp index 47f93a81c2c..83b23730c49 100644 --- a/oneflow/core/graph_impl/acc_tick_compute_task_node.cpp +++ b/oneflow/core/graph_impl/acc_tick_compute_task_node.cpp @@ -26,6 +26,7 @@ class AccTickCompTaskNode final : public CompTaskNode { TaskType GetTaskType() const override { return TaskType::kAccTick; } void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; void BuildExecGphAndRegst() override; }; @@ -37,6 +38,7 @@ void AccTickCompTaskNode::ProduceAllRegstsAndBindEdges() { void AccTickCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); } +void AccTickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } void AccTickCompTaskNode::BuildExecGphAndRegst() { std::shared_ptr in_regst = GetSoleConsumedRegst("in"); diff --git a/oneflow/core/graph_impl/callback_notify_compute_task_node.cpp b/oneflow/core/graph_impl/callback_notify_compute_task_node.cpp index d4326fc9588..80baabc6321 100644 --- a/oneflow/core/graph_impl/callback_notify_compute_task_node.cpp +++ b/oneflow/core/graph_impl/callback_notify_compute_task_node.cpp @@ -29,6 +29,7 @@ class CallbackNotifyCompTaskNode final : public CompTaskNode { private: void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; void BuildExecGphAndRegst() override; }; @@ -38,6 +39,8 @@ void CallbackNotifyCompTaskNode::ConsumeAllRegsts() { ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("in", edge->GetSoleRegst()); }); } +void CallbackNotifyCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } + void CallbackNotifyCompTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); node->mut_op() = this->op(); diff --git a/oneflow/core/graph_impl/case_compute_task_node.cpp b/oneflow/core/graph_impl/case_compute_task_node.cpp index 803b5271cc7..fc5cc90b52b 100644 --- a/oneflow/core/graph_impl/case_compute_task_node.cpp +++ b/oneflow/core/graph_impl/case_compute_task_node.cpp @@ -26,6 +26,7 @@ class CaseCompTaskNode final : public CompTaskNode { void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; TaskType GetTaskType() const override { return TaskType::kCase; } @@ -36,6 +37,8 @@ class CaseCompTaskNode final : public CompTaskNode { void CaseCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); } +void CaseCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } + void CaseCompTaskNode::ProduceAllRegstsAndBindEdges() { HashMap lbi2obn_id; FOR_RANGE(int64_t, obn_id, 0, op()->output_bns().size()) { diff --git a/oneflow/core/graph_impl/critical_section_wait_compute_task_node.cpp b/oneflow/core/graph_impl/critical_section_wait_compute_task_node.cpp index d2a4a6775b6..65341724df9 100644 --- a/oneflow/core/graph_impl/critical_section_wait_compute_task_node.cpp +++ b/oneflow/core/graph_impl/critical_section_wait_compute_task_node.cpp @@ -30,6 +30,7 @@ class CriticalSectionWaitTickCompTaskNode final : public CompTaskNode { private: void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; void BuildExecGphAndRegst() override; }; @@ -43,6 +44,8 @@ void CriticalSectionWaitTickCompTaskNode::ConsumeAllRegsts() { ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("in", edge->GetSoleRegst()); }); } +void CriticalSectionWaitTickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } + void CriticalSectionWaitTickCompTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); node->mut_op() = op(); diff --git a/oneflow/core/graph_impl/decode_h2d_compute_task_node.cpp b/oneflow/core/graph_impl/decode_h2d_compute_task_node.cpp index 5d7f94da6d2..c548272c4cb 100644 --- a/oneflow/core/graph_impl/decode_h2d_compute_task_node.cpp +++ b/oneflow/core/graph_impl/decode_h2d_compute_task_node.cpp @@ -27,6 +27,7 @@ class DecodeH2DCompTaskNode final : public CompTaskNode { void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; TaskType GetTaskType() const override { return TaskType::kDecodeH2D; } @@ -38,6 +39,8 @@ void DecodeH2DCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); } +void DecodeH2DCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } + void DecodeH2DCompTaskNode::ProduceAllRegstsAndBindEdges() { auto regst_num = ParseIntegerFromEnv("ONEFLOW_DECODE_H2D_REGST_NUM", 2); std::shared_ptr out_regst = ProduceRegst("out", false, regst_num, regst_num); diff --git a/oneflow/core/graph_impl/device_tick_compute_task_node.cpp b/oneflow/core/graph_impl/device_tick_compute_task_node.cpp index 1a2a53c0b50..1ffcfdcca95 100644 --- a/oneflow/core/graph_impl/device_tick_compute_task_node.cpp +++ b/oneflow/core/graph_impl/device_tick_compute_task_node.cpp @@ -30,6 +30,7 @@ class DeviceTickCompTaskNode final : public CompTaskNode { private: void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; void BuildExecGphAndRegst() override; }; @@ -43,6 +44,8 @@ void DeviceTickCompTaskNode::ConsumeAllRegsts() { ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("in", edge->GetSoleRegst()); }); } +void DeviceTickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } + void DeviceTickCompTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); node->mut_op() = op(); diff --git a/oneflow/core/graph_impl/distribute_concat_compute_task_node.cpp b/oneflow/core/graph_impl/distribute_concat_compute_task_node.cpp index bc8a9ace7f3..5c3b5e9b069 100644 --- a/oneflow/core/graph_impl/distribute_concat_compute_task_node.cpp +++ b/oneflow/core/graph_impl/distribute_concat_compute_task_node.cpp @@ -26,6 +26,7 @@ class DistributeConcatCompTaskNode final : public CompTaskNode { void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; TaskType GetTaskType() const override { return TaskType::kDistributeConcat; } @@ -49,6 +50,8 @@ void DistributeConcatCompTaskNode::ConsumeAllRegsts() { CHECK_EQ(cnt, 1); } +void DistributeConcatCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } + void DistributeConcatCompTaskNode::BuildExecGphAndRegst() { BuildExecGphStructAndBindInRegst(); BuildOutRegst(); diff --git a/oneflow/core/graph_impl/distribute_split_compute_task_node.cpp b/oneflow/core/graph_impl/distribute_split_compute_task_node.cpp index fd3e0334888..f387719a17f 100644 --- a/oneflow/core/graph_impl/distribute_split_compute_task_node.cpp +++ b/oneflow/core/graph_impl/distribute_split_compute_task_node.cpp @@ -26,6 +26,7 @@ class DistributeSplitCompTaskNode final : public CompTaskNode { void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; TaskType GetTaskType() const override { return TaskType::kDistributeSplit; } @@ -44,6 +45,8 @@ void DistributeSplitCompTaskNode::ConsumeAllRegsts() { ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("in", edge->GetSoleRegst()); }); } +void DistributeSplitCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } + void DistributeSplitCompTaskNode::BuildExecGphAndRegst() { BuildExecGphStructAndBindInRegst(); BuildOutRegst(); diff --git a/oneflow/core/graph_impl/dst_subset_tick_compute_task_node.cpp b/oneflow/core/graph_impl/dst_subset_tick_compute_task_node.cpp index 901a4b10092..e587ced5e02 100644 --- a/oneflow/core/graph_impl/dst_subset_tick_compute_task_node.cpp +++ b/oneflow/core/graph_impl/dst_subset_tick_compute_task_node.cpp @@ -30,6 +30,7 @@ class DstSubsetTickCompTaskNode final : public CompTaskNode { private: void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; void BuildExecGphAndRegst() override; }; @@ -43,6 +44,8 @@ void DstSubsetTickCompTaskNode::ConsumeAllRegsts() { ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("in", edge->GetSoleRegst()); }); } +void DstSubsetTickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } + void DstSubsetTickCompTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); node->mut_op() = op(); diff --git a/oneflow/core/graph_impl/esac_compute_task_node.cpp b/oneflow/core/graph_impl/esac_compute_task_node.cpp index 8595c22ce0f..ac43f5a2c91 100644 --- a/oneflow/core/graph_impl/esac_compute_task_node.cpp +++ b/oneflow/core/graph_impl/esac_compute_task_node.cpp @@ -26,6 +26,7 @@ class EsacCompTaskNode final : public CompTaskNode { void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override { UNIMPLEMENTED() << "EsacCompTaskNode is deprecated"; } TaskType GetTaskType() const override { return TaskType::kEsac; } diff --git a/oneflow/core/graph_impl/normal_forward_compute_task_node.cpp b/oneflow/core/graph_impl/normal_forward_compute_task_node.cpp index 8f71a3ac335..a93742b76a2 100644 --- a/oneflow/core/graph_impl/normal_forward_compute_task_node.cpp +++ b/oneflow/core/graph_impl/normal_forward_compute_task_node.cpp @@ -95,6 +95,8 @@ void NormalForwardCompTaskNode::ConsumeAllRegsts() { }); } +void NormalForwardCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } + void NormalForwardCompTaskNode::BuildExecGphAndRegst() { BuildExecGphStructAndBindInRegst(); BuildOutRegst(); diff --git a/oneflow/core/graph_impl/pack_compute_task_node.cpp b/oneflow/core/graph_impl/pack_compute_task_node.cpp index 88ec0ce6834..88f27c6527b 100644 --- a/oneflow/core/graph_impl/pack_compute_task_node.cpp +++ b/oneflow/core/graph_impl/pack_compute_task_node.cpp @@ -28,6 +28,7 @@ class PackCompTaskNode final : public CompTaskNode { void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; private: void BuildExecGphAndRegst() override; @@ -40,6 +41,8 @@ void PackCompTaskNode::ProduceAllRegstsAndBindEdges() { void PackCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); } +void PackCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } + void PackCompTaskNode::BuildExecGphAndRegst() { ExecNode* exec_node = mut_exec_gph().NewNode(); exec_node->mut_op() = op(); diff --git a/oneflow/core/graph_impl/reentrant_lock_compute_task_node.cpp b/oneflow/core/graph_impl/reentrant_lock_compute_task_node.cpp index 5045b65dc23..8b0cfef852b 100644 --- a/oneflow/core/graph_impl/reentrant_lock_compute_task_node.cpp +++ b/oneflow/core/graph_impl/reentrant_lock_compute_task_node.cpp @@ -30,6 +30,7 @@ class ReentrantLockCompTaskNode final : public CompTaskNode { private: void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; void BuildExecGphAndRegst() override; void InferProducedDataRegstTimeShape() override; }; @@ -44,6 +45,8 @@ void ReentrantLockCompTaskNode::ConsumeAllRegsts() { ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("in", edge->GetSoleRegst()); }); } +void ReentrantLockCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } + void ReentrantLockCompTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); node->mut_op() = op(); diff --git a/oneflow/core/graph_impl/repeat_compute_task_node.cpp b/oneflow/core/graph_impl/repeat_compute_task_node.cpp index 6385ac8dc37..582d61dadef 100644 --- a/oneflow/core/graph_impl/repeat_compute_task_node.cpp +++ b/oneflow/core/graph_impl/repeat_compute_task_node.cpp @@ -26,6 +26,7 @@ class RepeatCompTaskNode final : public CompTaskNode { void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; TaskType GetTaskType() const override { return TaskType::kRepeat; } @@ -37,6 +38,8 @@ void RepeatCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); } +void RepeatCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } + void RepeatCompTaskNode::ProduceAllRegstsAndBindEdges() { std::shared_ptr out_regst = ProduceRegst("out", false, 1, 1); ForEachOutDataEdge([&](TaskEdge* edge) { edge->AddRegst("out", out_regst); }); diff --git a/oneflow/core/graph_impl/source_tick_compute_task_node.cpp b/oneflow/core/graph_impl/source_tick_compute_task_node.cpp index 93393a89145..93cbb1dda7d 100644 --- a/oneflow/core/graph_impl/source_tick_compute_task_node.cpp +++ b/oneflow/core/graph_impl/source_tick_compute_task_node.cpp @@ -26,6 +26,7 @@ class SourceTickCompTaskNode final : public CompTaskNode { void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override {} + void ConsumeFakeRegsts() override {} void BuildExecGphAndRegst() override; bool IsMeaningLess() override { return false; } diff --git a/oneflow/core/graph_impl/src_subset_tick_compute_task_node.cpp b/oneflow/core/graph_impl/src_subset_tick_compute_task_node.cpp index e1fa7349553..30639d07abe 100644 --- a/oneflow/core/graph_impl/src_subset_tick_compute_task_node.cpp +++ b/oneflow/core/graph_impl/src_subset_tick_compute_task_node.cpp @@ -30,6 +30,7 @@ class SrcSubsetTickCompTaskNode final : public CompTaskNode { private: void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; void BuildExecGphAndRegst() override; }; @@ -43,6 +44,8 @@ void SrcSubsetTickCompTaskNode::ConsumeAllRegsts() { ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("in", edge->GetSoleRegst()); }); } +void SrcSubsetTickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } + void SrcSubsetTickCompTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); node->mut_op() = op(); diff --git a/oneflow/core/graph_impl/ssp_variable_proxy_task_node.cpp b/oneflow/core/graph_impl/ssp_variable_proxy_task_node.cpp index aa6d46729b0..4c8e2743ab8 100644 --- a/oneflow/core/graph_impl/ssp_variable_proxy_task_node.cpp +++ b/oneflow/core/graph_impl/ssp_variable_proxy_task_node.cpp @@ -67,6 +67,8 @@ class SspVariableProxyCompTaskNode final : public CompTaskNode { ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("var", edge->GetSoleRegst()); }); } + void ConsumeFakeRegsts() override { ConsumeFakeRegst("var"); } + TaskType GetTaskType() const override { return TaskType::kSspVariableProxy; } private: diff --git a/oneflow/core/graph_impl/tick_compute_task_node.cpp b/oneflow/core/graph_impl/tick_compute_task_node.cpp index 052b72841a7..306f78874e6 100644 --- a/oneflow/core/graph_impl/tick_compute_task_node.cpp +++ b/oneflow/core/graph_impl/tick_compute_task_node.cpp @@ -30,6 +30,7 @@ class TickCompTaskNode final : public CompTaskNode { private: void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; void BuildExecGphAndRegst() override; }; @@ -43,6 +44,8 @@ void TickCompTaskNode::ConsumeAllRegsts() { ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("in", edge->GetSoleRegst()); }); } +void TickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } + void TickCompTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); node->mut_op() = op(); diff --git a/oneflow/core/graph_impl/unpack_compute_task_node.cpp b/oneflow/core/graph_impl/unpack_compute_task_node.cpp index 2916ee2fbb5..49a4dfa3e2d 100644 --- a/oneflow/core/graph_impl/unpack_compute_task_node.cpp +++ b/oneflow/core/graph_impl/unpack_compute_task_node.cpp @@ -28,6 +28,7 @@ class UnpackCompTaskNode final : public CompTaskNode { void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; private: void BuildExecGphAndRegst() override; @@ -42,6 +43,8 @@ void UnpackCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); } +void UnpackCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } + void UnpackCompTaskNode::BuildExecGphAndRegst() { ExecNode* exec_node = mut_exec_gph().NewNode(); exec_node->mut_op() = op(); diff --git a/oneflow/core/graph_impl/wait_and_send_ids_compute_task_node.cpp b/oneflow/core/graph_impl/wait_and_send_ids_compute_task_node.cpp index ba79fca0c8d..9d3e0626b9e 100644 --- a/oneflow/core/graph_impl/wait_and_send_ids_compute_task_node.cpp +++ b/oneflow/core/graph_impl/wait_and_send_ids_compute_task_node.cpp @@ -26,6 +26,7 @@ class WaitAndSendIdsCompTaskNode final : public CompTaskNode { void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override {} + void ConsumeFakeRegsts() override {} void BuildExecGphAndRegst() override; bool IsMeaningLess() override { return false; } diff --git a/oneflow/core/register/register_desc.cpp b/oneflow/core/register/register_desc.cpp index ca5ac440556..d7b3e75efb1 100644 --- a/oneflow/core/register/register_desc.cpp +++ b/oneflow/core/register/register_desc.cpp @@ -120,6 +120,32 @@ void RegstDesc::EraseUninitializedShapeBlob() { }); } +void RegstDesc::InitFromProtoExceptConsumers(const RegstDescProto& proto) { + regst_desc_id_ = proto.regst_desc_id(); + CHECK_EQ(proto.producer_task_id(), producer_->task_id()); + regst_desc_type_ = proto.regst_desc_type(); + if (regst_desc_type_.has_data_regst_desc()) { + const DataRegstDesc& data_regst_desc_proto = proto.regst_desc_type().data_regst_desc(); + for (const auto& pair : data_regst_desc_proto.lbi2blob_desc()) { + *AddLbi(pair.lbi()) = BlobDesc(pair.blob_desc()); + } + CHECK(!data_regst_desc_proto.has_time_shape()); + } else if (regst_desc_type_.has_ctrl_regst_desc()) { + // do nothing + } else { + UNIMPLEMENTED(); + } + min_register_num_ = proto.min_register_num(); + max_register_num_ = proto.max_register_num(); + min_register_num_ = proto.register_num(); + mem_case_ = proto.mem_case(); + enable_reuse_mem_ = proto.enable_reuse_mem(); + mem_block_id_ = proto.mem_block_id(); + mem_block_offset_ = proto.mem_block_offset(); + hint_inplace_consumed_regst_desc_id_ = proto.hint_inplace_consumed_regst_desc_id(); + force_inplace_consumed_regst_desc_id_ = proto.force_inplace_consumed_regst_desc_id(); +} + void RegstDesc::ToProto(RegstDescProto* ret, bool check) const { ret->set_regst_desc_id(regst_desc_id_); ret->set_producer_task_id(producer_->task_id()); diff --git a/oneflow/core/register/register_desc.h b/oneflow/core/register/register_desc.h index faee94ca856..1da2e267325 100644 --- a/oneflow/core/register/register_desc.h +++ b/oneflow/core/register/register_desc.h @@ -99,6 +99,7 @@ class RegstDesc final { // util void EraseUninitializedShapeBlob(); + void InitFromProtoExceptConsumers(const RegstDescProto& proto); void ToProto(RegstDescProto* proto) const { ToProto(proto, /*check*/ true); } void ToProto(RegstDescProto*, bool check) const; bool HasSameBlobDescs(const RegstDesc*); @@ -117,8 +118,8 @@ class RegstDesc final { bool enable_reuse_mem_; int32_t mem_block_id_; int64_t mem_block_offset_; - int32_t hint_inplace_consumed_regst_desc_id_; - int32_t force_inplace_consumed_regst_desc_id_; + int64_t hint_inplace_consumed_regst_desc_id_; + int64_t force_inplace_consumed_regst_desc_id_; std::shared_ptr data_regst_time_shape_; }; From acff92c1724ab3fb876c40fb875c84a2c862ec68 Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 14 Apr 2023 15:21:09 +0000 Subject: [PATCH 11/18] fix typo --- oneflow/core/graph/fake_consumed_regst_provider.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/core/graph/fake_consumed_regst_provider.h b/oneflow/core/graph/fake_consumed_regst_provider.h index 8bd8dc26326..2de5a194924 100644 --- a/oneflow/core/graph/fake_consumed_regst_provider.h +++ b/oneflow/core/graph/fake_consumed_regst_provider.h @@ -32,4 +32,4 @@ class FakeConsumedRegstProvider { } // namespace oneflow -#endif // ONEFLOW_CORE_GRAPH_FAKE_CONSUMED_REGST_PROVIDER_H_ \ No newline at end of file +#endif // ONEFLOW_CORE_GRAPH_FAKE_CONSUMED_REGST_PROVIDER_H_ From 17203d092f7a8853c4e0540696d9683c6c0b60f2 Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 14 Apr 2023 15:42:04 +0000 Subject: [PATCH 12/18] add task factory to create new task node --- oneflow/core/graph/task_type_visitor.h | 150 +++++++++++++++++++++++++ 1 file changed, 150 insertions(+) create mode 100644 oneflow/core/graph/task_type_visitor.h diff --git a/oneflow/core/graph/task_type_visitor.h b/oneflow/core/graph/task_type_visitor.h new file mode 100644 index 00000000000..532d48bf308 --- /dev/null +++ b/oneflow/core/graph/task_type_visitor.h @@ -0,0 +1,150 @@ +#include +#include "oneflow/core/job/task.pb.h" +#include "oneflow/core/graph/collective_boxing_task_node.h" +#include "oneflow/core/graph/nccl_send_recv_boxing_task_node.h" +#include "oneflow/core/graph/copy_task_node.h" +#include "oneflow/core/graph/boxing_zeros_task_node.h" +#include "oneflow/core/graph/slice_boxing_task_node.h" +#include "oneflow/core/graph/collective_boxing_pack_task_node.h" +#include "oneflow/core/graph/collective_boxing_unpack_task_node.h" +#include "oneflow/core/graph/boxing_identity_task_node.h" + +namespace oneflow { + +template +struct TaskTypeVisitor { + template + static auto Visit(TaskType task_type, Args&&... args) { + switch (task_type) { + case TaskType::kInvalid: LOG(FATAL) << "invalid task type"; + case TaskType::kNormalForward: + return DerivedT::VisitNormalForward(std::forward(args)...); + case TaskType::kCopyHd: return DerivedT::VisitCopyHd(std::forward(args)...); + case TaskType::kCopyCommNet: return DerivedT::VisitCopyCommNet(std::forward(args)...); + case TaskType::kDeviceTick: return DerivedT::VisitDeviceTick(std::forward(args)...); + case TaskType::kPack: return DerivedT::VisitPack(std::forward(args)...); + case TaskType::kUnpack: return DerivedT::VisitUnpack(std::forward(args)...); + case TaskType::kRepeat: return DerivedT::VisitRepeat(std::forward(args)...); + case TaskType::kAcc: return DerivedT::VisitAcc(std::forward(args)...); + case TaskType::kAccCtrlTick: return DerivedT::VisitAccCtrlTick(std::forward(args)...); + case TaskType::kSrcSubsetTick: + return DerivedT::VisitSrcSubsetTick(std::forward(args)...); + case TaskType::kDstSubsetTick: + return DerivedT::VisitDstSubsetTick(std::forward(args)...); + case TaskType::kSourceTick: return DerivedT::VisitSourceTick(std::forward(args)...); + case TaskType::kTick: return DerivedT::VisitTick(std::forward(args)...); + case TaskType::kAccTick: return DerivedT::VisitAccTick(std::forward(args)...); + case TaskType::kCase: return DerivedT::VisitCase(std::forward(args)...); + case TaskType::kEsac: return DerivedT::VisitEsac(std::forward(args)...); + case TaskType::kWaitAndSendIds: + return DerivedT::VisitWaitAndSendIds(std::forward(args)...); + case TaskType::kReentrantLock: + return DerivedT::VisitReentrantLock(std::forward(args)...); + case TaskType::kCallbackNotify: + return DerivedT::VisitCallbackNotify(std::forward(args)...); + case TaskType::kDistributeConcat: + return DerivedT::VisitDistributeConcat(std::forward(args)...); + case TaskType::kDistributeSplit: + return DerivedT::VisitDistributeSplit(std::forward(args)...); + case TaskType::kSliceBoxing: return DerivedT::VisitSliceBoxing(std::forward(args)...); + case TaskType::kCollectiveBoxingGeneric: + return DerivedT::VisitCollectiveBoxingGeneric(std::forward(args)...); + case TaskType::kBoxingIdentity: + return DerivedT::VisitBoxingIdentity(std::forward(args)...); + case TaskType::kDecodeH2D: return DerivedT::VisitDecodeH2D(std::forward(args)...); + case TaskType::kCollectiveBoxingPack: + return DerivedT::VisitCollectiveBoxingPack(std::forward(args)...); + case TaskType::kCollectiveBoxingUnpack: + return DerivedT::VisitCollectiveBoxingUnpack(std::forward(args)...); + case TaskType::kSspVariableProxy: + return DerivedT::VisitSspVariableProxy(std::forward(args)...); + case TaskType::kBoxingZeros: return DerivedT::VisitBoxingZeros(std::forward(args)...); + case TaskType::kCriticalSectionWaitTick: + return DerivedT::VisitCriticalSectionWaitTick(std::forward(args)...); + case TaskType::kNcclSendRecvBoxing: + return DerivedT::VisitNcclSendRecvBoxing(std::forward(args)...); + } + LOG(FATAL) << "invalid task type"; + } +}; + +struct IsTransportTaskType final : public TaskTypeVisitor { + static bool VisitCopyHd() { return true; } + static bool VisitCopyCommNet() { return true; } + static bool VisitSliceBoxing() { return true; } + static bool VisitCollectiveBoxingGeneric() { return true; } + static bool VisitBoxingIdentity() { return true; } + static bool VisitCollectiveBoxingPack() { return true; } + static bool VisitCollectiveBoxingUnpack() { return true; } + static bool VisitNcclSendRecvBoxing() { return true; } + static bool VisitBoxingZeros() { return true; } + + static bool VisitNormalForward() { return false; } + static bool VisitDeviceTick() { return false; } + static bool VisitPack() { return false; } + static bool VisitUnpack() { return false; } + static bool VisitRepeat() { return false; } + static bool VisitAcc() { return false; } + static bool VisitSrcSubsetTick() { return false; } + static bool VisitDstSubsetTick() { return false; } + static bool VisitSourceTick() { return false; } + static bool VisitTick() { return false; } + static bool VisitAccTick() { return false; } + static bool VisitCase() { return false; } + static bool VisitEsac() { return false; } + static bool VisitWaitAndSendIds() { return false; } + static bool VisitReentrantLock() { return false; } + static bool VisitCallbackNotify() { return false; } + static bool VisitDistributeConcat() { return false; } + static bool VisitDistributeSplit() { return false; } + static bool VisitDecodeH2D() { return false; } + static bool VisitSspVariableProxy() { return false; } + static bool VisitCriticalSectionWaitTick() { return false; } +}; + +template +struct TransportTaskTypeVisitor { + template + static auto Visit(TaskType task_type, Args&&... args) { + switch (task_type) { + case TaskType::kInvalid: LOG(FATAL) << "invalid task type"; + case TaskType::kCopyHd: return DerivedT::VisitCopyHd(std::forward(args)...); + case TaskType::kCopyCommNet: return DerivedT::VisitCopyCommNet(std::forward(args)...); + case TaskType::kSliceBoxing: return DerivedT::VisitSliceBoxing(std::forward(args)...); + case TaskType::kCollectiveBoxingGeneric: + return DerivedT::VisitCollectiveBoxingGeneric(std::forward(args)...); + case TaskType::kBoxingIdentity: + return DerivedT::VisitBoxingIdentity(std::forward(args)...); + case TaskType::kNcclSendRecvBoxing: + return DerivedT::VisitNcclSendRecvBoxing(std::forward(args)...); + case TaskType::kBoxingZeros: return DerivedT::VisitBoxingZeros(std::forward(args)...); + case TaskType::kCollectiveBoxingPack: + return DerivedT::VisitCollectiveBoxingPack(std::forward(args)...); + case TaskType::kCollectiveBoxingUnpack: + return DerivedT::VisitCollectiveBoxingUnpack(std::forward(args)...); + default: LOG(FATAL) << "invalid task type"; + } + } +}; + +struct CreateTransportTask final : public TransportTaskTypeVisitor { + static Maybe VisitCopyHd() { return new CopyHdTaskNode(); } + static Maybe VisitCopyCommNet() { return new CopyCommNetTaskNode(); } + static Maybe VisitSliceBoxing() { return new SliceBoxingTaskNode(); } + static Maybe VisitCollectiveBoxingGeneric() { + return new CollectiveBoxingGenericTaskNode(); + } + static Maybe VisitBoxingIdentity() { return new BoxingIdentityTaskNode(); } + static Maybe VisitCollectiveBoxingPack() { + return new CollectiveBoxingPackTaskNode(); + } + static Maybe VisitCollectiveBoxingUnpack() { + return new CollectiveBoxingUnpackTaskNode(); + } + static Maybe VisitBoxingZeros() { return new BoxingZerosTaskNode(); } + static Maybe VisitNcclSendRecvBoxing() { + return new NcclSendRecvBoxingTaskNode(); + } +}; + +} // namespace oneflow From 88f429769761472da575576d4a78e6baf1fb0a10 Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 14 Apr 2023 15:48:47 +0000 Subject: [PATCH 13/18] add infer from ndsbp --- oneflow/core/graph/exec_graph.cpp | 65 +++++++++++++++++++++++++++++++ oneflow/core/graph/exec_graph.h | 1 + 2 files changed, 66 insertions(+) diff --git a/oneflow/core/graph/exec_graph.cpp b/oneflow/core/graph/exec_graph.cpp index a164a06e493..99715f4c575 100644 --- a/oneflow/core/graph/exec_graph.cpp +++ b/oneflow/core/graph/exec_graph.cpp @@ -103,6 +103,28 @@ Maybe CheckPhysicalBlobDesc( return Maybe::Ok(); } +// A helper function to infer blob's physical shape with ND SBP. +Maybe InferPhysicalBlobDesc( + const Operator& op, const PbRpf& bns, + const std::function(const std::string&)>& GetLogicalBlobDesc, + const NdSbpSignature* nd_sbp_signature, const ParallelContext* parallel_ctx, + const std::function& GetPhysicalBlobDesc) { + const std::shared_ptr op_parallel_desc = JUST(op.GetOpParallelDesc()); + for (const auto& bn : bns) { + BlobDesc* physical_blob_desc = GetPhysicalBlobDesc(bn); + const auto& logical_blob_desc = *JUST(GetLogicalBlobDesc(bn)); + CHECK_NOTNULL_OR_RETURN(physical_blob_desc) + << "physical_blob_desc should not be nullptr. op location: " << op.op_loc(); + *physical_blob_desc = logical_blob_desc; + const auto& physical_shape = JUST_MSG( + GetPhysicalShape(logical_blob_desc.shape(), nd_sbp_signature->bn_in_op2nd_sbp().at(bn), + *op_parallel_desc, *parallel_ctx), + std::stringstream() << " check physical shape failed, op name " << op.op_loc()); + physical_blob_desc->set_shape(*physical_shape); + } + return Maybe::Ok(); +} + } // namespace void ExecNode::InferBlobDescsByInputs(const ParallelContext* parallel_ctx) { @@ -131,6 +153,49 @@ void ExecNode::InferBlobDescsByInputs(const ParallelContext* parallel_ctx) { << " infer inplace obn to ibn is failed, op name " << op_->op_loc()); } +void ExecNode::InferBlobDescsByNdSbp(const ParallelContext* parallel_ctx) { + const HashSet ibns{op()->input_bns().begin(), op()->input_bns().end()}; + HashMap ibn2blob_desc{}; + const auto& GetBlobDesc4BnInOp = [&](const std::string& bn_in_op) -> BlobDesc* { + // Generate temp regst to store input blob desc, and will be released after infer output blob + // desc. + if (ibns.count(bn_in_op) > 0) { + auto iter = ibn2blob_desc.find(bn_in_op); + if (iter == ibn2blob_desc.end()) { + iter = ibn2blob_desc.emplace(bn_in_op, kInvalidDataType).first; + } + return &iter->second; + } + auto it = bn_in_op2regst_.find(bn_in_op); + if (it == bn_in_op2regst_.end()) { return nullptr; } + std::shared_ptr regst = it->second; + CHECK(regst); + return regst->MutBlobDesc(op()->BnInOp2Lbi(bn_in_op)); + }; + const OpNode* op_node = Singleton::Get()->OpNode4OpName(op()->op_name()); + const NdSbpSignature* nd_sbp_signature = &CHECK_NOTNULL(op_node)->nd_sbp_signature(); + + // TODO(strint): user op can infer output with SBP, so there is no need to infer the input. + // Reference: https://github.com/Oneflow-Inc/oneflow/pull/8971 + // Infer input blob desc with SBP, the infer results are set intuo the temp input blob desc. + CHECK_JUST(InferPhysicalBlobDesc( + *op(), op()->input_bns(), + std::bind(&Operator::GetLogicalBlobDesc4Ibn, op().get(), std::placeholders::_1), + nd_sbp_signature, parallel_ctx, GetBlobDesc4BnInOp)); + + // Infer output blob desc with input. + CHECK_JUST_MSG(op_->InferBlobDescsIf(GetBlobDesc4BnInOp, parallel_ctx, &GlobalJobDesc()), + std::stringstream() << " infer blob descs is failed, op name " << op_->op_loc()); + CHECK_JUST(CheckPhysicalBlobDesc( + *op(), op()->output_bns(), + std::bind(&Operator::GetLogicalBlobDesc4Obn, op().get(), std::placeholders::_1), + nd_sbp_signature, parallel_ctx, GetBlobDesc4BnInOp)); + CHECK_JUST_MSG(op_->InferInplaceObn2IbnIf(&mut_inplace_obn2ibn_, &con_inplace_obn2ibn_, + GetBlobDesc4BnInOp, parallel_ctx), + std::stringstream() + << " infer inplace obn to ibn is failed, op name " << op_->op_loc()); +} + std::function ExecNode::GetBlobDesc4BnInOpFunc() const { return [this](const std::string& bn_in_op) -> BlobDesc* { auto it = bn_in_op2regst_.find(bn_in_op); diff --git a/oneflow/core/graph/exec_graph.h b/oneflow/core/graph/exec_graph.h index c8a14abae9a..4851759a1a6 100644 --- a/oneflow/core/graph/exec_graph.h +++ b/oneflow/core/graph/exec_graph.h @@ -74,6 +74,7 @@ class ExecNode final : public Node { typedef void (ExecNode::*InferBlobDescsMethod)(const ParallelContext*); void InferBlobDescsByInputs(const ParallelContext* parallel_ctx); + void InferBlobDescsByNdSbp(const ParallelContext* parallel_ctx); const HashMap& mut_inplace_obn2ibn() const { return mut_inplace_obn2ibn_; From 266c3882124df2a0d38f85e12a4c192a02b4cea7 Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 14 Apr 2023 16:23:15 +0000 Subject: [PATCH 14/18] rm useless --- oneflow/core/graph/task_type_visitor.h | 91 -------------------------- 1 file changed, 91 deletions(-) diff --git a/oneflow/core/graph/task_type_visitor.h b/oneflow/core/graph/task_type_visitor.h index 532d48bf308..5d0fc7fd676 100644 --- a/oneflow/core/graph/task_type_visitor.h +++ b/oneflow/core/graph/task_type_visitor.h @@ -11,97 +11,6 @@ namespace oneflow { -template -struct TaskTypeVisitor { - template - static auto Visit(TaskType task_type, Args&&... args) { - switch (task_type) { - case TaskType::kInvalid: LOG(FATAL) << "invalid task type"; - case TaskType::kNormalForward: - return DerivedT::VisitNormalForward(std::forward(args)...); - case TaskType::kCopyHd: return DerivedT::VisitCopyHd(std::forward(args)...); - case TaskType::kCopyCommNet: return DerivedT::VisitCopyCommNet(std::forward(args)...); - case TaskType::kDeviceTick: return DerivedT::VisitDeviceTick(std::forward(args)...); - case TaskType::kPack: return DerivedT::VisitPack(std::forward(args)...); - case TaskType::kUnpack: return DerivedT::VisitUnpack(std::forward(args)...); - case TaskType::kRepeat: return DerivedT::VisitRepeat(std::forward(args)...); - case TaskType::kAcc: return DerivedT::VisitAcc(std::forward(args)...); - case TaskType::kAccCtrlTick: return DerivedT::VisitAccCtrlTick(std::forward(args)...); - case TaskType::kSrcSubsetTick: - return DerivedT::VisitSrcSubsetTick(std::forward(args)...); - case TaskType::kDstSubsetTick: - return DerivedT::VisitDstSubsetTick(std::forward(args)...); - case TaskType::kSourceTick: return DerivedT::VisitSourceTick(std::forward(args)...); - case TaskType::kTick: return DerivedT::VisitTick(std::forward(args)...); - case TaskType::kAccTick: return DerivedT::VisitAccTick(std::forward(args)...); - case TaskType::kCase: return DerivedT::VisitCase(std::forward(args)...); - case TaskType::kEsac: return DerivedT::VisitEsac(std::forward(args)...); - case TaskType::kWaitAndSendIds: - return DerivedT::VisitWaitAndSendIds(std::forward(args)...); - case TaskType::kReentrantLock: - return DerivedT::VisitReentrantLock(std::forward(args)...); - case TaskType::kCallbackNotify: - return DerivedT::VisitCallbackNotify(std::forward(args)...); - case TaskType::kDistributeConcat: - return DerivedT::VisitDistributeConcat(std::forward(args)...); - case TaskType::kDistributeSplit: - return DerivedT::VisitDistributeSplit(std::forward(args)...); - case TaskType::kSliceBoxing: return DerivedT::VisitSliceBoxing(std::forward(args)...); - case TaskType::kCollectiveBoxingGeneric: - return DerivedT::VisitCollectiveBoxingGeneric(std::forward(args)...); - case TaskType::kBoxingIdentity: - return DerivedT::VisitBoxingIdentity(std::forward(args)...); - case TaskType::kDecodeH2D: return DerivedT::VisitDecodeH2D(std::forward(args)...); - case TaskType::kCollectiveBoxingPack: - return DerivedT::VisitCollectiveBoxingPack(std::forward(args)...); - case TaskType::kCollectiveBoxingUnpack: - return DerivedT::VisitCollectiveBoxingUnpack(std::forward(args)...); - case TaskType::kSspVariableProxy: - return DerivedT::VisitSspVariableProxy(std::forward(args)...); - case TaskType::kBoxingZeros: return DerivedT::VisitBoxingZeros(std::forward(args)...); - case TaskType::kCriticalSectionWaitTick: - return DerivedT::VisitCriticalSectionWaitTick(std::forward(args)...); - case TaskType::kNcclSendRecvBoxing: - return DerivedT::VisitNcclSendRecvBoxing(std::forward(args)...); - } - LOG(FATAL) << "invalid task type"; - } -}; - -struct IsTransportTaskType final : public TaskTypeVisitor { - static bool VisitCopyHd() { return true; } - static bool VisitCopyCommNet() { return true; } - static bool VisitSliceBoxing() { return true; } - static bool VisitCollectiveBoxingGeneric() { return true; } - static bool VisitBoxingIdentity() { return true; } - static bool VisitCollectiveBoxingPack() { return true; } - static bool VisitCollectiveBoxingUnpack() { return true; } - static bool VisitNcclSendRecvBoxing() { return true; } - static bool VisitBoxingZeros() { return true; } - - static bool VisitNormalForward() { return false; } - static bool VisitDeviceTick() { return false; } - static bool VisitPack() { return false; } - static bool VisitUnpack() { return false; } - static bool VisitRepeat() { return false; } - static bool VisitAcc() { return false; } - static bool VisitSrcSubsetTick() { return false; } - static bool VisitDstSubsetTick() { return false; } - static bool VisitSourceTick() { return false; } - static bool VisitTick() { return false; } - static bool VisitAccTick() { return false; } - static bool VisitCase() { return false; } - static bool VisitEsac() { return false; } - static bool VisitWaitAndSendIds() { return false; } - static bool VisitReentrantLock() { return false; } - static bool VisitCallbackNotify() { return false; } - static bool VisitDistributeConcat() { return false; } - static bool VisitDistributeSplit() { return false; } - static bool VisitDecodeH2D() { return false; } - static bool VisitSspVariableProxy() { return false; } - static bool VisitCriticalSectionWaitTick() { return false; } -}; - template struct TransportTaskTypeVisitor { template From 38bf3662eedc2c388dce0627b924b528f83da098 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 20 Apr 2023 07:48:27 +0000 Subject: [PATCH 15/18] fix merge --- oneflow/core/register/register_desc.cpp | 3 --- oneflow/core/register/register_desc.h | 3 --- 2 files changed, 6 deletions(-) diff --git a/oneflow/core/register/register_desc.cpp b/oneflow/core/register/register_desc.cpp index 7310ca27524..d7b3e75efb1 100644 --- a/oneflow/core/register/register_desc.cpp +++ b/oneflow/core/register/register_desc.cpp @@ -120,7 +120,6 @@ void RegstDesc::EraseUninitializedShapeBlob() { }); } -<<<<<<< HEAD void RegstDesc::InitFromProtoExceptConsumers(const RegstDescProto& proto) { regst_desc_id_ = proto.regst_desc_id(); CHECK_EQ(proto.producer_task_id(), producer_->task_id()); @@ -147,8 +146,6 @@ void RegstDesc::InitFromProtoExceptConsumers(const RegstDescProto& proto) { force_inplace_consumed_regst_desc_id_ = proto.force_inplace_consumed_regst_desc_id(); } -======= ->>>>>>> 2d5436543068a2667ac229393a425980d473932b void RegstDesc::ToProto(RegstDescProto* ret, bool check) const { ret->set_regst_desc_id(regst_desc_id_); ret->set_producer_task_id(producer_->task_id()); diff --git a/oneflow/core/register/register_desc.h b/oneflow/core/register/register_desc.h index e2bcc6a0fcd..1da2e267325 100644 --- a/oneflow/core/register/register_desc.h +++ b/oneflow/core/register/register_desc.h @@ -99,10 +99,7 @@ class RegstDesc final { // util void EraseUninitializedShapeBlob(); -<<<<<<< HEAD void InitFromProtoExceptConsumers(const RegstDescProto& proto); -======= ->>>>>>> 2d5436543068a2667ac229393a425980d473932b void ToProto(RegstDescProto* proto) const { ToProto(proto, /*check*/ true); } void ToProto(RegstDescProto*, bool check) const; bool HasSameBlobDescs(const RegstDesc*); From a3b9798bb1730ac890daf1c272b33e03cd2e2dcd Mon Sep 17 00:00:00 2001 From: Xiaoyu Xu Date: Wed, 26 Apr 2023 14:42:19 +0800 Subject: [PATCH 16/18] Update task_node.cpp --- oneflow/core/graph/task_node.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/core/graph/task_node.cpp b/oneflow/core/graph/task_node.cpp index 684cc483d5f..76cde3aea2b 100644 --- a/oneflow/core/graph/task_node.cpp +++ b/oneflow/core/graph/task_node.cpp @@ -226,7 +226,7 @@ void TaskNode::InitFromProtoExceptConsumedRegsts(const TaskProto& task_proto) { Maybe TaskNode::InitConsumedRegstsFromProto( const TaskProto& task_proto, const std::function(int64_t regst_desc_id)>& RegstDesc4Id) { - // Step3: init consumed_regst. + // init consumed_regst. for (const auto& pair : task_proto.consumed_regst_desc_id()) { for (int64_t regst_desc_id : pair.second.regst_desc_id()) { ConsumeRegst(pair.first, JUST(RegstDesc4Id(regst_desc_id))); From 2be6d00897ab39ad402c181ed372c25af99988e9 Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 5 May 2023 10:07:14 +0000 Subject: [PATCH 17/18] address review --- oneflow/core/graph/task_node.cpp | 19 +++++++++++++------ oneflow/core/graph/task_node.h | 9 +++++---- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/oneflow/core/graph/task_node.cpp b/oneflow/core/graph/task_node.cpp index 76cde3aea2b..affa2818481 100644 --- a/oneflow/core/graph/task_node.cpp +++ b/oneflow/core/graph/task_node.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/task_node.h" +#include "oneflow/core/common/container_util.h" #include "oneflow/core/job/id_manager.h" #include "oneflow/core/memory/memory_case_util.h" #include "oneflow/core/graph/task_graph_rebuild_ctx.h" @@ -301,9 +302,10 @@ void TaskNode::BindEdgeWithProducedRegst(TaskEdge* edge, const std::string& name edge->AddRegst(name, GetProducedRegst(name)); } -std::shared_ptr TaskNode::GetOrCheckRegst(const std::string& name, bool enable_reuse_mem, - int32_t min_register_num, - int32_t max_register_num) const { +std::shared_ptr TaskNode::GetAndCheckRegst(const std::string& name, + bool enable_reuse_mem, + int32_t min_register_num, + int32_t max_register_num) const { auto iter = produced_regsts_.find(name); if (iter == produced_regsts_.end()) { return nullptr; } const auto& regst = (iter->second); @@ -322,7 +324,7 @@ std::shared_ptr TaskNode::ProduceRegst(const std::string& name, bool int32_t max_register_num) { // Because the Regst of separate compilation is not created in order, some Regst may have been // built. This implementation can avoid ProduceRegst being called multiple times. - const auto& regst = GetOrCheckRegst(name, enable_reuse_mem, min_register_num, max_register_num); + const auto& regst = GetAndCheckRegst(name, enable_reuse_mem, min_register_num, max_register_num); if (regst) { return regst; } RegstDescTypeProto regst_desc_type; regst_desc_type.mutable_data_regst_desc(); @@ -408,7 +410,8 @@ bool TaskEdge::HasRegst(const std::string& name_in_producer) const { std::shared_ptr TaskEdge::GetSoleRegst() const { CHECK_EQ(name_in_producer2regst_.size(), 1) - << "edge: " << this << ", src: " << src_node()->task_id(); + << "edge: " << this << ", src: " << src_node()->task_id() + << ", dst: " << dst_node()->task_id(); return name_in_producer2regst_.begin()->second; } @@ -421,7 +424,11 @@ std::vector> TaskEdge::GetRegsts() const { void TaskEdge::AddRegst(const std::string& name_in_producer, const std::shared_ptr& regst) { - if (HasRegst(name_in_producer)) { return; } + if (HasRegst(name_in_producer)) { + CHECK(CHECK_JUST(MapAt(name_in_producer2regst_, name_in_producer))->regst_desc_id() + == regst->regst_desc_id()); + return; + } CHECK(name_in_producer2regst_.emplace(name_in_producer, regst).second); } diff --git a/oneflow/core/graph/task_node.h b/oneflow/core/graph/task_node.h index 65ca8181a2e..81937610ae1 100644 --- a/oneflow/core/graph/task_node.h +++ b/oneflow/core/graph/task_node.h @@ -100,6 +100,7 @@ class TaskNode : public Node { std::string VisualStr() const override; virtual bool IsMeaningLess(); void ToProto(TaskProto* task_proto) const { ToProto(task_proto, /*check*/ true); } + // Used to create task node from proto in plan separation compilation. virtual void InitFromProtoExceptConsumedRegsts(const TaskProto& task_proto); Maybe InitConsumedRegstsFromProto( const TaskProto& task_proto, @@ -154,9 +155,9 @@ class TaskNode : public Node { private: void UpdateTaskId(); - std::shared_ptr GetOrCheckRegst(const std::string& name, bool enable_reuse_mem, - int32_t min_register_num, - int32_t max_register_num) const; + std::shared_ptr GetAndCheckRegst(const std::string& name, bool enable_reuse_mem, + int32_t min_register_num, + int32_t max_register_num) const; int64_t machine_id_; int64_t thrd_id_; @@ -189,7 +190,7 @@ class TaskEdge final : public Edge { void AddLbis(const std::vector& lbis) { lbis_.insert(lbis.begin(), lbis.end()); } void CheckRegstLbiValid() const; - bool OutHasBindRegst() const { return !name_in_producer2regst_.empty(); } + bool HasRegst() const { return !name_in_producer2regst_.empty(); } Maybe InitFromProto(const TaskEdgeProto& proto, const TaskGraphRebuildCtx& task_graph_rebuild_ctx); From 014a96ca4e1af30150ac0bc5dd573d8ef8ca7d76 Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 5 May 2023 13:33:24 +0000 Subject: [PATCH 18/18] add comment --- oneflow/core/graph/compute_task_node.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/oneflow/core/graph/compute_task_node.cpp b/oneflow/core/graph/compute_task_node.cpp index 1cec2eb4582..bead3d261b2 100644 --- a/oneflow/core/graph/compute_task_node.cpp +++ b/oneflow/core/graph/compute_task_node.cpp @@ -82,6 +82,7 @@ void CompTaskNode::ConsumeFakeRegstsIf() { for (const auto& pair : consumed_regsts()) { for (const auto& regst_desc : pair.second) { if (regst_desc->regst_desc_type().has_data_regst_desc()) { + // Only one fake data regst is creatd for each CompTaskNode with ConsumeFakeRegsts(). CHECK(data_regst_desc == nullptr); data_regst_desc = CHECK_NOTNULL(regst_desc.get()); } else if (regst_desc->regst_desc_type().has_ctrl_regst_desc()) { @@ -93,6 +94,7 @@ void CompTaskNode::ConsumeFakeRegstsIf() { } if (data_regst_desc != nullptr) { for (const auto& ibn : op_node()->op().input_bns()) { + // Only one fake data regst is creatd and just use it for all input_bns as a placeholder. data_regst_desc->AddLbi(op_node()->op().BnInOp2Lbi(ibn)); } }