Skip to content

Vortex-Artificial-Intelligence/Channel-Attention

Repository files navigation

Channel-Attention

A plug-and-play channel attention mechanism module implemented in PyTorch.

PyPI version License PyTorch codestyle

Installation

You can install the package via pip:

pip install channel-attention==0.0.1

We only develop and test with PyTorch. Please make sure to install it from PyTorch official website based on your system configuration.

Usage

The core of the channel attention mechanism lies in its invariance between input and output. Therefore, we can easily embed this module into a certain location in a neural network to further improve the model's performance.

import torch
from channel_attention import SEAttention

# 1D Time Series Data with (batch_size, channels, seq_len)
inputs = torch.rand(8, 16, 128)
attn = SEAttention(n_dims=1, n_channels=16, reduction=4)
print(attn(inputs).shape)

# 2D Image Data with (batch_size, channels, height, width)
inputs_2d = torch.rand(8, 16, 64, 64)
attn_2d = SEAttention(n_dims=2, n_channels=16, reduction=4)
print(attn_2d(inputs_2d).shape)

When the number of input channels is small, the channel attention mechanism is very lightweight and does not significantly increase computational complexity.

Modules

1. SEAttention: [paper] The Squeeze-and-Excitation Attention with Global Average Pooling and Feed Forward Network.

2. ChannelAttention: [paper] The Channel Attention with Global Average Pooling and Global Max Pooling.

3. SpatialAttention: [paper] The Spatial Attention with Global Average Pooling and Global Max Pooling.

4. ConvBlockAttention: [paper] The Convolutional Block Attention Module (CBAM) combining Channel Attention and Spatial Attention.

Experiments

About

A plug-and-play channel attention mechanism module implemented in PyTorch.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages