diff --git a/docs/source/reference_index/utilities_misc.rst b/docs/source/reference_index/utilities_misc.rst index 1fe9962079..5e7bee44ca 100644 --- a/docs/source/reference_index/utilities_misc.rst +++ b/docs/source/reference_index/utilities_misc.rst @@ -31,3 +31,4 @@ Module Index ~utils.tex_file_writing ~utils.tex_templates typing + ~utils.updaters diff --git a/manim/animation/speedmodifier.py b/manim/animation/speedmodifier.py index 63b9b2e5b3..e28bf883e9 100644 --- a/manim/animation/speedmodifier.py +++ b/manim/animation/speedmodifier.py @@ -2,19 +2,19 @@ from __future__ import annotations -import inspect import types from typing import TYPE_CHECKING, Callable from numpy import piecewise -from ..animation.animation import Animation, Wait, prepare_animation -from ..animation.composition import AnimationGroup -from ..mobject.mobject import Mobject, _AnimationBuilder -from ..scene.scene import Scene +from manim.animation.animation import Animation, Wait, prepare_animation +from manim.animation.composition import AnimationGroup +from manim.mobject.mobject import Mobject, _AnimationBuilder +from manim.scene.scene import Scene +from manim.utils.updaters import MobjectUpdaterWrapper if TYPE_CHECKING: - from ..mobject.mobject import Updater + from manim.utils.updaters import MobjectUpdater __all__ = ["ChangeSpeed"] @@ -235,7 +235,7 @@ def get_scaled_total_time(self) -> float: def add_updater( cls, mobject: Mobject, - update_function: Updater, + update_function: MobjectUpdater, index: int | None = None, call_updater: bool = False, ): @@ -264,7 +264,8 @@ def add_updater( :class:`.ChangeSpeed` :meth:`.Mobject.add_updater` """ - if "dt" in inspect.signature(update_function).parameters: + wrapper = MobjectUpdaterWrapper(update_function) + if wrapper.is_time_based: mobject.add_updater( lambda mob, dt: update_function( mob, ChangeSpeed.dt if ChangeSpeed.is_changing_dt else dt diff --git a/manim/mobject/mobject.py b/manim/mobject/mobject.py index 2079de7923..dc147c35b7 100644 --- a/manim/mobject/mobject.py +++ b/manim/mobject/mobject.py @@ -6,7 +6,6 @@ import copy -import inspect import itertools as it import math import operator as op @@ -14,18 +13,17 @@ import sys import types import warnings -from collections.abc import Iterable +from collections.abc import Iterable, Sequence from functools import partialmethod, reduce from pathlib import Path from typing import TYPE_CHECKING, Callable, Literal import numpy as np +from manim import config, logger +from manim.constants import * from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL - -from .. import config, logger -from ..constants import * -from ..utils.color import ( +from manim.utils.color import ( BLACK, WHITE, YELLOW_C, @@ -34,14 +32,16 @@ color_gradient, interpolate_color, ) -from ..utils.exceptions import MultiAnimationOverrideException -from ..utils.iterables import list_update, remove_list_redundancies -from ..utils.paths import straight_path -from ..utils.space_ops import angle_between_vectors, normalize, rotation_matrix +from manim.utils.exceptions import MultiAnimationOverrideException +from manim.utils.iterables import list_update, remove_list_redundancies +from manim.utils.paths import straight_path +from manim.utils.space_ops import angle_between_vectors, normalize, rotation_matrix +from manim.utils.updaters import MobjectUpdaterWrapper if TYPE_CHECKING: - from typing_extensions import Self, TypeAlias + from typing_extensions import Self + from manim.animation.animation import Animation from manim.typing import ( FunctionOverride, ManimFloat, @@ -53,12 +53,7 @@ Point3D_Array, Vector3D, ) - - from ..animation.animation import Animation - - TimeBasedUpdater: TypeAlias = Callable[["Mobject", float], object] - NonTimeBasedUpdater: TypeAlias = Callable[["Mobject"], object] - Updater: TypeAlias = NonTimeBasedUpdater | TimeBasedUpdater + from manim.utils.updaters import MobjectDtUpdater, MobjectUpdater class Mobject: @@ -71,6 +66,7 @@ class Mobject: Attributes ---------- submobjects : List[:class:`Mobject`] + The contained objects. points : :class:`numpy.ndarray` The points of the objects. @@ -96,7 +92,7 @@ def __init_subclass__(cls, **kwargs) -> None: def __init__( self, - color: ParsableManimColor | list[ParsableManimColor] = WHITE, + color: ParsableManimColor | Sequence[ParsableManimColor] = WHITE, name: str | None = None, dim: int = 3, target=None, @@ -108,7 +104,7 @@ def __init__( self.z_index = z_index self.point_hash = None self.submobjects = [] - self.updaters: list[Updater] = [] + self.updater_wrappers: Sequence[MobjectUpdaterWrapper] = [] self.updating_suspended = False self.color = ManimColor.parse(color) @@ -865,6 +861,10 @@ def generate_target(self, use_deepcopy: bool = False) -> Self: # Updating + @property + def updaters(self) -> Sequence[MobjectUpdater]: + return self.get_updaters() + def update(self, dt: float = 0, recursive: bool = True) -> Self: """Apply all updaters. @@ -891,17 +891,17 @@ def update(self, dt: float = 0, recursive: bool = True) -> Self: """ if self.updating_suspended: return self - for updater in self.updaters: - if "dt" in inspect.signature(updater).parameters: - updater(self, dt) + for wrapper in self.updater_wrappers: + if wrapper.is_time_based: + wrapper.updater(self, dt) else: - updater(self) + wrapper.updater(self) if recursive: for submob in self.submobjects: submob.update(dt, recursive) return self - def get_time_based_updaters(self) -> list[TimeBasedUpdater]: + def get_time_based_updaters(self) -> Sequence[MobjectDtUpdater]: """Return all updaters using the ``dt`` parameter. The updaters use this parameter as the input for difference in time. @@ -918,9 +918,9 @@ def get_time_based_updaters(self) -> list[TimeBasedUpdater]: """ return [ - updater - for updater in self.updaters - if "dt" in inspect.signature(updater).parameters + wrapper.updater + for wrapper in self.updater_wrappers + if wrapper.is_time_based ] def has_time_based_updater(self) -> bool: @@ -937,11 +937,9 @@ def has_time_based_updater(self) -> bool: :meth:`get_time_based_updaters` """ - return any( - "dt" in inspect.signature(updater).parameters for updater in self.updaters - ) + return any(wrapper.is_time_based for wrapper in self.updater_wrappers) - def get_updaters(self) -> list[Updater]: + def get_updaters(self) -> Sequence[MobjectUpdater]: """Return all updaters. Returns @@ -955,14 +953,14 @@ def get_updaters(self) -> list[Updater]: :meth:`get_time_based_updaters` """ - return self.updaters + return [wrapper.updater for wrapper in self.updater_wrappers] - def get_family_updaters(self) -> list[Updater]: + def get_family_updaters(self) -> Sequence[MobjectUpdater]: return list(it.chain(*(sm.get_updaters() for sm in self.get_family()))) def add_updater( self, - update_function: Updater, + update_function: MobjectUpdater, index: int | None = None, call_updater: bool = False, ) -> Self: @@ -1026,19 +1024,19 @@ def construct(self): :meth:`remove_updater` :class:`~.UpdateFromFunc` """ + wrapper = MobjectUpdaterWrapper(update_function) if index is None: - self.updaters.append(update_function) + self.updater_wrappers.append(wrapper) else: - self.updaters.insert(index, update_function) + self.updater_wrappers.insert(index, wrapper) if call_updater: - parameters = inspect.signature(update_function).parameters - if "dt" in parameters: - update_function(self, 0) + if wrapper.is_time_based: + wrapper.updater(self, 0) else: - update_function(self) + wrapper.updater(self) return self - def remove_updater(self, update_function: Updater) -> Self: + def remove_updater(self, update_function: MobjectUpdater) -> Self: """Remove an updater. If the same updater is applied multiple times, every instance gets removed. @@ -1061,8 +1059,11 @@ def remove_updater(self, update_function: Updater) -> Self: :meth:`get_updaters` """ - while update_function in self.updaters: - self.updaters.remove(update_function) + self.updater_wrappers = [ + wrapper + for wrapper in self.updater_wrappers + if wrapper.updater != update_function + ] return self def clear_updaters(self, recursive: bool = True) -> Self: @@ -1085,7 +1086,7 @@ def clear_updaters(self, recursive: bool = True) -> Self: :meth:`get_updaters` """ - self.updaters = [] + self.updater_wrappers = [] if recursive: for submob in self.submobjects: submob.clear_updaters() @@ -1116,8 +1117,7 @@ def match_updaters(self, mobject: Mobject) -> Self: """ self.clear_updaters() - for updater in mobject.get_updaters(): - self.add_updater(updater) + self.updater_wrappers = mobject.updater_wrappers.copy() return self def suspend_updating(self, recursive: bool = True) -> Self: diff --git a/manim/mobject/opengl/opengl_mobject.py b/manim/mobject/opengl/opengl_mobject.py index 1f7d44d6a3..adab5ce4ac 100644 --- a/manim/mobject/opengl/opengl_mobject.py +++ b/manim/mobject/opengl/opengl_mobject.py @@ -1,7 +1,6 @@ from __future__ import annotations import copy -import inspect import itertools as it import random import sys @@ -45,10 +44,11 @@ normalize, rotation_matrix_transpose, ) +from manim.utils.updaters import MobjectUpdaterWrapper if TYPE_CHECKING: import numpy.typing as npt - from typing_extensions import Self, TypeAlias + from typing_extensions import Self from manim.renderer.shader_wrapper import ShaderWrapper from manim.typing import ( @@ -60,10 +60,7 @@ Point3D_Array, Vector3D, ) - - TimeBasedUpdater: TypeAlias = Callable[["Mobject", float], object] - NonTimeBasedUpdater: TypeAlias = Callable[["Mobject"], object] - Updater: TypeAlias = NonTimeBasedUpdater | TimeBasedUpdater + from manim.utils.updaters import MobjectDtUpdater, MobjectUpdater T = TypeVar("T") @@ -1448,66 +1445,71 @@ def restore(self) -> Self: # Updating def init_updaters(self) -> None: - self.time_based_updaters = [] - self.non_time_updaters = [] + self.updater_wrappers: Sequence[MobjectUpdaterWrapper] = [] self.has_updaters = False self.updating_suspended = False + @property + def updaters(self) -> Sequence[MobjectUpdater]: + return self.get_updaters() + def update(self, dt: float = 0, recurse: bool = True) -> Self: if not self.has_updaters or self.updating_suspended: return self - for updater in self.time_based_updaters: - updater(self, dt) - for updater in self.non_time_updaters: - updater(self) + for wrapper in self.updater_wrappers: + if wrapper.is_time_based: + wrapper.updater(self, dt) + else: + wrapper.updater(self) if recurse: for submob in self.submobjects: submob.update(dt, recurse) return self - def get_time_based_updaters(self) -> Sequence[TimeBasedUpdater]: - return self.time_based_updaters + def get_time_based_updaters(self) -> Sequence[MobjectDtUpdater]: + return [ + wrapper.updater + for wrapper in self.updater_wrappers + if wrapper.is_time_based + ] def has_time_based_updater(self) -> bool: - return len(self.time_based_updaters) > 0 + return any(wrapper.is_time_based for wrapper in self.updater_wrappers) - def get_updaters(self) -> Sequence[Updater]: - return self.time_based_updaters + self.non_time_updaters + def get_updaters(self) -> Sequence[MobjectUpdater]: + return [wrapper.updater for wrapper in self.updater_wrappers] - def get_family_updaters(self) -> Sequence[Updater]: + def get_family_updaters(self) -> Sequence[MobjectUpdater]: return list(it.chain(*(sm.get_updaters() for sm in self.get_family()))) def add_updater( self, - update_function: Updater, + update_function: MobjectUpdater, index: int | None = None, call_updater: bool = False, ) -> Self: - if "dt" in inspect.signature(update_function).parameters: - updater_list = self.time_based_updaters - else: - updater_list = self.non_time_updaters - + wrapper = MobjectUpdaterWrapper(update_function) if index is None: - updater_list.append(update_function) + self.updater_wrappers.append(wrapper) else: - updater_list.insert(index, update_function) + self.updater_wrappers.insert(index, wrapper) self.refresh_has_updater_status() if call_updater: self.update() return self - def remove_updater(self, update_function: Updater) -> Self: - for updater_list in [self.time_based_updaters, self.non_time_updaters]: - while update_function in updater_list: - updater_list.remove(update_function) + def remove_updater(self, update_function: MobjectUpdater) -> Self: + self.updater_wrappers = [ + wrapper + for wrapper in self.updater_wrappers + if wrapper.updater != update_function + ] self.refresh_has_updater_status() return self def clear_updaters(self, recurse: bool = True) -> Self: - self.time_based_updaters = [] - self.non_time_updaters = [] + self.updater_wrappers = [] self.refresh_has_updater_status() if recurse: for submob in self.submobjects: @@ -1516,8 +1518,7 @@ def clear_updaters(self, recurse: bool = True) -> Self: def match_updaters(self, mobject: OpenGLMobject) -> Self: self.clear_updaters() - for updater in mobject.get_updaters(): - self.add_updater(updater) + self.updater_wrappers = mobject.updater_wrappers.copy() return self def suspend_updating(self, recurse: bool = True) -> Self: diff --git a/manim/renderer/shader.py b/manim/renderer/shader.py index a098ed30ca..b1faacd665 100644 --- a/manim/renderer/shader.py +++ b/manim/renderer/shader.py @@ -1,16 +1,23 @@ from __future__ import annotations import contextlib -import inspect import re import textwrap +from collections.abc import Sequence from pathlib import Path +from typing import TYPE_CHECKING import moderngl import numpy as np -from .. import config -from ..utils import opengl +from manim import config +from manim.utils import opengl +from manim.utils.updaters import MeshUpdaterWrapper + +if TYPE_CHECKING: + from typing_extensions import Self + + from manim.utils.updaters import MeshDtUpdater, MeshUpdater SHADER_FOLDER = Path(__file__).parent / "shaders" shader_program_cache: dict = {} @@ -175,76 +182,90 @@ def hierarchical_normal_matrix(self): current_object = current_object.parent return np.linalg.multi_dot(list(reversed(normal_matrices)))[:3, :3] - def init_updaters(self): - self.time_based_updaters = [] - self.non_time_updaters = [] + # Updating + + @property + def updaters(self) -> Sequence[MeshUpdater]: + return self.get_updaters() + + def init_updaters(self) -> None: + self.updater_wrappers: Sequence[MeshUpdaterWrapper] = [] self.has_updaters = False self.updating_suspended = False - def update(self, dt=0): + def update(self, dt: float = 0) -> Self: if not self.has_updaters or self.updating_suspended: return self - for updater in self.time_based_updaters: - updater(self, dt) - for updater in self.non_time_updaters: - updater(self) + for wrapper in self.updater_wrappers: + if wrapper.is_time_based: + wrapper.updater(self, dt) + else: + wrapper.updater(self) return self - def get_time_based_updaters(self): - return self.time_based_updaters - - def has_time_based_updater(self): - return len(self.time_based_updaters) > 0 + def get_time_based_updaters(self) -> Sequence[MeshDtUpdater]: + return [ + wrapper.updater + for wrapper in self.updater_wrappers + if wrapper.is_time_based + ] - def get_updaters(self): - return self.time_based_updaters + self.non_time_updaters + def has_time_based_updater(self) -> bool: + return any(wrapper.is_time_based for wrapper in self.updater_wrappers) - def add_updater(self, update_function, index=None, call_updater=True): - if "dt" in inspect.signature(update_function).parameters: - updater_list = self.time_based_updaters - else: - updater_list = self.non_time_updaters + def get_updaters(self) -> Sequence[MeshUpdater]: + return [wrapper.updater for wrapper in self.updater_wrappers] + def add_updater( + self, + update_function: MeshUpdater, + index: int | None = None, + call_updater: bool = False, + ) -> Self: + wrapper = MeshUpdaterWrapper(update_function) if index is None: - updater_list.append(update_function) + self.updater_wrappers.append(wrapper) else: - updater_list.insert(index, update_function) + self.updater_wrappers.insert(index, wrapper) self.refresh_has_updater_status() if call_updater: self.update() return self - def remove_updater(self, update_function): - for updater_list in [self.time_based_updaters, self.non_time_updaters]: - while update_function in updater_list: - updater_list.remove(update_function) + def remove_updater(self, update_function: MeshUpdater) -> Self: + self.updater_wrappers = [ + wrapper + for wrapper in self.updater_wrappers + if wrapper.updater != update_function + ] self.refresh_has_updater_status() return self - def clear_updaters(self): - self.time_based_updaters = [] - self.non_time_updaters = [] + def clear_updaters(self, recurse: bool = True) -> Self: + self.updater_wrappers = [] self.refresh_has_updater_status() + if recurse: + for submob in self.submobjects: + submob.clear_updaters() return self - def match_updaters(self, mobject): + def match_updaters(self, obj: Object3D) -> Self: self.clear_updaters() - for updater in mobject.get_updaters(): - self.add_updater(updater) + self.updater_wrappers = obj.updater_wrappers.copy() return self - def suspend_updating(self): + def suspend_updating(self) -> Self: self.updating_suspended = True return self - def resume_updating(self, call_updater=True): + def resume_updating(self, call_updater: bool = True) -> Self: self.updating_suspended = False if call_updater: self.update(dt=0) return self - def refresh_has_updater_status(self): + def refresh_has_updater_status(self) -> Self: self.has_updaters = len(self.get_updaters()) > 0 return self diff --git a/manim/scene/scene.py b/manim/scene/scene.py index 02c548cf7f..fff9fe892f 100644 --- a/manim/scene/scene.py +++ b/manim/scene/scene.py @@ -33,28 +33,30 @@ from watchdog.events import FileSystemEventHandler from watchdog.observers import Observer +from manim import config, logger +from manim.animation.animation import Animation, Wait, prepare_animation +from manim.camera.camera import Camera +from manim.constants import * +from manim.gui.gui import configure_pygui from manim.mobject.mobject import Mobject from manim.mobject.opengl.opengl_mobject import OpenGLPoint - -from .. import config, logger -from ..animation.animation import Animation, Wait, prepare_animation -from ..camera.camera import Camera -from ..constants import * -from ..gui.gui import configure_pygui -from ..renderer.cairo_renderer import CairoRenderer -from ..renderer.opengl_renderer import OpenGLRenderer -from ..renderer.shader import Object3D -from ..utils import opengl, space_ops -from ..utils.exceptions import EndSceneEarlyException, RerunSceneException -from ..utils.family import extract_mobject_family_members -from ..utils.family_ops import restructure_list_to_exclude_certain_family_members -from ..utils.file_ops import open_media_file -from ..utils.iterables import list_difference_update, list_update +from manim.renderer.cairo_renderer import CairoRenderer +from manim.renderer.opengl_renderer import OpenGLRenderer +from manim.renderer.shader import Object3D +from manim.utils import opengl, space_ops +from manim.utils.exceptions import EndSceneEarlyException, RerunSceneException +from manim.utils.family import extract_mobject_family_members +from manim.utils.family_ops import restructure_list_to_exclude_certain_family_members +from manim.utils.file_ops import open_media_file +from manim.utils.iterables import list_difference_update, list_update +from manim.utils.updaters import MobjectUpdaterWrapper if TYPE_CHECKING: from collections.abc import Iterable, Sequence from typing import Callable + from manim.utils.updaters import SceneUpdater + class RerunSceneHandler(FileSystemEventHandler): """A class to handle rerunning a Scene after the input file is modified.""" @@ -124,7 +126,7 @@ def __init__( self.camera_target = ORIGIN self.widgets = [] self.dearpygui_imported = dearpygui_imported - self.updaters = [] + self.updaters: Sequence[SceneUpdater] = [] self.point_lights = [] self.ambient_light = None self.key_to_function_map = {} @@ -168,12 +170,14 @@ def __deepcopy__(self, clone_from_id): if k == "camera_class": setattr(result, k, v) setattr(result, k, copy.deepcopy(v, clone_from_id)) + # TODO: where is this attribute even defined? result.mobject_updater_lists = [] # Update updaters for mobject in self.mobjects: - cloned_updaters = [] - for updater in mobject.updaters: + cloned_updater_wrappers = [] + for wrapper in mobject.updater_wrappers: + updater = wrapper.updater # Make the cloned updater use the cloned Mobjects as free variables # rather than the original ones. Analyzing function bytecode with the # dis module will help in understanding this. @@ -209,11 +213,13 @@ def __deepcopy__(self, clone_from_id): updater.__defaults__, tuple(cloned_closure), ) - cloned_updaters.append(cloned_updater) + cloned_updater_wrappers.append(MobjectUpdaterWrapper(cloned_updater)) mobject_clone = clone_from_id[id(mobject)] - mobject_clone.updaters = cloned_updaters - if len(cloned_updaters) > 0: - result.mobject_updater_lists.append((mobject_clone, cloned_updaters)) + mobject_clone.updater_wrappers = cloned_updater_wrappers + if len(cloned_updater_wrappers) > 0: + result.mobject_updater_lists.append( + (mobject_clone, cloned_updater_wrappers) + ) return result def render(self, preview: bool = False): @@ -572,7 +578,7 @@ def replace_in_list( if not replaced: raise ValueError(f"Could not find {old_mobject} in scene") - def add_updater(self, func: Callable[[float], None]) -> None: + def add_updater(self, func: SceneUpdater) -> None: """Add an update function to the scene. The scene updater functions are run every frame, @@ -603,7 +609,7 @@ def add_updater(self, func: Callable[[float], None]) -> None: """ self.updaters.append(func) - def remove_updater(self, func: Callable[[float], None]) -> None: + def remove_updater(self, func: SceneUpdater) -> None: """Remove an update function from the scene. Parameters diff --git a/manim/utils/updaters.py b/manim/utils/updaters.py new file mode 100644 index 0000000000..4916ba652d --- /dev/null +++ b/manim/utils/updaters.py @@ -0,0 +1,258 @@ +"""Updater types and classes. Updaters are functions which update an object +(a :class:`~.Mobject`, :class:`~.OpenGLMobject`, :class:`~.Object3D` or +:class:`~.Scene`) on every frame, and might accept an additional parameter +``dt`` which represents the frame's duration. +""" + +from __future__ import annotations + +import inspect +from collections.abc import Callable +from typing import TYPE_CHECKING, TypeVar, Union + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + from manim.mobject.mobject import Mobject + from manim.mobject.opengl.opengl_mobject import OpenGLMobject + from manim.renderer.shader import Object3D + + +__all__ = [ + "MobjectBasicUpdater", + "MobjectDtUpdater", + "MobjectUpdater", + "MeshBasicUpdater", + "MeshDtUpdater", + "MeshUpdater", + "SceneUpdater", + "MobjectUpdaterWrapper", + "MeshUpdaterWrapper", +] + + +Mobj = TypeVar("Mobj", bound=Union["Mobject", "OpenGLMobject"]) + +MobjectBasicUpdater: TypeAlias = Callable[[Mobj], Mobj] +"""A function which updates a :class:`~.Mobject` or :class:`~.OpenGLMobject` on +every frame. +""" + +MobjectDtUpdater: TypeAlias = Callable[[Mobj, float], Mobj] +"""A function which updates a :class:`~.Mobject` or :class:`~.OpenGLMobject` on +every frame, also depending on the frame's duration ``dt``. +""" + +MobjectUpdater: TypeAlias = Union[MobjectBasicUpdater, MobjectDtUpdater] +"""A function which updates a :class:`~.Mobject` or :class:`~.OpenGLMobject` on +every frame, possibly depending on the frame's duration ``dt``. +""" + +MeshBasicUpdater: TypeAlias = Callable[["Object3D"], "Object3D"] +"""A function which updates an :class:`~.Object3D` on every frame.""" + +MeshDtUpdater: TypeAlias = Callable[["Object3D", float], "Object3D"] +"""A function which updates an :class:`~.Object3D` on every frame, also +depending on the frame's duration ``dt``. +""" + +MeshUpdater: TypeAlias = Union[MeshBasicUpdater, MeshDtUpdater] +"""A function which updates an :class:`~.Object3D` on every frame, possibly +depending on the frame's duration ``dt``. +""" + +SceneUpdater: TypeAlias = Callable[[float], None] +"""A function which updates a :class:`~.Scene` on every frame, depending on the +frame's duration ``dt``. +""" + + +class AbstractUpdaterWrapper: + """Base class for :class:`MobjectUpdaterWrapper` and + :class:`MeshUpdaterWrapper`. See :class:`MobjectUpdaterWrapper` for more + information. + + Parameters + ---------- + updater + An updater function, whose first parameter is either a + :class:`~.Mobject` or an :class:`~.Object3D` (parent of + :class:`~.Mesh`), and might optionally have a second parameter which + should be a ``float`` representing a time change ``dt``. This function + should return the same object in the 1st parameter after applying a + change on it. + + Attributes + ---------- + updater + The same updater function passed as a parameter. + is_time_based + Whether :attr:`updater` is a time-based updater or not. + + Raises + ------ + ValueError + If an updater is passed with 0 or more than 2 parameters with no + default values. + """ + + __slots__ = ["updater", "is_time_based"] + + def __init__(self, updater: MobjectUpdater | MeshUpdater): + self.updater = updater + + signature = inspect.signature(updater) + parameters = [str(param) for param in signature.parameters.values()] + + for i, param in enumerate(parameters): + # Stop when finding **kwargs or parameters with default values + if param.startswith("**") or "=" in param: + num_non_default_parameters = i + break + else: + num_non_default_parameters = len(parameters) + + if num_non_default_parameters == 0: + self._raise_error(0) + + # If this is a method being called from an instance, exclude the 1st + # parameter if it's called "self" + if inspect.ismethod(updater) and parameters[0] == "self": + num_non_default_parameters -= 1 + # Exclude the "cls" parameter from class methods + if parameters[0] == "cls": + num_non_default_parameters -= 1 + + # Handle functions containing *args, assuming that all can be passed + # a 2nd parameter dt + if 1 <= num_non_default_parameters <= 3 and parameters[-1].startswith("*"): + num_non_default_parameters = 2 + + if num_non_default_parameters == 1: + self.is_time_based = False + elif num_non_default_parameters == 2: + self.is_time_based = True + else: + self._raise_error(num_non_default_parameters) + + def _raise_error(self, num_non_default_parameters: int): + updater_name = self.updater.__qualname__ + signature = str(inspect.signature(self.updater)) + full_name = updater_name + signature + + if num_non_default_parameters == 0: + num_non_default_parameters = "no" + + raise ValueError( + "An updater function must accept either 1 or 2 parameters without " + "default values (not including 'self' or 'cls' for methods), but " + f"the function {full_name} has {num_non_default_parameters} such " + "parameters." + ) + + +class MobjectUpdaterWrapper(AbstractUpdaterWrapper): + """Wraps a :class:`MobjectUpdater` function, inspects its signature and + calculates whether it's time-based or not. + + If it has a single parameter (with no default value), it's considered a + :class:`MobjectBasicUpdater` which doesn't depend on time. + + If it has two parameters (with no default values), it's considered a + :class:`MobjectDtUpdater` which depends on time, and the affected + :class:`~.Mobject` has a change on every frame which depends on the frame's + duration ``dt``. + + .. note:: + It's not mandatory that the parameters are named ``mob`` and ``dt``. + + **Only parameters with no default values are considered in when determining + whether the updater is time-based or not.** For example, an updater + ``lambda mob, rate=5: ...`` is considered a :class:`MobjectBasicUpdater` + since the 2nd parameter ``rate`` has a default value of 5. + + .. note:: + The above rule allows for passing functions with more than 2 + parameters, as long as the extra parameters have default values. + + A ``ValueError`` is raised if a function is passed which has 0 or more than + 2 parameters with no default values. + + When passing an instance method, the first parameter ``self`` is excluded + from the count. When passing a class method, the first parameter ``cls`` is + also excluded. + + .. note :: + It is fine to call the 1st parameter ``self`` if the updater is not an + instance method: it will still be counted as a parameter. The rule + above only applies for instance methods. + + For example, ``lambda self: self.move_to(square)`` is a valid + :class:`MobjectBasicUpdater`, and ``lambda self, dt: self.rotate(dt)`` + is a :class:`MobjectDtUpdater`. + + However, it is not recommended to name the 1st parameter ``self`` if + the updater is not a method, because it is not Pythonic. A better + option is ``mob``. + + .. warning:: + Do **NOT** name the 1st parameter ``cls`` if the function is not a + class method. Otherwise, the updater will not be parsed correctly. + + Parameters + ---------- + updater + An updater function, whose first parameter is a :class:`~.Mobject` and + might optionally have a second parameter which should be a ``float`` + representing a time change ``dt``. This function should return the same + :class:`~.Mobject` after applying a change on it. + + Attributes + ---------- + updater + The same updater function passed as a parameter. + is_time_based + Whether :attr:`updater` is a time-based updater or not. + + Raises + ------ + ValueError + If an updater is passed with 0 or more than 2 parameters with no + default values. + """ + + def __init__(self, updater: MobjectUpdater): + super().__init__(updater) + + +class MeshUpdaterWrapper(AbstractUpdaterWrapper): + """Similar to :class:`MobjectUpdaterWrapper`, but wraps instead a + :class:`MeshUpdater` which modifies an :class:`~.Object3D`, parent of + :class:`~.Mesh`. See the docs for :class:`MobjectUpdaterWrapper` + for more information. + + Parameters + ---------- + updater + An updater function, whose first parameter is an :class:`~.Object3D` + (parent of :class:`~.Mesh`) and might optionally have a second parameter + which should be a ``float`` representing a time change ``dt``. This + function should return the same :class:`~.Object3D` after applying a + change on it. + + Attributes + ---------- + updater + The same updater function passed as a parameter. + is_time_based + Whether :attr:`updater` is a time-based updater or not. + + Raises + ------ + ValueError + If an updater is passed with 0 or more than 2 parameters with no + default values. + """ + + def __init__(self, updater: MeshUpdater): + super().__init__(updater) diff --git a/tests/module/utils/test_updaters.py b/tests/module/utils/test_updaters.py new file mode 100644 index 0000000000..42b07d4537 --- /dev/null +++ b/tests/module/utils/test_updaters.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +import pytest + +from manim.constants import RIGHT +from manim.mobject.geometry.polygram import Square +from manim.mobject.graph import Graph +from manim.utils.updaters import MobjectUpdaterWrapper + + +def test_UpdaterWrapper() -> None: + square = Square().move_to(RIGHT) + + # Non-time-based updater: 1 parameter with no default value + wrapper = MobjectUpdaterWrapper(lambda mob: mob.next_to(square)) + assert not wrapper.is_time_based + + # Time-based updater: 2 parameters with no default value + wrapper = MobjectUpdaterWrapper(lambda mob, dt: mob.rotate(dt)) + assert wrapper.is_time_based + + # It's not necessary for the 2nd parameter to be called dt + wrapper = MobjectUpdaterWrapper(lambda mob, delta_time: mob.rotate(delta_time)) + assert wrapper.is_time_based + + # An updater can even have more than 2 parameters, as long as they have + # default values + wrapper = MobjectUpdaterWrapper(lambda mob, dt, rate=2: mob.rotate(rate * dt)) + assert wrapper.is_time_based + + # An updater cannot have no parameters + with pytest.raises(ValueError): + wrapper = MobjectUpdaterWrapper(lambda: square.move_to(RIGHT)) + + # An updater cannot have more than 2 parameters without a default value + with pytest.raises(ValueError): + wrapper = MobjectUpdaterWrapper(lambda mob, rate, third: mob.rotate(rate)) + + # Only parameters with no default value are considered when determining + # whether the updater is time-based or not. If an updater has 2 parameters, + # but the 2nd one has a default value, it's considered non-time-based. + wrapper = MobjectUpdaterWrapper(lambda mob, other=square: mob.next_to(other)) + assert not wrapper.is_time_based + + # When using an instance method, the first argument is ignored if it's + # called 'self'. This is an attempt to exclude static methods from this + # rule. + graph = Graph([1, 2], [(1, 2)]) + wrapper = MobjectUpdaterWrapper(graph.update_edges) # signature: (self, graph) + assert not wrapper.is_time_based # since only the 'graph' param is counted + + # This doesn't happen when calling the method from the class rather than + # from an instance. + wrapper = MobjectUpdaterWrapper(Graph.update_edges) # signature: (self, graph) + # 'self' is the 1st, 'graph' is the 2nd, (incorrectly) considered as time + assert wrapper.is_time_based + + # In general, if the function is not an instance method, the 1st parameter + # is almost always included, even if it's called "self". + wrapper = MobjectUpdaterWrapper(lambda self: self.move_to(square)) + assert not wrapper.is_time_based + wrapper = MobjectUpdaterWrapper(lambda self, dt: self.rotate(dt)) + assert wrapper.is_time_based + + # The only exception is if it's called "cls". Don't call it "cls" if it's + # not a class method. + wrapper = MobjectUpdaterWrapper(lambda cls, dt: cls.rotate(dt)) + assert ( + not wrapper.is_time_based + ) # Only 1 parameter, dt, is considered, and it's used as a Mobject, not float + with pytest.raises(ValueError): + # Since cls is excluded, there are no other parameters + wrapper = MobjectUpdaterWrapper(lambda cls: cls.next_to(square))