"""
Clipping grid to a polygon
==========================

The :func:`pygmt.grdcut` function allows you to extract a subregion from a
grid. In this example we use a :class:`geopandas.GeoDataFrame`
to crop the grid to the polygon's inner or outer region, invert it, or do both.
"""

# %%
import geopandas as gpd
import pygmt
from shapely.geometry import Polygon

fig = pygmt.Figure()

# Define region of interest around Iceland
region = [-28, -10, 62, 68]  # xmin, xmax, ymin, ymax

# Load sample grid (3 arc-minutes global relief) in target area
grid = pygmt.datasets.load_earth_relief(resolution="03m", region=region)

# Create a more polygon (irregular shape) around a smaller ROI
poly = Polygon(
    [
        (-26, 63),
        (-23, 63.5),
        (-20, 64),
        (-18, 65),
        (-19, 66),
        (-22, 66.5),
        (-25, 66),
        (-27, 65),
        (-26, 63),
    ]
)
gdf = gpd.GeoDataFrame({"geometry": [poly]}, crs="OGC:CRS84")

# Original grid
fig.basemap(
    region=region,
    projection="M12c",
    frame=["WSne+toriginal grid", "xa5f1", "ya2f1"],
)
fig.grdimage(grid=grid, cmap="oleron")

# Cropped grid
fig.shift_origin(xshift="w+0.5c")
cropped_grid = pygmt.grdcut(grid=grid, polygon=gdf, crop=True)
fig.basemap(
    region=region,
    projection="M12c",
    frame=["WSne+tcropped", "xa5f1", "ya2f1"],
)
fig.grdimage(grid=cropped_grid, cmap="oleron")

# Inverted grid
fig.shift_origin(xshift="w+0.5c")
inverted_grid = pygmt.grdcut(grid=grid, polygon=gdf, invert=True)
fig.basemap(
    region=region,
    projection="M12c",
    frame=["WSne+tinverted", "xa5f1", "ya2f1"],
)
fig.grdimage(grid=inverted_grid, cmap="oleron")

# Cropped + inverted grid
fig.shift_origin(xshift="w+0.5c")
cropped_inverted_grid = pygmt.grdcut(grid=grid, polygon=gdf, crop=True, invert=True)
fig.basemap(
    region=region,
    projection="M12c",
    frame=["WSne+tcropped and inverted", "xa5f1", "ya2f1"],
)
fig.grdimage(grid=cropped_inverted_grid, cmap="oleron")

# Shared colorbar
fig.colorbar(frame=["x+lElevation", "y+lm"], position="JMR+o0.5c/0c+w8c")

fig.show()
