-
Notifications
You must be signed in to change notification settings - Fork 630
Run vLLM inference using torchtitan model definition (single GPU) #2119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
left some comments, ignore if you still plan to change
| return output | ||
|
|
||
|
|
||
| class VLLMPagedFlashAttention(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's this for? I thought we only run inference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By design this class should be able to run both training and inference in the future, so we have one single attention class. Now it's only used when running inference. I will clean up and remove non-inference related stuff. cc @zhxchen17
|
|
||
|
|
||
| class FeedForwardVLLMCompat(nn.Module): | ||
| class TorchTitanQwen3ForCausalLM(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should aim for at most 1 general wrapper for all models -- we shouldn't have 1 wrapper for each model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that's the goal! The current wrapper is specific to Qwen3 model as we access each model layers by name, let me think about how to design the interface
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suppose each model author defines the model in slightly different ways, then it seems to up to the ppl who do RL to make the adaptation work like we prototyped here.
So I guess we need some sort of contract baked in the authoring time. e.g. Models should annotate/implement their attention layers in certain ways (having some sort of base class or special methods?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed with @zhxchen17.
I think we can start by adding a BaseTransformer class (in torchtitan/protocols/model.py) with the standard layers defined, so that other model (text ones) can inherit, e.g. https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama4/model/model.py#L425
but this won't stop model from not using the layers (e.g. tok_embeddings). We can make them less error-prone by removing the forward function in subclasses since they are identical.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! The refactor of making a BaseTransformer class can be handled in a separate PR, this would need an alignment between rope_cache and freqs_cis naming as well
Also I was thinking how to design the 1 single general wrapper of vllm model.
There are 4 things are model specific and need to be pluged into the general wrapper class, before register model to vllm:
- model class itself
- model_args class
- model's parallel plan (mainly TP. PP needs to be considered separately)
- state dict adapter
I refactored the code into a basic wrapper and plugin components by inheriting when registering the model. Wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this would need an alignment between rope_cache and freqs_cis naming as well
Why? Is it because we put freqs_cis generation outside model for CP? I think @fegin has a plan to move it back into the model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No it's not related to CP, I was thinking we need to access self.rope_cache / self.freqs_cis during the forward of BaseTransformer:
| h = layer(h, self.rope_cache, attention_masks, positions) |
| positions_2d = positions.unsqueeze(0) # [total_tokens] -> [1, total_tokens] | ||
|
|
||
| # Get embeddings from 2D tokens | ||
| h = self.model.tok_embeddings(tokens_2d) # [1, total_tokens, hidden_size] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qq: curious are the input tokens always have bz=1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not 100% sure about this part, but I guess no? I only tried single prompts for now but if batched inference is enabled, I will leave a ToDo here for now
| register() | ||
|
|
||
|
|
||
| def parse_args(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not urgent, but we should use "our" config system in the long term
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curretnly the entry point is vllm engine, so we are taking the config from whatever vllm engine passed to us. Let me check vllm engine see if there's anything we could do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait how is it related to vllm config system? You are just using them as is in args = parse_args().
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This args is only for infer.py script, it will pass args into vllm engine LLM() , and vllm engine will create a VLLMConfig instance internally, and pass to our model wrapper
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it will pass args into vllm engine LLM()
I don't think it's passing the args to LLM(). What would be different if we use our config manager to construct args?
| model_args.n_heads | ||
| if model_args.n_kv_heads is None | ||
| else model_args.n_kv_heads | ||
| def _replice_with_vllm_paged_attention(self, model_args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we expose this function as a util function? It could also be reused for the trainer model
d7b714b to
226150d
Compare
torchtitan/experiments/deterministic_vllm_rl/models/base_wrapper.py
Outdated
Show resolved
Hide resolved
| return parallel_dims | ||
|
|
||
|
|
||
| def build_device_mesh_and_parallelize( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious how will this function be used? Won't VLLM engine handle TP=2 for us?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
VLLM engine handle TP=2 for us?
vllm applies TP by patching each module: https://docs.vllm.ai/en/latest/contributing/model/basic/#3-optional-implement-tensor-parallelism-and-quantization-support, which also changes model definition and happens during model initialization
This function is calling parallelize_qwen3 function from qwen3/infra/parallelize.py, according to parallel_dims, it will apply different parallism
| register() | ||
|
|
||
|
|
||
| def parse_args(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait how is it related to vllm config system? You are just using them as is in args = parse_args().
| logits = self.model.output(h) | ||
|
|
||
| if isinstance(logits, DTensor): | ||
| logits = logits.full_tensor() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you verified this for TP, or is it only working for single gpu? If it's the latter, add a TODO around such conversions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently this PR only works for single GPU, and I will add TP in a following PR
| layer_idx = next(VLLMPagedFlashAttention._layer_counter) | ||
| prefix = f"layers.{layer_idx}" | ||
|
|
||
| self.vllm_attn = Attention( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why can't we always use VLLMPagedFlashAttention?
| model_cls=train_spec.model_cls, | ||
| model_args_cls=model_args_cls, | ||
| state_dict_adapter=train_spec.state_dict_adapter, | ||
| parallelize_fn=train_spec.parallelize_fn, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems we need these fields and the wrappers TorchTitanVLLMModel / TorchTitanVLLMModelFromSpec because we rely on vllm's LLM() api to create the model.
This is hacky and making things complicated as we are dumping a lot of logic (originally in train.py and checkpoint.py) to the model code itself.
I feel this is unnecessary if our end goal is to use the engine part of vLLM, not the model init part.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dumping a lot of logic (originally in train.py and checkpoint.py) to the model code itself
Agreed, the main blocker is that we need to have control of how Worker instantiate a model.
According to vllm design , this class is not only a model nn.module, but a model_runner https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/worker_base.py#L85, that's why it has load_weights function
| if parallel_dims.tp_enabled: | ||
| self.world_mesh = parallel_dims.world_mesh | ||
| tp_mesh = self.world_mesh["tp"] | ||
| parallelize_fn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wondering why do we parallelize the model during init? This model is used during vllm inference, and I thought VLLM has its own TP impl?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want to apply our TP instead of using vLLM TP implementation, we can not have direct access to the model later once LLM() initialized, so we can not apply TP later
As titled, put it in deterministic RL folder