diff --git a/llmvm/client/printing.py b/llmvm/client/printing.py index 593bc6e..cc90bb6 100644 --- a/llmvm/client/printing.py +++ b/llmvm/client/printing.py @@ -15,6 +15,7 @@ from rich.console import Console from rich.markdown import Markdown +from rich.syntax import Syntax from rich.theme import Theme from llmvm.common.container import Container @@ -109,12 +110,15 @@ async def decode(content: str) -> bool: class StreamPrinter(): def __init__(self, file=sys.stderr): - self.buffer = '' self.console = Console(file=file) - self.markdown_mode = False self.token_color = Container.get_config_variable('client_stream_token_color', default='bright_black') self.thinking_token_color = Container.get_config_variable('client_stream_thinking_token_color', default='cyan') + # state for line-by-line rich rendering + self.current_line = '' + self.in_code_block = False + self.code_lang = '' + async def display_image(self, image_bytes): if len(image_bytes) < 10: return @@ -208,8 +212,46 @@ async def write(self, node: AstNode): string = str(node) if string: - self.buffer += string self.console.print(string, end='', style=f"{token_color}", highlight=False) + self.current_line += string + + while '\n' in self.current_line: + line, self.current_line = self.current_line.split('\n', 1) + await self._render_line(line + '\n') + + if isinstance(node, TokenStopNode) or isinstance(node, StreamingStopNode): + if self.current_line: + await self._render_line(self.current_line) + self.current_line = '' + return + + async def _erase_last_line(self): + self.console.file.write('\x1b[1A\r\x1b[2K') + self.console.file.flush() + + async def _render_line(self, text: str): + await self._erase_last_line() + if text.strip().startswith('```'): + Markdown.__rich_console__ = markdown__rich_console__ + self.console.print(Markdown(text), end='') + fence_content = text.strip()[3:].strip() + if self.in_code_block: + self.in_code_block = False + self.code_lang = '' + else: + self.in_code_block = True + if fence_content: + self.code_lang = fence_content + elif self.in_code_block: + lang = self.code_lang if self.code_lang else 'text' + syntax = Syntax(text.rstrip('\n'), lang, theme="monokai", background_color="default", word_wrap=True, padding=0) + self.console.print(syntax) + if text.endswith('\n'): + self.console.print() + else: + Markdown.__rich_console__ = markdown__rich_console__ + self.console.print(Markdown(text), end='') + class ConsolePrinter: