From c456273a5ce43183ae1ea31be348d132c706cd7e Mon Sep 17 00:00:00 2001 From: chengchingwen Date: Fri, 19 Aug 2022 02:43:05 +0800 Subject: [PATCH 1/5] specialize diag * mat * diag --- stdlib/LinearAlgebra/src/diagonal.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 8895e4cf5e892..5829385129160 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -368,6 +368,15 @@ end return out end +function (*)(Da::Diagonal, A::AbstractMatrix, Db::Diagonal) + _muldiag_size_check(Da, A) + _muldiag_size_check(A, Db) + dda = Da.diag + ddb = Db.diag + n = length(dda) + return broadcast(*, reshape(dda, n, 1), A, reshape(ddb, 1, n)) +end + # Get ambiguous method if try to unify AbstractVector/AbstractMatrix here using AbstractVecOrMat @inline mul!(out::AbstractVector, D::Diagonal, V::AbstractVector, alpha::Number, beta::Number) = _muldiag!(out, D, V, alpha, beta) From c8c7ec6048c7b7c37e2808ddf33a01cae4c97f53 Mon Sep 17 00:00:00 2001 From: Peter Date: Fri, 19 Aug 2022 03:18:46 +0800 Subject: [PATCH 2/5] Update stdlib/LinearAlgebra/src/diagonal.jl Co-authored-by: N5N3 <2642243996@qq.com> --- stdlib/LinearAlgebra/src/diagonal.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 5829385129160..3cae185aa730c 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -374,7 +374,7 @@ function (*)(Da::Diagonal, A::AbstractMatrix, Db::Diagonal) dda = Da.diag ddb = Db.diag n = length(dda) - return broadcast(*, reshape(dda, n, 1), A, reshape(ddb, 1, n)) + return broadcast(*, dda, A, reshape(ddb, 1, n)) end # Get ambiguous method if try to unify AbstractVector/AbstractMatrix here using AbstractVecOrMat From ec3bc768dc57a81df7522eebf7af9d3893950e86 Mon Sep 17 00:00:00 2001 From: chengchingwen Date: Fri, 19 Aug 2022 03:33:18 +0800 Subject: [PATCH 3/5] add test for diag * mat * diag --- stdlib/LinearAlgebra/test/diagonal.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/stdlib/LinearAlgebra/test/diagonal.jl b/stdlib/LinearAlgebra/test/diagonal.jl index 2cbedf49440ea..25e81f50384fb 100644 --- a/stdlib/LinearAlgebra/test/diagonal.jl +++ b/stdlib/LinearAlgebra/test/diagonal.jl @@ -653,6 +653,14 @@ end @test D2 == D * D end +@testset "multiplication of 2 Diagonal and a Matix (#46400)" begin + A = randn(10, 10) + D = Diagonal(randn(10)) + D2 = Diagonal(randn(10)) + @test D * A * D2 ≈ D * (A * D2) + @test D * A * D2 ≈ (D * A) * D2 +end + @testset "multiplication of QR Q-factor and Diagonal (#16615 spot test)" begin D = Diagonal(randn(5)) Q = qr(randn(5, 5)).Q From 857c30e16608629e081e43d5a981c05c0325f300 Mon Sep 17 00:00:00 2001 From: Peter Date: Fri, 19 Aug 2022 21:47:34 +0800 Subject: [PATCH 4/5] Update stdlib/LinearAlgebra/src/diagonal.jl Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --- stdlib/LinearAlgebra/src/diagonal.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 3cae185aa730c..2c87e63c4a7c1 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -371,10 +371,7 @@ end function (*)(Da::Diagonal, A::AbstractMatrix, Db::Diagonal) _muldiag_size_check(Da, A) _muldiag_size_check(A, Db) - dda = Da.diag - ddb = Db.diag - n = length(dda) - return broadcast(*, dda, A, reshape(ddb, 1, n)) + return broadcast(*, Da.diag, A, permutedims(Db.diag)) end # Get ambiguous method if try to unify AbstractVector/AbstractMatrix here using AbstractVecOrMat From 348e7afa5cb4e7a55cede9eaa1b1e79b538c6b21 Mon Sep 17 00:00:00 2001 From: Peter Date: Mon, 22 Aug 2022 22:01:30 +0800 Subject: [PATCH 5/5] Update stdlib/LinearAlgebra/test/diagonal.jl Co-authored-by: Daniel Karrasch --- stdlib/LinearAlgebra/test/diagonal.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/stdlib/LinearAlgebra/test/diagonal.jl b/stdlib/LinearAlgebra/test/diagonal.jl index 25e81f50384fb..64e3fdbc2a6f1 100644 --- a/stdlib/LinearAlgebra/test/diagonal.jl +++ b/stdlib/LinearAlgebra/test/diagonal.jl @@ -659,6 +659,8 @@ end D2 = Diagonal(randn(10)) @test D * A * D2 ≈ D * (A * D2) @test D * A * D2 ≈ (D * A) * D2 + @test_throws DimensionMismatch Diagonal(ones(9)) * A * D2 + @test_throws DimensionMismatch D * A * Diagonal(ones(9)) end @testset "multiplication of QR Q-factor and Diagonal (#16615 spot test)" begin