Skip to content

Commit 7263dab

Browse files
authored
Merge pull request #37 from arena-ai/26-refine-logprobs-handling
Refine logprobs handling
2 parents d4d9abb + cba95a2 commit 7263dab

File tree

2 files changed

+206
-114
lines changed

2 files changed

+206
-114
lines changed

backend/app/api/routes/dde.py

+121-83
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Iterable, Literal
1+
from typing import Any, Iterable, Literal, TypedDict
22
from app.lm.models.chat_completion import TokenLogprob
33
from app.lm.models import ChatCompletionResponse
44
from fastapi import APIRouter, HTTPException, status, UploadFile
@@ -463,56 +463,106 @@ async def extract_from_file(
463463
full_system_content = f"{system_prompt}\n{examples_text}"
464464

465465
messages = [
466-
ChatCompletionMessage(role="system", content=full_system_content),
467-
ChatCompletionMessage(
468-
role="user",
469-
content=f"Maintenant, faites la même extraction sur un nouveau document d'input:\n####\nINPUT:{prompt}",
470-
),
471-
]
472-
pydantic_reponse = create_pydantic_model(
473-
json.loads(document_data_extractor.response_template)
474-
)
475-
format_response = {
476-
"type": "json_schema",
477-
"json_schema": {
478-
"schema": to_strict_json_schema(pydantic_reponse),
479-
"name": "response",
480-
"strict": True,
481-
},
482-
}
483-
484-
chat_completion_request = ChatCompletionRequest(
485-
model="gpt-4o-2024-08-06",
486-
messages=messages,
487-
max_tokens=2000,
488-
temperature=0.1,
489-
logprobs=True,
490-
top_logprobs=5,
491-
response_format=format_response,
492-
).model_dump(exclude_unset=True)
493-
494-
chat_completion_response = await ArenaHandler(
495-
session, document_data_extractor.owner, chat_completion_request
496-
).process_request()
497-
extracted_info = chat_completion_response.choices[0].message.content
498-
# TODO: handle refusal or case in which content was not correctly done
466+
ChatCompletionMessage(role="system", content=full_system_content),
467+
ChatCompletionMessage(role="user", content=f"Maintenant, faites la même extraction sur un nouveau document d'input:\n####\nINPUT:{prompt}")
468+
]
469+
470+
pydantic_reponse=create_pydantic_model(json.loads(document_data_extractor.response_template))
471+
format_response={"type": "json_schema",
472+
"json_schema":{
473+
"schema":to_strict_json_schema(pydantic_reponse),
474+
"name":'response',
475+
'strict':True}}
476+
477+
chat_completion_request = ChatCompletionRequest(
478+
model='gpt-4o-2024-08-06',
479+
messages=messages,
480+
max_tokens=2000,
481+
temperature=0.1,
482+
logprobs=True,
483+
top_logprobs= 5,
484+
response_format=format_response
485+
486+
).model_dump(exclude_unset=True)
487+
488+
chat_completion_response = await ArenaHandler(session, document_data_extractor.owner, chat_completion_request).process_request()
489+
extracted_data=chat_completion_response.choices[0].message.content
490+
extracted_data_token = chat_completion_response.choices[0].logprobs.content
491+
#TODO: handle refusal or case in which content was not correctly done
499492
# TODO: Improve the prompt to ensure the output is always a valid JSON
500-
json_string = extracted_info[
501-
extracted_info.find("{") : extracted_info.rfind("}") + 1
502-
]
503-
extracted_data = {
504-
k: v
505-
for k, v in json.loads(json_string).items()
506-
if k not in ("source", "year")
507-
}
508-
logprob_data = extract_logprobs_from_response(
509-
chat_completion_response, extracted_data
510-
)
511-
return {
512-
"extracted_info": json.loads(json_string),
513-
"logprob_data": logprob_data,
514-
}
493+
json_string = extracted_data[extracted_data.find('{'):extracted_data.rfind('}')+1]
494+
keys = list(pydantic_reponse.__fields__.keys())
495+
value_indices = extract_tokens_indices_for_each_key(keys, extracted_data_token)
496+
logprobs = extract_logprobs_from_indices(value_indices, extracted_data_token)
497+
return {'extracted_data': json.loads(json_string), 'logprobs': logprobs}
498+
499+
class Token(TypedDict):
500+
token: str
501+
502+
def extract_tokens_indices_for_each_key(keys: list[str], token_list:list[Token]) -> dict[str, list[int]]:
503+
"""
504+
Extracts the indices of tokens corresponding to extracted data related to a list of specified keys.
515505
506+
The extraction criteria are based on the following:
507+
- The function looks for tokens that match the specified keys.
508+
- It saves the indices of the tokens that correspond to the values extracted by the model for each key.
509+
- Tokens' indices are saved if they follow the pattern '":' or '":"'.
510+
- It stops saving indices if it encounters a token that indicates the start of a new key or the end of the object.
511+
512+
Args:
513+
keys (list[str]): A list of keys for which to find corresponding token indices.
514+
token_list (list[Token]): A list of Token objects, each containing a token.
515+
516+
Returns:
517+
dict[str, list[int]]: A dictionary mapping each key to the corresponding indices of the tokens representing the values extracted by the model.
518+
"""
519+
value_indices = {key: [] for key in keys}
520+
current_key = ""
521+
matched_key = None
522+
remaining_keys = keys.copy()
523+
saving_indices = False
524+
for i, token_object in enumerate(token_list):
525+
token = token_object.token
526+
if matched_key is not None:
527+
if saving_indices:
528+
if token == '","' or token == ',"':
529+
next_token = token_list[i + 1].token if i + 1 < len(token_list) else None
530+
if next_token is not None and any(key.startswith(next_token) for key in remaining_keys):
531+
value_indices[matched_key].append(i - 1) #stop saving indices when token is "," and the next token is the start of one of the keys
532+
matched_key = None
533+
saving_indices = False
534+
current_key = ""
535+
continue
536+
elif token_list[i + 1].token == '}':
537+
value_indices[matched_key].append(i) #stop saving indices when the next token is '}'
538+
matched_key = None
539+
saving_indices = False
540+
current_key = ""
541+
continue
542+
continue
543+
elif token == '":' or token == '":"':
544+
value_indices[matched_key].append(i + 1) #start saving indices after tokens '":' or '":"'
545+
saving_indices = True
546+
else:
547+
current_key += token
548+
for key in remaining_keys:
549+
if key.startswith(current_key):
550+
if current_key == key:
551+
matched_key = key #full key matched
552+
remaining_keys.remove(key)
553+
break
554+
else:
555+
current_key = ""
556+
return value_indices
557+
558+
def extract_logprobs_from_indices(value_indices: dict[str, list[int]], token_list: list[Token]) -> dict[str, list[Any]]:
559+
logprobs = {key: [] for key in value_indices}
560+
for key, indices in value_indices.items():
561+
start_idx = indices[0]
562+
end_idx = indices[-1]
563+
for i in range(start_idx, end_idx + 1):
564+
logprobs[key].append(token_list[i].top_logprobs[0].logprob)
565+
return logprobs
516566

517567
def create_pydantic_model(
518568
schema: dict[
@@ -607,35 +657,25 @@ def extract_logprobs_from_response(
607657
response: ChatCompletionResponse, extracted_data: dict[str, Any]
608658
) -> dict[str, float | list[float]]:
609659
logprob_data = {}
610-
tokens_info = response.choices[0].logprobs.content
611-
612-
def process_numeric_values(extracted_data: dict[str, Any], path=""):
613-
for i in range(len(tokens_info) - 1):
614-
token = tokens_info[i].token
615-
616-
if token.isdigit(): # Only process tokens that are numeric
617-
combined_token, combined_logprob = combine_tokens(
618-
tokens_info, i
619-
)
620-
if combined_token_in_extracted_data(
621-
combined_token, extracted_data.values()
622-
): # Checks if a combined token matches any numeric values in the extracted data.
660+
extracted_data_token = response.choices[0].logprobs.content
661+
662+
def process_numeric_values(extracted_data: dict[str, Any], path=''):
663+
664+
for i in range(len(extracted_data_token)-1):
665+
token = extracted_data_token[i].token
666+
if token.isdigit(): # Only process tokens that are numeric
667+
combined_token, combined_logprob = combine_tokens(extracted_data_token, i)
668+
if combined_token_in_extracted_data(combined_token, extracted_data.values()): #Checks if a combined token matches any numeric values in the extracted data.
623669
key = find_key_by_value(
624670
combined_token, extracted_data
625671
) # Finds the key in 'extracted_data' corresponding to a numeric value that matches the combined token.
626672
if key:
627-
full_key = path + key
628-
logprob_data[full_key + "_prob_first_token"] = (
629-
math.exp(tokens_info[i].logprob)
630-
)
631-
logprob_data[full_key + "_prob_second_token"] = (
632-
math.exp(tokens_info[i + 1].logprob)
633-
)
673+
full_key = path + key
674+
logprob_data[full_key + '_prob_first_token'] = math.exp(extracted_data_token[i].logprob)
675+
logprob_data[full_key + '_prob_second_token'] = math.exp(extracted_data_token[i+1].logprob)
634676

635-
toplogprobs_firsttoken = tokens_info[i].top_logprobs
636-
toplogprobs_secondtoken = tokens_info[
637-
i + 1
638-
].top_logprobs
677+
toplogprobs_firsttoken = extracted_data_token[i].top_logprobs
678+
toplogprobs_secondtoken = extracted_data_token[i+1].top_logprobs
639679

640680
logprobs_first = [
641681
top_logprob.logprob
@@ -666,17 +706,15 @@ def traverse_and_extract(data: dict, path=""):
666706
return logprob_data
667707

668708

669-
def combine_tokens(
670-
tokens_info: list[TokenLogprob], start_index: int
671-
) -> tuple[str, float]:
672-
combined_token = tokens_info[start_index].token
673-
combined_logprob = tokens_info[start_index].logprob
709+
def combine_tokens(extracted_data_token: list[TokenLogprob], start_index: int) -> tuple[str, float]:
710+
combined_token = extracted_data_token[start_index].token
711+
combined_logprob = extracted_data_token[start_index].logprob
674712

675713
# Keep combining tokens as long as the next token is a digit
676-
for i in range(start_index + 1, len(tokens_info)):
677-
if not tokens_info[i].token.isdigit():
714+
for i in range(start_index + 1, len(extracted_data_token)):
715+
if not extracted_data_token[i].token.isdigit():
678716
break
679-
combined_token += tokens_info[i].token
680-
combined_logprob += tokens_info[i].logprob
681-
717+
combined_token += extracted_data_token[i].token
718+
combined_logprob += extracted_data_token[i].logprob
719+
682720
return combined_token, combined_logprob

0 commit comments

Comments
 (0)