Skip to content

Commit 9fd2344

Browse files
committed
Refactor interpolation implementation
1 parent 497b0e0 commit 9fd2344

File tree

7 files changed

+903
-1045
lines changed

7 files changed

+903
-1045
lines changed

firedrake/assemble.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from firedrake.bcs import DirichletBC, EquationBC, EquationBCSplit
2424
from firedrake.functionspaceimpl import WithGeometry, FunctionSpace, FiredrakeDualSpace
2525
from firedrake.functionspacedata import entity_dofs_key, entity_permutations_key
26+
from firedrake.interpolation import get_interpolator
2627
from firedrake.petsc import PETSc
2728
from firedrake.slate import slac, slate
2829
from firedrake.slate.slac.kernel_builder import CellFacetKernelArg, LayerCountKernelArg
@@ -613,17 +614,8 @@ def base_form_assembly_visitor(self, expr, tensor, bcs, *args):
613614
rank = len(expr.arguments())
614615
if rank > 2:
615616
raise ValueError("Cannot assemble an Interpolate with more than two arguments")
616-
# Get the target space
617-
V = v.function_space().dual()
618-
619-
# Get the interpolator
620-
interp_data = expr.interp_data.copy()
621-
default_missing_val = interp_data.pop('default_missing_val', None)
622-
if rank == 1 and isinstance(tensor, firedrake.Function):
623-
V = tensor
624-
interpolator = firedrake.Interpolator(expr, V, bcs=bcs, **interp_data)
625-
# Assembly
626-
return interpolator.assemble(tensor=tensor, default_missing_val=default_missing_val)
617+
interpolator = get_interpolator(expr)
618+
return interpolator.assemble(tensor=tensor, bcs=bcs)
627619
elif tensor and isinstance(expr, (firedrake.Function, firedrake.Cofunction, firedrake.MatrixBase)):
628620
return tensor.assign(expr)
629621
elif tensor and isinstance(expr, ufl.ZeroBaseForm):

firedrake/bcs.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# A module implementing strong (Dirichlet) boundary conditions.
22
import numpy as np
33

4-
import functools
4+
from functools import partial, reduce
55
import itertools
66

77
import ufl
@@ -167,7 +167,7 @@ def hermite_stride(bcnodes):
167167
# Edge conditions have only been tested with Lagrange elements.
168168
# Need to expand the list.
169169
bcnodes1.append(hermite_stride(self._function_space.boundary_nodes(ss)))
170-
bcnodes1 = functools.reduce(np.intersect1d, bcnodes1)
170+
bcnodes1 = reduce(np.intersect1d, bcnodes1)
171171
bcnodes.append(bcnodes1)
172172
return np.concatenate(bcnodes)
173173

@@ -359,11 +359,10 @@ def function_arg(self, g):
359359
raise RuntimeError(f"Provided boundary value {g} does not match shape of space")
360360
try:
361361
self._function_arg = firedrake.Function(V)
362-
# Use `Interpolator` instead of assembling an `Interpolate` form
363-
# as the expression compilation needs to happen at this stage to
364-
# determine if we should use interpolation or projection
365-
# -> e.g. interpolation may not be supported for the element.
366-
self._function_arg_update = firedrake.Interpolator(g, self._function_arg)._interpolate
362+
interpolator = firedrake.get_interpolator(firedrake.interpolate(g, V))
363+
# Call this here to check if the element supports interpolation
364+
interpolator._build_callable()
365+
self._function_arg_update = partial(interpolator.assemble, tensor=self._function_arg)
367366
except (NotImplementedError, AttributeError):
368367
# Element doesn't implement interpolation
369368
self._function_arg = firedrake.Function(V).project(g)

0 commit comments

Comments
 (0)