diff --git a/.gitignore b/.gitignore index 1e0b66480b..31f5d9f5f4 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,5 @@ cprofile # Test exclusions qa/L0_openai/openai tensorrtllm_models +tensorrtllm_mistral_models/ custom_tokenizer diff --git a/python/openai/README.md b/python/openai/README.md index 71244b373e..8c4906faf0 100644 --- a/python/openai/README.md +++ b/python/openai/README.md @@ -241,8 +241,8 @@ pytest -v tests/ ### LoRA Adapters If the command line argument `--lora-separator=` is provided -when starting the OpenAI Frontend, a vLLM LoRA adaptor listed on the -`multi_lora.json` may be selected by appending the LoRA name to the model name, +when starting the OpenAI Frontend, a LoRA adaptor listed in `multi_lora.json` +may be selected by appending the LoRA name to the model name, separated by the LoRA separator, on the inference request in `` format. @@ -297,9 +297,56 @@ the same `` format for each LoRA adapter listed on the `multi_lora.json`. Note: The LoRA name inclusion is limited to locally stored models, inference requests are not limited though. +#### vLLM See the [vLLM documentation](https://github.com/triton-inference-server/vllm_backend/blob/main/docs/llama_multi_lora_tutorial.md) -on how to serve a model with LoRA adapters. +on how to serve a vLLM model with LoRA adapters. + +#### TensorRT-LLM +Similarly, see [TensorRT-LLM document](https://github.com/triton-inference-server/tensorrtllm_backend/blob/main/docs/lora.md) +on how to prepare LoRA-enabled TensorRT-LLM engines and generate LoRA tensors. +The path of LoRA adapter in `multi_lora.json` is the directory of +`model.lora_config.npy` and `model.lora_weights.npy` tensors. + +
+For example + +model repository +``` +inflight_batcher_llm +├── postprocessing +| ├── 1 +| | └── model.py +| └── config.pbtxt +├── preprocessing +| ├── 1 +| | └── model.py +| └── config.pbtxt +├── tensorrt_llm +| ├── 1 +| | └── model.py +| └── config.pbtxt +└── tensorrt_llm_bls + ├── 1 + | ├── Japanese-Alpaca-LoRA-7b-v0-weights + | | ├── model.lora_config.npy + | | └── model.lora_weights.npy + | ├── luotuo-lora-7b-0.1-weights + | | ├── model.lora_config.npy + | | └── model.lora_weights.npy + | ├── model.py + | └── multi_lora.json + └── config.pbtxt +``` + +multi_lora.json +``` +{ + "doll": "inflight_batcher_llm/tensorrt_llm_bls/1/luotuo-lora-7b-0.1-weights", + "sheep": "inflight_batcher_llm/tensorrt_llm_bls/1/Japanese-Alpaca-LoRA-7b-v0-weights" +} +``` +
### Embedding Models Currently, OpenAI-Compatible Frontend supports loading embedding models and embeddings endpoints via vLLM backend. Check [vLLM supported models](https://docs.vllm.ai/en/latest/models/supported_models.html#embedding) for all supported embedding models from vLLM. diff --git a/python/openai/openai_frontend/engine/triton_engine.py b/python/openai/openai_frontend/engine/triton_engine.py index 9d515079b9..0d735be5c0 100644 --- a/python/openai/openai_frontend/engine/triton_engine.py +++ b/python/openai/openai_frontend/engine/triton_engine.py @@ -53,6 +53,7 @@ from engine.utils.tool_call_parsers import ToolCallParser, ToolParserManager from engine.utils.triton import ( RequestKind, + TritonLoraConfig, _create_trtllm_embedding_request, _create_trtllm_generate_request, _create_vllm_embedding_request, @@ -61,7 +62,7 @@ _get_openai_completion_format_logprobs_from_vllm_response, _get_output, _get_usage_from_response, - _get_vllm_lora_names, + _parse_lora_configs, _StreamingUsageAccumulator, _validate_triton_responses_non_streaming, ) @@ -107,7 +108,7 @@ class TritonModelMetadata: # Tokenizers used for chat templates tokenizer: Optional[Any] # LoRA names supported by the backend - lora_names: Optional[List[str]] + lora_configs: Optional[List[TritonLoraConfig]] # Name of the input tensor enabling "echo" parameter in /v1/completions endpoint echo_tensor_name: Optional[str] # Time that model was loaded by Triton @@ -160,11 +161,11 @@ def models(self) -> List[Model]: if ( self.lora_separator is not None and len(self.lora_separator) > 0 - and metadata.lora_names is not None + and metadata.lora_configs is not None ): - for lora_name in metadata.lora_names: + for lora_config in metadata.lora_configs: model_names.append( - f"{metadata.name}{self.lora_separator}{lora_name}" + f"{metadata.name}{self.lora_separator}{lora_config.name}" ) for model_name in model_names: @@ -210,7 +211,7 @@ async def chat( metadata.model, prompt, request, - lora_name, + self._get_lora_config(model_name, lora_name), metadata.echo_tensor_name, self.default_max_tokens, ) @@ -348,7 +349,7 @@ async def completion( metadata.model, request.prompt, request, - lora_name, + self._get_lora_config(model_name, lora_name), metadata.echo_tensor_name, self.default_max_tokens, ) @@ -505,11 +506,12 @@ def _get_model_metadata(self) -> Dict[str, TritonModelMetadata]: backend = "ensemble" print(f"Found model: {name=}, {backend=}") - lora_names = None - if self.backend == "vllm" or backend == "vllm": - lora_names = _get_vllm_lora_names( - self.server.options.model_repository, name, model.version - ) + lora_configs = _parse_lora_configs( + self.server.options.model_repository, + name, + model.version, + backend if self.backend is None else self.backend, + ) echo_tensor_name = None for input in model.config()["input"]: @@ -525,7 +527,7 @@ def _get_model_metadata(self) -> Dict[str, TritonModelMetadata]: backend=backend, model=model, tokenizer=self.tokenizer, - lora_names=lora_names, + lora_configs=lora_configs, echo_tensor_name=echo_tensor_name, create_time=self.create_time, inference_request_converter=self._determine_request_converter( @@ -807,9 +809,10 @@ def _validate_chat_request( ) if ( - metadata.lora_names is not None + metadata.lora_configs is not None and lora_name is not None - and lora_name not in metadata.lora_names + and lora_name + not in [lora_config.name for lora_config in metadata.lora_configs] ): raise ClientError(f"Unknown LoRA: {lora_name}; for model: {request.model}") @@ -970,9 +973,10 @@ def _validate_completion_request( ) if ( - metadata.lora_names is not None + metadata.lora_configs is not None and lora_name is not None - and lora_name not in metadata.lora_names + and lora_name + not in [lora_config.name for lora_config in metadata.lora_configs] ): raise ClientError(f"Unknown LoRA: {lora_name}; for model: {request.model}") @@ -1081,3 +1085,14 @@ def _get_named_function_name( tool_choice_required_function_name = None return tool_choice_function_name or tool_choice_required_function_name + + def _get_lora_config( + self, model_name: str, lora_name: Optional[str] + ) -> TritonLoraConfig: + model_metadata = self.model_metadata.get(model_name) + if lora_name is None or model_metadata.lora_configs is None: + return None + for lora_config in model_metadata.lora_configs: + if lora_config.name == lora_name: + return lora_config + raise ClientError(f"Unknown LoRA: {lora_name}; for model: {model_name}") diff --git a/python/openai/openai_frontend/engine/utils/triton.py b/python/openai/openai_frontend/engine/utils/triton.py index e2c4cf92c9..344ad374c9 100644 --- a/python/openai/openai_frontend/engine/utils/triton.py +++ b/python/openai/openai_frontend/engine/utils/triton.py @@ -28,6 +28,7 @@ import os import re import sys +import traceback from dataclasses import asdict, dataclass, field from enum import Enum from pathlib import Path @@ -56,11 +57,21 @@ class RequestKind(Enum): EMBEDDING = 2 +@dataclass +class TritonLoraConfig: + name: str + + # Unique fields for TensorRT-LLM backend + task_id: Optional[int] = None + path: Optional[str] = None + is_registered: Optional[bool] = False + + def _create_vllm_generate_request( model, prompt, request: CreateChatCompletionRequest | CreateCompletionRequest, - lora_name: str | None, + lora_config: TritonLoraConfig | None, echo_tensor_name: str | None, default_max_tokens: int, ): @@ -135,8 +146,8 @@ def _create_vllm_generate_request( request_logprobs = True inputs["return_logprobs"] = np.bool_([request_logprobs]) - if lora_name is not None: - sampling_parameters["lora_name"] = lora_name + if lora_config is not None: + sampling_parameters["lora_name"] = lora_config.name guided_json = _get_guided_json_from_tool(request) if guided_json is not None: @@ -167,15 +178,10 @@ def _create_trtllm_generate_request( model, prompt, request: CreateChatCompletionRequest | CreateCompletionRequest, - lora_name: str | None, + lora_config: TritonLoraConfig | None, echo_tensor_name: str | None, default_max_tokens: int, ): - if lora_name is not None: - raise ClientError( - "LoRA selection is currently not supported for TRT-LLM backend" - ) - inputs = {} inputs["text_input"] = [[prompt]] inputs["stream"] = np.bool_([[request.stream]]) @@ -221,6 +227,21 @@ def _create_trtllm_generate_request( inputs["guided_decoding_guide_type"] = [["json_schema"]] inputs["guided_decoding_guide"] = [[guided_json]] + if lora_config is not None: + # To perform inference with a specific LoRA for the first time `lora_task_id` `lora_weights` and `lora_config` must all be given. + # The LoRA will be cached, so that subsequent requests for the same task only require `lora_task_id`. + inputs["lora_task_id"] = np.uint64([[lora_config.task_id]]) + if not lora_config.is_registered: + lora_weights_data = np.load( + os.path.join(lora_config.path, "model.lora_weights.npy") + ) + lora_config_data = np.load( + os.path.join(lora_config.path, "model.lora_config.npy") + ) + inputs["lora_weights"] = lora_weights_data + inputs["lora_config"] = lora_config_data + lora_config.is_registered = True + inputs["return_num_input_tokens"] = np.bool_([[True]]) inputs["return_num_output_tokens"] = np.bool_([[True]]) return model.create_request(inputs=inputs) @@ -594,9 +615,9 @@ def _get_guided_json_from_tool( return None -def _get_vllm_lora_names( - model_repository: str | list[str], model_name: str, model_version: int -) -> None | List[str]: +def _parse_lora_configs( + model_repository: str | list[str], model_name: str, model_version: int, backend: str +) -> None | List[tuple[str, str]]: if ( len(model_name) == 0 or model_name.isspace() @@ -606,7 +627,9 @@ def _get_vllm_lora_names( raise ValueError( f"Invalid model name: '{model_name}'. Model names must be valid file-system-path segment names." ) - lora_names = [] + + lora_configs = [] + lora_task_id = 1 repo_paths = model_repository if isinstance(repo_paths, str): repo_paths = [repo_paths] @@ -618,6 +641,7 @@ def _get_vllm_lora_names( raise ValueError( f"Invalid model name: '{model_name}'. Model names must be valid file-system-path segment names." ) + model_path = os.path.normpath(model_path) if not os.path.isdir(model_path): # Cloud path? @@ -632,26 +656,60 @@ def _get_vllm_lora_names( # Model directory is malformed? return None version_path = os.path.join(model_path, str(model_version)) - is_lora_enabled = False - model_file_path = os.path.join(version_path, "model.json") - try: - with open(model_file_path, "r") as f: - config = json.load(f) - if "enable_lora" in config: - # The value could be a string or a bool. - is_lora_enabled = str(config["enable_lora"]).lower() == "true" - except Exception: - # Model directory or model.json is malformed? - return None - if is_lora_enabled != True: - continue lora_config_path = os.path.join(version_path, "multi_lora.json") + + if backend == "vllm": + is_lora_enabled = False + model_file_path = os.path.join(version_path, "model.json") + try: + with open(model_file_path, "r") as f: + config = json.load(f) + if "enable_lora" in config: + # The value could be a string or a bool. + is_lora_enabled = str(config["enable_lora"]).lower() == "true" + except Exception: + # Model directory or model.json is malformed? + return None + if is_lora_enabled != True: + continue + else: + # TRT-LLM backend does not use model.json + if not os.path.exists(lora_config_path): + continue + try: with open(lora_config_path, "r") as f: lora_config = json.load(f) - for lora_name in lora_config.keys(): - lora_names.append(lora_name) - except Exception: + for lora_name, lora_path in lora_config.items(): + print(f"backend: {backend}") + if backend == "vllm": + lora_configs.append(TritonLoraConfig(name=lora_name)) + else: + lora_weights_path = os.path.join( + lora_path, "model.lora_weights.npy" + ) + lora_config_path = os.path.join( + lora_path, "model.lora_config.npy" + ) + if not os.path.exists(lora_weights_path): + raise ServerError( + f"LoRA weights file not found for '{lora_name}' at path: {lora_weights_path}" + ) + if not os.path.exists(lora_config_path): + raise ServerError( + f"LoRA config file not found for '{lora_name}' at path: {lora_config_path}" + ) + + lora_configs.append( + TritonLoraConfig( + name=lora_name, path=lora_path, task_id=lora_task_id + ) + ) + lora_task_id += 1 + except ServerError as e: + raise e + except Exception as e: # LoRA is enabled but its list is not provided or malformed? + print(traceback.format_exc()) return None - return lora_names + return lora_configs diff --git a/python/openai/openai_frontend/frontend/fastapi/routers/chat.py b/python/openai/openai_frontend/frontend/fastapi/routers/chat.py index b7dee0c20c..49d1c5f23d 100644 --- a/python/openai/openai_frontend/frontend/fastapi/routers/chat.py +++ b/python/openai/openai_frontend/frontend/fastapi/routers/chat.py @@ -24,6 +24,8 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import traceback + from fastapi import APIRouter, HTTPException, Request from fastapi.responses import StreamingResponse from schemas.openai import CreateChatCompletionRequest, CreateChatCompletionResponse @@ -55,6 +57,8 @@ async def create_chat_completion( except ClientError as e: raise HTTPException(status_code=StatusCode.CLIENT_ERROR, detail=f"{e}") except ServerError as e: + print(traceback.format_exc()) raise HTTPException(status_code=StatusCode.SERVER_ERROR, detail=f"{e}") except Exception as e: + print(traceback.format_exc()) raise HTTPException(status_code=StatusCode.SERVER_ERROR, detail=f"{e}") diff --git a/python/openai/openai_frontend/frontend/fastapi/routers/completions.py b/python/openai/openai_frontend/frontend/fastapi/routers/completions.py index 2aa962923e..642bc117d0 100644 --- a/python/openai/openai_frontend/frontend/fastapi/routers/completions.py +++ b/python/openai/openai_frontend/frontend/fastapi/routers/completions.py @@ -24,6 +24,8 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import traceback + from fastapi import APIRouter, HTTPException, Request from fastapi.responses import StreamingResponse from schemas.openai import CreateCompletionRequest, CreateCompletionResponse @@ -54,6 +56,8 @@ async def create_completion( except ClientError as e: raise HTTPException(status_code=StatusCode.CLIENT_ERROR, detail=f"{e}") except ServerError as e: + print(traceback.format_exc()) raise HTTPException(status_code=StatusCode.SERVER_ERROR, detail=f"{e}") except Exception as e: + print(traceback.format_exc()) raise HTTPException(status_code=StatusCode.SERVER_ERROR, detail=f"{e}") diff --git a/python/openai/openai_frontend/frontend/fastapi/routers/embeddings.py b/python/openai/openai_frontend/frontend/fastapi/routers/embeddings.py index 8f0bfe6771..eb2ea5d9da 100644 --- a/python/openai/openai_frontend/frontend/fastapi/routers/embeddings.py +++ b/python/openai/openai_frontend/frontend/fastapi/routers/embeddings.py @@ -24,6 +24,8 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import traceback + from fastapi import APIRouter, HTTPException, Request from fastapi.responses import StreamingResponse from schemas.openai import CreateEmbeddingRequest, CreateEmbeddingResponse @@ -52,6 +54,8 @@ async def create_embedding( except ClientError as e: raise HTTPException(status_code=StatusCode.CLIENT_ERROR, detail=f"{e}") except ServerError as e: + print(traceback.format_exc()) raise HTTPException(status_code=StatusCode.SERVER_ERROR, detail=f"{e}") except Exception as e: + print(traceback.format_exc()) raise HTTPException(status_code=StatusCode.SERVER_ERROR, detail=f"{e}") diff --git a/python/openai/tests/test_lora.py b/python/openai/tests/test_lora.py index d6322ee4b2..a073f4b93e 100644 --- a/python/openai/tests/test_lora.py +++ b/python/openai/tests/test_lora.py @@ -33,7 +33,7 @@ from huggingface_hub import snapshot_download from openai import BadRequestError, NotFoundError from openai_frontend.engine.utils.triton import ( - _get_vllm_lora_names as get_vllm_lora_names, + _parse_lora_configs as parse_lora_configs, ) from .utils import OpenAIServer @@ -53,9 +53,10 @@ ("test_models", "mock_llm", False), ], ) -def test_get_vllm_lora_name(model_repository: str, model_name: str, expect_error: bool): +def test_parse_lora_configs(model_repository: str, model_name: str, expect_error: bool): try: - get_vllm_lora_names(model_repository, model_name, 1) + parse_lora_configs(model_repository, model_name, 1, "vllm") + parse_lora_configs(model_repository, model_name, 1, "tensorrtllm") except ValueError as e: if expect_error: assert ( @@ -83,7 +84,8 @@ def is_vllm_installed(): class LoRATest(unittest.TestCase): - _model_name = "gemma-2b" + _backend = "vllm" if is_vllm_installed() else "tensorrtllm" + _model_name = "gemma-2b" if _backend == "vllm" else "tensorrt_llm_bls" # TODO: Find a LoRA model that has its own tokenizer. _tokenizer = "meta-llama/Meta-Llama-3.1-8B-Instruct" _lora_separator = "_lora_" @@ -99,7 +101,7 @@ def setUp(self): self._completions_outputs = {} self._chat_completion_outputs = {} - def _create_model_repository_with_lora(self): + def _create_vllm_model_repository_with_lora(self): shutil.rmtree("models", ignore_errors=True) os.makedirs(f"models/{self._model_name}/1", exist_ok=True) with open(f"models/{self._model_name}/config.pbtxt", "w") as f: @@ -132,7 +134,20 @@ def _create_model_repository_with_lora(self): local_dir=f"models/{self._model_name}/1/GemmaSheep", ) - def _create_model_repository_without_lora(self): + def _create_trtllm_model_repository_with_lora(self): + shutil.rmtree("models", ignore_errors=True) + shutil.copytree("tests/tensorrtllm_models", "models") + with open(f"models/{self._model_name}/1/multi_lora.json", "w") as f: + f.write( + json.dumps( + { + "doll": f"models/{self._model_name}/1/luotuo-lora-7b-0.1-weights", + "sheep": f"models/{self._model_name}/1/Japanese-Alpaca-LoRA-7b-v0-weights", + } + ) + ) + + def _create_vllm_model_repository_without_lora(self): shutil.rmtree("models", ignore_errors=True) os.makedirs(f"models/{self._model_name}/1", exist_ok=True) with open(f"models/{self._model_name}/config.pbtxt", "w") as f: @@ -140,6 +155,10 @@ def _create_model_repository_without_lora(self): with open(f"models/{self._model_name}/1/model.json", "w") as f: f.write(json.dumps({"model": "unsloth/gemma-2b"})) + def _create_trtllm_model_repository_without_lora(self): + shutil.rmtree("models", ignore_errors=True) + shutil.copytree("tests/tensorrtllm_models", "models") + def _create_model_repository_mock_llm(self): shutil.rmtree("models", ignore_errors=True) os.makedirs(f"models/{self._model_name}/1", exist_ok=True) @@ -214,9 +233,17 @@ def _test_list_models(self, client, expected_lora_names): expected_model_names.append(self._get_model_name(lora_name)) models = client.models.list() for model in models: + if self._backend == "tensorrtllm" and not model.id.startswith( + "tensorrt_llm_bls" + ): + continue self.assertIn(model.id, expected_model_names) expected_model_names.remove(model.id) - self.assertEqual(len(expected_model_names), 0) + self.assertEqual( + len(expected_model_names), + 0, + f"expected_model_names: {expected_model_names}", + ) def _test_retrieve_model(self, client, lora_name): model_name = self._get_model_name(lora_name) @@ -260,9 +287,14 @@ def _test_chat_completion(self, client, lora_name): ) self._chat_completion_outputs[lora_name] = output - @unittest.skipUnless(is_vllm_installed(), "vLLM not installed") def test_lora_separator_not_set(self): - self._create_model_repository_with_lora() + if self._backend == "vllm": + self._create_vllm_model_repository_with_lora() + elif self._backend == "tensorrtllm": + self._create_trtllm_model_repository_with_lora() + else: + raise Exception(f"Unexpected backend {self._backend=}") + with OpenAIServer( cli_args=[ "--model-repository", @@ -296,9 +328,14 @@ def test_lora_separator_not_set(self): expected_error = f"Error code: 400 - {{'detail': 'Unknown model: {self._model_name}{self._lora_separator}sheep'}}" self.assertEqual(str(e.exception), expected_error) - @unittest.skipUnless(is_vllm_installed(), "vLLM not installed") def test_lora_separator_set(self): - self._create_model_repository_with_lora() + if self._backend == "vllm": + self._create_vllm_model_repository_with_lora() + elif self._backend == "tensorrtllm": + self._create_trtllm_model_repository_with_lora() + else: + raise Exception(f"Unexpected backend {self._backend=}") + with OpenAIServer( cli_args=[ "--model-repository", @@ -316,11 +353,13 @@ def test_lora_separator_set(self): self._test_retrieve_model(client, "") self._test_retrieve_model(client, "doll") self._test_retrieve_model(client, "sheep") + # Test retrieving LoRAs unknown to the backend with self.assertRaises(NotFoundError) as e: self._test_retrieve_model(client, "unknown") expected_error = f"Error code: 404 - {{'detail': 'Unknown model: {self._model_name}{self._lora_separator}unknown'}}" self.assertEqual(str(e.exception), expected_error) + # Test selecting LoRAs self._test_completions(client, "") self._test_completions(client, "doll") @@ -328,6 +367,7 @@ def test_lora_separator_set(self): self._test_chat_completion(client, "") self._test_chat_completion(client, "doll") self._test_chat_completion(client, "sheep") + # Test selecting LoRAs unknown to the backend expected_error = f"Error code: 400 - {{'detail': 'Unknown LoRA: unknown; for model: {self._model_name}{self._lora_separator}unknown'}}" with self.assertRaises(BadRequestError) as e: @@ -337,9 +377,14 @@ def test_lora_separator_set(self): self._test_chat_completion(client, "unknown") self.assertEqual(str(e.exception), expected_error) - @unittest.skipUnless(is_vllm_installed(), "vLLM not installed") def test_lora_separator_set_for_lora_off_model(self): - self._create_model_repository_without_lora() + if self._backend == "vllm": + self._create_vllm_model_repository_without_lora() + elif self._backend == "tensorrtllm": + self._create_trtllm_model_repository_without_lora() + else: + raise Exception(f"Unexpected backend {self._backend=}") + with OpenAIServer( cli_args=[ "--model-repository", diff --git a/qa/L0_openai/generate_engine.py b/qa/L0_openai/generate_engine.py index 83ea35a88d..b454896cfb 100644 --- a/qa/L0_openai/generate_engine.py +++ b/qa/L0_openai/generate_engine.py @@ -27,18 +27,27 @@ from tensorrt_llm import BuildConfig from tensorrt_llm._tensorrt_engine import LLM +from tensorrt_llm.lora_manager import LoraConfig from tensorrt_llm.plugin import PluginConfig def generate_model_engine(model: str, engines_path: str): config = BuildConfig(plugin_config=PluginConfig.from_dict({"_gemm_plugin": "auto"})) + lora_config = LoraConfig( + lora_target_modules=["attn_q", "attn_k", "attn_v"], + max_lora_rank=8, + max_loras=4, + max_cpu_loras=8, + ) + engine = LLM( model, dtype="float16", max_batch_size=128, build_config=config, guided_decoding_backend="xgrammar", + lora_config=lora_config, ) engine.save(engines_path) diff --git a/qa/L0_openai/test.sh b/qa/L0_openai/test.sh index a1db293436..5820b75dd0 100755 --- a/qa/L0_openai/test.sh +++ b/qa/L0_openai/test.sh @@ -97,6 +97,19 @@ function prepare_tensorrtllm() { python3 ${FILL_TEMPLATE} -i ${MODEL_REPO}/postprocessing/config.pbtxt tokenizer_dir:${ENGINE_PATH},triton_max_batch_size:64,postprocessing_instance_count:1 python3 ${FILL_TEMPLATE} -i ${MODEL_REPO}/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:64,decoupled_mode:True,bls_instance_count:1,accumulate_tokens:False,logits_datatype:TYPE_FP32,prompt_embedding_table_data_type:TYPE_FP16 python3 ${FILL_TEMPLATE} -i ${MODEL_REPO}/tensorrt_llm/config.pbtxt triton_backend:${TRITON_BACKEND},triton_max_batch_size:64,decoupled_mode:True,max_beam_width:1,engine_dir:${ENGINE_PATH},batching_strategy:inflight_fused_batching,max_queue_size:0,max_queue_delay_microseconds:1000,encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32,exclude_input_in_output:True,prompt_embedding_table_data_type:TYPE_FP16,guided_decoding_backend:${GUIDED_DECODING_BACKEND},xgrammar_tokenizer_info_path:${XGRAMMAR_TOKENIZER_INFO_PATH} + + # 4. Prepare lora adapters + # FIXME: Remove this WAR when it is fixed in the future stable version of TRT-LLM. + sed -i 's/dims: \[ -1, 3 \]/dims: \[ -1, 4 \]/' ${MODEL_REPO}/tensorrt_llm/config.pbtxt + sed -i 's/dims: \[ -1, 3 \]/dims: \[ -1, 4 \]/' ${MODEL_REPO}/tensorrt_llm_bls/config.pbtxt + pushd ${MODEL_REPO}/tensorrt_llm_bls/1 + for lora_name in silk-road/luotuo-lora-7b-0.1 kunishou/Japanese-Alpaca-LoRA-7b-v0; do + name=$(basename $lora_name) + git clone https://huggingface.co/$lora_name + python3 /app/examples/hf_lora_convert.py -i $name -o $name-weights --storage-type float16 + rm -rf $name + done + popd } function pre_test() {