import torch
import torch.nn as nn

from mighty.trainer import TrainerAutoencoder
from mighty.trainer.trainer import Trainer
from mighty.utils.common import input_from_batch, batch_to_cuda, find_layers
from import DataLoader
from mighty.utils.signal import peak_to_signal_noise_ratio, compute_sparsity
from mighty.utils.stub import OptimizerStub
from mighty.utils.var_online import MeanOnline
from sparse.nn.model import Softshrink, MatchingPursuit, LISTA

__all__ = [

[docs]class TestMatchingPursuitParameters(TrainerAutoencoder): r""" Swipe through a range of Softshrink threshold parameters :code:`bmp_params_range` and **show** the best one. The user then can pick the "best" parameter by looking at the PSNR and sparsity plots - the choice is not trivial and depends on the application. """
[docs] def __init__(self, model: nn.Module, criterion: nn.Module, data_loader: DataLoader, bmp_params_range: torch.Tensor, param_name="param", **kwargs): super().__init__(model, criterion=criterion, data_loader=data_loader, optimizer=OptimizerStub(), **kwargs) self.bmp_params = bmp_params_range self.param_name = param_name
def train_epoch(self, epoch): self.timer.batch_id += self.timer.batches_in_epoch def full_forward_pass(self, train=True): if not train: return None assert isinstance(self.criterion, nn.MSELoss), "BMP can work only with MSE loss" mode_saved = self.model.train(False) use_cuda = torch.cuda.is_available() loss_online = MeanOnline() psnr_online = MeanOnline() sparsity_online = MeanOnline() with torch.no_grad(): for batch in self.data_loader.eval( description="Full forward pass (eval)"): if use_cuda: batch = batch_to_cuda(batch) input = input_from_batch(batch) loss = [] psnr = [] sparsity = [] for bmp_param in self.bmp_params: outputs = self.model(input, bmp_param) latent, reconstructed = outputs loss_lambd = self._get_loss(batch, outputs) psnr_lmdb = peak_to_signal_noise_ratio( input, reconstructed) sparsity_lambd = compute_sparsity(latent) loss.append(loss_lambd.cpu()) psnr.append(psnr_lmdb.cpu()) sparsity.append(sparsity_lambd.cpu()) loss_online.update(torch.stack(loss)) psnr_online.update(torch.stack(psnr)) sparsity_online.update(torch.stack(sparsity)) loss = loss_online.get_mean() self.monitor.viz.line(Y=loss, X=self.bmp_params, win='Loss', opts=dict( xlabel=f'BMP {self.param_name}', ylabel='Loss', title='Loss' )) psnr = psnr_online.get_mean() self.monitor.viz.line(Y=psnr, X=self.bmp_params, win='PSNR', opts=dict( xlabel=f'BMP {self.param_name}', ylabel='Peak signal-to-noise ratio', title='PSNR' )) sparsity = sparsity_online.get_mean() self.monitor.viz.line(Y=sparsity, X=self.bmp_params, win='Sparsity', opts=dict( xlabel=f'BMP {self.param_name}', ylabel='sparsity', title='L1 output sparsity' )) self.monitor.viz.close(win='Accuracy') self.model.train(mode_saved) return loss def _epoch_finished(self, loss): Trainer._epoch_finished(self, loss)
[docs]class TestMatchingPursuit(TrainerAutoencoder): r""" Test Matching Pursuit with the fixed Softshrink threshold (embedded in a model) and trained weights. """ def train_epoch(self, epoch): self.timer.batch_id += self.timer.batches_in_epoch
[docs]class TrainMatchingPursuit(TrainerAutoencoder): r""" Train :code:`MatchingPursuit` or :code:`LISTA` AutoEncoder with :code:`LossPenalty` loss function, defined as .. math:: L(\boldsymbol{W}, x_i) = \frac{1}{2} \left|\left| x_i - \boldsymbol{W} z_i \right|\right| + \lambda \left|\left| z_i \right|\right|_1^2 where :math:`\boldsymbol{W} z_i` is a reconstruction of an input vector :math:`x_i`. The training process alternates between two steps: 1) fix the dictionary matrix :math:`\boldsymbol{W}` and find the coefficients :math:`z_i` with Basis Pursuit; 2) fix the coefficients :math:`z_i` and update the dictionary :math:`\boldsymbol{W}` with gradient descent. """ watch_modules = TrainerAutoencoder.watch_modules + (Softshrink, MatchingPursuit, LISTA) def monitor_functions(self): super().monitor_functions() try: softshrink = next(find_layers(self.model, layer_class=Softshrink)) except StopIteration: softshrink = None def lambda_mean(viz): # effective (positive) lambda lambd = viz.line_update(y=[lambd.mean()], opts=dict( xlabel='Epoch', ylabel='Lambda mean', title='Softshrink lambda threshold', )) if softshrink is not None: self.monitor.register_func(lambda_mean) solver_online = def plot_dv_norm(viz): dv_norm = solver_online['dv_norm'].get_mean() viz.line_update(y=[dv_norm], opts=dict( xlabel='Epoch', ylabel='dv_norm (final improvement)', title='Solver convergence', )) def plot_iterations(viz): iterations = solver_online['iterations'].get_mean() viz.line_update(y=[iterations], opts=dict( xlabel='Epoch', ylabel='solver iterations', title='Solver iterations run', )) self.monitor.register_func(plot_dv_norm) self.monitor.register_func(plot_iterations) def _epoch_finished(self, loss): self.model.solver.reset_statistics() super()._epoch_finished(loss) def full_forward_pass(self, train=True): # Save convergence statistics for train forward pass only self.model.solver.save_stats = train loss = super().full_forward_pass(train=train) self.model.solver.save_stats = False return loss
[docs]class TrainLISTA(TrainMatchingPursuit): r""" Train LISTA with the original loss, defined in the paper as MSE between the latent vector Z (forward pass of LISTA NN) and the best possible latent vector Z*, obtained by running Basis Pursuit ADMM (shows better results that using original ISTA as the ground truth) on input X. .. math:: L(W, X) = \frac{1}{2} \left|\left| Z^* - Z \right|\right|^2 :code:`TrainLISTA` performs worse than :code:`TrainMatchingPursuit`. """
[docs] def __init__(self, model: nn.Module, model_reference : nn.Module, criterion: nn.Module, data_loader: DataLoader, optimizer, scheduler=None, **kwargs): super().__init__(model, criterion=criterion, data_loader=data_loader, optimizer=optimizer, scheduler=scheduler, **kwargs) model_reference.train(False) for param in model_reference.parameters(): param.requires_grad_(False) self.model_reference = model_reference self.model.solver = self.model_reference.solver
def log_trainer(self): super().log_trainer() self.monitor.log("Reference model:") self.monitor.log_model(self.model_reference) def _get_loss(self, batch, output): assert isinstance(self.model, LISTA) input = input_from_batch(batch) latent, reconstructed = output latent_best, _ = self.model_reference(input) loss = self.criterion(latent, latent_best) return loss def _epoch_finished(self, loss): solver_online = self.monitor.plot_psnr(solver_online['psnr'].get_mean(), mode='solver') self.monitor.update_sparsity(solver_online['sparsity'].get_mean(), mode='solver') super()._epoch_finished(loss)