7
7
import warnings
8
8
import logging
9
9
import functools
10
+
10
11
from tabpfn_client .client import ServiceClient
12
+ from tabpfn_client .tabpfn_common_utils .expense_estimation import estimate_duration
13
+
14
+
15
+ COST_ESTIMATION_LATENCY_OFFSET = 1.0
16
+
11
17
12
18
# For seamlessly switching between a mock mode for simulating prediction
13
19
# costs and real prediction, use thread-local variables to keep track of the
@@ -49,53 +55,6 @@ def increment_mock_time(seconds: float):
49
55
set_mock_time (get_mock_time () + seconds )
50
56
51
57
52
- def estimate_duration (
53
- num_rows : int ,
54
- num_features : int ,
55
- task : Literal ["classification" , "regression" ],
56
- tabpfn_config : dict = {},
57
- ) -> float :
58
- """
59
- Estimates the duration of a prediction task.
60
- """
61
- # Logic comes from _estimate_model_usage in base.py of the TabPFN codebase.
62
- CONSTANT_COMPUTE_OVERHEAD = 8000
63
- NUM_SAMPLES_FACTOR = 4
64
- NUM_SAMPLES_PLUS_FEATURES = 6.5
65
- CELLS_FACTOR = 0.25
66
- CELLS_SQUARED_FACTOR = 1.3e-7
67
- EMBEDDING_SIZE = 192
68
- NUM_HEADS = 6
69
- NUM_LAYERS = 12
70
- FEATURES_PER_GROUP = 2
71
- GPU_FACTOR = 1e-11
72
- LATENCY_OFFSET = 1.0
73
-
74
- n_estimators = tabpfn_config .get (
75
- "n_estimators" , 4 if task == "classification" else 8
76
- )
77
-
78
- num_samples = num_rows
79
- num_feature_groups = int (np .ceil (num_features / FEATURES_PER_GROUP ))
80
-
81
- num_cells = (num_feature_groups + 1 ) * num_samples
82
- compute_cost = (EMBEDDING_SIZE ** 2 ) * NUM_HEADS * NUM_LAYERS
83
-
84
- base_duration = (
85
- n_estimators
86
- * compute_cost
87
- * (
88
- CONSTANT_COMPUTE_OVERHEAD
89
- + num_samples * NUM_SAMPLES_FACTOR
90
- + (num_samples + num_feature_groups ) * NUM_SAMPLES_PLUS_FEATURES
91
- + num_cells * CELLS_FACTOR
92
- + num_cells ** 2 * CELLS_SQUARED_FACTOR
93
- )
94
- )
95
-
96
- return round (base_duration * GPU_FACTOR + LATENCY_OFFSET , 3 )
97
-
98
-
99
58
def mock_predict (
100
59
X_test ,
101
60
task : Literal ["classification" , "regression" ],
@@ -116,7 +75,11 @@ def mock_predict(
116
75
)
117
76
118
77
duration = estimate_duration (
119
- X_train .shape [0 ] + X_test .shape [0 ], X_test .shape [1 ], task , config
78
+ num_rows = X_train .shape [0 ] + X_test .shape [0 ],
79
+ num_features = X_test .shape [1 ],
80
+ task = task ,
81
+ tabpfn_config = config ,
82
+ latency_offset = COST_ESTIMATION_LATENCY_OFFSET , # To slightly overestimate (safer)
120
83
)
121
84
increment_mock_time (duration )
122
85
0 commit comments