22
22
RBFInterpolator ,
23
23
)
24
24
25
+ from ..plots .plot_helpers import show_or_save_plot
26
+
25
27
# Numpy 1.x compatibility,
26
28
# TODO: remove these lines when all dependencies support numpy>=2.0.0
27
29
if np .lib .NumpyVersion (np .__version__ ) >= "2.0.0b1" :
@@ -1378,7 +1380,7 @@ def remove_outliers_iqr(self, threshold=1.5):
1378
1380
)
1379
1381
1380
1382
# Define all presentation methods
1381
- def __call__ (self , * args ):
1383
+ def __call__ (self , * args , filename = None ):
1382
1384
"""Plot the Function if no argument is given. If an
1383
1385
argument is given, return the value of the function at the desired
1384
1386
point.
@@ -1392,13 +1394,18 @@ def __call__(self, *args):
1392
1394
evaluated at all points in the list and a list of floats will be
1393
1395
returned. If the function is N-D, N arguments must be given, each
1394
1396
one being an scalar or list.
1397
+ filename : str | None, optional
1398
+ The path the plot should be saved to. By default None, in which case
1399
+ the plot will be shown instead of saved. Supported file endings are:
1400
+ eps, jpg, jpeg, pdf, pgf, png, ps, raw, rgba, svg, svgz, tif, tiff
1401
+ and webp (these are the formats supported by matplotlib).
1395
1402
1396
1403
Returns
1397
1404
-------
1398
1405
ans : None, scalar, list
1399
1406
"""
1400
1407
if len (args ) == 0 :
1401
- return self .plot ()
1408
+ return self .plot (filename = filename )
1402
1409
else :
1403
1410
return self .get_value (* args )
1404
1411
@@ -1459,8 +1466,11 @@ def plot(self, *args, **kwargs):
1459
1466
Function.plot_2d if Function is 2-Dimensional and forward arguments
1460
1467
and key-word arguments."""
1461
1468
if isinstance (self , list ):
1469
+ # Extract filename from kwargs
1470
+ filename = kwargs .get ("filename" , None )
1471
+
1462
1472
# Compare multiple plots
1463
- Function .compare_plots (self )
1473
+ Function .compare_plots (self , filename )
1464
1474
else :
1465
1475
if self .__dom_dim__ == 1 :
1466
1476
self .plot_1d (* args , ** kwargs )
@@ -1488,6 +1498,7 @@ def plot_1d( # pylint: disable=too-many-statements
1488
1498
force_points = False ,
1489
1499
return_object = False ,
1490
1500
equal_axis = False ,
1501
+ filename = None ,
1491
1502
):
1492
1503
"""Plot 1-Dimensional Function, from a lower limit to an upper limit,
1493
1504
by sampling the Function several times in the interval. The title of
@@ -1518,6 +1529,11 @@ def plot_1d( # pylint: disable=too-many-statements
1518
1529
Setting force_points to True will plot all points, as a scatter, in
1519
1530
which the Function was evaluated in the dataset. Default value is
1520
1531
False.
1532
+ filename : str | None, optional
1533
+ The path the plot should be saved to. By default None, in which case
1534
+ the plot will be shown instead of saved. Supported file endings are:
1535
+ eps, jpg, jpeg, pdf, pgf, png, ps, raw, rgba, svg, svgz, tif, tiff
1536
+ and webp (these are the formats supported by matplotlib).
1521
1537
1522
1538
Returns
1523
1539
-------
@@ -1558,7 +1574,7 @@ def plot_1d( # pylint: disable=too-many-statements
1558
1574
plt .title (self .title )
1559
1575
plt .xlabel (self .__inputs__ [0 ].title ())
1560
1576
plt .ylabel (self .__outputs__ [0 ].title ())
1561
- plt . show ( )
1577
+ show_or_save_plot ( filename )
1562
1578
if return_object :
1563
1579
return fig , ax
1564
1580
@@ -1581,6 +1597,7 @@ def plot_2d( # pylint: disable=too-many-statements
1581
1597
disp_type = "surface" ,
1582
1598
alpha = 0.6 ,
1583
1599
cmap = "viridis" ,
1600
+ filename = None ,
1584
1601
):
1585
1602
"""Plot 2-Dimensional Function, from a lower limit to an upper limit,
1586
1603
by sampling the Function several times in the interval. The title of
@@ -1620,6 +1637,11 @@ def plot_2d( # pylint: disable=too-many-statements
1620
1637
cmap : string, optional
1621
1638
Colormap of plotted graph, which can be any of the color maps
1622
1639
available in matplotlib. Default value is viridis.
1640
+ filename : str | None, optional
1641
+ The path the plot should be saved to. By default None, in which case
1642
+ the plot will be shown instead of saved. Supported file endings are:
1643
+ eps, jpg, jpeg, pdf, pgf, png, ps, raw, rgba, svg, svgz, tif, tiff
1644
+ and webp (these are the formats supported by matplotlib).
1623
1645
1624
1646
Returns
1625
1647
-------
@@ -1692,7 +1714,7 @@ def plot_2d( # pylint: disable=too-many-statements
1692
1714
axes .set_xlabel (self .__inputs__ [0 ].title ())
1693
1715
axes .set_ylabel (self .__inputs__ [1 ].title ())
1694
1716
axes .set_zlabel (self .__outputs__ [0 ].title ())
1695
- plt . show ( )
1717
+ show_or_save_plot ( filename )
1696
1718
1697
1719
@staticmethod
1698
1720
def compare_plots ( # pylint: disable=too-many-statements
@@ -1707,6 +1729,7 @@ def compare_plots( # pylint: disable=too-many-statements
1707
1729
force_points = False ,
1708
1730
return_object = False ,
1709
1731
show = True ,
1732
+ filename = None ,
1710
1733
):
1711
1734
"""Plots N 1-Dimensional Functions in the same plot, from a lower
1712
1735
limit to an upper limit, by sampling the Functions several times in
@@ -1751,6 +1774,11 @@ def compare_plots( # pylint: disable=too-many-statements
1751
1774
False.
1752
1775
show : bool, optional
1753
1776
If True, shows the plot. Default value is True.
1777
+ filename : str | None, optional
1778
+ The path the plot should be saved to. By default None, in which case
1779
+ the plot will be shown instead of saved. Supported file endings are:
1780
+ eps, jpg, jpeg, pdf, pgf, png, ps, raw, rgba, svg, svgz, tif, tiff
1781
+ and webp (these are the formats supported by matplotlib).
1754
1782
1755
1783
Returns
1756
1784
-------
@@ -1826,7 +1854,7 @@ def compare_plots( # pylint: disable=too-many-statements
1826
1854
plt .ylabel (ylabel )
1827
1855
1828
1856
if show :
1829
- plt . show ( )
1857
+ show_or_save_plot ( filename )
1830
1858
1831
1859
if return_object :
1832
1860
return fig , ax
0 commit comments