A plug-and-play channel attention mechanism module implemented in PyTorch.
Installation | Usage | Modules | Blog | Experiments
You can install the package via pip:
pip install channel-attention==0.0.1We only develop and test with PyTorch. Please make sure to install it from PyTorch official website based on your system configuration.
The core of the channel attention mechanism lies in its invariance between input and output. Therefore, we can easily embed this module into a certain location in a neural network to further improve the model's performance.
import torch
from channel_attention import SEAttention
# 1D Time Series Data with (batch_size, channels, seq_len)
inputs = torch.rand(8, 16, 128)
attn = SEAttention(n_dims=1, n_channels=16, reduction=4)
print(attn(inputs).shape)
# 2D Image Data with (batch_size, channels, height, width)
inputs_2d = torch.rand(8, 16, 64, 64)
attn_2d = SEAttention(n_dims=2, n_channels=16, reduction=4)
print(attn_2d(inputs_2d).shape)When the number of input channels is small, the channel attention mechanism is very lightweight and does not significantly increase computational complexity.




