-
Notifications
You must be signed in to change notification settings - Fork 802
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix and refine graph save runtime_state_dict #10016
Conversation
strint
commented
Mar 21, 2023
•
edited
Loading
edited
- Fix cache 超出时淘汰了 base graph,导致 load 错误
- 缩减 runtime_state_dict 大小: https://github.com/Oneflow-Inc/OneTeam/issues/1963#issuecomment-1477776137
return new_fn | ||
|
||
|
||
def _test_linear_multi_graph_share(test_case, device, with_reshape): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
过时的
@@ -154,7 +154,7 @@ def forward(self, x): | |||
linear_reshape = LinearReshapeModule() | |||
|
|||
class LinearGraph(flow.nn.Graph): | |||
@flow.nn.Graph.with_dynamic_input_shape() | |||
@flow.nn.Graph.with_dynamic_input_shape(size=3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
验证超出 size、出现淘汰,可以正常支持 save 和 load
@@ -226,7 +251,7 @@ def forward(self, x): | |||
linear_reshape = LinearReshapeModule() | |||
|
|||
class LinearGraph(flow.nn.Graph): | |||
@flow.nn.Graph.with_dynamic_input_shape(size=4) | |||
@flow.nn.Graph.with_dynamic_input_shape(size=2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
验证 load 时,也可以正常淘汰
assert not self.is_full() | ||
|
||
if not (self.keep_the_1st and self.is_empty()): | ||
self.queue.appendleft(key) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
避免第一个 base 的图被淘汰
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是如何避免第一个 base 的图被淘汰的?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是如何避免第一个 base 的图被淘汰的?
self.queue 存的可以淘汰的cache key,LRU的策略,top是最老的,end是最新的,然后根据使用情况调整顺序,维护 LRU 需要的 cache key 排序。
这里避免 base 被淘汰的方法,是因为 base graph 一定是第一个加入 cache 的,不把它加入 self.queue 他就一定不会被淘汰了。
Speed stats:
|
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/10016/ |
sd unet graph save/load test: https://github.com/Oneflow-Inc/diffusers/blob/main/examples/unet_torch_interplay.py ## with_eager True ``` saving graphs... get state dict time: 0.33900022506713867 state_dict(with_eager=True) tensors size 2768.363235473633 MB. save state dict time: 13.160083293914795 ``` ``` loading graphs... load state dict time: 2.118417739868164 state_dict tensors size 2768.363235473633 MB. load into graph time: 3.463979482650757 ``` ## with_eager False Set with_eager to False ``` runtime_state_dict(with_eager=False) ``` ``` saving graphs... get state dict time: 0.3210592269897461 state_dict(with_eager=False) tensors size 1128.9570999145508 MB. save state dict time: 5.13447380065918 ``` ``` loading graphs... load state dict time: 0.9304521083831787 state_dict tensors size 1128.9570999145508 MB. load into graph time: 3.212466239929199 ``` **with_eager False can make state_dict size and time cost reduced 56%**
# This state tensor is from eager module. Just save a dummy tensor here. | ||
_state_tensor_tuple4save.append( | ||
oneflow.Tensor().to( | ||
self._state_tensor_tuple[state_idx].device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果不需要保存 eager module 上的 tensor【通常 eager module 上的 tensor 会通过 eager 做保存和加载】,则可以只存一个空 tensor,减小 runtime_state_dict 的大小
test_case1 = np.array_equal(of_lazy_out1.numpy(), of_eager_out1.numpy()) | ||
return_dict["save4"] = test_case1 | ||
|
||
state_dict = linear_g.runtime_state_dict(with_eager=with_eager) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
测试不保存 eager module tensor
@@ -1011,13 +1012,15 @@ def enable_save_runtime_state_dict(self, mode: bool = True): | |||
self._enable_save_runtime_state_dict = False | |||
|
|||
def runtime_state_dict( | |||
self, destination=None | |||
self, destination=None, with_eager=False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
默认不保存 eager module 上的 tensor
Speed stats:
|
else: | ||
assert len(self._state_tensor_tuple) == len(self._state_op_names) | ||
for state_idx in range(len(self._state_tensor_tuple)): | ||
if self._state_op_names[state_idx] in self._eager_state_op_names: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
什么情况下存在 eager module 持有,graph state 不持有的 tensor。 是被 constant folding 的 tensor 吗?
eager free tensor 是 graph state tensor 吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
什么情况下存在 eager module 持有,graph state 不持有的 tensor。 是被 constant folding 的 tensor 吗?
是的
eager free tensor 是 graph state tensor 吗?
嗯,也是。但是不是 eager module 上的。
assert not self.is_full() | ||
|
||
if not (self.keep_the_1st and self.is_empty()): | ||
self.queue.appendleft(key) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是如何避免第一个 base 的图被淘汰的?
while self.is_queue_full(): | ||
pop_key = self.pop() | ||
if self.is_full(): | ||
old_key = self.pop() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的 pop 是不是会把最久远 base 的 key 给 pop 出来
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的 pop 是不是会把最久远 base 的 key 给 pop 出来
是的
Speed stats:
|
sd unet graph save/load test: https://github.com/Oneflow-Inc/diffusers/blob/main/examples/unet_torch_interplay.py with_eager True
with_eager FalseSet with_eager to False
with_eager False can make state_dict size and time cost reduced 56% |
CI error: FAILED python/oneflow/test/modules/test_one_embedding_adam.py::TestOptimizers::test_one_embedding_adam https://github.com/Oneflow-Inc/oneflow/actions/runs/4485531702/jobs/7888144384?pr=10016 无关的 CI 错误。 |
Speed stats:
|
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/10016/ |
Speed stats:
|
CI failed when running job: cuda-speed-test. PR label automerge has been removed |
Speed stats:
|
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/10016/ |