@@ -466,25 +466,28 @@ def bin_trim(self, prompt: str, num_token: int) -> str:
466
466
467
467
class OpenAISDK (OpenAI ):
468
468
469
- def __init__ (self ,
470
- path : str = 'gpt-3.5-turbo' ,
471
- max_seq_len : int = 4096 ,
472
- query_per_second : int = 1 ,
473
- rpm_verbose : bool = False ,
474
- retry : int = 2 ,
475
- key : str | List [str ] = 'ENV' ,
476
- org : str | List [str ] | None = None ,
477
- meta_template : Dict | None = None ,
478
- openai_api_base : str = OPENAI_API_BASE ,
479
- openai_proxy_url : Optional [str ] = None ,
480
- mode : str = 'none' ,
481
- logprobs : bool | None = False ,
482
- top_logprobs : int | None = None ,
483
- temperature : float | None = None ,
484
- tokenizer_path : str | None = None ,
485
- extra_body : Dict | None = None ,
486
- max_completion_tokens : int = 16384 ,
487
- verbose : bool = False ):
469
+ def __init__ (
470
+ self ,
471
+ path : str = 'gpt-3.5-turbo' ,
472
+ max_seq_len : int = 4096 ,
473
+ query_per_second : int = 1 ,
474
+ rpm_verbose : bool = False ,
475
+ retry : int = 2 ,
476
+ key : str | List [str ] = 'ENV' ,
477
+ org : str | List [str ] | None = None ,
478
+ meta_template : Dict | None = None ,
479
+ openai_api_base : str = OPENAI_API_BASE ,
480
+ openai_proxy_url : Optional [str ] = None ,
481
+ mode : str = 'none' ,
482
+ logprobs : bool | None = False ,
483
+ top_logprobs : int | None = None ,
484
+ temperature : float | None = None ,
485
+ tokenizer_path : str | None = None ,
486
+ extra_body : Dict | None = None ,
487
+ max_completion_tokens : int = 16384 ,
488
+ verbose : bool = False ,
489
+ status_code_mappings : dict = {},
490
+ ):
488
491
super ().__init__ (path ,
489
492
max_seq_len ,
490
493
query_per_second ,
@@ -519,9 +522,11 @@ def __init__(self,
519
522
http_client = httpx .Client (proxies = proxies ))
520
523
if self .verbose :
521
524
self .logger .info (f'Used openai_client: { self .openai_client } ' )
525
+ self .status_code_mappings = status_code_mappings
522
526
523
527
def _generate (self , input : PromptList | str , max_out_len : int ,
524
528
temperature : float ) -> str :
529
+ from openai import BadRequestError
525
530
assert isinstance (input , (str , PromptList ))
526
531
527
532
# max num token for gpt-3.5-turbo is 4097
@@ -605,7 +610,30 @@ def _generate(self, input: PromptList | str, max_out_len: int,
605
610
self .logger .info (responses )
606
611
except Exception as e : # noqa F841
607
612
pass
613
+ if not responses .choices :
614
+ self .logger .error (
615
+ 'Response is empty, it is an internal server error \
616
+ from the API provider.' )
608
617
return responses .choices [0 ].message .content
618
+
619
+ except BadRequestError as e :
620
+ # Handle BadRequest status
621
+ # You can specify self.status_code_mappings to bypass \
622
+ # API sensitivity blocks
623
+ # For example: status_code_mappings={400: 'Input data \
624
+ # may contain inappropriate content.'}
625
+ status_code = e .status_code
626
+ if (status_code is not None
627
+ and status_code in self .status_code_mappings ):
628
+ original_error_message = e .body .get ('message' )
629
+ error_message = self .status_code_mappings [status_code ]
630
+ self .logger .info (
631
+ f'Status Code: { status_code } , '
632
+ f'Original Error Message: { original_error_message } ,'
633
+ f'Return Message: { error_message } ' )
634
+ return error_message
635
+ else :
636
+ self .logger .error (e )
609
637
except Exception as e :
610
638
self .logger .error (e )
611
639
num_retries += 1
0 commit comments