Skip to content

Commit 177f2f3

Browse files
committed
Fix compiling when Triton is enabled
Cleanups/style fixes
1 parent 5ab6901 commit 177f2f3

1 file changed

Lines changed: 55 additions & 129 deletions

File tree

dequant_triton.py

Lines changed: 55 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
from __future__ import annotations
22

33
from dataclasses import dataclass, field as dcfield
4-
from typing import Any, NamedTuple
4+
from typing import Any, TypeVar
55

66
import torch
77
import triton
88
import triton.language as tl
99
from gguf import GGML_QUANT_SIZES, GGMLQuantizationType
1010

11+
C = TypeVar("C")
12+
def passthroughdecorator(c: C) -> C:
13+
return c
14+
15+
nocompiledecorator = getattr(getattr(torch, "compiler", None), "disable", None) or passthroughdecorator
16+
1117
TRITON_MAJOR, TRITON_MINOR = (
1218
int(part) for part in triton.__version__.split(".", 3)[:2]
1319
)
@@ -16,9 +22,7 @@
1622
# in 3.3 or 3.4 whether or not the method is decorated with staticmethod. Just afraid of that
1723
# changing and breaking stuff in future versions. Triton 3.4+ can deal with the staticmethod decorator.
1824
if TRITON_MAJOR == 3 and TRITON_MINOR <= 3:
19-
20-
def maybestaticmethod(c: Any) -> Any:
21-
return c
25+
maybestaticmethod = passthroughdecorator
2226
elif TRITON_MAJOR == 3 and TRITON_MINOR >= 4:
2327
maybestaticmethod = staticmethod
2428
elif TRITON_MAJOR < 3:
@@ -39,11 +43,7 @@ def maybestaticmethod(c: Any) -> Any:
3943
torch.float32: tl.float32,
4044
torch.float16: tl.float16,
4145
torch.bfloat16: tl.bfloat16,
42-
} | (
43-
{torch.float8_e4m3fn: tl.float8e4nv}
44-
if hasattr(torch, "float8_e4m3fn") and hasattr(tl, "float8e4nv")
45-
else {}
46-
)
46+
}
4747

4848
_DEFAULT_AUTOTUNE_CONFIGS: list[triton.Config] = [
4949
triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=2),
@@ -60,7 +60,7 @@ def maybestaticmethod(c: Any) -> Any:
6060
_AUTOTUNE_CONFIGS: dict[str, list[triton.Config]] = {}
6161

6262

63-
@dataclass
63+
@dataclass(frozen=True)
6464
class KernelImpl:
6565
type_size: tl.constexpr
6666
block_size: tl.constexpr
@@ -70,14 +70,7 @@ def get_autotuner(self, **kwargs: dict) -> triton.runtime.Autotuner:
7070

7171
@maybestaticmethod
7272
@triton.jit
73-
def dequantize_kernel(
74-
q_tensor_ptr,
75-
out_tensor_ptr,
76-
n_total_blocks,
77-
DTYPE: tl.constexpr,
78-
N_BLOCKS_PER_PROG: tl.constexpr,
79-
CTX: tl.constexpr,
80-
) -> None:
73+
def dequantize_kernel(q_tensor_ptr, out_tensor_ptr, n_total_blocks, DTYPE: tl.constexpr, N_BLOCKS_PER_PROG: tl.constexpr, CTX: tl.constexpr) -> None:
8174
pid = tl.program_id(axis=0)
8275
start_block_idx = pid * N_BLOCKS_PER_PROG
8376
n_blocks = n_total_blocks - start_block_idx
@@ -94,24 +87,19 @@ def dequantize_kernel(
9487
CTX.value.dequantize_block_kernel(
9588
quantized_block_ptr,
9689
output_ptr,
97-
CTX=CTX,
90+
CTX=tl.constexpr(CTX),
9891
DTYPE=DTYPE,
9992
)
10093

10194

102-
class KernelDefinition(NamedTuple):
95+
class KernelDefinition:
10396
qtype: GGMLQuantizationType
10497
block_size: int
10598
type_size: int
10699
kernel: KernelImpl
107100
autotuner_kernel: triton.runtime.Autotuner
108101

109-
@classmethod
110-
def build(
111-
cls,
112-
qtype: GGMLQuantizationType,
113-
kernel_class: type[KernelImpl],
114-
) -> "KernelDefinition":
102+
def __init__(self, qtype: GGMLQuantizationType, kernel_class: type[KernelImpl]):
115103
block_size, type_size = GGML_QUANT_SIZES[qtype]
116104
kernel_instance = kernel_class(
117105
block_size=tl.constexpr(block_size),
@@ -123,22 +111,15 @@ def build(
123111
),
124112
key=["n_total_blocks"],
125113
)
126-
return cls(
127-
qtype=qtype,
128-
block_size=block_size,
129-
type_size=type_size,
130-
kernel=kernel_instance,
131-
autotuner_kernel=autotuner_kernel,
132-
)
114+
self.qtype = qtype
115+
self.block_size = block_size
116+
self.type_size = type_size
117+
self.kernel = kernel_instance
118+
self.autotuner_kernel = autotuner_kernel
119+
133120

134-
def __call__(
135-
self,
136-
blocks: torch.Tensor,
137-
block_size: int,
138-
type_size: int,
139-
dtype: torch.dtype | None = None,
140-
_math_dtype: tl.dtype | None = tl.float32,
141-
) -> torch.Tensor:
121+
@nocompiledecorator
122+
def __call__(self, blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None, _math_dtype: tl.dtype | None = tl.float32) -> torch.Tensor:
142123
qtype, ggml_type_size = self.qtype, self.type_size
143124
if blocks.dtype != torch.uint8:
144125
if blocks.dtype == torch.int8:
@@ -190,23 +171,18 @@ def grid(meta: dict[str, Any]) -> tuple[int]:
190171
### K-quants
191172

192173

193-
@dataclass
174+
@dataclass(frozen=True)
194175
class KernelImpl_K_Quant(KernelImpl):
195176
k_scale_size: tl.constexpr = dcfield(
196177
default_factory=lambda: tl.constexpr(K_SCALE_SIZE)
197178
)
198179

199180

200-
@dataclass
181+
@dataclass(frozen=True)
201182
class KernelImpl_Q2_K(KernelImpl_K_Quant):
202183
@maybestaticmethod
203184
@triton.jit
204-
def dequantize_block_kernel(
205-
block_start_ptr,
206-
out_tensor_ptr,
207-
CTX: tl.constexpr,
208-
DTYPE: tl.constexpr,
209-
) -> None:
185+
def dequantize_block_kernel(block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, DTYPE: tl.constexpr) -> None:
210186
# Vector of offsets for a 16-element chunk
211187
offsets_16 = tl.arange(0, 16)
212188

@@ -256,16 +232,11 @@ def dequantize_block_kernel(
256232
tl.store(output_ptr + offsets_16, dequant_16)
257233

258234

259-
@dataclass
235+
@dataclass(frozen=True)
260236
class KernelImpl_Q3_K(KernelImpl_K_Quant):
261237
@maybestaticmethod
262238
@triton.jit
263-
def dequantize_block_kernel(
264-
block_start_ptr,
265-
out_tensor_ptr,
266-
CTX: tl.constexpr,
267-
DTYPE: tl.constexpr,
268-
) -> None:
239+
def dequantize_block_kernel(block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, DTYPE: tl.constexpr) -> None:
269240
# Vector of offsets for a 16-element chunk (one row of the output matrix)
270241
offsets_16 = tl.arange(0, 16)
271242

@@ -327,17 +298,12 @@ def dequantize_block_kernel(
327298
tl.store(output_ptr, dequant_16)
328299

329300

330-
@dataclass
301+
@dataclass(frozen=True)
331302
class KernelImpl_Q4_K(KernelImpl_K_Quant):
332303
# Helper function, shared by Q4_K and Q5_K.
333304
@maybestaticmethod
334305
@triton.jit
335-
def get_scales_min(
336-
k_idx: int,
337-
d_sc_word: tl.tensor,
338-
m_word: tl.tensor,
339-
m_sc_word: tl.tensor,
340-
) -> tl.tuple:
306+
def get_scales_min(k_idx: int, d_sc_word: tl.tensor, m_word: tl.tensor, m_sc_word: tl.tensor) -> tl.tuple:
341307
if k_idx < 4:
342308
k_idx_x8 = k_idx * 8
343309
d_sc_byte = d_sc_word >> k_idx_x8
@@ -355,12 +321,7 @@ def get_scales_min(
355321

356322
@maybestaticmethod
357323
@triton.jit
358-
def dequantize_block_kernel(
359-
block_start_ptr,
360-
out_tensor_ptr,
361-
CTX: tl.constexpr,
362-
DTYPE: tl.constexpr,
363-
) -> None:
324+
def dequantize_block_kernel(block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, DTYPE: tl.constexpr) -> None:
364325
offsets_32 = tl.arange(0, 32)
365326
offsets_scale = offsets_32 + 4 + CTX.value.k_scale_size
366327

@@ -407,16 +368,11 @@ def dequantize_block_kernel(
407368
(output_chunk_ptr + 32).store(dequant_high)
408369

409370

410-
@dataclass
371+
@dataclass(frozen=True)
411372
class KernelImpl_Q5_K(KernelImpl_Q4_K):
412373
@maybestaticmethod
413374
@triton.jit
414-
def dequantize_block_kernel(
415-
block_start_ptr,
416-
out_tensor_ptr,
417-
CTX: tl.constexpr,
418-
DTYPE: tl.constexpr,
419-
) -> None:
375+
def dequantize_block_kernel(block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, DTYPE: tl.constexpr) -> None:
420376
offsets_32 = tl.arange(0, 32)
421377
offsets_scale = offsets_32 + 4 + CTX.value.k_scale_size
422378

@@ -459,16 +415,11 @@ def dequantize_block_kernel(
459415
output_ptr.store(dequant_32)
460416

461417

462-
@dataclass
418+
@dataclass(frozen=True)
463419
class KernelImpl_Q6_K(KernelImpl_K_Quant):
464420
@maybestaticmethod
465421
@triton.jit
466-
def dequantize_block_kernel(
467-
block_start_ptr,
468-
out_tensor_ptr,
469-
CTX: tl.constexpr,
470-
DTYPE: tl.constexpr,
471-
) -> None:
422+
def dequantize_block_kernel(block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, DTYPE: tl.constexpr) -> None:
472423
offsets_32 = tl.arange(0, 32)
473424
mask_16 = offsets_32 < 16
474425

@@ -520,7 +471,7 @@ def dequantize_block_kernel(
520471
### Legacy quants
521472

522473

523-
@dataclass
474+
@dataclass(frozen=True)
524475
class KernelImpl_Legacy(KernelImpl):
525476
@maybestaticmethod
526477
@triton.jit
@@ -535,16 +486,11 @@ def store_output(out_tensor_ptr, dequant_low, dequant_high) -> None:
535486
out_ptrs_high.store(dequant_high)
536487

537488

538-
@dataclass
489+
@dataclass(frozen=True)
539490
class KernelImpl_Q4_0(KernelImpl_Legacy):
540491
@maybestaticmethod
541492
@triton.jit
542-
def dequantize_block_kernel(
543-
block_start_ptr,
544-
out_tensor_ptr,
545-
CTX: tl.constexpr,
546-
DTYPE: tl.constexpr,
547-
) -> None:
493+
def dequantize_block_kernel(block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, DTYPE: tl.constexpr) -> None:
548494
# Vector of offsets for the 16 bytes of quantized data
549495
offsets_16 = tl.arange(0, 16)
550496

@@ -572,16 +518,11 @@ def dequantize_block_kernel(
572518
CTX.value.store_output(out_tensor_ptr, dequant_low, dequant_high)
573519

574520

575-
@dataclass
521+
@dataclass(frozen=True)
576522
class KernelImpl_Q4_1(KernelImpl_Legacy):
577523
@maybestaticmethod
578524
@triton.jit
579-
def dequantize_block_kernel(
580-
block_start_ptr,
581-
out_tensor_ptr,
582-
CTX: tl.constexpr,
583-
DTYPE: tl.constexpr,
584-
) -> None:
525+
def dequantize_block_kernel( block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, DTYPE: tl.constexpr) -> None:
585526
# Vector of offsets for the 16 bytes of quantized data
586527
offsets_16 = tl.arange(0, 16)
587528

@@ -607,16 +548,11 @@ def dequantize_block_kernel(
607548
CTX.value.store_output(out_tensor_ptr, dequant_low, dequant_high)
608549

609550

610-
@dataclass
551+
@dataclass(frozen=True)
611552
class KernelImpl_Q5_0(KernelImpl_Legacy):
612553
@maybestaticmethod
613554
@triton.jit
614-
def dequantize_block_kernel(
615-
block_start_ptr,
616-
out_tensor_ptr,
617-
CTX: tl.constexpr,
618-
DTYPE: tl.constexpr,
619-
) -> None:
555+
def dequantize_block_kernel(block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, DTYPE: tl.constexpr) -> None:
620556
offsets_16 = tl.arange(0, 16)
621557
offsets_4 = tl.arange(0, 4)
622558

@@ -642,16 +578,11 @@ def dequantize_block_kernel(
642578
CTX.value.store_output(out_tensor_ptr, dequant_low, dequant_high)
643579

644580

645-
@dataclass
581+
@dataclass(frozen=True)
646582
class KernelImpl_Q5_1(KernelImpl_Legacy):
647583
@maybestaticmethod
648584
@triton.jit
649-
def dequantize_block_kernel(
650-
block_start_ptr,
651-
out_tensor_ptr,
652-
CTX: tl.constexpr,
653-
DTYPE: tl.constexpr,
654-
) -> None:
585+
def dequantize_block_kernel(block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, DTYPE: tl.constexpr) -> None:
655586
offsets_16 = tl.arange(0, 16)
656587

657588
# Data layout: 2 bytes 'd', 2 bytes 'm', 4 bytes 'qh', 16 bytes 'qs'
@@ -684,16 +615,11 @@ def dequantize_block_kernel(
684615
CTX.value.store_output(out_tensor_ptr, dequant_low, dequant_high)
685616

686617

687-
@dataclass
618+
@dataclass(frozen=True)
688619
class KernelImpl_Q8_0(KernelImpl_Legacy):
689620
@maybestaticmethod
690621
@triton.jit
691-
def dequantize_block_kernel(
692-
block_start_ptr,
693-
out_tensor_ptr,
694-
CTX: tl.constexpr,
695-
DTYPE: tl.constexpr,
696-
) -> None:
622+
def dequantize_block_kernel(block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, DTYPE: tl.constexpr) -> None:
697623
offsets_32 = tl.arange(0, 32)
698624
d_ptr = block_start_ptr.to(tl.pointer_type(tl.float16), bitcast=True) + 0
699625
x_ptr = (
@@ -707,17 +633,17 @@ def dequantize_block_kernel(
707633

708634
dequantize_functions: dict[GGMLQuantizationType, KernelDefinition] = {
709635
# Legancy quants
710-
GQT.Q4_0: KernelDefinition.build(GQT.Q4_0, KernelImpl_Q4_0),
711-
GQT.Q4_1: KernelDefinition.build(GQT.Q4_1, KernelImpl_Q4_1),
712-
GQT.Q5_0: KernelDefinition.build(GQT.Q5_0, KernelImpl_Q5_0),
713-
GQT.Q5_1: KernelDefinition.build(GQT.Q5_1, KernelImpl_Q5_1),
714-
GQT.Q8_0: KernelDefinition.build(GQT.Q8_0, KernelImpl_Q8_0),
636+
GQT.Q4_0: KernelDefinition(GQT.Q4_0, KernelImpl_Q4_0),
637+
GQT.Q4_1: KernelDefinition(GQT.Q4_1, KernelImpl_Q4_1),
638+
GQT.Q5_0: KernelDefinition(GQT.Q5_0, KernelImpl_Q5_0),
639+
GQT.Q5_1: KernelDefinition(GQT.Q5_1, KernelImpl_Q5_1),
640+
GQT.Q8_0: KernelDefinition(GQT.Q8_0, KernelImpl_Q8_0),
715641
# K-quants
716-
GQT.Q2_K: KernelDefinition.build(GQT.Q2_K, KernelImpl_Q2_K),
717-
GQT.Q3_K: KernelDefinition.build(GQT.Q3_K, KernelImpl_Q3_K),
718-
GQT.Q4_K: KernelDefinition.build(GQT.Q4_K, KernelImpl_Q4_K),
719-
GQT.Q5_K: KernelDefinition.build(GQT.Q5_K, KernelImpl_Q5_K),
720-
GQT.Q6_K: KernelDefinition.build(GQT.Q6_K, KernelImpl_Q6_K),
642+
GQT.Q2_K: KernelDefinition(GQT.Q2_K, KernelImpl_Q2_K),
643+
GQT.Q3_K: KernelDefinition(GQT.Q3_K, KernelImpl_Q3_K),
644+
GQT.Q4_K: KernelDefinition(GQT.Q4_K, KernelImpl_Q4_K),
645+
GQT.Q5_K: KernelDefinition(GQT.Q5_K, KernelImpl_Q5_K),
646+
GQT.Q6_K: KernelDefinition(GQT.Q6_K, KernelImpl_Q6_K),
721647
}
722648

723649
__all__ = ("dequantize_functions",)

0 commit comments

Comments
 (0)