@@ -99,6 +99,7 @@ def ragged_dot(
9999 block_k : int ,
100100 max_concurrent_steps : int ,
101101 grid_block_n : int ,
102+ transpose_rhs : bool = False ,
102103) -> jax .Array :
103104 if lhs .dtype != rhs .dtype :
104105 raise NotImplementedError (
@@ -107,6 +108,9 @@ def ragged_dot(
107108 m , k = lhs .shape
108109 g , k2 , n = rhs .shape
109110
111+ if transpose_rhs :
112+ k2 , n = n , k2
113+
110114 if group_sizes .shape [0 ] != g :
111115 raise ValueError (
112116 f"Expected group_sizes to have shape { g } but got { group_sizes .shape } "
@@ -134,7 +138,9 @@ def mn_loop(loop_info: plgpu.NDLoopInfo): # pylint: disable=unused-variable
134138 def acc_scope (acc_ref ):
135139 plgpu .emit_pipeline (
136140 lambda _ , lhs_smem , rhs_smem : plgpu .wgmma (
137- acc_ref , lhs_smem , rhs_smem
141+ acc_ref ,
142+ lhs_smem ,
143+ plgpu .transpose_ref (rhs_smem , (1 , 0 )) if transpose_rhs else rhs_smem ,
138144 ),
139145 grid = (k // block_k ,),
140146 in_specs = [
@@ -144,8 +150,8 @@ def acc_scope(acc_ref):
144150 delay_release = 1 ,
145151 ),
146152 plgpu .BlockSpec (
147- (block_k , block_n ),
148- lambda k : (k , ni ),
153+ (block_n , block_k ) if transpose_rhs else ( block_k , block_n ),
154+ lambda k : (ni , k ) if transpose_rhs else ( k , ni ),
149155 delay_release = 1 ,
150156 ),
151157 ],
@@ -233,68 +239,84 @@ def _():
233239
234240
235241def main (unused_argv ):
236- m , k , n , num_groups = 16 * 1024 , 2048 , 16 * 1024 , 16
237- kx , ky , kz = random .split (random .key (1234 ), num = 3 )
242+ for transpose_rhs in [False , True ]:
243+ m , k , n , num_groups = 16 * 1024 , 2048 , 16 * 1024 , 16
244+ kx , ky , kz = random .split (random .key (1234 ), num = 3 )
238245
239- lhs = jax .random .normal (kx , (m , k ), jnp .float16 )
240- rhs = jax .random .normal (ky , (num_groups , k , n ), jnp .float16 )
241- group_boundaries = jax .lax .sort (
242- jax .random .randint (kz , (num_groups - 1 ,), 0 , m , jnp .int32 )
243- )
244- group_starts = lax .concatenate (
245- [jnp .array ([0 ], dtype = jnp .int32 ), group_boundaries ], 0
246- )
247- group_ends = lax .concatenate (
248- [group_boundaries , jnp .array ([m ], dtype = jnp .int32 )], 0
249- )
250- group_sizes = group_ends - group_starts
251- assert group_sizes .shape == (num_groups ,)
252-
253- block_m = block_n = (64 , 128 , 192 )
254- block_k = (64 ,)
255- max_concurrent_steps = (2 , 4 , 5 , 6 )
256- grid_block_n = (1 , 2 , 4 , 8 , 16 )
257- configs = itertools .product (
258- block_m , block_n , block_k , max_concurrent_steps , grid_block_n
259- )
260- names = (
261- "block_m" , "block_n" , "block_k" , "max_concurrent_steps" , "grid_block_n"
262- )
263- best_runtime = float ("inf" )
264- best_kwargs = {}
265- for config in configs :
266- kwargs = dict (zip (names , config ))
267- if n % (kwargs ["grid_block_n" ] * kwargs ["block_n" ]):
268- continue
269- try :
270- f = functools .partial (ragged_dot , group_sizes = group_sizes , ** kwargs )
271- _ , runtime = profiler .measure (f )(lhs , rhs )
272- except ValueError as e :
273- if "Mosaic GPU kernel exceeds available shared memory" not in str (e ):
274- raise
275- runtime = float ("inf" )
276- # Enable this to get more detailed information.
246+ lhs = jax .random .normal (kx , (m , k ), jnp .float16 )
247+ if transpose_rhs :
248+ rhs = jax .random .normal (ky , (num_groups , n , k ), jnp .float16 )
277249 else :
278- print (" " .join (f"{ k } ={ v } " for k , v in kwargs .items ()), int (runtime * 1000 ))
279- if runtime < best_runtime : # pytype: disable=unsupported-operands
280- best_runtime = runtime
281- best_kwargs = kwargs
282- if not best_kwargs :
283- raise ValueError ("No valid configuration found" )
284-
285- ref , ref_runtime = profiler .measure (jax .lax .ragged_dot )(
286- lhs , rhs , group_sizes = group_sizes
287- )
288- result = ragged_dot (lhs , rhs , group_sizes = group_sizes , ** best_kwargs )
289- np .testing .assert_allclose (result , ref , atol = 1e-3 , rtol = 1e-3 )
250+ rhs = jax .random .normal (ky , (num_groups , k , n ), jnp .float16 )
251+ group_boundaries = jax .lax .sort (
252+ jax .random .randint (kz , (num_groups - 1 ,), 0 , m , jnp .int32 )
253+ )
254+ group_starts = lax .concatenate (
255+ [jnp .array ([0 ], dtype = jnp .int32 ), group_boundaries ], 0
256+ )
257+ group_ends = lax .concatenate (
258+ [group_boundaries , jnp .array ([m ], dtype = jnp .int32 )], 0
259+ )
260+ group_sizes = group_ends - group_starts
261+ assert group_sizes .shape == (num_groups ,)
262+
263+ block_m = block_n = (64 , 128 , 192 )
264+ block_k = (64 ,)
265+ max_concurrent_steps = (2 , 4 , 5 , 6 )
266+ grid_block_n = (1 , 2 , 4 , 8 , 16 )
267+ configs = itertools .product (
268+ block_m , block_n , block_k , max_concurrent_steps , grid_block_n
269+ )
270+ names = (
271+ "block_m" , "block_n" , "block_k" , "max_concurrent_steps" , "grid_block_n"
272+ )
273+ best_runtime = float ("inf" )
274+ best_kwargs = {}
275+ for config in configs :
276+ kwargs = dict (zip (names , config ))
277+ if n % (kwargs ["grid_block_n" ] * kwargs ["block_n" ]):
278+ continue
279+ try :
280+ f = functools .partial (
281+ ragged_dot , group_sizes = group_sizes , transpose_rhs = transpose_rhs ,
282+ ** kwargs
283+ )
284+ _ , runtime = profiler .measure (f )(lhs , rhs )
285+ except ValueError as e :
286+ if "Mosaic GPU kernel exceeds available shared memory" not in str (e ):
287+ raise
288+ runtime = float ("inf" )
289+ # Enable this to get more detailed information.
290+ else :
291+ print (" " .join (f"{ k } ={ v } " for k , v in kwargs .items ()), int (runtime * 1000 ))
292+ if runtime < best_runtime : # pytype: disable=unsupported-operands
293+ best_runtime = runtime
294+ best_kwargs = kwargs
295+ if not best_kwargs :
296+ raise ValueError ("No valid configuration found" )
297+
298+ def ref_ragged_dot (lhs , rhs , group_sizes ):
299+ if transpose_rhs :
300+ rhs = jnp .transpose (rhs , (0 , 2 , 1 ))
301+ return jax .lax .ragged_dot (lhs , rhs , group_sizes = group_sizes )
302+
303+ ref , ref_runtime = profiler .measure (ref_ragged_dot )(
304+ lhs , rhs , group_sizes = group_sizes
305+ )
306+ result = ragged_dot (
307+ lhs , rhs , group_sizes = group_sizes , transpose_rhs = transpose_rhs ,
308+ ** best_kwargs
309+ )
310+ np .testing .assert_allclose (result , ref , atol = 1e-3 , rtol = 1e-3 )
290311
291- tflops = float (2 * k * m * n ) / (best_runtime / 1e3 ) / 1e12
292- ref_tflops = float (2 * k * m * n ) / (ref_runtime / 1e3 ) / 1e12
293- print (
294- "Best parameters: " , " " .join (f"{ k } ={ v } " for k , v in best_kwargs .items ())
295- )
296- print (f"Kernel: { best_runtime * 1000 :.1f} us = { tflops :.1f} TFLOPS" )
297- print (f"Reference: { ref_runtime * 1000 :.1f} us = { ref_tflops :.1f} TFLOPS" )
312+ tflops = float (2 * k * m * n ) / (best_runtime / 1e3 ) / 1e12
313+ ref_tflops = float (2 * k * m * n ) / (ref_runtime / 1e3 ) / 1e12
314+ print (f"Transpose RHS: { transpose_rhs } " )
315+ print (
316+ "Best parameters: " , " " .join (f"{ k } ={ v } " for k , v in best_kwargs .items ())
317+ )
318+ print (f"Kernel: { best_runtime * 1000 :.1f} us = { tflops :.1f} TFLOPS" )
319+ print (f"Reference: { ref_runtime * 1000 :.1f} us = { ref_tflops :.1f} TFLOPS" )
298320
299321
300322if __name__ == "__main__" :
0 commit comments