Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add consume fake regst #10140

Merged
merged 25 commits into from
May 6, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion oneflow/core/graph/boxing_identity_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.
*/
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/graph/boxing_identity_task_node.h"
#include "oneflow/core/graph/boxing_task_graph.pb.h"

namespace oneflow {

Expand Down Expand Up @@ -46,11 +47,24 @@ void BoxingIdentityTaskNode::BuildExecGphAndRegst() {
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out");
out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn()));
node->BindBnWithRegst(sole_op->SoleObn(), out_regst);
node->InferBlobDescs(nullptr);
(node->*GetInferBlobDescsMethod())(nullptr);
}

void BoxingIdentityTaskNode::InferProducedDataRegstTimeShape() {
NaiveInferProducedDataRegstTimeShape();
}

Maybe<void> BoxingIdentityTaskNode::InitTransportTaskFromProto(
const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) {
CHECK_OR_RETURN(transport_task_proto.has_boxing_identity_task())
<< "not a serialized BoxingIdentityTaskNode. debug string: "
<< transport_task_proto.DebugString();
return Maybe<void>::Ok();
}

void BoxingIdentityTaskNode::ToTransportTaskProto(TransportTaskProto* transport_task_proto) const {
ToProto(transport_task_proto->mutable_task_proto(), /*check=*/false);
transport_task_proto->mutable_boxing_identity_task();
}

} // namespace oneflow
4 changes: 4 additions & 0 deletions oneflow/core/graph/boxing_identity_task_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ class BoxingIdentityTaskNode : public TransportTaskNode {
void Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi);
TaskType GetTaskType() const override { return TaskType::kBoxingIdentity; }

Maybe<void> InitTransportTaskFromProto(const TransportTaskProto& transport_task_proto,
const TaskGraphRebuildCtx& ctx) override;
void ToTransportTaskProto(TransportTaskProto*) const override;

private:
void BuildExecGphAndRegst() override;
void ProduceAllRegstsAndBindEdges() override;
Expand Down
110 changes: 110 additions & 0 deletions oneflow/core/graph/boxing_task_graph.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
syntax = "proto2";
package oneflow;

import "oneflow/core/register/logical_blob_id.proto";
import "oneflow/core/common/shape.proto";
import "oneflow/core/common/data_type.proto";
import "oneflow/core/job/sbp_parallel.proto";
import "oneflow/core/job/task.proto";
import "oneflow/core/job/placement.proto";
import "oneflow/core/graph/task_edge.proto";
import "oneflow/core/operator/op_conf.proto";
import "oneflow/core/register/tensor_slice_view.proto";

message ComputeTasksProto {
map<int64, TaskProto> parallel_id2task = 2;
}

message CollectiveBoxingGenericTaskProto {
required OperatorConf op_conf = 1;
}

message NcclSendRecvBoxingTaskProto {
required ShapeProto logical_shape = 1;
required DataType data_type = 2;
required NdSbp src_nd_sbp = 3;
required NdSbp dst_nd_sbp = 4;
required ParallelConf src_parallel_conf = 5;
required ParallelConf dst_parallel_conf = 6;
required ParallelConf parallel_conf = 7;
required ParallelContext parallel_ctx = 8;
required bool has_input = 9;
required bool has_output = 10;
required string stream_name = 11;
}

enum CopyHdType {
H2D = 0;
D2H = 1;
}

message CopyHdTaskProto {
required CopyHdType copy_type = 1;
}

message CopyCommNetTaskProto {
}

message BoxingZerosTaskProto {
required ShapeProto shape = 1;
required DataType data_type = 2;
required ShapeProto time_shape = 3;
}

enum SliceBoxingTaskMode {
kSliceBoxingTaskModeInvalid = 0;
kSliceBoxingTaskModeCopy = 1;
kSliceBoxingTaskModeAdd = 2;
}

message SliceBoxingTaskProto {
map<int64, TensorSliceViewProto> in_data_edge_uid2slice = 1;
repeated int64 ordered_in_data_edge_uid = 2;
required TensorSliceViewProto out_slice = 3;
required ShapeProto out_shape = 4;
required SliceBoxingTaskMode mode = 5;
}

message CollectiveBoxingPackTaskProto {
required ShapeProto logical_shape = 1;
required SbpParallel src_sbp_parallel = 2;
required SbpParallel dst_sbp_parallel = 3;
required int64 parallel_num = 4;
}

message CollectiveBoxingUnpackTaskProto {
required ShapeProto logical_shape = 1;
required SbpParallel src_sbp_parallel = 2;
required SbpParallel dst_sbp_parallel = 3;
required int64 parallel_num = 4;
}

message BoxingIdentityTaskProto {
}

message TransportTaskProto {
required TaskProto task_proto = 1;
required LogicalBlobId lbi = 11;
oneof transport_task_type {
CollectiveBoxingGenericTaskProto collective_boxing_generic_task = 2;
NcclSendRecvBoxingTaskProto nccl_send_recv_boxing_task = 3;
CopyHdTaskProto copy_hd_task = 4;
CopyCommNetTaskProto copy_comm_net_task = 5;
BoxingZerosTaskProto boxing_zeros_task = 6;
SliceBoxingTaskProto slice_boxing_task = 7;
CollectiveBoxingPackTaskProto collective_boxing_pack_task = 8;
CollectiveBoxingUnpackTaskProto collective_boxing_unpack_task = 9;
BoxingIdentityTaskProto boxing_identity_task = 10;
}
}

message TaskIdsProto {
repeated int64 task_id = 1;
}

message BoxingTaskGraphProto {
map<string, ComputeTasksProto> boxing_related_op_name2compute_tasks = 1;
repeated TransportTaskProto transport_task = 2;
repeated TaskEdgeProto task_edge = 3;
map<string, TaskIdsProto> boxing_unrelated_op_name2task_ids = 4;
}
22 changes: 21 additions & 1 deletion oneflow/core/graph/boxing_zeros_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.
*/
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/graph/boxing_zeros_task_node.h"
#include "oneflow/core/graph/boxing_task_graph.pb.h"

namespace oneflow {

Expand Down Expand Up @@ -50,11 +51,30 @@ void BoxingZerosTaskNode::BuildExecGphAndRegst() {
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out");
out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn()));
node->BindBnWithRegst(sole_op->SoleObn(), out_regst);
node->InferBlobDescs(nullptr);
(node->*GetInferBlobDescsMethod())(nullptr);
}

void BoxingZerosTaskNode::InferProducedDataRegstTimeShape() {
GetProducedRegst("out")->mut_data_regst_time_shape()->reset(new Shape(time_shape_));
}
Maybe<void> BoxingZerosTaskNode::InitTransportTaskFromProto(
const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) {
CHECK_OR_RETURN(transport_task_proto.has_boxing_zeros_task())
<< "not a serialized BoxingZerosTaskNode. debug string: "
<< transport_task_proto.DebugString();
const auto& proto = transport_task_proto.boxing_zeros_task();
shape_ = Shape(proto.shape());
data_type_ = proto.data_type();
time_shape_ = Shape(proto.time_shape());
return Maybe<void>::Ok();
}

void BoxingZerosTaskNode::ToTransportTaskProto(TransportTaskProto* transport_task_proto) const {
ToProto(transport_task_proto->mutable_task_proto(), /*check=*/false);
auto* proto = transport_task_proto->mutable_boxing_zeros_task();
shape_.ToProto(proto->mutable_shape());
proto->set_data_type(data_type_);
time_shape_.ToProto(proto->mutable_time_shape());
}

} // namespace oneflow
4 changes: 4 additions & 0 deletions oneflow/core/graph/boxing_zeros_task_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class BoxingZerosTaskNode : public TransportTaskNode {
DataType data_type, const Shape& time_shape);
TaskType GetTaskType() const override { return TaskType::kBoxingZeros; }

Maybe<void> InitTransportTaskFromProto(const TransportTaskProto& transport_task_proto,
const TaskGraphRebuildCtx& ctx) override;
void ToTransportTaskProto(TransportTaskProto*) const override;

private:
void BuildExecGphAndRegst() override;
void ProduceAllRegstsAndBindEdges() override;
Expand Down
26 changes: 25 additions & 1 deletion oneflow/core/graph/collective_boxing_pack_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.
*/
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/graph/collective_boxing_pack_task_node.h"
#include "oneflow/core/graph/boxing_task_graph.pb.h"

namespace oneflow {

Expand Down Expand Up @@ -59,11 +60,34 @@ void CollectiveBoxingPackTaskNode::BuildExecGphAndRegst() {
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out");
out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn()));
node->BindBnWithRegst(sole_op->SoleObn(), out_regst);
node->InferBlobDescs(nullptr);
(node->*GetInferBlobDescsMethod())(nullptr);
}

void CollectiveBoxingPackTaskNode::InferProducedDataRegstTimeShape() {
NaiveInferProducedDataRegstTimeShape();
}

Maybe<void> CollectiveBoxingPackTaskNode::InitTransportTaskFromProto(
const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) {
CHECK_OR_RETURN(transport_task_proto.has_collective_boxing_pack_task())
<< "not a serialized CollectiveBoxingPackTaskNode. debug string: "
<< transport_task_proto.DebugString();
const auto& proto = transport_task_proto.collective_boxing_pack_task();
logical_shape_ = Shape(proto.logical_shape());
src_sbp_parallel_ = proto.src_sbp_parallel();
dst_sbp_parallel_ = proto.dst_sbp_parallel();
parallel_num_ = proto.parallel_num();
return Maybe<void>::Ok();
}

void CollectiveBoxingPackTaskNode::ToTransportTaskProto(
TransportTaskProto* transport_task_proto) const {
ToProto(transport_task_proto->mutable_task_proto(), /*check=*/false);
auto* proto = transport_task_proto->mutable_collective_boxing_pack_task();
logical_shape_.ToProto(proto->mutable_logical_shape());
*proto->mutable_src_sbp_parallel() = src_sbp_parallel_;
*proto->mutable_dst_sbp_parallel() = dst_sbp_parallel_;
proto->set_parallel_num(parallel_num_);
}

} // namespace oneflow
4 changes: 4 additions & 0 deletions oneflow/core/graph/collective_boxing_pack_task_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ class CollectiveBoxingPackTaskNode : public TransportTaskNode {
const SbpParallel& dst_sbp_parallel, const int64_t parallel_num);
TaskType GetTaskType() const override { return TaskType::kCollectiveBoxingPack; }

Maybe<void> InitTransportTaskFromProto(const TransportTaskProto& transport_task_proto,
const TaskGraphRebuildCtx& ctx) override;
void ToTransportTaskProto(TransportTaskProto*) const override;

private:
void BuildExecGphAndRegst() override;
void ProduceAllRegstsAndBindEdges() override;
Expand Down
18 changes: 17 additions & 1 deletion oneflow/core/graph/collective_boxing_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/graph/boxing_task_graph.pb.h"
#include "oneflow/core/graph/collective_boxing_task_node.h"
#include "oneflow/core/graph/boxing/collective_boxing_util.h"

Expand Down Expand Up @@ -54,12 +55,27 @@ void CollectiveBoxingGenericTaskNode::BuildExecGphAndRegst() {
node->BindBnWithRegst(obn, out_regst);
out_regst->AddLbi(boxing_op->BnInOp2Lbi(obn));
}
node->InferBlobDescs(nullptr);
(node->*GetInferBlobDescsMethod())(nullptr);
}

void CollectiveBoxingGenericTaskNode::InferProducedDataRegstTimeShape() {
auto out_regst = GetProducedRegst("out");
if (out_regst != nullptr) { out_regst->mut_data_regst_time_shape()->reset(new Shape({1, 1})); }
}

Maybe<void> CollectiveBoxingGenericTaskNode::InitTransportTaskFromProto(
const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) {
CHECK_OR_RETURN(transport_task_proto.has_collective_boxing_generic_task())
<< "not a serialized CollectiveBoxingGenericTaskNode. debug string: "
<< transport_task_proto.DebugString();
op_conf_ = transport_task_proto.collective_boxing_generic_task().op_conf();
return Maybe<void>::Ok();
}

void CollectiveBoxingGenericTaskNode::ToTransportTaskProto(
TransportTaskProto* transport_task_proto) const {
ToProto(transport_task_proto->mutable_task_proto(), /*check=*/false);
*transport_task_proto->mutable_collective_boxing_generic_task()->mutable_op_conf() = op_conf_;
}

} // namespace oneflow
4 changes: 4 additions & 0 deletions oneflow/core/graph/collective_boxing_task_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ class CollectiveBoxingGenericTaskNode : public TransportTaskNode {
void Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi,
const OperatorConf& op_conf);

Maybe<void> InitTransportTaskFromProto(const TransportTaskProto& transport_task_proto,
const TaskGraphRebuildCtx& ctx) override;
void ToTransportTaskProto(TransportTaskProto*) const override;

private:
void BuildExecGphAndRegst() override;
void ProduceAllRegstsAndBindEdges() override;
Expand Down
26 changes: 25 additions & 1 deletion oneflow/core/graph/collective_boxing_unpack_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/graph/boxing_task_graph.pb.h"
#include "oneflow/core/graph/collective_boxing_unpack_task_node.h"

namespace oneflow {
Expand Down Expand Up @@ -59,11 +60,34 @@ void CollectiveBoxingUnpackTaskNode::BuildExecGphAndRegst() {
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out");
out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn()));
node->BindBnWithRegst(sole_op->SoleObn(), out_regst);
node->InferBlobDescs(nullptr);
(node->*GetInferBlobDescsMethod())(nullptr);
}

void CollectiveBoxingUnpackTaskNode::InferProducedDataRegstTimeShape() {
NaiveInferProducedDataRegstTimeShape();
}

Maybe<void> CollectiveBoxingUnpackTaskNode::InitTransportTaskFromProto(
const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) {
CHECK_OR_RETURN(transport_task_proto.has_collective_boxing_unpack_task())
<< "not a serialized CollectiveBoxingUnpackTaskNode. debug string: "
<< transport_task_proto.DebugString();
const auto& proto = transport_task_proto.collective_boxing_unpack_task();
logical_shape_ = Shape(proto.logical_shape());
src_sbp_parallel_ = proto.src_sbp_parallel();
dst_sbp_parallel_ = proto.dst_sbp_parallel();
parallel_num_ = proto.parallel_num();
return Maybe<void>::Ok();
}

void CollectiveBoxingUnpackTaskNode::ToTransportTaskProto(
TransportTaskProto* transport_task_proto) const {
ToProto(transport_task_proto->mutable_task_proto(), /*check=*/false);
auto* proto = transport_task_proto->mutable_collective_boxing_unpack_task();
logical_shape_.ToProto(proto->mutable_logical_shape());
*proto->mutable_src_sbp_parallel() = src_sbp_parallel_;
*proto->mutable_dst_sbp_parallel() = dst_sbp_parallel_;
proto->set_parallel_num(parallel_num_);
}

} // namespace oneflow
4 changes: 4 additions & 0 deletions oneflow/core/graph/collective_boxing_unpack_task_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ class CollectiveBoxingUnpackTaskNode : public TransportTaskNode {

TaskType GetTaskType() const override { return TaskType::kCollectiveBoxingUnpack; }

Maybe<void> InitTransportTaskFromProto(const TransportTaskProto& transport_task_proto,
const TaskGraphRebuildCtx& ctx) override;
void ToTransportTaskProto(TransportTaskProto*) const override;

private:
void BuildExecGphAndRegst() override;
void ProduceAllRegstsAndBindEdges() override;
Expand Down
Loading