diff --git a/src/grid/cubic.py b/src/grid/cubic.py index 1c2998f6..094b9d38 100644 --- a/src/grid/cubic.py +++ b/src/grid/cubic.py @@ -26,6 +26,10 @@ from grid.basegrid import Grid, OneDGrid +from collections import deque +from typing import Optional, Callable +from scipy.spatial import cKDTree + class _HyperRectangleGrid(Grid): def __init__(self, points, weights, shape): @@ -553,6 +557,7 @@ def __init__(self, origin, axes, shape, weight="Trapezoid"): dim = self._origin.size # Make an array to store coordinates of grid points self._points = np.zeros((np.prod(shape), dim)) + self._weight_scheme = weight if dim == 3: coords = np.array( np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), np.arange(shape[2])) @@ -784,6 +789,11 @@ def origin(self): """Return the Cartesian coordinates of the uniform grid origin.""" return self._origin + @property + def weight_scheme(self): + r"""Return the weight scheme of the uniform grid.""" + return self._weight_scheme + def save(self, filename): r""" Save uniform cubic grid attributes as a npz file. @@ -1013,3 +1023,276 @@ def generate_cube(self, fname, data, atcoords, atnums, pseudo_numbers=None): row_data = data.flat[i : i + num_chunks] f.write((row_data.size * " {:12.5E}").format(*row_data)) f.write("\n") + + +class AdaptiveUniformGrid: + """ + This is a wrapper class that provides adaptive refinement for a UniformGrid instance. + + This class takes a UniformGrid object and applies a recursive subdivision + algorithm to generate a new, non-uniform grid with points concentrated in + regions of high function error, leading to more efficient and accurate integration. + + The main entry point is the `refinement` method. + """ + + def __init__(self, uniform_grid: UniformGrid, error_estimate = "quadrature"): + """Initialization. + + Parameters + ---------- + uniform_grid : UniformGrid + The coarse, uniform grid that will serve as the starting point for refinement. + error_estimate: str, optional + The type of error used to choose which points to refine. Either + 'quadrature' (suited for integration), or `gradient`. Default is 'quadrature'. + """ + if not isinstance(uniform_grid, UniformGrid): + raise ValueError("The input grid should be a UniformGrid instance.") + if uniform_grid.weight_scheme != "Rectangle": + raise ValueError(f"The weight scheme {uniform_grid.weight_scheme} should be Rectangle.") + self.grid = uniform_grid + self.ndim = uniform_grid.ndim + self.axis_spacings = np.array([np.linalg.norm(axis) for axis in self.grid.axes]) + self.axes_norm = self.grid.axes / self.axis_spacings + self.error_estimate = error_estimate + + if self.error_estimate == "gradient": + # Build flat indices for (-) and (+) along each axis, grouped together + # makes it faster to do central finite-difference + # This should match the output order of `_generate_subdivision_points`. + # in 3D- it is [[9, 18], [3, 6], [1, 2]] corresponding to each dimension + # in 2D- it is [[3, 6], [1, 2]] + idx_pairs = [] + shape = (3,) * self.ndim # grid shape per axis + for i in range(self.ndim): + neg = [0] * self.ndim + neg[i] = 1 # index 1 corresponds to -1 offset + pos = [0] * self.ndim + pos[i] = 2 # index 2 corresponds to +1 offset + idx_pairs.append(( + np.ravel_multi_index(neg, shape), + np.ravel_multi_index(pos, shape), + )) + idx_pairs = np.array(idx_pairs) + self._idx_pairs = idx_pairs + + def _get_func_values( + self, points: np.ndarray, func: Callable, evaluated_points: dict + ) -> np.ndarray: + if len(points) == 0: + return np.array([]) + + # Round points for consistent cache keys + rounded_points = np.round(points, 10) + + # Check which points need evaluation + keys = [tuple(p) for p in rounded_points] + missing_indices = [] + values = np.zeros(len(points)) + + for i, key in enumerate(keys): + if key in evaluated_points: + values[i] = evaluated_points[key] + else: + missing_indices.append(i) + + # Batch evaluate missing points + if missing_indices: + missing_points = points[missing_indices] + missing_values = func(missing_points) + + # Update cache and values array + for i, missing_idx in enumerate(missing_indices): + key = keys[missing_idx] + value = missing_values[i] + evaluated_points[key] = value + values[missing_idx] = value + + return values + + def _estimate_error(self, point, weight, func_vals, spacings, subdivision_points): + if self.error_estimate == "quadrature": + return self._estimate_error_quadrature( + point, weight, func_vals, spacings, subdivision_points + ) + elif self.error_estimate == "gradient": + return self._estimate_error_gradient( + point, weight, func_vals, spacings, subdivision_points + ) + raise ValueError(f"Could not recognize the type of error estimate {self.error_estimate}.") + + def _estimate_error_quadrature(self, point, weight, func_vals, _, subdivision_points): + child_weight = weight / len(subdivision_points) + quad_at_pt = func_vals[0] * weight # At the center point + quad_child = np.sum(func_vals * child_weight) + err = np.abs(quad_at_pt - quad_child) + return err + + def _estimate_error_gradient( + self, point, _, func_vals, spacings, __ + ) -> float: + """ + Estimate error using finite difference gradient approximation with batch evaluation. + Uses efficient batch function evaluation to minimize function call overhead. + """ + gradient_magnitude = 0.0 + for dim in range(self.ndim): + # Grabs the minus and forward index + i_minus = self._idx_pairs[dim][0] + i_plus = self._idx_pairs[dim][1] + + f_forward = func_vals[i_plus] + f_backward = func_vals[i_minus] + + # The spacing between center point and i_minus/i_plus is spacing/3.0 + grad_dim = (f_forward - f_backward) / (2 * spacings[dim] / 3.0) + gradient_magnitude += grad_dim**2 + + gradient_magnitude = np.sqrt(gradient_magnitude) + + # Error estimate: gradient magnitude times spacing + # Use geometric mean of spacings as characteristic length + characteristic_spacing = np.prod(spacings) ** (1 / self.ndim) + + return gradient_magnitude * characteristic_spacing + + def _generate_subdivision_points( + self, center_point: np.ndarray, spacings: np.ndarray + ) -> np.ndarray: + # Generate all subdivision points for uniform 3^D cube subdivision. + # The number of subdivision points is 3^D + child_spacings = spacings / 3.0 + # Generate all possible combinations of {0, +, -} such that + # the first point is the `center_point` within the subdivision. + ranges = (np.array([0.0, -1.0, 1.0])[:, None] * child_spacings).T + grids = np.meshgrid(*ranges, indexing="ij") + points_spacing = np.column_stack([grid.ravel() for grid in grids]) # Shape: (3^D, 3) + # With the spacing in each dimension, multiply it by the axes of the cubic grid. + # here k = D, j = D, and the number of subdivision is i = 3^D. + spacing_axes = points_spacing @ self.axes_norm + subdivision_points = center_point + spacing_axes + return subdivision_points + + def refinement( + self, + func: Callable, + tolerance: float = 1e-4, + min_spacing: Optional[float] = None, + max_depth: int = 10, + refine_contrib_threshold: float = 1e-4 + ) -> dict: + """ + Parameters + ---------- + func : Callable + The real-valued, scalar function to be integrated. + The function must be vectorized, and takes in ndarray(M,3) -> float. + tolerance : float, optional + The error tolerance for a local point. + min_spacing : float, optional + The minimum allowed spacing for subdivision. + max_depth : int, optional + Maximum refinement depth to prevent infinite loops. + refine_contrib_threshold: float, optional + Skip refinement for cells with |f(center)| V < refine_value_threshold, + where V is the volume element. + Prevents work in negligible regions. Units match f. Default: 0.0. + + Returns + ------- + dict + A dictionary containing the final integral value, refined grid, and statistics. + """ + if min_spacing is None: + min_spacing = np.min(self.axis_spacings) / 100 + + points = self.grid.points.copy() + weights = self.grid.weights.copy() + initial_spacings = self.axis_spacings.copy() + func_evals = func(points) + + # Refinement dequeue takes in 5 arguments: + # (index of point, point, current spacing, current weight, depth) + refinement_queue = deque() + + # Potentially do refinement on that satisfies this error criteria + # Speeds up computation + indices = np.where(np.abs(func_evals) * weights > refine_contrib_threshold)[0] + for index in indices: + refinement_queue.append( + ( + index, + points[index, :], + initial_spacings.copy(), + weights[index], + 0 + ) + ) + + # Process refinement queue + while refinement_queue: + index, point, spacings, weight, depth = refinement_queue.popleft() + + if depth > max_depth or np.any(spacings < min_spacing): + continue + + subdivision_pts = self._generate_subdivision_points(point, spacings) + + # Compute function values at the subdivision points + func_vals_center = func_evals[index] + func_vals_extra = func(subdivision_pts[1:]) + func_vals_subdiv = np.concatenate(([func_vals_center], func_vals_extra)) + + # Use the subdivision points and function values to compute the error + local_error = self._estimate_error( + point, weight, func_vals_subdiv, spacings, subdivision_pts + ) + + # Do refinement on the center point + if local_error > tolerance: + num_sub_points = len(subdivision_pts) # 3^ndim + + # Update the weight/spacing of that center point + weights[index] = weight / num_sub_points + child_spacings = spacings / 3 + + # Add center point back to the queue with updated weight and spacing + refinement_queue.append( + ( + index, + point, + child_spacings.copy(), + weights[index], + depth + 1 + ) + ) + + # Add all subdivision points back to queue for further processing + # Ignore the first point since it is the center point, and added before. + for i_subpt, sub_point in enumerate(subdivision_pts[1:]): + refinement_queue.append( + ( + len(points) + i_subpt, + sub_point, + child_spacings.copy(), + weights[index], + depth + 1 + ) + ) + # Add the points, func_evals and weights to the initial list. + points = np.vstack((points, subdivision_pts[1:])) + weights = np.append( + weights, + np.full((num_sub_points - 1), fill_value = weights[index]) + ) + func_evals = np.append(func_evals, func_vals_extra) + + final_grid = Grid(points, weights) + final_integral = final_grid.integrate(func_evals) + return { + "integral": final_integral, + "final_grid": final_grid, + "num_points": len(final_grid.points), + "num_evaluations": len(final_grid.points), # Accurate count of function evaluations + } diff --git a/src/grid/tests/test_cubic.py b/src/grid/tests/test_cubic.py index 42f3f30b..95be2521 100644 --- a/src/grid/tests/test_cubic.py +++ b/src/grid/tests/test_cubic.py @@ -28,6 +28,10 @@ from grid.cubic import Tensor1DGrids, UniformGrid, _HyperRectangleGrid from grid.onedgrid import GaussLaguerre, MidPoint +import pytest +import copy +from grid.cubic import AdaptiveUniformGrid, Grid + class TestHyperRectangleGrid(TestCase): r"""Test HyperRectangleGrid class.""" @@ -1050,3 +1054,148 @@ def test_uniformgrid_points_without_rotate(self): ] ) assert_allclose(grid.points, expected, rtol=1.0e-7, atol=1.0e-7) + + +class TestAdaptiveUniformGrid: + case_2d_gentle_wide = { + "id": "2D_Gentle_Peak_Wide_Range", + "grid_setup": { + "origin": np.array([-10.0, -10.0]), + "axes": np.diag([2.0, 2.0]), + "shape": np.array([11, 11]), + }, + "func": lambda points: np.exp(-0.1 * np.sum(points**2, axis=1)), + "analytical_integral": np.pi / 0.1, + "tolerance": 5e-3, + } + + case_3d_gentle_wide = { + "id": "3D_Gentle_Peak_Wide_Range (Lightweight)", + "grid_setup": { + "origin": np.array([-5.0, -5.0, -5.0]), + "axes": np.diag([1.0, 1.0, 1.0]), + "shape": np.array([11, 11, 11]), + }, + "func": lambda points: np.exp(-0.2 * np.sum(points**2, axis=1)), + "analytical_integral": (np.pi / 0.2) ** 1.5, + "tolerance": 5e-3, + "max_depth": 6, + } + + case_2d_moderate_wide = { + "id": "2D_Moderate_Peak_Wide_Range", + "grid_setup": { + "origin": np.array([-5.0, -5.0]), + "axes": np.diag([1.0, 1.0]), + "shape": np.array([11, 11]), + }, + "func": lambda points: np.exp(-2 * np.sum(points**2, axis=1)), + "analytical_integral": np.pi / 2, + "tolerance": 5e-3, + } + + case_2d_asymmetric_peaks = { + "id": "2D_Asymmetric_Multi_Peak", + "grid_setup": { + "origin": np.array([-6.0, -6.0]), + "axes": np.diag([1.0, 1.0]), + "shape": np.array([13, 13]), + }, + "func": lambda points: ( + np.exp(-0.5 * np.sum((points - np.array([-3.0, -3.0])) ** 2, axis=1)) + + np.exp(-3 * np.sum((points - np.array([2.0, 3.0])) ** 2, axis=1)) + ), + "analytical_integral": (np.pi / 0.5) + (np.pi / 3), + "tolerance": 1e-3, + } + + case_constant_function = { + "id": "Constant_Function_No_Refinement", + "grid_setup": { + "origin": np.array([-5.0, -5.0]), + "axes": np.diag([1.0, 1.0]), + "shape": np.array([11, 11]), + }, + "func": lambda points: np.ones(len(points)), + "analytical_integral": 100.0, + "tolerance": 1e-3, + } + + case_2d_sharp_legacy_wide = { + "id": "2D_Sharp_Peak_Legacy_Wide", + "grid_setup": { + "origin": np.array([-5.0, -5.0]), + "axes": np.diag([1.0, 1.0]), + "shape": np.array([11, 11]), + }, + "func": lambda points: np.exp(-20 * np.sum(points**2, axis=1)), + "analytical_integral": np.pi / 20, + "tolerance": 0.1, + } + + @pytest.mark.parametrize( + "test_case", + [ + # case_2d_gentle_wide, + # case_3d_gentle_wide, + case_2d_moderate_wide, + case_2d_asymmetric_peaks, + case_constant_function, + case_2d_sharp_legacy_wide, + ], + ids=lambda tc: tc["id"], + ) + def test_refinement_scenarios(self, test_case): + """Test adaptive refinement using diverse scenarios.""" + # Setup + grid_setup = test_case["grid_setup"] + uniform_grid = UniformGrid( + grid_setup["origin"], grid_setup["axes"], grid_setup["shape"], weight="Rectangle" + ) + adaptive_grid = AdaptiveUniformGrid(uniform_grid) + + analytical_integral = test_case["analytical_integral"] + + # Calculate initial error + initial_values = test_case["func"](uniform_grid.points) + initial_integral = uniform_grid.integrate(initial_values) + initial_error = abs(initial_integral - analytical_integral) + initial_num_points = uniform_grid.size + + # Perform refinement + result = adaptive_grid.refinement( + func=test_case["func"], + tolerance=test_case["tolerance"], + max_depth=test_case.get("max_depth", 10), + ) + + # Get refinement results + refined_integral = result["integral"] + refined_error = abs(refined_integral - analytical_integral) + refined_num_points = result["num_points"] + + # Detailed reporting + initial_error_rate = abs(initial_error / analytical_integral) * 100 + refined_error_rate = abs(refined_error / analytical_integral) * 100 + error_reduction = initial_error / refined_error if refined_error > 0 else float("inf") + + print(f"\n=== {test_case['id']} ===") + print(f"{'Metric':<25} | {'Initial':<20} | {'Refined':<20}") + print("-" * 70) + print(f"{'Exact integral':<25} | {analytical_integral:<20.8f} | {'-':<20}") + print(f"{'Grid points':<25} | {initial_num_points:<20} | {refined_num_points:<20}") + print(f"{'Computed integral':<25} | {initial_integral:<20.8f} | {refined_integral:<20.8f}") + print(f"{'Absolute error':<25} | {initial_error:<20.8e} | {refined_error:<20.8e}") + print(f"{'Error rate (%)':<25} | {initial_error_rate:<20.4f} | {refined_error_rate:<20.4f}") + print( + f"{'Function evaluations':<25} | {initial_num_points:<20} | {result['num_evaluations']:<20}" + ) + print(f"{'Error reduction':<25} | {'-':<20} | {error_reduction:<20.1f}x") + print("=" * 70) + + # Assertions + if test_case["id"] == "Constant_Function_No_Refinement": + assert refined_num_points == initial_num_points + else: + assert refined_error < initial_error + assert refined_num_points >= initial_num_points