Skip to content

JAX implementation of "Back to Basics: Let Denoising Generative Models Denoise" a.k.a. Just Image Transformers

License

Notifications You must be signed in to change notification settings

jkyl/just-image-transformer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

just-image-transformer

CI Weights

Unofficial JAX implementation of "Back to Basics: Let Denoising Generative Models Denoise" a.k.a. "Just Image Transformers" (JiT), by Tianhong Li and Kaiming He.

Note

Nvidia GPUs >= Ampere are supported by default. I developed this code on a PC with a 3090 in WSL2. TPU support can be enabled by setting implementation="xla" in the default attention function in model.py.

Quick start

# Install dependencies
uv sync

# Train JiT-L/32 on ImageNet
./scripts/train.sh --notes="my first run"

# Generate a single image
./scripts/inference.sh model.npz output.png

# Start the inference server
./scripts/server.sh model.npz

# Lint and type check
make

Inference server

./scripts/server.sh starts a Flask app with a simple web UI for generating images. You can adjust the seed, class label, CFG strength, number of steps, and schedule. Progress streams in real-time.

Screen.Recording.2025-11-29.at.3.57.28.PM.mov

References

@article{li2025backtobasics,
  title={Back to Basics: Let Denoising Generative Models Denoise},
  author={Li, Tianhong and He, Kaiming},
  journal={arXiv preprint arXiv:2511.13720},
  year={2025}
}

About

JAX implementation of "Back to Basics: Let Denoising Generative Models Denoise" a.k.a. Just Image Transformers

Topics

Resources

License

Stars

Watchers

Forks

Languages