Skip to content

Conversation

@wwwjn
Copy link
Contributor

@wwwjn wwwjn commented Dec 5, 2025

As titled, put it in deterministic RL folder

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Dec 5, 2025
@wwwjn wwwjn requested review from acisseJZhong and bwasti December 5, 2025 22:22
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left some comments, ignore if you still plan to change

return output


class VLLMPagedFlashAttention(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this for? I thought we only run inference.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By design this class should be able to run both training and inference in the future, so we have one single attention class. Now it's only used when running inference. I will clean up and remove non-inference related stuff. cc @zhxchen17



class FeedForwardVLLMCompat(nn.Module):
class TorchTitanQwen3ForCausalLM(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should aim for at most 1 general wrapper for all models -- we shouldn't have 1 wrapper for each model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's the goal! The current wrapper is specific to Qwen3 model as we access each model layers by name, let me think about how to design the interface

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suppose each model author defines the model in slightly different ways, then it seems to up to the ppl who do RL to make the adaptation work like we prototyped here.

So I guess we need some sort of contract baked in the authoring time. e.g. Models should annotate/implement their attention layers in certain ways (having some sort of base class or special methods?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed with @zhxchen17.
I think we can start by adding a BaseTransformer class (in torchtitan/protocols/model.py) with the standard layers defined, so that other model (text ones) can inherit, e.g. https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama4/model/model.py#L425

but this won't stop model from not using the layers (e.g. tok_embeddings). We can make them less error-prone by removing the forward function in subclasses since they are identical.

Copy link
Contributor Author

@wwwjn wwwjn Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! The refactor of making a BaseTransformer class can be handled in a separate PR, this would need an alignment between rope_cache and freqs_cis naming as well

Also I was thinking how to design the 1 single general wrapper of vllm model.
There are 4 things are model specific and need to be pluged into the general wrapper class, before register model to vllm:

  1. model class itself
  2. model_args class
  3. model's parallel plan (mainly TP. PP needs to be considered separately)
  4. state dict adapter

I refactored the code into a basic wrapper and plugin components by inheriting when registering the model. Wdyt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would need an alignment between rope_cache and freqs_cis naming as well

Why? Is it because we put freqs_cis generation outside model for CP? I think @fegin has a plan to move it back into the model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it's not related to CP, I was thinking we need to access self.rope_cache / self.freqs_cis during the forward of BaseTransformer:

h = layer(h, self.rope_cache, attention_masks, positions)

positions_2d = positions.unsqueeze(0) # [total_tokens] -> [1, total_tokens]

# Get embeddings from 2D tokens
h = self.model.tok_embeddings(tokens_2d) # [1, total_tokens, hidden_size]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: curious are the input tokens always have bz=1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% sure about this part, but I guess no? I only tried single prompts for now but if batched inference is enabled, I will leave a ToDo here for now

register()


def parse_args():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not urgent, but we should use "our" config system in the long term

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curretnly the entry point is vllm engine, so we are taking the config from whatever vllm engine passed to us. Let me check vllm engine see if there's anything we could do

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait how is it related to vllm config system? You are just using them as is in args = parse_args().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This args is only for infer.py script, it will pass args into vllm engine LLM() , and vllm engine will create a VLLMConfig instance internally, and pass to our model wrapper

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it will pass args into vllm engine LLM()

I don't think it's passing the args to LLM(). What would be different if we use our config manager to construct args?

model_args.n_heads
if model_args.n_kv_heads is None
else model_args.n_kv_heads
def _replice_with_vllm_paged_attention(self, model_args):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we expose this function as a util function? It could also be reused for the trainer model

@wwwjn wwwjn force-pushed the vllm-infer branch 3 times, most recently from d7b714b to 226150d Compare December 9, 2025 18:52
@wwwjn wwwjn marked this pull request as ready for review December 9, 2025 18:57
return parallel_dims


def build_device_mesh_and_parallelize(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious how will this function be used? Won't VLLM engine handle TP=2 for us?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VLLM engine handle TP=2 for us?

vllm applies TP by patching each module: https://docs.vllm.ai/en/latest/contributing/model/basic/#3-optional-implement-tensor-parallelism-and-quantization-support, which also changes model definition and happens during model initialization

This function is calling parallelize_qwen3 function from qwen3/infra/parallelize.py, according to parallel_dims, it will apply different parallism

register()


def parse_args():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait how is it related to vllm config system? You are just using them as is in args = parse_args().

logits = self.model.output(h)

if isinstance(logits, DTensor):
logits = logits.full_tensor()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you verified this for TP, or is it only working for single gpu? If it's the latter, add a TODO around such conversions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently this PR only works for single GPU, and I will add TP in a following PR

layer_idx = next(VLLMPagedFlashAttention._layer_counter)
prefix = f"layers.{layer_idx}"

self.vllm_attn = Attention(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why can't we always use VLLMPagedFlashAttention?

Comment on lines +50 to +53
model_cls=train_spec.model_cls,
model_args_cls=model_args_cls,
state_dict_adapter=train_spec.state_dict_adapter,
parallelize_fn=train_spec.parallelize_fn,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we need these fields and the wrappers TorchTitanVLLMModel / TorchTitanVLLMModelFromSpec because we rely on vllm's LLM() api to create the model.

This is hacky and making things complicated as we are dumping a lot of logic (originally in train.py and checkpoint.py) to the model code itself.

I feel this is unnecessary if our end goal is to use the engine part of vLLM, not the model init part.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dumping a lot of logic (originally in train.py and checkpoint.py) to the model code itself

Agreed, the main blocker is that we need to have control of how Worker instantiate a model.

According to vllm design , this class is not only a model nn.module, but a model_runner https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/worker_base.py#L85, that's why it has load_weights function

if parallel_dims.tp_enabled:
self.world_mesh = parallel_dims.world_mesh
tp_mesh = self.world_mesh["tp"]
parallelize_fn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering why do we parallelize the model during init? This model is used during vllm inference, and I thought VLLM has its own TP impl?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to apply our TP instead of using vLLM TP implementation, we can not have direct access to the model later once LLM() initialized, so we can not apply TP later

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants