Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
a48aa01
initial class
leo-collins Jul 3, 2025
13711c8
add test
leo-collins Jul 3, 2025
14e250e
add function
leo-collins Jul 3, 2025
2cf60e4
working for simple expressions
leo-collins Jul 3, 2025
8937403
fix
leo-collins Jul 3, 2025
d8d2f6f
add form test
leo-collins Jul 3, 2025
5b20d6d
add more tests
leo-collins Jul 3, 2025
062a779
change `split` to get domain from function space
leo-collins Jul 3, 2025
e37249d
add interpolate test
leo-collins Jul 3, 2025
85de70d
add test
leo-collins Jul 3, 2025
df1d911
add case for `Interpolate`
leo-collins Jul 4, 2025
957adaa
cyclic import goose chase
leo-collins Jul 4, 2025
838ae0b
notes
leo-collins Jul 4, 2025
7dffe54
fix imports
leo-collins Jul 8, 2025
2de0be5
fixes
leo-collins Jul 9, 2025
fa52bf3
update `measure.py`
leo-collins Jul 15, 2025
7befed8
rename to `unique_domain_extractor`
leo-collins Jul 15, 2025
b3770df
remove relevant uses of `expand_mesh_sequence`
leo-collins Jul 15, 2025
f5490ba
tidy up
leo-collins Jul 15, 2025
8555dfa
lint
leo-collins Jul 17, 2025
5cb1c57
base form
leo-collins Jul 17, 2025
1c77207
update tests
leo-collins Jul 17, 2025
c68448a
fix imports
leo-collins Jul 17, 2025
a55ebe7
cyclic import goose chase
leo-collins Jul 4, 2025
3c81eb7
add test for interpolate with mesh sequence
leo-collins Jul 30, 2025
f6bd995
fix type check
leo-collins Jul 30, 2025
c3933b3
remove interpolate tests
leo-collins Aug 28, 2025
283eb2f
tidy
leo-collins Aug 28, 2025
71be2d4
remove interpolate case
leo-collins Aug 28, 2025
4075d9a
tidy
leo-collins Aug 28, 2025
05f620c
`BaseForm` case
leo-collins Sep 3, 2025
43ff190
`BaseFormOperator` case
leo-collins Sep 3, 2025
75bcbfa
lint
leo-collins Sep 3, 2025
e296ddc
tidy
leo-collins Sep 9, 2025
6a0ccd0
lint
leo-collins Sep 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
301 changes: 301 additions & 0 deletions test/test_unique_domain_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,301 @@
import pytest
from utils import FiniteElement, LagrangeElement, MixedElement

from ufl import (
Action,
Adjoint,
Coefficient,
Constant,
FacetNormal,
FunctionSpace,
Interpolate,
Matrix,
Measure,
Mesh,
MeshSequence,
SpatialCoordinate,
TestFunction,
TrialFunction,
cos,
div,
grad,
inner,
split,
triangle,
)
from ufl.domain import extract_unique_domain
from ufl.pullback import contravariant_piola, identity_pullback
from ufl.sobolevspace import L2, HDiv


def test_extract_unique_domain():
cell = triangle
elem0 = LagrangeElement(cell, 1)
elem1 = FiniteElement("Brezzi-Douglas-Marini", cell, 2, (2,), contravariant_piola, HDiv)
elem2 = FiniteElement("Discontinuous Lagrange", cell, 1, (), identity_pullback, L2)
elem = MixedElement([elem0, elem1, elem2])
mesh1 = Mesh(LagrangeElement(cell, 1, (2,)), ufl_id=100)
mesh2 = Mesh(LagrangeElement(cell, 1, (2,)), ufl_id=101)
mesh3 = Mesh(LagrangeElement(cell, 1, (2,)), ufl_id=102)
domain = MeshSequence([mesh1, mesh2, mesh3])
V = FunctionSpace(domain, elem)

u = TrialFunction(V)
u1, u2, u3 = split(u)
for i, u_i in enumerate((u1, u2, u3)):
assert extract_unique_domain(u_i) == domain[i]

f = Coefficient(V)
f1, f2, f3 = split(f)
for i, f_i in enumerate((f1, f2, f3)):
assert extract_unique_domain(f_i) == domain[i]

x1, y1 = SpatialCoordinate(mesh1)
expr = u1 + x1 * cos(x1)
assert extract_unique_domain(expr) == mesh1

expr2 = u1 * Constant(mesh1) + x1
assert extract_unique_domain(expr2) == mesh1

x2, y2 = SpatialCoordinate(mesh2)
with pytest.raises(ValueError):
_ = extract_unique_domain(u1 + u2)
_ = extract_unique_domain(u1 + u2 + x2 * cos(x2 * u1))


def test_extract_unique_domain_form():
cell = triangle
elem0 = LagrangeElement(cell, 1)
elem1 = FiniteElement("Brezzi-Douglas-Marini", cell, 2, (2,), contravariant_piola, HDiv)
elem2 = FiniteElement("Discontinuous Lagrange", cell, 1, (), identity_pullback, L2)
elem = MixedElement([elem0, elem1, elem2])
mesh1 = Mesh(LagrangeElement(cell, 1, (2,)), ufl_id=100)
mesh2 = Mesh(LagrangeElement(cell, 1, (2,)), ufl_id=101)
mesh3 = Mesh(LagrangeElement(cell, 1, (2,)), ufl_id=102)
domain = MeshSequence([mesh1, mesh2, mesh3])
V = FunctionSpace(domain, elem)

u = TrialFunction(V)
u1, u2, u3 = split(u)
v = TestFunction(V)
v1, v2, v3 = split(v)

f = Coefficient(V)
f1, f2, f3 = split(f)

n = FacetNormal(mesh1)
dx1 = Measure("dx", mesh1)
ds1 = Measure("ds", mesh1)
dx2 = Measure("dx", mesh2)

form1 = inner(grad(u1), grad(v1)) * dx1 - inner(grad(u1), n) * v1 * ds1
assert extract_unique_domain(form1) == mesh1

form2 = inner(u1, f1) * dx1
assert extract_unique_domain(form2) == mesh1

form3 = inner(u1, v1) * dx1 + inner(u2, v2) * dx2
with pytest.raises(ValueError):
extract_unique_domain(form3)


def test_extract_unique_domain_single_mesh():
"""Test domain extraction for standard function spaces on a single mesh."""
cell = triangle
mesh = Mesh(LagrangeElement(cell, 1, (2,)), ufl_id=200)

# Test scalar elements
P1 = LagrangeElement(cell, 1)
V_scalar = FunctionSpace(mesh, P1)
u_scalar = TrialFunction(V_scalar)
f_scalar = Coefficient(V_scalar)

assert extract_unique_domain(u_scalar) == mesh
assert extract_unique_domain(f_scalar) == mesh

P1_vec = LagrangeElement(cell, 1, (2,))
V_vector = FunctionSpace(mesh, P1_vec)
u_vector = TrialFunction(V_vector)
f_vector = Coefficient(V_vector)

assert extract_unique_domain(u_vector) == mesh
assert extract_unique_domain(f_vector) == mesh

assert extract_unique_domain(u_vector[0]) == mesh
assert extract_unique_domain(u_vector[1]) == mesh
assert extract_unique_domain(f_vector[0]) == mesh
assert extract_unique_domain(f_vector[1]) == mesh

P1_tensor = LagrangeElement(cell, 1, (2, 2))
V_tensor = FunctionSpace(mesh, P1_tensor)
u_tensor = TrialFunction(V_tensor)
f_tensor = Coefficient(V_tensor)

assert extract_unique_domain(u_tensor) == mesh
assert extract_unique_domain(f_tensor) == mesh
assert extract_unique_domain(u_tensor[0, 0]) == mesh
assert extract_unique_domain(u_tensor[1, 1]) == mesh
assert extract_unique_domain(f_tensor[0, 1]) == mesh

x, y = SpatialCoordinate(mesh)
expr1 = u_scalar + f_scalar
expr2 = u_vector[0] + x
expr3 = inner(u_vector, f_vector)

assert extract_unique_domain(expr1) == mesh
assert extract_unique_domain(expr2) == mesh
assert extract_unique_domain(expr3) == mesh

# Test forms
dx = Measure("dx", mesh)
form = inner(u_scalar, f_scalar) * dx
assert extract_unique_domain(form) == mesh


def test_extract_unique_domain_mixed_scalar_vector_tensor():
"""Test domain extraction for mixed function spaces
with scalar, vector, and tensor elements."""
cell = triangle
mesh1 = Mesh(LagrangeElement(cell, 1, (2,)), ufl_id=400)
mesh2 = Mesh(LagrangeElement(cell, 1, (2,)), ufl_id=401)
mesh3 = Mesh(LagrangeElement(cell, 1, (2,)), ufl_id=402)
domain = MeshSequence([mesh1, mesh2, mesh3])

scalar_elem = LagrangeElement(cell, 1)
vector_elem = LagrangeElement(cell, 1, (2,))
tensor_elem = LagrangeElement(cell, 1, (2, 2))
mixed_elem = MixedElement([scalar_elem, vector_elem, tensor_elem])

V = FunctionSpace(domain, mixed_elem)
u = TrialFunction(V)
f = Coefficient(V)

u_scalar, u_vector, u_tensor = split(u)
f_scalar, f_vector, f_tensor = split(f)

for i, u_i in enumerate((u_scalar, u_vector, u_tensor)):
assert extract_unique_domain(u_i) == domain[i]
for i, f_i in enumerate((f_scalar, f_vector, f_tensor)):
assert extract_unique_domain(f_i) == domain[i]

for i in range(2):
assert extract_unique_domain(u_vector[i]) == mesh2
assert extract_unique_domain(f_vector[i]) == mesh2

for i in range(2):
for j in range(2):
assert extract_unique_domain(u_tensor[i, j]) == mesh3
assert extract_unique_domain(f_tensor[i, j]) == mesh3

x1, y1 = SpatialCoordinate(mesh1)
x2, y2 = SpatialCoordinate(mesh2)
x3, y3 = SpatialCoordinate(mesh3)

expr_scalar = u_scalar * y1 + f_scalar + x1
assert extract_unique_domain(expr_scalar) == mesh1

expr_vector = inner(u_vector * y2, f_vector) + x2
assert extract_unique_domain(expr_vector) == mesh2

expr_vec_comp = u_vector[0] + f_vector[1] * y2 + x2
assert extract_unique_domain(expr_vec_comp) == mesh2

expr_tensor = y3 * u_tensor[0, 0] + f_tensor[1, 1] + x3
assert extract_unique_domain(expr_tensor) == mesh3

with pytest.raises(ValueError):
extract_unique_domain(u_scalar + u_vector[0])

with pytest.raises(ValueError):
extract_unique_domain(u_vector[0] + u_tensor[0, 0])

with pytest.raises(ValueError):
extract_unique_domain(f_scalar + f_tensor[1, 1])

with pytest.raises(ValueError):
extract_unique_domain(u_scalar + x2)

with pytest.raises(ValueError):
extract_unique_domain(u_vector[0] + x3)

dx1 = Measure("dx", mesh1)
dx2 = Measure("dx", mesh2)
dx3 = Measure("dx", mesh3)

form_scalar = u_scalar * f_scalar * dx1
form_vector = inner(u_vector, f_vector) * dx2
form_tensor = u_tensor[0, 0] * f_tensor[1, 1] * dx3

assert extract_unique_domain(form_scalar) == mesh1
assert extract_unique_domain(form_vector) == mesh2
assert extract_unique_domain(form_tensor) == mesh3

div_expr = div(u_vector) * f_scalar
with pytest.raises(ValueError):
extract_unique_domain(div_expr)


def test_extract_unique_domain_repeated_meshes():
cell = triangle
mesh1 = Mesh(LagrangeElement(cell, 1, (2,)), ufl_id=500)
mesh2 = Mesh(LagrangeElement(cell, 1, (2,)), ufl_id=501)

# MeshSequence with repeated meshes
domain_repeated = MeshSequence([mesh1, mesh2, mesh1])

scalar_elem = LagrangeElement(cell, 1, shape=())
mixed_elem = MixedElement([scalar_elem, scalar_elem, scalar_elem])
V = FunctionSpace(domain_repeated, mixed_elem)
u = TrialFunction(V)

u1, u2, u3 = split(u)

assert extract_unique_domain(u1) == mesh1
assert extract_unique_domain(u2) == mesh2
assert extract_unique_domain(u3) == mesh1

expr_same = u1 + u3
assert extract_unique_domain(expr_same) == mesh1

with pytest.raises(ValueError):
extract_unique_domain(u1 + u2)


def test_extract_unique_domain_baseform():
cell = triangle
mesh1 = Mesh(LagrangeElement(cell, 1, (2,)), ufl_id=400)
mesh2 = Mesh(LagrangeElement(cell, 1, (2,)), ufl_id=401)
scalar_elem = LagrangeElement(cell, 1)

V1 = FunctionSpace(mesh1, scalar_elem)
V2 = FunctionSpace(mesh2, scalar_elem)

A = Matrix(V1, V2)
assert extract_unique_domain(A) == mesh1
Comment on lines +270 to +274
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this what we expect?


v = Coefficient(V2)
action_Au = Action(A, v)
assert extract_unique_domain(action_Au) == mesh1

Astar = Adjoint(A)
assert extract_unique_domain(Astar) == mesh2

v1 = TrialFunction(V1)
v2star = TestFunction(V2.dual())
interp = Interpolate(v1, v2star) # V1 x V2^* -> R, equiv V1 -> V2
assert extract_unique_domain(interp) == mesh2
adjoint_interp = Adjoint(interp) # V2^* x V1 -> R, equiv V2^* -> V1^*
assert extract_unique_domain(adjoint_interp) == mesh1

cofunc = Coefficient(V2.dual())
scalar = Action(cofunc, v)
assert extract_unique_domain(scalar) is None

v = TestFunction(V2)
dx = Measure("dx", mesh2)
one_form = v * dx
formsum = cofunc + one_form
assert extract_unique_domain(formsum) is mesh2

two_form = interp * v * dx
assert extract_unique_domain(two_form) is mesh2
4 changes: 2 additions & 2 deletions ufl/algorithms/apply_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,7 @@ def _(self, o: ReferenceValue) -> Expr:
f = o.ufl_operands[0]
if not f._ufl_is_terminal_:
raise ValueError("ReferenceValue can only wrap a terminal")
domain = extract_unique_domain(f, expand_mesh_sequence=False)
domain = extract_unique_domain(f)
if isinstance(domain, MeshSequence):
element = f.ufl_function_space().ufl_element() # type: ignore
if element.num_sub_elements != len(domain):
Expand Down Expand Up @@ -897,7 +897,7 @@ def _(self, o: Expr) -> Expr:
)
if not valid_operand:
raise ValueError("ReferenceGrad can only wrap a reference frame type!")
domain = extract_unique_domain(f, expand_mesh_sequence=False)
domain = extract_unique_domain(f)
if isinstance(domain, MeshSequence):
if not f._ufl_is_in_reference_frame_:
raise RuntimeError("Expecting a reference frame type")
Expand Down
4 changes: 2 additions & 2 deletions ufl/algorithms/compute_form_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def _build_coefficient_replace_map(coefficients, element_mapping=None):
# coefficient had a domain, the new one does too.
# This should be overhauled with requirement that Expressions
# always have a domain.
domain = extract_unique_domain(f, expand_mesh_sequence=False)
domain = extract_unique_domain(f)
if domain is not None:
new_e = FunctionSpace(domain, new_e)
new_f = Coefficient(new_e, count=i)
Expand Down Expand Up @@ -454,7 +454,7 @@ def compute_form_data(
for o in self.reduced_coefficients:
if o in coefficients_to_split:
c = self.function_replace_map[o]
mesh = extract_unique_domain(c, expand_mesh_sequence=False)
mesh = extract_unique_domain(c)
elem = c.ufl_element()
coefficient_split[c] = [
Coefficient(FunctionSpace(m, e))
Expand Down
3 changes: 2 additions & 1 deletion ufl/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from ufl.core.expr import Expr
from ufl.core.terminal import FormArgument
from ufl.corealg.traversal import traverse_unique_terminals
from ufl.geometry import GeometricQuantity
from ufl.sobolevspace import H1


Expand Down Expand Up @@ -40,6 +39,8 @@ def is_cellwise_constant(expr):

def is_scalar_constant_expression(expr):
"""Check if an expression is a globally constant scalar expression."""
from ufl.geometry import GeometricQuantity

if is_python_scalar(expr):
return True
if expr.ufl_shape:
Expand Down
9 changes: 6 additions & 3 deletions ufl/corealg/dag_traverser.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""Base class for dag traversers."""

from __future__ import annotations

from functools import singledispatchmethod, wraps
from typing import overload
from typing import TYPE_CHECKING, overload

from ufl.classes import Expr
from ufl.form import BaseForm
if TYPE_CHECKING:
from ufl.classes import Expr
from ufl.form import BaseForm


class DAGTraverser:
Expand Down
4 changes: 2 additions & 2 deletions ufl/differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,15 +318,15 @@ def __new__(cls, f):
# Return zero if expression is trivially constant
if is_cellwise_constant(f):
# TODO: Use max topological dimension if there are multiple topological dimensions.
dim = extract_unique_domain(f, expand_mesh_sequence=False).topological_dimension()
dim = extract_unique_domain(f).topological_dimension()
return Zero(f.ufl_shape + (dim,), f.ufl_free_indices, f.ufl_index_dimensions)
return CompoundDerivative.__new__(cls)

def __init__(self, f):
"""Initalise."""
CompoundDerivative.__init__(self, (f,))
# TODO: Use max topological dimension if there are multiple topological dimensions.
self._dim = extract_unique_domain(f, expand_mesh_sequence=False).topological_dimension()
self._dim = extract_unique_domain(f).topological_dimension()

def _ufl_expr_reconstruct_(self, op):
"""Return a new object of the same type with new operands."""
Expand Down
Loading
Loading