1+ import unittest
2+
3+ import torch
4+
5+ # Import the custom ops to ensure they are registered
6+ from executorch .extension .llm .custom_ops import custom_ops # noqa: F401
7+
8+
9+ class TestUpdateCrossAttnCache (unittest .TestCase ):
10+ def test_update_cross_attn_cache (self ):
11+
12+ # Create tensors
13+ # Cache: [B=2, H=1, S_max=4, D=4]
14+ cache = torch .zeros (2 , 1 , 4 , 4 , dtype = torch .float32 )
15+ # Value: [B=2, H=1, S=2, D=4] (S < S_max)
16+ value = torch .randn (2 , 1 , 2 , 4 , dtype = torch .float32 )
17+
18+ # Compile a function that uses the op
19+ @torch .compile
20+ def fn (v , c ):
21+ return torch .ops .executorch .update_cross_attn_cache (v , c )
22+
23+ # Run it
24+ out = fn (value , cache )
25+
26+ # Check correctness
27+ # The first 2 elements in dim 2 (sequence dim) should match value
28+ torch .testing .assert_close (
29+ cache [:, :, :2 , :], value , msg = "Cache slice not updated correctly"
30+ )
31+
32+ # Make sure out and cache are close. In eager they are the same objects.
33+ torch .testing .assert_close (
34+ out , cache , msg = "Output and cache are different objects"
35+ )
36+
37+ # The rest should be zeros
38+ torch .testing .assert_close (
39+ cache [:, :, 2 :, :],
40+ torch .zeros_like (cache [:, :, 2 :, :]),
41+ msg = "Rest of cache was modified" ,
42+ )
43+
44+ def test_update_cross_attn_cache_in_cond (self ):
45+ # Create tensors
46+
47+ # Value: [B=2, H=1, S=2, D=4]
48+ value = torch .randn (2 , 1 , 2 , 4 , dtype = torch .float32 )
49+ # Alternative value for false branch
50+ value_alt = torch .randn (2 , 1 , 2 , 4 , dtype = torch .float32 )
51+
52+ # Define a function that uses the op inside torch.cond
53+ def fn_with_cond (pred , v1 , v2 , c ):
54+ def true_fn (v1 , v2 , cache ):
55+ return torch .ops .executorch .update_cross_attn_cache (v1 , cache )
56+
57+ def false_fn (v1 , v2 , cache ):
58+ return torch .ops .executorch .update_cross_attn_cache (v2 , cache )
59+
60+ return torch .cond (pred , true_fn , false_fn , (v1 , v2 , c ))
61+
62+ # Test with true condition
63+ pred_true = torch .tensor (True )
64+ cache_true = torch .zeros (2 , 1 , 4 , 4 , dtype = torch .float32 )
65+
66+ # Compile the function
67+ @torch .compile
68+ def compiled_fn (pred , v1 , v2 , c ):
69+ return fn_with_cond (pred , v1 , v2 , c )
70+
71+ # Run with true condition
72+ compiled_fn (pred_true , value , value_alt , cache_true )
73+
74+ # Check that the true branch was executed (value was used)
75+ torch .testing .assert_close (
76+ cache_true [:, :, :2 , :],
77+ value ,
78+ msg = "Cache not updated correctly in true branch" ,
79+ )
80+
81+ # Test with false condition
82+ pred_false = torch .tensor (False )
83+ cache_false = torch .zeros (2 , 1 , 4 , 4 , dtype = torch .float32 )
84+
85+ compiled_fn (pred_false , value , value_alt , cache_false )
86+
87+ # Check that the false branch was executed (value_alt was used)
88+ torch .testing .assert_close (
89+ cache_false [:, :, :2 , :],
90+ value_alt ,
91+ msg = "Cache not updated correctly in false branch" ,
92+ )
93+
94+ def test_update_cross_attn_cache_export (self ):
95+
96+ # Create tensors
97+ # Cache: [B=2, H=1, S_max=4, D=4]
98+ cache = torch .zeros (2 , 1 , 4 , 4 , dtype = torch .float32 )
99+ # Value: [B=2, H=1, S=2, D=4]
100+ value = torch .randn (2 , 1 , 2 , 4 , dtype = torch .float32 )
101+ # Alternative value for false branch
102+ value_alt = torch .randn (2 , 1 , 2 , 4 , dtype = torch .float32 )
103+
104+ # Define a module that uses torch.cond with the op
105+ class UpdateCacheCondModule (torch .nn .Module ):
106+ def forward (self , pred , v1 , v2 , c ):
107+ def true_fn (v1 , v2 , cache ):
108+ return torch .ops .executorch .update_cross_attn_cache (v1 , cache )
109+
110+ def false_fn (v1 , v2 , cache ):
111+ return torch .ops .executorch .update_cross_attn_cache (v2 , cache )
112+
113+ return torch .cond (pred , true_fn , false_fn , (v1 , v2 , c ))
114+
115+ module = UpdateCacheCondModule ()
116+
117+ # Export the module with true condition
118+ pred_true = torch .tensor (True )
119+ exported_program = torch .export .export (
120+ module ,
121+ (pred_true , value , value_alt , cache ),
122+ )
123+
124+ # Run the exported program with true condition
125+ cache_true = torch .zeros (2 , 1 , 4 , 4 , dtype = torch .float32 )
126+ exported_program .module ()(pred_true , value , value_alt , cache_true )
127+
128+ # Check that the true branch was executed (value was used)
129+ torch .testing .assert_close (
130+ cache_true [:, :, :2 , :],
131+ value ,
132+ msg = "Cache not updated correctly in true branch after export" ,
133+ )
134+
135+ # Run the exported program with false condition
136+ pred_false = torch .tensor (False )
137+ cache_false = torch .zeros (2 , 1 , 4 , 4 , dtype = torch .float32 )
138+ exported_program .module ()(pred_false , value , value_alt , cache_false )
139+
140+ # Check that the false branch was executed (value_alt was used)
141+ torch .testing .assert_close (
142+ cache_false [:, :, :2 , :],
143+ value_alt ,
144+ msg = "Cache not updated correctly in false branch after export" ,
145+ )
146+
147+ def test_update_cross_attn_cache_different_shapes (self ):
148+ print ("Testing executorch::update_cross_attn_cache with different shapes..." )
149+
150+ # Test with different batch sizes and sequence lengths
151+ test_cases = [
152+ # (B, H, S_max, S, D)
153+ (1 , 2 , 10 , 5 , 8 ),
154+ (4 , 4 , 8 , 3 , 16 ),
155+ (2 , 1 , 16 , 10 , 32 ),
156+ ]
157+
158+ for B , H , S_max , S , D in test_cases :
159+ # Cache: [B, H, S_max, D], Value: [B, H, S, D]
160+ cache = torch .zeros (B , H , S_max , D , dtype = torch .float32 )
161+ value = torch .randn (B , H , S , D , dtype = torch .float32 )
162+
163+ @torch .compile
164+ def fn (v , c ):
165+ return torch .ops .executorch .update_cross_attn_cache (v , c )
166+
167+ fn (value , cache )
168+
169+ # Check that the first S positions in dim 2 are updated
170+ torch .testing .assert_close (
171+ cache [:, :, :S , :],
172+ value ,
173+ msg = f"Failed for shape B={ B } , H={ H } , S_max={ S_max } , S={ S } , D={ D } " ,
174+ )
175+
176+ # Check that the rest remain zeros
177+ if S < S_max :
178+ torch .testing .assert_close (
179+ cache [:, :, S :, :],
180+ torch .zeros_like (cache [:, :, S :, :]),
181+ msg = f"Remaining cache modified for shape B={ B } , H={ H } , S_max={ S_max } , S={ S } , D={ D } " ,
182+ )
183+
184+ def test_update_cross_attn_cache_full_sequence (self ):
185+
186+ # Cache: [B=2, H=1, S_max=4, D=4]
187+ cache = torch .zeros (2 , 1 , 4 , 4 , dtype = torch .float32 )
188+ # Value: [B=2, H=1, S=4, D=4] (S == S_max)
189+ value = torch .randn (2 , 1 , 4 , 4 , dtype = torch .float32 )
190+
191+ @torch .compile
192+ def fn (v , c ):
193+ return torch .ops .executorch .update_cross_attn_cache (v , c )
194+
195+ fn (value , cache )
196+
197+ # The entire cache should match value
198+ torch .testing .assert_close (
199+ cache , value , msg = "Cache not fully updated when S == S_max"
200+ )
0 commit comments