# Copyright 2025 Daniil Shmelev
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================
"""
Runge-Kutta Schemes
"""
from __future__ import annotations
import copy
from typing import Union, Callable, Tuple
import warnings
try:
import numpy as np
import sympy
from scipy.optimize import root as _scipy_root
import matplotlib.pyplot as plt
from tqdm import tqdm
except ImportError:
np = sympy = _scipy_root = plt = tqdm = None
def _require_full():
if np is None:
raise ImportError(
"This feature requires additional dependencies. "
"Install with: pip install kauri[full]"
)
from .gentrees import trees_of_order, planar_trees_of_order
from .trees import Tree, Forest, ForestSum, PlanarTree, _is_scalar
from ._protocols import TreeLike, ForestLike, ForestSumLike
from .maps import Map, sign
from .generic_algebra import apply_map, sign_factor
from .bck import counit
def _check_planar_order(char_func, tol, limit):
"""Check classical order of *char_func* on ordered trees."""
n = 0
while True:
for t in planar_trees_of_order(n):
if abs(char_func(t) - 1. / t.factorial()) > tol:
return n - 1
if n >= limit:
raise RuntimeError(f"Order equals or exceeds limit of {limit}")
n += 1
def _check_planar_antisymmetric_order(defect_func, tol, limit):
"""Check antisymmetric order of *defect_func* on ordered trees."""
n = 0
while True:
for t in planar_trees_of_order(n):
expected = 1 if t.list_repr is None else 0
if abs(defect_func(t) - expected) > tol:
return n - 1
if n >= limit:
raise RuntimeError(f"Order equals or exceeds limit of {limit}")
n += 1
def _internal_symbolic(i, t_rep, a, b, s):
return sum(a[i,j] * _derivative_symbolic(j, t_rep, a, b, s) for j in range(s))
def _derivative_symbolic(i, t_rep, a, b, s):
if t_rep is None or len(t_rep) == 1:
return 1
out = 1
for subtree in t_rep[:-1]:
out *= _internal_symbolic(i, subtree, a, b, s)
return out
def _elementary_symbolic(t_rep, a, b, s):
if t_rep is None: # Empty tree
return 1
if len(t_rep) == 1: # Singleton tree
return sum(b)
return sum(b[i] * _derivative_symbolic(i, t_rep, a, b, s) for i in range(s))
def _rk_symbolic_weight(t, s, explicit = False, a_mask = None, b_mask = None):
if a_mask is None:
a_mask = [[1 for _ in range(s)] for _ in range(s)]
if b_mask is None:
b_mask = [1 for _ in range(s)]
if explicit:
for i in range(s):
for j in range(i, s):
a_mask[i][j] = 0
a = sympy.Matrix(s, s, lambda i, j: sympy.symbols(f'a{i}{j}'))
b = sympy.Matrix(1, s, lambda i, j: sympy.symbols(f'b{j}'))
# Zero terms according to mask
for i in range(s):
for j in range(s):
if not a_mask[i][j]:
a[i,j] = 0
for i in range(s):
if not b_mask[i]:
b[i] = 0
return _elementary_symbolic(t.list_repr, a, b, s)
[docs]
def rk_symbolic_weight(
t : Union[Tree, Forest, ForestSum],
s : int,
explicit : bool = False,
a_mask : list = None,
b_mask : list = None,
mathematica_code : bool = False,
rationalise : bool = True
) -> Union[sympy.core.add.Add, str]:
"""
Returns the elementary weight of a Tree, Forest or ForestSum :math:`t` as a SymPy symbolic expression.
:param t: A Tree, Forest or ForestSum
:param s: The number of Runge--Kutta stages
:type s: int
:param explicit: If true, assumes the Runge--Kutta scheme is explicit, i.e. :math:`a_{ij} = 0` for :math:`i \\leq j`.
:type explicit: bool
:param a_mask: A two-dimensional array specifying which coefficients of the Runge--Kutta parameter matrix :math:`A`
are non-zero. If not None, sets :math:`a_{ij} = 0` for all :math:`i,j` such that ``a_mask[i][j] = 0``.
:param b_mask: A one-dimensional array or list specifying which coefficients of the Runge--Kutta parameter vector :math:`b`
are non-zero. If not None, sets :math:`b_i = 0` for all :math:`i` such that ``b_mask[i] = 0``.
:param mathematica_code: If true, outputs the expression as mathematica code.
:type mathematica_code: bool
:param rationalise: If true, will attempt to rationalise the coefficients in the expression
:type rationalise: bool
:returns: The elementary weight of :math:`t`, as a SymPy symbolic expression if `mathematica_code` is False or as a
string if `mathematica_code` is True.
:rtype: sympy.core.add.Add | string
**Example usage:**
.. kauri-exec::
t = Tree([[],[]])
print(rk_symbolic_weight(t, 2))
print(rk_symbolic_weight(t, 2, explicit = True))
a_mask = [[1,0],[0,1]]
b_mask = [0,1]
print(rk_symbolic_weight(t, 2, a_mask = a_mask, b_mask = b_mask))
"""
if not isinstance(t, (int, float, TreeLike, ForestLike, ForestSumLike)):
raise TypeError("t must be a Tree, Forest, ForestSum (or planar equivalent), int or float, not " + str(type(t)))
if not isinstance(s, int):
raise TypeError("Number of stages s must be an int, not " + str(type(s)))
if not isinstance(explicit, bool):
raise TypeError("explicit must be a bool, not " + str(type(explicit)))
if not (isinstance(a_mask, list) or a_mask is None):
raise TypeError("a_mask must be a list, not " + str(type(a_mask)))
if not (isinstance(b_mask, list) or b_mask is None):
raise TypeError("b_mask must be a list, not " + str(type(b_mask)))
if not isinstance(mathematica_code, bool):
raise TypeError("mathematica_code must be a bool, not " + str(type(mathematica_code)))
if not isinstance(rationalise, bool):
raise TypeError("rationalise must be a bool, not " + str(type(rationalise)))
_require_full()
t_ = t
if _is_scalar(t):
t_ = t * Tree(None).as_forest_sum()
out = apply_map(t_, lambda x : _rk_symbolic_weight(x, s, explicit, a_mask, b_mask))
if rationalise:
out = sympy.nsimplify(out, tolerance=1e-10, rational = True)
if mathematica_code:
out = sympy.mathematica_code(out)
return out
[docs]
def rk_order_cond(
t : Union[Tree, Forest, ForestSum],
s : int,
explicit : bool = False,
a_mask : list = None,
b_mask : list = None,
mathematica_code : bool = False,
rationalise : bool = True
) -> Union[sympy.core.add.Add, str]:
"""
Returns the Runge--Kutta order condition associated with tree :math:`t` as a SymPy symbolic expression.
:param t: A Tree
:param s: The number of Runge--Kutta stages
:type s: int
:param explicit: If true, assumes the Runge--Kutta scheme is explicit, i.e. :math:`a_{ij} = 0` for :math:`i \\leq j`.
:type explicit: bool
:param a_mask: A two-dimensional array specifying which coefficients of the Runge--Kutta parameter matrix :math:`A`
are non-zero. If not None, sets :math:`a_{ij} = 0` for all :math:`i,j` such that ``a_mask[i][j] = 0``.
:param b_mask: A one-dimensional array or list specifying which coefficients of the Runge--Kutta parameter vector :math:`b`
are non-zero. If not None, sets :math:`b_i = 0` for all :math:`i` such that ``b_mask[i] = 0``.
:param mathematica_code: If true, outputs the expression as mathematica code.
:type mathematica_code: bool
:param rationalise: If true, will attempt to rationalise the coefficients in the expression
:type rationalise: bool
:returns: The order condition associated with the tree :math:`t`, as a SymPy symbolic expression if `mathematica_code` is False or as a
string if `mathematica_code` is True.
:rtype: sympy.core.add.Add | string
**Example usage:**
.. kauri-exec::
t = Tree([[],[]])
print(rk_order_cond(t, 2))
print(rk_order_cond(t, 2, explicit = True))
a_mask = [[1,0],[0,1]]
b_mask = [0,1]
print(rk_order_cond(t, 2, a_mask = a_mask, b_mask = b_mask))
"""
_require_full()
if not isinstance(t, (int, float, TreeLike, ForestLike, ForestSumLike)):
raise TypeError("t must be a Tree, Forest, ForestSum (or planar equivalent), int or float, not " + str(type(t)))
return rk_symbolic_weight(t - 1. / t.factorial(), s, explicit, a_mask, b_mask, mathematica_code, rationalise)
[docs]
class RK:
"""
A Runge--Kutta method with the Butcher tableau:
.. math::
\\begin{array}{c|c}
c & A \\\\
\\hline
& b^T
\\end{array}
where :math:`c_i = \\sum_{j=1}^s a_{ij}`.
:param a: The Runge--Kutta parameter matrix :math:`A`.
:param b: The Runge--Kutta parameter vector :math:`b`.
"""
def __init__(self, a, b, name = None):
_valid_types = (list, np.ndarray) if np is not None else (list,)
if not isinstance(a, _valid_types):
raise TypeError("a must be a list or array, not " + str(type(a)))
if not isinstance(b, _valid_types):
raise TypeError("b must be a list or array, not " + str(type(b)))
self.name = name
self.s = len(b)
if len(a) != self.s or len(a[0]) != self.s:
raise ValueError("Parameter 'a' must be a square s x s matrix and b a vector of length s")
self.a = a
self.b = b
self.c = [sum(a[i][j] for j in range(self.s)) for i in range(self.s)]
self.explicit = self._check_explicit()
self.deriv_dict = {}
for i in range(self.s):
self.deriv_dict[(i, None)] = 1
def __repr__(self):
out = "["
for i in range(self.s - 1):
out += repr(self.a[i]) + ",\n"
out += repr(self.a[-1]) + "]\n"
out += repr(self.b)
return out
def _check_explicit(self):
for i in range(self.s):
for j in range(i, self.s):
if self.a[i][j]:
return False
return True
def _inverse(self):
b_inv = [-self.b[i] for i in range(self.s)]
a_inv = [[self.a[i][j] - self.b[j] for j in range(self.s)] for i in range(self.s)]
return RK(a_inv, b_inv)
[docs]
def reverse(self) -> 'RK':
"""
Returns the RK scheme given by reversing the step size h to -h, with Butcher tableau:
.. math::
\\begin{array}{c|c}
-c & -A \\\\
\\hline
& -b^T
\\end{array}
:rtype: RK
"""
return RK([[-self.a[i][j] for j in range(self.s)] for i in range(self.s)], [-self.b[i] for i in range(self.s)])
[docs]
def adjoint(self) -> 'RK':
"""
Returns the adjoint Runge--Kutta method, given by the Butcher tableau:
.. math::
\\begin{array}{c|c}
\\widetilde{c} & e \\widetilde{b}^T - \\widetilde{A} \\\\
\\hline
& \\widetilde{b}^T
\\end{array}
where :math:`\\widetilde{b}_i := b_{s+1-i}` and :math:`\\widetilde{A}_{ij} := A_{s+1 - i, s+ 1 - j}` for all
:math:`1 \\leq i,j \\leq s`.
:rtype: RK
"""
b_adj = [self.b[self.s - 1 - j] for j in range(self.s)]
a_adj = [[self.b[self.s - 1 - j] - self.a[self.s - 1 - i][self.s - j - 1] for j in range(self.s)] for i in range(self.s)]
return RK(a_adj, b_adj)
def _explicit_step(self, y0, t0, f, h):
k = [None] * self.s
for i in range(self.s):
y_stage = y0 + h * sum(self.a[i][j] * k[j] for j in range(i))
k[i] = f(t0 + self.c[i] * h, y_stage)
y_next = y0 + h * sum(self.b[i] * k[i] for i in range(self.s))
return y_next
def _implicit_step(self, y0, t0, f, h, tol = 1e-10, max_iter = 100):
y0 = np.array(y0)
dim = len(y0)
# Start with all stages equal f(t_n, y_n)
k0 = np.tile(f(t0, y0), self.s)
def G(K_flat):
K = K_flat.reshape((self.s, dim))
G_vec = []
for i in range(self.s):
y_stage = y0 + h * sum(self.a[i][j] * K[j] for j in range(self.s))
t_stage = t0 + self.c[i] * h
G_i = K[i] - f(t_stage, y_stage)
G_vec.append(G_i)
return np.concatenate(G_vec)
sol = _scipy_root(G, k0, method='hybr', tol=tol, options={'maxfev': max_iter})
if not sol.success:
warnings.warn(f"Implicit RK solver failed: {sol.message}")
K = sol.x.reshape((self.s, dim))
y_next = y0 + h * sum(self.b[i] * K[i] for i in range(self.s))
return y_next
[docs]
def step(self,
y0 : Union[list, np.ndarray],
t0 : float,
f : Callable[[float, float], Union[list, np.ndarray]],
h : float,
tol : float = 1e-10,
max_iter : int = 100
) -> Union[list, np.ndarray]:
"""
Runs one step of the Runge--Kutta method.
:param y0: Initial condition for y
:type y0: list | array
:param t0: Initial condition for t
:type t0: float
:param f: Function defining the ODE :math:`dy / dt = f(t,y)`.
:type f: callable
:param h: Step size
:type h: float
:param tol: Tolerance for the root solving algorithm. Only applicable if the scheme is implicit.
:type tol: float
:param max_iter: Maximum number of iterations for the root solving algorithm. Only applicable if the scheme is implicit.
:type max_iter: int
:return: Next point, y1
:rtype: list | array
"""
_require_full()
if not isinstance(y0, (list, np.ndarray)):
raise TypeError("y0 must be a list or array, not " + str(type(y0)))
if not isinstance(t0, float):
raise TypeError("t0 must be a float, not " + str(type(t0)))
if not callable(f):
raise TypeError("f must be callable")
if not isinstance(h, float):
raise TypeError("h must be a float, not " + str(type(h)))
if not isinstance(tol, float):
raise TypeError("tol must be a float, not " + str(type(tol)))
if not isinstance(max_iter, int):
raise TypeError("max_iter must be an int, not " + str(type(max_iter)))
def f_(t_, y_):
return np.array(f(t_,y_))
y0_ = np.array(y0).copy()
if self.explicit:
return self._explicit_step(y0_, t0, f_, h)
return self._implicit_step(y0_, t0, f_, h, tol, max_iter)
[docs]
def run(self,
y0 : Union[list, np.ndarray],
t0 : float,
t_end : float,
f : Callable[[float, float], Union[list, np.ndarray]],
n : int,
tol : float = 1e-10,
max_iter : int = 100,
plot : bool = False,
plot_dims : Union[list, np.ndarray] = None,
plot_kwargs : dict = None
) -> Tuple[list, list]:
"""
Runs the Runge--Kutta method.
:param y0: Initial condition for y
:type y0: list | array
:param t0: Initial condition for t
:type t0: float
:param t_end: End point for t
:type t_end: float
:param f: Function defining the ODE :math:`dy / dt = f(t,y)`.
:type f: callable
:param n: Number of steps
:type n: int
:param tol: Tolerance for the root solving algorithm. Only applicable if the scheme is implicit.
:type tol: float
:param max_iter: Maximum number of iterations for the root solving algorithm. Only applicable if the scheme is implicit.
:type max_iter: int
:param plot: If true, will plot the solution
:type plot: bool
:param plot_dims: List of dimensions of the solution to plot
:type plot_dims: list | array
:param plot_kwargs: kwargs to pass to pyplot.plot() if plotting the solution.
:type plot_kwargs: dict
:return: t_vals, y_vals - the lists of values of t and y respectively
:rtype: tuple[list, list]
"""
_require_full()
if not isinstance(y0, (list, np.ndarray)):
raise TypeError("y0 must be a list or array, not " + str(type(y0)))
if not isinstance(t0, float):
raise TypeError("t0 must be a float, not " + str(type(t0)))
if not isinstance(t_end, float):
raise TypeError("t_end must be a float, not " + str(type(t_end)))
if not callable(f):
raise TypeError("f must be callable")
if not isinstance(n, int):
raise TypeError("n must be an int, not " + str(type(n)))
if n <= 0:
raise ValueError("n must be a positive integer, got " + str(n))
if not isinstance(tol, float):
raise TypeError("tol must be a float, not " + str(type(tol)))
if not isinstance(max_iter, int):
raise TypeError("max_iter must be an int, not " + str(type(max_iter)))
if not (isinstance(plot, bool) or plot is None):
raise TypeError("plot must be a bool, not " + str(type(plot)))
if not (isinstance(plot_dims, (list, np.ndarray)) or plot_dims is None):
raise TypeError("plot_dims must be a list or array, not " + str(type(plot_dims)))
if not (isinstance(plot_kwargs, dict) or plot_kwargs is None):
raise TypeError("plot_kwargs must be a dict, not " + str(type(plot_kwargs)))
if plot_kwargs is None:
plot_kwargs = {}
if plot_dims is None:
plot_dims = list(range(len(y0)))
def f_(t_, y_):
return np.array(f(t_, y_))
y0_ = np.array(y0).copy()
t_vals = [t0]
y_vals = [y0_]
t = t0
y = y0_.copy()
h = (t_end - t0) / n
step_func = (lambda y_, t_ : self._explicit_step(y_, t_, f_, h)) if self.explicit else (lambda y_, t_ : self._implicit_step(y_, t_, f_, h, tol, max_iter))
for _ in tqdm(range(n)):
y = step_func(y, t)
t += h
t_vals.append(t)
y_vals.append(copy.deepcopy(y))
if plot:
plt.plot(t_vals, np.array(y_vals)[:, plot_dims], **plot_kwargs)
return t_vals, y_vals
[docs]
def __mul__(self, other : 'RK') -> 'RK':
"""
Returns the composition of two RK schemes, :math:`(A_1, b_1)` and :math:`(A_2, b_2)`, with Butcher tableau:
.. math::
\\begin{array}{c|cc}
c_1 & A_1 & 0 \\\\
c_2 & e b_1^T & A_2\\\\
\\hline
& b_1 & b_2
\\end{array}
where :math:`e` is the vector of 1s.
:rtype: RK
"""
if not isinstance(other, RK):
raise TypeError("Cannot compose RK and object of type " + str(type(other)))
s1 = other.s
a1 = other.a
b1 = other.b
s2 = self.s
a2 = self.a
b2 = self.b
a = [[a1[i][j] for j in range(s1)] + [0 for _ in range(s2)] for i in range(s1)]
a += [[b1[j] for j in range(s1)] + [a2[i][j] for j in range(s2)] for i in range(s2)]
b = list(b1) + list(b2)
return RK(a,b)
[docs]
def __pow__(self, exponent : int) -> 'RK':
"""
Returns the compositional power of the Runge--Kutta scheme. In particular, ``self ** (-1)`` returns the scheme
with Butcher tableau:
.. math::
\\begin{array}{c|c}
& A - e b^T \\\\
\\hline
& -b^T
\\end{array}
where :math:`e` is the vector of 1s.
:param exponent: Exponent
:type exponent: int
:rtype: RK
"""
if not isinstance(exponent, int):
raise TypeError("Exponent in RK power must be int, got " + str(type(exponent)) + " instead.")
if exponent == 0:
return RK([[0]], [0])
if exponent < 0:
base = self._inverse()
expn_ = -exponent
else:
base = copy.deepcopy(self)
expn_ = exponent
out = base
for _ in range(expn_-1):
out = out * base
return out
def _internal_weights(self, i, t_rep):
return sum(self.a[i][j] * self._derivative_weights(j, t_rep) for j in range(self.s))
def _derivative_weights(self, i, t_rep):
key = (i, t_rep)
if key in self.deriv_dict:
return self.deriv_dict[key]
out = 1
for subtree in t_rep[:-1]:
out *= self._internal_weights(i, subtree)
self.deriv_dict[key] = out
return out
def _elementary_weights(self, t_rep):
if t_rep is None:
return 1
return sum(self.b[i] * self._derivative_weights(i, t_rep) for i in range(self.s))
[docs]
def elementary_weights_map(self) -> Map:
"""
Returns the elementary weight function of the Runge-Kutta method as an instance of the Map class.
:rtype: Map
"""
def f_(x):
return self._elementary_weights(x.list_repr)
return Map(f_)
[docs]
def modified_equation_map(self) -> Map:
"""
Returns the map corresponding to the elementary weights function of the
modified (B-series) vector field, :math:`\\widetilde{\\phi}`, defined by
.. math::
(\\widetilde{\\phi} \\star e)(t) = \\phi(t)
where :math:`\\phi` is the elementary weights function of the Runge-Kutta
scheme and :math:`e(t) = 1 / t!` is the elementary weights function of
the exact solution. Equivalently,
.. math::
\\widetilde{\\phi}(t) = (\\phi \\star e^{\\star (-1)})(t).
:return: Elementary weights map of the modified vector field
:rtype: Map
"""
return self.elementary_weights_map().modified_equation()
[docs]
def order(self, tol : float = 1e-10, limit : int = 10) -> int:
"""
Returns the order of the RK scheme.
:param tol: Tolerance for evaluating order conditions. An order condition of the form ``self.elementary_weights(t) = 1./t.factorial()``
is considered to be satisfied if ``abs( self.elementary_weights(t) - 1./t.factorial() ) < tol``
:type tol: float
:param limit: Highest admissible order. If the order equals or exceeds this limit, a runtime error
will be raised.
:type limit: int
:rtype: int
"""
if not isinstance(tol, float):
raise TypeError("tol must be a float, not " + str(type(tol)))
n = 0
while True:
for t in trees_of_order(n):
if abs(self._elementary_weights(t.list_repr) - 1. / t.factorial()) > tol:
return n-1
if n >= limit:
raise RuntimeError("Order equals or exceeds limit of " + str(limit))
n += 1
[docs]
def antisymmetric_order(self, tol : float = 1e-10, limit : int = 10) -> int:
"""
Returns the antisymmetric order of the RK scheme. See :cite:`shmelev2025ees`
for details.
:param tol: Tolerance for evaluating order conditions.
:type tol: float
:param limit: Highest admissible order. If the order equals or exceeds this limit, a runtime error
will be raised.
:type limit: int
:rtype: int
"""
if not isinstance(tol, float):
raise TypeError("tol must be a float, not " + str(type(tol)))
ew = self.elementary_weights_map()
m = (ew & sign) * ew
n = 0
while True:
for t in trees_of_order(n):
if abs(m(t) - counit(t)) > tol:
return n - 1
if n >= limit:
raise RuntimeError("Order equals or exceeds limit of " + str(limit))
n += 1
[docs]
def planar_order(self, tol: float = 1e-10, limit: int = 10) -> int:
"""
Returns the order of the RK scheme on ordered (planar) trees.
Checks ``Phi(tau) = 1/gamma(tau)`` for all ordered trees tau.
:param tol: Tolerance for evaluating order conditions.
:type tol: float
:param limit: Highest admissible order.
:type limit: int
:rtype: int
"""
if not isinstance(tol, float):
raise TypeError("tol must be a float, not " + str(type(tol)))
return _check_planar_order(
lambda t: self._elementary_weights(t.list_repr), tol, limit)
[docs]
def planar_antisymmetric_order(self, tol: float = 1e-10, limit: int = 10) -> int:
"""
Returns the antisymmetric order of the RK scheme on ordered (planar)
trees, using the NCK Hopf algebra.
Checks ``D(tau) = ((sign . Phi) *_nck Phi)(tau) - epsilon(tau) = 0``
for all ordered trees tau.
:param tol: Tolerance for evaluating order conditions.
:type tol: float
:param limit: Highest admissible order.
:type limit: int
:rtype: int
"""
if not isinstance(tol, float):
raise TypeError("tol must be a float, not " + str(type(tol)))
from .nck.nck import map_product as nck_map_product
ew = self.elementary_weights_map()
sign_ew = Map(lambda t: sign_factor(t) * ew(t))
m = nck_map_product(sign_ew, ew)
return _check_planar_antisymmetric_order(m, tol, limit)