Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from transformers import AutoModelForCausalLM

from janus.models import MultiModalityCausalLM, VLChatProcessor
from janus.utils.cuda_memory_manager import (
monitor_memory,
)
import numpy as np
import os
import PIL.Image
Expand Down Expand Up @@ -51,6 +54,7 @@
prompt = sft_format + vl_chat_processor.image_start_tag


@monitor_memory(warning_threshold_gb=1.5, track_stats=True)
@torch.inference_mode()
def generate(
mmgpt: MultiModalityCausalLM,
Expand All @@ -66,51 +70,61 @@ def generate(
input_ids = vl_chat_processor.tokenizer.encode(prompt)
input_ids = torch.LongTensor(input_ids)

tokens = torch.zeros((parallel_size*2, len(input_ids)), dtype=torch.int).cuda()
for i in range(parallel_size*2):
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).cuda()
for i in range(parallel_size * 2):
tokens[i, :] = input_ids
if i % 2 != 0:
tokens[i, 1:-1] = vl_chat_processor.pad_id

inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens)

generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()
generated_tokens = torch.zeros(
(parallel_size, image_token_num_per_image), dtype=torch.int
).cuda()

for i in range(image_token_num_per_image):
outputs = mmgpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=outputs.past_key_values if i != 0 else None)
outputs = mmgpt.language_model.model(
inputs_embeds=inputs_embeds,
use_cache=True,
past_key_values=outputs.past_key_values if i != 0 else None,
)
hidden_states = outputs.last_hidden_state

logits = mmgpt.gen_head(hidden_states[:, -1, :])
logit_cond = logits[0::2, :]
logit_uncond = logits[1::2, :]
logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)

logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
probs = torch.softmax(logits / temperature, dim=-1)

next_token = torch.multinomial(probs, num_samples=1)
generated_tokens[:, i] = next_token.squeeze(dim=-1)

next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
next_token = torch.cat(
[next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1
).view(-1)
img_embeds = mmgpt.prepare_gen_img_embeds(next_token)
inputs_embeds = img_embeds.unsqueeze(dim=1)


dec = mmgpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size])
dec = mmgpt.gen_vision_model.decode_code(
generated_tokens.to(dtype=torch.int),
shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size],
)
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)

dec = np.clip((dec + 1) / 2 * 255, 0, 255)

visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
visual_img[:, :, :] = dec

os.makedirs('generated_samples', exist_ok=True)
os.makedirs("generated_samples", exist_ok=True)
for i in range(parallel_size):
save_path = os.path.join('generated_samples', "img_{}.jpg".format(i))
save_path = os.path.join("generated_samples", "img_{}.jpg".format(i))
PIL.Image.fromarray(visual_img[i]).save(save_path)


generate(
vl_gpt,
vl_chat_processor,
prompt,
)
)
4 changes: 4 additions & 0 deletions janus/janusflow/models/clip_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
from einops import rearrange

from janus.janusflow.models.siglip_vit import create_siglip_vit
from janus.utils.cuda_memory_manager import (
monitor_memory,
)


class CLIPVisionTower(nn.Module):
Expand Down Expand Up @@ -104,6 +107,7 @@ def feature_select(self, image_forward_outs):
raise ValueError(f"Unexpected select feature: {self.select_feature}")
return image_features

@monitor_memory(warning_threshold_gb=1.5, track_stats=True)
def forward(self, images):
"""

Expand Down
2 changes: 2 additions & 0 deletions janus/janusflow/models/modeling_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from janus.janusflow.models.clip_encoder import CLIPVisionTower
from janus.janusflow.models.uvit import ShallowUViTEncoder, ShallowUViTDecoder
from janus.utils.cuda_memory_manager import monitor_memory
import torch.nn as nn


Expand Down Expand Up @@ -168,6 +169,7 @@ def __init__(self, config: MultiModalityConfig):
)
self.vision_gen_dec_aligner = nn.Linear(2048, 768, bias=True)

@monitor_memory(warning_threshold_gb=1.5, track_stats=True)
def prepare_inputs_embeds(
self,
input_ids: torch.LongTensor,
Expand Down
2 changes: 2 additions & 0 deletions janus/janusflow/models/processing_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from janus.janusflow.models.image_processing_vlm import VLMImageProcessor
from janus.utils.conversation import get_conv_template
from janus.utils.cuda_memory_manager import monitor_memory


class DictOutput(object):
Expand Down Expand Up @@ -384,6 +385,7 @@ def __call__(

return prepare

@monitor_memory(warning_threshold_gb=1.5, track_stats=True)
def batchify(
self, prepare_list: List[VLChatProcessorOutput]
) -> BatchedVLChatProcessorOutput:
Expand Down
4 changes: 4 additions & 0 deletions janus/models/clip_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
from einops import rearrange

from janus.models.siglip_vit import create_siglip_vit
from janus.utils.cuda_memory_manager import (
monitor_memory,
)


class CLIPVisionTower(nn.Module):
Expand Down Expand Up @@ -104,6 +107,7 @@ def feature_select(self, image_forward_outs):
raise ValueError(f"Unexpected select feature: {self.select_feature}")
return image_features

@monitor_memory(warning_threshold_gb=1.5, track_stats=True)
def forward(self, images):
"""

Expand Down
5 changes: 5 additions & 0 deletions janus/models/modeling_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@

from janus.models.clip_encoder import CLIPVisionTower
from janus.models.projector import MlpProjector
from janus.utils.cuda_memory_manager import (
monitor_memory,
)


class vision_head(torch.nn.Module):
Expand Down Expand Up @@ -218,6 +221,7 @@ def __init__(self, config: MultiModalityConfig):
language_config = config.language_config
self.language_model = LlamaForCausalLM(language_config)

@monitor_memory(warning_threshold_gb=1.5, track_stats=True)
def prepare_inputs_embeds(
self,
input_ids: torch.LongTensor,
Expand Down Expand Up @@ -259,6 +263,7 @@ def prepare_inputs_embeds(

return inputs_embeds

@monitor_memory(warning_threshold_gb=1.5, track_stats=True)
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
return self.gen_aligner(self.gen_embed(image_ids))

Expand Down
4 changes: 4 additions & 0 deletions janus/models/processing_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@

from janus.models.image_processing_vlm import VLMImageProcessor
from janus.utils.conversation import get_conv_template
from janus.utils.cuda_memory_manager import (
monitor_memory,
)


class DictOutput(object):
Expand Down Expand Up @@ -354,6 +357,7 @@ def __call__(

return prepare

@monitor_memory(warning_threshold_gb=1.5, track_stats=True)
def batchify(
self, prepare_list: List[VLChatProcessorOutput]
) -> BatchedVLChatProcessorOutput:
Expand Down
68 changes: 68 additions & 0 deletions janus/utils/cuda_memory_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from functools import wraps
from typing import Callable, Any
import torch
import warnings


def monitor_memory(
warning_threshold_gb: float = 2.0,
track_stats: bool = True,
cleanup_on_warning: bool = True,
) -> Callable:
"""Memory monitoring decorator for CUDA operations.

Args:
warning_threshold_gb: Memory threshold in GB to trigger warnings
track_stats: Whether to track and print memory statistics
cleanup_on_warning: Whether to attempt memory cleanup when threshold is reached

Returns:
Decorator function that monitors memory usage
"""

def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs) -> Any:
if not torch.cuda.is_available():
return func(*args, **kwargs)

# Get initial memory state
free_before = torch.cuda.mem_get_info()[0] / 1024**3

try:
# Check memory state and cleanup if needed
if free_before < warning_threshold_gb and cleanup_on_warning:
torch.cuda.empty_cache()
free_after_cleanup = torch.cuda.mem_get_info()[0] / 1024**3

if free_after_cleanup < warning_threshold_gb:
warnings.warn(
f"Low memory in {func.__name__}: {free_after_cleanup:.2f}GB free"
)

result = func(*args, **kwargs)

# Track memory statistics if enabled
if track_stats:
peak = torch.cuda.max_memory_allocated() / 1024**3
free_after = torch.cuda.mem_get_info()[0] / 1024**3
print(
f"Memory stats for {func.__name__}:\n"
f"Peak: {peak:.2f}GB | Delta: {free_before - free_after:.2f}GB"
)
torch.cuda.reset_peak_memory_stats()

return result

except RuntimeError as e:
if "out of memory" in str(e):
free = torch.cuda.mem_get_info()[0] / 1024**3
raise RuntimeError(
f"OOM in {func.__name__} with {free:.2f}GB free. "
"Consider reducing batch size or image resolution."
) from e
raise

return wrapper

return decorator