diff --git a/src/ewatercycle/util.py b/src/ewatercycle/util.py index b5117645..a774ff30 100644 --- a/src/ewatercycle/util.py +++ b/src/ewatercycle/util.py @@ -5,12 +5,18 @@ from datetime import datetime from importlib.metadata import entry_points, version from pathlib import Path -from typing import Any +from typing import Any, overload +import cartopy.crs +import cartopy.feature +import cartopy.io.shapereader +import cartopy.mpl.geoaxes import fiona +import matplotlib.figure import numpy as np import xarray as xr from dateutil.parser import parse +from matplotlib import pyplot as plt from shapely import geometry @@ -353,3 +359,94 @@ def get_package_versions() -> dict[str, str]: package_versions["remotebmi"] = version("remotebmi") package_versions.update({pkg: version(pkg) for pkg in packages}) return package_versions + + +@overload +def plot_catchment( + shapefile: str | Path, + axis: None = None, + lat_bounds: tuple[float, float] | None = None, + lon_bounds: tuple[float, float] | None = None, + figsize: tuple[float, float] | None = None, + color: str | None = None, +) -> tuple[matplotlib.figure.Figure, cartopy.mpl.geoaxes.GeoAxes]: ... + + +@overload +def plot_catchment( + shapefile: str | Path, + axis: cartopy.mpl.geoaxes.GeoAxes, + lat_bounds: tuple[float, float] | None = None, + lon_bounds: tuple[float, float] | None = None, + figsize: tuple[float, float] | None = None, + color: str | None = None, +) -> None: ... + + +def plot_catchment( + shapefile: str | Path, + axis: cartopy.mpl.geoaxes.GeoAxes | None = None, + lat_bounds: tuple[float, float] | None = None, + lon_bounds: tuple[float, float] | None = None, + figsize: tuple[float, float] | None = None, + color: str | None = None, +) -> tuple[matplotlib.figure.Figure, cartopy.mpl.geoaxes.GeoAxes] | None: + """Plot a catchment shapefile. + + Args: + shapefile: Path to the shapefile. First geometry will be displayed. + lat_bounds: The latitude bounds of the plot. Defaults to None, which will make + the bounds be infered automatically. + lon_bounds: The longitude bounds of the plot. Defaults to None, which will make + the bounds be infered automatically. + axis: Existing GeoAxes object to plot the shapefile's geometry in. If not + provided, a new plot with the Plate Carree projection, coastlines, rivers + and oceans will be generated. + figsize: Desired size of the figure (if axis is not provided). Defaults to None. + color: Face color of the catchment. Defaults to None, which will fall back to + the Cartopy default color. Can be a name or hexadecimal color code + (starting with #). + + Returns: + The generated Figure and GeoAxes objects, or None in case the axis argument was + provided. + """ + shape = cartopy.io.shapereader.Reader(shapefile) + shape_geo = next(shape.records()).geometry + if shape_geo is None: + msg = "Could not read shapefile: geometry is undefined." + raise ValueError(msg) + + if axis is not None: + ax = axis + if not hasattr(ax, "projection"): + msg = "Axis is missing a CRS/projection. Cannot plot shapefile." + raise ValueError(msg) + else: + fig = plt.figure(figsize=figsize) + ax = plt.axes(projection=cartopy.crs.PlateCarree()) + ax.add_feature(cartopy.feature.COASTLINE, linewidth=1) # type: ignore[attr-defined] + ax.add_feature(cartopy.feature.RIVERS, linewidth=1) # type: ignore[attr-defined] + ax.add_feature(cartopy.feature.OCEAN, edgecolor="none", facecolor="#a2daff") # type: ignore[attr-defined] + + ax.add_geometries( + shape_geo, crs=ax.projection, edgecolor="black", facecolor=color, alpha=0.8 + ) + + if lat_bounds is None or lon_bounds is None: + minx, miny, maxx, maxy = shape_geo.bounds + xpad = (maxx - minx) / 10 + ypad = (maxy - miny) / 10 + if lon_bounds is None and axis is None: + lon_bounds = (minx - xpad, maxx + xpad) + if lat_bounds is None and axis is None: + lat_bounds = (miny - ypad, maxy + ypad) + + if lon_bounds is not None: + ax.set_xlim(lon_bounds) + if lat_bounds is not None: + ax.set_ylim(lat_bounds) + + if axis is None: + return fig, ax + return None diff --git a/tests/src/test_util.py b/tests/src/test_util.py index a4db9c97..92ee20da 100644 --- a/tests/src/test_util.py +++ b/tests/src/test_util.py @@ -1,17 +1,21 @@ from datetime import datetime, timezone from pathlib import Path +import cartopy.crs import pytest import xarray as xr +from matplotlib import pyplot as plt from numpy.testing import assert_array_equal import ewatercycle +from ewatercycle.testing.fixtures import rhine_shape from ewatercycle.util import ( find_closest_point, fit_extents_to_grid, get_package_versions, get_time, merge_esvmaltool_datasets, + plot_catchment, reindex, to_absolute_path, ) @@ -259,3 +263,23 @@ def test_version_getter(): assert versions["ewatercycle"] == ewatercycle.__version__ assert "grpc4bmi" in versions assert "remotebmi" in versions + + +def test_plot_catchment(): + shp = rhine_shape() + _ = plot_catchment(shp) + + _, ax = plt.subplots() + with pytest.raises(ValueError, match="Axis is missing a CRS"): + plot_catchment(shp, axis=ax) + + _ = plt.figure() + ax = plt.axes(projection=cartopy.crs.PlateCarree()) + plot_catchment( + shp, + axis=ax, + lat_bounds=(40.0, 60.0), + lon_bounds=(0.0, 20.0), + figsize=(5, 5), + color="black", + )