r"""
PyTorch implementation of AutoEncoders that form sparse representations.
.. autosummary::
:toctree: toctree/include/
MatchingPursuit
LISTA
"""
import math
import torch
import torch.nn as nn
from mighty.models.autoencoder import AutoencoderOutput
from sparse.nn.solver import BasisPursuitADMM
__all__ = [
"MatchingPursuit",
"Softshrink",
"LISTA"
]
[docs]class MatchingPursuit(nn.Module):
r"""
Basis Matching Pursuit (ADMM) AutoEncoder neural network for sparse coding.
Parameters
----------
in_features : int
The num. of input features (X dimension).
out_features : int
The dimensionality of the embedding vector Z.
solver : BasisPursuitADMM
Matching Pursuit solver for the :math:`Q_1^\epsilon` problem (see
:func:`sparse.nn.solver.basis_pursuit_admm`).
Notes
-----
In overcomplete coding, where sparse representations emerge,
:code:`out_features >> in_features`. If :code:`out_features ≲ in_features`,
the encoding representation will be dense.
See Also
--------
sparse.nn.solver.basis_pursuit_admm : Basis Matching Pursuit solver,
used in this model
"""
def __init__(self, in_features, out_features, solver=BasisPursuitADMM()):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.solver = solver
self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
@property
def lambd(self):
r"""
Solver hard/soft threshold :math:`\lambda`.
"""
return self.solver.lambd
[docs] def forward(self, x, lambd=None):
"""
AutoEncoder forward pass.
Parameters
----------
x : (B, C, H, W) torch.Tensor
A batch of input images.
lambd : float or None
If not None, a new solver is created with the given `lambd` to
solve this batch `x`. In this case, the solver statistics won't
be tracked.
Returns
-------
z : (B, Z) torch.Tensor
Embedding vectors: sparse representation of `x`.
decoded : (B, C, H, W) torch.Tensor
Reconstructed `x` from `z`.
"""
input_shape = x.shape
if lambd is None:
solver = self.solver
else:
# the statistics won't be tracked
solver = BasisPursuitADMM(lambd=lambd, tol=self.solver.tol,
max_iters=self.solver.max_iters)
x = x.flatten(start_dim=1)
with torch.no_grad():
self.normalize_weight()
# save the statistics during testing only
z = solver.solve(A=self.weight.t(), b=x)
decoded = z.matmul(self.weight)
return AutoencoderOutput(z, decoded.view(*input_shape))
[docs] def normalize_weight(self):
"""
Normalize the pre-synaptic weight sum to ``1.0``.
"""
w_norm = self.weight.norm(p=2, dim=1, keepdim=True)
self.weight.div_(w_norm)
def extra_repr(self):
return f"in_features={self.in_features}, " \
f"out_features={self.out_features}"
class Softshrink(nn.Module):
def __init__(self, n_features: int):
super().__init__()
self.lambd = nn.Parameter(torch.rand(n_features))
self.relu = nn.ReLU()
def forward(self, x):
lambd = self.relu(self.lambd) # lambda threshold must be positive
mask1 = x > lambd
mask2 = x < -lambd
out = torch.zeros_like(x)
out += mask1.float() * -lambd + mask1.float() * x
out += mask2.float() * lambd + mask2.float() * x
return out
def extra_repr(self):
return f"n_features={self.lambd.nelement()}"
[docs]class LISTA(nn.Module):
r"""
Learned Iterative Shrinkage-Thresholding Algorithm [1]_ AutoEncoder neural
network for sparse coding.
Parameters
----------
in_features : int
The num. of input features (X dimension).
out_features : int
The dimensionality of the embedding vector Z.
n_folds : int
The num. of recursions to apply to get better convergence of Z.
Must be greater or equal to 1.
solver : BasisPursuitADMM
Matching Pursuit solver for the :math:`Q_1^\epsilon` problem (see
:func:`sparse.nn.solver.basis_pursuit_admm`). Used only in
`forward_best` function.
Notes
-----
In overcomplete coding, where sparse representations emerge,
:code:`out_features >> in_features`. If :code:`out_features ≲ in_features`,
the encoding representation will be dense.
References
----------
.. [1] Gregor, K., & LeCun, Y. (2010, June). Learning fast approximations
of sparse coding. In Proceedings of the 27th international conference on
international conference on machine learning (pp. 399-406).
"""
def __init__(self, in_features, out_features, n_folds=2,
solver=BasisPursuitADMM()):
super().__init__()
assert n_folds >= 1
self.in_features = in_features
self.out_features = out_features
self.n_folds = n_folds
self.solver = solver
self.weight_input = nn.Parameter(
torch.Tensor(out_features, in_features)) # W_e matrix
self.weight_lateral = nn.Parameter(
torch.Tensor(out_features, out_features)) # S matrix
self.soft_shrink = Softshrink(out_features)
self.reset_parameters()
@property
def lambd(self):
r"""
Learned Softshrink threshold vector of size :code:`out_features`.
"""
return self.soft_shrink.lambd
[docs] def reset_parameters(self):
# kaiming preserves the weights variance norm, compared to randn()
nn.init.kaiming_uniform_(self.weight_input, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.weight_lateral, a=math.sqrt(5))
nn.init.uniform_(self.soft_shrink.lambd, a=0.01, b=0.1)
[docs] def forward(self, x):
"""
AutoEncoder forward pass.
Parameters
----------
x : (B, C, H, W) torch.Tensor
A batch of input images
Returns
-------
z : (B, Z) torch.Tensor
Embedding vectors: sparse representation of `x`.
decoded : (B, C, H, W) torch.Tensor
Reconstructed `x` from `z`.
"""
input_shape = x.shape
x = x.flatten(start_dim=1) # (B, In)
b = x.matmul(self.weight_input.t()) # (B, Out)
z = self.soft_shrink(b) # (B, Out)
for recursive_step in range(self.n_folds - 1):
z = self.soft_shrink(b + z.matmul(self.weight_lateral.t()))
decoded = z.matmul(self.weight_input) # (B, In)
return AutoencoderOutput(z, decoded.view(*input_shape))
[docs] def forward_best(self, x):
"""
Test function to match the output of the :class:`MatchingPursuit`.
Parameters
----------
x : (B, C, H, W) torch.Tensor
A batch of input images
Returns
-------
z : (B, Z) torch.Tensor
Embedding vectors: sparse representation of `x`.
decoded : (B, C, H, W) torch.Tensor
Reconstructed `x` from `z`.
"""
input_shape = x.shape
with torch.no_grad():
x = x.flatten(start_dim=1)
w_norm = self.weight_input.norm(p=2, dim=1, keepdim=True)
weight = self.weight_input / w_norm
z = self.solver.solve(A=weight.t(), b=x)
decoded = z.matmul(weight) # (B, In)
return AutoencoderOutput(z, decoded.view(*input_shape))
def extra_repr(self):
return f"in_features={self.in_features}, " \
f"out_features={self.out_features}, " \
f"n_folds={self.n_folds}"