A novel approach to clustering that combines deep learning, reinforcement learning, and Markov Decision Processes to learn adaptive distance metrics for improved clustering performance under local autonomy. The system uses an Adaptive Distance Estimation Network (ADEN) that learns context-aware distance functions through interaction with parametrized clustering environments.
- Adaptive Distance Learning: Neural network learns optimal distance metrics rather than using fixed Euclidean distances
- Reinforcement Learning Framework: Models clustering as a Markov Decision Process with transition probabilities between cluster assignments
- Annealing Optimization: Two-phase training with ฮฒ-annealing for progressive refinement
- GPU Acceleration: Full CUDA support for large-scale clustering tasks
- Comprehensive Benchmarking: Systematic comparison against analytical ground truth solutions
Example: Phase transition behavior during ฮฒ-annealing showing cluster formation and refinement
-
ADEN (
ADEN.py): Adaptive Distance Estimation Network- Multi-head attention mechanism for learning context-aware distances
- Combines base Euclidean distances with learned adaptive deviations
- Temperature-scaled distance predictions with ReLU activation
-
Clustering Environments (
Env.py):ClusteringEnvNumpy: CPU-based environment for ground truth computationClusteringEnvTorch: GPU-accelerated environment for neural network training- Parametrized transition probabilities p(k|j,i) based on utility functions
-
Training System (
ADENTrain.py):- TrainDbar: Neural network training on expected distances via Monte Carlo sampling
- TrainY: Cluster centroid optimization using gradient descent on free energy
- TrainAnneal: Coordinated annealing schedule with ฮฒ parameter growth
-
Ground Truth Solver (
ClusteringGroundTruth.py):- Analytical solutions for clustering optimization when local autonomy is known
- Reference implementations for benchmarking
- Free energy minimization with scipy optimization
- Python 3.8+
- CUDA-capable GPU (recommended)
- PyTorch with CUDA support
git clone https://github.com/salar96/AutonomyAwareClustering.git
cd AutonomyAwareClustering
# Install dependencies
pip install -r requirements.txtThe system includes synthetic data generators and supports real datasets:
TestCaseGenerator.py: Multiple synthetic clustering scenariosUTD19_London.mat: Real-world sensor location data (included)- Custom datasets via CSV import
import torch
import numpy as np
from ADEN import ADEN
from Env import ClusteringEnvTorch
from ADENTrain import TrainAnneal
from TestCaseGenerator import data_RLClustering
import utils
# Set device and seed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
utils.set_seed(0)
# Load synthetic dataset
X, M, T_P, N, d = data_RLClustering(4) # 4-cluster 2D dataset
X = torch.tensor(X).float().to(device)
Y = torch.mean(X, dim=0, keepdim=True).to(device) + 0.01 * torch.randn(M, d).to(device)
# Create parametrized environment
env = ClusteringEnvTorch(
n_data=N, n_clusters=M, n_features=d,
parametrized=True, kappa=0.4, gamma=0.0, zeta=1.0, T=0.01,
device=device
)
# Initialize ADEN model
model = ADEN(input_dim=d, d_model=64, n_layers=4, n_heads=8, d_ff=128, dropout=0.01)
# Train with annealing
Y_opt, pi_opt, _, _, _ = TrainAnneal(
model, X, Y, env, device,
epochs_dbar=1000, epochs_train_y=100,
beta_init=10.0, beta_final=10000.0, beta_growth_rate=1.1
)# Full benchmark suite across parameter combinations
python benchmark.py
# Single scenario focused testing
python benchmark_UDT.py
# Results saved to Benchmark/ directory with timestampsUse the provided Jupyter notebooks for experimentation:
# Main training notebook with synthetic and real-world data
jupyter notebook DeepClusteringParametrized.ipynb
# Classical RL comparison
jupyter notebook TabularRL_Clustering.ipynb
# Ground truth analysis
jupyter notebook Clustering_GT.ipynbkappa: Exploration probability (0.1-0.5) - controls transition randomnessgamma: Weight for data-cluster distances d(i,k)zeta: Weight for cluster-cluster distances d(j,k)T: Softmax temperature - lower values = sharper transitionsparametrized: Boolean - use distance-based vs fixed transition probabilities
d_model: Internal embedding dimension (default: 64)n_layers: Number of attention blocks (default: 4)n_heads: Multi-head attention heads (default: 8)d_ff: Feed-forward network dimension (default: 128)
epochs_dbar: ADEN training epochs per annealing step (1000-2000)epochs_train_y: Centroid optimization epochs per step (100)beta_init/beta_final: Annealing schedule bounds (10.0 to 10000.0)beta_growth_rate: Multiplicative growth factor (1.1)
The environment computes cluster transition probabilities:
The ground truth optimal assignments and centroids are given by:
ADEN enhances base distances with learned components:
where
Progressive sharpening of cluster assignments:
โโโ ADEN.py # Adaptive Distance Estimation Network
โโโ ADENTrain.py # Training algorithms (TrainDbar, TrainY, TrainAnneal)
โโโ Env.py # Clustering environments (NumPy/PyTorch)
โโโ ClusteringGroundTruth.py # Analytical ground truth solvers
โโโ TestCaseGenerator.py # Synthetic dataset generation
โโโ benchmark.py # Comprehensive benchmarking suite
โโโ benchmark_UDT.py # Focused benchmark scenarios
โโโ utils.py # Utility functions (distances, seeding)
โโโ Plotter.py # Visualization utilities
โโโ animator.py # GIF animation generation
โโโ ReinforcementClustering.py # Classical tabular RL approach
โโโ DeepClusteringParametrized.ipynb # Main experiment notebook
โโโ TabularRL_Clustering.ipynb # Classical RL experiments
โโโ Clustering_GT.ipynb # Ground truth analysis
โโโ Benchmark/ # Benchmark results (timestamped)
โโโ BenchmarkUDT/ # UDT-specific results
โโโ Results/ # Visualization outputs
โโโ animations/ # Generated GIF animations
The system provides comprehensive visualization capabilities:
- Static Plots:
Plotter.pygenerates publication-ready clustering visualizations - Animations:
animator.pycreates GIF animations showing clustering evolution - Real-time Monitoring: Training progress with loss curves and convergence metrics
Example visualization code:
from Plotter import PlotClustering
PlotClustering(
X.cpu().numpy(), Y_opt.cpu().numpy(), pi_opt,
figsize=(12, 6), cmap="gist_rainbow",
save_path="Results/clustering_result.png"
)This framework has been applied to:
-
Sensor Network Optimization: UTD19 London sensor placement dataset
-
Synthetic Benchmark Problems: Multi-modal, multi-scale clustering scenarios
-
Decentralized Systems: Autonomous agent coordination and resource allocation
The system uses multiple clustering quality metrics:
- Chamfer Distance: Bidirectional point-to-cluster matching
- Hungarian Distance: Optimal cluster center assignment cost
- Free Energy: Thermodynamic clustering objective
- Distortion: Weighted sum of within-cluster distances
- Extend
ClusteringEnvNumpyorClusteringEnvTorch - Implement
return_probabilities()andstep()methods - Update benchmark configurations
- Modify
ADENclass inADEN.py - Ensure compatibility with
TrainDbarbatching - Update
reset_weights()for proper initialization - Test with different
d_modelconfigurations
- Add data loading function in
TestCaseGenerator.py - Follow the format:
return X, M, T_P, N, d - Normalize data to [0,1] range for stability
If you use this code in your research, please cite:
@article{autonomy_aware_clustering_2024,
title={Autonomy-Aware Clustering},
author={[Authors]},
journal={[Journal]},
year={2024}
}Contributions are welcome! Please:
- Fork the repository
- Create a feature branch
- Follow the existing code style
- Add tests for new functionality
- Submit a pull request
This project is licensed under the MIT License - see the LICENSE file for details.
CUDA Memory Errors: Reduce batch_size_dbar or num_samples_in_batch_dbar
Convergence Issues:
- Adjust
beta_growth_rate(try 1.05-1.2) - Increase
perturbation_stdto escape local minima - Check environment parameter ranges
Training Instability:
- Use
%env CUDA_LAUNCH_BLOCKING=1in notebooks for debugging - Ensure
utils.set_seed(0)is called before training - Monitor loss curves for numerical issues
Performance:
- Use PyTorch environments for GPU training
- NumPy environments for ground truth computation only
- Profile with
torch.profilerfor bottleneck identification
For more detailed troubleshooting, see the GitHub Issues page.
