Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a structural loadparams! #1875

Merged
merged 19 commits into from
Apr 5, 2022
Merged

Conversation

darsnack
Copy link
Member

@darsnack darsnack commented Feb 15, 2022

This replaces loadparams! with loadmodel! which uses fmap to structurally walk the model and copy parameters over. Right now it mutates destination model, so fields like the activation are not copied.

I opted to have a more verbose implementation than the one-liner fmap(loadto!, m, mbar). It allows us to have more informative error messages for the standard layers. Custom layers will fallback to the error thrown by Functors.jl.

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable

@ToucheSir
Copy link
Member

Because no API addition is complete without bikeshedding, might I link to the discussion in beacon-biosignals/LegolasFlux.jl#4 (comment) 😜.

(loadmodel! is a fine name, just FYI if you haven't seen it already).

src/loading.jl Outdated Show resolved Hide resolved
@mcabbott
Copy link
Member

Can you write briefly what this does & doesn't do? E.g. it copies non-trainable parameter arrays, unlike restructure, but not integer sizes/strides, not activation functions, how about dropout rate? At least as a docstring.

What's the argument for the name change?

What happens with immutable arrays? Maybe it should work like update! in that it mutates if it can but returns the actual truth.

What this custom traverse doesn't allow is recovering all the weights from something like fmapstructural(identity, m), is this a good idea? Should the source be allowed to completely omit e.g. an activation?

It seems you could still have nice errors, even for custom layers, by just keeping track of the one-layer-below name/summary. Not fmap I guess, but could still be one function. That might end up more compact & easier to test well.

function load!(dst::AbstractArray, src::AbstractArray; str=summary(dst))
    size(dst) == size(src) || throw(DimensionMismatch("wrong size within $str"))
    copyto!!(dst, src)
end

function load!(dst, src; str="")
    din, re = functor(dst)
    sin, _ = functor(typeof(dst), src)  # could try-catch this for pretty error on missing field?
    str = string(dst)
    re(map((d,s) -> load!(d,s; str), din, sin))
end

copyto!!(dst::DenseArray, src) = copyto!(dst, src)
copyto!!(dst::AbstractArray, src) = typeof(dst)(src)  # or something, to allow immutable

Do we need the xs::AbstractVector method in addition to Params?

That's a lot of questions, sorry!

@darsnack
Copy link
Member Author

Yes, this PR still needs docstrings, examples, and doc updates before it's ready.

(loadmodel! is a fine name, just FYI if you haven't seen it already).

I had not seen that! "State" might be better than "model" here. I'm open to changing it or keeping the old name.

What's the argument for the name change?

Flux uses "parameters" to mean "trainable parameters," while this function loads both the trainable parameters and the non-trainable state. I wanted the name to make clear that we are considering the model's structure when loading. My main motivation was so that a user noticed the change, but I'm happy to keep the old name.

Can you write briefly what this does & doesn't do?

What happens with immutable arrays?

Right now they would just fail (though I could give a better error message). It would be easy enough to make this behave like Optimisers.update!, and my original implementation did just that as well as included every functor'd field. We can also allow immutable arrays but not reconstruct every field. We can also have something like trainable which allows someone to specify which fields should be "loadable."

But I am waffling on what's the right behavior. Before calling loadmodel!(m, mbar), a user would have constructed m. That means setting the stride, padding, activation functions, dropout rate, etc. When they call loadmodel!, I expect that what they want is the parameters and state that came from training mbar but not everything. So, I walked back my initial implementation.

As it stands, Flux doesn't work with immutable parameters. The gradients can be immutable, but not the parameters. I figure we can make this function behave like Optimisers.update! when we have optimizers that support immutable parameters.

What this custom traverse doesn't allow is recovering all the weights from something like fmapstructural(identity, m), is this a good idea? Should the source be allowed to completely omit e.g. an activation?

Believe this is asking a similar thing to what I addressed above, but I wasn't sure. Could you rephrase if I misunderstood?

It seems you could still have nice errors, even for custom layers, by just keeping track of the one-layer-below name/summary.

Good point, I'll refactor.

Do we need the xs::AbstractVector method in addition to Params?

I'm not sure if it is used somewhere, but the old loadparams! would accept it.

@mcabbott
Copy link
Member

mcabbott commented Feb 15, 2022

but I wasn't sure. Could you rephrase

I'm not super-sure. At the moment it needs two valid Flux models, matching in most respects (but perhaps one is trained). My thought was that it could easily take a Flux model and anything else with a matching tree of structs. And this might be useful because a nested set of NamedTuples is going to be more robust to save/serialise/etc, as it can be handled without Flux's (or your model's) special structs.

In fact, I almost wonder if it should be more like restructure, and produce the minimal thing it will accept:

julia> m = Chain(Dense(2,2,tanh));

julia> m0, re = somefun!(m);  # creating m0 is approximately free.

julia> m0  # this contains exactly what will be re-loaded, and `nothing` else -- and makes it easy to inspect what "parameters and state" are going to be loaded for layer X
(layers = ((weight = Float32[-1.1553937 1.2085382; -0.27197266 0.09527644], bias = Float32[0.0, 0.0], σ = nothing),),)

julia> re(m0)  # this is like `loadparams!`, and could have a one-step method e.g. `somefun!(m, m0)`
Chain(
  Dense(2, 2, tanh),                    # 6 parameters
)

@ToucheSir
Copy link
Member

ToucheSir commented Feb 15, 2022

The two valid models requirement doesn't seem necessary for this PR. Assuming structural similarity, I don't see any obvious barriers to using the nested namedtuple as the second arg instead of an actual model struct (i.e. drop the T). Medium-long term something like you mention would be ideal, but I see this PR as a good MVP so that users can actually start getting rid of implicit params for serialization. Better to be a little conservative while we figure out the edge cases (e.g. how to exclude extraneous state like Kyle mentioned) of a more general approach.

@darsnack
Copy link
Member Author

darsnack commented Feb 15, 2022

Okay, I see what you mean. That would be a larger step than this PR, but something I am willing to do. This would basically be two pieces:

  • a "saving" step that produces the NamedTuple that has everything that can be saved
  • a "loading" step that accepts that NamedTuple

The "loading" step can already be done by this PR as Brian mentioned (some tweaks necessary of course). Ultimately, the code can be written so that the saving and loading are both just calling somefun!. I think the user-facing API should still make saving and loading be distinct. I can't think of a way to make these the same function and be sensible.

In the somefun!(m, m0) case, I don't want to use the re that produced m0. The reason I created this PR is that we don't currently have a good way of saying "m is the ground truth structure, copy the parameters into m while making sure the structure matches." So, using a re would circumvent that check.

@mcabbott
Copy link
Member

I likewise thought destructure should become something like flatten(m), remake!(m, v) for a bit. But now I think there's something neat about the "simplify" and "rebuild" living together; it makes the connection very hard to miss. You don't need to document & remember that loadparams! is the inverse of unpack not and not of flatten (or whatever), the concept has exactly one name, one symbol exported.

I also think the ability to run the "un-load" half and see exactly which bits of the model are and are not captured is a nice thing. Instead of trying to read the docs for what a parameter is, whether X is trainable, what happens if Y isn't a functor... you can just try it and get the ground truth.

using a re would circumvent that check.

But does it have to? This somefun!(m, m_trained) doesn't need be any less strict than a stand-alone loadparams!. Whether internally it makes m0 and calls re!, or takes a shortcut... IMO we should write the version as few different paths as possible first, and then benchmark to see whether shortcuts justify adding more code (& more places to have bugs).

to be a little conservative while

Maybe. But moving on from implicit params (& introducing a new name) sounds like a good point at which to figure out the right design, rather than inflict changes later.

@ericphanson
Copy link
Contributor

re-documentation, it would be great if the docs made sure to specify the interface for custom layers to participate (e.g. when to define _loadleaf etc), not just how to use it with built-in layers.

@darsnack
Copy link
Member Author

Okay let's just make sure we're on the same page, cause all the symbols are making me confused.

We have:

  • somefun!(m) -> (m0, re): take a model, return the ground truth to be serialized and the function to put it back together
  • somefun!(m, m0): effectively call somefun!(m)[2](m0)
  • somefun!(m, mtrained): effectively call somefun!(m)[2](somefun!(mtrained)[1])

@ToucheSir
Copy link
Member

ToucheSir commented Feb 15, 2022

Personally, I'm not a fan of the out, re = fn(args...) API unless it's absolutely necessary (as in destructure), as you have to keep re around and it may close over a bunch of data that could otherwise be GCed. Having an API that effectively does Fix1(restructure, model) (not literally, but you get the idea) where restructure(model, params) = ... would be fine, but it should not be the default.

@mcabbott
Copy link
Member

mcabbott commented Feb 15, 2022

somefun!(m, mtrained): effectively call somefun!(m)[2](somefun!(mtrained)[1])

This seems redundant, I would hope that somefun!(m)[2](somefun!(mtrained)[1]) == somefun!(m)[2](mtrained), since (by definition) somefun!(mtrained)[1] is a nested struct with all the same fields as mtrained.

custom layers to participate (e.g. when to define _loadleaf etc)

I think this is an argument against defining things for every built-in layer, if we can possibly avoid it. The interface is @functor. To make pretty errors, we can use the name of the containing layer. (I know my show code is a bad citizen here, sorry...)

unless it's absolutely necessary (as in destructure),

But it's not necessary for destructure. We could just have rebuild!(m, flat) and flat = flatten(m), and let you keep m around instead. Which re closes over anyway. But then you have to document which flavour of rebuilder goes with which flavour of simplification. I don't love this out, re = fn(args...) story but it does tie the two halves together nicely, and it has some precedent around here.

The difference between the two is that somefun!mutates and destructure does not. This saves 1 copy. Since the latter is more likely to be used in a loop, arguably destructure! is the one which should mutate (or have a mutating variant).

@darsnack
Copy link
Member Author

I agree that between de/restructure and loadmodel!, we are talking about the same underlying core with slightly different use cases. Sharing code paths would be good.

The fact that loadmodel! works without needing a closure suggest to me that somefun! should be something that walks 1 structure and produces the simplified form, or walks 2 structures copying from the second to the first. Then mapping the simplified form to a flat vector can be a separate piece where closing over might be required.

@ToucheSir
Copy link
Member

Let me be more precise: destructure returns both an output (flattened params) and additional bookkeeping information (re). The latter is required in order to restructure, so even if the API was remake!(m, v) one would have to keep it around. Hence it makes a lot of sense to have re be callable (as FluxML/Optimisers.jl#54 did), as it needs to be carried around anyways.

In contrast, somefun! should only have to return the serialization-friendly model state and not any additional bookkeeping info. Therefore, there's no need to construct a reconstruction function and carry that around. Restructure(m)(v) should be the same as restructure(m, v), so there's no need for the former outside of syntactic convenience (for which you can use Fix1 or define a curried version).

@mcabbott
Copy link
Member

additional bookkeeping information (re). The latter is required in order to restructure

somefun! should only have to return the serialization-friendly model state

I think you are assuming that most uses of destructure wants the re not just the vector, while most uses of somefun!(m) would want only the simplified struct, not the re.

But this "most" doesn't seem absolute. You may want flat parameters to e.g. use them for some regularisation within the loss. You may want Base structs to save them. On the other hand, you may also (as in the present PR) want only the re part, somefun!(m, mtrained), and you may similarly call _, re = destructure(m) so that you can put re(x,p) inside some loss.

The two seem more and more analogous to me, differing in what the Base-only form looks like (nested or flat) and whether non-trainable parameters are included. I suppose I'm advocating that they have similar user-facing interfaces, more or less, so that there are fewer different things to remember. (Whether they can share any code I don't know -- most of the code in destructure is to make gradients work and I think there's no intention to do that here.)

The present implementation of Optimisers.destructure in fact stores auxiliary information besides the original model (namely a nested set of offsets), but this seems like an implementation detail, it got that information from the model and remake!(m, v) could equally well make it later (modulo benchmarking).

@mcabbott
Copy link
Member

Then mapping the simplified form to a flat vector can be a separate piece

This hasn't been discussed so far. You may indeed want to turn the complete state including non-trainable parameters into a flat vector. This could be done like destructure(somefun!(m)[1])[1], since destructure is happy to walk a treee of namedtuples. Is that weird or confusing? Composing the two different res would rebuild the whole model.

One reason I asked about the ::AbstractVector method above was wondering whether loadparams!(m, rand(100)) with a vector of numbers should just automatically do that. Do we want it to? If so that may change our opinions on how the rest should work.

@darsnack
Copy link
Member Author

I guess I am in agreement at the highest level of this discussion, but I'm confused about what's actually being proposed. So, can we make somefun! more concrete? What's an actual possible name for this function, and what does it do?

@darsnack
Copy link
Member Author

One reason I asked about the ::AbstractVector method above was wondering whether loadparams!(m, rand(100)) with a vector of numbers should just automatically do that. Do we want it to? If so that may change our opinions on how the rest should work.

Ultimately, I want to deprecate that path entirely, but I plan on doing it once implicit params are gone.

@mcabbott
Copy link
Member

Maybe I should think some more and write the options somewhere. But not today. Sorry about derailing the PR!

tl;dr is that I'd vote not to introduce a new loadmodel! function until we've thought a bit more about what we want. Perhaps using something like fmap(copyto!, m, m2) for now is safest, as this won't break.

@ToucheSir
Copy link
Member

I think you are assuming that most uses of destructure wants the re not just the vector, while most uses of somefun!(m) would want only the simplified struct, not the re.

But this "most" doesn't seem absolute. You may want flat parameters to e.g. use them for some regularisation within the loss. You may want Base structs to save them. On the other hand, you may also (as in the present PR) want only the re part, somefun!(m, mtrained), and you may similarly call _, re = destructure(m) so that you can put re(x,p) inside some loss.

All this makes sense for destructure, but I'm not sure how it applies to saving and loading model state here? In fact,

you may also (as in the present PR) want only the re part, somefun!(m, mtrained)

To me is a great argument for having the two separate load/save functions, because otherwise you're incurring unnecessary work to generate both of ps, re just to throw away one.

More generally, the two return value API is honestly kinda weird for users, especially those coming from Python. My understanding has always been that it was a necessary evil in order to get acceptable performance for destructure, but were that not the case then

remake!(m, v) could equally well make it later (modulo benchmarking).

would've been a less confusing option.

Moreover, while re from destructure is somewhat easier to thread deep into a training loop by including it in a loss function, a re from this theoretical somefun would not be. e.g. instead of being able to call (save|load)model! directly for checkpoints, the loop would have to hold onto a re and update it every so often. That runs the risk of closing over possibly stale model state and makes everything from serialization to distributed training more difficult.

But my biggest concern (which I apologize for not thinking of earlier) is more basic. In PyTorch, I can do model.nested[1].path.load_state_dict(sd["nested.1.path"]) with impunity. With loadparams!, this would be a direct translation to loadparams!(model.nested[1].path, sd.nested[1].path). With a simple implementation of m0, re = somefun(...), this is not possible because re only works at the top level. One could make some fancy overloads for getindex, getproperty etc. on the type of re to enable such nested loading, but to me that seems like adding unnecessary internal complexity just to maintain an already unfamiliar user-facing API. This kind of subset loading is not a niche case either. I've seen/used it in real world code for transfer learning, self-supervised training and other forms of model surgery.

@ericphanson
Copy link
Contributor

https://github.com/beacon-biosignals/LegolasFlux.jl/blob/main/src/functors.jl has a basic implementation that has worked well and been stable, although I don't think anyone has pushed it too far in terms of variety of layers etc. In terms of

we've thought a bit more about what we want

It could be helpful to know what features are missing from a basic implementation like that.

@mcabbott
Copy link
Member

mcabbott commented Feb 15, 2022

that it was a necessary evil in order to get acceptable performance for destructure, but were that not the case then

No, I don't think this is true. The old implementation literally closes over m and calls a reconstructor function, no special work is done up front.

The separated version was what I proposed in FluxML/Functors.jl#31, but I got the impression everyone preferred the destructure form, hence FluxML/Optimisers.jl#54 . The one big upside I see to the combined form is that it involves fewer distinct names to remember, which is the inverse of which, etc.

@ToucheSir
Copy link
Member

I had a look back at the original issues, and I think I got some wires crossed reading #986 (comment). For posterity, #799 seems to have the actual design discussion.

That said, the point about performance is true now that Optimisers.destructure exists and saves offsets for faster restructuring. FWIW I thought fvec was fine, but it made a lot more sense to have the functionality in Optimisers instead of Functors.

@darsnack
Copy link
Member Author

It could be helpful to know what features are missing from a basic implementation like that.

The main missing feature is the kind of structural error checking that you get from this design. Collecting into a flat vector could always silently be wrong, and when it does catch an error, the best you can say is "some parameters are missing." Here, we can be more helpful since we know exactly what structure is being loaded.

It also allows for the convenient syntax of loadmodel!(m1[1][3], m2[4]) like Brian mentioned.


Okay, so we've had more discussion on this PR than I expected (which is good)! And I think at a high level the concept of "model -> simple structure" is shared between this and de/restructure. There's some disagreement around when/where re is necessary, and whether including in loading makes sense at all. I think we need to concretely say what's the final function(s) and how the behave to be able to move on this discussion.

For now, we can let this PR stew if needed. I will say that this function is needed to make Metalhead.jl work with pre-trained models. So, I propose the following path forward:

  • if we come to an agreement, we merge and move on
  • if we can't agree, we use the same old loadparams! name and defer this discussion to a later PR (this way we are simply extending what's already there and not introducing anything new)

@darsnack
Copy link
Member Author

I've also significantly simplified the implementation so that most custom types will participate in the thorough error checking for free.

This PR considers any type for which all(Functors.isleaf, Functors.children(x)) is true as "block" for loading. Any type that wants to be treated the same but fails that check can override isloadleaf. A type that wants custom error checks/messages can override loadto!(m::T, mbar) for their T. Neither of these overrides will be necessary for the vast majority of types. All that's required is @functor as always.

I will add documentation soon too which should hopefully help narrow the discussion as well.

@darsnack
Copy link
Member Author

darsnack commented Mar 5, 2022

What are our final thoughts here?

I see several transformations we want to do:

  1. T1: Nested structure -> simple nested structure
  2. T2: (Simple) nested structure -> flat structure
  3. T3: Flat structure -> (simple) nested structure
  4. T4: Simple nested structure -> nested structure

T2 must produce a "reconstructor" and T3 must consume it. T1 and T4 don't need a reconstructor, they just need some sense of alignment in the structures that it walks.

There are all kinds of ways to write the implementations for these transforms to share code. Ultimately, I think the code that they really share is the concept of walking a tree from Functors.jl. So my vote here is not to write restructure and co. to "just work" for loadparams! but to move as much shared code into fmap.

Given #1882, I would suggest we move forward with this PR and place either loadparams! (name change) or loadparams!(m, ::Params) on the deprecation path. We can make note in the error message that these forms don't work for bias=false. I can rebase this PR once #1882 is merged and fix the edge cases.

@ToucheSir
Copy link
Member

I'm not sure if fmap is the best place to hang all this common functionality, but agreed that we can worry about abstracting bits out after this interface has landed.

@darsnack
Copy link
Member Author

darsnack commented Apr 4, 2022

@ericphanson let me know if the new version clears things up

src/loading.jl Outdated
Comment on lines 36 to 38
Inactive parameters, encoded by `false` in place on an array,
can be copied to and from all-zero arrays.
Attempting to copy a non-zero array to/from an inactive parameter will throw an error.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this is more clear?

Suggested change
Inactive parameters, encoded by `false` in place on an array,
can be copied to and from all-zero arrays.
Attempting to copy a non-zero array to/from an inactive parameter will throw an error.
Inactive parameters can be encoded by using the boolean value `false` instead of an array.
If `src` or `dst` has `false` where the other model has an all-zero array, no error will be raised (and no values copied). However, attempting to copy a non-zero array to/from an inactive parameter will throw an error.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I used this with a few modifications, because your comment made me realize that the behavior is not 1-1 like the docstring implies.

Copy link
Member

@mcabbott mcabbott Apr 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we shorten this somehow? It seems a very tiny edge case about layers with bias=false vs. models with an actual bias array. Most likely this will never happen in real life. Yet somehow it gets an essay describing all possible paths.

How about just "Zero bias and bias=false are considered equivalent."?

Copy link
Contributor

@ericphanson ericphanson Apr 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could put it under extended help? See point 11 of https://docs.julialang.org/en/v1/manual/documentation/.

IMO magic should at least be clearly documented...

src/loading.jl Outdated
Inactive parameters can be encoded by using the boolean value `false` instead of an array.
If `dst == false` and `src` is an all-zero array, no error will be raised (and no values copied);
however, attempting to copy a non-zero array to an inactive parameter will throw an error.
Likewise, copying `src == false` to any `dst` array is valid, but copying `src == true` will error.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tiny nitpick: src isn't == false, but rather one of its values:

Suggested change
Likewise, copying `src == false` to any `dst` array is valid, but copying `src == true` will error.
Likewise, copying a `src` value of `false` to any `dst` array is valid, but copying a `src` value of `true` will error.

src/loading.jl Outdated
Comment on lines 35 to 39
and do not need to match between `dst` and `src`.
Inactive parameters can be encoded by using the boolean value `false` instead of an array.
If `dst == false` and `src` is an all-zero array, no error will be raised (and no values copied);
however, attempting to copy a non-zero array to an inactive parameter will throw an error.
Likewise, copying `src == false` to any `dst` array is valid, but copying `src == true` will error.
Copy link
Member

@mcabbott mcabbott Apr 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible shortening:

Suggested change
and do not need to match between `dst` and `src`.
Inactive parameters can be encoded by using the boolean value `false` instead of an array.
If `dst == false` and `src` is an all-zero array, no error will be raised (and no values copied);
however, attempting to copy a non-zero array to an inactive parameter will throw an error.
Likewise, copying `src == false` to any `dst` array is valid, but copying `src == true` will error.
and need not match.
Zero-valued arrays and boolean `false` (which is Flux's encoding of absent bias) are considered equivalent.

(edited not to be so specific to bias)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strictly speaking, the rule applies to anything not just bias. But bias should be the only occurrence of this rule in practice.

@ericphanson how do you like this shortened version?

Copy link
Contributor

@ericphanson ericphanson Apr 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really like it since like you say, it sounds like this only applies to bias (and to vectors), and it doesn't give the full semantics. Re-#1875 (comment), I think if we want a short docstring, then we should just put more of the details under extended help, so it only shows up in the online docs or if you do ?? loadparams! in the REPL.

In my view, special casing false, and allowing it to interop with zero-arrays is a bit magical, and therefore should at least be clearly documented, since it's not something you can really predict from the rest of the behavior.

Copy link
Member

@mcabbott mcabbott Apr 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. [Edit -- crossed in the mail...]

bias=false is the one official API for making variants of layers, which aims to handle. Others, like affine=false, are not -- the models must simply match.

The fact that you could, perversely, use false elsewhere, and trigger the feature, seems like we are now describing ways to hack the code to do other things. There are many others. E.g. loadleaf!(dst, src, err) = dst means that if dst has an array, and src has any other non-array (like 1.0, or Dense), then nothing will happen. Sufficiently far off the intended track, the source is the only truth.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this is supposed to work with custom layers, right? So who knows how someone is using false

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that works for me :)

my concern is about well-documented unambiguous semantics so this can be a reliable model serialization tool, including for models with custom layers etc. I.e. the Flux as a library of composable building blocks thing.

I think @mcabbott's concerns are about making it simple and keeping Flux self-consistent (but not necessarily worried about interactions outside of Flux itself). I think simple + consistent is important too, and extended help can let us achieve both, to some extent.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, didn't see your comment @mcabbott. I would be fine with removing the boolean <-> array special casing altogether.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW I think the reason allowing other mismatches is, I think, that the other half of this "model serialization tool" is something like

function simpletree(m)
    fmapstructure(m; prune=nothing) do x
        # We know isleaf(x), but further keep only values modelcopy! will accept:
        x isa AbstractArray && return x
        x === false && return x  # if we keep that...
        nothing
    end
end

which should produce a nested set of NamedTuples, with only the details this thing will load --- no layer types, no activation functions, and tied arrays appear only once. If nothing is the magic value for this, then we probably want a method to ignore it on loading:

loadleaf!(dst, src::Nothing, err) = dst

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think saving a model trained w/o bias and loading into a model w/ bias that you intend to fine-tune is a pretty reasonable/common use case. This is the pre-trained model flow, not the save my own model and load my own model flow.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've mentioned it before, but PyTorch has a pretty compelling model for dealing with these mismatches: load_state_dict errors by default, but also has a non-strict mode where it returns a symmetric diff of the source and destination model trees. All this to say that the behaviour here need not be set in stone, and that we should strive to be at least as good about telling the user about how/why loading failed when it does.

@@ -0,0 +1,92 @@
loadleaf!(dst, src, err) = dst
Copy link
Member

@mcabbott mcabbott Apr 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also wonder if there should be more errors here:

Suggested change
loadleaf!(dst, src, err) = dst
loadleaf!(dst, src, err) = dst
loadleaf!(dst::AbstractArray, src, err) = error()
loadleaf!(dst, src::AbstractArray, err) = error()

I can imagine that allowing src to have nothing means "don't change the existing weight". Which is what #1875 (comment) would generate. But it may also make truncations of branches not just leaves, which aren't allowed right now, but would I think be easy:

loadleaf!(dst, src::Nothing, err) = dst
loadleaf!(dst:: AbstractArray, src::Nothing, err) = dst

loadmodel!(dst, src::Nothing; cache = Base.IdSet()) = dst

@darsnack
Copy link
Member Author

darsnack commented Apr 4, 2022

Okay I went with the extended help suggestion, but if special casing false is the only thing holding up this PR, then let's save it for another day and merge the rest today.

src/loading.jl Outdated Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants