diff --git a/examples/PyMPDATA_examples/burgers_equation/__init__.py b/examples/PyMPDATA_examples/burgers_equation/__init__.py
index 7908197af..8a8876563 100644
--- a/examples/PyMPDATA_examples/burgers_equation/__init__.py
+++ b/examples/PyMPDATA_examples/burgers_equation/__init__.py
@@ -4,5 +4,3 @@
burgers-equation.ipynb:
.. include:: ./burgers_equation.ipynb.badges.md
"""
-
-from .burgers_equation import run_numerical_simulation
diff --git a/examples/PyMPDATA_examples/burgers_equation/burgers_equation.ipynb b/examples/PyMPDATA_examples/burgers_equation/burgers_equation.ipynb
index 135767d5c..e54095e5c 100644
--- a/examples/PyMPDATA_examples/burgers_equation/burgers_equation.ipynb
+++ b/examples/PyMPDATA_examples/burgers_equation/burgers_equation.ipynb
@@ -16,14 +16,19 @@
"metadata": {},
"source": [
"Solution to the [Burgers' equation](https://en.wikipedia.org/wiki/Burgers%27_equation) using MPDATA compared against analytic results \n",
- "(students' project by: Wojciech Neuman, Paulina Pojda, Michał Szczygieł, Joanna Wójcicka & Antoni Zięciak)\n",
+ "(based on students' project by: Wojciech Neuman, Paulina Pojda, Michał Szczygieł, Joanna Wójcicka & Antoni Zięciak)\n",
+ "$$ \\partial_t u + u\\partial_x u = \\partial_t u + \\frac{1}{2} \\partial_x u^2 = \\frac{1}{\\text{Re}} \\partial^2_x u $$\n",
+ "where Re is the Reynolds number, $u$ is the velocity, $x$ is the spatial coordinate and $t$ is time.\n",
"\n",
- "$$ \\frac{\\partial u}{\\partial t} = -\\frac{1}{2} \\frac{\\partial u^2}{\\partial x} $$"
+ "Initial and boundary conditions:\n",
+ "- $-1 \\le x \\le 1$\n",
+ "- $u(x, 0) = -\\sin(\\pi * x)$\n",
+ "- $u(-1, t) = u(1, t) = 0$"
]
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 1,
"id": "745955e4-6f48-44ba-a9f4-bab1493a996e",
"metadata": {},
"outputs": [],
@@ -37,7 +42,7 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 2,
"id": "b6651fc5",
"metadata": {
"ExecuteTime": {
@@ -47,2053 +52,168 @@
},
"outputs": [],
"source": [
+ "import os\n",
"import numpy as np\n",
"from functools import partial\n",
"from matplotlib import pyplot\n",
- "from open_atmos_jupyter_utils import show_anim, show_plot\n",
+ "from open_atmos_jupyter_utils import show_anim\n",
"from scipy.optimize import root_scalar\n",
- "from PyMPDATA_examples.burgers_equation import run_numerical_simulation"
+ "from PyMPDATA import Options, ScalarField, Solver, Stepper, VectorField\n",
+ "from PyMPDATA.boundary_conditions import Constant"
]
},
{
- "cell_type": "code",
- "execution_count": 11,
- "id": "a786e106-86c1-4b06-b474-198f709520fc",
+ "cell_type": "markdown",
+ "id": "b0d6cec1-e9ee-4130-9303-30734ea602c2",
"metadata": {},
- "outputs": [],
"source": [
- "T_MAX = 1\n",
- "T_SHOCK = 1 / np.pi\n",
- "T_RANGE = [0, 0.1, 0.3, 0.5, 0.7, 1]\n",
- "\n",
- "NT = 400\n",
- "NX = 100\n",
- "\n",
- "X_ANALYTIC = np.linspace(-1, 1, NX)"
+ "## Numerical solver"
]
},
{
"cell_type": "code",
- "execution_count": 12,
- "id": "227d698d-0533-472c-9be1-bec14e6dabff",
+ "execution_count": 3,
+ "id": "da9db73a-5c0d-460a-b283-9c2649ea64d0",
"metadata": {},
"outputs": [],
"source": [
- "def f(x0, t, xi):\n",
- " \"\"\"\n",
- " The function to solve: x0 - sin(pi*x0)*t - xi = 0\n",
- " where xi is the initial condition at x0.\n",
- " \"\"\"\n",
- " return x0 - np.sin(np.pi * x0) * t - xi\n",
- "\n",
- "\n",
- "def df(x0, t, _):\n",
- " \"\"\"\n",
- " The derivative of the function f with respect to x0.\n",
- " \"\"\"\n",
- " return 1 - np.cos(np.pi * x0) * np.pi * t\n",
+ "def interpolate_in_space_and_multiply(*, vector_out, scalar_in, multiplier):\n",
+ " vector_out[:] = multiplier * (scalar_in[1:] + scalar_in[:-1]) / 2\n",
"\n",
+ "def extrapolate_in_time(*, vectors_in, vector_out):\n",
+ " vector_out[:] = 0.5 * (3 * vectors_in[0] - vectors_in[1])\n",
"\n",
- "def df2(x0, t, _):\n",
- " \"\"\"\n",
- " The 2nd derivative of the function f with respect to x0.\n",
- " \"\"\"\n",
- " return np.sin(np.pi * x0) * np.pi**2 * t\n",
+ "def advector_view_no_edges(solver):\n",
+ " return solver.advector.get_component(0)[1:-1]\n",
"\n",
+ "def check_cfl_condition(solver):\n",
+ " assert np.all(abs(solver.advector.get_component(0)) <= 1), np.amax(abs(solver.advector.get_component(0)))\n",
"\n",
- "def find_root(x0, t, xi):\n",
- " \"\"\"Find the root of the equation f(x0, t, xi) = 0 \"\"\"\n",
- " return root_scalar(f, args=(t, xi), x0=x0, fprime=df, fprime2=df2, method='halley', maxiter=100).root\n",
+ "def run_numerical_simulation(*, nt, t_max, dx, psi0, reynolds_number, output_interval):\n",
+ " dt = t_max / nt\n",
+ " boundary_conditions = Constant(0),\n",
+ " options = Options(nonoscillatory=True, infinite_gauge=True, n_iters=3, non_zero_mu_coeff=reynolds_number != np.inf,)\n",
+ " solver = Solver(\n",
+ " stepper=Stepper(\n",
+ " options=options,\n",
+ " grid=psi0.shape\n",
+ " ),\n",
+ " advectee=ScalarField(\n",
+ " data=psi0,\n",
+ " halo=options.n_halo,\n",
+ " boundary_conditions=boundary_conditions,\n",
+ " ),\n",
+ " advector=VectorField(\n",
+ " data=(np.zeros(len(psi0)+1),),\n",
+ " halo=options.n_halo,\n",
+ " boundary_conditions=boundary_conditions,\n",
+ " ), \n",
+ " )\n",
"\n",
+ " compute_advector = partial(\n",
+ " interpolate_in_space_and_multiply,\n",
+ " scalar_in=solver.advectee.get(),\n",
+ " multiplier=.5 * dt / dx\n",
+ " )\n",
+ " \n",
+ " interpolated_advectors = tuple(\n",
+ " np.empty_like(advector_view_no_edges(solver))\n",
+ " for _ in (0,1)\n",
+ " )\n",
+ " \n",
+ " states = []\n",
+ " compute_advector(vector_out=interpolated_advectors[1])\n",
+ " for step in range(nt + 1):\n",
+ " if step != 0:\n",
+ " compute_advector(vector_out=interpolated_advectors[0])\n",
+ " extrapolate_in_time(vector_out=advector_view_no_edges(solver), vectors_in=interpolated_advectors)\n",
+ " check_cfl_condition(solver)\n",
+ " solver.advance(n_steps=1, mu_coeff=(1 / reynolds_number,) if reynolds_number != np.inf else None)\n",
+ " interpolated_advectors = interpolated_advectors[::-1]\n",
+ " if step % output_interval == 0:\n",
+ " states.append(solver.advectee.get().copy())\n",
"\n",
- "def analytical_solution(x, t):\n",
- " \"\"\"\n",
- " Analytical solution for the wave equation\n",
- " \"\"\"\n",
- " u = np.zeros(len(x))\n",
- " for i, xi in enumerate(x):\n",
- " if t < T_SHOCK:\n",
- " x0 = find_root(x0=0, t=t, xi=xi)\n",
- " u[i] = -np.sin(np.pi * x0)\n",
- " else:\n",
- " if xi == 0:\n",
- " u[i] = 0\n",
- " else:\n",
- " # After the schock occurs, we have discontinuity at the x=0\n",
- " # so we have to start finding roots from some other arbitraty point\n",
- " # from which we have continuous function, we are starting from the -1\n",
- " # for the negative x values and from the 1 for the positive x values\n",
- " x0 = find_root(x0=xi / abs(xi), t=t, xi=xi)\n",
- " u[i] = -np.sin(np.pi * x0)\n",
- " return u\n",
- "\n",
- "def calculate_analytical_solutions():\n",
- " \"\"\" \n",
- " Calculate the analytical solutions for the given time range.\n",
- " Initial and boundary conditions:\n",
- " - -1 <= x <= 1\n",
- " - u(x, 0) = -sin(pi * x)\n",
- " - u(-1, t) = u(1, t) = 0\n",
- " \"\"\"\n",
- " solutions = np.zeros((len(X_ANALYTIC), len(T_RANGE)))\n",
- "\n",
- " for j, t in enumerate(T_RANGE):\n",
- " solutions[:, j] = analytical_solution(X_ANALYTIC, t)\n",
- "\n",
- " return solutions"
+ " return states, x, dt"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6f39dffb-6735-46f2-b257-5824ab0bb239",
+ "metadata": {},
+ "source": [
+ "## Numerical solution"
]
},
{
"cell_type": "code",
- "execution_count": 16,
- "id": "ce086840-57f2-41f8-8654-8cca1a29067d",
+ "execution_count": 4,
+ "id": "132a1186-f51b-4fb1-923d-4bce4528650c",
"metadata": {},
"outputs": [],
"source": [
- "def plot_commons():\n",
- " pyplot.xlabel(\"x\")\n",
- " pyplot.ylabel(\"u\")\n",
- " pyplot.ylim([-1.05, 1.05])\n",
- " pyplot.grid()\n",
+ "def initial_condition(x):\n",
+ " return -np.sin(np.pi * x)\n",
"\n",
- "def plot_numerical_vs_analytical(states, x, t, t_max, nt):\n",
- " analytical = analytical_solution(x, t)\n",
- " time_index = int((t / t_max) * nt) \n",
- " time_index = min(time_index, len(states) - 1)\n",
- " numerical = states[time_index, :]\n",
- " pyplot.step(x, numerical, label=\"Numerical\", where='mid')\n",
- " pyplot.plot(x, analytical, label=\"Analytical\")\n",
- " pyplot.title(f\"t={t:.3f}\")\n",
- " plot_commons()\n",
- " pyplot.legend()\n",
- " show_plot(filename=\"numeric\")\n",
+ "reynolds_numbers = np.inf, 1e16\n",
"\n",
- "def plot_gif(step, states, x, dt):\n",
- " fig = pyplot.figure()\n",
- " pyplot.plot(x, analytical_solution(x, 0), label=\"Initial condition\")\n",
- " pyplot.step(x, states[step], label=\"Numerical\", where='mid')\n",
- " pyplot.plot(x, analytical_solution(x, step * dt), label=\"Analytical\")\n",
- " pyplot.title(f\"t={step * dt:.3f}\")\n",
- " plot_commons()\n",
- " pyplot.legend()\n",
- " return fig \n",
+ "T_MAX = 1\n",
+ "NT = 4000 if 'CI' not in os.environ else 800\n",
+ "NX = 100 if 'CI' not in os.environ else 20\n",
+ "OUTPUT_INTERVAL = 10\n",
"\n",
- "def plot_analytical_solutions(solutions, t_range):\n",
- " pyplot.plot(X_ANALYTIC, solutions)\n",
- " pyplot.xlim([-1, 1]) \n",
- " pyplot.title(\"Analytical Solution to the Burgers Equation\")\n",
- " pyplot.legend([f\"t={t}\" for t in t_range])\n",
- " plot_commons()\n",
- " show_plot(filename=\"analytical\")"
+ "x, dx = np.linspace(-1 + 1/NX, 1 - 1/NX, NX, retstep=True)\n",
+ "\n",
+ "states_num = {}\n",
+ "for reynolds_number in reynolds_numbers:\n",
+ " states_num[reynolds_number], x_num, dt_num = run_numerical_simulation(\n",
+ " nt=NT,\n",
+ " t_max=T_MAX,\n",
+ " dx=dx,\n",
+ " psi0=initial_condition(x),\n",
+ " reynolds_number=reynolds_number,\n",
+ " output_interval=OUTPUT_INTERVAL,\n",
+ " )"
]
},
{
- "cell_type": "code",
- "execution_count": 17,
- "id": "2ec8b760606c6a18",
- "metadata": {
- "ExecuteTime": {
- "end_time": "2025-06-08T10:14:02.211986Z",
- "start_time": "2025-06-08T10:14:02.013055Z"
- },
- "collapsed": false,
- "jupyter": {
- "outputs_hidden": false
- }
- },
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "0e184eabb630493cadba6c0dd22f1ea0",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(HTML(value=\"./analytical.pdf
\"), HTML(value=\"…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
+ "cell_type": "markdown",
+ "id": "43f93c5e-3bd2-4225-8c65-d400050aa663",
+ "metadata": {},
"source": [
- "analytical_solutions = calculate_analytical_solutions()\n",
- "plot_analytical_solutions(analytical_solutions, T_RANGE)"
+ "## Semi-analytic solver"
]
},
{
"cell_type": "code",
- "execution_count": 19,
- "id": "e1bc2d61d23d9b3d",
+ "execution_count": 5,
+ "id": "227d698d-0533-472c-9be1-bec14e6dabff",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "function = lambda x, t, initial: x - np.sin(np.pi * x) * t - initial\n",
+ "fprime = lambda x, t, _: 1 - np.cos(np.pi * x) * np.pi * t\n",
+ "fprime2 = lambda x, t, _: np.sin(np.pi * x) * np.pi**2 * t\n",
+ "\n",
+ "def analytical_solution(x, t): \n",
+ " u = np.zeros_like(x)\n",
+ " for i, xi in enumerate(x):\n",
+ " res = root_scalar(\n",
+ " f=function,\n",
+ " args=(t, xi),\n",
+ " x0=xi / abs(xi),\n",
+ " fprime=fprime,\n",
+ " fprime2=fprime2,\n",
+ " method='halley',\n",
+ " maxiter=10000\n",
+ " ) \n",
+ " assert res.converged\n",
+ " u[i] = initial_condition(res.root)\n",
+ " return u"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e6d18b6c-32d0-4f57-9001-69b29eb259de",
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-08T10:14:03.595403Z",
@@ -2104,1372 +224,13 @@
"outputs_hidden": false
}
},
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "05b6f66bff5c4169a6649a44c21c2ca9",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(HTML(value=\"./numeric.pdf
\"), HTML(value=\""
+ "
"
],
"text/plain": [
""
@@ -3493,7 +254,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "3713f3d551fb4e29af8f87661cdaa174",
+ "model_id": "b27ae42804a24f9fbd629ee6220ef18f",
"version_major": 2,
"version_minor": 0
},
@@ -3506,13 +267,36 @@
}
],
"source": [
- "show_anim(plot_partial, range(0, int(T_SHOCK/(T_MAX/NT)) - 10), gif_file=\"burgers.gif\")"
+ "sci_label = lambda v: \"\\infty\" if v == np.inf else rf\"{float(f'{v:.2e}'.split('e')[0])}\\times10^{{{int(f'{v:.2e}'.split('e')[1])}}}\"\n",
+ "\n",
+ "def plot_gif(frame, states, x, dt):\n",
+ " time = frame * OUTPUT_INTERVAL * dt\n",
+ " fig, ax = pyplot.subplots(figsize=(12,6))\n",
+ " for t, y, label, style in [\n",
+ " (0, analytical_solution(x, 0), \"Initial condition\", '-'),\n",
+ " *[\n",
+ " (time, states[reynolds_number][frame], f\"Numerical (Re=${sci_label(reynolds_number)}$)\", 'step') \n",
+ " for reynolds_number in reynolds_numbers\n",
+ " ],\n",
+ " (time, analytical_solution(x, time), r\"Analytical inviscid (Re$\\rightarrow\\infty$)\", '-'),\n",
+ " ]:\n",
+ " if style == 'step':\n",
+ " ax.step(x, y, where='mid', label=label, linewidth=2.5)\n",
+ " else:\n",
+ " ax.plot(x, y, label=label)\n",
+ " ax.set(xlabel=\"x\", ylabel=\"u\", ylim=(-1.35, 1.35), title=f\"t={time:.3f}\")\n",
+ " ax.grid();\n",
+ " ax.legend()\n",
+ " pyplot.savefig(f\"{frame:03d}.pdf\")\n",
+ " return fig\n",
+ "\n",
+ "show_anim(partial(plot_gif, states=states_num, x=x_num, dt=dt_num), range(NT // OUTPUT_INTERVAL + 1), gif_file=\"burgers.gif\")"
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "8ffa5c6d-ef5f-4015-9c7c-034aa0e99e70",
+ "id": "0a7ab539-b6fe-43dc-abf2-82b409aeb219",
"metadata": {},
"outputs": [],
"source": []
diff --git a/examples/PyMPDATA_examples/burgers_equation/burgers_equation.py b/examples/PyMPDATA_examples/burgers_equation/burgers_equation.py
deleted file mode 100644
index 312d17b2f..000000000
--- a/examples/PyMPDATA_examples/burgers_equation/burgers_equation.py
+++ /dev/null
@@ -1,64 +0,0 @@
-"""
-Solution for the Burgers equation solution with MPDATA
-"""
-
-import numpy as np
-
-from PyMPDATA import Options, ScalarField, Solver, Stepper, VectorField
-from PyMPDATA.boundary_conditions import Constant
-
-OPTIONS = Options(nonoscillatory=False, infinite_gauge=True)
-
-
-def initialize_simulation(nt, nx, t_max):
- """
- Initializes simulation variables and returns them.
- """
- dt = t_max / nt
- courants_x, dx = np.linspace(-1, 1, nx + 1, endpoint=True, retstep=True)
- x = courants_x[:-1] + dx / 2
- u0 = -np.sin(np.pi * x)
-
- stepper = Stepper(options=OPTIONS, n_dims=1)
- advectee = ScalarField(
- data=u0, halo=OPTIONS.n_halo, boundary_conditions=(Constant(0), Constant(0))
- )
- advector = VectorField(
- data=(np.full(courants_x.shape, 0.0),),
- halo=OPTIONS.n_halo,
- boundary_conditions=(Constant(0), Constant(0)),
- )
- solver = Solver(stepper=stepper, advectee=advectee, advector=advector)
- return dt, dx, x, advectee, advector, solver
-
-
-def update_advector_n(vel, dt, dx, slice_idx):
- """
- Computes and returns the updated advector_n.
- """
- indices = np.arange(slice_idx.start, slice_idx.stop)
- return 0.5 * ((vel[indices] - vel[indices - 1]) / 2 + vel[:-1]) * dt / dx
-
-
-def run_numerical_simulation(*, nt, nx, t_max):
- """
- Runs the numerical simulation and returns (states, x, dt, dx).
- """
- dt, dx, x, advectee, advector, solver = initialize_simulation(nt, nx, t_max)
- states = []
- vel = advectee.get()
- advector_n_1 = 0.5 * (vel[:-1] + np.diff(vel) / 2) * dt / dx
- assert np.all(advector_n_1 <= 1)
- i = slice(1, len(vel))
-
- for _ in range(nt):
- vel = advectee.get()
- advector_n = update_advector_n(vel, dt, dx, i)
- advector.get_component(0)[1:-1] = 0.5 * (3 * advector_n - advector_n_1)
- assert np.all(advector.get_component(0) <= 1)
-
- solver.advance(n_steps=1)
- advector_n_1 = advector_n.copy()
- states.append(solver.advectee.get().copy())
-
- return np.array(states), x, dt, dx
diff --git a/tests/smoke_tests/burgers_equation/test_burgers_equation.py b/tests/smoke_tests/burgers_equation/test_burgers_equation.py
index 3b4a8df6b..c8f9ac9a6 100644
--- a/tests/smoke_tests/burgers_equation/test_burgers_equation.py
+++ b/tests/smoke_tests/burgers_equation/test_burgers_equation.py
@@ -1,58 +1,26 @@
-"""Unit tests for the Burgers' equation numerical simulation."""
+"""smoke tests for the Burgers' equation numerical simulation."""
+
+from pathlib import Path
import numpy as np
import pytest
-from PyMPDATA_examples.burgers_equation import run_numerical_simulation
-
+from open_atmos_jupyter_utils import notebook_vars
+from PyMPDATA_examples import burgers_equation
-@pytest.fixture(name="states")
-def states_fixture():
- """Run the simulation once for all tests."""
- return run_numerical_simulation(nt=400, nx=100, t_max=1 / np.pi)[0]
+PLOT = False
-class TestBurgersEquation:
- """Test suite for general numerical verification of Burgers' equation simulation."""
+@pytest.fixture(scope="session", name="variables")
+def _variables_fixture():
+ return notebook_vars(
+ file=Path(burgers_equation.__file__).parent / "burgers_equation.ipynb",
+ plot=PLOT,
+ )
- @staticmethod
- def test_total_momentum_conservation(states):
- """Verify total momentum remains approximately constant over time."""
- sum_initial_state = np.sum(states[0])
- eps = 1e-5
-
- for state in states:
- sum_state = np.sum(state)
- np.testing.assert_allclose(
- desired=sum_initial_state,
- actual=sum_state,
- atol=eps,
- )
- @staticmethod
- def test_solution_within_bounds(states):
- """Ensure numerical solution u(i, j) stays within expected bounds."""
- eps = 1e-2
- min_val = np.min(states) + eps
- max_val = np.max(states) - eps
- assert min_val >= -1.0
- assert max_val <= 1.0
+class TestBurgersEquation:
+ """assertions on the final notebook state"""
@staticmethod
- def test_zero_constant_boundary_conditions(states):
- """Verify zero-constant boundary conditions are satisfied at all time steps."""
- eps = 5e-2
-
- for state in states:
- left_boundary = state[0]
- right_boundary = state[-1]
-
- np.testing.assert_allclose(
- desired=0,
- actual=left_boundary,
- atol=eps,
- )
- np.testing.assert_allclose(
- desired=0,
- actual=right_boundary,
- atol=eps,
- )
+ def test_vs_analytic(variables):
+ pass