Source code for kauri.trees

# Copyright 2025 Daniil Shmelev
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================

"""
Tree, Forest, ForestSum and TensorProductSum classes.

The classes `Tree`, `Forest` and `ForestSum` are immutable and hashable.
The hash is generated in such a way that two elements of the same class which are equivalent
(e.g. two different orderings of the same tree) will have the same hash.
However, this is not the case across classes. For example, for a Tree t, `hash(t)`,
`hash(t.as_forest())` and `hash(t.as_forest_sum())` are different.

The class `Tree` is totally ordered by the lexicographic ordering. If the trees have
the same structure but are colored differently, they are ordered based on color, with
the color of the highest levels of the trees being the primary ordering.

"""

import math
import numbers
from collections import Counter
from collections.abc import Iterator
from dataclasses import dataclass
from functools import total_ordering
from typing import Union
import warnings


from .utils import (_nodes, _height, _factorial, _sigma,
                    _sorted_list_repr, _list_repr_to_level_sequence,
                    _to_list, _next_layout, _next_planar_layout, _level_sequence_to_list_repr,
                    _check_valid, _to_labelled_tuple, _get_max_color, _to_unlabelled_tuple,
                    _list_repr_to_color_sequence, LabelledReprComparison)
from ._protocols import ForestLike, TreeLike


def _frozen_copy(self):
    return self


def _frozen_deepcopy(self, memodict=None):
    if memodict is None:
        memodict = {}
    memodict[id(self)] = self
    return self


def _lazy_count(self, items_attr):
    if self.count is None:
        object.__setattr__(self, 'count', Counter(getattr(self.simplify(), items_attr)))


def _lazy_hash(self, items_attr):
    _lazy_count(self, items_attr)
    if self.hash_ is None:
        object.__setattr__(self, 'hash_', hash(frozenset(self.count.items())))
    return self.hash_

######################################
[docs] @dataclass(frozen=True) @total_ordering class Tree: """ A single non-planar (un)labelled rooted tree, initialised by its list representation. For example, the unlabelled cherry tree has the list representation [[],[]]. Noting that every list corresponds to a node, we can apply a labelling/coloring by setting the last element of the list to be a non-negative integer. For example, [[2], [1], 0] corresponds to the cherry tree, with the root node labelled by 0, the left leaf labelled by 2 and the right leaf labelled by 1. If a label is left out, it will default to 0. :param list_repr: The nested list representation of the tree .. kauri-exec:: t1 = kr.Tree([[[]],[]]) # An unlabelled tree t2 = kr.Tree([[[3],1],[2],0]) # A labelled tree t3 = kr.Tree([[[3],1],[2]]) # This is the same as t2, since the missing label defaults to 0 kr.display(t1, t2, t3) """ ###################################### list_repr: Union[tuple, list, None] = None unlabelled_repr = None _max_color = 0 def __post_init__(self): if self.list_repr is not None: if not _check_valid(self.list_repr): raise ValueError(repr(self.list_repr) + " is not a valid list representation for a tree.") tuple_repr = _to_labelled_tuple(self.list_repr) object.__setattr__(self, 'list_repr', tuple_repr) unlabelled_repr = _to_unlabelled_tuple(tuple_repr) object.__setattr__(self, 'unlabelled_repr', unlabelled_repr) object.__setattr__(self, '_max_color', _get_max_color(tuple_repr)) __copy__ = _frozen_copy __deepcopy__ = _frozen_deepcopy def __repr__(self): if self.list_repr is None: return "\u2205" if self._max_color == 0: return repr(_to_list(self.unlabelled_repr)) return repr(_to_list(self.list_repr)) def _repr_svg_(self): from .display import _to_svg return _to_svg(self) def __hash__(self): return hash(self.sorted_list_repr())
[docs] def unjoin(self) -> 'Forest': """ For a tree :math:`t = [t_1, t_2, ..., t_k]`, returns the forest :math:`t_1 t_2 \\cdots t_k`. In :cite:`connes1999hopf`, this map is denoted by :math:`B_-`. :return: :math:`t_1 t_2 \\cdots t_k` :rtype: CommutativeForest **Example usage:** .. kauri-exec:: t = Tree([[[]],[]]) kr.display(t, '\u2192', t.unjoin()) """ if self.list_repr is None: return EMPTY_FOREST return Forest(tuple(Tree(rep) for rep in self.list_repr[:-1]))
[docs] def nodes(self) -> int: """ Returns the number of nodes in a tree, :math:`|t|` :return: Number of nodes, :math:`|t|` :rtype: int **Example usage:** .. kauri-exec:: t = Tree([[[]],[]]) print(t.nodes()) """ return _nodes(self.unlabelled_repr)
[docs] def colors(self) -> int: """ Returns the number of colors/labels in a labelled tree. Since the labels are indexed starting from 0, this is equivalent to one more than the maximum label. :return: Number of colors :rtype: int **Example usage:** .. kauri-exec:: print(Tree([]).colors()) print(Tree([0]).colors()) print(Tree([[9],1]).colors()) """ if self.list_repr is None: return 0 return self._max_color + 1
[docs] def height(self) -> int: """ Returns the height of a tree, given by the number of nodes in the longest walk from the root to a leaf. :return: Height :rtype: int **Example usage:** .. kauri-exec:: t = Tree([[[]],[]]) print(t.height()) """ return _height(self.unlabelled_repr)
[docs] def factorial(self) -> int: """ Compute the tree factorial, :math:`t!` :return: Tree factorial, :math:`t!` :rtype: int **Example usage:** .. kauri-exec:: t = Tree([[[]],[]]) print(t.factorial()) """ return _factorial(self.unlabelled_repr)[0]
[docs] def sigma(self) -> int: """ Computes the symmetry factor :math:`\\sigma(t)`, the order of the symmetric group of the tree. For a tree :math:`t = [t_1^{m_1} t_2^{m_2} \\cdots t_k^{m_k}]`, the symmetry factor satisfies the recursion .. math:: \\sigma(t) = \\prod_{i=1}^k m_i! \\sigma(t_i)^{m_i}. :return: Symmetry factor, :math:`\\sigma(t)` :rtype: int **Example usage:** .. kauri-exec:: t = Tree([[[]],[]]) print(t.sigma()) """ return _sigma(self.unlabelled_repr)
[docs] def alpha(self) -> int: """ For a tree :math:`t` with :math:`n` nodes, computes the number of distinct ways of labelling the nodes of the tree with symbols :math:`\\{1, 2, \\ldots, n\\}`, such that: - Each vertex receives one and only one label, - Labellings that are equivalent under the symmetry group are counted only once, - If :math:`(i,j)` is a labelled edge, then :math:`i<j`. This number is typically denoted by :math:`\\alpha(t)` and given by .. math:: \\alpha(t) = \\frac{n!}{t! \\sigma(t)} :return: :math:`\\alpha(t)` :rtype: int **Example usage:** .. kauri-exec:: t = Tree([[[]],[]]) print(t.alpha()) """ return self.beta() // self.factorial()
[docs] def beta(self) -> int: """ For a tree :math:`t` with :math:`n` nodes, computes the number of distinct ways of labelling the nodes of the tree with symbols :math:`\\{1, 2, \\ldots, n\\}`, such that: - Each vertex receives one and only one label, - Labellings that are equivalent under the symmetry group are counted only once. This number is typically denoted by :math:`\\beta(t)` and given by .. math:: \\beta(t) = \\frac{n!}{\\sigma(t)} :return: :math:`\\beta(t)` :rtype: int **Example usage:** .. kauri-exec:: t = Tree([[[]],[]]) print(t.beta()) """ return math.factorial(self.nodes()) // self.sigma()
[docs] def density(self) -> float: """ Density of the tree, :math:`t! / |t|!`. :return: Density, :math:`t! / |t|!` :rtype: float **Example usage:** .. kauri-exec:: t = Tree([[[]],[]]) print(t.density()) """ return self.factorial() / math.factorial(self.nodes())
[docs] def sign(self) -> 'ForestSum': """ Returns the tree signed by the number of nodes, :math:`(-1)^{|t|} t`. :return: Signed tree, :math:`(-1)^{|t|} t` :rtype: ForestSum """ return self.as_forest_sum() if self.nodes() % 2 == 0 else -self
[docs] def __mul__(self, other : Union[int, float, 'Tree', 'Forest', 'ForestSum']) -> Union['Forest', 'ForestSum']: """ Multiplies a tree by a: - scalar, returning a ForestSum - Tree, returning a Forest, - Forest, returning a Forest, - ForestSum, returning a ForestSum. :param other: A scalar, Tree, Forest or ForestSum **Example usage:** .. kauri-exec:: t = 2 * Tree([[]]) * CommutativeForest([Tree([]), Tree([[],[]])]) kr.display(t) """ if _is_scalar(other): out = ForestSum(( (other,self), )) elif isinstance(other, Tree): out = Forest((self, other)) elif isinstance(other, Forest): out = Forest((self,) + other.tree_list) elif isinstance(other, ForestSum): out = ForestSum(tuple((c, self * f) for c,f in other.term_list)) else: _check_compatible(self, other) raise TypeError("Cannot multiply Tree by object of type " + str(type(other))) return out.simplify()
__rmul__ = __mul__
[docs] def __pow__(self, n : int) -> 'Forest': """ Returns the :math:`n^{th}` power of a tree for a positive integer :math:`n`, given by a forest with :math:`n` copies of the tree. :param n: Exponent, a positive integer **Example usage:** .. kauri-exec:: t = Tree([[]]) ** 3 kr.display(t) """ if not isinstance(n, int): raise TypeError("Exponent in Tree.__pow__ must be an int, not " + str(type(n))) if n < 0: raise ValueError("Cannot raise Tree to a negative power") if n == 0: return EMPTY_FOREST out = Forest((self,) * n) return out.simplify()
[docs] def __add__(self, other : Union[int, float, 'Tree', 'Forest', 'ForestSum']) -> 'ForestSum': """ Adds a tree to a scalar, Tree, Forest or ForestSum. :param other: A scalar, Tree, Forest or ForestSum :rtype: ForestSum **Example usage:** .. kauri-exec:: t = 2 + Tree([[]]) + CommutativeForest([Tree([]), Tree([[],[]])]) kr.display(t) """ if _is_scalar(other): out = ForestSum(( (1, self), (other, EMPTY_FOREST) )) elif isinstance(other, (Tree, Forest)): out = ForestSum(( (1, self), (1, other) )) elif isinstance(other, ForestSum): _check_compatible(self, other) out = ForestSum( ((1, self),) + other.term_list ) else: _check_compatible(self, other) raise TypeError("Cannot add Tree and " + str(type(other))) return out.simplify()
def __sub__(self, other): return self + (-other) def __rsub__(self, other): return (-self) + other __radd__ = __add__ def __neg__(self): return self * (-1)
[docs] def __eq__(self, other : Union['Tree', 'Forest', 'ForestSum']) -> bool: """ Compares the tree with another object and returns true if they represent the same tree, regardless of class type (Tree, Forest or ForestSum) or possible reorderings of the same tree. :param other: Tree, Forest or ForestSum :rtype: bool **Example usage:** .. kauri-exec:: print(Tree([[],[]]) == Tree([[],[]]).as_forest()) print(Tree([[],[]]) == Tree([[],[]]).as_forest_sum()) print(Tree([[[]],[]]) == Tree([[],[[]]])) """ if _is_scalar(other): return self.as_forest_sum() == other * EMPTY_TREE if isinstance(other, Tree): return self.equals(other) if isinstance(other, Forest): return self.as_forest() == other if isinstance(other, ForestSum): return self.as_forest_sum() == other return NotImplemented
def __lt__(self, other): if not isinstance(other, Tree): return NotImplemented # Deal with empty trees if self.list_repr is None: if other.list_repr is None: return False return True if other.list_repr is None: return False # If trees are non-empty if self.nodes() != other.nodes(): return self.nodes() < other.nodes() return LabelledReprComparison(self.sorted_list_repr()) < LabelledReprComparison(other.sorted_list_repr())
[docs] def sorted_list_repr(self) -> list: """ Returns the list representation of the sorted tree, where the heaviest branches are rotated to the left. :return: Sorted list representation :rtype: list **Example usage:** .. kauri-exec:: t = Tree([[],[[]]]) print(t.sorted_list_repr()) """ return _sorted_list_repr(self.list_repr)
[docs] def level_sequence(self) -> list: """ Returns the level sequence of the tree, defined as the list :math:`{\\ell_1, \\ell_2, \\cdots, \\ell_n}`, where :math:`\\ell_i` is the level of the :math:`i^{th}` node when the nodes are ordered lexicographically. :return: Level sequence :rtype: list **Example usage:** .. kauri-exec:: t = Tree([[[]],[]]) print(t.level_sequence()) """ return _list_repr_to_level_sequence(self.unlabelled_repr)
[docs] def sorted(self) -> 'Tree': """ Returns the sorted tree, where the heaviest branches are rotated to the left. :return: Sorted tree :rtype: Tree **Example usage:** .. kauri-exec:: t = Tree([[],[[]]]) kr.display(t, '\u2192', t.sorted()) """ return Tree(self.sorted_list_repr())
[docs] def equals(self, other_tree): """Two trees are equal iff their sorted (canonical) list representations match.""" return self.sorted_list_repr() == other_tree.sorted_list_repr()
[docs] def as_forest(self) -> 'Forest': """ Returns the tree t as a forest. Equivalent to ``CommutativeForest([t])``. :return: Tree as a forest :rtype: CommutativeForest **Example usage:** .. code-block:: python >>> Tree([[],[[]]]).as_forest() """ return Forest((self,))
[docs] def as_forest_sum(self) -> 'ForestSum': """ Returns the tree t as a forest sum. Equivalent to ``ForestSum([CommutativeForest([t])])``. :return: Tree as a forest sum :rtype: ForestSum **Example usage:** .. code-block:: python >>> Tree([[],[[]]]).as_forest_sum() """ return ForestSum(( (1, self), ))
[docs] def __next__(self) -> 'Tree': """ Generates the next tree with respect to the lexicographic order. If the tree is labelled, the labelling will be ignored. :return: Next tree :rtype: Tree **Example usage:** .. kauri-exec:: t = Tree([[],[]]) kr.display(t, '\u2192', next(t)) """ if self._max_color > 0: warnings.warn("Calling next() on a labelled tree will ignore the labelling.") if self.list_repr is None: return Tree([]) layout = self.level_sequence() next_ = _next_layout(layout) return Tree(_level_sequence_to_list_repr(next_))
[docs] def __matmul__(self, other : Union[int, float, 'Tree', 'Forest', 'ForestSum']) -> 'TensorProductSum': """ Returns the tensor product of a Tree and a scalar, Tree, Forest or ForestSum. :param other: Other :type other: int | float | Tree | Forest | ForestSum :return: Tensor product :rtype: TensorProductSum **Example usage:** .. kauri-exec:: t = Tree([]) @ (Tree([[]]) + Tree([]) * Tree([[],[]])) kr.display(t) """ if _is_scalar(other): return TensorProductSum(( (other, self.as_forest(), EMPTY_FOREST), )) if isinstance(other, (Tree, Forest)): return TensorProductSum(( (1, self.as_forest(), other.as_forest()), )) if isinstance(other, ForestSum): term_list = [] for c, f in other: term_list.append((c, self, f)) return TensorProductSum(term_list) raise TypeError("Cannot take tensor product of Tree and " + str(type(other)))
[docs] def unlabelled(self): """ Returns the unlabelled version of the tree. **Example usage:** .. kauri-exec:: t = Tree([[[3],1],[2],0]) kr.display(t, '\u2192', t.unlabelled()) """ return Tree(self.unlabelled_repr)
[docs] def color_sequence(self): """Returns the color (label) sequence of the tree in pre-order.""" return _list_repr_to_color_sequence(self.list_repr)
######################################
[docs] @dataclass(frozen=True) class CommutativeForest: """ A commutative product of trees. :param tree_list: A list of trees contained in the forest **Example usage:** .. kauri-exec:: t1 = Tree([]) t2 = Tree([[]]) t3 = Tree([[[]],[]]) f = CommutativeForest([t1,t2,t3]) kr.display(f) """ ###################################### tree_list : Union[tuple, list] = tuple() count : Counter = None hash_ : int = None def __post_init__(self): tuple_repr = tuple(self.tree_list) if tuple_repr == tuple(): tuple_repr = (Tree(None),) object.__setattr__(self, 'tree_list', tuple_repr) __copy__ = _frozen_copy __deepcopy__ = _frozen_deepcopy def __hash__(self): return _lazy_hash(self, 'tree_list')
[docs] def simplify(self) -> 'Forest': # Remove redundant empty trees """ Simplify the forest by removing redundant empty trees. :return: self :rtype: CommutativeForest **Example usage:** .. kauri-exec:: f1 = Tree([[],[[]]]) * Tree(None) f2 = f1.simplify() # Tree([[],[[]]]) """ if len(self.tree_list) <= 1: return self filtered = tuple(t for t in self.tree_list if t.list_repr is not None) if not filtered: return EMPTY_FOREST if len(filtered) == len(self.tree_list): return self return Forest(filtered)
def __repr__(self): if len(self.tree_list) == 0: return "\u2205" r = "" for t in self.tree_list[:-1]: r += repr(t) + " " r += repr(self.tree_list[-1]) + "" return r def _repr_svg_(self): from .display import _to_svg return _to_svg(self) def __iter__(self): yield from self.tree_list
[docs] def join(self, root_color : int = 0) -> 'Tree': """ For a forest :math:`t_1 t_2 \\cdots t_k`, returns the tree :math:`[t_1, t_2, \\cdots, t_k]`. In :cite:`connes1999hopf`, this map is denoted by :math:`B_+`. :param root_color: Color to assign to the root (default 0) :type root_color: int :return: :math:`[t_1, t_2, \\cdots, t_k]` :rtype: Tree **Example usage:** .. kauri-exec:: f = Tree([]) * Tree([[]]) kr.display(f, '\u2192', f.join()) """ if not isinstance(root_color, int): raise TypeError("root_color must be int, not " + str(type(root_color))) out = [t.list_repr for t in self.tree_list] + [root_color] out = tuple(filter(lambda x: x is not None, out)) return Tree(out)
[docs] def nodes(self) -> int: """ For a forest :math:`t_1 t_2 \\cdots t_k`, returns the number of nodes in the forest, :math:`\\sum_{i=1}^k |t_i|`. :return: Number of nodes :rtype: int **Example usage:** .. kauri-exec:: f = Tree([]) * Tree([[]]) print(f.nodes()) """ return sum(t.nodes() for t in self.tree_list)
[docs] def colors(self) -> int: """ Returns the number of colors/labels in the forest. Since the labels are indexed starting from 0, this is equivalent to one more than the maximum label. :return: Number of colors :rtype: int **Example usage:** .. kauri-exec:: print((Tree([[9],0]) * Tree([3])).colors()) """ return max((t.colors() for t in self.tree_list), default=0)
[docs] def num_trees(self) -> int: """ For a forest :math:`t_1 t_2 \\cdots t_k`, returns the number of trees in the forest, :math:`k`. :return: Number of trees :rtype: int **Example usage:** .. kauri-exec:: f = Tree([]) * Tree([[]]) print(f.num_trees()) """ return len(self.tree_list)
[docs] def factorial(self) -> int: """ Apply the tree factorial to the forest as a multiplicative map. For a forest :math:`t_1 t_2 \\cdots t_k`, returns :math:`\\prod_{i=1}^k t_i!`. :return: :math:`\\prod_{i=1}^k t_i!` :rtype: int **Example usage:** .. kauri-exec:: f = Tree([[]]) * Tree([[],[]]) print(f.factorial()) """ return math.prod(x.factorial() for x in self.tree_list)
[docs] def sign(self) -> 'ForestSum': """ Returns the forest signed by the number of nodes, :math:`(-1)^{|f|} f`. :return: Signed forest, :math:`(-1)^{|f|} f` :rtype: ForestSum **Example usage:** .. kauri-exec:: f1 = Tree([[]]) * Tree([[],[]]) kr.display(f1.sign()) f2 = Tree([]) * Tree([[],[]]) kr.display(f2.sign()) """ return self.as_forest_sum() if self.nodes() % 2 == 0 else -self
[docs] def __mul__(self, other : Union[int, float, 'Tree', 'Forest', 'ForestSum']) -> Union['Forest', 'ForestSum']: """ Multiplies a forest by a: - scalar, returning a ForestSum - Tree, returning a Forest, - Forest, returning a Forest, - ForestSum, returning a ForestSum. :param other: A scalar, Tree, Forest or ForestSum **Example usage:** .. kauri-exec:: t = 2 * Tree([[]]) * CommutativeForest([Tree([]), Tree([[],[]])]) kr.display(t) """ if _is_scalar(other): out = ForestSum(( (other, self), )) elif isinstance(other, Tree): out = Forest(self.tree_list + (other,)) elif isinstance(other, Forest): out = Forest(self.tree_list + other.tree_list) elif isinstance(other, ForestSum): out = ForestSum(tuple( (c, self * f) for c, f in other.term_list )) else: _check_compatible(self, other) raise TypeError("Cannot multiply Forest and " + str(type(other))) return out.simplify()
__rmul__ = __mul__
[docs] def __pow__(self, n : int) -> 'Forest': """ Returns the :math:`n^{th}` power of a forest for a positive integer :math:`n`, given by a forest with :math:`n` copies of the original forest. :param n: Exponent, a positive integer **Example usage:** .. kauri-exec:: t = ( Tree([]) * Tree([[]]) ) ** 3 kr.display(t) """ if not isinstance(n, int): raise TypeError("Exponent in Forest.__pow__ must be an int, not " + str(type(n))) if n < 0: raise ValueError("Cannot raise Forest to a negative power") if n == 0: return EMPTY_FOREST out = Forest(self.tree_list * n) return out.simplify()
[docs] def __add__(self, other : Union[int, float, 'Tree', 'Forest', 'ForestSum']) -> 'ForestSum': """ Adds a forest to a scalar, Tree, Forest or ForestSum. :param other: A scalar, Tree, Forest or ForestSum :rtype: ForestSum **Example usage:** .. kauri-exec:: t = 2 + Tree([[]]) + CommutativeForest([Tree([]), Tree([[],[]])]) kr.display(t) """ if _is_scalar(other): out = ForestSum(( (1, self), (other, EMPTY_FOREST) )) elif isinstance(other, (Tree, Forest)): out = ForestSum(( (1, self), (1, other) )) elif isinstance(other, ForestSum): _check_compatible(self, other) out = ForestSum( ((1, self),) + other.term_list ) else: _check_compatible(self, other) raise TypeError("Cannot add Forest and " + str(type(other))) return out.simplify()
def __sub__(self, other): return self + (-other) def __rsub__(self, other): return (-self) + other __radd__ = __add__ def __neg__(self): return self * (-1)
[docs] def equals(self, other_forest): """Two forests are equal iff they contain the same trees with the same multiplicities.""" _lazy_count(self, 'tree_list') _lazy_count(other_forest, 'tree_list') return self.count == other_forest.count
[docs] def __eq__(self, other : Union['Forest', 'ForestSum']) -> bool: """ Compares the forest with another object and returns true if they represent the same forest, regardless of class type (Forest or ForestSum) or possible reorderings of trees. :param other: Forest or ForestSum :rtype: bool **Example usage:** .. kauri-exec:: t1 = Tree([]) t2 = Tree([[]]) t3 = Tree([[[]],[]]) t4 = Tree([[],[[]]]) print(t1 * t2 == t2 * t1) print(t1 * t2 == (t1 * t2).as_forest_sum()) print(t1 * t3 == t1 * t4) """ if _is_scalar(other): return self.as_forest_sum() == other * EMPTY_TREE if isinstance(other, Tree): return self.equals(other.as_forest()) if isinstance(other, CommutativeForest): return self.equals(other) if isinstance(other, ForestSum): return self.as_forest_sum() == other return NotImplemented
def as_forest(self): return self
[docs] def as_forest_sum(self) -> 'ForestSum': """ Returns the forest f as a forest sum. Equivalent to ``ForestSum([f])``. :return: CommutativeForest as a forest sum :rtype: ForestSum **Example usage:** .. code-block:: python >>> (Tree([[],[[]]]) * Tree([[]])).as_forest_sum() """ return ForestSum(( (1,self), ))
[docs] def singleton_reduced(self) -> 'Forest': """ Removes redundant occurrences of the single-node tree in the forest. If the forest contains a tree with more than one node, removes all occurences of the single-node tree. Otherwise, returns the single-node tree. :return: Singleton-reduced forest **Example usage:** .. kauri-exec:: f1 = Tree([]) * Tree([[],[]]) f2 = Tree([]) * Tree([]) * Tree([]) kr.display(f1, '\u2192', f1.singleton_reduced()) kr.display(f2, '\u2192', f2.singleton_reduced()) """ if self.colors() > 1: warnings.warn("Singleton reduced representation will not respect colorings") out = self.simplify() if len(out.tree_list) > 1: new_tree_list = tuple(filter(lambda x: len(x.list_repr) != 1, out.tree_list)) if len(new_tree_list) == 0: new_tree_list = (Tree([]),) out = Forest(new_tree_list) return out
[docs] def __matmul__(self, other: Union[int, float, 'Tree', 'Forest', 'ForestSum']) -> 'TensorProductSum': """ Returns the tensor product of a Forest and a scalar, Tree, Forest or ForestSum. :param other: Other :type other: int | float | Tree | Forest | ForestSum :return: Tensor product :rtype: TensorProductSum **Example usage:** .. kauri-exec:: tp = Tree([]) @ (Tree([[]]) + Tree([]) * Tree([[],[]])) kr.display(tp) """ if _is_scalar(other): return TensorProductSum(( (other, self, EMPTY_FOREST), )) if isinstance(other, (Tree, Forest)): return TensorProductSum(( (1, self, other.as_forest()), )) if isinstance(other, ForestSum): term_list = [] for c, f in other: term_list.append((c, self, f)) return TensorProductSum(term_list) raise TypeError("Cannot take tensor product of Forest and " + str(type(other)))
def __getitem__(self, i): return self.tree_list[i]
Forest = CommutativeForest ######################################
[docs] @dataclass(frozen=True) class ForestSum: """ A linear combination of forests. Works with both non-planar (:class:`CommutativeForest`) and planar (:class:`NoncommutativeForest`) forests, but all terms in a single ForestSum must use the same forest type. ``OrderedForestSum`` is an alias for ``ForestSum``. :param term_list: A list or tuple containing tuples of coefficients and forests representing terms of the sum. If a term contains a tree, it will be converted to a forest on initialisation. **Example usage:** .. kauri-exec:: t1 = Tree([]) t2 = Tree([[]]) t3 = Tree([[[]],[]]) s = ForestSum([(1, t1), (-2, t1*t2), (1, t2*t3)]) kr.display(s) ForestSum also works with planar trees: .. kauri-exec:: p1 = PlanarTree([]) p2 = PlanarTree([[]]) s = ForestSum([(1, p1), (-2, p1*p2)]) kr.display(s) """ ###################################### term_list : Union[tuple, list] = tuple() count : Counter = None hash_ : int = None def __post_init__(self): new_term_list = [] for term in self.term_list: if not _is_scalar(term[0]): raise TypeError("ForestSum coefficients must be scalars, got " + str(type(term[0]))) if isinstance(term[1], ForestLike): new_term_list.append(term) elif isinstance(term[1], Tree): new_term_list.append((term[0], term[1].as_forest())) elif isinstance(term[1], TreeLike): if hasattr(term[1], 'as_ordered_forest'): new_term_list.append((term[0], term[1].as_ordered_forest())) else: new_term_list.append((term[0], CommutativeForest((term[1],)))) else: raise TypeError("Terms must be tuples of (coefficient, ForestLike | TreeLike)") new_term_list = tuple(new_term_list) # Reject mixed planar/non-planar forest types if len(new_term_list) > 1: first_type = type(new_term_list[0][1]) for _, f in new_term_list[1:]: if type(f) is not first_type: raise TypeError( f"ForestSum cannot mix forest types: got {first_type.__name__} " f"and {type(f).__name__}") object.__setattr__(self, 'term_list', new_term_list) __copy__ = _frozen_copy __deepcopy__ = _frozen_deepcopy def __hash__(self): return _lazy_hash(self, 'term_list') def __repr__(self): if len(self.term_list) == 0: return "0" parts = [str(c) + " * " + repr(f) for c, f in self.term_list] return " + ".join(parts) if parts else "0" def _repr_svg_(self): from .display import _to_svg return _to_svg(self) def __iter__(self): for c,f in self.term_list: yield c,f
[docs] def nodes(self) -> int: """ For a forest sum :math:`\\sum_{i=1}^m c_i \\prod_{j=1}^{k_i} t_{ij}`, returns the total number of nodes in the forest sum, :math:`\\sum_{i=1}^m \\sum_{j=1}^{k_i} |t_{ij}|`. :return: Number of nodes :rtype: int **Example usage:** .. kauri-exec:: f = Tree([]) * Tree([[]]) + 2 * Tree([[],[]]) print(f.nodes()) """ return sum(f.nodes() for c, f in self.term_list)
[docs] def colors(self) -> int: """ Returns the number of colors/labels in the forest sum. Since the labels are indexed starting from 0, this is equivalent to one more than the maximum label. :return: Number of colors :rtype: int **Example usage:** .. kauri-exec:: print((Tree([[9],0]) * Tree([3]) + Tree([2])).colors()) """ return max((f.colors() for _, f in self.term_list), default=0)
[docs] def num_trees(self) -> int: """ For a forest sum :math:`\\sum_{i=1}^m c_i \\prod_{j=1}^{k_i} t_{ij}`, returns the total number of trees in the forest sum, :math:`\\sum_{i=1}^m k_i`. :return: Number of trees :rtype: int **Example usage:** .. kauri-exec:: f = Tree([]) * Tree([[]]) + 2 * Tree([[],[]]) print(f.num_trees()) """ return sum(f.num_trees() for c, f in self.term_list)
[docs] def num_forests(self) -> int: """ For a forest sum :math:`\\sum_{i=1}^m c_i \\prod_{j=1}^{k_i} t_{ij}`, returns the total number of forests in the forest sum, :math:`m`. :return: Number of forests :rtype: int **Example usage:** .. kauri-exec:: f = Tree([]) * Tree([[]]) + 2 * Tree([[],[]]) print(f.num_forests()) """ return len(self.term_list)
[docs] def simplify(self) -> 'ForestSum': """ Simplify the forest sum by removing redundant empty trees and cancelling terms where applicable. :return: Reduced forest sum :rtype: ForestSum **Example usage:** .. kauri-exec:: s1 = Tree([[],[[]]]) * Tree(None) + Tree([]) + Tree([[]]) - Tree([[]]) s2 = s1.simplify() # Tree([[],[[]]]) + Tree([]) """ merged = {} for c, f in self.term_list: f_simplified = f.simplify() if f_simplified in merged: merged[f_simplified] = merged[f_simplified] + c else: merged[f_simplified] = c result = tuple((c, f) for f, c in merged.items() if c != 0) if not result: return ZERO_FOREST_SUM return ForestSum(result)
[docs] def factorial(self) -> int: """ Apply the tree factorial to the forest sum as a multiplicative linear map. For a forest sum :math:`\\sum_{i=1}^m c_i \\prod_{j=1}^{k_i} t_{ij}`, returns :math:`\\sum_{i=1}^m c_i \\prod_{j=1}^{k_i} t_{ij}!`. :return: :math:`\\sum_{i=1}^m c_i \\prod_{j=1}^{k_i} t_{ij}!` :rtype: int **Example usage:** .. kauri-exec:: s = Tree([[],[[]]]) * Tree([]) + Tree([[]]) print(s.factorial()) """ return sum(c * f.factorial() for c,f in self.term_list)
[docs] def sign(self) -> 'ForestSum': """ Returns the forest sum where every forest is replaced by its signed value, :math:`(-1)^{|f|} f`. :return: Signed forest sum :rtype: ForestSum **Example usage:** .. kauri-exec:: s = Tree([[[]],[]]) * Tree([[]]) + 2 * Tree([]) kr.display(s.sign()) """ return ForestSum(tuple((-c if f.nodes() % 2 else c, f) for c,f in self.term_list))
[docs] def __mul__(self, other : Union[int, float, 'Tree', 'Forest', 'ForestSum']) -> 'ForestSum': """ Multiplies a ForestSum by a scalar, Tree, Forest or ForestSum, returning a ForestSum. :param other: A scalar, Tree, Forest or ForestSum :rtype: ForestSum **Example usage:** .. kauri-exec:: s = ForestSum([(1, Tree([])), (-2, Tree([[],[]]))]) kr.display(2 * Tree([[]]) * s) """ return self._mul_impl(other, reverse=False)
def __rmul__(self, other): return self._mul_impl(other, reverse=True) def _mul_impl(self, other, *, reverse): if _is_scalar(other): new_term_list = tuple( (c * other, f) for c, f in self.term_list ) elif isinstance(other, (TreeLike, ForestLike)): _check_compatible(self, other) if reverse: new_term_list = tuple( (c, other * f) for c, f in self.term_list ) else: new_term_list = tuple( (c, f * other) for c, f in self.term_list ) elif isinstance(other, ForestSum): _check_compatible(self, other) left, right = (other.term_list, self.term_list) if reverse else (self.term_list, other.term_list) new_term_list = tuple( (c1 * c2, f1 * f2) for c1, f1 in left for c2, f2 in right if c1 != 0 and c2 != 0) else: raise TypeError("Cannot multiply ForestSum and " + str(type(other))) out = ForestSum(new_term_list) if new_term_list else ZERO_FOREST_SUM return out.simplify()
[docs] def __pow__(self, n : int) -> 'ForestSum': """ Returns the :math:`n^{th}` power of a forest sum for a positive integer :math:`n`. :param n: Exponent, a positive integer :rtype: ForestSum **Example usage:** .. kauri-exec:: t = ( Tree([]) * Tree([[]]) + Tree([[],[]]) ) ** 3 kr.display(t) """ if not isinstance(n, int): raise TypeError("Exponent in ForestSum.__pow__ must be an int, not " + str(type(n))) if n < 0: raise ValueError("Cannot raise ForestSum to a negative power") if n == 0: if _is_planar_obj(self): return ForestSum(((1, EMPTY_ORDERED_FOREST),)) return EMPTY_FOREST_SUM temp = self for _ in range(n-1): temp = temp * self return temp.simplify()
[docs] def __add__(self, other : Union[int, float, 'Tree', 'Forest', 'ForestSum']) -> 'ForestSum': """ Adds a ForestSum to a scalar, Tree, Forest or ForestSum. :param other: A scalar, Tree, Forest or ForestSum :rtype: ForestSum **Example usage:** .. kauri-exec:: s = ForestSum([(1, Tree([])), (-2, Tree([[],[]]))]) kr.display(2 + Tree([[]]) + s) """ if _is_scalar(other): empty = EMPTY_ORDERED_FOREST if _is_planar_obj(self) else EMPTY_FOREST new_term_list = self.term_list + ((other, empty),) elif isinstance(other, (TreeLike, ForestLike)): _check_compatible(self, other) new_term_list = self.term_list + ((1, other),) elif isinstance(other, ForestSum): _check_compatible(self, other) new_term_list = self.term_list + other.term_list else: raise TypeError("Cannot add ForestSum and " + str(type(other))) out = ForestSum(new_term_list) return out.simplify()
def __sub__(self, other): return self + (- other) def __rsub__(self, other): return (-self) + other __radd__ = __add__ def __neg__(self): return ForestSum(tuple((-c, f) for c, f in self.term_list)) def equals(self, other): _lazy_count(self, 'term_list') _lazy_count(other, 'term_list') return self.count == other.count
[docs] def __eq__(self, other : 'ForestSum') -> bool: """ Compares the forest sum with another forest sum and returns true if they represent the same forest sum, regardless of possible reorderings of trees. :param other: ForestSum :rtype: bool **Example usage:** .. kauri-exec:: t1 = Tree([]) t2 = Tree([[]]) t3 = Tree([[[]],[]]) t4 = Tree([[],[[]]]) print(t1 * t2 + t3 == t3 + t2 * t1) print(t1 * t2 + t3 == t1 * t2 + t4) """ if _is_scalar(other): return self.equals(other * EMPTY_TREE) if isinstance(other, Tree): return self.equals(other.as_forest_sum()) if isinstance(other, Forest): return self.equals(other.as_forest_sum()) if isinstance(other, ForestSum): return self.equals(other) if isinstance(other, TreeLike) and hasattr(other, 'as_forest_sum'): return self.equals(other.as_forest_sum()) if isinstance(other, ForestLike) and hasattr(other, 'as_forest_sum'): return self.equals(other.as_forest_sum()) return NotImplemented
[docs] def singleton_reduced(self) -> 'ForestSum': """ Removes redundant occurrences of the single-node tree in each forest of the forest sum. If the forest contains a tree with more than one node, removes all occurences of the single-node tree. Otherwise, replaces it with the single-node tree. :return: Singleton-reduced forest sum :rtype: ForestSum **Example usage:** .. kauri-exec:: s1 = Tree([]) * Tree([[],[]]) + Tree([]) * Tree([]) * Tree([]) kr.display(s1, '\u2192', s1.singleton_reduced()) """ return ForestSum(tuple((c, f.singleton_reduced()) for c, f in self.term_list))
def as_forest_sum(self): return self
[docs] def __matmul__(self, other: Union[int, float, 'Tree', 'Forest', 'ForestSum']) -> 'TensorProductSum': """ Returns the tensor product of a ForestSum and a scalar, Tree, Forest or ForestSum. :param other: Other :type other: int | float | Tree | Forest | ForestSum :return: Tensor product :rtype: TensorProductSum **Example usage:** .. kauri-exec:: tp = Tree([]) @ (Tree([[]]) + Tree([]) * Tree([[],[]])) kr.display(tp) """ if _is_scalar(other): empty = EMPTY_ORDERED_FOREST if _is_planar_obj(self) else EMPTY_FOREST term_list = [] for c, f in self: term_list.append((other * c, f, empty)) return TensorProductSum(term_list) if isinstance(other, (TreeLike, ForestLike)): _check_compatible(self, other) other_ = _coerce_to_forest(other) term_list = [] for c, f in self: term_list.append((c, f, other_)) return TensorProductSum(term_list) if isinstance(other, ForestSum): _check_compatible(self, other) term_list = [] for c1, f1 in self: for c2, f2 in other: term_list.append((c1 * c2, f1, f2)) return TensorProductSum(term_list) raise TypeError("Cannot take tensor product of ForestSum and " + str(type(other)))
def __getitem__(self, i): return self.term_list[i]
############################################## ############################################## try: import sympy as _sympy _SCALAR_TYPES = (numbers.Real, _sympy.Expr) except ImportError: _SCALAR_TYPES = (numbers.Real,) def _is_scalar(obj): return isinstance(obj, _SCALAR_TYPES) def _is_tree_or_forest(obj): return isinstance(obj, (TreeLike, ForestLike)) def _coerce_to_forest(obj): """Coerce a tree-like object to its corresponding forest type.""" if isinstance(obj, ForestLike): return obj if isinstance(obj, Tree): return obj.as_forest() if isinstance(obj, PlanarTree): return obj.as_ordered_forest() raise TypeError(f"Cannot coerce {type(obj)} to forest") EMPTY_TREE = Tree(None) EMPTY_FOREST = Forest((EMPTY_TREE,)) EMPTY_FOREST_SUM = ForestSum( ( (1, EMPTY_FOREST), ) ) ZERO_FOREST_SUM = ForestSum(()) ############################################## ##############################################
[docs] @dataclass(frozen=True) class TensorProductSum: """ A linear combination of tensor products of forests. Works with both non-planar (:class:`CommutativeForest`) and planar (:class:`NoncommutativeForest`) forests. :param term_list: A list of tuples representing terms in the sum. Tuples must be of the form `(c, f1, f2)`, where `c` is an `int` or `float` and `f1, f2` are Forests, representing the term :math:`c \\cdot (f1 \\otimes f2)`. **Example usage:** .. kauri-exec:: tp = Tree([]) @ Tree([[]]) - 2 * Tree([[],[]]) @ Tree(None) kr.display(tp) """ term_list: Union[tuple, list, None] #(c, f1, f2) count : Counter = None hash_ : int = None def __post_init__(self): tuple_list = [] for x in self.term_list: if not (_is_scalar(x[0]) and _is_tree_or_forest(x[1]) and _is_tree_or_forest(x[2])): raise TypeError("Terms must be tuples of type (scalar, TreeLike | ForestLike, TreeLike | ForestLike)") tuple_list.append((x[0], _coerce_to_forest(x[1]), _coerce_to_forest(x[2]))) tuple_list = tuple(tuple_list) object.__setattr__(self, 'term_list', tuple_list) __copy__ = _frozen_copy __deepcopy__ = _frozen_deepcopy def __repr__(self): if self.term_list is None or self.term_list == tuple(): return "0" parts = [str(c) + " * " + repr(f1) + " \u2297 " + repr(f2) for c, f1, f2 in self.term_list] return " + ".join(parts) def _repr_svg_(self): from .display import _to_svg return _to_svg(self)
[docs] def simplify(self) -> 'TensorProductSum': """ Simplify the tensor product sum by removing redundant empty trees and cancelling terms where applicable. :return: Reduces tensor product sum :rtype: TensorProductSum **Example usage:** .. kauri-exec:: tp1 = Tree([[],[[]]]) @ (Tree([]) * Tree(None)) + Tree([]) @ Tree([[]]) - Tree([]) @ Tree([[]]) tp2 = tp1.simplify() # Tree([[],[[]]]) @ Tree([]) """ new_term_list = [] for c, f1, f2 in self.term_list: f1_reduced = f1.simplify() f2_reduced = f2.simplify() for i, (_, f1_, f2_) in enumerate(new_term_list): if f1_reduced.equals(f1_) and f2_reduced.equals(f2_): old_term_ = new_term_list[i] new_term_list[i] = (old_term_[0] + c, old_term_[1], old_term_[2]) break else: new_term_list.append((c, f1_reduced, f2_reduced)) result = tuple(term for term in new_term_list if term[0] != 0) return TensorProductSum(result)
[docs] def singleton_reduced(self) -> 'TensorProductSum': """ Removes redundant occurrences of the single-node tree in each forest of the tensor product sum. If the forest contains a tree with more than one node, removes all occurences of the single-node tree. Otherwise, replaces it with the single-node tree. :return: Singleton-reduced tensor product sum :rtype: TensorProductSum **Example usage:** .. kauri-exec:: s1 = (Tree([]) * Tree([[],[]])) @ (Tree([]) * Tree([]) * Tree([])) kr.display(s1, '\u2192', s1.singleton_reduced()) """ return TensorProductSum(tuple((c, f1.singleton_reduced(), f2.singleton_reduced()) for c, f1, f2 in self.term_list))
[docs] def __eq__(self, other : 'TensorProductSum') -> bool: """ Compares the tensor product sum with another tensor product sum and returns true if they represent the same sum, regardless of possible reorderings of trees within forests or reorderings of terms. :param other: TensorProductSum :rtype: bool **Example usage:** .. kauri-exec:: t1 = Tree([]) t2 = Tree([[]]) t3 = Tree([[[]],[]]) t4 = Tree([[],[[]]]) print(t1 @ t2 + t2 @ t3 == t2 @ t3 + t1 @ t2) print(t1 @ (t2 * t3) == t1 @ (t3 * t2)) print(t1 @ t3 == t1 @ t4) """ if _is_scalar(other): if other == 0: return not self.simplify().term_list return NotImplemented if not isinstance(other, TensorProductSum): return NotImplemented _lazy_count(self, 'term_list') _lazy_count(other, 'term_list') return self.count == other.count
def __hash__(self): return _lazy_hash(self, 'term_list')
[docs] def __add__(self, other : 'TensorProductSum') -> 'TensorProductSum': """ Adds two tensor product sums, or adds a scalar (treated as scalar * (empty ⊗ empty)). :param other: Other tensor product sum or scalar :type other: TensorProductSum | int | float """ if isinstance(other, TensorProductSum): return TensorProductSum(self.term_list + other.term_list).simplify() if _is_scalar(other): if other == 0: return self return TensorProductSum(self.term_list + ((other, EMPTY_FOREST, EMPTY_FOREST),)).simplify() raise TypeError("Cannot add TensorProductSum and " + str(type(other)))
def __neg__(self): return TensorProductSum(tuple((-x[0], x[1], x[2]) for x in self.term_list)) def __sub__(self, other): return self + (-other)
[docs] def __mul__(self, other : Union[int, float, 'TensorProductSum']) -> 'TensorProductSum': """ Multiplies a tensor product sum by a scalar or tensor product sum. :param other: Other :type other: int | float | TensorProductSum """ if isinstance(other, TensorProductSum): new_term_list = [] for c1, f11, f12 in self: for c2, f21, f22 in other: new_term_list.append((c1 * c2, f11 * f21, f12 * f22)) return TensorProductSum(tuple(new_term_list)).simplify() if _is_scalar(other): if other == 0: return TensorProductSum(()) return TensorProductSum(tuple((other * x[0], x[1], x[2]) for x in self.term_list)) raise TypeError("Cannot multiply TensorSum by " + str(type(other)))
def __rsub__(self, other): return (-self) + other __radd__ = __add__ __rmul__ = __mul__ def __iter__(self): for c, f1, f2 in self.term_list: yield c, f1, f2 def __len__(self): return len(self.term_list) def __getitem__(self, i): return self.term_list[i]
[docs] def colors(self): """ Returns the number of colors/labels in the tensor product sum. Since the labels are indexed starting from 0, this is equivalent to one more than the maximum label. :return: Number of colors :rtype: int **Example usage:** .. kauri-exec:: print((Tree([[9],0]) @ Tree([3]) + Tree([2]) @ Tree([4])).colors()) """ return max((max(f1.colors(), f2.colors()) for _, f1, f2 in self.term_list), default=0)
###################################### @total_ordering @dataclass(frozen=True) class PlanarTree: """Ordered rooted tree; sibling order is part of identity. A single planar (ordered) rooted tree, initialised by its list representation. Unlike ``Tree``, the order of the children matters: ``PlanarTree([[],[[]]])`` and ``PlanarTree([[[]],[]])`` represent **different** trees. :param list_repr: The nested list representation of the tree .. kauri-exec:: t1 = kr.PlanarTree([[],[[]]]) # An unlabelled planar tree t2 = kr.PlanarTree([[1],[[3],2],0]) # A labelled planar tree t3 = kr.PlanarTree([[1],[[3],2]]) # Same as t2 (missing label defaults to 0) kr.display(t1, t2, t3) """ list_repr: Union[tuple, list, None] = None unlabelled_repr = None _max_color = 0 def __post_init__(self) -> None: if self.list_repr is not None: if not _check_valid(self.list_repr): raise ValueError(f"{self.list_repr!r} is not a valid planar tree representation.") tuple_repr: tuple = _to_labelled_tuple(self.list_repr) object.__setattr__(self, "list_repr", tuple_repr) object.__setattr__(self, "unlabelled_repr", _to_unlabelled_tuple(tuple_repr)) object.__setattr__(self, "_max_color", _get_max_color(tuple_repr)) def nodes(self) -> int: """Returns the number of nodes in the planar tree, :math:`|t|`. **Example usage:** .. kauri-exec:: t1 = PlanarTree([[],[[]]]) t2 = PlanarTree([[[]],[]]) print(t1.nodes(), t2.nodes()) # Same node count, different planar trees """ return _nodes(self.unlabelled_repr) def factorial(self) -> int: """Compute the tree factorial for a planar tree. Uses the same recursion as ``Tree.factorial()``. **Example usage:** .. kauri-exec:: t1 = PlanarTree([[],[[]]]) t2 = PlanarTree([[[]],[]]) print(t1.factorial(), t2.factorial()) # Same factorial, different planar trees """ return _factorial(self.unlabelled_repr)[0] def sigma(self) -> int: """Symmetry factor of an ordered tree — always 1 (sibling order is part of identity). **Example usage:** .. kauri-exec:: t = PlanarTree([[],[]]) print(t.sigma()) # Always 1 for planar (cf. Tree([[],[]]).sigma() == 2) """ return 1 def height(self) -> int: """Returns the height of the tree (longest root-to-leaf path length). **Example usage:** .. kauri-exec:: t1 = PlanarTree([[],[[]]]) t2 = PlanarTree([[[]],[]]) print(t1.height(), t2.height()) # Same height, different planar trees """ return _height(self.unlabelled_repr) def density(self) -> float: """Density of the tree, :math:`t! / |t|!`. **Example usage:** .. kauri-exec:: t1 = PlanarTree([[],[[]]]) t2 = PlanarTree([[[]],[]]) print(t1.density(), t2.density()) # Same density, different planar trees """ return self.factorial() / math.factorial(self.nodes()) def alpha(self) -> int: """Number of monotone labellings (up to symmetry). Since sigma=1 for planar trees, equals ``beta() / factorial()``. **Example usage:** .. kauri-exec:: t = PlanarTree([[],[]]) print(t.alpha()) # cf. Tree([[],[]]).alpha() == 1 """ return self.beta() // self.factorial() def beta(self) -> int: """Number of distinct labellings (up to symmetry). Since sigma=1 for planar trees, equals ``nodes()!``. **Example usage:** .. kauri-exec:: t = PlanarTree([[],[]]) print(t.beta()) # cf. Tree([[],[]]).beta() == 3 """ return math.factorial(self.nodes()) def unjoin(self) -> 'NoncommutativeForest': """For a tree t = [t_1, ..., t_k], returns the forest t_1 ... t_k (the B- map). **Example usage:** .. kauri-exec:: t1 = PlanarTree([[[]],[]]) t2 = PlanarTree([[],[[]]]) kr.display(t1, '→', t1.unjoin()) # [[]],[] -- left child is taller kr.display(t2, '→', t2.unjoin()) # [],[[]] -- right child is taller """ if self.list_repr is None: return EMPTY_ORDERED_FOREST return NoncommutativeForest(tuple(PlanarTree(rep) for rep in self.list_repr[:-1])) def unlabelled(self) -> 'PlanarTree': """Returns the unlabelled version of the tree. **Example usage:** .. kauri-exec:: t = PlanarTree([[[3],1],[2],0]) kr.display(t, '→', t.unlabelled()) """ return PlanarTree(self.unlabelled_repr) def as_ordered_forest(self) -> 'OrderedForest': """Returns the tree as a single-tree ordered (noncommutative) forest.""" return OrderedForest((self,)) def as_forest(self) -> 'OrderedForest': """Returns the tree as a single-tree ordered (noncommutative) forest. Alias for :meth:`as_ordered_forest`, provided for API parity with :meth:`Tree.as_forest`. """ return self.as_ordered_forest() def to_nonplanar_tree(self) -> Tree: """Converts this planar tree to its non-planar equivalent (forgets sibling order).""" if self.list_repr is None: return Tree(None) return Tree(self.list_repr) def __hash__(self): return hash(self.list_repr) def __eq__(self, other): """Compares the tree with another object and returns true if they represent the same planar tree. Unlike ``Tree.__eq__``, sibling order matters. :param other: PlanarTree, NoncommutativeForest or ForestSum :rtype: bool **Example usage:** .. kauri-exec:: print(PlanarTree([[],[]]) == PlanarTree([[],[]]).as_ordered_forest()) print(PlanarTree([[[]],[]]) == PlanarTree([[],[[]]])) # Order matters! print(PlanarTree([[],[]]) == PlanarTree([[],[]])) """ if isinstance(other, PlanarTree): return self.list_repr == other.list_repr if isinstance(other, (NoncommutativeForest, ForestSum)): return self.as_forest_sum() == other return NotImplemented def __lt__(self, other): if not isinstance(other, PlanarTree): return NotImplemented if self.list_repr is None: if other.list_repr is None: return False return True if other.list_repr is None: return False if self.nodes() != other.nodes(): return self.nodes() < other.nodes() return LabelledReprComparison(self.list_repr) < LabelledReprComparison(other.list_repr) def __mul__(self, other): return self.as_ordered_forest().__mul__(other) def __rmul__(self, other): return self.as_ordered_forest().__rmul__(other) def __pow__(self, n): if not isinstance(n, int): raise TypeError("Exponent must be an int, not " + str(type(n))) if n < 0: raise ValueError("Cannot raise PlanarTree to a negative power") if n == 0: return EMPTY_ORDERED_FOREST return NoncommutativeForest((self,) * n).simplify() def __add__(self, other): if _is_scalar(other): return ForestSum(((1, self.as_ordered_forest()), (other, EMPTY_ORDERED_FOREST))).simplify() if isinstance(other, (PlanarTree, NoncommutativeForest)): return ForestSum(((1, self), (1, other))).simplify() if isinstance(other, ForestSum): _check_compatible(self, other) return ForestSum(((1, self),) + other.term_list).simplify() _check_compatible(self, other) raise TypeError("Cannot add PlanarTree and " + str(type(other))) __radd__ = __add__ def __sub__(self, other): return self + (-other) def __rsub__(self, other): return (-self) + other def __neg__(self): return ForestSum(((-1, self.as_ordered_forest()),)) def sign(self): """Returns the tree signed by the number of nodes, :math:`(-1)^{|t|} t`.""" return self.as_forest_sum() if self.nodes() % 2 == 0 else -self def colors(self) -> int: """Returns the number of colors/labels in a labelled planar tree. **Example usage:** .. kauri-exec:: print(PlanarTree([]).colors()) print(PlanarTree([0]).colors()) print(PlanarTree([[9],1]).colors()) """ if self.list_repr is None: return 0 return self._max_color + 1 def __repr__(self): if self.list_repr is None: return "\u2205" if self._max_color == 0: return repr(_to_list(self.unlabelled_repr)) return repr(_to_list(self.list_repr)) def _repr_svg_(self): from .display import _to_svg return _to_svg(self) def level_sequence(self) -> list: """Returns the level sequence of the planar tree (root at level 0). **Example usage:** .. kauri-exec:: t1 = PlanarTree([[],[[]]]) t2 = PlanarTree([[[]],[]]) print(t1.level_sequence()) # Different sequences for print(t2.level_sequence()) # different planar trees """ return _list_repr_to_level_sequence(self.unlabelled_repr) def color_sequence(self): """Returns the color (label) sequence of the planar tree in pre-order.""" return _list_repr_to_color_sequence(self.list_repr) def __matmul__(self, other): """Returns the tensor product of a PlanarTree and a scalar, PlanarTree, NoncommutativeForest or ForestSum. :param other: Other :return: Tensor product :rtype: TensorProductSum **Example usage:** .. kauri-exec:: t = PlanarTree([]) @ (PlanarTree([[]]) + PlanarTree([]) * PlanarTree([[],[]])) kr.display(t) """ if _is_scalar(other): return TensorProductSum(((other, self.as_ordered_forest(), EMPTY_ORDERED_FOREST),)) if isinstance(other, (PlanarTree, NoncommutativeForest)): return TensorProductSum(((1, self.as_ordered_forest(), _coerce_to_forest(other)),)) if isinstance(other, ForestSum): return TensorProductSum(tuple((c, self, f) for c, f in other)) raise TypeError("Cannot take tensor product of PlanarTree and " + str(type(other))) def sorted_list_repr(self): """Returns the list representation. For planar trees this is the identity (sibling order is part of the tree's identity). **Example usage:** .. kauri-exec:: t1 = PlanarTree([[],[[]]]) t2 = PlanarTree([[[]],[]]) print(t1.sorted_list_repr()) # Preserves original order (unlike Tree) print(t2.sorted_list_repr()) """ return self.list_repr def equals(self, other_tree): """Two planar trees are equal iff their list representations match exactly.""" return self.list_repr == other_tree.list_repr def __next__(self) -> 'PlanarTree': """Generates the next planar tree in lexicographic order of level sequences. .. note:: Enumeration runs from the star tree (widest) to the chain tree (tallest) within each order, which is the opposite direction from ``Tree.__next__()``. To iterate all planar trees of order *n*, start from the star tree ``PlanarTree([[]]*( n-1))``. :return: Next planar tree :rtype: PlanarTree **Example usage:** .. kauri-exec:: t = PlanarTree([[],[]]) kr.display(t, '\u2192', next(t)) """ if self.list_repr is None: return PlanarTree([]) if self._max_color > 0: warnings.warn("Calling next() on a labelled tree will ignore the labelling.") layout = self.level_sequence() next_ = _next_planar_layout(layout) return PlanarTree(_level_sequence_to_list_repr(next_)) def as_forest_sum(self): """Returns the planar tree as a ForestSum with coefficient 1.""" return ForestSum(((1, self.as_ordered_forest()),)) @dataclass(frozen=True) class NoncommutativeForest: """Noncommutative forest (word) of planar trees. :param tree_list: A list of planar trees contained in the forest **Example usage:** .. kauri-exec:: t1 = PlanarTree([]) t2 = PlanarTree([[]]) f1 = NoncommutativeForest([t1,t2]) f2 = NoncommutativeForest([t2,t1]) kr.display(f1, f2) # Order matters: these are different forests """ tree_list: tuple[PlanarTree, ...] = tuple() def __post_init__(self) -> None: values: tuple[PlanarTree, ...] = tuple(self.tree_list) if len(values) == 0: values = (EMPTY_PLANAR_TREE,) object.__setattr__(self, "tree_list", values) def __iter__(self) -> Iterator[PlanarTree]: yield from self.tree_list def __getitem__(self, index: int) -> PlanarTree: return self.tree_list[index] def simplify(self) -> 'NoncommutativeForest': """Removes empty trees from the forest, preserving order. **Example usage:** .. kauri-exec:: f1 = PlanarTree([[],[[]]]) * PlanarTree(None) f2 = f1.simplify() # PlanarTree([[],[[]]]) """ if len(self.tree_list) <= 1: return self filtered = tuple(tree for tree in self.tree_list if tree.list_repr is not None) if len(filtered) == 0: return EMPTY_ORDERED_FOREST if len(filtered) == len(self.tree_list): return self return NoncommutativeForest(filtered) def nodes(self) -> int: """Returns the total number of nodes in the forest. **Example usage:** .. kauri-exec:: f = PlanarTree([]) * PlanarTree([[]]) print(f.nodes()) """ return sum(tree.nodes() for tree in self.tree_list) def num_trees(self) -> int: """Returns the number of trees in the forest. **Example usage:** .. kauri-exec:: f = PlanarTree([]) * PlanarTree([[]]) print(f.num_trees()) """ return len(self.tree_list) def singleton_reduced(self) -> 'NoncommutativeForest': """Remove single-node trees from the forest, preserving order. **Example usage:** .. kauri-exec:: f1 = PlanarTree([]) * PlanarTree([[],[]]) f2 = PlanarTree([]) * PlanarTree([]) * PlanarTree([]) kr.display(f1, '\u2192', f1.singleton_reduced()) kr.display(f2, '\u2192', f2.singleton_reduced()) """ if self.colors() > 1: warnings.warn("Singleton reduced representation will not respect colorings") out = self.simplify() if len(out.tree_list) > 1: new_tree_list = tuple(t for t in out.tree_list if len(t.list_repr) != 1) if len(new_tree_list) == 0: new_tree_list = (PlanarTree([]),) out = NoncommutativeForest(new_tree_list) return out def as_forest(self): """Return self (protocol compatibility).""" return self def factorial(self) -> int: """Product of tree factorials: ``prod(t.factorial() for t in self.tree_list)``. **Example usage:** .. kauri-exec:: f = PlanarTree([[]]) * PlanarTree([[],[]]) print(f.factorial()) """ return math.prod(t.factorial() for t in self.tree_list) def __repr__(self): if len(self.tree_list) == 0: return "\u2205" return " ".join(repr(t) for t in self.tree_list) def _repr_svg_(self): from .display import _to_svg return _to_svg(self) def colors(self) -> int: """Returns the maximum number of colors across all trees in the forest. **Example usage:** .. kauri-exec:: print((PlanarTree([[9],0]) * PlanarTree([3])).colors()) """ return max((t.colors() for t in self.tree_list), default=0) def equals(self, other): """Two noncommutative forests are equal iff their tree lists match exactly.""" if not isinstance(other, NoncommutativeForest): return False return self.tree_list == other.tree_list def _forest_mul(self, other, *, prepend): if _is_scalar(other): return ForestSum(((other, self),)).simplify() if isinstance(other, PlanarTree): other_trees = (other,) elif isinstance(other, NoncommutativeForest): other_trees = other.tree_list elif isinstance(other, ForestSum): _check_compatible(self, other) terms = tuple( (coeff, NoncommutativeForest( (forest.tree_list + self.tree_list) if prepend else (self.tree_list + forest.tree_list) ).simplify()) for coeff, forest in other.term_list ) return ForestSum(terms).simplify() else: _check_compatible(self, other) side = (f"{type(other)} and NoncommutativeForest" if prepend else f"NoncommutativeForest and {type(other)}") raise TypeError(f"Cannot multiply {side}") trees = (other_trees + self.tree_list) if prepend else (self.tree_list + other_trees) return NoncommutativeForest(trees).simplify() def __mul__(self, other): """Multiplies a noncommutative forest by a scalar, PlanarTree, NoncommutativeForest or ForestSum. **Example usage:** .. kauri-exec:: t1 = PlanarTree([]) t2 = PlanarTree([[]]) kr.display(t1 * t2, t2 * t1) # Noncommutative: order matters """ return self._forest_mul(other, prepend=False) def __rmul__(self, other): return self._forest_mul(other, prepend=True) def __pow__(self, n): """Returns the n-th power of a noncommutative forest by repeating its tree list n times. **Example usage:** .. kauri-exec:: t = ( PlanarTree([]) * PlanarTree([[]]) ) ** 3 kr.display(t) """ if not isinstance(n, int): raise TypeError("Exponent must be an int, not " + str(type(n))) if n < 0: raise ValueError("Cannot raise NoncommutativeForest to a negative power") if n == 0: return EMPTY_ORDERED_FOREST return NoncommutativeForest(self.tree_list * n).simplify() def __add__(self, other): """Adds a noncommutative forest to a scalar, PlanarTree, NoncommutativeForest or ForestSum. **Example usage:** .. kauri-exec:: t = 2 + PlanarTree([[]]) + NoncommutativeForest([PlanarTree([]), PlanarTree([[],[[]]]) ]) kr.display(t) """ if _is_scalar(other): return ForestSum(((1, self), (other, EMPTY_ORDERED_FOREST))).simplify() if isinstance(other, (PlanarTree, NoncommutativeForest)): return ForestSum(((1, self), (1, other))).simplify() if isinstance(other, ForestSum): _check_compatible(self, other) return ForestSum(((1, self),) + other.term_list).simplify() _check_compatible(self, other) raise TypeError("Cannot add NoncommutativeForest and " + str(type(other))) __radd__ = __add__ def __sub__(self, other): return self + (-other) def __rsub__(self, other): return (-self) + other def __neg__(self): return ForestSum(((-1, self),)) def __eq__(self, other): """Compares the forest with another object. Unlike CommutativeForest, order matters: ``t1 * t2 != t2 * t1`` when the trees differ. **Example usage:** .. kauri-exec:: t1 = PlanarTree([]) t2 = PlanarTree([[]]) print(t1 * t2 == t1 * t2) print(t1 * t2 == t2 * t1) """ if isinstance(other, NoncommutativeForest): return self.tree_list == other.tree_list if isinstance(other, (PlanarTree, ForestSum)): return self.as_forest_sum() == other return NotImplemented def __hash__(self): return hash(self.tree_list) def sign(self): """Returns the forest signed by the number of nodes, :math:`(-1)^{|f|} f`. **Example usage:** .. kauri-exec:: f1 = PlanarTree([[]]) * PlanarTree([[],[]]) kr.display(f1.sign()) f2 = PlanarTree([]) * PlanarTree([[],[]]) kr.display(f2.sign()) """ return self.as_forest_sum() if self.nodes() % 2 == 0 else -self def __matmul__(self, other): """Returns the tensor product of a NoncommutativeForest and a scalar, PlanarTree, NoncommutativeForest or ForestSum. **Example usage:** .. kauri-exec:: tp = PlanarTree([]) @ (PlanarTree([[]]) + PlanarTree([]) * PlanarTree([[],[]])) kr.display(tp) """ if _is_scalar(other): return TensorProductSum(((other, self, EMPTY_ORDERED_FOREST),)) if isinstance(other, (PlanarTree, NoncommutativeForest)): return TensorProductSum(((1, self, _coerce_to_forest(other)),)) if isinstance(other, ForestSum): return TensorProductSum(tuple((c, self, f) for c, f in other)) raise TypeError("Cannot take tensor product of NoncommutativeForest and " + str(type(other))) def as_forest_sum(self): """Returns the forest as a ForestSum with coefficient 1.""" return ForestSum(((1, self),)) def join(self, root_color=0): """Joins the forest into a single planar tree by connecting all trees to a new root (the B+ map). **Example usage:** .. kauri-exec:: f1 = PlanarTree([]) * PlanarTree([[]]) f2 = PlanarTree([[]]) * PlanarTree([]) kr.display(f1, '\u2192', f1.join()) # Join preserves left-to-right order kr.display(f2, '\u2192', f2.join()) """ children = tuple(t.list_repr for t in self.tree_list if t.list_repr is not None) return PlanarTree(children + (root_color,)) OrderedForest = NoncommutativeForest OrderedForestSum = ForestSum EMPTY_PLANAR_TREE = PlanarTree(None) EMPTY_ORDERED_FOREST = NoncommutativeForest((EMPTY_PLANAR_TREE,)) _CROSS_TYPE_HINT = ( "Cannot combine planar and non-planar tree types. " "Use PlanarTree/OrderedForest with planar algebras (pgl, nck), " "or Tree/Forest with non-planar algebras (gl, bck, cem)." ) def _is_planar_obj(obj): """Return True if obj is planar, False if non-planar, None if unknown.""" if isinstance(obj, (PlanarTree, NoncommutativeForest)): return True if isinstance(obj, (Tree, CommutativeForest)): return False if isinstance(obj, ForestSum): for c, f in obj.term_list: if isinstance(f, NoncommutativeForest): return True if isinstance(f, CommutativeForest): return False return None return None def _check_compatible(a, b): """Raise TypeError if a and b mix planar and non-planar types.""" pa, pb = _is_planar_obj(a), _is_planar_obj(b) if pa is not None and pb is not None and pa != pb: raise TypeError(_CROSS_TYPE_HINT) def validate_order(order: int, *, allow_zero: bool = True) -> None: if not isinstance(order, int): raise TypeError("order must be an int, not " + str(type(order))) if allow_zero: if order < 0: raise ValueError("order must be non-negative") return if order <= 0: raise ValueError("order must be positive")