Skip to content

Commit 135d30b

Browse files
committed
fix: add initial method for class bind
1 parent bb3c817 commit 135d30b

File tree

3 files changed

+219
-5
lines changed

3 files changed

+219
-5
lines changed

lib/plugify/plugin.py

Lines changed: 206 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import os
33
import importlib.util
44
from copy import deepcopy
5+
from enum import Enum
6+
from typing import Any, Callable, List, Optional, Dict, Tuple
57

68

79
class Plugin:
@@ -189,7 +191,7 @@ def to_list(self):
189191
return deepcopy(self.m)
190192

191193

192-
def extract_required_modules(module_path, visited=None):
194+
def extract_required_modules(module_path: str, visited: Optional[set] = None):
193195
"""
194196
Recursively extract all imported modules and their fully qualified names.
195197
@@ -213,7 +215,7 @@ def extract_required_modules(module_path, visited=None):
213215
try:
214216
with open(module_path, "r", encoding="utf-8") as file:
215217
tree = ast.parse(file.read(), filename=module_path)
216-
218+
217219
for node in ast.walk(tree):
218220
if isinstance(node, ast.Import):
219221
for alias in node.names:
@@ -226,7 +228,7 @@ def extract_required_modules(module_path, visited=None):
226228
print(f"Error processing {module_path}: {e}")
227229
return required_modules
228230

229-
def find_module_path(module_name):
231+
def find_module_path(module_name: str):
230232
"""
231233
Locate the file path of a given Python module name, ensuring it's a .py file.
232234
"""
@@ -239,8 +241,8 @@ def find_module_path(module_name):
239241
pass
240242
return None
241243

244+
all_dependencies = set(required_modules)
242245
try:
243-
all_dependencies = set(required_modules)
244246
for module_name in required_modules:
245247
base_module = module_name.split('.')[0]
246248
module_file = find_module_path(base_module)
@@ -249,4 +251,203 @@ def find_module_path(module_name):
249251
except Exception as e:
250252
print(f"Error processing dependencies for {module_path}: {e}")
251253

252-
return all_dependencies
254+
return all_dependencies
255+
256+
257+
class Ownership(Enum):
258+
OWNED = True
259+
BORROWED = False
260+
261+
262+
# Store class registry for retAlias lookups
263+
_class_registry = {}
264+
265+
266+
def bind_class_methods(
267+
cls: type,
268+
constructors: List[Callable],
269+
destructor: Optional[Callable],
270+
methods: List[Tuple[str, Callable, bool, Optional[List[Tuple[str, bool]]], Optional[Tuple[str, bool]]]],
271+
invalid_value: Any = 0
272+
):
273+
"""
274+
Dynamically bind methods to a class for RAII handle management.
275+
276+
Args:
277+
cls: The class to extend with methods
278+
constructors: List of constructor functions (can be empty for handle-only construction)
279+
destructor: Destructor function (can be None)
280+
methods: List of [name, func, bindSelf, paramAliases, retAlias]
281+
- name (str): Method name
282+
- func (callable): Underlying C function
283+
- bindSelf (bool): Whether to pass self._handle as first param
284+
- paramAliases (list): List of pairs with 'name' and 'owner' values
285+
- retAlias (dict): Pair with 'name' and 'owner' values
286+
invalid_value: Value representing an invalid/closed handle (default: 0)
287+
"""
288+
289+
class_name = cls.__name__
290+
291+
# 1. Add __init__ method
292+
def __init__(self, *args, **kwargs):
293+
"""
294+
Initialize the wrapper. Supports two modes:
295+
1. Direct handle construction: ClassName(handle, Ownership.OWNED/BORROWED)
296+
2. Constructor call: ClassName(*constructor_args)
297+
"""
298+
# Check if this is handle + ownership construction
299+
if len(args) >= 2 and isinstance(args[1], Ownership):
300+
self._handle = args[0]
301+
self._owned = args[1]
302+
else:
303+
# Constructor call
304+
if len(constructors) == 0:
305+
raise ValueError(f"{class_name} requires handle and ownership for construction")
306+
elif len(constructors) == 1:
307+
# Single constructor - call it directly
308+
self._handle = constructors[0](*args, **kwargs)
309+
else:
310+
# Multiple constructors - first arg should be the constructor function
311+
if len(args) == 0:
312+
raise ValueError(
313+
f"{class_name} with multiple constructors requires constructor function as first argument")
314+
func = args[0]
315+
if func not in constructors:
316+
raise ValueError(f"Invalid constructor function for {class_name}")
317+
self._handle = func(*args[1:], **kwargs)
318+
319+
self._owned = Ownership.OWNED
320+
321+
cls.__init__ = __init__
322+
323+
# 2. Add lifecycle methods (close, __del__, __enter__, __exit__)
324+
def close(self):
325+
"""Close/destroy the handle if owned."""
326+
if self._handle != invalid_value and self._owned == Ownership.OWNED:
327+
if destructor is not None:
328+
destructor(self._handle)
329+
self._handle = invalid_value
330+
self._owned = Ownership.BORROWED
331+
332+
cls.close = close
333+
334+
def __del__(self):
335+
self.close()
336+
337+
cls.__del__ = __del__
338+
339+
def __enter__(self):
340+
return self
341+
342+
cls.__enter__ = __enter__
343+
344+
def __exit__(self, exc_type, exc_val, exc_tb):
345+
self.close()
346+
return False
347+
348+
cls.__exit__ = __exit__
349+
350+
# 3. Add utility methods (release, reset, get, valid)
351+
def release(self) -> Any:
352+
"""Release ownership of the handle and return it."""
353+
tmp = self._handle
354+
self._handle = invalid_value
355+
self._owned = Ownership.BORROWED
356+
return tmp
357+
358+
cls.release = release
359+
360+
def reset(self):
361+
"""Reset the handle by closing it."""
362+
self.close()
363+
364+
cls.reset = reset
365+
366+
def get(self) -> Any:
367+
"""Get the raw handle value without transferring ownership."""
368+
return self._handle
369+
370+
cls.get = get
371+
372+
def valid(self) -> bool:
373+
"""Check if the handle is valid."""
374+
return self._handle != invalid_value
375+
376+
cls.valid = valid
377+
378+
# Register this class for retAlias lookups
379+
_class_registry[class_name] = cls
380+
381+
# 4. Add bound methods from the methods list
382+
for method_info in methods:
383+
method_name = method_info[0]
384+
func = method_info[1]
385+
bind_self = method_info[2]
386+
param_aliases = method_info[3]
387+
ret_alias = method_info[4]
388+
389+
# Create the bound method with closure over parameters
390+
def create_method(
391+
func: Callable,
392+
bind_self: bool,
393+
param_aliases: Optional[List[Tuple[str, bool]]],
394+
ret_alias: Optional[Tuple[str, bool]],
395+
method_name: str
396+
):
397+
def method(self, *args, **kwargs):
398+
# Check if handle is valid
399+
if self._handle == invalid_value:
400+
raise RuntimeError(f"{class_name} handle is closed")
401+
402+
# Process arguments - convert to list for modification
403+
args_list = list(args)
404+
405+
# Handle paramAliases - extract handles from wrapper objects
406+
if param_aliases:
407+
for i, alias_info in enumerate(param_aliases):
408+
if alias_info and i < len(args_list):
409+
alias_name = alias_info[0]
410+
owner = alias_info[1]
411+
412+
if alias_name and args_list[i] is not None:
413+
arg = args_list[i]
414+
# Check if the argument has the expected methods
415+
if hasattr(arg, 'release') and hasattr(arg, 'get'):
416+
if owner:
417+
# Transfer ownership - use release()
418+
args_list[i] = arg.release()
419+
else:
420+
# Borrow - use get()
421+
args_list[i] = arg.get()
422+
423+
# Call the underlying function
424+
if bind_self:
425+
# Pass self._handle as first parameter
426+
result = func(self._handle, *args_list, **kwargs)
427+
else:
428+
# Don't pass self._handle
429+
result = func(*args_list, **kwargs)
430+
431+
# Handle retAlias - wrap return value in class
432+
if ret_alias:
433+
ret_name = ret_alias[0]
434+
owner = ret_alias[1]
435+
436+
# Look up the class
437+
ret_class = _class_registry.get(ret_name)
438+
if ret_class and result != invalid_value:
439+
ownership = Ownership.OWNED if owner else Ownership.BORROWED
440+
return ret_class(result, ownership)
441+
elif result == invalid_value:
442+
return None
443+
444+
return result
445+
446+
# Preserve function name for better debugging
447+
method.__name__ = method_name
448+
return method
449+
450+
# Bind the method to the class
451+
setattr(cls, method_name, create_method(func, bind_self, param_aliases, ret_alias, method_name))
452+
453+
return cls

src/module.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3215,6 +3215,13 @@ namespace py3lm {
32153215
return MakeError("Failed to find plugify.plugin.extract_required_modules function");
32163216
}
32173217

3218+
_BindClassMethodsObject = PyObject_GetAttrString(plugifyPluginModule, "bind_class_methods");
3219+
if (!_BindClassMethodsObject || !PyCallable_Check(_BindClassMethodsObject)) {
3220+
Py_DECREF(plugifyPluginModule);
3221+
LogError();
3222+
return MakeError("Failed to find plugify.plugin.bind_class_methods function");
3223+
}
3224+
32183225
Py_DECREF(plugifyPluginModule);
32193226

32203227
_ppsModule = PyImport_ImportModule("plugify.pps");
@@ -3376,6 +3383,10 @@ namespace py3lm {
33763383
Py_DECREF(_ExtractRequiredModulesObject);
33773384
}
33783385

3386+
if (_BindClassMethodsObject) {
3387+
Py_DECREF(_BindClassMethodsObject);
3388+
}
3389+
33793390
if (_PluginTypeObject) {
33803391
Py_DECREF(_PluginTypeObject);
33813392
}
@@ -3414,6 +3425,7 @@ namespace py3lm {
34143425
_Vector4TypeObject = nullptr;
34153426
_Matrix4x4TypeObject = nullptr;
34163427
_ExtractRequiredModulesObject = nullptr;
3428+
_BindClassMethodsObject = nullptr;
34173429
_PluginTypeObject = nullptr;
34183430
_PluginInfoTypeObject = nullptr;
34193431
_internalMap.clear();

src/module.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ namespace py3lm {
181181
PyObject* _Vector4TypeObject = nullptr;
182182
PyObject* _Matrix4x4TypeObject = nullptr;
183183
PyObject* _ExtractRequiredModulesObject = nullptr;
184+
PyObject* _BindClassMethodsObject = nullptr;
184185
PyObject* _ppsModule = nullptr;
185186
PyObject* _enumModule = nullptr;
186187
PyObject* _formatException = nullptr;

0 commit comments

Comments
 (0)