Source code for torchact.nn.ab_relu

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs]class ABReLU(nn.Module): r""" Implementation of Average-Biased Rectified Linear Unit https://arxiv.org/abs/1804.02051 :math:`A_{i}^{n}=\frac{\sum_{\rho_{1}=1}^{D_{1}}{\sum_{\rho_{2}=1}^{D_{2}}\cdots\sum_{\rho_{d}=1}^{D_{d}}{I_{i}^{n}(\rho_{1},\rho_{2},\cdots,\rho_{d})}}}{D_{1}\times D_{2}\times\cdots\times D_{d}}` :math:`\beta=\alpha\times A_{i}^{n}` :math:`I_{o}^{n}(\rho)=\begin{cases}I_{i}^{n}(\rho)-\beta,&\text{if }I_{i}^{n}(\rho)-\beta>0\\0,&\text{otherwise}\end{cases}` :param float alpha: parameter to be set empirically. Default: 1.0 :param bool inplace: In-place operation. Default: False Examples:: >>> import torch, torchact >>> m = torchact.nn.ABReLU() >>> input = torch.tensor([1.0, -2.0, 0.0, 3.0]) >>> output = m(input) >>> print(output) tensor([0.5000, 0.0000, 0.0000, 2.5000]) """ def __init__(self, alpha: float = 1.0, inplace: bool = False): super(ABReLU, self).__init__() self.alpha = alpha self.inplace = inplace def forward(self, x): D = list(x.size()) d = len(D) A = x.clone() for dim in range(d): for _ in range(D[dim]): A = A.sum(dim=dim, keepdim=True) A = A / torch.prod(torch.tensor(D).float()) beta = self.alpha * A x_out = x - beta x_out = torch.clip(x_out, min=0) return x_out