-
-
Notifications
You must be signed in to change notification settings - Fork 603
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make bias optional #873
Make bias optional #873
Conversation
Thank you for your prompt PR. For completeness sake this option could also be added to dense, recurrent and normalization layers (to disable affine transformation after normalization). |
src/layers/conv.jl
Outdated
end | ||
|
||
function Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity; | ||
stride = 1, pad = 0, dilation = 1) where {T,N} | ||
stride = 1, pad = 0, dilation = 1, use_bias = true) where {T,N} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is weird though that when calling this constructor with use_bias=false
one has to pass a vector b
as well.
I would suggest the following non-breaking change instead of the use_bias
flag:
- relax the signature to
Conv(w::AbstractArray{T,N}, b::Union{Nothing,AbstractVector{T}}, σ = identity; ...)
and have a call to
Conv(w, nothing)
construct a Conv layer with no bias
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be able to support Conv(w, nothing)
and make it a bit more extensible now, I think
How about if we did this by setting the bias to We could even use e.g. some kind of |
Passing |
Yeah, holding the bias felt awkward to say the least. I had wanted to keep the signature more or less the same to avoid breaking code that depended on that behavior. I had also previously tried just setting the bias to a I'll modify this to have a nothing instead. That looks clean |
I was going to suggest setting the bias to Would there be much extra computing, given that you are already broadcasting Some commit of #856 had |
We could introduce something like
And use this whenever we need it for similar use cases. The constructor can just accept this type or a |
With the |
Good point about reshape, and |
Now we should be able to construct |
The Conv((k, k), in=>out; weight = convweight((k, k), in=>out), bias = convbias(out)) = ... Then we can do |
For people who want to provide their own weight matrices, A I also worry that if provided, it effectively makes the first few args to the convenience constructor redundant. |
This ought to be documented in the "convenience" constructor as well! Great change, too; I always wrote silly extra constructors until now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good!
test/layers/conv.jl
Outdated
op = bias(ip) | ||
@test sum(op) == prod(size(op)) | ||
|
||
bias = Conv((2,2), 1=>3, bias = zero(3)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why use the zero
function here? Also, are we still interested in a zero
type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, did passing a bias = zero
(from #873 (comment)) point to the zero type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, sorry, I was imagining a definition like zero = Zero()
. That name might not work given that zero
exists as a function, but I think we'll need a new type to guarantee that it won't be updated by optimisers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, in that case I'll just get the definition of ZeroType
back and we can have the guarantee baked into that
src/utils.jl
Outdated
res .= zero(S) | ||
end | ||
|
||
function *(a::Zeros{T,2}, b::AbstractArray{S,2}) where {T,S} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the motivation for this method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Largely to define matmul here to avoid scalar operations, iirc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flux.Zeros(3000,3000)) * rand(3000,3000);
would hit generic_matmul!
here, which is slow too
julia> @btime $(zeros(3000,3000)) * $(rand(3000,3000));
564.237 ms (2 allocations: 68.66 MiB)
julia> @btime $(Flux.Zeros(3000,3000)) * $(rand(3000,3000));
7.953 ms (6 allocations: 68.66 MiB)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense, I was just surprised that this method came up at all since Zeros()
usually gets used as a bias. Did this come up from trying to support Dense
layers with Zeros()
for the weight?
If so it'd be good to discuss whether that's something we want to support anyway; it seems like an odd thing to want to do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Partly to cover our ground in case it ends up being used outside of the context of a bias, and to cover the basic ops so it doesn't end up being accidentally slow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's an unlimited number of methods we could special case like this though, right? Given that we can't add everything I think we need to decide what we're supporting and why. For broadcasting +
and -
that's obvious enough (it covers the bias case) but I don't know that we want to implement every matmul method without a good use case for it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair point
@MikeInnes I added the kwarg only constructors, and made some docs corrections too. Perhaps good to get this in for now? |
The error happens with or without the bias machinery; perhaps Zygote related |
src/utils.jl
Outdated
|
||
function broadcasted(::typeof(*), a::AbstractArray, b::Zeros) | ||
sz = similar(a, Broadcast.broadcast_shape(size(a), size(b))) | ||
sz .= zero(a) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have the same concern with this as with the methods above, i.e. if we're going to allocate anyway we should just be able to use the built-in fallback.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could also potentially just have it return a correctly shaped Zeros
object to avoid allocating.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I understand the concern, and I agree that that shouldn't be necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that seems like the best route for this particular method.
why we need to define our own types when we can use FillArrays.Zeros? FullArrays is already used by zygote. It's very hard to review this PR, since it comes with a lot of unrelated doc additions |
I think the reasoning is that we want to represent not just an array of zeros, but an array which is held fixed with respect to gradient descent. There are a couple of ways to do this; either Zygote or FillArrays could decide that the gradient of a
|
bors r+ |
Build succeeded: |
Addresses #868