From 0f0223dfdf57d2e28167e08106ed03e07826d05b Mon Sep 17 00:00:00 2001
From: Tim Gymnich <tgymnich@icloud.com>
Date: Tue, 5 Sep 2023 16:02:51 +0200
Subject: [PATCH 1/4] ldiv

---
 lib/mps/linalg.jl | 142 ++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 142 insertions(+)

diff --git a/lib/mps/linalg.jl b/lib/mps/linalg.jl
index dd70f5f29..67ad33d93 100644
--- a/lib/mps/linalg.jl
+++ b/lib/mps/linalg.jl
@@ -231,5 +231,147 @@ end
 
     commit!(cmdbuf)
 
+    wait_completed(cmdbuf)
+
+    return B
+end
+
+
+function LinearAlgebra.:(\)(A::LU{T,<:MtlMatrix{T},<:MtlVector{UInt32}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
+    C = deepcopy(B)
+    LinearAlgebra.ldiv!(A, C)
+    return C
+end
+
+
+function LinearAlgebra.ldiv!(A::LU{T,<:MtlMatrix{T},<:MtlVector{UInt32}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
+    M, N = size(B, 1), size(B, 2)
+    dev = current_device()
+    queue = global_queue(dev)
+
+    At = similar(A.factors)
+    Bt = similar(B, (N, M))
+    P = reshape((A.ipiv .- UInt32(1)), (1, M))
+    X = similar(B, (N, M))
+
+    transpose!(At, A.factors)
+    transpose!(Bt, B)
+
+    mps_a = MPSMatrix(At)
+    mps_b = MPSMatrix(Bt)
+    mps_p = MPSMatrix(P)
+    mps_x = MPSMatrix(X)
+
+    MTLCommandBuffer(queue) do cmdbuf
+        kernel = MPSMatrixSolveLU(dev, false, M, N)
+        encode!(cmdbuf, kernel, mps_a, mps_b, mps_p, mps_x)
+    end
+
+    transpose!(B, X)
     return B
 end
+
+
+function LinearAlgebra.ldiv!(A::UpperTriangular{T,<:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
+    M, N = size(B, 1), size(B, 2)
+    dev = current_device()
+    queue = global_queue(dev)
+
+    Ad = MtlMatrix(A')
+    Br = similar(B, (M, M))
+    X = similar(Br)
+
+    transpose!(Br, B)
+
+    mps_a = MPSMatrix(Ad)
+    mps_b = MPSMatrix(Br)
+    mps_x = MPSMatrix(X)
+
+    buf = MTLCommandBuffer(queue) do cmdbuf
+        kernel = MPSMatrixSolveTriangular(dev, false, true, false, false, N, M, 1.0)
+        encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
+    end
+
+    wait_completed(buf)
+
+    copy!(B, X)
+    return B
+end
+
+
+function LinearAlgebra.ldiv!(A::UnitUpperTriangular{T,<:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
+    M, N = size(B, 1), size(B, 2)
+    dev = current_device()
+    queue = global_queue(dev)
+
+    Ad = MtlMatrix(A)
+    Br = reshape(B, (M, N))
+    X = similar(Br)
+
+    mps_a = MPSMatrix(Ad)
+    mps_b = MPSMatrix(Br)
+    mps_x = MPSMatrix(X)
+
+
+    buf = MTLCommandBuffer(queue) do cmdbuf
+        kernel = MPSMatrixSolveTriangular(dev, true, false, false, true, M, N, 1.0)
+        encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
+    end
+
+    wait_completed(buf)
+
+    copy!(Br, X)
+    return B
+end
+
+
+function LinearAlgebra.ldiv!(A::LowerTriangular{T,<:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
+    M, N = size(B, 1), size(B, 2)
+    dev = current_device()
+    queue = global_queue(dev)
+
+    Ad = MtlMatrix(A)
+    Br = reshape(B, (M, N))
+    X = similar(Br)
+
+    mps_a = MPSMatrix(Ad)
+    mps_b = MPSMatrix(Br)
+    mps_x = MPSMatrix(X)
+
+
+    buf = MTLCommandBuffer(queue) do cmdbuf
+        kernel = MPSMatrixSolveTriangular(dev, true, true, false, false, M, N, 1.0)
+        encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
+    end
+
+    wait_completed(buf)
+
+    copy!(Br, X)
+    return B
+end
+
+
+function LinearAlgebra.ldiv!(A::UnitLowerTriangular{T,<:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
+    M, N = size(B, 1), size(B, 2)
+    dev = current_device()
+    queue = global_queue(dev)
+
+    Ad = MtlMatrix(A)
+    Br = reshape(B, (M, N))
+    X = similar(Br)
+
+    mps_a = MPSMatrix(Ad)
+    mps_b = MPSMatrix(Br)
+    mps_x = MPSMatrix(X)
+
+
+    buf = MTLCommandBuffer(queue) do cmdbuf
+        kernel = MPSMatrixSolveTriangular(dev, true, true, false, true, M, N, 1.0)
+        encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
+    end
+
+    wait_completed(buf)
+
+    copy!(Br, X)
+    return B
+end
\ No newline at end of file

From 5440fbe831805207bd80cac06a141a23f677df64 Mon Sep 17 00:00:00 2001
From: Tim Gymnich <tgymnich@icloud.com>
Date: Tue, 5 Sep 2023 16:03:46 +0200
Subject: [PATCH 2/4] add test

---
 test/mps/linalg.jl | 35 +++++++++++++++++++++++++++++++++++
 1 file changed, 35 insertions(+)

diff --git a/test/mps/linalg.jl b/test/mps/linalg.jl
index 82e7c4fb5..71d222677 100644
--- a/test/mps/linalg.jl
+++ b/test/mps/linalg.jl
@@ -211,6 +211,41 @@ using Metal: storagemode
     @test storagemode(lua.factors) == storagemode(lua.ipiv) == storagemode(A)
 end
 
+@testset "solves" begin
+    b = MtlVector(rand(Float32, 1024))
+    B = MtlMatrix(rand(Float32, 1024, 1024))
+
+    A = MtlMatrix(rand(Float32, 1024, 512))
+    x = lu(A) \ b
+    @test A * x ≈ b
+    X = lu(A) \ B
+    @test A * X ≈ B
+
+    A = UpperTriangular(MtlMatrix(rand(Float32, 1024, 1024)))
+    x = A \ b
+    @test A * x ≈ b
+    X = A \ B
+    @test A * X ≈ B
+
+    A = UnitUpperTriangular(MtlMatrix(rand(Float32, 1024, 1024)))
+    x = A \ b
+    @test A * x ≈ b
+    X = A \ B
+    @test A * X ≈ B
+
+    A = LowerTriangular(MtlMatrix(rand(Float32, 1024, 1024)))
+    x = A \ b
+    @test A * x ≈ b
+    X = A \ B
+    @test A * X ≈ B
+
+    A = UnitLowerTriangular(MtlMatrix(rand(Float32, 1024, 1024)))
+    x = A \ b
+    @test A * x ≈ b
+    X = A \ B
+    @test A * X ≈ B
+end
+
 using .MPS: MPSMatrixSoftMax, MPSMatrixLogSoftMax
 @testset "MPSMatrixSoftMax" begin
     cols = rand(Int)

From d566a49632308f14b0d0d672545783445ee39685 Mon Sep 17 00:00:00 2001
From: Tim Gymnich <tgymnich@icloud.com>
Date: Mon, 15 Apr 2024 13:51:54 +0200
Subject: [PATCH 3/4] add solvers

---
 lib/mps/MPS.jl       |  2 +-
 lib/mps/libmps.jl    |  6 ++--
 lib/mps/solve.jl     | 77 ++++++++++++++++++++++++++++++++++++++++++++
 res/wrap/libmps.toml |  9 ++++++
 4 files changed, 90 insertions(+), 4 deletions(-)
 create mode 100644 lib/mps/solve.jl

diff --git a/lib/mps/MPS.jl b/lib/mps/MPS.jl
index e357c43c2..667683aae 100644
--- a/lib/mps/MPS.jl
+++ b/lib/mps/MPS.jl
@@ -39,10 +39,10 @@ include("vector.jl")
 include("matrixrandom.jl")
 include("ndarray.jl")
 include("decomposition.jl")
+include("solve.jl")
 include("copy.jl")
 
 # integrations
 include("random.jl")
 include("linalg.jl")
-
 end
diff --git a/lib/mps/libmps.jl b/lib/mps/libmps.jl
index fcb2b4048..ccb2be5ca 100644
--- a/lib/mps/libmps.jl
+++ b/lib/mps/libmps.jl
@@ -535,11 +535,11 @@ end
 
 @objcwrapper immutable = false MPSMatrixVectorMultiplication <: MPSMatrixBinaryKernel
 
-@objcwrapper immutable = true MPSMatrixSolveTriangular <: MPSMatrixBinaryKernel
+@objcwrapper immutable = false MPSMatrixSolveTriangular <: MPSMatrixBinaryKernel
 
-@objcwrapper immutable = true MPSMatrixSolveLU <: MPSMatrixBinaryKernel
+@objcwrapper immutable = false MPSMatrixSolveLU <: MPSMatrixBinaryKernel
 
-@objcwrapper immutable = true MPSMatrixSolveCholesky <: MPSMatrixBinaryKernel
+@objcwrapper immutable = false MPSMatrixSolveCholesky <: MPSMatrixBinaryKernel
 
 @cenum MPSMatrixDecompositionStatus::Int32 begin
     MPSMatrixDecompositionStatusSuccess = 0
diff --git a/lib/mps/solve.jl b/lib/mps/solve.jl
new file mode 100644
index 000000000..080fc19d5
--- /dev/null
+++ b/lib/mps/solve.jl
@@ -0,0 +1,77 @@
+
+export MPSMatrixSolveTriangular
+
+# @objcwrapper immutable=false MPSMatrixSolveTriangular <: MPSMatrixUnaryKernel
+
+function MPSMatrixSolveTriangular(device, right, upper, unit, order, numberOfRightHandSides, alpha)
+    kernel = @objc [MPSMatrixSolveTriangular alloc]::id{MPSMatrixSolveTriangular}
+    obj = MPSMatrixSolveTriangular(kernel)
+    finalizer(release, obj)
+    @objc [obj::id{MPSMatrixSolveTriangular} initWithDevice:device::id{MTLDevice}
+                                             right:right::Bool
+                                             upper:upper::Bool
+                                             transpose:transpose::Bool
+                                             unit:unit::Bool
+                                             order:order::NSUInteger
+                                             numberOfRightHandSides:numberOfRightHandSides::NSUInteger
+                                             alpha:alpha::Float64]::id{MPSMatrixSolveTriangular}
+    return obj
+end
+
+function encode!(cmdbuf::MTLCommandBuffer, kernel::MPSMatrixSolveTriangular, sourceMatrix, resultMatrix, pivotIndices, status)
+    @objc [kernel::id{MPSMatrixSolveTriangular} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
+                                                sourceMatrix:sourceMatrix::id{MPSMatrix}
+                                                resultMatrix:resultMatrix::id{MPSMatrix}
+                                                pivotIndices:pivotIndices::id{MPSMatrix}
+                                                status:status::id{MPSMatrix}]::Nothing
+end
+
+
+export MPSMatrixSolveLU
+
+# @objcwrapper immutable=false MPSMatrixSolveLU <: MPSMatrixUnaryKernel
+
+function MPSMatrixSolveLU(device, transpose, order, numberOfRightHandSides)
+    kernel = @objc [MPSMatrixSolveLU alloc]::id{MPSMatrixSolveLU}
+    obj = MPSMatrixSolveLU(kernel)
+    finalizer(release, obj)
+    @objc [obj::id{MPSMatrixSolveLU} initWithDevice:device::id{MTLDevice}
+                                             transpose:transpose::Bool
+                                             order:order::NSUInteger
+                                             numberOfRightHandSides:numberOfRightHandSides::NSUInteger]::id{MPSMatrixSolveLU}
+    return obj
+end
+
+function encode!(cmdbuf::MTLCommandBuffer, kernel::MPSMatrixSolveLU, sourceMatrix, rightHandSideMatrix, pivotIndices, solutionMatrix)
+    @objc [kernel::id{MPSMatrixSolveLU} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
+                                                sourceMatrix:sourceMatrix::id{MPSMatrix}
+                                                rightHandSideMatrix:rightHandSideMatrix::id{MPSMatrix}
+                                                pivotIndices:pivotIndices::id{MPSMatrix}
+                                                solutionMatrix:solutionMatrix::id{MPSMatrix}]::Nothing
+end
+
+
+
+
+export MPSMatrixSolveCholesky
+
+# @objcwrapper immutable=false MPSMatrixSolveCholesky <: MPSMatrixUnaryKernel
+
+function MPSMatrixSolveCholesky(device, upper, order, numberOfRightHandSides)
+    kernel = @objc [MPSMatrixSolveCholesky alloc]::id{MPSMatrixSolveCholesky}
+    obj = MPSMatrixSolveCholesky(kernel)
+    finalizer(release, obj)
+    @objc [obj::id{MPSMatrixSolveCholesky} initWithDevice:device::id{MTLDevice}
+                                             upper:upper::Bool
+                                             order:order::NSUInteger
+                                             numberOfRightHandSides:numberOfRightHandSides::NSUInteger]::id{MPSMatrixSolveCholesky}
+    return obj
+end
+
+function encode!(cmdbuf::MTLCommandBuffer, kernel::MPSMatrixSolveCholesky, sourceMatrix, rightHandSideMatrix, solutionMatrix)
+    @objc [kernel::id{MPSMatrixSolveCholesky} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
+                                                sourceMatrix:sourceMatrix::id{MPSMatrix}
+                                                rightHandSideMatrix:rightHandSideMatrix::id{MPSMatrix}
+                                                pivotIndices:pivotIndices::id{MPSMatrix}
+                                                solutionMatrix:solutionMatrix::id{MPSMatrix}]::Nothing
+end
diff --git a/res/wrap/libmps.toml b/res/wrap/libmps.toml
index 668b52c63..ec4c8e622 100644
--- a/res/wrap/libmps.toml
+++ b/res/wrap/libmps.toml
@@ -102,6 +102,15 @@ immutable=false
 [api.MPSMatrixSoftMax]
 immutable=false
 
+[api.MPSMatrixSolveTriangular]
+immutable=false
+
+[api.MPSMatrixSolveLU]
+immutable=false
+
+[api.MPSMatrixSolveCholesky]
+immutable=false
+
 [api.MPSMatrixUnaryKernel]
 immutable=false
 

From b49a90ed3ee77f4969bc2eca520c3752d4aa3174 Mon Sep 17 00:00:00 2001
From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com>
Date: Sat, 8 Feb 2025 19:21:17 -0400
Subject: [PATCH 4/4] `current_device()` -> `device()`

---
 lib/mps/linalg.jl | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/lib/mps/linalg.jl b/lib/mps/linalg.jl
index 67ad33d93..e24c7c38b 100644
--- a/lib/mps/linalg.jl
+++ b/lib/mps/linalg.jl
@@ -246,7 +246,7 @@ end
 
 function LinearAlgebra.ldiv!(A::LU{T,<:MtlMatrix{T},<:MtlVector{UInt32}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
     M, N = size(B, 1), size(B, 2)
-    dev = current_device()
+    dev = device()
     queue = global_queue(dev)
 
     At = similar(A.factors)
@@ -274,7 +274,7 @@ end
 
 function LinearAlgebra.ldiv!(A::UpperTriangular{T,<:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
     M, N = size(B, 1), size(B, 2)
-    dev = current_device()
+    dev = device()
     queue = global_queue(dev)
 
     Ad = MtlMatrix(A')
@@ -301,7 +301,7 @@ end
 
 function LinearAlgebra.ldiv!(A::UnitUpperTriangular{T,<:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
     M, N = size(B, 1), size(B, 2)
-    dev = current_device()
+    dev = device()
     queue = global_queue(dev)
 
     Ad = MtlMatrix(A)
@@ -327,7 +327,7 @@ end
 
 function LinearAlgebra.ldiv!(A::LowerTriangular{T,<:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
     M, N = size(B, 1), size(B, 2)
-    dev = current_device()
+    dev = device()
     queue = global_queue(dev)
 
     Ad = MtlMatrix(A)
@@ -353,7 +353,7 @@ end
 
 function LinearAlgebra.ldiv!(A::UnitLowerTriangular{T,<:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
     M, N = size(B, 1), size(B, 2)
-    dev = current_device()
+    dev = device()
     queue = global_queue(dev)
 
     Ad = MtlMatrix(A)