Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hvncat: Better handling of 0- and 1-length dims/shape args #41197

Merged
merged 26 commits into from
Jul 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
1b35e95
Better handling of 0- and 1-length dims/shape args
BioTurboNick Jun 11, 2021
b25b530
Fixed git lines screwup
BioTurboNick Jun 11, 2021
ec59c12
Restored missing bits
BioTurboNick Jun 11, 2021
a80a522
T.() => convert.()
BioTurboNick Jun 11, 2021
e73244c
error suggestion
BioTurboNick Jun 11, 2021
c5f3497
style suggestion
BioTurboNick Jun 11, 2021
c724b93
Removed duplicated methods
BioTurboNick Jun 12, 2021
cc6cd7e
T => convert
BioTurboNick Jun 12, 2021
af6b44c
Fix parentheses
BioTurboNick Jun 12, 2021
a8d71bc
fix function
BioTurboNick Jun 12, 2021
b4cfd43
Fix ambiguity
BioTurboNick Jun 12, 2021
632ef9b
whitespace fix
BioTurboNick Jun 12, 2021
a8c5d6e
Ambiguity and more T => convert
BioTurboNick Jun 12, 2021
cd55f21
fix type parameter
BioTurboNick Jun 12, 2021
415fe84
fix using
BioTurboNick Jun 12, 2021
be7d439
Belonged in other PR
BioTurboNick Jun 12, 2021
e114402
Removed method leading to unintended recurssion.
BioTurboNick Jun 12, 2021
aa80b40
Pass through to later checks
BioTurboNick Jun 12, 2021
77c078d
pass through first function
BioTurboNick Jun 12, 2021
19530be
Modify behavior of 0-argument int-form to return N-dim array or error
BioTurboNick Jun 12, 2021
52166b3
Remove stray using
BioTurboNick Jun 15, 2021
046752c
Consolidate 0-dimension case
BioTurboNick Jul 1, 2021
f730a60
Update test/abstractarray.jl
BioTurboNick Jul 1, 2021
bc53fc7
Test eltype and size
BioTurboNick Jul 1, 2021
264f3bc
Fix stack overflow when dim < ndims
BioTurboNick Jul 1, 2021
b565092
Test fix
BioTurboNick Jul 2, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 53 additions & 14 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2126,28 +2126,34 @@ julia> hvncat(((3, 3), (3, 3), (6,)), true, a, b, c, d, e, f)
4 = elements in each 4d slice (4,)
=> shape = ((2, 1, 1), (3, 1), (4,), (4,)) with `rowfirst` = true
"""
hvncat(::Tuple{}, ::Bool) = []
hvncat(::Tuple{}, ::Bool, xs...) = []
hvncat(::Tuple{Vararg{Any, 1}}, ::Bool, xs...) = vcat(xs...) # methods assume 2+ dimensions
hvncat(dimsshape::Tuple, row_first::Bool, xs...) = _hvncat(dimsshape, row_first, xs...)
hvncat(dim::Int, xs...) = _hvncat(dim, true, xs...)

_hvncat(::Union{Tuple, Int}, ::Bool) = []
_hvncat(dimsshape::Union{Tuple, Int}, row_first::Bool) = _typed_hvncat(Any, dimsshape, row_first)
_hvncat(dimsshape::Union{Tuple, Int}, row_first::Bool, xs...) = _typed_hvncat(promote_eltypeof(xs...), dimsshape, row_first, xs...)
_hvncat(dimsshape::Union{Tuple, Int}, row_first::Bool, xs::T...) where T<:Number = _typed_hvncat(T, dimsshape, row_first, xs...)
_hvncat(dimsshape::Union{Tuple, Int}, row_first::Bool, xs::Number...) = _typed_hvncat(promote_typeof(xs...), dimsshape, row_first, xs...)
_hvncat(dimsshape::Union{Tuple, Int}, row_first::Bool, xs::AbstractArray...) = _typed_hvncat(promote_eltype(xs...), dimsshape, row_first, xs...)
_hvncat(dimsshape::Union{Tuple, Int}, row_first::Bool, xs::AbstractArray{T}...) where T = _typed_hvncat(T, dimsshape, row_first, xs...)

typed_hvncat(::Type{T}, ::Tuple{}, ::Bool) where T = Vector{T}()
typed_hvncat(::Type{T}, ::Tuple{}, ::Bool, xs...) where T = Vector{T}()
typed_hvncat(T::Type, ::Tuple{Vararg{Any, 1}}, ::Bool, xs...) = typed_vcat(T, xs...) # methods assume 2+ dimensions
typed_hvncat(T::Type, dimsshape::Tuple, row_first::Bool, xs...) = _typed_hvncat(T, dimsshape, row_first, xs...)
typed_hvncat(T::Type, dim::Int, xs...) = _typed_hvncat(T, Val(dim), xs...)

_typed_hvncat(::Type{T}, ::Tuple{}, ::Bool) where T = Vector{T}()
_typed_hvncat(::Type{T}, ::Tuple{}, ::Bool, xs...) where T = Vector{T}()
_typed_hvncat(::Type{T}, ::Tuple{}, ::Bool, xs::Number...) where T = Vector{T}()
# 1-dimensional hvncat methods

_typed_hvncat(::Type, ::Val{0}) = _typed_hvncat_0d_only_one()
_typed_hvncat(T::Type, ::Val{0}, x) = fill(convert(T, x))
_typed_hvncat(T::Type, ::Val{0}, x::Number) = fill(convert(T, x))
_typed_hvncat(T::Type, ::Val{0}, x::AbstractArray) = convert.(T, x)
_typed_hvncat(::Type, ::Val{0}, ::Any...) = _typed_hvncat_0d_only_one()
_typed_hvncat(::Type, ::Val{0}, ::Number...) = _typed_hvncat_0d_only_one()
_typed_hvncat(::Type, ::Val{0}, ::AbstractArray...) = _typed_hvncat_0d_only_one()

_typed_hvncat_0d_only_one() =
throw(ArgumentError("a 0-dimensional array may only contain exactly one element"))

_typed_hvncat(::Type{T}, ::Val{N}) where {T, N} = Array{T, N}(undef, ntuple(x -> 0, Val(N)))

function _typed_hvncat(::Type{T}, dims::Tuple{Vararg{Int, N}}, row_first::Bool, xs::Number...) where {T, N}
A = Array{T, N}(undef, dims...)
lengtha = length(A) # Necessary to store result because throw blocks are being deoptimized right now, which leads to excessive allocations
Expand Down Expand Up @@ -2185,14 +2191,13 @@ function hvncat_fill!(A::Array, row_first::Bool, xs::Tuple)
end

_typed_hvncat(T::Type, dim::Int, ::Bool, xs...) = _typed_hvncat(T, Val(dim), xs...) # catches from _hvncat type promoters
_typed_hvncat(::Type{T}, ::Val) where T = Vector{T}()
_typed_hvncat(T::Type, ::Val{N}, xs::Number...) where N = _typed_hvncat(T, (ntuple(x -> 1, N - 1)..., length(xs)), false, xs...)
function _typed_hvncat(::Type{T}, ::Val{N}, as::AbstractArray...) where {T, N}
# optimization for arrays that can be concatenated by copying them linearly into the destination
# conditions: the elements must all have 1- or 0-length dimensions above N
for a ∈ as
ndims(a) <= N || all(x -> size(a, x) == 1, (N + 1):ndims(a)) ||
return _typed_hvncat(T, (ntuple(x -> 1, N - 1)..., length(as)), false, as...)
return _typed_hvncat(T, (ntuple(x -> 1, N - 1)..., length(as), 1), false, as...)
# the extra 1 is to avoid an infinite cycle
end

nd = max(N, ndims(as[1]))
Expand Down Expand Up @@ -2246,6 +2251,31 @@ function _typed_hvncat(::Type{T}, ::Val{N}, as...) where {T, N}
return A
end


# 0-dimensional cases for balanced and unbalanced hvncat method

_typed_hvncat(T::Type, ::Tuple{}, ::Bool, x...) = _typed_hvncat(T, Val(0), x...)
_typed_hvncat(T::Type, ::Tuple{}, ::Bool, x::Number...) = _typed_hvncat(T, Val(0), x...)


# balanced dimensions hvncat methods

_typed_hvncat(T::Type, dims::Tuple{Int}, ::Bool, as...) = _typed_hvncat_1d(T, dims[1], Val(false), as...)
_typed_hvncat(T::Type, dims::Tuple{Int}, ::Bool, as::Number...) = _typed_hvncat_1d(T, dims[1], Val(false), as...)

function _typed_hvncat_1d(::Type{T}, ds::Int, ::Val{row_first}, as...) where {T, row_first}
lengthas = length(as)
ds > 0 ||
throw(ArgumentError("`dimsshape` argument must consist of positive integers"))
lengthas == ds ||
throw(ArgumentError("number of elements does not match `dimshape` argument; expected $ds, got $lengthas"))
if row_first
return _typed_hvncat(T, Val(2), as...)
else
return _typed_hvncat(T, Val(1), as...)
end
end

function _typed_hvncat(::Type{T}, dims::Tuple{Vararg{Int, N}}, row_first::Bool, as...) where {T, N}
d1 = row_first ? 2 : 1
d2 = row_first ? 1 : 2
Expand Down Expand Up @@ -2308,7 +2338,16 @@ function _typed_hvncat(::Type{T}, dims::Tuple{Vararg{Int, N}}, row_first::Bool,
return A
end

function _typed_hvncat(::Type{T}, shape::Tuple{Vararg{Tuple, N}}, row_first::Bool, as...) where {T, N}

# unbalanced dimensions hvncat methods

function _typed_hvncat(T::Type, shape::Tuple{Tuple}, row_first::Bool, xs...)
length(shape[1]) > 0 ||
throw(ArgumentError("each level of `shape` argument must have at least one value"))
return _typed_hvncat_1d(T, shape[1][1], Val(row_first), xs...)
end

function _typed_hvncat(T::Type, shape::NTuple{N, Tuple}, row_first::Bool, as...) where {N}
d1 = row_first ? 2 : 1
d2 = row_first ? 1 : 2
shape = collect(shape) # saves allocations later
Expand Down
64 changes: 63 additions & 1 deletion test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1342,6 +1342,7 @@ end
end
end

using Base: typed_hvncat
@testset "hvncat" begin
a = fill(1, (2,3,2,4,5))
b = fill(2, (1,1,2,4,5))
Expand Down Expand Up @@ -1389,7 +1390,68 @@ end
@test [v v;;; fill(v, 1, 2)] == fill(v, 1, 2, 2)
end

@test_throws BoundsError hvncat(((1, 2), (3,)), false, zeros(Int, 0, 0, 0), 7, 8)
# 0-dimension behaviors
# exactly one argument, placed in an array
# if already an array, copy, with type conversion as necessary
@test_throws ArgumentError hvncat(0)
@test hvncat(0, 1) == fill(1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need to allow this? Passing () for 0-d makes sense, but this form doesn't to me, since there is no dimension 0.

Copy link
Contributor Author

@BioTurboNick BioTurboNick Jun 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hvncat(::Int, args...) is just a specialization for hvncat(::Tuple{Vararg{Int}}, true, args...), so it continues this pattern:

hvncat(3, ...) == hvncat((1, 1, n), ...).
hvncat(2, ...) == hvncat((1, n), ...)
hvncat(1, ...) == hvncat((n,), ...)
hvncat(0, ...) == hvncat((), ....)

That does mean, though, I could actually consolidate the _typed_hvncat methods for the 0-d cases to all refer to _typed_hvncat(::Type{T}, ::Val{0}, args...), or vice versa.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I guess the way to understand it is that it refers to the number of dimensions of the block array?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, and it's the special case where there is only one dimension involved: [a ;;; b ;;; c] => hvncat(3, a, b, c) or [a ;;;] => hvncat(3, a), which just bumps up the dimensions if ndims(a) < 3.

@test hvncat(0, [1]) == [1]
@test_throws ArgumentError hvncat(0, 1, 1)
@test_throws ArgumentError typed_hvncat(Float64, 0)
@test typed_hvncat(Float64, 0, 1) == fill(1.0)
@test typed_hvncat(Float64, 0, [1]) == Float64[1.0]
@test_throws ArgumentError typed_hvncat(Float64, 0, 1, 1)
@test_throws ArgumentError hvncat((), true) == []
@test hvncat((), true, 1) == fill(1)
@test hvncat((), true, [1]) == [1]
@test_throws ArgumentError hvncat((), true, 1, 1)
@test_throws ArgumentError typed_hvncat(Float64, (), true) == Float64[]
@test typed_hvncat(Float64, (), true, 1) == fill(1.0)
@test typed_hvncat(Float64, (), true, [1]) == [1.0]
@test_throws ArgumentError typed_hvncat(Float64, (), true, 1, 1)

# 1-dimension behaviors
# int form
@test hvncat(1) == []
@test hvncat(1, 1) == [1]
@test hvncat(1, [1]) == [1]
@test hvncat(1, [1 2; 3 4]) == [1 2; 3 4]
@test hvncat(1, 1, 1) == [1 ; 1]
@test typed_hvncat(Float64, 1) == Float64[]
@test typed_hvncat(Float64, 1, 1) == Float64[1.0]
@test typed_hvncat(Float64, 1, [1]) == Float64[1.0]
@test typed_hvncat(Float64, 1, 1, 1) == Float64[1.0 ; 1.0]
# dims form
@test_throws ArgumentError hvncat((1,), true)
@test hvncat((2,), true, 1, 1) == [1; 1]
@test hvncat((2,), true, [1], [1]) == [1; 1]
@test_throws ArgumentError hvncat((2,), true, 1)
@test typed_hvncat(Float64, (2,), true, 1, 1) == Float64[1.0; 1.0]
@test typed_hvncat(Float64, (2,), true, [1], [1]) == Float64[1.0; 1.0]
@test_throws ArgumentError typed_hvncat(Float64, (2,), true, 1)
# row_first has no effect with just one dimension of the dims form
@test hvncat((2,), false, 1, 1) == [1; 1]
@test typed_hvncat(Float64, (2,), false, 1, 1) == Float64[1.0; 1.0]
# shape form
@test hvncat(((2,),), true, 1, 1) == [1 1]
@test hvncat(((2,),), true, [1], [1]) == [1 1]
@test_throws ArgumentError hvncat(((2,),), true, 1)
@test hvncat(((2,),), false, 1, 1) == [1; 1]
@test hvncat(((2,),), false, [1], [1]) == [1; 1]
@test typed_hvncat(Float64, ((2,),), true, 1, 1) == Float64[1.0 1.0]
@test typed_hvncat(Float64, ((2,),), true, [1], [1]) == Float64[1.0 1.0]
@test_throws ArgumentError typed_hvncat(Float64, ((2,),), true, 1)
@test typed_hvncat(Float64, ((2,),), false, 1, 1) == Float64[1.0; 1.0]
@test typed_hvncat(Float64, ((2,),), false, [1], [1]) == Float64[1.0; 1.0]

# zero-value behaviors for int form above dimension zero
# e.g. [;;], [;;;], though that isn't valid syntax
@test [] == hvncat(1) isa Array{Any, 1}
@test Array{Any, 2}(undef, 0, 0) == hvncat(2) isa Array{Any, 2}
@test Array{Any, 3}(undef, 0, 0, 0) == hvncat(3) isa Array{Any, 3}
@test Int[] == typed_hvncat(Int, 1) isa Array{Int, 1}
@test Array{Int, 2}(undef, 0, 0) == typed_hvncat(Int, 2) isa Array{Int, 2}
@test Array{Int, 3}(undef, 0, 0, 0) == typed_hvncat(Int, 3) isa Array{Int, 3}
end

@testset "keepat!" begin
Expand Down