17
17
18
18
import numpy as np
19
19
import pandas as pd
20
- from pydantic import Field , validator
20
+ from pydantic import Field , field_validator , model_validator
21
21
22
22
from bofire .data_models .base import BaseModel
23
23
from bofire .data_models .constraints .api import (
@@ -57,7 +57,6 @@ class Domain(BaseModel):
57
57
58
58
inputs : Inputs = Field (default_factory = lambda : Inputs ())
59
59
outputs : Outputs = Field (default_factory = lambda : Outputs ())
60
-
61
60
constraints : Constraints = Field (default_factory = lambda : Constraints ())
62
61
63
62
"""Representation of the optimization problem/domain
@@ -84,8 +83,9 @@ def from_lists(
84
83
constraints = Constraints (constraints = constraints ),
85
84
)
86
85
87
- @validator ("inputs" , always = True , pre = True )
88
- def validate_inputs_list (cls , v , values ):
86
+ @field_validator ("inputs" , mode = "before" )
87
+ @classmethod
88
+ def validate_inputs_list (cls , v ):
89
89
if isinstance (v , collections .abc .Sequence ):
90
90
v = Inputs (features = v )
91
91
return v
@@ -94,26 +94,28 @@ def validate_inputs_list(cls, v, values):
94
94
else :
95
95
return v
96
96
97
- @validator ("outputs" , always = True , pre = True )
98
- def validate_outputs_list (cls , v , values ):
97
+ @field_validator ("outputs" , mode = "before" )
98
+ @classmethod
99
+ def validate_outputs_list (cls , v ):
99
100
if isinstance (v , collections .abc .Sequence ):
100
101
return Outputs (features = v )
101
102
if isinstance_or_union (v , AnyOutput ):
102
103
return Outputs (features = [v ])
103
104
else :
104
105
return v
105
106
106
- @validator ("constraints" , always = True , pre = True )
107
- def validate_constraints_list (cls , v , values ):
107
+ @field_validator ("constraints" , mode = "before" )
108
+ @classmethod
109
+ def validate_constraints_list (cls , v ):
108
110
if isinstance (v , list ):
109
111
return Constraints (constraints = v )
110
112
if isinstance_or_union (v , AnyConstraint ):
111
113
return Constraints (constraints = [v ])
112
114
else :
113
115
return v
114
116
115
- @validator ( "outputs" , always = True )
116
- def validate_unique_feature_keys (cls , v : Outputs , values ) -> Outputs :
117
+ @model_validator ( mode = "after" )
118
+ def validate_unique_feature_keys (self ) :
117
119
"""Validates if provided input and output feature keys are unique
118
120
119
121
Args:
@@ -126,16 +128,14 @@ def validate_unique_feature_keys(cls, v: Outputs, values) -> Outputs:
126
128
Returns:
127
129
Outputs: Keeps output features as given.
128
130
"""
129
- if "inputs" not in values :
130
- return v
131
- features = v + values ["inputs" ]
132
- keys = [f .key for f in features ]
131
+
132
+ keys = self .outputs .get_keys () + self .inputs .get_keys ()
133
133
if len (set (keys )) != len (keys ):
134
- raise ValueError ("feature keys are not unique" )
135
- return v
134
+ raise ValueError ("Feature keys are not unique" )
135
+ return self
136
136
137
- @validator ( "constraints" , always = True )
138
- def validate_constraints (cls , v , values ):
137
+ @model_validator ( mode = "after" )
138
+ def validate_constraints (self ):
139
139
"""Validate if all features included in the constraints are also defined as features for the domain.
140
140
141
141
Args:
@@ -148,18 +148,17 @@ def validate_constraints(cls, v, values):
148
148
Returns:
149
149
List[Constraint]: List of constraints defined for the domain
150
150
"""
151
- if "inputs" not in values :
152
- return v
153
- keys = [f .key for f in values ["inputs" ]]
154
- for c in v :
151
+
152
+ keys = self .inputs .get_keys ()
153
+ for c in self .constraints :
155
154
if isinstance (c , LinearConstraint ) or isinstance (c , NChooseKConstraint ):
156
155
for f in c .features :
157
156
if f not in keys :
158
157
raise ValueError (f"feature { f } in constraint unknown ({ keys } )" )
159
- return v
158
+ return self
160
159
161
- @validator ( "constraints" , always = True )
162
- def validate_linear_constraints ( cls , v , values ):
160
+ @model_validator ( mode = "after" )
161
+ def validate_linear_constraints_and_nchoosek ( self ):
163
162
"""Validate if all features included in linear constraints are continuous ones.
164
163
165
164
Args:
@@ -173,21 +172,13 @@ def validate_linear_constraints(cls, v, values):
173
172
Returns:
174
173
List[Constraint]: List of constraints defined for the domain
175
174
"""
176
- if "inputs" not in values :
177
- return v
178
-
179
- # gather continuous inputs in dictionary
180
- continuous_inputs_dict = {}
181
- for f in values ["inputs" ]:
182
- if isinstance (f , ContinuousInput ):
183
- continuous_inputs_dict [f .key ] = f
175
+ keys = self .inputs .get_keys (ContinuousInput )
184
176
185
177
# check if non continuous input features appear in linear constraints
186
- for c in v :
187
- if isinstance (c , LinearConstraint ):
188
- for f in c .features :
189
- assert f in continuous_inputs_dict , f"{ f } must be continuous."
190
- return v
178
+ for c in self .constraints .get (includes = [LinearConstraint , NChooseKConstraint ]):
179
+ for f in c .features : # type: ignore
180
+ assert f in keys , f"{ f } must be continuous."
181
+ return self
191
182
192
183
def get_feature_reps_df (self ) -> pd .DataFrame :
193
184
"""Returns a pandas dataframe describing the features contained in the optimization domain."""
@@ -617,11 +608,6 @@ def candidate_column_names(self):
617
608
]
618
609
)
619
610
620
- def _set_constraints_unvalidated (
621
- self , constraints : Union [Sequence [AnyConstraint ], Constraints ]
622
- ):
623
- """Hack for reduce_domain"""
624
- self .constraints = Constraints (constraints = [])
625
- if isinstance (constraints , Constraints ):
626
- constraints = constraints .constraints
627
- self .constraints .constraints = constraints
611
+
612
+ if __name__ == "__main__" :
613
+ pass
0 commit comments