# Copyright 2025 Daniil Shmelev
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================
"""
SVG rendering for Tree, Forest, ForestSum and TensorProductSum objects.
"""
import warnings
from typing import Union
from .trees import (Tree, ForestSum, Forest, TensorProductSum,
PlanarTree, NoncommutativeForest, _is_scalar)
from .utils import _branch_level_sequences, _str
# ── Configuration constants ──────────────────────────────────────────────
NODE_RADIUS = 4
EDGE_WIDTH = 1.5
LEVEL_SPACING = 20
SIBLING_GAP = 10
TREE_GAP = 5
TERM_GAP = 18
COEFF_GAP = 6
TENSOR_GAP = 14
DISPLAY_GAP = 20
FONT_SIZE = 14
PADDING = 8
CHAR_WIDTH_FACTOR = 0.6 # estimated character width as fraction of font size
NODE_STROKE_WIDTH = 1.0
COLORS = ['black',
'firebrick',
'mediumblue',
'forestgreen',
'rebeccapurple',
'darkorange',
'grey',
'dodgerblue',
'deeppink']
# ── Helpers ──────────────────────────────────────────────────────────────
def _shift_items_x(items, offset):
"""Shift all x-coordinates in render items by offset."""
shifted = []
for item in items:
kind = item[0]
if kind == 'node':
shifted.append(('node', item[1] + offset, item[2], item[3]))
elif kind == 'edge':
shifted.append(('edge', item[1] + offset, item[2],
item[3] + offset, item[4]))
elif kind == 'text':
shifted.append(('text', item[1] + offset, item[2],
item[3], item[4]))
else:
raise ValueError("Unknown render item kind: " + repr(kind))
return shifted
# ── Layer 1: Layout ──────────────────────────────────────────────────────
def _layout_tree(level_seq, color_seq, x_center, y_base, scale):
"""Recursively compute layout items for a single tree.
Returns (items, width, height) where items are render primitives and
y grows *upward* (root at bottom).
"""
gap = SIBLING_GAP * scale
level_sp = LEVEL_SPACING * scale
if level_seq == []:
return [], 0, 0
if level_seq == [0]:
return [('node', x_center, y_base, color_seq[0])], gap, level_sp
items = []
branches = _branch_level_sequences(level_seq)
# Split color_seq into per-branch sequences
branch_colors = []
idx = 1
for branch in branches:
branch_colors.append(color_seq[idx:idx + len(branch)])
idx += len(branch)
# Layout each branch
branch_items = []
branch_widths = []
branch_heights = []
for branch, bcols in zip(branches, branch_colors):
b_items, b_w, b_h = _layout_tree(branch, bcols, 0, y_base + level_sp, scale)
branch_items.append(b_items)
branch_widths.append(b_w)
branch_heights.append(b_h)
total_width = sum(branch_widths) + (len(branch_widths) - 1) * gap
total_width = max(total_width, gap)
# Position branches left-to-right centred on x_center
pos = x_center - total_width / 2
for i in range(len(branch_items)):
offset_x = pos + branch_widths[i] / 2
items.extend(_shift_items_x(branch_items[i], offset_x))
# Edge from root to branch root (branch root is always at local
# (0, y_base+level_sp), shifted by offset_x)
if branch_items[i]:
items.append(('edge', x_center, y_base,
offset_x, y_base + level_sp))
pos += branch_widths[i] + gap
# Add root node
items.append(('node', x_center, y_base, color_seq[0]))
max_height = level_sp
if branch_heights:
max_height = max(branch_heights) + level_sp
return items, total_width, max_height
def _layout_forest(forest, x_start, y_base, scale, show_empty=False):
"""Lay out a forest (sequence of trees) left-to-right.
Returns (items, width, height).
"""
gap = TREE_GAP * scale
cw = CHAR_WIDTH_FACTOR
if show_empty and len(forest.tree_list) == 1 and forest.tree_list[0].list_repr is None:
fs = FONT_SIZE * scale
items = [('text', x_start + fs * cw / 2, y_base, '\u2205', fs)]
return items, fs * cw, LEVEL_SPACING * scale
items = []
x = x_start
max_height = 0
for t in forest.tree_list:
level_seq = t.level_sequence()
color_seq = t.color_sequence()
if level_seq == []:
continue
t_items, t_w, t_h = _layout_tree(level_seq, color_seq, 0, y_base, scale)
items.extend(_shift_items_x(t_items, x + t_w / 2))
x += t_w + gap
max_height = max(max_height, t_h)
width = max(x - x_start - gap, 0) if items else 0
if max_height == 0:
max_height = LEVEL_SPACING * scale
return items, width, max_height
def _format_coeff(c, is_first, rationalise):
"""Format coefficient for display, suppressing trivial '1' coefficients."""
try:
abs_c = abs(c)
is_neg = bool(c < 0)
except TypeError:
# Symbolic coefficient (e.g., sympy.Symbol) — treat as non-negative
return _str(c, rationalise)
if abs_c == 1:
if is_first:
return '' if not is_neg else '\u2212'
else:
return '' # sign handled by +/- operator
else:
s = _str(abs_c, rationalise)
if is_first and is_neg:
return '\u2212' + s
return s
def _layout_coeff_op(items, x, c, is_first, scale, rationalise):
"""Lay out the operator (+/−) and coefficient for a single term.
Returns the new x position after emitting any text items.
"""
term_gap = TERM_GAP * scale
coeff_gap = COEFF_GAP * scale
fs = FONT_SIZE * scale
cw = CHAR_WIDTH_FACTOR
if not is_first:
try:
op = '+' if bool(c >= 0) else '\u2212'
except TypeError:
op = '+' # symbolic coefficient — treat as positive
items.append(('text', x + term_gap / 2, 0, op, fs))
x += term_gap
coeff_str = _format_coeff(c, is_first, rationalise)
if coeff_str:
items.append(('text', x + len(coeff_str) * fs * cw / 2, 0, coeff_str, fs))
x += len(coeff_str) * fs * cw + coeff_gap
return x
def _layout_forest_sum(forest_sum, scale, rationalise=False):
"""Lay out a ForestSum left-to-right.
Returns (items, total_width, total_height).
"""
if not isinstance(forest_sum, ForestSum):
if _is_scalar(forest_sum):
forest_sum = Tree(None) * forest_sum
else:
forest_sum = forest_sum.as_forest_sum()
if len(forest_sum.term_list) == 0:
fs = FONT_SIZE * scale
items = [('text', 0, 0, '0', fs)]
return items, fs * 0.5, fs
items = []
x = 0
max_height = 0
coeff_gap = COEFF_GAP * scale
for i, (c, f) in enumerate(forest_sum.term_list):
x = _layout_coeff_op(items, x, c, i == 0, scale, rationalise)
f_items, f_w, f_h = _layout_forest(f, x, 0, scale, show_empty=True)
items.extend(f_items)
x += f_w + coeff_gap / 2
max_height = max(max_height, f_h)
if max_height == 0:
max_height = LEVEL_SPACING * scale
return items, x, max_height
def _layout_tensor_sum(tensor_sum, scale, rationalise=False):
"""Lay out a TensorProductSum left-to-right.
Returns (items, total_width, total_height).
"""
fs = FONT_SIZE * scale
cw = CHAR_WIDTH_FACTOR
if not tensor_sum.term_list:
items = [('text', 0, 0, '0', fs)]
return items, fs * 0.5, fs
items = []
x = 0
max_height = 0
coeff_gap = COEFF_GAP * scale
tensor_gap = TENSOR_GAP * scale
for i, (c, f1, f2) in enumerate(tensor_sum.term_list):
x = _layout_coeff_op(items, x, c, i == 0, scale, rationalise)
# Left forest
f1_items, f1_w, f1_h = _layout_forest(f1, x, 0, scale, show_empty=True)
items.extend(f1_items)
x += f1_w + tensor_gap / 2
# Tensor symbol
items.append(('text', x + fs * cw / 2, 0, '\u2297', fs))
x += fs * cw + tensor_gap / 2
# Right forest
f2_items, f2_w, f2_h = _layout_forest(f2, x, 0, scale, show_empty=True)
items.extend(f2_items)
x += f2_w + coeff_gap / 2
max_height = max(max_height, f1_h, f2_h)
if max_height == 0:
max_height = LEVEL_SPACING * scale
return items, x, max_height
# ── Layer 2: SVG Rendering ───────────────────────────────────────────────
def _render_svg(items, width, height, scale):
"""Convert render items into an SVG string.
Layout uses y-up coordinates (root at bottom). SVG uses y-down,
so we flip: svg_y = height - layout_y.
"""
pad = PADDING * scale
svg_w = width + 2 * pad
svg_h = height + 2 * pad
r = NODE_RADIUS * scale
ew = EDGE_WIDTH * scale
nsw = NODE_STROKE_WIDTH
parts = [
f'<svg xmlns="http://www.w3.org/2000/svg" '
f'width="{svg_w:.1f}" height="{svg_h:.1f}" '
f'viewBox="0 0 {svg_w:.1f} {svg_h:.1f}">'
]
# Render order: edges (behind), then nodes, then text (on top)
_ORDER = {'edge': 0, 'node': 1, 'text': 2}
for item in sorted(items, key=lambda it: _ORDER[it[0]]):
kind = item[0]
if kind == 'edge':
_, x1, y1, x2, y2 = item
parts.append(
f'<line x1="{x1 + pad:.1f}" y1="{height - y1 + pad:.1f}" '
f'x2="{x2 + pad:.1f}" y2="{height - y2 + pad:.1f}" '
f'stroke="black" stroke-width="{ew:.1f}" '
f'stroke-linecap="round" />'
)
elif kind == 'node':
_, x, y, color_idx = item
sx, sy = x + pad, height - y + pad
color = COLORS[color_idx] if color_idx < len(COLORS) else 'black'
if color_idx > 0:
parts.append(
f'<circle cx="{sx:.1f}" cy="{sy:.1f}" r="{r:.1f}" '
f'fill="{color}" stroke="black" stroke-width="{nsw}" />'
)
else:
parts.append(
f'<circle cx="{sx:.1f}" cy="{sy:.1f}" r="{r:.1f}" '
f'fill="black" />'
)
else:
_, x, y, text, font_size = item
parts.append(
f'<text x="{x + pad:.1f}" y="{height - y + pad:.1f}" '
f'font-size="{font_size:.1f}" font-family="sans-serif" '
f'text-anchor="middle" dominant-baseline="central">'
f'{text}</text>'
)
parts.append('</svg>')
return '\n'.join(parts)
# ── Layer 3: Orchestration ───────────────────────────────────────────────
def _layout_single(obj, scale, rationalise):
"""Layout a single tree-algebra object.
Returns (items, width, height).
"""
if isinstance(obj, TensorProductSum):
return _layout_tensor_sum(obj, scale, rationalise)
elif isinstance(obj, ForestSum):
return _layout_forest_sum(obj, scale, rationalise)
elif isinstance(obj, (Forest, NoncommutativeForest)):
return _layout_forest(obj, 0, 0, scale, show_empty=True)
elif isinstance(obj, (Tree, PlanarTree)):
if obj.list_repr is None:
fs = FONT_SIZE * scale
cw = CHAR_WIDTH_FACTOR
items = [('text', fs * cw / 2, fs * cw / 2, '\u2205', fs)]
return items, fs * cw, fs * cw
else:
level_seq = obj.level_sequence()
color_seq = obj.color_sequence()
items, w, h = _layout_tree(level_seq, color_seq, 0, 0, scale)
items = _shift_items_x(items, w / 2)
return items, w, h
elif isinstance(obj, str):
fs = FONT_SIZE * scale
cw = CHAR_WIDTH_FACTOR
w = len(obj) * fs * cw
items = [('text', w / 2, 0, obj, fs)]
return items, w, LEVEL_SPACING * scale
elif isinstance(obj, (int, float)):
return _layout_single(str(obj), scale, rationalise)
else:
# Try converting to string as a fallback (e.g. sympy expressions)
try:
return _layout_single(str(obj), scale, rationalise)
except Exception:
raise TypeError("Cannot display object of type " + str(type(obj)))
def _to_svg(*objects, scale=1.0, rationalise=False):
"""Generate an SVG string for one or more tree-algebra objects.
Multiple arguments are laid out side by side with extra spacing.
"""
if scale <= 0:
raise ValueError("scale must be positive, got " + str(scale))
if len(objects) == 1:
items, w, h = _layout_single(objects[0], scale, rationalise)
return _render_svg(items, w, h, scale)
gap = DISPLAY_GAP * scale
all_items = []
total_w = 0
max_h = 0
for i, obj in enumerate(objects):
items, w, h = _layout_single(obj, scale, rationalise)
all_items.extend(_shift_items_x(items, total_w))
total_w += w
max_h = max(max_h, h)
if i < len(objects) - 1:
total_w += gap
return _render_svg(all_items, total_w, max_h, scale)
# ── Jupyter detection ────────────────────────────────────────────────────
def _in_jupyter():
try:
from IPython import get_ipython
ip = get_ipython()
return ip is not None and 'IPKernelApp' in ip.config
except (ImportError, AttributeError):
return False
# ── Public API ───────────────────────────────────────────────────────────
[docs]
def display(*objects: Union[Tree, Forest, ForestSum, TensorProductSum,
PlanarTree, NoncommutativeForest, str],
scale: float = 1.0,
fig_size: tuple = None,
file_name: str = None,
use_plt: bool = None,
rationalise: bool = False) -> None:
"""
Display one or more tree-algebra objects, optionally with string labels.
In Jupyter, renders inline SVG.
Multiple arguments are rendered side by side in a single image,
analogous to ``print(a, b, c)``. Strings are rendered as text
labels between objects, e.g. ``display(t, "\u2192", t.unjoin())``.
:param objects: One or more objects or strings to display
:param scale: Scale factor for SVG output (default 1.0)
:param file_name: If provided, saves SVG to ``file_name.svg``
:param rationalise: If True, rationalise float coefficients
"""
_TREE_TYPES = (Tree, Forest, ForestSum, TensorProductSum,
PlanarTree, NoncommutativeForest)
if not objects:
raise TypeError("display() requires at least one argument.")
for obj in objects:
if isinstance(obj, (str, int, float)):
continue
if not isinstance(obj, _TREE_TYPES):
# Allow anything with a __str__ method (e.g. sympy expressions)
if not hasattr(obj, '__str__'):
raise TypeError("Cannot display object of type " + str(type(obj)))
continue
if isinstance(obj, ForestSum) and len(obj.term_list) == 0:
continue
if isinstance(obj, TensorProductSum) and (obj.term_list is None or len(obj.term_list) == 0):
continue
if obj.colors() > 9:
raise ValueError("Cannot display labelled trees with more than 9 different colors.")
if use_plt is not None:
warnings.warn("use_plt is deprecated and ignored. Output is now SVG.",
DeprecationWarning, stacklevel=2)
if fig_size is not None:
warnings.warn("fig_size is deprecated and ignored. SVG auto-sizes.",
DeprecationWarning, stacklevel=2)
svg = _to_svg(*objects, scale=scale, rationalise=rationalise)
if file_name is not None:
with open(file_name + '.svg', 'w', encoding='utf-8') as fh:
fh.write(svg)
if _in_jupyter():
from IPython.display import display as ipy_display
ipy_display({'image/svg+xml': svg}, raw=True)