Skip to content
96 changes: 71 additions & 25 deletions torchtitan/experiments/deterministic_vllm_rl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Note: Currently supports single-device training only.
- Implements custom backward pass for gradient computation
- Uses `num_splits=1` for deterministic behavior

2. `models/qwen3/model_vllm_compat.py`: Qwen3VLLMCompatModel
2. `models/qwen3/model_batch_invariant.py`: Qwen3VLLMCompatModel
- Qwen3 model with merged gate/up projections matching vLLM format
- Uses VLLMRMSNorm with gradient support

Expand Down Expand Up @@ -211,37 +211,83 @@ This implementation uses the same kernels for both rollouts (vLLM) and training
2. Only causal attention is supported
3. Requires NVIDIA GPUs with Flash Attention support

## Project Structure

```
deterministic_vllm_rl/
├── README.md # Documentation
├── __init__.py # Package initialization
├── batch_invariant_backward.py # Backward passes for vLLM ops
├── weights_vllm_compat.py # Weight conversion utilities
├── simple_rl.py # RL training loop
├── models/
│ ├── __init__.py
│ ├── attention.py # VLLMCompatibleFlashAttention
│ └── qwen3/
│ ├── __init__.py
│ └── model_vllm_compat.py # vLLM-compatible Qwen3 model
├── weights/
│ ├── __init__.py
│ ├── converter.py # Weight conversion script
│ └── README.md # Weight conversion documentation
└── tests/
├── __init__.py
├── test_batch_invariant_backward.py # Test backward passes
└── test_exact_determinism.py # Test determinism
```

## TODO

- `FlashAttnWithBackward` will need to become more composable and should not live exclusively within this directory.
- vLLM integration will need to become more generic with a provided Attention operator that is KV-cache compatible.
- vLLM parallelism will need to add generic parallelism initialization to support Monarch managed TP/DP.

# Run vLLM inference with TorchTitan Qwen3 Model

This directory contains code to run a single canonical model definition (TorchTitan model definition) with vLLM inference engine (not batch-invariant yet, working in progress).
This work is inspired by https://github.com/vllm-project/vllm/pull/28685.

## Overview
The integration consists of two main components:

1. **Model Adapter** (`model/qwen3.py`): A custom model class that extends vLLM's `Qwen3ForCausalLM` to handle TorchTitan checkpoint naming conventions
2. **Inference Script** (`infer.py`): A simple script to register the model and run inference


## Quick Start
### Prerequisites

1. Install PyTorch nightly for torchtitan:
```
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --force-reinstall
```


2. Install vLLM from source [vllm-use-an-existing-pytorch-installation](https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#use-an-existing-pytorch-installation):
```bash
# install PyTorch first, either from PyPI or from source
git clone https://github.com/vllm-project/vllm.git
cd vllm
python use_existing_torch.py
uv pip install -r requirements/build.txt
uv pip install --no-build-isolation -e .
```


NOTE: If `flash_attn_varlen_func` hits error "torch.AcceleratorError: CUDA error: the provided PTX was compiled with an unsupported toolchain" during forward path, this is due to GPU driver version is not compatible with vLLM/PyTorch compiled version. Use the following command to recompile vLLM.

```
# Set CUDA version environment variable
export CUDA_HOME=/usr/local/cuda-12.4
export PATH=/usr/local/cuda-12.4/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/cuda-12.4/lib64:$LD_LIBRARY_PATH

# Clean previous build
rm -rf build dist *.egg-info
uv pip uninstall -y vllm

# Rebuild vLLM from source with CUDA 12.4
pip install -e .

```

3. Download Qwen3/Qwen3-0.6b checkpoint from HuggingFace and put into `example_checkpoint` folder. Make sure to change the "architecture" field in `config.json` to be `Qwen3TorchTitanForCausalLM` so vllm engine could use torchtitan model.


4. Run inference:
```
python torchtitan/experiments/deterministic_vllm_rl/infer.py --model torchtitan/experiments/deterministic_vllm_rl/example_checkpoint/qwen3-0.6B
```

Run with TP: (work in progress)
```
python torchtitan/experiments/deterministic_vllm_rl/infer.py --model torchtitan/experiments/deterministic_vllm_rl/example_checkpoint/qwen3-0.6B --tensor-parallel-size 2

```

## TODO
1. Rewrite attention part to use vllm.Attention() with backward as the only attention path.
2. Integrate with simple_rl.py to run end-to-end RL with one canonical model definition.
3. Leverage batch-invariant kernels into model definition.



## Contributing

This experiment is part of TorchTitan. To contribute:
Expand Down
9 changes: 9 additions & 0 deletions torchtitan/experiments/deterministic_vllm_rl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
- Qwen3VLLMCompatModel: vLLM-compatible model with merged projections
- batch_invariant_backward: Gradient support for vLLM's deterministic operations
- simple_rl: End-to-end RL training loop
- TorchTitanVLLMModel: Generic wrapper for TorchTitan models with vLLM

For vLLM inference with TorchTitan models, see:
- models/vllm_wrapper.py: Core vLLM wrapper
- models/__init__.py: Auto-registration with vLLM
- infer.py: Example inference script
"""

from .batch_invariant_backward import (
Expand All @@ -24,11 +30,14 @@
)
from .models import VLLMCompatibleFlashAttention
from .models.qwen3 import Qwen3VLLMCompatModel
from .models.vllm_wrapper import TorchTitanVLLMModel


__all__ = [
"VLLMCompatibleFlashAttention",
"Qwen3VLLMCompatModel",
"enable_batch_invariant_backward_mode",
"rms_norm_with_gradients",
"silu_and_mul_with_gradients",
"TorchTitanVLLMModel",
]
114 changes: 114 additions & 0 deletions torchtitan/experiments/deterministic_vllm_rl/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import argparse

from vllm import LLM, SamplingParams
from vllm.logger import init_logger

# Import models module - this automatically registers TorchTitan models with vLLM
from torchtitan.experiments.deterministic_vllm_rl import models # noqa: F401


logger = init_logger(__name__)


def parse_args():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not urgent, but we should use "our" config system in the long term

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curretnly the entry point is vllm engine, so we are taking the config from whatever vllm engine passed to us. Let me check vllm engine see if there's anything we could do

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait how is it related to vllm config system? You are just using them as is in args = parse_args().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This args is only for infer.py script, it will pass args into vllm engine LLM() , and vllm engine will create a VLLMConfig instance internally, and pass to our model wrapper

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it will pass args into vllm engine LLM()

I don't think it's passing the args to LLM(). What would be different if we use our config manager to construct args?

parser = argparse.ArgumentParser(
description="Run TorchTitan model inference with vLLM Engine",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--model_ckpt_path",
type=str,
default="torchtitan/experiments/deterministic_vllm_rl/example_checkpoint/qwen3-0.6B",
help="Path to TorchTitan checkpoint directory",
)
parser.add_argument(
"--prompt",
type=str,
default="Hello, my name is",
help="Prompt text for generation",
)
parser.add_argument(
"--max-tokens",
type=int,
default=100,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temperature",
type=float,
default=0.8,
help="Sampling temperature",
)
parser.add_argument(
"--tensor-parallel-size",
type=int,
default=1,
help="Number of GPUs for tensor parallelism (default: 1 for single GPU)",
)
return parser.parse_args()


def main():
args = parse_args()

logger.info("Initializing vLLM with TorchTitan model")
logger.info(f"Model: {args.model_ckpt_path}")
logger.info(f"Tensor Parallel Size: {args.tensor_parallel_size}")

# Initialize vLLM with custom TorchTitan model
# The LLM initialization will internally:
# 1. Load TrainSpec for Qwen3 (from models/__init__.py register())
# 2. Create TorchTitanVLLMModel instance
# 3. Create JobConfig and ParallelDims from vLLM config
# 4. Apply parallelization using parallelize_qwen3
# 5. Load model weights and prepare for inference
logger.info("Creating vLLM LLM engine...")

llm = LLM(
model=args.model_ckpt_path, # Model checkpoint path
hf_overrides={
"checkpoint_dir": args.model_ckpt_path,
},
dtype="bfloat16",
trust_remote_code=True,
enforce_eager=True, # Use eager mode
tensor_parallel_size=args.tensor_parallel_size,
)

logger.info("vLLM engine initialized successfully")
logger.info(f"Prompt: {args.prompt}")

# Prepare prompt and sampling parameters
prompts = [args.prompt]
sampling_params = SamplingParams(
temperature=args.temperature,
top_p=0.95,
max_tokens=args.max_tokens,
)

# Generate text
logger.info("Generating text...")
outputs = llm.generate(
prompts=prompts,
sampling_params=sampling_params,
)

# Print results
logger.info("Generation complete")
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text

print(f"\nPrompt: {prompt}")
print(f"Generated text: {generated_text!r}\n")


if __name__ == "__main__":
main()
73 changes: 71 additions & 2 deletions torchtitan/experiments/deterministic_vllm_rl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,77 @@

"""
Models for deterministic vLLM RL training.

This module automatically registers TorchTitan models with vLLM when imported.
"""

from .attention import VLLMCompatibleFlashAttention
from vllm.logger import init_logger

from torchtitan.protocols.train_spec import get_train_spec, TrainSpec
from .attention import VLLMCompatibleFlashAttention, VLLMPagedFlashAttention
from .vllm_wrapper import TorchTitanVLLMModel


logger = init_logger(__name__)


def register_torchtitan_model_from_train_spec(
train_spec: TrainSpec,
model_name: str,
) -> None:
"""
Register a TorchTitan model with vLLM using a TrainSpec.

Args:
train_spec: TorchTitan TrainSpec containing model components
model_name: Name to register in vLLM (e.g., "Qwen3TorchTitanForCausalLM")

"""
from vllm.model_executor.models.registry import ModelRegistry

# Extract model_args from TrainSpec
# TrainSpec has model_args as a Mapping, get the first value
if isinstance(train_spec.model_args, dict):
model_args_cls = type(next(iter(train_spec.model_args.values())))
else:
model_args_cls = train_spec.model_args

# Create dynamic model class directly from TrainSpec components
class TorchTitanVLLMModelFromSpec(TorchTitanVLLMModel):
"""Dynamically created vLLM model from TrainSpec."""

def __init__(self, *, vllm_config, prefix=""):
super().__init__(
model_cls=train_spec.model_cls,
model_args_cls=model_args_cls,
state_dict_adapter=train_spec.state_dict_adapter,
parallelize_fn=train_spec.parallelize_fn,
Comment on lines +50 to +53
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we need these fields and the wrappers TorchTitanVLLMModel / TorchTitanVLLMModelFromSpec because we rely on vllm's LLM() api to create the model.

This is hacky and making things complicated as we are dumping a lot of logic (originally in train.py and checkpoint.py) to the model code itself.

I feel this is unnecessary if our end goal is to use the engine part of vLLM, not the model init part.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dumping a lot of logic (originally in train.py and checkpoint.py) to the model code itself

Agreed, the main blocker is that we need to have control of how Worker instantiate a model.

According to vllm design , this class is not only a model nn.module, but a model_runner https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/worker_base.py#L85, that's why it has load_weights function

vllm_config=vllm_config,
prefix=prefix,
)

# Set the class name
TorchTitanVLLMModelFromSpec.__name__ = model_name
TorchTitanVLLMModelFromSpec.__qualname__ = model_name

# Register with vLLM
ModelRegistry.register_model(model_name, TorchTitanVLLMModelFromSpec)

logger.info(
f"Successfully registered {model_name} with vLLM using TrainSpec "
f"(model_cls={train_spec.model_cls.__name__})"
)


# Auto-register TorchTitan models with vLLM when this module is imported
register_torchtitan_model_from_train_spec(
train_spec=get_train_spec("qwen3"),
model_name="Qwen3TorchTitanForCausalLM",
)


__all__ = ["VLLMCompatibleFlashAttention"]
__all__ = [
"VLLMCompatibleFlashAttention",
"VLLMPagedFlashAttention",
"TorchTitanVLLMModel",
]
Loading