sparse.nn.trainer.TrainMatchingPursuit¶
- class sparse.nn.trainer.TrainMatchingPursuit(model: Module, criterion: Module, data_loader: DataLoader, optimizer: Optimizer, scheduler: Optional[Union[_LRScheduler, ReduceLROnPlateau]] = None, **kwargs)[source]¶
Bases:
TrainerAutoencoder
Train
MatchingPursuit
orLISTA
AutoEncoder withLossPenalty
loss function, defined as\[L(\boldsymbol{W}, x_i) = \frac{1}{2} \left|\left| x_i - \boldsymbol{W} z_i \right|\right| + \lambda \left|\left| z_i \right|\right|_1^2\]where \(\boldsymbol{W} z_i\) is a reconstruction of an input vector \(x_i\).
- The training process alternates between two steps:
1) fix the dictionary matrix \(\boldsymbol{W}\) and find the coefficients \(z_i\) with Basis Pursuit;
2) fix the coefficients \(z_i\) and update the dictionary \(\boldsymbol{W}\) with gradient descent.
- __init__(model: Module, criterion: Module, data_loader: DataLoader, optimizer: Optimizer, scheduler: Optional[Union[_LRScheduler, ReduceLROnPlateau]] = None, **kwargs)¶
Methods
__init__
(model, criterion, data_loader, ...)checkpoint_path
([best])Get the checkpoint path, given the mode.
full_forward_pass
([train])Fixes the model weights, evaluates the epoch score and updates the monitor.
is_unsupervised
()- Returns
log_trainer
()Logs the trainer in Visdom text field.
monitor_functions
()Override this method to register Visdom callbacks on each epoch.
open_monitor
([offline])Opens a Visdom monitor.
plot_autoencoder
()Plots AutoEncoder reconstruction.
restore
([checkpoint_path, best, strict])Restores the trainer progress and the model from the path.
save
([best])Saves the trainer and the model parameters to
self.checkpoint_path(best)
.state_dict
()- Returns
train
([n_epochs, mutual_info_layers, ...])User-entry function to train the model for
n_epochs
.train_batch
(batch)The core function of a trainer to update the model parameters, given a batch.
train_epoch
(epoch)Trains an epoch.
train_mask
([mask_explain_params])Train mask to see what part of an image is crucial from the network perspective (saliency map).
training_finished
()Training is finished callback.
training_started
()Training is started callback.
update_accuracy
([train])Updates the accuracy of the model.
update_best_score
(score)If
score
is greater than theself.best_score
, save the model.Attributes
best_score_type
epoch
The current epoch, int.
watch_modules