# Copyright 2026 Daniil Shmelev
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================
"""Ordered-forest substitution for Lie--Butcher series.
The core operation implemented here is the ordered-forest coaction
``Delta_W`` from Lundervold--Munthe-Kaas: given a logarithmic linear map
``psi`` on ordered forests and a basis-aware outer character ``beta``,
``substitute(psi, beta)`` returns the substituted character
``psi star_W beta = (psi tensor beta) Delta_W``.
This is the LB-series analogue of ordinary B-series substitution used by
the reused-stage CF methods.
"""
from __future__ import annotations
from collections import Counter
from functools import lru_cache
from itertools import permutations, product
from .maps import Map
from .trees import (
EMPTY_ORDERED_FOREST,
ForestSum,
OrderedForest,
PlanarTree,
)
from .generic_algebra import mkw_apply, mkw_base_char_func
from .mkw.mkw import (
_as_basis_aware_map,
_basis_aware_func,
)
def _nonempty_trees(forest: OrderedForest) -> tuple[PlanarTree, ...]:
return tuple(t for t in forest.tree_list if t.list_repr is not None)
def _forest_from_trees(trees: tuple[PlanarTree, ...]) -> OrderedForest:
return OrderedForest(trees) if trees else EMPTY_ORDERED_FOREST
def _flatten_forest(forest: OrderedForest):
"""Return vertex records for a planar forest in preorder."""
records = []
root_ids = []
def visit(tree: PlanarTree, parent, child_index):
node_id = len(records)
records.append(
{
"parent": parent,
"child_index": child_index,
"children": [],
}
)
if parent is None:
root_ids.append(node_id)
else:
records[parent]["children"].append(node_id)
for i, child_repr in enumerate(tree.list_repr[:-1]):
visit(PlanarTree(child_repr), node_id, i)
for i, tree in enumerate(_nonempty_trees(forest)):
visit(tree, None, i)
return tuple(records), tuple(root_ids)
def _set_partitions(items: tuple[int, ...]):
if not items:
yield ()
return
first, rest = items[0], items[1:]
for partition in _set_partitions(rest):
yield (frozenset((first,)),) + partition
for i, block in enumerate(partition):
yield (
partition[:i]
+ (frozenset((*block, first)),)
+ partition[i + 1 :]
)
def _block_roots(block: frozenset[int], records) -> tuple[int, ...]:
return tuple(v for v in block if records[v]["parent"] not in block)
def _sibling_list(parent, records, root_ids):
return root_ids if parent is None else tuple(records[parent]["children"])
def _is_consecutive(values: list[int]) -> bool:
return bool(values) and max(values) - min(values) + 1 == len(values)
def _is_admissible_block(block: frozenset[int], records, root_ids) -> bool:
roots = _block_roots(block, records)
parents = {records[root]["parent"] for root in roots}
if len(parents) != 1:
return False
parent = next(iter(parents))
siblings = _sibling_list(parent, records, root_ids)
root_positions = [siblings.index(root) for root in roots]
if not _is_consecutive(root_positions):
return False
for vertex in block:
children = records[vertex]["children"]
for index, child in enumerate(children):
if child in block:
if any(
right_child not in block
for right_child in children[index + 1 :]
):
return False
return True
def _induced_forest(block: frozenset[int], records) -> OrderedForest:
roots = sorted(_block_roots(block, records))
def build_tree(vertex: int) -> PlanarTree:
children = [
build_tree(child).list_repr
for child in records[vertex]["children"]
if child in block
]
return PlanarTree(tuple(children) + (0,))
return _forest_from_trees(tuple(build_tree(root) for root in roots))
def _linear_extensions(items: tuple[int, ...], constraints: set[tuple[int, int]]):
for candidate in permutations(items):
positions = {item: i for i, item in enumerate(candidate)}
if all(positions[left] < positions[right] for left, right in constraints):
yield candidate
def _quotient_forests(
partition: tuple[frozenset[int], ...],
records,
root_ids,
) -> Counter:
block_of = {
vertex: block_index
for block_index, block in enumerate(partition)
for vertex in block
}
parent_of: dict[int, int | None] = {}
attachment_site: dict[int, int | None] = {}
roots_by_block: dict[int, tuple[int, ...]] = {}
for block_index, block in enumerate(partition):
roots = _block_roots(block, records)
roots_by_block[block_index] = roots
parent = records[roots[0]]["parent"]
attachment_site[block_index] = parent
parent_of[block_index] = None if parent is None else block_of[parent]
children_by_parent: dict[int | None, list[int]] = {None: []}
for block_index, parent_index in parent_of.items():
children_by_parent.setdefault(parent_index, [])
children_by_parent.setdefault(block_index, [])
if parent_index is None:
children_by_parent[None].append(block_index)
else:
children_by_parent[parent_index].append(block_index)
choices = []
for parent_index, children in children_by_parent.items():
child_tuple = tuple(children)
constraints: set[tuple[int, int]] = set()
for left in child_tuple:
for right in child_tuple:
if left == right:
continue
if attachment_site[left] != attachment_site[right]:
continue
site = attachment_site[left]
siblings = _sibling_list(site, records, root_ids)
left_pos = min(siblings.index(root) for root in roots_by_block[left])
right_pos = min(siblings.index(root) for root in roots_by_block[right])
if left_pos < right_pos:
constraints.add((left, right))
choices.append(
(
parent_index,
tuple(_linear_extensions(child_tuple, constraints)),
)
)
out = Counter()
for selected_orders in product(*(orders for _, orders in choices)):
order_by_parent = {
parent: order
for (parent, _), order in zip(choices, selected_orders)
}
def build_tree(block_index: int) -> PlanarTree:
children = [
build_tree(child_index).list_repr
for child_index in order_by_parent[block_index]
]
return PlanarTree(tuple(children) + (0,))
roots = tuple(build_tree(block_index) for block_index in order_by_parent[None])
out[_forest_from_trees(roots)] += 1
return out
[docs]
@lru_cache(maxsize=None)
def delta_w_terms(forest: OrderedForest):
"""Return terms of the ordered-forest contraction coaction ``Delta_W``.
Each term is ``(coeff, left_factors, right_forest)``, where
``left_factors`` is the symmetric product of admissible subforests.
"""
forest = forest.simplify()
trees = _nonempty_trees(forest)
if not trees:
return ((1, (), EMPTY_ORDERED_FOREST),)
records, root_ids = _flatten_forest(forest)
terms = []
vertices = tuple(range(len(records)))
for partition in _set_partitions(vertices):
ordered_partition = tuple(sorted(partition, key=lambda block: min(block)))
if not all(
_is_admissible_block(block, records, root_ids)
for block in ordered_partition
):
continue
left_factors = tuple(
_induced_forest(block, records)
for block in ordered_partition
)
for right_forest, coeff in _quotient_forests(
ordered_partition, records, root_ids
).items():
terms.append((coeff, left_factors, right_forest))
return tuple(terms)
[docs]
def substitute(logarithmic: Map, character: Map) -> Map:
"""Return the substituted character ``logarithmic star_W character``."""
outer = _basis_aware_func(character)
def _subst(x):
if isinstance(x, ForestSum):
return mkw_apply(x, _subst)
forest = x.as_ordered_forest() if isinstance(x, PlanarTree) else x
total = 0
for coeff, left_factors, right_forest in delta_w_terms(forest):
left_value = 1
for factor in left_factors:
left_value *= logarithmic(factor)
total += coeff * left_value * outer(right_forest)
return total
return _as_basis_aware_map(_subst)
[docs]
def frozen_exponential_character(weight) -> Map:
"""The pullback character of one frozen exponential ``exp(weight * F)``.
On ordered trees this is the bullet-only character:
- ``alpha(empty) = 1``,
- ``alpha(bullet) = weight``,
- ``alpha(t) = 0`` for every tree with more than one node.
"""
return _as_basis_aware_map(
mkw_base_char_func(
lambda tree, coeff=weight: (
1
if tree.list_repr is None
else (coeff if len(tree.list_repr) == 1 else 0)
)
)
)