Skip to content

Commit daaf4d1

Browse files
authoredJan 31, 2025··
Merge: Improve GP Fit (#472)
- scipy fit is used for all cases now - instead, the MLL type is switched based on whether TL is active or not
2 parents fb6e7d8 + 190bbe9 commit daaf4d1

File tree

3 files changed

+15
-7
lines changed

3 files changed

+15
-7
lines changed
 

‎CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file.
44
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
66

7+
## [0.12.2] - 2025-01-31
8+
### Changed
9+
- More robust settings for the GP fitting
10+
711
## [0.12.1] - 2025-01-29
812
### Changed
913
- Default of `allow_recommending_already_recommended` is changed back to `False`

‎baybe/surrogates/gaussian_process/core.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -201,15 +201,19 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None:
201201
covar_module=covar_module,
202202
likelihood=likelihood,
203203
)
204-
mll = gpytorch.ExactMarginalLogLikelihood(self._model.likelihood, self._model)
205204

206-
# TODO: This is a simple temporary workaround to avoid model overfitting
207-
# via early stopping in the presence of task parameters, which currently
208-
# have no prior configured.
205+
# TODO: This is still a temporary workaround to avoid overfitting seen in
206+
# low-dimensional TL cases. More robust settings are being researched.
209207
if context.n_task_dimensions > 0:
210-
botorch.optim.fit.fit_gpytorch_mll_torch(mll, step_limit=200)
208+
mll = gpytorch.mlls.LeaveOneOutPseudoLikelihood(
209+
self._model.likelihood, self._model
210+
)
211211
else:
212-
botorch.fit.fit_gpytorch_mll(mll)
212+
mll = gpytorch.ExactMarginalLogLikelihood(
213+
self._model.likelihood, self._model
214+
)
215+
216+
botorch.fit.fit_gpytorch_mll(mll)
213217

214218
@override
215219
def __str__(self) -> str:

‎baybe/surrogates/naive.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def _estimate_moments(
3737
import torch
3838

3939
# TODO: use target value bounds for covariance scaling when explicitly provided
40-
mean = self._model * torch.ones([len(candidates_comp_scaled)])
40+
mean = self._model * torch.ones([len(candidates_comp_scaled)]) # type: ignore[operator]
4141
var = torch.ones(len(candidates_comp_scaled))
4242
return mean, var
4343

0 commit comments

Comments
 (0)
Please sign in to comment.