Skip to content

Commit 8ef22c3

Browse files
committedMar 25, 2021
cusolver inverse
1 parent 8a8d394 commit 8ef22c3

File tree

7 files changed

+621
-0
lines changed

7 files changed

+621
-0
lines changed
 

‎linalg/cholesky-inverse/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
!*
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import torch
2+
import time
3+
import itertools
4+
import gc
5+
import json
6+
import multiprocessing
7+
import threading
8+
9+
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
10+
11+
TIME_MULTIPLIER = 1e6
12+
TIME_UNIT = 'us'
13+
14+
nb = 200
15+
# nb = 1
16+
17+
torch.manual_seed(42)
18+
torch.cuda.manual_seed(42)
19+
20+
def compare(x, y, *, rtol, atol):
21+
if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
22+
if not x.is_cuda:
23+
x = x.cuda()
24+
if not y.is_cuda:
25+
raise RuntimeError("y tensor should be cuda, but it's not")
26+
return torch.testing._compare_tensors_internal(x, y, rtol=rtol, atol=atol, equal_nan=False)
27+
28+
a = True
29+
b = {}
30+
for x_, y_, s_ in zip(x, y, ['U', 'S', 'V']):
31+
a_, b_ = compare(x_, y_, rtol=rtol, atol=atol)
32+
33+
a = a and a_
34+
if not a_:
35+
b[s_] = b_
36+
37+
return a, json.dumps(b, indent=2)
38+
39+
40+
def main(s: str = ''):
41+
def prof(b_, n_, dtype=torch.float, p=None, flag=None):
42+
gc.collect()
43+
torch.cuda.empty_cache()
44+
45+
if p is None:
46+
p = lambda x: x
47+
48+
# print(b_, n_)
49+
# x = torch.randn(*b_, n_, n_, device='cuda', dtype=dtype)
50+
zo = random_hermitian_pd_matrix(n_, *b_, device='cpu', dtype=torch.float64).cuda()
51+
z = torch.cholesky(zo.cpu()).to(dtype=dtype, device='cuda')
52+
# x = torch.randn(*b_, n_, n_, device='cuda').to(dtype=dtype)
53+
# x = torch.randn(*b_, n_, 1, device='cuda').to(dtype=dtype)
54+
55+
# xc = x.clone().cpu()
56+
zc = z.clone().cpu()
57+
58+
# cpu timing
59+
t1 = time.time()
60+
for _ in range(nb):
61+
yc = p(zc)
62+
t2 = time.time()
63+
cpu_time = (t2-t1)/nb*TIME_MULTIPLIER
64+
# print('cpu', cpu_time, 'ms')
65+
66+
if torch.isnan(yc).any() or torch.isnan(zc).any():
67+
print('cpu output contains nan')
68+
69+
# warmup
70+
for _ in range(nb):
71+
y_warmup = p(z)
72+
torch.cuda.synchronize()
73+
74+
# c, d = compare(xc, x, rtol=1e-7, atol=1e-7)
75+
# if not c:
76+
# print('original matrix compare')
77+
# print(d)
78+
# raise RuntimeError('original value x modified')
79+
c1, d1 = compare(zc, z, rtol=1e-7, atol=1e-7)
80+
if not c1:
81+
print('original matrix compare')
82+
print(d1)
83+
raise RuntimeError('original value z modified')
84+
85+
torch.cuda.profiler.start()
86+
with torch.autograd.profiler.emit_nvtx(record_shapes=True):
87+
y = p(z)
88+
torch.cuda.synchronize()
89+
torch.cuda.profiler.stop()
90+
91+
torch.cuda.synchronize()
92+
93+
# gpu timing
94+
t1 = time.time()
95+
for _ in range(nb):
96+
# y = torch.cholesky(x)
97+
y = p(z)
98+
torch.cuda.synchronize()
99+
t2 = time.time()
100+
gpu_time = (t2-t1)/nb*TIME_MULTIPLIER
101+
# print('gpu', gpu_time, 'ms')
102+
103+
e, f = compare(y_warmup, y, rtol=0, atol=0)
104+
if not e:
105+
print('non-determinism: cholesky_solve value output')
106+
print(f)
107+
raise RuntimeError('non-deterministic output')
108+
109+
torch.backends.cuda.matmul.allow_tf32 = False
110+
reconstruct = (zo @ y.double()).float()
111+
torch.backends.cuda.matmul.allow_tf32 = True
112+
113+
a, b = compare(torch.eye(n_).expand(*b_, n_, n_), reconstruct, rtol=1e-3, atol=1e-3)
114+
# a, b = compare(yc, y, rtol=1e-3, atol=1e-3)
115+
if not a:
116+
print('numerical mismatch: reconstruct value compare')
117+
print(b)
118+
119+
print(f'{b_} {n_} {dtype}'.ljust(35) + f'{cpu_time : .3f} {gpu_time : .3f}')
120+
# f.write(f'{b_} {n_} {dtype}; ' + f'{cpu_time : .3e}, {gpu_time : .3e}\n')
121+
torch.cuda.synchronize()
122+
123+
print(s)
124+
print(torch.__version__)
125+
print()
126+
print('batch_size, matrix_size, dtype'.ljust(35) +
127+
f'cpu_time({TIME_UNIT}), gpu_time({TIME_UNIT})')
128+
129+
for b, n in itertools.product(
130+
# [[]] + [[2**i] for i in range(11)],
131+
# [[], [1]],
132+
[[2**i] for i in range(1, 11)],
133+
[2**j for j in range(1, 12, 1)]
134+
):
135+
if b and b[0] * n >= 2**14:
136+
continue
137+
prof(b, n, p=torch.cholesky_inverse)
138+
139+
if __name__ == "__main__":
140+
main()
141+

‎linalg/cholesky-inverse/parse.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import glob
2+
from collections import defaultdict
3+
import json
4+
import io
5+
import numpy as np
6+
7+
BEFORE = 'before-commit'
8+
AFTER = 'after-commit'
9+
10+
SORT_KEY = {
11+
"cpu": -1,
12+
"before_magma": 0,
13+
"after_potrs_based": 2,
14+
"after_heuristics": 3,
15+
}
16+
17+
class Markdown:
18+
def __init__(self):
19+
self.buffer = io.BufferedRandom(io.BytesIO())
20+
self.enc = 'utf-8'
21+
22+
def write(self, s: str):
23+
self.buffer.write(s.encode(self.enc))
24+
25+
def read(self) -> bytes:
26+
self.buffer.seek(0)
27+
return self.buffer.read()
28+
29+
def main():
30+
profs = glob.glob('./prof*.txt')
31+
# profs = glob.glob('./prox*.txt')
32+
33+
dt_gpu = defaultdict(dict)
34+
dt_cpu = defaultdict(dict)
35+
columns = ["cpu"]
36+
37+
for prof in profs:
38+
impl_key = prof[7:-4]
39+
columns.append(impl_key)
40+
41+
with open(prof, 'r') as f:
42+
fl = f.readlines()
43+
44+
al = [line.rstrip().split(' ') for line in fl if line.startswith('[')]
45+
46+
for line in al:
47+
shape = line[0]
48+
t_cpu, t_gpu = (float(x) for x in line[-2:])
49+
50+
dt_gpu[shape][impl_key] = t_gpu
51+
dt_cpu[shape][impl_key] = t_cpu
52+
53+
columns.sort(key=SORT_KEY.__getitem__)
54+
55+
print(json.dumps(dt_gpu, indent=2))
56+
# print(dt_cpu)
57+
58+
md = Markdown()
59+
md.write('time is in **us** (10^-6 s)\n\n')
60+
md.write('|shape|' + '|'.join(columns) + '|\n')
61+
md.write('|---:' * (len(columns)+1) + '|\n')
62+
63+
for shape in dt_gpu.keys():
64+
t_cpu_avg = np.mean([x for x in dt_cpu[shape].values()])
65+
md.write(f'| {shape} | {t_cpu_avg : .3f} |')
66+
67+
for column in columns[1:]:
68+
md.write(f' {dt_gpu[shape].get(column, -1) : .3f} |')
69+
70+
md.write('\n')
71+
72+
73+
with open('readme.md', 'wb') as f:
74+
# with open('readme1.md', 'wb') as f:
75+
f.write(md.read())
76+
77+
78+
if __name__ == "__main__":
79+
main()
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
2+
1.9.0a0+gitf83bb72
3+
4+
batch_size, matrix_size, dtype cpu_time(us), gpu_time(us)
5+
[] 2 torch.float32 10.827 68.474
6+
[] 4 torch.float32 10.290 68.400
7+
[] 8 torch.float32 11.035 68.010
8+
[] 16 torch.float32 13.291 60.335
9+
[] 32 torch.float32 18.488 59.669
10+
[] 64 torch.float32 40.630 59.838
11+
[] 128 torch.float32 544.492 81.584
12+
[] 256 torch.float32 545.278 183.367
13+
[] 512 torch.float32 1765.593 467.857
14+
[] 1024 torch.float32 7690.465 1377.743
15+
[] 2048 torch.float32 50593.435 5687.451
16+
[1] 2 torch.float32 10.318 66.886
17+
[1] 4 torch.float32 10.408 66.848
18+
[1] 8 torch.float32 11.165 67.350
19+
[1] 16 torch.float32 13.280 59.938
20+
[1] 32 torch.float32 18.553 58.808
21+
[1] 64 torch.float32 35.299 59.521
22+
[1] 128 torch.float32 587.937 107.318
23+
[1] 256 torch.float32 665.547 155.816
24+
[1] 512 torch.float32 2052.648 420.599
25+
[1] 1024 torch.float32 7829.883 1377.732
26+
[1] 2048 torch.float32 47425.777 5707.585
27+
[2] 2 torch.float32 12.114 67.579
28+
[2] 4 torch.float32 11.870 68.905
29+
[2] 8 torch.float32 13.413 69.492
30+
[2] 16 torch.float32 17.011 70.037
31+
[2] 32 torch.float32 26.886 134.109
32+
[2] 64 torch.float32 69.199 202.242
33+
[2] 128 torch.float32 1211.519 391.432
34+
[2] 256 torch.float32 1169.972 611.967
35+
[2] 512 torch.float32 4103.551 1224.111
36+
[2] 1024 torch.float32 15942.370 3580.601
37+
[2] 2048 torch.float32 101954.410 12943.521
38+
[4] 2 torch.float32 11.822 66.485
39+
[4] 4 torch.float32 13.098 67.027
40+
[4] 8 torch.float32 15.808 66.571
41+
[4] 16 torch.float32 22.434 67.693
42+
[4] 32 torch.float32 40.867 127.492
43+
[4] 64 torch.float32 117.341 190.632
44+
[4] 128 torch.float32 1923.342 343.292
45+
[4] 256 torch.float32 2123.984 675.958
46+
[4] 512 torch.float32 6900.259 1573.567
47+
[4] 1024 torch.float32 33191.272 4883.418
48+
[4] 2048 torch.float32 204610.639 21273.334
49+
[8] 2 torch.float32 17.005 89.726
50+
[8] 4 torch.float32 19.516 89.639
51+
[8] 8 torch.float32 24.843 92.627
52+
[8] 16 torch.float32 38.671 163.162
53+
[8] 32 torch.float32 125.901 155.975
54+
[8] 64 torch.float32 231.619 226.618
55+
[8] 128 torch.float32 3673.230 393.283
56+
[8] 256 torch.float32 3719.603 825.180
57+
[8] 512 torch.float32 13651.674 2176.757
58+
[8] 1024 torch.float32 70810.848 7957.207
59+
[16] 2 torch.float32 17.269 90.749
60+
[16] 4 torch.float32 21.216 90.303
61+
[16] 8 torch.float32 30.577 90.498
62+
[16] 16 torch.float32 54.492 90.675
63+
[16] 32 torch.float32 127.280 161.965
64+
[16] 64 torch.float32 496.632 241.767
65+
[16] 128 torch.float32 7933.570 392.665
66+
[16] 256 torch.float32 8850.931 892.092
67+
[16] 512 torch.float32 29227.349 3448.063
68+
[32] 2 torch.float32 17.402 69.261
69+
[32] 4 torch.float32 24.254 69.902
70+
[32] 8 torch.float32 42.560 70.021
71+
[32] 16 torch.float32 89.972 70.196
72+
[32] 32 torch.float32 241.292 140.296
73+
[32] 64 torch.float32 923.823 228.437
74+
[32] 128 torch.float32 14301.009 546.384
75+
[32] 256 torch.float32 17318.430 1514.449
76+
[64] 2 torch.float32 32.297 91.580
77+
[64] 4 torch.float32 50.187 91.859
78+
[64] 8 torch.float32 86.789 91.848
79+
[64] 16 torch.float32 179.645 92.436
80+
[64] 32 torch.float32 542.165 167.738
81+
[64] 64 torch.float32 2208.329 276.698
82+
[64] 128 torch.float32 30164.434 648.333
83+
[128] 2 torch.float32 47.898 71.030
84+
[128] 4 torch.float32 77.578 70.615
85+
[128] 8 torch.float32 147.750 71.871
86+
[128] 16 torch.float32 335.081 72.927
87+
[128] 32 torch.float32 1780.635 150.491
88+
[128] 64 torch.float32 4148.735 279.559
89+
[256] 2 torch.float32 80.156 71.524
90+
[256] 4 torch.float32 138.524 71.729
91+
[256] 8 torch.float32 282.979 76.048
92+
[256] 16 torch.float32 740.534 78.084
93+
[256] 32 torch.float32 2066.872 178.185
94+
[512] 2 torch.float32 152.835 76.374
95+
[512] 4 torch.float32 267.853 77.028
96+
[512] 8 torch.float32 560.991 78.712
97+
[512] 16 torch.float32 1469.697 85.351
98+
[1024] 2 torch.float32 298.941 88.992
99+
[1024] 4 torch.float32 528.816 89.751
100+
[1024] 8 torch.float32 1196.932 92.831

0 commit comments

Comments
 (0)