Skip to content

Commit e141383

Browse files
authored
Merge pull request #305 from DedalusProject/ufunc
Broader support of custom functions
2 parents 7dc7ae6 + ba4f464 commit e141383

File tree

7 files changed

+109
-81
lines changed

7 files changed

+109
-81
lines changed

dedalus/core/arithmetic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
44
"""
55

6+
import sys
67
from functools import reduce
78
import numpy as np
89
from scipy import sparse
@@ -981,7 +982,7 @@ def sym_diff(self, var):
981982

982983
# Define aliases
983984
for key, value in aliases.items():
984-
exec(f"{key} = {value.__name__}")
985+
setattr(sys.modules[__name__], key, value)
985986

986987
# Export aliases
987988
__all__.extend(aliases.keys())

dedalus/core/field.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,20 @@ def valid_modes(self):
326326
# Return copy to avoid mangling cached result from coeff_layout
327327
return valid_modes.copy()
328328

329+
@property
330+
def real(self):
331+
if self.is_real:
332+
return self
333+
else:
334+
return (self + np.conj(self)) / 2
335+
336+
@property
337+
def imag(self):
338+
if self.is_real:
339+
return 0
340+
else:
341+
return (self - np.conj(self)) / 2j
342+
329343

330344
class Current(Operand):
331345

dedalus/core/operators.py

Lines changed: 71 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
44
"""
55

6+
import sys
67
from collections import defaultdict
78
from functools import partial, reduce
89
import numpy as np
@@ -454,7 +455,7 @@ class GeneralFunction(NonlinearOperator, FutureField):
454455
455456
Notes
456457
-----
457-
On evaluation, this wrapper evaluates the provided funciton with the given
458+
On evaluation, this wrapper evaluates the provided function with the given
458459
arguments and keywords, and takes the output to be data in the specified
459460
layout, i.e.
460461
@@ -502,27 +503,69 @@ def operate(self, out):
502503

503504

504505
class UnaryGridFunction(NonlinearOperator, FutureField):
506+
"""
507+
Wrapper for applying unary functions to fields in grid space.
508+
This can be used with arbitrary user-defined functions, but
509+
symbolic differentiation is only implemented for some scipy/numpy
510+
universal functions.
505511
506-
supported = {ufunc.__name__: ufunc for ufunc in
507-
(np.absolute, np.conj, np.exp, np.exp2, np.expm1,
508-
np.log, np.log2, np.log10, np.log1p, np.sqrt, np.square,
509-
np.sin, np.cos, np.tan, np.arcsin, np.arccos, np.arctan,
510-
np.sinh, np.cosh, np.tanh, np.arcsinh, np.arccosh, np.arctanh,
511-
scp.erf
512-
)}
513-
aliased = {'abs':np.absolute, 'conj':np.conjugate}
514-
# Add ufuncs and shortcuts to parseables
515-
parseables.update(supported)
516-
parseables.update(aliased)
517-
518-
def __init__(self, func, arg, **kw):
519-
if func not in self.supported.values():
520-
raise ValueError("Unsupported ufunc: %s" %func)
521-
#arg = Operand.cast(arg)
522-
super().__init__(arg, **kw)
512+
Parameters
513+
----------
514+
func : function
515+
Unary function acting on grid data. Must be vectorized
516+
and include an output array argument, e.g. func(x, out).
517+
arg : dedalus operand
518+
Argument field or operator.
519+
deriv : function, optional
520+
Symbolic derivative of func. Defaults are provided
521+
for some common numpy/scipy ufuncs (default: None).
522+
out : field, optional
523+
Output field (default: new field).
524+
525+
Notes
526+
-----
527+
The supplied function must support an output argument called 'out'
528+
and act in a vectorized fashion. The action is essentially:
529+
530+
func(arg['g'], out=out['g'])
531+
532+
"""
533+
534+
ufunc_derivatives = {
535+
np.absolute: lambda x: np.sign(x),
536+
np.sign: lambda x: 0,
537+
np.exp: lambda x: np.exp(x),
538+
np.exp2: lambda x: np.exp2(x) * np.log(2),
539+
np.log: lambda x: x**(-1),
540+
np.log2: lambda x: (x * np.log(2))**(-1),
541+
np.log10: lambda x: (x * np.log(10))**(-1),
542+
np.sqrt: lambda x: (1/2) * x**(-1/2),
543+
np.square: lambda x: 2*x,
544+
np.sin: lambda x: np.cos(x),
545+
np.cos: lambda x: -np.sin(x),
546+
np.tan: lambda x: np.cos(x)**(-2),
547+
np.arcsin: lambda x: (1 - x**2)**(-1/2),
548+
np.arccos: lambda x: -(1 - x**2)**(-1/2),
549+
np.arctan: lambda x: (1 + x**2)**(-1),
550+
np.sinh: lambda x: np.cosh(x),
551+
np.cosh: lambda x: np.sinh(x),
552+
np.tanh: lambda x: 1-np.tanh(x)**2,
553+
np.arcsinh: lambda x: (x**2 + 1)**(-1/2),
554+
np.arccosh: lambda x: (x**2 - 1)**(-1/2),
555+
np.arctanh: lambda x: (1 - x**2)**(-1),
556+
scp.erf: lambda x: 2*(np.pi)**(-1/2)*np.exp(-x**2)}
557+
558+
# Add ufuncs and shortcuts to aliases
559+
aliases.update({ufunc.__name__: ufunc for ufunc in ufunc_derivatives})
560+
aliases.update({'abs': np.absolute, 'conj': np.conjugate})
561+
562+
def __init__(self, func, arg, deriv=None, out=None):
563+
super().__init__(arg, out=out)
523564
self.func = func
524-
if arg.tensorsig:
525-
raise ValueError("Ufuncs not defined for non-scalar fields.")
565+
if deriv is None and func in self.ufunc_derivatives:
566+
self.deriv = self.ufunc_derivatives[func]
567+
else:
568+
self.deriv = deriv
526569
# FutureField requirements
527570
self.domain = arg.domain
528571
self.tensorsig = arg.tensorsig
@@ -538,40 +581,19 @@ def _build_bases(self, arg0):
538581
bases = arg0.domain
539582
return bases
540583

541-
def new_operands(self, arg):
542-
return UnaryGridFunction(self.func, arg)
584+
def new_operand(self, arg):
585+
return UnaryGridFunction(self.func, arg, deriv=self.deriv)
543586

544587
def reinitialize(self, **kw):
545588
arg = self.args[0].reinitialize(**kw)
546-
return self.new_operands(arg)
589+
return self.new_operand(arg)
547590

548591
def sym_diff(self, var):
549592
"""Symbolically differentiate with respect to specified operand."""
550-
diff_map = {np.absolute: lambda x: np.sign(x),
551-
np.sign: lambda x: 0,
552-
np.exp: lambda x: np.exp(x),
553-
np.exp2: lambda x: np.exp2(x) * np.log(2),
554-
np.log: lambda x: x**(-1),
555-
np.log2: lambda x: (x * np.log(2))**(-1),
556-
np.log10: lambda x: (x * np.log(10))**(-1),
557-
np.sqrt: lambda x: (1/2) * x**(-1/2),
558-
np.square: lambda x: 2*x,
559-
np.sin: lambda x: np.cos(x),
560-
np.cos: lambda x: -np.sin(x),
561-
np.tan: lambda x: np.cos(x)**(-2),
562-
np.arcsin: lambda x: (1 - x**2)**(-1/2),
563-
np.arccos: lambda x: -(1 - x**2)**(-1/2),
564-
np.arctan: lambda x: (1 + x**2)**(-1),
565-
np.sinh: lambda x: np.cosh(x),
566-
np.cosh: lambda x: np.sinh(x),
567-
np.tanh: lambda x: 1-np.tanh(x)**2,
568-
np.arcsinh: lambda x: (x**2 + 1)**(-1/2),
569-
np.arccosh: lambda x: (x**2 - 1)**(-1/2),
570-
np.arctanh: lambda x: (1 - x**2)**(-1),
571-
scp.erf: lambda x: 2*(np.pi)**(-1/2)*np.exp(-x**2)}
593+
if self.deriv is None:
594+
raise ValueError(f"Symbolic derivative not implemented for {self.func.__name__}.")
572595
arg = self.args[0]
573-
arg_diff = arg.sym_diff(var)
574-
return diff_map[self.func](arg) * arg_diff
596+
return self.deriv(arg) * arg.sym_diff(var)
575597

576598
def check_conditions(self):
577599
# Field must be in grid layout
@@ -1024,7 +1046,7 @@ def matrix_coupling(self, *vars):
10241046
return self.operand.matrix_coupling(*vars)
10251047

10261048

1027-
@parseable('interpolate', 'interp')
1049+
#@parseable('interpolate', 'interp')
10281050
def interpolate(arg, **positions):
10291051
# Identify domain
10301052
domain = unify_attributes((arg,)+tuple(positions), 'domain', require=False)
@@ -4386,7 +4408,7 @@ def compute_cfl_frequency(self, velocity, out):
43864408

43874409
# Define aliases
43884410
for key, value in aliases.items():
4389-
exec(f"{key} = {value.__name__}")
4411+
setattr(sys.modules[__name__], key, value)
43904412

43914413
# Export aliases
43924414
__all__.extend(aliases.keys())

dedalus/core/problems.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@
2828
# Build basic parsing namespace
2929
parseables = {}
3030
parseables.update({name: getattr(operators, name) for name in operators.__all__})
31-
parseables.update(operators.aliases)
3231
parseables.update({name: getattr(arithmetic, name) for name in arithmetic.__all__})
33-
parseables.update(arithmetic.aliases)
3432

3533

3634
class ProblemBase:

dedalus/tests/test_grid_operators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
N_range = [16]
1010
dealias_range = [1]
1111
dtype_range = [np.float64, np.complex128]
12-
ufuncs = d3.UnaryGridFunction.supported.values()
12+
ufuncs = d3.UnaryGridFunction.ufunc_derivatives.keys()
1313

1414

1515
@CachedMethod

docs/pages/general_functions.rst

Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,36 @@
11
General Functions
22
*****************
33

4-
**Note: this documentation has not yet been updated for v3 of Dedalus.**
5-
6-
The ``GeneralFunction`` class enables users to simply define new explicit operators for the right-hand side and analysis tasks of their simulations.
4+
The ``GeneralFunction`` and ``UnaryGridFunction`` classes enables users to simply define new explicit operators for the right-hand side and analysis tasks of their simulations.
75
Such operators can be used to apply arbitrary user-defined functions to the grid values or coefficients of some set of input fields, or even do things like introduce random data or read data from an external source.
86

9-
A ``GeneralFunction`` object is instantiated with a Dedalus domain, a layout object or descriptor (e.g. ``'g'`` or ``'c'`` for grid or coefficient space), a function, a list of arguments, and a dictionary of keywords.
7+
A ``GeneralFunction`` object is instantiated with a Dedalus distributor, domain, tensor signature, dtype, layout object or descriptor (e.g. ``'g'`` or ``'c'`` for grid or coefficient space), function, list of arguments, and dictionary of keywords.
108
The resulting object is a Dedalus operator that can be evaluated and composed like other Dedalus operators.
119
It operates by first ensuring that any arguments that are Dedalus field objects are in the specified layout, then calling the function with the specified arguments and keywords, and finally setting the result as the output data in the specified layout.
1210

13-
Here's an example how you can use this class to apply a nonlinear function to the grid data of a single Dedalus field.
14-
First, we define the underlying function we want to apply to the field data -- say the error function from scipy:
15-
16-
.. code-block:: python
17-
18-
from scipy import special
11+
A simpler option that should work for many use cases is the ``UnaryGridFunction`` class, which specifically applies a function to the grid data of a single field.
12+
The output field's distributor, domain/bases, tensor signature, and dtype are all taken to be idential to those of the input field.
13+
Only the function and input field need to be specified.
14+
The function must be vectorized, take a single Numpy array as input, and include an ``out`` argument that specifies the output array.
15+
Applying most Numpy or Scipy universal functions to a Dedalus field will automatically produce the corresponding ``UnaryGridFunction`` operator.
1916

20-
def erf_func(field):
21-
# Call scipy erf function on the field's data
22-
return special.erf(field.data)
23-
24-
Second, we make a wrapper that returns a ``GeneralFunction`` instance that applies ``erf_func`` to a provided field in grid space.
25-
This function produces a Dedalus operator, so it's what we want to use on the RHS or in analysis tasks:
17+
Here's an example of using the ``UnaryGridFunction`` class to apply a custom function to the grid data of a single Dedalus field.
18+
First, we define the underlying function we want to apply to the field data:
2619

2720
.. code-block:: python
2821
29-
import dedalus.public as de
30-
31-
def erf_operator(field):
32-
# Return GeneralFunction instance that applies erf_func in grid space
33-
return de.operators.GeneralFunction(
34-
field.domain,
35-
layout = 'g',
36-
func = erf_func,
37-
args = (field,)
38-
)
22+
# Custom function acting on grid data
23+
def custom_grid_function(x, out):
24+
out[:] = (x + np.abs(x)) / 2
25+
return out
3926
40-
Finally, we add this wrapper to the parsing namespace to make it available in string-specified equations and analysis tasks:
27+
Second, we make a wrapper that returns a ``UnaryGridFunction`` instance that applies ``custom_grid_function`` to a specified field.
28+
This wrapper produces a Dedalus operator, so it's what we want to use on the RHS or in analysis tasks:
4129

4230
.. code-block:: python
4331
44-
de.operators.parseables['erf'] = erf_operator
32+
# Operator wrapper for custom function
33+
custom_grid_operator = lambda field: d3.UnaryGridFunction(custom_grid_function, field)
34+
35+
# Analysis task applying custom operator to a field
36+
snapshots.add_task(custom_grid_operator(u), name="custom(u)")

docs/pages/user_guide.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,5 @@ Specific how-to's:
2020
gauge_conditions
2121
tau_method
2222
half_dimensions
23+
general_functions
2324

0 commit comments

Comments
 (0)