diff --git a/lib/mps/MPS.jl b/lib/mps/MPS.jl index e357c43c..667683aa 100644 --- a/lib/mps/MPS.jl +++ b/lib/mps/MPS.jl @@ -39,10 +39,10 @@ include("vector.jl") include("matrixrandom.jl") include("ndarray.jl") include("decomposition.jl") +include("solve.jl") include("copy.jl") # integrations include("random.jl") include("linalg.jl") - end diff --git a/lib/mps/libmps.jl b/lib/mps/libmps.jl index fcb2b404..ccb2be5c 100644 --- a/lib/mps/libmps.jl +++ b/lib/mps/libmps.jl @@ -535,11 +535,11 @@ end @objcwrapper immutable = false MPSMatrixVectorMultiplication <: MPSMatrixBinaryKernel -@objcwrapper immutable = true MPSMatrixSolveTriangular <: MPSMatrixBinaryKernel +@objcwrapper immutable = false MPSMatrixSolveTriangular <: MPSMatrixBinaryKernel -@objcwrapper immutable = true MPSMatrixSolveLU <: MPSMatrixBinaryKernel +@objcwrapper immutable = false MPSMatrixSolveLU <: MPSMatrixBinaryKernel -@objcwrapper immutable = true MPSMatrixSolveCholesky <: MPSMatrixBinaryKernel +@objcwrapper immutable = false MPSMatrixSolveCholesky <: MPSMatrixBinaryKernel @cenum MPSMatrixDecompositionStatus::Int32 begin MPSMatrixDecompositionStatusSuccess = 0 diff --git a/lib/mps/linalg.jl b/lib/mps/linalg.jl index dd70f5f2..e24c7c38 100644 --- a/lib/mps/linalg.jl +++ b/lib/mps/linalg.jl @@ -231,5 +231,147 @@ end commit!(cmdbuf) + wait_completed(cmdbuf) + + return B +end + + +function LinearAlgebra.:(\)(A::LU{T,<:MtlMatrix{T},<:MtlVector{UInt32}}, B::MtlVecOrMat{T}) where {T<:MtlFloat} + C = deepcopy(B) + LinearAlgebra.ldiv!(A, C) + return C +end + + +function LinearAlgebra.ldiv!(A::LU{T,<:MtlMatrix{T},<:MtlVector{UInt32}}, B::MtlVecOrMat{T}) where {T<:MtlFloat} + M, N = size(B, 1), size(B, 2) + dev = device() + queue = global_queue(dev) + + At = similar(A.factors) + Bt = similar(B, (N, M)) + P = reshape((A.ipiv .- UInt32(1)), (1, M)) + X = similar(B, (N, M)) + + transpose!(At, A.factors) + transpose!(Bt, B) + + mps_a = MPSMatrix(At) + mps_b = MPSMatrix(Bt) + mps_p = MPSMatrix(P) + mps_x = MPSMatrix(X) + + MTLCommandBuffer(queue) do cmdbuf + kernel = MPSMatrixSolveLU(dev, false, M, N) + encode!(cmdbuf, kernel, mps_a, mps_b, mps_p, mps_x) + end + + transpose!(B, X) + return B +end + + +function LinearAlgebra.ldiv!(A::UpperTriangular{T,<:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat} + M, N = size(B, 1), size(B, 2) + dev = device() + queue = global_queue(dev) + + Ad = MtlMatrix(A') + Br = similar(B, (M, M)) + X = similar(Br) + + transpose!(Br, B) + + mps_a = MPSMatrix(Ad) + mps_b = MPSMatrix(Br) + mps_x = MPSMatrix(X) + + buf = MTLCommandBuffer(queue) do cmdbuf + kernel = MPSMatrixSolveTriangular(dev, false, true, false, false, N, M, 1.0) + encode!(cmdbuf, kernel, mps_a, mps_b, mps_x) + end + + wait_completed(buf) + + copy!(B, X) + return B +end + + +function LinearAlgebra.ldiv!(A::UnitUpperTriangular{T,<:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat} + M, N = size(B, 1), size(B, 2) + dev = device() + queue = global_queue(dev) + + Ad = MtlMatrix(A) + Br = reshape(B, (M, N)) + X = similar(Br) + + mps_a = MPSMatrix(Ad) + mps_b = MPSMatrix(Br) + mps_x = MPSMatrix(X) + + + buf = MTLCommandBuffer(queue) do cmdbuf + kernel = MPSMatrixSolveTriangular(dev, true, false, false, true, M, N, 1.0) + encode!(cmdbuf, kernel, mps_a, mps_b, mps_x) + end + + wait_completed(buf) + + copy!(Br, X) + return B +end + + +function LinearAlgebra.ldiv!(A::LowerTriangular{T,<:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat} + M, N = size(B, 1), size(B, 2) + dev = device() + queue = global_queue(dev) + + Ad = MtlMatrix(A) + Br = reshape(B, (M, N)) + X = similar(Br) + + mps_a = MPSMatrix(Ad) + mps_b = MPSMatrix(Br) + mps_x = MPSMatrix(X) + + + buf = MTLCommandBuffer(queue) do cmdbuf + kernel = MPSMatrixSolveTriangular(dev, true, true, false, false, M, N, 1.0) + encode!(cmdbuf, kernel, mps_a, mps_b, mps_x) + end + + wait_completed(buf) + + copy!(Br, X) return B end + + +function LinearAlgebra.ldiv!(A::UnitLowerTriangular{T,<:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat} + M, N = size(B, 1), size(B, 2) + dev = device() + queue = global_queue(dev) + + Ad = MtlMatrix(A) + Br = reshape(B, (M, N)) + X = similar(Br) + + mps_a = MPSMatrix(Ad) + mps_b = MPSMatrix(Br) + mps_x = MPSMatrix(X) + + + buf = MTLCommandBuffer(queue) do cmdbuf + kernel = MPSMatrixSolveTriangular(dev, true, true, false, true, M, N, 1.0) + encode!(cmdbuf, kernel, mps_a, mps_b, mps_x) + end + + wait_completed(buf) + + copy!(Br, X) + return B +end \ No newline at end of file diff --git a/lib/mps/solve.jl b/lib/mps/solve.jl new file mode 100644 index 00000000..080fc19d --- /dev/null +++ b/lib/mps/solve.jl @@ -0,0 +1,77 @@ + +export MPSMatrixSolveTriangular + +# @objcwrapper immutable=false MPSMatrixSolveTriangular <: MPSMatrixUnaryKernel + +function MPSMatrixSolveTriangular(device, right, upper, unit, order, numberOfRightHandSides, alpha) + kernel = @objc [MPSMatrixSolveTriangular alloc]::id{MPSMatrixSolveTriangular} + obj = MPSMatrixSolveTriangular(kernel) + finalizer(release, obj) + @objc [obj::id{MPSMatrixSolveTriangular} initWithDevice:device::id{MTLDevice} + right:right::Bool + upper:upper::Bool + transpose:transpose::Bool + unit:unit::Bool + order:order::NSUInteger + numberOfRightHandSides:numberOfRightHandSides::NSUInteger + alpha:alpha::Float64]::id{MPSMatrixSolveTriangular} + return obj +end + +function encode!(cmdbuf::MTLCommandBuffer, kernel::MPSMatrixSolveTriangular, sourceMatrix, resultMatrix, pivotIndices, status) + @objc [kernel::id{MPSMatrixSolveTriangular} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer} + sourceMatrix:sourceMatrix::id{MPSMatrix} + resultMatrix:resultMatrix::id{MPSMatrix} + pivotIndices:pivotIndices::id{MPSMatrix} + status:status::id{MPSMatrix}]::Nothing +end + + +export MPSMatrixSolveLU + +# @objcwrapper immutable=false MPSMatrixSolveLU <: MPSMatrixUnaryKernel + +function MPSMatrixSolveLU(device, transpose, order, numberOfRightHandSides) + kernel = @objc [MPSMatrixSolveLU alloc]::id{MPSMatrixSolveLU} + obj = MPSMatrixSolveLU(kernel) + finalizer(release, obj) + @objc [obj::id{MPSMatrixSolveLU} initWithDevice:device::id{MTLDevice} + transpose:transpose::Bool + order:order::NSUInteger + numberOfRightHandSides:numberOfRightHandSides::NSUInteger]::id{MPSMatrixSolveLU} + return obj +end + +function encode!(cmdbuf::MTLCommandBuffer, kernel::MPSMatrixSolveLU, sourceMatrix, rightHandSideMatrix, pivotIndices, solutionMatrix) + @objc [kernel::id{MPSMatrixSolveLU} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer} + sourceMatrix:sourceMatrix::id{MPSMatrix} + rightHandSideMatrix:rightHandSideMatrix::id{MPSMatrix} + pivotIndices:pivotIndices::id{MPSMatrix} + solutionMatrix:solutionMatrix::id{MPSMatrix}]::Nothing +end + + + + +export MPSMatrixSolveCholesky + +# @objcwrapper immutable=false MPSMatrixSolveCholesky <: MPSMatrixUnaryKernel + +function MPSMatrixSolveCholesky(device, upper, order, numberOfRightHandSides) + kernel = @objc [MPSMatrixSolveCholesky alloc]::id{MPSMatrixSolveCholesky} + obj = MPSMatrixSolveCholesky(kernel) + finalizer(release, obj) + @objc [obj::id{MPSMatrixSolveCholesky} initWithDevice:device::id{MTLDevice} + upper:upper::Bool + order:order::NSUInteger + numberOfRightHandSides:numberOfRightHandSides::NSUInteger]::id{MPSMatrixSolveCholesky} + return obj +end + +function encode!(cmdbuf::MTLCommandBuffer, kernel::MPSMatrixSolveCholesky, sourceMatrix, rightHandSideMatrix, solutionMatrix) + @objc [kernel::id{MPSMatrixSolveCholesky} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer} + sourceMatrix:sourceMatrix::id{MPSMatrix} + rightHandSideMatrix:rightHandSideMatrix::id{MPSMatrix} + pivotIndices:pivotIndices::id{MPSMatrix} + solutionMatrix:solutionMatrix::id{MPSMatrix}]::Nothing +end diff --git a/res/wrap/libmps.toml b/res/wrap/libmps.toml index 668b52c6..ec4c8e62 100644 --- a/res/wrap/libmps.toml +++ b/res/wrap/libmps.toml @@ -102,6 +102,15 @@ immutable=false [api.MPSMatrixSoftMax] immutable=false +[api.MPSMatrixSolveTriangular] +immutable=false + +[api.MPSMatrixSolveLU] +immutable=false + +[api.MPSMatrixSolveCholesky] +immutable=false + [api.MPSMatrixUnaryKernel] immutable=false diff --git a/test/mps/linalg.jl b/test/mps/linalg.jl index 82e7c4fb..71d22267 100644 --- a/test/mps/linalg.jl +++ b/test/mps/linalg.jl @@ -211,6 +211,41 @@ using Metal: storagemode @test storagemode(lua.factors) == storagemode(lua.ipiv) == storagemode(A) end +@testset "solves" begin + b = MtlVector(rand(Float32, 1024)) + B = MtlMatrix(rand(Float32, 1024, 1024)) + + A = MtlMatrix(rand(Float32, 1024, 512)) + x = lu(A) \ b + @test A * x ≈ b + X = lu(A) \ B + @test A * X ≈ B + + A = UpperTriangular(MtlMatrix(rand(Float32, 1024, 1024))) + x = A \ b + @test A * x ≈ b + X = A \ B + @test A * X ≈ B + + A = UnitUpperTriangular(MtlMatrix(rand(Float32, 1024, 1024))) + x = A \ b + @test A * x ≈ b + X = A \ B + @test A * X ≈ B + + A = LowerTriangular(MtlMatrix(rand(Float32, 1024, 1024))) + x = A \ b + @test A * x ≈ b + X = A \ B + @test A * X ≈ B + + A = UnitLowerTriangular(MtlMatrix(rand(Float32, 1024, 1024))) + x = A \ b + @test A * x ≈ b + X = A \ B + @test A * X ≈ B +end + using .MPS: MPSMatrixSoftMax, MPSMatrixLogSoftMax @testset "MPSMatrixSoftMax" begin cols = rand(Int)