Skip to content

Commit 2218f0f

Browse files
yogirajgutteLucas-PratesGui-FernandesBR
authored
MNT: move piecewise functions to separate file (#746)
* MNT: move piecewise functions to separate file closes #667 * improved import for linting * MNT: applying code formaters * ENH: simplifying and optimizing the function, implementing tests. * MNT: update changelog and apply changes suggested in review --------- Co-authored-by: Lucas Prates <[email protected]> Co-authored-by: Lucas de Oliveira Prates <[email protected]> Co-authored-by: Gui-FernandesBR <[email protected]>
1 parent 7a122ad commit 2218f0f

File tree

6 files changed

+134
-110
lines changed

6 files changed

+134
-110
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ Attention: The newest changes should be on top -->
4242

4343
### Changed
4444

45+
- MNT: move piecewise functions to separate file [#746](https://github.com/RocketPy-Team/RocketPy/pull/746)
4546
- DOC: flight comparison improvements [#755](https://github.com/RocketPy-Team/RocketPy/pull/755)
4647

4748
### Fixed

rocketpy/mathutils/__init__.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
from .function import (
2-
Function,
3-
PiecewiseFunction,
4-
funcify_method,
5-
reset_funcified_methods,
6-
)
1+
from .function import Function, funcify_method, reset_funcified_methods
2+
from .piecewise_function import PiecewiseFunction
73
from .vector_matrix import Matrix, Vector

rocketpy/mathutils/function.py

-103
Original file line numberDiff line numberDiff line change
@@ -3419,109 +3419,6 @@ def __validate_extrapolation(self, extrapolation):
34193419
return extrapolation
34203420

34213421

3422-
class PiecewiseFunction(Function):
3423-
"""Class for creating piecewise functions. These kind of functions are
3424-
defined by a dictionary of functions, where the keys are tuples that
3425-
represent the domain of the function. The domains must be disjoint.
3426-
"""
3427-
3428-
def __new__(
3429-
cls,
3430-
source,
3431-
inputs=None,
3432-
outputs=None,
3433-
interpolation="spline",
3434-
extrapolation=None,
3435-
datapoints=100,
3436-
):
3437-
"""
3438-
Creates a piecewise function from a dictionary of functions. The keys of
3439-
the dictionary must be tuples that represent the domain of the function.
3440-
The domains must be disjoint. The piecewise function will be evaluated
3441-
at datapoints points to create Function object.
3442-
3443-
Parameters
3444-
----------
3445-
source: dictionary
3446-
A dictionary of Function objects, where the keys are the domains.
3447-
inputs : list of strings
3448-
A list of strings that represent the inputs of the function.
3449-
outputs: list of strings
3450-
A list of strings that represent the outputs of the function.
3451-
interpolation: str
3452-
The type of interpolation to use. The default value is 'spline'.
3453-
extrapolation: str
3454-
The type of extrapolation to use. The default value is None.
3455-
datapoints: int
3456-
The number of points in which the piecewise function will be
3457-
evaluated to create a base function. The default value is 100.
3458-
"""
3459-
if inputs is None:
3460-
inputs = ["Scalar"]
3461-
if outputs is None:
3462-
outputs = ["Scalar"]
3463-
# Check if source is a dictionary
3464-
if not isinstance(source, dict):
3465-
raise TypeError("source must be a dictionary")
3466-
# Check if all keys are tuples
3467-
for key in source.keys():
3468-
if not isinstance(key, tuple):
3469-
raise TypeError("keys of source must be tuples")
3470-
# Check if all domains are disjoint
3471-
for key1 in source.keys():
3472-
for key2 in source.keys():
3473-
if key1 != key2:
3474-
if key1[0] < key2[1] and key1[1] > key2[0]:
3475-
raise ValueError("domains must be disjoint")
3476-
3477-
# Crate Function
3478-
def calc_output(func, inputs):
3479-
"""Receives a list of inputs value and a function, populates another
3480-
list with the results corresponding to the same results.
3481-
3482-
Parameters
3483-
----------
3484-
func : Function
3485-
The Function object to be
3486-
inputs : list, tuple, np.array
3487-
The array of points to applied the func to.
3488-
3489-
Examples
3490-
--------
3491-
>>> inputs = [0, 1, 2, 3, 4, 5]
3492-
>>> def func(x):
3493-
... return x*10
3494-
>>> calc_output(func, inputs)
3495-
[0, 10, 20, 30, 40, 50]
3496-
3497-
Notes
3498-
-----
3499-
In the future, consider using the built-in map function from python.
3500-
"""
3501-
output = np.zeros(len(inputs))
3502-
for j, value in enumerate(inputs):
3503-
output[j] = func.get_value_opt(value)
3504-
return output
3505-
3506-
input_data = []
3507-
output_data = []
3508-
for key in sorted(source.keys()):
3509-
i = np.linspace(key[0], key[1], datapoints)
3510-
i = i[~np.isin(i, input_data)]
3511-
input_data = np.concatenate((input_data, i))
3512-
3513-
f = Function(source[key])
3514-
output_data = np.concatenate((output_data, calc_output(f, i)))
3515-
3516-
return Function(
3517-
np.concatenate(([input_data], [output_data])).T,
3518-
inputs=inputs,
3519-
outputs=outputs,
3520-
interpolation=interpolation,
3521-
extrapolation=extrapolation,
3522-
)
3523-
3524-
35253422
def funcify_method(*args, **kwargs): # pylint: disable=too-many-statements
35263423
"""Decorator factory to wrap methods as Function objects and save them as
35273424
cached properties.
+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import numpy as np
2+
3+
from rocketpy.mathutils.function import Function
4+
5+
6+
class PiecewiseFunction(Function):
7+
"""Class for creating piecewise functions. These kind of functions are
8+
defined by a dictionary of functions, where the keys are tuples that
9+
represent the domain of the function. The domains must be disjoint.
10+
"""
11+
12+
def __new__(
13+
cls,
14+
source,
15+
inputs=None,
16+
outputs=None,
17+
interpolation="spline",
18+
extrapolation=None,
19+
datapoints=100,
20+
):
21+
"""
22+
Creates a piecewise function from a dictionary of functions. The keys of
23+
the dictionary must be tuples that represent the domain of the function.
24+
The domains must be disjoint. The piecewise function will be evaluated
25+
at datapoints points to create Function object.
26+
27+
Parameters
28+
----------
29+
source: dictionary
30+
A dictionary of Function objects, where the keys are the domains.
31+
inputs : list of strings
32+
A list of strings that represent the inputs of the function.
33+
outputs: list of strings
34+
A list of strings that represent the outputs of the function.
35+
interpolation: str
36+
The type of interpolation to use. The default value is 'spline'.
37+
extrapolation: str
38+
The type of extrapolation to use. The default value is None.
39+
datapoints: int
40+
The number of points in which the piecewise function will be
41+
evaluated to create a base function. The default value is 100.
42+
"""
43+
cls.__validate__source(source)
44+
if inputs is None:
45+
inputs = ["Scalar"]
46+
if outputs is None:
47+
outputs = ["Scalar"]
48+
49+
input_data = np.array([])
50+
output_data = np.array([])
51+
for lower, upper in sorted(source.keys()):
52+
grid = np.linspace(lower, upper, datapoints)
53+
54+
# since intervals are disjoint and sorted, we only need to check
55+
# if the first point is already included
56+
if input_data.size != 0:
57+
if lower == input_data[-1]:
58+
grid = np.delete(grid, 0)
59+
input_data = np.concatenate((input_data, grid))
60+
61+
f = Function(source[(lower, upper)])
62+
output_data = np.concatenate((output_data, f.get_value(grid)))
63+
64+
return Function(
65+
np.concatenate(([input_data], [output_data])).T,
66+
inputs=inputs,
67+
outputs=outputs,
68+
interpolation=interpolation,
69+
extrapolation=extrapolation,
70+
)
71+
72+
@staticmethod
73+
def __validate__source(source):
74+
"""Validates that source is dictionary with non-overlapping
75+
intervals
76+
77+
Parameters
78+
----------
79+
source : dict
80+
A dictionary of Function objects, where the keys are the domains.
81+
"""
82+
# Check if source is a dictionary
83+
if not isinstance(source, dict):
84+
raise TypeError("source must be a dictionary")
85+
# Check if all keys are tuples
86+
for key in source.keys():
87+
if not isinstance(key, tuple):
88+
raise TypeError("keys of source must be tuples")
89+
# Check if all domains are disjoint
90+
for lower1, upper1 in source.keys():
91+
for lower2, upper2 in source.keys():
92+
if (lower1, upper1) != (lower2, upper2):
93+
if lower1 < upper2 and upper1 > lower2:
94+
raise ValueError("domains must be disjoint")

rocketpy/motors/tank_geometry.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
import numpy as np
44

5-
from ..mathutils.function import Function, PiecewiseFunction, funcify_method
5+
from ..mathutils.function import Function, funcify_method
6+
from ..mathutils.piecewise_function import PiecewiseFunction
67
from ..plots.tank_geometry_plots import _TankGeometryPlots
78
from ..prints.tank_geometry_prints import _TankGeometryPrints
89

tests/unit/test_piecewise_function.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import pytest
2+
3+
from rocketpy import PiecewiseFunction
4+
5+
6+
@pytest.mark.parametrize(
7+
"source",
8+
[
9+
((0, 4), lambda x: x),
10+
{"0-4": lambda x: x},
11+
{(0, 4): lambda x: x, (3, 5): lambda x: 2 * x},
12+
],
13+
)
14+
def test_invalid_source(source):
15+
"""Test an error is raised when the source parameter is invalid"""
16+
with pytest.raises((TypeError, ValueError)):
17+
PiecewiseFunction(source)
18+
19+
20+
@pytest.mark.parametrize(
21+
"source",
22+
[
23+
{(-1, 0): lambda x: -x, (0, 1): lambda x: x},
24+
{
25+
(0, 1): lambda x: x,
26+
(1, 2): lambda x: 1,
27+
(2, 3): lambda x: 3 - x,
28+
},
29+
],
30+
)
31+
@pytest.mark.parametrize("inputs", [None, "X"])
32+
@pytest.mark.parametrize("outputs", [None, "Y"])
33+
def test_new(source, inputs, outputs):
34+
"""Test if PiecewiseFunction.__new__ runs correctly"""
35+
PiecewiseFunction(source, inputs, outputs)

0 commit comments

Comments
 (0)