diff --git a/oneflow/api/python/framework/random_generator.cpp b/oneflow/api/python/framework/random_generator.cpp index b7b5f8937d5..cda2a53598a 100644 --- a/oneflow/api/python/framework/random_generator.cpp +++ b/oneflow/api/python/framework/random_generator.cpp @@ -23,10 +23,10 @@ namespace py = pybind11; namespace oneflow { -Maybe CreateGenerator(const std::string& device_tag) { +Maybe CreateGenerator(const std::string& device_str) { std::string device_name = ""; int device_index = -1; - JUST(ParsingDeviceTag(device_tag, &device_name, &device_index)); + JUST(ParseDeviceString(device_str, &device_name, &device_index)); return one::MakeGenerator(device_name, device_index); } @@ -58,10 +58,10 @@ ONEFLOW_API_PYBIND11_MODULE("", m) { return one::ManualSeed(seed_val, device, device_index); }); m.def("create_generator", &CreateGenerator); - m.def("default_generator", [](const std::string& device_tag) -> Maybe { + m.def("default_generator", [](const std::string& device_str) -> Maybe { std::string device_name = ""; int device_index = -1; - JUST(ParsingDeviceTag(device_tag, &device_name, &device_index)); + JUST(ParseDeviceString(device_str, &device_name, &device_index)); return one::DefaultGenerator(device_name, device_index); }); m.def("ManualSeedAllCudaGenerator", [](const py::object& seed) -> Maybe { diff --git a/oneflow/core/common/device.proto b/oneflow/core/common/device.proto new file mode 100644 index 00000000000..91976eac581 --- /dev/null +++ b/oneflow/core/common/device.proto @@ -0,0 +1,9 @@ +syntax = "proto2"; +package oneflow; + +import "oneflow/core/common/device_type.proto"; + +message DeviceProto { + required DeviceType device_type = 1; + required int64 device_id = 2; +} diff --git a/oneflow/core/framework/attr_value.h b/oneflow/core/framework/attr_value.h index 90372c3f537..db2cdc615e0 100644 --- a/oneflow/core/framework/attr_value.h +++ b/oneflow/core/framework/attr_value.h @@ -17,6 +17,7 @@ limitations under the License. #define ONEFLOW_CORE_FRAMEWORK_ATTR_VALUE_H_ #include "fmt/core.h" +#include "oneflow/core/framework/device.h" #include "oneflow/core/framework/user_op_attr.pb.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/shape.h" @@ -61,6 +62,8 @@ namespace user_op { #define LIST_STRING_ATTR_SEQ \ OF_PP_MAKE_TUPLE_SEQ(at_list_string, std::vector, AttrType::kAtListString) +#define DEVICE_ATTR_SEQ OF_PP_MAKE_TUPLE_SEQ(at_device, Symbol, AttrType::kAtDevice) + #define ATTR_SEQ \ BASIC_ATTR_SEQ \ ENUM_ATTR_SEQ \ @@ -68,7 +71,8 @@ namespace user_op { LIST_BASIC_ATTR_SEQ \ LIST_ENUM_ATTR_SEQ \ LIST_MESSAGE_ATTR_SEQ \ - LIST_STRING_ATTR_SEQ + LIST_STRING_ATTR_SEQ \ + DEVICE_ATTR_SEQ // Type Trait: GetAttrType, GetCppType diff --git a/oneflow/core/framework/attr_value_accessor.cpp b/oneflow/core/framework/attr_value_accessor.cpp index 1cc2eeb78fc..454b84eb3ba 100644 --- a/oneflow/core/framework/attr_value_accessor.cpp +++ b/oneflow/core/framework/attr_value_accessor.cpp @@ -19,6 +19,8 @@ limitations under the License. #include "oneflow/core/common/stride.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/framework/attr_value.h" +#include "oneflow/core/framework/device.h" +#include "oneflow/core/framework/to_string.h" #include "oneflow/core/framework/user_op_conf.h" namespace oneflow { @@ -67,6 +69,19 @@ void AttrValueAccessor::Attr(const Stride& cpp_val, AttrValue* attr_val) cpp_val.ToProto(attr_val->mutable_at_stride()); } +template<> +Symbol AttrValueAccessor>::Attr(const AttrValue& val) { + auto pb_device = val.at_device(); + return CHECK_JUST(Device::New(*CHECK_JUST(DeviceTag4DeviceType(pb_device.device_type())), + pb_device.device_id())); +} + +template<> +void AttrValueAccessor>::Attr(const Symbol& cpp_val, AttrValue* attr_val) { + attr_val->mutable_at_device()->set_device_type(cpp_val->enum_type()); + attr_val->mutable_at_device()->set_device_id(cpp_val->device_id()); +} + // List of Basic Attr #define LIST_BASIC_ATTR_SEQ_ENTRY(field, cpp_type, attr_type) \ template<> \ diff --git a/oneflow/core/framework/device.cpp b/oneflow/core/framework/device.cpp index 1304578095d..797376367d8 100644 --- a/oneflow/core/framework/device.cpp +++ b/oneflow/core/framework/device.cpp @@ -30,10 +30,6 @@ namespace oneflow { namespace { -inline size_t HashDevice(const std::string& type, int64_t device_id) { - return Hash(type, device_id); -} - void CheckDeviceType(const std::string& type) { if (!TRY(DeviceType4DeviceTag(type)).IsOk()) { std::string error_msg = "Expected one of " + PrintAvailableDevices() @@ -48,7 +44,7 @@ Device::Device(const std::string& type, int64_t device_id) : type_(type), enum_type_(kInvalidDevice), device_id_(device_id), - hash_value_(HashDevice(type, device_id)) {} + hash_value_(Hash(type, device_id)) {} Maybe Device::Init() { if (type_ == "auto") { return Maybe::Ok(); } @@ -62,50 +58,39 @@ Maybe Device::Init() { } /* static */ Maybe> Device::New(const std::string& type, int64_t device_id) { - return ThreadLocalGetOrNew(type, device_id); -} - -/* static */ Maybe> Device::ThreadLocalGetOrNew(const std::string& type, - int64_t device_id) { CHECK_GE_OR_RETURN(device_id, 0) << Error::InvalidValueError() << "Device ID should be non-negative"; - static thread_local HashMap>> map; - auto* device_id2symbol = &map[type]; - auto iter = device_id2symbol->find(device_id); - if (iter == device_id2symbol->end()) { + static thread_local HashMap, Symbol> map; + auto key = std::make_tuple(type, device_id); + auto iter = map.find(key); + if (iter == map.end()) { Device device(type, device_id); JUST(device.Init()); - iter = device_id2symbol->emplace(device_id, SymbolOf(device)).first; + iter = map.emplace(key, SymbolOf(device)).first; } return iter->second; } -/* static */ Maybe> Device::ThreadLocalGetOrNew( - const std::string& type_or_type_with_device_id) { +/* static */ Maybe> Device::New(const std::string& type) { + return New(type, GlobalProcessCtx::LocalRank()); +} + +/* static */ Maybe> Device::ParseAndNew(const std::string& device_str) { static thread_local HashMap> map; - auto iter = map.find(type_or_type_with_device_id); + auto iter = map.find(device_str); if (iter == map.end()) { std::string type; int device_id = -1; - JUST(ParsingDeviceTag(type_or_type_with_device_id, &type, &device_id)); + JUST(ParseDeviceString(device_str, &type, &device_id)); CheckDeviceType(type); if (device_id == -1) { device_id = GlobalProcessCtx::LocalRank(); } Device device(type, device_id); JUST(device.Init()); - iter = map.emplace(type_or_type_with_device_id, SymbolOf(device)).first; + iter = map.emplace(device_str, SymbolOf(device)).first; } return iter->second; } -/* static */ Maybe> Device::New(const std::string& type) { - return New(type, GlobalProcessCtx::LocalRank()); -} - -/* static */ Maybe> Device::ParseAndNew( - const std::string& type_or_type_with_device_id) { - return ThreadLocalGetOrNew(type_or_type_with_device_id); -} - std::string Device::ToRepr() const { std::stringstream ss; ss << "device(type='"; @@ -116,6 +101,11 @@ std::string Device::ToRepr() const { return ss.str(); } +std::ostream& operator<<(std::ostream& os, Symbol device) { + os << device->ToRepr(); + return os; +} + std::string Device::ToString() const { std::stringstream ss; ss << type_; @@ -165,8 +155,8 @@ decltype(Device::GetPlacement) Device::GetPlacement = DECORATE(&RawGetPlacement, ThreadLocalCopiable); decltype(Placement4Device) Placement4Device = DECORATE(&RawPlacement4Device, ThreadLocal); -Maybe ParsingDeviceTag(const std::string& device_tag, std::string* device_name, - int* device_index) { +Maybe ParseDeviceString(const std::string& device_tag, std::string* device_name, + int* device_index) { std::string::size_type pos = device_tag.find(':'); if (pos == std::string::npos) { *device_name = device_tag; diff --git a/oneflow/core/framework/device.h b/oneflow/core/framework/device.h index 028d5501960..a107f59d8dd 100644 --- a/oneflow/core/framework/device.h +++ b/oneflow/core/framework/device.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef ONEFLOW_CORE_FRAMEWORK_DEVICE_H_ #define ONEFLOW_CORE_FRAMEWORK_DEVICE_H_ +#include +#include #include #include #include @@ -47,13 +49,9 @@ class Device final { bool operator==(const Device& device) const { return type_ == device.type() && device_id_ == device.device_id(); } - bool operator!=(const Device& device) const { - return !(type_ == device.type() && device_id_ == device.device_id()); - } + bool operator!=(const Device& device) const { return !operator==(device); } const std::shared_ptr& mem_case() const { return mem_case_; } - static Maybe> ThreadLocalGetOrNew(const std::string& type, int64_t device_id); - static Maybe> ThreadLocalGetOrNew(const std::string& type_or_type_with_device_id); static Maybe> New(const std::string& type, int64_t device_id); static Maybe> New(const std::string& type); static Maybe> ParseAndNew(const std::string& type_or_type_with_device_id); @@ -73,13 +71,18 @@ class Device final { std::shared_ptr mem_case_; }; +std::ostream& operator<<(std::ostream& os, Symbol device); + extern Maybe> (*Placement4Device)(Symbol device); -Maybe ParsingDeviceTag(const std::string& device_tag, std::string* device_name, - int* device_index); +Maybe ParseDeviceString(const std::string& device_tag, std::string* device_name, + int* device_index); } // namespace oneflow +template<> +struct fmt::formatter> : ostream_formatter {}; + namespace std { template<> struct hash final { diff --git a/oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp index ab00fd922bd..08725c5ad3b 100644 --- a/oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp @@ -715,20 +715,19 @@ Maybe LazyInterpreterApplyImplForCopyUserOpExpr(const UserOpExpr& op_expr, input_lbn = TensorNameScope::Global()->Lookup(input_tensor); } CHECK_OR_RETURN(!input_lbn.empty()); // NOLINT(maybe-need-error-msg) - std::string device_type = JUST(ctx.attrs.GetAttr("device_type")); - int64_t device_id = JUST(ctx.attrs.GetAttr("device_id")); + auto device = JUST(ctx.attrs.GetAttr>("device")); CHECK_EQ_OR_RETURN(outputs->size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(op_expr.output_size(), 1); // NOLINT(maybe-need-error-msg) if (input_tensor->is_local()) { - (*outputs)[0] = JUST(LocalTensor::MakeTensor( - input_tensor->shape(), JUST(input_tensor->stride()), input_tensor->dtype()->data_type(), - JUST(Device::New(device_type, device_id)), - /* is_lazy= */ true, - /*requires_grad=*/false, /*is_leaf=*/true)); + (*outputs)[0] = + JUST(LocalTensor::MakeTensor(input_tensor->shape(), JUST(input_tensor->stride()), + input_tensor->dtype()->data_type(), device, + /* is_lazy= */ true, + /*requires_grad=*/false, /*is_leaf=*/true)); } else { ParallelConf parallel_conf = JUST(input_tensor->parallel_desc())->parallel_conf(); - parallel_conf.set_device_tag(device_type); + parallel_conf.set_device_tag(device->type()); ParallelDesc parallel_desc(parallel_conf); (*outputs)[0] = JUST(GlobalTensor::MakeTensor(input_tensor->shape(), input_tensor->dtype()->data_type(), diff --git a/oneflow/core/framework/random_generator_impl.cpp b/oneflow/core/framework/random_generator_impl.cpp index de5cbd910a8..dc9d528d0fd 100644 --- a/oneflow/core/framework/random_generator_impl.cpp +++ b/oneflow/core/framework/random_generator_impl.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/framework/random_generator_impl.h" +#include "oneflow/core/common/container_util.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/cpp_attribute.h" #include "oneflow/core/common/str_util.h" @@ -76,7 +77,7 @@ Maybe CPUGeneratorImpl::GetState() const { << splits.size() - 1; } for (int i = 0; i < CPUGeneratorState::state_size; ++i) { - state.states[i] = std::atoll(splits.at(i).data()); + state.states[i] = std::atoll(JUST(VectorAt(splits, i)).data()); } state.seed = current_seed(); @@ -144,7 +145,7 @@ int GetThreadNum(const cudaDeviceProp& prop) { CUDAGeneratorImpl::CUDAGeneratorImpl(uint64_t seed, int device_index) : DeviceGeneratorImpl(seed, DeviceType::kCUDA, device_index), philox_offset_per_thread_(0) { - cudaDeviceProp prop; + cudaDeviceProp prop; // NOLINT OF_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_index)); max_block_num_ = prop.multiProcessorCount; max_thread_num_ = GetThreadNum(prop); @@ -313,14 +314,14 @@ Maybe AutoGeneratorImpl::GetState() const { memcpy(data, device_tags.data(), state.device_tag_length); data += state.device_tag_length; for (int i = 0; i < tensor_states.size(); ++i) { - const auto& tensor = tensor_states.at(i); + const auto& tensor = tensor_states[i]; const auto& callback = [&data, &state_sizes, i]( ep::Stream*, const std::shared_ptr& eager_blob_object) { memcpy(data, eager_blob_object->dptr(), state_sizes.at(i)); }; JUST(SyncAccessTensorWithTimeOut(tensor, callback, "const")); - data += state_sizes.at(i); + data += JUST(VectorAt(state_sizes, i)); } } const auto& device = JUST(Device::New("cpu")); @@ -370,7 +371,7 @@ Maybe AutoGeneratorImpl::SetState(const std::shared_ptr& tensor_st data += state.device_tag_length; std::vector> tensor_states(state.num); for (int i = 0; i < state.num; ++i) { - int64_t state_size = state_sizes.at(i); + int64_t state_size = JUST(VectorAt(state_sizes, i)); tensor_states[i] = JUST(functional::Empty(Shape{state_size}, DType::UInt8(), device, /*pin_memory=*/false)); const auto& callback = [&data, &state_size]( @@ -395,13 +396,10 @@ Maybe AutoGeneratorImpl::SetState(const std::shared_ptr& tensor_st std::lock_guard lock(mutex_); for (int i = 0; i < splits.size(); ++i) { - std::string device_name; - int device_index = -1; - JUST(ParsingDeviceTag(splits.at(i), &device_name, &device_index)); + const auto& device = JUST(Device::ParseAndNew(splits[i])); detail::DeviceKey device_key; - const auto& device = JUST(Device::New(device_name, device_index)); device_key.device_type = JUST(DeviceType4DeviceTag(device->type())); - device_key.device_index = device_index; + device_key.device_index = device->device_id(); auto it = generators_.find(device_key); if (it == generators_.end()) { it = generators_ @@ -409,7 +407,7 @@ Maybe AutoGeneratorImpl::SetState(const std::shared_ptr& tensor_st device_key.device_index))) .first; } - JUST(it->second->SetState(tensor_states.at(i))); + JUST(it->second->SetState(JUST(VectorAt(tensor_states, i)))); } return Maybe::Ok(); } diff --git a/oneflow/core/framework/tensor.cpp b/oneflow/core/framework/tensor.cpp index 64f3a6a9bc4..4ae50e66e08 100644 --- a/oneflow/core/framework/tensor.cpp +++ b/oneflow/core/framework/tensor.cpp @@ -99,11 +99,9 @@ std::shared_ptr LocalTensor::pin_memory() const { } Maybe LocalTensor::clone() const { - const auto& device_type = JUST(this->device())->type(); - int64_t device_id = JUST(this->device())->device_id(); std::shared_ptr input = std::const_pointer_cast(shared_from_this()); const bool pin_memory = JUST(JUST(input->AsLocalTensor())->is_pinned()); - return JUST(functional::Copy(input, device_type, device_id, /*pin_memory=*/pin_memory)); + return JUST(functional::Copy(input, JUST(this->device()), /*pin_memory=*/pin_memory)); } Maybe LocalTensor::set_data(const std::shared_ptr& other) { diff --git a/oneflow/core/framework/user_op_attr.proto b/oneflow/core/framework/user_op_attr.proto index 033f290c364..be56cfb1a97 100644 --- a/oneflow/core/framework/user_op_attr.proto +++ b/oneflow/core/framework/user_op_attr.proto @@ -4,6 +4,7 @@ package oneflow; import "oneflow/core/common/shape.proto"; import "oneflow/core/common/sequential.proto"; import "oneflow/core/common/data_type.proto"; +import "oneflow/core/common/device.proto"; enum AttrType { kAtInt32 = 1; @@ -22,6 +23,7 @@ enum AttrType { kAtListString = 14; kAtStride = 15; kAtListStride = 16; + kAtDevice = 17; } message AttrValue { @@ -64,6 +66,7 @@ message AttrValue { ListString at_list_string = 14; Int64ListProto at_stride = 15; ListStride at_list_stride = 16; + DeviceProto at_device = 17; } } diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index c3f2efc2e31..edd9e194c9e 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -1603,7 +1603,10 @@ bind_python: False - name: "copy" - signature: "Tensor (Tensor x, String device_type, Int64 device_id, Bool pin_memory=False) => Copy" + signature: [ + "Tensor (Tensor x, String device_type, Int64 device_id, Bool pin_memory=False) => Copy", + "Tensor (Tensor x, Device device, Bool pin_memory=False) => Copy" + ] bind_python: True - name: "to" diff --git a/oneflow/core/functional/impl/array_functor.cpp b/oneflow/core/functional/impl/array_functor.cpp index abf82c42c38..5a624efa58f 100644 --- a/oneflow/core/functional/impl/array_functor.cpp +++ b/oneflow/core/functional/impl/array_functor.cpp @@ -1720,16 +1720,18 @@ class UpsampleGradFunctor { std::shared_ptr op_; }; -class CopyFunctor { +class CopyToDeviceFunctor { public: - CopyFunctor() { op_ = CHECK_JUST(one::OpBuilder("copy").Input("in").Output("out").Build()); } - Maybe operator()(const std::shared_ptr& x, const std::string& device_type, - const int64_t& device_id, const bool pin_memory) const { - auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("device_type", "device_id", "pin_memory"); - attrs.SetAllAttrs(device_type, device_id, pin_memory); + CopyToDeviceFunctor() { + op_ = CHECK_JUST(one::OpBuilder("copy").Input("in").Output("out").Build()); + } + Maybe operator()(const std::shared_ptr& x, Symbol device, + const bool pin_memory) const { + auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("device", "pin_memory"); + attrs.SetAllAttrs(device, pin_memory); #ifdef WITH_CUDA - if (device_type == "cuda") { InitCudaContextOnce(device_id); } + if (device->enum_type() == DeviceType::kCUDA) { InitCudaContextOnce(device->device_id()); } #endif return OpInterpUtil::Dispatch(*op_, {x}, attrs); } @@ -1738,6 +1740,14 @@ class CopyFunctor { std::shared_ptr op_; }; +class CopyFunctor { + public: + Maybe operator()(const std::shared_ptr& x, const std::string& device_type, + const int64_t& device_id, const bool pin_memory) const { + return functional::Copy(x, JUST(Device::New(device_type, device_id)), pin_memory); + } +}; + class FlipFunctor { public: FlipFunctor() { op_ = CHECK_JUST(one::OpBuilder("flip").Input("x").Output("y").Build()); } @@ -2982,21 +2992,13 @@ class IndexSelectFunctor { }; namespace { -inline Maybe device_equal(const std::string& device_name, const int device_id, - Symbol device) { - return (device_name == device->type() && device_id == device->device_id()); -} -Maybe LocalTensorTo(const std::shared_ptr& x, const std::string& device_name, - const int device_id, const Symbol& dtype, const bool& copy) { +Maybe LocalTensorTo(const std::shared_ptr& x, Symbol device, + const Symbol& dtype, const bool& copy) { std::shared_ptr tensor = x; - if (!JUST(device_equal(device_name, device_id, JUST(x->device())))) { - tensor = JUST(Copy(tensor, device_name, device_id, /*pin_memory=*/false)); - } + if (device != JUST(x->device())) { tensor = JUST(Copy(tensor, device, /*pin_memory=*/false)); } if (dtype != x->dtype()) { tensor = JUST(Cast(tensor, dtype, /*pin_memory=*/false)); } - if (copy && tensor == x) { - tensor = JUST(Copy(tensor, device_name, device_id, /*pin_memory=*/false)); - } + if (copy && tensor == x) { tensor = JUST(Copy(tensor, device, /*pin_memory=*/false)); } return tensor; } @@ -3027,7 +3029,7 @@ Maybe GlobalTensorTo(const std::shared_ptr& x, const std::string for (int i = 0; i < sbp_tuple.size(); ++i) { sbp_tuple[i] = nd_sbp->sbp_parallel().Get(i); } tensor = JUST(GlobalToLocal(x, /*copy=*/false)); Symbol device = JUST(Device::New(device_type)); - tensor = JUST(LocalTensorTo(tensor, device->type(), device->device_id(), dtype, copy)); + tensor = JUST(LocalTensorTo(tensor, device, dtype, copy)); JUST(tensor->set_requires_grad(x->requires_grad())); return JUST(LocalToGlobal(tensor, placement, sbp_tuple, *(x->shape()), dtype, /* sync_data */ true, /*copy=*/false)); @@ -3051,17 +3053,13 @@ class ToFunctor { << "for global tensor, but got " << device_.value_or(""); return JUST(GlobalTensorTo(input, device_type, dtype, copy)); } else { - std::string device_name = ""; - int device_id = 0; - if (device_.has_value()) { - JUST(ParsingDeviceTag(device_.value_or(""), &device_name, &device_id)); - if (device_id == -1) { device_id = GlobalProcessCtx::LocalRank(); } - } else { - Symbol device = JUST(input->device()); - device_name = device->type(); - device_id = device->device_id(); - } - return JUST(LocalTensorTo(input, device_name, device_id, dtype, copy)); + Symbol device = + device_ + .map([](const std::shared_ptr& str) -> Symbol { + return CHECK_JUST(Device::ParseAndNew(*str)); + }) + .value_or(JUST(input->device())); + return JUST(LocalTensorTo(input, device, dtype, copy)); } } }; @@ -3087,9 +3085,8 @@ class To2Functor { } } else { auto dtype = dtype_.value_or(input->dtype()); - auto device = - device_.has_value() ? device_.value_or(Symbol()) : JUST(input->device()); - return JUST(LocalTensorTo(input, device->type(), device->device_id(), dtype, copy)); + auto device = device_.value_or(JUST(input->device())); + return JUST(LocalTensorTo(input, device, dtype, copy)); } } }; @@ -3103,7 +3100,7 @@ class To3Functor { return GlobalTensorTo(input, JUST(input->parallel_desc())->device_tag(), dtype, copy); } else { auto device = JUST(input->device()); - return LocalTensorTo(input, device->type(), device->device_id(), dtype, copy); + return LocalTensorTo(input, device, dtype, copy); } } }; @@ -3117,9 +3114,7 @@ class To4Functor { << "tensor.to(other) can only be called when tensor and other are local tensors"; Symbol dtype = other->dtype(); Symbol device = JUST(other->device()); - std::string device_name = device->type(); - int device_id = device->device_id(); - return LocalTensorTo(input, device_name, device_id, dtype, copy); + return LocalTensorTo(input, device, dtype, copy); } }; @@ -3138,17 +3133,13 @@ class ToDeviceFunctor { << "for global tensor, but got " << device_.value_or(""); return JUST(GlobalTensorTo(input, device_type, dtype, copy)); } else { - std::string device_name = ""; - int device_id = 0; - if (device_.has_value()) { - JUST(ParsingDeviceTag(device_.value_or(""), &device_name, &device_id)); - if (device_id == -1) { device_id = GlobalProcessCtx::LocalRank(); } - } else { - Symbol device = JUST(input->device()); - device_name = device->type(); - device_id = device->device_id(); - } - return JUST(LocalTensorTo(input, device_name, device_id, dtype, copy)); + Symbol device = + device_ + .map([](const std::shared_ptr& str) -> Symbol { + return CHECK_JUST(Device::ParseAndNew(*str)); + }) + .value_or(JUST(input->device())); + return JUST(LocalTensorTo(input, device, dtype, copy)); } } }; @@ -3987,7 +3978,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("Slice"); m.add_functor("SliceGrad"); m.add_functor("SliceView1dContiguous"); - m.add_functor("Copy"); + m.add_functor("Copy"); m.add_functor("Flip"); m.add_functor("UnfoldTensor"); m.add_functor("UnfoldTensorGrad"); diff --git a/oneflow/core/job/parallel_desc.cpp b/oneflow/core/job/parallel_desc.cpp index 574bdafab4f..1f0792917fd 100644 --- a/oneflow/core/job/parallel_desc.cpp +++ b/oneflow/core/job/parallel_desc.cpp @@ -159,7 +159,7 @@ Maybe> ParallelDesc::GetTensorDevice4CurrentProcessCtx( int64_t machine_id = 0; int64_t device_id = 0; GlobalProcessCtx::GetCurrentMachineIdAndDeviceId(&machine_id, &device_id); - const auto& device = JUST(Device::ThreadLocalGetOrNew(device_tag(), device_id)); + const auto& device = JUST(Device::New(device_tag(), device_id)); int64_t parallel_id_val = -1; if (TryGetParallelId(machine_id, device_id, ¶llel_id_val)) { *parallel_id = parallel_id_val; @@ -474,7 +474,7 @@ Maybe> RawGetTensorDevice(Symbol parallel_desc) { int64_t device_id = 0; GlobalProcessCtx::GetCurrentMachineIdAndDeviceId(&machine_id, &device_id); const auto& type = parallel_desc->device_tag(); - return JUST(Device::ThreadLocalGetOrNew(type, device_id)); + return JUST(Device::New(type, device_id)); } Maybe> RawTxtStringToPlacement(const std::string& parallel_conf_str) { diff --git a/oneflow/user/ops/copy_op.cpp b/oneflow/user/ops/copy_op.cpp index 434ddf4b63b..dea6063d5d0 100644 --- a/oneflow/user/ops/copy_op.cpp +++ b/oneflow/user/ops/copy_op.cpp @@ -80,8 +80,7 @@ Maybe> MakeCopyStream(const Symbol& in_device, /* static */ Maybe> CopyOp::InferDeviceAndStream( user_op::DeviceAndStreamInferContext* ctx) { - Symbol out_device = - JUST(Device::New(ctx->Attr("device_type"), ctx->Attr("device_id"))); + Symbol out_device = ctx->Attr>("device"); *ctx->OutputTensorDevice4ArgNameAndIndex("out", 0) = out_device; const Symbol& in_device = ctx->InputTensorDevice4ArgNameAndIndex("in", 0); const bool pin_memory = ctx->Attr("pin_memory");