Skip to content

Commit a244453

Browse files
authoredJul 22, 2024··
[Feature] Support inference ppl datasets (#1315)
* commit inference ppl datasets * revised format * revise * revise * revise * revise * revise * revise
1 parent e938482 commit a244453

File tree

12 files changed

+662
-0
lines changed

12 files changed

+662
-0
lines changed
 
+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Inference-PPL Datasets
2+
3+
- **Description**: Compute the loss only on the labeled positions, especially used for reasoning corpus.
4+
- **Datasets**: cn-reasoning-val.jsonl (example datasets, inference-ppl can be generalized to more corpus).
5+
6+
# PPL Computation
7+
8+
$$ \text{ppl} = - \frac{1}{n} \sum_{i=0}^n \sum_{c=0}^{vocab\_size} y_{i,c} \log p_{i,c} \tag{1} $$
9+
10+
where Eq. (1) is the normal mean ppl computation formula, for inference-ppl, we only compute the average score based on pre-labeled position.
11+
12+
# Quick Start
13+
14+
```shell
15+
cd opencompass
16+
python run.py configs/eval_inference_ppl.py
17+
```
18+
19+
# Some results
20+
21+
| Model | Result |
22+
| ----------- | ----------- |
23+
| Qwen1.5-7b | 0.59 |
24+
| Qwen1.5-14b | 0.54 |
25+
| Llama2-7b | 0.49 |
26+
| Llama2-13b | 0.43 |
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from opencompass.openicl.icl_prompt_template import PromptTemplate
2+
from opencompass.openicl.icl_retriever import ZeroRetriever
3+
from opencompass.openicl.icl_inferencer import InferencePPLOnlyInferencer
4+
from opencompass.openicl.icl_evaluator import AverageInferencePPLEvaluator
5+
6+
from opencompass.datasets import InferencePPLDataset
7+
8+
# Build InferencePPLDataset
9+
inference_ppl_datasets = []
10+
11+
llm_cmp_infer_cfg = dict(
12+
prompt_template=dict(
13+
type=PromptTemplate,
14+
template='{text}',
15+
),
16+
# No in-context example, using ZeroRetriever
17+
retriever=dict(type=ZeroRetriever),
18+
# compute inference-ppl
19+
inferencer=dict(type=InferencePPLOnlyInferencer),
20+
)
21+
22+
# Average the inference-ppl scores
23+
llm_cmp_eval_cfg = dict(evaluator=dict(type=AverageInferencePPLEvaluator))
24+
25+
inference_ppl_datasets.append(
26+
dict(
27+
abbr=f'inference-ppl',
28+
type=InferencePPLDataset,
29+
path='./data/inference_ppl',
30+
name='cn-reasoning-val',
31+
samples=None, # Set small samples for testing
32+
reader_cfg=dict(
33+
input_columns=['text'],
34+
output_column=None,
35+
),
36+
infer_cfg=llm_cmp_infer_cfg,
37+
eval_cfg=llm_cmp_eval_cfg,
38+
))

‎configs/eval_inference_ppl.py

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from mmengine.config import read_base
2+
3+
with read_base():
4+
# Inference PPL datasets
5+
from .datasets.inference_ppl.inference_ppl import inference_ppl_datasets
6+
7+
# Model configs
8+
from .models.qwen.hf_qwen1_5_7b import models as qwen1_5_7b
9+
from .models.qwen.hf_qwen1_5_14b import models as qwen1_5_14b
10+
from .models.hf_llama.hf_llama2_7b import models as llama2_7b
11+
from .models.hf_llama.hf_llama2_13b import models as llama2_13b
12+
13+
14+
from opencompass.partitioners import NaivePartitioner
15+
from opencompass.runners import LocalRunner
16+
from opencompass.tasks import OpenICLInferTask, OpenICLEvalTask
17+
18+
19+
# -------------Inference Stage ----------------------------------------
20+
21+
datasets = [*inference_ppl_datasets]
22+
workdir = 'outputs/inference_ppl'
23+
24+
models = [
25+
*qwen1_5_7b,
26+
*qwen1_5_14b,
27+
*llama2_7b,
28+
*llama2_13b,
29+
]
30+
31+
32+
33+
# Set custom batch_size and num_gpus for faster loss calculation
34+
# Smaller batch_size should give more precise results, at the cost of worse efficiency
35+
model_cfg = dict(
36+
batch_size=8,
37+
run_cfg=dict(num_gpus=4, num_procs=1)
38+
)
39+
40+
for mdl in models:
41+
mdl.update(model_cfg)
42+
43+
44+
infer = dict(
45+
partitioner=dict(type=NaivePartitioner),
46+
runner=dict(
47+
type=LocalRunner,
48+
task=dict(type=OpenICLInferTask),
49+
max_num_workers=256, # Maximum concurrent evaluation task count
50+
),
51+
)
52+
53+
54+
# -------------Evaluation Stage ----------------------------------------
55+
eval = dict(
56+
partitioner=dict(type=NaivePartitioner),
57+
runner=dict(
58+
type=LocalRunner,
59+
task=dict(type=OpenICLEvalTask),
60+
max_num_workers=256,
61+
)
62+
)

‎opencompass/datasets/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from .humanevalx import * # noqa: F401, F403
5454
from .hungarian_math import * # noqa: F401, F403
5555
from .IFEval.ifeval import IFEvalDataset, IFEvaluator # noqa: F401, F403
56+
from .inference_ppl import InferencePPLDataset # noqa: F401, F403
5657
from .infinitebench import * # noqa: F401, F403
5758
from .iwslt2017 import * # noqa: F401, F403
5859
from .jigsawmultilingual import * # noqa: F401, F403

‎opencompass/datasets/inference_ppl.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import os.path as osp
2+
from typing import List
3+
4+
from datasets import load_dataset
5+
6+
from opencompass.registry import LOAD_DATASET
7+
8+
from .base import BaseDataset
9+
10+
11+
@LOAD_DATASET.register_module()
12+
class InferencePPLDataset(BaseDataset):
13+
14+
@staticmethod
15+
def load(path: str, name: List[str] = None, samples: int = None):
16+
17+
# Check if file exists in the given path
18+
supported_extensions = ['jsonl']
19+
for ext in supported_extensions:
20+
filename = osp.join(
21+
path, f'{name}.{ext}') # name refers to data subset name
22+
23+
if osp.exists(filename):
24+
break
25+
else:
26+
raise FileNotFoundError(f'{filename} not found.')
27+
28+
samples = 'test' if samples is None else f'test[:{samples}]'
29+
30+
data_files = {'test': filename}
31+
32+
dataset = load_dataset('json', data_files=data_files, split=samples)
33+
34+
# Filter out empty samples
35+
dataset = dataset.filter(lambda example: len(example['text']) > 0)
36+
37+
return dataset

‎opencompass/models/base.py

+36
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,28 @@ def get_ppl(self,
8585
' ppl-based evaluation yet, try gen-based '
8686
'instead.')
8787

88+
@abstractmethod
89+
def get_ppl_tokenwise(
90+
self,
91+
inputs: List[str],
92+
mask_length: Optional[List[int]] = None) -> List[float]:
93+
"""Get tokenwise perplexity scores given a list of inputs.
94+
95+
Args:
96+
inputs (List[str]): A list of strings.
97+
mask_length (Optional[List[int]]): A list of mask lengths. If
98+
provided, the perplexity scores will be calculated with the
99+
first mask_length[i] tokens masked out. It's okay to skip
100+
its implementation if advanced features in PPLInfernecer is
101+
not needed.
102+
103+
Returns:
104+
List[float]: A list of perplexity scores.
105+
"""
106+
raise NotImplementedError(f'{self.__class__.__name__} does not support'
107+
' ppl-based evaluation yet, try gen-based '
108+
'instead.')
109+
88110
@abstractmethod
89111
def encode(self, prompt: str) -> torch.Tensor:
90112
"""Encode prompt to tokens. Not necessary for most cases.
@@ -151,6 +173,20 @@ def get_ppl_from_template(self,
151173
inputs = self.parse_template(templates, mode='ppl')
152174
return self.get_ppl(inputs, mask_length)
153175

176+
def get_ppl_tokenwise_from_template(self,
177+
templates: List[PromptType],
178+
label: List[List[int]],
179+
mask_length=None):
180+
"""Get token-wise perplexity given a list of templates.
181+
182+
Args:
183+
templates (List[PromptType]): A list of templates.
184+
mask_length (List[int]): A list of mask lengths. If provided, the
185+
perplexity will be calculated only on the unmasked tokens.
186+
"""
187+
inputs = self.parse_template(templates, mode='ppl')
188+
return self.get_ppl_tokenwise(inputs, label, mask_length)
189+
154190
def generate_from_template(self, templates: List[PromptType],
155191
max_out_len: int, **kwargs):
156192
"""Generate completion from a list of templates.

‎opencompass/models/huggingface_above_v4_33.py

+159
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,165 @@ def _load_model(self, path: str, kwargs: dict, peft_path: Optional[str] = None,
226226
self.model.eval()
227227
self.model.generation_config.do_sample = False
228228

229+
230+
def get_ppl_tokenwise(self, inputs: List[str], label: List[List[int]], mask_length: Optional[List[int]] = None) -> List[float]:
231+
"""Get inference-ppl per token given a list of inputs and label.
232+
233+
Args:
234+
inputs (List[str]): A list of strings.
235+
label (List[List[int]]): A list of list of label, each label is a tuple of (start, end, 1)
236+
mask_length (Optional[List[int]]): A list of mask lengths. If
237+
provided, the perplexity scores will be calculated with the
238+
first mask_length[i] tokens masked out. It's okay to skip
239+
its implementation if advanced features in PPLInfernecer is
240+
not needed.
241+
242+
Returns:
243+
List[float]: A list of perplexity scores.
244+
"""
245+
assert self.tokenizer.pad_token
246+
import torch
247+
import torch.nn.functional as F
248+
pad_token_id = self.tokenizer.pad_token_id
249+
messages = _convert_base_messages(inputs)
250+
251+
tokenize_kwargs = dict(
252+
return_tensors='pt',
253+
padding=True,
254+
truncation=True,
255+
add_special_tokens=True,
256+
max_length=self.max_seq_len,
257+
)
258+
259+
self.tokenizer.padding_side = 'right'
260+
self.tokenizer.truncation_side = 'right'
261+
262+
tokens = self.tokenizer.batch_encode_plus(messages, **tokenize_kwargs)
263+
264+
tokens = {k: v.to(self.model.device) for k, v in tokens.items()}
265+
outputs = self.model(**tokens)[0]
266+
267+
batch_size, seq_len, vocab_size = outputs.shape
268+
shift_logits = outputs[:, :-1, :].contiguous().float()
269+
shift_labels = tokens['input_ids'][:, 1:].contiguous()
270+
loss = F.cross_entropy(
271+
shift_logits.view(-1, vocab_size),
272+
shift_labels.view(-1),
273+
ignore_index=pad_token_id,
274+
reduction='none').view(batch_size, seq_len - 1)
275+
lens = (tokens['input_ids'] != pad_token_id).sum(-1).cpu().numpy()
276+
277+
if mask_length is not None:
278+
import numpy as np
279+
mask = torch.zeros_like(shift_labels) # [batch,seqlen]
280+
for i in range(len(mask)):
281+
for j in range(mask_length[i] - 1, len(mask[i])):
282+
mask[i][j] = 1
283+
loss = loss * mask
284+
lens -= np.array(mask_length)
285+
286+
loss = loss.cpu().numpy()
287+
288+
decode_messages = [[self.tokenizer.decode([input_id]) for input_id in token] for token in tokens['input_ids']]
289+
char_messages = [[ch for ch in message] for message in messages]
290+
291+
# shifted to align label and loss
292+
for i in range(len(decode_messages)):
293+
decode_messages[i] = decode_messages[i][1:]
294+
295+
aggregated_label_list = [[] for _ in range(len(decode_messages))]
296+
297+
tag_list = [[] for _ in range(len(decode_messages))]
298+
299+
for tmp_index, label_list in enumerate(label):
300+
for single_label in label_list:
301+
left = single_label[0]
302+
right = single_label[1]
303+
for i in range(left, right):
304+
aggregated_label_list[tmp_index].append(i)
305+
306+
307+
def align_sequences(seq1, seq2, sep_len):
308+
"""
309+
seq1: decoded sequence from token, one token may contain multiple characters
310+
seq2: original separate character sequence
311+
"""
312+
i, j = 0, 0
313+
matched_pairs = []
314+
while i < len(seq1) and j < len(seq2):
315+
word = seq1[i]
316+
if len(word) == 0:
317+
matched_pairs.append((word, []))
318+
i += 1
319+
continue
320+
321+
if '\ufffd' in word:
322+
for _ in range(sep_len):
323+
matched_pairs.append((word, [j]))
324+
i += 1
325+
j += 1
326+
continue
327+
328+
char_sequence = ''
329+
while j < len(seq2) and (char_sequence != word):
330+
char_sequence += seq2[j]
331+
if char_sequence == word:
332+
matched_pairs.append((word, [k for k in range(j - len(word) + 1, j+1)]))
333+
j += 1
334+
break
335+
elif len(char_sequence) > len(word):
336+
if word == char_sequence[-len(word):]:
337+
matched_pairs.append((word, [k for k in range(j - len(word) + 1, j+1)]))
338+
j += 1
339+
break
340+
else:
341+
j += 1
342+
else:
343+
j += 1
344+
i += 1
345+
346+
return matched_pairs
347+
348+
349+
350+
if 'qwen' in self.path or 'Qwen' in self.path:
351+
sep_len = 2
352+
elif 'Llama-3' in self.path:
353+
sep_len = 2
354+
elif 'Yi' in self.path:
355+
sep_len = 3
356+
elif 'Llama-2' in self.path:
357+
sep_len = 3
358+
elif 'deepseek' in self.path:
359+
sep_len = 2
360+
else:
361+
sep_len = 3
362+
363+
364+
matched_pairs_list = [align_sequences(decode_messages[i], char_messages[i], sep_len) for i in range(len(decode_messages))]
365+
for match_index, matched_pairs in enumerate(matched_pairs_list):
366+
for i, (word, indices) in enumerate(matched_pairs):
367+
for j in indices:
368+
if j in aggregated_label_list[match_index]:
369+
tag_list[match_index].append(i)
370+
break
371+
372+
inference_loss_list = []
373+
token_len_list = []
374+
for i in range(len(loss)):
375+
inference_loss = 0
376+
token_len = 0
377+
for j in range(len(loss[i])):
378+
if j in tag_list[i]:
379+
380+
inference_loss += loss[i][j]
381+
print(loss[i][j])
382+
token_len += 1
383+
inference_loss_list.append(inference_loss)
384+
token_len_list.append(token_len)
385+
386+
return inference_loss_list, token_len_list
387+
229388
def _get_potential_stop_words(self, path: Optional[str]):
230389
from transformers import GenerationConfig
231390
potential_stop_words = []

‎opencompass/openicl/icl_evaluator/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .icl_em_evaluator import EMEvaluator # noqa
77
from .icl_hf_evaluator import * # noqa
88
from .icl_jieba_rouge_evaluator import JiebaRougeEvaluator # noqa
9+
from .icl_misc_evaluator import AverageInferencePPLEvaluator # noqa
910
from .icl_misc_evaluator import AverageMinKEvaluator # noqa
1011
from .icl_misc_evaluator import AveragePPLEvaluator # noqa
1112
from .icl_plugin_evaluator import TEvalEvaluator # noqa

‎opencompass/openicl/icl_evaluator/icl_misc_evaluator.py

+8
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,11 @@ class AverageMinKEvaluator(BaseEvaluator):
1717
def score(self, mink):
1818
average_mink = sum(mink) / len(mink)
1919
return {'average_mink': average_mink}
20+
21+
22+
@ICL_EVALUATORS.register_module()
23+
class AverageInferencePPLEvaluator(BaseEvaluator):
24+
25+
def score(self, ppl, token_len):
26+
average_ppl = sum(ppl) / sum(token_len)
27+
return {'average_ppl': average_ppl}

‎opencompass/openicl/icl_inferencer/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from .icl_chat_inferencer import ChatInferencer # noqa
55
from .icl_clp_inferencer import CLPInferencer # noqa
66
from .icl_gen_inferencer import GenInferencer # noqa
7+
from .icl_inference_ppl_only_inferencer import \
8+
InferencePPLOnlyInferencer # noqa
79
from .icl_ll_inferencer import LLInferencer # noqa
810
from .icl_mink_percent_inferencer import MinKPercentInferencer # noqa
911
from .icl_ppl_inferencer import PPLInferencer # noqa
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
"""PPL Inferencer."""
2+
3+
import os
4+
from typing import List, Optional
5+
6+
import mmengine
7+
import torch
8+
from tqdm import tqdm
9+
10+
from opencompass.models.base import BaseModel
11+
from opencompass.registry import ICL_INFERENCERS
12+
13+
from ..icl_prompt_template import PromptTemplate
14+
from ..icl_retriever import BaseRetriever
15+
from ..utils import get_logger
16+
from .icl_base_inferencer import BaseInferencer, dump_results_dict
17+
18+
logger = get_logger(__name__)
19+
20+
21+
@ICL_INFERENCERS.register_module()
22+
class InferencePPLOnlyInferencer(BaseInferencer):
23+
"""InferencePPLOnlyInferencer class to calculate Inference-PPL only, no
24+
choice is made. This Inferencer is usually used along with
25+
AverageInferencePPLEvaluator.
26+
27+
Attributes:
28+
model (:obj:`BaseModel`, optional): The module to inference.
29+
max_seq_len (:obj:`int`): Maximum number of tokenized words allowed by
30+
the LM.
31+
batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`
32+
output_json_filepath (:obj:`str`, optional): File path for output
33+
`JSON` file.
34+
output_json_filename (:obj:`str`, optional): File name for output
35+
`JSON` file.
36+
save_every (:obj:`int`, optional): Save intermediate results every
37+
"""
38+
39+
def __init__(
40+
self,
41+
model: BaseModel,
42+
max_seq_len: Optional[int] = None,
43+
batch_size: Optional[int] = 1,
44+
output_json_filepath: Optional[str] = './icl_inference_output',
45+
output_json_filename: Optional[str] = 'predictions',
46+
save_every: Optional[int] = 1,
47+
**kwargs) -> None:
48+
super().__init__(
49+
model=model,
50+
max_seq_len=max_seq_len,
51+
batch_size=batch_size,
52+
output_json_filename=output_json_filename,
53+
output_json_filepath=output_json_filepath,
54+
**kwargs,
55+
)
56+
57+
self.save_every = save_every
58+
59+
def inference(self,
60+
retriever: BaseRetriever,
61+
ice_template: Optional[PromptTemplate] = None,
62+
prompt_template: Optional[PromptTemplate] = None,
63+
output_json_filepath: Optional[str] = None,
64+
output_json_filename: Optional[str] = None) -> List:
65+
# 1. Preparation for output logs
66+
output_handler = InferencePPLOnlyInferencerOutputHandler()
67+
68+
if output_json_filepath is None:
69+
output_json_filepath = self.output_json_filepath
70+
if output_json_filename is None:
71+
output_json_filename = self.output_json_filename
72+
73+
# 2. Get results of retrieval process
74+
ice_idx_list = retriever.retrieve()
75+
76+
# 3. Generate prompts for testing input
77+
prompt_list, label_list = self.get_generation_prompt_list_and_label(
78+
ice_idx_list,
79+
retriever,
80+
max_seq_len=self.max_seq_len,
81+
ice_template=ice_template,
82+
prompt_template=prompt_template)
83+
84+
prompt_list = [{
85+
'prompt': prompt,
86+
'label': label
87+
} for prompt, label in zip(prompt_list, label_list)]
88+
89+
# 3.1 Fetch and zip prompt & gold answer if output column exists
90+
ds_reader = retriever.dataset_reader
91+
92+
assert ds_reader.output_column is None, (
93+
'InferencePPLOnlyInferencer supports `output_column=None` only.')
94+
95+
# Create tmp json file for saving intermediate results and future
96+
# resuming
97+
index = 0
98+
tmp_json_filepath = os.path.join(output_json_filepath,
99+
'tmp_' + output_json_filename)
100+
if os.path.exists(tmp_json_filepath):
101+
# TODO: move resume to output handler
102+
try:
103+
tmp_result_dict = mmengine.load(tmp_json_filepath)
104+
except Exception:
105+
pass
106+
else:
107+
output_handler.results_dict = tmp_result_dict
108+
index = len(tmp_result_dict)
109+
110+
# 4. Wrap prompts with Dataloader
111+
dataloader = self.get_dataloader(prompt_list[index:], self.batch_size)
112+
113+
# 5. Inference for prompts in each batch
114+
logger.info('Starting inference process...')
115+
for datum in tqdm(dataloader, disable=not self.is_main_process):
116+
entry = [datum_single['prompt'] for datum_single in datum]
117+
label = [datum_single['label'] for datum_single in datum]
118+
119+
# 5-1. Inference with local model
120+
with torch.no_grad():
121+
(inference_loss_list,
122+
token_len_list) = self.model.get_ppl_tokenwise_from_template(
123+
entry, label)
124+
125+
parsed_entries = self.model.parse_template(entry, mode='gen')
126+
# 5-3. Save current output
127+
for prompt, inference_loss, token_len, in zip(
128+
parsed_entries, inference_loss_list, token_len_list):
129+
output_handler.save_results(prompt, inference_loss, token_len,
130+
index)
131+
index = index + 1
132+
133+
# 5-4. Save intermediate results
134+
if (self.save_every is not None and index % self.save_every == 0
135+
and self.is_main_process):
136+
output_handler.write_to_json(output_json_filepath,
137+
'tmp_' + output_json_filename)
138+
139+
# 6. Output
140+
if self.is_main_process:
141+
os.makedirs(output_json_filepath, exist_ok=True)
142+
output_handler.write_to_json(output_json_filepath,
143+
output_json_filename)
144+
if os.path.exists(tmp_json_filepath):
145+
os.remove(tmp_json_filepath)
146+
147+
return [
148+
sample['ppl'] for sample in output_handler.results_dict.values()
149+
]
150+
151+
def get_generation_prompt_list_from_retriever_indices(
152+
self,
153+
ice_idx_list: List[List[int]],
154+
retriever: BaseRetriever,
155+
max_seq_len: Optional[int] = None,
156+
ice_template: Optional[PromptTemplate] = None,
157+
prompt_template: Optional[PromptTemplate] = None):
158+
prompt_list = []
159+
for idx, ice_idx in enumerate(ice_idx_list):
160+
ice = retriever.generate_ice(ice_idx, ice_template=ice_template)
161+
162+
prompt = retriever.generate_prompt_for_generate_task(
163+
idx,
164+
ice,
165+
ice_template=ice_template,
166+
prompt_template=prompt_template)
167+
168+
if max_seq_len is not None:
169+
prompt_token_num = self.model.get_token_len_from_template(
170+
prompt, mode='gen')
171+
while len(ice_idx) > 0 and prompt_token_num > max_seq_len:
172+
ice_idx = ice_idx[:-1]
173+
ice = retriever.generate_ice(ice_idx,
174+
ice_template=ice_template)
175+
prompt = retriever.generate_prompt_for_generate_task(
176+
idx,
177+
ice,
178+
ice_template=ice_template,
179+
prompt_template=prompt_template)
180+
prompt_token_num = self.model.get_token_len_from_template(
181+
prompt, mode='gen')
182+
prompt_list.append(prompt)
183+
return prompt_list
184+
185+
def get_generation_prompt_list_and_label(
186+
self,
187+
ice_idx_list: List[List[int]],
188+
retriever: BaseRetriever,
189+
max_seq_len: Optional[int] = None,
190+
ice_template: Optional[PromptTemplate] = None,
191+
prompt_template: Optional[PromptTemplate] = None):
192+
prompt_list = []
193+
label_list = []
194+
for idx, ice_idx in enumerate(ice_idx_list):
195+
ice = retriever.generate_ice(ice_idx, ice_template=ice_template)
196+
197+
prompt, label = retriever.generate_prompt_and_label_for_generate_task( # noqa
198+
idx,
199+
ice,
200+
ice_template=ice_template,
201+
prompt_template=prompt_template)
202+
203+
if max_seq_len is not None:
204+
prompt_token_num = self.model.get_token_len_from_template(
205+
prompt, mode='gen')
206+
while len(ice_idx) > 0 and prompt_token_num > max_seq_len:
207+
ice_idx = ice_idx[:-1]
208+
ice = retriever.generate_ice(ice_idx,
209+
ice_template=ice_template)
210+
prompt, label = retriever.generate_prompt_for_generate_task( # noqa
211+
idx,
212+
ice,
213+
ice_template=ice_template,
214+
prompt_template=prompt_template)
215+
prompt_token_num = self.model.get_token_len_from_template(
216+
prompt, mode='gen')
217+
prompt_list.append(prompt)
218+
label_list.append(label)
219+
return prompt_list, label_list
220+
221+
222+
class InferencePPLOnlyInferencerOutputHandler:
223+
origin_prompt_dict = {}
224+
output_dict = {}
225+
results_dict = {}
226+
227+
def __init__(self) -> None:
228+
self.results_dict = {}
229+
230+
def write_to_json(self, save_dir: str, filename: str):
231+
"""Dump the result to a json file."""
232+
dump_results_dict(self.results_dict, os.path.join(save_dir, filename))
233+
234+
def save_results(self, origin_prompt, ppl, token_len, idx):
235+
self.results_dict[str(idx)] = {
236+
'origin_prompt': origin_prompt,
237+
'ppl': ppl,
238+
'token_len': token_len,
239+
}

‎opencompass/openicl/icl_retriever/icl_base_retriever.py

+53
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,59 @@ def generate_prompt_for_generate_task(
207207
raise NotImplementedError(
208208
'Leaving prompt as empty is not supported')
209209

210+
def generate_prompt_and_label_for_generate_task(
211+
self,
212+
idx,
213+
ice,
214+
gen_field_replace_token='',
215+
ice_template: Optional[PromptTemplate] = None,
216+
prompt_template: Optional[PromptTemplate] = None):
217+
"""Generate the prompt and the label info for one test example in
218+
generative evaluation with `prompt_template`. If `prompt_template` is
219+
not provided, the `ice_template` will be used to generate the prompt.
220+
The token represented by `gen_field_replace_token` will not be replaced
221+
by the generated text, or it will leaks the answer.
222+
223+
Args:
224+
idx (`int`): The index of the test example.
225+
ice (`str`): The in-context example for the test example.
226+
gen_field_replace_token (`str`): The token of the answer in the
227+
prompt. Defaults to ''.
228+
ice_template (`Optional[PromptTemplate]`): The template for
229+
in-context example. Defaults to None.
230+
prompt_template (`Optional[PromptTemplate]`): The template for
231+
prompt. Defaults to None.
232+
"""
233+
if prompt_template is not None and ice_template is not None:
234+
if prompt_template.ice_token is not None:
235+
return prompt_template.generate_item(
236+
self.test_ds[idx],
237+
output_field=self.dataset_reader.output_column,
238+
output_field_replace_token=gen_field_replace_token,
239+
ice_field_replace_token=ice), self.test_ds[idx]['label']
240+
else:
241+
raise NotImplementedError(
242+
'ice_token of prompt_template is not provided')
243+
elif ice_template is not None and prompt_template is None:
244+
if ice_template.ice_token is not None:
245+
return ice_template.generate_item(
246+
self.test_ds[idx],
247+
output_field=self.dataset_reader.output_column,
248+
output_field_replace_token=gen_field_replace_token,
249+
ice_field_replace_token=ice), self.test_ds[idx]['label']
250+
else:
251+
raise NotImplementedError(
252+
'ice_token of ice_template is not provided')
253+
elif ice_template is None and prompt_template is not None:
254+
return prompt_template.generate_item(
255+
self.test_ds[idx],
256+
output_field=self.dataset_reader.output_column,
257+
output_field_replace_token=gen_field_replace_token,
258+
ice_field_replace_token=ice), self.test_ds[idx]['label']
259+
else:
260+
raise NotImplementedError(
261+
'Leaving prompt as empty is not supported')
262+
210263
def generate_prompt_for_adv_generate_task(
211264
self,
212265
idx,

0 commit comments

Comments
 (0)
Please sign in to comment.