Skip to content
6 changes: 6 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEW
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
Expand Down Expand Up @@ -160,6 +162,8 @@ golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sU
golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b h1:DXr+pvt3nC887026GRP39Ej11UATqWDmWuS99x26cD0=
golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b/go.mod h1:4QTo5u+SEIbbKW1RacMZq1YEfOBqeXa19JeshGi+zc4=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.27.0 h1:kb+q2PyFnEADO2IEF935ehFUXlWiNjJWtRNgBLSfbxQ=
golang.org/x/mod v0.27.0/go.mod h1:rWI627Fq0DEoudcK+MBkNkCe0EetEaDSwJJkCcjpazc=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
Expand Down Expand Up @@ -188,6 +192,8 @@ golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg=
golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
Expand Down
4 changes: 2 additions & 2 deletions internal/api/chat/create_conversation_message_stream_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ func (s *ChatServerV2) CreateConversationMessageStream(
APIKey: settings.OpenAIAPIKey,
}

openaiChatHistory, inappChatHistory, err := s.aiClientV2.ChatCompletionStreamV2(ctx, stream, conversation.ID.Hex(), modelSlug, conversation.OpenaiChatHistoryCompletion, llmProvider)
openaiChatHistory, inappChatHistory, err := s.aiClientV2.ChatCompletionStreamV2(ctx, stream, conversation.UserID, conversation.ID.Hex(), modelSlug, conversation.OpenaiChatHistoryCompletion, llmProvider)
if err != nil {
return s.sendStreamError(stream, err)
}
Expand All @@ -307,7 +307,7 @@ func (s *ChatServerV2) CreateConversationMessageStream(
for i, bsonMsg := range conversation.InappChatHistory {
protoMessages[i] = mapper.BSONToChatMessageV2(bsonMsg)
}
title, err := s.aiClientV2.GetConversationTitleV2(ctx, protoMessages, llmProvider)
title, err := s.aiClientV2.GetConversationTitleV2(ctx, conversation.UserID, protoMessages, llmProvider)
if err != nil {
s.logger.Error("Failed to get conversation title", "error", err, "conversationID", conversation.ID.Hex())
return
Expand Down
3 changes: 3 additions & 0 deletions internal/api/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
chatv2 "paperdebugger/pkg/gen/api/chat/v2"
commentv1 "paperdebugger/pkg/gen/api/comment/v1"
projectv1 "paperdebugger/pkg/gen/api/project/v1"
usagev1 "paperdebugger/pkg/gen/api/usage/v1"
userv1 "paperdebugger/pkg/gen/api/user/v1"

// "github.com/grpc-ecosystem/go-grpc-middleware"
Expand Down Expand Up @@ -106,6 +107,7 @@ func NewGrpcServer(
userServer userv1.UserServiceServer,
projectServer projectv1.ProjectServiceServer,
commentServer commentv1.CommentServiceServer,
usageServer usagev1.UsageServiceServer,
) *GrpcServer {
grpcServer := &GrpcServer{}
grpcServer.userService = userService
Expand All @@ -121,5 +123,6 @@ func NewGrpcServer(
userv1.RegisterUserServiceServer(grpcServer.Server, userServer)
projectv1.RegisterProjectServiceServer(grpcServer.Server, projectServer)
commentv1.RegisterCommentServiceServer(grpcServer.Server, commentServer)
usagev1.RegisterUsageServiceServer(grpcServer.Server, usageServer)
return grpcServer
}
23 changes: 18 additions & 5 deletions internal/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ import (
"paperdebugger/internal/libs/logger"
"paperdebugger/internal/libs/metadatautil"
"paperdebugger/internal/libs/shared"
"paperdebugger/internal/services"
authv1 "paperdebugger/pkg/gen/api/auth/v1"
chatv1 "paperdebugger/pkg/gen/api/chat/v1"
chatv2 "paperdebugger/pkg/gen/api/chat/v2"
commentv1 "paperdebugger/pkg/gen/api/comment/v1"
projectv1 "paperdebugger/pkg/gen/api/project/v1"
sharedv1 "paperdebugger/pkg/gen/api/shared/v1"
usagev1 "paperdebugger/pkg/gen/api/usage/v1"
userv1 "paperdebugger/pkg/gen/api/user/v1"

"github.com/gin-gonic/gin"
Expand All @@ -30,25 +32,31 @@ import (
)

type Server struct {
grpcServer *GrpcServer
ginServer *GinServer
grpcServer *GrpcServer
ginServer *GinServer
pricingService *services.PricingService

logger *logger.Logger
}

func NewServer(
grpcServer *GrpcServer,
ginServer *GinServer,
pricingService *services.PricingService,
logger *logger.Logger,
) *Server {
return &Server{
grpcServer: grpcServer,
ginServer: ginServer,
logger: logger,
grpcServer: grpcServer,
ginServer: ginServer,
pricingService: pricingService,
logger: logger,
}
}

func (s *Server) Run(addr string) {
// Start the pricing updater in the background
s.pricingService.StartPriceUpdater(context.Background())

listener, err := net.Listen("tcp", ":0")
if err != nil {
s.logger.Fatalf("failed to start grpc server listener: %v", err)
Expand Down Expand Up @@ -105,6 +113,11 @@ func (s *Server) Run(addr string) {
s.logger.Fatalf("failed to register comment service grpc gateway: %v", err)
return
}
err = usagev1.RegisterUsageServiceHandler(context.Background(), mux, client)
if err != nil {
s.logger.Fatalf("failed to register usage service grpc gateway: %v", err)
return
}

s.logger.Infof("[PAPERDEBUGGER] http server listening on %s", addr)
s.ginServer.Any("/_pd/api/*path", func(c *gin.Context) { mux.ServeHTTP(c.Writer, c.Request) })
Expand Down
66 changes: 66 additions & 0 deletions internal/api/usage/get_session_usage.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package usage

import (
"context"

"paperdebugger/internal/libs/contextutil"
"paperdebugger/internal/models"
usagev1 "paperdebugger/pkg/gen/api/usage/v1"

"google.golang.org/protobuf/types/known/timestamppb"
)

func (s *UsageServer) GetSessionUsage(
ctx context.Context,
req *usagev1.GetSessionUsageRequest,
) (*usagev1.GetSessionUsageResponse, error) {
actor, err := contextutil.GetActor(ctx)
if err != nil {
return nil, err
}

session, err := s.usageService.GetActiveSession(ctx, actor.ID)
if err != nil {
return nil, err
}

if session == nil {
return &usagev1.GetSessionUsageResponse{
Session: nil,
}, nil
}

// Get pricing map for cost calculation
pricingMap, err := s.pricingService.GetPricingMap(ctx)
if err != nil {
s.logger.Warn("Failed to get pricing map", "error", err)
pricingMap = make(map[string]*models.ModelPricing)
}

// Convert models map to proto format and calculate costs
protoModels := make(map[string]*usagev1.ModelTokens)
var totalCostUSD float64
for modelName, tokens := range session.Models {
var costUSD float64
if pricing, ok := pricingMap[modelName]; ok && pricing != nil {
costUSD = float64(tokens.PromptTokens)*pricing.PromptPrice +
float64(tokens.CompletionTokens)*pricing.CompletionPrice
totalCostUSD += costUSD
}
protoModels[modelName] = &usagev1.ModelTokens{
PromptTokens: tokens.PromptTokens,
CompletionTokens: tokens.CompletionTokens,
TotalTokens: tokens.TotalTokens,
RequestCount: tokens.RequestCount,
CostUsd: costUSD,
}
}

return &usagev1.GetSessionUsageResponse{
Session: &usagev1.SessionUsage{
SessionExpiry: timestamppb.New(session.SessionExpiry.Time()),
Models: protoModels,
TotalCostUsd: totalCostUSD,
},
}, nil
}
58 changes: 58 additions & 0 deletions internal/api/usage/get_weekly_usage.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package usage

import (
"context"

"paperdebugger/internal/libs/contextutil"
"paperdebugger/internal/models"
usagev1 "paperdebugger/pkg/gen/api/usage/v1"
)

func (s *UsageServer) GetWeeklyUsage(
ctx context.Context,
req *usagev1.GetWeeklyUsageRequest,
) (*usagev1.GetWeeklyUsageResponse, error) {
actor, err := contextutil.GetActor(ctx)
if err != nil {
return nil, err
}

stats, err := s.usageService.GetWeeklyUsage(ctx, actor.ID)
if err != nil {
return nil, err
}

// Get pricing map for cost calculation
pricingMap, err := s.pricingService.GetPricingMap(ctx)
if err != nil {
s.logger.Warn("Failed to get pricing map", "error", err)
pricingMap = make(map[string]*models.ModelPricing)
}

// Convert models map to proto format and calculate costs
protoModels := make(map[string]*usagev1.ModelTokens)
var totalCostUSD float64
for modelName, tokens := range stats.Models {
var costUSD float64
if pricing, ok := pricingMap[modelName]; ok && pricing != nil {
costUSD = float64(tokens.PromptTokens)*pricing.PromptPrice +
float64(tokens.CompletionTokens)*pricing.CompletionPrice
totalCostUSD += costUSD
}
protoModels[modelName] = &usagev1.ModelTokens{
PromptTokens: tokens.PromptTokens,
CompletionTokens: tokens.CompletionTokens,
TotalTokens: tokens.TotalTokens,
RequestCount: tokens.RequestCount,
CostUsd: costUSD,
}
}

return &usagev1.GetWeeklyUsageResponse{
Usage: &usagev1.WeeklyUsage{
Models: protoModels,
SessionCount: stats.SessionCount,
TotalCostUsd: totalCostUSD,
},
}, nil
}
27 changes: 27 additions & 0 deletions internal/api/usage/server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package usage

import (
"paperdebugger/internal/libs/logger"
"paperdebugger/internal/services"
usagev1 "paperdebugger/pkg/gen/api/usage/v1"
)

type UsageServer struct {
usagev1.UnimplementedUsageServiceServer

usageService *services.UsageService
pricingService *services.PricingService
logger *logger.Logger
}

func NewUsageServer(
usageService *services.UsageService,
pricingService *services.PricingService,
logger *logger.Logger,
) usagev1.UsageServiceServer {
return &UsageServer{
usageService: usageService,
pricingService: pricingService,
logger: logger,
}
}
42 changes: 41 additions & 1 deletion internal/libs/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"paperdebugger/internal/libs/cfg"
"paperdebugger/internal/libs/logger"
"paperdebugger/internal/models"

"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo"
Expand Down Expand Up @@ -43,5 +44,44 @@ func NewDB(cfg *cfg.Cfg, logger *logger.Logger) (*DB, error) {
}

logger.Info("[MONGO] initialized")
return &DB{Client: client, cfg: cfg, logger: logger}, nil

db := &DB{Client: client, cfg: cfg, logger: logger}
db.ensureIndexes()
return db, nil
}

// ensureIndexes creates necessary indexes for the database collections.
func (db *DB) ensureIndexes() {
sessions := db.Database("paperdebugger").Collection((models.LLMSession{}).CollectionName())

// TTL index: auto-delete sessions after 30 days past their expiry time
_, err := sessions.Indexes().CreateOne(context.Background(), mongo.IndexModel{
Keys: bson.D{{Key: "session_expiry", Value: 1}},
Options: options.Index().SetExpireAfterSeconds(30 * 24 * 60 * 60),
})
if err != nil {
db.logger.Error("Failed to create TTL index on llm_sessions", "error", err)
}

// Compound index for efficient active session lookups
_, err = sessions.Indexes().CreateOne(context.Background(), mongo.IndexModel{
Keys: bson.D{
{Key: "user_id", Value: 1},
{Key: "session_expiry", Value: -1},
},
})
if err != nil {
db.logger.Error("Failed to create compound index on llm_sessions", "error", err)
}

// Compound index for usage queries and recent session lookups
_, err = sessions.Indexes().CreateOne(context.Background(), mongo.IndexModel{
Keys: bson.D{
{Key: "user_id", Value: 1},
{Key: "session_start", Value: -1},
},
})
if err != nil {
db.logger.Error("Failed to create session_start index on llm_sessions", "error", err)
}
}
23 changes: 23 additions & 0 deletions internal/models/model_pricing.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package models

import (
"time"

"go.mongodb.org/mongo-driver/v2/bson"
)

// ModelPricing stores the pricing information for an LLM model.
// Prices are in USD per token.
type ModelPricing struct {
ID bson.ObjectID `bson:"_id"`
ModelID string `bson:"model_id"` // e.g., "openai/gpt-4"
ModelSlug string `bson:"model_slug"` // e.g., "gpt-4" (short name used in our app)
Name string `bson:"name"` // e.g., "OpenAI: GPT-4"
PromptPrice float64 `bson:"prompt_price"` // USD per token
CompletionPrice float64 `bson:"completion_price"` // USD per token
UpdatedAt time.Time `bson:"updated_at"`
}

func (m ModelPricing) CollectionName() string {
return "model_pricing"
}
25 changes: 25 additions & 0 deletions internal/models/usage.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package models

import "go.mongodb.org/mongo-driver/v2/bson"

// ModelTokens stores token counts for a specific model.
type ModelTokens struct {
PromptTokens int64 `bson:"prompt_tokens"`
CompletionTokens int64 `bson:"completion_tokens"`
TotalTokens int64 `bson:"total_tokens"`
RequestCount int64 `bson:"request_count"`
}

// LLMSession represents a user's session for tracking LLM usage and token counts.
// Tokens are stored per model in the Models map.
type LLMSession struct {
ID bson.ObjectID `bson:"_id"`
UserID bson.ObjectID `bson:"user_id"`
SessionStart bson.DateTime `bson:"session_start"`
SessionExpiry bson.DateTime `bson:"session_expiry"`
Models map[string]*ModelTokens `bson:"models"`
}

func (s LLMSession) CollectionName() string {
return "llm_sessions"
}
Loading