-
Notifications
You must be signed in to change notification settings - Fork 802
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refine device-related code #9791
Conversation
Signed-off-by: daquexian <[email protected]>
…resolve conflict Signed-off-by: daquexian <[email protected]>
oneflow/core/framework/device.cpp
Outdated
@@ -62,37 +58,36 @@ Maybe<void> Device::Init() { | |||
} | |||
|
|||
/* static */ Maybe<Symbol<Device>> Device::New(const std::string& type, int64_t device_id) { | |||
return ThreadLocalGetOrNew(type, device_id); | |||
return ThreadLocalGetOrCreate(type, device_id); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Device::New最好改成直接构造一个新的device,而不是调用ThreadLocalGetOrCreate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是什么原因呢,如果这样改的话有一个问题是和 Symbol<Device>
不兼容,symbol 是全局唯一的只读对象
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
也没有不兼容,构造device之后调用一下SymbolOf(device)就可以了。New不会缓存symbol,ThreadLocalGetOrCreate才缓存,这和这两个接口的命名也比较吻合,否则这两个接口就只需要保留一个就行了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
更新:讨论后决定把两个接口只保留一个
} | ||
|
||
template<> | ||
void AttrValueAccessor<Symbol<Device>>::Attr(const Symbol<Device>& cpp_val, AttrValue* attr_val) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个接口的第二个参数感觉可以换成引用,另外这个接口感觉改成SetAttr或者AsignAttr会更合适点(当然这个接口在这个PR里改动也是不必要的,以后再说吧
oneflow/core/framework/device.cpp
Outdated
} | ||
return iter->second; | ||
} | ||
|
||
/* static */ Maybe<Symbol<Device>> Device::ThreadLocalGetOrNew( | ||
const std::string& type_or_type_with_device_id) { | ||
/* static */ Maybe<Symbol<Device>> Device::ThreadLocalGetOrCreate(const std::string& str) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
str这个变量名可以换一个更有意义的名字
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,改成 device_str 了,这个和 PyTorch 的命名对齐 https://pytorch.org/cppdocs/api/structc10_1_1_device.html#_CPPv4N3c106Device6DeviceERKNSt6stringE ,相应的 ParseDeviceTag 方法也改成了 ParseDeviceString
Signed-off-by: daquexian <[email protected]>
Signed-off-by: daquexian <[email protected]>
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally. |
Static analysis with clang failed. PR label automerge has been removed |
Signed-off-by: daquexian <[email protected]>
Speed stats:
|
CI failed when running job: cuda-speed-test. PR label automerge has been removed |
Signed-off-by: daquexian <[email protected]>
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally. |
Static analysis with clang failed. PR label automerge has been removed |
Speed stats:
|
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/9791/ |
Speed stats:
|
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/9791/ |
核心的逻辑是: 1. 用不同的 device 区分支持/不支持重计算的 tensor 2. 在 remat::Allocator 里实现了选择 cost 最低的 tensor 并 evict 的逻辑(对内存布局和 evict 方式的优化就是在这里) 3. 在 OpCallInstructionUtil::Compute 里实现了重新计算出被用到但已被 evict 的 tensor 的逻辑 其他的都是一些周边改动 使用方式: ```python x1 = flow.ones(3).to('cuda+remat') # 移动到支持重计算的 device 上 x2 = flow.ones(3).to('cuda') # 移动到不支持重计算的 device 上 x3 = x2 + x3 # 报错:device 不同 # ----- model = ResNet50() model.to('cuda+remat') data, label = dataloader() data, label = data.to('cuda+remat'), label.to('cuda+remat') loss = model(data) # 如果过程中显存满了,会自动丢弃一些 tensor loss.backward() # 如果接下来又用到被丢弃的 tensor,会自动把它们重新计算出来 ``` ---- 一部分通用的改动已经在前置 PR 里被合并: * #9698 * #9791 * #9850 * #9851 --------- Signed-off-by: daquexian <[email protected]> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> Co-authored-by: oneflow-ci-bot <[email protected]> Co-authored-by: Peihong Liu <[email protected]>
DTR 的设计中打算用不同的 device 区分开启/不开启重计算的 tensor(和 torch/xla 做法相同),实现过程中发现 device 相关的代码有些可改进的地方
device 无法作为 op 的 attr,master 里用分别设置 device_type 和 device_id 两个 attr 来代替,因此产生了很多无中生有的代码:
这些冗余代码在给 Device 类增加新参数时也会引起额外的改动量
一些地方错误地使用了 Optional::value_or,如
一些命名问题,如
ParsingDeviceTag
没有用动词(改为ParseDeviceTag
)、Device::ThreadLocalGetOrNew
和Device::New
功能相同,"New" 的含义却互相冲突(删掉了Device::ThreadLocalGetOrNew
)operator== 和 operator!= 逻辑重复