@@ -410,6 +410,12 @@ def _contraction_list_from_path(
410410 return contraction_list
411411
412412
413+ def _right_to_left_path (n : int ) -> tuple [tuple [int , int ], ...]:
414+ # Create a right to left contraction path
415+ # if n = 5, out = ((4, 3), (3, 2), (2, 1), (1, 0))
416+ return tuple (pairwise (reversed (range (n ))))
417+
418+
413419def einsum (subscripts : str , * operands : "TensorLike" , optimize = None ) -> TensorVariable :
414420 """
415421 Multiplication and summation of tensors using the Einstein summation convention.
@@ -563,7 +569,7 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
563569 else :
564570 # By default, we try right to left because we assume that most graphs
565571 # have a lower dimensional rightmost operand
566- path = tuple ( pairwise ( reversed ( range ( len (tensor_operands ))) ))
572+ path = _right_to_left_path ( len (tensor_operands ))
567573 contraction_list = _contraction_list_from_path (
568574 subscripts , tensor_operands , path
569575 )
@@ -581,7 +587,18 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
581587 einsum_call = True , # Not part of public API
582588 optimize = "optimal" ,
583589 ) # type: ignore
584- path = tuple (contraction [0 ] for contraction in contraction_list )
590+ np_path = tuple (contraction [0 ] for contraction in contraction_list )
591+
592+ if len (np_path ) == 1 and len (np_path [0 ]) > 2 :
593+ # When there's nothing to optimize, einsum_path reduces all entries simultaneously instead of doing
594+ # pairwise reductions, which our implementation below demands.
595+ path = _right_to_left_path (len (tensor_operands ))
596+ contraction_list = _contraction_list_from_path (
597+ subscripts , tensor_operands , path
598+ )
599+ else :
600+ path = np_path
601+
585602 optimized = True
586603
587604 def removechars (s , chars ):
@@ -744,7 +761,7 @@ def filter_singleton_dims(operand, names, other_operand, other_names):
744761 )
745762 else :
746763 raise ValueError (
747- f"Each step of einsum must have 1 or 2 operands, got { len (operand_indices )} "
764+ f"Each step of einsum must have 1 or 2 operands, got { len (operand_indices )} , { path = } . "
748765 )
749766
750767 # the resulting 'operand' with axis labels 'names' should be a permutation of the desired result
0 commit comments