@@ -488,9 +488,21 @@ Stub::Execute(ExecuteArgs* execute_args, ResponseBatch* response_batch)
488
488
void
489
489
Stub::Initialize (InitializeArgs* initialize_args)
490
490
{
491
- py::module sys = py::module ::import (" sys" );
491
+ py::module sys = py::module_ ::import (" sys" );
492
492
493
- std::string model_name = model_path_.substr (model_path_.find_last_of (" /" ) + 1 );
493
+ std::string model_name =
494
+ model_path_.substr (model_path_.find_last_of (" /" ) + 1 );
495
+
496
+ // Model name without the .py extension
497
+ auto dotpy_pos = model_name.find_last_of (" .py" );
498
+ if (dotpy_pos == std::string::npos || dotpy_pos != model_name.size () - 1 ) {
499
+ throw PythonBackendException (
500
+ " Model name must end with '.py'. Model name is \" " + model_name + " \" ." );
501
+ }
502
+
503
+ // The position of last character of the string that is searched for is
504
+ // returned by 'find_last_of'. Need to manually adjust the position.
505
+ std::string model_name_trimmed = model_name.substr (0 , dotpy_pos - 2 );
494
506
std::string model_path_parent =
495
507
model_path_.substr (0 , model_path_.find_last_of (" /" ));
496
508
std::string model_path_parent_parent =
@@ -501,9 +513,9 @@ Stub::Initialize(InitializeArgs* initialize_args)
501
513
sys.attr (" path" ).attr (" append" )(python_backend_folder);
502
514
503
515
py::module python_backend_utils =
504
- py::module ::import (" triton_python_backend_utils" );
516
+ py::module_ ::import (" triton_python_backend_utils" );
505
517
py::module c_python_backend_utils =
506
- py::module ::import (" c_python_backend_utils" );
518
+ py::module_ ::import (" c_python_backend_utils" );
507
519
py::setattr (
508
520
python_backend_utils, " Tensor" , c_python_backend_utils.attr (" Tensor" ));
509
521
py::setattr (
@@ -520,7 +532,8 @@ Stub::Initialize(InitializeArgs* initialize_args)
520
532
c_python_backend_utils.attr (" TritonModelException" ));
521
533
522
534
py::object TritonPythonModel =
523
- py::module::import ((model_version_ + std::string (" .model" )).c_str ())
535
+ py::module_::import (
536
+ (std::string (model_version_) + " ." + model_name_trimmed).c_str ())
524
537
.attr (" TritonPythonModel" );
525
538
deserialize_bytes_ = python_backend_utils.attr (" deserialize_bytes_tensor" );
526
539
serialize_bytes_ = python_backend_utils.attr (" serialize_byte_tensor" );
0 commit comments