sparse.nn.trainer.TrainLISTA¶
- class sparse.nn.trainer.TrainLISTA(model: Module, model_reference: Module, criterion: Module, data_loader: DataLoader, optimizer, scheduler=None, **kwargs)[source]¶
Bases:
TrainMatchingPursuitTrain LISTA with the original loss, defined in the paper as MSE between the latent vector Z (forward pass of LISTA NN) and the best possible latent vector Z*, obtained by running Basis Pursuit ADMM (shows better results that using original ISTA as the ground truth) on input X.
\[L(W, X) = \frac{1}{2} \left|\left| Z^* - Z \right|\right|^2\]TrainLISTAperforms worse thanTrainMatchingPursuit.- __init__(model: Module, model_reference: Module, criterion: Module, data_loader: DataLoader, optimizer, scheduler=None, **kwargs)[source]¶
Methods
__init__(model, model_reference, criterion, ...)checkpoint_path([best])Get the checkpoint path, given the mode.
full_forward_pass([train])Fixes the model weights, evaluates the epoch score and updates the monitor.
is_unsupervised()- Returns
log_trainer()Logs the trainer in Visdom text field.
monitor_functions()Override this method to register Visdom callbacks on each epoch.
open_monitor([offline])Opens a Visdom monitor.
plot_autoencoder()Plots AutoEncoder reconstruction.
restore([checkpoint_path, best, strict])Restores the trainer progress and the model from the path.
save([best])Saves the trainer and the model parameters to
self.checkpoint_path(best).state_dict()- Returns
train([n_epochs, mutual_info_layers, ...])User-entry function to train the model for
n_epochs.train_batch(batch)The core function of a trainer to update the model parameters, given a batch.
train_epoch(epoch)Trains an epoch.
train_mask([mask_explain_params])Train mask to see what part of an image is crucial from the network perspective (saliency map).
training_finished()Training is finished callback.
training_started()Training is started callback.
update_accuracy([train])Updates the accuracy of the model.
update_best_score(score)If
scoreis greater than theself.best_score, save the model.Attributes
best_score_typeepochThe current epoch, int.
watch_modules