From 3d5fa6226363b8503c297c2ce0e8692f37e274fe Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 12 Jul 2021 00:09:57 -0400 Subject: [PATCH 1/9] fix push + pop gradient for vector of arrays, add real tests --- src/lib/array.jl | 46 ++++++++++++++++++++++++++++++++-------------- test/gradcheck.jl | 27 +++++++++++++-------------- 2 files changed, 45 insertions(+), 28 deletions(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index 9bec64b95..ff1adfd7b 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -78,25 +78,43 @@ for f in [push!, pop!, pushfirst!, popfirst!] _ -> error("Mutating arrays is not supported -- called $($f)(::$(typeof(xs)), _...)") end -# This is kind of bad, but at least we don't materialize the whole -# array. Prefer to use `Buffer` -# function _pullback(cx::Context, ::typeof(push!), xs::AbstractVector{<:AbstractArray}, x::AbstractArray{T}...) where T -@adjoint! function push!(xs::AbstractVector{<:AbstractArray}, x::AbstractArray{T}...) where T - sz_xs = size.(xs) - sz_x = size.(x) - push!(xs, x...), Δ -> begin - (Δ, map(x -> Ones{T}(x...), sz_x)...) +# Exceptions: +@adjoint! function push!(dst::AbstractVector{<:AbstractArray}, xs::AbstractArray...) + num_xs = length(xs) + push!(dst, xs...), Δ -> begin + (Δ[1:end-num_xs], Δ[end-num_xs+1:end]...) end end -@adjoint! function pop!(xs::AbstractVector{<:AbstractArray{T}}) where T - sz_xs = size.(xs) - op = pop!(xs) - op, Δ -> begin - ([Ones{T}(sz...) for sz in sz_xs], ) - end +@adjoint! function pop!(src::AbstractVector{<:AbstractArray{T}}) where T + zs = fill(nothing, length(src)-1) + pop!(src), Δ -> (vcat(zs, [Δ]),) end +#= + + +julia> gradient((xs, y) -> sum(abs2, push!(xs, y)[1]), [[1,2], [3,4]], [5,6]) +([[2, 4], nothing, nothing], 2-element Ones{Int64}) +([[2, 4], nothing], nothing) # correct + +julia> gradient((xs, y) -> sum(abs2, push!(xs, y)[2]), [[1,2], [3,4]], [5,6]) +([nothing, [6, 8], nothing], 2-element Ones{Int64}) +([nothing, [6, 8]], nothing) # correct + +julia> gradient((xs, y) -> sum(abs2, push!(xs, y)[3]), [[1,2], [3,4]], [5,6]) +([nothing, nothing, [10, 12]], 2-element Ones{Int64}) +([nothing, nothing], [10, 12]) # correct + +(jl_aA0cjy) pkg> st Zygote + Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_aA0cjy/Project.toml` + [e88e6eb3] Zygote v0.6.14 + +julia> gradient(xs -> sum(abs2, pop!(xs)), [[1,2], [3,4]]) +(Union{Nothing, Vector{Int64}}[nothing, nothing, [6, 8]],) + +=# + # General @adjoint collect(x::Array) = collect(x), Δ -> (Δ,) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index af49b7697..3a4b12ae3 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1365,20 +1365,19 @@ using Zygote: Buffer prod(copy(b)) end == (3,) - @testset "Limited Mutation" begin - p = [rand(3,3), rand(3,3)] - r = rand(5,5) - - # TODO: ngradient cannot handle Vector{Array} - gs = gradient((p,x) -> sum(sum.(push!(p,x))), p, r) - @test length(p[end]) == length(gs[1][end]) - @test gs[1] ≈ map(x -> one.(x), p) - @test gs[2] ≈ one.(r) - - # p = [rand(3,3), rand(3,3)] # redefine `p` after mutation - # gs = gradient(x -> sum(pop!(x)), p) - # @test length(gs[1]) == 2 - # @test gs[1][1] == one.(p[1]) + @testset "limited mutation" begin + # push! into vectors of arrays -- it returns the whole new vector + @test gradient((xs, y) -> sum(abs2, push!(xs, y)[1]), [[1,2], [3,4]], [5,6]) == ([[2, 4], nothing], nothing) + @test gradient((xs, y) -> sum(abs2, push!(xs, y)[2]), [[1,2], [3,4]], [5,6]) == ([nothing, [6, 8]], nothing) + @test gradient((xs, y) -> sum(abs2, push!(xs, y)[3]), [[1,2], [3,4]], [5,6]) == ([nothing, nothing], [10, 12]) + + # multiple arguments + gradient((xs, y, z) -> 3 * sum(push!(xs, y, z)[1]), [ones(2,2)], ones(2,2), ones(2,2)) == ([fill(3,2,2)], nothing, nothing) + gradient((xs, y, z) -> 4 * sum(push!(xs, y, z)[2]), [ones(2,2)], ones(2,2), ones(2,2)) == ([nothing], fill(4,2,2), nothing) + gradient((xs, y, z) -> 5 * sum(push!(xs, y, z)[3]), [ones(2,2)], ones(2,2), ones(2,2)) == ([nothing], nothing, fill(5,2,2)) + + # pop! of vectors of arrays -- it returns only the removed element + @test gradient(xs -> sum(abs2, pop!(xs)), [[1,2], [3,4]]) == ([nothing, [6, 8]],) end end From df7d0697e329b6589310d7f760647880cda0d99a Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 12 Jul 2021 01:12:32 -0400 Subject: [PATCH 2/9] tweak --- src/lib/array.jl | 28 ++-------------------------- test/gradcheck.jl | 11 ++++++++--- 2 files changed, 10 insertions(+), 29 deletions(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index ff1adfd7b..ec21bb96d 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -79,42 +79,18 @@ for f in [push!, pop!, pushfirst!, popfirst!] end # Exceptions: -@adjoint! function push!(dst::AbstractVector{<:AbstractArray}, xs::AbstractArray...) +@adjoint! function push!(dst::AbstractVector, xs::AbstractArray...) num_xs = length(xs) push!(dst, xs...), Δ -> begin (Δ[1:end-num_xs], Δ[end-num_xs+1:end]...) end end -@adjoint! function pop!(src::AbstractVector{<:AbstractArray{T}}) where T +@adjoint! function pop!(src::AbstractVector{<:AbstractArray}) zs = fill(nothing, length(src)-1) pop!(src), Δ -> (vcat(zs, [Δ]),) end -#= - - -julia> gradient((xs, y) -> sum(abs2, push!(xs, y)[1]), [[1,2], [3,4]], [5,6]) -([[2, 4], nothing, nothing], 2-element Ones{Int64}) -([[2, 4], nothing], nothing) # correct - -julia> gradient((xs, y) -> sum(abs2, push!(xs, y)[2]), [[1,2], [3,4]], [5,6]) -([nothing, [6, 8], nothing], 2-element Ones{Int64}) -([nothing, [6, 8]], nothing) # correct - -julia> gradient((xs, y) -> sum(abs2, push!(xs, y)[3]), [[1,2], [3,4]], [5,6]) -([nothing, nothing, [10, 12]], 2-element Ones{Int64}) -([nothing, nothing], [10, 12]) # correct - -(jl_aA0cjy) pkg> st Zygote - Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_aA0cjy/Project.toml` - [e88e6eb3] Zygote v0.6.14 - -julia> gradient(xs -> sum(abs2, pop!(xs)), [[1,2], [3,4]]) -(Union{Nothing, Vector{Int64}}[nothing, nothing, [6, 8]],) - -=# - # General @adjoint collect(x::Array) = collect(x), Δ -> (Δ,) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 3a4b12ae3..590e2fc95 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1372,12 +1372,17 @@ using Zygote: Buffer @test gradient((xs, y) -> sum(abs2, push!(xs, y)[3]), [[1,2], [3,4]], [5,6]) == ([nothing, nothing], [10, 12]) # multiple arguments - gradient((xs, y, z) -> 3 * sum(push!(xs, y, z)[1]), [ones(2,2)], ones(2,2), ones(2,2)) == ([fill(3,2,2)], nothing, nothing) - gradient((xs, y, z) -> 4 * sum(push!(xs, y, z)[2]), [ones(2,2)], ones(2,2), ones(2,2)) == ([nothing], fill(4,2,2), nothing) - gradient((xs, y, z) -> 5 * sum(push!(xs, y, z)[3]), [ones(2,2)], ones(2,2), ones(2,2)) == ([nothing], nothing, fill(5,2,2)) + @test gradient((xs, y, z) -> 3 * sum(push!(xs, y, z)[1]), [ones(2,2)], ones(2,2), ones(2,2)) == ([fill(3,2,2)], nothing, nothing) + @test gradient((xs, y, z) -> 4 * sum(push!(xs, y, z)[2]), [ones(2,2)], ones(2,2), ones(2,2)) == ([nothing], fill(4,2,2), nothing) + @test gradient((xs, y, z) -> 5 * sum(push!(xs, y, z)[3]), [ones(2,2)], ones(2,2), ones(2,2)) == ([nothing], nothing, fill(5,2,2)) + + # Vector{Any} + @test gradient(x -> sum(abs2, only(push!([], x))), [1 2; 3 4]) == ([2 4; 6 8],) + @test_throws ErrorException gradient(x -> sum(abs2, push!([], x)), 1) # pop! of vectors of arrays -- it returns only the removed element @test gradient(xs -> sum(abs2, pop!(xs)), [[1,2], [3,4]]) == ([nothing, [6, 8]],) + @test_throws ErrorException gradient(xs -> pop!(xs), [1,2,3]) end end From 56b966f55db0adc29140eb7b75362a0e477935a4 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 12 Jul 2021 02:52:29 -0400 Subject: [PATCH 3/9] allow only trivial gradients in push!(::Params) etc. --- src/compiler/interface.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index 9dc934a49..7dc19ee37 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -186,14 +186,18 @@ end @adjoint! function Base.push!(xs::IdSet, x...) l = length(x) push!(xs, x...), Δ -> begin + Δ == nothing && return nothing + println("got nontrivial gradient for push!(::IdSet, ...): Δ = ", Δ) (Δ, ntuple(_ -> nothing, l)...) end end -@adjoint! function Base.push!(xs::Params, x::AbstractArray{T}...) where T +@adjoint! function Base.push!(xs::Params, x::AbstractArray...) sz_x = size.(x) push!(xs, x...), Δ -> begin - (Δ, map(x -> Ones{T}(x...), sz_x)...) + Δ == nothing && return nothing + println("got nontrivial gradient for push!(::Params, ...): Δ = ", Δ) + # (Δ, map(x -> Ones{T}(x...), sz_x)...) # don't think this is correct end end From dc1d15ff70a235f9504df5417de408433db22986 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 12 Jul 2021 03:25:34 -0400 Subject: [PATCH 4/9] generalise, and fail --- src/lib/array.jl | 16 ++++++++++++---- test/gradcheck.jl | 28 +++++++++++++++++++++++++--- 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index ec21bb96d..18127adc5 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -79,17 +79,25 @@ for f in [push!, pop!, pushfirst!, popfirst!] end # Exceptions: -@adjoint! function push!(dst::AbstractVector, xs::AbstractArray...) - num_xs = length(xs) +@adjoint! function push!(dst::AbstractVector, xs...) + # num_xs = length(xs) + n = length(xs) + valn = Val(n) push!(dst, xs...), Δ -> begin - (Δ[1:end-num_xs], Δ[end-num_xs+1:end]...) + # (Δ[1:end-num_xs], Δ[end-num_xs+1:end]...) + (Δ[1:end-n], ntuple(i -> Δ[end-n+i], valn)...) end end -@adjoint! function pop!(src::AbstractVector{<:AbstractArray}) +@adjoint! function pop!(src::AbstractVector) zs = fill(nothing, length(src)-1) pop!(src), Δ -> (vcat(zs, [Δ]),) end +@adjoint! function pop!(src::AbstractVector{<:Number}) + zs = falses(length(src)-1) + pop!(src), Δ -> (vcat(zs, Δ),) +end + # General diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 590e2fc95..5ae119eac 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1365,12 +1365,26 @@ using Zygote: Buffer prod(copy(b)) end == (3,) - @testset "limited mutation" begin - # push! into vectors of arrays -- it returns the whole new vector + @testset "push!" begin # push! returns the whole new vector + # vector of numbers + @test gradient((xs, y) -> push!(xs,y)[1], [1,2,3], 4) == ([1, 0, 0], 0) + @test gradient((xs, y) -> push!(xs,y)[end], [1,2,3], 4) == ([0, 0, 0], 1) + + @test_skip gradient([1,2,3], 4) do xs, y + z = sum(xs) + z + sum(push!(xs, y)) + end # DimensionMismatch("arrays could not be broadcast to a common size; got a dimension with lengths 3 and 4") + + # push! into vectors of arrays @test gradient((xs, y) -> sum(abs2, push!(xs, y)[1]), [[1,2], [3,4]], [5,6]) == ([[2, 4], nothing], nothing) @test gradient((xs, y) -> sum(abs2, push!(xs, y)[2]), [[1,2], [3,4]], [5,6]) == ([nothing, [6, 8]], nothing) @test gradient((xs, y) -> sum(abs2, push!(xs, y)[3]), [[1,2], [3,4]], [5,6]) == ([nothing, nothing], [10, 12]) + @test_skip gradient([[1,2], [3,4]], [5,6]) do xs, y + z = sum(sum(abs2, x) for x in xs) + z + sum(sum, push!(xs, y)) + end # DimensionMismatch("arrays could not be broadcast to a common size; got a dimension with lengths 2 and 3") + # multiple arguments @test gradient((xs, y, z) -> 3 * sum(push!(xs, y, z)[1]), [ones(2,2)], ones(2,2), ones(2,2)) == ([fill(3,2,2)], nothing, nothing) @test gradient((xs, y, z) -> 4 * sum(push!(xs, y, z)[2]), [ones(2,2)], ones(2,2), ones(2,2)) == ([nothing], fill(4,2,2), nothing) @@ -1380,7 +1394,15 @@ using Zygote: Buffer @test gradient(x -> sum(abs2, only(push!([], x))), [1 2; 3 4]) == ([2 4; 6 8],) @test_throws ErrorException gradient(x -> sum(abs2, push!([], x)), 1) - # pop! of vectors of arrays -- it returns only the removed element + end + @testset "pop!" begin # pop! returns only the removed element + @test gradient(xs -> pop!(xs)^2, [1,2,3]) == ([0,0,6],) + @test_skip gradient([1,2,3], 4) do xs, y + z = pop!(xs) + y^2 + z + sum(abs2, xs) + end # DimensionMismatch("arrays could not be broadcast to a common size; got a dimension with lengths 2 and 3") + + # pop! of vectors of arrays @test gradient(xs -> sum(abs2, pop!(xs)), [[1,2], [3,4]]) == ([nothing, [6, 8]],) @test_throws ErrorException gradient(xs -> pop!(xs), [1,2,3]) end From 4b3e44864c7c49d189a33fc3a859cdec2249e0fe Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 12 Jul 2021 08:34:06 -0400 Subject: [PATCH 5/9] fix --- src/lib/array.jl | 1 + test/features.jl | 8 ++++---- test/gradcheck.jl | 10 ++++++---- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index 18127adc5..3fb94919c 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -84,6 +84,7 @@ end n = length(xs) valn = Val(n) push!(dst, xs...), Δ -> begin + Δ === nothing && return nothing # (Δ[1:end-num_xs], Δ[end-num_xs+1:end]...) (Δ[1:end-n], ntuple(i -> Δ[end-n+i], valn)...) end diff --git a/test/features.jl b/test/features.jl index d683d0d94..c156b5d9c 100644 --- a/test/features.jl +++ b/test/features.jl @@ -351,10 +351,10 @@ end @test Zygote.@code_adjoint(f(1)) isa Zygote.Adjoint -@test_throws ErrorException Zygote.gradient(1) do x - push!([], x) - return x -end +# @test_throws ErrorException Zygote.gradient(1) do x +# push!([], x) +# return x +# end @test gradient(1) do x stk = [] diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 5ae119eac..b23cce231 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1371,8 +1371,10 @@ using Zygote: Buffer @test gradient((xs, y) -> push!(xs,y)[end], [1,2,3], 4) == ([0, 0, 0], 1) @test_skip gradient([1,2,3], 4) do xs, y - z = sum(xs) - z + sum(push!(xs, y)) + a = sum(xs) + b = sum(push!(xs, y)) + c = sum(xs) + a+b+c end # DimensionMismatch("arrays could not be broadcast to a common size; got a dimension with lengths 3 and 4") # push! into vectors of arrays @@ -1392,7 +1394,7 @@ using Zygote: Buffer # Vector{Any} @test gradient(x -> sum(abs2, only(push!([], x))), [1 2; 3 4]) == ([2 4; 6 8],) - @test_throws ErrorException gradient(x -> sum(abs2, push!([], x)), 1) + # @test_throws ErrorException gradient(x -> sum(abs2, push!([], x)), 1) end @testset "pop!" begin # pop! returns only the removed element @@ -1404,7 +1406,7 @@ using Zygote: Buffer # pop! of vectors of arrays @test gradient(xs -> sum(abs2, pop!(xs)), [[1,2], [3,4]]) == ([nothing, [6, 8]],) - @test_throws ErrorException gradient(xs -> pop!(xs), [1,2,3]) + # @test_throws ErrorException gradient(xs -> pop!(xs), [1,2,3]) end end From 182a30ec0c8fc23375899fdb500ddd6b5904fe34 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 18 Jul 2021 01:08:52 -0400 Subject: [PATCH 6/9] rm gradients which don't work --- src/lib/array.jl | 22 ---------------------- test/gradcheck.jl | 44 -------------------------------------------- 2 files changed, 66 deletions(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index 3fb94919c..ce871c23b 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -78,28 +78,6 @@ for f in [push!, pop!, pushfirst!, popfirst!] _ -> error("Mutating arrays is not supported -- called $($f)(::$(typeof(xs)), _...)") end -# Exceptions: -@adjoint! function push!(dst::AbstractVector, xs...) - # num_xs = length(xs) - n = length(xs) - valn = Val(n) - push!(dst, xs...), Δ -> begin - Δ === nothing && return nothing - # (Δ[1:end-num_xs], Δ[end-num_xs+1:end]...) - (Δ[1:end-n], ntuple(i -> Δ[end-n+i], valn)...) - end -end - -@adjoint! function pop!(src::AbstractVector) - zs = fill(nothing, length(src)-1) - pop!(src), Δ -> (vcat(zs, [Δ]),) -end -@adjoint! function pop!(src::AbstractVector{<:Number}) - zs = falses(length(src)-1) - pop!(src), Δ -> (vcat(zs, Δ),) -end - - # General @adjoint collect(x::Array) = collect(x), Δ -> (Δ,) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index b23cce231..87fe5f46f 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1365,50 +1365,6 @@ using Zygote: Buffer prod(copy(b)) end == (3,) - @testset "push!" begin # push! returns the whole new vector - # vector of numbers - @test gradient((xs, y) -> push!(xs,y)[1], [1,2,3], 4) == ([1, 0, 0], 0) - @test gradient((xs, y) -> push!(xs,y)[end], [1,2,3], 4) == ([0, 0, 0], 1) - - @test_skip gradient([1,2,3], 4) do xs, y - a = sum(xs) - b = sum(push!(xs, y)) - c = sum(xs) - a+b+c - end # DimensionMismatch("arrays could not be broadcast to a common size; got a dimension with lengths 3 and 4") - - # push! into vectors of arrays - @test gradient((xs, y) -> sum(abs2, push!(xs, y)[1]), [[1,2], [3,4]], [5,6]) == ([[2, 4], nothing], nothing) - @test gradient((xs, y) -> sum(abs2, push!(xs, y)[2]), [[1,2], [3,4]], [5,6]) == ([nothing, [6, 8]], nothing) - @test gradient((xs, y) -> sum(abs2, push!(xs, y)[3]), [[1,2], [3,4]], [5,6]) == ([nothing, nothing], [10, 12]) - - @test_skip gradient([[1,2], [3,4]], [5,6]) do xs, y - z = sum(sum(abs2, x) for x in xs) - z + sum(sum, push!(xs, y)) - end # DimensionMismatch("arrays could not be broadcast to a common size; got a dimension with lengths 2 and 3") - - # multiple arguments - @test gradient((xs, y, z) -> 3 * sum(push!(xs, y, z)[1]), [ones(2,2)], ones(2,2), ones(2,2)) == ([fill(3,2,2)], nothing, nothing) - @test gradient((xs, y, z) -> 4 * sum(push!(xs, y, z)[2]), [ones(2,2)], ones(2,2), ones(2,2)) == ([nothing], fill(4,2,2), nothing) - @test gradient((xs, y, z) -> 5 * sum(push!(xs, y, z)[3]), [ones(2,2)], ones(2,2), ones(2,2)) == ([nothing], nothing, fill(5,2,2)) - - # Vector{Any} - @test gradient(x -> sum(abs2, only(push!([], x))), [1 2; 3 4]) == ([2 4; 6 8],) - # @test_throws ErrorException gradient(x -> sum(abs2, push!([], x)), 1) - - end - @testset "pop!" begin # pop! returns only the removed element - @test gradient(xs -> pop!(xs)^2, [1,2,3]) == ([0,0,6],) - @test_skip gradient([1,2,3], 4) do xs, y - z = pop!(xs) + y^2 - z + sum(abs2, xs) - end # DimensionMismatch("arrays could not be broadcast to a common size; got a dimension with lengths 2 and 3") - - # pop! of vectors of arrays - @test gradient(xs -> sum(abs2, pop!(xs)), [[1,2], [3,4]]) == ([nothing, [6, 8]],) - # @test_throws ErrorException gradient(xs -> pop!(xs), [1,2,3]) - end - end @testset "FillArrays" begin From 1d1fbb25e728f111fe516410e1ad09dc660ce01d Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 18 Jul 2021 01:09:35 -0400 Subject: [PATCH 7/9] rm unused methods from push(IdSet) gradient --- src/compiler/interface.jl | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index 7dc19ee37..f528d02f3 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -184,21 +184,15 @@ function Base.push!(ps::Params, x) end @adjoint! function Base.push!(xs::IdSet, x...) - l = length(x) - push!(xs, x...), Δ -> begin - Δ == nothing && return nothing - println("got nontrivial gradient for push!(::IdSet, ...): Δ = ", Δ) - (Δ, ntuple(_ -> nothing, l)...) - end + back(::Nothing) = nothing + back(Δ) = error("can't handle nontrivial gradient for push!(::IdSet, ...): Δ = " * repr(Δ)) + push!(xs, x...), back end @adjoint! function Base.push!(xs::Params, x::AbstractArray...) - sz_x = size.(x) - push!(xs, x...), Δ -> begin - Δ == nothing && return nothing - println("got nontrivial gradient for push!(::Params, ...): Δ = ", Δ) - # (Δ, map(x -> Ones{T}(x...), sz_x)...) # don't think this is correct - end + back(::Nothing) = nothing + back(Δ) = error("can't handle nontrivial gradient for push!(::Params, ...): Δ = " * repr(Δ)) + push!(xs, x...), back end Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps) From f906ef1b603acdf9fea472f5c94bacf19d9dd138 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 18 Jul 2021 01:42:01 -0400 Subject: [PATCH 8/9] restrict push error to arrays, rm adjoint for params --- src/compiler/interface.jl | 12 ------------ src/lib/array.jl | 4 ++-- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index f528d02f3..e210e65b6 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -183,18 +183,6 @@ function Base.push!(ps::Params, x) return ps end -@adjoint! function Base.push!(xs::IdSet, x...) - back(::Nothing) = nothing - back(Δ) = error("can't handle nontrivial gradient for push!(::IdSet, ...): Δ = " * repr(Δ)) - push!(xs, x...), back -end - -@adjoint! function Base.push!(xs::Params, x::AbstractArray...) - back(::Nothing) = nothing - back(Δ) = error("can't handle nontrivial gradient for push!(::Params, ...): Δ = " * repr(Δ)) - push!(xs, x...), back -end - Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps) function Base.delete!(ps::Params, x) diff --git a/src/lib/array.jl b/src/lib/array.jl index ce871c23b..035a1b239 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -74,8 +74,8 @@ _droplike(dy::Union{LinearAlgebra.Adjoint, LinearAlgebra.Transpose}, dxv::Abstra _ -> error("Mutating arrays is not supported -- called copyto!(::$(typeof(xs)), _...)") for f in [push!, pop!, pushfirst!, popfirst!] - @eval @adjoint! $f(xs, x...) = $f(xs, x...), - _ -> error("Mutating arrays is not supported -- called $($f)(::$(typeof(xs)), _...)") + @eval @adjoint! $f(x::AbstractVector, ys...) = $f(x, ys...), + _ -> error("Mutating arrays is not supported -- called $($f)(::$(typeof(x)), _...)") end # General From f9d17a142a347a7cf99227e02091914fdd113ba3 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 28 Sep 2021 19:16:57 -0400 Subject: [PATCH 9/9] Update test/features.jl Co-authored-by: Brian Chen --- test/features.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/features.jl b/test/features.jl index c156b5d9c..d683d0d94 100644 --- a/test/features.jl +++ b/test/features.jl @@ -351,10 +351,10 @@ end @test Zygote.@code_adjoint(f(1)) isa Zygote.Adjoint -# @test_throws ErrorException Zygote.gradient(1) do x -# push!([], x) -# return x -# end +@test_throws ErrorException Zygote.gradient(1) do x + push!([], x) + return x +end @test gradient(1) do x stk = []