Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TaskFlow] Fix pir for taskflow #9822

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
8fc77d1
fix pir for taskflow
DrownFish19 Jan 24, 2025
5696e9b
update suffix for PIR(ON/OFF)
DrownFish19 Feb 6, 2025
236a2ec
Merge remote-tracking branch 'paddlenlp/develop' into dev_20250124_fi…
DrownFish19 Feb 6, 2025
46c9a6a
fix and revert skip
DrownFish19 Feb 6, 2025
20fa6a3
Merge branch 'PaddlePaddle:develop' into dev_20250124_fix_taskflow_infer
DrownFish19 Feb 8, 2025
ebdca6f
update dependency_parsing
DrownFish19 Feb 10, 2025
7be9e06
Merge branch 'PaddlePaddle:develop' into dev_20250124_fix_taskflow_infer
DrownFish19 Feb 10, 2025
7e0f0af
Merge branch 'PaddlePaddle:develop' into dev_20250124_fix_taskflow_infer
DrownFish19 Feb 11, 2025
fc00ddd
Merge branch 'dev_20250124_fix_taskflow_infer' of github.com:DrownFis…
DrownFish19 Feb 12, 2025
69281aa
Merge branch 'PaddlePaddle:develop' into dev_20250124_fix_taskflow_infer
DrownFish19 Feb 12, 2025
f841ab4
Merge branch 'develop' into dev_20250124_fix_taskflow_infer
ZHUI Feb 13, 2025
7c17c2a
fix pir bug
Fantasy-02 Feb 14, 2025
f0a16b3
fix pir bug
Fantasy-02 Feb 14, 2025
5ccdc73
fix lint
Fantasy-02 Feb 14, 2025
a6f2a23
Merge branch 'PaddlePaddle:develop' into dev_20250124_fix_taskflow_infer
DrownFish19 Feb 14, 2025
57969f4
fix ci
Fantasy-02 Feb 17, 2025
0b7c924
Merge remote-tracking branch 'paddlenlp/develop' into dev_20250124_fi…
DrownFish19 Feb 18, 2025
c32bee7
Merge branch 'dev_20250124_fix_taskflow_infer' of github.com:DrownFis…
DrownFish19 Feb 18, 2025
933ff21
disable test case
DrownFish19 Feb 18, 2025
a687b6f
Merge remote-tracking branch 'paddlenlp/develop' into dev_20250124_fi…
DrownFish19 Feb 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions llm/predict/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import numpy as np
import paddle
import paddle.incubate.multiprocessing as mp
from paddle.base.framework import in_cinn_mode, in_pir_executor_mode, use_pir_api
from paddle.base.framework import in_cinn_mode, in_pir_executor_mode
from paddle.distributed import fleet

try:
Expand All @@ -51,7 +51,12 @@
PretrainedTokenizer,
)
from paddlenlp.trl import llm_utils
from paddlenlp.utils.env import MAX_BSZ, MAX_DRAFT_TOKENS
from paddlenlp.utils.env import (
MAX_BSZ,
MAX_DRAFT_TOKENS,
PADDLE_INFERENCE_MODEL_SUFFIX,
PADDLE_INFERENCE_WEIGHTS_SUFFIX,
)
from paddlenlp.utils.import_utils import is_paddlenlp_ops_available
from paddlenlp.utils.log import logger

Expand Down Expand Up @@ -670,10 +675,11 @@ def _create_predictor(self, predictor_args: PredictorArgument):
infer_model_path = llm_utils.get_infer_model_path(
predictor_args.model_name_or_path, predictor_args.model_prefix
)
if use_pir_api():
config = paddle.inference.Config(infer_model_path + ".json", infer_model_path + ".pdiparams")
else:
config = paddle.inference.Config(infer_model_path + ".pdmodel", infer_model_path + ".pdiparams")

config = paddle.inference.Config(
infer_model_path + PADDLE_INFERENCE_MODEL_SUFFIX,
infer_model_path + PADDLE_INFERENCE_WEIGHTS_SUFFIX,
)

config.switch_ir_optim(True)
# remove `gpu_cpu_map_matmul_v2_to_matmul_pass` to avoid mapping matmul_v2 -> matmul op
Expand Down Expand Up @@ -1103,7 +1109,7 @@ def predict(self, input_texts: list[str], return_tokens=False):
self.full_hidden_states = self._infer(self.model_inputs)
else:
self._infer(self.model_inputs)
logger.info(f"running spend {time.time() - s_time}")
logger.info(f"running spend {time.time() - s_time}")

if self.proposer is not None:
self.proposer.postprocess(base_model_inputs=self.model_inputs)
Expand Down Expand Up @@ -1190,10 +1196,10 @@ def _create_predictor(self, predictor_args: PredictorArgument):
predictor_args.model_name_or_path, predictor_args.model_prefix
)

if use_pir_api():
config = paddle.inference.Config(infer_model_path + ".json", infer_model_path + ".pdiparams")
else:
config = paddle.inference.Config(infer_model_path + ".pdmodel", infer_model_path + ".pdiparams")
config = paddle.inference.Config(
infer_model_path + PADDLE_INFERENCE_MODEL_SUFFIX,
infer_model_path + PADDLE_INFERENCE_WEIGHTS_SUFFIX,
)

config.switch_ir_optim(False)
if predictor_args.device in paddle.device.get_all_custom_device_type():
Expand Down Expand Up @@ -1230,7 +1236,7 @@ def predict(self, input_texts: list[str], return_tokens=False):
self.proposer.insert_query(
base_model_inputs=self.model_inputs, real_bs=len(input_texts), seq_lens=self.seq_lens
)
logger.info(f"preprocess spend {time.time() - s_time}")
logger.info(f"preprocess spend {time.time() - s_time}")

result_queue = mp.Queue()
tensor_queue = mp.Queue()
Expand Down Expand Up @@ -1269,7 +1275,7 @@ def predict(self, input_texts: list[str], return_tokens=False):
self.full_hidden_states = self.predictor.run(list(self.model_inputs.values()))[0]
else:
self.predictor.run(list(self.model_inputs.values()))
logger.info(f"running spend {time.time() - s_time}")
logger.info(f"running spend {time.time() - s_time}")

if self.proposer is not None:
self.proposer.postprocess(base_model_inputs=self.model_inputs)
Expand Down Expand Up @@ -1303,7 +1309,7 @@ def create_predictor(
config: PretrainedConfig,
model_args: ModelArgument,
tokenizer: PretrainedTokenizer = None,
**kwargs
**kwargs,
):
"""
Create a predictor
Expand Down
2 changes: 1 addition & 1 deletion llm/server/docs/deploy_usage_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ cd /home/workspace/models_dir
# ├── rank_mapping.csv # 多卡模型会有此文件,如为单卡模型,则无此文件(可选,仅在多卡部署模式下需要)
# └── rank_0 # 保存模型结构和权重文件的目录
# ├── model.pdiparams
# └── model.pdmodel
# └── model.pdmodel 或者 model.json # Paddle 3.0 版本模型为model.json,Paddle 2.x 版本模型为model.pdmodel
```

### 创建容器
Expand Down
454 changes: 233 additions & 221 deletions llm/server/server/server/engine/infer.py

Large diffs are not rendered by default.

17 changes: 9 additions & 8 deletions paddlenlp/experimental/autonlp/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from ...utils.log import logger
from .auto_trainer_base import AutoTrainerBase
from .utils import UTCLoss
from .utils.env import PADDLE_INFERENCE_MODEL_SUFFIX, PADDLE_INFERENCE_WEIGHTS_SUFFIX

Check warning on line 50 in paddlenlp/experimental/autonlp/text_classification.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/autonlp/text_classification.py#L50

Added line #L50 was not covered by tests


class AutoTrainerForTextClassification(AutoTrainerBase):
Expand Down Expand Up @@ -560,16 +561,16 @@
if os.path.exists(default_export_path):
if "utc" in model_config["model_name_or_path"]:
files = [
"model.pdiparams",
"model.pdmodel",
f"model{PADDLE_INFERENCE_WEIGHTS_SUFFIX}",
f"model{PADDLE_INFERENCE_MODEL_SUFFIX}",
"tokenizer_config.json",
"vocab.txt",
"taskflow_config.json",
]
else:
files = [
"model.pdiparams",
"model.pdmodel",
f"model{PADDLE_INFERENCE_WEIGHTS_SUFFIX}",
f"model{PADDLE_INFERENCE_MODEL_SUFFIX}",
"tokenizer_config.json",
"vocab.txt",
"taskflow_config.json",
Expand Down Expand Up @@ -735,8 +736,8 @@
executor=exe,
batch_generator=_batch_generator_func,
model_dir=export_path,
model_filename="model.pdmodel",
params_filename="model.pdiparams",
model_filename=f"model{PADDLE_INFERENCE_MODEL_SUFFIX}",
params_filename=f"model{PADDLE_INFERENCE_WEIGHTS_SUFFIX}",
batch_size=batch_size,
batch_nums=batch_nums,
scope=None,
Expand All @@ -757,8 +758,8 @@
post_training_quantization.quantize()
post_training_quantization.save_quantized_model(
save_model_path=compress_path,
model_filename="model.pdmodel",
params_filename="model.pdiparams",
model_filename=f"model{PADDLE_INFERENCE_MODEL_SUFFIX}",
params_filename=f"model{PADDLE_INFERENCE_WEIGHTS_SUFFIX}",
)

paddle.disable_static()
57 changes: 15 additions & 42 deletions paddlenlp/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1437,55 +1437,28 @@
outputs, input_ids, cur_len_gpu, origin_len_gpu, scores, unfinished_flag, model_kwargs, pad_token_id
)

if hasattr(paddle.framework, "_no_check_dy2st_diff"):
# TODO(daisiming): _no_check_dy2st_diff is used to turn off the checking of behavior
# inconsistency between dynamic graph and static graph. _no_check_dy2st_diff should be
# removed after static graphs support inplace and stride.
with paddle.framework._no_check_dy2st_diff():
paddle.increment(cur_len)
paddle.increment(cur_len_gpu)
else:
paddle.increment(cur_len)
paddle.increment(cur_len_gpu)
cur_len += 1
cur_len_gpu += 1

Check warning on line 1441 in paddlenlp/generation/utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/generation/utils.py#L1440-L1441

Added lines #L1440 - L1441 were not covered by tests

attn_mask = model_kwargs["attention_mask"]
# make the shape of attention_mask = (-1, -1, -1, -1) in dy2static.
model_kwargs["attention_mask"] = paddle.reshape(attn_mask, attn_mask.shape)
model_kwargs["cache"] = outputs[1] if isinstance(outputs, tuple) else None
max_new_tokens = paddle.full([1], max_new_tokens + cur_len - 1, dtype="int64")

if hasattr(paddle.framework, "_no_check_dy2st_diff"):
# TODO(daisiming): _no_check_dy2st_diff is used to turn off the checking of behavior
# inconsistency between dynamic graph and static graph. _no_check_dy2st_diff should be
# removed after static graphs support inplace and stride.
with paddle.framework._no_check_dy2st_diff():
while cur_len < max_new_tokens and paddle.any(unfinished_flag):
input_ids, scores, unfinished_flag, model_kwargs = _post_process_(
_forward_(**model_kwargs),
input_ids,
cur_len_gpu,
origin_len_gpu,
scores,
unfinished_flag,
model_kwargs,
pad_token_id,
)
paddle.increment(cur_len)
paddle.increment(cur_len_gpu)
else:
while cur_len < max_new_tokens and paddle.any(unfinished_flag):
input_ids, scores, unfinished_flag, model_kwargs = _post_process_(
_forward_(**model_kwargs),
input_ids,
cur_len_gpu,
origin_len_gpu,
scores,
unfinished_flag,
model_kwargs,
pad_token_id,
)
paddle.increment(cur_len)
paddle.increment(cur_len_gpu)
while cur_len < max_new_tokens and paddle.any(unfinished_flag):
input_ids, scores, unfinished_flag, model_kwargs = _post_process_(

Check warning on line 1450 in paddlenlp/generation/utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/generation/utils.py#L1449-L1450

Added lines #L1449 - L1450 were not covered by tests
_forward_(**model_kwargs),
input_ids,
cur_len_gpu,
origin_len_gpu,
scores,
unfinished_flag,
model_kwargs,
pad_token_id,
)
cur_len += 1
cur_len_gpu += 1

Check warning on line 1461 in paddlenlp/generation/utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/generation/utils.py#L1460-L1461

Added lines #L1460 - L1461 were not covered by tests

return input_ids[:, origin_len:], scores

Expand Down
16 changes: 11 additions & 5 deletions paddlenlp/server/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import paddle

from ..utils.env import PADDLE_INFERENCE_MODEL_SUFFIX, PADDLE_INFERENCE_WEIGHTS_SUFFIX
from ..utils.log import logger


Expand All @@ -40,13 +41,15 @@

def _get_default_static_model_path(self):
# The model path had the static_model_path
static_model_path = os.path.join(self._model_path, self._default_static_model_path, "inference.pdmodel")
static_model_path = os.path.join(

Check warning on line 44 in paddlenlp/server/predictor.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/server/predictor.py#L44

Added line #L44 was not covered by tests
self._model_path, self._default_static_model_path, f"inference{PADDLE_INFERENCE_MODEL_SUFFIX}"
)
if os.path.exists(static_model_path):
return os.path.join(self._model_path, self._default_static_model_path, "inference")
for file_name in os.listdir(self._model_path):
# FIXME(wawltor) The path maybe not correct
if file_name.count(".pdmodel"):
return os.path.join(self._model_path, file_name[:-8])
if file_name.count(PADDLE_INFERENCE_MODEL_SUFFIX):
return os.path.join(self._model_path, file_name[: -len(PADDLE_INFERENCE_MODEL_SUFFIX)])

Check warning on line 52 in paddlenlp/server/predictor.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/server/predictor.py#L51-L52

Added lines #L51 - L52 were not covered by tests
return None

def _is_int8_model(self, model_path):
Expand Down Expand Up @@ -110,7 +113,10 @@
"""
Construct the input data and predictor in the PaddlePaddele static mode.
"""
self._config = paddle.inference.Config(static_model_path + ".pdmodel", static_model_path + ".pdiparams")
self._config = paddle.inference.Config(

Check warning on line 116 in paddlenlp/server/predictor.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/server/predictor.py#L116

Added line #L116 was not covered by tests
static_model_path + PADDLE_INFERENCE_MODEL_SUFFIX,
static_model_path + PADDLE_INFERENCE_WEIGHTS_SUFFIX,
)
self._config.disable_glog_info()
if paddle.get_device() == "cpu":
self._config.disable_gpu()
Expand Down Expand Up @@ -146,7 +152,7 @@
os.mkdir(onnx_dir)
float_onnx_file = os.path.join(onnx_dir, "model.onnx")
if not os.path.exists(float_onnx_file):
model_path = static_model_path + ".pdmodel"
model_path = static_model_path + PADDLE_INFERENCE_MODEL_SUFFIX

Check warning on line 155 in paddlenlp/server/predictor.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/server/predictor.py#L155

Added line #L155 was not covered by tests
params_file = static_model_path + ".pdiparams"
onnx_model = paddle2onnx.command.c_paddle_to_onnx(
model_file=model_path, params_file=params_file, opset_version=13, enable_onnx_checker=True
Expand Down
21 changes: 11 additions & 10 deletions paddlenlp/taskflow/multimodal_feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from PIL import Image

from ..transformers import AutoModel, AutoProcessor
from ..utils.env import PADDLE_INFERENCE_MODEL_SUFFIX, PADDLE_INFERENCE_WEIGHTS_SUFFIX
from ..utils.log import logger
from .task import Task
from .utils import dygraph_mode_guard, static_mode_guard
Expand Down Expand Up @@ -411,9 +412,9 @@
self.inference_image_model_path = os.path.join(_base_path, "static", "get_image_features")
self.inference_text_model_path = os.path.join(_base_path, "static", "get_text_features")
if (
not os.path.exists(self.inference_image_model_path + ".pdiparams")
not os.path.exists(self.inference_image_model_path + PADDLE_INFERENCE_WEIGHTS_SUFFIX)
or self._param_updated
or not os.path.exists(self.inference_text_model_path + ".pdiparams")
or not os.path.exists(self.inference_text_model_path + PADDLE_INFERENCE_WEIGHTS_SUFFIX)
):
with dygraph_mode_guard():
self._construct_model(self.model)
Expand All @@ -422,8 +423,8 @@
if self._predictor_type == "paddle-inference":
# Get text inference model
self.inference_model_path = self.inference_text_model_path
self._static_model_file = self.inference_model_path + ".pdmodel"
self._static_params_file = self.inference_model_path + ".pdiparams"
self._static_model_file = self.inference_model_path + PADDLE_INFERENCE_MODEL_SUFFIX
self._static_params_file = self.inference_model_path + PADDLE_INFERENCE_WEIGHTS_SUFFIX

Check warning on line 427 in paddlenlp/taskflow/multimodal_feature_extraction.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/taskflow/multimodal_feature_extraction.py#L426-L427

Added lines #L426 - L427 were not covered by tests
self._config = paddle.inference.Config(self._static_model_file, self._static_params_file)
self._prepare_static_mode()

Expand All @@ -435,8 +436,8 @@

# Get image inference model
self.inference_model_path = self.inference_image_model_path
self._static_model_file = self.inference_model_path + ".pdmodel"
self._static_params_file = self.inference_model_path + ".pdiparams"
self._static_model_file = self.inference_model_path + PADDLE_INFERENCE_MODEL_SUFFIX
self._static_params_file = self.inference_model_path + PADDLE_INFERENCE_WEIGHTS_SUFFIX

Check warning on line 440 in paddlenlp/taskflow/multimodal_feature_extraction.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/taskflow/multimodal_feature_extraction.py#L439-L440

Added lines #L439 - L440 were not covered by tests
self._config = paddle.inference.Config(self._static_model_file, self._static_params_file)
self._prepare_static_mode()

Expand All @@ -449,15 +450,15 @@
# Get text onnx model
self.export_type = "text"
self.inference_model_path = self.inference_text_model_path
self._static_model_file = self.inference_model_path + ".pdmodel"
self._static_params_file = self.inference_model_path + ".pdiparams"
self._static_model_file = self.inference_model_path + PADDLE_INFERENCE_MODEL_SUFFIX
self._static_params_file = self.inference_model_path + PADDLE_INFERENCE_WEIGHTS_SUFFIX

Check warning on line 454 in paddlenlp/taskflow/multimodal_feature_extraction.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/taskflow/multimodal_feature_extraction.py#L453-L454

Added lines #L453 - L454 were not covered by tests
self._prepare_onnx_mode()
self.predictor_map["text"] = self.predictor

# Get image onnx model
self.export_type = "image"
self.inference_model_path = self.inference_image_model_path
self._static_model_file = self.inference_model_path + ".pdmodel"
self._static_params_file = self.inference_model_path + ".pdiparams"
self._static_model_file = self.inference_model_path + PADDLE_INFERENCE_MODEL_SUFFIX
self._static_params_file = self.inference_model_path + PADDLE_INFERENCE_WEIGHTS_SUFFIX

Check warning on line 462 in paddlenlp/taskflow/multimodal_feature_extraction.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/taskflow/multimodal_feature_extraction.py#L461-L462

Added lines #L461 - L462 were not covered by tests
self._prepare_onnx_mode()
self.predictor_map["image"] = self.predictor
Loading
Loading