From 712386162eaaf88797f9c896992ff48d6600c9cd Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Fri, 20 Feb 2026 04:03:48 +0000 Subject: [PATCH] Refactor messages system to support categorization, editing, and metadata - Added `category` attribute to `MessageThread` and updated storage to use subdirectories. - Updated `Application._load_persisted_threads` to recursively load threads. - Refactored `ChatList` to use a `Tree` widget for categorized thread display. - Added `/thread move` and `/category create` commands. - Added `edit_message` and `delete_message` methods to `MessageThread`. - Added `/message edit` and `/message delete` commands. - Updated `MessageBubble` UI with Edit/Delete buttons and inline editing. - Updated `ChatViewWidget` to handle message updates and display thread mode. - Refactored `MessageService.stream_message` to append user messages instead of overwriting history, preserving metadata and edits. - Added metadata support (e.g., model name) to messages. Co-authored-by: presstab <6628210+presstab@users.noreply.github.com> --- src/jrdev/commands/__init__.py | 4 + src/jrdev/commands/category.py | 31 ++++ src/jrdev/commands/message.py | 81 ++++++++++ src/jrdev/commands/thread.py | 57 +++++++ src/jrdev/core/application.py | 64 ++++---- src/jrdev/core/commands.py | 6 +- src/jrdev/core/state.py | 7 +- src/jrdev/messages/thread.py | 81 +++++++++- src/jrdev/services/message_service.py | 13 +- src/jrdev/ui/tui/chat/chat_list.py | 176 ++++++++++++++-------- src/jrdev/ui/tui/chat/chat_view_widget.py | 56 +++++-- src/jrdev/ui/tui/chat/message_bubble.py | 128 ++++++++++++++-- 12 files changed, 575 insertions(+), 129 deletions(-) create mode 100644 src/jrdev/commands/category.py create mode 100644 src/jrdev/commands/message.py diff --git a/src/jrdev/commands/__init__.py b/src/jrdev/commands/__init__.py index 30f32ea..bdf173b 100644 --- a/src/jrdev/commands/__init__.py +++ b/src/jrdev/commands/__init__.py @@ -7,6 +7,7 @@ from jrdev.commands.addcontext import handle_addcontext from jrdev.commands.asyncsend import handle_asyncsend from jrdev.commands.cancel import handle_cancel +from jrdev.commands.category import handle_category from jrdev.commands.clearcontext import handle_clearcontext from jrdev.commands.code import handle_code from jrdev.commands.compact import handle_compact @@ -17,6 +18,7 @@ from jrdev.commands.help import handle_help from jrdev.commands.init import handle_init from jrdev.commands.keys import handle_keys +from jrdev.commands.message import handle_message from jrdev.commands.migrate import handle_migrate from jrdev.commands.model import handle_model from jrdev.commands.modelprofile import handle_modelprofile @@ -34,6 +36,7 @@ "handle_addcontext", "handle_asyncsend", "handle_cancel", + "handle_category", "handle_code", "handle_compact", "handle_cost", @@ -44,6 +47,7 @@ "handle_help", "handle_init", "handle_keys", + "handle_message", "handle_migrate", "handle_model", "handle_models", diff --git a/src/jrdev/commands/category.py b/src/jrdev/commands/category.py new file mode 100644 index 0000000..09180b1 --- /dev/null +++ b/src/jrdev/commands/category.py @@ -0,0 +1,31 @@ +import argparse +from typing import Any, List +import os + +from jrdev.ui.ui import PrintType +from jrdev.messages.thread import THREADS_DIR + +async def handle_category(app: Any, args: List[str], _worker_id: str) -> None: + """ + Manage categories for message threads. + + Usage: + /category create + """ + if len(args) < 3 or args[1] != "create": + app.ui.print_text("Usage: /category create ", PrintType.ERROR) + return + + category_name = args[2] + + # Validate name (no slashes, dots, etc) + if not category_name.replace('_', '').replace('-', '').isalnum(): + app.ui.print_text("Error: Category name must be alphanumeric (with underscores or hyphens).", PrintType.ERROR) + return + + cat_dir = os.path.join(THREADS_DIR, category_name) + try: + os.makedirs(cat_dir, exist_ok=True) + app.ui.print_text(f"Category '{category_name}' created.", PrintType.SUCCESS) + except OSError as e: + app.ui.print_text(f"Error creating category: {e}", PrintType.ERROR) diff --git a/src/jrdev/commands/message.py b/src/jrdev/commands/message.py new file mode 100644 index 0000000..2805b72 --- /dev/null +++ b/src/jrdev/commands/message.py @@ -0,0 +1,81 @@ +import argparse +from typing import Any, List +from jrdev.ui.ui import PrintType + +async def handle_message(app: Any, args: List[str], _worker_id: str) -> None: + """ + Manage individual messages. + + Usage: + /message edit + /message delete + """ + if len(args) < 2: + return + + subcommand = args[1] + + if subcommand == "edit": + if len(args) < 5: + app.ui.print_text("Usage: /message edit ", PrintType.ERROR) + return + + thread_id = args[2] + try: + index = int(args[3]) + except ValueError: + app.ui.print_text("Error: index must be an integer", PrintType.ERROR) + return + + content = " ".join(args[4:]) + + thread = app.state.threads.get(thread_id) + if not thread: + # try prefix + thread_id_pre = f"thread_{thread_id}" + thread = app.state.threads.get(thread_id_pre) + if thread: + thread_id = thread_id_pre + + if not thread: + app.ui.print_text(f"Error: Thread '{thread_id}' not found.", PrintType.ERROR) + return + + if thread.edit_message(index, content): + app.ui.print_text("Message updated.", PrintType.SUCCESS) + app.ui.chat_thread_update(thread_id) + else: + app.ui.print_text("Error: Failed to update message (invalid index?)", PrintType.ERROR) + + elif subcommand == "delete": + if len(args) < 4: + app.ui.print_text("Usage: /message delete ", PrintType.ERROR) + return + + thread_id = args[2] + try: + index = int(args[3]) + except ValueError: + app.ui.print_text("Error: index must be an integer", PrintType.ERROR) + return + + thread = app.state.threads.get(thread_id) + if not thread: + # try prefix + thread_id_pre = f"thread_{thread_id}" + thread = app.state.threads.get(thread_id_pre) + if thread: + thread_id = thread_id_pre + + if not thread: + app.ui.print_text(f"Error: Thread '{thread_id}' not found.", PrintType.ERROR) + return + + if thread.delete_message(index): + app.ui.print_text("Message deleted.", PrintType.SUCCESS) + app.ui.chat_thread_update(thread_id) + else: + app.ui.print_text("Error: Failed to delete message (invalid index?)", PrintType.ERROR) + + else: + app.ui.print_text(f"Unknown message subcommand: {subcommand}", PrintType.ERROR) diff --git a/src/jrdev/commands/thread.py b/src/jrdev/commands/thread.py index f02baf4..8f1d9e1 100644 --- a/src/jrdev/commands/thread.py +++ b/src/jrdev/commands/thread.py @@ -147,6 +147,16 @@ async def handle_thread(app: Any, args: List[str], _worker_id: str) -> None: help="Turn web search on or off for the current thread", ) + # Move thread command + move_parser = subparsers.add_parser( + "move", + help="Move a thread to a category", + description="Move an existing conversation thread to a category", + epilog=f"Example: {format_command_with_args_plain('/thread move', 'thread_abc feature_requests')}", + ) + move_parser.add_argument("thread_id", type=str, help="Unique ID of the thread to move") + move_parser.add_argument("category", type=str, help="Target category name") + try: if any(arg in ["-h", "--help"] for arg in args[1:]): if len(args) == 2 and args[1] in ["-h", "--help"]: @@ -186,6 +196,8 @@ async def handle_thread(app: Any, args: List[str], _worker_id: str) -> None: await _handle_delete_thread(app, parsed_args) elif parsed_args.subcommand == "websearch": await _handle_websearch_toggle(app, parsed_args) + elif parsed_args.subcommand == "move": + await _handle_move_thread(app, parsed_args) else: app.ui.print_text("Error: Missing subcommand", PrintType.ERROR) app.ui.print_text("Available Thread Subcommands:", PrintType.HEADER) @@ -198,6 +210,7 @@ async def handle_thread(app: Any, args: List[str], _worker_id: str) -> None: ("info", "", "Show current thread details", "thread info"), ("view", "[count]", "Display message history", "thread view 5"), ("delete", "", "Delete an existing thread", "thread delete thread_abc"), + ("move", " ", "Move thread to category", "thread move thread_abc work"), ] for cmd, cmd_args, desc, example in subcommands: @@ -252,6 +265,11 @@ async def handle_thread(app: Any, args: List[str], _worker_id: str) -> None: format_command_with_args_plain("/thread websearch", "on|off"), "Enable or disable per-thread web search mode used by chat input\nExample: /thread websearch on", ), + ( + "Move Thread", + format_command_with_args_plain("/thread move", " "), + "Move thread to a category\nExample: /thread move thread_abc work", + ), ] for header, cmd, desc in sections: @@ -467,3 +485,42 @@ async def _handle_websearch_toggle(app: Any, args: argparse.Namespace) -> None: state_str = "enabled" if enable else "disabled" app.ui.print_text(f"Web search {state_str} for thread {thread.thread_id}", PrintType.SUCCESS) app.ui.chat_thread_update(thread.thread_id) + + +async def _handle_move_thread(app: Any, args: argparse.Namespace) -> None: + """Move an existing message thread to a category.""" + thread_id = args.thread_id + category = args.category + + thread = app.state.threads.get(thread_id) + if not thread: + # try with prefix + thread_id_pre = f"thread_{thread_id}" + thread = app.state.threads.get(thread_id_pre) + if thread: + thread_id = thread_id_pre + + if not thread: + app.ui.print_text(f"Error: Thread '{args.thread_id}' not found.", PrintType.ERROR) + return + + old_category = getattr(thread, "category", "default") + + # If category is same, do nothing + if old_category == category: + app.ui.print_text(f"Thread is already in category '{category}'.", PrintType.INFO) + return + + # Delete old file + thread.delete_persisted_file() + + # Update category + thread.category = category + + # Save (creates new file in new dir) + thread.save() + + app.ui.print_text(f"Thread '{thread_id}' moved to category '{category}'.", PrintType.SUCCESS) + + # Update UI + app.ui.chat_thread_update(thread_id) diff --git a/src/jrdev/core/application.py b/src/jrdev/core/application.py index 0dd36ed..918ff07 100644 --- a/src/jrdev/core/application.py +++ b/src/jrdev/core/application.py @@ -84,39 +84,51 @@ def write_terminal_text_styles(self) -> None: self.logger.error("Error writing terminal text styles") def _load_persisted_threads(self) -> Dict[str, MessageThread]: - """Load all persisted message threads from disk.""" + """Load all persisted message threads from disk, including subdirectories.""" loaded_threads: Dict[str, MessageThread] = {} if not os.path.isdir(THREADS_DIR): self.logger.info(f"Threads directory '{THREADS_DIR}' not found. No threads to load.") return loaded_threads self.logger.info(f"Loading persisted threads from '{THREADS_DIR}'...") - for filename in os.listdir(THREADS_DIR): - if filename.endswith(".json"): - file_path = os.path.join(THREADS_DIR, filename) - try: - with open(file_path, 'r', encoding='utf-8') as f: - data = json.load(f) + + for root, dirs, files in os.walk(THREADS_DIR): + for filename in files: + if filename.endswith(".json"): + file_path = os.path.join(root, filename) + + # Infer category from parent directory relative to THREADS_DIR + rel_dir = os.path.relpath(root, THREADS_DIR) + category = "default" + if rel_dir != ".": + category = rel_dir - if "thread_id" not in data: - self.logger.warning(f"File {file_path} is missing 'thread_id'. Skipping.") - continue - - thread = MessageThread.from_dict(data) - - # Don't load old router threads - thread_type = thread.metadata.get("type") - if thread_type and thread_type == "router": - continue - - loaded_threads[thread.thread_id] = thread - self.logger.debug(f"Successfully loaded thread: {thread.thread_id} from {file_path}") - except json.JSONDecodeError as e: - self.logger.error(f"Error decoding JSON from {file_path}: {e}. Skipping file.") - except KeyError as e: - self.logger.error(f"Missing key in thread data from {file_path}: {e}. Skipping file.") - except Exception as e: - self.logger.error(f"Unexpected error loading thread from {file_path}: {e}. Skipping file.") + try: + with open(file_path, 'r', encoding='utf-8') as f: + data = json.load(f) + + if "thread_id" not in data: + self.logger.warning(f"File {file_path} is missing 'thread_id'. Skipping.") + continue + + thread = MessageThread.from_dict(data) + + # Ensure loaded category matches filesystem structure + thread.category = category + + # Don't load old router threads + thread_type = thread.metadata.get("type") + if thread_type and thread_type == "router": + continue + + loaded_threads[thread.thread_id] = thread + self.logger.debug(f"Successfully loaded thread: {thread.thread_id} from {file_path}") + except json.JSONDecodeError as e: + self.logger.error(f"Error decoding JSON from {file_path}: {e}. Skipping file.") + except KeyError as e: + self.logger.error(f"Missing key in thread data from {file_path}: {e}. Skipping file.") + except Exception as e: + self.logger.error(f"Unexpected error loading thread from {file_path}: {e}. Skipping file.") self.logger.info(f"Finished loading threads. Total loaded: {len(loaded_threads)}.") return loaded_threads diff --git a/src/jrdev/core/commands.py b/src/jrdev/core/commands.py index 634fe15..9f98b1f 100644 --- a/src/jrdev/core/commands.py +++ b/src/jrdev/core/commands.py @@ -5,6 +5,7 @@ handle_addcontext, handle_asyncsend, handle_cancel, + handle_category, handle_clearcontext, handle_code, handle_compact, @@ -14,6 +15,7 @@ handle_help, handle_init, handle_keys, + handle_message, handle_migrate, handle_model, handle_models, @@ -58,6 +60,7 @@ def _register_core_commands(self) -> None: "/viewcontext": handle_viewcontext, "/asyncsend": handle_asyncsend, "/tasks": handle_tasks, + "/category": handle_category, "/cancel": handle_cancel, "/code": handle_code, "/projectcontext": handle_projectcontext, @@ -67,7 +70,8 @@ def _register_core_commands(self) -> None: "/research": handle_research, "/routeragent": handle_routeragent, "/thread": handle_thread, - "/migrate": handle_migrate + "/migrate": handle_migrate, + "/message": handle_message } self.commands.update(core_commands) diff --git a/src/jrdev/core/state.py b/src/jrdev/core/state.py index e9bf5e2..aa0e378 100644 --- a/src/jrdev/core/state.py +++ b/src/jrdev/core/state.py @@ -100,7 +100,7 @@ def get_all_threads(self) -> List[MessageThread]: return list(self.threads.values()) # Thread management - def create_thread(self, thread_id: str="", meta_data: Dict[str, str]=None) -> str: + def create_thread(self, thread_id: str="", meta_data: Dict[str, str]=None, category: str="default", mode: str="chat") -> str: """Create a new message thread""" if thread_id == "": thread_id = f"thread_{uuid.uuid4().hex[:8]}" @@ -109,7 +109,10 @@ def create_thread(self, thread_id: str="", meta_data: Dict[str, str]=None) -> st # This is handled by @auto_persist on MessageThread methods like set_name or if it's saved on creation # For now, MessageThread constructor doesn't auto-save, so an explicit save might be needed # or ensure first mutation triggers save. The current design relies on mutation. - self.threads[thread_id] = MessageThread(thread_id) + thread = MessageThread(thread_id) + thread.category = category + thread.metadata["mode"] = mode + self.threads[thread_id] = thread if meta_data: for k, v in meta_data.items(): self.threads[thread_id].metadata[k] = v diff --git a/src/jrdev/messages/thread.py b/src/jrdev/messages/thread.py index 4795b02..f310697 100644 --- a/src/jrdev/messages/thread.py +++ b/src/jrdev/messages/thread.py @@ -36,6 +36,7 @@ def __init__(self, thread_id: str): thread_id: Unique identifier for this thread """ self.thread_id: str = thread_id + self.category: str = "default" self.name: Optional[str] = None self.messages: List[Dict[str, str]] = [] self.context: Set[str] = set() @@ -44,6 +45,7 @@ def __init__(self, thread_id: str): self.metadata: Dict[str, Any] = { "created_at": datetime.now(), "last_modified": datetime.now(), + "mode": "chat", } def to_dict(self) -> Dict[str, Any]: @@ -58,6 +60,7 @@ def to_dict(self) -> Dict[str, Any]: return { "thread_id": self.thread_id, + "category": self.category, "name": self.name, "messages": self.messages, "context": list(self.context), @@ -70,6 +73,7 @@ def to_dict(self) -> Dict[str, Any]: def from_dict(cls, data: Dict[str, Any]) -> 'MessageThread': """Create a MessageThread instance from a dictionary.""" thread = cls(data["thread_id"]) + thread.category = data.get("category", "default") thread.name = data.get("name") thread.messages = data.get("messages", []) thread.context = set(data.get("context", [])) @@ -103,7 +107,12 @@ def from_dict(cls, data: Dict[str, Any]) -> 'MessageThread': def save(self) -> None: """Save the thread's state to a JSON file.""" - file_path = os.path.join(THREADS_DIR, f"{self.thread_id}.json") + # Ensure category directory exists + category = getattr(self, "category", "default") or "default" + category_dir = os.path.join(THREADS_DIR, category) + os.makedirs(category_dir, exist_ok=True) + + file_path = os.path.join(category_dir, f"{self.thread_id}.json") tmp_file_path = file_path + ".tmp" try: with open(tmp_file_path, "w") as f: @@ -121,10 +130,22 @@ def save(self) -> None: def delete_persisted_file(self) -> None: """Delete the persisted JSON file for this thread.""" - file_path = os.path.join(THREADS_DIR, f"{self.thread_id}.json") + category = getattr(self, "category", "default") or "default" + file_path = os.path.join(THREADS_DIR, category, f"{self.thread_id}.json") try: if os.path.exists(file_path): os.remove(file_path) + # Also check default if not found (in case it was moved implicitly without update, though unlikely) + elif category != "default": + default_path = os.path.join(THREADS_DIR, "default", f"{self.thread_id}.json") + if os.path.exists(default_path): + os.remove(default_path) + else: + # Also check root for legacy files + root_path = os.path.join(THREADS_DIR, f"{self.thread_id}.json") + if os.path.exists(root_path): + os.remove(root_path) + except OSError as e: # Optionally log this error # print(f"Error deleting persisted file {file_path}: {e}") # For debugging @@ -171,27 +192,53 @@ def add_embedded_files(self, files: List[str]) -> None: self.metadata["last_modified"] = datetime.now() @auto_persist - def add_response(self, response: str) -> None: + def add_user_message(self, content: str, metadata: Dict[str, Any] = None) -> None: + """Add a user message to the thread history.""" + msg = {"role": "user", "content": content} + if metadata: + msg["metadata"] = metadata + self.messages.append(msg) + self.metadata["last_modified"] = datetime.now() + + @auto_persist + def add_response(self, response: str, metadata: Dict[str, Any] = None) -> None: """Add a complete assistant response to the thread history.""" - self.messages.append({"role": "assistant", "content": response}) + msg = {"role": "assistant", "content": response} + if metadata: + msg["metadata"] = metadata + self.messages.append(msg) self.metadata["last_modified"] = datetime.now() @auto_persist - def add_response_partial(self, chunk: str) -> None: + def add_response_partial(self, chunk: str, metadata: Dict[str, Any] = None) -> None: """Add a partial assistant response chunk to the thread history.""" if self.messages and self.messages[-1].get("role") == "assistant": self.messages[-1]["content"] += chunk + if metadata: + if "metadata" not in self.messages[-1]: + self.messages[-1]["metadata"] = {} + self.messages[-1]["metadata"].update(metadata) else: - self.messages.append({"role": "assistant", "content": chunk}) + msg = {"role": "assistant", "content": chunk} + if metadata: + msg["metadata"] = metadata + self.messages.append(msg) self.metadata["last_modified"] = datetime.now() @auto_persist - def finalize_response(self, full_text: str) -> None: + def finalize_response(self, full_text: str, metadata: Dict[str, Any] = None) -> None: """Finalize the assistant response, replacing partials with full text.""" if self.messages and self.messages[-1].get("role") == "assistant": self.messages[-1]["content"] = full_text + if metadata: + if "metadata" not in self.messages[-1]: + self.messages[-1]["metadata"] = {} + self.messages[-1]["metadata"].update(metadata) else: - self.messages.append({"role": "assistant", "content": full_text}) + msg = {"role": "assistant", "content": full_text} + if metadata: + msg["metadata"] = metadata + self.messages.append(msg) self.metadata["last_modified"] = datetime.now() @auto_persist @@ -201,3 +248,21 @@ def set_compacted(self, messages: List[Dict[str, str]]) -> None: self.context = set() self.embedded_files = set() self.metadata["last_modified"] = datetime.now() + + @auto_persist + def edit_message(self, index: int, new_content: str) -> bool: + """Edit the content of a specific message.""" + if 0 <= index < len(self.messages): + self.messages[index]["content"] = new_content + self.metadata["last_modified"] = datetime.now() + return True + return False + + @auto_persist + def delete_message(self, index: int) -> bool: + """Delete a specific message from the history.""" + if 0 <= index < len(self.messages): + del self.messages[index] + self.metadata["last_modified"] = datetime.now() + return True + return False diff --git a/src/jrdev/services/message_service.py b/src/jrdev/services/message_service.py index 02dc421..18c3d61 100644 --- a/src/jrdev/services/message_service.py +++ b/src/jrdev/services/message_service.py @@ -57,7 +57,11 @@ async def stream_message(self, msg_thread: MessageThread, content: str, task_id: # Update message thread state with the new user message and context used # This ensures the user's message is part of the history before the assistant responds. msg_thread.add_embedded_files(builder.get_files()) # Files used are now "embedded" - msg_thread.messages = messages_for_llm # Update thread history to include this user's message + + # New: Append user message instead of overwriting entire history + if messages_for_llm and messages_for_llm[-1]["role"] == "user": + new_user_msg = messages_for_llm[-1] + msg_thread.add_user_message(new_user_msg["content"], metadata=new_user_msg.get("metadata")) # Stream response from LLM response_accumulator = "" @@ -73,6 +77,7 @@ async def stream_message(self, msg_thread: MessageThread, content: str, task_id: # completely filter out thinking is_first_chunk = True in_think = False + metadata = {"model": self.app.state.model} async for chunk in llm_response_stream: if is_first_chunk: is_first_chunk = False @@ -81,18 +86,18 @@ async def stream_message(self, msg_thread: MessageThread, content: str, task_id: yield "Thinking..." else: response_accumulator += chunk - msg_thread.add_response_partial(chunk) # Update thread with partial assistant response + msg_thread.add_response_partial(chunk, metadata=metadata) # Update thread with partial assistant response yield chunk elif in_think: if chunk == "": in_think = False else: response_accumulator += chunk - msg_thread.add_response_partial(chunk) # Update thread with partial assistant response + msg_thread.add_response_partial(chunk, metadata=metadata) # Update thread with partial assistant response yield chunk # Finalize the full response in the message thread - msg_thread.finalize_response(response_accumulator.strip()) + msg_thread.finalize_response(response_accumulator.strip(), metadata=metadata) except Exception as e: logger.error("Message Service: %s", e) if task_id: diff --git a/src/jrdev/ui/tui/chat/chat_list.py b/src/jrdev/ui/tui/chat/chat_list.py index ad4d110..2efe8e1 100644 --- a/src/jrdev/ui/tui/chat/chat_list.py +++ b/src/jrdev/ui/tui/chat/chat_list.py @@ -1,9 +1,10 @@ -from textual import events, on +from textual import on from textual.app import ComposeResult from textual.widget import Widget -from textual.widgets import Button +from textual.widgets import Button, Tree +from textual.widgets.tree import TreeNode from textual.message import Message -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Any import logging from jrdev.messages.thread import MessageThread @@ -18,31 +19,47 @@ def __init__(self, thread_id: str) -> None: self.thread_id = thread_id super().__init__() + DEFAULT_CSS = """ + ChatList { + layout: vertical; + width: 100%; + height: 100%; + } + #new_thread { + dock: top; + width: 100%; + height: 3; + min-height: 3; + margin-bottom: 1; + } + Tree { + width: 100%; + height: 1fr; + padding: 1; + } + """ + def __init__(self, core_app, id: Optional[str] = None) -> None: super().__init__(id=id) self.core_app = core_app - self.buttons: Dict[str, Button] = {} # id -> Button + # We need to map thread_id to TreeNode to easily find/update them + self.thread_nodes: Dict[str, TreeNode] = {} + self.category_nodes: Dict[str, TreeNode] = {} self.threads: Dict[str, MessageThread] = {} # id -> MsgThread self.active_thread_id: Optional[str] = None self.new_button = Button(label="+ New Chat", id="new_thread", classes="sidebar_button") + self.tree = Tree("Chats") + self.tree.show_root = False + self.tree.guide_depth = 2 def compose(self) -> ComposeResult: - for button in self.buttons.values(): - yield button yield self.new_button + yield self.tree async def on_mount(self) -> None: - self.can_focus = False - for button in self.buttons.values(): - self.style_button(button) - self.style_button(self.new_button) - - def style_button(self, btn): - btn.can_focus = False - btn.styles.border = "none" - btn.styles.min_width = 4 - btn.styles.width = "100%" - btn.styles.align_horizontal = "center" + self.new_button.can_focus = False + self.tree.focus() # Let the tree take focus for navigation + pass async def add_thread(self, msg_thread: MessageThread) -> None: # filter out any router threads @@ -51,79 +68,106 @@ async def add_thread(self, msg_thread: MessageThread) -> None: return tid = msg_thread.thread_id + is_new = tid not in self.threads + self.threads[tid] = msg_thread + + category = getattr(msg_thread, "category", "default") or "default" + + # Find or create category node + if category not in self.category_nodes: + # Add category node + cat_node = self.tree.root.add(category, expand=True) + self.category_nodes[category] = cat_node + else: + cat_node = self.category_nodes[category] + name = tid.removeprefix("thread_") if msg_thread.name: name = msg_thread.name - btn = Button(label=name, id=tid, classes="sidebar_button") - self.buttons[tid] = btn - self.threads[tid] = msg_thread - await self.mount(btn) - # if this is the first thread, make it active + + # Add thread node or update + if tid in self.thread_nodes: + node = self.thread_nodes[tid] + # Check if parent matches (category change) + if node.parent != cat_node: + node.remove() + # Re-add in correct category + thread_node = cat_node.add_leaf(name, data=tid) + self.thread_nodes[tid] = thread_node + # Reselect if active + if self.active_thread_id == tid: + self.tree.select_node(thread_node) + else: + # Just update label + if str(node.label) != name: + node.label = name + else: + thread_node = cat_node.add_leaf(name, data=tid) + self.thread_nodes[tid] = thread_node + # If this is active (e.g. loaded on startup), select it + if self.active_thread_id == tid: + self.tree.select_node(thread_node) + + # if this is the first thread and none active, make it active if self.active_thread_id is None: self.set_active(tid) - self.style_button(btn) + + if is_new and self.active_thread_id == tid: + self.tree.select_node(self.thread_nodes[tid]) + def check_threads(self, all_threads: List[str]) -> None: # check our list of threads against the list from app state to_remove = [tid for tid in self.threads.keys() if tid not in all_threads] for tid in to_remove: - btn = self.buttons.pop(tid, None) - if btn is not None: - btn.remove() + if tid in self.thread_nodes: + self.thread_nodes[tid].remove() + del self.thread_nodes[tid] + self.threads.pop(tid, None) if self.active_thread_id == tid: self.active_thread_id = None + # Optional: Cleanup empty categories + cats_to_remove = [] + for cat, node in self.category_nodes.items(): + if not node.children: + node.remove() + cats_to_remove.append(cat) + for cat in cats_to_remove: + del self.category_nodes[cat] + def set_active(self, thread_id: str) -> None: # if this is already active thread, then ignore if self.active_thread_id == thread_id: return - # remove “active” from old - if self.active_thread_id and self.active_thread_id in self.buttons: - self.buttons[self.active_thread_id].remove_class("active") - # set new self.active_thread_id = thread_id - if thread_id in self.buttons: - self.buttons[thread_id].add_class("active") - - @on(Button.Pressed, ".sidebar_button") - async def handle_thread_button_click(self, event: Button.Pressed): - btn = event.button - - # check if it is the new thread button - if btn.id == "new_thread": - # Post the command to create a new thread - self.post_message(CommandRequest("/thread new")) - # Wait for the thread to be created in the backend, then update UI - # We'll optimistically try to find the new thread after a short delay - # but the main UI will also update us via ChatThreadUpdate event. - # Instead, we emit a message to parent to switch to the new chat after it's created. - # So, we do nothing here except post the command. - return - - if btn.id not in self.buttons: - # ignore button if it doesn't belong to chat_list - return - - # switch chat thread - self.post_message(CommandRequest(f"/thread switch {btn.id}")) - - async def on_command_request(self, event: CommandRequest) -> None: - # Listen for /thread new command completion by monitoring thread list changes - # This is handled by the parent UI, so nothing to do here - pass + if thread_id in self.thread_nodes: + self.tree.select_node(self.thread_nodes[thread_id]) + + @on(Button.Pressed, "#new_thread") + async def handle_new_thread_click(self, event: Button.Pressed): + self.post_message(CommandRequest("/thread new")) + + @on(Tree.NodeSelected) + async def handle_tree_selection(self, event: Tree.NodeSelected): + node = event.node + if node.data: + # It's a thread node + thread_id = node.data + if thread_id != self.active_thread_id: + self.post_message(CommandRequest(f"/thread switch {thread_id}")) + else: + # Category node, toggle expansion + node.toggle() async def thread_update(self, msg_thread: MessageThread): - # Overridden to handle new thread activation + # Overridden to handle new thread activation or updates is_new = self.threads.get(msg_thread.thread_id, None) is None + await self.add_thread(msg_thread) + if is_new: - await self.add_thread(msg_thread) # Set as active and notify parent to switch view self.set_active(msg_thread.thread_id) self.post_message(self.NewChatActivated(msg_thread.thread_id)) - else: - # check if thread name may have changed - btn = self.buttons[msg_thread.thread_id] - if msg_thread.name and msg_thread.name != str(btn.label): - btn.label = msg_thread.name diff --git a/src/jrdev/ui/tui/chat/chat_view_widget.py b/src/jrdev/ui/tui/chat/chat_view_widget.py index 4d6d9a6..fea07f8 100644 --- a/src/jrdev/ui/tui/chat/chat_view_widget.py +++ b/src/jrdev/ui/tui/chat/chat_view_widget.py @@ -244,10 +244,13 @@ async def _update_layout_output_border_title(self, thread: MessageThread = None) thread: Optional[MessageThread] = self.core_app.get_current_thread() if thread: thread_name = thread.name.strip() if thread.name else None + mode = thread.metadata.get("mode", "chat") + mode_str = f" [{mode.upper()}]" if mode != "chat" else "" + if thread_name: - self.layout_output.border_title = f"Chat: {thread_name}" + self.layout_output.border_title = f"Chat: {thread_name}{mode_str}" else: - self.layout_output.border_title = f"Chat: {thread.thread_id}" + self.layout_output.border_title = f"Chat: {thread.thread_id}{mode_str}" else: self.layout_output.border_title = "Chat" @@ -269,13 +272,20 @@ async def _load_current_thread(self) -> None: if self.current_thread_id == thread.thread_id and self.message_scroller.children: # If it's the same thread and we already have bubbles, just scroll - self.message_scroller.scroll_end(animate=False) - return + # But wait, if messages were edited, we need to reload. + # ChatViewWidget doesn't know if messages were edited unless notified. + # For now, let's assume we reload if we are told to via on_thread_switched. + # But on_thread_switched is called by handle_chat_update. + # So if we edited a message, handle_chat_update is called, calling on_thread_switched. + # So we should probably reload even if thread ID is same, to reflect edits. + # But rebuilding all bubbles is expensive. + # For now, we will rebuild to ensure correctness. + pass self.current_thread_id = thread.thread_id await self.message_scroller.remove_children() - for msg in thread.messages: + for idx, msg in enumerate(thread.messages): role = msg["role"] body = msg["content"] @@ -288,7 +298,7 @@ async def _load_current_thread(self) -> None: else: display_content = body - bubble = MessageBubble(display_content, role=role) + bubble = MessageBubble(display_content, role=role, thread_id=thread.thread_id, message_index=idx) await self.message_scroller.mount(bubble) await self._prune_bubbles() @@ -300,7 +310,12 @@ async def add_user_message(self, raw_user_input: str) -> None: Adds a new user message bubble to the UI. Called by JrDevUI when user submits input via ChatInputWidget. """ - bubble = MessageBubble(raw_user_input, role="user") + thread = self.core_app.get_current_thread() + if not thread: + return + + idx = len(thread.messages) + bubble = MessageBubble(raw_user_input, role="user", thread_id=thread.thread_id, message_index=idx) await self.message_scroller.mount(bubble) await self._prune_bubbles() # Context display is updated when the thread itself is updated (e.g., via _load_current_thread) @@ -322,12 +337,35 @@ async def handle_stream_chunk(self, event: TextualEvents.StreamChunk) -> None: if last_bubble and last_bubble.role == "assistant": last_bubble.append_chunk(event.chunk) else: - new_bubble = MessageBubble(event.chunk, role="assistant") + idx = len(active_thread.messages) - 1 + new_bubble = MessageBubble(event.chunk, role="assistant", thread_id=active_thread.thread_id, message_index=idx) await self.message_scroller.mount(new_bubble) await self._prune_bubbles() self.message_scroller.scroll_end(animate=False) + @on(MessageBubble.MessageEdited) + async def on_message_edited(self, event: MessageBubble.MessageEdited) -> None: + # Use simple string replacement for now, avoiding newlines issues by just passing content + # But wait, CommandRequest takes string. + # If content has newlines, I can't pass it easily via command line arguments unless I quote it carefully. + # Since I'm inside the app, I can call the command handler directly or invoke the thread method? + # But keeping it consistent with commands is good. + # The /message command uses " ".join(args), so newlines are lost if split by shell. + # I'll invoke the thread method directly for reliability, then trigger update. + + thread = self.core_app.state.threads.get(event.thread_id) + if thread: + if thread.edit_message(event.index, event.content): + self.notify("Message updated.") + self.core_app.ui.chat_thread_update(event.thread_id) + else: + self.notify("Failed to update message.", severity="error") + + @on(MessageBubble.MessageDeleted) + async def on_message_deleted(self, event: MessageBubble.MessageDeleted) -> None: + self.post_message(CommandRequest(f"/message delete {event.thread_id} {event.index}")) + def set_project_context_on(self, is_on: bool) -> None: """Programmatically sets the project context switch state.""" if self.context_switch.value != is_on: @@ -506,4 +544,4 @@ def update_models(self) -> None: def handle_external_update(self, is_enabled: bool) -> None: """Handles external updates to the project context state (e.g., from core app).""" if self.context_switch.value != is_enabled: - self.set_project_context_on(is_enabled) \ No newline at end of file + self.set_project_context_on(is_enabled) diff --git a/src/jrdev/ui/tui/chat/message_bubble.py b/src/jrdev/ui/tui/chat/message_bubble.py index 5c1fb11..a541512 100644 --- a/src/jrdev/ui/tui/chat/message_bubble.py +++ b/src/jrdev/ui/tui/chat/message_bubble.py @@ -1,17 +1,19 @@ import pyperclip import logging +from typing import Optional from textual import on from textual.app import ComposeResult from textual.color import Color -from textual.containers import Vertical +from textual.containers import Vertical, Horizontal from textual.widgets import Button +from textual.message import Message from jrdev.ui.tui.terminal.terminal_text_area import TerminalTextArea logger = logging.getLogger("jrdev") class MessageBubble(Vertical): - """A widget to display a single chat message with a copy button.""" + """A widget to display a single chat message with copy, edit, and delete buttons.""" DEFAULT_CSS = """ MessageBubble { @@ -25,19 +27,46 @@ class MessageBubble(Vertical): border: none; margin-bottom: 0; /* Space between text area and button */ } - MessageBubble > Button { + .bubble_actions { dock: bottom; + height: 1; + width: 100%; + layout: horizontal; + } + .bubble_btn { height: 1; width: auto; - min-width: 8; /* "Copy" + padding */ + min-width: 6; + margin-right: 1; + border: none; + } + .bubble_btn:hover { + text-style: reverse; } """ - def __init__(self, message_content: str, role: str, id: str | None = None) -> None: + class MessageEdited(Message): + def __init__(self, thread_id: str, index: int, content: str) -> None: + self.thread_id = thread_id + self.index = index + self.content = content + super().__init__() + + class MessageDeleted(Message): + def __init__(self, thread_id: str, index: int) -> None: + self.thread_id = thread_id + self.index = index + super().__init__() + + def __init__(self, message_content: str, role: str, thread_id: str, message_index: int, id: str | None = None) -> None: super().__init__(id=id) self.message_content = message_content self.role = role + self.thread_id = thread_id + self.message_index = message_index self.is_thinking = message_content == "Thinking..." + self.is_editing = False + self.is_deleting = False border_color_map = { "user": "green", @@ -46,23 +75,35 @@ def __init__(self, message_content: str, role: str, id: str | None = None) -> No color = border_color_map.get(self.role, "grey") # Default to grey if role is unknown self.styles.border = ("round", color) - # Example: self.border_title = self.role.capitalize() # If you want a title def compose(self) -> ComposeResult: - """Compose the message bubble with a text area and a copy button.""" + """Compose the message bubble with a text area and action buttons.""" self.text_area = TerminalTextArea(_id=f"{self.id}-text_area") yield self.text_area - yield Button("Copy Selection") + + with Horizontal(classes="bubble_actions"): + yield Button("Copy", id="copy_btn", classes="bubble_btn") + yield Button("Edit", id="edit_btn", classes="bubble_btn") + yield Button("Delete", id="delete_btn", classes="bubble_btn") + # Create hidden buttons for edit/delete confirmation/actions + yield Button("Save", id="save_btn", classes="bubble_btn") + yield Button("Cancel", id="cancel_btn", classes="bubble_btn") + yield Button("Confirm Delete", id="confirm_delete_btn", classes="bubble_btn") def on_mount(self) -> None: """Called when the widget is mounted in the DOM.""" self.text_area.read_only = True self.text_area.soft_wrap = True self.text_area.text = self.message_content - self.text_area.cursor_blink = False # Already set in TerminalTextArea but good to be explicit + self.text_area.cursor_blink = False self.text_area.show_line_numbers = False self.text_area.can_focus = False + # Hide edit controls initially + self.query_one("#save_btn").visible = False + self.query_one("#cancel_btn").visible = False + self.query_one("#confirm_delete_btn").visible = False + if self.role == "user": self.styles.border = ("round", Color.parse("#63f554")) self.border_title = "Me" @@ -70,7 +111,7 @@ def on_mount(self) -> None: self.styles.border = ("round", Color.parse("#27dfd0")) self.border_title = "Assistant" - @on(Button.Pressed) + @on(Button.Pressed, "#copy_btn") async def handle_copy_button(self, event: Button.Pressed) -> None: """Handles the copy button press, copying selected or all text.""" text_to_copy = self.text_area.selected_text or self.text_area.text @@ -88,16 +129,77 @@ async def handle_copy_button(self, event: Button.Pressed) -> None: else: self.notify("Nothing to copy.", timeout=2) + @on(Button.Pressed, "#edit_btn") + async def handle_edit_start(self, event: Button.Pressed) -> None: + self.is_editing = True + self.text_area.read_only = False + self.text_area.can_focus = True + self.text_area.focus() + + self.query_one("#copy_btn").visible = False + self.query_one("#edit_btn").visible = False + self.query_one("#delete_btn").visible = False + self.query_one("#save_btn").visible = True + self.query_one("#cancel_btn").visible = True + self.query_one("#confirm_delete_btn").visible = False + + @on(Button.Pressed, "#cancel_btn") + async def handle_edit_cancel(self, event: Button.Pressed) -> None: + self.is_editing = False + self.is_deleting = False + self.text_area.read_only = True + self.text_area.can_focus = False + # Revert text changes + self.text_area.text = self.message_content + + self.query_one("#copy_btn").visible = True + self.query_one("#edit_btn").visible = True + self.query_one("#delete_btn").visible = True + self.query_one("#save_btn").visible = False + self.query_one("#cancel_btn").visible = False + self.query_one("#confirm_delete_btn").visible = False + + @on(Button.Pressed, "#save_btn") + async def handle_edit_save(self, event: Button.Pressed) -> None: + new_content = self.text_area.text + self.post_message(self.MessageEdited(self.thread_id, self.message_index, new_content)) + + self.message_content = new_content + + # Reset UI + self.is_editing = False + self.text_area.read_only = True + self.text_area.can_focus = False + + self.query_one("#copy_btn").visible = True + self.query_one("#edit_btn").visible = True + self.query_one("#delete_btn").visible = True + self.query_one("#save_btn").visible = False + self.query_one("#cancel_btn").visible = False + self.query_one("#confirm_delete_btn").visible = False + + + @on(Button.Pressed, "#delete_btn") + async def handle_delete_start(self, event: Button.Pressed) -> None: + self.is_deleting = True + self.query_one("#copy_btn").visible = False + self.query_one("#edit_btn").visible = False + self.query_one("#delete_btn").visible = False + self.query_one("#save_btn").visible = False + self.query_one("#cancel_btn").visible = True + self.query_one("#confirm_delete_btn").visible = True + + @on(Button.Pressed, "#confirm_delete_btn") + async def handle_delete_confirm(self, event: Button.Pressed) -> None: + self.post_message(self.MessageDeleted(self.thread_id, self.message_index)) + def append_chunk(self, chunk: str) -> None: """Appends a chunk of text to the message bubble's text area for streaming.""" - # if this was previously in thinking state, clear the thinking message if self.is_thinking: self.text_area.clear() self.is_thinking = False - # often the first chunks after thinking will be new lines while chunk.startswith("\n"): chunk = chunk.removeprefix("\n") - # add the new text self.text_area.append_text(chunk)