Skip to content

Commit cba95a2

Browse files
committed
fix tests
1 parent d867f32 commit cba95a2

File tree

1 file changed

+28
-14
lines changed

1 file changed

+28
-14
lines changed

backend/app/tests/api/routes/test_dde.py

+28-14
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from app.models import DocumentDataExtractorOut
55
from app.services.object_store import Documents
66
import pytest
7+
import json
78
from typing import Generator, Any
89
from app.api.routes.dde import extract_tokens_indices_for_each_key, extract_logprobs_from_indices
910

@@ -12,12 +13,24 @@
1213
def document_data_extractor(
1314
client: TestClient, superuser_token_headers: dict[str, str]
1415
) -> Generator[dict[str, Any], None, None]:
15-
fake_name = "Test dde"
16-
fake_prompt = "Extract the name from document"
16+
fake_name = "Test_dde"
17+
fake_prompt = "Extract the adresse from document"
18+
response_template = {
19+
"adresse": [
20+
"str",
21+
"required"
22+
]
23+
}
1724

18-
payload = {"name": fake_name, "prompt": fake_prompt}
25+
payload = {"name": fake_name, "prompt": fake_prompt, 'response_template':response_template}
1926

2027
headers = superuser_token_headers
28+
29+
response = client.post(
30+
f"{settings.API_V1_STR}/dde",
31+
headers=headers,
32+
json=payload,
33+
)
2134

2235
fake_dde = DocumentDataExtractorOut(
2336
id=1,
@@ -26,11 +39,7 @@ def document_data_extractor(
2639
timestamp="2024-10-03T09:31:33.748765",
2740
owner_id=1,
2841
document_data_examples=[],
29-
)
30-
response = client.post(
31-
f"{settings.API_V1_STR}/dde",
32-
headers=headers,
33-
json=payload,
42+
response_template = json.dumps(response_template)
3443
)
3544

3645
assert response.status_code == 200
@@ -39,6 +48,7 @@ def document_data_extractor(
3948
assert response_data["name"] == fake_dde.name
4049
assert response_data["prompt"] == fake_dde.prompt
4150
assert response_data["owner_id"] == fake_dde.owner_id
51+
assert response_data["response_template"] == fake_dde.response_template
4252
assert len(response_data["document_data_examples"]) == 0
4353

4454
yield response_data
@@ -73,14 +83,15 @@ def test_update_document_data_extractor(
7383
superuser_token_headers: dict[str, str],
7484
document_data_extractor: dict[str, Any],
7585
):
76-
updated_name = "Updated dde"
86+
updated_name = "Updated_dde"
7787
dde_id = document_data_extractor["id"]
7888

7989
document_data_extractor["name"] = updated_name
8090

8191
update_payload = {
8292
"name": updated_name,
8393
"prompt": document_data_extractor["prompt"],
94+
"response_template": json.loads(document_data_extractor["response_template"])
8495
}
8596

8697
response = client.put(
@@ -97,6 +108,7 @@ def test_update_document_data_extractor(
97108
assert response_data["prompt"] == document_data_extractor["prompt"]
98109
assert response_data["timestamp"] == document_data_extractor["timestamp"]
99110
assert response_data["owner_id"] == document_data_extractor["owner_id"]
111+
assert response_data["response_template"] == document_data_extractor["response_template"]
100112
assert len(response_data["document_data_examples"]) == 0
101113

102114

@@ -106,11 +118,11 @@ def test_create_document_data_example(client: TestClient, superuser_token_header
106118
with patch.object(Documents, "exists", return_value=True):
107119
start_page = 0
108120
end_page = 2
109-
info_to_extract = {"name": "Marta"}
121+
info_to_extract = {"adresse": "3 RUE DE ROUVRAY"}
110122

111123
data_doc = {
112124
"document_id": "abc",
113-
"data": str(info_to_extract),
125+
"data": info_to_extract,
114126
"document_data_extractor_id": document_data_extractor["id"],
115127
"start_page": start_page,
116128
"end_page": end_page,
@@ -128,7 +140,7 @@ def test_create_document_data_example(client: TestClient, superuser_token_header
128140
assert response.status_code == 200
129141
response_data = response.json()
130142
assert response_data["document_id"] == data_doc["document_id"]
131-
assert response_data["data"] == data_doc["data"]
143+
assert response_data["data"] == json.dumps(data_doc["data"])
132144
assert (
133145
response_data["document_data_extractor_id"]
134146
== data_doc["document_data_extractor_id"]
@@ -143,7 +155,7 @@ def test_update_document_data_example(
143155
):
144156
name_dde = document_data_extractor["name"]
145157
id_example = document_data_extractor["document_data_examples"][0]["id"]
146-
updated_data = "{'name': 'Sarah'}"
158+
updated_data = {'adresse': '2 ALLEE DES HORTENSIAS'}
147159

148160
update_payload = {
149161
"document_id": document_data_extractor["document_data_examples"][0][
@@ -153,6 +165,8 @@ def test_update_document_data_example(
153165
"document_data_extractor_id": document_data_extractor[
154166
"document_data_examples"
155167
][0]["document_data_extractor_id"],
168+
"start_page":document_data_extractor["document_data_examples"][0]['start_page'],
169+
"end_page":document_data_extractor["document_data_examples"][0]['end_page']
156170
}
157171

158172
document_data_extractor["document_data_examples"][0]["data"] = updated_data
@@ -168,7 +182,7 @@ def test_update_document_data_example(
168182

169183
response_data = response.json()
170184
assert response_data["document_id"] == "abc"
171-
assert response_data["data"] == updated_data
185+
assert response_data["data"] == json.dumps(updated_data)
172186
assert response_data["document_data_extractor_id"] == 1
173187
assert response_data["id"] == 1
174188

0 commit comments

Comments
 (0)