diff --git a/erospector/erospector b/erospector/erospector new file mode 100755 index 0000000000000..b6026ab611c05 Binary files /dev/null and b/erospector/erospector differ diff --git a/erospector/go.mod b/erospector/go.mod new file mode 100644 index 0000000000000..b1b4d35b665ba --- /dev/null +++ b/erospector/go.mod @@ -0,0 +1,8 @@ +module erospector + +go 1.24.3 + +require ( + github.com/joho/godotenv v1.5.1 // indirect + github.com/sashabaranov/go-openai v1.40.5 // indirect +) diff --git a/erospector/go.sum b/erospector/go.sum new file mode 100644 index 0000000000000..b9c67b4f5c814 --- /dev/null +++ b/erospector/go.sum @@ -0,0 +1,4 @@ +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/sashabaranov/go-openai v1.40.5 h1:SwIlNdWflzR1Rxd1gv3pUg6pwPc6cQ2uMoHs8ai+/NY= +github.com/sashabaranov/go-openai v1.40.5/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= diff --git a/erospector/gpt.py b/erospector/gpt.py new file mode 100644 index 0000000000000..07f3306994291 --- /dev/null +++ b/erospector/gpt.py @@ -0,0 +1,302 @@ +from azure.identity import DefaultAzureCredential, get_bearer_token_provider +from openai import AzureOpenAI +import sys +import json +import subprocess +import re +import os +import tempfile +import shutil + +def get_client(): + """Initialize and return the Azure OpenAI client""" + token_provider = get_bearer_token_provider( + DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" + ) + + client = AzureOpenAI( + api_version="2024-04-01-preview", + azure_endpoint="https://aadagarwal-hack.openai.azure.com/", + azure_ad_token_provider=token_provider + ) + return client + + +def run_go_tests_with_code(source_file_code: str, test_file_code: str, source_file_path: str = None) -> dict: + """Run Go tests with source code strings, using actual file paths when available""" + if source_file_path: + # Use the actual file location for better context + source_dir = os.path.dirname(source_file_path) + source_filename = os.path.basename(source_file_path) + + # Generate test file name by replacing .go with _test.go + if source_filename.endswith('.go'): + test_filename = source_filename[:-3] + '_test.go' + else: + test_filename = source_filename + '_test.go' + + test_file_path = os.path.join(source_dir, test_filename) + + # Write the test file in the same directory as the source + try: + with open(test_file_path, 'w') as f: + f.write(test_file_code) + + # Run tests in the source directory + result = subprocess.run( + ["go", "test", "-coverprofile=cover.out", "./..."], + cwd=source_dir, + capture_output=True, + text=True, + timeout=30, + ) + + # Parse coverage % from go tool output + try: + cov = subprocess.run( + ["go", "tool", "cover", "-func=cover.out"], + cwd=source_dir, + capture_output=True, + text=True, + ) + match = re.search(r"total:\s+\(statements\)\s+([\d.]+)%", cov.stdout) + coverage = float(match.group(1)) if match else 0.0 + except Exception: + coverage = 0.0 + + return { + "stdout": result.stdout, + "stderr": result.stderr, + "coverage_percent": coverage, + "returncode": result.returncode, + "test_file_path": test_file_path + } + + except Exception as e: + return {"error": f"Failed to write test file or run tests: {str(e)}"} + else: + # Fallback to temporary directory method + with tempfile.TemporaryDirectory() as temp_dir: + # Write the source code to files + source_path = os.path.join(temp_dir, "source.go") + test_path = os.path.join(temp_dir, "source_test.go") + + with open(source_path, 'w') as f: + f.write(source_file_code) + + with open(test_path, 'w') as f: + f.write(test_file_code) + + try: + result = subprocess.run( + ["go", "test", "-coverprofile=cover.out"], + cwd=temp_dir, + capture_output=True, + text=True, + timeout=10, + ) + except Exception as e: + return {"error": str(e)} + + # Parse coverage % from go tool output + try: + cov = subprocess.run( + ["go", "tool", "cover", "-func=cover.out"], + cwd=temp_dir, + capture_output=True, + text=True, + ) + match = re.search(r"total:\s+\(statements\)\s+([\d.]+)%", cov.stdout) + coverage = float(match.group(1)) if match else 0.0 + except Exception: + coverage = 0.0 + + return { + "stdout": result.stdout, + "stderr": result.stderr, + "coverage_percent": coverage, + "returncode": result.returncode + } + +def ask_gpt_for_test(file_code, primer): + """Generate a test for the whole file using GPT""" + client = get_client() + + prompt = f""" +Here is the full Go file content: +{file_code} + +Please generate unit tests for the entire file. +Do not mock out any dependencies. +Please simplify the tests as much as possible. +""" + response = client.chat.completions.create( + model="gpt-4.1", # model = "deployment_name" + messages=[ + {"role": "system", "content": "You are a Go expert writing unit tests for containerd."}, + {"role": "user", "content": "Here is some relevant code and tests to understand before writing any tests:"}, + {"role": "user", "content": primer}, + {"role": "user", "content": prompt} + ], + # input=[{"role": "user", "content": "Once the test is generated, run the go tests, and fix any issues until the tests run."}], + # tools=tools + ) + return response.choices[0].message.content + +def ask_gpt_for_test_and_run(file_code, primer, source_file_path=None): + """Generate a test for the whole file using GPT with iterative testing""" + client = get_client() + + tools = [{ + "type": "function", + "function": { + "name": "run_go_tests_with_code", + "description": "Run Go tests and return the results.", + "parameters": { + "type": "object", + "properties": { + "source_file_code": { + "type": "string", + "description": "Code of the Go source file." + }, + "test_file_code": { + "type": "string", + "description": "Code of the Go test file." + }, + "source_file_path": { + "type": "string", + "description": "Path to the Go source file for testing in proper context." + } + }, + "required": [ + "source_file_code", + "test_file_code", + "source_file_path" + ], + "additionalProperties": False + } + } + }] + + prompt = f""" +Here is the full Go file content: +{file_code} + +Please generate unit tests for the entire file. +Do not mock out any dependencies. +Please simplify the tests as much as possible. + +IMPORTANT: After generating the test code, you MUST use the run_go_tests_with_code function to test it immediately. +Call the function with the original source code, your generated test code, and the source file path: {source_file_path} +If there are any errors, fix them and run the tests again until they pass. +Do not ask for permission - just run the tests automatically. +""" + + messages = [ + {"role": "system", "content": "You are a Go expert writing unit tests for containerd. You MUST run tests automatically using the run_go_tests_with_code function after generating test code. Do not ask for permission. In addition, fix any issues until the tests run successfully"}, + {"role": "user", "content": "Here is some relevant code and tests to understand before writing any tests:"}, + {"role": "user", "content": primer}, + {"role": "user", "content": prompt} + ] + + response = client.chat.completions.create( + model="gpt-4.1", + messages=messages, + tools=tools + ) + + # Handle function calls if the model wants to use tools + if response.choices[0].message.tool_calls: + # Add the assistant's response to messages + messages.append(response.choices[0].message) + + # Process each tool call + for tool_call in response.choices[0].message.tool_calls: + if tool_call.function.name == "run_go_tests_with_code": + # Parse the function arguments + args = json.loads(tool_call.function.arguments) + + # Execute the function with code strings and source file path + result = run_go_tests_with_code( + args["source_file_code"], + args["test_file_code"], + args.get("source_file_path", source_file_path) + ) + + # Add the function result to messages + messages.append({ + "role": "tool", + "content": json.dumps(result), + "tool_call_id": tool_call.id + }) + + # Get the final response after function execution + final_response = client.chat.completions.create( + model="gpt-4.1", + messages=messages + [{ + "role": "user", + "content": "Please provide the final, corrected test code as a complete Go test file. Do not describe what you did - just provide the working code." + }], + tools=tools + ) + + return final_response.choices[0].message.content + # else: + # # No function calls were made, let's force a test run + # # Extract the test code from the response and run it + # test_code = response.choices[0].message.content + + # # Try to run the tests with the generated code + # try: + # result = run_go_tests_with_code(file_code, test_code, source_file_path) + + # # Add the result to the conversation and ask the model to improve if needed + # messages.append(response.choices[0].message) + # messages.append({ + # "role": "user", + # "content": f"I ran your test code and got these results:\n{json.dumps(result, indent=2)}\n\nPlease analyze the results and provide the final, corrected test code as a complete Go test file. If there were errors, fix them. Do not describe what you're doing - just provide the working code." + # }) + + # final_response = client.chat.completions.create( + # model="gpt-4.1", + # messages=messages + # ) + + # return final_response.choices[0].message.content + # except Exception as e: + # # If we can't run the tests, just return the original response + # return test_code + + +def main(): + """Main function to handle file input""" + if len(sys.argv) != 2: + print(json.dumps({"error": "Usage: python gpt.py "})) + return 1 + + input_file_path = sys.argv[1] + + try: + # Read input data from file + with open(input_file_path, 'r') as f: + input_data = json.load(f) + + file_code = input_data.get('fileCode', '') + primer = input_data.get('primer', '') + mode = input_data.get('mode', 'simple') # 'simple' or 'test_and_run' + source_file_path = input_data.get('sourceFilePath', '') + + if mode == 'test_and_run': + result = ask_gpt_for_test_and_run(file_code, primer, source_file_path) + else: + result = ask_gpt_for_test(file_code, primer) + + print(json.dumps({"result": result})) + return 0 + except Exception as e: + print(json.dumps({"error": str(e)})) + return 1 + +# If the script is run directly, call main() +if __name__ == "__main__": + sys.exit(main()) \ No newline at end of file diff --git a/erospector/gpt_python.go b/erospector/gpt_python.go new file mode 100644 index 0000000000000..011cc69f8ebac --- /dev/null +++ b/erospector/gpt_python.go @@ -0,0 +1,188 @@ +package main + +import ( + "encoding/json" + "errors" + "os" + "os/exec" + "path/filepath" + "runtime" +) + +// Response represents the JSON response from the Python script +type Response struct { + Result string `json:"result"` + Error string `json:"error"` +} + +// AskGPTForTestPython calls the Python script to generate a test using Azure OpenAI +func AskGPTForTestPython(fileCode string, primer string) (string, error) { + // Get the path to the Python script + _, filename, _, _ := runtime.Caller(0) + dir := filepath.Dir(filename) + scriptPath := filepath.Join(dir, "gpt.py") + + // Check if the script exists + if _, err := os.Stat(scriptPath); os.IsNotExist(err) { + return "", errors.New("Python script not found at: " + scriptPath) + } + + // Create a temporary file for the input data + inputFile, err := os.CreateTemp("", "erospector-input-*.json") + if err != nil { + return "", errors.New("Error creating temporary file: " + err.Error()) + } + defer os.Remove(inputFile.Name()) + + // Write the input data as JSON + inputData := map[string]string{ + "fileCode": fileCode, + "primer": primer, + } + + inputJSON, err := json.Marshal(inputData) + if err != nil { + return "", errors.New("Error marshaling input data: " + err.Error()) + } + if _, err := inputFile.Write(inputJSON); err != nil { + return "", errors.New("Error writing to temporary file: " + err.Error()) + } + if err := inputFile.Close(); err != nil { + return "", errors.New("Error closing temporary file: " + err.Error()) + } + + // Run the Python script with the input file path as argument + cmd := exec.Command("python3", scriptPath, inputFile.Name()) + output, err := cmd.Output() + if err != nil { + return "", errors.New("Error executing Python script: " + err.Error()) + } + + // Parse the JSON response + var response Response + if err := json.Unmarshal(output, &response); err != nil { + return "", errors.New("Error parsing Python output: " + err.Error()) + } + + // Check if there was an error in the Python script + if response.Error != "" { + return "", errors.New("Python error: " + response.Error) + } + + return response.Result, nil +} + +// Create a function that calls ask_gpt_for_test_and_run in gpt.py +func AskGPTForTestAndRun(fileCode string, testFileCode string) (string, error) { + // Get the path to the Python script + _, filename, _, _ := runtime.Caller(0) + dir := filepath.Dir(filename) + scriptPath := filepath.Join(dir, "gpt.py") + + // Check if the script exists + if _, err := os.Stat(scriptPath); os.IsNotExist(err) { + return "", errors.New("Python script not found at: " + scriptPath) + } + + // Create a temporary file for the input data + inputFile, err := os.CreateTemp("", "erospector-input-*.json") + if err != nil { + return "", errors.New("Error creating temporary file: " + err.Error()) + } + defer os.Remove(inputFile.Name()) + + // Write the input data as JSON + inputData := map[string]string{ + "fileCode": fileCode, + "testFileCode": testFileCode, + } + + inputJSON, err := json.Marshal(inputData) + if err != nil { + return "", errors.New("Error marshaling input data: " + err.Error()) + } + if _, err := inputFile.Write(inputJSON); err != nil { + return "", errors.New("Error writing to temporary file: " + err.Error()) + } + if err := inputFile.Close(); err != nil { + return "", errors.New("Error closing temporary file: " + err.Error()) + } + + // Run the Python script with the input file path as argument + cmd := exec.Command("python3", scriptPath, inputFile.Name()) + output, err := cmd.Output() + if err != nil { + return "", errors.New("Error executing Python script: " + err.Error()) + } + + // Parse the JSON response + var response Response + if err := json.Unmarshal(output, &response); err != nil { + return "", errors.New("Error parsing Python output: " + err.Error()) + } + + // Check if there was an error in the Python script + if response.Error != "" { + return "", errors.New("Python error: " + response.Error) + } + + return response.Result, nil +} + +// AskGPTForTestPythonWithMode calls the Python script with a specified mode +func AskGPTForTestPythonWithMode(fileCode string, primer string, mode string, sourceFilePath string) (string, error) { + // Get the path to the Python script + _, filename, _, _ := runtime.Caller(0) + dir := filepath.Dir(filename) + scriptPath := filepath.Join(dir, "gpt.py") + + // Check if the script exists + if _, err := os.Stat(scriptPath); os.IsNotExist(err) { + return "", errors.New("Python script not found at: " + scriptPath) + } + + // Create a temporary file for the input data + inputFile, err := os.CreateTemp("", "erospector-input-*.json") + if err != nil { + return "", errors.New("Error creating temporary file: " + err.Error()) + } + defer os.Remove(inputFile.Name()) + + // Write the input data as JSON + inputData := map[string]string{ + "fileCode": fileCode, + "primer": primer, + "mode": mode, + "sourceFilePath": sourceFilePath, + } + inputJSON, err := json.Marshal(inputData) + if err != nil { + return "", errors.New("Error marshaling input data: " + err.Error()) + } + if _, err := inputFile.Write(inputJSON); err != nil { + return "", errors.New("Error writing to temporary file: " + err.Error()) + } + if err := inputFile.Close(); err != nil { + return "", errors.New("Error closing temporary file: " + err.Error()) + } + + // Run the Python script with the input file path as argument + cmd := exec.Command("python3", scriptPath, inputFile.Name()) + output, err := cmd.Output() + if err != nil { + return "", errors.New("Error executing Python script: " + err.Error()) + } + + // Parse the JSON response + var response Response + if err := json.Unmarshal(output, &response); err != nil { + return "", errors.New("Error parsing Python output: " + err.Error()) + } + + // Check if there was an error in the Python script + if response.Error != "" { + return "", errors.New("Python error: " + response.Error) + } + + return response.Result, nil +} diff --git a/erospector/main.go b/erospector/main.go new file mode 100644 index 0000000000000..15b8f9c46c193 --- /dev/null +++ b/erospector/main.go @@ -0,0 +1,145 @@ +package main + +import ( + "flag" + "fmt" + "os" + "path/filepath" + "strings" +) + +func main() { + // Define command-line flags + var targetFunctions string + var listOnly bool + var testMode string + + flag.StringVar(&targetFunctions, "funcs", "", "Comma-separated list of function names to generate tests for (e.g., 'Function1,Function2')") + flag.StringVar(&targetFunctions, "f", "", "Comma-separated list of function names to generate tests for (short form)") + flag.BoolVar(&listOnly, "list", false, "Only list available functions without generating tests") + flag.BoolVar(&listOnly, "l", false, "Only list available functions without generating tests (short form)") + flag.StringVar(&testMode, "mode", "simple", "Test generation mode: 'simple' or 'test_and_run'") + flag.StringVar(&testMode, "m", "simple", "Test generation mode: 'simple' or 'test_and_run' (short form)") + + // Parse flags but leave os.Args[0] which is the program name + flag.CommandLine.Parse(os.Args[1:]) + + // After flag parsing, the remaining arguments are in flag.Args() + if len(flag.Args()) < 1 { + fmt.Printf("Usage: erospector [options] \n") + fmt.Printf("\nOptions:\n") + fmt.Printf(" -funcs, -f Comma-separated list of function names to generate tests for\n") + fmt.Printf(" -list, -l Only list available functions without generating tests\n") + fmt.Printf(" -mode, -m Test generation mode: 'simple' (default) or 'test_and_run'\n") + fmt.Printf("\nExamples:\n") + fmt.Printf(" erospector /path/to/file.go Generate tests for all functions\n") + fmt.Printf(" erospector -list /path/to/file.go List all functions without generating tests\n") + fmt.Printf(" erospector -f Func1,Func2 /path/to/file.go Generate tests only for Func1 and Func2\n") + fmt.Printf(" erospector -m test_and_run /path/to/file.go Generate tests with iterative improvement\n") + return + } + + filePath := flag.Args()[0] + funcs, err := ExtractFunctions(filePath) + if err != nil { + panic(err) + } + + // Filter functions if specific names were provided + var targetFuncs []GoFunction + if targetFunctions != "" { + funcNames := strings.Split(targetFunctions, ",") + funcNameMap := make(map[string]bool) + for _, name := range funcNames { + funcNameMap[strings.TrimSpace(name)] = true + } + + for _, fn := range funcs { + if funcNameMap[fn.Name] { + targetFuncs = append(targetFuncs, fn) + } + } + + if len(targetFuncs) == 0 { + fmt.Printf("Warning: None of the specified functions were found in the file.\n") + fmt.Printf("Available functions in %s:\n", filePath) + for _, fn := range funcs { + fmt.Printf(" - %s\n", fn.Name) + } + return + } + + funcs = targetFuncs + } + + // If list-only mode is enabled, just display functions and exit + if listOnly { + fmt.Printf("\nAvailable functions:\n") + for i, fn := range funcs { + fmt.Printf("%d. %s\n", i+1, fn.Name) + } + return + } + + // Get the absolute path to the repo root directory + repoRoot := "/home/aadhar/repos/containerd/" + + // Directories to scan recursively for Go files + directoriesWithGoFiles := []string{ + filepath.Join(repoRoot, "plugins/snapshots/erofs/"), + filepath.Join(repoRoot, "plugins/diff/erofs/"), + filepath.Join(repoRoot, "core/diff/"), + filepath.Join(repoRoot, "core/snapshots/"), + filepath.Join(repoRoot, "internal/erofsutils/"), + } + + // Find all Go files in the specified directories + primerFiles, missingDirs := FindGoFilesInDirectories(directoriesWithGoFiles) + + if len(missingDirs) > 0 { + fmt.Printf("\nWarning: The following directories could not be found:\n") + for _, dir := range missingDirs { + fmt.Printf(" - %s\n", dir) + } + } + + // Load all the files into the primer + primer, err := LoadPrimerFromFiles(primerFiles) + if err != nil { + fmt.Printf("Warning: Error loading primer from files: %v\n", err) + } + + // fmt.Printf("Preloaded the following files for context:\n") + // for _, file := range primerFiles { + // fmt.Printf(" - %s\n", file) + // } + + // Read the source file content once + fileContent, err := os.ReadFile(filePath) + if err != nil { + fmt.Printf("Error reading source file: %v\n", err) + return + } + fileCodeString := string(fileContent) + + // Validate test mode + if testMode != "simple" && testMode != "test_and_run" { + fmt.Printf("Error: Invalid test mode '%s'. Must be 'simple' or 'test_and_run'\n", testMode) + return + } + + fmt.Printf("Using test generation mode: %s\n", testMode) + + testCode, err := AskGPTForTestPythonWithMode(fileCodeString, primer, testMode, filePath) + if err != nil { + fmt.Printf("Error from GPT: %s\n", err) + return + } + + _, err = WriteTestFile(filePath, testCode) + if err != nil { + fmt.Printf("File write error: %s\n", err) + return + } + fmt.Printf("Test written\n") +} diff --git a/erospector/parser.go b/erospector/parser.go new file mode 100644 index 0000000000000..6e4f4d29b0291 --- /dev/null +++ b/erospector/parser.go @@ -0,0 +1,38 @@ +package main + +import ( + "go/ast" + "go/parser" + "go/token" + "os" +) + +type GoFunction struct { + Name string + Code string +} + +func ExtractFunctions(filepath string) ([]GoFunction, error) { + src, err := os.ReadFile(filepath) + if err != nil { + return nil, err + } + + fset := token.NewFileSet() + node, err := parser.ParseFile(fset, filepath, src, parser.ParseComments) + if err != nil { + return nil, err + } + + var funcs []GoFunction + for _, decl := range node.Decls { + if fn, ok := decl.(*ast.FuncDecl); ok { + code := src[fn.Pos()-1 : fn.End()-1] + funcs = append(funcs, GoFunction{ + Name: fn.Name.Name, + Code: string(code), + }) + } + } + return funcs, nil +} diff --git a/erospector/primer_loader.go b/erospector/primer_loader.go new file mode 100644 index 0000000000000..9702861e2dc57 --- /dev/null +++ b/erospector/primer_loader.go @@ -0,0 +1,65 @@ +package main + +import ( + "fmt" + "os" + "path/filepath" + "strings" +) + +// FindGoFilesInDirectories finds all .go files in the specified directories recursively +func FindGoFilesInDirectories(dirs []string) ([]string, []string) { + var files []string + var missingDirs []string + + for _, dir := range dirs { + // Check if directory exists first + if _, err := os.Stat(dir); os.IsNotExist(err) { + missingDirs = append(missingDirs, dir) + fmt.Printf("Warning: Directory doesn't exist: %s\n", dir) + continue + } + + err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + if err != nil { + fmt.Printf("Warning: Error accessing path %s: %v\n", path, err) + return nil // Continue walking even if there's an error with this file + } + + // Skip directories themselves + if info.IsDir() { + return nil + } + + // Only include .go files + if filepath.Ext(path) == ".go" { + files = append(files, path) + } + + return nil + }) + + if err != nil { + fmt.Printf("Warning: Error walking directory %s: %v\n", dir, err) + } + } + + return files, missingDirs +} + +func LoadPrimerFromFiles(files []string) (string, error) { + var primer strings.Builder + + for _, path := range files { + content, err := os.ReadFile(path) + if err != nil { + return "", fmt.Errorf("failed to read %s: %w", path, err) + } + + primer.WriteString(fmt.Sprintf("\n--- BEGIN FILE: %s ---\n", filepath.Base(path))) + primer.WriteString(string(content)) + primer.WriteString(fmt.Sprintf("\n--- END FILE: %s ---\n\n", filepath.Base(path))) + } + + return primer.String(), nil +} diff --git a/erospector/writer.go b/erospector/writer.go new file mode 100644 index 0000000000000..a5322a02ca89f --- /dev/null +++ b/erospector/writer.go @@ -0,0 +1,96 @@ +package main + +import ( + "fmt" + "os" + "path/filepath" + "regexp" + "strings" +) + +// extractGoCodeBlocks extracts code inside Go code blocks marked by ```go or ```golang +func extractGoCodeBlocks(content string) string { + // Regular expression to match Go code blocks in markdown + codeBlockRegex := regexp.MustCompile("(?s)```(?:go|golang)\\s*\n(.*?)```") + + // Find all matches + matches := codeBlockRegex.FindAllStringSubmatch(content, -1) + + if len(matches) == 0 { + // If no Go code blocks are found, try to find any code blocks + codeBlockRegex = regexp.MustCompile("(?s)```\\s*\n(.*?)```") + matches = codeBlockRegex.FindAllStringSubmatch(content, -1) + if len(matches) == 0 { + // If still no code blocks, return the original content + return content + } + } + + // Concatenate all Go code blocks + var result strings.Builder + for _, match := range matches { + if len(match) >= 2 { + // Use the code exactly as it appears in the code block + code := match[1] + result.WriteString(code) + result.WriteString("\n\n") + } + } + + return result.String() +} + +func WriteTestFile(sourceFile string, testCode string) (string, error) { + // Get the directory where the source file is located + sourceDir := filepath.Dir(sourceFile) + + // Get the base filename without extension + base := strings.TrimSuffix(filepath.Base(sourceFile), ".go") + + // Create the test file path in the same directory as the source file + testFile := filepath.Join(sourceDir, base+"_test.go") + + // Extract Go code from the GPT response + // goCode := extractGoCodeBlocks(testCode) + + // Check if the file exists + fileExists := false + _, err := os.Stat(testFile) + if err == nil { + fileExists = true + } + + var f *os.File + + if !fileExists { + // Create a new file if it doesn't exist + f, err = os.Create(testFile) + if err != nil { + return "", err + } + fmt.Printf("Creating new test file: %s\n", testFile) + } else { + // Overwrite the existing file with new test code + f, err = os.Create(testFile) + if err != nil { + return "", err + } + fmt.Printf("Updating existing test file: %s\n", testFile) + } + + // Extract the actual Go code from the response (in case it's wrapped in markdown) + goCode := extractGoCodeBlocks(testCode) + if goCode == testCode { + // If no code blocks were found, use the original content + goCode = testCode + } + + // Write the test code to the file + _, err = f.WriteString(goCode) + if err != nil { + return "", err + } + defer f.Close() + + return testFile, err +}