Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 59 additions & 35 deletions manim/mobject/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def construct(self):

import itertools as it
from collections.abc import Callable, Iterable, Sequence
from typing import Any, Self

from manim.mobject.geometry.line import Line
from manim.mobject.geometry.polygram import Polygon
Expand Down Expand Up @@ -186,9 +187,9 @@ def construct(self):

def __init__(
self,
table: Iterable[Iterable[float | str | VMobject]],
row_labels: Iterable[VMobject] | None = None,
col_labels: Iterable[VMobject] | None = None,
table: Sequence[Sequence[float | str | VMobject]],
row_labels: Sequence[VMobject] | None = None,
col_labels: Sequence[VMobject] | None = None,
top_left_entry: VMobject | None = None,
v_buff: float = 0.8,
h_buff: float = 1.3,
Expand All @@ -198,16 +199,25 @@ def __init__(
include_background_rectangle: bool = False,
background_rectangle_color: ParsableManimColor = BLACK,
element_to_mobject: Callable[
[float | str],
VMobject,
]
| Callable[
[VMobject],
VMobject,
]
| Callable[
[float | str | VMobject],
VMobject,
] = Paragraph,
]
| type[VMobject] = Paragraph,
element_to_mobject_config: dict = {},
arrange_in_grid_config: dict = {},
line_config: dict = {},
**kwargs,
**kwargs: Any,
):
self.row_labels = row_labels
self.col_labels = col_labels
self.row_labels = list(row_labels) if row_labels else None
self.col_labels = list(col_labels) if col_labels else None
self.top_left_entry = top_left_entry
self.row_dim = len(table)
self.col_dim = len(table[0])
Expand All @@ -230,7 +240,7 @@ def __init__(
raise ValueError("Not all rows in table have the same length.")

super().__init__(**kwargs)
mob_table = self._table_to_mob_table(table)
mob_table: list[list[VMobject]] = self._table_to_mob_table(table)
self.elements_without_labels = VGroup(*it.chain(*mob_table))
mob_table = self._add_labels(mob_table)
self._organize_mob_table(mob_table)
Expand All @@ -252,7 +262,7 @@ def __init__(
def _table_to_mob_table(
self,
table: Iterable[Iterable[float | str | VMobject]],
) -> list:
) -> list[list[VMobject]]:
"""Initializes the entries of ``table`` as :class:`~.VMobject`.

Parameters
Expand All @@ -268,13 +278,15 @@ def _table_to_mob_table(
"""
return [
[
self.element_to_mobject(item, **self.element_to_mobject_config)
# error: Argument 1 has incompatible type "float | str | VMobject"; expected "float | str" [arg-type]
# error: Argument 1 has incompatible type "float | str | VMobject"; expected "VMobject" [arg-type]
self.element_to_mobject(item, **self.element_to_mobject_config) # type: ignore[arg-type]
for item in row
]
for row in table
]

def _organize_mob_table(self, table: Iterable[Iterable[VMobject]]) -> VGroup:
def _organize_mob_table(self, table: Sequence[Sequence[VMobject]]) -> VGroup:
"""Arranges the :class:`~.VMobject` of ``table`` in a grid.

Parameters
Expand All @@ -300,7 +312,7 @@ def _organize_mob_table(self, table: Iterable[Iterable[VMobject]]) -> VGroup:
)
return help_table

def _add_labels(self, mob_table: VGroup) -> VGroup:
def _add_labels(self, mob_table: list[list[VMobject]]) -> list[list[VMobject]]:
"""Adds labels to an in a grid arranged :class:`~.VGroup`.

Parameters
Expand All @@ -319,13 +331,13 @@ def _add_labels(self, mob_table: VGroup) -> VGroup:
if self.col_labels is not None:
if self.row_labels is not None:
if self.top_left_entry is not None:
col_labels = [self.top_left_entry] + self.col_labels
col_labels = [self.top_left_entry] + list(self.col_labels)
mob_table.insert(0, col_labels)
else:
# Placeholder to use arrange_in_grid if top_left_entry is not set.
# Import OpenGLVMobject to work with --renderer=opengl
dummy_mobject = get_vectorized_mobject_class()()
col_labels = [dummy_mobject] + self.col_labels
col_labels = [dummy_mobject] + list(self.col_labels)
mob_table.insert(0, col_labels)
else:
mob_table.insert(0, self.col_labels)
Expand Down Expand Up @@ -682,7 +694,10 @@ def construct(self):
item.set_color(random_bright_color())
self.add(table)
"""
return VGroup(*self.row_labels)
if self.row_labels:
return VGroup(*self.row_labels)
else:
return VGroup()

def get_col_labels(self) -> VGroup:
"""Return the column labels of the table.
Expand Down Expand Up @@ -710,7 +725,10 @@ def construct(self):
item.set_color(random_bright_color())
self.add(table)
"""
return VGroup(*self.col_labels)
if self.col_labels:
return VGroup(*self.col_labels)
else:
return VGroup()

def get_labels(self) -> VGroup:
"""Returns the labels of the table.
Expand Down Expand Up @@ -753,7 +771,7 @@ def add_background_to_entries(self, color: ParsableManimColor = BLACK) -> Table:
mob.add_background_rectangle(color=ManimColor(color))
return self

def get_cell(self, pos: Sequence[int] = (1, 1), **kwargs) -> Polygon:
def get_cell(self, pos: Sequence[int] = (1, 1), **kwargs: Any) -> Polygon:
"""Returns one specific cell as a rectangular :class:`~.Polygon` without the entry.

Parameters
Expand Down Expand Up @@ -814,7 +832,7 @@ def get_highlighted_cell(
self,
pos: Sequence[int] = (1, 1),
color: ParsableManimColor = PURE_YELLOW,
**kwargs,
**kwargs: Any,
) -> BackgroundRectangle:
"""Returns a :class:`~.BackgroundRectangle` of the cell at the given position.

Expand Down Expand Up @@ -853,7 +871,7 @@ def add_highlighted_cell(
self,
pos: Sequence[int] = (1, 1),
color: ParsableManimColor = PURE_YELLOW,
**kwargs,
**kwargs: Any,
) -> Table:
"""Highlights one cell at a specific position on the table by adding a :class:`~.BackgroundRectangle`.

Expand Down Expand Up @@ -896,7 +914,7 @@ def create(
label_animation: Callable[[VMobject | VGroup], Animation] = Write,
element_animation: Callable[[VMobject | VGroup], Animation] = Create,
entry_animation: Callable[[VMobject | VGroup], Animation] = FadeIn,
**kwargs,
**kwargs: Any,
) -> AnimationGroup:
"""Customized create-type function for tables.

Expand Down Expand Up @@ -936,7 +954,7 @@ def construct(self):
self.play(table.create())
self.wait()
"""
animations: Sequence[Animation] = [
animations: list[Animation] = [
line_animation(
VGroup(self.vertical_lines, self.horizontal_lines),
**kwargs,
Expand All @@ -963,12 +981,14 @@ def construct(self):

return AnimationGroup(*animations, lag_ratio=lag_ratio)

def scale(self, scale_factor: float, **kwargs):
def scale(
self, scale_factor: float, scale_stroke: bool = False, **kwargs: Any
) -> Self:
# h_buff and v_buff must be adjusted so that Table.get_cell
# can construct an accurate polygon for a cell.
self.h_buff *= scale_factor
self.v_buff *= scale_factor
super().scale(scale_factor, **kwargs)
super().scale(scale_factor, scale_stroke=scale_stroke, **kwargs)
return self


Expand All @@ -994,9 +1014,10 @@ def construct(self):

def __init__(
self,
table: Iterable[Iterable[float | str]],
element_to_mobject: Callable[[float | str], VMobject] = MathTex,
**kwargs,
table: Sequence[Sequence[float | str]],
element_to_mobject: Callable[[float | str], VMobject]
| type[VMobject] = MathTex,
**kwargs: Any,
):
"""
Special case of :class:`~.Table` with `element_to_mobject` set to :class:`~.MathTex`.
Expand Down Expand Up @@ -1049,9 +1070,10 @@ def construct(self):

def __init__(
self,
table: Iterable[Iterable[VMobject]],
element_to_mobject: Callable[[VMobject], VMobject] = lambda m: m,
**kwargs,
table: Sequence[Sequence[VMobject]],
element_to_mobject: Callable[[VMobject], VMobject]
| type[VMobject] = lambda m: m,
**kwargs: Any,
):
"""
Special case of :class:`~.Table` with ``element_to_mobject`` set to an identity function.
Expand Down Expand Up @@ -1097,9 +1119,10 @@ def construct(self):

def __init__(
self,
table: Iterable[Iterable[float | str]],
element_to_mobject: Callable[[float | str], VMobject] = Integer,
**kwargs,
table: Sequence[Sequence[float | str]],
element_to_mobject: Callable[[float | str], VMobject]
| type[VMobject] = Integer,
**kwargs: Any,
):
"""
Special case of :class:`~.Table` with `element_to_mobject` set to :class:`~.Integer`.
Expand Down Expand Up @@ -1141,10 +1164,11 @@ def construct(self):

def __init__(
self,
table: Iterable[Iterable[float | str]],
element_to_mobject: Callable[[float | str], VMobject] = DecimalNumber,
table: Sequence[Sequence[float | str]],
element_to_mobject: Callable[[float | str], VMobject]
| type[VMobject] = DecimalNumber,
element_to_mobject_config: dict = {"num_decimal_places": 1},
**kwargs,
**kwargs: Any,
):
"""
Special case of :class:`~.Table` with ``element_to_mobject`` set to :class:`~.DecimalNumber`.
Expand Down
3 changes: 0 additions & 3 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,6 @@ ignore_errors = True
[mypy-manim.mobject.opengl.opengl_vectorized_mobject]
ignore_errors = True

[mypy-manim.mobject.table]
ignore_errors = True

[mypy-manim.mobject.types.point_cloud_mobject]
ignore_errors = True

Expand Down
Loading