|
1 |
| -from typing import Any, Iterable, Literal |
| 1 | +from typing import Any, Iterable, Literal, TypedDict |
2 | 2 | from app.lm.models.chat_completion import TokenLogprob
|
3 | 3 | from app.lm.models import ChatCompletionResponse
|
4 | 4 | from fastapi import APIRouter, HTTPException, status, UploadFile
|
@@ -463,56 +463,106 @@ async def extract_from_file(
|
463 | 463 | full_system_content = f"{system_prompt}\n{examples_text}"
|
464 | 464 |
|
465 | 465 | messages = [
|
466 |
| - ChatCompletionMessage(role="system", content=full_system_content), |
467 |
| - ChatCompletionMessage( |
468 |
| - role="user", |
469 |
| - content=f"Maintenant, faites la même extraction sur un nouveau document d'input:\n####\nINPUT:{prompt}", |
470 |
| - ), |
471 |
| - ] |
472 |
| - pydantic_reponse = create_pydantic_model( |
473 |
| - json.loads(document_data_extractor.response_template) |
474 |
| - ) |
475 |
| - format_response = { |
476 |
| - "type": "json_schema", |
477 |
| - "json_schema": { |
478 |
| - "schema": to_strict_json_schema(pydantic_reponse), |
479 |
| - "name": "response", |
480 |
| - "strict": True, |
481 |
| - }, |
482 |
| - } |
483 |
| - |
484 |
| - chat_completion_request = ChatCompletionRequest( |
485 |
| - model="gpt-4o-2024-08-06", |
486 |
| - messages=messages, |
487 |
| - max_tokens=2000, |
488 |
| - temperature=0.1, |
489 |
| - logprobs=True, |
490 |
| - top_logprobs=5, |
491 |
| - response_format=format_response, |
492 |
| - ).model_dump(exclude_unset=True) |
493 |
| - |
494 |
| - chat_completion_response = await ArenaHandler( |
495 |
| - session, document_data_extractor.owner, chat_completion_request |
496 |
| - ).process_request() |
497 |
| - extracted_info = chat_completion_response.choices[0].message.content |
498 |
| - # TODO: handle refusal or case in which content was not correctly done |
| 466 | + ChatCompletionMessage(role="system", content=full_system_content), |
| 467 | + ChatCompletionMessage(role="user", content=f"Maintenant, faites la même extraction sur un nouveau document d'input:\n####\nINPUT:{prompt}") |
| 468 | + ] |
| 469 | + |
| 470 | + pydantic_reponse=create_pydantic_model(json.loads(document_data_extractor.response_template)) |
| 471 | + format_response={"type": "json_schema", |
| 472 | + "json_schema":{ |
| 473 | + "schema":to_strict_json_schema(pydantic_reponse), |
| 474 | + "name":'response', |
| 475 | + 'strict':True}} |
| 476 | + |
| 477 | + chat_completion_request = ChatCompletionRequest( |
| 478 | + model='gpt-4o-2024-08-06', |
| 479 | + messages=messages, |
| 480 | + max_tokens=2000, |
| 481 | + temperature=0.1, |
| 482 | + logprobs=True, |
| 483 | + top_logprobs= 5, |
| 484 | + response_format=format_response |
| 485 | + |
| 486 | + ).model_dump(exclude_unset=True) |
| 487 | + |
| 488 | + chat_completion_response = await ArenaHandler(session, document_data_extractor.owner, chat_completion_request).process_request() |
| 489 | + extracted_data=chat_completion_response.choices[0].message.content |
| 490 | + extracted_data_token = chat_completion_response.choices[0].logprobs.content |
| 491 | + #TODO: handle refusal or case in which content was not correctly done |
499 | 492 | # TODO: Improve the prompt to ensure the output is always a valid JSON
|
500 |
| - json_string = extracted_info[ |
501 |
| - extracted_info.find("{") : extracted_info.rfind("}") + 1 |
502 |
| - ] |
503 |
| - extracted_data = { |
504 |
| - k: v |
505 |
| - for k, v in json.loads(json_string).items() |
506 |
| - if k not in ("source", "year") |
507 |
| - } |
508 |
| - logprob_data = extract_logprobs_from_response( |
509 |
| - chat_completion_response, extracted_data |
510 |
| - ) |
511 |
| - return { |
512 |
| - "extracted_info": json.loads(json_string), |
513 |
| - "logprob_data": logprob_data, |
514 |
| - } |
| 493 | + json_string = extracted_data[extracted_data.find('{'):extracted_data.rfind('}')+1] |
| 494 | + keys = list(pydantic_reponse.__fields__.keys()) |
| 495 | + value_indices = extract_tokens_indices_for_each_key(keys, extracted_data_token) |
| 496 | + logprobs = extract_logprobs_from_indices(value_indices, extracted_data_token) |
| 497 | + return {'extracted_data': json.loads(json_string), 'logprobs': logprobs} |
| 498 | + |
| 499 | +class Token(TypedDict): |
| 500 | + token: str |
| 501 | + |
| 502 | +def extract_tokens_indices_for_each_key(keys: list[str], token_list:list[Token]) -> dict[str, list[int]]: |
| 503 | + """ |
| 504 | + Extracts the indices of tokens corresponding to extracted data related to a list of specified keys. |
515 | 505 |
|
| 506 | + The extraction criteria are based on the following: |
| 507 | + - The function looks for tokens that match the specified keys. |
| 508 | + - It saves the indices of the tokens that correspond to the values extracted by the model for each key. |
| 509 | + - Tokens' indices are saved if they follow the pattern '":' or '":"'. |
| 510 | + - It stops saving indices if it encounters a token that indicates the start of a new key or the end of the object. |
| 511 | +
|
| 512 | + Args: |
| 513 | + keys (list[str]): A list of keys for which to find corresponding token indices. |
| 514 | + token_list (list[Token]): A list of Token objects, each containing a token. |
| 515 | +
|
| 516 | + Returns: |
| 517 | + dict[str, list[int]]: A dictionary mapping each key to the corresponding indices of the tokens representing the values extracted by the model. |
| 518 | + """ |
| 519 | + value_indices = {key: [] for key in keys} |
| 520 | + current_key = "" |
| 521 | + matched_key = None |
| 522 | + remaining_keys = keys.copy() |
| 523 | + saving_indices = False |
| 524 | + for i, token_object in enumerate(token_list): |
| 525 | + token = token_object.token |
| 526 | + if matched_key is not None: |
| 527 | + if saving_indices: |
| 528 | + if token == '","' or token == ',"': |
| 529 | + next_token = token_list[i + 1].token if i + 1 < len(token_list) else None |
| 530 | + if next_token is not None and any(key.startswith(next_token) for key in remaining_keys): |
| 531 | + value_indices[matched_key].append(i - 1) #stop saving indices when token is "," and the next token is the start of one of the keys |
| 532 | + matched_key = None |
| 533 | + saving_indices = False |
| 534 | + current_key = "" |
| 535 | + continue |
| 536 | + elif token_list[i + 1].token == '}': |
| 537 | + value_indices[matched_key].append(i) #stop saving indices when the next token is '}' |
| 538 | + matched_key = None |
| 539 | + saving_indices = False |
| 540 | + current_key = "" |
| 541 | + continue |
| 542 | + continue |
| 543 | + elif token == '":' or token == '":"': |
| 544 | + value_indices[matched_key].append(i + 1) #start saving indices after tokens '":' or '":"' |
| 545 | + saving_indices = True |
| 546 | + else: |
| 547 | + current_key += token |
| 548 | + for key in remaining_keys: |
| 549 | + if key.startswith(current_key): |
| 550 | + if current_key == key: |
| 551 | + matched_key = key #full key matched |
| 552 | + remaining_keys.remove(key) |
| 553 | + break |
| 554 | + else: |
| 555 | + current_key = "" |
| 556 | + return value_indices |
| 557 | + |
| 558 | +def extract_logprobs_from_indices(value_indices: dict[str, list[int]], token_list: list[Token]) -> dict[str, list[Any]]: |
| 559 | + logprobs = {key: [] for key in value_indices} |
| 560 | + for key, indices in value_indices.items(): |
| 561 | + start_idx = indices[0] |
| 562 | + end_idx = indices[-1] |
| 563 | + for i in range(start_idx, end_idx + 1): |
| 564 | + logprobs[key].append(token_list[i].top_logprobs[0].logprob) |
| 565 | + return logprobs |
516 | 566 |
|
517 | 567 | def create_pydantic_model(
|
518 | 568 | schema: dict[
|
@@ -607,35 +657,25 @@ def extract_logprobs_from_response(
|
607 | 657 | response: ChatCompletionResponse, extracted_data: dict[str, Any]
|
608 | 658 | ) -> dict[str, float | list[float]]:
|
609 | 659 | logprob_data = {}
|
610 |
| - tokens_info = response.choices[0].logprobs.content |
611 |
| - |
612 |
| - def process_numeric_values(extracted_data: dict[str, Any], path=""): |
613 |
| - for i in range(len(tokens_info) - 1): |
614 |
| - token = tokens_info[i].token |
615 |
| - |
616 |
| - if token.isdigit(): # Only process tokens that are numeric |
617 |
| - combined_token, combined_logprob = combine_tokens( |
618 |
| - tokens_info, i |
619 |
| - ) |
620 |
| - if combined_token_in_extracted_data( |
621 |
| - combined_token, extracted_data.values() |
622 |
| - ): # Checks if a combined token matches any numeric values in the extracted data. |
| 660 | + extracted_data_token = response.choices[0].logprobs.content |
| 661 | + |
| 662 | + def process_numeric_values(extracted_data: dict[str, Any], path=''): |
| 663 | + |
| 664 | + for i in range(len(extracted_data_token)-1): |
| 665 | + token = extracted_data_token[i].token |
| 666 | + if token.isdigit(): # Only process tokens that are numeric |
| 667 | + combined_token, combined_logprob = combine_tokens(extracted_data_token, i) |
| 668 | + if combined_token_in_extracted_data(combined_token, extracted_data.values()): #Checks if a combined token matches any numeric values in the extracted data. |
623 | 669 | key = find_key_by_value(
|
624 | 670 | combined_token, extracted_data
|
625 | 671 | ) # Finds the key in 'extracted_data' corresponding to a numeric value that matches the combined token.
|
626 | 672 | if key:
|
627 |
| - full_key = path + key |
628 |
| - logprob_data[full_key + "_prob_first_token"] = ( |
629 |
| - math.exp(tokens_info[i].logprob) |
630 |
| - ) |
631 |
| - logprob_data[full_key + "_prob_second_token"] = ( |
632 |
| - math.exp(tokens_info[i + 1].logprob) |
633 |
| - ) |
| 673 | + full_key = path + key |
| 674 | + logprob_data[full_key + '_prob_first_token'] = math.exp(extracted_data_token[i].logprob) |
| 675 | + logprob_data[full_key + '_prob_second_token'] = math.exp(extracted_data_token[i+1].logprob) |
634 | 676 |
|
635 |
| - toplogprobs_firsttoken = tokens_info[i].top_logprobs |
636 |
| - toplogprobs_secondtoken = tokens_info[ |
637 |
| - i + 1 |
638 |
| - ].top_logprobs |
| 677 | + toplogprobs_firsttoken = extracted_data_token[i].top_logprobs |
| 678 | + toplogprobs_secondtoken = extracted_data_token[i+1].top_logprobs |
639 | 679 |
|
640 | 680 | logprobs_first = [
|
641 | 681 | top_logprob.logprob
|
@@ -666,17 +706,15 @@ def traverse_and_extract(data: dict, path=""):
|
666 | 706 | return logprob_data
|
667 | 707 |
|
668 | 708 |
|
669 |
| -def combine_tokens( |
670 |
| - tokens_info: list[TokenLogprob], start_index: int |
671 |
| -) -> tuple[str, float]: |
672 |
| - combined_token = tokens_info[start_index].token |
673 |
| - combined_logprob = tokens_info[start_index].logprob |
| 709 | +def combine_tokens(extracted_data_token: list[TokenLogprob], start_index: int) -> tuple[str, float]: |
| 710 | + combined_token = extracted_data_token[start_index].token |
| 711 | + combined_logprob = extracted_data_token[start_index].logprob |
674 | 712 |
|
675 | 713 | # Keep combining tokens as long as the next token is a digit
|
676 |
| - for i in range(start_index + 1, len(tokens_info)): |
677 |
| - if not tokens_info[i].token.isdigit(): |
| 714 | + for i in range(start_index + 1, len(extracted_data_token)): |
| 715 | + if not extracted_data_token[i].token.isdigit(): |
678 | 716 | break
|
679 |
| - combined_token += tokens_info[i].token |
680 |
| - combined_logprob += tokens_info[i].logprob |
681 |
| - |
| 717 | + combined_token += extracted_data_token[i].token |
| 718 | + combined_logprob += extracted_data_token[i].logprob |
| 719 | + |
682 | 720 | return combined_token, combined_logprob
|
0 commit comments