[train] Support logprobs, fix generation config defaults and add more generation tests for the new HTTP inference pathway#1038
[train] Support logprobs, fix generation config defaults and add more generation tests for the new HTTP inference pathway#1038
Conversation
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
|
|
||
| finally: | ||
| ray.shutdown() | ||
|
|
There was a problem hiding this comment.
Tests already use ray_init_fixture, which handles cleanup
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
There was a problem hiding this comment.
Code Review
This pull request primarily refactors the inference engine interaction to consistently use token-based generation and properly handle log probabilities, especially for Token Importance Sampling (TIS). Key changes include updating various configuration files and example scripts to set logprobs=1 instead of logprobs=0 for sampling parameters, and modifying the RemoteInferenceClient to use a new /inference/v1/generate endpoint that operates with token_ids and returns response_logprobs. The skyrl_gym_generator.py is updated to explicitly use token-in-token-out for consistency. Additionally, the GPU CI script is updated to include a new test for the skyrl_gym_generator, and several test files (test_engine_generation.py, test_inference_engine_client_http_endpoint.py, test_lora.py, test_megatron_worker.py, test_pause_and_continue_generation.py, test_policy_local_engines_e2e.py, test_save_weights_for_sampler.py, test_skyrl_gym_generator.py, test_verifiers_generator.py) are refactored to use a new InferenceEngineState context manager for managing engine lifecycle and to align with the token-based generation approach. A minor change also strips trailing newlines from action strings in skyrl_gym/envs/search/env.py. Review comments suggest simplifying a conditional expression for response_logprobs and addressing an inconsistency in tokenizer.apply_chat_template regarding add_special_tokens.
| "stop_reason": stop_reason, | ||
| "response_ids": final_token_ids, | ||
| "response_ids": accum_token_ids, | ||
| "response_logprobs": response_logprobs if len(response_logprobs) > 0 else None, |
There was a problem hiding this comment.
The expression response_logprobs if len(response_logprobs) > 0 else None can be simplified. Since an empty list [] evaluates to False in a boolean context, you can use a more concise and Pythonic expression.
| "response_logprobs": response_logprobs if len(response_logprobs) > 0 else None, | |
| "response_logprobs": response_logprobs if response_logprobs else None, |
kouroshHakha
left a comment
There was a problem hiding this comment.
Overall is good. I'd wish you had broken down the pr into individual smaller prs for each part. But it's ok for now, for the next PRs, let's make sure orthogonal features are kept separate on PRs.
| prompt_token_ids = self.tokenizer.apply_chat_template( | ||
| prompts, | ||
| add_generation_prompt=True, | ||
| add_special_tokens=False, |
There was a problem hiding this comment.
was this intentional?
| # Run tests for new inference layer | ||
| _SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra vllm pytest -s tests/gpu/gpu_ci/test_policy_local_engines_e2e.py -m "vllm" | ||
| _SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra vllm pytest -s tests/gpu/gpu_ci/test_engine_generation.py -m "vllm" | ||
| _SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra vllm pytest -s tests/gpu/gpu_ci/test_skyrl_gym_generator.py |
There was a problem hiding this comment.
can we create a list here on which test would fail and need work if we switch to _SKYRL_USE_NEW_INFERENCE=1 uv run <options> pytest -s tests/gpu/gpu_ci/
In the next PRs as we fix those tests we will add more lines here.
In the end at some point we will just run the tests with new inference flag on the gpu_ci folder (some tests will be skipped, those that are testing inferenceEnigneClient)
What does this PR do?
This PR migrates more tests for the new HTTP inference pathway and adds some missing features like rollout logprobs support along the way. Also makes some fixes for tests failures on
main. The changes are as follows:Test improvements
Introduces a new
InferenceEngineStateclass to manage instantiating inference engines in states. With better state management, this fixes some cleanup issues for existingtest_policy_local_engines_e2etest in CI.Configuration fix for vLLM server actor
vLLM server can have different generation quality from
AsyncLLMEngine.generate. I noticed this while going over generations in the weight sync tests:/v1/completionsAsyncLLMEngine.generateMore details here: https://gist.github.com/SumanthRH/847a328c121c1463b8b8aca6d548224f
The reason is that vllm server's generation config defaults are different. Passing
--generation-config vllmfixes the issue.Switch to
/inference/v1/generateforRemoteInferenceClient.generateFor
RemoteInferenceClient.generate, I notice that we were re-tokenizing intermediate tokens (on abort), which can cause small drifts since tokenization is not invertible. The solution is to not rely on/v1/completionsand instead use the token-in-token-out endpoint/inference/v1/generate- this also makes it compatible with accumulating logprobs returned from the server. There can also be silent issues with the completions API as above. For RL, it is best to use the/generateendpointSupport response logprobs for
RemoteInferenceClientAdds support for
response_logprobsinRemoteInferenceClient. Note that there are some slight differences insampling_paramsfor/inference/v1/generateandAsyncLLMEngine.generate. As per the OpenAI completions API ,logprobs=0is meant to return logprobs for the chosen token (same aslogprobs=1). However,/inference/v1/generatetreatslogprobs=0aslogprobs=null, and doesn't return any logprobs. This is a vLLM issue. I have created a PR: vllm-project/vllm#34010. While we wait for it to land, I believe it is overall better to rely onlogprobs=1for getting logprobs for the chosen token. it also lends itself to truthy checksif logprobs:better.Support
test_skyrl_gym_generatorfor_SKYRL_USE_NEW_INFERENCE=1SkyRLGymGeneratorto provide input tokens over text forgenerate-> this is because the new pathway only supports tokensSearchEnv.validate_actionto strip newlines: With/inference/v1/generate, only output tokens are provided (unlikeAsyncLLMEngine.generatewhere output text is also available). With output tokens, there can be a case where the LLM generates a trailing newline - generating [<search,>\n] as opposed to [<search>]. One would need to postprocess the output text after detokenization to ensure that strings end exactly with the stop string. There are two fixes here:RemoteInferenceClientdo custom postprocessing forgeneratebased on stop stringsSearchEnv(It is the only Env with this strict parsing)I prefer 2. because
RemoteInferenceClientlayer should be pretty much pass-through and operate in token space.