Skip to content

Commit a081497

Browse files
committed
add solvers
1 parent 15d58e4 commit a081497

File tree

2 files changed

+79
-0
lines changed

2 files changed

+79
-0
lines changed

lib/mps/MPS.jl

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ include("linalg.jl")
2727
# decompositions
2828
include("decomposition.jl")
2929

30+
include("solve.jl")
31+
3032
# matrix copy
3133
include("copy.jl")
3234

lib/mps/solve.jl

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
2+
export MPSMatrixSolveTriangular
3+
4+
@objcwrapper immutable=false MPSMatrixSolveTriangular <: MPSMatrixUnaryKernel
5+
6+
function MPSMatrixSolveTriangular(device, right, upper, unit, order, numberOfRightHandSides, alpha)
7+
kernel = @objc [MPSMatrixSolveTriangular alloc]::id{MPSMatrixSolveTriangular}
8+
obj = MPSMatrixSolveTriangular(kernel)
9+
finalizer(release, obj)
10+
@objc [obj::id{MPSMatrixSolveTriangular} initWithDevice:device::id{MTLDevice}
11+
right:right::Bool
12+
upper:upper::Bool
13+
transpose:transpose::Bool
14+
unit:unit::Bool
15+
order:order::NSUInteger
16+
numberOfRightHandSides:numberOfRightHandSides::NSUInteger
17+
alpha:alpha::Float64]::id{MPSMatrixSolveTriangular}
18+
return obj
19+
end
20+
21+
function encode!(cmdbuf::MTLCommandBuffer, kernel::MPSMatrixSolveTriangular, sourceMatrix, resultMatrix, pivotIndices, status)
22+
@objc [kernel::id{MPSMatrixSolveTriangular} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
23+
sourceMatrix:sourceMatrix::id{MPSMatrix}
24+
resultMatrix:resultMatrix::id{MPSMatrix}
25+
pivotIndices:pivotIndices::id{MPSMatrix}
26+
status:status::id{MPSMatrix}]::Nothing
27+
end
28+
29+
30+
export MPSMatrixSolveLU
31+
32+
@objcwrapper immutable=false MPSMatrixSolveLU <: MPSMatrixUnaryKernel
33+
34+
function MPSMatrixSolveLU(device, transpose, order, numberOfRightHandSides)
35+
kernel = @objc [MPSMatrixSolveLU alloc]::id{MPSMatrixSolveLU}
36+
obj = MPSMatrixSolveLU(kernel)
37+
finalizer(release, obj)
38+
@objc [obj::id{MPSMatrixSolveLU} initWithDevice:device::id{MTLDevice}
39+
transpose:transpose::Bool
40+
order:order::NSUInteger
41+
numberOfRightHandSides:numberOfRightHandSides::NSUInteger]::id{MPSMatrixSolveLU}
42+
return obj
43+
end
44+
45+
function encode!(cmdbuf::MTLCommandBuffer, kernel::MPSMatrixSolveLU, sourceMatrix, rightHandSideMatrix, pivotIndices, solutionMatrix)
46+
@objc [kernel::id{MPSMatrixSolveLU} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
47+
sourceMatrix:sourceMatrix::id{MPSMatrix}
48+
rightHandSideMatrix:rightHandSideMatrix::id{MPSMatrix}
49+
pivotIndices:pivotIndices::id{MPSMatrix}
50+
solutionMatrix:solutionMatrix::id{MPSMatrix}]::Nothing
51+
end
52+
53+
54+
55+
56+
export MPSMatrixSolveCholesky
57+
58+
@objcwrapper immutable=false MPSMatrixSolveCholesky <: MPSMatrixUnaryKernel
59+
60+
function MPSMatrixSolveCholesky(device, upper, order, numberOfRightHandSides)
61+
kernel = @objc [MPSMatrixSolveCholesky alloc]::id{MPSMatrixSolveCholesky}
62+
obj = MPSMatrixSolveCholesky(kernel)
63+
finalizer(release, obj)
64+
@objc [obj::id{MPSMatrixSolveCholesky} initWithDevice:device::id{MTLDevice}
65+
upper:upper::Bool
66+
order:order::NSUInteger
67+
numberOfRightHandSides:numberOfRightHandSides::NSUInteger]::id{MPSMatrixSolveCholesky}
68+
return obj
69+
end
70+
71+
function encode!(cmdbuf::MTLCommandBuffer, kernel::MPSMatrixSolveCholesky, sourceMatrix, rightHandSideMatrix, solutionMatrix)
72+
@objc [kernel::id{MPSMatrixSolveCholesky} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
73+
sourceMatrix:sourceMatrix::id{MPSMatrix}
74+
rightHandSideMatrix:rightHandSideMatrix::id{MPSMatrix}
75+
pivotIndices:pivotIndices::id{MPSMatrix}
76+
solutionMatrix:solutionMatrix::id{MPSMatrix}]::Nothing
77+
end

0 commit comments

Comments
 (0)