diff --git a/notebooks/tutorial_RGDR.ipynb b/notebooks/tutorial_RGDR.ipynb
index 29a7389..97ae938 100644
--- a/notebooks/tutorial_RGDR.ipynb
+++ b/notebooks/tutorial_RGDR.ipynb
@@ -51,7 +51,7 @@
"outputs": [],
"source": [
"target_timeseries = target_resampled.sel(cluster=3).ts.isel(i_interval=0)\n",
- "precursor_field = field_resampled.sst.isel(i_interval=1)\n",
+ "precursor_field = field_resampled.sst.isel(i_interval=slice(1,5)) # Multiple lags: 1 through 4\n",
"\n",
"rgdr = RGDR(eps_km=600, alpha=0.05, min_area_km2=3000**2)"
]
@@ -83,7 +83,7 @@
],
"source": [
"fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(12, 2))\n",
- "_ = rgdr.plot_correlation(precursor_field, target_timeseries, ax1, ax2)"
+ "_ = rgdr.plot_correlation(precursor_field, target_timeseries, lag=1, ax1=ax1, ax2=ax2)"
]
},
{
@@ -116,9 +116,10 @@
"source": [
"fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(12, 2))\n",
"\n",
- "_ = rgdr.plot_clusters(precursor_field, target_timeseries, ax=ax1)\n",
+ "_ = rgdr.plot_clusters(precursor_field, target_timeseries, lag=1, ax=ax1)\n",
"\n",
- "_ = RGDR(min_area_km2=1000**2).plot_clusters(precursor_field, target_timeseries, ax=ax2)"
+ "_ = RGDR(eps_km=600, min_area_km2=None).plot_clusters(precursor_field, target_timeseries,\n",
+ " lag=1, ax=ax2)"
]
},
{
@@ -489,17 +490,13 @@
" stroke: currentColor;\n",
" fill: currentColor;\n",
"}\n",
- "
<xarray.DataArray 'sst' (cluster_labels: 3, anchor_year: 39)>\n",
- "290.8 291.0 290.7 290.1 291.1 291.3 ... 299.0 299.5 298.9 298.9 299.2 298.2\n",
+ "<xarray.DataArray 'sst' (cluster_labels: 6, anchor_year: 39)>\n",
+ "290.8 291.0 290.7 290.1 291.1 291.3 ... 285.6 286.3 286.2 285.5 285.0 285.1\n",
"Coordinates:\n",
" * anchor_year (anchor_year) int32 1980 1981 1982 1983 ... 2016 2017 2018\n",
- " i_interval int64 1\n",
- " index (anchor_year) int64 1 13 25 37 49 61 ... 409 421 433 445 457\n",
- " interval (anchor_year) object (1980-07-02, 1980-08-01] ... (2018-0...\n",
- " target bool False\n",
- " * cluster_labels (cluster_labels) int32 -2 0 1\n",
- " latitude (cluster_labels) float64 36.05 nan 29.44\n",
- " longitude (cluster_labels) float64 223.9 nan 185.4
290.8 291.0 290.7 290.1 291.1 291.3 ... 299.5 298.9 298.9 299.2 298.2
array([[290.79588914, 290.970545 , 290.71731703, 290.0762239 ,\n",
+ " * cluster_labels (cluster_labels) <U20 'lag:1_cluster:-2' ... 'lag:4_clust...\n",
+ " latitude (cluster_labels) float64 36.05 29.44 37.33 29.58 38.14 39.78\n",
+ " longitude (cluster_labels) float64 223.9 185.4 221.8 190.2 217.8 219.3
290.8 291.0 290.7 290.1 291.1 291.3 ... 286.3 286.2 285.5 285.0 285.1
array([[290.79588914, 290.970545 , 290.71731703, 290.0762239 ,\n",
" 291.08960917, 291.31511491, 291.11538436, 290.26277142,\n",
" 290.80443321, 290.99960169, 291.53446464, 291.36075119,\n",
" 291.85483292, 291.09343404, 291.31408735, 291.41374784,\n",
@@ -509,16 +506,6 @@
" 291.44962923, 291.91882282, 290.70922506, 290.48853941,\n",
" 290.34711093, 291.99974208, 292.3259726 , 293.32741257,\n",
" 291.91085874, 291.45966391, 291.37066195],\n",
- " [291.41316602, 290.99213881, 290.7072011 , 290.28445664,\n",
- " 291.24062074, 291.15653517, 290.90280998, 290.26394997,\n",
- " 291.24222615, 291.51096688, 291.63083812, 292.04367136,\n",
- " 291.17972744, 291.09824891, 291.49867568, 291.09047242,\n",
- " 291.78362741, 291.61656609, 291.40476286, 291.22261622,\n",
- " 291.74089263, 292.15751269, 291.59638891, 291.53674801,\n",
- " 292.04091404, 292.05970401, 291.54779359, 291.4457052 ,\n",
- " 291.83593796, 291.9013184 , 291.58505835, 292.59423736,\n",
- " 291.89844604, 292.16743467, 292.15498931, 292.00892894,\n",
- " 291.98830915, 291.94098392, 291.62096043],\n",
" [298.94548042, 298.92119198, 297.71694295, 298.98291705,\n",
" 299.35080249, 298.06875363, 298.51701671, 298.495829 ,\n",
" 299.61375876, 298.35675492, 298.46956715, 299.34460285,\n",
@@ -528,64 +515,43 @@
" 298.88763028, 299.2529288 , 299.0168395 , 298.84020348,\n",
" 298.48441327, 299.30649003, 299.69018872, 299.65241405,\n",
" 299.320106 , 299.04325189, 299.48610574, 298.91044985,\n",
- " 298.89195415, 299.19568083, 298.21053747]])
anchor_year
(anchor_year)
int32
1980 1981 1982 ... 2016 2017 2018
array([1980, 1981, 1982, 1983, 1984, 1985, 1986, 1987, 1988, 1989, 1990, 1991,\n",
" 1992, 1993, 1994, 1995, 1996, 1997, 1998, 1999, 2000, 2001, 2002, 2003,\n",
" 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015,\n",
- " 2016, 2017, 2018])
i_interval
()
int64
1
index
(anchor_year)
int64
1 13 25 37 49 ... 421 433 445 457
array([ 1, 13, 25, 37, 49, 61, 73, 85, 97, 109, 121, 133, 145,\n",
- " 157, 169, 181, 193, 205, 217, 229, 241, 253, 265, 277, 289, 301,\n",
- " 313, 325, 337, 349, 361, 373, 385, 397, 409, 421, 433, 445, 457],\n",
- " dtype=int64)
interval
(anchor_year)
object
(1980-07-02, 1980-08-01] ... (20...
array([Interval('1980-07-02', '1980-08-01', closed='right'),\n",
- " Interval('1981-07-02', '1981-08-01', closed='right'),\n",
- " Interval('1982-07-02', '1982-08-01', closed='right'),\n",
- " Interval('1983-07-02', '1983-08-01', closed='right'),\n",
- " Interval('1984-07-02', '1984-08-01', closed='right'),\n",
- " Interval('1985-07-02', '1985-08-01', closed='right'),\n",
- " Interval('1986-07-02', '1986-08-01', closed='right'),\n",
- " Interval('1987-07-02', '1987-08-01', closed='right'),\n",
- " Interval('1988-07-02', '1988-08-01', closed='right'),\n",
- " Interval('1989-07-02', '1989-08-01', closed='right'),\n",
- " Interval('1990-07-02', '1990-08-01', closed='right'),\n",
- " Interval('1991-07-02', '1991-08-01', closed='right'),\n",
- " Interval('1992-07-02', '1992-08-01', closed='right'),\n",
- " Interval('1993-07-02', '1993-08-01', closed='right'),\n",
- " Interval('1994-07-02', '1994-08-01', closed='right'),\n",
- " Interval('1995-07-02', '1995-08-01', closed='right'),\n",
- " Interval('1996-07-02', '1996-08-01', closed='right'),\n",
- " Interval('1997-07-02', '1997-08-01', closed='right'),\n",
- " Interval('1998-07-02', '1998-08-01', closed='right'),\n",
- " Interval('1999-07-02', '1999-08-01', closed='right'),\n",
- " Interval('2000-07-02', '2000-08-01', closed='right'),\n",
- " Interval('2001-07-02', '2001-08-01', closed='right'),\n",
- " Interval('2002-07-02', '2002-08-01', closed='right'),\n",
- " Interval('2003-07-02', '2003-08-01', closed='right'),\n",
- " Interval('2004-07-02', '2004-08-01', closed='right'),\n",
- " Interval('2005-07-02', '2005-08-01', closed='right'),\n",
- " Interval('2006-07-02', '2006-08-01', closed='right'),\n",
- " Interval('2007-07-02', '2007-08-01', closed='right'),\n",
- " Interval('2008-07-02', '2008-08-01', closed='right'),\n",
- " Interval('2009-07-02', '2009-08-01', closed='right'),\n",
- " Interval('2010-07-02', '2010-08-01', closed='right'),\n",
- " Interval('2011-07-02', '2011-08-01', closed='right'),\n",
- " Interval('2012-07-02', '2012-08-01', closed='right'),\n",
- " Interval('2013-07-02', '2013-08-01', closed='right'),\n",
- " Interval('2014-07-02', '2014-08-01', closed='right'),\n",
- " Interval('2015-07-02', '2015-08-01', closed='right'),\n",
- " Interval('2016-07-02', '2016-08-01', closed='right'),\n",
- " Interval('2017-07-02', '2017-08-01', closed='right'),\n",
- " Interval('2018-07-02', '2018-08-01', closed='right')], dtype=object)
target
()
bool
False
cluster_labels
(cluster_labels)
int32
-2 0 1
latitude
(cluster_labels)
float64
36.05 nan 29.44
array([36.0508552, nan, 29.4398051])
longitude
(cluster_labels)
float64
223.9 nan 185.4
array([223.86658208, nan, 185.40970765])
"
+ " 2016, 2017, 2018])
cluster_labels
(cluster_labels)
<U20
'lag:1_cluster:-2' ... 'lag:4_cl...
array(['lag:1_cluster:-2', 'lag:1_cluster:1', 'lag:2_cluster:-1',\n",
+ " 'lag:2_cluster:1', 'lag:3_cluster:-1', 'lag:4_cluster:-2'], dtype='<U20')
latitude
(cluster_labels)
float64
36.05 29.44 37.33 29.58 38.14 39.78
array([36.0508552 , 29.4398051 , 37.33257702, 29.58134561, 38.13773082,\n",
+ " 39.78162825])
longitude
(cluster_labels)
float64
223.9 185.4 221.8 190.2 217.8 219.3
array([223.86658208, 185.40970765, 221.82516648, 190.20336403,\n",
+ " 217.7810629 , 219.30300121])
"
],
"text/plain": [
- "\n",
- "290.8 291.0 290.7 290.1 291.1 291.3 ... 299.0 299.5 298.9 298.9 299.2 298.2\n",
+ "\n",
+ "290.8 291.0 290.7 290.1 291.1 291.3 ... 285.6 286.3 286.2 285.5 285.0 285.1\n",
"Coordinates:\n",
" * anchor_year (anchor_year) int32 1980 1981 1982 1983 ... 2016 2017 2018\n",
- " i_interval int64 1\n",
- " index (anchor_year) int64 1 13 25 37 49 61 ... 409 421 433 445 457\n",
- " interval (anchor_year) object (1980-07-02, 1980-08-01] ... (2018-0...\n",
- " target bool False\n",
- " * cluster_labels (cluster_labels) int32 -2 0 1\n",
- " latitude (cluster_labels) float64 36.05 nan 29.44\n",
- " longitude (cluster_labels) float64 223.9 nan 185.4"
+ " * cluster_labels (cluster_labels) "
+ ""
]
},
"execution_count": 6,
@@ -624,7 +590,7 @@
},
{
"data": {
- "image/png": "",
+ "image/png": "",
"text/plain": [
""
]
@@ -639,8 +605,8 @@
"import matplotlib.pyplot as plt\n",
"\n",
"clustered_data -= clustered_data.mean(dim='anchor_year')\n",
- "clustered_data.sel(cluster_labels=1).plot.line(x='anchor_year', label='pos. corr')\n",
- "clustered_data.sel(cluster_labels=-2).plot.line(x='anchor_year', label='neg. corr')\n",
+ "clustered_data.sel(cluster_labels=\"lag:1_cluster:1\").plot.line(x='anchor_year', label='pos. corr')\n",
+ "clustered_data.sel(cluster_labels=\"lag:1_cluster:-2\").plot.line(x='anchor_year', label='neg. corr')\n",
"plt.legend()\n"
]
},
diff --git a/s2spy/rgdr/rgdr.py b/s2spy/rgdr/rgdr.py
index 0e505ca..00edd9b 100644
--- a/s2spy/rgdr/rgdr.py
+++ b/s2spy/rgdr/rgdr.py
@@ -79,24 +79,96 @@ def remove_small_area_clusters(ds: XrType, min_area_km2: float) -> XrType:
valid_clusters = np.array([c for c, a in zip(clusters, areas) if a > min_area_km2])
ds["cluster_labels"] = ds["cluster_labels"].where(
- np.isin(ds["cluster_labels"], valid_clusters), 0
+ np.isin(ds["cluster_labels"], valid_clusters), "0"
)
return ds
-def masked_spherical_dbscan(
+def add_gridcell_area(data: xr.DataArray):
+ """Adds the area of each gridcell (latitude) in km2.
+
+ Note: Assumes an even grid (in degrees)
+
+ Args:
+ data: Data containing lat, lon coordinates in degrees.
+
+ Returns:
+ Input data with an added coordinate "area".
+ """
+ dlat = np.abs(data.latitude.values[1] - data.latitude.values[0])
+ dlon = np.abs(data.longitude.values[1] - data.longitude.values[0])
+ data["area"] = spherical_area(data.latitude, dlat, dlon)
+ return data
+
+
+def assert_clusters_present(data: xr.DataArray) -> None:
+ """Asserts that any (non-'0') clusters are present in the data."""
+
+ if "i_interval" in data.dims:
+ n_clusters = np.zeros(data["i_interval"].size)
+ for i, _ in enumerate(n_clusters):
+ n_clusters[i] = np.unique(data.isel(i_interval=i).cluster_labels).size
+
+ if np.any(n_clusters == 1): # A single cluster is the '0' (leftovers) cluster.
+ empty_lags = data["i_interval"].values[n_clusters == 1]
+ raise ValueError(
+ f"No significant clusters found in lag(s): i_interval={empty_lags}."
+ "Please remove these intervals from the model before continuing."
+ )
+
+ elif np.unique(data.cluster_labels).size == 1:
+ raise ValueError("No significant clusters found in the input DataArray")
+
+
+def _get_dbscan_clusters(
+ data: xr.Dataset, coords: np.ndarray, lag: int, dbscan_params: dict
+) -> np.ndarray:
+ """Generates the DBSCAN cluster labels based on the correlation and p-value.
+
+ Args:
+ data (xr.DataArray): DataArray of the precursor field, of only a single
+ i_interval. Requires the 'latitude' and 'longitude' dimensions to be stacked
+ into a "coords" dimension.
+ coords (np.ndarray): 2-D array containing the coordinates of each (lat, lon) grid
+ point, in radians.
+ lag (int): The i_interval value of the input data.
+ dbscan_params (dict): Dictionary containing the elements 'alpha', 'eps',
+ 'min_area_km2'. See the documentation of RGDR for more information.
+
+ Returns:
+ np.ndarray: 1-D array of the same length as `coords`, containing cluster labels
+ for every coordinate."""
+
+ labels = np.zeros(len(coords), dtype="= 0, data["corr"] < 0]):
+ mask = np.logical_and(data["p_val"] < dbscan_params["alpha"], sign_mask)
+
+ if np.sum(mask) > 0: # Check if the mask contains any points to cluster
+ db = DBSCAN(
+ eps=dbscan_params["eps"] / RADIUS_EARTH_KM,
+ min_samples=1,
+ algorithm="auto",
+ metric="haversine",
+ ).fit(coords[mask])
+
+ cluster_labels = sign * (db.labels_ + 1)
+ labels[mask] = [f"lag:{lag}_cluster:{int(lbl)}" for lbl in cluster_labels]
+
+ return labels
+
+
+def _find_clusters(
precursor: xr.DataArray,
corr: xr.DataArray,
p_val: xr.DataArray,
dbscan_params: dict,
) -> xr.DataArray:
+ """Computes clusters and adds their labels to the precursor dataset.
- """Determines the clusters based on sklearn's DBSCAN implementation. Alpha determines
- the mask based on the minimum p_value. Grouping can be adjusted using the `eps_km`
- parameter. Cluster labels are negative for areas with a negative correlation coefficient
- and positive for areas with a positive correlation coefficient. Areas without any
- significant correlation are put in the cluster labelled '0'.
+ For clustering the DBSCAN algorithm is used, with a Haversine distance metric.
Args:
precursor (xr.DataArray): DataArray of the precursor field, containing
@@ -109,43 +181,74 @@ def masked_spherical_dbscan(
'min_area_km2'. See the documentation of RGDR for more information.
Returns:
- xr.DataArray: Precursor data grouped by the DBSCAN clusters.
+ xr.DataArray: The input precursor data, with as extra coordinate labelled
+ clusters.
"""
- orig_name = precursor.name
data = precursor.to_dataset()
data["corr"], data["p_val"] = corr, p_val # Will require less tracking of indices
+ if "i_interval" not in data.dims:
+ data = data.expand_dims("i_interval")
+ lags = data["i_interval"].values
+
data = data.stack(coord=["latitude", "longitude"])
coords = np.asarray(data["coord"].values.tolist())
coords = np.radians(coords)
# Prepare labels, default value is 0 (not in cluster)
- labels = np.zeros(len(coords), dtype=int)
+ labels = np.zeros((len(lags), len(coords)), dtype="= 0, data["corr"] < 0]):
- mask = np.logical_and(data["p_val"] < dbscan_params["alpha"], sign_mask)
- if np.sum(mask) > 0: # Check if the mask contains any points to cluster
- db = DBSCAN(
- eps=dbscan_params["eps"] / RADIUS_EARTH_KM,
- min_samples=1,
- algorithm="auto",
- metric="haversine",
- ).fit(coords[mask])
-
- labels[mask] = sign * (db.labels_ + 1)
+ for i, lag in enumerate(lags):
+ labels[i] = _get_dbscan_clusters(
+ data.isel(i_interval=i), coords, lag, dbscan_params
+ )
precursor = precursor.stack(coord=["latitude", "longitude"])
- precursor["cluster_labels"] = ("coord", labels)
+ if "i_interval" not in precursor.dims:
+ precursor["cluster_labels"] = ("coord", labels[0])
+ else:
+ precursor["cluster_labels"] = (("i_interval", "coord"), labels)
precursor = precursor.unstack(("coord"))
- dlat = np.abs(precursor.latitude.values[1] - precursor.latitude.values[0])
- dlon = np.abs(precursor.longitude.values[1] - precursor.longitude.values[0])
- precursor["area"] = spherical_area(precursor.latitude, dlat, dlon)
+ return precursor
+
+
+def masked_spherical_dbscan(
+ precursor: xr.DataArray,
+ corr: xr.DataArray,
+ p_val: xr.DataArray,
+ dbscan_params: dict,
+) -> xr.DataArray:
+
+ """Determines the clusters based on sklearn's DBSCAN implementation. Alpha determines
+ the mask based on the minimum p_value. Grouping can be adjusted using the `eps_km`
+ parameter. Cluster labels are negative for areas with a negative correlation coefficient
+ and positive for areas with a positive correlation coefficient. Areas without any
+ significant correlation are put in the cluster labelled '0'.
+
+ Args:
+ precursor (xr.DataArray): DataArray of the precursor field, containing
+ 'latitude' and 'longitude' dimensions in degrees.
+ corr (xr.DataArray): DataArray with the correlation values, generated by
+ correlation_map()
+ p_val (xr.DataArray): DataArray with the p-values, generated by
+ correlation_map()
+ dbscan_params (dict): Dictionary containing the elements 'alpha', 'eps',
+ 'min_area_km2'. See the documentation of RGDR for more information.
+
+ Returns:
+ xr.DataArray: Precursor data grouped by the DBSCAN clusters.
+ """
+ precursor = add_gridcell_area(precursor)
+
+ precursor = _find_clusters(precursor, corr, p_val, dbscan_params)
if dbscan_params["min_area"]:
precursor = remove_small_area_clusters(precursor, dbscan_params["min_area"])
- precursor.name = orig_name
+ # Make sure a cluster is present in each lag
+ assert_clusters_present(precursor)
+
return precursor
@@ -282,10 +385,11 @@ def get_clusters(
corr, p_val = self.get_correlation(precursor, timeseries)
return masked_spherical_dbscan(precursor, corr, p_val, self._dbscan_params)
- def plot_correlation(
+ def plot_correlation( # pylint: disable=too-many-arguments
self,
precursor: xr.DataArray,
timeseries: xr.DataArray,
+ lag: Optional[int] = None,
ax1: Optional[plt.Axes] = None,
ax2: Optional[plt.Axes] = None,
) -> List[Type[mpl.collections.QuadMesh]]:
@@ -296,6 +400,8 @@ def plot_correlation(
precursor: Precursor field data with the dimensions
'latitude', 'longitude', and 'anchor_year'
timeseries: Timeseries data with only the dimension 'anchor_year'
+ lag: The i_interval which should be plotted. Required if the precursor
+ has the dimension "i_interval".
ax1: a matplotlib axis handle to plot
the correlation values into. If None, an axis handle will be created
instead.
@@ -305,6 +411,15 @@ def plot_correlation(
Returns:
List[mpl.collections.QuadMesh]: List of matplotlib artists.
"""
+
+ if "i_interval" in precursor.dims:
+ if lag is None:
+ raise ValueError(
+ "Precursor contains multiple intervals, please provide"
+ " the lag which should be plotted."
+ )
+ precursor = precursor.sel(i_interval=lag)
+
corr, p_val = self.get_correlation(precursor, timeseries)
if (ax1 is None) and (ax2 is None):
@@ -326,6 +441,7 @@ def plot_clusters(
self,
precursor: xr.DataArray,
timeseries: xr.DataArray,
+ lag: Optional[int] = None,
ax: Optional[plt.Axes] = None,
) -> Type[mpl.collections.QuadMesh]:
"""Generates a figure showing the clusters resulting from the initiated RGDR
@@ -335,18 +451,30 @@ class and input precursor field.
precursor: Precursor field data with the dimensions
'latitude', 'longitude', and 'anchor_year'
timeseries: Timeseries data with only the dimension 'anchor_year'
+ lag: The i_interval which should be plotted. Required if the precursor
+ has the dimension "i_interval".
ax (plt.Axes, optional): a matplotlib axis handle to plot the clusters
into. If None, an axis handle will be created instead.
Returns:
matplotlib.collections.QuadMesh: Matplotlib artist.
"""
- clusters = self.get_clusters(precursor, timeseries)
-
if ax is None:
_, ax = plt.subplots()
- return clusters.cluster_labels.plot(cmap="viridis", ax=ax)
+ clusters = self.get_clusters(precursor, timeseries)
+
+ if "i_interval" in precursor.dims:
+ if lag is None:
+ raise ValueError(
+ "Precursor contains multiple intervals, please provide"
+ " the lag which should be plotted."
+ )
+ clusters = clusters.sel(i_interval=lag)
+
+ clusters = utils.cluster_labels_to_ints(clusters)
+
+ return clusters["cluster_labels"].plot(cmap="viridis", ax=ax)
def fit(self, precursor: xr.DataArray, timeseries: xr.DataArray):
"""Fits RGDR clusters to precursor data.
@@ -400,12 +528,18 @@ def transform(self, data: xr.DataArray) -> xr.DataArray:
data["cluster_labels"] = self._clusters
data["area"] = self._area
- # Add the geographical centers for later alignment between, e.g., splits
reduced_data = utils.weighted_groupby(
data, groupby="cluster_labels", weight="area"
)
- return utils.geographical_cluster_center(data, reduced_data)
+ # Add the geographical centers for later alignment between, e.g., splits
+ reduced_data = utils.geographical_cluster_center(data, reduced_data)
+
+ # Remove the '0' cluster
+ reduced_data = reduced_data.where(reduced_data["cluster_labels"] != "0").dropna(
+ dim="cluster_labels"
+ )
+ return reduced_data
def fit_transform(self, precursor: xr.DataArray, timeseries: xr.DataArray):
"""Fits RGDR clusters to precursor data, and applies RGDR on the input data.
diff --git a/s2spy/rgdr/utils.py b/s2spy/rgdr/utils.py
index 92d45ae..b06dd61 100644
--- a/s2spy/rgdr/utils.py
+++ b/s2spy/rgdr/utils.py
@@ -68,25 +68,45 @@ def geographical_cluster_center(
for i, cluster in enumerate(clusters):
# Select only the grid cells within the cluster
- cluster_data = stacked_data.where(
+ cluster_area = stacked_data["area"].where(
stacked_data["cluster_labels"] == cluster
- ).dropna(dim="coords")
+ )
+
+ if "i_interval" in cluster_area.dims:
+ cluster_area = cluster_area.dropna('i_interval', how='all')
+ cluster_area = cluster_area.dropna("coords")
# Area weighted mean to get the geographical center of the cluster
- # for the 0 clusters (leftovers), set to nan as this will avoid them in e.g.
- # plots
- if cluster == 0:
- cluster_lats[i] = np.nan
- cluster_lons[i] = np.nan
- else:
- cluster_lats[i] = (
- cluster_data["latitude"].weighted(cluster_data["area"]).mean().item()
- )
- cluster_lons[i] = (
- cluster_data["longitude"].weighted(cluster_data["area"]).mean().item()
- )
+ cluster_lats[i] = (
+ cluster_area["latitude"].weighted(cluster_area).mean().item()
+ )
+ cluster_lons[i] = (
+ cluster_area["longitude"].weighted(cluster_area).mean().item()
+ )
reduced_data["latitude"] = ("cluster_labels", cluster_lats)
reduced_data["longitude"] = ("cluster_labels", cluster_lons)
return reduced_data
+
+
+def cluster_labels_to_ints(clustered_data: xr.DataArray) -> xr.DataArray:
+ """Converts the labels of already clustered data to integers.
+
+ Args:
+ clustered_data: Data already clustered and grouped by cluster.
+
+ Returns:
+ Same as input, but with the labels converted to integers
+ """
+ un_labels = np.unique(clustered_data.cluster_labels)
+ label_vals = [int(lb[-2:].replace(":","")) for lb in un_labels]
+ label_lookup = dict(zip(un_labels, label_vals))
+
+ clustered_data['cluster_labels'] = xr.apply_ufunc(
+ lambda val: label_lookup[val],
+ clustered_data['cluster_labels'],
+ vectorize=True
+ )
+
+ return clustered_data
diff --git a/tests/test_rgdr/test_rgdr.py b/tests/test_rgdr/test_rgdr.py
index 3cfb548..4c719eb 100644
--- a/tests/test_rgdr/test_rgdr.py
+++ b/tests/test_rgdr/test_rgdr.py
@@ -6,6 +6,7 @@
import xarray as xr
from s2spy import RGDR
from s2spy.rgdr import rgdr
+from s2spy.rgdr import utils
from s2spy.time import AdventCalendar
from s2spy.time import resample
@@ -38,6 +39,12 @@ def example_field(raw_field, dummy_calendar):
return resample(cal, raw_field).sst.isel(i_interval=1)
+@pytest.fixture(autouse=True, scope="class")
+def example_field_multiple_lags(raw_field, dummy_calendar):
+ cal = dummy_calendar.map_to_data(raw_field)
+ return resample(cal, raw_field).sst.isel(i_interval=slice(1, 4))
+
+
@pytest.fixture(autouse=True, scope="class")
def example_target(raw_target, raw_field, dummy_calendar):
cal = dummy_calendar.map_to_data(raw_field)
@@ -146,6 +153,7 @@ def test_dbscan(
clusters = rgdr.masked_spherical_dbscan(
example_field, corr, p_val, dummy_dbscan_params
)
+ clusters = utils.cluster_labels_to_ints(clusters)
np.testing.assert_array_equal(clusters["cluster_labels"], expected_labels)
@@ -158,6 +166,8 @@ def test_dbscan_min_area(
clusters = rgdr.masked_spherical_dbscan(
example_field, corr, p_val, dbscan_params
)
+ clusters = utils.cluster_labels_to_ints(clusters)
+
expected_labels[expected_labels == -1] = 0 # Small -1 cluster is missing
np.testing.assert_array_equal(clusters["cluster_labels"], expected_labels)
@@ -186,7 +196,8 @@ def test_fit(self, dummy_rgdr, example_field, example_target):
def test_transform(self, dummy_rgdr, example_field, example_target):
dummy_rgdr.fit(example_field, example_target)
clustered_data = dummy_rgdr.transform(example_field)
- cluster_labels = np.array([-2.0, -1.0, 0.0, 1.0])
+ clustered_data = utils.cluster_labels_to_ints(clustered_data)
+ cluster_labels = np.array([-1, -2, 1])
np.testing.assert_array_equal(clustered_data["cluster_labels"], cluster_labels)
def test_fit_transform_fits(self, example_field, example_target):
@@ -198,12 +209,32 @@ def test_fit_transform_fits(self, example_field, example_target):
def test_fit_transform(self, example_field, example_target):
rgdr = RGDR(min_area_km2=1000**2)
clustered_data = rgdr.fit_transform(example_field, example_target)
- cluster_labels = np.array([-2.0, -1.0, 0.0, 1.0])
+ cluster_labels = np.array(
+ ["lag:1_cluster:-1", "lag:1_cluster:-2", "lag:1_cluster:1"])
+ np.testing.assert_array_equal(clustered_data["cluster_labels"], cluster_labels)
+
+ def test_fit_transform_multiple_lags(self, example_field_multiple_lags, example_target):
+ rgdr = RGDR()
+ clustered_data = rgdr.fit_transform(example_field_multiple_lags, example_target)
+ cluster_labels = np.array(
+ ["lag:1_cluster:-2", "lag:1_cluster:1", "lag:2_cluster:-1",
+ "lag:2_cluster:1", "lag:3_cluster:-1"])
np.testing.assert_array_equal(clustered_data["cluster_labels"], cluster_labels)
def test_corr_plot(self, dummy_rgdr, example_field, example_target):
dummy_rgdr.plot_correlation(example_field, example_target)
+ def test_corr_plot_multiple_lags(
+ self, dummy_rgdr, example_field_multiple_lags, example_target
+ ):
+ dummy_rgdr.plot_correlation(example_field_multiple_lags, example_target, lag=1)
+
+ def test_corr_plot_multiple_lags_fail(
+ self, dummy_rgdr, example_field_multiple_lags, example_target
+ ):
+ with pytest.raises(ValueError):
+ dummy_rgdr.plot_correlation(example_field_multiple_lags, example_target)
+
def test_corr_plot_ax(self, dummy_rgdr, example_field, example_target):
_, (ax1, ax2) = plt.subplots(ncols=2)
dummy_rgdr.plot_correlation(example_field, example_target, ax1=ax1, ax2=ax2)
@@ -211,6 +242,17 @@ def test_corr_plot_ax(self, dummy_rgdr, example_field, example_target):
def test_cluster_plot(self, dummy_rgdr, example_field, example_target):
dummy_rgdr.plot_clusters(example_field, example_target)
+ def test_cluster_plot_multiple_lags(
+ self, dummy_rgdr, example_field_multiple_lags, example_target
+ ):
+ dummy_rgdr.plot_clusters(example_field_multiple_lags, example_target, lag=1)
+
+ def test_cluster_plot_multiple_lags_fail(
+ self, dummy_rgdr, example_field_multiple_lags, example_target
+ ):
+ with pytest.raises(ValueError):
+ dummy_rgdr.plot_clusters(example_field_multiple_lags, example_target)
+
def test_cluster_plot_ax(self, dummy_rgdr, example_field, example_target):
_, ax = plt.subplots()
dummy_rgdr.plot_clusters(example_field, example_target, ax=ax)