Skip to content

Commit ec515c9

Browse files
authoredAug 18, 2021
Fix default_model_filename (triton-inference-server#76)
1 parent 3d800b3 commit ec515c9

File tree

2 files changed

+30
-7
lines changed

2 files changed

+30
-7
lines changed
 

‎src/pb_stub.cc

+18-5
Original file line numberDiff line numberDiff line change
@@ -488,9 +488,21 @@ Stub::Execute(ExecuteArgs* execute_args, ResponseBatch* response_batch)
488488
void
489489
Stub::Initialize(InitializeArgs* initialize_args)
490490
{
491-
py::module sys = py::module::import("sys");
491+
py::module sys = py::module_::import("sys");
492492

493-
std::string model_name = model_path_.substr(model_path_.find_last_of("/") + 1);
493+
std::string model_name =
494+
model_path_.substr(model_path_.find_last_of("/") + 1);
495+
496+
// Model name without the .py extension
497+
auto dotpy_pos = model_name.find_last_of(".py");
498+
if (dotpy_pos == std::string::npos || dotpy_pos != model_name.size() - 1) {
499+
throw PythonBackendException(
500+
"Model name must end with '.py'. Model name is \"" + model_name + "\".");
501+
}
502+
503+
// The position of last character of the string that is searched for is
504+
// returned by 'find_last_of'. Need to manually adjust the position.
505+
std::string model_name_trimmed = model_name.substr(0, dotpy_pos - 2);
494506
std::string model_path_parent =
495507
model_path_.substr(0, model_path_.find_last_of("/"));
496508
std::string model_path_parent_parent =
@@ -501,9 +513,9 @@ Stub::Initialize(InitializeArgs* initialize_args)
501513
sys.attr("path").attr("append")(python_backend_folder);
502514

503515
py::module python_backend_utils =
504-
py::module::import("triton_python_backend_utils");
516+
py::module_::import("triton_python_backend_utils");
505517
py::module c_python_backend_utils =
506-
py::module::import("c_python_backend_utils");
518+
py::module_::import("c_python_backend_utils");
507519
py::setattr(
508520
python_backend_utils, "Tensor", c_python_backend_utils.attr("Tensor"));
509521
py::setattr(
@@ -520,7 +532,8 @@ Stub::Initialize(InitializeArgs* initialize_args)
520532
c_python_backend_utils.attr("TritonModelException"));
521533

522534
py::object TritonPythonModel =
523-
py::module::import((model_version_ + std::string(".model")).c_str())
535+
py::module_::import(
536+
(std::string(model_version_) + "." + model_name_trimmed).c_str())
524537
.attr("TritonPythonModel");
525538
deserialize_bytes_ = python_backend_utils.attr("deserialize_bytes_tensor");
526539
serialize_bytes_ = python_backend_utils.attr("serialize_byte_tensor");

‎src/python.cc

+12-2
Original file line numberDiff line numberDiff line change
@@ -1231,8 +1231,18 @@ ModelInstanceState::SetupStubProcess()
12311231
const char* model_path = model_state->RepositoryPath().c_str();
12321232

12331233
std::stringstream ss;
1234-
// Use <path>/version/model.py as the model location
1235-
ss << model_path << "/" << model_version << "/model.py";
1234+
std::string artifact_name;
1235+
RETURN_IF_ERROR(model_state->ModelConfig().MemberAsString(
1236+
"default_model_filename", &artifact_name));
1237+
ss << model_path << "/" << model_version << "/";
1238+
1239+
if (artifact_name.size() > 0) {
1240+
ss << artifact_name;
1241+
} else {
1242+
// Default artifact name.
1243+
ss << "model.py";
1244+
}
1245+
12361246
model_path_ = ss.str();
12371247
struct stat buffer;
12381248

0 commit comments

Comments
 (0)
Please sign in to comment.