sparse.nn.trainer.TrainMatchingPursuit

class sparse.nn.trainer.TrainMatchingPursuit(model: Module, criterion: Module, data_loader: DataLoader, optimizer: Optimizer, scheduler: Optional[Union[_LRScheduler, ReduceLROnPlateau]] = None, **kwargs)[source]

Bases: TrainerAutoencoder

Train MatchingPursuit or LISTA AutoEncoder with LossPenalty loss function, defined as

\[L(\boldsymbol{W}, x_i) = \frac{1}{2} \left|\left| x_i - \boldsymbol{W} z_i \right|\right| + \lambda \left|\left| z_i \right|\right|_1^2\]

where \(\boldsymbol{W} z_i\) is a reconstruction of an input vector \(x_i\).

The training process alternates between two steps:

1) fix the dictionary matrix \(\boldsymbol{W}\) and find the coefficients \(z_i\) with Basis Pursuit;

2) fix the coefficients \(z_i\) and update the dictionary \(\boldsymbol{W}\) with gradient descent.

__init__(model: Module, criterion: Module, data_loader: DataLoader, optimizer: Optimizer, scheduler: Optional[Union[_LRScheduler, ReduceLROnPlateau]] = None, **kwargs)

Methods

__init__(model, criterion, data_loader, ...)

checkpoint_path([best])

Get the checkpoint path, given the mode.

full_forward_pass([train])

Fixes the model weights, evaluates the epoch score and updates the monitor.

is_unsupervised()

Returns

log_trainer()

Logs the trainer in Visdom text field.

monitor_functions()

Override this method to register Visdom callbacks on each epoch.

open_monitor([offline])

Opens a Visdom monitor.

plot_autoencoder()

Plots AutoEncoder reconstruction.

restore([checkpoint_path, best, strict])

Restores the trainer progress and the model from the path.

save([best])

Saves the trainer and the model parameters to self.checkpoint_path(best).

state_dict()

Returns

train([n_epochs, mutual_info_layers, ...])

User-entry function to train the model for n_epochs.

train_batch(batch)

The core function of a trainer to update the model parameters, given a batch.

train_epoch(epoch)

Trains an epoch.

train_mask([mask_explain_params])

Train mask to see what part of an image is crucial from the network perspective (saliency map).

training_finished()

Training is finished callback.

training_started()

Training is started callback.

update_accuracy([train])

Updates the accuracy of the model.

update_best_score(score)

If score is greater than the self.best_score, save the model.

Attributes

best_score_type

epoch

The current epoch, int.

watch_modules