Skip to content

Commit

Permalink
refine device-related code (#9791)
Browse files Browse the repository at this point in the history
DTR 的设计中打算用不同的 device 区分开启/不开启重计算的 tensor(和 torch/xla 做法相同),实现过程中发现
device 相关的代码有些可改进的地方

1. device 无法作为 op 的 attr,master 里用分别设置 device_type 和 device_id 两个 attr
来代替,因此产生了很多无中生有的代码:
    ```c++
    Device a;
    op.SetAttr(a.device_type(), a.device_id());
Device b = Device::New(op.attr("device_type"), op.attr("device_id"));
    // 直接 Device b = a 显然更简单
    ```
    
    ```c++
inline Maybe<bool> device_equal(const std::string& device_name, const
int device_id,
                                    Symbol<Device> device) {
return (device_name == device->type() && device_id ==
device->device_id());
    }
    Device a;
    op.SetAttr(a.device_type(), a.device_id());
    if (device_equal(op.attr("device_type"), op.attr("device_id"), b))
    // 直接 if (a == b) 显然更简单
    ```
    这些冗余代码在给 Device 类增加新参数时也会引起额外的改动量
2. 一些地方错误地使用了 Optional::value_or,如
    ```c++
    auto device =
device_.has_value() ? device_.value_or(Symbol<Device>()) :
JUST(input->device());
    ```
3. 一些命名问题,如 `ParsingDeviceTag` 没有用动词(改为
`ParseDeviceTag`)、`Device::ThreadLocalGetOrNew` 和 `Device::New`
功能相同,"New" 的含义互相冲突(删掉了 `Device::ThreadLocalGetOrNew`)
3. operator== 和 operator!= 逻辑重复

---------

Signed-off-by: daquexian <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: oneflow-ci-bot <[email protected]>
  • Loading branch information
3 people authored Feb 1, 2023
1 parent aef9981 commit 4bfef84
Show file tree
Hide file tree
Showing 14 changed files with 132 additions and 120 deletions.
8 changes: 4 additions & 4 deletions oneflow/api/python/framework/random_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ namespace py = pybind11;

namespace oneflow {

Maybe<one::Generator> CreateGenerator(const std::string& device_tag) {
Maybe<one::Generator> CreateGenerator(const std::string& device_str) {
std::string device_name = "";
int device_index = -1;
JUST(ParsingDeviceTag(device_tag, &device_name, &device_index));
JUST(ParseDeviceString(device_str, &device_name, &device_index));
return one::MakeGenerator(device_name, device_index);
}

Expand Down Expand Up @@ -58,10 +58,10 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
return one::ManualSeed(seed_val, device, device_index);
});
m.def("create_generator", &CreateGenerator);
m.def("default_generator", [](const std::string& device_tag) -> Maybe<one::Generator> {
m.def("default_generator", [](const std::string& device_str) -> Maybe<one::Generator> {
std::string device_name = "";
int device_index = -1;
JUST(ParsingDeviceTag(device_tag, &device_name, &device_index));
JUST(ParseDeviceString(device_str, &device_name, &device_index));
return one::DefaultGenerator(device_name, device_index);
});
m.def("ManualSeedAllCudaGenerator", [](const py::object& seed) -> Maybe<void> {
Expand Down
9 changes: 9 additions & 0 deletions oneflow/core/common/device.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
syntax = "proto2";
package oneflow;

import "oneflow/core/common/device_type.proto";

message DeviceProto {
required DeviceType device_type = 1;
required int64 device_id = 2;
}
6 changes: 5 additions & 1 deletion oneflow/core/framework/attr_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
#define ONEFLOW_CORE_FRAMEWORK_ATTR_VALUE_H_

#include "fmt/core.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/user_op_attr.pb.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/shape.h"
Expand Down Expand Up @@ -61,14 +62,17 @@ namespace user_op {
#define LIST_STRING_ATTR_SEQ \
OF_PP_MAKE_TUPLE_SEQ(at_list_string, std::vector<std::string>, AttrType::kAtListString)

#define DEVICE_ATTR_SEQ OF_PP_MAKE_TUPLE_SEQ(at_device, Symbol<Device>, AttrType::kAtDevice)

#define ATTR_SEQ \
BASIC_ATTR_SEQ \
ENUM_ATTR_SEQ \
MESSAGE_ATTR_SEQ \
LIST_BASIC_ATTR_SEQ \
LIST_ENUM_ATTR_SEQ \
LIST_MESSAGE_ATTR_SEQ \
LIST_STRING_ATTR_SEQ
LIST_STRING_ATTR_SEQ \
DEVICE_ATTR_SEQ

// Type Trait: GetAttrType, GetCppType

Expand Down
15 changes: 15 additions & 0 deletions oneflow/core/framework/attr_value_accessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ limitations under the License.
#include "oneflow/core/common/stride.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/framework/attr_value.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/framework/user_op_conf.h"

namespace oneflow {
Expand Down Expand Up @@ -67,6 +69,19 @@ void AttrValueAccessor<Stride>::Attr(const Stride& cpp_val, AttrValue* attr_val)
cpp_val.ToProto(attr_val->mutable_at_stride());
}

template<>
Symbol<Device> AttrValueAccessor<Symbol<Device>>::Attr(const AttrValue& val) {
auto pb_device = val.at_device();
return CHECK_JUST(Device::New(*CHECK_JUST(DeviceTag4DeviceType(pb_device.device_type())),
pb_device.device_id()));
}

template<>
void AttrValueAccessor<Symbol<Device>>::Attr(const Symbol<Device>& cpp_val, AttrValue* attr_val) {
attr_val->mutable_at_device()->set_device_type(cpp_val->enum_type());
attr_val->mutable_at_device()->set_device_id(cpp_val->device_id());
}

// List of Basic Attr
#define LIST_BASIC_ATTR_SEQ_ENTRY(field, cpp_type, attr_type) \
template<> \
Expand Down
52 changes: 21 additions & 31 deletions oneflow/core/framework/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ namespace oneflow {

namespace {

inline size_t HashDevice(const std::string& type, int64_t device_id) {
return Hash(type, device_id);
}

void CheckDeviceType(const std::string& type) {
if (!TRY(DeviceType4DeviceTag(type)).IsOk()) {
std::string error_msg = "Expected one of " + PrintAvailableDevices()
Expand All @@ -48,7 +44,7 @@ Device::Device(const std::string& type, int64_t device_id)
: type_(type),
enum_type_(kInvalidDevice),
device_id_(device_id),
hash_value_(HashDevice(type, device_id)) {}
hash_value_(Hash(type, device_id)) {}

Maybe<void> Device::Init() {
if (type_ == "auto") { return Maybe<void>::Ok(); }
Expand All @@ -62,50 +58,39 @@ Maybe<void> Device::Init() {
}

/* static */ Maybe<Symbol<Device>> Device::New(const std::string& type, int64_t device_id) {
return ThreadLocalGetOrNew(type, device_id);
}

/* static */ Maybe<Symbol<Device>> Device::ThreadLocalGetOrNew(const std::string& type,
int64_t device_id) {
CHECK_GE_OR_RETURN(device_id, 0)
<< Error::InvalidValueError() << "Device ID should be non-negative";
static thread_local HashMap<std::string, HashMap<int64_t, Symbol<Device>>> map;
auto* device_id2symbol = &map[type];
auto iter = device_id2symbol->find(device_id);
if (iter == device_id2symbol->end()) {
static thread_local HashMap<std::tuple<std::string, int>, Symbol<Device>> map;
auto key = std::make_tuple(type, device_id);
auto iter = map.find(key);
if (iter == map.end()) {
Device device(type, device_id);
JUST(device.Init());
iter = device_id2symbol->emplace(device_id, SymbolOf(device)).first;
iter = map.emplace(key, SymbolOf(device)).first;
}
return iter->second;
}

/* static */ Maybe<Symbol<Device>> Device::ThreadLocalGetOrNew(
const std::string& type_or_type_with_device_id) {
/* static */ Maybe<Symbol<Device>> Device::New(const std::string& type) {
return New(type, GlobalProcessCtx::LocalRank());
}

/* static */ Maybe<Symbol<Device>> Device::ParseAndNew(const std::string& device_str) {
static thread_local HashMap<std::string, Symbol<Device>> map;
auto iter = map.find(type_or_type_with_device_id);
auto iter = map.find(device_str);
if (iter == map.end()) {
std::string type;
int device_id = -1;
JUST(ParsingDeviceTag(type_or_type_with_device_id, &type, &device_id));
JUST(ParseDeviceString(device_str, &type, &device_id));
CheckDeviceType(type);
if (device_id == -1) { device_id = GlobalProcessCtx::LocalRank(); }
Device device(type, device_id);
JUST(device.Init());
iter = map.emplace(type_or_type_with_device_id, SymbolOf(device)).first;
iter = map.emplace(device_str, SymbolOf(device)).first;
}
return iter->second;
}

/* static */ Maybe<Symbol<Device>> Device::New(const std::string& type) {
return New(type, GlobalProcessCtx::LocalRank());
}

/* static */ Maybe<Symbol<Device>> Device::ParseAndNew(
const std::string& type_or_type_with_device_id) {
return ThreadLocalGetOrNew(type_or_type_with_device_id);
}

std::string Device::ToRepr() const {
std::stringstream ss;
ss << "device(type='";
Expand All @@ -116,6 +101,11 @@ std::string Device::ToRepr() const {
return ss.str();
}

std::ostream& operator<<(std::ostream& os, Symbol<Device> device) {
os << device->ToRepr();
return os;
}

std::string Device::ToString() const {
std::stringstream ss;
ss << type_;
Expand Down Expand Up @@ -165,8 +155,8 @@ decltype(Device::GetPlacement) Device::GetPlacement =
DECORATE(&RawGetPlacement, ThreadLocalCopiable);
decltype(Placement4Device) Placement4Device = DECORATE(&RawPlacement4Device, ThreadLocal);

Maybe<void> ParsingDeviceTag(const std::string& device_tag, std::string* device_name,
int* device_index) {
Maybe<void> ParseDeviceString(const std::string& device_tag, std::string* device_name,
int* device_index) {
std::string::size_type pos = device_tag.find(':');
if (pos == std::string::npos) {
*device_name = device_tag;
Expand Down
17 changes: 10 additions & 7 deletions oneflow/core/framework/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ limitations under the License.
#ifndef ONEFLOW_CORE_FRAMEWORK_DEVICE_H_
#define ONEFLOW_CORE_FRAMEWORK_DEVICE_H_

#include <fmt/core.h>
#include <fmt/ostream.h>
#include <memory>
#include <string>
#include <unordered_set>
Expand Down Expand Up @@ -47,13 +49,9 @@ class Device final {
bool operator==(const Device& device) const {
return type_ == device.type() && device_id_ == device.device_id();
}
bool operator!=(const Device& device) const {
return !(type_ == device.type() && device_id_ == device.device_id());
}
bool operator!=(const Device& device) const { return !operator==(device); }
const std::shared_ptr<MemoryCase>& mem_case() const { return mem_case_; }

static Maybe<Symbol<Device>> ThreadLocalGetOrNew(const std::string& type, int64_t device_id);
static Maybe<Symbol<Device>> ThreadLocalGetOrNew(const std::string& type_or_type_with_device_id);
static Maybe<Symbol<Device>> New(const std::string& type, int64_t device_id);
static Maybe<Symbol<Device>> New(const std::string& type);
static Maybe<Symbol<Device>> ParseAndNew(const std::string& type_or_type_with_device_id);
Expand All @@ -73,13 +71,18 @@ class Device final {
std::shared_ptr<MemoryCase> mem_case_;
};

std::ostream& operator<<(std::ostream& os, Symbol<Device> device);

extern Maybe<Symbol<ParallelDesc>> (*Placement4Device)(Symbol<Device> device);

Maybe<void> ParsingDeviceTag(const std::string& device_tag, std::string* device_name,
int* device_index);
Maybe<void> ParseDeviceString(const std::string& device_tag, std::string* device_name,
int* device_index);

} // namespace oneflow

template<>
struct fmt::formatter<oneflow::Symbol<oneflow::Device>> : ostream_formatter {};

namespace std {
template<>
struct hash<oneflow::Device> final {
Expand Down
15 changes: 7 additions & 8 deletions oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -715,20 +715,19 @@ Maybe<void> LazyInterpreterApplyImplForCopyUserOpExpr(const UserOpExpr& op_expr,
input_lbn = TensorNameScope::Global()->Lookup(input_tensor);
}
CHECK_OR_RETURN(!input_lbn.empty()); // NOLINT(maybe-need-error-msg)
std::string device_type = JUST(ctx.attrs.GetAttr<std::string>("device_type"));
int64_t device_id = JUST(ctx.attrs.GetAttr<int64_t>("device_id"));
auto device = JUST(ctx.attrs.GetAttr<Symbol<Device>>("device"));

CHECK_EQ_OR_RETURN(outputs->size(), 1); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(op_expr.output_size(), 1); // NOLINT(maybe-need-error-msg)
if (input_tensor->is_local()) {
(*outputs)[0] = JUST(LocalTensor::MakeTensor(
input_tensor->shape(), JUST(input_tensor->stride()), input_tensor->dtype()->data_type(),
JUST(Device::New(device_type, device_id)),
/* is_lazy= */ true,
/*requires_grad=*/false, /*is_leaf=*/true));
(*outputs)[0] =
JUST(LocalTensor::MakeTensor(input_tensor->shape(), JUST(input_tensor->stride()),
input_tensor->dtype()->data_type(), device,
/* is_lazy= */ true,
/*requires_grad=*/false, /*is_leaf=*/true));
} else {
ParallelConf parallel_conf = JUST(input_tensor->parallel_desc())->parallel_conf();
parallel_conf.set_device_tag(device_type);
parallel_conf.set_device_tag(device->type());
ParallelDesc parallel_desc(parallel_conf);
(*outputs)[0] =
JUST(GlobalTensor::MakeTensor(input_tensor->shape(), input_tensor->dtype()->data_type(),
Expand Down
20 changes: 9 additions & 11 deletions oneflow/core/framework/random_generator_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.
*/
#include "oneflow/core/framework/random_generator_impl.h"

#include "oneflow/core/common/container_util.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/cpp_attribute.h"
#include "oneflow/core/common/str_util.h"
Expand Down Expand Up @@ -76,7 +77,7 @@ Maybe<Tensor> CPUGeneratorImpl::GetState() const {
<< splits.size() - 1;
}
for (int i = 0; i < CPUGeneratorState::state_size; ++i) {
state.states[i] = std::atoll(splits.at(i).data());
state.states[i] = std::atoll(JUST(VectorAt(splits, i)).data());
}
state.seed = current_seed();

Expand Down Expand Up @@ -144,7 +145,7 @@ int GetThreadNum(const cudaDeviceProp& prop) {

CUDAGeneratorImpl::CUDAGeneratorImpl(uint64_t seed, int device_index)
: DeviceGeneratorImpl(seed, DeviceType::kCUDA, device_index), philox_offset_per_thread_(0) {
cudaDeviceProp prop;
cudaDeviceProp prop; // NOLINT
OF_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_index));
max_block_num_ = prop.multiProcessorCount;
max_thread_num_ = GetThreadNum(prop);
Expand Down Expand Up @@ -313,14 +314,14 @@ Maybe<Tensor> AutoGeneratorImpl::GetState() const {
memcpy(data, device_tags.data(), state.device_tag_length);
data += state.device_tag_length;
for (int i = 0; i < tensor_states.size(); ++i) {
const auto& tensor = tensor_states.at(i);
const auto& tensor = tensor_states[i];
const auto& callback = [&data, &state_sizes, i](
ep::Stream*,
const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {
memcpy(data, eager_blob_object->dptr(), state_sizes.at(i));
};
JUST(SyncAccessTensorWithTimeOut(tensor, callback, "const"));
data += state_sizes.at(i);
data += JUST(VectorAt(state_sizes, i));
}
}
const auto& device = JUST(Device::New("cpu"));
Expand Down Expand Up @@ -370,7 +371,7 @@ Maybe<void> AutoGeneratorImpl::SetState(const std::shared_ptr<Tensor>& tensor_st
data += state.device_tag_length;
std::vector<std::shared_ptr<Tensor>> tensor_states(state.num);
for (int i = 0; i < state.num; ++i) {
int64_t state_size = state_sizes.at(i);
int64_t state_size = JUST(VectorAt(state_sizes, i));
tensor_states[i] =
JUST(functional::Empty(Shape{state_size}, DType::UInt8(), device, /*pin_memory=*/false));
const auto& callback = [&data, &state_size](
Expand All @@ -395,21 +396,18 @@ Maybe<void> AutoGeneratorImpl::SetState(const std::shared_ptr<Tensor>& tensor_st
std::lock_guard<std::mutex> lock(mutex_);

for (int i = 0; i < splits.size(); ++i) {
std::string device_name;
int device_index = -1;
JUST(ParsingDeviceTag(splits.at(i), &device_name, &device_index));
const auto& device = JUST(Device::ParseAndNew(splits[i]));
detail::DeviceKey device_key;
const auto& device = JUST(Device::New(device_name, device_index));
device_key.device_type = JUST(DeviceType4DeviceTag(device->type()));
device_key.device_index = device_index;
device_key.device_index = device->device_id();
auto it = generators_.find(device_key);
if (it == generators_.end()) {
it = generators_
.emplace(device_key, JUST(detail::MakeGeneratorImpl(seed_, device_key.device_type,
device_key.device_index)))
.first;
}
JUST(it->second->SetState(tensor_states.at(i)));
JUST(it->second->SetState(JUST(VectorAt(tensor_states, i))));
}
return Maybe<void>::Ok();
}
Expand Down
4 changes: 1 addition & 3 deletions oneflow/core/framework/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,9 @@ std::shared_ptr<Tensor> LocalTensor::pin_memory() const {
}

Maybe<Tensor> LocalTensor::clone() const {
const auto& device_type = JUST(this->device())->type();
int64_t device_id = JUST(this->device())->device_id();
std::shared_ptr<Tensor> input = std::const_pointer_cast<Tensor>(shared_from_this());
const bool pin_memory = JUST(JUST(input->AsLocalTensor())->is_pinned());
return JUST(functional::Copy(input, device_type, device_id, /*pin_memory=*/pin_memory));
return JUST(functional::Copy(input, JUST(this->device()), /*pin_memory=*/pin_memory));
}

Maybe<void> LocalTensor::set_data(const std::shared_ptr<Tensor>& other) {
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/framework/user_op_attr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package oneflow;
import "oneflow/core/common/shape.proto";
import "oneflow/core/common/sequential.proto";
import "oneflow/core/common/data_type.proto";
import "oneflow/core/common/device.proto";

enum AttrType {
kAtInt32 = 1;
Expand All @@ -22,6 +23,7 @@ enum AttrType {
kAtListString = 14;
kAtStride = 15;
kAtListStride = 16;
kAtDevice = 17;
}

message AttrValue {
Expand Down Expand Up @@ -64,6 +66,7 @@ message AttrValue {
ListString at_list_string = 14;
Int64ListProto at_stride = 15;
ListStride at_list_stride = 16;
DeviceProto at_device = 17;
}
}

Expand Down
5 changes: 4 additions & 1 deletion oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1603,7 +1603,10 @@
bind_python: False

- name: "copy"
signature: "Tensor (Tensor x, String device_type, Int64 device_id, Bool pin_memory=False) => Copy"
signature: [
"Tensor (Tensor x, String device_type, Int64 device_id, Bool pin_memory=False) => Copy",
"Tensor (Tensor x, Device device, Bool pin_memory=False) => Copy"
]
bind_python: True

- name: "to"
Expand Down
Loading

0 comments on commit 4bfef84

Please sign in to comment.