1
1
import os
2
2
import httpx
3
+ from typing import Any
4
+ import logging
3
5
4
- from tabpfn_client .tabpfn_classifier_interface import TabPFNClassifierInterface
6
+ from tabpfn_client .tabpfn_classifier_interface import AbstractTabPFNClassifier
7
+ from tabpfn_client .tabpfn_common_utils import utils as common_utils
5
8
6
9
SERVER_ENDPOINTS_YAML = os .path .join (os .path .dirname (__file__ ), "server_endpoints.yaml" )
7
10
8
11
9
- class TabPFNServiceClient (TabPFNClassifierInterface ):
12
+ class TabPFNServiceClient (AbstractTabPFNClassifier ):
10
13
def __init__ (self , server_spec : dict , access_token : str ):
11
14
self .host = server_spec ["host" ]
12
15
self .port = server_spec ["port" ]
@@ -16,22 +19,66 @@ def __init__(self, server_spec: dict, access_token: str):
16
19
self .access_token = access_token
17
20
self .server_endpoints = server_spec ["endpoints" ]
18
21
22
+ self .last_per_user_train_set_id = None
23
+
19
24
def remove_models_from_memory (self ):
20
25
raise NotImplementedError
21
26
22
- def load_result_minimal (self , path , i , e ):
23
- raise NotImplementedError
27
+ def fit (self , X : Any , y : Any ):
28
+ X = common_utils .serialize_to_csv_formatted_bytes (X )
29
+ y = common_utils .serialize_to_csv_formatted_bytes (y )
24
30
25
- def fit (self , X , y ):
26
- pass
31
+ response = self .client .post (
32
+ url = self .server_endpoints ["upload_train_set" ]["path" ],
33
+ headers = {"Authorization" : f"Bearer { self .access_token } " },
34
+ files = common_utils .to_httpx_post_file_format ([
35
+ ("x_file" , X ),
36
+ ("y_file" , y )
37
+ ])
38
+ )
27
39
28
- def predict (self , X ):
29
- raise NotImplementedError
40
+ if response .status_code != 200 :
41
+ logging .error (f"Fail to call upload_train_set(), response status: { response .status_code } " )
42
+ # TODO: error probably doesn't have json() method, check in unit test
43
+ logging .error (f"Fail to call fit(), server response: { response .json ()} " )
44
+ raise RuntimeError (f"Fail to call fit(), server response: { response .json ()} " )
45
+
46
+ self .last_per_user_train_set_id = response .json ()["per_user_train_set_id" ]
47
+
48
+ return self
49
+
50
+ def predict (self , X , return_winning_class = False , normalize_with_test = False ):
51
+
52
+ # TODO: handle return_winning_class and normalize_with_test
53
+
54
+ # check if user has already called fit() before
55
+ if self .last_per_user_train_set_id is None :
56
+ raise RuntimeError ("You must call fit() before calling predict()" )
57
+
58
+ X = common_utils .serialize_to_csv_formatted_bytes (X )
59
+
60
+ response = self .client .post (
61
+ url = self .server_endpoints ["predict" ]["path" ],
62
+ headers = {"Authorization" : f"Bearer { self .access_token } " },
63
+ params = {"per_user_train_set_id" : self .last_per_user_train_set_id },
64
+ files = common_utils .to_httpx_post_file_format ([
65
+ ("x_file" , X )
66
+ ])
67
+ )
68
+
69
+ if response .status_code != 200 :
70
+ logging .error (f"Fail to call predict(), response status: { response .status_code } " )
71
+ raise RuntimeError (f"Fail to call predict(), server response: { response .json ()} " )
72
+
73
+ return response .json ()
30
74
31
75
def predict_proba (self , X , return_winning_probability = False , normalize_with_test = False ):
32
76
pass
33
77
34
78
def try_root (self ):
35
- response = self .client .get ("/" )
79
+ response = self .client .get (
80
+ self .server_endpoints ["protected_root" ]["path" ],
81
+ headers = {"Authorization" : f"Bearer { self .access_token } " },
82
+ )
36
83
print ("response:" , response .json ())
37
84
return response
0 commit comments