Skip to content

Commit 66390dc

Browse files
authored
fix: align versions with inference endpoints (#8)
* chore: fight with macos * chore: fight macos * fix: make versions compatible with TGI * fix: remove extraneous parameter * chore: add debugging * fix: use property, not entire object * debug: use unified logger * fix: return dict instead of pydantic model * feat: handle non standard response behavior * chore: add debug * chore: normalize behaviors * chore: add logging
1 parent a8893b0 commit 66390dc

File tree

12 files changed

+381
-332
lines changed

12 files changed

+381
-332
lines changed

narrative_llm_tools/cli/cli.py

+7
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def main() -> None:
1616
)
1717
parser.add_argument("file", help="Path to the JSONL file to validate.")
1818
parser.add_argument("--threads", type=int, default=4, help="Number of threads to use")
19+
parser.add_argument("--clean", type=str, help="Output validated lines to specified file")
1920
parser.add_argument("--quiet", "-q", action="store_true", help="Suppress progress bar")
2021

2122
args = parser.parse_args()
@@ -57,6 +58,12 @@ def main() -> None:
5758
# Collect errors from results
5859
errors = [error for result in results for error in result.errors]
5960

61+
if args.clean:
62+
with open(args.clean, "w") as f:
63+
for result in results:
64+
if not result.errors:
65+
f.write(result.original_line + "\n")
66+
6067
if errors:
6168
print("Validation FAILED.\n")
6269
for err in sorted(errors, key=lambda x: int(x.split()[1].rstrip(":"))):

narrative_llm_tools/handlers/huggingface.py

+32-16
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import json
22
import logging
33
from collections.abc import Hashable
4-
from typing import Any, Literal, Optional, Protocol
4+
from typing import Any, Literal, Protocol
55

66
from pydantic import BaseModel
77
from torch import Tensor
88
from transformers import pipeline # type: ignore
99

10-
from narrative_llm_tools.rest_api_client.types import RestApiResponse
10+
from narrative_llm_tools.rest_api_client.types import RestApiResponse, ReturnToLlmBehavior
1111
from narrative_llm_tools.state.conversation_state import (
1212
ConversationMessage,
1313
ConversationState,
@@ -17,14 +17,15 @@
1717
from narrative_llm_tools.tools import Tool
1818
from narrative_llm_tools.utils.format_enforcer import get_format_enforcer
1919

20-
logger = logging.getLogger(__name__)
20+
logger = logging.getLogger("narrative-llm-tools")
2121
logger.setLevel(logging.WARNING)
2222

23+
2324
class HandlerResponse(BaseModel):
2425
"""Response from the handler."""
2526

2627
tool_calls: list[dict[str, Any]]
27-
warnings: Optional[list[str]]
28+
warnings: list[str] | None
2829

2930

3031
class ModelConfig(BaseModel):
@@ -34,7 +35,6 @@ class ModelConfig(BaseModel):
3435
path: str
3536
max_new_tokens: int = 4096
3637
device_map: str = "auto"
37-
low_cpu_mem_usage: bool = False
3838
begin_token: str = "<|begin_of_text|>"
3939
eot_token: str = "<|eot_id|>"
4040

@@ -111,14 +111,14 @@ class AuthenticationError(EndpointError):
111111

112112

113113
class EndpointHandler:
114-
def __init__(self, path: str = "", low_cpu_mem_usage: bool = False) -> None:
114+
def __init__(self, path: str = "") -> None:
115115
"""
116116
Initialize the EndpointHandler with the provided model path.
117117
118118
Args:
119119
path (str, optional): The path or identifier of the model. Defaults to "".
120120
"""
121-
self.config = ModelConfig(path=path, low_cpu_mem_usage=low_cpu_mem_usage)
121+
self.config = ModelConfig(path=path)
122122

123123
try:
124124
self.pipeline: Pipeline = self._create_pipeline()
@@ -135,11 +135,10 @@ def _create_pipeline(self) -> Pipeline:
135135
model=self.config.path,
136136
max_new_tokens=self.config.max_new_tokens,
137137
device_map=self.config.device_map,
138-
low_cpu_mem_usage=self.config.low_cpu_mem_usage,
139138
)
140139
return pipe # type: ignore
141140

142-
def __call__(self, data: dict[str, Any]) -> HandlerResponse:
141+
def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
143142
"""
144143
Generate model output given a conversation and optional tools/parameters.
145144
@@ -295,7 +294,7 @@ def __call__(self, data: dict[str, Any]) -> HandlerResponse:
295294
if not isinstance(tool_call, dict):
296295
raise ModelOutputError("Model output is not a list of tool calls.")
297296

298-
return HandlerResponse(tool_calls=return_msg, warnings=None)
297+
return HandlerResponse(tool_calls=return_msg, warnings=None).model_dump(exclude_none=True)
299298

300299
except (
301300
ValidationError,
@@ -331,6 +330,7 @@ def _execute_tool_calls(self, tool_calls: list[Tool], state: ConversationState)
331330
"""Execute tool calls and update conversation state."""
332331
logger.debug(f"Executing tool calls: {tool_calls}")
333332
rest_api_catalog = state.get_rest_api_catalog()
333+
logger.info(f"Rest API catalog: {rest_api_catalog}")
334334

335335
if not rest_api_catalog:
336336
logger.info("No rest API catalog is available, skipping all tool calls.")
@@ -344,21 +344,37 @@ def _execute_tool_calls(self, tool_calls: list[Tool], state: ConversationState)
344344
api_client = rest_api_catalog[tool.name]
345345
api_response: RestApiResponse = api_client.call(tool.parameters)
346346
api_client_behavior = (
347-
api_client.config.response_behavior.get(api_response.status)
348-
if api_client.config.response_behavior.get(api_response.status)
347+
api_client.config.response_behavior.get(str(api_response.status))
348+
if api_client.config.response_behavior.get(str(api_response.status))
349349
else api_client.config.response_behavior.get("default")
350350
)
351351

352-
if api_response.type == "json" and api_client_behavior == "return_to_llm":
353-
tool_responses.append(ToolResponse(name=tool.name, content=api_response.body))
352+
logger.info(f"API response: {api_response}, behavior: {api_client_behavior}")
353+
behavior_type = api_client_behavior.behavior_type if api_client_behavior else None
354+
355+
if (
356+
behavior_type
357+
and behavior_type == "return_to_llm"
358+
):
359+
llm_response_behavior: ReturnToLlmBehavior = api_client_behavior # type: ignore
360+
361+
response = (
362+
llm_response_behavior.response
363+
if llm_response_behavior.response
364+
else api_response.body
365+
)
366+
tool_responses.append(ToolResponse(name=tool.name, content=response))
354367
elif (
355-
api_response.type == "json" and api_client_behavior == "return_response_to_user"
368+
api_response.type == "json"
369+
and behavior_type
370+
and behavior_type == "return_response_to_user"
356371
):
357372
tool_responses.append(ToolResponse(name=tool.name, content=api_response.body))
358373
return_to_user = True
359374
elif (
360375
api_response.type == "json"
361-
and api_client_behavior == "return_request_to_user"
376+
and behavior_type
377+
and behavior_type == "return_request_to_user"
362378
and api_response.request
363379
):
364380
tool_responses.append(

narrative_llm_tools/rest_api_client/rest_api_client.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
RestApiResponse,
1818
)
1919

20-
logger = logging.getLogger(__name__)
20+
logger = logging.getLogger("narrative-llm-tools")
2121

2222

2323
class RestApiClient(BaseModel):

narrative_llm_tools/rest_api_client/types.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __eq__(self, other: Any) -> bool:
2828

2929
class Behavior(BaseModel):
3030
behavior_type: str
31+
response: str | None = None
3132

3233
def __hash__(self) -> int:
3334
return hash(self.behavior_type)
@@ -38,17 +39,17 @@ def __eq__(self, other: Any) -> bool:
3839

3940
class ReturnToLlmBehavior(Behavior):
4041
behavior_type: Literal["return_to_llm"] = "return_to_llm"
41-
llm_response: str | None = None
42+
response: str | None = None
4243

4344

4445
class ReturnResponseToUserBehavior(Behavior):
4546
behavior_type: Literal["return_response_to_user"] = "return_response_to_user"
46-
user_response: str | None = None
47+
response: str | None = None
4748

4849

4950
class ReturnRequestToUserBehavior(Behavior):
5051
behavior_type: Literal["return_request_to_user"] = "return_request_to_user"
51-
user_response: str | None = None
52+
response: str | None = None
5253

5354

5455
class RestApiResponse(BaseModel):
@@ -62,8 +63,8 @@ class RestApiConfig(BaseModel):
6263
url: str
6364
method: HttpMethod
6465
auth: BearerTokenAuth | None = None
65-
response_behavior: dict[int | Literal["default"], Behavior] = {
66-
"default": ReturnToLlmBehavior(llm_response=None),
66+
response_behavior: dict[str | Literal["default"], Behavior] = {
67+
"default": ReturnToLlmBehavior(response=None),
6768
}
6869
query_path: str | None = None
6970
parameter_location: ParameterLocation | None = None

narrative_llm_tools/state/conversation.py

+32-23
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,13 @@ def validate_conversation_structure(self) -> "Conversation":
2121
conv = self.conversations
2222

2323
if len(conv) < 3:
24-
all_errors.append(
25-
f"'conversation' must have at least 3 messages. Found {len(conv)}."
26-
)
24+
all_errors.append(f"'conversation' must have at least 3 messages. Found {len(conv)}.")
2725

2826
system_count = 0
2927
tool_catalog_count = 0
3028
last_role = None
3129
found_system = False
32-
tool_catalog_schema = None
30+
tool_catalog_schema = None
3331
assistant_call_indices = []
3432
user_count = 0
3533

@@ -134,15 +132,22 @@ def validate_conversation_structure(self) -> "Conversation":
134132
_, prev_arr = parse_json_array(prev_content, tool_catalog_schema)
135133

136134
if len(arr) != len(prev_arr):
137-
msg_errors.append("tool_response array length must match the preceding 'assistant'/'tool_call' array.")
135+
msg_errors.append(
136+
"tool_response array length must match the preceding "
137+
"'assistant'/'tool_call' array."
138+
)
138139
else:
139-
for idx, (response, prev_call) in enumerate(zip(arr, prev_arr)):
140+
for idx, (response, prev_call) in enumerate(
141+
zip(arr, prev_arr, strict=False)
142+
):
140143
structure_errors = validate_tool_response_structure(response, idx)
141144
if structure_errors:
142145
msg_errors.extend(structure_errors)
143146
continue
144-
145-
matching_errors = validate_tool_response_matching(response, prev_call, idx)
147+
148+
matching_errors = validate_tool_response_matching(
149+
response, prev_call, idx
150+
)
146151
msg_errors.extend(matching_errors)
147152

148153
all_errors.extend(msg_errors)
@@ -218,20 +223,25 @@ def validate_conversation_object(obj: Any, line_number: int) -> list[str]:
218223
class ValidationResult:
219224
line_number: int
220225
errors: list[str]
226+
original_line: str
221227

222228

223229
def validate_line(args: tuple[str, int]) -> ValidationResult:
224230
line, line_number = args
225231
if not line.strip():
226-
return ValidationResult(line_number, [f"Line {line_number}: Empty line is not allowed."])
232+
return ValidationResult(
233+
line_number,
234+
[f"Line {line_number}: Empty line is not allowed."],
235+
line,
236+
)
227237

228238
try:
229239
conversation_obj = json.loads(line)
230240
except json.JSONDecodeError as e:
231-
return ValidationResult(line_number, [f"Line {line_number}: Invalid JSON - {str(e)}"])
241+
return ValidationResult(line_number, [f"Line {line_number}: Invalid JSON - {str(e)}"], line)
232242

233243
line_errors = validate_conversation_object(conversation_obj, line_number)
234-
return ValidationResult(line_number, line_errors)
244+
return ValidationResult(line_number, line_errors, line)
235245

236246

237247
def extract_enumerated_names(tool_catalog_schema: Mapping[str, Any]) -> set[str]:
@@ -291,33 +301,32 @@ def validate_tool_catalog_schema(schema_str: str) -> tuple[Any, list[str]]:
291301
return None, errors
292302

293303

294-
def validate_tool_response_structure(response: dict, idx: int) -> list[str]:
304+
def validate_tool_response_structure(response: dict[str, Any], idx: int) -> list[str]:
295305
"""Validate the structure of a single tool response object."""
296306
errors = []
297-
307+
298308
if not isinstance(response, dict):
299-
errors.append(f"Response at index {idx} must be an object")
309+
errors.append(f"Response at index {idx} must be an object") # type: ignore[unreachable]
300310
return errors
301311

302312
if set(response.keys()) != {"name", "content"}:
303-
errors.append(
304-
f"Response at index {idx} must have exactly 'name' and 'content' fields"
305-
)
313+
errors.append(f"Response at index {idx} must have exactly 'name' and 'content' fields")
306314
return errors
307315

308316
if not isinstance(response["name"], str) or not isinstance(response["content"], str):
309-
errors.append(
310-
f"Response at index {idx}: 'name' and 'content' must be strings"
311-
)
312-
317+
errors.append(f"Response at index {idx}: 'name' and 'content' must be strings")
318+
313319
return errors
314320

315-
def validate_tool_response_matching(response: dict, prev_call: dict, idx: int) -> list[str]:
321+
322+
def validate_tool_response_matching(
323+
response: dict[str, Any], prev_call: dict[str, Any], idx: int
324+
) -> list[str]:
316325
"""Validate that a tool response matches its corresponding tool call."""
317326
errors = []
318327
if response["name"] != prev_call.get("name"):
319328
errors.append(
320329
f"Response at index {idx}: name '{response['name']}' does not match "
321330
f"tool call name '{prev_call.get('name')}'"
322331
)
323-
return errors
332+
return errors

narrative_llm_tools/state/conversation_state.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from narrative_llm_tools.rest_api_client.rest_api_client import RestApiClient
99
from narrative_llm_tools.tools.json_schema_tools import JsonSchemaTools, Tool
1010

11-
logger = logging.getLogger(__name__)
11+
logger = logging.getLogger("narrative-llm-tools")
1212

1313

1414
class ConversationMessage(BaseModel):
@@ -285,6 +285,7 @@ def _handle_tool_call(self, message: ConversationMessage) -> None:
285285
Handles adding a tool_call message and performing relevant state transitions.
286286
"""
287287
tool_calls = self.parse_tool_calls_content(message.content)
288+
logger.info(f"Handling tool call: {message}")
288289
self.raw_messages.append(message)
289290

290291
if self.responded_to_user(message.content):
@@ -300,6 +301,7 @@ def _handle_tool_response(self, message: ConversationMessage) -> None:
300301
"""
301302
Handles adding a tool response message and updating state accordingly.
302303
"""
304+
logger.info(f"Handling tool response: {message}")
303305
self.raw_messages.append(message)
304306

305307
if self.status == ConversationStatus.WAITING_TOOL_RESPONSE:

narrative_llm_tools/state/messages.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,8 @@ def validate_value(cls, v: Any) -> str:
4545
if isinstance(v, str):
4646
return v
4747
raise ValueError(f"Message 'content' must be a string, got {type(v)}")
48-
49-
model_config = {
50-
'extra': 'forbid'
51-
}
48+
49+
model_config = {"extra": "forbid"}
5250

5351

5452
class SystemMessage(BaseMessage):
@@ -126,6 +124,4 @@ def validate_response(cls, v: Any) -> str:
126124
class MessageWrapper(BaseModel):
127125
message: Message
128126

129-
model_config = {
130-
'extra': 'forbid'
131-
}
127+
model_config = {"extra": "forbid"}

narrative_llm_tools/tools/json_schema_tools.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from narrative_llm_tools.rest_api_client.rest_api_client import RestApiClient
77
from narrative_llm_tools.rest_api_client.types import RestApiConfig
88

9-
logger = logging.getLogger(__name__)
9+
logger = logging.getLogger("narrative-llm-tools")
1010

1111

1212
class NameProperty(BaseModel):

narrative_llm_tools/utils/format_enforcer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from narrative_llm_tools.tools.json_schema_tools import JsonSchemaTools
1515

16-
logger = logging.getLogger(__name__)
16+
logger = logging.getLogger("narrative-llm-tools")
1717

1818

1919
class TransformersPrefixAllowedTokensFn(Protocol):

0 commit comments

Comments
 (0)