Skip to content

Commit 301c3cb

Browse files
authored
Merge pull request #54 from arena-ai/46-update-log-probs-for-the-new-response_template
update log probs for the new response template
2 parents 3493c85 + 3055fd1 commit 301c3cb

File tree

7 files changed

+221
-178
lines changed

7 files changed

+221
-178
lines changed

backend/app/api/routes/dde.py

+4-105
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from openai.lib._pydantic import to_strict_json_schema
3333
from app.handlers.prompt_for_image import full_prompt_from_image
3434
from app.handlers.prompt_for_text import full_prompt_from_text
35+
from app.handlers.logprobs import map_characters_to_token_indices, extract_json_data
3536

3637
from app.models import ContentType
3738

@@ -484,109 +485,7 @@ async def extract_from_file(
484485
json_string = extracted_data[
485486
extracted_data.find("{") : extracted_data.rfind("}") + 1
486487
]
487-
#token_indices=map_characters_to_token_indices(extracted_data_token)
488-
#regex_spans=find_value_spans(extracted_data)
489-
#logprobs_sum=get_token_spans_and_logprobs(token_indices, regex_spans, extracted_data_token)
490-
return {"extracted_data": json.loads(json_string), "extracted_logprobs": {}, "identifier": identifier}
491-
492-
def map_characters_to_token_indices(extracted_data_token: list[TokenLogprob]) -> list[int]:
493-
"""
494-
Maps each character in the JSON string output to its corresponding token index.
495-
496-
Args:
497-
extracted_data_token : A list of `TokenLogprob` objects, where each object represents a token and its data (such as the logprobs)
498-
499-
Returns:
500-
A list of integers where each position corresponds to a character in the concatenated JSON string,
501-
and the integer at each position is the index of the token responsible for generating that specific character in the JSON string.
502-
503-
Example:
504-
--------
505-
Given `extracted_data_token = [TokenLogprob(token='{'), TokenLogprob(token='"key1"'), TokenLogprob(token=': '), TokenLogprob(token='"value1"'), TokenLogprob(token='}')]`
506-
the JSON output is : '{"key1": "value1"}' and the function will return the list [0, 1, 1, 1, 1, 1, 1, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4]
507-
508-
"""
509-
510-
json_output = "".join(token_data.token for token_data in extracted_data_token)
511-
512-
token_indices = [None] * len(json_output)
513-
current_char_pos = 0
514-
515-
for token_idx, token_data in enumerate(extracted_data_token):
516-
token_text = token_data.token
517-
for char_pos in range(len(token_text)):
518-
token_indices[current_char_pos] = token_idx
519-
current_char_pos += 1
520-
521-
return token_indices
522-
523-
def find_value_spans(json_string: str) -> list[tuple[str, tuple[int, int]]]:
524-
"""
525-
Extracts spans (start and end positions) of values (both strings or numbers) within a JSON-formatted string.
526-
527-
Args:
528-
json_string : A JSON-formatted string where values are paired with keys and separated by colons.
529-
530-
Returns:
531-
A list of tuples, where each tuple contains the matched value of the key and a tuple with two integers (start, end), representing the character span of the respective value within `json_string`.
532-
533-
Example:
534-
--------
535-
Given `json_string = '{"key1": "value1"}'`, the function will return:
536-
[("key1", (9, 17))]
537-
"""
538-
539-
pattern = r'"([^"\n}]+)"\s*:\s*("[^"\n]+"|[-0-9.eE]+)\s*'
540-
541-
matches = []
542-
for match in re.finditer(pattern, json_string):
543-
value = match.group(1)
544-
start = match.start(2)
545-
end = match.end(2)
546-
matches.append((value, (start, end)))
547-
return matches
548-
549-
550-
def get_token_spans_and_logprobs(
551-
token_indices: list[int],
552-
value_spans: list[tuple[str, tuple[int, int]]],
553-
extracted_data_token: list[TokenLogprob]
554-
) -> dict[str,float]:
555-
"""
556-
Identifies the token indices for each value span and extracts the log probabilities for these tokens, summing them to provide an overall log probability for each value span.
557-
558-
Args:
559-
token_indices : A list mapping each character in the json string to a token index
560-
value_spans : A list of tuples, each containing the value of the key and the character sapn within the JSON string
561-
extracted_data_token : A list of `TokenLogprob` objects, each containing a token and its log probability data, where the index of each item corresponds to its token index.
562-
563-
Returns:
564-
A dictionary mapping each key to the summed log probability of all the tokens that cotntains part of its value.
565-
566-
567-
Example:
568-
--------
569-
Given:
570-
- `token_indices = [0, 1, 1, 1, 1, 1, 1, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4]`, which maps each character to a token index.
571-
- `value_spans = [("key1", (9, 17))]`.
572-
- `extracted_data_token = [TokenLogprob(token="{", logprob=-1.5), TokenLogprob(token="key1", logprob=-1), TokenLogprob(token=": ", logprob=-1), TokenLogprob(token="value1", logprob=-1.5), TokenLogprob(token="}", logprob=-0.8)]`
573-
574-
The function will return:
575-
{"key1": -1.5}
576-
"""
577-
logprobs_for_values = {}
578-
579-
for value, (start, end) in value_spans:
580-
token_start = token_indices[start]
581-
token_end = token_indices[end]
582-
logprobs = [
583-
extracted_data_token[token_idx].logprob
584-
for token_idx in range(token_start, token_end)
585-
]
586-
logprobs_for_values[value] = sum(logprobs)
587-
588-
return logprobs_for_values
589-
590-
591-
488+
token_indices=map_characters_to_token_indices(extracted_data_token)
489+
extracted_data=extract_json_data(json_string, extracted_data_token, token_indices)
490+
return {"extracted_data": extracted_data, "identifier": identifier}
592491

backend/app/handlers/logprobs.py

+140
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
from lark import Lark, Transformer, v_args, Tree, Token
2+
from lark.tree import Meta
3+
from pydantic import BaseModel
4+
from typing import Any, Optional
5+
import math
6+
from app.lm.models.chat_completion import TokenLogprob
7+
8+
class HasProb(BaseModel):
9+
value: Any
10+
start: int
11+
end: int
12+
logprob: float
13+
prob: float
14+
15+
def map_characters_to_token_indices(extracted_data_token: list[TokenLogprob]) -> list[int]:
16+
"""
17+
Maps each character in the JSON string output to its corresponding token index.
18+
19+
Args:
20+
extracted_data_token : A list of `TokenLogprob` objects, where each object represents a token and its data (such as the logprobs)
21+
22+
Returns:
23+
A list of integers where each position corresponds to a character in the concatenated JSON string,
24+
and the integer at each position is the index of the token responsible for generating that specific character in the JSON string.
25+
26+
Example:
27+
--------
28+
Given `extracted_data_token = [TokenLogprob(token='{'), TokenLogprob(token='"key1"'), TokenLogprob(token=': '), TokenLogprob(token='"value1"'), TokenLogprob(token='}')]`
29+
the JSON output is : '{"key1": "value1"}' and the function will return the list [0, 1, 1, 1, 1, 1, 1, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4]
30+
31+
"""
32+
33+
json_output = "".join(token_data.token for token_data in extracted_data_token)
34+
35+
token_indices = [None] * len(json_output)
36+
current_char_pos = 0
37+
38+
for token_idx, token_data in enumerate(extracted_data_token):
39+
token_text = token_data.token
40+
for char_pos in range(len(token_text)):
41+
token_indices[current_char_pos] = token_idx
42+
current_char_pos += 1
43+
44+
return token_indices
45+
46+
# Define a grammar for JSON
47+
json_grammar = r"""
48+
start: value
49+
50+
?value: object #'?' is a Lark convention indicating that the rule can return the value directly instead of creating a separate parse tree node.
51+
| array
52+
| string
53+
| SIGNED_NUMBER -> number #'-> number' specifies an alias for the rule
54+
| "true"
55+
| "false"
56+
| "null"
57+
58+
array : "[" [value ("," value)*] "]"
59+
object : "{" [pair ("," pair)*] "}"
60+
pair : key ":" value
61+
key : ESCAPED_STRING
62+
63+
string : ESCAPED_STRING
64+
65+
%import common.ESCAPED_STRING
66+
%import common.SIGNED_NUMBER
67+
%import common.WS
68+
%ignore WS
69+
"""
70+
71+
@v_args(meta=True)
72+
class Extractor(Transformer):
73+
def __init__(self, tokens: list[TokenLogprob], token_indices: list[int]):
74+
super().__init__()
75+
self.tokens = tokens
76+
self.token_indices = token_indices
77+
78+
def _compute_logprob_sum(self, start: int, end: int) -> float:
79+
token_start = self.token_indices[start]
80+
token_end = self.token_indices[end]
81+
sum_logporb= sum(self.tokens[i].logprob for i in range(token_start, token_end))
82+
return sum_logporb
83+
84+
def number(self, meta: Meta, children: list[Token]) -> HasProb:
85+
logprob_sum = self._compute_logprob_sum(meta.start_pos, meta.end_pos)
86+
prob=math.exp(logprob_sum)* 100
87+
return HasProb(value=float(children[0]), start=meta.start_pos, end=meta.end_pos, logprob=logprob_sum, prob=prob)
88+
89+
def string(self, meta: Meta, children: list[Token]) -> HasProb:
90+
logprob_sum = self._compute_logprob_sum(meta.start_pos, meta.end_pos)
91+
prob=math.exp(logprob_sum)* 100
92+
return HasProb(value=children[0][1:-1], start=meta.start_pos, end=meta.end_pos, logprob=logprob_sum, prob=prob)
93+
94+
def true(self, meta: Meta, children: list[Token]) -> HasProb:
95+
logprob_sum = self._compute_logprob_sum(meta.start_pos, meta.end_pos)
96+
prob=math.exp(logprob_sum)* 100
97+
return HasProb(value=True, start=meta.start_pos, end=meta.end_pos, logprob=logprob_sum, prob=prob)
98+
99+
def false(self, meta: Meta, children: list[Token]) -> HasProb:
100+
logprob_sum = self._compute_logprob_sum(meta.start_pos, meta.end_pos)
101+
prob=math.exp(logprob_sum)* 100
102+
return HasProb(value=False, start=meta.start_pos, end=meta.end_pos, logprob=logprob_sum, prob=prob)
103+
104+
def null(self, meta: Meta, children: list[Token]):
105+
return None
106+
107+
def array(self, meta: Meta, children:list[dict[str, Any] | Any]) -> list[dict[str,Any] | Any]:
108+
return [child.value if isinstance(child, HasProb) else child for child in children]
109+
110+
def object(self, meta: Meta, children:list[tuple[str,Any]]) -> dict[str,Any]:
111+
result = {}
112+
for key, value in children:
113+
if isinstance(value, HasProb):
114+
result[key]=value.value
115+
result[f"{key}_logprob"]=value.logprob
116+
result[f"{key}_probability"]=value.prob
117+
else:
118+
result[key]=value
119+
return result
120+
121+
def pair(self, meta: Meta, children:list[str, Any]) -> tuple[str, Any]:
122+
value = children[1]
123+
key = children[0]
124+
if isinstance(value, Tree) and not value.children: #['b', Tree(Token('RULE', 'value'), [])]
125+
value = None
126+
return key, value
127+
128+
def key(self, meta: Meta, children: list[Token]) -> str:
129+
return children[0][1:-1]
130+
131+
def start(self, meta: Meta, children:list[dict[str,Any]]) -> dict[str, Any]:
132+
return children[0]
133+
134+
json_parser = Lark(json_grammar, parser="lalr", propagate_positions=True, maybe_placeholders=False)
135+
136+
def extract_json_data(json_string: str, tokens: list[TokenLogprob], token_indices: list[int]) -> dict[str,Any]:
137+
tree = json_parser.parse(json_string)
138+
extractor = Extractor(tokens, token_indices)
139+
return extractor.transform(tree)
140+

backend/app/tests/api/routes/test_dde.py

+1-71
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,6 @@
66
import pytest
77
import json
88
from typing import Generator, Any
9-
from app.api.routes.dde import (
10-
map_characters_to_token_indices,
11-
find_value_spans,
12-
get_token_spans_and_logprobs,
13-
)
14-
159

1610
@pytest.fixture(scope="module")
1711
def document_data_extractor(
@@ -201,69 +195,5 @@ def test_update_document_data_example(
201195
assert response_data["data"] == json.dumps(updated_data)
202196
assert response_data["document_data_extractor_id"] == 1
203197
assert response_data["id"] == 1
204-
205-
206-
class TokenLogprob:
207-
def __init__(self, token: str, logprob: float):
208-
self.token = token
209-
self.logprob = logprob
210-
211-
@pytest.fixture
212-
def data_token():
213-
return [
214-
TokenLogprob(token='{', logprob = -1.9365e-07), # Token index 0
215-
TokenLogprob(token='"key1"', logprob = -0.01117), # Token index 1
216-
TokenLogprob(token=': "', logprob = -0.00279), # Token index 2
217-
TokenLogprob(token='val', logprob = -1.1472e-06), # Token index 3
218-
TokenLogprob(token='ue1"', logprob = -0.00851), # Token index 4
219-
TokenLogprob(token=', "', logprob = -0.00851), # Token index 5
220-
TokenLogprob(token='key2', logprob = -0.00851), # Token index 6
221-
TokenLogprob(token='": ', logprob = -0.00851), # Token index 7
222-
TokenLogprob(token='42', logprob = -0.00851), # Token index 8
223-
TokenLogprob(token='}', logprob = -1.265e-07) # Token index 9
224-
]
225-
226-
@pytest.fixture
227-
def token_indices():
228-
return [0,
229-
1, 1, 1, 1, 1, 1,
230-
2, 2, 2,
231-
3, 3, 3,
232-
4, 4, 4, 4,
233-
5, 5, 5,
234-
6, 6, 6, 6,
235-
7, 7, 7,
236-
8, 8,
237-
9]
238-
239-
@pytest.fixture
240-
def sample_json_string():
241-
return '{"key1": "value1", "key2": 42}'
242-
243-
@pytest.fixture
244-
def value_spans():
245-
return [
246-
("key1", (9, 17)),
247-
("key2", (27, 29))
248-
]
249-
250-
def test_map_characters_to_token_indices(data_token, token_indices):
251-
result = map_characters_to_token_indices(data_token)
252-
253-
assert result == token_indices
254-
assert result.count(1) == len(data_token[1].token)
255-
256-
def test_find_value_spans(sample_json_string, value_spans):
257-
result = find_value_spans(sample_json_string)
258-
259-
assert result == value_spans
260-
assert sample_json_string[9:17] == '"value1"'
261-
assert sample_json_string[27:29] == '42'
262-
263-
def test_get_token_spans_and_logprobs(token_indices, value_spans, data_token):
264-
expected_output = {"key1": -0.0113011472, "key2": -0.00851}
265-
result = get_token_spans_and_logprobs(token_indices, value_spans, data_token)
266-
267-
assert result == expected_output
268-
198+
269199
# TODO: test extract_from_file

0 commit comments

Comments
 (0)