2525// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626
2727#include < stdint.h>
28+
2829#include < exception>
30+
2931#include " libtorch_utils.h"
3032#include " triton/backend/backend_common.h"
3133#include " triton/backend/backend_input_collector.h"
@@ -103,6 +105,7 @@ class ModelState : public BackendModel {
103105
104106 bool EnabledWeightSharing () { return enable_weight_sharing_; }
105107 const std::vector<std::string>& ModelOutputs () { return output_names_; }
108+ const std::string& MethodToCall () { return method_to_call_; }
106109
107110 private:
108111 ModelState (TRITONBACKEND_Model* triton_model);
@@ -145,6 +148,10 @@ class ModelState : public BackendModel {
145148 // List of all the outputs specified in the output section of model
146149 // configuration.
147150 std::vector<std::string> output_names_;
151+
152+ // Method to call on PyTorch Module.
153+ // Defaults to "forward".
154+ std::string method_to_call_;
148155};
149156
150157TRITONSERVER_Error*
@@ -180,7 +187,7 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
180187 enable_weight_sharing_(false ), enable_tensor_fuser_pair_({false , true }),
181188 enable_jit_profiling_pair_({false , true }),
182189 enable_jit_executor_pair_({false , true }),
183- enable_nvfuser_pair_({false , false })
190+ enable_nvfuser_pair_({false , false }), method_to_call_( " forward " )
184191{
185192 output_names_.clear ();
186193
@@ -454,6 +461,30 @@ ModelState::ParseParameters()
454461 " for model instance '" + Name () + " '" )
455462 .c_str ());
456463 }
464+
465+ // If 'ENABLE_NVFUSER' is not present in 'parameters' then no
466+ // update is made to 'enable_nvfuser'.
467+ std::string method_to_call = " forward" ;
468+ err = GetParameterValue (params, " METHOD_TO_CALL" , &method_to_call);
469+ if (err != nullptr ) {
470+ if (TRITONSERVER_ErrorCode (err) != TRITONSERVER_ERROR_NOT_FOUND) {
471+ return err;
472+ } else {
473+ LOG_MESSAGE (
474+ TRITONSERVER_LOG_INFO,
475+ (std::string (" method_to_call is not specified" ) +
476+ " for model instance '" + Name () + " '" )
477+ .c_str ());
478+ TRITONSERVER_ErrorDelete (err);
479+ }
480+ } else {
481+ method_to_call_ = method_to_call;
482+ LOG_MESSAGE (
483+ TRITONSERVER_LOG_INFO,
484+ (std::string (" method_to_call is " ) + method_to_call_ +
485+ " for model instance '" + Name () + " '" )
486+ .c_str ());
487+ }
457488 }
458489
459490 return nullptr ;
@@ -764,7 +795,8 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
764795 // configuration specifies only those.
765796 std::vector<std::string> allowed_inputs;
766797
767- const torch::jit::Method& method = torch_model_->get_method (" forward" );
798+ const torch::jit::Method& method =
799+ torch_model_->get_method (model_state_->MethodToCall ());
768800 const auto & schema = method.function ().getSchema ();
769801 const std::vector<c10::Argument>& arguments = schema.arguments ();
770802
@@ -1312,28 +1344,32 @@ ModelInstanceState::Execute(
13121344 torch::jit::overrideCanFuseOnCPU (false );
13131345 torch::jit::overrideCanFuseOnGPU (false );
13141346 torch::jit::setTensorExprFuserEnabled (false );
1315- torch::jit::fuser::cuda::setEnabled (true );
1347+ torch::jit::fuser::cuda::setEnabled (true );
13161348 } else {
13171349 torch::jit::overrideCanFuseOnCPU (true );
13181350 torch::jit::overrideCanFuseOnGPU (true );
13191351 torch::jit::setTensorExprFuserEnabled (true );
1320- torch::jit::fuser::cuda::setEnabled (false );
1352+ torch::jit::fuser::cuda::setEnabled (false );
13211353 }
13221354 }
13231355
13241356 torch::NoGradGuard no_grad;
13251357
13261358 // If input is a dictionary, prepare dictionary from 'input_tensors'.
1359+ std::string method_to_call = model_state_->MethodToCall ();
13271360 if (is_dict_input_) {
1328- torch ::Dict<std::string, torch ::Tensor> input_dict ;
1361+ c10 ::Dict<std::string, at ::Tensor> dict ;
13291362 for (auto & input_index : input_index_map_) {
13301363 torch::jit::IValue ival = (*input_tensors)[input_index.second ];
1331- input_dict .insert (input_index.first , ival.toTensor ());
1364+ dict .insert (input_index.first , ival.toTensor ());
13321365 }
1333- std::vector<torch::jit::IValue> input_dict_ivalue = {input_dict};
1334- model_outputs_ = torch_model_->forward (input_dict_ivalue);
1366+ model_outputs_ = torch_model_->run_method (method_to_call, dict);
13351367 } else {
1336- model_outputs_ = torch_model_->forward (*input_tensors);
1368+ auto inp = c10::impl::GenericList (c10::TensorType::get ());
1369+ for (auto & input_tensor : *input_tensors) {
1370+ inp.emplace_back (input_tensor.toTensor ());
1371+ }
1372+ model_outputs_ = torch_model_->run_method (method_to_call, inp);
13371373 }
13381374
13391375 if (model_outputs_.isTuple ()) {
@@ -1761,9 +1797,9 @@ ModelInstanceState::SetInputTensors(
17611797
17621798 batchn_shape[0 ] += GetElementCount (input_shape, input_dims_count);
17631799 }
1764- }
1765- else {
1766- batchn_shape = std::vector<int64_t >(input_shape, input_shape + input_dims_count);
1800+ } else {
1801+ batchn_shape =
1802+ std::vector<int64_t >(input_shape, input_shape + input_dims_count);
17671803 if (supports_batching_) {
17681804 batchn_shape[0 ] = total_batch_size;
17691805 }
@@ -1772,8 +1808,8 @@ ModelInstanceState::SetInputTensors(
17721808 // The input must be in contiguous CPU/GPU memory.
17731809 std::vector<std::pair<TRITONSERVER_MemoryType, int64_t >> alloc_perference;
17741810 if (device_.is_cpu ()) {
1775- alloc_perference = {{TRITONSERVER_MEMORY_CPU_PINNED, 0 },
1776- {TRITONSERVER_MEMORY_CPU, 0 }};
1811+ alloc_perference = {
1812+ {TRITONSERVER_MEMORY_CPU_PINNED, 0 }, {TRITONSERVER_MEMORY_CPU, 0 }};
17771813 } else {
17781814 alloc_perference = {{TRITONSERVER_MEMORY_GPU, device_.index ()}};
17791815 }
@@ -1887,9 +1923,11 @@ ModelInstanceState::ReadOutputTensors(
18871923
18881924 // Output tensors may not reside on the same device as model
18891925 torch::Device tensor_device = output_flat.device ();
1890- const auto memory_type = (tensor_device.type () == torch::kCPU ) ? TRITONSERVER_MEMORY_CPU
1891- : TRITONSERVER_MEMORY_GPU;
1892- const auto memory_id = (tensor_device.type () == torch::kCPU ) ? 0 : tensor_device.index ();
1926+ const auto memory_type = (tensor_device.type () == torch::kCPU )
1927+ ? TRITONSERVER_MEMORY_CPU
1928+ : TRITONSERVER_MEMORY_GPU;
1929+ const auto memory_id =
1930+ (tensor_device.type () == torch::kCPU ) ? 0 : tensor_device.index ();
18931931
18941932 // Batch output doesn't support string data type yet, as it is not trivial
18951933 // to parse string output
@@ -1906,16 +1944,16 @@ ModelInstanceState::ReadOutputTensors(
19061944 return TRITONSERVER_ErrorNew (
19071945 TRITONSERVER_ERROR_INVALID_ARG,
19081946 (std::string (" output '" ) + name +
1909- " ' is a scalar which is not supported." )
1947+ " ' is a scalar which is not supported." )
19101948 .c_str ());
19111949 }
19121950
19131951 responder.ProcessTensor (
1914- name, output_dtype, batchn_shape, output_buffer,
1915- memory_type, memory_id);
1952+ name, output_dtype, batchn_shape, output_buffer, memory_type,
1953+ memory_id);
19161954 } else {
19171955 responder.ProcessBatchOutput (
1918- name, *batch_output, output_buffer, memory_type, memory_id);
1956+ name, *batch_output, output_buffer, memory_type, memory_id);
19191957 }
19201958 } else if (output_tensors[op_index].isList ()) {
19211959 // Custom handling for string/bytes tensor...
0 commit comments