11from __future__ import annotations
22
33from dataclasses import dataclass , field as dcfield
4- from typing import Any , NamedTuple
4+ from typing import Any , TypeVar
55
66import torch
77import triton
88import triton .language as tl
99from gguf import GGML_QUANT_SIZES , GGMLQuantizationType
1010
11+ C = TypeVar ("C" )
12+ def passthroughdecorator (c : C ) -> C :
13+ return c
14+
15+ nocompiledecorator = getattr (getattr (torch , "compiler" , None ), "disable" , None ) or passthroughdecorator
16+
1117TRITON_MAJOR , TRITON_MINOR = (
1218 int (part ) for part in triton .__version__ .split ("." , 3 )[:2 ]
1319)
1622# in 3.3 or 3.4 whether or not the method is decorated with staticmethod. Just afraid of that
1723# changing and breaking stuff in future versions. Triton 3.4+ can deal with the staticmethod decorator.
1824if TRITON_MAJOR == 3 and TRITON_MINOR <= 3 :
19-
20- def maybestaticmethod (c : Any ) -> Any :
21- return c
25+ maybestaticmethod = passthroughdecorator
2226elif TRITON_MAJOR == 3 and TRITON_MINOR >= 4 :
2327 maybestaticmethod = staticmethod
2428elif TRITON_MAJOR < 3 :
@@ -39,11 +43,7 @@ def maybestaticmethod(c: Any) -> Any:
3943 torch .float32 : tl .float32 ,
4044 torch .float16 : tl .float16 ,
4145 torch .bfloat16 : tl .bfloat16 ,
42- } | (
43- {torch .float8_e4m3fn : tl .float8e4nv }
44- if hasattr (torch , "float8_e4m3fn" ) and hasattr (tl , "float8e4nv" )
45- else {}
46- )
46+ }
4747
4848_DEFAULT_AUTOTUNE_CONFIGS : list [triton .Config ] = [
4949 triton .Config ({"N_BLOCKS_PER_PROG" : 1 }, num_warps = 2 ),
@@ -60,7 +60,7 @@ def maybestaticmethod(c: Any) -> Any:
6060_AUTOTUNE_CONFIGS : dict [str , list [triton .Config ]] = {}
6161
6262
63- @dataclass
63+ @dataclass ( frozen = True )
6464class KernelImpl :
6565 type_size : tl .constexpr
6666 block_size : tl .constexpr
@@ -70,14 +70,7 @@ def get_autotuner(self, **kwargs: dict) -> triton.runtime.Autotuner:
7070
7171 @maybestaticmethod
7272 @triton .jit
73- def dequantize_kernel (
74- q_tensor_ptr ,
75- out_tensor_ptr ,
76- n_total_blocks ,
77- DTYPE : tl .constexpr ,
78- N_BLOCKS_PER_PROG : tl .constexpr ,
79- CTX : tl .constexpr ,
80- ) -> None :
73+ def dequantize_kernel (q_tensor_ptr , out_tensor_ptr , n_total_blocks , DTYPE : tl .constexpr , N_BLOCKS_PER_PROG : tl .constexpr , CTX : tl .constexpr ) -> None :
8174 pid = tl .program_id (axis = 0 )
8275 start_block_idx = pid * N_BLOCKS_PER_PROG
8376 n_blocks = n_total_blocks - start_block_idx
@@ -94,24 +87,19 @@ def dequantize_kernel(
9487 CTX .value .dequantize_block_kernel (
9588 quantized_block_ptr ,
9689 output_ptr ,
97- CTX = CTX ,
90+ CTX = tl . constexpr ( CTX ) ,
9891 DTYPE = DTYPE ,
9992 )
10093
10194
102- class KernelDefinition ( NamedTuple ) :
95+ class KernelDefinition :
10396 qtype : GGMLQuantizationType
10497 block_size : int
10598 type_size : int
10699 kernel : KernelImpl
107100 autotuner_kernel : triton .runtime .Autotuner
108101
109- @classmethod
110- def build (
111- cls ,
112- qtype : GGMLQuantizationType ,
113- kernel_class : type [KernelImpl ],
114- ) -> "KernelDefinition" :
102+ def __init__ (self , qtype : GGMLQuantizationType , kernel_class : type [KernelImpl ]):
115103 block_size , type_size = GGML_QUANT_SIZES [qtype ]
116104 kernel_instance = kernel_class (
117105 block_size = tl .constexpr (block_size ),
@@ -123,22 +111,15 @@ def build(
123111 ),
124112 key = ["n_total_blocks" ],
125113 )
126- return cls (
127- qtype = qtype ,
128- block_size = block_size ,
129- type_size = type_size ,
130- kernel = kernel_instance ,
131- autotuner_kernel = autotuner_kernel ,
132- )
114+ self .qtype = qtype
115+ self .block_size = block_size
116+ self .type_size = type_size
117+ self .kernel = kernel_instance
118+ self .autotuner_kernel = autotuner_kernel
119+
133120
134- def __call__ (
135- self ,
136- blocks : torch .Tensor ,
137- block_size : int ,
138- type_size : int ,
139- dtype : torch .dtype | None = None ,
140- _math_dtype : tl .dtype | None = tl .float32 ,
141- ) -> torch .Tensor :
121+ @nocompiledecorator
122+ def __call__ (self , blocks : torch .Tensor , block_size : int , type_size : int , dtype : torch .dtype | None = None , _math_dtype : tl .dtype | None = tl .float32 ) -> torch .Tensor :
142123 qtype , ggml_type_size = self .qtype , self .type_size
143124 if blocks .dtype != torch .uint8 :
144125 if blocks .dtype == torch .int8 :
@@ -190,23 +171,18 @@ def grid(meta: dict[str, Any]) -> tuple[int]:
190171### K-quants
191172
192173
193- @dataclass
174+ @dataclass ( frozen = True )
194175class KernelImpl_K_Quant (KernelImpl ):
195176 k_scale_size : tl .constexpr = dcfield (
196177 default_factory = lambda : tl .constexpr (K_SCALE_SIZE )
197178 )
198179
199180
200- @dataclass
181+ @dataclass ( frozen = True )
201182class KernelImpl_Q2_K (KernelImpl_K_Quant ):
202183 @maybestaticmethod
203184 @triton .jit
204- def dequantize_block_kernel (
205- block_start_ptr ,
206- out_tensor_ptr ,
207- CTX : tl .constexpr ,
208- DTYPE : tl .constexpr ,
209- ) -> None :
185+ def dequantize_block_kernel (block_start_ptr , out_tensor_ptr , CTX : tl .constexpr , DTYPE : tl .constexpr ) -> None :
210186 # Vector of offsets for a 16-element chunk
211187 offsets_16 = tl .arange (0 , 16 )
212188
@@ -256,16 +232,11 @@ def dequantize_block_kernel(
256232 tl .store (output_ptr + offsets_16 , dequant_16 )
257233
258234
259- @dataclass
235+ @dataclass ( frozen = True )
260236class KernelImpl_Q3_K (KernelImpl_K_Quant ):
261237 @maybestaticmethod
262238 @triton .jit
263- def dequantize_block_kernel (
264- block_start_ptr ,
265- out_tensor_ptr ,
266- CTX : tl .constexpr ,
267- DTYPE : tl .constexpr ,
268- ) -> None :
239+ def dequantize_block_kernel (block_start_ptr , out_tensor_ptr , CTX : tl .constexpr , DTYPE : tl .constexpr ) -> None :
269240 # Vector of offsets for a 16-element chunk (one row of the output matrix)
270241 offsets_16 = tl .arange (0 , 16 )
271242
@@ -327,17 +298,12 @@ def dequantize_block_kernel(
327298 tl .store (output_ptr , dequant_16 )
328299
329300
330- @dataclass
301+ @dataclass ( frozen = True )
331302class KernelImpl_Q4_K (KernelImpl_K_Quant ):
332303 # Helper function, shared by Q4_K and Q5_K.
333304 @maybestaticmethod
334305 @triton .jit
335- def get_scales_min (
336- k_idx : int ,
337- d_sc_word : tl .tensor ,
338- m_word : tl .tensor ,
339- m_sc_word : tl .tensor ,
340- ) -> tl .tuple :
306+ def get_scales_min (k_idx : int , d_sc_word : tl .tensor , m_word : tl .tensor , m_sc_word : tl .tensor ) -> tl .tuple :
341307 if k_idx < 4 :
342308 k_idx_x8 = k_idx * 8
343309 d_sc_byte = d_sc_word >> k_idx_x8
@@ -355,12 +321,7 @@ def get_scales_min(
355321
356322 @maybestaticmethod
357323 @triton .jit
358- def dequantize_block_kernel (
359- block_start_ptr ,
360- out_tensor_ptr ,
361- CTX : tl .constexpr ,
362- DTYPE : tl .constexpr ,
363- ) -> None :
324+ def dequantize_block_kernel (block_start_ptr , out_tensor_ptr , CTX : tl .constexpr , DTYPE : tl .constexpr ) -> None :
364325 offsets_32 = tl .arange (0 , 32 )
365326 offsets_scale = offsets_32 + 4 + CTX .value .k_scale_size
366327
@@ -407,16 +368,11 @@ def dequantize_block_kernel(
407368 (output_chunk_ptr + 32 ).store (dequant_high )
408369
409370
410- @dataclass
371+ @dataclass ( frozen = True )
411372class KernelImpl_Q5_K (KernelImpl_Q4_K ):
412373 @maybestaticmethod
413374 @triton .jit
414- def dequantize_block_kernel (
415- block_start_ptr ,
416- out_tensor_ptr ,
417- CTX : tl .constexpr ,
418- DTYPE : tl .constexpr ,
419- ) -> None :
375+ def dequantize_block_kernel (block_start_ptr , out_tensor_ptr , CTX : tl .constexpr , DTYPE : tl .constexpr ) -> None :
420376 offsets_32 = tl .arange (0 , 32 )
421377 offsets_scale = offsets_32 + 4 + CTX .value .k_scale_size
422378
@@ -459,16 +415,11 @@ def dequantize_block_kernel(
459415 output_ptr .store (dequant_32 )
460416
461417
462- @dataclass
418+ @dataclass ( frozen = True )
463419class KernelImpl_Q6_K (KernelImpl_K_Quant ):
464420 @maybestaticmethod
465421 @triton .jit
466- def dequantize_block_kernel (
467- block_start_ptr ,
468- out_tensor_ptr ,
469- CTX : tl .constexpr ,
470- DTYPE : tl .constexpr ,
471- ) -> None :
422+ def dequantize_block_kernel (block_start_ptr , out_tensor_ptr , CTX : tl .constexpr , DTYPE : tl .constexpr ) -> None :
472423 offsets_32 = tl .arange (0 , 32 )
473424 mask_16 = offsets_32 < 16
474425
@@ -520,7 +471,7 @@ def dequantize_block_kernel(
520471### Legacy quants
521472
522473
523- @dataclass
474+ @dataclass ( frozen = True )
524475class KernelImpl_Legacy (KernelImpl ):
525476 @maybestaticmethod
526477 @triton .jit
@@ -535,16 +486,11 @@ def store_output(out_tensor_ptr, dequant_low, dequant_high) -> None:
535486 out_ptrs_high .store (dequant_high )
536487
537488
538- @dataclass
489+ @dataclass ( frozen = True )
539490class KernelImpl_Q4_0 (KernelImpl_Legacy ):
540491 @maybestaticmethod
541492 @triton .jit
542- def dequantize_block_kernel (
543- block_start_ptr ,
544- out_tensor_ptr ,
545- CTX : tl .constexpr ,
546- DTYPE : tl .constexpr ,
547- ) -> None :
493+ def dequantize_block_kernel (block_start_ptr , out_tensor_ptr , CTX : tl .constexpr , DTYPE : tl .constexpr ) -> None :
548494 # Vector of offsets for the 16 bytes of quantized data
549495 offsets_16 = tl .arange (0 , 16 )
550496
@@ -572,16 +518,11 @@ def dequantize_block_kernel(
572518 CTX .value .store_output (out_tensor_ptr , dequant_low , dequant_high )
573519
574520
575- @dataclass
521+ @dataclass ( frozen = True )
576522class KernelImpl_Q4_1 (KernelImpl_Legacy ):
577523 @maybestaticmethod
578524 @triton .jit
579- def dequantize_block_kernel (
580- block_start_ptr ,
581- out_tensor_ptr ,
582- CTX : tl .constexpr ,
583- DTYPE : tl .constexpr ,
584- ) -> None :
525+ def dequantize_block_kernel ( block_start_ptr , out_tensor_ptr , CTX : tl .constexpr , DTYPE : tl .constexpr ) -> None :
585526 # Vector of offsets for the 16 bytes of quantized data
586527 offsets_16 = tl .arange (0 , 16 )
587528
@@ -607,16 +548,11 @@ def dequantize_block_kernel(
607548 CTX .value .store_output (out_tensor_ptr , dequant_low , dequant_high )
608549
609550
610- @dataclass
551+ @dataclass ( frozen = True )
611552class KernelImpl_Q5_0 (KernelImpl_Legacy ):
612553 @maybestaticmethod
613554 @triton .jit
614- def dequantize_block_kernel (
615- block_start_ptr ,
616- out_tensor_ptr ,
617- CTX : tl .constexpr ,
618- DTYPE : tl .constexpr ,
619- ) -> None :
555+ def dequantize_block_kernel (block_start_ptr , out_tensor_ptr , CTX : tl .constexpr , DTYPE : tl .constexpr ) -> None :
620556 offsets_16 = tl .arange (0 , 16 )
621557 offsets_4 = tl .arange (0 , 4 )
622558
@@ -642,16 +578,11 @@ def dequantize_block_kernel(
642578 CTX .value .store_output (out_tensor_ptr , dequant_low , dequant_high )
643579
644580
645- @dataclass
581+ @dataclass ( frozen = True )
646582class KernelImpl_Q5_1 (KernelImpl_Legacy ):
647583 @maybestaticmethod
648584 @triton .jit
649- def dequantize_block_kernel (
650- block_start_ptr ,
651- out_tensor_ptr ,
652- CTX : tl .constexpr ,
653- DTYPE : tl .constexpr ,
654- ) -> None :
585+ def dequantize_block_kernel (block_start_ptr , out_tensor_ptr , CTX : tl .constexpr , DTYPE : tl .constexpr ) -> None :
655586 offsets_16 = tl .arange (0 , 16 )
656587
657588 # Data layout: 2 bytes 'd', 2 bytes 'm', 4 bytes 'qh', 16 bytes 'qs'
@@ -684,16 +615,11 @@ def dequantize_block_kernel(
684615 CTX .value .store_output (out_tensor_ptr , dequant_low , dequant_high )
685616
686617
687- @dataclass
618+ @dataclass ( frozen = True )
688619class KernelImpl_Q8_0 (KernelImpl_Legacy ):
689620 @maybestaticmethod
690621 @triton .jit
691- def dequantize_block_kernel (
692- block_start_ptr ,
693- out_tensor_ptr ,
694- CTX : tl .constexpr ,
695- DTYPE : tl .constexpr ,
696- ) -> None :
622+ def dequantize_block_kernel (block_start_ptr , out_tensor_ptr , CTX : tl .constexpr , DTYPE : tl .constexpr ) -> None :
697623 offsets_32 = tl .arange (0 , 32 )
698624 d_ptr = block_start_ptr .to (tl .pointer_type (tl .float16 ), bitcast = True ) + 0
699625 x_ptr = (
@@ -707,17 +633,17 @@ def dequantize_block_kernel(
707633
708634dequantize_functions : dict [GGMLQuantizationType , KernelDefinition ] = {
709635 # Legancy quants
710- GQT .Q4_0 : KernelDefinition . build (GQT .Q4_0 , KernelImpl_Q4_0 ),
711- GQT .Q4_1 : KernelDefinition . build (GQT .Q4_1 , KernelImpl_Q4_1 ),
712- GQT .Q5_0 : KernelDefinition . build (GQT .Q5_0 , KernelImpl_Q5_0 ),
713- GQT .Q5_1 : KernelDefinition . build (GQT .Q5_1 , KernelImpl_Q5_1 ),
714- GQT .Q8_0 : KernelDefinition . build (GQT .Q8_0 , KernelImpl_Q8_0 ),
636+ GQT .Q4_0 : KernelDefinition (GQT .Q4_0 , KernelImpl_Q4_0 ),
637+ GQT .Q4_1 : KernelDefinition (GQT .Q4_1 , KernelImpl_Q4_1 ),
638+ GQT .Q5_0 : KernelDefinition (GQT .Q5_0 , KernelImpl_Q5_0 ),
639+ GQT .Q5_1 : KernelDefinition (GQT .Q5_1 , KernelImpl_Q5_1 ),
640+ GQT .Q8_0 : KernelDefinition (GQT .Q8_0 , KernelImpl_Q8_0 ),
715641 # K-quants
716- GQT .Q2_K : KernelDefinition . build (GQT .Q2_K , KernelImpl_Q2_K ),
717- GQT .Q3_K : KernelDefinition . build (GQT .Q3_K , KernelImpl_Q3_K ),
718- GQT .Q4_K : KernelDefinition . build (GQT .Q4_K , KernelImpl_Q4_K ),
719- GQT .Q5_K : KernelDefinition . build (GQT .Q5_K , KernelImpl_Q5_K ),
720- GQT .Q6_K : KernelDefinition . build (GQT .Q6_K , KernelImpl_Q6_K ),
642+ GQT .Q2_K : KernelDefinition (GQT .Q2_K , KernelImpl_Q2_K ),
643+ GQT .Q3_K : KernelDefinition (GQT .Q3_K , KernelImpl_Q3_K ),
644+ GQT .Q4_K : KernelDefinition (GQT .Q4_K , KernelImpl_Q4_K ),
645+ GQT .Q5_K : KernelDefinition (GQT .Q5_K , KernelImpl_Q5_K ),
646+ GQT .Q6_K : KernelDefinition (GQT .Q6_K , KernelImpl_Q6_K ),
721647}
722648
723649__all__ = ("dequantize_functions" ,)
0 commit comments