Source code for sparse.nn.solver

r"""
Basis Pursuit (BP) solvers (refer to :ref:`relaxation`) PyTorch API.

.. autosummary::
   :toctree: toctree/nn/

   basis_pursuit_admm

"""

from collections import defaultdict

import torch
import torch.nn.functional as F
import warnings

from mighty.models.serialize import SerializableModule
from mighty.utils.signal import peak_to_signal_noise_ratio, compute_sparsity
from mighty.utils.var_online import MeanOnline

__all__ = [
    "basis_pursuit_admm",
    "BasisPursuitADMM"
]


def _reduce(solved, solved_batch, x_solution, x, *args):
    if solved_batch.any():
        args_reduced = []
        unsolved_ids = torch.nonzero(~solved, as_tuple=False)
        unsolved_ids.squeeze_(dim=1)
        keys = solved_batch.nonzero(as_tuple=False)
        keys.squeeze_(dim=1)
        became_solved_ids = unsolved_ids[keys]
        x_solution[became_solved_ids] = x[keys]
        solved[became_solved_ids] = True
        mask_unsolved = ~solved_batch
        x = x[mask_unsolved]
        args_reduced.append(x)
        for arg in args:
            arg = arg[mask_unsolved]
            args_reduced.append(arg)
    else:
        args_reduced = [x, *args]
    return args_reduced


[docs]def basis_pursuit_admm(A, b, lambd, M_inv=None, tol=1e-4, max_iters=100, return_stats=False): r""" Basis Pursuit solver for the :math:`Q_1^\epsilon` problem .. math:: \min_x \frac{1}{2} \left|\left| \boldsymbol{A}\vec{x} - \vec{b} \right|\right|_2^2 + \lambda \|x\|_1 via the alternating direction method of multipliers (ADMM). Parameters ---------- A : (N, M) torch.Tensor The input weight matrix :math:`\boldsymbol{A}`. b : (B, N) torch.Tensor The right side of the equation :math:`\boldsymbol{A}\vec{x} = \vec{b}`. lambd : float :math:`\lambda`, controls the sparsity of :math:`\vec{x}`. tol : float The accuracy tolerance of ADMM. max_iters : int Run for at most `max_iters` iterations. Returns ------- torch.Tensor (B, M) The solution vector batch :math:`\vec{x}`. """ A_dot_b = b.matmul(A) if M_inv is None: M = A.t().matmul(A) + torch.eye(A.shape[1], device=A.device) M_inv = M.inverse().t() del M batch_size = b.shape[0] v = torch.zeros(batch_size, A.shape[1], device=A.device) u = torch.zeros(batch_size, A.shape[1], device=A.device) v_prev = v.clone() v_solution = v.clone() solved = torch.zeros(batch_size, dtype=torch.bool) iter_id = 0 dv_norm = None for iter_id in range(max_iters): b_eff = A_dot_b + v - u x = b_eff.matmul(M_inv) # M_inv is already transposed # x is of shape (<=B, m_atoms) v = F.softshrink(x + u, lambd) u = u + x - v v_norm = v.norm(dim=1) if (v_norm == 0).any(): warnings.warn(f"Lambda ({lambd}) is set too large: " f"the output vector is zero-valued.") dv_norm = (v - v_prev).norm(dim=1) / (v_norm + 1e-9) solved_batch = dv_norm < tol v, u, A_dot_b = _reduce(solved, solved_batch, v_solution, v, u, A_dot_b) if v.shape[0] == 0: # all solved break v_prev = v.clone() if iter_id != max_iters - 1: assert solved.all() v_solution[~solved] = v # dump unsolved iterations if return_stats: return v_solution, dv_norm.mean(), iter_id return v_solution
class BasisPursuitADMM(SerializableModule): state_attr = ['lambd', 'tol', 'max_iters'] def __init__(self, lambd=0.1, tol=1e-4, max_iters=100): super().__init__() self.lambd = lambd self.tol = tol self.max_iters = max_iters self.online = defaultdict(MeanOnline) self.save_stats = False def solve(self, A, b, M_inv=None): v_solution, dv_norm, iteration = basis_pursuit_admm( A=A, b=b, lambd=self.lambd, M_inv=M_inv, tol=self.tol, max_iters=self.max_iters, return_stats=True) if self.save_stats: iteration = torch.tensor(iteration + 1, dtype=torch.float32) self.online['dv_norm'].update(dv_norm.cpu()) self.online['iterations'].update(iteration) b_restored = v_solution.matmul(A.t()) self.online['psnr'].update(peak_to_signal_noise_ratio( b, b_restored).cpu()) self.online['sparsity'].update(compute_sparsity(v_solution).cpu()) return v_solution def reset_statistics(self): for online in self.online.values(): online.reset() def extra_repr(self): return f"lambd={self.lambd}, " \ f"tol={self.tol}, max_iters={self.max_iters}"