Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for variable sequence length in fmha #9991

Merged
merged 11 commits into from
Mar 16, 2023

Conversation

liujuncheng
Copy link
Collaborator

No description provided.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个接口加个readme,中文也可以

} else {
UNIMPLEMENTED_THEN_RETURN();
}
} else if (shape.NumAxes() == 3) {
if (layout == "BM(HK)" || layout == "MB(HK)" || layout == "BM(H2K)" || layout == "MB(H2K)"
Copy link
Collaborator

@jackalcooper jackalcooper Mar 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

把这些字符串改为一个ENUM或者一系列static变量?这样万一出错了不需要肉眼看的很辛苦

@github-actions
Copy link
Contributor

Speed stats:

UNIMPLEMENTED_THEN_RETURN() << name
<< "_layout should be '(BM)(HK)', '(BM)(H2K)', or '(BM)(H3K)' "
"when the number of dimensions of "
<< name << " tensor is 3.";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为:tensor is 3. -> tensor is 2.

*b = JUST(batch_size);
*m = JUST(seq_len);
*h = shape.At(1);
*k = shape.At(2);
} else {
UNIMPLEMENTED_THEN_RETURN()
<< name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里漏了"(BM)HK"

@github-actions
Copy link
Contributor

Speed stats:
GPU Name: GeForce GTX 1080 

❌ OneFlow resnet50 time: 141.1ms (= 14109.6ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 141.7ms (= 14172.9ms / 100, input_shape=[16, 3, 224, 224])
❌ Relative speed: 1.00 (= 141.7ms / 141.1ms)

OneFlow resnet50 time: 80.6ms (= 8063.1ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 84.2ms (= 8423.8ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.04 (= 84.2ms / 80.6ms)

OneFlow resnet50 time: 49.3ms (= 9852.8ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 57.9ms (= 11574.1ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.17 (= 57.9ms / 49.3ms)

OneFlow resnet50 time: 32.4ms (= 6487.3ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 44.3ms (= 8850.8ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.36 (= 44.3ms / 32.4ms)

OneFlow resnet50 time: 25.7ms (= 5133.5ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 34.0ms (= 6796.5ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.32 (= 34.0ms / 25.7ms)

OneFlow swin dataloader time: 0.239s (= 47.864s / 200, num_workers=1)
PyTorch swin dataloader time: 0.147s (= 29.363s / 200, num_workers=1)
Relative speed: 0.613 (= 0.147s / 0.239s)

OneFlow swin dataloader time: 0.070s (= 13.964s / 200, num_workers=4)
PyTorch swin dataloader time: 0.041s (= 8.203s / 200, num_workers=4)
Relative speed: 0.587 (= 0.041s / 0.070s)

OneFlow swin dataloader time: 0.046s (= 9.200s / 200, num_workers=8)
PyTorch swin dataloader time: 0.022s (= 4.473s / 200, num_workers=8)
Relative speed: 0.486 (= 0.022s / 0.046s)

❌ OneFlow resnet50 time: 152.2ms (= 15217.7ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 161.6ms (= 16160.1ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
❌ Relative speed: 1.06 (= 161.6ms / 152.2ms)

OneFlow resnet50 time: 90.9ms (= 9094.3ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 103.2ms (= 10317.0ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.13 (= 103.2ms / 90.9ms)

OneFlow resnet50 time: 58.9ms (= 11775.3ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 79.1ms (= 15823.5ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.34 (= 79.1ms / 58.9ms)

OneFlow resnet50 time: 41.5ms (= 8290.7ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 78.9ms (= 15783.6ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.90 (= 78.9ms / 41.5ms)

OneFlow resnet50 time: 36.8ms (= 7367.2ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 75.2ms (= 15045.4ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 2.04 (= 75.2ms / 36.8ms)

@github-actions
Copy link
Contributor

Speed stats:
GPU Name: GeForce GTX 1080 

❌ OneFlow resnet50 time: 141.0ms (= 14098.3ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 143.9ms (= 14394.2ms / 100, input_shape=[16, 3, 224, 224])
❌ Relative speed: 1.02 (= 143.9ms / 141.0ms)

OneFlow resnet50 time: 80.6ms (= 8058.4ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 84.0ms (= 8396.9ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.04 (= 84.0ms / 80.6ms)

OneFlow resnet50 time: 48.9ms (= 9771.3ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 58.7ms (= 11735.3ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.20 (= 58.7ms / 48.9ms)

OneFlow resnet50 time: 32.4ms (= 6475.1ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 45.2ms (= 9038.0ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.40 (= 45.2ms / 32.4ms)

OneFlow resnet50 time: 25.7ms (= 5145.6ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 37.6ms (= 7529.7ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.46 (= 37.6ms / 25.7ms)

OneFlow swin dataloader time: 0.238s (= 47.674s / 200, num_workers=1)
PyTorch swin dataloader time: 0.158s (= 31.624s / 200, num_workers=1)
Relative speed: 0.663 (= 0.158s / 0.238s)

OneFlow swin dataloader time: 0.069s (= 13.786s / 200, num_workers=4)
PyTorch swin dataloader time: 0.042s (= 8.317s / 200, num_workers=4)
Relative speed: 0.603 (= 0.042s / 0.069s)

OneFlow swin dataloader time: 0.041s (= 8.135s / 200, num_workers=8)
PyTorch swin dataloader time: 0.022s (= 4.475s / 200, num_workers=8)
Relative speed: 0.550 (= 0.022s / 0.041s)

❌ OneFlow resnet50 time: 152.3ms (= 15228.7ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 160.8ms (= 16080.2ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
❌ Relative speed: 1.06 (= 160.8ms / 152.3ms)

OneFlow resnet50 time: 90.9ms (= 9088.1ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 101.0ms (= 10100.5ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.11 (= 101.0ms / 90.9ms)

OneFlow resnet50 time: 58.9ms (= 11778.6ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 79.8ms (= 15962.8ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.36 (= 79.8ms / 58.9ms)

OneFlow resnet50 time: 42.7ms (= 8530.6ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 78.7ms (= 15748.8ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.85 (= 78.7ms / 42.7ms)

OneFlow resnet50 time: 37.1ms (= 7429.8ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 67.9ms (= 13572.2ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.83 (= 67.9ms / 37.1ms)

@github-actions
Copy link
Contributor

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/9991/

@liujuncheng liujuncheng enabled auto-merge (squash) March 16, 2023 14:33
@github-actions
Copy link
Contributor

Speed stats:
GPU Name: GeForce GTX 1080 

❌ OneFlow resnet50 time: 140.7ms (= 14074.6ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 143.4ms (= 14344.7ms / 100, input_shape=[16, 3, 224, 224])
❌ Relative speed: 1.02 (= 143.4ms / 140.7ms)

OneFlow resnet50 time: 80.4ms (= 8041.4ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 85.1ms (= 8505.7ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.06 (= 85.1ms / 80.4ms)

OneFlow resnet50 time: 48.5ms (= 9693.7ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 58.4ms (= 11670.7ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.20 (= 58.4ms / 48.5ms)

OneFlow resnet50 time: 32.2ms (= 6448.3ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 44.6ms (= 8924.0ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.38 (= 44.6ms / 32.2ms)

OneFlow resnet50 time: 25.7ms (= 5138.7ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 40.6ms (= 8113.5ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.58 (= 40.6ms / 25.7ms)

OneFlow swin dataloader time: 0.235s (= 46.960s / 200, num_workers=1)
PyTorch swin dataloader time: 0.150s (= 30.066s / 200, num_workers=1)
Relative speed: 0.640 (= 0.150s / 0.235s)

OneFlow swin dataloader time: 0.067s (= 13.478s / 200, num_workers=4)
PyTorch swin dataloader time: 0.043s (= 8.602s / 200, num_workers=4)
Relative speed: 0.638 (= 0.043s / 0.067s)

OneFlow swin dataloader time: 0.041s (= 8.181s / 200, num_workers=8)
PyTorch swin dataloader time: 0.022s (= 4.497s / 200, num_workers=8)
Relative speed: 0.550 (= 0.022s / 0.041s)

❌ OneFlow resnet50 time: 152.7ms (= 15265.1ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 162.0ms (= 16199.3ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
❌ Relative speed: 1.06 (= 162.0ms / 152.7ms)

OneFlow resnet50 time: 91.0ms (= 9097.4ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 101.4ms (= 10144.2ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.12 (= 101.4ms / 91.0ms)

OneFlow resnet50 time: 58.8ms (= 11761.2ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 87.3ms (= 17469.8ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.49 (= 87.3ms / 58.8ms)

OneFlow resnet50 time: 41.9ms (= 8387.8ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 74.1ms (= 14827.9ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.77 (= 74.1ms / 41.9ms)

OneFlow resnet50 time: 37.4ms (= 7477.5ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 68.0ms (= 13594.7ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.82 (= 68.0ms / 37.4ms)

@github-actions
Copy link
Contributor

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/9991/

@liujuncheng liujuncheng merged commit def0c7e into master Mar 16, 2023
@liujuncheng liujuncheng deleted the dev_att_var_seq_len branch March 16, 2023 16:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants