Skip to content

Commit e08e93f

Browse files
albanDfacebook-github-bot
authored andcommittedAug 24, 2020
Reland of benchmark code (pytorch#43428)
Summary: Reland of the benchmark code that broke the slow tests because the GPU were running out of memory Pull Request resolved: pytorch#43428 Reviewed By: ngimel Differential Revision: D23296136 Pulled By: albanD fbshipit-source-id: 0002ae23dc82f401604e33d0905d6b9eedebc851
1 parent 4cfac34 commit e08e93f

11 files changed

+2078
-0
lines changed
 
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Benchmarking tool for the autograd API
2+
3+
This folder contain a set of self-contained scripts that allow to benchmark the autograd with different common models.
4+
It is designed to run the benchmark before and after your change and will generate a table to share on the PR.
5+
6+
To do so, you can use `functional_autograd_benchmark.py` to run the benchmarks before your change (using as output `before.txt`) and after your change (using as output `after.txt`).
7+
You can then use `compare.py` to get a markdown table comparing the two runs.
8+
9+
The default arguments of `functional_autograd_benchmark.py` should be used in general. You can change them though to force a given device or force running even the (very) slow settings.
10+
11+
### Sample usage
12+
13+
```bash
14+
# Make sure you compile pytorch in release mode and with the same flags before/after
15+
export DEBUG=0
16+
# When running on CPU, it might be required to limit the number of cores to avoid oversubscription
17+
export OMP_NUM_THREADS=10
18+
19+
# Compile pytorch with the base revision
20+
git checkout master
21+
python setup.py develop
22+
23+
# Run the benchmark for the base
24+
# This will use the GPU if available.
25+
pushd benchmarks/functional_autograd_benchmark
26+
python functional_autograd_benchmark.py --output before.txt
27+
28+
# Compile pytorch with your change
29+
popd
30+
git checkout your_feature_branch
31+
python setup.py develop
32+
33+
# Run the benchmark for the new version
34+
pushd benchmarks/functional_autograd_benchmark
35+
python functional_autograd_benchmark.py --output after.txt
36+
37+
# Get the markdown table that you can paste in your github PR
38+
python compare.py
39+
40+
popd
41+
42+
```
43+
44+
### Files in this folder:
45+
- `functional_autograd_benchmark.py` is the main entry point to run the benchmark.
46+
- `compare.py` is the entry point to run the comparison script that generates a markdown table.
47+
- `torchaudio_models.py` and `torchvision_models.py` contains code extracted from torchaudio and torchvision to be able to run the models without having a specific version of these libraries installed.
48+
- `ppl_models.py`, `vision_models.py` and `audio_text_models.py` contain all the getter functions used for the benchmark.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import torch
2+
from torch import nn, Tensor
3+
4+
import torchaudio_models as models
5+
6+
from utils import extract_weights, load_weights, GetterReturnType
7+
8+
def get_wav2letter(device: torch.device) -> GetterReturnType:
9+
N = 10
10+
input_frames = 700
11+
vocab_size = 28
12+
model = models.Wav2Letter(num_classes=vocab_size)
13+
criterion = torch.nn.NLLLoss()
14+
model.to(device)
15+
params, names = extract_weights(model)
16+
17+
inputs = torch.rand([N, 1, input_frames], device=device)
18+
labels = torch.rand(N, 3, device=device).mul(vocab_size).long()
19+
20+
def forward(*new_params: Tensor) -> Tensor:
21+
load_weights(model, names, new_params)
22+
out = model(inputs)
23+
24+
loss = criterion(out, labels)
25+
return loss
26+
27+
return forward, params
28+
29+
def get_deepspeech(device: torch.device) -> GetterReturnType:
30+
sample_rate = 16000
31+
window_size = 0.02
32+
window = "hamming"
33+
audio_conf = dict(sample_rate=sample_rate,
34+
window_size=window_size,
35+
window=window,
36+
noise_dir=None)
37+
38+
N = 10
39+
num_classes = 10
40+
spectrogram_size = 161
41+
# Commented are the original sizes in the code
42+
seq_length = 500 # 1343
43+
target_length = 10 # 50
44+
labels = torch.rand(num_classes, device=device)
45+
inputs = torch.rand(N, 1, spectrogram_size, seq_length, device=device)
46+
# Sequence length for each input
47+
inputs_sizes = torch.rand(N, device=device).mul(seq_length * 0.1).add(seq_length * 0.8)
48+
targets = torch.rand(N, target_length, device=device)
49+
targets_sizes = torch.full((N,), target_length, dtype=torch.int, device=device)
50+
51+
model = models.DeepSpeech(rnn_type=nn.LSTM, labels=labels, rnn_hidden_size=1024, nb_layers=5,
52+
audio_conf=audio_conf, bidirectional=True)
53+
model = model.to(device)
54+
criterion = nn.CTCLoss()
55+
params, names = extract_weights(model)
56+
57+
def forward(*new_params: Tensor) -> Tensor:
58+
load_weights(model, names, new_params)
59+
out, out_sizes = model(inputs, inputs_sizes)
60+
out = out.transpose(0, 1) # For ctc loss
61+
62+
loss = criterion(out, targets, out_sizes, targets_sizes)
63+
return loss
64+
65+
return forward, params
66+
67+
def get_transformer(device: torch.device) -> GetterReturnType:
68+
# For most SOTA research, you would like to have embed to 720, nhead to 12, bsz to 64, tgt_len/src_len to 128.
69+
N = 64
70+
seq_length = 128
71+
ntoken = 50
72+
model = models.TransformerModel(ntoken=ntoken, ninp=720, nhead=12, nhid=2048, nlayers=2)
73+
model.to(device)
74+
criterion = nn.NLLLoss()
75+
params, names = extract_weights(model)
76+
77+
data = torch.rand(N, seq_length + 1, device=device).mul(ntoken).long()
78+
inputs = data.narrow(1, 0, seq_length)
79+
targets = data.narrow(1, 1, seq_length)
80+
81+
def forward(*new_params: Tensor) -> Tensor:
82+
load_weights(model, names, new_params)
83+
out = model(inputs)
84+
85+
loss = criterion(out.reshape(N * seq_length, ntoken), targets.reshape(N * seq_length))
86+
return loss
87+
88+
return forward, params
89+
90+
def get_multiheadattn(device: torch.device) -> GetterReturnType:
91+
# From https://github.com/pytorch/text/blob/master/test/data/test_modules.py#L10
92+
embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64
93+
# Build torchtext MultiheadAttention module
94+
in_proj = models.InProjContainer(torch.nn.Linear(embed_dim, embed_dim, bias=False),
95+
torch.nn.Linear(embed_dim, embed_dim, bias=False),
96+
torch.nn.Linear(embed_dim, embed_dim, bias=False))
97+
98+
model = models.MultiheadAttentionContainer(nhead, in_proj,
99+
models.ScaledDotProduct(),
100+
torch.nn.Linear(embed_dim, embed_dim, bias=False))
101+
model.to(device)
102+
params, names = extract_weights(model)
103+
104+
query = torch.rand((tgt_len, bsz, embed_dim), device=device)
105+
key = value = torch.rand((src_len, bsz, embed_dim), device=device)
106+
attn_mask_2D = torch.randint(0, 2, (tgt_len, src_len), device=device).to(torch.bool)
107+
bias_k = bias_v = torch.rand((1, 1, embed_dim), device=device)
108+
109+
attn_mask = torch.stack([attn_mask_2D] * (bsz * nhead))
110+
bias_k = bias_k.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1)
111+
bias_v = bias_v.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1)
112+
113+
def forward(*new_params: Tensor) -> Tensor:
114+
load_weights(model, names, new_params)
115+
mha_output, attn_weights = model(query, key, value, attn_mask=attn_mask, bias_k=bias_k, bias_v=bias_v)
116+
117+
# Don't test any specific loss, just backprop ones for both outputs
118+
loss = mha_output.sum() + attn_weights.sum()
119+
120+
return loss
121+
122+
return forward, params
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import argparse
2+
from collections import defaultdict
3+
4+
from utils import to_markdown_table, from_markdown_table
5+
6+
def main():
7+
parser = argparse.ArgumentParser("Main script to compare results from the benchmarks")
8+
parser.add_argument("--before", type=str, default="before.txt", help="Text file containing the times to use as base")
9+
parser.add_argument("--after", type=str, default="after.txt", help="Text file containing the times to use as new version")
10+
parser.add_argument("--output", type=str, default="", help="Text file where to write the output")
11+
args = parser.parse_args()
12+
13+
with open(args.before, "r") as f:
14+
content = f.read()
15+
res_before = from_markdown_table(content)
16+
17+
with open(args.after, "r") as f:
18+
content = f.read()
19+
res_after = from_markdown_table(content)
20+
21+
diff = defaultdict(defaultdict)
22+
for model in res_before:
23+
for task in res_before[model]:
24+
mean_before, var_before = res_before[model][task]
25+
if task not in res_after[model]:
26+
diff[model][task] = (None, mean_before, var_before, None, None)
27+
else:
28+
mean_after, var_after = res_after[model][task]
29+
diff[model][task] = (mean_before / mean_after, mean_before, var_before, mean_after, var_after)
30+
for model in res_after:
31+
for task in res_after[model]:
32+
if task not in res_before[model]:
33+
mean_after, var_after = res_after[model][task]
34+
diff[model][task] = (None, None, None, mean_after, var_after)
35+
36+
header = ("model", "task", "speedup", "mean (before)", "var (before)", "mean (after)", "var (after)")
37+
out = to_markdown_table(diff, header=header)
38+
39+
print(out)
40+
if args.output:
41+
with open(args.output, "w") as f:
42+
f.write(out)
43+
44+
if __name__ == "__main__":
45+
main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import torch
2+
from torch.autograd import functional
3+
4+
import time
5+
from argparse import ArgumentParser
6+
from collections import defaultdict
7+
from typing import NamedTuple, Callable, List, Any
8+
9+
import ppl_models
10+
import vision_models
11+
import audio_text_models
12+
13+
from utils import to_markdown_table, TimingResultType, InputsType, GetterType, VType
14+
15+
# Listing of the different tasks
16+
FAST_TASKS_NO_DOUBLE_BACK = [
17+
"vjp",
18+
]
19+
20+
FAST_TASKS = FAST_TASKS_NO_DOUBLE_BACK + [
21+
"vhp",
22+
"jvp",
23+
]
24+
25+
ALL_TASKS = FAST_TASKS + [
26+
"hvp",
27+
"jacobian",
28+
"hessian"
29+
]
30+
31+
DOUBLE_BACKWARD_TASKS = ["jvp", "hvp", "vhp", "hessian"]
32+
33+
# Model definition which contains:
34+
# - name: a string with the model name.
35+
# - getter: a function to get the model. It takes as input the device on which the model
36+
# will run. It should return the forward function and the parameters (Tensors) used as
37+
# input for the forward function. Note that the forward must *not* have any side effect.
38+
# - tasks: the list of recommended tasks that can run in a reasonable amount of time with this model.
39+
# - unsupported: the list of tasks that this model cannot run.
40+
class ModelDef(NamedTuple):
41+
name: str
42+
getter: GetterType
43+
tasks: List[str]
44+
unsupported: List[str]
45+
46+
MODELS = [
47+
ModelDef("resnet18", vision_models.get_resnet18, FAST_TASKS, []),
48+
ModelDef("fcn_resnet", vision_models.get_fcn_resnet, FAST_TASKS, []),
49+
ModelDef("detr", vision_models.get_detr, FAST_TASKS, []),
50+
ModelDef("ppl_simple_reg", ppl_models.get_simple_regression, ALL_TASKS, []),
51+
ModelDef("ppl_robust_reg", ppl_models.get_robust_regression, ALL_TASKS, []),
52+
ModelDef("wav2letter", audio_text_models.get_wav2letter, FAST_TASKS, []),
53+
ModelDef("deepspeech", audio_text_models.get_deepspeech, FAST_TASKS_NO_DOUBLE_BACK, DOUBLE_BACKWARD_TASKS),
54+
ModelDef("transformer", audio_text_models.get_transformer, FAST_TASKS, []),
55+
ModelDef("multiheadattn", audio_text_models.get_multiheadattn, FAST_TASKS, []),
56+
]
57+
58+
def get_v_for(model: Callable, inp: InputsType, task: str) -> VType:
59+
v: VType
60+
61+
if task in ["vjp"]:
62+
out = model(*inp)
63+
v = torch.rand_like(out)
64+
elif task in ["jvp", "hvp", "vhp"]:
65+
if isinstance(inp, tuple):
66+
v = tuple(torch.rand_like(i) for i in inp)
67+
else:
68+
v = torch.rand_like(inp)
69+
else:
70+
v = None
71+
72+
return v
73+
74+
def run_once(model: Callable, inp: InputsType, task: str, v: VType) -> None:
75+
func = getattr(functional, task)
76+
77+
if v is not None:
78+
res = func(model, inp, v=v, strict=True)
79+
else:
80+
res = func(model, inp, strict=True)
81+
82+
def run_model(model_getter: GetterType, args: Any, task: str) -> List[float]:
83+
if args.gpu == -1:
84+
device = torch.device("cpu")
85+
86+
def noop():
87+
pass
88+
do_sync = noop
89+
else:
90+
device = torch.device("cuda:{}".format(args.gpu))
91+
do_sync = torch.cuda.synchronize
92+
93+
model, inp = model_getter(device)
94+
95+
v = get_v_for(model, inp, task)
96+
# Warmup
97+
run_once(model, inp, task, v)
98+
99+
elapsed = []
100+
for it in range(args.num_iters):
101+
do_sync()
102+
start = time.time()
103+
run_once(model, inp, task, v)
104+
do_sync()
105+
elapsed.append(time.time() - start)
106+
107+
return elapsed
108+
109+
def main():
110+
parser = ArgumentParser("Main script to benchmark functional API of the autograd.")
111+
parser.add_argument("--output", type=str, default="", help="Text file where to write the output")
112+
parser.add_argument("--num-iters", type=int, default=10)
113+
parser.add_argument("--gpu", type=int, default=-2, help="GPU to use, -1 for CPU and -2 for auto-detect")
114+
parser.add_argument("--run-slow-tasks", action="store_true", help="Run even the slow tasks")
115+
parser.add_argument("--model-filter", type=str, default="", help="Only run the models in this filter")
116+
parser.add_argument("--task-filter", type=str, default="", help="Only run the tasks in this filter")
117+
parser.add_argument("--num-threads", type=int, default=10,
118+
help="Number of concurrent threads to use when running on cpu")
119+
parser.add_argument("--seed", type=int, default=0, help="The random seed to use.")
120+
args = parser.parse_args()
121+
122+
results: TimingResultType = defaultdict(defaultdict)
123+
torch.set_num_threads(args.num_threads)
124+
torch.set_num_interop_threads(args.num_threads)
125+
126+
# This automatically seed cuda if it is available
127+
torch.manual_seed(args.seed)
128+
129+
if args.gpu == -2:
130+
args.gpu = 0 if torch.cuda.is_available() else -1
131+
132+
for name, model_getter, recommended_tasks, unsupported_tasks in MODELS:
133+
if args.model_filter and name not in args.model_filter:
134+
continue
135+
tasks = ALL_TASKS if args.run_slow_tasks else recommended_tasks
136+
for task in tasks:
137+
if task in unsupported_tasks:
138+
continue
139+
if args.task_filter and task not in args.task_filter:
140+
continue
141+
runtimes = run_model(model_getter, args, task)
142+
143+
runtimes = torch.tensor(runtimes)
144+
mean, var = runtimes.mean(), runtimes.var()
145+
results[name][task] = (mean.item(), var.item())
146+
print("Results for model {} on task {}: {}s (var: {})".format(name, task, mean, var))
147+
148+
if args.output:
149+
with open(args.output, "w") as f:
150+
f.write(to_markdown_table(results))
151+
152+
if __name__ == "__main__":
153+
main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import torch
2+
from torch import Tensor
3+
import torch.distributions as dist
4+
5+
from utils import GetterReturnType
6+
7+
def get_simple_regression(device: torch.device) -> GetterReturnType:
8+
N = 10
9+
K = 10
10+
11+
loc_beta = 0.
12+
scale_beta = 1.
13+
14+
beta_prior = dist.Normal(loc_beta, scale_beta)
15+
16+
X = torch.rand(N, K + 1, device=device)
17+
Y = torch.rand(N, 1, device=device)
18+
19+
# X.shape: (N, K + 1), Y.shape: (N, 1), beta_value.shape: (K + 1, 1)
20+
beta_value = beta_prior.sample((K + 1, 1))
21+
beta_value.requires_grad_(True)
22+
23+
def forward(beta_value: Tensor) -> Tensor:
24+
mu = X.mm(beta_value)
25+
26+
# We need to compute the first and second gradient of this score with respect
27+
# to beta_value.
28+
score = dist.Bernoulli(logits=mu).log_prob(Y).sum() + beta_prior.log_prob(beta_value).sum()
29+
return score
30+
31+
return forward, (beta_value.to(device),)
32+
33+
34+
def get_robust_regression(device: torch.device) -> GetterReturnType:
35+
N = 10
36+
K = 10
37+
38+
# X.shape: (N, K + 1), Y.shape: (N, 1)
39+
X = torch.rand(N, K + 1, device=device)
40+
Y = torch.rand(N, 1, device=device)
41+
42+
# Predefined nu_alpha and nu_beta, nu_alpha.shape: (1, 1), nu_beta.shape: (1, 1)
43+
nu_alpha = torch.randn(1, 1, device=device)
44+
nu_beta = torch.rand(1, 1, device=device)
45+
nu = dist.Gamma(nu_alpha, nu_beta)
46+
47+
# Predefined sigma_rate: sigma_rate.shape: (N, 1)
48+
sigma_rate = torch.rand(N, 1, device=device)
49+
sigma = dist.Exponential(sigma_rate)
50+
51+
# Predefined beta_mean and beta_sigma: beta_mean.shape: (K + 1, 1), beta_sigma.shape: (K + 1, 1)
52+
beta_mean = torch.rand(K + 1, 1, device=device)
53+
beta_sigma = torch.rand(K + 1, 1, device=device)
54+
beta = dist.Normal(beta_mean, beta_sigma)
55+
56+
nu_value = nu.sample()
57+
nu_value.requires_grad_(True)
58+
59+
sigma_value = sigma.sample()
60+
sigma_unconstrained_value = sigma_value.log()
61+
sigma_unconstrained_value.requires_grad_(True)
62+
63+
beta_value = beta.sample()
64+
beta_value.requires_grad_(True)
65+
66+
def forward(nu_value: Tensor, sigma_unconstrained_value: Tensor, beta_value: Tensor) -> Tensor:
67+
sigma_constrained_value = sigma_unconstrained_value.exp()
68+
mu = X.mm(beta_value)
69+
70+
# For this model, we need to compute the following three scores:
71+
# We need to compute the first and second gradient of this score with respect
72+
# to nu_value.
73+
nu_score = dist.StudentT(nu_value, mu, sigma_constrained_value).log_prob(Y).sum() \
74+
+ nu.log_prob(nu_value)
75+
76+
77+
78+
# We need to compute the first and second gradient of this score with respect
79+
# to sigma_unconstrained_value.
80+
sigma_score = dist.StudentT(nu_value, mu, sigma_constrained_value).log_prob(Y).sum() \
81+
+ sigma.log_prob(sigma_constrained_value) \
82+
+ sigma_unconstrained_value
83+
84+
85+
86+
# We need to compute the first and second gradient of this score with respect
87+
# to beta_value.
88+
beta_score = dist.StudentT(nu_value, mu, sigma_constrained_value).log_prob(Y).sum() \
89+
+ beta.log_prob(beta_value)
90+
91+
return nu_score.sum() + sigma_score.sum() + beta_score.sum()
92+
93+
return forward, (nu_value.to(device), sigma_unconstrained_value.to(device), beta_value.to(device))

‎benchmarks/functional_autograd_benchmark/torchaudio_models.py

+556
Large diffs are not rendered by default.

‎benchmarks/functional_autograd_benchmark/torchvision_models.py

+803
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import torch
2+
3+
from collections import defaultdict
4+
5+
from torch import nn, Tensor
6+
from typing import List, Tuple, Dict, Union, Callable
7+
8+
# Type helpers
9+
InputsType = Union[Tensor, Tuple[Tensor, ...]]
10+
# A Getter takes in a device and returns a callable and the inputs to that callable
11+
GetterReturnType = Tuple[Callable[..., Tensor], InputsType]
12+
GetterType = Callable[[torch.device], GetterReturnType]
13+
# V here refers to the v in either vjp, jvp, vhp or hvp
14+
VType = Union[None, Tensor, Tuple[Tensor, ...]]
15+
# Type used to store timing results. The first key is the model name, the second key
16+
# is the task name, the result is a Tuple of: speedup, mean_before, var_before, mean_after, var_after.
17+
TimingResultType = Dict[str, Dict[str, Tuple[float, ...]]]
18+
19+
# Utilities to make nn.Module "functional"
20+
# In particular the goal is to be able to provide a function that takes as input
21+
# the parameters and evaluate the nn.Module using fixed inputs.
22+
def _del_nested_attr(obj: nn.Module, names: List[str]) -> None:
23+
"""
24+
Deletes the attribute specified by the given list of names.
25+
For example, to delete the attribute obj.conv.weight,
26+
use _del_nested_attr(obj, ['conv', 'weight'])
27+
"""
28+
if len(names) == 1:
29+
delattr(obj, names[0])
30+
else:
31+
_del_nested_attr(getattr(obj, names[0]), names[1:])
32+
33+
def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:
34+
"""
35+
Set the attribute specified by the given list of names to value.
36+
For example, to set the attribute obj.conv.weight,
37+
use _del_nested_attr(obj, ['conv', 'weight'], value)
38+
"""
39+
if len(names) == 1:
40+
setattr(obj, names[0], value)
41+
else:
42+
_set_nested_attr(getattr(obj, names[0]), names[1:], value)
43+
44+
def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
45+
"""
46+
This function removes all the Parameters from the model and
47+
return them as a tuple as well as their original attribute names.
48+
The weights must be re-loaded with `load_weights` before the model
49+
can be used again.
50+
Note that this function modifies the model in place and after this
51+
call, mod.parameters() will be empty.
52+
"""
53+
orig_params = tuple(mod.parameters())
54+
# Remove all the parameters in the model
55+
names = []
56+
for name, p in list(mod.named_parameters()):
57+
_del_nested_attr(mod, name.split("."))
58+
names.append(name)
59+
60+
# Make params regular Tensors instead of nn.Parameter
61+
params = tuple(p.detach().requires_grad_() for p in orig_params)
62+
return params, names
63+
64+
def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...]) -> None:
65+
"""
66+
Reload a set of weights so that `mod` can be used again to perform a forward pass.
67+
Note that the `params` are regular Tensors (that can have history) and so are left
68+
as Tensors. This means that mod.parameters() will still be empty after this call.
69+
"""
70+
for name, p in zip(names, params):
71+
_set_nested_attr(mod, name.split("."), p)
72+
73+
# Utilities to read/write markdown table-like content.
74+
def to_markdown_table(res: TimingResultType, header: Tuple[str, ...] = None) -> str:
75+
if header is None:
76+
header = ("model", "task", "mean", "var")
77+
out = ""
78+
79+
def write_line(*args):
80+
nonlocal out
81+
out += "| {} |\n".format(" | ".join(str(a) for a in args))
82+
83+
# Make it a markdown table
84+
write_line(*header)
85+
write_line(*["--"] * len(header))
86+
for model, tasks in res.items():
87+
for task, line in tasks.items():
88+
write_line(*(model, task) + line)
89+
90+
return out
91+
92+
def from_markdown_table(data: str) -> TimingResultType:
93+
out = data.strip().split("\n")
94+
out = out[2:] # Ignore the header lines
95+
96+
res: TimingResultType
97+
res = defaultdict(defaultdict)
98+
99+
for line in out:
100+
model, task, mean, var = [f.strip() for f in line.strip().split("|") if f]
101+
res[model][task] = (float(mean), float(var))
102+
103+
return res
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import torch
2+
from torch import Tensor
3+
import torchvision_models as models
4+
5+
from utils import extract_weights, load_weights, GetterReturnType
6+
7+
from typing import cast
8+
9+
def get_resnet18(device: torch.device) -> GetterReturnType:
10+
N = 32
11+
model = models.resnet18(pretrained=False)
12+
criterion = torch.nn.CrossEntropyLoss()
13+
model.to(device)
14+
params, names = extract_weights(model)
15+
16+
inputs = torch.rand([N, 3, 224, 224], device=device)
17+
labels = torch.rand(N, device=device).mul(10).long()
18+
19+
def forward(*new_params: Tensor) -> Tensor:
20+
load_weights(model, names, new_params)
21+
out = model(inputs)
22+
23+
loss = criterion(out, labels)
24+
return loss
25+
26+
return forward, params
27+
28+
def get_fcn_resnet(device: torch.device) -> GetterReturnType:
29+
N = 8
30+
criterion = torch.nn.MSELoss()
31+
model = models.fcn_resnet50(pretrained=False, pretrained_backbone=False)
32+
model.to(device)
33+
params, names = extract_weights(model)
34+
35+
inputs = torch.rand([N, 3, 480, 480], device=device)
36+
# Given model has 21 classes
37+
labels = torch.rand([N, 21, 480, 480], device=device)
38+
39+
def forward(*new_params: Tensor) -> Tensor:
40+
load_weights(model, names, new_params)
41+
out = model(inputs)['out']
42+
43+
loss = criterion(out, labels)
44+
return loss
45+
46+
return forward, params
47+
48+
def get_detr(device: torch.device) -> GetterReturnType:
49+
# All values below are from CLI defaults in https://github.com/facebookresearch/detr
50+
N = 2
51+
num_classes = 91
52+
hidden_dim = 256
53+
nheads = 8
54+
num_encoder_layers = 6
55+
num_decoder_layers = 6
56+
57+
model = models.DETR(num_classes=num_classes, hidden_dim=hidden_dim, nheads=nheads,
58+
num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers)
59+
losses = ['labels', 'boxes', 'cardinality']
60+
eos_coef = 0.1
61+
bbox_loss_coef = 5
62+
giou_loss_coef = 2
63+
weight_dict = {'loss_ce': 1, 'loss_bbox': bbox_loss_coef, 'loss_giou': giou_loss_coef}
64+
matcher = models.HungarianMatcher(1, 5, 2)
65+
criterion = models.SetCriterion(num_classes=num_classes, matcher=matcher, weight_dict=weight_dict,
66+
eos_coef=eos_coef, losses=losses)
67+
68+
model = model.to(device)
69+
criterion = criterion.to(device)
70+
params, names = extract_weights(model)
71+
72+
inputs = torch.rand(N, 3, 800, 1200, device=device)
73+
labels = []
74+
for idx in range(N):
75+
targets = {}
76+
n_targets: int = int(torch.randint(5, 10, size=tuple()).item())
77+
label = torch.randint(5, 10, size=(n_targets,))
78+
targets["labels"] = label
79+
boxes = torch.randint(100, 800, size=(n_targets, 4))
80+
for t in range(n_targets):
81+
if boxes[t, 0] > boxes[t, 2]:
82+
boxes[t, 0], boxes[t, 2] = boxes[t, 2], boxes[t, 0]
83+
if boxes[t, 1] > boxes[t, 3]:
84+
boxes[t, 1], boxes[t, 3] = boxes[t, 3], boxes[t, 1]
85+
targets["boxes"] = boxes.float()
86+
labels.append(targets)
87+
88+
def forward(*new_params: Tensor) -> Tensor:
89+
load_weights(model, names, new_params)
90+
out = model(inputs)
91+
92+
loss = criterion(out, labels)
93+
weight_dict = criterion.weight_dict
94+
final_loss = cast(Tensor, sum(loss[k] * weight_dict[k] for k in loss.keys() if k in weight_dict))
95+
return final_loss
96+
97+
return forward, params

‎test/run_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
'test_determination',
8888
'test_futures',
8989
'test_fx',
90+
'test_functional_autograd_benchmark'
9091
]
9192

9293
WINDOWS_BLOCKLIST = [
+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from torch.testing._internal.common_utils import TestCase, run_tests, slowTest, IS_WINDOWS
2+
3+
import subprocess
4+
import tempfile
5+
import os
6+
import unittest
7+
8+
# This is a very simple smoke test for the functional autograd benchmarking script.
9+
class TestFunctionalAutogradBenchmark(TestCase):
10+
def _test_runner(self, model, disable_gpu=False):
11+
# Note about windows:
12+
# The temporary file is exclusively open by this process and the child process
13+
# is not allowed to open it again. As this is a simple smoke test, we choose for now
14+
# not to run this on windows and keep the code here simple.
15+
with tempfile.NamedTemporaryFile() as out_file:
16+
cmd = ['python', '../benchmarks/functional_autograd_benchmark/functional_autograd_benchmark.py']
17+
# Only run the warmup
18+
cmd += ['--num-iters', '0']
19+
# Only run the vjp task (fastest one)
20+
cmd += ['--task-filter', 'vjp']
21+
# Only run the specified model
22+
cmd += ['--model-filter', model]
23+
# Output file
24+
cmd += ['--output', out_file.name]
25+
if disable_gpu:
26+
cmd += ['--gpu', '-1']
27+
28+
res = subprocess.run(cmd)
29+
30+
self.assertTrue(res.returncode == 0)
31+
# Check that something was written to the file
32+
out_file.seek(0, os.SEEK_END)
33+
self.assertTrue(out_file.tell() > 0)
34+
35+
36+
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows does not have all the features we need.")
37+
def test_fast_tasks(self):
38+
fast_tasks = ['resnet18', 'ppl_simple_reg', 'ppl_robust_reg', 'wav2letter',
39+
'transformer', 'multiheadattn']
40+
41+
for task in fast_tasks:
42+
self._test_runner(task)
43+
44+
@slowTest
45+
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows does not have all the features we need.")
46+
def test_slow_tasks(self):
47+
slow_tasks = ['fcn_resnet', 'detr']
48+
# deepspeech is voluntarily excluded as it takes too long to run without
49+
# proper tuning of the number of threads it should use.
50+
51+
for task in slow_tasks:
52+
# Disable GPU for slow test as the CI GPU don't have enough memory
53+
self._test_runner(task, disable_gpu=True)
54+
55+
56+
if __name__ == '__main__':
57+
run_tests()

0 commit comments

Comments
 (0)
Please sign in to comment.