Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions exllamav3/loader/safetensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def get_tensor(
transpose: bool = False,
pad_to: tuple = None,
fidx: int = None,
auto_transpose_to_pad: bool = False,
) -> torch.Tensor | None:

if device is None:
Expand Down Expand Up @@ -297,6 +298,16 @@ def get_tensor(
end = beg + esize * numel
bytesize = end - beg

if auto_transpose_to_pad and pad_to is not None and len(shape) == len(pad_to) == 2:
# Some resaved checkpoints preserve the same architecture but store 2D linear weights
# in the opposite orientation. Only flip when the configured orientation does not fit.
shape_current = tuple(shape) if not transpose else (shape[1], shape[0])
shape_alt = (shape[1], shape[0]) if not transpose else tuple(shape)
current_fits = all(a <= b for a, b in zip(shape_current, pad_to))
alt_fits = all(a <= b for a, b in zip(shape_alt, pad_to))
if not current_fits and alt_fits:
transpose = not transpose

load_method = self.load_method
if load_method == "mt_fread" and self.deferred_mode and not no_defer:
load_method = "defer"
Expand Down
10 changes: 9 additions & 1 deletion exllamav3/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,15 @@ def load_fp16(self, key: str | list[str]) -> bool:
scale_inv = self.config.stc.get_tensor(key + ".weight_scale_inv", dev, transpose = self.transposed_load, optional = True, no_defer = True)
assert scale is None or scale_inv is None
no_defer = scale is not None or scale_inv is not None
weight = self.config.stc.get_tensor(key + ".weight", dev, float2half = True, transpose = self.transposed_load, pad_to = pad2, no_defer = no_defer)
weight = self.config.stc.get_tensor(
key + ".weight",
dev,
float2half = True,
transpose = self.transposed_load,
pad_to = pad2,
no_defer = no_defer,
auto_transpose_to_pad = True,
)
bias = self.config.stc.get_tensor(key + ".bias", dev, float2half = True, optional = True, pad_to = pad1)
if scale is not None:
weight = self.apply_fp8_scales_(weight, scale)
Expand Down
48 changes: 48 additions & 0 deletions tests/test_checkpoint_layout_2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os
import sys

import torch
from safetensors.torch import save_file

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from exllamav3.loader.safetensors import SafetensorsCollection


def test_auto_transpose_to_pad_flips_only_when_needed(tmp_path):
weight = torch.arange(6, dtype = torch.float16).reshape(2, 3)
save_file({"layer.weight": weight}, tmp_path / "model.safetensors")

stc = SafetensorsCollection(str(tmp_path), load_method = "python")
loaded = stc.get_tensor(
"layer.weight",
device = "cpu",
float2half = True,
transpose = False,
pad_to = (3, 2),
auto_transpose_to_pad = True,
)

assert loaded.shape == (3, 2)
torch.testing.assert_close(loaded, weight.T.contiguous())


def test_auto_transpose_to_pad_keeps_current_orientation_when_it_already_fits(tmp_path):
weight = torch.arange(6, dtype = torch.float16).reshape(2, 3)
save_file({"layer.weight": weight}, tmp_path / "model.safetensors")

stc = SafetensorsCollection(str(tmp_path), load_method = "python")
loaded = stc.get_tensor(
"layer.weight",
device = "cpu",
float2half = True,
transpose = False,
pad_to = (3, 4),
auto_transpose_to_pad = True,
)

expected = torch.zeros((3, 4), dtype = torch.float16)
expected[:2, :3] = weight

assert loaded.shape == (3, 4)
torch.testing.assert_close(loaded, expected)