Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ Or to install with MLX support, run:
pip install genlm-backend[mlx]
```


## 🧪 Example: Autobatched Sequential Importance Sampling with LLMs

This example demonstrates how `genlm-backend` enables concise, scalable probabilistic inference with language models. It implements a Sequential Importance Sampling (SIS) algorithm that makes asynchronous log-probabality requests which get automatically batched by the language model.
Expand Down
36 changes: 36 additions & 0 deletions genlm/backend/llm/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,42 @@ def cache_kv(self, prompt_tokens):
result = self.model(torch.tensor([prompt_tokens]).to(self.device))
node = self.cache.extend_cache(0, prompt_tokens, result.logits[0], 0)
node.past_key_values = result.past_key_values

def load_lora(self, lora_path, lora_name='lora_1'):
"""Load a LoRA adapter into the base model.

Args:
lora_path (str): Path to the adapter weights directory or identifier in HuggingFace's model hub.
lora_name (str): Name to assign to the loaded adapter.

Notes:
This does not activate the adapter immediately. Call `set_lora()` to enable the adapter.
"""
if lora_path is None:
raise ImportError(
"You should set your lora directory path to load lora."
)
else:
self.model.load_adapter(lora_path, lora_name)

def set_lora(self, lora_name='lora_1'):
"""Activate a previously loaded LoRA adapter.

Args:
lora_name (str): Name of the LoRA adapter to activate.

"""
self.clear_kv_cache()
self.clear_cache()
self.model.set_adapter(lora_name)

def clear_lora(self):
"""
Deactivate all LoRA adapters.
"""
self.clear_kv_cache()
self.clear_cache()
self.model.set_adapter([])

@torch.no_grad()
def batch_evaluate_queries(self):
Expand Down
23 changes: 22 additions & 1 deletion genlm/backend/llm/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

try:
from vllm import AsyncLLMEngine, SamplingParams, AsyncEngineArgs
from vllm.lora.request import LoRARequest
from vllm.utils import Counter
from vllm.inputs import TokensPrompt

Expand Down Expand Up @@ -81,6 +82,7 @@ def __init__(self, async_llm_engine, cache_size=0, cache_opts={}):
if cache_size > 0
else None
)
self.lora_request = None

async_llm_engine.engine.log_stats = False

Expand Down Expand Up @@ -128,6 +130,22 @@ def from_name(cls, model_name, engine_opts=None, **kwargs):
def underlying_model(self):
return self.async_llm_engine.engine.model_executor.driver_worker.model_runner.model

def clear_lora(self):
"""
Disable any active LoRA adapter for the vLLM engine.
"""
self.lora_request = None

def set_lora(self, lora_path, lora_name="current_lora", lora_id=1):
"""Configure a LoRA adapter request for the vLLM engine.

Args:
lora_path (str): Path to the adapter weights directory or identifier in HuggingFace's model hub.
lora_name (str): Identifier name to associate with this LoRA adapter within vLLM.
lora_id (int): Globally unique ID for the adapter.
"""
self.lora_request = LoRARequest(lora_name, lora_id, lora_path)

async def next_token_logprobs(self, token_ids):
"""Request log probabilities of next token asynchronously with output caching.

Expand Down Expand Up @@ -172,6 +190,7 @@ async def _next_token_logprobs(self, token_ids):
sampling_params=SamplingParams(
**self.default_params, logits_processors=[processor]
),
lora_request=self.lora_request,
request_id=req_id,
):
if output.finished:
Expand Down Expand Up @@ -215,11 +234,12 @@ def batch_next_token_logprobs_sync(self, token_ids_list):
params=SamplingParams(
**self.default_params, logits_processors=[processor]
),
lora_request=self.lora_request,
request_id=req_id,
)

while self.async_llm_engine.engine.has_unfinished_requests():
output = self.async_llm_engine.engine.step()
output = self.async_llm_engine.engine.step()
for out in output:
if out.finished:
assert out.request_id in req_id2processors, (
Expand Down Expand Up @@ -275,6 +295,7 @@ async def sample(
seed=seed,
stop=[self.byte_vocab[i].decode() for i in eos_token_ids],
),
lora_request=self.lora_request,
request_id=str(next(self.request_counter)),
):
if output.finished:
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ dependencies = [
"numba",
"vllm>=0.6.6,<=0.10.0; sys_platform == 'linux'",
"triton>=3.2.0; sys_platform == 'linux'",
"peft"
]

[project.optional-dependencies]
mlx = [
"mlx",
"mlx-lm"
]

docs = [
"mkdocs",
"mkdocstrings[python]",
Expand Down
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
destroy_model_parallel,
destroy_distributed_environment,
)
from vllm.lora.request import LoRARequest

HAS_VLLM = True
except ImportError:
Expand Down Expand Up @@ -142,6 +143,7 @@ def __init__(self, llm):
stop=None,
ignore_eos=True,
)
self.lora_request = None

self.llm.llm_engine.log_stats = False

Expand All @@ -158,11 +160,18 @@ def from_name(cls, model_name, llm_opts=None):
llm = LLM(model=model_name, tokenizer=model_name, **llm_opts)
return cls(llm)

def clear_lora(self):
self.lora_request = None

def set_lora(self, lora_path, lora_name="current_lora", lora_id=1):
self.lora_request = LoRARequest(lora_name, lora_id, lora_path)

def next_token_logprobs_sync(self, token_ids):
outputs = self.llm.generate(
prompts=TokensPrompt(prompt_token_ids=token_ids),
sampling_params=self.DEFAULT_SAMPLING_PARAMS,
use_tqdm=False,
lora_request=self.lora_request
)
logprobs = np.array(
[
Expand All @@ -185,6 +194,7 @@ async def batch_next_token_logprobs(self, token_ids_list):
prompts=prompts,
sampling_params=self.DEFAULT_SAMPLING_PARAMS,
use_tqdm=False,
lora_request=self.lora_request
)
logprobs = np.array(
[
Expand Down
206 changes: 206 additions & 0 deletions tests/test_hf_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import pytest
import asyncio
import torch
from conftest import cuda_only
from arsenal.maths import compare
from genlm.backend.llm import load_model_by_name

@pytest.fixture(scope="module")
def model_name():
return "HuggingFaceTB/SmolLM-135M"

@pytest.fixture(scope="module")
def merged_path():
return 'vxef/smol_merged_toy'

@pytest.fixture(scope="module")
def lora_path():
return "vxef/smol_lora_toy"

@pytest.fixture(scope="module")
def transformer_merged_llm(merged_path):
return load_model_by_name(
merged_path, backend="hf", llm_opts={"hf_opts": {"torch_dtype": torch.float32}}
)

@pytest.fixture(scope="module")
def transformer_llm(model_name):
return load_model_by_name(
model_name, backend="hf", llm_opts={"hf_opts": {"torch_dtype": torch.float32}}
)

@pytest.fixture(scope="module")
def transformer_llm_nolora(model_name):
return load_model_by_name(
model_name, backend="hf", llm_opts={"hf_opts": {"torch_dtype": torch.float32}}
)

@pytest.fixture(scope="module", autouse=True)
def load_lora(transformer_llm, lora_path):
transformer_llm.load_lora(lora_path, 'lora_1')
transformer_llm.set_lora(lora_name='lora_1')


@pytest.fixture(scope="module")
def token_ids_list(transformer_llm):
test_prompts = [
"There might be something wrong",
"with the language model code",
"It's probably this or that",
"with the language model code",
]
return [transformer_llm.tokenizer.encode(p) for p in test_prompts]

@cuda_only
def test_transformer_llm(transformer_llm):
assert transformer_llm is not None

@cuda_only
def test_transformer_merged_llm(transformer_merged_llm):
assert transformer_merged_llm is not None

@cuda_only
def test_next_token_logprobs_lora_uncached(transformer_llm, transformer_merged_llm, token_ids_list):
for token_ids in token_ids_list:
unmerged_logprobs = transformer_llm.next_token_logprobs_uncached(token_ids).cpu().numpy()
merged_logprobs = transformer_merged_llm.next_token_logprobs_uncached(token_ids).cpu().numpy()
assert compare(unmerged_logprobs, merged_logprobs).max_rel_err < 1e-3, token_ids

@cuda_only
def test_next_token_logprobs_lora(transformer_llm, transformer_merged_llm, token_ids_list):
for token_ids in token_ids_list:
unmerged_logprobs = asyncio.run(transformer_llm.next_token_logprobs(token_ids)).cpu().numpy()
merged_logprobs = asyncio.run(transformer_merged_llm.next_token_logprobs(token_ids)).cpu().numpy()
assert compare(unmerged_logprobs, merged_logprobs).max_rel_err < 1e-3, token_ids

@cuda_only
def test_token_logprobs_lora_sync(transformer_llm, transformer_merged_llm, token_ids_list):
unmerged_logprobs = [transformer_llm.next_token_logprobs_sync(token_ids).cpu().numpy() for token_ids in token_ids_list]
merged_logprobs = [transformer_merged_llm.next_token_logprobs_sync(token_ids).cpu().numpy() for token_ids in token_ids_list]

for i, (unmerged_logprob, merged_logprob) in enumerate(zip(unmerged_logprobs, merged_logprobs)):
assert compare(unmerged_logprob, merged_logprob).max_rel_err < 1e-3, token_ids_list[i]

@cuda_only
def test_batch_token_logprobs_lora(transformer_llm, transformer_merged_llm, token_ids_list):
unmerged_logprobs = (
asyncio.run(transformer_llm.batch_next_token_logprobs(token_ids_list)).cpu().numpy()
)
merged_logprobs = (
asyncio.run(transformer_merged_llm.batch_next_token_logprobs(token_ids_list)).cpu().numpy()
)
for i, (unmerged_logprob, merged_logprob) in enumerate(zip(unmerged_logprobs, merged_logprobs)):
assert compare(unmerged_logprob, merged_logprob).max_rel_err < 1e-3, token_ids_list[i]

@cuda_only
def test_batch_token_logprobs_lora_sync(transformer_llm, transformer_merged_llm, token_ids_list):
unmerged_logprobs = transformer_llm.batch_next_token_logprobs_sync(token_ids_list).cpu().numpy()
merged_logprobs = transformer_llm.batch_next_token_logprobs_sync(token_ids_list).cpu().numpy()
for i, (unmerged_logprob, merged_logprob) in enumerate(zip(unmerged_logprobs, merged_logprobs)):
assert compare(unmerged_logprob, merged_logprob).max_rel_err < 1e-3, token_ids_list[i]

@cuda_only
def test_set_disable_swap(transformer_llm, token_ids_list, transformer_llm_nolora):
lora_logprobs_noswapped = []
nolora_logprobs_noswapped = []
for token_ids in token_ids_list:
lora_logprobs_noswapped.append(asyncio.run(transformer_llm.next_token_logprobs(token_ids)).cpu().numpy())
nolora_logprobs_noswapped.append(asyncio.run(transformer_llm_nolora.next_token_logprobs(token_ids)).cpu().numpy())

lora_logprobs_swapped = []
nolora_logprobs_swapped = []
for token_ids in token_ids_list:
lora_logprobs_swapped.append(asyncio.run(transformer_llm.next_token_logprobs(token_ids)).cpu().numpy())
transformer_llm.clear_lora()
nolora_logprobs_swapped.append(asyncio.run(transformer_llm.next_token_logprobs(token_ids)).cpu().numpy())
transformer_llm.set_lora('lora_1')

for i, (noswapped, swapped) in enumerate(zip(lora_logprobs_noswapped, lora_logprobs_swapped)):
assert compare(noswapped, swapped).max_rel_err < 1e-3, token_ids_list[i]
for i, (noswapped, swapped) in enumerate(zip(nolora_logprobs_noswapped, nolora_logprobs_swapped)):
assert compare(noswapped, swapped).max_rel_err < 1e-3, token_ids_list[i]

@cuda_only
def test_set_disable_swap_unchached(transformer_llm, token_ids_list, transformer_llm_nolora):
lora_logprobs_noswapped = []
nolora_logprobs_noswapped = []
for token_ids in token_ids_list:
lora_logprobs_noswapped.append(transformer_llm.next_token_logprobs_uncached(token_ids).cpu().numpy())
nolora_logprobs_noswapped.append(transformer_llm_nolora.next_token_logprobs_uncached(token_ids).cpu().numpy())

lora_logprobs_swapped = []
nolora_logprobs_swapped = []
for token_ids in token_ids_list:
lora_logprobs_swapped.append(transformer_llm.next_token_logprobs_uncached(token_ids).cpu().numpy())
transformer_llm.clear_lora()
nolora_logprobs_swapped.append(transformer_llm.next_token_logprobs_uncached(token_ids).cpu().numpy())
transformer_llm.set_lora('lora_1')

for i, (noswapped, swapped) in enumerate(zip(lora_logprobs_noswapped, lora_logprobs_swapped)):
assert compare(noswapped, swapped).max_rel_err < 1e-3, token_ids_list[i]
for i, (noswapped, swapped) in enumerate(zip(nolora_logprobs_noswapped, nolora_logprobs_swapped)):
assert compare(noswapped, swapped).max_rel_err < 1e-3, token_ids_list[i]


@cuda_only
def test_set_disable_swap_sync(transformer_llm, token_ids_list, transformer_llm_nolora):
lora_logprobs_noswapped = [transformer_llm.next_token_logprobs_sync(token_ids).cpu().numpy() for token_ids in token_ids_list]
nolora_logprobs_noswapped = [transformer_llm_nolora.next_token_logprobs_sync(token_ids).cpu().numpy() for token_ids in token_ids_list]

lora_logprobs_swapped = []
nolora_logprobs_swapped = []
for token_ids in token_ids_list:
lora_logprobs_swapped.append(transformer_llm.next_token_logprobs_sync(token_ids).cpu().numpy())
transformer_llm.clear_lora()
nolora_logprobs_swapped.append(transformer_llm.next_token_logprobs_sync(token_ids).cpu().numpy())
transformer_llm.set_lora('lora_1')

for i, (noswapped, swapped) in enumerate(zip(lora_logprobs_noswapped, lora_logprobs_swapped)):
assert compare(noswapped, swapped).max_rel_err < 1e-3, token_ids_list[i]
for i, (noswapped, swapped) in enumerate(zip(nolora_logprobs_noswapped, nolora_logprobs_swapped)):
assert compare(noswapped, swapped).max_rel_err < 1e-3, token_ids_list[i]


@cuda_only
def test_set_disable_swap_batch(transformer_llm, token_ids_list, transformer_llm_nolora):
lora_logprobs_noswapped = (
asyncio.run(transformer_llm.batch_next_token_logprobs(token_ids_list)).cpu().numpy()
)
nolora_logprobs_noswapped = (
asyncio.run(transformer_llm_nolora.batch_next_token_logprobs(token_ids_list)).cpu().numpy()
)

batches = [token_ids_list[i:i+2] for i in range(0, len(token_ids_list), 2)]

lora_logprobs_swapped = []
nolora_logprobs_swapped = []
for token_ids in batches:
lora_logprobs_swapped.extend(asyncio.run(transformer_llm.batch_next_token_logprobs(token_ids)).cpu().numpy())
transformer_llm.clear_lora()
nolora_logprobs_swapped.extend(asyncio.run(transformer_llm.batch_next_token_logprobs(token_ids)).cpu().numpy())
transformer_llm.set_lora('lora_1')

for i, (noswapped, swapped) in enumerate(zip(lora_logprobs_noswapped, lora_logprobs_swapped)):
assert compare(noswapped, swapped).max_rel_err < 1e-3, token_ids_list[i]
for i, (noswapped, swapped) in enumerate(zip(nolora_logprobs_noswapped, nolora_logprobs_swapped)):
assert compare(noswapped, swapped).max_rel_err < 1e-3, token_ids_list[i]

@cuda_only
def test_set_disable_swap_batch_sync(transformer_llm, token_ids_list, transformer_llm_nolora):
lora_logprobs_noswapped = transformer_llm.batch_next_token_logprobs_sync(token_ids_list).cpu().numpy()
nolora_logprobs_noswapped = transformer_llm_nolora.batch_next_token_logprobs_sync(token_ids_list).cpu().numpy()

batches = [token_ids_list[i:i+2] for i in range(0, len(token_ids_list), 2)]

lora_logprobs_swapped = []
nolora_logprobs_swapped = []
for token_ids in batches:
lora_logprobs_swapped.extend(transformer_llm.batch_next_token_logprobs_sync(token_ids).cpu().numpy())
transformer_llm.clear_lora()
nolora_logprobs_swapped.extend(transformer_llm.batch_next_token_logprobs_sync(token_ids).cpu().numpy())
transformer_llm.set_lora('lora_1')

for i, (noswapped, swapped) in enumerate(zip(lora_logprobs_noswapped, lora_logprobs_swapped)):
assert compare(noswapped, swapped).max_rel_err < 1e-3, token_ids_list[i]
for i, (noswapped, swapped) in enumerate(zip(nolora_logprobs_noswapped, nolora_logprobs_swapped)):
assert compare(noswapped, swapped).max_rel_err < 1e-3, token_ids_list[i]
Loading
Loading