Skip to content
2 changes: 1 addition & 1 deletion docs/getting_started/explanation_concepts.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ dt = np.timedelta64(5, "m")
runtime = np.timedelta64(1, "D")

# Run the simulation
pset.execute(pyfunc=kernels, dt=dt, runtime=runtime)
pset.execute(kernels=kernels, dt=dt, runtime=runtime)
```

### Output
Expand Down
4 changes: 2 additions & 2 deletions docs/user_guide/examples/tutorial_interaction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@
"]\n",
"\n",
"pset.execute(\n",
" pyfunc=kernels,\n",
" kernels=kernels,\n",
" runtime=np.timedelta64(60, \"s\"),\n",
" dt=np.timedelta64(1, \"s\"),\n",
" output_file=output_file,\n",
Expand Down Expand Up @@ -331,7 +331,7 @@
"]\n",
"\n",
"pset.execute(\n",
" pyfunc=kernels,\n",
" kernels=kernels,\n",
" runtime=np.timedelta64(60, \"s\"),\n",
" dt=np.timedelta64(1, \"s\"),\n",
" output_file=output_file,\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/user_guide/examples_v3/tutorial_stommel_uxarray.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@
" pset.execute(\n",
" endtime=endtime,\n",
" dt=timedelta(seconds=60),\n",
" pyfunc=AdvectionEE,\n",
" kernels=AdvectionEE,\n",
" verbose_progress=False,\n",
" )\n",
" except FieldOutOfBoundError:\n",
Expand Down
1 change: 1 addition & 0 deletions docs/user_guide/v4-migration.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Version 4 of Parcels is unreleased at the moment. The information in this migrat
- The `InteractionKernel` class has been removed. Since normal Kernels now have access to _all_ particles, particle-particle interaction can be performed within normal Kernels.
- Users need to explicitly use `convert_z_to_sigma_croco` in sampling kernels (such as the `AdvectionRK4_3D_CROCO` or `SampleOMegaCroco` kernels) when working with CROCO data, as the automatic conversion from depth to sigma grids under the hood has been removed.
- We added a new AdvectionRK2 Kernel. The AdvectionRK4 kernel is still available, but RK2 is now the recommended default advection scheme as it is faster while the accuracy is comparable for most applications. See also the Choosing an integration method tutorial.
- Functions shouldn't be converted to Kernels before adding to a pset.execute() call. Instead, simply pass the function(s) as a list to pset.execute().

## FieldSet

Expand Down
2 changes: 0 additions & 2 deletions src/parcels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from parcels._core.fieldset import FieldSet
from parcels._core.particleset import ParticleSet
from parcels._core.kernel import Kernel
from parcels._core.particlefile import ParticleFile
from parcels._core.particle import (
Variable,
Expand Down Expand Up @@ -45,7 +44,6 @@
# Core classes
"FieldSet",
"ParticleSet",
"Kernel",
"ParticleFile",
"Variable",
"Particle",
Expand Down
92 changes: 28 additions & 64 deletions src/parcels/_core/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
if TYPE_CHECKING:
from collections.abc import Callable

__all__ = ["Kernel"]


ErrorsToThrow = {
StatusCode.ErrorOutsideTimeInterval: _raise_outside_time_interval_error,
Expand All @@ -45,12 +43,12 @@ class Kernel:

Parameters
----------
kernels :
list of Kernel functions
fieldset : parcels.Fieldset
FieldSet object providing the field information (possibly None)
ptype :
PType object for the kernel particle
pyfunc :
(aggregated) Kernel function

Notes
-----
Expand All @@ -60,32 +58,35 @@ class Kernel:

def __init__(
self,
fieldset,
ptype,
pyfuncs: list[types.FunctionType],
kernels: list[types.FunctionType],
pset,
):
for f in pyfuncs:
if not isinstance(kernels, list):
raise ValueError(f"kernels must be a list. Got {kernels=!r}")

for f in kernels:
if not isinstance(f, types.FunctionType):
raise TypeError(f"Argument pyfunc should be a function or list of functions. Got {type(f)}")
raise TypeError(f"Argument `kernels` should be a function or list of functions. Got {type(f)}")
assert_same_function_signature(f, ref=AdvectionRK4, context="Kernel")

if len(pyfuncs) == 0:
raise ValueError("List of `pyfuncs` should have at least one function.")
if len(kernels) == 0:
raise ValueError("List of `kernels` should have at least one function.")

self._fieldset = fieldset
self._ptype = ptype
self._fieldset = pset.fieldset
self._ptype = pset._ptype

self._positionupdate_kernel_added = False

for f in pyfuncs:
for f in kernels:
self.check_fieldsets_in_kernels(f)

self._pyfuncs: list[Callable] = pyfuncs
self._kernels: list[Callable] = kernels

if pset._positionupdate_kernel_added:
self.add_positionupdate_kernel()

@property #! Ported from v3. To be removed in v4? (/find another way to name kernels in output file)
def funcname(self):
ret = ""
for f in self._pyfuncs:
for f in self._kernels:
ret += f.__name__
return ret

Expand Down Expand Up @@ -123,21 +124,21 @@ def PositionUpdate(particles, fieldset): # pragma: no cover
# Update dt in case it's increased in RK45 kernel
particles.dt = particles.next_dt

self._pyfuncs = (PositionUpdate + self)._pyfuncs
self._kernels = [PositionUpdate] + self._kernels

def check_fieldsets_in_kernels(self, pyfunc): # TODO v4: this can go into another method? assert_is_compatible()?
def check_fieldsets_in_kernels(self, kernel): # TODO v4: this can go into another method? assert_is_compatible()?
"""
Checks the integrity of the fieldset with the kernels.

This function is to be called from the derived class when setting up the 'pyfunc'.
This function is to be called from the derived class when setting up the 'kernel'.
"""
if self.fieldset is not None:
if pyfunc is AdvectionAnalytical:
if kernel is AdvectionAnalytical:
if self._fieldset.U.interp_method != "cgrid_velocity":
raise NotImplementedError("Analytical Advection only works with C-grids")
if self._fieldset.U.grid._gtype not in [GridType.CurvilinearZGrid, GridType.RectilinearZGrid]:
raise NotImplementedError("Analytical Advection only works with Z-grids in the vertical")
elif pyfunc is AdvectionRK45:
elif kernel is AdvectionRK45:
if "next_dt" not in [v.name for v in self.ptype.variables]:
raise ValueError('ParticleClass requires a "next_dt" for AdvectionRK45 Kernel.')
if not hasattr(self.fieldset, "RK45_tol"):
Expand Down Expand Up @@ -174,48 +175,11 @@ def merge(self, kernel):
assert self.ptype == kernel.ptype, "Cannot merge kernels with different particle types"

return type(self)(
self._kernels + kernel._kernels,
self.fieldset,
self.ptype,
pyfuncs=self._pyfuncs + kernel._pyfuncs,
)

def __add__(self, kernel):
if isinstance(kernel, types.FunctionType):
kernel = type(self)(self.fieldset, self.ptype, pyfuncs=[kernel])
return self.merge(kernel)

def __radd__(self, kernel):
if isinstance(kernel, types.FunctionType):
kernel = type(self)(self.fieldset, self.ptype, pyfuncs=[kernel])
return kernel.merge(self)

@classmethod
def from_list(cls, fieldset, ptype, pyfunc_list):
"""Create a combined kernel from a list of functions.

Takes a list of functions, converts them to kernels, and joins them
together.

Parameters
----------
fieldset : parcels.Fieldset
FieldSet object providing the field information (possibly None)
ptype :
PType object for the kernel particle
pyfunc_list : list of functions
List of functions to be combined into a single kernel.
*args :
Additional arguments passed to first kernel during construction.
**kwargs :
Additional keyword arguments passed to first kernel during construction.
"""
if not isinstance(pyfunc_list, list):
raise TypeError(f"Argument `pyfunc_list` should be a list of functions. Got {type(pyfunc_list)}")
if not all([isinstance(f, types.FunctionType) for f in pyfunc_list]):
raise ValueError("Argument `pyfunc_list` should be a list of functions.")

return cls(fieldset, ptype, pyfunc_list)

def execute(self, pset, endtime, dt):
"""Execute this Kernel over a ParticleSet for several timesteps.

Expand Down Expand Up @@ -248,7 +212,7 @@ def execute(self, pset, endtime, dt):
pset.dt = np.minimum(np.maximum(pset.dt, -time_to_endtime), 0)

# run kernels for all particles that need to be evaluated
for f in self._pyfuncs:
for f in self._kernels:
f(pset[evaluate_particles], self._fieldset)

# check for particles that have to be repeated
Expand Down Expand Up @@ -281,8 +245,8 @@ def execute(self, pset, endtime, dt):
error_func(pset[inds].z, pset[inds].lat, pset[inds].lon)

# Only add PositionUpdate kernel at the end of the first execute call to avoid adding dt to time too early
if not self._positionupdate_kernel_added:
if not pset._positionupdate_kernel_added:
self.add_positionupdate_kernel()
self._positionupdate_kernel_added = True
pset._positionupdate_kernel_added = True

return pset
39 changes: 8 additions & 31 deletions src/parcels/_core/particleset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import sys
import types
import warnings
from collections.abc import Iterable
from typing import Literal
Expand Down Expand Up @@ -134,6 +135,7 @@ def __init__(
self._data[kwvar][:] = kwval

self._kernel = None
self._positionupdate_kernel_added = False

def __del__(self):
if self._data is not None and isinstance(self._data, xr.Dataset):
Expand Down Expand Up @@ -290,29 +292,6 @@ def from_particlefile(cls, fieldset, pclass, filename, restart=True, restarttime
"ParticleSet.from_particlefile is not yet implemented in v4."
) # TODO implement this when ParticleFile is implemented in v4

def Kernel(self, pyfunc):
"""Wrapper method to convert a `pyfunc` into a :class:`parcels.kernel.Kernel` object.

Conversion is based on `fieldset` and `ptype` of the ParticleSet.

Parameters
----------
pyfunc : function or list of functions
Python function to convert into kernel. If a list of functions is provided,
the functions will be converted to kernels and combined into a single kernel.
"""
if isinstance(pyfunc, list):
return Kernel.from_list(
self.fieldset,
self._ptype,
pyfunc,
)
return Kernel(
self.fieldset,
self._ptype,
pyfuncs=[pyfunc],
)

def data_indices(self, variable_name, compare_values, invert=False):
"""Get the indices of all particles where the value of `variable_name` equals (one of) `compare_values`.

Expand Down Expand Up @@ -376,7 +355,7 @@ def set_variable_write_status(self, var, write_status):

def execute(
self,
pyfunc,
kernels,
dt: datetime.timedelta | np.timedelta64 | float,
endtime: np.timedelta64 | np.datetime64 | None = None,
runtime: datetime.timedelta | np.timedelta64 | float | None = None,
Expand All @@ -390,10 +369,9 @@ def execute(

Parameters
----------
pyfunc :
Kernel function to execute. This can be the name of a
kernels :
List of Kernel functions to execute. This can be the name of a
defined Python function or a :class:`parcels.kernel.Kernel` object.
Kernels can be concatenated using the + operator.
dt (np.timedelta64 or float):
Timestep interval (as a np.timedelta64 object of float in seconds) to be passed to the kernel.
Use a negative value for a backward-in-time simulation.
Expand All @@ -417,10 +395,9 @@ def execute(
if len(self) == 0:
return

if not isinstance(pyfunc, Kernel):
pyfunc = self.Kernel(pyfunc)

self._kernel = pyfunc
if isinstance(kernels, types.FunctionType):
kernels = [kernels]
self._kernel = Kernel(kernels, self)

if output_file is not None:
output_file.set_metadata(self.fieldset.gridset[0]._mesh)
Expand Down
2 changes: 1 addition & 1 deletion tests-v3/test_kernel_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
def expr_kernel(name, pset, expr):
pycode = (f"def {name}(particle, fieldset, time):\n"
f" particle.p = {expr}") # fmt: skip
return Kernel(pset.fieldset, pset.particledata.ptype, pyfunc=None, funccode=pycode, funcname=name)
return Kernel(kernels=None, fieldset=pset.fieldset, ptype=pset._ptype, funccode=pycode, funcname=name)


@pytest.fixture
Expand Down
4 changes: 2 additions & 2 deletions tests/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_fieldKh_Brownian(mesh):

np.random.seed(1234)
pset = ParticleSet(fieldset=fieldset, lon=np.zeros(npart), lat=np.zeros(npart))
pset.execute(pset.Kernel(DiffusionUniformKh), runtime=runtime, dt=np.timedelta64(1, "h"))
pset.execute(DiffusionUniformKh, runtime=runtime, dt=np.timedelta64(1, "h"))

expected_std_lon = np.sqrt(2 * kh_zonal * mesh_conversion**2 * timedelta_to_float(runtime))
expected_std_lat = np.sqrt(2 * kh_meridional * mesh_conversion**2 * timedelta_to_float(runtime))
Expand Down Expand Up @@ -70,7 +70,7 @@ def test_fieldKh_SpatiallyVaryingDiffusion(mesh, kernel):

np.random.seed(1636)
pset = ParticleSet(fieldset=fieldset, lon=np.zeros(npart), lat=np.zeros(npart))
pset.execute(pset.Kernel(kernel), runtime=np.timedelta64(3, "h"), dt=np.timedelta64(1, "h"))
pset.execute(kernel, runtime=np.timedelta64(3, "h"), dt=np.timedelta64(1, "h"))

tol = 2000 * mesh_conversion # effectively 2000 m errors (because of low numbers of particles)
assert np.allclose(np.mean(pset.lon), 0, atol=tol)
Expand Down
Loading
Loading