-
Notifications
You must be signed in to change notification settings - Fork 21
Ive got a number of issues and the fixes #11
Description
Bug 1: WSL2 CUDA_ERROR_ILLEGAL_ADDRESS in softmax_topk / decode
WSL2's CUDA paravirtualization layer returns invalid device pointers from cuMemHostGetDevicePointer_v2 when using CU_MEMHOSTALLOC_DEVICEMAP (flag 0x02). The allocation and pointer calls both return CUDA_SUCCESS, but the device pointer is garbage. When softmax_topk or sigmoid_topk writes topk results to that pointer → CUDA_ERROR_ILLEGAL_ADDRESS.
Affects: Anyone running Krasis on WSL2 (Windows Subsystem for Linux).
Fix: In src/gpu_decode.rs, detect WSL2 and skip DEVICEMAP allocation. The existing D2H copy fallback path handles it automatically:
fn is_wsl2() -> bool {
std::fs::read_to_string("/proc/version")
.map(|v| {
let v = v.to_lowercase();
v.contains("microsoft") || v.contains("wsl")
})
.unwrap_or(false)
}
// In PinnedMapped::new(), add at the top:
if is_wsl2() {
return Err("WSL2: DEVICEMAP pinned memory not supported".to_string());
}
When PinnedMapped::new() returns Err, pinned_topk_ids becomes None, use_pinned becomes false, and the code falls back to writing topk results to normal device memory + cuMemcpyDtoH_v2. Works perfectly, small performance cost (~2 extra D2H copies per layer per token).
Bug 2: FP8 KV cache on Ampere GPUs → ILLEGAL_ADDRESS in Rust GQA decode
server.py defaults --kv-dtype to fp8_e4m3 (line 603). The launcher has auto-downgrade logic for GPUs without FP8 support (SM < 8.9), but the server entry point doesn't. On Ampere (RTX 3060/3070/3080/3090, SM 8.6):
PyTorch allocates FP8 tensors fine (it's just a storage format, 1 byte per element)
FlashInfer prefill handles FP8 correctly
Rust GQA decode kernel computes memory offsets assuming BF16 (2 bytes per element)
First full_attention layer → reads past buffer bounds → CUDA_ERROR_ILLEGAL_ADDRESS
Affects: Anyone running krasis CLI (not the launcher UI) on any pre-Ada GPU.
Fix: Either add SM auto-detection to server.py like the launcher has, or add --kv-dtype bf16 to the launch command. The server should probably match the launcher's behavior:
server.py, after argument parsing:
if args.kv_dtype == "fp8_e4m3":
cc = torch.cuda.get_device_capability()
if cc[0] < 9 and not (cc[0] == 8 and cc[1] >= 9):
logger.info("GPU SM %d.%d < 8.9: downgrading KV dtype from fp8 to bf16", cc[0], cc[1])
args.kv_dtype = "bf16"
Bug 3: Prefill/decode warmup OOM on 8GB GPUs crashes server
On 8GB cards, after loading model weights + KV cache + AWQ attention quantization, VRAM hits 0 MB free. The warmup phases (_warmup_prefill and _warmup_decode) then OOM. Since warmup is only pre-populating CUDA caches (torch.compile, cuBLAS handles, FlashInfer workspace), it should be non-fatal — the first real request just runs slightly slower.
Affects: 8GB GPUs (RTX 3060 Ti, 3070, 4060, etc.) with larger models.
Fix: Wrap both warmup calls in try/except in server.py:
try:
_warmup_prefill(_model)
except Exception as e:
logger.warning("Prefill warmup failed (non-fatal, skipping): %s", e)
try:
_warmup_decode(_model, num_steps=1)
except Exception as e:
logger.warning("Decode warmup failed (non-fatal, skipping): %s", e)
Same for the decode validation warmup later (~line 1309). Also add gc.collect(); torch.cuda.empty_cache() before setup_gpu_decode_store() to reclaim PyTorch's cached blocks before Rust-side cudarc allocations.
Bug 4: VRAM calibration crash kills server on low-VRAM cards
The 4-point VRAM calibration (prefill short/long + decode short/long) fails on 8GB GPUs because there's not enough headroom. When any measurement fails, the server raises RuntimeError and exits. This is too aggressive — the server can function without HCS by streaming all experts from CPU via DMA.
Affects: 8GB GPUs where calibration inference OOMs.
Fix: Fall back to zero HCS budget instead of crashing:
if (prefill_short_free is None or prefill_long_free is None
or decode_short_free is None or decode_long_free is None):
logger.warning("VRAM calibration failed — using zero HCS budget (all experts from CPU)")
prefill_short_free = prefill_long_free = 0
decode_short_free = decode_long_free = 0
_baseline_free = 0
I did use AI to generate all that but I hope someone finds this useful