Source code for kauri.gentrees

# Copyright 2025 Daniil Shmelev
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================

"""
Functions for generating rooted trees in lexicographic order, based on the algorithms of :cite:`beyer1980constant`.
"""
from typing import Generator
from functools import lru_cache, cache
from itertools import product

from .trees import Tree, EMPTY_TREE
from .utils import _level_sequence_to_list_repr, _apply_color_sequence

[docs] def trees_up_to_order(order : int) -> Generator[Tree, None, None]: """ Yields the trees up to and including order :math:`n`, ordered by the lexicographic order. :param order: Maximum order :type order: int :yields: The next tree in lexicographic order, as long as the order of the tree does not exceed :math:`n`. :rtype: Tree **Example usage:** .. kauri-exec:: trees = list(trees_up_to_order(5)) for i in range(0, len(trees), 10): kr.display(*trees[i:i+10]) """ if not isinstance(order, int): raise TypeError("order must be an int, not " + str(type(order))) if order < 0: raise ValueError("order must be non-negative") t = Tree(None) while t.nodes() <= order: yield t t = next(t)
[docs] def trees_of_order(order : int) -> Generator[Tree, None, None]: """ Yields the trees of order :math:`n`, ordered by the lexicographic order. :param order: Order :type order: int :yields: The next tree in lexicographic order, as long as the order of the tree is :math:`n`. :rtype: Tree **Example usage:** .. kauri-exec:: kr.display(*trees_of_order(5)) """ if not isinstance(order, int): raise TypeError("order must be an int, not " + str(type(order))) if order < 0: raise ValueError("order must be non-negative") t = Tree(_level_sequence_to_list_repr(list(range(order)))) while t.nodes() == order: yield t t = next(t)
def _ordered_compositions(total: int) -> Generator[tuple[int, ...], None, None]: """Yields ordered tuples of positive integers summing to total.""" if total == 0: yield tuple() return for first in range(1, total + 1): for rest in _ordered_compositions(total - first): yield (first, *rest) @lru_cache(maxsize=None) def _planar_repr_of_order(order: int) -> tuple[tuple, ...]: """Canonical tuple representations for planar trees of fixed order.""" if order == 1: return (tuple(),) out: list[tuple] = [] for child_orders in _ordered_compositions(order - 1): per_child = tuple(_planar_repr_of_order(child_order) for child_order in child_orders) for child_tuple in product(*per_child): out.append(tuple(child_tuple)) return tuple(sorted(out))
[docs] def planar_trees_of_order(order: int): """ Yields planar rooted trees of fixed order. Order 0 contains only the empty planar tree. **Example usage:** .. kauri-exec:: trees = list(planar_trees_of_order(5)) for i in range(0, len(trees), 10): kr.display(*trees[i:i+10]) """ from .trees import EMPTY_PLANAR_TREE, PlanarTree, validate_order validate_order(order) if order == 0: yield EMPTY_PLANAR_TREE return for list_repr in _planar_repr_of_order(order): yield PlanarTree(list_repr)
[docs] def planar_trees_up_to_order(order: int): """Yields planar rooted trees of all orders from 0 through order. **Example usage:** .. kauri-exec:: trees = list(planar_trees_up_to_order(5)) for i in range(0, len(trees), 10): kr.display(*trees[i:i+10]) """ from .trees import validate_order validate_order(order) for current_order in range(order + 1): yield from planar_trees_of_order(current_order)
def _validate_num_colors(d): if not isinstance(d, int): raise TypeError("number of colors d must be an int, not " + str(type(d))) if d < 0: raise ValueError("number of colors d must be non-negative") def _all_colorings(unlabelled, d: int, n: int, cls=Tree): """Yields a tree (of type *cls*) for every coloring of an unlabelled shape.""" for coloring in product(range(d), repeat=n): yield cls(_apply_color_sequence(unlabelled, iter(coloring))) def _color_all_variants(shape: Tree, d: int): """Yields all distinct colorings of an unlabelled tree shape with d colors.""" n = shape.nodes() if n == 0: yield shape return unlabelled = shape.unlabelled_repr if shape.sigma() == 1: yield from _all_colorings(unlabelled, d, n) else: seen = set() for t in _all_colorings(unlabelled, d, n): if t not in seen: seen.add(t) yield t
[docs] def colored_trees_of_order(order: int, d: int): """ Yields all distinct colored rooted trees of a given order with *d* colors. Each node is decorated with a color from {0, ..., d-1}. :param order: Number of nodes :type order: int :param d: Number of colors :type d: int :yields: Colored trees :rtype: Tree **Example usage:** .. kauri-exec:: trees = list(colored_trees_of_order(4, 2)) for i in range(0, len(trees), 10): kr.display(*trees[i:i+10]) """ _validate_num_colors(d) for shape in trees_of_order(order): yield from _color_all_variants(shape, d)
[docs] def colored_trees_up_to_order(order: int, d: int): """ Yields all distinct colored rooted trees up to and including a given order with *d* colors. :param order: Maximum number of nodes :type order: int :param d: Number of colors :type d: int :yields: Colored trees :rtype: Tree **Example usage:** .. kauri-exec:: trees = list(colored_trees_up_to_order(4, 2)) for i in range(0, len(trees), 10): kr.display(*trees[i:i+10]) """ _validate_num_colors(d) for shape in trees_up_to_order(order): yield from _color_all_variants(shape, d)
[docs] def colored_planar_trees_of_order(order: int, d: int): """ Yields all colored planar rooted trees of a given order with *d* colors. Each node is decorated with a color from {0, ..., d-1}. Planar trees have no symmetry, so every coloring is distinct. :param order: Number of nodes :type order: int :param d: Number of colors :type d: int :yields: Colored planar trees :rtype: PlanarTree **Example usage:** .. kauri-exec:: trees = list(colored_planar_trees_of_order(4, 2)) for i in range(0, len(trees), 10): kr.display(*trees[i:i+10]) """ from .trees import PlanarTree, EMPTY_PLANAR_TREE, validate_order validate_order(order) _validate_num_colors(d) if order == 0: yield EMPTY_PLANAR_TREE return for shape in planar_trees_of_order(order): yield from _all_colorings(shape.unlabelled_repr, d, order, PlanarTree)
[docs] def colored_planar_trees_up_to_order(order: int, d: int): """ Yields all colored planar rooted trees up to and including a given order with *d* colors. :param order: Maximum number of nodes :type order: int :param d: Number of colors :type d: int :yields: Colored planar trees :rtype: PlanarTree **Example usage:** .. kauri-exec:: trees = list(colored_planar_trees_up_to_order(4, 2)) for i in range(0, len(trees), 10): kr.display(*trees[i:i+10]) """ from .trees import validate_order validate_order(order) _validate_num_colors(d) for current_order in range(order + 1): yield from colored_planar_trees_of_order(current_order, d)
[docs] def ordered_forests_of_order(order: int): """ Yields ordered forests of planar rooted trees with a fixed total order. Order 0 contains only the empty ordered forest. """ from .trees import validate_order validate_order(order) yield from _ordered_forest_list_exact_cached(order)
@cache def _planar_tree_node_pairs(max_order: int) -> tuple: return tuple( (tree, tree.nodes()) for tree in planar_trees_up_to_order(max_order) if tree.nodes() != 0 ) @cache def _ordered_forest_list_exact_cached(order: int) -> tuple: return _ordered_forest_strata_cached(order)[order] @cache def _ordered_forest_strata_cached(max_order: int) -> tuple: from .trees import EMPTY_ORDERED_FOREST, OrderedForest tree_pairs = _planar_tree_node_pairs(max_order) strata = [(EMPTY_ORDERED_FOREST,)] for order in range(1, max_order + 1): out = [] for tree, nodes in tree_pairs: if nodes > order: break remaining = order - nodes if remaining == 0: out.append(OrderedForest((tree,))) else: for suffix in strata[remaining]: out.append(OrderedForest((tree,) + suffix.tree_list)) strata.append(tuple(out)) return tuple(strata)
[docs] def ordered_forests_up_to_order(order: int): """ Yields ordered forests of planar rooted trees up to a given total order. """ from .trees import validate_order validate_order(order) yield from _ordered_forest_list_cached(order)
@cache def _ordered_forest_list_cached(max_order: int) -> tuple: return tuple( forest for stratum in _ordered_forest_strata_cached(max_order) for forest in stratum )
[docs] def colored_ordered_forests_of_order(order: int, d: int): """ Yields colored ordered forests with a fixed total order and *d* colors. Each node is decorated with a color from {0, ..., d-1}. """ from .trees import validate_order validate_order(order) _validate_num_colors(d) yield from _colored_ordered_forest_list_exact_cached(order, d)
@cache def _colored_planar_tree_node_pairs(max_order: int, d: int) -> tuple: return tuple( (tree, tree.nodes()) for tree in _colored_planar_tree_list_cached(max_order, d) if tree.nodes() != 0 ) @cache def _colored_ordered_forest_list_exact_cached(order: int, d: int) -> tuple: return _colored_ordered_forest_strata_cached(order, d)[order] @cache def _colored_ordered_forest_strata_cached(max_order: int, d: int) -> tuple: from .trees import EMPTY_ORDERED_FOREST, OrderedForest tree_pairs = _colored_planar_tree_node_pairs(max_order, d) strata = [(EMPTY_ORDERED_FOREST,)] for order in range(1, max_order + 1): out = [] for tree, nodes in tree_pairs: if nodes > order: break remaining = order - nodes if remaining == 0: out.append(OrderedForest((tree,))) else: for suffix in strata[remaining]: out.append(OrderedForest((tree,) + suffix.tree_list)) strata.append(tuple(out)) return tuple(strata)
[docs] def colored_ordered_forests_up_to_order(order: int, d: int): """ Yields colored ordered forests up to a given total order with *d* colors. """ from .trees import validate_order validate_order(order) _validate_num_colors(d) yield from _colored_ordered_forest_list_cached(order, d)
# --------------------------------------------------------------------------- # Colored tree indexing # --------------------------------------------------------------------------- @cache def _colored_tree_list_cached(max_order: int, d: int) -> tuple: """Cached tuple of all colored trees up to max_order with d colors.""" return tuple(colored_trees_up_to_order(max_order, d)) @cache def _colored_tree_lookup_cached(max_order: int, d: int) -> dict: """Cached dict mapping Tree -> index.""" trees = _colored_tree_list_cached(max_order, d) return {t: i for i, t in enumerate(trees)} @cache def _colored_planar_tree_list_cached(max_order: int, d: int) -> tuple: """Cached tuple of all colored planar trees up to max_order with d colors.""" return tuple(colored_planar_trees_up_to_order(max_order, d)) @cache def _colored_planar_tree_lookup_cached(max_order: int, d: int) -> dict: """Cached dict mapping PlanarTree -> index.""" trees = _colored_planar_tree_list_cached(max_order, d) return {t: i for i, t in enumerate(trees)} @cache def _colored_ordered_forest_list_cached(max_order: int, d: int) -> tuple: """Cached tuple of all colored ordered forests up to max_order with d colors.""" return tuple( forest for stratum in _colored_ordered_forest_strata_cached(max_order, d) for forest in stratum ) @cache def _colored_ordered_forest_lookup_cached(max_order: int, d: int) -> dict: """Cached dict mapping OrderedForest -> index.""" forests = _colored_ordered_forest_list_cached(max_order, d) return {f: i for i, f in enumerate(forests)}
[docs] def colored_trees(d: int, max_order: int) -> list[Tree]: """ Returns all distinct colored rooted trees up to a given order with *d* colors, starting with the empty tree. :param d: Number of colors (path dimension). :type d: int :param max_order: Maximum number of nodes. :type max_order: int :return: List of colored trees. :rtype: list[Tree] """ _validate_num_colors(d) return list(_colored_tree_list_cached(max_order, d))
[docs] def colored_tree_to_idx(tree: Tree, d: int, max_order: int) -> int: """ Returns the index of a colored tree in the canonical enumeration. Index 0 is the empty tree. Non-empty trees are enumerated by order, then by shape, then by coloring. :param tree: A colored rooted tree. :type tree: Tree :param d: Number of colors (path dimension). :type d: int :param max_order: Maximum number of nodes. :type max_order: int :return: Index in the enumeration. :rtype: int """ _validate_num_colors(d) lookup = _colored_tree_lookup_cached(max_order, d) if tree not in lookup: raise ValueError(f"Tree {tree} not found in enumeration for d={d}, max_order={max_order}") return lookup[tree]
[docs] def idx_to_colored_tree(idx: int, d: int, max_order: int) -> Tree: """ Returns the colored tree at a given index in the canonical enumeration. :param idx: Index (0 = empty tree). :type idx: int :param d: Number of colors (path dimension). :type d: int :param max_order: Maximum number of nodes. :type max_order: int :return: The colored tree at the given index. :rtype: Tree """ _validate_num_colors(d) trees = _colored_tree_list_cached(max_order, d) if idx < 0 or idx >= len(trees): raise ValueError(f"idx {idx} out of range [0, {len(trees)}) for d={d}, max_order={max_order}") return trees[idx]
[docs] def colored_planar_tree_to_idx(tree, d: int, max_order: int) -> int: """ Returns the index of a colored planar tree in the canonical enumeration. Planar analogue of :func:`colored_tree_to_idx`. Index 0 is the empty tree; non-empty trees are enumerated in the order emitted by :func:`colored_planar_trees_up_to_order`. :param tree: A colored planar rooted tree. :type tree: PlanarTree :param d: Number of colors (path dimension). :type d: int :param max_order: Maximum number of nodes. :type max_order: int :return: Index in the enumeration. :rtype: int """ _validate_num_colors(d) lookup = _colored_planar_tree_lookup_cached(max_order, d) if tree not in lookup: raise ValueError(f"Planar tree {tree} not found in enumeration for d={d}, max_order={max_order}") return lookup[tree]
[docs] def idx_to_colored_planar_tree(idx: int, d: int, max_order: int): """ Returns the colored planar tree at a given index in the canonical enumeration. Planar analogue of :func:`idx_to_colored_tree`. Index 0 is the empty planar tree. :param idx: Index (0 = empty planar tree). :type idx: int :param d: Number of colors (path dimension). :type d: int :param max_order: Maximum number of nodes. :type max_order: int :return: The colored planar tree at the given index. :rtype: PlanarTree """ _validate_num_colors(d) trees = _colored_planar_tree_list_cached(max_order, d) if idx < 0 or idx >= len(trees): raise ValueError(f"idx {idx} out of range [0, {len(trees)}) for d={d}, max_order={max_order}") return trees[idx]
[docs] def colored_ordered_forests(d: int, max_order: int) -> list: """ Returns all colored ordered forests up to a given total order with *d* colors, starting with the empty ordered forest. :param d: Number of colors. :type d: int :param max_order: Maximum total number of nodes. :type max_order: int :return: List of colored ordered forests. :rtype: list[OrderedForest] """ _validate_num_colors(d) return list(_colored_ordered_forest_list_cached(max_order, d))
[docs] def colored_ordered_forest_to_idx(forest, d: int, max_order: int) -> int: """ Returns the index of a colored ordered forest in the canonical enumeration. :param forest: A colored ordered forest. :type forest: OrderedForest :param d: Number of colors. :type d: int :param max_order: Maximum total number of nodes. :type max_order: int :return: Index in the enumeration. :rtype: int """ _validate_num_colors(d) lookup = _colored_ordered_forest_lookup_cached(max_order, d) if forest not in lookup: raise ValueError(f"Ordered forest {forest} not found in enumeration for d={d}, max_order={max_order}") return lookup[forest]
[docs] def idx_to_colored_ordered_forest(idx: int, d: int, max_order: int): """ Returns the colored ordered forest at a given index in the canonical enumeration. :param idx: Index, with 0 the empty ordered forest. :type idx: int :param d: Number of colors. :type d: int :param max_order: Maximum total number of nodes. :type max_order: int :return: The colored ordered forest at the given index. :rtype: OrderedForest """ _validate_num_colors(d) forests = _colored_ordered_forest_list_cached(max_order, d) if idx < 0 or idx >= len(forests): raise ValueError(f"idx {idx} out of range [0, {len(forests)}) for d={d}, max_order={max_order}") return forests[idx]
# --------------------------------------------------------------------------- # Recursive tree ordering and canonical-recursive permutation # # The "recursive" ordering enumerates decorated trees bottom-up, cycling root # labels innermost. Non-planar uses child multisets; planar uses ordered child # sequences. Both orderings match the C++ enumeration in pySigLib's # siglib/shared/branched_trees.h. # # The "canonical" ordering (used by colored_trees_of_order etc.) enumerates # by shape first, then colorings. # --------------------------------------------------------------------------- def _enumerate_child_indices(target_nodes, min_idx, tree_nodes, total_count, *, ordered): """Enumerate tuples of child tree indices whose node counts sum to target_nodes. When ``ordered=False`` the tuples are non-decreasing in index (multiset semantics); when ``ordered=True`` any left-to-right ordering is allowed. """ if target_nodes == 0: yield () return start = 0 if ordered else min_idx for idx in range(start, total_count): n = tree_nodes[idx] if n > target_nodes: break for rest in _enumerate_child_indices( target_nodes - n, idx, tree_nodes, total_count, ordered=ordered, ): yield (idx,) + rest @cache def _enumerate_trees_recursive(d: int, max_order: int, planar: bool = False) -> tuple: """Enumerate decorated trees in recursive ordering (child-group first, root label innermost).""" trees = [] tree_nodes = [] for order in range(1, max_order + 1): if order == 1: for label in range(d): trees.append((1, label, ())) tree_nodes.append(1) else: current_count = len(trees) for children in _enumerate_child_indices( order - 1, 0, tree_nodes, current_count, ordered=planar, ): if not children: continue for label in range(d): trees.append((order, label, children)) tree_nodes.append(order) return tuple(trees) def _all_recursive_kauri_trees(d: int, max_order: int, planar: bool): """Build every recursive-order tree as a kauri Tree/PlanarTree in one bottom-up pass. Because each tree's ``child_ids`` reference strictly smaller indices, we can build the parent directly from cached children, avoiding the O(n^2) blow-up of the naive per-tree recursion. """ from .trees import Tree, Forest, PlanarTree, NoncommutativeForest rec = _enumerate_trees_recursive(d, max_order, planar) out = [None] * len(rec) for i, (_n, label, child_ids) in enumerate(rec): if not child_ids: out[i] = PlanarTree([label]) if planar else Tree([label]) elif planar: out[i] = NoncommutativeForest(tuple(out[c] for c in child_ids)).join(root_color=label) else: out[i] = Forest(tuple(out[c] for c in child_ids)).join(root_color=label) return out def _canonical_to_recursive_permutation_impl(d: int, max_order: int, planar: bool): try: import numpy as np except ImportError: raise ImportError("Permutation functions require numpy. Install with: pip install kauri[full]") _validate_num_colors(d) rec_kauri = _all_recursive_kauri_trees(d, max_order, planar) rec_lookup = {t: i for i, t in enumerate(rec_kauri)} canonical = (_colored_planar_tree_list_cached if planar else _colored_tree_list_cached)(max_order, d) perm = [rec_lookup[kt] for kt in canonical[1:]] return np.array(perm, dtype=np.int64) def _recursive_to_canonical_permutation_impl(d: int, max_order: int, planar: bool): try: import numpy as np except ImportError: raise ImportError("Permutation functions require numpy. Install with: pip install kauri[full]") perm = (planar_canonical_to_recursive_permutation if planar else canonical_to_recursive_permutation)(d, max_order) inv = np.empty_like(perm) inv[perm] = np.arange(len(perm)) return inv
[docs] @cache def canonical_to_recursive_permutation(d: int, max_order: int): """ Compute the permutation mapping canonical tree indices to recursive tree indices. ``perm[i] = j`` means the tree at canonical position ``i`` is at recursive position ``j``. Both are 0-indexed and exclude the empty tree. :param d: Number of colors (path dimension). :type d: int :param max_order: Maximum number of nodes. :type max_order: int :return: Permutation array of shape ``(num_trees,)``. :rtype: numpy.ndarray """ return _canonical_to_recursive_permutation_impl(d, max_order, planar=False)
[docs] @cache def recursive_to_canonical_permutation(d: int, max_order: int): """ Compute the permutation mapping recursive tree indices to canonical tree indices. Inverse of :func:`canonical_to_recursive_permutation`. :param d: Number of colors (path dimension). :type d: int :param max_order: Maximum number of nodes. :type max_order: int :return: Inverse permutation array of shape ``(num_trees,)``. :rtype: numpy.ndarray """ return _recursive_to_canonical_permutation_impl(d, max_order, planar=False)
[docs] @cache def planar_canonical_to_recursive_permutation(d: int, max_order: int): """ Compute the permutation mapping canonical planar-tree indices to recursive planar-tree indices. ``perm[i] = j`` means the planar tree at canonical position ``i`` is at recursive position ``j``. Both are 0-indexed and exclude the empty tree. The canonical ordering is the one emitted by :func:`colored_planar_trees_up_to_order`; the recursive ordering matches the bottom-up planar enumeration used internally by pySigLib. :param d: Number of colors (path dimension). :type d: int :param max_order: Maximum number of nodes. :type max_order: int :return: Permutation array of shape ``(num_trees,)``. :rtype: numpy.ndarray """ return _canonical_to_recursive_permutation_impl(d, max_order, planar=True)
[docs] @cache def planar_recursive_to_canonical_permutation(d: int, max_order: int): """ Compute the permutation mapping recursive planar-tree indices to canonical planar-tree indices. Inverse of :func:`planar_canonical_to_recursive_permutation`. :param d: Number of colors (path dimension). :type d: int :param max_order: Maximum number of nodes. :type max_order: int :return: Inverse permutation array of shape ``(num_trees,)``. :rtype: numpy.ndarray """ return _recursive_to_canonical_permutation_impl(d, max_order, planar=True)