Question: how can one use TrainState.cache? #893
-
From the TrainState docs it seems that there is a field in the struct called What is the right way to initialize the TrainState.cache? state = TrainState(params, opt_state, Dict(:best_params => deepcopy(params), :min_loss => Inf))
# and later in the training loop ...
# Update the minimum loss and best parameters in the cache
if loss < state.cache[:min_loss]
state.cache[:min_loss] = loss
state.cache[:best_params] = deepcopy(new_params)
end |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
For your case, you should simply create an external cache and update it as necessary without involving the TrainState. |
Beta Was this translation helpful? Give feedback.
cache
is an internal field. It is used by the backend implementations, i.e. the ones defined in the extensions likeZygote
,Enzyme
, etc. to cache intermediate buffers and such. It is not meant to be exposed to the external user.For your case, you should simply create an external cache and update it as necessary without involving the TrainState.