-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcoder.py
More file actions
128 lines (90 loc) · 3.9 KB
/
coder.py
File metadata and controls
128 lines (90 loc) · 3.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import asyncio
import re
from typing import List
from context import ContextManager
from models import Model
from typing import Set
from editor import SmartEditor
from edit_prompts import EditPrompts
class Coder:
"""Basic code editing orchestrator"""
def __init__(self, model: Model, files: List[str]):
self.model = model
self.files = files or []
self.context_manager = ContextManager()
self.chat_history = []
self.editor = SmartEditor(self.context_manager.root_path)
self.edit_prompts = EditPrompts()
self.edit_format = "diff"
async def run_single(self, message: str):
"""Run a single interaction with the model"""
response = await self._process_message(message)
print("Assistant:", response)
async def run_chat_loop(self):
"""Interactive chat loop"""
while True:
try:
user_input = input("> ")
if user_input.lower() in ["exit", "quit"]:
break
response = await self._process_message(user_input)
print("Assistant:", response)
except KeyboardInterrupt:
print("\nGoodbye!")
break
async def _process_message(self, message: str) -> str:
"""Process message and return the model response"""
mentioned_symbols = self._extract_mentioned_symbols(message)
# If no files specified, scan the current directory for Python files
files_to_analyze = self.files
if not files_to_analyze:
repo_info = self.context_manager.scan_repository()
files_to_analyze = repo_info["files"][:10] # Limit to first 10 files to avoid overwhelming
context = self.context_manager.get_context_for_message(files_to_analyze, message)
edit_prompt = self.edit_prompts.get_edit_prompt(
self.edit_format, context, message
)
messages = [
{
"role": "system",
"content": edit_prompt
}
]
if context.strip():
messages.append({
"role":"system",
"content": f"Multi-language Repository Context:\n{context}"
})
messages.append({
"role": "user",
"content": message
})
response = await self.model.send_completion(messages)
if self._contains_edits(response):
edit_results = self.editor.apply_edits(response, self.edit_format)
if edit_results["success"]:
edited_files = ", ".join(edit_results["edited_files"])
response += f"\n\n Successfully applied edits to: {edited_files}"
else:
response += f"\n\n Failed to apply some edits. Errors: {edit_results['errors']}"
self.chat_history.append({"role": "user", "content": message})
return response
def _extract_mentioned_symbols(self, message: str) -> Set[str]:
"""Extract function/class names mentioned in user message"""
symbols = set()
# Look for function calls: word()
func_pattern = r'\b([a-zA-Z_][a-zA-Z0-9_]*)\s*\('
symbols.update(re.findall(func_pattern, message))
# Look for class references: ClassName
class_pattern = r'\b([A-Z][a-zA-Z0-9_]*)\b'
symbols.update(re.findall(class_pattern, message))
return symbols
def _contains_edits(self, response: str) -> bool:
"""Check if response contains edit instructions"""
edit_indicators = [
"<<<<<<< SEARCH",
"```diff",
"--- ",
"+++ "
]
return any(indicator in response for indicator in edit_indicators)