diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index c8ad938032..41a24d7ab1 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -42,6 +42,7 @@ ) from torch_tensorrt.dynamo.utils import ( deallocate_module, + get_cpu_memory_usage, get_flat_args_with_check, get_output_metadata, parse_graph_io, @@ -104,6 +105,7 @@ def cross_compile_for_windows( l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, + cpu_memory_budget: int = _defaults.CPU_MEMORY_BUDGET, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows @@ -178,6 +180,7 @@ def cross_compile_for_windows( tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model + cpu_memory_budget (int): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail. If set to -1, the compilation will use all available CPU memory. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -333,6 +336,7 @@ def cross_compile_for_windows( "tiling_optimization_level": tiling_optimization_level, "l2_limit_for_tiling": l2_limit_for_tiling, "use_distributed_mode_trace": use_distributed_mode_trace, + "cpu_memory_budget": cpu_memory_budget, } # disable the following settings is not supported for cross compilation for windows feature @@ -434,6 +438,7 @@ def compile( l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, + cpu_memory_budget: int = _defaults.CPU_MEMORY_BUDGET, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -680,8 +685,9 @@ def compile( "l2_limit_for_tiling": l2_limit_for_tiling, "offload_module_to_cpu": offload_module_to_cpu, "use_distributed_mode_trace": use_distributed_mode_trace, + "cpu_memory_budget": cpu_memory_budget, } - + logger.debug(f"CPU memory usage before lowering: {get_cpu_memory_usage()} MB") settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) exported_program = pre_export_lowering(exported_program, settings) @@ -695,14 +701,17 @@ def compile( # Apply lowering on the graph module gm = post_lowering(gm, settings) + logger.debug(f"CPU memory usage after post_lowering: {get_cpu_memory_usage()} MB") logger.debug("Lowered Input graph: " + str(gm.graph)) # Move the weights in the state_dict to CPU if offload_module_to_cpu: + deallocate_module(gm, delete_module=False) deallocate_module(exported_program.module(), delete_module=False) logger.info( "The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False" ) + logger.debug(f"CPU memory usage after CPU offload: {get_cpu_memory_usage()} MB") else: remaining_memory, total_memory = torch.cuda.mem_get_info() if remaining_memory < total_memory // 2: @@ -829,6 +838,7 @@ def preserve_module_specs( torch_executed_ops=settings.torch_executed_ops, require_full_compilation=settings.require_full_compilation, skip_fusion=(num_supported_ops == total_ops), + cpu_memory_budget=settings.cpu_memory_budget, ) except torch.fx.passes.splitter_base.FxNetSplitterInternalError: @@ -857,10 +867,10 @@ def preserve_module_specs( dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(partitioned_module)) submodule_node_dict = {} - for node in partitioned_module.graph.nodes: - if "_run_on_acc" not in node.name: + for name, node in partitioned_module.named_children(): + if "_run_on_acc" not in name: continue - submodule_node_dict[node.name] = node + submodule_node_dict[name] = node preserve_module_specs(original_in_spec, original_out_spec, partitioned_module) # Store TRT replicas of Torch subgraphs @@ -868,6 +878,16 @@ def preserve_module_specs( # Iterate over all components that can be accelerated # Generate the corresponding TRT Module for those + # Here we delete the frozen parameters from the graph module. Note this does not affect the submodules. We are going to delete the frozen parameters from the submodules in the convert_module function. + # This is done to release CPU memory. + for attr in dir(gm): + if attr.startswith("_frozen_param"): + delattr(gm, attr) + + from torch_tensorrt.dynamo.conversion._ConverterRegistry import DYNAMO_CONVERTERS + + DYNAMO_CONVERTERS.disallowed_targets = set() + for name, _ in partitioned_module.named_children(): submodule = getattr(partitioned_module, name) # filter on the GraphModule @@ -1056,6 +1076,7 @@ def convert_exported_program_to_serialized_trt_engine( l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, + cpu_memory_budget: int = _defaults.CPU_MEMORY_BUDGET, **kwargs: Any, ) -> bytes: """Convert an ExportedProgram to a serialized TensorRT engine @@ -1243,7 +1264,7 @@ def convert_exported_program_to_serialized_trt_engine( # Prepare torch_trt inputs trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs) - trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs) + trt_kwarg_inputs: Optional[dict[str, Any]] = prepare_inputs(kwarg_inputs) device = to_torch_tensorrt_device(device) enabled_precisions = {dtype._from(p) for p in enabled_precisions} diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index de970ecd81..712eeb1ba0 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -57,6 +57,7 @@ L2_LIMIT_FOR_TILING = -1 USE_DISTRIBUTED_MODE_TRACE = False OFFLOAD_MODULE_TO_CPU = False +CPU_MEMORY_BUDGET = -1 if platform.system() == "Linux": import pwd diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index d8f6809eae..52ac86012c 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -7,6 +7,7 @@ from torch_tensorrt.dynamo._defaults import ( ASSUME_DYNAMIC_SHAPE_SUPPORT, CACHE_BUILT_ENGINES, + CPU_MEMORY_BUDGET, DISABLE_TF32, DLA_GLOBAL_DRAM_SIZE, DLA_LOCAL_DRAM_SIZE, @@ -140,6 +141,7 @@ class CompilationSettings: l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU + cpu_memory_budget: int = CPU_MEMORY_BUDGET def __getstate__(self) -> dict[str, Any]: from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 73af09448e..2542d652bd 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -50,7 +50,12 @@ from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger from torch_tensorrt.dynamo.observer import Observer -from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, deallocate_module, to_torch_device +from torch_tensorrt.dynamo.utils import ( + DYNAMIC_DIM, + deallocate_module, + get_cpu_memory_usage, + to_torch_device, +) from torch_tensorrt.logging import TRT_LOGGER _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -65,7 +70,7 @@ class UnsupportedOperatorException(RuntimeError): class TRTInterpreterResult(NamedTuple): - serialized_engine: bytes + engine: trt.ICudaEngine input_names: Sequence[str] output_names: Sequence[str] weight_name_map: Optional[dict[Any, Any]] @@ -512,8 +517,7 @@ def _save_weight_mapping(self) -> None: _LOGGER.info("Building weight name mapping...") # Stage 1: Name mapping torch_device = to_torch_device(self.compilation_settings.device) - self.module.to(torch_device) - sd = self.module.state_dict() + sd = {k: v.to(torch_device) for k, v in self.module.state_dict().items()} weight_name_map: dict[str, Any] = {} weight_refit_map = self.ctx.weight_refit_map constant_mapping = {k: v for k, v in weight_refit_map.items() if v.size == 1} @@ -592,13 +596,11 @@ def _save_weight_mapping(self) -> None: torch.cuda.empty_cache() @needs_refit # type: ignore[misc] - def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None: + def _insert_engine_to_cache(self, hash_val: str, engine: trt.ICudaEngine) -> None: + serialized_engine = engine.serialize() # TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine # if not self.compilation_settings.strip_engine_weights: # # set EXCLUDE_WEIGHTS flag to strip weights - # runtime = trt.Runtime(TRT_LOGGER) - # engine = runtime.deserialize_cuda_engine(serialized_engine) - # serialization_config = engine.create_serialization_config() # serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) # serialized_engine = engine.serialize_with_config( @@ -733,6 +735,9 @@ def run( return interpreter_result # type: ignore[no-any-return] self._construct_trt_network_def() + _LOGGER.debug( + f"CPU memory usage after network construction: {get_cpu_memory_usage()} MB" + ) if not self.compilation_settings.immutable_weights: self._save_weight_mapping() @@ -750,16 +755,19 @@ def run( self._create_timing_cache( builder_config, self.compilation_settings.timing_cache_path ) - serialized_engine = self.builder.build_serialized_network( + + cuda_engine = self.builder.build_engine_with_config( self.ctx.net, builder_config ) - assert serialized_engine + assert cuda_engine + + _LOGGER.debug( + f"CPU memory usage after engine building: {get_cpu_memory_usage()} MB" + ) _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) - _LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory") - self.ctx.clear_cpu_weights_reference_holder() self._save_timing_cache( @@ -772,14 +780,10 @@ def run( and self.compilation_settings.cache_built_engines and self.engine_cache is not None ): - self._insert_engine_to_cache(hash_val, serialized_engine) - - with io.BytesIO() as engine_bytes: - engine_bytes.write(serialized_engine) - engine_str = engine_bytes.getvalue() + self._insert_engine_to_cache(hash_val, cuda_engine) return TRTInterpreterResult( - engine_str, + cuda_engine, self._input_names, self._output_names, self.weight_name_map, diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 35b6c26617..aaec8d3be8 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -1,7 +1,8 @@ from __future__ import annotations +import io import logging -from typing import Any, List, Optional, Sequence +from typing import Any, List, NamedTuple, Optional, Sequence import torch from torch_tensorrt._enums import dtype @@ -9,16 +10,25 @@ from torch_tensorrt._Input import Input from torch_tensorrt.dynamo._engine_cache import BaseEngineCache from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.conversion._TRTInterpreter import ( - TRTInterpreter, - TRTInterpreterResult, -) +from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule -from torch_tensorrt.dynamo.utils import get_output_dtypes +from torch_tensorrt.dynamo.utils import ( + get_cpu_memory_usage, + get_output_dtypes, + release_memory, +) logger = logging.getLogger(__name__) +class SerializedInterpreterResult(NamedTuple): + serialized_engine: bytes + input_names: Sequence[str] + output_names: Sequence[str] + weight_name_map: Optional[dict[Any, Any]] + requires_output_allocator: bool + + def infer_module_output_dtypes( module: torch.fx.GraphModule, truncate_double: bool = False, @@ -29,7 +39,7 @@ def infer_module_output_dtypes( """ outputs = [node for node in module.graph.nodes if node.op == "output"] outputs = outputs[0].args - return get_output_dtypes(outputs, truncate_double) # type: ignore[no-any-return] + return get_output_dtypes(outputs, truncate_double) def interpret_module_to_result( @@ -39,7 +49,7 @@ def interpret_module_to_result( arg_inputs: Optional[Sequence[Input]] = None, kwarg_inputs: Optional[dict[str, Any]] = None, engine_cache: Optional[BaseEngineCache] = None, -) -> TRTInterpreterResult: +) -> SerializedInterpreterResult: """Interpret an FX module to a TRTInterpreterResult Args: module: FX GraphModule to interpret @@ -65,7 +75,32 @@ def interpret_module_to_result( ) interpreter_result = interpreter.run() - return interpreter_result + # Delete the frozen parameters from the module to release CPU memory + del interpreter + for attr in dir(module): + if attr.startswith("_frozen_param"): + delattr(module, attr) + release_memory() + logger.debug( + f"CPU memory usage after clearing frozen parameters and building memory in conversion: {get_cpu_memory_usage()} MB" + ) + + serialized_engine = interpreter_result.engine.serialize() + with io.BytesIO() as engine_bytes: + engine_bytes.write(serialized_engine) + serialized_engine = engine_bytes.getvalue() + logger.debug( + f"CPU memory usage after serializing engine: {get_cpu_memory_usage()} MB" + ) + serialized_interpreter_result = SerializedInterpreterResult( + serialized_engine=serialized_engine, + input_names=interpreter_result.input_names, + output_names=interpreter_result.output_names, + weight_name_map=interpreter_result.weight_name_map, + requires_output_allocator=interpreter_result.requires_output_allocator, + ) + + return serialized_interpreter_result def convert_module( diff --git a/py/torch_tensorrt/dynamo/debug/_Debugger.py b/py/torch_tensorrt/dynamo/debug/_Debugger.py index 39e4217f73..e565929861 100644 --- a/py/torch_tensorrt/dynamo/debug/_Debugger.py +++ b/py/torch_tensorrt/dynamo/debug/_Debugger.py @@ -220,6 +220,7 @@ def get_logging_config(self, log_level: Optional[int] = None) -> dict[str, Any]: "class": "logging.FileHandler", "filename": f"{self.cfg.logging_dir}/torch_tensorrt_logging.log", "formatter": "standard", + "mode": "w", # This will clear the previous content } config["loggers"][""]["handlers"].append("file") return config diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index 5ba84b09b0..9b821df906 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -37,7 +37,9 @@ def constant_fold( # For TRT INetwork construction the constants are moved to CPU in get_attr call. for node, constant in cf.node_replacements.items(): replace_node_with_constant( - gm, node, torch.nn.Parameter(constant, requires_grad=False) + gm, + node, + torch.nn.Parameter(constant.cpu().contiguous(), requires_grad=False), ) erased_params = [] diff --git a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py index e2f544c2a7..ec7c3e6e16 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py @@ -1,6 +1,7 @@ import logging from typing import Collection, Dict, List, Optional, Tuple +import psutil import torch import torch.fx.passes.operator_support as ops from torch.fx.node import Target @@ -24,6 +25,10 @@ ) logger = logging.getLogger(__name__) +NON_BREAKABLE_OP_LISTS = [ + ["addmm", "addmm"], + ["conv2d", "batch_norm2d", "relu"], +] class OpSupportTester(ops.OperatorSupportBase): # type: ignore @@ -113,6 +118,7 @@ def __init__( require_full_compilation: bool = REQUIRE_FULL_COMPILATION, return_tuple: bool = False, skip_fusion: bool = False, + cpu_memory_budget: int = -1, ): """ Preprocesses graph before splitting: @@ -132,6 +138,7 @@ def __init__( skip_fusion=skip_fusion, ) self.operator_support = operator_support + self.cpu_memory_budget = cpu_memory_budget # Get all accelerated nodes based on operator support conditions self.acc_nodes = FxNetAccNodesFinder( @@ -225,12 +232,294 @@ def partition_graph(self) -> torch.fx.GraphModule: # Remove segments smaller than the block size (with exceptions) subgraphs = self.remove_small_acc_subgraphs(subgraphs) + subgraphs = self.break_subgraphs( + subgraphs, subgraph_size_budget=self.calculate_size_budget() + ) + # Set the number of TRT engines to be generated self.num_trt_accelerated_subgraphs = len([s for s in subgraphs if s.is_acc]) # Tag the accelerated nodes and split the graph accordingly self.tag(subgraphs) - return self.split() + + gm = self.split() + + return gm + + def calculate_size_budget( + self, engine_compilation_memory_usage_multiplier: int = 4 + ) -> int: + """ + This function calculates the size budget based on the available RSS. We assume that TRT compilation + needs at most 4x the memory of the model. + """ + if self.cpu_memory_budget == -1: + available_rss: int = psutil.virtual_memory().available + else: + used_rss: int = psutil.virtual_memory().used + available_rss = self.cpu_memory_budget - used_rss + return available_rss // engine_compilation_memory_usage_multiplier + + def break_subgraphs_by_node( + self, subgraphs: List[Subgraph], num_of_break: int = 1 + ) -> List[Subgraph]: + """ + This function breaks the subgraphs into smaller subgraphs at the specified frequency to save CPU memory. + """ + op_to_break = "addmm." + num_of_sdpa_node = len( + [node for node in self.acc_nodes if op_to_break in str(node.target)] + ) + break_period = num_of_sdpa_node // num_of_break + 1 + current_break_idx = 0 + current_num_break = 0 + new_subgraphs = [] + for subgraph in subgraphs: + if subgraph.is_acc: + for i, node in enumerate(subgraph.nodes): + if op_to_break in str(node.target): + current_num_break += 1 + if current_num_break % break_period != 0: + continue + new_subgraphs.append( + Subgraph( + is_acc=True, + nodes=subgraph.nodes[current_break_idx : i + 1], + device_ordinal=subgraph.device_ordinal, + ) + ) + current_break_idx = i + 1 + new_subgraphs.append( + Subgraph( + is_acc=True, + nodes=subgraph.nodes[current_break_idx:], + device_ordinal=subgraph.device_ordinal, + ) + ) + else: + new_subgraphs.append(subgraph) + + new_subgraphs = self.validate_and_correct_subgraphs(new_subgraphs) + + return new_subgraphs + + def break_subgraphs( + self, subgraphs: List[Subgraph], subgraph_size_budget: int + ) -> List[Subgraph]: + """ + This function breaks the subgraphs into smaller subgraphs to save CPU memory. + """ + new_subgraphs = [] + # We throw an error if the remaining memory is almost empty compared to the model size. + # i.e. if the remaining memory is 4G (budget is 1G) the model size is greater than 40G, we stop the compilation. + sizes = self.size_of_subgraphs(subgraphs) + if sum(sizes) > subgraph_size_budget * 40: + raise ValueError( + f"CPU memory budget or available memory is too small to compile the model. CPU memory budget: {self.cpu_memory_budget // (1024 * 1024) if self.cpu_memory_budget != -1 else "All available memory"} MB, Model size: {sum(sizes) // (1024 * 1024)} MB. " + + "Consider setting cpu_memory_budget to a larger value or disable offload_module_to_cpu to save more CPU memory." + ) + for subgraph, size in zip(subgraphs, sizes): + + while size > subgraph_size_budget: + broken_subgraphs, size_0, size_1 = self.break_subgraph_by_size( + subgraph, subgraph_size_budget + ) + size = size_1 + new_subgraphs.append(broken_subgraphs[0]) + subgraph = broken_subgraphs[1] + new_subgraphs.append(subgraph) + + return new_subgraphs + + def break_subgraph_by_size( + self, subgraph: Subgraph, size_to_break: int + ) -> Tuple[List[Subgraph], int, int]: + """ + This function breaks the subgraphs into smaller subgraphs at the specified frequency to save CPU memory. + """ + all_nodes = subgraph.nodes + device_ordinal = subgraph.device_ordinal + new_subgraphs = [ + Subgraph( + is_acc=True, + nodes=[], + device_ordinal=device_ordinal, + ), + Subgraph( + is_acc=True, + nodes=all_nodes, + device_ordinal=device_ordinal, + ), + ] + + while True: + new_subgraphs = self.step_and_validate(new_subgraphs) + size_0, size_1 = self.size_of_subgraphs(new_subgraphs) + if size_0 > size_to_break: + break + + if len(new_subgraphs[1].nodes) == 0: + new_subgraphs.pop(1) + return new_subgraphs, size_0, size_1 + + def step_and_validate( + self, new_subgraphs: List[Subgraph], step_size: int = 1 + ) -> List[Subgraph]: + + # TODO: We can change it to binary search to find the optimal break point + for _ in range(step_size): + new_subgraphs[0].nodes.append(new_subgraphs[1].nodes.pop(0)) + + while True: + new_subgraphs = self.validate_and_correct_subgraphs(new_subgraphs) + nodes_in_first_subgraph = set(new_subgraphs[0].nodes) + leaf_node = self.get_leaf_node(nodes_in_first_subgraph) + broken_fusion = self.step_if_break_fusion( + new_subgraphs, leaf_node, nodes_in_first_subgraph + ) + if not broken_fusion or len(new_subgraphs[1].nodes) == 0: + break + + return new_subgraphs + + def step_if_break_fusion( + self, + subgraphs: List[Subgraph], + leaf_nodes: set[torch.fx.Node], + nodes_in_first_subgraph: set[torch.fx.Node], + ) -> bool: + + def add_nodes(node: torch.fx.Node) -> None: + """ + This function adds a node and all its previous nodes to the first subgraph and removes it from the second subgraph in post order. + """ + if node.op in CALLABLE_NODE_OPS and node not in nodes_in_first_subgraph: + nodes_in_first_subgraph.add(node) + for input_node in node._input_nodes: + add_nodes(input_node) + subgraphs[0].nodes.append(node) + subgraphs[1].nodes.remove(node) + + def match_subgraph_and_step(node: torch.fx.Node) -> bool: + added_nodes = False + for op_list in NON_BREAKABLE_OP_LISTS: + for i, op in enumerate(op_list): + if i != len(op_list) - 1 and op in str(node.target): + # Search following ops forward using BFS. We skip search previous ops because + # even if it's just a subset of fusion graph, we still want it to be fused. + + users = node.users.keys() + matching_nodes: set[torch.fx.Node] = set() + for following_op_idx in range(i + 1, len(op_list)): + matching_nodes = set() + for user in users: + if op_list[following_op_idx] in str(user.target): + matching_nodes.add(user) + if not matching_nodes: + break + users = set() + for matching_node in matching_nodes: + for next_user in matching_node.users: + users.add(next_user) + + for matching_node in matching_nodes: + added_nodes = True + add_nodes(matching_node) + + if added_nodes: + # Early terminate the search if we have found a match because preceeding matches can cover following matches + break + + return True if added_nodes else False + + found_match = False + for leaf in leaf_nodes: + if match_subgraph_and_step(leaf): + found_match = True + + return found_match + + def get_leaf_node( + self, nodes_in_first_subgraph: set[torch.fx.Node] + ) -> set[torch.fx.Node]: + leaf_node = set() + + for node in nodes_in_first_subgraph: + for user in node.users: + if user not in nodes_in_first_subgraph: + leaf_node.add(node) + break + return leaf_node + + def size_of_subgraphs(self, subgraphs: List[Subgraph]) -> List[int]: + """ + This function calculates the size of the subgraph. + """ + state_dict = self.module.state_dict(keep_vars=True) + sizes = [] + weight_visited_nodes = set() + for subgraph in subgraphs: + nodes_in_subgraph = set(subgraph.nodes) + stack = subgraph.nodes.copy() + size = 0 + while stack: + node = stack.pop() + if node in weight_visited_nodes: + continue + weight_visited_nodes.add(node) + if node.op == "get_attr": + weight = state_dict[node.target] + size += weight.numel() * weight.element_size() + continue + if node not in nodes_in_subgraph: + # Trace to other subgraphs + continue + for input_node in node._input_nodes: + if input_node not in weight_visited_nodes: + stack.append(input_node) + sizes.append(size) + return sizes + + def validate_and_correct_subgraphs( + self, subgraphs: List[Subgraph] + ) -> List[Subgraph]: + """ + This function validates the subgraphs by checking if the subgraphs are valid, and corrects the subgraphs if they are not valid. + """ + visited_nodes = ( + {} + ) # a map from a node to the index of the subgraph it's user should belong to + for i, subgraph in enumerate(subgraphs): + if i == 0: + for node in subgraph.nodes: + visited_nodes[node] = i + visited_nodes[subgraph.nodes[-1]] = i + 1 + continue + + elif not subgraph.is_acc: + for node in subgraph.nodes: + visited_nodes[subgraph.nodes[-1]] = i + 1 + continue + + else: + to_remove_nodes = [] + for j, node in enumerate(subgraph.nodes): + if j == len(subgraph.nodes) - 1: + visited_nodes[node] = i + 1 + continue + subgraph_idx = 0 + for dep in self.deps[node]: + if dep in visited_nodes: + subgraph_idx = max(subgraph_idx, visited_nodes[dep]) + + if subgraph_idx != i: + subgraphs[subgraph_idx].nodes.append(node) + to_remove_nodes.append(node) + visited_nodes[node] = subgraph_idx + for node in to_remove_nodes: + subgraph.nodes.remove(node) + + return subgraphs def starter_nodes(self) -> Tuple[NodeSet, NodeSet]: """Generates starter nodes for partitioning + segmentation""" @@ -255,6 +544,7 @@ def partition( torch_executed_ops: Collection[Target] = set(), require_full_compilation: bool = REQUIRE_FULL_COMPILATION, skip_fusion: bool = False, + cpu_memory_budget: int = -1, ) -> Tuple[torch.fx.GraphModule, OpSupportTester]: """Partition an FX GraphModule with aten ops into TRT engines Partitioning is based on converter operator support @@ -281,6 +571,7 @@ def partition( min_block_size=min_block_size, require_full_compilation=require_full_compilation, skip_fusion=skip_fusion, + cpu_memory_budget=cpu_memory_budget, ) partitioned_graph = partitioner.partition_graph() diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 97328acd6d..ba80ec4221 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -1,7 +1,9 @@ from __future__ import annotations +import ctypes import gc import logging +import platform import warnings from dataclasses import fields, replace from enum import Enum @@ -17,6 +19,7 @@ ) import numpy as np +import psutil import sympy import tensorrt as trt import torch @@ -853,3 +856,36 @@ def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype] f"got unexpected type {type(output)}, expected type is a torch.fx.node.Node or a tuple/list of torch.fx.node.Node" ) return output_dtypes + + +def is_tegra_platform() -> bool: + if torch.cuda.get_device_capability() in [(8, 7), (7, 2)]: + return True + return False + + +def is_thor() -> bool: + if torch.cuda.get_device_capability() in [(11, 0)]: + return True + return False + + +def get_cpu_memory_usage() -> Any: + return psutil.Process().memory_info().rss / 1024 / 1024 + + +def release_memory() -> None: + gc.collect() + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + torch.cuda.synchronize() + + if platform.system() == "Linux": + try: + libc = ctypes.CDLL("libc.so.6") + if libc.malloc_trim(0) != 1: + logger.warning("Failed to release CPU memory.") + except Exception: + logger.warning("Failed to release CPU memory.") diff --git a/tests/py/dynamo/models/test_models.py b/tests/py/dynamo/models/test_models.py index d4133ff4b4..a1600e46eb 100644 --- a/tests/py/dynamo/models/test_models.py +++ b/tests/py/dynamo/models/test_models.py @@ -55,6 +55,52 @@ def test_resnet18(ir): torch._dynamo.reset() +def compile_one(idx: int, ir: str): + model = models.resnet18(pretrained=True).eval().to("cuda") + input = torch.randn((idx + 1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "ir": ir, + "pass_through_build_failures": True, + "optimization_level": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"In multiprocess compilation test, process {idx} failed: Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +@pytest.mark.unit +@unittest.skipIf( + not importlib.util.find_spec("torchvision"), + "torchvision is not installed", +) +def test_resnet18_multiprocess(ir): + import torch.multiprocessing as mp + + mp.set_start_method("spawn", force=True) + procs = [] + for i in range(3): + p = mp.Process(target=compile_one, args=(i, ir)) + p.start() + procs.append(p) + for p in procs: + p.join() + torch._dynamo.reset() + + @pytest.mark.unit @unittest.skipIf( not importlib.util.find_spec("torchvision"), diff --git a/tools/llm/torchtrt_ext/register_sdpa.py b/tools/llm/torchtrt_ext/register_sdpa.py index a82384fda9..c86ee6f3a4 100644 --- a/tools/llm/torchtrt_ext/register_sdpa.py +++ b/tools/llm/torchtrt_ext/register_sdpa.py @@ -23,6 +23,7 @@ torch.ops.aten.scaled_dot_product_attention.default, torch.ops.aten._scaled_dot_product_efficient_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten._scaled_dot_product_cudnn_attention.default, ) @@ -43,6 +44,7 @@ def _remove_decompositions(): REPLACEABLE_ATEN_OPS = { torch.ops.aten._scaled_dot_product_efficient_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten._scaled_dot_product_cudnn_attention.default, } from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( @@ -79,7 +81,10 @@ def _process_sdpa_node( ValueError: If the SDPA node has an unexpected number of arguments """ - if node.target == torch.ops.aten._scaled_dot_product_efficient_attention.default: + if node.target in [ + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_cudnn_attention.default, + ]: if len(node.args) == 7: ( query,