Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ cprofile
# Test exclusions
qa/L0_openai/openai
tensorrtllm_models
tensorrtllm_mistral_models/
custom_tokenizer
53 changes: 50 additions & 3 deletions python/openai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,8 @@ pytest -v tests/
### LoRA Adapters

If the command line argument `--lora-separator=<separator_string>` is provided
when starting the OpenAI Frontend, a vLLM LoRA adaptor listed on the
`multi_lora.json` may be selected by appending the LoRA name to the model name,
when starting the OpenAI Frontend, a LoRA adaptor listed in `multi_lora.json`
may be selected by appending the LoRA name to the model name,
separated by the LoRA separator, on the inference request in
`<model_name><separator_string><lora_name>` format.

Expand Down Expand Up @@ -297,9 +297,56 @@ the same `<model_name><separator_string><lora_name>` format for each LoRA
adapter listed on the `multi_lora.json`. Note: The LoRA name inclusion is
limited to locally stored models, inference requests are not limited though.

#### vLLM
See the
[vLLM documentation](https://github.com/triton-inference-server/vllm_backend/blob/main/docs/llama_multi_lora_tutorial.md)
on how to serve a model with LoRA adapters.
on how to serve a vLLM model with LoRA adapters.

#### TensorRT-LLM
Similarly, see [TensorRT-LLM document](https://github.com/triton-inference-server/tensorrtllm_backend/blob/main/docs/lora.md)
on how to prepare LoRA-enabled TensorRT-LLM engines and generate LoRA tensors.
The path of LoRA adapter in `multi_lora.json` is the directory of
`model.lora_config.npy` and `model.lora_weights.npy` tensors.

<details>
<summary>For example</summary>

model repository
```
inflight_batcher_llm
├── postprocessing
| ├── 1
| | └── model.py
| └── config.pbtxt
├── preprocessing
| ├── 1
| | └── model.py
| └── config.pbtxt
├── tensorrt_llm
| ├── 1
| | └── model.py
| └── config.pbtxt
└── tensorrt_llm_bls
├── 1
| ├── Japanese-Alpaca-LoRA-7b-v0-weights
| | ├── model.lora_config.npy
| | └── model.lora_weights.npy
| ├── luotuo-lora-7b-0.1-weights
| | ├── model.lora_config.npy
| | └── model.lora_weights.npy
| ├── model.py
| └── multi_lora.json
└── config.pbtxt
```

multi_lora.json
```
{
"doll": "inflight_batcher_llm/tensorrt_llm_bls/1/luotuo-lora-7b-0.1-weights",
"sheep": "inflight_batcher_llm/tensorrt_llm_bls/1/Japanese-Alpaca-LoRA-7b-v0-weights"
}
```
</details>

### Embedding Models
Currently, OpenAI-Compatible Frontend supports loading embedding models and embeddings endpoints via vLLM backend. Check [vLLM supported models](https://docs.vllm.ai/en/latest/models/supported_models.html#embedding) for all supported embedding models from vLLM.
Expand Down
49 changes: 32 additions & 17 deletions python/openai/openai_frontend/engine/triton_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from engine.utils.tool_call_parsers import ToolCallParser, ToolParserManager
from engine.utils.triton import (
RequestKind,
TritonLoraConfig,
_create_trtllm_embedding_request,
_create_trtllm_generate_request,
_create_vllm_embedding_request,
Expand All @@ -61,7 +62,7 @@
_get_openai_completion_format_logprobs_from_vllm_response,
_get_output,
_get_usage_from_response,
_get_vllm_lora_names,
_parse_lora_configs,
_StreamingUsageAccumulator,
_validate_triton_responses_non_streaming,
)
Expand Down Expand Up @@ -107,7 +108,7 @@ class TritonModelMetadata:
# Tokenizers used for chat templates
tokenizer: Optional[Any]
# LoRA names supported by the backend
lora_names: Optional[List[str]]
lora_configs: Optional[List[TritonLoraConfig]]
# Name of the input tensor enabling "echo" parameter in /v1/completions endpoint
echo_tensor_name: Optional[str]
# Time that model was loaded by Triton
Expand Down Expand Up @@ -160,11 +161,11 @@ def models(self) -> List[Model]:
if (
self.lora_separator is not None
and len(self.lora_separator) > 0
and metadata.lora_names is not None
and metadata.lora_configs is not None
):
for lora_name in metadata.lora_names:
for lora_config in metadata.lora_configs:
model_names.append(
f"{metadata.name}{self.lora_separator}{lora_name}"
f"{metadata.name}{self.lora_separator}{lora_config.name}"
)

for model_name in model_names:
Expand Down Expand Up @@ -210,7 +211,7 @@ async def chat(
metadata.model,
prompt,
request,
lora_name,
self._get_lora_config(model_name, lora_name),
metadata.echo_tensor_name,
self.default_max_tokens,
)
Expand Down Expand Up @@ -348,7 +349,7 @@ async def completion(
metadata.model,
request.prompt,
request,
lora_name,
self._get_lora_config(model_name, lora_name),
metadata.echo_tensor_name,
self.default_max_tokens,
)
Expand Down Expand Up @@ -505,11 +506,12 @@ def _get_model_metadata(self) -> Dict[str, TritonModelMetadata]:
backend = "ensemble"
print(f"Found model: {name=}, {backend=}")

lora_names = None
if self.backend == "vllm" or backend == "vllm":
lora_names = _get_vllm_lora_names(
self.server.options.model_repository, name, model.version
)
lora_configs = _parse_lora_configs(
self.server.options.model_repository,
name,
model.version,
backend if self.backend is None else self.backend,
)

echo_tensor_name = None
for input in model.config()["input"]:
Expand All @@ -525,7 +527,7 @@ def _get_model_metadata(self) -> Dict[str, TritonModelMetadata]:
backend=backend,
model=model,
tokenizer=self.tokenizer,
lora_names=lora_names,
lora_configs=lora_configs,
echo_tensor_name=echo_tensor_name,
create_time=self.create_time,
inference_request_converter=self._determine_request_converter(
Expand Down Expand Up @@ -807,9 +809,10 @@ def _validate_chat_request(
)

if (
metadata.lora_names is not None
metadata.lora_configs is not None
and lora_name is not None
and lora_name not in metadata.lora_names
and lora_name
not in [lora_config.name for lora_config in metadata.lora_configs]
):
raise ClientError(f"Unknown LoRA: {lora_name}; for model: {request.model}")

Expand Down Expand Up @@ -970,9 +973,10 @@ def _validate_completion_request(
)

if (
metadata.lora_names is not None
metadata.lora_configs is not None
and lora_name is not None
and lora_name not in metadata.lora_names
and lora_name
not in [lora_config.name for lora_config in metadata.lora_configs]
):
raise ClientError(f"Unknown LoRA: {lora_name}; for model: {request.model}")

Expand Down Expand Up @@ -1081,3 +1085,14 @@ def _get_named_function_name(
tool_choice_required_function_name = None

return tool_choice_function_name or tool_choice_required_function_name

def _get_lora_config(
self, model_name: str, lora_name: Optional[str]
) -> TritonLoraConfig:
model_metadata = self.model_metadata.get(model_name)
if lora_name is None or model_metadata.lora_configs is None:
return None
for lora_config in model_metadata.lora_configs:
if lora_config.name == lora_name:
return lora_config
raise ClientError(f"Unknown LoRA: {lora_name}; for model: {model_name}")
118 changes: 88 additions & 30 deletions python/openai/openai_frontend/engine/utils/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import os
import re
import sys
import traceback
from dataclasses import asdict, dataclass, field
from enum import Enum
from pathlib import Path
Expand Down Expand Up @@ -56,11 +57,21 @@ class RequestKind(Enum):
EMBEDDING = 2


@dataclass
class TritonLoraConfig:
name: str

# Unique fields for TensorRT-LLM backend
task_id: Optional[int] = None
path: Optional[str] = None
is_registered: Optional[bool] = False


def _create_vllm_generate_request(
model,
prompt,
request: CreateChatCompletionRequest | CreateCompletionRequest,
lora_name: str | None,
lora_config: TritonLoraConfig | None,
echo_tensor_name: str | None,
default_max_tokens: int,
):
Expand Down Expand Up @@ -135,8 +146,8 @@ def _create_vllm_generate_request(
request_logprobs = True
inputs["return_logprobs"] = np.bool_([request_logprobs])

if lora_name is not None:
sampling_parameters["lora_name"] = lora_name
if lora_config is not None:
sampling_parameters["lora_name"] = lora_config.name

guided_json = _get_guided_json_from_tool(request)
if guided_json is not None:
Expand Down Expand Up @@ -167,15 +178,10 @@ def _create_trtllm_generate_request(
model,
prompt,
request: CreateChatCompletionRequest | CreateCompletionRequest,
lora_name: str | None,
lora_config: TritonLoraConfig | None,
echo_tensor_name: str | None,
default_max_tokens: int,
):
if lora_name is not None:
raise ClientError(
"LoRA selection is currently not supported for TRT-LLM backend"
)

inputs = {}
inputs["text_input"] = [[prompt]]
inputs["stream"] = np.bool_([[request.stream]])
Expand Down Expand Up @@ -221,6 +227,21 @@ def _create_trtllm_generate_request(
inputs["guided_decoding_guide_type"] = [["json_schema"]]
inputs["guided_decoding_guide"] = [[guided_json]]

if lora_config is not None:
# To perform inference with a specific LoRA for the first time `lora_task_id` `lora_weights` and `lora_config` must all be given.
# The LoRA will be cached, so that subsequent requests for the same task only require `lora_task_id`.
inputs["lora_task_id"] = np.uint64([[lora_config.task_id]])
if not lora_config.is_registered:
lora_weights_data = np.load(
os.path.join(lora_config.path, "model.lora_weights.npy")
)
lora_config_data = np.load(
os.path.join(lora_config.path, "model.lora_config.npy")
)
inputs["lora_weights"] = lora_weights_data
inputs["lora_config"] = lora_config_data
lora_config.is_registered = True

inputs["return_num_input_tokens"] = np.bool_([[True]])
inputs["return_num_output_tokens"] = np.bool_([[True]])
return model.create_request(inputs=inputs)
Expand Down Expand Up @@ -594,9 +615,9 @@ def _get_guided_json_from_tool(
return None


def _get_vllm_lora_names(
model_repository: str | list[str], model_name: str, model_version: int
) -> None | List[str]:
def _parse_lora_configs(
model_repository: str | list[str], model_name: str, model_version: int, backend: str
) -> None | List[tuple[str, str]]:
if (
len(model_name) == 0
or model_name.isspace()
Expand All @@ -606,7 +627,9 @@ def _get_vllm_lora_names(
raise ValueError(
f"Invalid model name: '{model_name}'. Model names must be valid file-system-path segment names."
)
lora_names = []

lora_configs = []
lora_task_id = 1
repo_paths = model_repository
if isinstance(repo_paths, str):
repo_paths = [repo_paths]
Expand All @@ -618,6 +641,7 @@ def _get_vllm_lora_names(
raise ValueError(
f"Invalid model name: '{model_name}'. Model names must be valid file-system-path segment names."
)

model_path = os.path.normpath(model_path)
if not os.path.isdir(model_path):
# Cloud path?
Expand All @@ -632,26 +656,60 @@ def _get_vllm_lora_names(
# Model directory is malformed?
return None
version_path = os.path.join(model_path, str(model_version))
is_lora_enabled = False
model_file_path = os.path.join(version_path, "model.json")
try:
with open(model_file_path, "r") as f:
config = json.load(f)
if "enable_lora" in config:
# The value could be a string or a bool.
is_lora_enabled = str(config["enable_lora"]).lower() == "true"
except Exception:
# Model directory or model.json is malformed?
return None
if is_lora_enabled != True:
continue
lora_config_path = os.path.join(version_path, "multi_lora.json")

if backend == "vllm":
is_lora_enabled = False
model_file_path = os.path.join(version_path, "model.json")
try:
with open(model_file_path, "r") as f:
config = json.load(f)
if "enable_lora" in config:
# The value could be a string or a bool.
is_lora_enabled = str(config["enable_lora"]).lower() == "true"
except Exception:
# Model directory or model.json is malformed?
return None
if is_lora_enabled != True:
continue
else:
# TRT-LLM backend does not use model.json
if not os.path.exists(lora_config_path):
continue

try:
with open(lora_config_path, "r") as f:
lora_config = json.load(f)
for lora_name in lora_config.keys():
lora_names.append(lora_name)
except Exception:
for lora_name, lora_path in lora_config.items():
print(f"backend: {backend}")
if backend == "vllm":
lora_configs.append(TritonLoraConfig(name=lora_name))
else:
lora_weights_path = os.path.join(
lora_path, "model.lora_weights.npy"
)
lora_config_path = os.path.join(
lora_path, "model.lora_config.npy"
)
if not os.path.exists(lora_weights_path):
raise ServerError(
f"LoRA weights file not found for '{lora_name}' at path: {lora_weights_path}"
)
if not os.path.exists(lora_config_path):
raise ServerError(
f"LoRA config file not found for '{lora_name}' at path: {lora_config_path}"
)

lora_configs.append(
TritonLoraConfig(
name=lora_name, path=lora_path, task_id=lora_task_id
)
)
lora_task_id += 1
except ServerError as e:
raise e
except Exception as e:
# LoRA is enabled but its list is not provided or malformed?
print(traceback.format_exc())
return None
return lora_names
return lora_configs
Loading
Loading