From 3e3ed394c168e3e24575f01ae83416dc2c8e96ba Mon Sep 17 00:00:00 2001 From: Takafumi Arakaki Date: Sat, 24 Nov 2018 15:03:15 -0800 Subject: [PATCH 1/3] Implement lazy addition --- src/LazyArrays.jl | 2 +- src/lazybroadcasting.jl | 25 +++++++++++ src/linalg/blasmul.jl | 13 ++++++ src/linalg/lazymul.jl | 3 ++ test/memorylayouttests.jl | 8 +++- test/multests.jl | 93 +++++++++++++++++++++++++++++++++++++++ 6 files changed, 142 insertions(+), 2 deletions(-) diff --git a/src/LazyArrays.jl b/src/LazyArrays.jl index 9946d04c..d6fd3c52 100644 --- a/src/LazyArrays.jl +++ b/src/LazyArrays.jl @@ -38,7 +38,7 @@ import FillArrays: AbstractFill import StaticArrays: StaticArrayStyle -export Mul, MulArray, MulVector, MulMatrix, +export Mul, MulArray, MulVector, MulMatrix, Add, Hcat, Vcat, Kron, BroadcastArray, cache, Ldiv, Inv, PInv, Diff, Cumsum include("memorylayout.jl") diff --git a/src/lazybroadcasting.jl b/src/lazybroadcasting.jl index 01e7a5d9..8d94ed5b 100644 --- a/src/lazybroadcasting.jl +++ b/src/lazybroadcasting.jl @@ -41,6 +41,20 @@ BroadcastStyle(::Type{<:BroadcastArray{<:Any,N}}) where N = LazyArrayStyle{N}() BroadcastStyle(L::LazyArrayStyle{N}, ::StaticArrayStyle{N}) where N = L BroadcastStyle(::StaticArrayStyle{N}, L::LazyArrayStyle{N}) where N = L +""" + BroadcastLayout(f, layouts) + +is returned by `MemoryLayout(A)` if a matrix `A` is a `BroadcastArray`. +`f` is a function that broadcast operation is applied and `layouts` is +a tuple of `MemoryLayout` of the broadcasted arguments. +""" +struct BroadcastLayout{F, LAY} <: MemoryLayout + f::F + layouts::LAY +end + +MemoryLayout(A::BroadcastArray) = BroadcastLayout(A.broadcasted.f, MemoryLayout.(A.broadcasted.args)) + ## scalar-range broadcast operations ## # Ranges already support smart broadcasting for op in (+, -, big) @@ -76,3 +90,14 @@ broadcasted(::LazyArrayStyle{N}, ::typeof(*), a::AbstractArray{T,N}, b::Zeros{V, broadcast(DefaultArrayStyle{N}(), *, a, b) broadcasted(::LazyArrayStyle{N}, ::typeof(*), a::Zeros{T,N}, b::AbstractArray{V,N}) where {T,V,N} = broadcast(DefaultArrayStyle{N}(), *, a, b) + +const Add = BroadcastArray{<:Any, <:Any, <:Broadcasted{<:Any, <:Any, typeof(+)}} +const AddVector = Add{<:Any, 1} +const AddMatrix = Add{<:Any, 2} + +""" + Add(A1, A2, …, AN) + +A lazy representation of `A1 .+ A2 .+ … .+ AN`; i.e., a shorthand for `BroadcastArray(+, A1, A2, …, AN)`. +""" +Add(As...) = BroadcastArray(+, As...) diff --git a/src/linalg/blasmul.jl b/src/linalg/blasmul.jl index 6980ea1b..638817a5 100644 --- a/src/linalg/blasmul.jl +++ b/src/linalg/blasmul.jl @@ -197,6 +197,19 @@ function materialize!(M::MatMulVecAdd) default_blasmul!(α, A, B, iszero(β) ? false : β, C) end +for MulAdd_ in [MatMulMatAdd, MatMulVecAdd] + # `MulAdd{<:BroadcastLayout{typeof(+)}}` cannot "win" against + # `MatMulMatAdd` and `MatMulVecAdd` hence `@eval`: + @eval function materialize!(M::$MulAdd_{<:BroadcastLayout{typeof(+)}}) + α, A, B, β, C = M.α, M.A, M.B, M.β, M.C + lmul!(β, C) + for A in A.broadcasted.args + C .= α .* Mul(A, B) .+ C + end + C + end +end + # make copy to make sure always works @inline function _gemv!(tA, α, A, x, β, y) if x ≡ y diff --git a/src/linalg/lazymul.jl b/src/linalg/lazymul.jl index 5504a548..d055eea5 100644 --- a/src/linalg/lazymul.jl +++ b/src/linalg/lazymul.jl @@ -64,3 +64,6 @@ macro lazylmul(Typ) LinearAlgebra.lmul!(A::$Typ, x::StridedMatrix) = copyto!(x, LazyArrays.Mul(A,x)) end) end + + +@lazymul AddMatrix diff --git a/test/memorylayouttests.jl b/test/memorylayouttests.jl index 5a0394c4..f7d806c7 100644 --- a/test/memorylayouttests.jl +++ b/test/memorylayouttests.jl @@ -5,7 +5,7 @@ using LazyArrays, LinearAlgebra, FillArrays, Test UnitUpperTriangularLayout, LowerTriangularLayout, UnitLowerTriangularLayout, ScalarLayout, hermitiandata, symmetricdata, FillLayout, ZerosLayout, - VcatLayout + VcatLayout, BroadcastLayout struct FooBar end struct FooNumber <: Number end @@ -142,4 +142,10 @@ struct FooNumber <: Number end @test MemoryLayout(Vcat(Ones(10),Zeros(10))) == VcatLayout((FillLayout(), ZerosLayout())) @test MemoryLayout(Vcat([1.],Zeros(10))) == VcatLayout((DenseColumnMajor(), ZerosLayout())) end + + @testset "BroadcastArray" begin + A = [1.0 2; 3 4] + @test MemoryLayout(Add(A, Fill(0, (2, 2)), Zeros(2, 2))) == + BroadcastLayout(+, (DenseColumnMajor(), FillLayout(), ZerosLayout())) + end end diff --git a/test/multests.jl b/test/multests.jl index ff4977b1..fd400e7d 100644 --- a/test/multests.jl +++ b/test/multests.jl @@ -640,6 +640,9 @@ import Base.Broadcast: materialize, materialize! Ac = A' blasnoalloc(c, 2.0, Ac, x, 3.0, y) @test @allocated(blasnoalloc(c, 2.0, Ac, x, 3.0, y)) == 0 + Aa = Add(A, Ac) + blasnoalloc(c, 2.0, Aa, x, 3.0, y) + @test_broken @allocated(blasnoalloc(c, 2.0, Aa, x, 3.0, y)) == 0 end @testset "multi-argument mul" begin @@ -670,3 +673,93 @@ import Base.Broadcast: materialize, materialize! @test Matrix(M) ≈ A^2 end end + + +@testset "Add" begin + @testset "gemv Float64" begin + for A in (Add(randn(5,5), randn(5,5)), + Add(randn(5,5), view(randn(9, 5), 1:2:9, :))), + b in (randn(5), view(randn(5),:), view(randn(5),1:5), view(randn(9),1:2:9)) + + Ã = copy(A) + c = similar(b) + + c .= Mul(A,b) + @test c ≈ Ã*b ≈ BLAS.gemv!('N', 1.0, Ã, b, 0.0, similar(c)) + + copyto!(c, Mul(A,b)) + @test c ≈ Ã*b ≈ BLAS.gemv!('N', 1.0, Ã, b, 0.0, similar(c)) + + b̃ = copy(b) + copyto!(b̃, Mul(A,b̃)) + @test_broken c ≈ b̃ + + c .= 2.0 .* Mul(A,b) + @test c ≈ BLAS.gemv!('N', 2.0, Ã, b, 0.0, similar(c)) + + c = copy(b) + c .= Mul(A,b) .+ c + @test c ≈ BLAS.gemv!('N', 1.0, Ã, b, 1.0, copy(b)) + + c = copy(b) + c .= Mul(A,b) .+ 2.0 .* c + @test c ≈ BLAS.gemv!('N', 1.0, Ã, b, 2.0, copy(b)) + + c = copy(b) + c .= 2.0 .* Mul(A,b) .+ c + @test c ≈ BLAS.gemv!('N', 2.0, Ã, b, 1.0, copy(b)) + + c = copy(b) + c .= 3.0 .* Mul(A,b) .+ 2.0 .* c + @test c ≈ BLAS.gemv!('N', 3.0, Ã, b, 2.0, copy(b)) + + d = similar(c) + c = copy(b) + d .= 3.0 .* Mul(A,b) .+ 2.0 .* c + @test d ≈ BLAS.gemv!('N', 3.0, Ã, b, 2.0, copy(b)) + end + end + + @testset "gemm" begin + for A in (Add(randn(5,5), randn(5,5)), + Add(randn(5,5), view(randn(9, 5), 1:2:9, :))), + B in (randn(5,5), view(randn(5,5),:,:), view(randn(5,5),1:5,:), + view(randn(5,5),1:5,1:5), view(randn(5,5),:,1:5)) + + Ã = copy(A) + C = similar(B) + + C .= Mul(A,B) + @test C ≈ BLAS.gemm!('N', 'N', 1.0, Ã, B, 0.0, similar(C)) + + B .= Mul(A,B) + @test_broken C ≈ B + + C .= 2.0 .* Mul(A,B) + @test C ≈ BLAS.gemm!('N', 'N', 2.0, Ã, B, 0.0, similar(C)) + + C = copy(B) + C .= Mul(A,B) .+ C + @test C ≈ BLAS.gemm!('N', 'N', 1.0, Ã, B, 1.0, copy(B)) + + + C = copy(B) + C .= Mul(A,B) .+ 2.0 .* C + @test C ≈ BLAS.gemm!('N', 'N', 1.0, Ã, B, 2.0, copy(B)) + + C = copy(B) + C .= 2.0 .* Mul(A,B) .+ C + @test C ≈ BLAS.gemm!('N', 'N', 2.0, Ã, B, 1.0, copy(B)) + + + C = copy(B) + C .= 3.0 .* Mul(A,B) .+ 2.0 .* C + @test C ≈ BLAS.gemm!('N', 'N', 3.0, Ã, B, 2.0, copy(B)) + + d = similar(C) + C = copy(B) + d .= 3.0 .* Mul(A,B) .+ 2.0 .* C + @test d ≈ BLAS.gemm!('N', 'N', 3.0, Ã, B, 2.0, copy(B)) + end + end +end From d0c62340dde70cfb616690bb68352a59014bcd01 Mon Sep 17 00:00:00 2001 From: Takafumi Arakaki Date: Sat, 24 Nov 2018 18:05:42 -0800 Subject: [PATCH 2/3] Support B .= Mul(Add(A1, A2), B) --- src/linalg/blasmul.jl | 3 +++ test/multests.jl | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/linalg/blasmul.jl b/src/linalg/blasmul.jl index 638817a5..adadbe0f 100644 --- a/src/linalg/blasmul.jl +++ b/src/linalg/blasmul.jl @@ -202,6 +202,9 @@ for MulAdd_ in [MatMulMatAdd, MatMulVecAdd] # `MatMulMatAdd` and `MatMulVecAdd` hence `@eval`: @eval function materialize!(M::$MulAdd_{<:BroadcastLayout{typeof(+)}}) α, A, B, β, C = M.α, M.A, M.B, M.β, M.C + if C ≡ B + B = copy(B) + end lmul!(β, C) for A in A.broadcasted.args C .= α .* Mul(A, B) .+ C diff --git a/test/multests.jl b/test/multests.jl index fd400e7d..1d3d5f0d 100644 --- a/test/multests.jl +++ b/test/multests.jl @@ -692,7 +692,7 @@ end b̃ = copy(b) copyto!(b̃, Mul(A,b̃)) - @test_broken c ≈ b̃ + @test c ≈ b̃ c .= 2.0 .* Mul(A,b) @test c ≈ BLAS.gemv!('N', 2.0, Ã, b, 0.0, similar(c)) @@ -733,7 +733,7 @@ end @test C ≈ BLAS.gemm!('N', 'N', 1.0, Ã, B, 0.0, similar(C)) B .= Mul(A,B) - @test_broken C ≈ B + @test C ≈ B C .= 2.0 .* Mul(A,B) @test C ≈ BLAS.gemm!('N', 'N', 2.0, Ã, B, 0.0, similar(C)) From 75836a8eeb9ec245131e87736f669bdceb1c5e25 Mon Sep 17 00:00:00 2001 From: Takafumi Arakaki Date: Mon, 26 Nov 2018 06:21:56 -0800 Subject: [PATCH 3/3] Unexport Add --- src/LazyArrays.jl | 2 +- test/memorylayouttests.jl | 2 +- test/multests.jl | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/LazyArrays.jl b/src/LazyArrays.jl index d6fd3c52..9946d04c 100644 --- a/src/LazyArrays.jl +++ b/src/LazyArrays.jl @@ -38,7 +38,7 @@ import FillArrays: AbstractFill import StaticArrays: StaticArrayStyle -export Mul, MulArray, MulVector, MulMatrix, Add, +export Mul, MulArray, MulVector, MulMatrix, Hcat, Vcat, Kron, BroadcastArray, cache, Ldiv, Inv, PInv, Diff, Cumsum include("memorylayout.jl") diff --git a/test/memorylayouttests.jl b/test/memorylayouttests.jl index f7d806c7..1b9c5224 100644 --- a/test/memorylayouttests.jl +++ b/test/memorylayouttests.jl @@ -5,7 +5,7 @@ using LazyArrays, LinearAlgebra, FillArrays, Test UnitUpperTriangularLayout, LowerTriangularLayout, UnitLowerTriangularLayout, ScalarLayout, hermitiandata, symmetricdata, FillLayout, ZerosLayout, - VcatLayout, BroadcastLayout + VcatLayout, BroadcastLayout, Add struct FooBar end struct FooNumber <: Number end diff --git a/test/multests.jl b/test/multests.jl index 1d3d5f0d..c280ce98 100644 --- a/test/multests.jl +++ b/test/multests.jl @@ -1,5 +1,5 @@ using Test, LinearAlgebra, LazyArrays, StaticArrays, FillArrays -import LazyArrays: MulAdd, MemoryLayout, DenseColumnMajor, DiagonalLayout, SymTridiagonalLayout +import LazyArrays: MulAdd, MemoryLayout, DenseColumnMajor, DiagonalLayout, SymTridiagonalLayout, Add import Base.Broadcast: materialize, materialize! @testset "Mul" begin