diff --git a/test/test_unique_domain_extractor.py b/test/test_unique_domain_extractor.py new file mode 100644 index 000000000..8d9c097c7 --- /dev/null +++ b/test/test_unique_domain_extractor.py @@ -0,0 +1,301 @@ +import pytest +from utils import FiniteElement, LagrangeElement, MixedElement + +from ufl import ( + Action, + Adjoint, + Coefficient, + Constant, + FacetNormal, + FunctionSpace, + Interpolate, + Matrix, + Measure, + Mesh, + MeshSequence, + SpatialCoordinate, + TestFunction, + TrialFunction, + cos, + div, + grad, + inner, + split, + triangle, +) +from ufl.domain import extract_unique_domain +from ufl.pullback import contravariant_piola, identity_pullback +from ufl.sobolevspace import L2, HDiv + + +def test_extract_unique_domain(): + cell = triangle + elem0 = LagrangeElement(cell, 1) + elem1 = FiniteElement("Brezzi-Douglas-Marini", cell, 2, (2,), contravariant_piola, HDiv) + elem2 = FiniteElement("Discontinuous Lagrange", cell, 1, (), identity_pullback, L2) + elem = MixedElement([elem0, elem1, elem2]) + mesh1 = Mesh(LagrangeElement(cell, 1, (2,)), ufl_id=100) + mesh2 = Mesh(LagrangeElement(cell, 1, (2,)), ufl_id=101) + mesh3 = Mesh(LagrangeElement(cell, 1, (2,)), ufl_id=102) + domain = MeshSequence([mesh1, mesh2, mesh3]) + V = FunctionSpace(domain, elem) + + u = TrialFunction(V) + u1, u2, u3 = split(u) + for i, u_i in enumerate((u1, u2, u3)): + assert extract_unique_domain(u_i) == domain[i] + + f = Coefficient(V) + f1, f2, f3 = split(f) + for i, f_i in enumerate((f1, f2, f3)): + assert extract_unique_domain(f_i) == domain[i] + + x1, y1 = SpatialCoordinate(mesh1) + expr = u1 + x1 * cos(x1) + assert extract_unique_domain(expr) == mesh1 + + expr2 = u1 * Constant(mesh1) + x1 + assert extract_unique_domain(expr2) == mesh1 + + x2, y2 = SpatialCoordinate(mesh2) + with pytest.raises(ValueError): + _ = extract_unique_domain(u1 + u2) + _ = extract_unique_domain(u1 + u2 + x2 * cos(x2 * u1)) + + +def test_extract_unique_domain_form(): + cell = triangle + elem0 = LagrangeElement(cell, 1) + elem1 = FiniteElement("Brezzi-Douglas-Marini", cell, 2, (2,), contravariant_piola, HDiv) + elem2 = FiniteElement("Discontinuous Lagrange", cell, 1, (), identity_pullback, L2) + elem = MixedElement([elem0, elem1, elem2]) + mesh1 = Mesh(LagrangeElement(cell, 1, (2,)), ufl_id=100) + mesh2 = Mesh(LagrangeElement(cell, 1, (2,)), ufl_id=101) + mesh3 = Mesh(LagrangeElement(cell, 1, (2,)), ufl_id=102) + domain = MeshSequence([mesh1, mesh2, mesh3]) + V = FunctionSpace(domain, elem) + + u = TrialFunction(V) + u1, u2, u3 = split(u) + v = TestFunction(V) + v1, v2, v3 = split(v) + + f = Coefficient(V) + f1, f2, f3 = split(f) + + n = FacetNormal(mesh1) + dx1 = Measure("dx", mesh1) + ds1 = Measure("ds", mesh1) + dx2 = Measure("dx", mesh2) + + form1 = inner(grad(u1), grad(v1)) * dx1 - inner(grad(u1), n) * v1 * ds1 + assert extract_unique_domain(form1) == mesh1 + + form2 = inner(u1, f1) * dx1 + assert extract_unique_domain(form2) == mesh1 + + form3 = inner(u1, v1) * dx1 + inner(u2, v2) * dx2 + with pytest.raises(ValueError): + extract_unique_domain(form3) + + +def test_extract_unique_domain_single_mesh(): + """Test domain extraction for standard function spaces on a single mesh.""" + cell = triangle + mesh = Mesh(LagrangeElement(cell, 1, (2,)), ufl_id=200) + + # Test scalar elements + P1 = LagrangeElement(cell, 1) + V_scalar = FunctionSpace(mesh, P1) + u_scalar = TrialFunction(V_scalar) + f_scalar = Coefficient(V_scalar) + + assert extract_unique_domain(u_scalar) == mesh + assert extract_unique_domain(f_scalar) == mesh + + P1_vec = LagrangeElement(cell, 1, (2,)) + V_vector = FunctionSpace(mesh, P1_vec) + u_vector = TrialFunction(V_vector) + f_vector = Coefficient(V_vector) + + assert extract_unique_domain(u_vector) == mesh + assert extract_unique_domain(f_vector) == mesh + + assert extract_unique_domain(u_vector[0]) == mesh + assert extract_unique_domain(u_vector[1]) == mesh + assert extract_unique_domain(f_vector[0]) == mesh + assert extract_unique_domain(f_vector[1]) == mesh + + P1_tensor = LagrangeElement(cell, 1, (2, 2)) + V_tensor = FunctionSpace(mesh, P1_tensor) + u_tensor = TrialFunction(V_tensor) + f_tensor = Coefficient(V_tensor) + + assert extract_unique_domain(u_tensor) == mesh + assert extract_unique_domain(f_tensor) == mesh + assert extract_unique_domain(u_tensor[0, 0]) == mesh + assert extract_unique_domain(u_tensor[1, 1]) == mesh + assert extract_unique_domain(f_tensor[0, 1]) == mesh + + x, y = SpatialCoordinate(mesh) + expr1 = u_scalar + f_scalar + expr2 = u_vector[0] + x + expr3 = inner(u_vector, f_vector) + + assert extract_unique_domain(expr1) == mesh + assert extract_unique_domain(expr2) == mesh + assert extract_unique_domain(expr3) == mesh + + # Test forms + dx = Measure("dx", mesh) + form = inner(u_scalar, f_scalar) * dx + assert extract_unique_domain(form) == mesh + + +def test_extract_unique_domain_mixed_scalar_vector_tensor(): + """Test domain extraction for mixed function spaces + with scalar, vector, and tensor elements.""" + cell = triangle + mesh1 = Mesh(LagrangeElement(cell, 1, (2,)), ufl_id=400) + mesh2 = Mesh(LagrangeElement(cell, 1, (2,)), ufl_id=401) + mesh3 = Mesh(LagrangeElement(cell, 1, (2,)), ufl_id=402) + domain = MeshSequence([mesh1, mesh2, mesh3]) + + scalar_elem = LagrangeElement(cell, 1) + vector_elem = LagrangeElement(cell, 1, (2,)) + tensor_elem = LagrangeElement(cell, 1, (2, 2)) + mixed_elem = MixedElement([scalar_elem, vector_elem, tensor_elem]) + + V = FunctionSpace(domain, mixed_elem) + u = TrialFunction(V) + f = Coefficient(V) + + u_scalar, u_vector, u_tensor = split(u) + f_scalar, f_vector, f_tensor = split(f) + + for i, u_i in enumerate((u_scalar, u_vector, u_tensor)): + assert extract_unique_domain(u_i) == domain[i] + for i, f_i in enumerate((f_scalar, f_vector, f_tensor)): + assert extract_unique_domain(f_i) == domain[i] + + for i in range(2): + assert extract_unique_domain(u_vector[i]) == mesh2 + assert extract_unique_domain(f_vector[i]) == mesh2 + + for i in range(2): + for j in range(2): + assert extract_unique_domain(u_tensor[i, j]) == mesh3 + assert extract_unique_domain(f_tensor[i, j]) == mesh3 + + x1, y1 = SpatialCoordinate(mesh1) + x2, y2 = SpatialCoordinate(mesh2) + x3, y3 = SpatialCoordinate(mesh3) + + expr_scalar = u_scalar * y1 + f_scalar + x1 + assert extract_unique_domain(expr_scalar) == mesh1 + + expr_vector = inner(u_vector * y2, f_vector) + x2 + assert extract_unique_domain(expr_vector) == mesh2 + + expr_vec_comp = u_vector[0] + f_vector[1] * y2 + x2 + assert extract_unique_domain(expr_vec_comp) == mesh2 + + expr_tensor = y3 * u_tensor[0, 0] + f_tensor[1, 1] + x3 + assert extract_unique_domain(expr_tensor) == mesh3 + + with pytest.raises(ValueError): + extract_unique_domain(u_scalar + u_vector[0]) + + with pytest.raises(ValueError): + extract_unique_domain(u_vector[0] + u_tensor[0, 0]) + + with pytest.raises(ValueError): + extract_unique_domain(f_scalar + f_tensor[1, 1]) + + with pytest.raises(ValueError): + extract_unique_domain(u_scalar + x2) + + with pytest.raises(ValueError): + extract_unique_domain(u_vector[0] + x3) + + dx1 = Measure("dx", mesh1) + dx2 = Measure("dx", mesh2) + dx3 = Measure("dx", mesh3) + + form_scalar = u_scalar * f_scalar * dx1 + form_vector = inner(u_vector, f_vector) * dx2 + form_tensor = u_tensor[0, 0] * f_tensor[1, 1] * dx3 + + assert extract_unique_domain(form_scalar) == mesh1 + assert extract_unique_domain(form_vector) == mesh2 + assert extract_unique_domain(form_tensor) == mesh3 + + div_expr = div(u_vector) * f_scalar + with pytest.raises(ValueError): + extract_unique_domain(div_expr) + + +def test_extract_unique_domain_repeated_meshes(): + cell = triangle + mesh1 = Mesh(LagrangeElement(cell, 1, (2,)), ufl_id=500) + mesh2 = Mesh(LagrangeElement(cell, 1, (2,)), ufl_id=501) + + # MeshSequence with repeated meshes + domain_repeated = MeshSequence([mesh1, mesh2, mesh1]) + + scalar_elem = LagrangeElement(cell, 1, shape=()) + mixed_elem = MixedElement([scalar_elem, scalar_elem, scalar_elem]) + V = FunctionSpace(domain_repeated, mixed_elem) + u = TrialFunction(V) + + u1, u2, u3 = split(u) + + assert extract_unique_domain(u1) == mesh1 + assert extract_unique_domain(u2) == mesh2 + assert extract_unique_domain(u3) == mesh1 + + expr_same = u1 + u3 + assert extract_unique_domain(expr_same) == mesh1 + + with pytest.raises(ValueError): + extract_unique_domain(u1 + u2) + + +def test_extract_unique_domain_baseform(): + cell = triangle + mesh1 = Mesh(LagrangeElement(cell, 1, (2,)), ufl_id=400) + mesh2 = Mesh(LagrangeElement(cell, 1, (2,)), ufl_id=401) + scalar_elem = LagrangeElement(cell, 1) + + V1 = FunctionSpace(mesh1, scalar_elem) + V2 = FunctionSpace(mesh2, scalar_elem) + + A = Matrix(V1, V2) + assert extract_unique_domain(A) == mesh1 + + v = Coefficient(V2) + action_Au = Action(A, v) + assert extract_unique_domain(action_Au) == mesh1 + + Astar = Adjoint(A) + assert extract_unique_domain(Astar) == mesh2 + + v1 = TrialFunction(V1) + v2star = TestFunction(V2.dual()) + interp = Interpolate(v1, v2star) # V1 x V2^* -> R, equiv V1 -> V2 + assert extract_unique_domain(interp) == mesh2 + adjoint_interp = Adjoint(interp) # V2^* x V1 -> R, equiv V2^* -> V1^* + assert extract_unique_domain(adjoint_interp) == mesh1 + + cofunc = Coefficient(V2.dual()) + scalar = Action(cofunc, v) + assert extract_unique_domain(scalar) is None + + v = TestFunction(V2) + dx = Measure("dx", mesh2) + one_form = v * dx + formsum = cofunc + one_form + assert extract_unique_domain(formsum) is mesh2 + + two_form = interp * v * dx + assert extract_unique_domain(two_form) is mesh2 diff --git a/ufl/algorithms/apply_derivatives.py b/ufl/algorithms/apply_derivatives.py index 3c5a07919..f29eae9ac 100644 --- a/ufl/algorithms/apply_derivatives.py +++ b/ufl/algorithms/apply_derivatives.py @@ -834,7 +834,7 @@ def _(self, o: ReferenceValue) -> Expr: f = o.ufl_operands[0] if not f._ufl_is_terminal_: raise ValueError("ReferenceValue can only wrap a terminal") - domain = extract_unique_domain(f, expand_mesh_sequence=False) + domain = extract_unique_domain(f) if isinstance(domain, MeshSequence): element = f.ufl_function_space().ufl_element() # type: ignore if element.num_sub_elements != len(domain): @@ -897,7 +897,7 @@ def _(self, o: Expr) -> Expr: ) if not valid_operand: raise ValueError("ReferenceGrad can only wrap a reference frame type!") - domain = extract_unique_domain(f, expand_mesh_sequence=False) + domain = extract_unique_domain(f) if isinstance(domain, MeshSequence): if not f._ufl_is_in_reference_frame_: raise RuntimeError("Expecting a reference frame type") diff --git a/ufl/algorithms/compute_form_data.py b/ufl/algorithms/compute_form_data.py index 09f98af4e..10829eeed 100644 --- a/ufl/algorithms/compute_form_data.py +++ b/ufl/algorithms/compute_form_data.py @@ -188,7 +188,7 @@ def _build_coefficient_replace_map(coefficients, element_mapping=None): # coefficient had a domain, the new one does too. # This should be overhauled with requirement that Expressions # always have a domain. - domain = extract_unique_domain(f, expand_mesh_sequence=False) + domain = extract_unique_domain(f) if domain is not None: new_e = FunctionSpace(domain, new_e) new_f = Coefficient(new_e, count=i) @@ -454,7 +454,7 @@ def compute_form_data( for o in self.reduced_coefficients: if o in coefficients_to_split: c = self.function_replace_map[o] - mesh = extract_unique_domain(c, expand_mesh_sequence=False) + mesh = extract_unique_domain(c) elem = c.ufl_element() coefficient_split[c] = [ Coefficient(FunctionSpace(m, e)) diff --git a/ufl/checks.py b/ufl/checks.py index 132bb7559..d4c7a51df 100644 --- a/ufl/checks.py +++ b/ufl/checks.py @@ -11,7 +11,6 @@ from ufl.core.expr import Expr from ufl.core.terminal import FormArgument from ufl.corealg.traversal import traverse_unique_terminals -from ufl.geometry import GeometricQuantity from ufl.sobolevspace import H1 @@ -40,6 +39,8 @@ def is_cellwise_constant(expr): def is_scalar_constant_expression(expr): """Check if an expression is a globally constant scalar expression.""" + from ufl.geometry import GeometricQuantity + if is_python_scalar(expr): return True if expr.ufl_shape: diff --git a/ufl/corealg/dag_traverser.py b/ufl/corealg/dag_traverser.py index 285e14e53..f066afa4f 100644 --- a/ufl/corealg/dag_traverser.py +++ b/ufl/corealg/dag_traverser.py @@ -1,10 +1,13 @@ """Base class for dag traversers.""" +from __future__ import annotations + from functools import singledispatchmethod, wraps -from typing import overload +from typing import TYPE_CHECKING, overload -from ufl.classes import Expr -from ufl.form import BaseForm +if TYPE_CHECKING: + from ufl.classes import Expr + from ufl.form import BaseForm class DAGTraverser: diff --git a/ufl/differentiation.py b/ufl/differentiation.py index fd4a66bb7..76d327f8e 100644 --- a/ufl/differentiation.py +++ b/ufl/differentiation.py @@ -318,7 +318,7 @@ def __new__(cls, f): # Return zero if expression is trivially constant if is_cellwise_constant(f): # TODO: Use max topological dimension if there are multiple topological dimensions. - dim = extract_unique_domain(f, expand_mesh_sequence=False).topological_dimension() + dim = extract_unique_domain(f).topological_dimension() return Zero(f.ufl_shape + (dim,), f.ufl_free_indices, f.ufl_index_dimensions) return CompoundDerivative.__new__(cls) @@ -326,7 +326,7 @@ def __init__(self, f): """Initalise.""" CompoundDerivative.__init__(self, (f,)) # TODO: Use max topological dimension if there are multiple topological dimensions. - self._dim = extract_unique_domain(f, expand_mesh_sequence=False).topological_dimension() + self._dim = extract_unique_domain(f).topological_dimension() def _ufl_expr_reconstruct_(self, op): """Return a new object of the same type with new operands.""" diff --git a/ufl/domain.py b/ufl/domain.py index 9ba4372e2..8514c4bbb 100644 --- a/ufl/domain.py +++ b/ufl/domain.py @@ -10,6 +10,7 @@ import numbers from collections.abc import Iterable, Sequence +from functools import singledispatchmethod from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -17,9 +18,14 @@ from ufl.finiteelement import AbstractFiniteElement # To avoid cyclic import when type-hinting. from ufl.form import Form from ufl.cell import AbstractCell +from ufl.core.base_form_operator import BaseFormOperator +from ufl.core.operator import Operator +from ufl.core.terminal import Terminal from ufl.core.ufl_id import attach_ufl_id from ufl.core.ufl_type import UFLObject +from ufl.corealg.dag_traverser import DAGTraverser from ufl.corealg.traversal import traverse_unique_terminals +from ufl.indexed import Indexed from ufl.sobolevspace import H1 # Export list for ufl.classes @@ -330,7 +336,7 @@ def as_domain(domain): return domain try: return extract_unique_domain(domain) - except AttributeError: + except (AttributeError, TypeError): domain = domain.ufl_domain() (domain,) = set(domain.meshes) return domain @@ -410,26 +416,6 @@ def extract_domains(expr: Expr | Form, expand_mesh_sequence: bool = True): return sort_domains(join_domains(domainlist, expand_mesh_sequence=expand_mesh_sequence)) -def extract_unique_domain(expr, expand_mesh_sequence: bool = True): - """Return the single unique domain expression is defined on or throw an error. - - Args: - expr: Expr or Form. - expand_mesh_sequence: If True, MeshSequence components are expanded. - - Returns: - domain. - - """ - domains = extract_domains(expr, expand_mesh_sequence=expand_mesh_sequence) - if len(domains) == 1: - return domains[0] - elif domains: - raise ValueError("Found multiple domains, cannot return just one.") - else: - return None - - def find_geometric_dimension(expr): """Find the geometric dimension of an expression.""" gdims = set() @@ -444,3 +430,145 @@ def find_geometric_dimension(expr): raise ValueError("Cannot determine geometric dimension from expression.") (gdim,) = gdims return gdim + + +class UniqueDomainExtractor(DAGTraverser): + """Extract unique domain from an expression or BaseForm.""" + + def __init__( + self, + compress: bool | None = True, + visited_cache: dict[tuple, Expr] | None = None, + result_cache: dict[Expr, Expr] | None = None, + ) -> None: + """Initialise.""" + self._compress = compress + self._visited_cache = {} if visited_cache is None else visited_cache + self._result_cache = {} if result_cache is None else result_cache + super().__init__(compress=compress, visited_cache=visited_cache, result_cache=result_cache) + + @singledispatchmethod + def process(self, o: Expr) -> Expr: + """Process ``o``. + + Args: + o: `Expr` to be processed. + + Returns: + Processed object. + + """ + return super().process(o) + + @process.register(Indexed) + @DAGTraverser.postorder + def _(self, o: Expr, *operand_results) -> AbstractDomain: + """Process Indexed object by extracting the domain corresponding to the index.""" + from ufl.functionspace import FunctionSpace + + expression, multiindex = o.ufl_operands + expression_domain = operand_results[0] + + if isinstance(expression_domain, MeshSequence): + index = multiindex[0]._value + element = expression.ufl_element() + if hasattr(element, "sub_elements"): + # Need to do this in case we have sub elements which are vector or tensor valued + j = 0 + for i, sub_element in enumerate(element.sub_elements): + # Get the value size for this sub-element on its corresponding mesh + sub_element_mesh = expression_domain.meshes[i] + sub_element_fs = FunctionSpace(sub_element_mesh, sub_element) + sub_element_size = sub_element_fs.value_size + + if index < j + sub_element_size: + return sub_element_mesh + j += sub_element_size + raise ValueError(f"Index {index} out of range for mixed function space") + else: + return expression_domain.meshes[index] + return expression_domain + + @process.register(Terminal) + @DAGTraverser.postorder + def _(self, o: Expr) -> AbstractDomain: + from ufl.argument import Argument + from ufl.coefficient import Coefficient + from ufl.constant import Constant + from ufl.geometry import GeometricQuantity + + if isinstance(o, Coefficient | Argument): + fs = o.ufl_function_space() + return fs.ufl_domain() + elif isinstance(o, GeometricQuantity): + return o._domain + elif isinstance(o, Constant): + return o._ufl_domain + else: + return None + + @process.register(Operator) + @DAGTraverser.postorder + def _(self, o: Expr, *operand_results) -> AbstractDomain: + """Process Operator.""" + domains = [d for d in operand_results if d is not None] + + if not domains: + return None + elif len(domains) == 1: + return domains[0] + else: + # Multiple operands have domains - they should all be the same + first_domain = domains[0] + if all(d == first_domain for d in domains): + return first_domain + else: + raise ValueError( + f"Expression {o!r} has differing domains: {domains!r}" + ) + + @process.register(BaseFormOperator) + @DAGTraverser.postorder + def _(self, o: Expr, *operand_results) -> AbstractDomain: + fs = o.ufl_function_space() + return fs.ufl_domain() + + +def extract_unique_domain(expr: Expr | Form) -> AbstractDomain: + """Extract the single unique domain from an expression. + + This works for expressions containing Indexed Arguments and Coefficients from + split functions on mixed function spaces. + + Args: + expr: Expr or Form to extract domain from + + Returns: + AbstractDomain: The unique domain extracted from the expression. + """ + from ufl.core.expr import Expr + from ufl.form import BaseForm, Form + + if isinstance(expr, Form): + domains = set() + for integral in expr.integrals(): + domain = extract_unique_domain(integral.integrand()) + if domain is not None: + domains.add(domain) + + if len(domains) == 0: + return None + elif len(domains) == 1: + return domains.pop() + else: + raise ValueError(f"Form has multiple domains: {domains}") + elif isinstance(expr, BaseForm): + if not expr.arguments(): + return None + else: + return extract_unique_domain(expr.arguments()[0]) + elif isinstance(expr, Expr): + extractor = UniqueDomainExtractor() + return extractor(expr) + else: + raise TypeError(f"Expected an Expr or Form, not a {type(expr).__name__}.") diff --git a/ufl/form.py b/ufl/form.py index b7ff23532..ba2e07a74 100644 --- a/ufl/form.py +++ b/ufl/form.py @@ -17,12 +17,10 @@ from itertools import chain from ufl.checks import is_scalar_constant_expression -from ufl.constant import Constant from ufl.constantvalue import Zero from ufl.core.expr import Expr, ufl_err_str from ufl.core.terminal import FormArgument from ufl.core.ufl_type import UFLType, ufl_type -from ufl.domain import extract_unique_domain, sort_domains from ufl.equation import Equation from ufl.integral import Integral from ufl.utils.counted import Counted @@ -40,6 +38,8 @@ def _sorted_integrals(integrals): Sort integrals by domain id, integral type, subdomain id for a more stable signature computation. """ + from ufl.domain import sort_domains + # Group integrals in multilevel dict by keys # [domain][integral_type][subdomain_id] integrals_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) @@ -399,6 +399,8 @@ def constants(self): def constant_numbering(self): """Return a contiguous numbering of constants in a mapping ``{constant:number}``.""" + from ufl.constant import Constant + if self._constant_numbering is None: self._constant_numbering = { expr: num @@ -594,7 +596,7 @@ def __repr__(self): def _analyze_domains(self): """Analyze domains.""" - from ufl.domain import join_domains, sort_domains + from ufl.domain import extract_unique_domain, join_domains, sort_domains # Collect integration domains. self._integration_domains = sort_domains( @@ -605,7 +607,7 @@ def _analyze_domains(self): for o in chain( self.arguments(), self.coefficients(), self.constants(), self.geometric_quantities() ): - domain = extract_unique_domain(o, expand_mesh_sequence=False) + domain = extract_unique_domain(o) domains_in_integrands.update(domain.meshes) domains_in_integrands -= set(self._integration_domains) all_domains = self._integration_domains + sort_domains(join_domains(domains_in_integrands)) diff --git a/ufl/functionspace.py b/ufl/functionspace.py index 95757046e..797d649b4 100644 --- a/ufl/functionspace.py +++ b/ufl/functionspace.py @@ -12,7 +12,6 @@ import numpy as np from ufl.core.ufl_type import UFLObject -from ufl.domain import join_domains from ufl.duals import is_dual, is_primal from ufl.utils.sequences import product @@ -311,6 +310,8 @@ def ufl_element(self): def ufl_domains(self): """Return ufl domains.""" + from ufl.domain import join_domains + domainlist = [] for s in self._ufl_function_spaces: domainlist.extend(s.ufl_domains()) diff --git a/ufl/matrix.py b/ufl/matrix.py index d548c0abd..42b009e58 100644 --- a/ufl/matrix.py +++ b/ufl/matrix.py @@ -7,7 +7,6 @@ # # Modified by Nacime Bouziani, 2021-2022. -from ufl.argument import Argument from ufl.core.ufl_type import ufl_type from ufl.form import BaseForm from ufl.functionspace import AbstractFunctionSpace @@ -64,6 +63,8 @@ def ufl_function_spaces(self): def _analyze_form_arguments(self): """Define arguments of a matrix when considered as a form.""" + from ufl.argument import Argument + self._arguments = ( Argument(self._ufl_function_spaces[0], 0), Argument(self._ufl_function_spaces[1], 1), diff --git a/ufl/measure.py b/ufl/measure.py index cb9493c64..27415b85e 100644 --- a/ufl/measure.py +++ b/ufl/measure.py @@ -15,7 +15,7 @@ from ufl.checks import is_true_ufl_scalar from ufl.constantvalue import as_ufl from ufl.core.expr import Expr -from ufl.domain import AbstractDomain, as_domain, extract_domains +from ufl.domain import AbstractDomain, as_domain, extract_unique_domain from ufl.protocols import id_or_none # Export list for ufl.classes @@ -417,15 +417,7 @@ def __rmul__(self, integrand): # integrand domain = self.ufl_domain() if domain is None: - domains = extract_domains(integrand) - if len(domains) == 1: - (domain,) = domains - elif len(domains) == 0: - raise ValueError("This integral is missing an integration domain.") - else: - raise ValueError( - "Multiple domains found, making the choice of integration domain ambiguous." - ) + domain = extract_unique_domain(integrand) # Otherwise create and return a one-integral form integral = Integral( diff --git a/ufl/pullback.py b/ufl/pullback.py index 4f4bacdc6..fa2f1362b 100644 --- a/ufl/pullback.py +++ b/ufl/pullback.py @@ -414,7 +414,7 @@ def apply(self, expr, domain=None): g_components = [] offset = 0 # For each unique piece in reference space, apply the appropriate pullback - domain = domain or extract_unique_domain(expr, expand_mesh_sequence=False) + domain = domain or extract_unique_domain(expr) if isinstance(domain, MeshSequence): if len(domain) != self._element.num_sub_elements: raise ValueError(f"""num. component meshes ({len(domain)}) != @@ -504,7 +504,7 @@ def apply(self, expr, domain=None): for subelem in self._element.sub_elements: offsets.append(offsets[-1] + subelem.reference_value_size) # For each unique piece in reference space, apply the appropriate pullback - domain = domain or extract_unique_domain(expr, expand_mesh_sequence=False) + domain = domain or extract_unique_domain(expr) if isinstance(domain, MeshSequence): if len(domain) != self._element.num_sub_elements: raise ValueError(f"""num. component meshes ({len(domain)}) != diff --git a/ufl/split_functions.py b/ufl/split_functions.py index 69be05e92..5ece9d262 100644 --- a/ufl/split_functions.py +++ b/ufl/split_functions.py @@ -7,7 +7,6 @@ # # Modified by Anders Logg, 2008 -from ufl.domain import extract_unique_domain from ufl.functionspace import FunctionSpace from ufl.indexed import Indexed from ufl.permutation import compute_indices @@ -60,7 +59,7 @@ def split(v): "Don't know how to split tensor valued mixed functions without flattened index space." ) - domain = extract_unique_domain(v, expand_mesh_sequence=False) + domain = v.ufl_function_space().ufl_domain() # Compute value size and set default range end value_size = v.ufl_function_space().value_size