Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
616 changes: 616 additions & 0 deletions examples/adaptive_grid.ipynb

Large diffs are not rendered by default.

Binary file added examples/result_kl_fpi_stype_only.npz
Binary file not shown.
55 changes: 45 additions & 10 deletions src/grid/cubic.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,10 @@ def get_points_along_axes(self):
x = self.points[coords_x, 0]
return x, y

def interpolate(self, points, values, use_log=False, nu_x=0, nu_y=0, nu_z=0, method="cubic"):
r"""Interpolate function and its derivatives on cubic grid.
def interpolate(
self, points, values, use_log=False, nu_x=0, nu_y=0, nu_z=0, method="cubic", grid_pts=None
):
r"""Interpolate function value at a given point.

Only implemented in three-dimensions.

Expand Down Expand Up @@ -130,17 +132,33 @@ def interpolate(self, points, values, use_log=False, nu_x=0, nu_y=0, nu_z=0, met
If zero, then the function in z-direction is interpolated.
If greater than zero, then the "nu_z"th-order derivative in the z-direction is
interpolated.
method : str, optional
method: str, optional
The method of interpolation to perform. Supported are "cubic" (most accurate but
computationally expensive), "linear", or "nearest" (least accurate but cheap
computationally). The last two methods use SciPy's RegularGridInterpolator function.
grid_pts: list[OneDGrids], optional
If provided, then uses `grid_pts` rather than the points of the HyperRectangle class
`self.points` to construct interpolation. Useful when doing a promolecular
transformation.

Returns
-------
float :
The interpolation of a function (or of it's derivatives) at a :math:`M` point.

"""
# Needed because CubicProTransform is a subclass of this method and has its own
# interpolate function. Since interpolate references itself, it chooses
# CubicProTransform rather than _HyperRectangleGrid class.
return self._interpolate(
points, values, use_log, nu_x, nu_y, nu_z, method, grid_pts
)

def _interpolate(
self, points, values, use_log=False, nu_x=0, nu_y=0, nu_z=0, method="cubic", grid_pts=None
):
r"""Core of the Interpolate Algorithm."""

if method not in ["cubic", "linear", "nearest"]:
raise ValueError(
f"Argument method should be either cubic, linear, or nearest , got {method}"
Expand All @@ -154,6 +172,10 @@ def interpolate(self, points, values, use_log=False, nu_x=0, nu_y=0, nu_z=0, met
f"Number of function values {values.shape[0]} does not match number of "
f"grid points {np.prod(self.shape)}."
)
if grid_pts is not None and not isinstance(grid_pts, np.ndarray):
raise TypeError(
f"The grid points {type(grid_pts)} should have type None or numpy array."
)

if use_log:
values = np.log(values)
Expand All @@ -165,6 +187,10 @@ def interpolate(self, points, values, use_log=False, nu_x=0, nu_y=0, nu_z=0, met
interpolate = RegularGridInterpolator((x, y, z), values, method=method)
return interpolate(points)

# If grid_pts isn't specified, then use the grid stored as the class attribute.
if grid_pts is None:
grid_pts = self.points

# Interpolate the Z-Axis.
def z_spline(z, x_index, y_index, nu_z=nu_z):
# x_index, y_index is assumed to be in the grid while z is not assumed.
Expand All @@ -173,7 +199,7 @@ def z_spline(z, x_index, y_index, nu_z=nu_z):
small_index = self.coordinates_to_index((x_index, y_index, 1))
large_index = self.coordinates_to_index((x_index, y_index, self.shape[2] - 2))
val = CubicSpline(
self.points[small_index:large_index, 2],
grid_pts[small_index:large_index, 2],
values[small_index:large_index],
)(z, nu_z)
return val
Expand All @@ -183,7 +209,7 @@ def y_splines(y, x_index, z, nu_y=nu_y):
# The `1` and `self.num_puts[1] - 2` is needed because I don't want the boundary.
# Assumes x_index is in the grid while y, z may not be.
val = CubicSpline(
self.points[np.arange(1, self.shape[1] - 2) * self.shape[2], 1],
grid_pts[np.arange(1, self.shape[1] - 2) * self.shape[2], 1],
[z_spline(z, x_index, y_index, nu_z) for y_index in range(1, self.shape[1] - 2)],
)(y, nu_y)
# Trying to vectorize over z-axis and y-axis, this computes the interpolation for every
Expand All @@ -195,7 +221,7 @@ def y_splines(y, x_index, z, nu_y=nu_y):
# Interpolate the point (x, y, z) from a list of interpolated points on x,y-axis.
def x_spline(x, y, z, nu_x):
val = CubicSpline(
self.points[np.arange(1, self.shape[0] - 2) * self.shape[1] * self.shape[2], 0],
grid_pts[np.arange(1, self.shape[0] - 2) * self.shape[1] * self.shape[2], 0],
[y_splines(y, x_index, z, nu_y) for x_index in range(1, self.shape[0] - 2)],
)(x, nu_x)
# Trying to vectorize over x-axis, this computes the interpolation for every
Expand All @@ -207,7 +233,9 @@ def x_spline(x, y, z, nu_x):
if use_log:
# All derivatives require the interpolation of f at (x,y,z)
interpolated = np.exp(
self.interpolate(points, values, use_log=False, nu_x=0, nu_y=0, nu_z=0)
self._interpolate(
points, values, use_log=False, nu_x=0, nu_y=0, nu_z=0, grid_pts=grid_pts
)
)
# Only consider taking the derivative in only one direction
one_var_deriv = sum([nu_x == 0, nu_y == 0, nu_z == 0]) == 2
Expand All @@ -218,21 +246,28 @@ def x_spline(x, y, z, nu_x):
elif one_var_deriv:
# Taking the k-th derivative wrt to only one variable (x, y, z)
# Interpolate d^k ln(f) d"deriv_var" for all k from 1 to "deriv_var"
# Each entry of `derivs` is the interpolation of the derivative eval on points.
if nu_x > 0:
derivs = [
self.interpolate(points, values, use_log=False, nu_x=i, nu_y=0, nu_z=0)
self._interpolate(
points, values, use_log=False, nu_x=i, nu_y=0, nu_z=0, grid_pts=grid_pts
)
for i in range(1, nu_x + 1)
]
deriv_var = nu_x
elif nu_y > 0:
derivs = [
self.interpolate(points, values, use_log=False, nu_x=0, nu_y=i, nu_z=0)
self._interpolate(
points, values, use_log=False, nu_x=0, nu_y=i, nu_z=0, grid_pts=grid_pts
)
for i in range(1, nu_y + 1)
]
deriv_var = nu_y
else:
derivs = [
self.interpolate(points, values, use_log=False, nu_x=0, nu_y=0, nu_z=i)
self._interpolate(
points, values, use_log=False, nu_x=0, nu_y=0, nu_z=i, grid_pts=grid_pts
)
for i in range(1, nu_z + 1)
]
deriv_var = nu_z
Expand Down
Loading