Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix and refine graph save runtime_state_dict #10016

Merged
merged 6 commits into from
Mar 22, 2023

Conversation

strint
Copy link
Contributor

@strint strint commented Mar 21, 2023

@strint strint requested review from BBuf and daquexian as code owners March 21, 2023 12:20
return new_fn


def _test_linear_multi_graph_share(test_case, device, with_reshape):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

过时的

@@ -154,7 +154,7 @@ def forward(self, x):
linear_reshape = LinearReshapeModule()

class LinearGraph(flow.nn.Graph):
@flow.nn.Graph.with_dynamic_input_shape()
@flow.nn.Graph.with_dynamic_input_shape(size=3)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

验证超出 size、出现淘汰,可以正常支持 save 和 load

@@ -226,7 +251,7 @@ def forward(self, x):
linear_reshape = LinearReshapeModule()

class LinearGraph(flow.nn.Graph):
@flow.nn.Graph.with_dynamic_input_shape(size=4)
@flow.nn.Graph.with_dynamic_input_shape(size=2)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

验证 load 时,也可以正常淘汰

assert not self.is_full()

if not (self.keep_the_1st and self.is_empty()):
self.queue.appendleft(key)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

避免第一个 base 的图被淘汰

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是如何避免第一个 base 的图被淘汰的?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是如何避免第一个 base 的图被淘汰的?

self.queue 存的可以淘汰的cache key,LRU的策略,top是最老的,end是最新的,然后根据使用情况调整顺序,维护 LRU 需要的 cache key 排序。

这里避免 base 被淘汰的方法,是因为 base graph 一定是第一个加入 cache 的,不把它加入 self.queue 他就一定不会被淘汰了。

@strint strint added bug graph graph mode labels Mar 21, 2023
@strint strint requested a review from oneflow-ci-bot March 21, 2023 12:26
@github-actions
Copy link
Contributor

Speed stats:
GPU Name: GeForce GTX 1080 

❌ OneFlow resnet50 time: 141.0ms (= 14099.4ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 146.1ms (= 14612.8ms / 100, input_shape=[16, 3, 224, 224])
❌ Relative speed: 1.04 (= 146.1ms / 141.0ms)

OneFlow resnet50 time: 80.6ms (= 8064.4ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 85.2ms (= 8517.8ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.06 (= 85.2ms / 80.6ms)

OneFlow resnet50 time: 48.5ms (= 9696.3ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 57.9ms (= 11575.8ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.19 (= 57.9ms / 48.5ms)

OneFlow resnet50 time: 32.2ms (= 6440.3ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 41.4ms (= 8287.8ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.29 (= 41.4ms / 32.2ms)

OneFlow resnet50 time: 25.2ms (= 5032.1ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 38.9ms (= 7785.5ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.55 (= 38.9ms / 25.2ms)

OneFlow swin dataloader time: 0.237s (= 47.489s / 200, num_workers=1)
PyTorch swin dataloader time: 0.150s (= 29.998s / 200, num_workers=1)
Relative speed: 0.632 (= 0.150s / 0.237s)

OneFlow swin dataloader time: 0.070s (= 13.930s / 200, num_workers=4)
PyTorch swin dataloader time: 0.040s (= 8.090s / 200, num_workers=4)
Relative speed: 0.581 (= 0.040s / 0.070s)

OneFlow swin dataloader time: 0.041s (= 8.245s / 200, num_workers=8)
PyTorch swin dataloader time: 0.023s (= 4.586s / 200, num_workers=8)
Relative speed: 0.556 (= 0.023s / 0.041s)

❌ OneFlow resnet50 time: 152.4ms (= 15241.2ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 161.7ms (= 16169.5ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
❌ Relative speed: 1.06 (= 161.7ms / 152.4ms)

OneFlow resnet50 time: 91.1ms (= 9114.3ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 103.1ms (= 10307.7ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.13 (= 103.1ms / 91.1ms)

OneFlow resnet50 time: 58.9ms (= 11785.1ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 78.5ms (= 15693.5ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.33 (= 78.5ms / 58.9ms)

OneFlow resnet50 time: 42.0ms (= 8409.3ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 73.4ms (= 14689.9ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.75 (= 73.4ms / 42.0ms)

OneFlow resnet50 time: 38.0ms (= 7598.6ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 75.6ms (= 15111.9ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.99 (= 75.6ms / 38.0ms)

@github-actions
Copy link
Contributor

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/10016/

sd unet graph save/load test:
https://github.com/Oneflow-Inc/diffusers/blob/main/examples/unet_torch_interplay.py

## with_eager True
```
saving graphs...
get state dict time:  0.33900022506713867
state_dict(with_eager=True) tensors size  2768.363235473633  MB.
save state dict time:  13.160083293914795
```

```
loading graphs...
load state dict time:  2.118417739868164
state_dict tensors size  2768.363235473633  MB.
load into graph time:  3.463979482650757
```

## with_eager False
Set with_eager to False
```
runtime_state_dict(with_eager=False)
```

```
saving graphs...
get state dict time:  0.3210592269897461
state_dict(with_eager=False) tensors size  1128.9570999145508  MB.
save state dict time:  5.13447380065918
```

```
loading graphs...
load state dict time:  0.9304521083831787
state_dict tensors size  1128.9570999145508  MB.
load into graph time:  3.212466239929199
```

**with_eager False can make state_dict size and time cost reduced 56%**
# This state tensor is from eager module. Just save a dummy tensor here.
_state_tensor_tuple4save.append(
oneflow.Tensor().to(
self._state_tensor_tuple[state_idx].device
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果不需要保存 eager module 上的 tensor【通常 eager module 上的 tensor 会通过 eager 做保存和加载】,则可以只存一个空 tensor,减小 runtime_state_dict 的大小

test_case1 = np.array_equal(of_lazy_out1.numpy(), of_eager_out1.numpy())
return_dict["save4"] = test_case1

state_dict = linear_g.runtime_state_dict(with_eager=with_eager)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

测试不保存 eager module tensor

@@ -1011,13 +1012,15 @@ def enable_save_runtime_state_dict(self, mode: bool = True):
self._enable_save_runtime_state_dict = False

def runtime_state_dict(
self, destination=None
self, destination=None, with_eager=False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

默认不保存 eager module 上的 tensor

@github-actions
Copy link
Contributor

Speed stats:

@strint strint changed the title keep the 1st graph for save Fixgraph for save Mar 22, 2023
@strint strint changed the title Fixgraph for save Fix andgraph for save Mar 22, 2023
@strint strint changed the title Fix andgraph for save Fix and refine graph save runtime_state_dict Mar 22, 2023
else:
assert len(self._state_tensor_tuple) == len(self._state_op_names)
for state_idx in range(len(self._state_tensor_tuple)):
if self._state_op_names[state_idx] in self._eager_state_op_names:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

什么情况下存在 eager module 持有,graph state 不持有的 tensor。 是被 constant folding 的 tensor 吗?

eager free tensor 是 graph state tensor 吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

什么情况下存在 eager module 持有,graph state 不持有的 tensor。 是被 constant folding 的 tensor 吗?

是的

eager free tensor 是 graph state tensor 吗?

嗯,也是。但是不是 eager module 上的。

assert not self.is_full()

if not (self.keep_the_1st and self.is_empty()):
self.queue.appendleft(key)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是如何避免第一个 base 的图被淘汰的?

while self.is_queue_full():
pop_key = self.pop()
if self.is_full():
old_key = self.pop()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的 pop 是不是会把最久远 base 的 key 给 pop 出来

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的 pop 是不是会把最久远 base 的 key 给 pop 出来

是的

@github-actions
Copy link
Contributor

Speed stats:
GPU Name: GeForce GTX 1080 

❌ OneFlow resnet50 time: 141.1ms (= 14106.9ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 142.7ms (= 14271.6ms / 100, input_shape=[16, 3, 224, 224])
❌ Relative speed: 1.01 (= 142.7ms / 141.1ms)

OneFlow resnet50 time: 80.6ms (= 8061.9ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 84.4ms (= 8438.0ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.05 (= 84.4ms / 80.6ms)

OneFlow resnet50 time: 49.9ms (= 9977.3ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 64.7ms (= 12945.4ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.30 (= 64.7ms / 49.9ms)

OneFlow resnet50 time: 33.1ms (= 6616.4ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 43.3ms (= 8669.4ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.31 (= 43.3ms / 33.1ms)

OneFlow resnet50 time: 26.5ms (= 5304.9ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 43.3ms (= 8651.9ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.63 (= 43.3ms / 26.5ms)

OneFlow swin dataloader time: 0.236s (= 47.185s / 200, num_workers=1)
PyTorch swin dataloader time: 0.150s (= 29.931s / 200, num_workers=1)
Relative speed: 0.634 (= 0.150s / 0.236s)

OneFlow swin dataloader time: 0.068s (= 13.675s / 200, num_workers=4)
PyTorch swin dataloader time: 0.042s (= 8.337s / 200, num_workers=4)
Relative speed: 0.610 (= 0.042s / 0.068s)

OneFlow swin dataloader time: 0.040s (= 8.058s / 200, num_workers=8)
PyTorch swin dataloader time: 0.022s (= 4.409s / 200, num_workers=8)
Relative speed: 0.547 (= 0.022s / 0.040s)

❌ OneFlow resnet50 time: 153.0ms (= 15297.1ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 164.0ms (= 16397.7ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
❌ Relative speed: 1.07 (= 164.0ms / 153.0ms)

OneFlow resnet50 time: 91.6ms (= 9160.1ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 103.3ms (= 10330.1ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.13 (= 103.3ms / 91.6ms)

OneFlow resnet50 time: 59.5ms (= 11909.6ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 79.2ms (= 15841.5ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.33 (= 79.2ms / 59.5ms)

OneFlow resnet50 time: 42.9ms (= 8577.2ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 72.1ms (= 14418.5ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.68 (= 72.1ms / 42.9ms)

OneFlow resnet50 time: 36.1ms (= 7212.4ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 68.9ms (= 13771.2ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.91 (= 68.9ms / 36.1ms)

@strint
Copy link
Contributor Author

strint commented Mar 22, 2023

sd unet graph save/load test: https://github.com/Oneflow-Inc/diffusers/blob/main/examples/unet_torch_interplay.py

with_eager True

saving graphs...
get state dict time:  0.33900022506713867
state_dict(with_eager=True) tensors size  2768.363235473633  MB.
save state dict time:  13.160083293914795
loading graphs...
load state dict time:  2.118417739868164
state_dict tensors size  2768.363235473633  MB.
load into graph time:  3.463979482650757

with_eager False

Set with_eager to False

runtime_state_dict(with_eager=False)
saving graphs...
get state dict time:  0.3210592269897461
state_dict(with_eager=False) tensors size  1128.9570999145508  MB.
save state dict time:  5.13447380065918
loading graphs...
load state dict time:  0.9304521083831787
state_dict tensors size  1128.9570999145508  MB.
load into graph time:  3.212466239929199

with_eager False can make state_dict size and time cost reduced 56%

@strint
Copy link
Contributor Author

strint commented Mar 22, 2023

CI error: FAILED python/oneflow/test/modules/test_one_embedding_adam.py::TestOptimizers::test_one_embedding_adam

https://github.com/Oneflow-Inc/oneflow/actions/runs/4485531702/jobs/7888144384?pr=10016

无关的 CI 错误。

@strint strint requested review from oneflow-ci-bot and removed request for oneflow-ci-bot March 22, 2023 03:40
@github-actions
Copy link
Contributor

Speed stats:
GPU Name: GeForce GTX 1080 

❌ OneFlow resnet50 time: 141.0ms (= 14097.8ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 141.4ms (= 14137.5ms / 100, input_shape=[16, 3, 224, 224])
❌ Relative speed: 1.00 (= 141.4ms / 141.0ms)

OneFlow resnet50 time: 80.7ms (= 8074.3ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 84.3ms (= 8425.5ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.04 (= 84.3ms / 80.7ms)

OneFlow resnet50 time: 49.8ms (= 9958.2ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 57.1ms (= 11413.4ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.15 (= 57.1ms / 49.8ms)

OneFlow resnet50 time: 33.4ms (= 6687.3ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 43.6ms (= 8722.9ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.30 (= 43.6ms / 33.4ms)

OneFlow resnet50 time: 25.2ms (= 5043.2ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 42.6ms (= 8523.6ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.69 (= 42.6ms / 25.2ms)

OneFlow swin dataloader time: 0.243s (= 48.519s / 200, num_workers=1)
PyTorch swin dataloader time: 0.152s (= 30.355s / 200, num_workers=1)
Relative speed: 0.626 (= 0.152s / 0.243s)

OneFlow swin dataloader time: 0.071s (= 14.191s / 200, num_workers=4)
PyTorch swin dataloader time: 0.043s (= 8.535s / 200, num_workers=4)
Relative speed: 0.601 (= 0.043s / 0.071s)

OneFlow swin dataloader time: 0.042s (= 8.385s / 200, num_workers=8)
PyTorch swin dataloader time: 0.023s (= 4.535s / 200, num_workers=8)
Relative speed: 0.541 (= 0.023s / 0.042s)

❌ OneFlow resnet50 time: 152.7ms (= 15274.3ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 162.7ms (= 16267.3ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
❌ Relative speed: 1.07 (= 162.7ms / 152.7ms)

OneFlow resnet50 time: 91.6ms (= 9163.9ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 102.2ms (= 10221.8ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.12 (= 102.2ms / 91.6ms)

OneFlow resnet50 time: 59.5ms (= 11909.9ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 83.7ms (= 16747.5ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.41 (= 83.7ms / 59.5ms)

OneFlow resnet50 time: 41.5ms (= 8293.7ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 70.8ms (= 14163.4ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.71 (= 70.8ms / 41.5ms)

OneFlow resnet50 time: 37.9ms (= 7587.2ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 68.9ms (= 13772.4ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.82 (= 68.9ms / 37.9ms)

@github-actions
Copy link
Contributor

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/10016/

@strint strint requested review from oneflow-ci-bot and removed request for oneflow-ci-bot March 22, 2023 04:08
@github-actions
Copy link
Contributor

Speed stats:
GPU Name: GeForce GTX 1080 

❌ OneFlow resnet50 time: 141.0ms (= 14104.8ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 145.9ms (= 14585.1ms / 100, input_shape=[16, 3, 224, 224])
❌ Relative speed: 1.03 (= 145.9ms / 141.0ms)

OneFlow resnet50 time: 80.8ms (= 8084.7ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 84.7ms (= 8465.5ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.05 (= 84.7ms / 80.8ms)

OneFlow resnet50 time: 50.3ms (= 10060.6ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 55.4ms (= 11071.3ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.10 (= 55.4ms / 50.3ms)

OneFlow resnet50 time: 33.2ms (= 6645.6ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 42.3ms (= 8458.9ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.27 (= 42.3ms / 33.2ms)

OneFlow resnet50 time: 26.3ms (= 5256.0ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 38.7ms (= 7730.6ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.47 (= 38.7ms / 26.3ms)

OneFlow swin dataloader time: 0.241s (= 48.147s / 200, num_workers=1)
PyTorch swin dataloader time: 0.149s (= 29.755s / 200, num_workers=1)
Relative speed: 0.618 (= 0.149s / 0.241s)

OneFlow swin dataloader time: 0.071s (= 14.120s / 200, num_workers=4)
PyTorch swin dataloader time: 0.041s (= 8.204s / 200, num_workers=4)
Relative speed: 0.581 (= 0.041s / 0.071s)

OneFlow swin dataloader time: 0.042s (= 8.380s / 200, num_workers=8)
PyTorch swin dataloader time: 0.022s (= 4.395s / 200, num_workers=8)
Relative speed: 0.524 (= 0.022s / 0.042s)

❌ OneFlow resnet50 time: 152.8ms (= 15281.8ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 165.8ms (= 16576.2ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
❌ Relative speed: 1.08 (= 165.8ms / 152.8ms)

OneFlow resnet50 time: 91.6ms (= 9157.4ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 109.4ms (= 10935.6ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.19 (= 109.4ms / 91.6ms)

OneFlow resnet50 time: 59.3ms (= 11852.2ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 78.8ms (= 15762.8ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.33 (= 78.8ms / 59.3ms)

OneFlow resnet50 time: 42.3ms (= 8460.0ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 72.6ms (= 14517.9ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.72 (= 72.6ms / 42.3ms)

OneFlow resnet50 time: 36.5ms (= 7290.4ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 70.0ms (= 14008.5ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.92 (= 70.0ms / 36.5ms)

@github-actions
Copy link
Contributor

CI failed when running job: cuda-speed-test. PR label automerge has been removed

@strint strint requested review from oneflow-ci-bot and removed request for oneflow-ci-bot March 22, 2023 08:27
@github-actions
Copy link
Contributor

Speed stats:
GPU Name: GeForce GTX 1080 

❌ OneFlow resnet50 time: 141.0ms (= 14098.2ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 149.1ms (= 14908.5ms / 100, input_shape=[16, 3, 224, 224])
✔️ Relative speed: 1.06 (= 149.1ms / 141.0ms)

OneFlow resnet50 time: 80.9ms (= 8087.8ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 86.7ms (= 8669.2ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.07 (= 86.7ms / 80.9ms)

OneFlow resnet50 time: 50.5ms (= 10094.5ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 60.7ms (= 12140.0ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.20 (= 60.7ms / 50.5ms)

OneFlow resnet50 time: 33.4ms (= 6682.2ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 42.8ms (= 8551.1ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.28 (= 42.8ms / 33.4ms)

OneFlow resnet50 time: 25.0ms (= 5004.1ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 39.9ms (= 7984.3ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.60 (= 39.9ms / 25.0ms)

OneFlow swin dataloader time: 0.236s (= 47.190s / 200, num_workers=1)
PyTorch swin dataloader time: 0.150s (= 30.013s / 200, num_workers=1)
Relative speed: 0.636 (= 0.150s / 0.236s)

OneFlow swin dataloader time: 0.071s (= 14.272s / 200, num_workers=4)
PyTorch swin dataloader time: 0.042s (= 8.417s / 200, num_workers=4)
Relative speed: 0.590 (= 0.042s / 0.071s)

OneFlow swin dataloader time: 0.043s (= 8.582s / 200, num_workers=8)
PyTorch swin dataloader time: 0.022s (= 4.428s / 200, num_workers=8)
Relative speed: 0.516 (= 0.022s / 0.043s)

❌ OneFlow resnet50 time: 152.8ms (= 15276.6ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 164.9ms (= 16488.5ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
❌ Relative speed: 1.08 (= 164.9ms / 152.8ms)

OneFlow resnet50 time: 92.0ms (= 9203.2ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 105.7ms (= 10565.0ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.15 (= 105.7ms / 92.0ms)

OneFlow resnet50 time: 60.4ms (= 12085.2ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 81.0ms (= 16202.3ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.34 (= 81.0ms / 60.4ms)

OneFlow resnet50 time: 42.0ms (= 8394.9ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 73.3ms (= 14657.8ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.75 (= 73.3ms / 42.0ms)

OneFlow resnet50 time: 37.0ms (= 7394.7ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 66.9ms (= 13386.0ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.81 (= 66.9ms / 37.0ms)

@github-actions
Copy link
Contributor

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/10016/

@mergify mergify bot merged commit 2d85205 into master Mar 22, 2023
@mergify mergify bot deleted the fix_graph_cache_out_of_size branch March 22, 2023 11:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants