WeightAveraging¶
- class lightning.pytorch.callbacks.WeightAveraging(device=None, use_buffers=True, **kwargs)[source]¶
Bases:
Callback
A callback that updates an averaged model for Stochastic Weight Averaging (SWA) or Exponential Moving Average (EMA) after each training step.
Arguments given to the constructor will be passed to the
AveragedModel
constructor. If nodevice
is specified, the device of the original model will be used. Contrary toAveragedModel
,use_buffers
is set toTrue
by default. That is, by default the callback will compute running averages for both the parameters and the buffers of the model. Settinguse_buffers
toFalse
will cause only the model parameters to be averaged, leaving updating the batch normalization statistics to the user (usingtorch.optim.swa_utils.update_bn()
).You can provide a custom averaging function with the
avg_fn
ormulti_avg_fn
parameter. See theAveragedModel
class for details. If no averaging function is provided, the default is to compute the equally-weighted average of the weights (SWA).You can customize when the average model is updated by overriding the
should_update()
method. The callback calls it with eitherstep_idx
orepoch_idx
and the method returns a boolean indicating whether to update after the given step or epoch. The default is to update after every step.During validation and after the training finishes, the current model parameters will be replaced with the averaged values.
See also the documentation on the weight averaging callbacks provided by Lightning.
Note
To ensure that the
AveragedModel
will contain all layers,setup()
will callconfigure_model()
before instantiating theAveragedModel
. However, that hook is not called in a strategy aware context, sharded models do not work with weight averaging, and a warning will be issued.Example:
from lightning.pytorch.callbacks import WeightAveraging from torch.optim.swa_utils import get_ema_avg_fn class EMAWeightAveraging(WeightAveraging): def __init__(self): super().__init__(avg_fn=get_ema_avg_fn()) def should_update(self, step_idx=None, epoch_idx=None): # Start after 100 steps. return (step_idx is not None) and (step_idx >= 100) trainer = Trainer(callbacks=EMAWeightAveraging(), max_epochs=10) trainer.fit(model, dataloader)
- Parameters:
device¶ (
Union
[device
,str
,int
,None
]) – By default, theAveragedModel
will be stored on the same device as the original model. If thedevice
argument is provided, theAveragedModel
will be stored on this device instead. If you run out of GPU memory, you might want to use"cpu"
.use_buffers¶ (
bool
) – IfFalse
, the buffers of the model will not be averaged.kwargs¶ (
Any
) – Additional keyword arguments to be passed to theAveragedModel
constructor, such asavg_fn
ormulti_avg_fn
.
- load_state_dict(state_dict)[source]¶
Called when loading a checkpoint.
Reloads the callback state given a
state_dict
.
- on_load_checkpoint(trainer, pl_module, checkpoint)[source]¶
Called when loading a model checkpoint.
Loads the current model and the
AveragedModel
parameters from the checkpoint.
- on_save_checkpoint(trainer, pl_module, checkpoint)[source]¶
Called when saving a checkpoint.
Moves the current model state to the key
current_model_state
, and places the average model state instate_dict
instead. Any other state variables of theAveragedModel
will be saved inaveraging_state
.
- on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]¶
Called when a training batch ends.
Updates the
AveragedModel
parameters, if requested byself.should_update()
.- Parameters:
- Return type:
- on_train_end(trainer, pl_module)[source]¶
Called when training ends.
Transfers parameters from the
AveragedModel
to the current model.- Parameters:
pl_module¶ (
LightningModule
) – The currentLightningModule
instance.
- Return type:
- on_train_epoch_end(trainer, pl_module)[source]¶
Called when a training epoch ends.
Updates the
AveragedModel
parameters, if requested byself.should_update()
.- Parameters:
pl_module¶ (
LightningModule
) – The currentLightningModule
instance.
- Return type:
- on_validation_epoch_end(trainer, pl_module)[source]¶
Called when a validation epoch ends.
Recovers the current model parameters from the
AveragedModel
.- Parameters:
pl_module¶ (
LightningModule
) – The currentLightningModule
instance.
- Return type:
- on_validation_epoch_start(trainer, pl_module)[source]¶
Called when a validation epoch begins.
Transfers parameter values from the
AveragedModel
to the current model.- Parameters:
pl_module¶ (
LightningModule
) – The currentLightningModule
instance.
- Return type:
- setup(trainer, pl_module, stage)[source]¶
Called when fit, validate, test, predict, or tune begins.
Creates an
AveragedModel
when fit begins.- Parameters:
pl_module¶ (
LightningModule
) – The currentLightningModule
instance.
- Return type:
- should_update(step_idx=None, epoch_idx=None)[source]¶
Called after every optimizer step and after every training epoch to check whether the average model should be updated.
One of the arguments is set to the zero-based index of the last training step or epoch. The default implementation returns
True
when anystep_idx
is provided. The user can customize when the average model gets updated by overriding this method.