diff --git a/chatgpt/model.go b/chatgpt/model.go index c143ab2..a83b03e 100644 --- a/chatgpt/model.go +++ b/chatgpt/model.go @@ -20,6 +20,7 @@ type ChatGPTResponse struct { Qualifications []string `json:"qualifications"` TechStack []string `json:"tech_stack"` Level string `json:"level"` + UserRespSummary []string `json:"user_response_summary"` } type OpenAIClient struct { @@ -157,10 +158,36 @@ Here is the input data: %s`, jdSummary) } +func BuildResponseSummary(question, response string) string { + return fmt.Sprintf(`Extract the **key technical points** from the following backend interview answer. + +Break the response into a list of concise, self-contained statements. Each item should: +- Represent a distinct technical idea, method, or decision +- Be understandable without the original question +- Focus only on what the user **actually said**, not what they should have said +- Exclude filler, vague claims, or generalities +- Be written in the past tense + +Output only valid JSON in this format: +{ +"user_response_summary":[ + "First technical point...", + "Second technical point...", + ... +]} + +Interview question: +"%s" + +User’s answer: +"%s"`, question, response) +} + type AIClient interface { GetChatGPTResponse(prompt string) (*ChatGPTResponse, error) GetChatGPTResponseConversation(conversationHistory []map[string]string) (*ChatGPTResponse, error) GetChatGPT35Response(prompt string) (*ChatGPTResponse, error) ExtractJDInput(jd string) (*JDParsedOutput, error) ExtractJDSummary(jdInput *JDParsedOutput) (string, error) + ExtractResponseSummary(userResponse, answer string) (*ChatGPTResponse, error) } diff --git a/chatgpt/service.go b/chatgpt/service.go index e627646..7b3bf01 100644 --- a/chatgpt/service.go +++ b/chatgpt/service.go @@ -269,3 +269,13 @@ func (c *OpenAIClient) ExtractJDSummary(jdInput *JDParsedOutput) (string, error) return jdSummary, nil } + +func (c *OpenAIClient) ExtractResponseSummary(question, response string) (*ChatGPTResponse, error) { + systemPrompt := BuildResponseSummary(question, response) + summarizedResponse, err := c.GetChatGPT35Response(systemPrompt) + if err != nil { + return nil, err + } + + return summarizedResponse, nil +} diff --git a/conversation/helpers.go b/conversation/helpers.go index a5d9159..5631741 100644 --- a/conversation/helpers.go +++ b/conversation/helpers.go @@ -3,20 +3,23 @@ package conversation import ( "encoding/json" "errors" + "fmt" "log" "sort" + "strings" "time" "github.com/michaelboegner/interviewer/chatgpt" "github.com/michaelboegner/interviewer/interview" ) -func GetChatGPTResponses(conversation *Conversation, openAI chatgpt.AIClient, interviewRepo interview.InterviewRepo) (*chatgpt.ChatGPTResponse, string, error) { - conversationHistory, err := GetConversationHistory(conversation, interviewRepo) +func GetChatGPTResponses(conversation *Conversation, openAI chatgpt.AIClient, interviewRepo interview.InterviewRepo, conversationContext []string) (*chatgpt.ChatGPTResponse, string, error) { + conversationHistory, err := GetConversationHistory(conversation, interviewRepo, conversationContext) if err != nil { log.Printf("GetConversationHistory failed: %v", err) return nil, "", err } + chatGPTResponse, err := openAI.GetChatGPTResponseConversation(conversationHistory) if err != nil { log.Printf("getNextQuestion failed: %v", err) @@ -31,7 +34,7 @@ func GetChatGPTResponses(conversation *Conversation, openAI chatgpt.AIClient, in return chatGPTResponse, chatGPTResponseString, nil } -func GetConversationHistory(conversation *Conversation, interviewRepo interview.InterviewRepo) ([]map[string]string, error) { +func GetConversationHistory(conversation *Conversation, interviewRepo interview.InterviewRepo, conversationContext []string) ([]map[string]string, error) { var arrayOfTopics []string var currentTopic string chatGPTConversationArray := make([]map[string]string, 0) @@ -65,6 +68,7 @@ func GetConversationHistory(conversation *Conversation, interviewRepo interview. questionNumbersSorted = append(questionNumbersSorted, questionNumber) } sort.Ints(questionNumbersSorted) + lastQuestionNumber := questionNumbersSorted[len(questionNumbersSorted)-1] for _, questionNumber := range questionNumbersSorted { question := topic.Questions[questionNumber] for i, message := range question.Messages { @@ -78,13 +82,33 @@ func GetConversationHistory(conversation *Conversation, interviewRepo interview. if message.Author == "interviewer" { role = "assistant" } + + content := message.Content + isFinalInjectionTarget := questionNumber == lastQuestionNumber && + message.Author == "user" + // DEBUG + fmt.Printf("isFinalInjectionTarget: %v\n", isFinalInjectionTarget) + fmt.Printf("conversationContext: %v\n", conversationContext) + if isFinalInjectionTarget && len(conversationContext) > 0 { + formattedContext := strings.Join(conversationContext, "\n") + content = fmt.Sprintf("Relevant prior user context:\n%s\n\n--- BEGIN USER'S ACTUAL RESPONSE ---\n%s", formattedContext, content) + } + chatGPTConversationArray = append(chatGPTConversationArray, map[string]string{ "role": role, - "content": message.Content, + "content": content, }) } } + fmt.Println("------ DEBUG: Formatted Conversation History ------") + for i, msg := range chatGPTConversationArray { + fmt.Printf("\n--- Message %d ---\n", i+1) + fmt.Printf("Role : %s\n", msg["role"]) + fmt.Printf("Content:\n%s\n", msg["content"]) + } + fmt.Println("------ END DEBUG ------") + return chatGPTConversationArray, nil } diff --git a/conversation/model.go b/conversation/model.go index 1ca74a5..34894eb 100644 --- a/conversation/model.go +++ b/conversation/model.go @@ -72,7 +72,7 @@ type ConversationRepo interface { CreateQuestion(conversation *Conversation, prompt string) (int, error) AddQuestion(question *Question) (int, error) GetQuestions(Conversation *Conversation) ([]*Question, error) - CreateMessages(conversation *Conversation, messages []Message) error + CreateMessages(conversation *Conversation, messages []Message) (int, error) AddMessage(conversationID, topic_id, questionNumber int, message Message) (int, error) GetMessages(conversationID, topic_id, questionNumber int) ([]Message, error) } diff --git a/conversation/repository.go b/conversation/repository.go index 4ac712d..b5d8f5e 100644 --- a/conversation/repository.go +++ b/conversation/repository.go @@ -210,7 +210,7 @@ func (repo *Repository) GetQuestions(conversation *Conversation) ([]*Question, e return questions, nil } -func (repo *Repository) CreateMessages(conversation *Conversation, messages []Message) error { +func (repo *Repository) CreateMessages(conversation *Conversation, messages []Message) (int, error) { var id int for _, message := range messages { query := ` @@ -229,14 +229,14 @@ func (repo *Repository) CreateMessages(conversation *Conversation, messages []Me ).Scan(&id) if err == sql.ErrNoRows { - return err + return 0, err } else if err != nil { log.Printf("Error querying conversation: %v\n", err) - return err + return 0, err } } - return nil + return id, nil } func (repo *Repository) AddMessage(conversationID, topic_id, questionNumber int, message Message) (int, error) { diff --git a/conversation/service.go b/conversation/service.go index ad66d6f..034a19b 100644 --- a/conversation/service.go +++ b/conversation/service.go @@ -1,10 +1,13 @@ package conversation import ( + "context" "errors" "log" + "time" "github.com/michaelboegner/interviewer/chatgpt" + "github.com/michaelboegner/interviewer/embedding" "github.com/michaelboegner/interviewer/interview" ) @@ -30,9 +33,11 @@ func CreateEmptyConversation(repo ConversationRepo, interviewID int, subTopic st } func CreateConversation( + ctx context.Context, repo ConversationRepo, interviewRepo interview.InterviewRepo, openAI chatgpt.AIClient, + embeddingService embedding.Service, conversation *Conversation, interviewID int, prompt, @@ -61,13 +66,29 @@ func CreateConversation( topic.Questions[questionNumber] = NewQuestion(conversationID, topicID, questionNumber, firstQuestion, messages) conversation.Topics[topicID] = topic - err = repo.CreateMessages(conversation, messages) + messageID, err := repo.CreateMessages(conversation, messages) if err != nil { log.Printf("repo.CreateMessages failed: %v", err) return nil, err } - chatGPTResponse, chatGPTResponseString, err := GetChatGPTResponses(conversation, openAI, interviewRepo) + embedInput := embedding.EmbedInput{ + InterviewID: interviewID, + ConversationID: conversationID, + TopicID: topicID, + QuestionNumber: questionNumber, + MessageID: messageID, + Question: firstQuestion, + UserResponse: message, + CreatedAt: time.Now().UTC(), + } + + conversationContext, err := embeddingService.ProcessAndRetrieve(ctx, embedInput) + if err != nil { + log.Printf("embeddingService.ProcessAndRetrieve failed: %v", err) + } + + chatGPTResponse, chatGPTResponseString, err := GetChatGPTResponses(conversation, openAI, interviewRepo, conversationContext) if err != nil { log.Printf("getChatGPTResponses failed: %v", err) return nil, err @@ -108,9 +129,11 @@ func CreateConversation( } func AppendConversation( + ctx context.Context, repo ConversationRepo, interviewRepo interview.InterviewRepo, openAI chatgpt.AIClient, + embeddingService embedding.Service, interviewID, userID int, conversation *Conversation, @@ -125,13 +148,28 @@ func AppendConversation( } messageUser := NewMessage(conversationID, topicID, questionNumber, User, message) - _, err := repo.AddMessage(conversationID, topicID, questionNumber, messageUser) + messageID, err := repo.AddMessage(conversationID, topicID, questionNumber, messageUser) if err != nil { return nil, err } conversation.Topics[topicID].Questions[questionNumber].Messages = append(conversation.Topics[topicID].Questions[questionNumber].Messages, messageUser) - chatGPTResponse, chatGPTResponseString, err := GetChatGPTResponses(conversation, openAI, interviewRepo) + embedInput := embedding.EmbedInput{ + InterviewID: interviewID, + ConversationID: conversationID, + TopicID: topicID, + QuestionNumber: questionNumber, + MessageID: messageID, + Question: conversation.Topics[topicID].Questions[questionNumber].Prompt, + UserResponse: message, + CreatedAt: time.Now().UTC(), + } + + conversationContext, err := embeddingService.ProcessAndRetrieve(ctx, embedInput) + if err != nil { + log.Printf("embeddingService.ProcessAndRetrieve failed: %v", err) + } + chatGPTResponse, chatGPTResponseString, err := GetChatGPTResponses(conversation, openAI, interviewRepo, conversationContext) if err != nil { log.Printf("getChatGPTResponses failed: %v", err) return nil, err diff --git a/database/migrations/000010_create_conversation_embeddings.down.sql b/database/migrations/000010_create_conversation_embeddings.down.sql new file mode 100644 index 0000000..970ef72 --- /dev/null +++ b/database/migrations/000010_create_conversation_embeddings.down.sql @@ -0,0 +1,7 @@ +DROP INDEX IF EXISTS conversation_embeddings_lookup_idx; +DROP INDEX IF EXISTS convo_embeddings_by_question_idx; +DROP INDEX IF EXISTS conversation_embeddings_embedding_idx; + +DROP TABLE IF EXISTS conversation_embeddings; + +DROP EXTENSION IF EXISTS vector; diff --git a/database/migrations/000010_create_conversation_embeddings.up.sql b/database/migrations/000010_create_conversation_embeddings.up.sql new file mode 100644 index 0000000..15054c3 --- /dev/null +++ b/database/migrations/000010_create_conversation_embeddings.up.sql @@ -0,0 +1,25 @@ +CREATE EXTENSION IF NOT EXISTS vector; + +CREATE TABLE conversation_embeddings ( + id SERIAL PRIMARY KEY, + interview_id INT NOT NULL, + conversation_id INT NOT NULL, + topic_id INT NOT NULL, + question_number INT NOT NULL, + message_id INT NOT NULL, + summary TEXT NOT NULL, + embedding VECTOR(384) NOT NULL, + created_at TIMESTAMP DEFAULT now() +); + +CREATE INDEX conversation_embeddings_embedding_idx + ON conversation_embeddings USING ivfflat (embedding vector_cosine_ops) + WITH (lists = 100); + +CREATE INDEX convo_embeddings_by_question_idx + ON conversation_embeddings (interview_id, topic_id, question_number); + +CREATE INDEX conversation_embeddings_lookup_idx + ON conversation_embeddings (interview_id, message_id); + +ANALYZE conversation_embeddings; diff --git a/embedding/embedder.go b/embedding/embedder.go new file mode 100644 index 0000000..ba7400f --- /dev/null +++ b/embedding/embedder.go @@ -0,0 +1,57 @@ +package embedding + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "net/http" + "os" + "time" +) + +type HTTPEmbedder struct { + Endpoint string + Timeout time.Duration +} + +func NewHTTPEmbedder() (*HTTPEmbedder, error) { + endpoint := os.Getenv("EMBEDDING_URL") + if endpoint == "" { + return nil, errors.New("env not set for EMBEDDING_URL") + } + + return &HTTPEmbedder{ + Endpoint: endpoint, + Timeout: 10 * time.Second, + }, nil +} + +func (e *HTTPEmbedder) EmbedText(ctx context.Context, input string) ([]float32, error) { + body, err := json.Marshal(map[string]string{"text": input}) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, "POST", e.Endpoint, bytes.NewBuffer(body)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: e.Timeout} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + var result struct { + Embedding []float32 `json:"embedding"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, err + } + + return result.Embedding, nil +} diff --git a/embedding/model.go b/embedding/model.go new file mode 100644 index 0000000..6f897a3 --- /dev/null +++ b/embedding/model.go @@ -0,0 +1,59 @@ +package embedding + +import ( + "context" + "time" + + "github.com/michaelboegner/interviewer/chatgpt" + "github.com/pgvector/pgvector-go" +) + +type EmbedInput struct { + InterviewID int + ConversationID int + MessageID int + TopicID int + QuestionNumber int + Question string + UserResponse string + CreatedAt time.Time +} + +type Embedding struct { + ID int + InterviewID int + ConversationID int + MessageID int + TopicID int + QuestionNumber int + Summary string + Vector pgvector.Vector + CreatedAt time.Time +} + +type Service struct { + Repo Repository + Embedder Embedder + Summarizer Summarizer +} + +type Repository interface { + StoreEmbedding(ctx context.Context, e Embedding) error + GetSimilarEmbeddings(ctx context.Context, interviewID, topicID, questionNumber, excludeMessageID int, queryVec pgvector.Vector, limit int) ([]string, error) +} + +type Embedder interface { + EmbedText(ctx context.Context, input string) ([]float32, error) +} + +type Summarizer interface { + ExtractResponseSummary(question, response string) (*chatgpt.ChatGPTResponse, error) +} + +func NewService(repo Repository, embedder Embedder, summarizer Summarizer) *Service { + return &Service{ + Repo: repo, + Embedder: embedder, + Summarizer: summarizer, + } +} diff --git a/embedding/repository.go b/embedding/repository.go new file mode 100644 index 0000000..7f9a015 --- /dev/null +++ b/embedding/repository.go @@ -0,0 +1,89 @@ +package embedding + +import ( + "context" + "database/sql" + "fmt" + + "github.com/pgvector/pgvector-go" +) + +type PGRepository struct { + DB *sql.DB +} + +func NewRepository(db *sql.DB) *PGRepository { + return &PGRepository{DB: db} +} + +func (r *PGRepository) StoreEmbedding(ctx context.Context, e Embedding) error { + query := ` + INSERT INTO conversation_embeddings ( + interview_id, + conversation_id, + message_id, + topic_id, + question_number, + summary, + embedding, + created_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + ` + + _, err := r.DB.ExecContext(ctx, query, + e.InterviewID, + e.ConversationID, + e.MessageID, + e.TopicID, + e.QuestionNumber, + e.Summary, + e.Vector, + e.CreatedAt, + ) + + // DEBUG + fmt.Printf("EMBEDDING STORED\n\n") + return err +} + +func (r *PGRepository) GetSimilarEmbeddings( + ctx context.Context, + interviewID, topicID, questionNumber, excludeMessageID int, + queryVec pgvector.Vector, + limit int, +) ([]string, error) { + query := ` + SELECT summary + FROM conversation_embeddings + WHERE interview_id = $1 + AND message_id != $2 + ORDER BY embedding <-> $3 + LIMIT $4; + ` + + rows, err := r.DB.QueryContext(ctx, query, + interviewID, + excludeMessageID, + queryVec, + limit, + ) + if err != nil { + return nil, fmt.Errorf("query error: %w", err) + } + defer rows.Close() + + var summaries []string + for rows.Next() { + var s string + if err := rows.Scan(&s); err != nil { + return nil, fmt.Errorf("row scan error: %w", err) + } + summaries = append(summaries, s) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("rows error: %w", err) + } + + return summaries, nil +} diff --git a/embedding/service.go b/embedding/service.go new file mode 100644 index 0000000..6cf2358 --- /dev/null +++ b/embedding/service.go @@ -0,0 +1,76 @@ +package embedding + +import ( + "context" + "fmt" + "log" + + "github.com/pgvector/pgvector-go" +) + +func (s *Service) ProcessAndRetrieve(ctx context.Context, input EmbedInput) ([]string, error) { + fmt.Printf("ProcessAndRetrieve firing\n") + + summaryResp, err := s.Summarizer.ExtractResponseSummary(input.Question, input.UserResponse) + if err != nil { + log.Printf("s.Summarizer.ExtractResponseSummary failed: %v", err) + return nil, err + } + + fmt.Printf("SummaryResp: %v\n", summaryResp) + + allRelevant := []string{} + seen := map[string]struct{}{} + limit := 1 + + for _, point := range summaryResp.UserRespSummary { + rawVec, err := s.Embedder.EmbedText(ctx, point) + if err != nil { + log.Printf("s.Embedder.EmbedText failed: %v", err) + return nil, err + } + vector := pgvector.NewVector(rawVec) + + fmt.Printf("vector: %v\n", vector) + + err = s.Repo.StoreEmbedding(ctx, Embedding{ + InterviewID: input.InterviewID, + ConversationID: input.ConversationID, + MessageID: input.MessageID, + TopicID: input.TopicID, + QuestionNumber: input.QuestionNumber, + Summary: point, + Vector: vector, + CreatedAt: input.CreatedAt, + }) + if err != nil { + log.Printf("s.Repo.StoreEmbedding failed: %v", err) + return nil, err + } + + relevantEmbeddings, err := s.Repo.GetSimilarEmbeddings( + ctx, + input.InterviewID, + input.TopicID, + input.QuestionNumber, + input.MessageID, + vector, + limit, + ) + if err != nil { + log.Printf("s.Repo.GetSimilarEmbeddings failed: %v", err) + return nil, err + } + + for _, r := range relevantEmbeddings { + if _, exists := seen[r]; !exists { + seen[r] = struct{}{} + allRelevant = append(allRelevant, r) + } + } + + fmt.Printf("relevant: %v\n", allRelevant) + } + + return allRelevant, nil +} diff --git a/go.mod b/go.mod index e48b15a..98b0e5b 100644 --- a/go.mod +++ b/go.mod @@ -1,13 +1,14 @@ module github.com/michaelboegner/interviewer -go 1.23 +go 1.23.0 toolchain go1.23.8 require ( github.com/golang-jwt/jwt/v5 v5.2.1 - github.com/google/go-cmp v0.5.9 + github.com/google/go-cmp v0.7.0 github.com/joho/godotenv v1.5.1 github.com/lib/pq v1.10.9 - golang.org/x/crypto v0.27.0 + github.com/pgvector/pgvector-go v0.3.0 + golang.org/x/crypto v0.36.0 ) diff --git a/go.sum b/go.sum index 931a771..358b9d4 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,64 @@ +entgo.io/ent v0.14.3 h1:wokAV/kIlH9TeklJWGGS7AYJdVckr0DloWjIcO9iIIQ= +entgo.io/ent v0.14.3/go.mod h1:aDPE/OziPEu8+OWbzy4UlvWmD2/kbRuWfK2A40hcxJM= +github.com/go-pg/pg/v10 v10.11.0 h1:CMKJqLgTrfpE/aOVeLdybezR2om071Vh38OLZjsyMI0= +github.com/go-pg/pg/v10 v10.11.0/go.mod h1:4BpHRoxE61y4Onpof3x1a2SQvi9c+q1dJnrNdMjsroA= +github.com/go-pg/zerochecker v0.2.0 h1:pp7f72c3DobMWOb2ErtZsnrPaSvHd2W4o9//8HtF4mU= +github.com/go-pg/zerochecker v0.2.0/go.mod h1:NJZ4wKL0NmTtz0GKCoJ8kym6Xn/EQzXRl2OnAe7MmDo= github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.7.2 h1:mLoDLV6sonKlvjIEsV56SkWNCnuNv531l94GaIzO+XI= +github.com/jackc/pgx/v5 v5.7.2/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g= +github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= -golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= +github.com/pgvector/pgvector-go v0.3.0 h1:Ij+Yt78R//uYqs3Zk35evZFvr+G0blW0OUN+Q2D1RWc= +github.com/pgvector/pgvector-go v0.3.0/go.mod h1:duFy+PXWfW7QQd5ibqutBO4GxLsUZ9RVXhFZGIBsWSA= +github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo= +github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs= +github.com/uptrace/bun v1.1.12 h1:sOjDVHxNTuM6dNGaba0wUuz7KvDE1BmNu9Gqs2gJSXQ= +github.com/uptrace/bun v1.1.12/go.mod h1:NPG6JGULBeQ9IU6yHp7YGELRa5Agmd7ATZdz4tGZ6z0= +github.com/uptrace/bun/dialect/pgdialect v1.1.12 h1:m/CM1UfOkoBTglGO5CUTKnIKKOApOYxkcP2qn0F9tJk= +github.com/uptrace/bun/dialect/pgdialect v1.1.12/go.mod h1:Ij6WIxQILxLlL2frUBxUBOZJtLElD2QQNDcu/PWDHTc= +github.com/uptrace/bun/driver/pgdriver v1.1.12 h1:3rRWB1GK0psTJrHwxzNfEij2MLibggiLdTqjTtfHc1w= +github.com/uptrace/bun/driver/pgdriver v1.1.12/go.mod h1:ssYUP+qwSEgeDDS1xm2XBip9el1y9Mi5mTAvLoiADLM= +github.com/vmihailenco/bufpool v0.1.11 h1:gOq2WmBrq0i2yW5QJ16ykccQ4wH9UyEsgLm6czKAd94= +github.com/vmihailenco/bufpool v0.1.11/go.mod h1:AFf/MOy3l2CFTKbxwt0mp2MwnqjNEs5H/UxrkA5jxTQ= +github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU= +github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= +github.com/vmihailenco/tagparser v0.1.2 h1:gnjoVuB/kljJ5wICEEOpx98oXMWPLj22G67Vbd1qPqc= +github.com/vmihailenco/tagparser v0.1.2/go.mod h1:OeAg3pn3UbLjkWt+rN9oFYB6u/cQgqMEUPoW2WPyhdI= +github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= +github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= +github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= +github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= +golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= +golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= +golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= +golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= +golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= +golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo= +gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0= +gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls= +gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +mellium.im/sasl v0.3.1 h1:wE0LW6g7U83vhvxjC1IY8DnXM+EU095yeo8XClvCdfo= +mellium.im/sasl v0.3.1/go.mod h1:xm59PUYpZHhgQ9ZqoJ5QaCqzWMi8IeS49dhp6plPCzw= diff --git a/handlers/handlers.go b/handlers/handlers.go index ba6474d..b77b521 100644 --- a/handlers/handlers.go +++ b/handlers/handlers.go @@ -697,9 +697,11 @@ func (h *Handler) CreateConversationsHandler(w http.ResponseWriter, r *http.Requ } conversationCreated, err := conversation.CreateConversation( + r.Context(), h.ConversationRepo, h.InterviewRepo, h.OpenAI, + *h.Embedding, conversationReturned, interviewID, interviewReturned.Prompt, @@ -779,9 +781,11 @@ func (h *Handler) AppendConversationsHandler(w http.ResponseWriter, r *http.Requ } conversationReturned, err = conversation.AppendConversation( + r.Context(), h.ConversationRepo, h.InterviewRepo, h.OpenAI, + *h.Embedding, interviewID, userID, conversationReturned, diff --git a/handlers/model.go b/handlers/model.go index cd51199..e5460e9 100644 --- a/handlers/model.go +++ b/handlers/model.go @@ -6,6 +6,7 @@ import ( "github.com/michaelboegner/interviewer/billing" "github.com/michaelboegner/interviewer/chatgpt" "github.com/michaelboegner/interviewer/conversation" + "github.com/michaelboegner/interviewer/embedding" "github.com/michaelboegner/interviewer/interview" "github.com/michaelboegner/interviewer/mailer" "github.com/michaelboegner/interviewer/token" @@ -59,6 +60,7 @@ type Handler struct { Billing *billing.Billing Mailer *mailer.Mailer OpenAI chatgpt.AIClient + Embedding *embedding.Service DB *sql.DB } @@ -71,6 +73,7 @@ func NewHandler( billing *billing.Billing, mailer *mailer.Mailer, openAI chatgpt.AIClient, + embeddingService *embedding.Service, db *sql.DB) *Handler { return &Handler{ InterviewRepo: interviewRepo, @@ -81,6 +84,7 @@ func NewHandler( Billing: billing, Mailer: mailer, OpenAI: openAI, + Embedding: embeddingService, DB: db, } } diff --git a/internal/server/server.go b/internal/server/server.go index 027f902..7204f3d 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -9,6 +9,7 @@ import ( "github.com/michaelboegner/interviewer/chatgpt" "github.com/michaelboegner/interviewer/conversation" "github.com/michaelboegner/interviewer/database" + "github.com/michaelboegner/interviewer/embedding" "github.com/michaelboegner/interviewer/handlers" "github.com/michaelboegner/interviewer/interview" "github.com/michaelboegner/interviewer/mailer" @@ -34,15 +35,32 @@ func NewServer() (*Server, error) { tokenRepo := token.NewRepository(db) conversationRepo := conversation.NewRepository(db) billingRepo := billing.NewRepository(db) + embeddingRepo := embedding.NewRepository(db) openAI := chatgpt.NewOpenAI() mailer := mailer.NewMailer() + embedder, err := embedding.NewHTTPEmbedder() + if err != nil { + log.Printf("embedding.NewHTTPEmbedder failed: %v", err) + return nil, err + } + embedding := embedding.NewService(embeddingRepo, embedder, openAI) billing, err := billing.NewBilling() if err != nil { log.Printf("billing.NewBilling failed: %v", err) return nil, err } - handler := handlers.NewHandler(interviewRepo, userRepo, tokenRepo, conversationRepo, billingRepo, billing, mailer, openAI, db) + handler := handlers.NewHandler( + interviewRepo, + userRepo, + tokenRepo, + conversationRepo, + billingRepo, + billing, + mailer, + openAI, + embedding, + db) mux.Handle("/api/users", http.HandlerFunc(handler.CreateUsersHandler)) mux.Handle("/api/auth/login", http.HandlerFunc(handler.LoginHandler))