Skip to content

Commit 021589c

Browse files
authored
Merge pull request #310 from firedrakeproject/rckirby/feature/macro
Support FIAT macroelements
2 parents 90c20c5 + 916e773 commit 021589c

File tree

3 files changed

+36
-37
lines changed

3 files changed

+36
-37
lines changed

tests/test_create_fiat_element.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22

33
import FIAT
4-
from FIAT.discontinuous_lagrange import HigherOrderDiscontinuousLagrange as FIAT_DiscontinuousLagrange
4+
from FIAT.discontinuous_lagrange import DiscontinuousLagrange as FIAT_DiscontinuousLagrange
55

66
import ufl
77
import finat.ufl

tsfc/finatinterface.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,13 @@
1919
# You should have received a copy of the GNU Lesser General Public License
2020
# along with FFC. If not, see <http://www.gnu.org/licenses/>.
2121

22-
from functools import singledispatch, partial
2322
import weakref
23+
from functools import partial, singledispatch
2424

2525
import FIAT
2626
import finat
27-
import ufl
2827
import finat.ufl
29-
28+
import ufl
3029

3130
__all__ = ("as_fiat_cell", "create_base_element",
3231
"create_element", "supported_elements")
@@ -52,6 +51,8 @@
5251
"Hermite": finat.Hermite,
5352
"Kong-Mulder-Veldhuizen": finat.KongMulderVeldhuizen,
5453
"Argyris": finat.Argyris,
54+
"Hsieh-Clough-Tocher": finat.HsiehCloughTocher,
55+
"Reduced-Hsieh-Clough-Tocher": finat.ReducedHsiehCloughTocher,
5556
"Mardal-Tai-Winther": finat.MardalTaiWinther,
5657
"Morley": finat.Morley,
5758
"Bell": finat.Bell,
@@ -144,12 +145,10 @@ def convert_finiteelement(element, **kwargs):
144145
kind = 'spectral' # default variant
145146

146147
if element.family() == "Lagrange":
147-
if kind == 'equispaced':
148-
lmbda = finat.Lagrange
149-
elif kind == 'spectral':
148+
if kind == 'spectral':
150149
lmbda = finat.GaussLobattoLegendre
151-
elif kind == 'integral':
152-
lmbda = finat.IntegratedLegendre
150+
elif kind.startswith('integral'):
151+
lmbda = partial(finat.IntegratedLegendre, variant=kind)
153152
elif kind in ['fdm', 'fdm_ipdg'] and is_interval:
154153
lmbda = finat.FDMLagrange
155154
elif kind == 'fdm_quadrature' and is_interval:
@@ -167,17 +166,16 @@ def convert_finiteelement(element, **kwargs):
167166
deps = {"shift_axes", "restriction"}
168167
return finat.RuntimeTabulated(cell, degree, variant=kind, shift_axes=shift_axes, restriction=restriction), deps
169168
else:
170-
raise ValueError("Variant %r not supported on %s" % (kind, element.cell))
169+
# Let FIAT handle the general case
170+
lmbda = partial(finat.Lagrange, variant=kind)
171171
elif element.family() in {"Raviart-Thomas", "Nedelec 1st kind H(curl)",
172172
"Brezzi-Douglas-Marini", "Nedelec 2nd kind H(curl)"}:
173173
lmbda = partial(lmbda, variant=element.variant())
174174
elif element.family() in ["Discontinuous Lagrange", "Discontinuous Lagrange L2"]:
175-
if kind == 'equispaced':
176-
lmbda = finat.DiscontinuousLagrange
177-
elif kind == 'spectral':
175+
if kind == 'spectral':
178176
lmbda = finat.GaussLegendre
179-
elif kind == 'integral':
180-
lmbda = finat.Legendre
177+
elif kind.startswith('integral'):
178+
lmbda = partial(finat.Legendre, variant=kind)
181179
elif kind in ['fdm', 'fdm_quadrature'] and is_interval:
182180
lmbda = finat.FDMDiscontinuousLagrange
183181
elif kind == 'fdm_ipdg' and is_interval:
@@ -191,7 +189,8 @@ def convert_finiteelement(element, **kwargs):
191189
deps = {"shift_axes", "restriction"}
192190
return finat.RuntimeTabulated(cell, degree, variant=kind, shift_axes=shift_axes, restriction=restriction, continuous=False), deps
193191
else:
194-
raise ValueError("Variant %r not supported on %s" % (kind, element.cell))
192+
# Let FIAT handle the general case
193+
lmbda = partial(finat.DiscontinuousLagrange, variant=kind)
195194
elif element.family() == ["DPC", "DPC L2"]:
196195
if element.cell.geometric_dimension() == 2:
197196
element = element.reconstruct(cell=ufl.cell.hypercube(2))

tsfc/kernel_interface/common.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,25 @@
11
import collections
2-
import string
32
import operator
3+
import string
44
from functools import reduce
55
from itertools import chain, product
66

7+
import gem
8+
import gem.impero_utils as impero_utils
79
import numpy
8-
from numpy import asarray
9-
10-
from ufl.utils.sequences import max_degree
11-
1210
from FIAT.reference_element import TensorProductCell
13-
11+
from finat.cell_tools import max_complex
1412
from finat.quadrature import AbstractQuadratureRule, make_quadrature
15-
16-
import gem
17-
1813
from gem.node import traversal
14+
from gem.optimise import constant_fold_zero
15+
from gem.optimise import remove_componenttensors as prune
1916
from gem.utils import cached_property
20-
import gem.impero_utils as impero_utils
21-
from gem.optimise import remove_componenttensors as prune, constant_fold_zero
22-
17+
from numpy import asarray
2318
from tsfc import fem, ufl_utils
24-
from tsfc.kernel_interface import KernelInterface
2519
from tsfc.finatinterface import as_fiat_cell, create_element
20+
from tsfc.kernel_interface import KernelInterface
2621
from tsfc.logging import logger
22+
from ufl.utils.sequences import max_degree
2723

2824

2925
class KernelBuilderBase(KernelInterface):
@@ -301,22 +297,26 @@ def set_quad_rule(params, cell, integral_type, functions):
301297
quadrature_degree = params["quadrature_degree"]
302298
except KeyError:
303299
quadrature_degree = params["estimated_polynomial_degree"]
304-
function_degrees = [f.ufl_function_space().ufl_element().degree() for f in functions]
300+
function_degrees = [f.ufl_function_space().ufl_element().degree()
301+
for f in functions]
305302
if all((asarray(quadrature_degree) > 10 * asarray(degree)).all()
306303
for degree in function_degrees):
307304
logger.warning("Estimated quadrature degree %s more "
308305
"than tenfold greater than any "
309306
"argument/coefficient degree (max %s)",
310307
quadrature_degree, max_degree(function_degrees))
311-
if params.get("quadrature_rule") == "default":
312-
del params["quadrature_rule"]
313-
try:
314-
quad_rule = params["quadrature_rule"]
315-
except KeyError:
308+
quad_rule = params.get("quadrature_rule", "default")
309+
if isinstance(quad_rule, str):
310+
scheme = quad_rule
316311
fiat_cell = as_fiat_cell(cell)
312+
finat_elements = set(create_element(f.ufl_element()) for f in functions
313+
if f.ufl_element().family() != "Real")
314+
fiat_cells = [fiat_cell] + [finat_el.complex for finat_el in finat_elements]
315+
fiat_cell = max_complex(fiat_cells)
316+
317317
integration_dim, _ = lower_integral_type(fiat_cell, integral_type)
318-
integration_cell = fiat_cell.construct_subelement(integration_dim)
319-
quad_rule = make_quadrature(integration_cell, quadrature_degree)
318+
integration_cell = fiat_cell.construct_subcomplex(integration_dim)
319+
quad_rule = make_quadrature(integration_cell, quadrature_degree, scheme=scheme)
320320
params["quadrature_rule"] = quad_rule
321321

322322
if not isinstance(quad_rule, AbstractQuadratureRule):

0 commit comments

Comments
 (0)