Skip to content

Commit fcb2edf

Browse files
FantasyVRpre-commit-ci[bot]
authored andcommitted
[Lang] Support LU sparse solver on CUDA backend (taichi-dev#6967)
Issue: taichi-dev#2906 ### Brief Summary To be consistent with API on CPU backend, this pr provides LU sparse solver on CUDA backend. CuSolver just provides a CPU version API of LU sparse solver which is used in this PR. The cuSolverRF provides a GPU version LU solve, but it only supports `double` datatype. Thus, it's not used in this PR. Besides, the `print_triplets` is refactored to resolve the ndarray `read` constraints (the `read` and `write` data should be the same datatype). Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent e2da2ac commit fcb2edf

10 files changed

+352
-285
lines changed

python/taichi/linalg/sparse_matrix.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,11 @@ def _get_ndarray_addr(self):
243243

244244
def print_triplets(self):
245245
"""Print the triplets stored in the builder"""
246-
self.ptr.print_triplets()
246+
taichi_arch = get_runtime().prog.config().arch
247+
if taichi_arch == _ti_core.Arch.x64 or taichi_arch == _ti_core.Arch.arm64:
248+
self.ptr.print_triplets_eigen()
249+
elif taichi_arch == _ti_core.Arch.cuda:
250+
self.ptr.print_triplets_cuda()
247251

248252
def build(self, dtype=f32, _format='CSR'):
249253
"""Create a sparse matrix using the triplets"""

python/taichi/linalg/sparse_solver.py

+2-21
Original file line numberDiff line numberDiff line change
@@ -100,32 +100,13 @@ def solve(self, b): # pylint: disable=R1710
100100
return self.solver.solve(b)
101101
if isinstance(b, Ndarray):
102102
x = ScalarNdarray(b.dtype, [self.matrix.m])
103-
self.solve_rf(self.matrix, b, x)
103+
self.solver.solve_rf(get_runtime().prog, self.matrix.matrix, b.arr,
104+
x.arr)
104105
return x
105106
raise TaichiRuntimeError(
106107
f"The parameter type: {type(b)} is not supported in linear solvers for now."
107108
)
108109

109-
def solve_cu(self, sparse_matrix, b):
110-
if isinstance(sparse_matrix, SparseMatrix) and isinstance(b, Ndarray):
111-
x = ScalarNdarray(b.dtype, [sparse_matrix.m])
112-
self.solver.solve_cu(get_runtime().prog, sparse_matrix.matrix,
113-
b.arr, x.arr)
114-
return x
115-
raise TaichiRuntimeError(
116-
f"The parameter type: {type(sparse_matrix)}, {type(b)} and {type(x)} is not supported in linear solvers for now."
117-
)
118-
119-
def solve_rf(self, sparse_matrix, b, x):
120-
if isinstance(sparse_matrix, SparseMatrix) and isinstance(
121-
b, Ndarray) and isinstance(x, Ndarray):
122-
self.solver.solve_rf(get_runtime().prog, sparse_matrix.matrix,
123-
b.arr, x.arr)
124-
else:
125-
raise TaichiRuntimeError(
126-
f"The parameter type: {type(sparse_matrix)}, {type(b)} and {type(x)} is not supported in linear solvers for now."
127-
)
128-
129110
def info(self):
130111
"""Check if the linear systems are solved successfully.
131112

taichi/program/sparse_matrix.cpp

+42-6
Original file line numberDiff line numberDiff line change
@@ -99,17 +99,53 @@ SparseMatrixBuilder::SparseMatrixBuilder(int rows,
9999
prog_, dtype_, std::vector<int>{3 * (int)max_num_triplets_ + 1});
100100
}
101101

102-
void SparseMatrixBuilder::print_triplets() {
103-
num_triplets_ = ndarray_data_base_ptr_->read_int(std::vector<int>{0});
102+
template <typename T, typename G>
103+
void SparseMatrixBuilder::print_triplets_template() {
104+
auto ptr = get_ndarray_data_ptr();
105+
G *data = reinterpret_cast<G *>(ptr);
106+
num_triplets_ = data[0];
104107
fmt::print("n={}, m={}, num_triplets={} (max={})\n", rows_, cols_,
105108
num_triplets_, max_num_triplets_);
109+
data += 1;
106110
for (int i = 0; i < num_triplets_; i++) {
107-
auto idx = 3 * i + 1;
108-
auto row = ndarray_data_base_ptr_->read_int(std::vector<int>{idx});
109-
auto col = ndarray_data_base_ptr_->read_int(std::vector<int>{idx + 1});
110-
auto val = ndarray_data_base_ptr_->read_float(std::vector<int>{idx + 2});
111+
fmt::print("[{}, {}] = {}\n", data[i * 3], data[i * 3 + 1],
112+
taichi_union_cast<T>(data[i * 3 + 2]));
113+
}
114+
}
115+
116+
void SparseMatrixBuilder::print_triplets_eigen() {
117+
auto element_size = data_type_size(dtype_);
118+
switch (element_size) {
119+
case 4:
120+
print_triplets_template<float32, int32>();
121+
break;
122+
case 8:
123+
print_triplets_template<float64, int64>();
124+
break;
125+
default:
126+
TI_ERROR("Unsupported sparse matrix data type!");
127+
break;
128+
}
129+
}
130+
131+
void SparseMatrixBuilder::print_triplets_cuda() {
132+
#ifdef TI_WITH_CUDA
133+
CUDADriver::get_instance().memcpy_device_to_host(
134+
&num_triplets_, (void *)get_ndarray_data_ptr(), sizeof(int));
135+
fmt::print("n={}, m={}, num_triplets={} (max={})\n", rows_, cols_,
136+
num_triplets_, max_num_triplets_);
137+
auto len = 3 * num_triplets_ + 1;
138+
std::vector<float32> trips(len);
139+
CUDADriver::get_instance().memcpy_device_to_host(
140+
(void *)trips.data(), (void *)get_ndarray_data_ptr(),
141+
len * sizeof(float32));
142+
for (auto i = 0; i < num_triplets_; i++) {
143+
int row = taichi_union_cast<int>(trips[3 * i + 1]);
144+
int col = taichi_union_cast<int>(trips[3 * i + 2]);
145+
auto val = trips[i * 3 + 3];
111146
fmt::print("[{}, {}] = {}\n", row, col, val);
112147
}
148+
#endif
113149
}
114150

115151
intptr_t SparseMatrixBuilder::get_ndarray_data_ptr() const {

taichi/program/sparse_matrix.h

+5-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ class SparseMatrixBuilder {
2222
const std::string &storage_format,
2323
Program *prog);
2424

25-
void print_triplets();
25+
void print_triplets_eigen();
26+
void print_triplets_cuda();
2627

2728
intptr_t get_ndarray_data_ptr() const;
2829

@@ -36,6 +37,9 @@ class SparseMatrixBuilder {
3637
template <typename T, typename G>
3738
void build_template(std::unique_ptr<SparseMatrix> &);
3839

40+
template <typename T, typename G>
41+
void print_triplets_template();
42+
3943
private:
4044
uint64 num_triplets_{0};
4145
std::unique_ptr<Ndarray> ndarray_data_base_ptr_{nullptr};

0 commit comments

Comments
 (0)