Skip to content

Commit

Permalink
Address skorch-dev#925
Browse files Browse the repository at this point in the history
Changing the `_get_param_names` method to return a list instead of a
generator to fix the exception error message when passing unknown
parameters to `set_params`. Before the error message just included
the generator `repr`-string as the list of possible parameters.
Now the string contains the possible parameter names instead.
  • Loading branch information
nemo committed Dec 20, 2022
1 parent af18eea commit f9e611f
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

### Fixed
- `_get_param_names` returns a list instead of a generator so that subsequent
error messages return useful information instead of a generator `repr`
string (#925)

## [0.12.1] - 2022-11-18

Expand Down
2 changes: 1 addition & 1 deletion skorch/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def on_grad_computed(
"""

def _get_param_names(self):
return (key for key in self.__dict__ if not key.endswith('_'))
return [key for key in self.__dict__ if not key.endswith('_')]

def get_params(self, deep=True):
return BaseEstimator.get_params(self, deep=deep)
Expand Down
2 changes: 1 addition & 1 deletion skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -1857,7 +1857,7 @@ def get_params_for_optimizer(self, prefix, named_parameters):
return args, kwargs

def _get_param_names(self):
return (k for k in self.__dict__ if not k.endswith('_'))
return [k for k in self.__dict__ if not k.endswith('_')]

def _get_params_callbacks(self, deep=True):
"""sklearn's .get_params checks for `hasattr(value,
Expand Down

0 comments on commit f9e611f

Please sign in to comment.