Source code for kauri.cem.cem

# Copyright 2025 Daniil Shmelev
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================

"""
The CEM Hopf algebra module
"""
from functools import cache
import itertools
from ..maps import Map
from ..trees import (Tree, PlanarTree, Forest, TensorProductSum)
from ..generic_algebra import func_power


def counit_impl(t):
    # Return 1 if t is the empty tree or the single-node tree (unit in CEM), otherwise 0
    if t.list_repr is None:
        return 1
    return 1 if len(t.list_repr) == 1 else 0

@cache
def antipode_impl(t):

    # Consider the empty tree and the single node tree to be equal, since the latter is the unit
    if t.list_repr is None:
        return Tree([]).as_forest_sum()
    if len(t.list_repr) == 1:
        return t.as_forest_sum()

    cp = coproduct_impl(t)
    out = -t.as_forest_sum() # First term, -t
    for c, branches, subtree_ in cp: # Remaining terms
        subtree = subtree_[0] # Convert from Forest to Tree
        if branches.equals(t.as_forest()) or subtree.equals(t):
            continue # We've already included the -t term at the start, so move on
        out = out - c * antipode_impl(subtree) * branches

    return out.singleton_reduced().simplify() # Single node tree is the unit, so can apply .singleton_reduced() here

@cache
def _coproduct_helper(t):
    # This returns the coproduct as a list of Forests and a list of Trees
    # The function _coproduct then converts this to a tensor product sum
    # and simplifies.

    # We compute the coproduct for a tree t = [t_1, t_2, ..., t_k] recursively,
    # similarly to how we did for the BCK coproduct. Instead of cutting edges,
    # however, we are now contracting them.

    # Caching of this function makes it fairly efficient for large computations.

    if t.list_repr is None: # Empty tree
        return [Tree([]).as_forest()], [Tree([])]
    if len(t.list_repr) == 1: # Singleton tree
        return [t.as_forest()], [t]

    # Compute the coproducts of t_1, t_2, ..., t_k
    subtree_coproduct_trees = []
    subtree_coproduct_forests = []
    for rep in t.list_repr[:-1]: # Recall last element is the root label, so take [:-1]
        b, s = _coproduct_helper(Tree(rep))
        subtree_coproduct_trees.append(s)
        subtree_coproduct_forests.append(b)

    # Now compute the coproduct of t
    t_coproduct_trees = []
    t_coproduct_forests = []

    k = len(subtree_coproduct_trees) # Number of subtrees

    # For each edge connecting a subtree to the root, we can either contract
    # it or not. The "edges" parameter here is a list of binary flags which
    # determine which edges to contract. We must loop over all possible
    # combinations of these flags, of which there are 2^k.
    for edges in itertools.product([0, 1], repeat=k):

        # For a given choice of edges to contract, we must now loop over all
        # possible combinations of coproduct terms of subtrees. This is the same
        # logic as in the BCK case, but split into two loops handling the trees
        # and forests separately, as its slightly easier to accommodate the
        # contraction logic this way.

        # Get trees
        for p in itertools.product(*subtree_coproduct_trees):
            rep = []
            for edge, t_ in zip(edges, p):
                if t_.list_repr is None:
                    continue
                rep += t_.list_repr[:-1] if edge else [t_.list_repr] # Add to tree, contracting the edge if necessary
            t_coproduct_trees.append(Tree(rep))

        # Get forests
        for p in itertools.product(*subtree_coproduct_forests):
            # Must ensure that the first tree in the forest is connected to the root,
            # as it's important to know what this tree is for the recursion.
            # If no such tree, add an empty tree to the forest to signify this
            # Forest constructor does not call Forest.simplify(), meaning this empty tree will survive
            t_list_ = []
            root_tree_repr = [] # The tree connected to the root
            for edge, f in zip(edges, p):
                if edge:
                    root_tree_repr += [f.tree_list[0].list_repr]
                    t_list_ += f.tree_list[1:]
                else:
                    t_list_ += f.tree_list
            t_list_ = [Tree(root_tree_repr)] + t_list_ # Keep tree connected to root at the start
            t_coproduct_forests.append(Forest(t_list_))

    return t_coproduct_forests, t_coproduct_trees

def _build_coproduct(t, singleton_reduce=True):
    """Build CEM coproduct TensorProductSum from _coproduct_helper output.

    If singleton_reduce=True, applies singleton_reduced() to left forests
    (for display/public API). If False, preserves bullet factors (needed
    for convolution products where f(bullet) != 1).
    """
    f, s = _coproduct_helper(t)
    if singleton_reduce:
        cp = zip([x.simplify().singleton_reduced() for x in f], s)
    else:
        cp = zip([x.simplify() for x in f], s)
    return TensorProductSum(tuple((1, x[0], x[1]) for x in cp)).simplify()

@cache
def _coproduct_raw(t):
    """Internal coproduct preserving bullet factors, cached for convolution use."""
    if not isinstance(t, Tree):
        hint = " The CEM algebra is only defined for non-planar trees." if isinstance(t, PlanarTree) else ""
        raise TypeError("CEM coproduct expects a Tree, not " + str(type(t)) + "." + hint)
    return _build_coproduct(t, singleton_reduce=False)

def coproduct_impl(t):
    if not isinstance(t, Tree):
        hint = " The CEM algebra is only defined for non-planar trees." if isinstance(t, PlanarTree) else ""
        raise TypeError("CEM coproduct expects a Tree, not " + str(type(t)) + "." + hint)
    return _build_coproduct(t, singleton_reduce=True)

counit = Map(counit_impl)
counit.__doc__ = """
The counit :math:`\\varepsilon_{CEM}` of the CEM Hopf algebra.

:type: Map

**Example usage:**

.. kauri-exec::

    print(cem.counit(Tree([])))  # Returns 1
    print(cem.counit(Tree([[]])))  # Returns 0
"""

def _safe_antipode(t):
    if not isinstance(t, Tree):
        hint = " The CEM algebra is only defined for non-planar trees." if isinstance(t, PlanarTree) else ""
        raise TypeError("Argument to cem.antipode must be a Tree, not " + str(type(t)) + "." + hint)
    if t.colors() > 1:
        raise ValueError("The CEM Hopf algebra is only defined for unlabelled trees")
    return antipode_impl(t)

antipode = Map(_safe_antipode)
antipode.__doc__ = """
The antipode :math:`S_{CEM}` of the CEM Hopf algebra.

:type: Map

**Example usage:**

.. kauri-exec::

    t = Tree([[[]],[]])
    kr.display(cem.antipode(t))
"""

[docs] def coproduct(t : Tree) -> TensorProductSum: """ The coproduct :math:`\\Delta_{CEM}` of the CEM Hopf algebra. :param t: tree :type t: Tree :rtype: TensorProductSum **Example usage:** .. kauri-exec:: t = Tree([[[]],[]]) kr.display(cem.coproduct(t)) """ if not isinstance(t, Tree): hint = " The CEM algebra is only defined for non-planar trees." if isinstance(t, PlanarTree) else "" raise TypeError("Argument to cem.coproduct must be a Tree, not " + str(type(t)) + "." + hint) if t.colors() > 1: raise ValueError("The CEM Hopf algebra is only defined for unlabelled trees") return coproduct_impl(t)
[docs] def map_product(f : Map, g : Map) -> Map: """ Returns the product of maps in the CEM Hopf algebra, defined by .. math:: (f \\cdot g)(t) := \\mu \\circ (f \\otimes g) \\circ \\Delta_{CEM} (t) .. note:: `cem.map_product(f,g)` is equivalent to the Map operator `f ^ g` :param f: f :type f: Map :param g: g :type g: Map :rtype: Map **Example usage:** .. kauri-exec:: f = cem.map_product(ident, cem.antipode) print(f(Tree([[]]))) """ if not (isinstance(f, Map) and isinstance(g, Map)): raise TypeError("Arguments in cem.map_product must be of type Map, not " + str(type(f)) + " and " + str(type(g))) return f ^ g
[docs] def map_power(f : Map, exponent : int) -> Map: """ Returns the power of a map in the CEM Hopf algebra, where the product of functions is defined by .. math:: (f \\cdot g)(t) := \\mu \\circ (f \\otimes g) \\circ \\Delta_{CEM} (t) and negative powers are defined as :math:`f^{-n} = (f \\circ S_{CEM})^n`, where :math:`S_{CEM}` is the CEM antipode. :param f: f :type f: Map :param exponent: exponent :type exponent: int **Example usage:** .. kauri-exec:: S = cem.map_power(ident, -1) # antipode print(S(Tree([[]]))) """ if not isinstance(f, Map): raise TypeError("f must be a Map, not " + str(type(f))) if not isinstance(exponent, int): raise TypeError("exponent must be an int, not " + str(type(exponent))) return Map(lambda x: func_power(x, f.func, exponent, _coproduct_raw, counit_impl, antipode_impl, singleton_reduce=True))