Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make bias optional #873

Merged
merged 41 commits into from
May 1, 2020
Merged

Make bias optional #873

merged 41 commits into from
May 1, 2020

Conversation

DhairyaLGandhi
Copy link
Member

Addresses #868

@pshashk
Copy link
Contributor

pshashk commented Sep 27, 2019

Thank you for your prompt PR. For completeness sake this option could also be added to dense, recurrent and normalization layers (to disable affine transformation after normalization).

end

function Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
stride = 1, pad = 0, dilation = 1, use_bias = true) where {T,N}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is weird though that when calling this constructor with use_bias=false one has to pass a vector b as well.
I would suggest the following non-breaking change instead of the use_bias flag:

  • relax the signature to
Conv(w::AbstractArray{T,N}, b::Union{Nothing,AbstractVector{T}}, σ = identity; ...)

and have a call to

Conv(w, nothing)

construct a Conv layer with no bias

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be able to support Conv(w, nothing) and make it a bit more extensible now, I think

@MikeInnes
Copy link
Member

How about if we did this by setting the bias to nothing?

We could even use e.g. some kind of Zero type that does the right thing when you broadcast over it; then we can automatically use that in a bunch of layers, rather than modifying each one.

@pshashk
Copy link
Contributor

pshashk commented Sep 27, 2019

Passing nothing as weight initializer (stuff like initW, initb, initβ, initγ) seems like the most flexible solution to me.

@DhairyaLGandhi
Copy link
Member Author

Yeah, holding the bias felt awkward to say the least. I had wanted to keep the signature more or less the same to avoid breaking code that depended on that behavior. I had also previously tried just setting the bias to a zero(eltype(T)), but the extra compute I felt could be avoided.

I'll modify this to have a nothing instead. That looks clean

@mcabbott
Copy link
Member

mcabbott commented Sep 27, 2019

I was going to suggest setting the bias to false, as this ought to drop out of σ.(conv(x, c.weight, cdims) .+ b) when applying the layer, without needing further checking. And that constructing this Conv(…, bias=false) would then be both nice and accurate.

Would there be much extra computing, given that you are already broadcasting σ = identity? Or perhaps: if we’re going to add logic here, skipping broadcasting when σ === identity && b===false would be the more important case.

Some commit of #856 had bias=false for Dense, but I didn’t get around to actually testing that on GPU.

@DhairyaLGandhi
Copy link
Member Author

We could introduce something like

struct ZeroType <: Number end
+(a::Number, ::ZeroType) = a
reshape(::ZeroType, args...) = ZeroType()

And use this whenever we need it for similar use cases. The constructor can just accept this type or a Nothing and we can get the bias (or any switch) that way

@DhairyaLGandhi
Copy link
Member Author

With the false approach, how do we handle the call to reshape in the forward pass?

@mcabbott
Copy link
Member

Good point about reshape, and [false]is ugly, so I change my vote to branching on b===nothing.

@DhairyaLGandhi
Copy link
Member Author

Now we should be able to construct Conv(w, nothing), Conv(w, ZeroType(...)) and Conv((2,2), 1=>3, use_bias = false)

@MikeInnes
Copy link
Member

The Conv(weight, bias) should be fine as is; how about the following setup for the convenience constructor:

Conv((k, k), in=>out; weight = convweight((k, k), in=>out), bias = convbias(out)) = ...

Then we can do Conv((2, 2), 3=>6, bias = zero). And as a bonus it's easy to generate the weight matrix outside of the conv layer, which is a nice utility to have.

@DhairyaLGandhi
Copy link
Member Author

For people who want to provide their own weight matrices, Conv(weight, bias) is the more intuitive constructor, and I like having convweight to help that constructor more than the convenience one, as it makes the API a bit symmetric. Conv(convweight((k,k), in=>out), convbias(out))

A weight kwarg might make it harder for folks to find Conv(weight, bias), hence will document it explicitly.

I also worry that if provided, it effectively makes the first few args to the convenience constructor redundant.

@janEbert
Copy link
Contributor

janEbert commented Oct 7, 2019

This ought to be documented in the "convenience" constructor as well! Great change, too; I always wrote silly extra constructors until now.

Copy link
Member

@MikeInnes MikeInnes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good!

src/layers/conv.jl Outdated Show resolved Hide resolved
src/layers/conv.jl Outdated Show resolved Hide resolved
src/layers/conv.jl Outdated Show resolved Hide resolved
op = bias(ip)
@test sum(op) == prod(size(op))

bias = Conv((2,2), 1=>3, bias = zero(3))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use the zero function here? Also, are we still interested in a zero type?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, did passing a bias = zero (from #873 (comment)) point to the zero type?

Copy link
Member

@MikeInnes MikeInnes Oct 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, sorry, I was imagining a definition like zero = Zero(). That name might not work given that zero exists as a function, but I think we'll need a new type to guarantee that it won't be updated by optimisers.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, in that case I'll just get the definition of ZeroType back and we can have the guarantee baked into that

src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated
res .= zero(S)
end

function *(a::Zeros{T,2}, b::AbstractArray{S,2}) where {T,S}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the motivation for this method?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Largely to define matmul here to avoid scalar operations, iirc

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flux.Zeros(3000,3000)) * rand(3000,3000);

would hit generic_matmul! here, which is slow too

julia> @btime $(zeros(3000,3000)) * $(rand(3000,3000));
  564.237 ms (2 allocations: 68.66 MiB)

julia> @btime $(Flux.Zeros(3000,3000)) * $(rand(3000,3000));
  7.953 ms (6 allocations: 68.66 MiB)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, I was just surprised that this method came up at all since Zeros() usually gets used as a bias. Did this come up from trying to support Dense layers with Zeros() for the weight?

If so it'd be good to discuss whether that's something we want to support anyway; it seems like an odd thing to want to do.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Partly to cover our ground in case it ends up being used outside of the context of a bias, and to cover the basic ops so it doesn't end up being accidentally slow

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's an unlimited number of methods we could special case like this though, right? Given that we can't add everything I think we need to decide what we're supporting and why. For broadcasting + and - that's obvious enough (it covers the bias case) but I don't know that we want to implement every matmul method without a good use case for it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point

@DhairyaLGandhi
Copy link
Member Author

@MikeInnes I added the kwarg only constructors, and made some docs corrections too. Perhaps good to get this in for now?

@DhairyaLGandhi
Copy link
Member Author

DhairyaLGandhi commented Feb 26, 2020

The error happens with or without the bias machinery; perhaps Zygote related

src/utils.jl Outdated

function broadcasted(::typeof(*), a::AbstractArray, b::Zeros)
sz = similar(a, Broadcast.broadcast_shape(size(a), size(b)))
sz .= zero(a)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the same concern with this as with the methods above, i.e. if we're going to allocate anyway we should just be able to use the built-in fallback.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also potentially just have it return a correctly shaped Zeros object to avoid allocating.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I understand the concern, and I agree that that shouldn't be necessary

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that seems like the best route for this particular method.

@DhairyaLGandhi
Copy link
Member Author

@CarloLucibello
Copy link
Member

why we need to define our own types when we can use FillArrays.Zeros? FullArrays is already used by zygote.

It's very hard to review this PR, since it comes with a lot of unrelated doc additions

@MikeInnes
Copy link
Member

FullArrays is already used by zygote.

I think the reasoning is that we want to represent not just an array of zeros, but an array which is held fixed with respect to gradient descent. There are a couple of ways to do this; either Zygote or FillArrays could decide that the gradient of a Fill is zero, but it's a bit hard to justify that outside of Flux (and probably breaks other uses of Fill). Alternatively Flux/Optimisers.jl could decide that Fill is a special case and doesn't get updated regardless of its gradient. This could again be a bit surprising and it's not clear how we'd document it (the type itself is the obvious choice, but we don't own that).

import Flux: Zeros sends a clear signal that this is a zero type that Flux knows about, which hopefully makes its behaviour feel more intuitively obvious, if not quite self-documenting.

@MikeInnes
Copy link
Member

bors r+

@bors
Copy link
Contributor

bors bot commented May 1, 2020

Build succeeded:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants