Unofficial JAX implementation of "Back to Basics: Let Denoising Generative Models Denoise" a.k.a. "Just Image Transformers" (JiT), by Tianhong Li and Kaiming He.
Note
Nvidia GPUs >= Ampere are supported by default. I developed this code on a PC with a 3090 in WSL2. TPU support can be enabled by setting implementation="xla" in the default attention function in model.py.
# Install dependencies
uv sync
# Train JiT-L/32 on ImageNet
./scripts/train.sh --notes="my first run"
# Generate a single image
./scripts/inference.sh model.npz output.png
# Start the inference server
./scripts/server.sh model.npz
# Lint and type check
make./scripts/server.sh starts a Flask app with a simple web UI for generating images. You can adjust the seed, class label, CFG strength, number of steps, and schedule. Progress streams in real-time.
Screen.Recording.2025-11-29.at.3.57.28.PM.mov
@article{li2025backtobasics,
title={Back to Basics: Let Denoising Generative Models Denoise},
author={Li, Tianhong and He, Kaiming},
journal={arXiv preprint arXiv:2511.13720},
year={2025}
}