From 090042c25d2f6c7cf6b6f54930cd8bd4e5fd0cca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Manr=C3=ADquez?= Date: Sat, 15 Jun 2024 22:48:11 -0400 Subject: [PATCH 01/10] Fix updaters, add UpdaterWrappers and some type aliases --- .../source/reference_index/utilities_misc.rst | 1 + manim/animation/speedmodifier.py | 17 +- manim/mobject/mobject.py | 90 +++++------ manim/mobject/opengl/opengl_mobject.py | 87 +++++----- manim/renderer/shader.py | 97 ++++++----- manim/scene/scene.py | 56 ++++--- manim/utils/updaters.py | 153 ++++++++++++++++++ tests/module/utils/test_updaters.py | 63 ++++++++ 8 files changed, 409 insertions(+), 155 deletions(-) create mode 100644 manim/utils/updaters.py create mode 100644 tests/module/utils/test_updaters.py diff --git a/docs/source/reference_index/utilities_misc.rst b/docs/source/reference_index/utilities_misc.rst index 1fe9962079..5e7bee44ca 100644 --- a/docs/source/reference_index/utilities_misc.rst +++ b/docs/source/reference_index/utilities_misc.rst @@ -31,3 +31,4 @@ Module Index ~utils.tex_file_writing ~utils.tex_templates typing + ~utils.updaters diff --git a/manim/animation/speedmodifier.py b/manim/animation/speedmodifier.py index 63b9b2e5b3..e28bf883e9 100644 --- a/manim/animation/speedmodifier.py +++ b/manim/animation/speedmodifier.py @@ -2,19 +2,19 @@ from __future__ import annotations -import inspect import types from typing import TYPE_CHECKING, Callable from numpy import piecewise -from ..animation.animation import Animation, Wait, prepare_animation -from ..animation.composition import AnimationGroup -from ..mobject.mobject import Mobject, _AnimationBuilder -from ..scene.scene import Scene +from manim.animation.animation import Animation, Wait, prepare_animation +from manim.animation.composition import AnimationGroup +from manim.mobject.mobject import Mobject, _AnimationBuilder +from manim.scene.scene import Scene +from manim.utils.updaters import MobjectUpdaterWrapper if TYPE_CHECKING: - from ..mobject.mobject import Updater + from manim.utils.updaters import MobjectUpdater __all__ = ["ChangeSpeed"] @@ -235,7 +235,7 @@ def get_scaled_total_time(self) -> float: def add_updater( cls, mobject: Mobject, - update_function: Updater, + update_function: MobjectUpdater, index: int | None = None, call_updater: bool = False, ): @@ -264,7 +264,8 @@ def add_updater( :class:`.ChangeSpeed` :meth:`.Mobject.add_updater` """ - if "dt" in inspect.signature(update_function).parameters: + wrapper = MobjectUpdaterWrapper(update_function) + if wrapper.is_time_based: mobject.add_updater( lambda mob, dt: update_function( mob, ChangeSpeed.dt if ChangeSpeed.is_changing_dt else dt diff --git a/manim/mobject/mobject.py b/manim/mobject/mobject.py index 6e405a18eb..ce5ff40a6f 100644 --- a/manim/mobject/mobject.py +++ b/manim/mobject/mobject.py @@ -6,7 +6,6 @@ import copy -import inspect import itertools as it import math import operator as op @@ -14,18 +13,17 @@ import sys import types import warnings -from collections.abc import Iterable +from collections.abc import Iterable, Sequence from functools import partialmethod, reduce from pathlib import Path from typing import TYPE_CHECKING, Callable, Literal import numpy as np +from manim import config, logger +from manim.constants import * from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL - -from .. import config, logger -from ..constants import * -from ..utils.color import ( +from manim.utils.color import ( BLACK, WHITE, YELLOW_C, @@ -34,13 +32,14 @@ color_gradient, interpolate_color, ) -from ..utils.exceptions import MultiAnimationOverrideException -from ..utils.iterables import list_update, remove_list_redundancies -from ..utils.paths import straight_path -from ..utils.space_ops import angle_between_vectors, normalize, rotation_matrix +from manim.utils.exceptions import MultiAnimationOverrideException +from manim.utils.iterables import list_update, remove_list_redundancies +from manim.utils.paths import straight_path +from manim.utils.space_ops import angle_between_vectors, normalize, rotation_matrix +from manim.utils.updaters import MobjectUpdaterWrapper if TYPE_CHECKING: - from typing_extensions import Self, TypeAlias + from typing_extensions import Self from manim.typing import ( FunctionOverride, @@ -53,13 +52,10 @@ Point3D_Array, Vector3D, ) + from manim.utils.updaters import MobjectTimeBasedUpdater, MobjectUpdater from ..animation.animation import Animation - TimeBasedUpdater: TypeAlias = Callable[["Mobject", float], object] - NonTimeBasedUpdater: TypeAlias = Callable[["Mobject"], object] - Updater: TypeAlias = NonTimeBasedUpdater | TimeBasedUpdater - class Mobject: """Mathematical Object: base class for objects that can be displayed on screen. @@ -71,6 +67,7 @@ class Mobject: Attributes ---------- submobjects : List[:class:`Mobject`] + The contained objects. points : :class:`numpy.ndarray` The points of the objects. @@ -96,7 +93,7 @@ def __init_subclass__(cls, **kwargs) -> None: def __init__( self, - color: ParsableManimColor | list[ParsableManimColor] = WHITE, + color: ParsableManimColor | Sequence[ParsableManimColor] = WHITE, name: str | None = None, dim: int = 3, target=None, @@ -108,7 +105,7 @@ def __init__( self.z_index = z_index self.point_hash = None self.submobjects = [] - self.updaters: list[Updater] = [] + self.updater_wrappers: Sequence[MobjectUpdaterWrapper] = [] self.updating_suspended = False self.color = ManimColor.parse(color) @@ -868,6 +865,10 @@ def generate_target(self, use_deepcopy: bool = False) -> Self: # Updating + @property + def updaters(self) -> Sequence[MobjectUpdater]: + return self.get_updaters() + def update(self, dt: float = 0, recursive: bool = True) -> Self: """Apply all updaters. @@ -894,17 +895,17 @@ def update(self, dt: float = 0, recursive: bool = True) -> Self: """ if self.updating_suspended: return self - for updater in self.updaters: - if "dt" in inspect.signature(updater).parameters: - updater(self, dt) + for wrapper in self.updater_wrappers: + if wrapper.is_time_based: + wrapper.updater(self, dt) else: - updater(self) + wrapper.updater(self) if recursive: for submob in self.submobjects: submob.update(dt, recursive) return self - def get_time_based_updaters(self) -> list[TimeBasedUpdater]: + def get_time_based_updaters(self) -> Sequence[MobjectTimeBasedUpdater]: """Return all updaters using the ``dt`` parameter. The updaters use this parameter as the input for difference in time. @@ -920,11 +921,7 @@ def get_time_based_updaters(self) -> list[TimeBasedUpdater]: :meth:`has_time_based_updater` """ - return [ - updater - for updater in self.updaters - if "dt" in inspect.signature(updater).parameters - ] + return [wrapper.updater for wrapper in self.updater_wrappers if t.is_time_based] def has_time_based_updater(self) -> bool: """Test if ``self`` has a time based updater. @@ -940,11 +937,9 @@ def has_time_based_updater(self) -> bool: :meth:`get_time_based_updaters` """ - return any( - "dt" in inspect.signature(updater).parameters for updater in self.updaters - ) + return any(wrapper.is_time_based for wrapper in self.updater_wrappers) - def get_updaters(self) -> list[Updater]: + def get_updaters(self) -> Sequence[MobjectUpdater]: """Return all updaters. Returns @@ -958,14 +953,14 @@ def get_updaters(self) -> list[Updater]: :meth:`get_time_based_updaters` """ - return self.updaters + return [wrapper.updater for wrapper in self.updater_wrappers] - def get_family_updaters(self) -> list[Updater]: + def get_family_updaters(self) -> Sequence[MobjectUpdater]: return list(it.chain(*(sm.get_updaters() for sm in self.get_family()))) def add_updater( self, - update_function: Updater, + update_function: MobjectUpdater, index: int | None = None, call_updater: bool = False, ) -> Self: @@ -1029,20 +1024,19 @@ def construct(self): :meth:`remove_updater` :class:`~.UpdateFromFunc` """ - + wrapper = MobjectUpdaterWrapper(update_function) if index is None: - self.updaters.append(update_function) + self.updater_wrappers.append(wrapper) else: - self.updaters.insert(index, update_function) + self.updater_wrappers.insert(index, wrapper) if call_updater: - parameters = inspect.signature(update_function).parameters - if "dt" in parameters: - update_function(self, 0) + if wrapper.is_time_based: + wrapper.updater(self, 0) else: - update_function(self) + wrapper.updater(self) return self - def remove_updater(self, update_function: Updater) -> Self: + def remove_updater(self, update_function: MobjectUpdater) -> Self: """Remove an updater. If the same updater is applied multiple times, every instance gets removed. @@ -1065,8 +1059,11 @@ def remove_updater(self, update_function: Updater) -> Self: :meth:`get_updaters` """ - while update_function in self.updaters: - self.updaters.remove(update_function) + self.updater_wrappers = [ + wrapper + for wrapper in self.updater_wrappers + if wrapper.updater != update_function + ] return self def clear_updaters(self, recursive: bool = True) -> Self: @@ -1089,7 +1086,7 @@ def clear_updaters(self, recursive: bool = True) -> Self: :meth:`get_updaters` """ - self.updaters = [] + self.updater_wrappers = [] if recursive: for submob in self.submobjects: submob.clear_updaters() @@ -1121,8 +1118,7 @@ def match_updaters(self, mobject: Mobject) -> Self: """ self.clear_updaters() - for updater in mobject.get_updaters(): - self.add_updater(updater) + self.updater_wrappers = mobject.updater_wrappers.copy() return self def suspend_updating(self, recursive: bool = True) -> Self: diff --git a/manim/mobject/opengl/opengl_mobject.py b/manim/mobject/opengl/opengl_mobject.py index 932b1d0d10..b08a827e4b 100644 --- a/manim/mobject/opengl/opengl_mobject.py +++ b/manim/mobject/opengl/opengl_mobject.py @@ -1,13 +1,13 @@ from __future__ import annotations import copy -import inspect import itertools as it import random import sys from collections.abc import Iterable, Sequence from functools import partialmethod, wraps from math import ceil +from typing import TYPE_CHECKING import moderngl import numpy as np @@ -43,6 +43,10 @@ normalize, rotation_matrix_transpose, ) +from manim.utils.updaters import MobjectUpdaterWrapper + +if TYPE_CHECKING: + from manim.utils.updaters import MobjectTimeBasedUpdater, MobjectUpdater def affects_shader_info_id(func): @@ -1390,82 +1394,91 @@ def restore(self): # Updating - def init_updaters(self): - self.time_based_updaters = [] - self.non_time_updaters = [] + def init_updaters(self) -> None: + self.updater_wrappers: Sequence[MobjectUpdaterWrapper] = [] self.has_updaters = False self.updating_suspended = False - def update(self, dt=0, recurse=True): + @property + def updaters(self) -> Sequence[MobjectUpdater]: + return self.get_updaters() + + def update(self, dt: float = 0, recurse: bool = True) -> Self: if not self.has_updaters or self.updating_suspended: return self - for updater in self.time_based_updaters: - updater(self, dt) - for updater in self.non_time_updaters: - updater(self) + for wrapper in self.updater_wrappers: + if wrapper.is_time_based: + wrapper.updater(self, dt) + else: + wrapper.updater(self) if recurse: for submob in self.submobjects: submob.update(dt, recurse) return self - def get_time_based_updaters(self): - return self.time_based_updaters + def get_time_based_updaters(self) -> Sequence[MobjectTimeBasedUpdater]: + return [ + wrapper.updater + for wrapper in self.updater_wrappers + if wrapper.is_time_based + ] - def has_time_based_updater(self): - return len(self.time_based_updaters) > 0 + def has_time_based_updater(self) -> bool: + return any(wrapper.is_time_based for wrapper in self.updater_wrappers) - def get_updaters(self): - return self.time_based_updaters + self.non_time_updaters + def get_updaters(self) -> Sequence[MobjectUpdaterWrapper]: + return [wrapper.updater for wrapper in self.updater_wrappers] - def get_family_updaters(self): + def get_family_updaters(self) -> Sequence[MobjectUpdaterWrapper]: return list(it.chain(*(sm.get_updaters() for sm in self.get_family()))) - def add_updater(self, update_function, index=None, call_updater=False): - if "dt" in inspect.signature(update_function).parameters: - updater_list = self.time_based_updaters - else: - updater_list = self.non_time_updaters - + def add_updater( + self, + update_function: MobjectUpdater, + index: int | None = None, + call_updater: bool = False, + ) -> Self: + wrapper = MobjectUpdaterWrapper(update_function) if index is None: - updater_list.append(update_function) + self.updater_wrappers.append(wrapper) else: - updater_list.insert(index, update_function) + self.updater_wrappers.insert(index, wrapper) self.refresh_has_updater_status() if call_updater: self.update() return self - def remove_updater(self, update_function): - for updater_list in [self.time_based_updaters, self.non_time_updaters]: - while update_function in updater_list: - updater_list.remove(update_function) + def remove_updater(self, update_function: MobjectUpdater) -> Self: + self.updater_wrappers = [ + wrapper + for wrapper in self.updater_wrappers + if wrapper.updater != update_function + ] self.refresh_has_updater_status() return self - def clear_updaters(self, recurse=True): - self.time_based_updaters = [] - self.non_time_updaters = [] + def clear_updaters(self, recurse: bool = True) -> Self: + self.updater_wrappers = [] self.refresh_has_updater_status() if recurse: for submob in self.submobjects: submob.clear_updaters() return self - def match_updaters(self, mobject): + def match_updaters(self, mobject: OpenGLMobject) -> Self: self.clear_updaters() - for updater in mobject.get_updaters(): - self.add_updater(updater) + self.updater_wrappers = mobject.updater_wrappers.copy() return self - def suspend_updating(self, recurse=True): + def suspend_updating(self, recurse: bool = True) -> Self: self.updating_suspended = True if recurse: for submob in self.submobjects: submob.suspend_updating(recurse) return self - def resume_updating(self, recurse=True, call_updater=True): + def resume_updating(self, recurse: bool = True, call_updater: bool = True) -> Self: self.updating_suspended = False if recurse: for submob in self.submobjects: @@ -1476,7 +1489,7 @@ def resume_updating(self, recurse=True, call_updater=True): self.update(dt=0, recurse=recurse) return self - def refresh_has_updater_status(self): + def refresh_has_updater_status(self) -> Self: self.has_updaters = any(mob.get_updaters() for mob in self.get_family()) return self diff --git a/manim/renderer/shader.py b/manim/renderer/shader.py index 85b9dad14a..7b0078d5cc 100644 --- a/manim/renderer/shader.py +++ b/manim/renderer/shader.py @@ -1,15 +1,22 @@ from __future__ import annotations -import inspect import re import textwrap +from collections.abc import Sequence from pathlib import Path +from typing import TYPE_CHECKING import moderngl import numpy as np -from .. import config -from ..utils import opengl +from manim import config +from manim.utils import opengl +from manim.utils.updaters import MeshUpdaterWrapper + +if TYPE_CHECKING: + from typing_extensions import Self + + from manim.utils.updaters import MeshTimeBasedUpdater, MeshUpdater SHADER_FOLDER = Path(__file__).parent / "shaders" shader_program_cache: dict = {} @@ -174,76 +181,90 @@ def hierarchical_normal_matrix(self): current_object = current_object.parent return np.linalg.multi_dot(list(reversed(normal_matrices)))[:3, :3] - def init_updaters(self): - self.time_based_updaters = [] - self.non_time_updaters = [] + # Updating + + @property + def updaters(self) -> Sequence[MeshUpdater]: + return self.get_updaters() + + def init_updaters(self) -> None: + self.updater_wrappers: Sequence[MeshUpdaterWrapper] = [] self.has_updaters = False self.updating_suspended = False - def update(self, dt=0): + def update(self, dt: float = 0) -> Self: if not self.has_updaters or self.updating_suspended: return self - for updater in self.time_based_updaters: - updater(self, dt) - for updater in self.non_time_updaters: - updater(self) + for wrapper in self.updater_wrappers: + if wrapper.is_time_based: + wrapper.updater(self, dt) + else: + wrapper.updater(self) return self - def get_time_based_updaters(self): - return self.time_based_updaters - - def has_time_based_updater(self): - return len(self.time_based_updaters) > 0 + def get_time_based_updaters(self) -> Sequence[MeshTimeBasedUpdater]: + return [ + wrapper.updater + for wrapper in self.updater_wrappers + if wrapper.is_time_based + ] - def get_updaters(self): - return self.time_based_updaters + self.non_time_updaters + def has_time_based_updater(self) -> bool: + return any(wrapper.is_time_based for wrapper in self.updater_wrappers) - def add_updater(self, update_function, index=None, call_updater=True): - if "dt" in inspect.signature(update_function).parameters: - updater_list = self.time_based_updaters - else: - updater_list = self.non_time_updaters + def get_updaters(self) -> Sequence[MeshUpdater]: + return [wrapper.updater for wrapper in self.updater_wrappers] + def add_updater( + self, + update_function: MeshUpdater, + index: int | None = None, + call_updater: bool = False, + ) -> Self: + wrapper = MeshUpdaterWrapper(update_function) if index is None: - updater_list.append(update_function) + self.updater_wrappers.append(wrapper) else: - updater_list.insert(index, update_function) + self.updater_wrappers.insert(index, wrapper) self.refresh_has_updater_status() if call_updater: self.update() return self - def remove_updater(self, update_function): - for updater_list in [self.time_based_updaters, self.non_time_updaters]: - while update_function in updater_list: - updater_list.remove(update_function) + def remove_updater(self, update_function: MeshUpdater) -> Self: + self.updater_wrappers = [ + wrapper + for wrapper in self.updater_wrappers + if wrapper.updater != update_function + ] self.refresh_has_updater_status() return self - def clear_updaters(self): - self.time_based_updaters = [] - self.non_time_updaters = [] + def clear_updaters(self, recurse: bool = True) -> Self: + self.updater_wrappers = [] self.refresh_has_updater_status() + if recurse: + for submob in self.submobjects: + submob.clear_updaters() return self - def match_updaters(self, mobject): + def match_updaters(self, obj: Object3D) -> Self: self.clear_updaters() - for updater in mobject.get_updaters(): - self.add_updater(updater) + self.updater_wrappers = obj.updater_wrappers.copy() return self - def suspend_updating(self): + def suspend_updating(self) -> Self: self.updating_suspended = True return self - def resume_updating(self, call_updater=True): + def resume_updating(self, call_updater: bool = True) -> Self: self.updating_suspended = False if call_updater: self.update(dt=0) return self - def refresh_has_updater_status(self): + def refresh_has_updater_status(self) -> Self: self.has_updaters = len(self.get_updaters()) > 0 return self diff --git a/manim/scene/scene.py b/manim/scene/scene.py index 3f80c91864..a13eab161a 100644 --- a/manim/scene/scene.py +++ b/manim/scene/scene.py @@ -33,28 +33,30 @@ from watchdog.events import FileSystemEventHandler from watchdog.observers import Observer +from manim import config, logger +from manim.animation.animation import Animation, Wait, prepare_animation +from manim.camera.camera import Camera +from manim.constants import * +from manim.gui.gui import configure_pygui from manim.mobject.mobject import Mobject from manim.mobject.opengl.opengl_mobject import OpenGLPoint - -from .. import config, logger -from ..animation.animation import Animation, Wait, prepare_animation -from ..camera.camera import Camera -from ..constants import * -from ..gui.gui import configure_pygui -from ..renderer.cairo_renderer import CairoRenderer -from ..renderer.opengl_renderer import OpenGLRenderer -from ..renderer.shader import Object3D -from ..utils import opengl, space_ops -from ..utils.exceptions import EndSceneEarlyException, RerunSceneException -from ..utils.family import extract_mobject_family_members -from ..utils.family_ops import restructure_list_to_exclude_certain_family_members -from ..utils.file_ops import open_media_file -from ..utils.iterables import list_difference_update, list_update +from manim.renderer.cairo_renderer import CairoRenderer +from manim.renderer.opengl_renderer import OpenGLRenderer +from manim.renderer.shader import Object3D +from manim.utils import opengl, space_ops +from manim.utils.exceptions import EndSceneEarlyException, RerunSceneException +from manim.utils.family import extract_mobject_family_members +from manim.utils.family_ops import restructure_list_to_exclude_certain_family_members +from manim.utils.file_ops import open_media_file +from manim.utils.iterables import list_difference_update, list_update +from manim.utils.updaters import MobjectUpdaterWrapper if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Iterable, Sequence from typing import Callable + from manim.utils.updaters import SceneUpdater + class RerunSceneHandler(FileSystemEventHandler): """A class to handle rerunning a Scene after the input file is modified.""" @@ -124,7 +126,7 @@ def __init__( self.camera_target = ORIGIN self.widgets = [] self.dearpygui_imported = dearpygui_imported - self.updaters = [] + self.updaters: Sequence[SceneUpdater] = [] self.point_lights = [] self.ambient_light = None self.key_to_function_map = {} @@ -168,12 +170,14 @@ def __deepcopy__(self, clone_from_id): if k == "camera_class": setattr(result, k, v) setattr(result, k, copy.deepcopy(v, clone_from_id)) + # TODO: where is this attribute even defined? result.mobject_updater_lists = [] # Update updaters for mobject in self.mobjects: - cloned_updaters = [] - for updater in mobject.updaters: + cloned_updater_wrappers = [] + for wrapper in mobject.updater_wrappers: + updater = wrapper.updater # Make the cloned updater use the cloned Mobjects as free variables # rather than the original ones. Analyzing function bytecode with the # dis module will help in understanding this. @@ -209,11 +213,13 @@ def __deepcopy__(self, clone_from_id): updater.__defaults__, tuple(cloned_closure), ) - cloned_updaters.append(cloned_updater) + cloned_updater_wrappers.append(MobjectUpdaterWrapper(cloned_updater)) mobject_clone = clone_from_id[id(mobject)] - mobject_clone.updaters = cloned_updaters - if len(cloned_updaters) > 0: - result.mobject_updater_lists.append((mobject_clone, cloned_updaters)) + mobject_clone.updater_wrappers = cloned_updater_wrappers + if len(cloned_updater_wrappers) > 0: + result.mobject_updater_lists.append( + (mobject_clone, cloned_updater_wrappers) + ) return result def render(self, preview: bool = False): @@ -572,7 +578,7 @@ def replace_in_list( if not replaced: raise ValueError(f"Could not find {old_mobject} in scene") - def add_updater(self, func: Callable[[float], None]) -> None: + def add_updater(self, func: SceneUpdater) -> None: """Add an update function to the scene. The scene updater functions are run every frame, @@ -603,7 +609,7 @@ def add_updater(self, func: Callable[[float], None]) -> None: """ self.updaters.append(func) - def remove_updater(self, func: Callable[[float], None]) -> None: + def remove_updater(self, func: SceneUpdater) -> None: """Remove an update function from the scene. Parameters diff --git a/manim/utils/updaters.py b/manim/utils/updaters.py new file mode 100644 index 0000000000..cd1e706e49 --- /dev/null +++ b/manim/utils/updaters.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +import inspect +from collections.abc import Callable +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + from manim.mobject.mobject import Mobject + from manim.opengl.shader import Object3D + + +__all__ = [ + "MobjectTimeBasedUpdater", + "MobjectNonTimeBasedUpdater", + "MobjectUpdater", + "MeshTimeBasedUpdater", + "MeshNonTimeBasedUpdater", + "MeshUpdater", + "SceneUpdater", + "MobjectUpdaterWrapper", + "MeshUpdaterWrapper", +] + + +MobjectTimeBasedUpdater: TypeAlias = Callable[["Mobject", float], "Mobject"] +MobjectNonTimeBasedUpdater: TypeAlias = Callable[["Mobject"], "Mobject"] +MobjectUpdater: TypeAlias = MobjectNonTimeBasedUpdater | MobjectTimeBasedUpdater + +MeshTimeBasedUpdater: TypeAlias = Callable[["Object3D", float], "Object3D"] +MeshNonTimeBasedUpdater: TypeAlias = Callable[["Object3D"], "Object3D"] +MeshUpdater: TypeAlias = MeshNonTimeBasedUpdater | MeshTimeBasedUpdater + +SceneUpdater: TypeAlias = Callable[[float], None] + + +class AbstractUpdaterWrapper: + """Wraps an :class:`Updater` function, inspects its signature and + calculates whether it's time-based or not. + + If it has a single parameter (with no default value), it's considered + non-time-based: it doesn't depend on time. + + If it has two parameters (with no default values), it's considered + time-based: it depends on time, and the affected Mobject has a change + on every frame which depends on the frame's duration dt. + + .. note:: + It's not mandatory that the parameters are named ``mob`` and ``dt``. + + **Only parameters with no default values are considered in when determining + whether the updater is time-based or not.** For example, an updater + ``lambda mob, rate=5: ...`` is considered non-time-based since the 2nd + parameter ``rate`` has a default value of 5. **This allows for passing + functions with more than 2 parameters, as long as the extra parameters have + default values.** + + A ``ValueError`` is raised if a function is passed which has 0 or more than + 2 parameters with no default values. + + When passing an instance method, the first parameter `self` is excluded + from the count. When passing a class method, the first parameter `cls` is + also excluded. + + .. note :: + It is fine to call the 1st parameter ``self`` if the updater is not an + instance method: it will still be counted as a parameter. The rule + above only applies for instance methods. For example, + ``lambda self: self.move_to(square)`` is a valid non-time-based + updater, and ``lambda self, dt: self.rotate(dt)`` is time-based. + + .. warning:: + Do **NOT** name the 1st parameter ``cls`` if the function is not a class + method. + + Attributes + ---------- + updater + An updater function, whose first parameter is a Mobject and might + optionally have a second parameter which should be a float + representing a time change dt. + + Raises + ------ + ValueError + If an updater is passed with 0 or more than 2 parameters with no + default values. + """ + + __slots__ = ["updater", "is_time_based"] + + def __init__(self, updater: MobjectUpdater | MeshUpdater): + self.updater = updater + + signature = inspect.signature(updater) + parameters = [str(param) for param in signature.parameters.values()] + + for i, param in enumerate(parameters): + # Stop when finding **kwargs or parameters with default values + if param.startswith("**") or "=" in param: + num_non_default_parameters = i + break + else: + num_non_default_parameters = len(parameters) + + if num_non_default_parameters == 0: + self._raise_error(0) + + # If this is a method being called from an instance, exclude the 1st + # parameter if it's called "self" + if inspect.ismethod(updater) and parameters[0] == "self": + num_non_default_parameters -= 1 + # Exclude the "cls" parameter from class methods + if parameters[0] == "cls": + num_non_default_parameters -= 1 + + # Handle functions containing *args, assuming that all can be passed + # a 2nd parameter dt + if 1 <= num_non_default_parameters <= 3 and parameters[-1].startswith("*"): + num_non_default_parameters = 2 + + if num_non_default_parameters == 1: + self.is_time_based = False + elif num_non_default_parameters == 2: + self.is_time_based = True + else: + self._raise_error(num_non_default_parameters) + + def _raise_error(self, num_non_default_parameters: int): + updater_name = self.updater.__code__.co_qualname + signature = str(inspect.signature(self.updater)) + full_name = updater_name + signature + + if num_non_default_parameters == 0: + num_non_default_parameters = "no" + + raise ValueError( + "An updater function must accept either 1 or 2 parameters without " + "default values (not including 'self' or 'cls' for methods), but " + f"the function {full_name} has {num_non_default_parameters} such " + "parameters." + ) + + +class MobjectUpdaterWrapper(AbstractUpdaterWrapper): + def __init__(self, updater: MobjectUpdater): + super().__init__(updater) + + +class MeshUpdaterWrapper(AbstractUpdaterWrapper): + def __init__(self, updater: MeshUpdater): + super().__init__(updater) diff --git a/tests/module/utils/test_updaters.py b/tests/module/utils/test_updaters.py new file mode 100644 index 0000000000..2e19fe63de --- /dev/null +++ b/tests/module/utils/test_updaters.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import pytest + +from manim.constants import RIGHT +from manim.mobject.geometry.polygram import Square +from manim.mobject.graph import Graph +from manim.utils.updaters import MobjectUpdaterWrapper + + +def test_UpdaterWrapper() -> None: + square = Square().move_to(RIGHT) + + # Non-time-based updater: 1 parameter with no default value + wrapper = MobjectUpdaterWrapper(lambda mob: mob.next_to(square)) + assert not wrapper.is_time_based + + # Time-based updater: 2 parameters with no default value + wrapper = MobjectUpdaterWrapper(lambda mob, dt: mob.rotate(dt)) + assert wrapper.is_time_based + + # It's not necessary for the 2nd parameter to be called dt + wrapper = MobjectUpdaterWrapper(lambda mob, delta_time: mob.rotate(delta_time)) + assert wrapper.is_time_based + + # An updater can even have more than 2 parameters, as long as they have + # default values + wrapper = MobjectUpdaterWrapper(lambda mob, dt, rate=2: mob.rotate(rate * dt)) + assert wrapper.is_time_based + + # An updater cannot have no parameters + with pytest.raises(ValueError) as no_parameters_info: + wrapper = MobjectUpdaterWrapper(lambda: "Hello world") + + # An updater cannot have more than 2 parameters without a default value + with pytest.raises(ValueError) as three_parameters_info: + wrapper = MobjectUpdaterWrapper(lambda mob, rate, third: mob.rotate(rate)) + + # Only parameters with no default value are considered when determining + # whether the updater is time-based or not. If an updater has 2 parameters, + # but the 2nd one has a default value, it's considered non-time-based. + wrapper = MobjectUpdaterWrapper(lambda mob, other=square: mob.next_to(other)) + assert not wrapper.is_time_based + + # When using an instance method, the first argument is ignored if it's + # called 'self'. This is an attempt to exclude static methods from this + # rule. + graph = Graph([1, 2], [(1, 2)]) + wrapper = MobjectUpdaterWrapper(graph.update_edges) # signature: (self, graph) + assert not wrapper.is_time_based # since only the 'graph' param is counted + + # This doesn't happen when calling the method from the class rather than + # from an instance. + wrapper = MobjectUpdaterWrapper(Graph.update_edges) # signature: (self, graph) + # 'self' is the 1st, 'graph' is the 2nd, (incorrectly) considered as time + assert wrapper.is_time_based + + # In general, if the function is not an instance method, the 1st parameter + # is always included, even if it's called "self". + wrapper = MobjectUpdaterWrapper(lambda self: self.move_to(square)) + assert not wrapper.is_time_based + wrapper = MobjectUpdaterWrapper(lambda self, dt: self.rotate(dt)) + assert wrapper.is_time_based From f97929cf5307bc5f0e0eb44112325e880303870d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Manr=C3=ADquez?= Date: Sat, 15 Jun 2024 23:49:33 -0400 Subject: [PATCH 02/10] Address errors --- manim/utils/updaters.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/manim/utils/updaters.py b/manim/utils/updaters.py index cd1e706e49..aa72968112 100644 --- a/manim/utils/updaters.py +++ b/manim/utils/updaters.py @@ -2,7 +2,7 @@ import inspect from collections.abc import Callable -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union if TYPE_CHECKING: from typing_extensions import TypeAlias @@ -26,11 +26,11 @@ MobjectTimeBasedUpdater: TypeAlias = Callable[["Mobject", float], "Mobject"] MobjectNonTimeBasedUpdater: TypeAlias = Callable[["Mobject"], "Mobject"] -MobjectUpdater: TypeAlias = MobjectNonTimeBasedUpdater | MobjectTimeBasedUpdater +MobjectUpdater: TypeAlias = Union[MobjectNonTimeBasedUpdater, MobjectTimeBasedUpdater] MeshTimeBasedUpdater: TypeAlias = Callable[["Object3D", float], "Object3D"] MeshNonTimeBasedUpdater: TypeAlias = Callable[["Object3D"], "Object3D"] -MeshUpdater: TypeAlias = MeshNonTimeBasedUpdater | MeshTimeBasedUpdater +MeshUpdater: TypeAlias = Union[MeshNonTimeBasedUpdater, MeshTimeBasedUpdater] SceneUpdater: TypeAlias = Callable[[float], None] @@ -128,7 +128,7 @@ def __init__(self, updater: MobjectUpdater | MeshUpdater): self._raise_error(num_non_default_parameters) def _raise_error(self, num_non_default_parameters: int): - updater_name = self.updater.__code__.co_qualname + updater_name = self.updater.__qualname__ signature = str(inspect.signature(self.updater)) full_name = updater_name + signature From 8165502fdba87b748b03e04d8b4001e04912343b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Manr=C3=ADquez?= Date: Sun, 16 Jun 2024 00:21:58 -0400 Subject: [PATCH 03/10] Move docstrings to MobjectUpdaterWrapper --- manim/utils/updaters.py | 104 ++++++++++++++++++++-------------------- 1 file changed, 52 insertions(+), 52 deletions(-) diff --git a/manim/utils/updaters.py b/manim/utils/updaters.py index aa72968112..b285098493 100644 --- a/manim/utils/updaters.py +++ b/manim/utils/updaters.py @@ -36,58 +36,6 @@ class AbstractUpdaterWrapper: - """Wraps an :class:`Updater` function, inspects its signature and - calculates whether it's time-based or not. - - If it has a single parameter (with no default value), it's considered - non-time-based: it doesn't depend on time. - - If it has two parameters (with no default values), it's considered - time-based: it depends on time, and the affected Mobject has a change - on every frame which depends on the frame's duration dt. - - .. note:: - It's not mandatory that the parameters are named ``mob`` and ``dt``. - - **Only parameters with no default values are considered in when determining - whether the updater is time-based or not.** For example, an updater - ``lambda mob, rate=5: ...`` is considered non-time-based since the 2nd - parameter ``rate`` has a default value of 5. **This allows for passing - functions with more than 2 parameters, as long as the extra parameters have - default values.** - - A ``ValueError`` is raised if a function is passed which has 0 or more than - 2 parameters with no default values. - - When passing an instance method, the first parameter `self` is excluded - from the count. When passing a class method, the first parameter `cls` is - also excluded. - - .. note :: - It is fine to call the 1st parameter ``self`` if the updater is not an - instance method: it will still be counted as a parameter. The rule - above only applies for instance methods. For example, - ``lambda self: self.move_to(square)`` is a valid non-time-based - updater, and ``lambda self, dt: self.rotate(dt)`` is time-based. - - .. warning:: - Do **NOT** name the 1st parameter ``cls`` if the function is not a class - method. - - Attributes - ---------- - updater - An updater function, whose first parameter is a Mobject and might - optionally have a second parameter which should be a float - representing a time change dt. - - Raises - ------ - ValueError - If an updater is passed with 0 or more than 2 parameters with no - default values. - """ - __slots__ = ["updater", "is_time_based"] def __init__(self, updater: MobjectUpdater | MeshUpdater): @@ -144,6 +92,58 @@ def _raise_error(self, num_non_default_parameters: int): class MobjectUpdaterWrapper(AbstractUpdaterWrapper): + """Wraps a :class:`MobjectUpdater` function, inspects its signature and + calculates whether it's time-based or not. + + If it has a single parameter (with no default value), it's considered + non-time-based: it doesn't depend on time. + + If it has two parameters (with no default values), it's considered + time-based: it depends on time, and the affected Mobject has a change + on every frame which depends on the frame's duration dt. + + .. note:: + It's not mandatory that the parameters are named ``mob`` and ``dt``. + + **Only parameters with no default values are considered in when determining + whether the updater is time-based or not.** For example, an updater + ``lambda mob, rate=5: ...`` is considered non-time-based since the 2nd + parameter ``rate`` has a default value of 5. **This allows for passing + functions with more than 2 parameters, as long as the extra parameters have + default values.** + + A ``ValueError`` is raised if a function is passed which has 0 or more than + 2 parameters with no default values. + + When passing an instance method, the first parameter `self` is excluded + from the count. When passing a class method, the first parameter `cls` is + also excluded. + + .. note :: + It is fine to call the 1st parameter ``self`` if the updater is not an + instance method: it will still be counted as a parameter. The rule + above only applies for instance methods. For example, + ``lambda self: self.move_to(square)`` is a valid non-time-based + updater, and ``lambda self, dt: self.rotate(dt)`` is time-based. + + .. warning:: + Do **NOT** name the 1st parameter ``cls`` if the function is not a + class method. + + Attributes + ---------- + updater + An updater function, whose first parameter is a :class:`Mobject` and + might optionally have a second parameter which should be a ``float`` + representing a time change ``dt``. + + Raises + ------ + ValueError + If an updater is passed with 0 or more than 2 parameters with no + default values. + """ + def __init__(self, updater: MobjectUpdater): super().__init__(updater) From 33737d6db78d26e65195f5c211ab221299cd3b73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Manr=C3=ADquez?= Date: Sun, 16 Jun 2024 11:16:42 -0400 Subject: [PATCH 04/10] Add more docstrings in manim.utils.updaters --- manim/utils/updaters.py | 75 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 70 insertions(+), 5 deletions(-) diff --git a/manim/utils/updaters.py b/manim/utils/updaters.py index b285098493..9859577150 100644 --- a/manim/utils/updaters.py +++ b/manim/utils/updaters.py @@ -2,12 +2,13 @@ import inspect from collections.abc import Callable -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, TypeVar, Union if TYPE_CHECKING: from typing_extensions import TypeAlias from manim.mobject.mobject import Mobject + from manim.mobject.opengl.opengl_mobject import OpenGLMobject from manim.opengl.shader import Object3D @@ -24,8 +25,10 @@ ] -MobjectTimeBasedUpdater: TypeAlias = Callable[["Mobject", float], "Mobject"] -MobjectNonTimeBasedUpdater: TypeAlias = Callable[["Mobject"], "Mobject"] +M = TypeVar("M", bound=Union["Mobject", "OpenGLMobject"]) + +MobjectTimeBasedUpdater: TypeAlias = Callable[[M, float], M] +MobjectNonTimeBasedUpdater: TypeAlias = Callable[[M], M] MobjectUpdater: TypeAlias = Union[MobjectNonTimeBasedUpdater, MobjectTimeBasedUpdater] MeshTimeBasedUpdater: TypeAlias = Callable[["Object3D", float], "Object3D"] @@ -36,6 +39,33 @@ class AbstractUpdaterWrapper: + """Base class for :class:`MobjectUpdaterWrapper` and + :class:`MeshUpdaterWrapper`. See :class:`MobjectUpdaterWrapper` for more + information. + + Parameters + ---------- + updater + An updater function, whose first parameter is either a :class:`Mobject` + or an :class:`Object3D` (parent of :class:`Mesh`) and + might optionally have a second parameter which should be a ``float`` + representing a time change ``dt``. This function should return the same + object in the 1st parameter after applying a change on it. + + Attributes + ---------- + updater + The same updater function passed as a parameter. + is_time_based + Whether :attr:`updater` is a time-based updater or not. + + Raises + ------ + ValueError + If an updater is passed with 0 or more than 2 parameters with no + default values. + """ + __slots__ = ["updater", "is_time_based"] def __init__(self, updater: MobjectUpdater | MeshUpdater): @@ -130,12 +160,20 @@ class MobjectUpdaterWrapper(AbstractUpdaterWrapper): Do **NOT** name the 1st parameter ``cls`` if the function is not a class method. - Attributes + Parameters ---------- updater An updater function, whose first parameter is a :class:`Mobject` and might optionally have a second parameter which should be a ``float`` - representing a time change ``dt``. + representing a time change ``dt``. This function should return the same + :class:`Mobject` after applying a change on it. + + Attributes + ---------- + updater + The same updater function passed as a parameter. + is_time_based + Whether :attr:`updater` is a time-based updater or not. Raises ------ @@ -149,5 +187,32 @@ def __init__(self, updater: MobjectUpdater): class MeshUpdaterWrapper(AbstractUpdaterWrapper): + """Similar to :class:`MobjectUpdaterWrapper`, but for :class:`Object3D`, + parent of :class:`Mesh`. See the docs for :class:`MobjectUpdaterWrapper` + for more information. + + Parameters + ---------- + updater + An updater function, whose first parameter is an :class:`Object3D` + (parent of :class:`Mesh`) and might optionally have a second parameter + which should be a ``float`` representing a time change ``dt``. This + function should return the same :class:`Object3D` after applying a + change on it. + + Attributes + ---------- + updater + The same updater function passed as a parameter. + is_time_based + Whether :attr:`updater` is a time-based updater or not. + + Raises + ------ + ValueError + If an updater is passed with 0 or more than 2 parameters with no + default values. + """ + def __init__(self, updater: MeshUpdater): super().__init__(updater) From dd085f0d079a46fe6c352cc046e918ea5e5b70f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Manr=C3=ADquez?= Date: Sun, 16 Jun 2024 13:36:15 -0400 Subject: [PATCH 05/10] Fix Mobject.get_time_based_updaters --- manim/mobject/mobject.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/manim/mobject/mobject.py b/manim/mobject/mobject.py index ce5ff40a6f..922e259323 100644 --- a/manim/mobject/mobject.py +++ b/manim/mobject/mobject.py @@ -41,6 +41,7 @@ if TYPE_CHECKING: from typing_extensions import Self + from manim.animation.animation import Animation from manim.typing import ( FunctionOverride, Image, @@ -54,8 +55,6 @@ ) from manim.utils.updaters import MobjectTimeBasedUpdater, MobjectUpdater - from ..animation.animation import Animation - class Mobject: """Mathematical Object: base class for objects that can be displayed on screen. @@ -921,7 +920,11 @@ def get_time_based_updaters(self) -> Sequence[MobjectTimeBasedUpdater]: :meth:`has_time_based_updater` """ - return [wrapper.updater for wrapper in self.updater_wrappers if t.is_time_based] + return [ + wrapper.updater + for wrapper in self.updater_wrappers + if wrapper.is_time_based + ] def has_time_based_updater(self) -> bool: """Test if ``self`` has a time based updater. From 61e35814cf7445d96c89705295a1aa26b40cdcca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Manr=C3=ADquez?= Date: Mon, 17 Jun 2024 15:09:02 -0400 Subject: [PATCH 06/10] Address requested changes --- manim/mobject/mobject.py | 4 +- manim/mobject/opengl/opengl_mobject.py | 4 +- manim/renderer/shader.py | 4 +- manim/utils/updaters.py | 110 ++++++++++++++++--------- tests/module/utils/test_updaters.py | 18 +++- 5 files changed, 93 insertions(+), 47 deletions(-) diff --git a/manim/mobject/mobject.py b/manim/mobject/mobject.py index 922e259323..b008153a03 100644 --- a/manim/mobject/mobject.py +++ b/manim/mobject/mobject.py @@ -53,7 +53,7 @@ Point3D_Array, Vector3D, ) - from manim.utils.updaters import MobjectTimeBasedUpdater, MobjectUpdater + from manim.utils.updaters import MobjectDtUpdater, MobjectUpdater class Mobject: @@ -904,7 +904,7 @@ def update(self, dt: float = 0, recursive: bool = True) -> Self: submob.update(dt, recursive) return self - def get_time_based_updaters(self) -> Sequence[MobjectTimeBasedUpdater]: + def get_time_based_updaters(self) -> Sequence[MobjectDtUpdater]: """Return all updaters using the ``dt`` parameter. The updaters use this parameter as the input for difference in time. diff --git a/manim/mobject/opengl/opengl_mobject.py b/manim/mobject/opengl/opengl_mobject.py index 961fc0c343..1a25606788 100644 --- a/manim/mobject/opengl/opengl_mobject.py +++ b/manim/mobject/opengl/opengl_mobject.py @@ -60,7 +60,7 @@ Point3D_Array, Vector3D, ) - from manim.utils.updaters import MobjectTimeBasedUpdater, MobjectUpdater + from manim.utils.updaters import MobjectDtUpdater, MobjectUpdater T = TypeVar("T") @@ -1467,7 +1467,7 @@ def update(self, dt: float = 0, recurse: bool = True) -> Self: submob.update(dt, recurse) return self - def get_time_based_updaters(self) -> Sequence[MobjectTimeBasedUpdater]: + def get_time_based_updaters(self) -> Sequence[MobjectDtUpdater]: return [ wrapper.updater for wrapper in self.updater_wrappers diff --git a/manim/renderer/shader.py b/manim/renderer/shader.py index 7b0078d5cc..276b49114e 100644 --- a/manim/renderer/shader.py +++ b/manim/renderer/shader.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from typing_extensions import Self - from manim.utils.updaters import MeshTimeBasedUpdater, MeshUpdater + from manim.utils.updaters import MeshDtUpdater, MeshUpdater SHADER_FOLDER = Path(__file__).parent / "shaders" shader_program_cache: dict = {} @@ -202,7 +202,7 @@ def update(self, dt: float = 0) -> Self: wrapper.updater(self) return self - def get_time_based_updaters(self) -> Sequence[MeshTimeBasedUpdater]: + def get_time_based_updaters(self) -> Sequence[MeshDtUpdater]: return [ wrapper.updater for wrapper in self.updater_wrappers diff --git a/manim/utils/updaters.py b/manim/utils/updaters.py index 9859577150..56a0631609 100644 --- a/manim/utils/updaters.py +++ b/manim/utils/updaters.py @@ -1,3 +1,7 @@ +"""Updater types and classes. Updaters are functions which update an object +(a :class:`Mobject`, :class:`OpenGLMobject`, :class:`Object3D` or +:class:`Scene`) on every frame, and might""" + from __future__ import annotations import inspect @@ -9,15 +13,15 @@ from manim.mobject.mobject import Mobject from manim.mobject.opengl.opengl_mobject import OpenGLMobject - from manim.opengl.shader import Object3D + from manim.renderer.shader import Object3D __all__ = [ - "MobjectTimeBasedUpdater", - "MobjectNonTimeBasedUpdater", + "MobjectBasicUpdater", + "MobjectDtUpdater", "MobjectUpdater", - "MeshTimeBasedUpdater", - "MeshNonTimeBasedUpdater", + "MeshBasicUpdater", + "MeshDtUpdater", "MeshUpdater", "SceneUpdater", "MobjectUpdaterWrapper", @@ -25,17 +29,38 @@ ] -M = TypeVar("M", bound=Union["Mobject", "OpenGLMobject"]) +Mobj = TypeVar("Mobj", bound=Union["Mobject", "OpenGLMobject"]) + +MobjectBasicUpdater: TypeAlias = Callable[[Mobj], Mobj] +"""A function which updates a :class:`~.Mobject` on every frame.""" + +MobjectDtUpdater: TypeAlias = Callable[[Mobj, float], Mobj] +"""A function which updates a :class:`~.Mobject` on every frame, also depending +also depending on the frame's duration ``dt``. +""" -MobjectTimeBasedUpdater: TypeAlias = Callable[[M, float], M] -MobjectNonTimeBasedUpdater: TypeAlias = Callable[[M], M] -MobjectUpdater: TypeAlias = Union[MobjectNonTimeBasedUpdater, MobjectTimeBasedUpdater] +MobjectUpdater: TypeAlias = Union[MobjectBasicUpdater, MobjectDtUpdater] +"""A function which updates a :class:`~.Mobject` on every frame, possibly +depending on the frame's duration ``dt``. +""" -MeshTimeBasedUpdater: TypeAlias = Callable[["Object3D", float], "Object3D"] -MeshNonTimeBasedUpdater: TypeAlias = Callable[["Object3D"], "Object3D"] -MeshUpdater: TypeAlias = Union[MeshNonTimeBasedUpdater, MeshTimeBasedUpdater] +MeshBasicUpdater: TypeAlias = Callable[["Object3D"], "Object3D"] +"""A function which updates an :class:`~.Object3D` on every frame.""" + +MeshDtUpdater: TypeAlias = Callable[["Object3D", float], "Object3D"] +"""A function which updates an :class:`~.Object3D` on every frame, also +depending on the frame's duration ``dt``. +""" + +MeshUpdater: TypeAlias = Union[MeshBasicUpdater, MeshDtUpdater] +"""A function which updates an :class:`~.Object3D` on every frame, possibly +depending on the frame's duration ``dt``. +""" SceneUpdater: TypeAlias = Callable[[float], None] +"""A function which updates a :class:`~.Scene` on every frame, depending on the +frame's duration ``dt``. +""" class AbstractUpdaterWrapper: @@ -46,11 +71,12 @@ class AbstractUpdaterWrapper: Parameters ---------- updater - An updater function, whose first parameter is either a :class:`Mobject` - or an :class:`Object3D` (parent of :class:`Mesh`) and - might optionally have a second parameter which should be a ``float`` - representing a time change ``dt``. This function should return the same - object in the 1st parameter after applying a change on it. + An updater function, whose first parameter is either a + :class:`~.Mobject` or an :class:`~.Object3D` (parent of + :class:`~.Mesh`), and might optionally have a second parameter which + should be a ``float`` representing a time change ``dt``. This function + should return the same object in the 1st parameter after applying a + change on it. Attributes ---------- @@ -125,48 +151,57 @@ class MobjectUpdaterWrapper(AbstractUpdaterWrapper): """Wraps a :class:`MobjectUpdater` function, inspects its signature and calculates whether it's time-based or not. - If it has a single parameter (with no default value), it's considered - non-time-based: it doesn't depend on time. + If it has a single parameter (with no default value), it's considered a + :class:`MobjectBasicUpdater` which doesn't depend on time. - If it has two parameters (with no default values), it's considered - time-based: it depends on time, and the affected Mobject has a change - on every frame which depends on the frame's duration dt. + If it has two parameters (with no default values), it's considered a + :class:`MobjectDtUpdater` which depends on time, and the affected + :class:`~.Mobject` has a change on every frame which depends on the frame's + duration ``dt``. .. note:: It's not mandatory that the parameters are named ``mob`` and ``dt``. **Only parameters with no default values are considered in when determining whether the updater is time-based or not.** For example, an updater - ``lambda mob, rate=5: ...`` is considered non-time-based since the 2nd - parameter ``rate`` has a default value of 5. **This allows for passing - functions with more than 2 parameters, as long as the extra parameters have - default values.** + ``lambda mob, rate=5: ...`` is considered a :class:`MobjectBasicUpdater` + since the 2nd parameter ``rate`` has a default value of 5. + + .. note:: + The above rule allows for passing functions with more than 2 + parameters, as long as the extra parameters have default values. A ``ValueError`` is raised if a function is passed which has 0 or more than 2 parameters with no default values. - When passing an instance method, the first parameter `self` is excluded - from the count. When passing a class method, the first parameter `cls` is + When passing an instance method, the first parameter ``self`` is excluded + from the count. When passing a class method, the first parameter ``cls`` is also excluded. .. note :: It is fine to call the 1st parameter ``self`` if the updater is not an instance method: it will still be counted as a parameter. The rule - above only applies for instance methods. For example, - ``lambda self: self.move_to(square)`` is a valid non-time-based - updater, and ``lambda self, dt: self.rotate(dt)`` is time-based. + above only applies for instance methods. + + For example, ``lambda self: self.move_to(square)`` is a valid + :class:`MobjectBasicUpdater`, and ``lambda self, dt: self.rotate(dt)`` + is a :class:`MobjectDtUpdater`. + + However, it is not recommended to name the 1st parameter ``self`` if + the updater is not a method, because it is not Pythonic. A better + option is ``mob``. .. warning:: Do **NOT** name the 1st parameter ``cls`` if the function is not a - class method. + class method. Otherwise, the updater will not be parsed correctly. Parameters ---------- updater - An updater function, whose first parameter is a :class:`Mobject` and + An updater function, whose first parameter is a :class:`~.Mobject` and might optionally have a second parameter which should be a ``float`` representing a time change ``dt``. This function should return the same - :class:`Mobject` after applying a change on it. + :class:`~.Mobject` after applying a change on it. Attributes ---------- @@ -187,8 +222,9 @@ def __init__(self, updater: MobjectUpdater): class MeshUpdaterWrapper(AbstractUpdaterWrapper): - """Similar to :class:`MobjectUpdaterWrapper`, but for :class:`Object3D`, - parent of :class:`Mesh`. See the docs for :class:`MobjectUpdaterWrapper` + """Similar to :class:`MobjectUpdaterWrapper`, but wraps instead a + :class:`MeshUpdater` which modifies an :class:`~.Object3D`, parent of + :class:`~.Mesh`. See the docs for :class:`MobjectUpdaterWrapper` for more information. Parameters @@ -197,7 +233,7 @@ class MeshUpdaterWrapper(AbstractUpdaterWrapper): An updater function, whose first parameter is an :class:`Object3D` (parent of :class:`Mesh`) and might optionally have a second parameter which should be a ``float`` representing a time change ``dt``. This - function should return the same :class:`Object3D` after applying a + function should return the same :class:`~.Object3D` after applying a change on it. Attributes diff --git a/tests/module/utils/test_updaters.py b/tests/module/utils/test_updaters.py index 2e19fe63de..42b07d4537 100644 --- a/tests/module/utils/test_updaters.py +++ b/tests/module/utils/test_updaters.py @@ -29,11 +29,11 @@ def test_UpdaterWrapper() -> None: assert wrapper.is_time_based # An updater cannot have no parameters - with pytest.raises(ValueError) as no_parameters_info: - wrapper = MobjectUpdaterWrapper(lambda: "Hello world") + with pytest.raises(ValueError): + wrapper = MobjectUpdaterWrapper(lambda: square.move_to(RIGHT)) # An updater cannot have more than 2 parameters without a default value - with pytest.raises(ValueError) as three_parameters_info: + with pytest.raises(ValueError): wrapper = MobjectUpdaterWrapper(lambda mob, rate, third: mob.rotate(rate)) # Only parameters with no default value are considered when determining @@ -56,8 +56,18 @@ def test_UpdaterWrapper() -> None: assert wrapper.is_time_based # In general, if the function is not an instance method, the 1st parameter - # is always included, even if it's called "self". + # is almost always included, even if it's called "self". wrapper = MobjectUpdaterWrapper(lambda self: self.move_to(square)) assert not wrapper.is_time_based wrapper = MobjectUpdaterWrapper(lambda self, dt: self.rotate(dt)) assert wrapper.is_time_based + + # The only exception is if it's called "cls". Don't call it "cls" if it's + # not a class method. + wrapper = MobjectUpdaterWrapper(lambda cls, dt: cls.rotate(dt)) + assert ( + not wrapper.is_time_based + ) # Only 1 parameter, dt, is considered, and it's used as a Mobject, not float + with pytest.raises(ValueError): + # Since cls is excluded, there are no other parameters + wrapper = MobjectUpdaterWrapper(lambda cls: cls.next_to(square)) From b2b472e76a7c6330c39c9eba43fa65073de96169 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Manr=C3=ADquez=20Novoa?= <49853152+chopan050@users.noreply.github.com> Date: Mon, 17 Jun 2024 15:14:59 -0400 Subject: [PATCH 07/10] Add missing backreferences in updaters.py --- manim/utils/updaters.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/manim/utils/updaters.py b/manim/utils/updaters.py index 56a0631609..f1fce1d347 100644 --- a/manim/utils/updaters.py +++ b/manim/utils/updaters.py @@ -230,8 +230,8 @@ class MeshUpdaterWrapper(AbstractUpdaterWrapper): Parameters ---------- updater - An updater function, whose first parameter is an :class:`Object3D` - (parent of :class:`Mesh`) and might optionally have a second parameter + An updater function, whose first parameter is an :class:`~.Object3D` + (parent of :class:`~.Mesh`) and might optionally have a second parameter which should be a ``float`` representing a time change ``dt``. This function should return the same :class:`~.Object3D` after applying a change on it. From b11214a34fd34bd0f5860e1f8aaea8e07ba0c47b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Manr=C3=ADquez?= Date: Mon, 17 Jun 2024 15:24:12 -0400 Subject: [PATCH 08/10] Include OpenGLMobject in Updater docs --- manim/utils/updaters.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/manim/utils/updaters.py b/manim/utils/updaters.py index f1fce1d347..f654ef4f8f 100644 --- a/manim/utils/updaters.py +++ b/manim/utils/updaters.py @@ -32,16 +32,18 @@ Mobj = TypeVar("Mobj", bound=Union["Mobject", "OpenGLMobject"]) MobjectBasicUpdater: TypeAlias = Callable[[Mobj], Mobj] -"""A function which updates a :class:`~.Mobject` on every frame.""" +"""A function which updates a :class:`~.Mobject` or :class:`~.OpenGLMobject` on +every frame. +""" MobjectDtUpdater: TypeAlias = Callable[[Mobj, float], Mobj] -"""A function which updates a :class:`~.Mobject` on every frame, also depending -also depending on the frame's duration ``dt``. +"""A function which updates a :class:`~.Mobject` or :class:`~.OpenGLMobject` on +every frame, also depending on the frame's duration ``dt``. """ MobjectUpdater: TypeAlias = Union[MobjectBasicUpdater, MobjectDtUpdater] -"""A function which updates a :class:`~.Mobject` on every frame, possibly -depending on the frame's duration ``dt``. +"""A function which updates a :class:`~.Mobject` or :class:`~.OpenGLMobject` on +every frame, possibly depending on the frame's duration ``dt``. """ MeshBasicUpdater: TypeAlias = Callable[["Object3D"], "Object3D"] From b2e88c9b9750478dc82d23351aac8d9de251fe61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Manr=C3=ADquez?= Date: Mon, 17 Jun 2024 15:46:11 -0400 Subject: [PATCH 09/10] Complete updaters.py module docstring --- manim/utils/updaters.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/manim/utils/updaters.py b/manim/utils/updaters.py index f654ef4f8f..4916ba652d 100644 --- a/manim/utils/updaters.py +++ b/manim/utils/updaters.py @@ -1,6 +1,8 @@ """Updater types and classes. Updaters are functions which update an object -(a :class:`Mobject`, :class:`OpenGLMobject`, :class:`Object3D` or -:class:`Scene`) on every frame, and might""" +(a :class:`~.Mobject`, :class:`~.OpenGLMobject`, :class:`~.Object3D` or +:class:`~.Scene`) on every frame, and might accept an additional parameter +``dt`` which represents the frame's duration. +""" from __future__ import annotations From 8d836f6030aaab986e5448a8dee45bb70b3d070f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 27 Oct 2024 13:12:16 +0000 Subject: [PATCH 10/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- manim/renderer/shader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/manim/renderer/shader.py b/manim/renderer/shader.py index fb4bb7cc4f..b1faacd665 100644 --- a/manim/renderer/shader.py +++ b/manim/renderer/shader.py @@ -1,7 +1,6 @@ from __future__ import annotations import contextlib -import inspect import re import textwrap from collections.abc import Sequence