1
1
import json
2
2
import logging
3
3
from collections .abc import Hashable
4
- from typing import Any , Literal , Optional , Protocol
4
+ from typing import Any , Literal , Protocol
5
5
6
6
from pydantic import BaseModel
7
7
from torch import Tensor
8
8
from transformers import pipeline # type: ignore
9
9
10
- from narrative_llm_tools .rest_api_client .types import RestApiResponse
10
+ from narrative_llm_tools .rest_api_client .types import RestApiResponse , ReturnToLlmBehavior
11
11
from narrative_llm_tools .state .conversation_state import (
12
12
ConversationMessage ,
13
13
ConversationState ,
17
17
from narrative_llm_tools .tools import Tool
18
18
from narrative_llm_tools .utils .format_enforcer import get_format_enforcer
19
19
20
- logger = logging .getLogger (__name__ )
20
+ logger = logging .getLogger ("narrative-llm-tools" )
21
21
logger .setLevel (logging .WARNING )
22
22
23
+
23
24
class HandlerResponse (BaseModel ):
24
25
"""Response from the handler."""
25
26
26
27
tool_calls : list [dict [str , Any ]]
27
- warnings : Optional [ list [str ]]
28
+ warnings : list [str ] | None
28
29
29
30
30
31
class ModelConfig (BaseModel ):
@@ -34,7 +35,6 @@ class ModelConfig(BaseModel):
34
35
path : str
35
36
max_new_tokens : int = 4096
36
37
device_map : str = "auto"
37
- low_cpu_mem_usage : bool = False
38
38
begin_token : str = "<|begin_of_text|>"
39
39
eot_token : str = "<|eot_id|>"
40
40
@@ -111,14 +111,14 @@ class AuthenticationError(EndpointError):
111
111
112
112
113
113
class EndpointHandler :
114
- def __init__ (self , path : str = "" , low_cpu_mem_usage : bool = False ) -> None :
114
+ def __init__ (self , path : str = "" ) -> None :
115
115
"""
116
116
Initialize the EndpointHandler with the provided model path.
117
117
118
118
Args:
119
119
path (str, optional): The path or identifier of the model. Defaults to "".
120
120
"""
121
- self .config = ModelConfig (path = path , low_cpu_mem_usage = low_cpu_mem_usage )
121
+ self .config = ModelConfig (path = path )
122
122
123
123
try :
124
124
self .pipeline : Pipeline = self ._create_pipeline ()
@@ -135,11 +135,10 @@ def _create_pipeline(self) -> Pipeline:
135
135
model = self .config .path ,
136
136
max_new_tokens = self .config .max_new_tokens ,
137
137
device_map = self .config .device_map ,
138
- low_cpu_mem_usage = self .config .low_cpu_mem_usage ,
139
138
)
140
139
return pipe # type: ignore
141
140
142
- def __call__ (self , data : dict [str , Any ]) -> HandlerResponse :
141
+ def __call__ (self , data : dict [str , Any ]) -> dict [ str , Any ] :
143
142
"""
144
143
Generate model output given a conversation and optional tools/parameters.
145
144
@@ -295,7 +294,7 @@ def __call__(self, data: dict[str, Any]) -> HandlerResponse:
295
294
if not isinstance (tool_call , dict ):
296
295
raise ModelOutputError ("Model output is not a list of tool calls." )
297
296
298
- return HandlerResponse (tool_calls = return_msg , warnings = None )
297
+ return HandlerResponse (tool_calls = return_msg , warnings = None ). model_dump ( exclude_none = True )
299
298
300
299
except (
301
300
ValidationError ,
@@ -331,6 +330,7 @@ def _execute_tool_calls(self, tool_calls: list[Tool], state: ConversationState)
331
330
"""Execute tool calls and update conversation state."""
332
331
logger .debug (f"Executing tool calls: { tool_calls } " )
333
332
rest_api_catalog = state .get_rest_api_catalog ()
333
+ logger .info (f"Rest API catalog: { rest_api_catalog } " )
334
334
335
335
if not rest_api_catalog :
336
336
logger .info ("No rest API catalog is available, skipping all tool calls." )
@@ -344,21 +344,37 @@ def _execute_tool_calls(self, tool_calls: list[Tool], state: ConversationState)
344
344
api_client = rest_api_catalog [tool .name ]
345
345
api_response : RestApiResponse = api_client .call (tool .parameters )
346
346
api_client_behavior = (
347
- api_client .config .response_behavior .get (api_response .status )
348
- if api_client .config .response_behavior .get (api_response .status )
347
+ api_client .config .response_behavior .get (str ( api_response .status ) )
348
+ if api_client .config .response_behavior .get (str ( api_response .status ) )
349
349
else api_client .config .response_behavior .get ("default" )
350
350
)
351
351
352
- if api_response .type == "json" and api_client_behavior == "return_to_llm" :
353
- tool_responses .append (ToolResponse (name = tool .name , content = api_response .body ))
352
+ logger .info (f"API response: { api_response } , behavior: { api_client_behavior } " )
353
+ behavior_type = api_client_behavior .behavior_type if api_client_behavior else None
354
+
355
+ if (
356
+ behavior_type
357
+ and behavior_type == "return_to_llm"
358
+ ):
359
+ llm_response_behavior : ReturnToLlmBehavior = api_client_behavior # type: ignore
360
+
361
+ response = (
362
+ llm_response_behavior .response
363
+ if llm_response_behavior .response
364
+ else api_response .body
365
+ )
366
+ tool_responses .append (ToolResponse (name = tool .name , content = response ))
354
367
elif (
355
- api_response .type == "json" and api_client_behavior == "return_response_to_user"
368
+ api_response .type == "json"
369
+ and behavior_type
370
+ and behavior_type == "return_response_to_user"
356
371
):
357
372
tool_responses .append (ToolResponse (name = tool .name , content = api_response .body ))
358
373
return_to_user = True
359
374
elif (
360
375
api_response .type == "json"
361
- and api_client_behavior == "return_request_to_user"
376
+ and behavior_type
377
+ and behavior_type == "return_request_to_user"
362
378
and api_response .request
363
379
):
364
380
tool_responses .append (
0 commit comments