Skip to content

Commit 4ea09eb

Browse files
committed
cholesky solve
1 parent fb3b7e8 commit 4ea09eb

13 files changed

+1023
-0
lines changed

linalg/cholesky-solve/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
!*

linalg/cholesky-solve/_prof_is_square_matrix.txt

Whitespace-only changes.

linalg/cholesky-solve/_prox_is_nrhs_eq_1.txt

Whitespace-only changes.

linalg/cholesky-solve/linalg-prof.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import torch
2+
import time
3+
import itertools
4+
import gc
5+
import json
6+
7+
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
8+
9+
TIME_MULTIPLIER = 1e6
10+
TIME_UNIT = 'us'
11+
12+
nb = 200
13+
# nb = 1
14+
15+
torch.manual_seed(42)
16+
torch.cuda.manual_seed(42)
17+
18+
def compare(x, y, *, rtol, atol):
19+
if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
20+
if not x.is_cuda:
21+
x = x.cuda()
22+
if not y.is_cuda:
23+
raise RuntimeError("y tensor should be cuda, but it's not")
24+
return torch.testing._compare_tensors_internal(x, y, rtol=rtol, atol=atol, equal_nan=False)
25+
26+
a = True
27+
b = {}
28+
for x_, y_, s_ in zip(x, y, ['U', 'S', 'V']):
29+
a_, b_ = compare(x_, y_, rtol=rtol, atol=atol)
30+
31+
a = a and a_
32+
if not a_:
33+
b[s_] = b_
34+
35+
return a, json.dumps(b, indent=2)
36+
37+
38+
def main(s: str = ''):
39+
def prof(b_, n_, dtype=torch.float, p=None, flag=None):
40+
gc.collect()
41+
torch.cuda.empty_cache()
42+
43+
if p is None:
44+
p = lambda x: x
45+
46+
# print(b_, n_)
47+
# x = torch.randn(*b_, n_, n_, device='cuda', dtype=dtype)
48+
zo = random_hermitian_pd_matrix(n_, *b_, device='cuda', dtype=torch.float64)
49+
z = torch.cholesky(zo).to(dtype=dtype)
50+
x = torch.randn(*b_, n_, n_, device='cuda').to(dtype=dtype)
51+
# x = torch.randn(*b_, n_, 1, device='cuda').to(dtype=dtype)
52+
53+
xc = x.clone().cpu()
54+
zc = z.clone().cpu()
55+
56+
# cpu timing
57+
t1 = time.time()
58+
for _ in range(nb):
59+
yc = p(xc, zc)
60+
t2 = time.time()
61+
cpu_time = (t2-t1)/nb*TIME_MULTIPLIER
62+
# print('cpu', cpu_time, 'ms')
63+
64+
if torch.isnan(yc).any() or torch.isnan(zc).any():
65+
print('cpu output contains nan')
66+
67+
# warmup
68+
for _ in range(nb):
69+
y_warmup = p(x, z)
70+
torch.cuda.synchronize()
71+
72+
c, d = compare(xc, x, rtol=1e-7, atol=1e-7)
73+
if not c:
74+
print('original matrix compare')
75+
print(d)
76+
raise RuntimeError('original value x modified')
77+
c1, d1 = compare(zc, z, rtol=1e-7, atol=1e-7)
78+
if not c1:
79+
print('original matrix compare')
80+
print(d1)
81+
raise RuntimeError('original value z modified')
82+
83+
torch.cuda.profiler.start()
84+
with torch.autograd.profiler.emit_nvtx(record_shapes=True):
85+
y = p(x, z)
86+
torch.cuda.synchronize()
87+
torch.cuda.profiler.stop()
88+
89+
torch.cuda.synchronize()
90+
91+
# gpu timing
92+
t1 = time.time()
93+
for _ in range(nb):
94+
# y = torch.cholesky(x)
95+
y = p(x, z)
96+
torch.cuda.synchronize()
97+
t2 = time.time()
98+
gpu_time = (t2-t1)/nb*TIME_MULTIPLIER
99+
# print('gpu', gpu_time, 'ms')
100+
101+
e, f = compare(y_warmup, y, rtol=0, atol=0)
102+
if not e:
103+
print('non-determinism: cholesky_solve value output')
104+
print(f)
105+
raise RuntimeError('non-deterministic output')
106+
107+
torch.backends.cuda.matmul.allow_tf32 = False
108+
reconstruct = (zo @ y.double()).float()
109+
torch.backends.cuda.matmul.allow_tf32 = True
110+
111+
a, b = compare(x, reconstruct, rtol=1e-3, atol=1e-3)
112+
# a, b = compare(yc, y, rtol=1e-3, atol=1e-3)
113+
if not a:
114+
print('numerical mismatch: reconstruct value compare')
115+
print(b)
116+
117+
print(f'{b_} {n_} {dtype}'.ljust(35) + f'{cpu_time : .3f} {gpu_time : .3f}')
118+
# f.write(f'{b_} {n_} {dtype}; ' + f'{cpu_time : .3e}, {gpu_time : .3e}\n')
119+
torch.cuda.synchronize()
120+
121+
print(s)
122+
print(torch.__version__)
123+
print()
124+
print('batch_size, matrix_size, dtype'.ljust(35) +
125+
f'cpu_time({TIME_UNIT}), gpu_time({TIME_UNIT})')
126+
127+
for b, n in itertools.product(
128+
[[]] + [[2**i] for i in range(11)],
129+
[2**j for j in range(1, 12, 1)]
130+
):
131+
if b and b[0] * n >= 2**14:
132+
continue
133+
prof(b, n, p=torch.cholesky_solve)
134+
135+
if __name__ == "__main__":
136+
main()
137+

linalg/cholesky-solve/parse.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import glob
2+
from collections import defaultdict
3+
import json
4+
import io
5+
import numpy as np
6+
7+
BEFORE = 'before-commit'
8+
AFTER = 'after-commit'
9+
10+
SORT_KEY = {
11+
"cpu": -1,
12+
"before_magma": 0,
13+
"after_potrs_64bit": 2,
14+
"after_heuristics": 3
15+
}
16+
17+
class Markdown:
18+
def __init__(self):
19+
self.buffer = io.BufferedRandom(io.BytesIO())
20+
self.enc = 'utf-8'
21+
22+
def write(self, s: str):
23+
self.buffer.write(s.encode(self.enc))
24+
25+
def read(self) -> bytes:
26+
self.buffer.seek(0)
27+
return self.buffer.read()
28+
29+
def main():
30+
profs = glob.glob('./prof*.txt')
31+
# profs = glob.glob('./prox*.txt')
32+
33+
dt_gpu = defaultdict(dict)
34+
dt_cpu = defaultdict(dict)
35+
columns = ["cpu"]
36+
37+
for prof in profs:
38+
impl_key = prof[7:-4]
39+
columns.append(impl_key)
40+
41+
with open(prof, 'r') as f:
42+
fl = f.readlines()
43+
44+
al = [line.rstrip().split(' ') for line in fl if line.startswith('[')]
45+
46+
for line in al:
47+
shape = line[0]
48+
t_cpu, t_gpu = (float(x) for x in line[-2:])
49+
50+
dt_gpu[shape][impl_key] = t_gpu
51+
dt_cpu[shape][impl_key] = t_cpu
52+
53+
columns.sort(key=SORT_KEY.__getitem__)
54+
55+
print(json.dumps(dt_gpu, indent=2))
56+
# print(dt_cpu)
57+
58+
md = Markdown()
59+
md.write('time is in **us** (10^-6 s)\n\n')
60+
md.write('|shape|' + '|'.join(columns) + '|\n')
61+
md.write('|---:' * (len(columns)+1) + '|\n')
62+
63+
for shape in dt_gpu.keys():
64+
t_cpu_avg = np.mean([x for x in dt_cpu[shape].values()])
65+
md.write(f'| {shape} | {t_cpu_avg : .3f} |')
66+
67+
for column in columns[1:]:
68+
md.write(f' {dt_gpu[shape].get(column, -1) : .3f} |')
69+
70+
md.write('\n')
71+
72+
73+
with open('readme.md', 'wb') as f:
74+
# with open('readme1.md', 'wb') as f:
75+
f.write(md.read())
76+
77+
78+
if __name__ == "__main__":
79+
main()
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
2+
1.9.0a0+git2b5c5c4
3+
4+
batch_size, matrix_size, dtype cpu_time(us), gpu_time(us)
5+
[] 2 torch.float32 105.547 75.507
6+
[] 4 torch.float32 9.592 75.125
7+
[] 8 torch.float32 10.310 75.818
8+
[] 16 torch.float32 10.427 68.911
9+
[] 32 torch.float32 13.537 77.344
10+
[] 64 torch.float32 60.569 86.546
11+
[] 128 torch.float32 99.032 119.070
12+
[] 256 torch.float32 280.218 201.018
13+
[] 512 torch.float32 1089.866 490.519
14+
[] 1024 torch.float32 6125.575 1335.486
15+
[] 2048 torch.float32 42986.248 5497.439
16+
[1] 2 torch.float32 9.669 73.801
17+
[1] 4 torch.float32 9.311 73.138
18+
[1] 8 torch.float32 10.223 73.413
19+
[1] 16 torch.float32 10.821 67.235
20+
[1] 32 torch.float32 13.647 69.747
21+
[1] 64 torch.float32 56.102 83.778
22+
[1] 128 torch.float32 164.089 109.557
23+
[1] 256 torch.float32 300.865 185.843
24+
[1] 512 torch.float32 835.133 427.641
25+
[1] 1024 torch.float32 4356.145 1345.123
26+
[1] 2048 torch.float32 26658.406 5495.042
27+
[2] 2 torch.float32 10.254 48.923
28+
[2] 4 torch.float32 10.238 48.424
29+
[2] 8 torch.float32 10.865 49.670
30+
[2] 16 torch.float32 12.029 49.565
31+
[2] 32 torch.float32 18.553 335.974
32+
[2] 64 torch.float32 83.658 405.704
33+
[2] 128 torch.float32 170.118 529.372
34+
[2] 256 torch.float32 365.396 830.517
35+
[2] 512 torch.float32 1402.911 1562.380
36+
[2] 1024 torch.float32 8500.582 3699.644
37+
numerical mismatch: reconstruct value compare
38+
With rtol=0.001 and atol=0.001, found 1 element(s) (out of 8388608) whose difference(s) exceeded the margin of error (including 0 nan comparisons). The greatest difference was 0.0010547935962677002 (0.04455813765525818 vs. 0.04350334405899048), which occurred at index (1, 452, 1011).
39+
[2] 2048 torch.float32 60374.918 13091.600
40+
[4] 2 torch.float32 11.771 51.131
41+
[4] 4 torch.float32 12.223 49.632
42+
[4] 8 torch.float32 12.293 51.562
43+
[4] 16 torch.float32 15.020 50.697
44+
[4] 32 torch.float32 26.133 335.603
45+
[4] 64 torch.float32 154.521 459.424
46+
[4] 128 torch.float32 269.843 556.146
47+
[4] 256 torch.float32 571.958 888.574
48+
[4] 512 torch.float32 2527.016 1773.859
49+
[4] 1024 torch.float32 17031.137 4997.247
50+
[4] 2048 torch.float32 119452.786 21604.799
51+
[8] 2 torch.float32 17.611 66.310
52+
[8] 4 torch.float32 19.614 65.430
53+
[8] 8 torch.float32 18.976 66.751
54+
[8] 16 torch.float32 24.600 66.377
55+
[8] 32 torch.float32 49.813 368.210
56+
[8] 64 torch.float32 296.102 518.253
57+
[8] 128 torch.float32 415.326 607.669
58+
[8] 256 torch.float32 1095.607 1049.521
59+
[8] 512 torch.float32 5024.378 2348.893
60+
[8] 1024 torch.float32 42197.851 7945.452
61+
[16] 2 torch.float32 23.073 66.698
62+
[16] 4 torch.float32 24.247 66.334
63+
[16] 8 torch.float32 25.295 66.991
64+
[16] 16 torch.float32 36.662 66.900
65+
[16] 32 torch.float32 86.474 375.259
66+
[16] 64 torch.float32 520.860 456.016
67+
[16] 128 torch.float32 715.033 654.156
68+
[16] 256 torch.float32 2046.187 1219.178
69+
[16] 512 torch.float32 10900.669 3345.146
70+
[32] 2 torch.float32 31.379 66.758
71+
[32] 4 torch.float32 37.876 66.538
72+
[32] 8 torch.float32 39.243 67.152
73+
[32] 16 torch.float32 59.557 67.266
74+
[32] 32 torch.float32 157.140 383.520
75+
[32] 64 torch.float32 955.098 512.199
76+
[32] 128 torch.float32 1370.115 723.370
77+
[32] 256 torch.float32 4047.383 1559.268
78+
[64] 2 torch.float32 49.703 67.573
79+
[64] 4 torch.float32 59.655 67.368
80+
[64] 8 torch.float32 63.415 67.888
81+
[64] 16 torch.float32 104.959 68.390
82+
[64] 32 torch.float32 294.157 381.888
83+
[64] 64 torch.float32 1776.475 486.399
84+
[64] 128 torch.float32 2635.866 829.155
85+
[128] 2 torch.float32 85.740 68.507
86+
[128] 4 torch.float32 105.935 67.955
87+
[128] 8 torch.float32 132.358 69.039
88+
[128] 16 torch.float32 194.751 69.127
89+
[128] 32 torch.float32 530.604 386.889
90+
[128] 64 torch.float32 3484.117 522.555
91+
[256] 2 torch.float32 159.428 68.678
92+
[256] 4 torch.float32 199.956 68.533
93+
[256] 8 torch.float32 207.843 69.817
94+
[256] 16 torch.float32 370.517 73.783
95+
[256] 32 torch.float32 998.839 415.101
96+
[512] 2 torch.float32 312.570 72.967
97+
[512] 4 torch.float32 386.612 73.049
98+
[512] 8 torch.float32 401.845 75.147
99+
[512] 16 torch.float32 663.637 79.657
100+
[1024] 2 torch.float32 599.290 85.372
101+
[1024] 4 torch.float32 766.145 84.642
102+
[1024] 8 torch.float32 797.913 88.762

0 commit comments

Comments
 (0)