diff --git a/src/memos/llms/openai.py b/src/memos/llms/openai.py index 1a1703340..959404f44 100644 --- a/src/memos/llms/openai.py +++ b/src/memos/llms/openai.py @@ -70,6 +70,7 @@ def generate(self, messages: MessageList, **kwargs) -> str: temperature=temperature, max_tokens=max_tokens, top_p=top_p, + **kwargs, ) logger.info(f"Response from OpenAI: {response.model_dump_json()}") response_content = response.choices[0].message.content diff --git a/src/memos/llms/vllm.py b/src/memos/llms/vllm.py index c3750bb4b..1bea2879d 100644 --- a/src/memos/llms/vllm.py +++ b/src/memos/llms/vllm.py @@ -85,16 +85,16 @@ def build_vllm_kv_cache(self, messages: Any) -> str: return prompt - def generate(self, messages: list[MessageDict]) -> str: + def generate(self, messages: list[MessageDict], **kwargs) -> str: """ Generate a response from the model. """ if self.client: - return self._generate_with_api_client(messages) + return self._generate_with_api_client(messages, **kwargs) else: raise RuntimeError("API client is not available") - def _generate_with_api_client(self, messages: list[MessageDict]) -> str: + def _generate_with_api_client(self, messages: list[MessageDict], **kwargs) -> str: """ Generate response using vLLM API client. """ @@ -106,6 +106,7 @@ def _generate_with_api_client(self, messages: list[MessageDict]) -> str: "max_tokens": int(getattr(self.config, "max_tokens", 1024)), "top_p": float(getattr(self.config, "top_p", 0.9)), "extra_body": {"chat_template_kwargs": {"enable_thinking": False}}, + **kwargs, } response = self.client.chat.completions.create(**completion_kwargs) diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index 13515c038..8541ec984 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -27,6 +27,7 @@ SIMPLE_STRUCT_MEM_READER_EXAMPLE_ZH, SIMPLE_STRUCT_MEM_READER_PROMPT, SIMPLE_STRUCT_MEM_READER_PROMPT_ZH, + reader_output_schema, ) from memos.utils import timed @@ -209,7 +210,9 @@ def _get_llm_response(self, mem_str: str) -> dict: prompt = prompt.replace(examples, "") messages = [{"role": "user", "content": prompt}] try: - response_text = self.llm.generate(messages) + response_text = self.llm.generate( + messages, response_format={"type": "json_object", "schema": reader_output_schema} + ) response_json = self.parse_json_result(response_text) except Exception as e: logger.error(f"[LLM] Exception during chat generation: {e}") diff --git a/src/memos/templates/mem_reader_prompts.py b/src/memos/templates/mem_reader_prompts.py index ec6812743..9cea54933 100644 --- a/src/memos/templates/mem_reader_prompts.py +++ b/src/memos/templates/mem_reader_prompts.py @@ -341,3 +341,46 @@ } """ + +reader_output_schema = { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "properties": { + "memory list": { + "type": "array", + "items": { + "type": "object", + "properties": { + "key": { + "type": "string", + "description": "A brief title or identifier for the memory.", + }, + "memory_type": { + "type": "string", + "enum": ["LongTermMemory", "ShortTermMemory", "WorkingMemory"], + "description": "The type of memory, expected to be 'LongTermMemory' in this context.", + }, + "value": { + "type": "string", + "description": "Detailed description of the memory, including viewpoint, time, and content.", + }, + "tags": { + "type": "array", + "items": {"type": "string"}, + "description": "List of keywords or categories associated with the memory.", + }, + }, + "required": ["key", "memory_type", "value", "tags"], + "additionalProperties": False, + }, + "description": "List of memory entries.", + }, + "summary": { + "type": "string", + "description": "A synthesized summary of the overall situation based on all memories.", + }, + }, + "required": ["memory list", "summary"], + "additionalProperties": False, + "description": "Structured output containing a list of memories and a summary.", +}