4
4
from app .models import DocumentDataExtractorOut
5
5
from app .services .object_store import Documents
6
6
import pytest
7
+ import json
7
8
from typing import Generator , Any
8
9
from app .api .routes .dde import extract_tokens_indices_for_each_key , extract_logprobs_from_indices
9
10
12
13
def document_data_extractor (
13
14
client : TestClient , superuser_token_headers : dict [str , str ]
14
15
) -> Generator [dict [str , Any ], None , None ]:
15
- fake_name = "Test dde"
16
- fake_prompt = "Extract the name from document"
16
+ fake_name = "Test_dde"
17
+ fake_prompt = "Extract the adresse from document"
18
+ response_template = {
19
+ "adresse" : [
20
+ "str" ,
21
+ "required"
22
+ ]
23
+ }
17
24
18
- payload = {"name" : fake_name , "prompt" : fake_prompt }
25
+ payload = {"name" : fake_name , "prompt" : fake_prompt , 'response_template' : response_template }
19
26
20
27
headers = superuser_token_headers
28
+
29
+ response = client .post (
30
+ f"{ settings .API_V1_STR } /dde" ,
31
+ headers = headers ,
32
+ json = payload ,
33
+ )
21
34
22
35
fake_dde = DocumentDataExtractorOut (
23
36
id = 1 ,
@@ -26,11 +39,7 @@ def document_data_extractor(
26
39
timestamp = "2024-10-03T09:31:33.748765" ,
27
40
owner_id = 1 ,
28
41
document_data_examples = [],
29
- )
30
- response = client .post (
31
- f"{ settings .API_V1_STR } /dde" ,
32
- headers = headers ,
33
- json = payload ,
42
+ response_template = json .dumps (response_template )
34
43
)
35
44
36
45
assert response .status_code == 200
@@ -39,6 +48,7 @@ def document_data_extractor(
39
48
assert response_data ["name" ] == fake_dde .name
40
49
assert response_data ["prompt" ] == fake_dde .prompt
41
50
assert response_data ["owner_id" ] == fake_dde .owner_id
51
+ assert response_data ["response_template" ] == fake_dde .response_template
42
52
assert len (response_data ["document_data_examples" ]) == 0
43
53
44
54
yield response_data
@@ -73,14 +83,15 @@ def test_update_document_data_extractor(
73
83
superuser_token_headers : dict [str , str ],
74
84
document_data_extractor : dict [str , Any ],
75
85
):
76
- updated_name = "Updated dde "
86
+ updated_name = "Updated_dde "
77
87
dde_id = document_data_extractor ["id" ]
78
88
79
89
document_data_extractor ["name" ] = updated_name
80
90
81
91
update_payload = {
82
92
"name" : updated_name ,
83
93
"prompt" : document_data_extractor ["prompt" ],
94
+ "response_template" : json .loads (document_data_extractor ["response_template" ])
84
95
}
85
96
86
97
response = client .put (
@@ -97,6 +108,7 @@ def test_update_document_data_extractor(
97
108
assert response_data ["prompt" ] == document_data_extractor ["prompt" ]
98
109
assert response_data ["timestamp" ] == document_data_extractor ["timestamp" ]
99
110
assert response_data ["owner_id" ] == document_data_extractor ["owner_id" ]
111
+ assert response_data ["response_template" ] == document_data_extractor ["response_template" ]
100
112
assert len (response_data ["document_data_examples" ]) == 0
101
113
102
114
@@ -106,11 +118,11 @@ def test_create_document_data_example(client: TestClient, superuser_token_header
106
118
with patch .object (Documents , "exists" , return_value = True ):
107
119
start_page = 0
108
120
end_page = 2
109
- info_to_extract = {"name " : "Marta " }
121
+ info_to_extract = {"adresse " : "3 RUE DE ROUVRAY " }
110
122
111
123
data_doc = {
112
124
"document_id" : "abc" ,
113
- "data" : str ( info_to_extract ) ,
125
+ "data" : info_to_extract ,
114
126
"document_data_extractor_id" : document_data_extractor ["id" ],
115
127
"start_page" : start_page ,
116
128
"end_page" : end_page ,
@@ -128,7 +140,7 @@ def test_create_document_data_example(client: TestClient, superuser_token_header
128
140
assert response .status_code == 200
129
141
response_data = response .json ()
130
142
assert response_data ["document_id" ] == data_doc ["document_id" ]
131
- assert response_data ["data" ] == data_doc ["data" ]
143
+ assert response_data ["data" ] == json . dumps ( data_doc ["data" ])
132
144
assert (
133
145
response_data ["document_data_extractor_id" ]
134
146
== data_doc ["document_data_extractor_id" ]
@@ -143,7 +155,7 @@ def test_update_document_data_example(
143
155
):
144
156
name_dde = document_data_extractor ["name" ]
145
157
id_example = document_data_extractor ["document_data_examples" ][0 ]["id" ]
146
- updated_data = "{'name ': 'Sarah'}"
158
+ updated_data = { 'adresse ' : '2 ALLEE DES HORTENSIAS' }
147
159
148
160
update_payload = {
149
161
"document_id" : document_data_extractor ["document_data_examples" ][0 ][
@@ -153,6 +165,8 @@ def test_update_document_data_example(
153
165
"document_data_extractor_id" : document_data_extractor [
154
166
"document_data_examples"
155
167
][0 ]["document_data_extractor_id" ],
168
+ "start_page" :document_data_extractor ["document_data_examples" ][0 ]['start_page' ],
169
+ "end_page" :document_data_extractor ["document_data_examples" ][0 ]['end_page' ]
156
170
}
157
171
158
172
document_data_extractor ["document_data_examples" ][0 ]["data" ] = updated_data
@@ -168,7 +182,7 @@ def test_update_document_data_example(
168
182
169
183
response_data = response .json ()
170
184
assert response_data ["document_id" ] == "abc"
171
- assert response_data ["data" ] == updated_data
185
+ assert response_data ["data" ] == json . dumps ( updated_data )
172
186
assert response_data ["document_data_extractor_id" ] == 1
173
187
assert response_data ["id" ] == 1
174
188
0 commit comments