1616
1717import pytensor
1818from pytensor .configdefaults import config
19- from pytensor .graph .basic import Apply , NoParams , Variable
19+ from pytensor .graph .basic import Apply , Variable
2020from pytensor .graph .utils import (
2121 MetaObject ,
22- MethodNotDefined ,
2322 TestValueError ,
2423 add_tag_trace ,
2524 get_variable_trace_string ,
2625)
27- from pytensor .link .c .params_type import Params , ParamsType
2826
2927
3028if TYPE_CHECKING :
3735ComputeMapType = dict [Variable , list [bool ]]
3836InputStorageType = list [StorageCellType ]
3937OutputStorageType = list [StorageCellType ]
40- ParamsInputType = Optional [tuple [Any , ...]]
41- PerformMethodType = Callable [
42- [Apply , list [Any ], OutputStorageType , ParamsInputType ], None
43- ]
38+ PerformMethodType = Callable [[Apply , list [Any ], OutputStorageType ], None ]
4439BasicThunkType = Callable [[], None ]
4540ThunkCallableType = Callable [
4641 [PerformMethodType , StorageMapType , ComputeMapType , Apply ], None
@@ -202,7 +197,6 @@ class Op(MetaObject):
202197
203198 itypes : Optional [Sequence ["Type" ]] = None
204199 otypes : Optional [Sequence ["Type" ]] = None
205- params_type : Optional [ParamsType ] = None
206200
207201 _output_type_depends_on_input_value = False
208202 """
@@ -426,7 +420,6 @@ def perform(
426420 node : Apply ,
427421 inputs : Sequence [Any ],
428422 output_storage : OutputStorageType ,
429- params : ParamsInputType = None ,
430423 ) -> None :
431424 """Calculate the function on the inputs and put the variables in the output storage.
432425
@@ -442,8 +435,6 @@ def perform(
442435 these lists). Each sub-list corresponds to value of each
443436 `Variable` in :attr:`node.outputs`. The primary purpose of this method
444437 is to set the values of these sub-lists.
445- params
446- A tuple containing the values of each entry in :attr:`Op.__props__`.
447438
448439 Notes
449440 -----
@@ -481,22 +472,6 @@ def do_constant_folding(self, fgraph: "FunctionGraph", node: Apply) -> bool:
481472 """
482473 return True
483474
484- def get_params (self , node : Apply ) -> Params :
485- """Try to get parameters for the `Op` when :attr:`Op.params_type` is set to a `ParamsType`."""
486- if isinstance (self .params_type , ParamsType ):
487- wrapper = self .params_type
488- if not all (hasattr (self , field ) for field in wrapper .fields ):
489- # Let's print missing attributes for debugging.
490- not_found = tuple (
491- field for field in wrapper .fields if not hasattr (self , field )
492- )
493- raise AttributeError (
494- f"{ type (self ).__name__ } : missing attributes { not_found } for ParamsType."
495- )
496- # ParamsType.get_params() will apply filtering to attributes.
497- return self .params_type .get_params (self )
498- raise MethodNotDefined ("get_params" )
499-
500475 def prepare_node (
501476 self ,
502477 node : Apply ,
@@ -538,34 +513,12 @@ def make_py_thunk(
538513 else :
539514 p = node .op .perform
540515
541- params = node .run_params ()
542-
543- if params is NoParams :
544- # default arguments are stored in the closure of `rval`
545- @is_thunk_type
546- def rval (
547- p = p , i = node_input_storage , o = node_output_storage , n = node , params = None
548- ):
549- r = p (n , [x [0 ] for x in i ], o )
550- for o in node .outputs :
551- compute_map [o ][0 ] = True
552- return r
553-
554- else :
555- params_val = node .params_type .filter (params )
556-
557- @is_thunk_type
558- def rval (
559- p = p ,
560- i = node_input_storage ,
561- o = node_output_storage ,
562- n = node ,
563- params = params_val ,
564- ):
565- r = p (n , [x [0 ] for x in i ], o , params )
566- for o in node .outputs :
567- compute_map [o ][0 ] = True
568- return r
516+ @is_thunk_type
517+ def rval (p = p , i = node_input_storage , o = node_output_storage , n = node ):
518+ r = p (n , [x [0 ] for x in i ], o )
519+ for o in node .outputs :
520+ compute_map [o ][0 ] = True
521+ return r
569522
570523 rval .inputs = node_input_storage
571524 rval .outputs = node_output_storage
@@ -640,7 +593,7 @@ class _NoPythonOp(Op):
640593
641594 """
642595
643- def perform (self , node , inputs , output_storage , params = None ):
596+ def perform (self , node , inputs , output_storage ):
644597 raise NotImplementedError ("No Python implementation is provided by this Op." )
645598
646599
0 commit comments