zamba.pytorch.finetuning¶
BackboneFinetuning
¶
Bases: BackboneFinetuning
Derived from PTL's built-in BackboneFinetuning
, but during the backbone freeze phase,
choose whether to freeze batch norm layers, even if train_bn
is True (i.e., even if we train them
during the backbone unfreeze phase).
Finetune a backbone model based on a learning rate user-defined scheduling.
When the backbone learning rate reaches the current model learning rate
and should_align
is set to True, it will align with it for the rest of the training.
Args:
unfreeze_backbone_at_epoch: Epoch at which the backbone will be unfreezed.
lambda_func: Scheduling function for increasing backbone learning rate.
backbone_initial_ratio_lr:
Used to scale down the backbone learning rate compared to rest of model
backbone_initial_lr: Optional, Inital learning rate for the backbone.
By default, we will use current_learning / backbone_initial_ratio_lr
should_align: Wheter to align with current learning rate when backbone learning
reaches it.
initial_denom_lr: When unfreezing the backbone, the intial learning rate will
current_learning_rate / initial_denom_lr.
train_bn: Wheter to make Batch Normalization trainable.
verbose: Display current learning rate for model and backbone
round: Precision for displaying learning rate
Example::
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import BackboneFinetuning
>>> multiplicative = lambda epoch: 1.5
>>> backbone_finetuning = BackboneFinetuning(200, multiplicative)
>>> trainer = Trainer(callbacks=[backbone_finetuning])
Source code in zamba/pytorch/finetuning.py
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
|
multiplier_factory(rate)
¶
Returns a function that returns a constant value for use in computing a constant learning rate multiplier.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
rate |
float
|
Constant multiplier. |
required |
Source code in zamba/pytorch/finetuning.py
5 6 7 8 9 10 11 12 13 14 15 16 |
|