From d93462e42787852858eb73966751e7d94b89fca4 Mon Sep 17 00:00:00 2001 From: wjiayis Date: Sat, 21 Feb 2026 14:50:53 +0800 Subject: [PATCH 01/13] feat: LLM usage tracking, backend --- go.sum | 6 + .../create_conversation_message_stream_v2.go | 4 +- internal/api/grpc.go | 3 + internal/api/server.go | 6 + internal/api/usage/get_session_usage.go | 43 ++ internal/api/usage/get_weekly_usage.go | 33 ++ internal/api/usage/server.go | 24 + internal/libs/db/db.go | 31 +- internal/models/usage.go | 19 + internal/services/toolkit/client/client_v2.go | 3 + .../services/toolkit/client/completion_v2.go | 27 +- .../toolkit/client/get_citation_keys.go | 2 +- .../toolkit/client/get_citation_keys_test.go | 2 + .../client/get_conversation_title_v2.go | 5 +- internal/services/toolkit/client/utils_v2.go | 6 + internal/services/usage.go | 175 +++++++ internal/wire.go | 3 + internal/wire_gen.go | 9 +- pkg/gen/api/chat/v2/chat.pb.go | 7 +- pkg/gen/api/usage/v1/usage.pb.go | 446 ++++++++++++++++++ pkg/gen/api/usage/v1/usage.pb.gw.go | 211 +++++++++ pkg/gen/api/usage/v1/usage_grpc.pb.go | 159 +++++++ proto/usage/v1/usage.proto | 49 ++ .../pkg/gen/apiclient/usage/v1/usage_pb.ts | 186 ++++++++ 24 files changed, 1440 insertions(+), 19 deletions(-) create mode 100644 internal/api/usage/get_session_usage.go create mode 100644 internal/api/usage/get_weekly_usage.go create mode 100644 internal/api/usage/server.go create mode 100644 internal/models/usage.go create mode 100644 internal/services/usage.go create mode 100644 pkg/gen/api/usage/v1/usage.pb.go create mode 100644 pkg/gen/api/usage/v1/usage.pb.gw.go create mode 100644 pkg/gen/api/usage/v1/usage_grpc.pb.go create mode 100644 proto/usage/v1/usage.proto create mode 100644 webapp/_webapp/src/pkg/gen/apiclient/usage/v1/usage_pb.ts diff --git a/go.sum b/go.sum index 1943dc8d..65f5adb4 100644 --- a/go.sum +++ b/go.sum @@ -55,6 +55,8 @@ github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEW github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= +github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4= @@ -160,6 +162,8 @@ golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sU golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b h1:DXr+pvt3nC887026GRP39Ej11UATqWDmWuS99x26cD0= golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b/go.mod h1:4QTo5u+SEIbbKW1RacMZq1YEfOBqeXa19JeshGi+zc4= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.27.0 h1:kb+q2PyFnEADO2IEF935ehFUXlWiNjJWtRNgBLSfbxQ= +golang.org/x/mod v0.27.0/go.mod h1:rWI627Fq0DEoudcK+MBkNkCe0EetEaDSwJJkCcjpazc= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= @@ -188,6 +192,8 @@ golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg= +golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= diff --git a/internal/api/chat/create_conversation_message_stream_v2.go b/internal/api/chat/create_conversation_message_stream_v2.go index b82adf5d..8c715a36 100644 --- a/internal/api/chat/create_conversation_message_stream_v2.go +++ b/internal/api/chat/create_conversation_message_stream_v2.go @@ -281,7 +281,7 @@ func (s *ChatServerV2) CreateConversationMessageStream( APIKey: settings.OpenAIAPIKey, } - openaiChatHistory, inappChatHistory, err := s.aiClientV2.ChatCompletionStreamV2(ctx, stream, conversation.ID.Hex(), modelSlug, conversation.OpenaiChatHistoryCompletion, llmProvider) + openaiChatHistory, inappChatHistory, err := s.aiClientV2.ChatCompletionStreamV2(ctx, stream, conversation.UserID, conversation.ID.Hex(), modelSlug, conversation.OpenaiChatHistoryCompletion, llmProvider) if err != nil { return s.sendStreamError(stream, err) } @@ -307,7 +307,7 @@ func (s *ChatServerV2) CreateConversationMessageStream( for i, bsonMsg := range conversation.InappChatHistory { protoMessages[i] = mapper.BSONToChatMessageV2(bsonMsg) } - title, err := s.aiClientV2.GetConversationTitleV2(ctx, protoMessages, llmProvider) + title, err := s.aiClientV2.GetConversationTitleV2(ctx, conversation.UserID, protoMessages, llmProvider) if err != nil { s.logger.Error("Failed to get conversation title", "error", err, "conversationID", conversation.ID.Hex()) return diff --git a/internal/api/grpc.go b/internal/api/grpc.go index ed9dc2b0..3451d667 100644 --- a/internal/api/grpc.go +++ b/internal/api/grpc.go @@ -15,6 +15,7 @@ import ( chatv2 "paperdebugger/pkg/gen/api/chat/v2" commentv1 "paperdebugger/pkg/gen/api/comment/v1" projectv1 "paperdebugger/pkg/gen/api/project/v1" + usagev1 "paperdebugger/pkg/gen/api/usage/v1" userv1 "paperdebugger/pkg/gen/api/user/v1" // "github.com/grpc-ecosystem/go-grpc-middleware" @@ -106,6 +107,7 @@ func NewGrpcServer( userServer userv1.UserServiceServer, projectServer projectv1.ProjectServiceServer, commentServer commentv1.CommentServiceServer, + usageServer usagev1.UsageServiceServer, ) *GrpcServer { grpcServer := &GrpcServer{} grpcServer.userService = userService @@ -121,5 +123,6 @@ func NewGrpcServer( userv1.RegisterUserServiceServer(grpcServer.Server, userServer) projectv1.RegisterProjectServiceServer(grpcServer.Server, projectServer) commentv1.RegisterCommentServiceServer(grpcServer.Server, commentServer) + usagev1.RegisterUsageServiceServer(grpcServer.Server, usageServer) return grpcServer } diff --git a/internal/api/server.go b/internal/api/server.go index b093c767..d8e9b36a 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -17,6 +17,7 @@ import ( commentv1 "paperdebugger/pkg/gen/api/comment/v1" projectv1 "paperdebugger/pkg/gen/api/project/v1" sharedv1 "paperdebugger/pkg/gen/api/shared/v1" + usagev1 "paperdebugger/pkg/gen/api/usage/v1" userv1 "paperdebugger/pkg/gen/api/user/v1" "github.com/gin-gonic/gin" @@ -105,6 +106,11 @@ func (s *Server) Run(addr string) { s.logger.Fatalf("failed to register comment service grpc gateway: %v", err) return } + err = usagev1.RegisterUsageServiceHandler(context.Background(), mux, client) + if err != nil { + s.logger.Fatalf("failed to register usage service grpc gateway: %v", err) + return + } s.logger.Infof("[PAPERDEBUGGER] http server listening on %s", addr) s.ginServer.Any("/_pd/api/*path", func(c *gin.Context) { mux.ServeHTTP(c.Writer, c.Request) }) diff --git a/internal/api/usage/get_session_usage.go b/internal/api/usage/get_session_usage.go new file mode 100644 index 00000000..d59c3b08 --- /dev/null +++ b/internal/api/usage/get_session_usage.go @@ -0,0 +1,43 @@ +package usage + +import ( + "context" + + "paperdebugger/internal/libs/contextutil" + usagev1 "paperdebugger/pkg/gen/api/usage/v1" + + "google.golang.org/protobuf/types/known/timestamppb" +) + +func (s *UsageServer) GetSessionUsage( + ctx context.Context, + req *usagev1.GetSessionUsageRequest, +) (*usagev1.GetSessionUsageResponse, error) { + actor, err := contextutil.GetActor(ctx) + if err != nil { + return nil, err + } + + session, err := s.usageService.GetActiveSession(ctx, actor.ID) + if err != nil { + return nil, err + } + + if session == nil { + return &usagev1.GetSessionUsageResponse{ + Session: nil, + }, nil + } + + return &usagev1.GetSessionUsageResponse{ + Session: &usagev1.SessionUsage{ + Id: session.ID.Hex(), + SessionStart: timestamppb.New(session.SessionStart.Time()), + SessionExpiry: timestamppb.New(session.SessionExpiry.Time()), + PromptTokens: session.PromptTokens, + CompletionTokens: session.CompletionTokens, + TotalTokens: session.TotalTokens, + RequestCount: session.RequestCount, + }, + }, nil +} diff --git a/internal/api/usage/get_weekly_usage.go b/internal/api/usage/get_weekly_usage.go new file mode 100644 index 00000000..e3c168f1 --- /dev/null +++ b/internal/api/usage/get_weekly_usage.go @@ -0,0 +1,33 @@ +package usage + +import ( + "context" + + "paperdebugger/internal/libs/contextutil" + usagev1 "paperdebugger/pkg/gen/api/usage/v1" +) + +func (s *UsageServer) GetWeeklyUsage( + ctx context.Context, + req *usagev1.GetWeeklyUsageRequest, +) (*usagev1.GetWeeklyUsageResponse, error) { + actor, err := contextutil.GetActor(ctx) + if err != nil { + return nil, err + } + + stats, err := s.usageService.GetWeeklyUsage(ctx, actor.ID) + if err != nil { + return nil, err + } + + return &usagev1.GetWeeklyUsageResponse{ + Usage: &usagev1.WeeklyUsage{ + PromptTokens: stats.PromptTokens, + CompletionTokens: stats.CompletionTokens, + TotalTokens: stats.TotalTokens, + RequestCount: stats.RequestCount, + SessionCount: stats.SessionCount, + }, + }, nil +} diff --git a/internal/api/usage/server.go b/internal/api/usage/server.go new file mode 100644 index 00000000..5d64854e --- /dev/null +++ b/internal/api/usage/server.go @@ -0,0 +1,24 @@ +package usage + +import ( + "paperdebugger/internal/libs/logger" + "paperdebugger/internal/services" + usagev1 "paperdebugger/pkg/gen/api/usage/v1" +) + +type UsageServer struct { + usagev1.UnimplementedUsageServiceServer + + usageService *services.UsageService + logger *logger.Logger +} + +func NewUsageServer( + usageService *services.UsageService, + logger *logger.Logger, +) usagev1.UsageServiceServer { + return &UsageServer{ + usageService: usageService, + logger: logger, + } +} diff --git a/internal/libs/db/db.go b/internal/libs/db/db.go index 52a5548c..8468f73c 100644 --- a/internal/libs/db/db.go +++ b/internal/libs/db/db.go @@ -6,6 +6,7 @@ import ( "paperdebugger/internal/libs/cfg" "paperdebugger/internal/libs/logger" + "paperdebugger/internal/models" "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/mongo" @@ -43,5 +44,33 @@ func NewDB(cfg *cfg.Cfg, logger *logger.Logger) (*DB, error) { } logger.Info("[MONGO] initialized") - return &DB{Client: client, cfg: cfg, logger: logger}, nil + + db := &DB{Client: client, cfg: cfg, logger: logger} + db.ensureIndexes() + return db, nil +} + +// ensureIndexes creates necessary indexes for the database collections. +func (db *DB) ensureIndexes() { + sessions := db.Database("paperdebugger").Collection((models.LLMSession{}).CollectionName()) + + // TTL index: auto-delete sessions after 30 days + _, err := sessions.Indexes().CreateOne(context.Background(), mongo.IndexModel{ + Keys: bson.D{{Key: "session_expiry", Value: 1}}, + Options: options.Index().SetExpireAfterSeconds(30 * 24 * 60 * 60), + }) + if err != nil { + db.logger.Error("Failed to create TTL index on llm_sessions", "error", err) + } + + // Compound index for efficient active session lookups + _, err = sessions.Indexes().CreateOne(context.Background(), mongo.IndexModel{ + Keys: bson.D{ + {Key: "user_id", Value: 1}, + {Key: "session_expiry", Value: -1}, + }, + }) + if err != nil { + db.logger.Error("Failed to create compound index on llm_sessions", "error", err) + } } diff --git a/internal/models/usage.go b/internal/models/usage.go new file mode 100644 index 00000000..91d73273 --- /dev/null +++ b/internal/models/usage.go @@ -0,0 +1,19 @@ +package models + +import "go.mongodb.org/mongo-driver/v2/bson" + +// LLMSession represents a user's session for tracking LLM usage and token counts. +type LLMSession struct { + ID bson.ObjectID `bson:"_id"` + UserID bson.ObjectID `bson:"user_id"` + SessionStart bson.DateTime `bson:"session_start"` + SessionExpiry bson.DateTime `bson:"session_expiry"` + PromptTokens int64 `bson:"prompt_tokens"` + CompletionTokens int64 `bson:"completion_tokens"` + TotalTokens int64 `bson:"total_tokens"` + RequestCount int64 `bson:"request_count"` +} + +func (s LLMSession) CollectionName() string { + return "llm_sessions" +} diff --git a/internal/services/toolkit/client/client_v2.go b/internal/services/toolkit/client/client_v2.go index 87a1e26a..4bbcf816 100644 --- a/internal/services/toolkit/client/client_v2.go +++ b/internal/services/toolkit/client/client_v2.go @@ -20,6 +20,7 @@ type AIClientV2 struct { reverseCommentService *services.ReverseCommentService projectService *services.ProjectService + usageService *services.UsageService cfg *cfg.Cfg logger *logger.Logger } @@ -60,6 +61,7 @@ func NewAIClientV2( reverseCommentService *services.ReverseCommentService, projectService *services.ProjectService, + usageService *services.UsageService, cfg *cfg.Cfg, logger *logger.Logger, ) *AIClientV2 { @@ -107,6 +109,7 @@ func NewAIClientV2( reverseCommentService: reverseCommentService, projectService: projectService, + usageService: usageService, cfg: cfg, logger: logger, } diff --git a/internal/services/toolkit/client/completion_v2.go b/internal/services/toolkit/client/completion_v2.go index f10082bf..463e5e0a 100644 --- a/internal/services/toolkit/client/completion_v2.go +++ b/internal/services/toolkit/client/completion_v2.go @@ -4,11 +4,13 @@ import ( "context" "encoding/json" "paperdebugger/internal/models" + "paperdebugger/internal/services" "paperdebugger/internal/services/toolkit/handler" chatv2 "paperdebugger/pkg/gen/api/chat/v2" "strings" "github.com/openai/openai-go/v3" + "go.mongodb.org/mongo-driver/v2/bson" ) // define []openai.ChatCompletionMessageParamUnion as OpenAIChatHistory @@ -25,8 +27,8 @@ import ( // 1. The full chat history sent to the language model (including any tool call results). // 2. The incremental chat history visible to the user (including tool call results and assistant responses). // 3. An error, if any occurred during the process. -func (a *AIClientV2) ChatCompletionV2(ctx context.Context, modelSlug string, messages OpenAIChatHistory, llmProvider *models.LLMProviderConfig) (OpenAIChatHistory, AppChatHistory, error) { - openaiChatHistory, inappChatHistory, err := a.ChatCompletionStreamV2(ctx, nil, "", modelSlug, messages, llmProvider) +func (a *AIClientV2) ChatCompletionV2(ctx context.Context, userID bson.ObjectID, modelSlug string, messages OpenAIChatHistory, llmProvider *models.LLMProviderConfig) (OpenAIChatHistory, AppChatHistory, error) { + openaiChatHistory, inappChatHistory, err := a.ChatCompletionStreamV2(ctx, nil, userID, "", modelSlug, messages, llmProvider) if err != nil { return nil, nil, err } @@ -54,7 +56,7 @@ func (a *AIClientV2) ChatCompletionV2(ctx context.Context, modelSlug string, mes // - If tool calls are required, it handles them and appends the results to the chat history, then continues the loop. // - If no tool calls are needed, it appends the assistant's response and exits the loop. // - Finally, it returns the updated chat histories and any error encountered. -func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream chatv2.ChatService_CreateConversationMessageStreamServer, conversationId string, modelSlug string, messages OpenAIChatHistory, llmProvider *models.LLMProviderConfig) (OpenAIChatHistory, AppChatHistory, error) { +func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream chatv2.ChatService_CreateConversationMessageStreamServer, userID bson.ObjectID, conversationId string, modelSlug string, messages OpenAIChatHistory, llmProvider *models.LLMProviderConfig) (OpenAIChatHistory, AppChatHistory, error) { openaiChatHistory := messages inappChatHistory := AppChatHistory{} @@ -96,8 +98,22 @@ func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream chunk := stream.Current() if len(chunk.Choices) == 0 { - // Handle usage information - // fmt.Printf("Usage: %+v\n", chunk.Usage) + if chunk.Usage.TotalTokens > 0 { + // Record usage and log stats asynchronously to avoid blocking the response + go func(usage services.UsageRecord) { + bgCtx := context.Background() + if err := a.usageService.RecordUsage(bgCtx, usage); err != nil { + a.logger.Error("Failed to store usage", "error", err) + return + } + + }(services.UsageRecord{ + UserID: userID, + PromptTokens: chunk.Usage.PromptTokens, + CompletionTokens: chunk.Usage.CompletionTokens, + TotalTokens: chunk.Usage.TotalTokens, + }) + } continue } @@ -185,7 +201,6 @@ func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream // answer_content += chunk.Choices[0].Delta.Content // fmt.Printf("answer_content: %s\n", answer_content) streamHandler.HandleTextDoneItem(chunk, answer_content, reasoning_content) - break } } diff --git a/internal/services/toolkit/client/get_citation_keys.go b/internal/services/toolkit/client/get_citation_keys.go index 1995d590..5cc43ce5 100644 --- a/internal/services/toolkit/client/get_citation_keys.go +++ b/internal/services/toolkit/client/get_citation_keys.go @@ -241,7 +241,7 @@ func (a *AIClientV2) GetCitationKeys(ctx context.Context, sentence string, userI // Bibliography is placed at the start of the prompt to leverage prompt caching message := fmt.Sprintf("Bibliography: %s\nSentence: %s\nBased on the sentence and bibliography, suggest only the most relevant citation keys separated by commas with no spaces (e.g. key1,key2). Be selective and only include citations that are directly relevant. Avoid suggesting more than 3 citations. If no relevant citations are found, return '%s'.", bibliography, sentence, emptyCitation) - _, resp, err := a.ChatCompletionV2(ctx, "gpt-5.2", OpenAIChatHistory{ + _, resp, err := a.ChatCompletionV2(ctx, userId, "gpt-5.2", OpenAIChatHistory{ openai.SystemMessage("You are a helpful assistant that suggests relevant citation keys."), openai.UserMessage(message), }, llmProvider) diff --git a/internal/services/toolkit/client/get_citation_keys_test.go b/internal/services/toolkit/client/get_citation_keys_test.go index 4d2a857d..802e6bbf 100644 --- a/internal/services/toolkit/client/get_citation_keys_test.go +++ b/internal/services/toolkit/client/get_citation_keys_test.go @@ -25,10 +25,12 @@ func setupTestClient(t *testing.T) (*client.AIClientV2, *services.ProjectService } projectService := services.NewProjectService(dbInstance, cfg.GetCfg(), logger.GetLogger()) + usageService := services.NewUsageService(dbInstance, cfg.GetCfg(), logger.GetLogger()) aiClient := client.NewAIClientV2( dbInstance, &services.ReverseCommentService{}, projectService, + usageService, cfg.GetCfg(), logger.GetLogger(), ) diff --git a/internal/services/toolkit/client/get_conversation_title_v2.go b/internal/services/toolkit/client/get_conversation_title_v2.go index 6c92f0c2..f3fd5c8c 100644 --- a/internal/services/toolkit/client/get_conversation_title_v2.go +++ b/internal/services/toolkit/client/get_conversation_title_v2.go @@ -11,9 +11,10 @@ import ( "github.com/openai/openai-go/v3" "github.com/samber/lo" + "go.mongodb.org/mongo-driver/v2/bson" ) -func (a *AIClientV2) GetConversationTitleV2(ctx context.Context, inappChatHistory []*chatv2.Message, llmProvider *models.LLMProviderConfig) (string, error) { +func (a *AIClientV2) GetConversationTitleV2(ctx context.Context, userID bson.ObjectID, inappChatHistory []*chatv2.Message, llmProvider *models.LLMProviderConfig) (string, error) { messages := lo.Map(inappChatHistory, func(message *chatv2.Message, _ int) string { if _, ok := message.Payload.MessageType.(*chatv2.MessagePayload_Assistant); ok { return fmt.Sprintf("Assistant: %s", message.Payload.GetAssistant().GetContent()) @@ -29,7 +30,7 @@ func (a *AIClientV2) GetConversationTitleV2(ctx context.Context, inappChatHistor message := strings.Join(messages, "\n") message = fmt.Sprintf("%s\nBased on above conversation, generate a short, clear, and descriptive title that summarizes the main topic or purpose of the discussion. The title should be concise, specific, and use natural language. Avoid vague or generic titles. Use abbreviation and short words if possible. Use 3-5 words if possible. Give me the title only, no other text including any other words.", message) - _, resp, err := a.ChatCompletionV2(ctx, "gpt-5-nano", OpenAIChatHistory{ + _, resp, err := a.ChatCompletionV2(ctx, userID, "gpt-5-nano", OpenAIChatHistory{ openai.SystemMessage("You are a helpful assistant that generates a title for a conversation."), openai.UserMessage(message), }, llmProvider) diff --git a/internal/services/toolkit/client/utils_v2.go b/internal/services/toolkit/client/utils_v2.go index 69e73071..47829575 100644 --- a/internal/services/toolkit/client/utils_v2.go +++ b/internal/services/toolkit/client/utils_v2.go @@ -74,6 +74,9 @@ func getDefaultParamsV2(modelSlug string, toolRegistry *registry.ToolRegistryV2) Tools: toolRegistry.GetTools(), ParallelToolCalls: openaiv3.Bool(true), Store: openaiv3.Bool(false), + StreamOptions: openaiv3.ChatCompletionStreamOptionsParam{ + IncludeUsage: openaiv3.Bool(true), + }, } } } @@ -85,6 +88,9 @@ func getDefaultParamsV2(modelSlug string, toolRegistry *registry.ToolRegistryV2) Tools: toolRegistry.GetTools(), // Tool registration is managed centrally by the registry ParallelToolCalls: openaiv3.Bool(true), Store: openaiv3.Bool(false), // Must set to false, because we are construct our own chat history. + StreamOptions: openaiv3.ChatCompletionStreamOptionsParam{ + IncludeUsage: openaiv3.Bool(true), + }, } } diff --git a/internal/services/usage.go b/internal/services/usage.go new file mode 100644 index 00000000..06603d0b --- /dev/null +++ b/internal/services/usage.go @@ -0,0 +1,175 @@ +package services + +import ( + "context" + "time" + + "paperdebugger/internal/libs/cfg" + "paperdebugger/internal/libs/db" + "paperdebugger/internal/libs/logger" + "paperdebugger/internal/models" + + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" + "go.mongodb.org/mongo-driver/v2/mongo/options" +) + +const SessionDuration = 5 * time.Hour + +type UsageService struct { + BaseService + sessionCollection *mongo.Collection +} + +type UsageRecord struct { + UserID bson.ObjectID + PromptTokens int64 + CompletionTokens int64 + TotalTokens int64 +} + +type UsageStats struct { + PromptTokens int64 `bson:"prompt_tokens"` + CompletionTokens int64 `bson:"completion_tokens"` + TotalTokens int64 `bson:"total_tokens"` + RequestCount int64 `bson:"request_count"` + SessionCount int64 `bson:"session_count"` +} + +func NewUsageService(db *db.DB, cfg *cfg.Cfg, logger *logger.Logger) *UsageService { + base := NewBaseService(db, cfg, logger) + return &UsageService{ + BaseService: base, + sessionCollection: base.db.Collection((models.LLMSession{}).CollectionName()), + } +} + +// RecordUsage updates the active session or creates a new one if none exists. +// Uses retry logic to handle race conditions when multiple requests try to create a session. +func (s *UsageService) RecordUsage(ctx context.Context, record UsageRecord) error { + now := time.Now() + nowBson := bson.DateTime(now.UnixMilli()) + + filter := bson.M{ + "user_id": record.UserID, + "session_expiry": bson.M{"$gt": nowBson}, + } + update := bson.M{ + "$inc": bson.M{ + "prompt_tokens": record.PromptTokens, + "completion_tokens": record.CompletionTokens, + "total_tokens": record.TotalTokens, + "request_count": 1, + }, + } + + result, err := s.sessionCollection.UpdateOne(ctx, filter, update) + if err != nil { + return err + } + if result.MatchedCount > 0 { + return nil + } + + // No active session found - create a new one + session := models.LLMSession{ + ID: bson.NewObjectID(), + UserID: record.UserID, + SessionStart: nowBson, + SessionExpiry: bson.DateTime(now.Add(SessionDuration).UnixMilli()), + PromptTokens: record.PromptTokens, + CompletionTokens: record.CompletionTokens, + TotalTokens: record.TotalTokens, + RequestCount: 1, + } + _, err = s.sessionCollection.InsertOne(ctx, session) + if err != nil { + // Insert failed (race condition or other error) - retry update + _, err = s.sessionCollection.UpdateOne(ctx, filter, update) + } + return err +} + +// GetActiveSession returns the current active session for a user, if any. +func (s *UsageService) GetActiveSession(ctx context.Context, userID bson.ObjectID) (*models.LLMSession, error) { + now := bson.DateTime(time.Now().UnixMilli()) + filter := bson.M{ + "user_id": userID, + "session_expiry": bson.M{"$gt": now}, + } + + var session models.LLMSession + err := s.sessionCollection.FindOne(ctx, filter).Decode(&session) + if err == mongo.ErrNoDocuments { + return nil, nil + } + if err != nil { + return nil, err + } + return &session, nil +} + +// GetWeeklyUsage returns aggregated usage for a user for the current week (Monday-Sunday). +func (s *UsageService) GetWeeklyUsage(ctx context.Context, userID bson.ObjectID) (*UsageStats, error) { + weekStart := startOfWeek(time.Now()) + return s.getUsageSince(ctx, userID, weekStart) +} + +func (s *UsageService) getUsageSince(ctx context.Context, userID bson.ObjectID, since time.Time) (*UsageStats, error) { + pipeline := bson.A{ + bson.M{"$match": bson.M{ + "user_id": userID, + "session_start": bson.M{"$gte": bson.DateTime(since.UnixMilli())}, + }}, + bson.M{"$group": bson.M{ + "_id": nil, + "prompt_tokens": bson.M{"$sum": "$prompt_tokens"}, + "completion_tokens": bson.M{"$sum": "$completion_tokens"}, + "total_tokens": bson.M{"$sum": "$total_tokens"}, + "request_count": bson.M{"$sum": "$request_count"}, + "session_count": bson.M{"$sum": 1}, + }}, + } + + cursor, err := s.sessionCollection.Aggregate(ctx, pipeline) + if err != nil { + return nil, err + } + defer cursor.Close(ctx) + + if cursor.Next(ctx) { + var result UsageStats + if err := cursor.Decode(&result); err != nil { + return nil, err + } + return &result, nil + } + return &UsageStats{}, nil +} + +// startOfWeek returns the start of the week (Monday 00:00:00 UTC). +func startOfWeek(t time.Time) time.Time { + t = t.UTC() + daysFromMonday := (int(t.Weekday()) + 6) % 7 // Sunday=6, Monday=0, Tuesday=1, ... + return time.Date(t.Year(), t.Month(), t.Day()-daysFromMonday, 0, 0, 0, 0, time.UTC) +} + +// ListRecentSessions returns the most recent sessions for a user. +func (s *UsageService) ListRecentSessions(ctx context.Context, userID bson.ObjectID, limit int64) ([]models.LLMSession, error) { + filter := bson.M{"user_id": userID} + opts := options.Find(). + SetSort(bson.D{{Key: "session_start", Value: -1}}). + SetLimit(limit) + + cursor, err := s.sessionCollection.Find(ctx, filter, opts) + if err != nil { + return nil, err + } + defer cursor.Close(ctx) + + var sessions []models.LLMSession + if err := cursor.All(ctx, &sessions); err != nil { + return nil, err + } + return sessions, nil +} diff --git a/internal/wire.go b/internal/wire.go index f823bc2e..52e6ff28 100644 --- a/internal/wire.go +++ b/internal/wire.go @@ -9,6 +9,7 @@ import ( "paperdebugger/internal/api/chat" "paperdebugger/internal/api/comment" "paperdebugger/internal/api/project" + "paperdebugger/internal/api/usage" "paperdebugger/internal/api/user" "paperdebugger/internal/libs/cfg" "paperdebugger/internal/libs/db" @@ -32,6 +33,7 @@ var Set = wire.NewSet( user.NewUserServer, project.NewProjectServer, comment.NewCommentServer, + usage.NewUsageServer, aiclient.NewAIClient, aiclient.NewAIClientV2, @@ -43,6 +45,7 @@ var Set = wire.NewSet( services.NewProjectService, services.NewPromptService, services.NewOAuthService, + services.NewUsageService, cfg.GetCfg, logger.GetLogger, diff --git a/internal/wire_gen.go b/internal/wire_gen.go index 75c4e91a..a706db0f 100644 --- a/internal/wire_gen.go +++ b/internal/wire_gen.go @@ -13,6 +13,7 @@ import ( "paperdebugger/internal/api/chat" "paperdebugger/internal/api/comment" "paperdebugger/internal/api/project" + "paperdebugger/internal/api/usage" "paperdebugger/internal/api/user" "paperdebugger/internal/libs/cfg" "paperdebugger/internal/libs/db" @@ -38,14 +39,16 @@ func InitializeApp() (*api.Server, error) { aiClient := client.NewAIClient(dbDB, reverseCommentService, projectService, cfgCfg, loggerLogger) chatService := services.NewChatService(dbDB, cfgCfg, loggerLogger) chatServiceServer := chat.NewChatServer(aiClient, chatService, projectService, userService, loggerLogger, cfgCfg) - aiClientV2 := client.NewAIClientV2(dbDB, reverseCommentService, projectService, cfgCfg, loggerLogger) + usageService := services.NewUsageService(dbDB, cfgCfg, loggerLogger) + aiClientV2 := client.NewAIClientV2(dbDB, reverseCommentService, projectService, usageService, cfgCfg, loggerLogger) chatServiceV2 := services.NewChatServiceV2(dbDB, cfgCfg, loggerLogger) chatv2ChatServiceServer := chat.NewChatServerV2(aiClientV2, chatServiceV2, projectService, userService, loggerLogger, cfgCfg) promptService := services.NewPromptService(dbDB, cfgCfg, loggerLogger) userServiceServer := user.NewUserServer(userService, promptService, cfgCfg, loggerLogger) projectServiceServer := project.NewProjectServer(projectService, loggerLogger, cfgCfg) commentServiceServer := comment.NewCommentServer(projectService, chatService, reverseCommentService, loggerLogger, cfgCfg) - grpcServer := api.NewGrpcServer(userService, cfgCfg, authServiceServer, chatServiceServer, chatv2ChatServiceServer, userServiceServer, projectServiceServer, commentServiceServer) + usageServiceServer := usage.NewUsageServer(usageService, loggerLogger) + grpcServer := api.NewGrpcServer(userService, cfgCfg, authServiceServer, chatServiceServer, chatv2ChatServiceServer, userServiceServer, projectServiceServer, commentServiceServer, usageServiceServer) oAuthService := services.NewOAuthService(dbDB, cfgCfg, loggerLogger) oAuthHandler := auth.NewOAuthHandler(oAuthService) ginServer := api.NewGinServer(cfgCfg, oAuthHandler) @@ -55,4 +58,4 @@ func InitializeApp() (*api.Server, error) { // wire.go: -var Set = wire.NewSet(api.NewServer, api.NewGrpcServer, api.NewGinServer, auth.NewOAuthHandler, auth.NewAuthServer, chat.NewChatServer, chat.NewChatServerV2, user.NewUserServer, project.NewProjectServer, comment.NewCommentServer, client.NewAIClient, client.NewAIClientV2, services.NewReverseCommentService, services.NewChatService, services.NewChatServiceV2, services.NewTokenService, services.NewUserService, services.NewProjectService, services.NewPromptService, services.NewOAuthService, cfg.GetCfg, logger.GetLogger, db.NewDB) +var Set = wire.NewSet(api.NewServer, api.NewGrpcServer, api.NewGinServer, auth.NewOAuthHandler, auth.NewAuthServer, chat.NewChatServer, chat.NewChatServerV2, user.NewUserServer, project.NewProjectServer, comment.NewCommentServer, usage.NewUsageServer, client.NewAIClient, client.NewAIClientV2, services.NewReverseCommentService, services.NewChatService, services.NewChatServiceV2, services.NewTokenService, services.NewUserService, services.NewProjectService, services.NewPromptService, services.NewOAuthService, services.NewUsageService, cfg.GetCfg, logger.GetLogger, db.NewDB) diff --git a/pkg/gen/api/chat/v2/chat.pb.go b/pkg/gen/api/chat/v2/chat.pb.go index 0d312c55..485bfd0f 100644 --- a/pkg/gen/api/chat/v2/chat.pb.go +++ b/pkg/gen/api/chat/v2/chat.pb.go @@ -7,13 +7,12 @@ package chatv2 import ( - reflect "reflect" - sync "sync" - unsafe "unsafe" - _ "google.golang.org/genproto/googleapis/api/annotations" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" ) const ( diff --git a/pkg/gen/api/usage/v1/usage.pb.go b/pkg/gen/api/usage/v1/usage.pb.go new file mode 100644 index 00000000..3e914ee0 --- /dev/null +++ b/pkg/gen/api/usage/v1/usage.pb.go @@ -0,0 +1,446 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc (unknown) +// source: usage/v1/usage.proto + +package usagev1 + +import ( + _ "google.golang.org/genproto/googleapis/api/annotations" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + timestamppb "google.golang.org/protobuf/types/known/timestamppb" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type SessionUsage struct { + state protoimpl.MessageState `protogen:"open.v1"` + Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` + SessionStart *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=session_start,json=sessionStart,proto3" json:"session_start,omitempty"` + SessionExpiry *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=session_expiry,json=sessionExpiry,proto3" json:"session_expiry,omitempty"` + PromptTokens int64 `protobuf:"varint,4,opt,name=prompt_tokens,json=promptTokens,proto3" json:"prompt_tokens,omitempty"` + CompletionTokens int64 `protobuf:"varint,5,opt,name=completion_tokens,json=completionTokens,proto3" json:"completion_tokens,omitempty"` + TotalTokens int64 `protobuf:"varint,6,opt,name=total_tokens,json=totalTokens,proto3" json:"total_tokens,omitempty"` + RequestCount int64 `protobuf:"varint,7,opt,name=request_count,json=requestCount,proto3" json:"request_count,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SessionUsage) Reset() { + *x = SessionUsage{} + mi := &file_usage_v1_usage_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SessionUsage) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SessionUsage) ProtoMessage() {} + +func (x *SessionUsage) ProtoReflect() protoreflect.Message { + mi := &file_usage_v1_usage_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SessionUsage.ProtoReflect.Descriptor instead. +func (*SessionUsage) Descriptor() ([]byte, []int) { + return file_usage_v1_usage_proto_rawDescGZIP(), []int{0} +} + +func (x *SessionUsage) GetId() string { + if x != nil { + return x.Id + } + return "" +} + +func (x *SessionUsage) GetSessionStart() *timestamppb.Timestamp { + if x != nil { + return x.SessionStart + } + return nil +} + +func (x *SessionUsage) GetSessionExpiry() *timestamppb.Timestamp { + if x != nil { + return x.SessionExpiry + } + return nil +} + +func (x *SessionUsage) GetPromptTokens() int64 { + if x != nil { + return x.PromptTokens + } + return 0 +} + +func (x *SessionUsage) GetCompletionTokens() int64 { + if x != nil { + return x.CompletionTokens + } + return 0 +} + +func (x *SessionUsage) GetTotalTokens() int64 { + if x != nil { + return x.TotalTokens + } + return 0 +} + +func (x *SessionUsage) GetRequestCount() int64 { + if x != nil { + return x.RequestCount + } + return 0 +} + +type WeeklyUsage struct { + state protoimpl.MessageState `protogen:"open.v1"` + PromptTokens int64 `protobuf:"varint,1,opt,name=prompt_tokens,json=promptTokens,proto3" json:"prompt_tokens,omitempty"` + CompletionTokens int64 `protobuf:"varint,2,opt,name=completion_tokens,json=completionTokens,proto3" json:"completion_tokens,omitempty"` + TotalTokens int64 `protobuf:"varint,3,opt,name=total_tokens,json=totalTokens,proto3" json:"total_tokens,omitempty"` + RequestCount int64 `protobuf:"varint,4,opt,name=request_count,json=requestCount,proto3" json:"request_count,omitempty"` + SessionCount int64 `protobuf:"varint,5,opt,name=session_count,json=sessionCount,proto3" json:"session_count,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *WeeklyUsage) Reset() { + *x = WeeklyUsage{} + mi := &file_usage_v1_usage_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *WeeklyUsage) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*WeeklyUsage) ProtoMessage() {} + +func (x *WeeklyUsage) ProtoReflect() protoreflect.Message { + mi := &file_usage_v1_usage_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use WeeklyUsage.ProtoReflect.Descriptor instead. +func (*WeeklyUsage) Descriptor() ([]byte, []int) { + return file_usage_v1_usage_proto_rawDescGZIP(), []int{1} +} + +func (x *WeeklyUsage) GetPromptTokens() int64 { + if x != nil { + return x.PromptTokens + } + return 0 +} + +func (x *WeeklyUsage) GetCompletionTokens() int64 { + if x != nil { + return x.CompletionTokens + } + return 0 +} + +func (x *WeeklyUsage) GetTotalTokens() int64 { + if x != nil { + return x.TotalTokens + } + return 0 +} + +func (x *WeeklyUsage) GetRequestCount() int64 { + if x != nil { + return x.RequestCount + } + return 0 +} + +func (x *WeeklyUsage) GetSessionCount() int64 { + if x != nil { + return x.SessionCount + } + return 0 +} + +type GetSessionUsageRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetSessionUsageRequest) Reset() { + *x = GetSessionUsageRequest{} + mi := &file_usage_v1_usage_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetSessionUsageRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetSessionUsageRequest) ProtoMessage() {} + +func (x *GetSessionUsageRequest) ProtoReflect() protoreflect.Message { + mi := &file_usage_v1_usage_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetSessionUsageRequest.ProtoReflect.Descriptor instead. +func (*GetSessionUsageRequest) Descriptor() ([]byte, []int) { + return file_usage_v1_usage_proto_rawDescGZIP(), []int{2} +} + +type GetSessionUsageResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Active session usage, null if no active session + Session *SessionUsage `protobuf:"bytes,1,opt,name=session,proto3" json:"session,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetSessionUsageResponse) Reset() { + *x = GetSessionUsageResponse{} + mi := &file_usage_v1_usage_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetSessionUsageResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetSessionUsageResponse) ProtoMessage() {} + +func (x *GetSessionUsageResponse) ProtoReflect() protoreflect.Message { + mi := &file_usage_v1_usage_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetSessionUsageResponse.ProtoReflect.Descriptor instead. +func (*GetSessionUsageResponse) Descriptor() ([]byte, []int) { + return file_usage_v1_usage_proto_rawDescGZIP(), []int{3} +} + +func (x *GetSessionUsageResponse) GetSession() *SessionUsage { + if x != nil { + return x.Session + } + return nil +} + +type GetWeeklyUsageRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetWeeklyUsageRequest) Reset() { + *x = GetWeeklyUsageRequest{} + mi := &file_usage_v1_usage_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetWeeklyUsageRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetWeeklyUsageRequest) ProtoMessage() {} + +func (x *GetWeeklyUsageRequest) ProtoReflect() protoreflect.Message { + mi := &file_usage_v1_usage_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetWeeklyUsageRequest.ProtoReflect.Descriptor instead. +func (*GetWeeklyUsageRequest) Descriptor() ([]byte, []int) { + return file_usage_v1_usage_proto_rawDescGZIP(), []int{4} +} + +type GetWeeklyUsageResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Usage *WeeklyUsage `protobuf:"bytes,1,opt,name=usage,proto3" json:"usage,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetWeeklyUsageResponse) Reset() { + *x = GetWeeklyUsageResponse{} + mi := &file_usage_v1_usage_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetWeeklyUsageResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetWeeklyUsageResponse) ProtoMessage() {} + +func (x *GetWeeklyUsageResponse) ProtoReflect() protoreflect.Message { + mi := &file_usage_v1_usage_proto_msgTypes[5] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetWeeklyUsageResponse.ProtoReflect.Descriptor instead. +func (*GetWeeklyUsageResponse) Descriptor() ([]byte, []int) { + return file_usage_v1_usage_proto_rawDescGZIP(), []int{5} +} + +func (x *GetWeeklyUsageResponse) GetUsage() *WeeklyUsage { + if x != nil { + return x.Usage + } + return nil +} + +var File_usage_v1_usage_proto protoreflect.FileDescriptor + +const file_usage_v1_usage_proto_rawDesc = "" + + "\n" + + "\x14usage/v1/usage.proto\x12\busage.v1\x1a\x1cgoogle/api/annotations.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\xbc\x02\n" + + "\fSessionUsage\x12\x0e\n" + + "\x02id\x18\x01 \x01(\tR\x02id\x12?\n" + + "\rsession_start\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\fsessionStart\x12A\n" + + "\x0esession_expiry\x18\x03 \x01(\v2\x1a.google.protobuf.TimestampR\rsessionExpiry\x12#\n" + + "\rprompt_tokens\x18\x04 \x01(\x03R\fpromptTokens\x12+\n" + + "\x11completion_tokens\x18\x05 \x01(\x03R\x10completionTokens\x12!\n" + + "\ftotal_tokens\x18\x06 \x01(\x03R\vtotalTokens\x12#\n" + + "\rrequest_count\x18\a \x01(\x03R\frequestCount\"\xcc\x01\n" + + "\vWeeklyUsage\x12#\n" + + "\rprompt_tokens\x18\x01 \x01(\x03R\fpromptTokens\x12+\n" + + "\x11completion_tokens\x18\x02 \x01(\x03R\x10completionTokens\x12!\n" + + "\ftotal_tokens\x18\x03 \x01(\x03R\vtotalTokens\x12#\n" + + "\rrequest_count\x18\x04 \x01(\x03R\frequestCount\x12#\n" + + "\rsession_count\x18\x05 \x01(\x03R\fsessionCount\"\x18\n" + + "\x16GetSessionUsageRequest\"K\n" + + "\x17GetSessionUsageResponse\x120\n" + + "\asession\x18\x01 \x01(\v2\x16.usage.v1.SessionUsageR\asession\"\x17\n" + + "\x15GetWeeklyUsageRequest\"E\n" + + "\x16GetWeeklyUsageResponse\x12+\n" + + "\x05usage\x18\x01 \x01(\v2\x15.usage.v1.WeeklyUsageR\x05usage2\x9a\x02\n" + + "\fUsageService\x12\x85\x01\n" + + "\x0fGetSessionUsage\x12 .usage.v1.GetSessionUsageRequest\x1a!.usage.v1.GetSessionUsageResponse\"-\x82\xd3\xe4\x93\x02'\x12%/_pd/api/v1/users/@self/usage/session\x12\x81\x01\n" + + "\x0eGetWeeklyUsage\x12\x1f.usage.v1.GetWeeklyUsageRequest\x1a .usage.v1.GetWeeklyUsageResponse\",\x82\xd3\xe4\x93\x02&\x12$/_pd/api/v1/users/@self/usage/weeklyB\x87\x01\n" + + "\fcom.usage.v1B\n" + + "UsageProtoP\x01Z*paperdebugger/pkg/gen/api/usage/v1;usagev1\xa2\x02\x03UXX\xaa\x02\bUsage.V1\xca\x02\bUsage\\V1\xe2\x02\x14Usage\\V1\\GPBMetadata\xea\x02\tUsage::V1b\x06proto3" + +var ( + file_usage_v1_usage_proto_rawDescOnce sync.Once + file_usage_v1_usage_proto_rawDescData []byte +) + +func file_usage_v1_usage_proto_rawDescGZIP() []byte { + file_usage_v1_usage_proto_rawDescOnce.Do(func() { + file_usage_v1_usage_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_usage_v1_usage_proto_rawDesc), len(file_usage_v1_usage_proto_rawDesc))) + }) + return file_usage_v1_usage_proto_rawDescData +} + +var file_usage_v1_usage_proto_msgTypes = make([]protoimpl.MessageInfo, 6) +var file_usage_v1_usage_proto_goTypes = []any{ + (*SessionUsage)(nil), // 0: usage.v1.SessionUsage + (*WeeklyUsage)(nil), // 1: usage.v1.WeeklyUsage + (*GetSessionUsageRequest)(nil), // 2: usage.v1.GetSessionUsageRequest + (*GetSessionUsageResponse)(nil), // 3: usage.v1.GetSessionUsageResponse + (*GetWeeklyUsageRequest)(nil), // 4: usage.v1.GetWeeklyUsageRequest + (*GetWeeklyUsageResponse)(nil), // 5: usage.v1.GetWeeklyUsageResponse + (*timestamppb.Timestamp)(nil), // 6: google.protobuf.Timestamp +} +var file_usage_v1_usage_proto_depIdxs = []int32{ + 6, // 0: usage.v1.SessionUsage.session_start:type_name -> google.protobuf.Timestamp + 6, // 1: usage.v1.SessionUsage.session_expiry:type_name -> google.protobuf.Timestamp + 0, // 2: usage.v1.GetSessionUsageResponse.session:type_name -> usage.v1.SessionUsage + 1, // 3: usage.v1.GetWeeklyUsageResponse.usage:type_name -> usage.v1.WeeklyUsage + 2, // 4: usage.v1.UsageService.GetSessionUsage:input_type -> usage.v1.GetSessionUsageRequest + 4, // 5: usage.v1.UsageService.GetWeeklyUsage:input_type -> usage.v1.GetWeeklyUsageRequest + 3, // 6: usage.v1.UsageService.GetSessionUsage:output_type -> usage.v1.GetSessionUsageResponse + 5, // 7: usage.v1.UsageService.GetWeeklyUsage:output_type -> usage.v1.GetWeeklyUsageResponse + 6, // [6:8] is the sub-list for method output_type + 4, // [4:6] is the sub-list for method input_type + 4, // [4:4] is the sub-list for extension type_name + 4, // [4:4] is the sub-list for extension extendee + 0, // [0:4] is the sub-list for field type_name +} + +func init() { file_usage_v1_usage_proto_init() } +func file_usage_v1_usage_proto_init() { + if File_usage_v1_usage_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_usage_v1_usage_proto_rawDesc), len(file_usage_v1_usage_proto_rawDesc)), + NumEnums: 0, + NumMessages: 6, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_usage_v1_usage_proto_goTypes, + DependencyIndexes: file_usage_v1_usage_proto_depIdxs, + MessageInfos: file_usage_v1_usage_proto_msgTypes, + }.Build() + File_usage_v1_usage_proto = out.File + file_usage_v1_usage_proto_goTypes = nil + file_usage_v1_usage_proto_depIdxs = nil +} diff --git a/pkg/gen/api/usage/v1/usage.pb.gw.go b/pkg/gen/api/usage/v1/usage.pb.gw.go new file mode 100644 index 00000000..3a455736 --- /dev/null +++ b/pkg/gen/api/usage/v1/usage.pb.gw.go @@ -0,0 +1,211 @@ +// Code generated by protoc-gen-grpc-gateway. DO NOT EDIT. +// source: usage/v1/usage.proto + +/* +Package usagev1 is a reverse proxy. + +It translates gRPC into RESTful JSON APIs. +*/ +package usagev1 + +import ( + "context" + "errors" + "io" + "net/http" + + "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" + "github.com/grpc-ecosystem/grpc-gateway/v2/utilities" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" +) + +// Suppress "imported and not used" errors +var ( + _ codes.Code + _ io.Reader + _ status.Status + _ = errors.New + _ = runtime.String + _ = utilities.NewDoubleArray + _ = metadata.Join +) + +func request_UsageService_GetSessionUsage_0(ctx context.Context, marshaler runtime.Marshaler, client UsageServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq GetSessionUsageRequest + metadata runtime.ServerMetadata + ) + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } + msg, err := client.GetSessionUsage(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err +} + +func local_request_UsageService_GetSessionUsage_0(ctx context.Context, marshaler runtime.Marshaler, server UsageServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq GetSessionUsageRequest + metadata runtime.ServerMetadata + ) + msg, err := server.GetSessionUsage(ctx, &protoReq) + return msg, metadata, err +} + +func request_UsageService_GetWeeklyUsage_0(ctx context.Context, marshaler runtime.Marshaler, client UsageServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq GetWeeklyUsageRequest + metadata runtime.ServerMetadata + ) + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } + msg, err := client.GetWeeklyUsage(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err +} + +func local_request_UsageService_GetWeeklyUsage_0(ctx context.Context, marshaler runtime.Marshaler, server UsageServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq GetWeeklyUsageRequest + metadata runtime.ServerMetadata + ) + msg, err := server.GetWeeklyUsage(ctx, &protoReq) + return msg, metadata, err +} + +// RegisterUsageServiceHandlerServer registers the http handlers for service UsageService to "mux". +// UnaryRPC :call UsageServiceServer directly. +// StreamingRPC :currently unsupported pending https://github.com/grpc/grpc-go/issues/906. +// Note that using this registration option will cause many gRPC library features to stop working. Consider using RegisterUsageServiceHandlerFromEndpoint instead. +// GRPC interceptors will not work for this type of registration. To use interceptors, you must use the "runtime.WithMiddlewares" option in the "runtime.NewServeMux" call. +func RegisterUsageServiceHandlerServer(ctx context.Context, mux *runtime.ServeMux, server UsageServiceServer) error { + mux.Handle(http.MethodGet, pattern_UsageService_GetSessionUsage_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + var stream runtime.ServerTransportStream + ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/usage.v1.UsageService/GetSessionUsage", runtime.WithHTTPPathPattern("/_pd/api/v1/users/@self/usage/session")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_UsageService_GetSessionUsage_0(annotatedContext, inboundMarshaler, server, req, pathParams) + md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_UsageService_GetSessionUsage_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) + mux.Handle(http.MethodGet, pattern_UsageService_GetWeeklyUsage_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + var stream runtime.ServerTransportStream + ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/usage.v1.UsageService/GetWeeklyUsage", runtime.WithHTTPPathPattern("/_pd/api/v1/users/@self/usage/weekly")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_UsageService_GetWeeklyUsage_0(annotatedContext, inboundMarshaler, server, req, pathParams) + md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_UsageService_GetWeeklyUsage_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) + + return nil +} + +// RegisterUsageServiceHandlerFromEndpoint is same as RegisterUsageServiceHandler but +// automatically dials to "endpoint" and closes the connection when "ctx" gets done. +func RegisterUsageServiceHandlerFromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) (err error) { + conn, err := grpc.NewClient(endpoint, opts...) + if err != nil { + return err + } + defer func() { + if err != nil { + if cerr := conn.Close(); cerr != nil { + grpclog.Errorf("Failed to close conn to %s: %v", endpoint, cerr) + } + return + } + go func() { + <-ctx.Done() + if cerr := conn.Close(); cerr != nil { + grpclog.Errorf("Failed to close conn to %s: %v", endpoint, cerr) + } + }() + }() + return RegisterUsageServiceHandler(ctx, mux, conn) +} + +// RegisterUsageServiceHandler registers the http handlers for service UsageService to "mux". +// The handlers forward requests to the grpc endpoint over "conn". +func RegisterUsageServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error { + return RegisterUsageServiceHandlerClient(ctx, mux, NewUsageServiceClient(conn)) +} + +// RegisterUsageServiceHandlerClient registers the http handlers for service UsageService +// to "mux". The handlers forward requests to the grpc endpoint over the given implementation of "UsageServiceClient". +// Note: the gRPC framework executes interceptors within the gRPC handler. If the passed in "UsageServiceClient" +// doesn't go through the normal gRPC flow (creating a gRPC client etc.) then it will be up to the passed in +// "UsageServiceClient" to call the correct interceptors. This client ignores the HTTP middlewares. +func RegisterUsageServiceHandlerClient(ctx context.Context, mux *runtime.ServeMux, client UsageServiceClient) error { + mux.Handle(http.MethodGet, pattern_UsageService_GetSessionUsage_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/usage.v1.UsageService/GetSessionUsage", runtime.WithHTTPPathPattern("/_pd/api/v1/users/@self/usage/session")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_UsageService_GetSessionUsage_0(annotatedContext, inboundMarshaler, client, req, pathParams) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_UsageService_GetSessionUsage_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) + mux.Handle(http.MethodGet, pattern_UsageService_GetWeeklyUsage_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/usage.v1.UsageService/GetWeeklyUsage", runtime.WithHTTPPathPattern("/_pd/api/v1/users/@self/usage/weekly")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_UsageService_GetWeeklyUsage_0(annotatedContext, inboundMarshaler, client, req, pathParams) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_UsageService_GetWeeklyUsage_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) + return nil +} + +var ( + pattern_UsageService_GetSessionUsage_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3, 2, 4, 2, 5, 2, 6}, []string{"_pd", "api", "v1", "users", "@self", "usage", "session"}, "")) + pattern_UsageService_GetWeeklyUsage_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3, 2, 4, 2, 5, 2, 6}, []string{"_pd", "api", "v1", "users", "@self", "usage", "weekly"}, "")) +) + +var ( + forward_UsageService_GetSessionUsage_0 = runtime.ForwardResponseMessage + forward_UsageService_GetWeeklyUsage_0 = runtime.ForwardResponseMessage +) diff --git a/pkg/gen/api/usage/v1/usage_grpc.pb.go b/pkg/gen/api/usage/v1/usage_grpc.pb.go new file mode 100644 index 00000000..7d33c1dd --- /dev/null +++ b/pkg/gen/api/usage/v1/usage_grpc.pb.go @@ -0,0 +1,159 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.6.1 +// - protoc (unknown) +// source: usage/v1/usage.proto + +package usagev1 + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + UsageService_GetSessionUsage_FullMethodName = "/usage.v1.UsageService/GetSessionUsage" + UsageService_GetWeeklyUsage_FullMethodName = "/usage.v1.UsageService/GetWeeklyUsage" +) + +// UsageServiceClient is the client API for UsageService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type UsageServiceClient interface { + GetSessionUsage(ctx context.Context, in *GetSessionUsageRequest, opts ...grpc.CallOption) (*GetSessionUsageResponse, error) + GetWeeklyUsage(ctx context.Context, in *GetWeeklyUsageRequest, opts ...grpc.CallOption) (*GetWeeklyUsageResponse, error) +} + +type usageServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewUsageServiceClient(cc grpc.ClientConnInterface) UsageServiceClient { + return &usageServiceClient{cc} +} + +func (c *usageServiceClient) GetSessionUsage(ctx context.Context, in *GetSessionUsageRequest, opts ...grpc.CallOption) (*GetSessionUsageResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(GetSessionUsageResponse) + err := c.cc.Invoke(ctx, UsageService_GetSessionUsage_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *usageServiceClient) GetWeeklyUsage(ctx context.Context, in *GetWeeklyUsageRequest, opts ...grpc.CallOption) (*GetWeeklyUsageResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(GetWeeklyUsageResponse) + err := c.cc.Invoke(ctx, UsageService_GetWeeklyUsage_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +// UsageServiceServer is the server API for UsageService service. +// All implementations must embed UnimplementedUsageServiceServer +// for forward compatibility. +type UsageServiceServer interface { + GetSessionUsage(context.Context, *GetSessionUsageRequest) (*GetSessionUsageResponse, error) + GetWeeklyUsage(context.Context, *GetWeeklyUsageRequest) (*GetWeeklyUsageResponse, error) + mustEmbedUnimplementedUsageServiceServer() +} + +// UnimplementedUsageServiceServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedUsageServiceServer struct{} + +func (UnimplementedUsageServiceServer) GetSessionUsage(context.Context, *GetSessionUsageRequest) (*GetSessionUsageResponse, error) { + return nil, status.Error(codes.Unimplemented, "method GetSessionUsage not implemented") +} +func (UnimplementedUsageServiceServer) GetWeeklyUsage(context.Context, *GetWeeklyUsageRequest) (*GetWeeklyUsageResponse, error) { + return nil, status.Error(codes.Unimplemented, "method GetWeeklyUsage not implemented") +} +func (UnimplementedUsageServiceServer) mustEmbedUnimplementedUsageServiceServer() {} +func (UnimplementedUsageServiceServer) testEmbeddedByValue() {} + +// UnsafeUsageServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to UsageServiceServer will +// result in compilation errors. +type UnsafeUsageServiceServer interface { + mustEmbedUnimplementedUsageServiceServer() +} + +func RegisterUsageServiceServer(s grpc.ServiceRegistrar, srv UsageServiceServer) { + // If the following call panics, it indicates UnimplementedUsageServiceServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&UsageService_ServiceDesc, srv) +} + +func _UsageService_GetSessionUsage_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetSessionUsageRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(UsageServiceServer).GetSessionUsage(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: UsageService_GetSessionUsage_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(UsageServiceServer).GetSessionUsage(ctx, req.(*GetSessionUsageRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _UsageService_GetWeeklyUsage_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetWeeklyUsageRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(UsageServiceServer).GetWeeklyUsage(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: UsageService_GetWeeklyUsage_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(UsageServiceServer).GetWeeklyUsage(ctx, req.(*GetWeeklyUsageRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// UsageService_ServiceDesc is the grpc.ServiceDesc for UsageService service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var UsageService_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "usage.v1.UsageService", + HandlerType: (*UsageServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "GetSessionUsage", + Handler: _UsageService_GetSessionUsage_Handler, + }, + { + MethodName: "GetWeeklyUsage", + Handler: _UsageService_GetWeeklyUsage_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "usage/v1/usage.proto", +} diff --git a/proto/usage/v1/usage.proto b/proto/usage/v1/usage.proto new file mode 100644 index 00000000..915e38a7 --- /dev/null +++ b/proto/usage/v1/usage.proto @@ -0,0 +1,49 @@ +syntax = "proto3"; + +package usage.v1; + +import "google/api/annotations.proto"; +import "google/protobuf/timestamp.proto"; + +option go_package = "paperdebugger/pkg/gen/api/usage/v1;usagev1"; + +service UsageService { + rpc GetSessionUsage(GetSessionUsageRequest) returns (GetSessionUsageResponse) { + option (google.api.http) = {get: "/_pd/api/v1/users/@self/usage/session"}; + } + + rpc GetWeeklyUsage(GetWeeklyUsageRequest) returns (GetWeeklyUsageResponse) { + option (google.api.http) = {get: "/_pd/api/v1/users/@self/usage/weekly"}; + } +} + +message SessionUsage { + string id = 1; + google.protobuf.Timestamp session_start = 2; + google.protobuf.Timestamp session_expiry = 3; + int64 prompt_tokens = 4; + int64 completion_tokens = 5; + int64 total_tokens = 6; + int64 request_count = 7; +} + +message WeeklyUsage { + int64 prompt_tokens = 1; + int64 completion_tokens = 2; + int64 total_tokens = 3; + int64 request_count = 4; + int64 session_count = 5; +} + +message GetSessionUsageRequest {} + +message GetSessionUsageResponse { + // Active session usage, null if no active session + SessionUsage session = 1; +} + +message GetWeeklyUsageRequest {} + +message GetWeeklyUsageResponse { + WeeklyUsage usage = 1; +} diff --git a/webapp/_webapp/src/pkg/gen/apiclient/usage/v1/usage_pb.ts b/webapp/_webapp/src/pkg/gen/apiclient/usage/v1/usage_pb.ts new file mode 100644 index 00000000..86f78383 --- /dev/null +++ b/webapp/_webapp/src/pkg/gen/apiclient/usage/v1/usage_pb.ts @@ -0,0 +1,186 @@ +// @generated by protoc-gen-es v2.11.0 with parameter "target=ts" +// @generated from file usage/v1/usage.proto (package usage.v1, syntax proto3) +/* eslint-disable */ + +import type { GenFile, GenMessage, GenService } from "@bufbuild/protobuf/codegenv2"; +import { fileDesc, messageDesc, serviceDesc } from "@bufbuild/protobuf/codegenv2"; +import { file_google_api_annotations } from "@buf/googleapis_googleapis.bufbuild_es/google/api/annotations_pb"; +import type { Timestamp } from "@bufbuild/protobuf/wkt"; +import { file_google_protobuf_timestamp } from "@bufbuild/protobuf/wkt"; +import type { Message } from "@bufbuild/protobuf"; + +/** + * Describes the file usage/v1/usage.proto. + */ +export const file_usage_v1_usage: GenFile = /*@__PURE__*/ + fileDesc("ChR1c2FnZS92MS91c2FnZS5wcm90bxIIdXNhZ2UudjEi4AEKDFNlc3Npb25Vc2FnZRIKCgJpZBgBIAEoCRIxCg1zZXNzaW9uX3N0YXJ0GAIgASgLMhouZ29vZ2xlLnByb3RvYnVmLlRpbWVzdGFtcBIyCg5zZXNzaW9uX2V4cGlyeRgDIAEoCzIaLmdvb2dsZS5wcm90b2J1Zi5UaW1lc3RhbXASFQoNcHJvbXB0X3Rva2VucxgEIAEoAxIZChFjb21wbGV0aW9uX3Rva2VucxgFIAEoAxIUCgx0b3RhbF90b2tlbnMYBiABKAMSFQoNcmVxdWVzdF9jb3VudBgHIAEoAyKDAQoLV2Vla2x5VXNhZ2USFQoNcHJvbXB0X3Rva2VucxgBIAEoAxIZChFjb21wbGV0aW9uX3Rva2VucxgCIAEoAxIUCgx0b3RhbF90b2tlbnMYAyABKAMSFQoNcmVxdWVzdF9jb3VudBgEIAEoAxIVCg1zZXNzaW9uX2NvdW50GAUgASgDIhgKFkdldFNlc3Npb25Vc2FnZVJlcXVlc3QiQgoXR2V0U2Vzc2lvblVzYWdlUmVzcG9uc2USJwoHc2Vzc2lvbhgBIAEoCzIWLnVzYWdlLnYxLlNlc3Npb25Vc2FnZSIXChVHZXRXZWVrbHlVc2FnZVJlcXVlc3QiPgoWR2V0V2Vla2x5VXNhZ2VSZXNwb25zZRIkCgV1c2FnZRgBIAEoCzIVLnVzYWdlLnYxLldlZWtseVVzYWdlMpoCCgxVc2FnZVNlcnZpY2UShQEKD0dldFNlc3Npb25Vc2FnZRIgLnVzYWdlLnYxLkdldFNlc3Npb25Vc2FnZVJlcXVlc3QaIS51c2FnZS52MS5HZXRTZXNzaW9uVXNhZ2VSZXNwb25zZSItgtPkkwInEiUvX3BkL2FwaS92MS91c2Vycy9Ac2VsZi91c2FnZS9zZXNzaW9uEoEBCg5HZXRXZWVrbHlVc2FnZRIfLnVzYWdlLnYxLkdldFdlZWtseVVzYWdlUmVxdWVzdBogLnVzYWdlLnYxLkdldFdlZWtseVVzYWdlUmVzcG9uc2UiLILT5JMCJhIkL19wZC9hcGkvdjEvdXNlcnMvQHNlbGYvdXNhZ2Uvd2Vla2x5QocBCgxjb20udXNhZ2UudjFCClVzYWdlUHJvdG9QAVoqcGFwZXJkZWJ1Z2dlci9wa2cvZ2VuL2FwaS91c2FnZS92MTt1c2FnZXYxogIDVVhYqgIIVXNhZ2UuVjHKAghVc2FnZVxWMeICFFVzYWdlXFYxXEdQQk1ldGFkYXRh6gIJVXNhZ2U6OlYxYgZwcm90bzM", [file_google_api_annotations, file_google_protobuf_timestamp]); + +/** + * @generated from message usage.v1.SessionUsage + */ +export type SessionUsage = Message<"usage.v1.SessionUsage"> & { + /** + * @generated from field: string id = 1; + */ + id: string; + + /** + * @generated from field: google.protobuf.Timestamp session_start = 2; + */ + sessionStart?: Timestamp; + + /** + * @generated from field: google.protobuf.Timestamp session_expiry = 3; + */ + sessionExpiry?: Timestamp; + + /** + * @generated from field: int64 prompt_tokens = 4; + */ + promptTokens: bigint; + + /** + * @generated from field: int64 completion_tokens = 5; + */ + completionTokens: bigint; + + /** + * @generated from field: int64 total_tokens = 6; + */ + totalTokens: bigint; + + /** + * @generated from field: int64 request_count = 7; + */ + requestCount: bigint; +}; + +/** + * Describes the message usage.v1.SessionUsage. + * Use `create(SessionUsageSchema)` to create a new message. + */ +export const SessionUsageSchema: GenMessage = /*@__PURE__*/ + messageDesc(file_usage_v1_usage, 0); + +/** + * @generated from message usage.v1.WeeklyUsage + */ +export type WeeklyUsage = Message<"usage.v1.WeeklyUsage"> & { + /** + * @generated from field: int64 prompt_tokens = 1; + */ + promptTokens: bigint; + + /** + * @generated from field: int64 completion_tokens = 2; + */ + completionTokens: bigint; + + /** + * @generated from field: int64 total_tokens = 3; + */ + totalTokens: bigint; + + /** + * @generated from field: int64 request_count = 4; + */ + requestCount: bigint; + + /** + * @generated from field: int64 session_count = 5; + */ + sessionCount: bigint; +}; + +/** + * Describes the message usage.v1.WeeklyUsage. + * Use `create(WeeklyUsageSchema)` to create a new message. + */ +export const WeeklyUsageSchema: GenMessage = /*@__PURE__*/ + messageDesc(file_usage_v1_usage, 1); + +/** + * @generated from message usage.v1.GetSessionUsageRequest + */ +export type GetSessionUsageRequest = Message<"usage.v1.GetSessionUsageRequest"> & { +}; + +/** + * Describes the message usage.v1.GetSessionUsageRequest. + * Use `create(GetSessionUsageRequestSchema)` to create a new message. + */ +export const GetSessionUsageRequestSchema: GenMessage = /*@__PURE__*/ + messageDesc(file_usage_v1_usage, 2); + +/** + * @generated from message usage.v1.GetSessionUsageResponse + */ +export type GetSessionUsageResponse = Message<"usage.v1.GetSessionUsageResponse"> & { + /** + * Active session usage, null if no active session + * + * @generated from field: usage.v1.SessionUsage session = 1; + */ + session?: SessionUsage; +}; + +/** + * Describes the message usage.v1.GetSessionUsageResponse. + * Use `create(GetSessionUsageResponseSchema)` to create a new message. + */ +export const GetSessionUsageResponseSchema: GenMessage = /*@__PURE__*/ + messageDesc(file_usage_v1_usage, 3); + +/** + * @generated from message usage.v1.GetWeeklyUsageRequest + */ +export type GetWeeklyUsageRequest = Message<"usage.v1.GetWeeklyUsageRequest"> & { +}; + +/** + * Describes the message usage.v1.GetWeeklyUsageRequest. + * Use `create(GetWeeklyUsageRequestSchema)` to create a new message. + */ +export const GetWeeklyUsageRequestSchema: GenMessage = /*@__PURE__*/ + messageDesc(file_usage_v1_usage, 4); + +/** + * @generated from message usage.v1.GetWeeklyUsageResponse + */ +export type GetWeeklyUsageResponse = Message<"usage.v1.GetWeeklyUsageResponse"> & { + /** + * @generated from field: usage.v1.WeeklyUsage usage = 1; + */ + usage?: WeeklyUsage; +}; + +/** + * Describes the message usage.v1.GetWeeklyUsageResponse. + * Use `create(GetWeeklyUsageResponseSchema)` to create a new message. + */ +export const GetWeeklyUsageResponseSchema: GenMessage = /*@__PURE__*/ + messageDesc(file_usage_v1_usage, 5); + +/** + * @generated from service usage.v1.UsageService + */ +export const UsageService: GenService<{ + /** + * @generated from rpc usage.v1.UsageService.GetSessionUsage + */ + getSessionUsage: { + methodKind: "unary"; + input: typeof GetSessionUsageRequestSchema; + output: typeof GetSessionUsageResponseSchema; + }, + /** + * @generated from rpc usage.v1.UsageService.GetWeeklyUsage + */ + getWeeklyUsage: { + methodKind: "unary"; + input: typeof GetWeeklyUsageRequestSchema; + output: typeof GetWeeklyUsageResponseSchema; + }, +}> = /*@__PURE__*/ + serviceDesc(file_usage_v1_usage, 0); + From 3a73435dd466dfc9492a69b59774ee4be1b6da5a Mon Sep 17 00:00:00 2001 From: wjiayis Date: Sat, 21 Feb 2026 15:01:33 +0800 Subject: [PATCH 02/13] feat: LLM usage tracking, frontend --- webapp/_webapp/src/paperdebugger.tsx | 8 ++ webapp/_webapp/src/query/api.ts | 26 +++++++ webapp/_webapp/src/query/index.ts | 27 +++++++ webapp/_webapp/src/query/keys.ts | 4 + webapp/_webapp/src/views/usage/index.tsx | 98 ++++++++++++++++++++++++ 5 files changed, 163 insertions(+) create mode 100644 webapp/_webapp/src/views/usage/index.tsx diff --git a/webapp/_webapp/src/paperdebugger.tsx b/webapp/_webapp/src/paperdebugger.tsx index 5cdc5e5d..172a897e 100644 --- a/webapp/_webapp/src/paperdebugger.tsx +++ b/webapp/_webapp/src/paperdebugger.tsx @@ -2,6 +2,7 @@ import { Chat } from "./views/chat"; import { Tabs } from "./components/tabs"; import { Settings } from "./views/settings"; import { Prompts } from "./views/prompts"; +import { Usage } from "./views/usage"; import { PdAppBodyContainer } from "./components/pd-app-body-container"; export const PaperDebugger = () => { @@ -23,6 +24,13 @@ export const PaperDebugger = () => { children: , tooltip: "Prompt Library", }, + { + key: "usage", + title: "Usage", + icon: "tabler:chart-bar", + children: , + tooltip: "Usage Statistics", + }, { key: "settings", title: "Settings", diff --git a/webapp/_webapp/src/query/api.ts b/webapp/_webapp/src/query/api.ts index 4098a018..3ae67e4b 100644 --- a/webapp/_webapp/src/query/api.ts +++ b/webapp/_webapp/src/query/api.ts @@ -224,3 +224,29 @@ export const acceptComments = async (data: PlainMessage const response = await apiclient.post(`/comments/accepted`, data); return fromJson(CommentsAcceptedResponseSchema, response); }; + +// Usage +import { + GetSessionUsageResponseSchema, + GetWeeklyUsageResponseSchema, +} from "../pkg/gen/apiclient/usage/v1/usage_pb"; + +export const getSessionUsage = async () => { + if (!apiclient.hasToken()) { + throw new Error("No token"); + } + const response = await apiclient.get("/users/@self/usage/session", undefined, { + ignoreErrorToast: true, + }); + return fromJson(GetSessionUsageResponseSchema, response); +}; + +export const getWeeklyUsage = async () => { + if (!apiclient.hasToken()) { + throw new Error("No token"); + } + const response = await apiclient.get("/users/@self/usage/weekly", undefined, { + ignoreErrorToast: true, + }); + return fromJson(GetWeeklyUsageResponseSchema, response); +}; diff --git a/webapp/_webapp/src/query/index.ts b/webapp/_webapp/src/query/index.ts index 2c05d959..4c9ea5cc 100644 --- a/webapp/_webapp/src/query/index.ts +++ b/webapp/_webapp/src/query/index.ts @@ -22,6 +22,8 @@ import { upsertUserInstructions, getProjectInstructions, upsertProjectInstructions, + getSessionUsage, + getWeeklyUsage, } from "./api"; import { CreatePromptResponse, @@ -37,6 +39,10 @@ import { GetProjectInstructionsResponse, UpsertProjectInstructionsResponse, } from "../pkg/gen/apiclient/project/v1/project_pb"; +import { + GetSessionUsageResponse, + GetWeeklyUsageResponse, +} from "../pkg/gen/apiclient/usage/v1/usage_pb"; import { useAuthStore } from "../stores/auth-store"; export const useGetProjectQuery = (projectId: string, opts?: UseQueryOptionsOverride) => { @@ -166,3 +172,24 @@ export const useUpsertProjectInstructionsMutation = ( ...opts, }); }; + +// Usage +export const useGetSessionUsageQuery = (opts?: UseQueryOptionsOverride) => { + const { user } = useAuthStore(); + return useQuery({ + queryKey: queryKeys.usage.getSessionUsage().queryKey, + queryFn: () => getSessionUsage(), + enabled: !!user, + ...opts, + }); +}; + +export const useGetWeeklyUsageQuery = (opts?: UseQueryOptionsOverride) => { + const { user } = useAuthStore(); + return useQuery({ + queryKey: queryKeys.usage.getWeeklyUsage().queryKey, + queryFn: () => getWeeklyUsage(), + enabled: !!user, + ...opts, + }); +}; diff --git a/webapp/_webapp/src/query/keys.ts b/webapp/_webapp/src/query/keys.ts index e09bfd7e..dfa3fc34 100644 --- a/webapp/_webapp/src/query/keys.ts +++ b/webapp/_webapp/src/query/keys.ts @@ -5,6 +5,10 @@ export const queryKeys = createQueryKeyStore({ getUser: () => ["users", "@self"], getUserInstructions: () => ["users", "@self", "instructions"], }, + usage: { + getSessionUsage: () => ["users", "@self", "usage", "session"], + getWeeklyUsage: () => ["users", "@self", "usage", "weekly"], + }, prompts: { listPrompts: () => ["users", "@self", "prompts"], }, diff --git a/webapp/_webapp/src/views/usage/index.tsx b/webapp/_webapp/src/views/usage/index.tsx new file mode 100644 index 00000000..98ad56a1 --- /dev/null +++ b/webapp/_webapp/src/views/usage/index.tsx @@ -0,0 +1,98 @@ +import { Spinner } from "@heroui/react"; +import { TabHeader } from "../../components/tab-header"; +import { useGetSessionUsageQuery, useGetWeeklyUsageQuery } from "../../query"; +import CellWrapper from "../../components/cell-wrapper"; + +const formatNumber = (n: bigint | number | undefined): string => { + if (n === undefined) return "0"; + return Number(n).toLocaleString(); +}; + +const formatDate = (timestamp: { seconds?: bigint; nanos?: number } | undefined): string => { + if (!timestamp || !timestamp.seconds) return "N/A"; + const date = new Date(Number(timestamp.seconds) * 1000); + return date.toLocaleString(); +}; + +const SectionContainer = ({ children }: { children: React.ReactNode }) => { + return
{children}
; +}; + +const SectionTitle = ({ children }: { children: React.ReactNode }) => { + return
{children}
; +}; + +const StatItem = ({ label, value }: { label: string; value: string }) => { + return ( +
+ {label} + {value} +
+ ); +}; + +export const Usage = () => { + const { data: sessionData, isLoading: sessionLoading } = useGetSessionUsageQuery(); + const { data: weeklyData, isLoading: weeklyLoading } = useGetWeeklyUsageQuery(); + + const isLoading = sessionLoading || weeklyLoading; + + if (isLoading) { + return ( +
+ +
+ ); + } + + const session = sessionData?.session; + const weekly = weeklyData?.usage; + + return ( +
+ +
+ + Current Session + {session ? ( + +
+ + +
+ + + + +
+ + ) : ( + +
No active session
+
+ )} + + + + Weekly Summary + {weekly ? ( + +
+ + + +
+ + +
+ + ) : ( + +
No usage data available
+
+ )} + +
+
+ ); +}; From c882e6eb1456e7d2f865c316dd0d9017a3179b21 Mon Sep 17 00:00:00 2001 From: wjiayis Date: Sat, 21 Feb 2026 15:14:53 +0800 Subject: [PATCH 03/13] feat: simplify frontend display --- internal/api/usage/get_session_usage.go | 9 +- internal/api/usage/get_weekly_usage.go | 6 +- pkg/gen/api/usage/v1/usage.pb.go | 136 ++++-------------- proto/usage/v1/usage.proto | 15 +- .../pkg/gen/apiclient/usage/v1/usage_pb.ts | 53 +------ webapp/_webapp/src/views/usage/index.tsx | 44 +++--- 6 files changed, 62 insertions(+), 201 deletions(-) diff --git a/internal/api/usage/get_session_usage.go b/internal/api/usage/get_session_usage.go index d59c3b08..06a28718 100644 --- a/internal/api/usage/get_session_usage.go +++ b/internal/api/usage/get_session_usage.go @@ -31,13 +31,8 @@ func (s *UsageServer) GetSessionUsage( return &usagev1.GetSessionUsageResponse{ Session: &usagev1.SessionUsage{ - Id: session.ID.Hex(), - SessionStart: timestamppb.New(session.SessionStart.Time()), - SessionExpiry: timestamppb.New(session.SessionExpiry.Time()), - PromptTokens: session.PromptTokens, - CompletionTokens: session.CompletionTokens, - TotalTokens: session.TotalTokens, - RequestCount: session.RequestCount, + SessionExpiry: timestamppb.New(session.SessionExpiry.Time()), + TotalTokens: session.TotalTokens, }, }, nil } diff --git a/internal/api/usage/get_weekly_usage.go b/internal/api/usage/get_weekly_usage.go index e3c168f1..f87cad60 100644 --- a/internal/api/usage/get_weekly_usage.go +++ b/internal/api/usage/get_weekly_usage.go @@ -23,11 +23,7 @@ func (s *UsageServer) GetWeeklyUsage( return &usagev1.GetWeeklyUsageResponse{ Usage: &usagev1.WeeklyUsage{ - PromptTokens: stats.PromptTokens, - CompletionTokens: stats.CompletionTokens, - TotalTokens: stats.TotalTokens, - RequestCount: stats.RequestCount, - SessionCount: stats.SessionCount, + TotalTokens: stats.TotalTokens, }, }, nil } diff --git a/pkg/gen/api/usage/v1/usage.pb.go b/pkg/gen/api/usage/v1/usage.pb.go index 3e914ee0..1fcf6299 100644 --- a/pkg/gen/api/usage/v1/usage.pb.go +++ b/pkg/gen/api/usage/v1/usage.pb.go @@ -24,16 +24,11 @@ const ( ) type SessionUsage struct { - state protoimpl.MessageState `protogen:"open.v1"` - Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` - SessionStart *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=session_start,json=sessionStart,proto3" json:"session_start,omitempty"` - SessionExpiry *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=session_expiry,json=sessionExpiry,proto3" json:"session_expiry,omitempty"` - PromptTokens int64 `protobuf:"varint,4,opt,name=prompt_tokens,json=promptTokens,proto3" json:"prompt_tokens,omitempty"` - CompletionTokens int64 `protobuf:"varint,5,opt,name=completion_tokens,json=completionTokens,proto3" json:"completion_tokens,omitempty"` - TotalTokens int64 `protobuf:"varint,6,opt,name=total_tokens,json=totalTokens,proto3" json:"total_tokens,omitempty"` - RequestCount int64 `protobuf:"varint,7,opt,name=request_count,json=requestCount,proto3" json:"request_count,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + SessionExpiry *timestamppb.Timestamp `protobuf:"bytes,1,opt,name=session_expiry,json=sessionExpiry,proto3" json:"session_expiry,omitempty"` + TotalTokens int64 `protobuf:"varint,2,opt,name=total_tokens,json=totalTokens,proto3" json:"total_tokens,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *SessionUsage) Reset() { @@ -66,20 +61,6 @@ func (*SessionUsage) Descriptor() ([]byte, []int) { return file_usage_v1_usage_proto_rawDescGZIP(), []int{0} } -func (x *SessionUsage) GetId() string { - if x != nil { - return x.Id - } - return "" -} - -func (x *SessionUsage) GetSessionStart() *timestamppb.Timestamp { - if x != nil { - return x.SessionStart - } - return nil -} - func (x *SessionUsage) GetSessionExpiry() *timestamppb.Timestamp { if x != nil { return x.SessionExpiry @@ -87,20 +68,6 @@ func (x *SessionUsage) GetSessionExpiry() *timestamppb.Timestamp { return nil } -func (x *SessionUsage) GetPromptTokens() int64 { - if x != nil { - return x.PromptTokens - } - return 0 -} - -func (x *SessionUsage) GetCompletionTokens() int64 { - if x != nil { - return x.CompletionTokens - } - return 0 -} - func (x *SessionUsage) GetTotalTokens() int64 { if x != nil { return x.TotalTokens @@ -108,22 +75,11 @@ func (x *SessionUsage) GetTotalTokens() int64 { return 0 } -func (x *SessionUsage) GetRequestCount() int64 { - if x != nil { - return x.RequestCount - } - return 0 -} - type WeeklyUsage struct { - state protoimpl.MessageState `protogen:"open.v1"` - PromptTokens int64 `protobuf:"varint,1,opt,name=prompt_tokens,json=promptTokens,proto3" json:"prompt_tokens,omitempty"` - CompletionTokens int64 `protobuf:"varint,2,opt,name=completion_tokens,json=completionTokens,proto3" json:"completion_tokens,omitempty"` - TotalTokens int64 `protobuf:"varint,3,opt,name=total_tokens,json=totalTokens,proto3" json:"total_tokens,omitempty"` - RequestCount int64 `protobuf:"varint,4,opt,name=request_count,json=requestCount,proto3" json:"request_count,omitempty"` - SessionCount int64 `protobuf:"varint,5,opt,name=session_count,json=sessionCount,proto3" json:"session_count,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + TotalTokens int64 `protobuf:"varint,1,opt,name=total_tokens,json=totalTokens,proto3" json:"total_tokens,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *WeeklyUsage) Reset() { @@ -156,20 +112,6 @@ func (*WeeklyUsage) Descriptor() ([]byte, []int) { return file_usage_v1_usage_proto_rawDescGZIP(), []int{1} } -func (x *WeeklyUsage) GetPromptTokens() int64 { - if x != nil { - return x.PromptTokens - } - return 0 -} - -func (x *WeeklyUsage) GetCompletionTokens() int64 { - if x != nil { - return x.CompletionTokens - } - return 0 -} - func (x *WeeklyUsage) GetTotalTokens() int64 { if x != nil { return x.TotalTokens @@ -177,20 +119,6 @@ func (x *WeeklyUsage) GetTotalTokens() int64 { return 0 } -func (x *WeeklyUsage) GetRequestCount() int64 { - if x != nil { - return x.RequestCount - } - return 0 -} - -func (x *WeeklyUsage) GetSessionCount() int64 { - if x != nil { - return x.SessionCount - } - return 0 -} - type GetSessionUsageRequest struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -356,21 +284,12 @@ var File_usage_v1_usage_proto protoreflect.FileDescriptor const file_usage_v1_usage_proto_rawDesc = "" + "\n" + - "\x14usage/v1/usage.proto\x12\busage.v1\x1a\x1cgoogle/api/annotations.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\xbc\x02\n" + - "\fSessionUsage\x12\x0e\n" + - "\x02id\x18\x01 \x01(\tR\x02id\x12?\n" + - "\rsession_start\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\fsessionStart\x12A\n" + - "\x0esession_expiry\x18\x03 \x01(\v2\x1a.google.protobuf.TimestampR\rsessionExpiry\x12#\n" + - "\rprompt_tokens\x18\x04 \x01(\x03R\fpromptTokens\x12+\n" + - "\x11completion_tokens\x18\x05 \x01(\x03R\x10completionTokens\x12!\n" + - "\ftotal_tokens\x18\x06 \x01(\x03R\vtotalTokens\x12#\n" + - "\rrequest_count\x18\a \x01(\x03R\frequestCount\"\xcc\x01\n" + - "\vWeeklyUsage\x12#\n" + - "\rprompt_tokens\x18\x01 \x01(\x03R\fpromptTokens\x12+\n" + - "\x11completion_tokens\x18\x02 \x01(\x03R\x10completionTokens\x12!\n" + - "\ftotal_tokens\x18\x03 \x01(\x03R\vtotalTokens\x12#\n" + - "\rrequest_count\x18\x04 \x01(\x03R\frequestCount\x12#\n" + - "\rsession_count\x18\x05 \x01(\x03R\fsessionCount\"\x18\n" + + "\x14usage/v1/usage.proto\x12\busage.v1\x1a\x1cgoogle/api/annotations.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"t\n" + + "\fSessionUsage\x12A\n" + + "\x0esession_expiry\x18\x01 \x01(\v2\x1a.google.protobuf.TimestampR\rsessionExpiry\x12!\n" + + "\ftotal_tokens\x18\x02 \x01(\x03R\vtotalTokens\"0\n" + + "\vWeeklyUsage\x12!\n" + + "\ftotal_tokens\x18\x01 \x01(\x03R\vtotalTokens\"\x18\n" + "\x16GetSessionUsageRequest\"K\n" + "\x17GetSessionUsageResponse\x120\n" + "\asession\x18\x01 \x01(\v2\x16.usage.v1.SessionUsageR\asession\"\x17\n" + @@ -406,19 +325,18 @@ var file_usage_v1_usage_proto_goTypes = []any{ (*timestamppb.Timestamp)(nil), // 6: google.protobuf.Timestamp } var file_usage_v1_usage_proto_depIdxs = []int32{ - 6, // 0: usage.v1.SessionUsage.session_start:type_name -> google.protobuf.Timestamp - 6, // 1: usage.v1.SessionUsage.session_expiry:type_name -> google.protobuf.Timestamp - 0, // 2: usage.v1.GetSessionUsageResponse.session:type_name -> usage.v1.SessionUsage - 1, // 3: usage.v1.GetWeeklyUsageResponse.usage:type_name -> usage.v1.WeeklyUsage - 2, // 4: usage.v1.UsageService.GetSessionUsage:input_type -> usage.v1.GetSessionUsageRequest - 4, // 5: usage.v1.UsageService.GetWeeklyUsage:input_type -> usage.v1.GetWeeklyUsageRequest - 3, // 6: usage.v1.UsageService.GetSessionUsage:output_type -> usage.v1.GetSessionUsageResponse - 5, // 7: usage.v1.UsageService.GetWeeklyUsage:output_type -> usage.v1.GetWeeklyUsageResponse - 6, // [6:8] is the sub-list for method output_type - 4, // [4:6] is the sub-list for method input_type - 4, // [4:4] is the sub-list for extension type_name - 4, // [4:4] is the sub-list for extension extendee - 0, // [0:4] is the sub-list for field type_name + 6, // 0: usage.v1.SessionUsage.session_expiry:type_name -> google.protobuf.Timestamp + 0, // 1: usage.v1.GetSessionUsageResponse.session:type_name -> usage.v1.SessionUsage + 1, // 2: usage.v1.GetWeeklyUsageResponse.usage:type_name -> usage.v1.WeeklyUsage + 2, // 3: usage.v1.UsageService.GetSessionUsage:input_type -> usage.v1.GetSessionUsageRequest + 4, // 4: usage.v1.UsageService.GetWeeklyUsage:input_type -> usage.v1.GetWeeklyUsageRequest + 3, // 5: usage.v1.UsageService.GetSessionUsage:output_type -> usage.v1.GetSessionUsageResponse + 5, // 6: usage.v1.UsageService.GetWeeklyUsage:output_type -> usage.v1.GetWeeklyUsageResponse + 5, // [5:7] is the sub-list for method output_type + 3, // [3:5] is the sub-list for method input_type + 3, // [3:3] is the sub-list for extension type_name + 3, // [3:3] is the sub-list for extension extendee + 0, // [0:3] is the sub-list for field type_name } func init() { file_usage_v1_usage_proto_init() } diff --git a/proto/usage/v1/usage.proto b/proto/usage/v1/usage.proto index 915e38a7..d9141dd0 100644 --- a/proto/usage/v1/usage.proto +++ b/proto/usage/v1/usage.proto @@ -18,21 +18,12 @@ service UsageService { } message SessionUsage { - string id = 1; - google.protobuf.Timestamp session_start = 2; - google.protobuf.Timestamp session_expiry = 3; - int64 prompt_tokens = 4; - int64 completion_tokens = 5; - int64 total_tokens = 6; - int64 request_count = 7; + google.protobuf.Timestamp session_expiry = 1; + int64 total_tokens = 2; } message WeeklyUsage { - int64 prompt_tokens = 1; - int64 completion_tokens = 2; - int64 total_tokens = 3; - int64 request_count = 4; - int64 session_count = 5; + int64 total_tokens = 1; } message GetSessionUsageRequest {} diff --git a/webapp/_webapp/src/pkg/gen/apiclient/usage/v1/usage_pb.ts b/webapp/_webapp/src/pkg/gen/apiclient/usage/v1/usage_pb.ts index 86f78383..35ec21ae 100644 --- a/webapp/_webapp/src/pkg/gen/apiclient/usage/v1/usage_pb.ts +++ b/webapp/_webapp/src/pkg/gen/apiclient/usage/v1/usage_pb.ts @@ -13,46 +13,21 @@ import type { Message } from "@bufbuild/protobuf"; * Describes the file usage/v1/usage.proto. */ export const file_usage_v1_usage: GenFile = /*@__PURE__*/ - fileDesc("ChR1c2FnZS92MS91c2FnZS5wcm90bxIIdXNhZ2UudjEi4AEKDFNlc3Npb25Vc2FnZRIKCgJpZBgBIAEoCRIxCg1zZXNzaW9uX3N0YXJ0GAIgASgLMhouZ29vZ2xlLnByb3RvYnVmLlRpbWVzdGFtcBIyCg5zZXNzaW9uX2V4cGlyeRgDIAEoCzIaLmdvb2dsZS5wcm90b2J1Zi5UaW1lc3RhbXASFQoNcHJvbXB0X3Rva2VucxgEIAEoAxIZChFjb21wbGV0aW9uX3Rva2VucxgFIAEoAxIUCgx0b3RhbF90b2tlbnMYBiABKAMSFQoNcmVxdWVzdF9jb3VudBgHIAEoAyKDAQoLV2Vla2x5VXNhZ2USFQoNcHJvbXB0X3Rva2VucxgBIAEoAxIZChFjb21wbGV0aW9uX3Rva2VucxgCIAEoAxIUCgx0b3RhbF90b2tlbnMYAyABKAMSFQoNcmVxdWVzdF9jb3VudBgEIAEoAxIVCg1zZXNzaW9uX2NvdW50GAUgASgDIhgKFkdldFNlc3Npb25Vc2FnZVJlcXVlc3QiQgoXR2V0U2Vzc2lvblVzYWdlUmVzcG9uc2USJwoHc2Vzc2lvbhgBIAEoCzIWLnVzYWdlLnYxLlNlc3Npb25Vc2FnZSIXChVHZXRXZWVrbHlVc2FnZVJlcXVlc3QiPgoWR2V0V2Vla2x5VXNhZ2VSZXNwb25zZRIkCgV1c2FnZRgBIAEoCzIVLnVzYWdlLnYxLldlZWtseVVzYWdlMpoCCgxVc2FnZVNlcnZpY2UShQEKD0dldFNlc3Npb25Vc2FnZRIgLnVzYWdlLnYxLkdldFNlc3Npb25Vc2FnZVJlcXVlc3QaIS51c2FnZS52MS5HZXRTZXNzaW9uVXNhZ2VSZXNwb25zZSItgtPkkwInEiUvX3BkL2FwaS92MS91c2Vycy9Ac2VsZi91c2FnZS9zZXNzaW9uEoEBCg5HZXRXZWVrbHlVc2FnZRIfLnVzYWdlLnYxLkdldFdlZWtseVVzYWdlUmVxdWVzdBogLnVzYWdlLnYxLkdldFdlZWtseVVzYWdlUmVzcG9uc2UiLILT5JMCJhIkL19wZC9hcGkvdjEvdXNlcnMvQHNlbGYvdXNhZ2Uvd2Vla2x5QocBCgxjb20udXNhZ2UudjFCClVzYWdlUHJvdG9QAVoqcGFwZXJkZWJ1Z2dlci9wa2cvZ2VuL2FwaS91c2FnZS92MTt1c2FnZXYxogIDVVhYqgIIVXNhZ2UuVjHKAghVc2FnZVxWMeICFFVzYWdlXFYxXEdQQk1ldGFkYXRh6gIJVXNhZ2U6OlYxYgZwcm90bzM", [file_google_api_annotations, file_google_protobuf_timestamp]); + fileDesc("ChR1c2FnZS92MS91c2FnZS5wcm90bxIIdXNhZ2UudjEiWAoMU2Vzc2lvblVzYWdlEjIKDnNlc3Npb25fZXhwaXJ5GAEgASgLMhouZ29vZ2xlLnByb3RvYnVmLlRpbWVzdGFtcBIUCgx0b3RhbF90b2tlbnMYAiABKAMiIwoLV2Vla2x5VXNhZ2USFAoMdG90YWxfdG9rZW5zGAEgASgDIhgKFkdldFNlc3Npb25Vc2FnZVJlcXVlc3QiQgoXR2V0U2Vzc2lvblVzYWdlUmVzcG9uc2USJwoHc2Vzc2lvbhgBIAEoCzIWLnVzYWdlLnYxLlNlc3Npb25Vc2FnZSIXChVHZXRXZWVrbHlVc2FnZVJlcXVlc3QiPgoWR2V0V2Vla2x5VXNhZ2VSZXNwb25zZRIkCgV1c2FnZRgBIAEoCzIVLnVzYWdlLnYxLldlZWtseVVzYWdlMpoCCgxVc2FnZVNlcnZpY2UShQEKD0dldFNlc3Npb25Vc2FnZRIgLnVzYWdlLnYxLkdldFNlc3Npb25Vc2FnZVJlcXVlc3QaIS51c2FnZS52MS5HZXRTZXNzaW9uVXNhZ2VSZXNwb25zZSItgtPkkwInEiUvX3BkL2FwaS92MS91c2Vycy9Ac2VsZi91c2FnZS9zZXNzaW9uEoEBCg5HZXRXZWVrbHlVc2FnZRIfLnVzYWdlLnYxLkdldFdlZWtseVVzYWdlUmVxdWVzdBogLnVzYWdlLnYxLkdldFdlZWtseVVzYWdlUmVzcG9uc2UiLILT5JMCJhIkL19wZC9hcGkvdjEvdXNlcnMvQHNlbGYvdXNhZ2Uvd2Vla2x5QocBCgxjb20udXNhZ2UudjFCClVzYWdlUHJvdG9QAVoqcGFwZXJkZWJ1Z2dlci9wa2cvZ2VuL2FwaS91c2FnZS92MTt1c2FnZXYxogIDVVhYqgIIVXNhZ2UuVjHKAghVc2FnZVxWMeICFFVzYWdlXFYxXEdQQk1ldGFkYXRh6gIJVXNhZ2U6OlYxYgZwcm90bzM", [file_google_api_annotations, file_google_protobuf_timestamp]); /** * @generated from message usage.v1.SessionUsage */ export type SessionUsage = Message<"usage.v1.SessionUsage"> & { /** - * @generated from field: string id = 1; - */ - id: string; - - /** - * @generated from field: google.protobuf.Timestamp session_start = 2; - */ - sessionStart?: Timestamp; - - /** - * @generated from field: google.protobuf.Timestamp session_expiry = 3; + * @generated from field: google.protobuf.Timestamp session_expiry = 1; */ sessionExpiry?: Timestamp; /** - * @generated from field: int64 prompt_tokens = 4; - */ - promptTokens: bigint; - - /** - * @generated from field: int64 completion_tokens = 5; - */ - completionTokens: bigint; - - /** - * @generated from field: int64 total_tokens = 6; + * @generated from field: int64 total_tokens = 2; */ totalTokens: bigint; - - /** - * @generated from field: int64 request_count = 7; - */ - requestCount: bigint; }; /** @@ -67,29 +42,9 @@ export const SessionUsageSchema: GenMessage = /*@__PURE__*/ */ export type WeeklyUsage = Message<"usage.v1.WeeklyUsage"> & { /** - * @generated from field: int64 prompt_tokens = 1; - */ - promptTokens: bigint; - - /** - * @generated from field: int64 completion_tokens = 2; - */ - completionTokens: bigint; - - /** - * @generated from field: int64 total_tokens = 3; + * @generated from field: int64 total_tokens = 1; */ totalTokens: bigint; - - /** - * @generated from field: int64 request_count = 4; - */ - requestCount: bigint; - - /** - * @generated from field: int64 session_count = 5; - */ - sessionCount: bigint; }; /** diff --git a/webapp/_webapp/src/views/usage/index.tsx b/webapp/_webapp/src/views/usage/index.tsx index 98ad56a1..26dd15dc 100644 --- a/webapp/_webapp/src/views/usage/index.tsx +++ b/webapp/_webapp/src/views/usage/index.tsx @@ -8,10 +8,22 @@ const formatNumber = (n: bigint | number | undefined): string => { return Number(n).toLocaleString(); }; -const formatDate = (timestamp: { seconds?: bigint; nanos?: number } | undefined): string => { - if (!timestamp || !timestamp.seconds) return "N/A"; - const date = new Date(Number(timestamp.seconds) * 1000); - return date.toLocaleString(); +const formatTimeRemaining = (timestamp: { seconds?: bigint; nanos?: number } | undefined): string => { + if (!timestamp || !timestamp.seconds) return ""; + const expiryMs = Number(timestamp.seconds) * 1000; + const nowMs = Date.now(); + const diffMs = expiryMs - nowMs; + + if (diffMs <= 0) return ""; + + const totalMinutes = Math.floor(diffMs / 60000); + const hours = Math.floor(totalMinutes / 60); + const minutes = totalMinutes % 60; + + if (hours > 0) { + return `resets in ${hours} hr ${minutes} min`; + } + return `resets in ${minutes} min`; }; const SectionContainer = ({ children }: { children: React.ReactNode }) => { @@ -53,17 +65,16 @@ export const Usage = () => {
- Current Session + + Current Session + {session?.sessionExpiry && ( + ({formatTimeRemaining(session.sessionExpiry)}) + )} + {session ? (
- - -
- - - - +
) : ( @@ -74,16 +85,11 @@ export const Usage = () => { - Weekly Summary + Weekly Limits {weekly ? (
- - - -
- - +
) : ( From baeeaa787decd6432fcb7984fb41208a3ab5d5c6 Mon Sep 17 00:00:00 2001 From: wjiayis Date: Sat, 21 Feb 2026 15:22:58 +0800 Subject: [PATCH 04/13] feat: refresh handling --- webapp/_webapp/src/views/usage/index.tsx | 63 ++++++++++++++++++++++-- 1 file changed, 60 insertions(+), 3 deletions(-) diff --git a/webapp/_webapp/src/views/usage/index.tsx b/webapp/_webapp/src/views/usage/index.tsx index 26dd15dc..36756be7 100644 --- a/webapp/_webapp/src/views/usage/index.tsx +++ b/webapp/_webapp/src/views/usage/index.tsx @@ -1,4 +1,6 @@ -import { Spinner } from "@heroui/react"; +import { Spinner, Button } from "@heroui/react"; +import { Icon } from "@iconify/react"; +import { useState, useEffect } from "react"; import { TabHeader } from "../../components/tab-header"; import { useGetSessionUsageQuery, useGetWeeklyUsageQuery } from "../../query"; import CellWrapper from "../../components/cell-wrapper"; @@ -26,6 +28,20 @@ const formatTimeRemaining = (timestamp: { seconds?: bigint; nanos?: number } | u return `resets in ${minutes} min`; }; +const formatLastUpdated = (timestamp: number): string => { + const diffMs = Date.now() - timestamp; + const seconds = Math.floor(diffMs / 1000); + const minutes = Math.floor(seconds / 60); + const hours = Math.floor(minutes / 60); + + if (seconds < 10) return "just now"; + if (seconds < 60) return `${seconds} seconds ago`; + if (minutes === 1) return "1 minute ago"; + if (minutes < 60) return `${minutes} minutes ago`; + if (hours === 1) return "1 hour ago"; + return `${hours} hours ago`; +}; + const SectionContainer = ({ children }: { children: React.ReactNode }) => { return
{children}
; }; @@ -44,10 +60,35 @@ const StatItem = ({ label, value }: { label: string; value: string }) => { }; export const Usage = () => { - const { data: sessionData, isLoading: sessionLoading } = useGetSessionUsageQuery(); - const { data: weeklyData, isLoading: weeklyLoading } = useGetWeeklyUsageQuery(); + const { + data: sessionData, + isLoading: sessionLoading, + dataUpdatedAt: sessionUpdatedAt, + refetch: refetchSession, + isFetching: sessionFetching, + } = useGetSessionUsageQuery(); + const { + data: weeklyData, + isLoading: weeklyLoading, + refetch: refetchWeekly, + isFetching: weeklyFetching, + } = useGetWeeklyUsageQuery(); + + const [, setTick] = useState(0); + + // Update the "last updated" text periodically + useEffect(() => { + const interval = setInterval(() => setTick((t) => t + 1), 10000); + return () => clearInterval(interval); + }, []); const isLoading = sessionLoading || weeklyLoading; + const isFetching = sessionFetching || weeklyFetching; + + const handleRefresh = () => { + refetchSession(); + refetchWeekly(); + }; if (isLoading) { return ( @@ -98,6 +139,22 @@ export const Usage = () => { )} + +
+ + Last updated: {formatLastUpdated(sessionUpdatedAt)} + + +
); From 27092aeb08a743ba8d7745b4963d21f70a3b9f70 Mon Sep 17 00:00:00 2001 From: wjiayis Date: Sat, 21 Feb 2026 16:23:54 +0800 Subject: [PATCH 05/13] chore: update comments --- internal/services/toolkit/client/completion_v2.go | 3 ++- internal/services/usage.go | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/internal/services/toolkit/client/completion_v2.go b/internal/services/toolkit/client/completion_v2.go index 463e5e0a..2c8daa0e 100644 --- a/internal/services/toolkit/client/completion_v2.go +++ b/internal/services/toolkit/client/completion_v2.go @@ -98,8 +98,9 @@ func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream chunk := stream.Current() if len(chunk.Choices) == 0 { + // Handle usage information if chunk.Usage.TotalTokens > 0 { - // Record usage and log stats asynchronously to avoid blocking the response + // Record usage asynchronously to avoid blocking the response go func(usage services.UsageRecord) { bgCtx := context.Background() if err := a.usageService.RecordUsage(bgCtx, usage); err != nil { diff --git a/internal/services/usage.go b/internal/services/usage.go index 06603d0b..d40a7156 100644 --- a/internal/services/usage.go +++ b/internal/services/usage.go @@ -45,7 +45,7 @@ func NewUsageService(db *db.DB, cfg *cfg.Cfg, logger *logger.Logger) *UsageServi } // RecordUsage updates the active session or creates a new one if none exists. -// Uses retry logic to handle race conditions when multiple requests try to create a session. +// Falls back to update if insert fails (handles race when another request created a session). func (s *UsageService) RecordUsage(ctx context.Context, record UsageRecord) error { now := time.Now() nowBson := bson.DateTime(now.UnixMilli()) From 3c6a7f20bd470c8a73bbf81f3abfbdeb05857905 Mon Sep 17 00:00:00 2001 From: wjiayis Date: Mon, 23 Feb 2026 22:02:48 +0800 Subject: [PATCH 06/13] chore: clarify comments --- internal/libs/db/db.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/libs/db/db.go b/internal/libs/db/db.go index 8468f73c..797fe586 100644 --- a/internal/libs/db/db.go +++ b/internal/libs/db/db.go @@ -54,7 +54,7 @@ func NewDB(cfg *cfg.Cfg, logger *logger.Logger) (*DB, error) { func (db *DB) ensureIndexes() { sessions := db.Database("paperdebugger").Collection((models.LLMSession{}).CollectionName()) - // TTL index: auto-delete sessions after 30 days + // TTL index: auto-delete sessions after 30 days past their expiry time _, err := sessions.Indexes().CreateOne(context.Background(), mongo.IndexModel{ Keys: bson.D{{Key: "session_expiry", Value: 1}}, Options: options.Index().SetExpireAfterSeconds(30 * 24 * 60 * 60), From 4660a2995893f66b2b67602c2cfda79c0aa2af51 Mon Sep 17 00:00:00 2001 From: wjiayis Date: Wed, 25 Feb 2026 21:46:07 +0800 Subject: [PATCH 07/13] feat: add usage_test.go --- internal/services/usage_test.go | 565 ++++++++++++++++++++++++++++++++ 1 file changed, 565 insertions(+) create mode 100644 internal/services/usage_test.go diff --git a/internal/services/usage_test.go b/internal/services/usage_test.go new file mode 100644 index 00000000..0e28466c --- /dev/null +++ b/internal/services/usage_test.go @@ -0,0 +1,565 @@ +package services_test + +import ( + "context" + "os" + "sync" + "testing" + "time" + + "paperdebugger/internal/libs/cfg" + "paperdebugger/internal/libs/db" + "paperdebugger/internal/libs/logger" + "paperdebugger/internal/models" + "paperdebugger/internal/services" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" +) + +func setupTestUsageService(t *testing.T) (*services.UsageService, *mongo.Collection) { + os.Setenv("PD_MONGO_URI", "mongodb://localhost:27017") + dbInstance, err := db.NewDB(cfg.GetCfg(), logger.GetLogger()) + if err != nil { + t.Fatalf("failed to connect to test db: %v", err) + } + + svc := services.NewUsageService(dbInstance, cfg.GetCfg(), logger.GetLogger()) + collection := dbInstance.Database("paperdebugger").Collection((models.LLMSession{}).CollectionName()) + + return svc, collection +} + +func cleanupSessions(t *testing.T, collection *mongo.Collection, userID bson.ObjectID) { + ctx := context.Background() + _, err := collection.DeleteMany(ctx, bson.M{"user_id": userID}) + if err != nil { + t.Logf("cleanup warning: %v", err) + } +} + +func TestUsageService_RecordUsage_NewSession(t *testing.T) { + svc, collection := setupTestUsageService(t) + ctx := context.Background() + userID := bson.NewObjectID() + defer cleanupSessions(t, collection, userID) + + record := services.UsageRecord{ + UserID: userID, + PromptTokens: 100, + CompletionTokens: 200, + TotalTokens: 300, + } + + err := svc.RecordUsage(ctx, record) + require.NoError(t, err) + + session, err := svc.GetActiveSession(ctx, userID) + require.NoError(t, err) + require.NotNil(t, session) + + assert.Equal(t, userID, session.UserID) + assert.Equal(t, int64(100), session.PromptTokens) + assert.Equal(t, int64(200), session.CompletionTokens) + assert.Equal(t, int64(300), session.TotalTokens) + assert.Equal(t, int64(1), session.RequestCount) + + // Verify session expiry is set correctly (5 hours from now) + now := time.Now() + expiryTime := time.UnixMilli(int64(session.SessionExpiry)) + expectedExpiry := now.Add(services.SessionDuration) + assert.WithinDuration(t, expectedExpiry, expiryTime, 2*time.Second) +} + +func TestUsageService_RecordUsage_ExistingActiveSession(t *testing.T) { + svc, collection := setupTestUsageService(t) + ctx := context.Background() + userID := bson.NewObjectID() + defer cleanupSessions(t, collection, userID) + + // Record first usage (creates session) + record1 := services.UsageRecord{ + UserID: userID, + PromptTokens: 100, + CompletionTokens: 200, + TotalTokens: 300, + } + err := svc.RecordUsage(ctx, record1) + require.NoError(t, err) + + // Record second usage to same session + record2 := services.UsageRecord{ + UserID: userID, + PromptTokens: 50, + CompletionTokens: 75, + TotalTokens: 125, + } + err = svc.RecordUsage(ctx, record2) + require.NoError(t, err) + + // Verify tokens are accumulated + session, err := svc.GetActiveSession(ctx, userID) + require.NoError(t, err) + require.NotNil(t, session) + + assert.Equal(t, int64(150), session.PromptTokens) + assert.Equal(t, int64(275), session.CompletionTokens) + assert.Equal(t, int64(425), session.TotalTokens) + assert.Equal(t, int64(2), session.RequestCount) +} + +func TestUsageService_RecordUsage_ExpiredSession(t *testing.T) { + svc, collection := setupTestUsageService(t) + ctx := context.Background() + userID := bson.NewObjectID() + defer cleanupSessions(t, collection, userID) + + // Create an expired session manually + now := time.Now() + expiredSession := models.LLMSession{ + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(now.Add(-6 * time.Hour).UnixMilli()), + SessionExpiry: bson.DateTime(now.Add(-1 * time.Hour).UnixMilli()), // Expired 1 hour ago + PromptTokens: 100, + CompletionTokens: 200, + TotalTokens: 300, + RequestCount: 1, + } + _, err := collection.InsertOne(ctx, expiredSession) + require.NoError(t, err) + + // Record new usage - should create a new session, not update the expired one + record := services.UsageRecord{ + UserID: userID, + PromptTokens: 50, + CompletionTokens: 75, + TotalTokens: 125, + } + err = svc.RecordUsage(ctx, record) + require.NoError(t, err) + + // Get active session + activeSession, err := svc.GetActiveSession(ctx, userID) + require.NoError(t, err) + require.NotNil(t, activeSession) + + // Should be a new session with only the new usage + assert.NotEqual(t, expiredSession.ID, activeSession.ID) + assert.Equal(t, int64(50), activeSession.PromptTokens) + assert.Equal(t, int64(75), activeSession.CompletionTokens) + assert.Equal(t, int64(125), activeSession.TotalTokens) + assert.Equal(t, int64(1), activeSession.RequestCount) +} + +func TestUsageService_RecordUsage_RaceCondition(t *testing.T) { + svc, collection := setupTestUsageService(t) + ctx := context.Background() + userID := bson.NewObjectID() + defer cleanupSessions(t, collection, userID) + + // Simulate concurrent requests trying to create sessions + concurrentRequests := 10 + var wg sync.WaitGroup + errors := make([]error, concurrentRequests) + + // Use a channel to synchronize goroutine starts for maximum race condition + start := make(chan struct{}) + + for i := range concurrentRequests { + wg.Add(1) + go func(idx int) { + defer wg.Done() + <-start // Wait for signal to start + record := services.UsageRecord{ + UserID: userID, + PromptTokens: 10, + CompletionTokens: 20, + TotalTokens: 30, + } + errors[idx] = svc.RecordUsage(ctx, record) + }(i) + } + + // Start all goroutines at once + close(start) + wg.Wait() + + // All requests should succeed (no errors) + for i, err := range errors { + assert.NoError(t, err, "Request %d should not have errored", i) + } + + // Count total sessions created (should be 1 or possibly more if race occurred) + filter := bson.M{"user_id": userID} + count, err := collection.CountDocuments(ctx, filter) + require.NoError(t, err) + + // Get all sessions to see the full picture + cursor, err := collection.Find(ctx, filter) + require.NoError(t, err) + var sessions []models.LLMSession + err = cursor.All(ctx, &sessions) + require.NoError(t, err) + + // Calculate total usage across all sessions + var totalPrompt, totalCompletion, totalTokens, totalRequests int64 + for _, s := range sessions { + totalPrompt += s.PromptTokens + totalCompletion += s.CompletionTokens + totalTokens += s.TotalTokens + totalRequests += s.RequestCount + } + + // All tokens should be accumulated across all sessions + assert.Equal(t, int64(100), totalPrompt, "Expected 10 requests * 10 tokens each") + assert.Equal(t, int64(200), totalCompletion, "Expected 10 requests * 20 tokens each") + assert.Equal(t, int64(300), totalTokens, "Expected 10 requests * 30 tokens each") + assert.Equal(t, int64(10), totalRequests, "Expected 10 requests recorded") + + // Note: In a race condition, multiple sessions might be created if concurrent + // InsertOne calls succeed simultaneously (no unique index prevents this). + // The important guarantee is that no usage data is lost - all requests are + // recorded correctly, even if spread across multiple sessions. + t.Logf("Sessions created during race: %d", count) + if count > 1 { + t.Logf("Multiple sessions created due to race condition (expected behavior)") + } +} + +func TestUsageService_GetActiveSession_NoSession(t *testing.T) { + svc, collection := setupTestUsageService(t) + ctx := context.Background() + userID := bson.NewObjectID() + defer cleanupSessions(t, collection, userID) + + session, err := svc.GetActiveSession(ctx, userID) + require.NoError(t, err) + assert.Nil(t, session) +} + +func TestUsageService_GetWeeklyUsage_SingleSession(t *testing.T) { + svc, collection := setupTestUsageService(t) + ctx := context.Background() + userID := bson.NewObjectID() + defer cleanupSessions(t, collection, userID) + + // Record some usage + record := services.UsageRecord{ + UserID: userID, + PromptTokens: 100, + CompletionTokens: 200, + TotalTokens: 300, + } + err := svc.RecordUsage(ctx, record) + require.NoError(t, err) + + // Get weekly usage + stats, err := svc.GetWeeklyUsage(ctx, userID) + require.NoError(t, err) + require.NotNil(t, stats) + + assert.Equal(t, int64(100), stats.PromptTokens) + assert.Equal(t, int64(200), stats.CompletionTokens) + assert.Equal(t, int64(300), stats.TotalTokens) + assert.Equal(t, int64(1), stats.RequestCount) + assert.Equal(t, int64(1), stats.SessionCount) +} + +func TestUsageService_GetWeeklyUsage_MultipleSessions(t *testing.T) { + svc, collection := setupTestUsageService(t) + ctx := context.Background() + userID := bson.NewObjectID() + defer cleanupSessions(t, collection, userID) + + // Create multiple sessions within the current week + now := time.Now() + sessions := []models.LLMSession{ + { + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(now.Add(-2 * 24 * time.Hour).UnixMilli()), // 2 days ago + SessionExpiry: bson.DateTime(now.Add(-2*24*time.Hour + services.SessionDuration).UnixMilli()), + PromptTokens: 100, + CompletionTokens: 200, + TotalTokens: 300, + RequestCount: 5, + }, + { + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(now.Add(-1 * 24 * time.Hour).UnixMilli()), // 1 day ago + SessionExpiry: bson.DateTime(now.Add(-1*24*time.Hour + services.SessionDuration).UnixMilli()), + PromptTokens: 50, + CompletionTokens: 75, + TotalTokens: 125, + RequestCount: 3, + }, + { + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(now.UnixMilli()), // Now + SessionExpiry: bson.DateTime(now.Add(services.SessionDuration).UnixMilli()), + PromptTokens: 200, + CompletionTokens: 300, + TotalTokens: 500, + RequestCount: 10, + }, + } + + for _, session := range sessions { + _, err := collection.InsertOne(ctx, session) + require.NoError(t, err) + } + + // Get weekly usage + stats, err := svc.GetWeeklyUsage(ctx, userID) + require.NoError(t, err) + require.NotNil(t, stats) + + // Verify aggregation + assert.Equal(t, int64(350), stats.PromptTokens) + assert.Equal(t, int64(575), stats.CompletionTokens) + assert.Equal(t, int64(925), stats.TotalTokens) + assert.Equal(t, int64(18), stats.RequestCount) + assert.Equal(t, int64(3), stats.SessionCount) +} + +func TestUsageService_GetWeeklyUsage_ExcludesOldSessions(t *testing.T) { + svc, collection := setupTestUsageService(t) + ctx := context.Background() + userID := bson.NewObjectID() + defer cleanupSessions(t, collection, userID) + + now := time.Now() + + // Create an old session (from last week) + oldSession := models.LLMSession{ + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(now.Add(-10 * 24 * time.Hour).UnixMilli()), // 10 days ago + SessionExpiry: bson.DateTime(now.Add(-10*24*time.Hour + services.SessionDuration).UnixMilli()), + PromptTokens: 1000, + CompletionTokens: 2000, + TotalTokens: 3000, + RequestCount: 50, + } + _, err := collection.InsertOne(ctx, oldSession) + require.NoError(t, err) + + // Create a current session + currentSession := models.LLMSession{ + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(now.UnixMilli()), + SessionExpiry: bson.DateTime(now.Add(services.SessionDuration).UnixMilli()), + PromptTokens: 100, + CompletionTokens: 200, + TotalTokens: 300, + RequestCount: 5, + } + _, err = collection.InsertOne(ctx, currentSession) + require.NoError(t, err) + + // Get weekly usage + stats, err := svc.GetWeeklyUsage(ctx, userID) + require.NoError(t, err) + require.NotNil(t, stats) + + // Should only include the current session + assert.Equal(t, int64(100), stats.PromptTokens) + assert.Equal(t, int64(200), stats.CompletionTokens) + assert.Equal(t, int64(300), stats.TotalTokens) + assert.Equal(t, int64(5), stats.RequestCount) + assert.Equal(t, int64(1), stats.SessionCount) +} + +func TestUsageService_GetWeeklyUsage_NoSessions(t *testing.T) { + svc, collection := setupTestUsageService(t) + ctx := context.Background() + userID := bson.NewObjectID() + defer cleanupSessions(t, collection, userID) + + stats, err := svc.GetWeeklyUsage(ctx, userID) + require.NoError(t, err) + require.NotNil(t, stats) + + // Should return zero stats + assert.Equal(t, int64(0), stats.PromptTokens) + assert.Equal(t, int64(0), stats.CompletionTokens) + assert.Equal(t, int64(0), stats.TotalTokens) + assert.Equal(t, int64(0), stats.RequestCount) + assert.Equal(t, int64(0), stats.SessionCount) +} + +func TestUsageService_ListRecentSessions(t *testing.T) { + svc, collection := setupTestUsageService(t) + ctx := context.Background() + userID := bson.NewObjectID() + defer cleanupSessions(t, collection, userID) + + // Create multiple sessions at different times + now := time.Now() + sessions := []models.LLMSession{ + { + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(now.Add(-3 * 24 * time.Hour).UnixMilli()), + SessionExpiry: bson.DateTime(now.Add(-3*24*time.Hour + services.SessionDuration).UnixMilli()), + PromptTokens: 100, + CompletionTokens: 200, + TotalTokens: 300, + RequestCount: 1, + }, + { + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(now.Add(-2 * 24 * time.Hour).UnixMilli()), + SessionExpiry: bson.DateTime(now.Add(-2*24*time.Hour + services.SessionDuration).UnixMilli()), + PromptTokens: 150, + CompletionTokens: 250, + TotalTokens: 400, + RequestCount: 2, + }, + { + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(now.Add(-1 * 24 * time.Hour).UnixMilli()), + SessionExpiry: bson.DateTime(now.Add(-1*24*time.Hour + services.SessionDuration).UnixMilli()), + PromptTokens: 200, + CompletionTokens: 300, + TotalTokens: 500, + RequestCount: 3, + }, + } + + for _, session := range sessions { + _, err := collection.InsertOne(ctx, session) + require.NoError(t, err) + } + + // List recent sessions (limit 2) + recent, err := svc.ListRecentSessions(ctx, userID, 2) + require.NoError(t, err) + assert.Len(t, recent, 2) + + // Should be in reverse chronological order (most recent first) + assert.Equal(t, int64(200), recent[0].PromptTokens) // Most recent + assert.Equal(t, int64(150), recent[1].PromptTokens) // Second most recent + + // List all sessions + all, err := svc.ListRecentSessions(ctx, userID, 10) + require.NoError(t, err) + assert.Len(t, all, 3) +} + +func TestStartOfWeek(t *testing.T) { + tests := []struct { + name string + input time.Time + expected time.Time + }{ + { + name: "Monday should return same day at 00:00", + input: time.Date(2024, 1, 1, 15, 30, 45, 0, time.UTC), // Monday + expected: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + }, + { + name: "Tuesday should return previous Monday", + input: time.Date(2024, 1, 2, 15, 30, 45, 0, time.UTC), // Tuesday + expected: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + }, + { + name: "Sunday should return previous Monday", + input: time.Date(2024, 1, 7, 15, 30, 45, 0, time.UTC), // Sunday + expected: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + }, + { + name: "Wednesday mid-week should return Monday", + input: time.Date(2024, 1, 3, 12, 0, 0, 0, time.UTC), // Wednesday + expected: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + }, + { + name: "Saturday should return previous Monday", + input: time.Date(2024, 1, 6, 23, 59, 59, 0, time.UTC), // Saturday + expected: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // We need to test the private function indirectly via GetWeeklyUsage + // But for this specific test, we'll verify the logic manually + input := tt.input.UTC() + daysFromMonday := (int(input.Weekday()) + 6) % 7 + result := time.Date(input.Year(), input.Month(), input.Day()-daysFromMonday, 0, 0, 0, 0, time.UTC) + + assert.Equal(t, tt.expected, result) + assert.Equal(t, time.Monday, result.Weekday(), "Start of week should be Monday") + }) + } +} + +func TestUsageService_GetWeeklyUsage_WeekBoundary(t *testing.T) { + svc, collection := setupTestUsageService(t) + ctx := context.Background() + userID := bson.NewObjectID() + defer cleanupSessions(t, collection, userID) + + // Get the start of this week + now := time.Now().UTC() + daysFromMonday := (int(now.Weekday()) + 6) % 7 + weekStart := time.Date(now.Year(), now.Month(), now.Day()-daysFromMonday, 0, 0, 0, 0, time.UTC) + + // Create sessions on both sides of the week boundary + sessions := []models.LLMSession{ + { + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(weekStart.Add(-1 * time.Hour).UnixMilli()), // Just before week start + SessionExpiry: bson.DateTime(weekStart.Add(-1*time.Hour + services.SessionDuration).UnixMilli()), + PromptTokens: 100, + CompletionTokens: 200, + TotalTokens: 300, + RequestCount: 1, + }, + { + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(weekStart.UnixMilli()), // Exactly at week start + SessionExpiry: bson.DateTime(weekStart.Add(services.SessionDuration).UnixMilli()), + PromptTokens: 50, + CompletionTokens: 75, + TotalTokens: 125, + RequestCount: 1, + }, + { + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(weekStart.Add(1 * time.Hour).UnixMilli()), // Just after week start + SessionExpiry: bson.DateTime(weekStart.Add(1*time.Hour + services.SessionDuration).UnixMilli()), + PromptTokens: 25, + CompletionTokens: 50, + TotalTokens: 75, + RequestCount: 1, + }, + } + + for _, session := range sessions { + _, err := collection.InsertOne(ctx, session) + require.NoError(t, err) + } + + stats, err := svc.GetWeeklyUsage(ctx, userID) + require.NoError(t, err) + require.NotNil(t, stats) + + // Should only include sessions at or after week start (last 2 sessions) + assert.Equal(t, int64(75), stats.PromptTokens) + assert.Equal(t, int64(125), stats.CompletionTokens) + assert.Equal(t, int64(200), stats.TotalTokens) + assert.Equal(t, int64(2), stats.RequestCount) + assert.Equal(t, int64(2), stats.SessionCount) +} From 92a99adace50499414e0be72c1f447c2aade169c Mon Sep 17 00:00:00 2001 From: wjiayis Date: Wed, 25 Feb 2026 22:05:16 +0800 Subject: [PATCH 08/13] feat: improve llm_sessions compound index --- internal/libs/db/db.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/internal/libs/db/db.go b/internal/libs/db/db.go index 797fe586..394866f0 100644 --- a/internal/libs/db/db.go +++ b/internal/libs/db/db.go @@ -73,4 +73,15 @@ func (db *DB) ensureIndexes() { if err != nil { db.logger.Error("Failed to create compound index on llm_sessions", "error", err) } + + // Compound index for usage queries and recent session lookups + _, err = sessions.Indexes().CreateOne(context.Background(), mongo.IndexModel{ + Keys: bson.D{ + {Key: "user_id", Value: 1}, + {Key: "session_start", Value: -1}, + }, + }) + if err != nil { + db.logger.Error("Failed to create session_start index on llm_sessions", "error", err) + } } From 736871b34d84c5cddac0423cdb075a625ae2a775 Mon Sep 17 00:00:00 2001 From: wjiayis Date: Wed, 25 Feb 2026 22:15:43 +0800 Subject: [PATCH 09/13] chore: improve race condition error handling --- internal/services/usage.go | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/internal/services/usage.go b/internal/services/usage.go index d40a7156..387be354 100644 --- a/internal/services/usage.go +++ b/internal/services/usage.go @@ -84,10 +84,24 @@ func (s *UsageService) RecordUsage(ctx context.Context, record UsageRecord) erro } _, err = s.sessionCollection.InsertOne(ctx, session) if err != nil { - // Insert failed (race condition or other error) - retry update - _, err = s.sessionCollection.UpdateOne(ctx, filter, update) + // Only retry with update if insert failed due to duplicate key (race condition) + if mongo.IsDuplicateKeyError(err) { + _, updateErr := s.sessionCollection.UpdateOne(ctx, filter, update) + if updateErr != nil { + // Log both errors for debugging + s.logger.Warn("Insert failed with duplicate key, update also failed", + "insertErr", err, + "updateErr", updateErr, + "userID", record.UserID) + return updateErr + } + // Race condition handled successfully + return nil + } + // Insert failed for non-duplicate-key reason (network, validation, etc.) + return err } - return err + return nil } // GetActiveSession returns the current active session for a user, if any. From ea8782fa631032eed501d2163029399a3c5f4d2d Mon Sep 17 00:00:00 2001 From: wjiayis Date: Wed, 4 Mar 2026 20:19:51 +0800 Subject: [PATCH 10/13] feat: break down to per-model usage --- internal/api/usage/get_session_usage.go | 13 +- internal/api/usage/get_weekly_usage.go | 14 +- internal/models/usage.go | 22 +- .../services/toolkit/client/completion_v2.go | 1 + internal/services/usage.go | 109 ++++- internal/services/usage_test.go | 414 ++++++++++++------ pkg/gen/api/usage/v1/usage.pb.go | 201 ++++++--- proto/usage/v1/usage.proto | 14 +- .../pkg/gen/apiclient/usage/v1/usage_pb.ts | 63 ++- webapp/_webapp/src/views/usage/index.tsx | 45 +- 10 files changed, 649 insertions(+), 247 deletions(-) diff --git a/internal/api/usage/get_session_usage.go b/internal/api/usage/get_session_usage.go index 06a28718..5eb35cad 100644 --- a/internal/api/usage/get_session_usage.go +++ b/internal/api/usage/get_session_usage.go @@ -29,10 +29,21 @@ func (s *UsageServer) GetSessionUsage( }, nil } + // Convert models map to proto format + models := make(map[string]*usagev1.ModelTokens) + for modelName, tokens := range session.Models { + models[modelName] = &usagev1.ModelTokens{ + PromptTokens: tokens.PromptTokens, + CompletionTokens: tokens.CompletionTokens, + TotalTokens: tokens.TotalTokens, + RequestCount: tokens.RequestCount, + } + } + return &usagev1.GetSessionUsageResponse{ Session: &usagev1.SessionUsage{ SessionExpiry: timestamppb.New(session.SessionExpiry.Time()), - TotalTokens: session.TotalTokens, + Models: models, }, }, nil } diff --git a/internal/api/usage/get_weekly_usage.go b/internal/api/usage/get_weekly_usage.go index f87cad60..e244b3c6 100644 --- a/internal/api/usage/get_weekly_usage.go +++ b/internal/api/usage/get_weekly_usage.go @@ -21,9 +21,21 @@ func (s *UsageServer) GetWeeklyUsage( return nil, err } + // Convert models map to proto format + models := make(map[string]*usagev1.ModelTokens) + for modelName, tokens := range stats.Models { + models[modelName] = &usagev1.ModelTokens{ + PromptTokens: tokens.PromptTokens, + CompletionTokens: tokens.CompletionTokens, + TotalTokens: tokens.TotalTokens, + RequestCount: tokens.RequestCount, + } + } + return &usagev1.GetWeeklyUsageResponse{ Usage: &usagev1.WeeklyUsage{ - TotalTokens: stats.TotalTokens, + Models: models, + SessionCount: stats.SessionCount, }, }, nil } diff --git a/internal/models/usage.go b/internal/models/usage.go index 91d73273..0f0aa6f2 100644 --- a/internal/models/usage.go +++ b/internal/models/usage.go @@ -2,16 +2,22 @@ package models import "go.mongodb.org/mongo-driver/v2/bson" +// ModelTokens stores token counts for a specific model. +type ModelTokens struct { + PromptTokens int64 `bson:"prompt_tokens"` + CompletionTokens int64 `bson:"completion_tokens"` + TotalTokens int64 `bson:"total_tokens"` + RequestCount int64 `bson:"request_count"` +} + // LLMSession represents a user's session for tracking LLM usage and token counts. +// Tokens are stored per model in the Models map. type LLMSession struct { - ID bson.ObjectID `bson:"_id"` - UserID bson.ObjectID `bson:"user_id"` - SessionStart bson.DateTime `bson:"session_start"` - SessionExpiry bson.DateTime `bson:"session_expiry"` - PromptTokens int64 `bson:"prompt_tokens"` - CompletionTokens int64 `bson:"completion_tokens"` - TotalTokens int64 `bson:"total_tokens"` - RequestCount int64 `bson:"request_count"` + ID bson.ObjectID `bson:"_id"` + UserID bson.ObjectID `bson:"user_id"` + SessionStart bson.DateTime `bson:"session_start"` + SessionExpiry bson.DateTime `bson:"session_expiry"` + Models map[string]*ModelTokens `bson:"models"` } func (s LLMSession) CollectionName() string { diff --git a/internal/services/toolkit/client/completion_v2.go b/internal/services/toolkit/client/completion_v2.go index 2c8daa0e..316c617b 100644 --- a/internal/services/toolkit/client/completion_v2.go +++ b/internal/services/toolkit/client/completion_v2.go @@ -110,6 +110,7 @@ func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream }(services.UsageRecord{ UserID: userID, + Model: modelSlug, PromptTokens: chunk.Usage.PromptTokens, CompletionTokens: chunk.Usage.CompletionTokens, TotalTokens: chunk.Usage.TotalTokens, diff --git a/internal/services/usage.go b/internal/services/usage.go index 387be354..332e243e 100644 --- a/internal/services/usage.go +++ b/internal/services/usage.go @@ -23,17 +23,23 @@ type UsageService struct { type UsageRecord struct { UserID bson.ObjectID + Model string PromptTokens int64 CompletionTokens int64 TotalTokens int64 } -type UsageStats struct { +// ModelUsageStats stores aggregated usage statistics for a specific model. +type ModelUsageStats struct { PromptTokens int64 `bson:"prompt_tokens"` CompletionTokens int64 `bson:"completion_tokens"` TotalTokens int64 `bson:"total_tokens"` RequestCount int64 `bson:"request_count"` - SessionCount int64 `bson:"session_count"` +} + +type UsageStats struct { + Models map[string]*ModelUsageStats `bson:"models"` + SessionCount int64 `bson:"session_count"` } func NewUsageService(db *db.DB, cfg *cfg.Cfg, logger *logger.Logger) *UsageService { @@ -45,21 +51,23 @@ func NewUsageService(db *db.DB, cfg *cfg.Cfg, logger *logger.Logger) *UsageServi } // RecordUsage updates the active session or creates a new one if none exists. -// Falls back to update if insert fails (handles race when another request created a session). +// Falls back to update if insert fails (handles race when another request created a session). func (s *UsageService) RecordUsage(ctx context.Context, record UsageRecord) error { now := time.Now() nowBson := bson.DateTime(now.UnixMilli()) + // Build field paths for per-model token storage + modelPrefix := "models." + record.Model filter := bson.M{ "user_id": record.UserID, "session_expiry": bson.M{"$gt": nowBson}, } update := bson.M{ "$inc": bson.M{ - "prompt_tokens": record.PromptTokens, - "completion_tokens": record.CompletionTokens, - "total_tokens": record.TotalTokens, - "request_count": 1, + modelPrefix + ".prompt_tokens": record.PromptTokens, + modelPrefix + ".completion_tokens": record.CompletionTokens, + modelPrefix + ".total_tokens": record.TotalTokens, + modelPrefix + ".request_count": 1, }, } @@ -73,14 +81,18 @@ func (s *UsageService) RecordUsage(ctx context.Context, record UsageRecord) erro // No active session found - create a new one session := models.LLMSession{ - ID: bson.NewObjectID(), - UserID: record.UserID, - SessionStart: nowBson, - SessionExpiry: bson.DateTime(now.Add(SessionDuration).UnixMilli()), - PromptTokens: record.PromptTokens, - CompletionTokens: record.CompletionTokens, - TotalTokens: record.TotalTokens, - RequestCount: 1, + ID: bson.NewObjectID(), + UserID: record.UserID, + SessionStart: nowBson, + SessionExpiry: bson.DateTime(now.Add(SessionDuration).UnixMilli()), + Models: map[string]*models.ModelTokens{ + record.Model: { + PromptTokens: record.PromptTokens, + CompletionTokens: record.CompletionTokens, + TotalTokens: record.TotalTokens, + RequestCount: 1, + }, + }, } _, err = s.sessionCollection.InsertOne(ctx, session) if err != nil { @@ -135,13 +147,40 @@ func (s *UsageService) getUsageSince(ctx context.Context, userID bson.ObjectID, "user_id": userID, "session_start": bson.M{"$gte": bson.DateTime(since.UnixMilli())}, }}, + // Convert models map to array for aggregation + bson.M{"$project": bson.M{ + "models_array": bson.M{"$objectToArray": "$models"}, + "session_count": bson.M{"$literal": 1}, + }}, + // Unwind the models array to aggregate per model + bson.M{"$unwind": bson.M{ + "path": "$models_array", + "preserveNullAndEmptyArrays": true, + }}, + // Group by model name and sum tokens + bson.M{"$group": bson.M{ + "_id": "$models_array.k", + "prompt_tokens": bson.M{"$sum": "$models_array.v.prompt_tokens"}, + "completion_tokens": bson.M{"$sum": "$models_array.v.completion_tokens"}, + "total_tokens": bson.M{"$sum": "$models_array.v.total_tokens"}, + "request_count": bson.M{"$sum": "$models_array.v.request_count"}, + }}, + // Reshape into array of model stats bson.M{"$group": bson.M{ - "_id": nil, - "prompt_tokens": bson.M{"$sum": "$prompt_tokens"}, - "completion_tokens": bson.M{"$sum": "$completion_tokens"}, - "total_tokens": bson.M{"$sum": "$total_tokens"}, - "request_count": bson.M{"$sum": "$request_count"}, - "session_count": bson.M{"$sum": 1}, + "_id": nil, + "models": bson.M{"$push": bson.M{ + "k": "$_id", + "v": bson.M{ + "prompt_tokens": "$prompt_tokens", + "completion_tokens": "$completion_tokens", + "total_tokens": "$total_tokens", + "request_count": "$request_count", + }, + }}, + }}, + // Convert back to object + bson.M{"$project": bson.M{ + "models": bson.M{"$arrayToObject": "$models"}, }}, } @@ -151,14 +190,40 @@ func (s *UsageService) getUsageSince(ctx context.Context, userID bson.ObjectID, } defer cursor.Close(ctx) + // Get session count separately (simpler query) + countPipeline := bson.A{ + bson.M{"$match": bson.M{ + "user_id": userID, + "session_start": bson.M{"$gte": bson.DateTime(since.UnixMilli())}, + }}, + bson.M{"$count": "session_count"}, + } + countCursor, err := s.sessionCollection.Aggregate(ctx, countPipeline) + if err != nil { + return nil, err + } + defer countCursor.Close(ctx) + + var sessionCount int64 + if countCursor.Next(ctx) { + var countResult struct { + SessionCount int64 `bson:"session_count"` + } + if err := countCursor.Decode(&countResult); err != nil { + return nil, err + } + sessionCount = countResult.SessionCount + } + if cursor.Next(ctx) { var result UsageStats if err := cursor.Decode(&result); err != nil { return nil, err } + result.SessionCount = sessionCount return &result, nil } - return &UsageStats{}, nil + return &UsageStats{Models: make(map[string]*ModelUsageStats)}, nil } // startOfWeek returns the start of the week (Monday 00:00:00 UTC). diff --git a/internal/services/usage_test.go b/internal/services/usage_test.go index 0e28466c..5acfffbd 100644 --- a/internal/services/usage_test.go +++ b/internal/services/usage_test.go @@ -48,6 +48,7 @@ func TestUsageService_RecordUsage_NewSession(t *testing.T) { record := services.UsageRecord{ UserID: userID, + Model: "gpt-4", PromptTokens: 100, CompletionTokens: 200, TotalTokens: 300, @@ -61,10 +62,12 @@ func TestUsageService_RecordUsage_NewSession(t *testing.T) { require.NotNil(t, session) assert.Equal(t, userID, session.UserID) - assert.Equal(t, int64(100), session.PromptTokens) - assert.Equal(t, int64(200), session.CompletionTokens) - assert.Equal(t, int64(300), session.TotalTokens) - assert.Equal(t, int64(1), session.RequestCount) + require.NotNil(t, session.Models) + require.NotNil(t, session.Models["gpt-4"]) + assert.Equal(t, int64(100), session.Models["gpt-4"].PromptTokens) + assert.Equal(t, int64(200), session.Models["gpt-4"].CompletionTokens) + assert.Equal(t, int64(300), session.Models["gpt-4"].TotalTokens) + assert.Equal(t, int64(1), session.Models["gpt-4"].RequestCount) // Verify session expiry is set correctly (5 hours from now) now := time.Now() @@ -82,6 +85,7 @@ func TestUsageService_RecordUsage_ExistingActiveSession(t *testing.T) { // Record first usage (creates session) record1 := services.UsageRecord{ UserID: userID, + Model: "gpt-4", PromptTokens: 100, CompletionTokens: 200, TotalTokens: 300, @@ -89,9 +93,10 @@ func TestUsageService_RecordUsage_ExistingActiveSession(t *testing.T) { err := svc.RecordUsage(ctx, record1) require.NoError(t, err) - // Record second usage to same session + // Record second usage to same session with same model record2 := services.UsageRecord{ UserID: userID, + Model: "gpt-4", PromptTokens: 50, CompletionTokens: 75, TotalTokens: 125, @@ -99,34 +104,110 @@ func TestUsageService_RecordUsage_ExistingActiveSession(t *testing.T) { err = svc.RecordUsage(ctx, record2) require.NoError(t, err) - // Verify tokens are accumulated + // Verify tokens are accumulated for the model session, err := svc.GetActiveSession(ctx, userID) require.NoError(t, err) require.NotNil(t, session) - assert.Equal(t, int64(150), session.PromptTokens) - assert.Equal(t, int64(275), session.CompletionTokens) - assert.Equal(t, int64(425), session.TotalTokens) - assert.Equal(t, int64(2), session.RequestCount) + require.NotNil(t, session.Models["gpt-4"]) + assert.Equal(t, int64(150), session.Models["gpt-4"].PromptTokens) + assert.Equal(t, int64(275), session.Models["gpt-4"].CompletionTokens) + assert.Equal(t, int64(425), session.Models["gpt-4"].TotalTokens) + assert.Equal(t, int64(2), session.Models["gpt-4"].RequestCount) } -func TestUsageService_RecordUsage_ExpiredSession(t *testing.T) { +func TestUsageService_RecordUsage_MultipleModels(t *testing.T) { svc, collection := setupTestUsageService(t) ctx := context.Background() userID := bson.NewObjectID() defer cleanupSessions(t, collection, userID) - // Create an expired session manually - now := time.Now() - expiredSession := models.LLMSession{ - ID: bson.NewObjectID(), + // Record usage for gpt-4 + record1 := services.UsageRecord{ UserID: userID, - SessionStart: bson.DateTime(now.Add(-6 * time.Hour).UnixMilli()), - SessionExpiry: bson.DateTime(now.Add(-1 * time.Hour).UnixMilli()), // Expired 1 hour ago + Model: "gpt-4", PromptTokens: 100, CompletionTokens: 200, TotalTokens: 300, - RequestCount: 1, + } + err := svc.RecordUsage(ctx, record1) + require.NoError(t, err) + + // Record usage for claude-3 + record2 := services.UsageRecord{ + UserID: userID, + Model: "claude-3", + PromptTokens: 50, + CompletionTokens: 75, + TotalTokens: 125, + } + err = svc.RecordUsage(ctx, record2) + require.NoError(t, err) + + // Record more usage for gpt-4 + record3 := services.UsageRecord{ + UserID: userID, + Model: "gpt-4", + PromptTokens: 25, + CompletionTokens: 30, + TotalTokens: 55, + } + err = svc.RecordUsage(ctx, record3) + require.NoError(t, err) + + // Verify per-model token storage + session, err := svc.GetActiveSession(ctx, userID) + require.NoError(t, err) + require.NotNil(t, session) + require.NotNil(t, session.Models) + + // Check gpt-4 tokens (accumulated from 2 records) + require.NotNil(t, session.Models["gpt-4"]) + assert.Equal(t, int64(125), session.Models["gpt-4"].PromptTokens) + assert.Equal(t, int64(230), session.Models["gpt-4"].CompletionTokens) + assert.Equal(t, int64(355), session.Models["gpt-4"].TotalTokens) + assert.Equal(t, int64(2), session.Models["gpt-4"].RequestCount) + + // Check claude-3 tokens (single record) + require.NotNil(t, session.Models["claude-3"]) + assert.Equal(t, int64(50), session.Models["claude-3"].PromptTokens) + assert.Equal(t, int64(75), session.Models["claude-3"].CompletionTokens) + assert.Equal(t, int64(125), session.Models["claude-3"].TotalTokens) + assert.Equal(t, int64(1), session.Models["claude-3"].RequestCount) + + // Verify weekly usage aggregates per model + stats, err := svc.GetWeeklyUsage(ctx, userID) + require.NoError(t, err) + require.NotNil(t, stats.Models) + + require.NotNil(t, stats.Models["gpt-4"]) + assert.Equal(t, int64(125), stats.Models["gpt-4"].PromptTokens) + + require.NotNil(t, stats.Models["claude-3"]) + assert.Equal(t, int64(50), stats.Models["claude-3"].PromptTokens) +} + +func TestUsageService_RecordUsage_ExpiredSession(t *testing.T) { + svc, collection := setupTestUsageService(t) + ctx := context.Background() + userID := bson.NewObjectID() + defer cleanupSessions(t, collection, userID) + + // Create an expired session manually + now := time.Now() + expiredSession := models.LLMSession{ + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(now.Add(-6 * time.Hour).UnixMilli()), + SessionExpiry: bson.DateTime(now.Add(-1 * time.Hour).UnixMilli()), // Expired 1 hour ago + Models: map[string]*models.ModelTokens{ + "gpt-4": { + PromptTokens: 100, + CompletionTokens: 200, + TotalTokens: 300, + RequestCount: 1, + }, + }, } _, err := collection.InsertOne(ctx, expiredSession) require.NoError(t, err) @@ -134,6 +215,7 @@ func TestUsageService_RecordUsage_ExpiredSession(t *testing.T) { // Record new usage - should create a new session, not update the expired one record := services.UsageRecord{ UserID: userID, + Model: "gpt-4", PromptTokens: 50, CompletionTokens: 75, TotalTokens: 125, @@ -148,10 +230,11 @@ func TestUsageService_RecordUsage_ExpiredSession(t *testing.T) { // Should be a new session with only the new usage assert.NotEqual(t, expiredSession.ID, activeSession.ID) - assert.Equal(t, int64(50), activeSession.PromptTokens) - assert.Equal(t, int64(75), activeSession.CompletionTokens) - assert.Equal(t, int64(125), activeSession.TotalTokens) - assert.Equal(t, int64(1), activeSession.RequestCount) + require.NotNil(t, activeSession.Models["gpt-4"]) + assert.Equal(t, int64(50), activeSession.Models["gpt-4"].PromptTokens) + assert.Equal(t, int64(75), activeSession.Models["gpt-4"].CompletionTokens) + assert.Equal(t, int64(125), activeSession.Models["gpt-4"].TotalTokens) + assert.Equal(t, int64(1), activeSession.Models["gpt-4"].RequestCount) } func TestUsageService_RecordUsage_RaceCondition(t *testing.T) { @@ -175,6 +258,7 @@ func TestUsageService_RecordUsage_RaceCondition(t *testing.T) { <-start // Wait for signal to start record := services.UsageRecord{ UserID: userID, + Model: "gpt-4", PromptTokens: 10, CompletionTokens: 20, TotalTokens: 30, @@ -204,13 +288,15 @@ func TestUsageService_RecordUsage_RaceCondition(t *testing.T) { err = cursor.All(ctx, &sessions) require.NoError(t, err) - // Calculate total usage across all sessions + // Calculate total usage across all sessions for all models var totalPrompt, totalCompletion, totalTokens, totalRequests int64 for _, s := range sessions { - totalPrompt += s.PromptTokens - totalCompletion += s.CompletionTokens - totalTokens += s.TotalTokens - totalRequests += s.RequestCount + for _, m := range s.Models { + totalPrompt += m.PromptTokens + totalCompletion += m.CompletionTokens + totalTokens += m.TotalTokens + totalRequests += m.RequestCount + } } // All tokens should be accumulated across all sessions @@ -249,6 +335,7 @@ func TestUsageService_GetWeeklyUsage_SingleSession(t *testing.T) { // Record some usage record := services.UsageRecord{ UserID: userID, + Model: "gpt-4", PromptTokens: 100, CompletionTokens: 200, TotalTokens: 300, @@ -261,10 +348,12 @@ func TestUsageService_GetWeeklyUsage_SingleSession(t *testing.T) { require.NoError(t, err) require.NotNil(t, stats) - assert.Equal(t, int64(100), stats.PromptTokens) - assert.Equal(t, int64(200), stats.CompletionTokens) - assert.Equal(t, int64(300), stats.TotalTokens) - assert.Equal(t, int64(1), stats.RequestCount) + require.NotNil(t, stats.Models) + require.NotNil(t, stats.Models["gpt-4"]) + assert.Equal(t, int64(100), stats.Models["gpt-4"].PromptTokens) + assert.Equal(t, int64(200), stats.Models["gpt-4"].CompletionTokens) + assert.Equal(t, int64(300), stats.Models["gpt-4"].TotalTokens) + assert.Equal(t, int64(1), stats.Models["gpt-4"].RequestCount) assert.Equal(t, int64(1), stats.SessionCount) } @@ -278,34 +367,46 @@ func TestUsageService_GetWeeklyUsage_MultipleSessions(t *testing.T) { now := time.Now() sessions := []models.LLMSession{ { - ID: bson.NewObjectID(), - UserID: userID, - SessionStart: bson.DateTime(now.Add(-2 * 24 * time.Hour).UnixMilli()), // 2 days ago - SessionExpiry: bson.DateTime(now.Add(-2*24*time.Hour + services.SessionDuration).UnixMilli()), - PromptTokens: 100, - CompletionTokens: 200, - TotalTokens: 300, - RequestCount: 5, + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(now.Add(-2 * 24 * time.Hour).UnixMilli()), // 2 days ago + SessionExpiry: bson.DateTime(now.Add(-2*24*time.Hour + services.SessionDuration).UnixMilli()), + Models: map[string]*models.ModelTokens{ + "gpt-4": { + PromptTokens: 100, + CompletionTokens: 200, + TotalTokens: 300, + RequestCount: 5, + }, + }, }, { - ID: bson.NewObjectID(), - UserID: userID, - SessionStart: bson.DateTime(now.Add(-1 * 24 * time.Hour).UnixMilli()), // 1 day ago - SessionExpiry: bson.DateTime(now.Add(-1*24*time.Hour + services.SessionDuration).UnixMilli()), - PromptTokens: 50, - CompletionTokens: 75, - TotalTokens: 125, - RequestCount: 3, + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(now.Add(-1 * 24 * time.Hour).UnixMilli()), // 1 day ago + SessionExpiry: bson.DateTime(now.Add(-1*24*time.Hour + services.SessionDuration).UnixMilli()), + Models: map[string]*models.ModelTokens{ + "gpt-4": { + PromptTokens: 50, + CompletionTokens: 75, + TotalTokens: 125, + RequestCount: 3, + }, + }, }, { - ID: bson.NewObjectID(), - UserID: userID, - SessionStart: bson.DateTime(now.UnixMilli()), // Now - SessionExpiry: bson.DateTime(now.Add(services.SessionDuration).UnixMilli()), - PromptTokens: 200, - CompletionTokens: 300, - TotalTokens: 500, - RequestCount: 10, + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(now.UnixMilli()), // Now + SessionExpiry: bson.DateTime(now.Add(services.SessionDuration).UnixMilli()), + Models: map[string]*models.ModelTokens{ + "gpt-4": { + PromptTokens: 200, + CompletionTokens: 300, + TotalTokens: 500, + RequestCount: 10, + }, + }, }, } @@ -319,11 +420,13 @@ func TestUsageService_GetWeeklyUsage_MultipleSessions(t *testing.T) { require.NoError(t, err) require.NotNil(t, stats) - // Verify aggregation - assert.Equal(t, int64(350), stats.PromptTokens) - assert.Equal(t, int64(575), stats.CompletionTokens) - assert.Equal(t, int64(925), stats.TotalTokens) - assert.Equal(t, int64(18), stats.RequestCount) + // Verify aggregation per model + require.NotNil(t, stats.Models) + require.NotNil(t, stats.Models["gpt-4"]) + assert.Equal(t, int64(350), stats.Models["gpt-4"].PromptTokens) + assert.Equal(t, int64(575), stats.Models["gpt-4"].CompletionTokens) + assert.Equal(t, int64(925), stats.Models["gpt-4"].TotalTokens) + assert.Equal(t, int64(18), stats.Models["gpt-4"].RequestCount) assert.Equal(t, int64(3), stats.SessionCount) } @@ -337,28 +440,36 @@ func TestUsageService_GetWeeklyUsage_ExcludesOldSessions(t *testing.T) { // Create an old session (from last week) oldSession := models.LLMSession{ - ID: bson.NewObjectID(), - UserID: userID, - SessionStart: bson.DateTime(now.Add(-10 * 24 * time.Hour).UnixMilli()), // 10 days ago - SessionExpiry: bson.DateTime(now.Add(-10*24*time.Hour + services.SessionDuration).UnixMilli()), - PromptTokens: 1000, - CompletionTokens: 2000, - TotalTokens: 3000, - RequestCount: 50, + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(now.Add(-10 * 24 * time.Hour).UnixMilli()), // 10 days ago + SessionExpiry: bson.DateTime(now.Add(-10*24*time.Hour + services.SessionDuration).UnixMilli()), + Models: map[string]*models.ModelTokens{ + "gpt-4": { + PromptTokens: 1000, + CompletionTokens: 2000, + TotalTokens: 3000, + RequestCount: 50, + }, + }, } _, err := collection.InsertOne(ctx, oldSession) require.NoError(t, err) // Create a current session currentSession := models.LLMSession{ - ID: bson.NewObjectID(), - UserID: userID, - SessionStart: bson.DateTime(now.UnixMilli()), - SessionExpiry: bson.DateTime(now.Add(services.SessionDuration).UnixMilli()), - PromptTokens: 100, - CompletionTokens: 200, - TotalTokens: 300, - RequestCount: 5, + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(now.UnixMilli()), + SessionExpiry: bson.DateTime(now.Add(services.SessionDuration).UnixMilli()), + Models: map[string]*models.ModelTokens{ + "gpt-4": { + PromptTokens: 100, + CompletionTokens: 200, + TotalTokens: 300, + RequestCount: 5, + }, + }, } _, err = collection.InsertOne(ctx, currentSession) require.NoError(t, err) @@ -369,10 +480,12 @@ func TestUsageService_GetWeeklyUsage_ExcludesOldSessions(t *testing.T) { require.NotNil(t, stats) // Should only include the current session - assert.Equal(t, int64(100), stats.PromptTokens) - assert.Equal(t, int64(200), stats.CompletionTokens) - assert.Equal(t, int64(300), stats.TotalTokens) - assert.Equal(t, int64(5), stats.RequestCount) + require.NotNil(t, stats.Models) + require.NotNil(t, stats.Models["gpt-4"]) + assert.Equal(t, int64(100), stats.Models["gpt-4"].PromptTokens) + assert.Equal(t, int64(200), stats.Models["gpt-4"].CompletionTokens) + assert.Equal(t, int64(300), stats.Models["gpt-4"].TotalTokens) + assert.Equal(t, int64(5), stats.Models["gpt-4"].RequestCount) assert.Equal(t, int64(1), stats.SessionCount) } @@ -386,11 +499,8 @@ func TestUsageService_GetWeeklyUsage_NoSessions(t *testing.T) { require.NoError(t, err) require.NotNil(t, stats) - // Should return zero stats - assert.Equal(t, int64(0), stats.PromptTokens) - assert.Equal(t, int64(0), stats.CompletionTokens) - assert.Equal(t, int64(0), stats.TotalTokens) - assert.Equal(t, int64(0), stats.RequestCount) + // Should return empty models map + assert.Empty(t, stats.Models) assert.Equal(t, int64(0), stats.SessionCount) } @@ -404,34 +514,46 @@ func TestUsageService_ListRecentSessions(t *testing.T) { now := time.Now() sessions := []models.LLMSession{ { - ID: bson.NewObjectID(), - UserID: userID, - SessionStart: bson.DateTime(now.Add(-3 * 24 * time.Hour).UnixMilli()), - SessionExpiry: bson.DateTime(now.Add(-3*24*time.Hour + services.SessionDuration).UnixMilli()), - PromptTokens: 100, - CompletionTokens: 200, - TotalTokens: 300, - RequestCount: 1, + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(now.Add(-3 * 24 * time.Hour).UnixMilli()), + SessionExpiry: bson.DateTime(now.Add(-3*24*time.Hour + services.SessionDuration).UnixMilli()), + Models: map[string]*models.ModelTokens{ + "gpt-4": { + PromptTokens: 100, + CompletionTokens: 200, + TotalTokens: 300, + RequestCount: 1, + }, + }, }, { - ID: bson.NewObjectID(), - UserID: userID, - SessionStart: bson.DateTime(now.Add(-2 * 24 * time.Hour).UnixMilli()), - SessionExpiry: bson.DateTime(now.Add(-2*24*time.Hour + services.SessionDuration).UnixMilli()), - PromptTokens: 150, - CompletionTokens: 250, - TotalTokens: 400, - RequestCount: 2, + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(now.Add(-2 * 24 * time.Hour).UnixMilli()), + SessionExpiry: bson.DateTime(now.Add(-2*24*time.Hour + services.SessionDuration).UnixMilli()), + Models: map[string]*models.ModelTokens{ + "gpt-4": { + PromptTokens: 150, + CompletionTokens: 250, + TotalTokens: 400, + RequestCount: 2, + }, + }, }, { - ID: bson.NewObjectID(), - UserID: userID, - SessionStart: bson.DateTime(now.Add(-1 * 24 * time.Hour).UnixMilli()), - SessionExpiry: bson.DateTime(now.Add(-1*24*time.Hour + services.SessionDuration).UnixMilli()), - PromptTokens: 200, - CompletionTokens: 300, - TotalTokens: 500, - RequestCount: 3, + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(now.Add(-1 * 24 * time.Hour).UnixMilli()), + SessionExpiry: bson.DateTime(now.Add(-1*24*time.Hour + services.SessionDuration).UnixMilli()), + Models: map[string]*models.ModelTokens{ + "gpt-4": { + PromptTokens: 200, + CompletionTokens: 300, + TotalTokens: 500, + RequestCount: 3, + }, + }, }, } @@ -446,8 +568,8 @@ func TestUsageService_ListRecentSessions(t *testing.T) { assert.Len(t, recent, 2) // Should be in reverse chronological order (most recent first) - assert.Equal(t, int64(200), recent[0].PromptTokens) // Most recent - assert.Equal(t, int64(150), recent[1].PromptTokens) // Second most recent + assert.Equal(t, int64(200), recent[0].Models["gpt-4"].PromptTokens) // Most recent + assert.Equal(t, int64(150), recent[1].Models["gpt-4"].PromptTokens) // Second most recent // List all sessions all, err := svc.ListRecentSessions(ctx, userID, 10) @@ -516,34 +638,46 @@ func TestUsageService_GetWeeklyUsage_WeekBoundary(t *testing.T) { // Create sessions on both sides of the week boundary sessions := []models.LLMSession{ { - ID: bson.NewObjectID(), - UserID: userID, - SessionStart: bson.DateTime(weekStart.Add(-1 * time.Hour).UnixMilli()), // Just before week start - SessionExpiry: bson.DateTime(weekStart.Add(-1*time.Hour + services.SessionDuration).UnixMilli()), - PromptTokens: 100, - CompletionTokens: 200, - TotalTokens: 300, - RequestCount: 1, + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(weekStart.Add(-1 * time.Hour).UnixMilli()), // Just before week start + SessionExpiry: bson.DateTime(weekStart.Add(-1*time.Hour + services.SessionDuration).UnixMilli()), + Models: map[string]*models.ModelTokens{ + "gpt-4": { + PromptTokens: 100, + CompletionTokens: 200, + TotalTokens: 300, + RequestCount: 1, + }, + }, }, { - ID: bson.NewObjectID(), - UserID: userID, - SessionStart: bson.DateTime(weekStart.UnixMilli()), // Exactly at week start - SessionExpiry: bson.DateTime(weekStart.Add(services.SessionDuration).UnixMilli()), - PromptTokens: 50, - CompletionTokens: 75, - TotalTokens: 125, - RequestCount: 1, + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(weekStart.UnixMilli()), // Exactly at week start + SessionExpiry: bson.DateTime(weekStart.Add(services.SessionDuration).UnixMilli()), + Models: map[string]*models.ModelTokens{ + "gpt-4": { + PromptTokens: 50, + CompletionTokens: 75, + TotalTokens: 125, + RequestCount: 1, + }, + }, }, { - ID: bson.NewObjectID(), - UserID: userID, - SessionStart: bson.DateTime(weekStart.Add(1 * time.Hour).UnixMilli()), // Just after week start - SessionExpiry: bson.DateTime(weekStart.Add(1*time.Hour + services.SessionDuration).UnixMilli()), - PromptTokens: 25, - CompletionTokens: 50, - TotalTokens: 75, - RequestCount: 1, + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(weekStart.Add(1 * time.Hour).UnixMilli()), // Just after week start + SessionExpiry: bson.DateTime(weekStart.Add(1*time.Hour + services.SessionDuration).UnixMilli()), + Models: map[string]*models.ModelTokens{ + "gpt-4": { + PromptTokens: 25, + CompletionTokens: 50, + TotalTokens: 75, + RequestCount: 1, + }, + }, }, } @@ -557,9 +691,11 @@ func TestUsageService_GetWeeklyUsage_WeekBoundary(t *testing.T) { require.NotNil(t, stats) // Should only include sessions at or after week start (last 2 sessions) - assert.Equal(t, int64(75), stats.PromptTokens) - assert.Equal(t, int64(125), stats.CompletionTokens) - assert.Equal(t, int64(200), stats.TotalTokens) - assert.Equal(t, int64(2), stats.RequestCount) + require.NotNil(t, stats.Models) + require.NotNil(t, stats.Models["gpt-4"]) + assert.Equal(t, int64(75), stats.Models["gpt-4"].PromptTokens) + assert.Equal(t, int64(125), stats.Models["gpt-4"].CompletionTokens) + assert.Equal(t, int64(200), stats.Models["gpt-4"].TotalTokens) + assert.Equal(t, int64(2), stats.Models["gpt-4"].RequestCount) assert.Equal(t, int64(2), stats.SessionCount) } diff --git a/pkg/gen/api/usage/v1/usage.pb.go b/pkg/gen/api/usage/v1/usage.pb.go index 1fcf6299..38d4fd62 100644 --- a/pkg/gen/api/usage/v1/usage.pb.go +++ b/pkg/gen/api/usage/v1/usage.pb.go @@ -23,17 +23,86 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) +type ModelTokens struct { + state protoimpl.MessageState `protogen:"open.v1"` + PromptTokens int64 `protobuf:"varint,1,opt,name=prompt_tokens,json=promptTokens,proto3" json:"prompt_tokens,omitempty"` + CompletionTokens int64 `protobuf:"varint,2,opt,name=completion_tokens,json=completionTokens,proto3" json:"completion_tokens,omitempty"` + TotalTokens int64 `protobuf:"varint,3,opt,name=total_tokens,json=totalTokens,proto3" json:"total_tokens,omitempty"` + RequestCount int64 `protobuf:"varint,4,opt,name=request_count,json=requestCount,proto3" json:"request_count,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ModelTokens) Reset() { + *x = ModelTokens{} + mi := &file_usage_v1_usage_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ModelTokens) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ModelTokens) ProtoMessage() {} + +func (x *ModelTokens) ProtoReflect() protoreflect.Message { + mi := &file_usage_v1_usage_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ModelTokens.ProtoReflect.Descriptor instead. +func (*ModelTokens) Descriptor() ([]byte, []int) { + return file_usage_v1_usage_proto_rawDescGZIP(), []int{0} +} + +func (x *ModelTokens) GetPromptTokens() int64 { + if x != nil { + return x.PromptTokens + } + return 0 +} + +func (x *ModelTokens) GetCompletionTokens() int64 { + if x != nil { + return x.CompletionTokens + } + return 0 +} + +func (x *ModelTokens) GetTotalTokens() int64 { + if x != nil { + return x.TotalTokens + } + return 0 +} + +func (x *ModelTokens) GetRequestCount() int64 { + if x != nil { + return x.RequestCount + } + return 0 +} + type SessionUsage struct { state protoimpl.MessageState `protogen:"open.v1"` SessionExpiry *timestamppb.Timestamp `protobuf:"bytes,1,opt,name=session_expiry,json=sessionExpiry,proto3" json:"session_expiry,omitempty"` - TotalTokens int64 `protobuf:"varint,2,opt,name=total_tokens,json=totalTokens,proto3" json:"total_tokens,omitempty"` + // Tokens per model (model_slug -> tokens) + Models map[string]*ModelTokens `protobuf:"bytes,2,rep,name=models,proto3" json:"models,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *SessionUsage) Reset() { *x = SessionUsage{} - mi := &file_usage_v1_usage_proto_msgTypes[0] + mi := &file_usage_v1_usage_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -45,7 +114,7 @@ func (x *SessionUsage) String() string { func (*SessionUsage) ProtoMessage() {} func (x *SessionUsage) ProtoReflect() protoreflect.Message { - mi := &file_usage_v1_usage_proto_msgTypes[0] + mi := &file_usage_v1_usage_proto_msgTypes[1] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -58,7 +127,7 @@ func (x *SessionUsage) ProtoReflect() protoreflect.Message { // Deprecated: Use SessionUsage.ProtoReflect.Descriptor instead. func (*SessionUsage) Descriptor() ([]byte, []int) { - return file_usage_v1_usage_proto_rawDescGZIP(), []int{0} + return file_usage_v1_usage_proto_rawDescGZIP(), []int{1} } func (x *SessionUsage) GetSessionExpiry() *timestamppb.Timestamp { @@ -68,23 +137,25 @@ func (x *SessionUsage) GetSessionExpiry() *timestamppb.Timestamp { return nil } -func (x *SessionUsage) GetTotalTokens() int64 { +func (x *SessionUsage) GetModels() map[string]*ModelTokens { if x != nil { - return x.TotalTokens + return x.Models } - return 0 + return nil } type WeeklyUsage struct { - state protoimpl.MessageState `protogen:"open.v1"` - TotalTokens int64 `protobuf:"varint,1,opt,name=total_tokens,json=totalTokens,proto3" json:"total_tokens,omitempty"` + state protoimpl.MessageState `protogen:"open.v1"` + // Tokens per model (model_slug -> tokens) + Models map[string]*ModelTokens `protobuf:"bytes,1,rep,name=models,proto3" json:"models,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + SessionCount int64 `protobuf:"varint,2,opt,name=session_count,json=sessionCount,proto3" json:"session_count,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *WeeklyUsage) Reset() { *x = WeeklyUsage{} - mi := &file_usage_v1_usage_proto_msgTypes[1] + mi := &file_usage_v1_usage_proto_msgTypes[2] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -96,7 +167,7 @@ func (x *WeeklyUsage) String() string { func (*WeeklyUsage) ProtoMessage() {} func (x *WeeklyUsage) ProtoReflect() protoreflect.Message { - mi := &file_usage_v1_usage_proto_msgTypes[1] + mi := &file_usage_v1_usage_proto_msgTypes[2] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -109,12 +180,19 @@ func (x *WeeklyUsage) ProtoReflect() protoreflect.Message { // Deprecated: Use WeeklyUsage.ProtoReflect.Descriptor instead. func (*WeeklyUsage) Descriptor() ([]byte, []int) { - return file_usage_v1_usage_proto_rawDescGZIP(), []int{1} + return file_usage_v1_usage_proto_rawDescGZIP(), []int{2} } -func (x *WeeklyUsage) GetTotalTokens() int64 { +func (x *WeeklyUsage) GetModels() map[string]*ModelTokens { if x != nil { - return x.TotalTokens + return x.Models + } + return nil +} + +func (x *WeeklyUsage) GetSessionCount() int64 { + if x != nil { + return x.SessionCount } return 0 } @@ -127,7 +205,7 @@ type GetSessionUsageRequest struct { func (x *GetSessionUsageRequest) Reset() { *x = GetSessionUsageRequest{} - mi := &file_usage_v1_usage_proto_msgTypes[2] + mi := &file_usage_v1_usage_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -139,7 +217,7 @@ func (x *GetSessionUsageRequest) String() string { func (*GetSessionUsageRequest) ProtoMessage() {} func (x *GetSessionUsageRequest) ProtoReflect() protoreflect.Message { - mi := &file_usage_v1_usage_proto_msgTypes[2] + mi := &file_usage_v1_usage_proto_msgTypes[3] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -152,7 +230,7 @@ func (x *GetSessionUsageRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetSessionUsageRequest.ProtoReflect.Descriptor instead. func (*GetSessionUsageRequest) Descriptor() ([]byte, []int) { - return file_usage_v1_usage_proto_rawDescGZIP(), []int{2} + return file_usage_v1_usage_proto_rawDescGZIP(), []int{3} } type GetSessionUsageResponse struct { @@ -165,7 +243,7 @@ type GetSessionUsageResponse struct { func (x *GetSessionUsageResponse) Reset() { *x = GetSessionUsageResponse{} - mi := &file_usage_v1_usage_proto_msgTypes[3] + mi := &file_usage_v1_usage_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -177,7 +255,7 @@ func (x *GetSessionUsageResponse) String() string { func (*GetSessionUsageResponse) ProtoMessage() {} func (x *GetSessionUsageResponse) ProtoReflect() protoreflect.Message { - mi := &file_usage_v1_usage_proto_msgTypes[3] + mi := &file_usage_v1_usage_proto_msgTypes[4] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -190,7 +268,7 @@ func (x *GetSessionUsageResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetSessionUsageResponse.ProtoReflect.Descriptor instead. func (*GetSessionUsageResponse) Descriptor() ([]byte, []int) { - return file_usage_v1_usage_proto_rawDescGZIP(), []int{3} + return file_usage_v1_usage_proto_rawDescGZIP(), []int{4} } func (x *GetSessionUsageResponse) GetSession() *SessionUsage { @@ -208,7 +286,7 @@ type GetWeeklyUsageRequest struct { func (x *GetWeeklyUsageRequest) Reset() { *x = GetWeeklyUsageRequest{} - mi := &file_usage_v1_usage_proto_msgTypes[4] + mi := &file_usage_v1_usage_proto_msgTypes[5] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -220,7 +298,7 @@ func (x *GetWeeklyUsageRequest) String() string { func (*GetWeeklyUsageRequest) ProtoMessage() {} func (x *GetWeeklyUsageRequest) ProtoReflect() protoreflect.Message { - mi := &file_usage_v1_usage_proto_msgTypes[4] + mi := &file_usage_v1_usage_proto_msgTypes[5] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -233,7 +311,7 @@ func (x *GetWeeklyUsageRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetWeeklyUsageRequest.ProtoReflect.Descriptor instead. func (*GetWeeklyUsageRequest) Descriptor() ([]byte, []int) { - return file_usage_v1_usage_proto_rawDescGZIP(), []int{4} + return file_usage_v1_usage_proto_rawDescGZIP(), []int{5} } type GetWeeklyUsageResponse struct { @@ -245,7 +323,7 @@ type GetWeeklyUsageResponse struct { func (x *GetWeeklyUsageResponse) Reset() { *x = GetWeeklyUsageResponse{} - mi := &file_usage_v1_usage_proto_msgTypes[5] + mi := &file_usage_v1_usage_proto_msgTypes[6] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -257,7 +335,7 @@ func (x *GetWeeklyUsageResponse) String() string { func (*GetWeeklyUsageResponse) ProtoMessage() {} func (x *GetWeeklyUsageResponse) ProtoReflect() protoreflect.Message { - mi := &file_usage_v1_usage_proto_msgTypes[5] + mi := &file_usage_v1_usage_proto_msgTypes[6] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -270,7 +348,7 @@ func (x *GetWeeklyUsageResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetWeeklyUsageResponse.ProtoReflect.Descriptor instead. func (*GetWeeklyUsageResponse) Descriptor() ([]byte, []int) { - return file_usage_v1_usage_proto_rawDescGZIP(), []int{5} + return file_usage_v1_usage_proto_rawDescGZIP(), []int{6} } func (x *GetWeeklyUsageResponse) GetUsage() *WeeklyUsage { @@ -284,12 +362,24 @@ var File_usage_v1_usage_proto protoreflect.FileDescriptor const file_usage_v1_usage_proto_rawDesc = "" + "\n" + - "\x14usage/v1/usage.proto\x12\busage.v1\x1a\x1cgoogle/api/annotations.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"t\n" + + "\x14usage/v1/usage.proto\x12\busage.v1\x1a\x1cgoogle/api/annotations.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\xa7\x01\n" + + "\vModelTokens\x12#\n" + + "\rprompt_tokens\x18\x01 \x01(\x03R\fpromptTokens\x12+\n" + + "\x11completion_tokens\x18\x02 \x01(\x03R\x10completionTokens\x12!\n" + + "\ftotal_tokens\x18\x03 \x01(\x03R\vtotalTokens\x12#\n" + + "\rrequest_count\x18\x04 \x01(\x03R\frequestCount\"\xdf\x01\n" + "\fSessionUsage\x12A\n" + - "\x0esession_expiry\x18\x01 \x01(\v2\x1a.google.protobuf.TimestampR\rsessionExpiry\x12!\n" + - "\ftotal_tokens\x18\x02 \x01(\x03R\vtotalTokens\"0\n" + - "\vWeeklyUsage\x12!\n" + - "\ftotal_tokens\x18\x01 \x01(\x03R\vtotalTokens\"\x18\n" + + "\x0esession_expiry\x18\x01 \x01(\v2\x1a.google.protobuf.TimestampR\rsessionExpiry\x12:\n" + + "\x06models\x18\x02 \x03(\v2\".usage.v1.SessionUsage.ModelsEntryR\x06models\x1aP\n" + + "\vModelsEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12+\n" + + "\x05value\x18\x02 \x01(\v2\x15.usage.v1.ModelTokensR\x05value:\x028\x01\"\xbf\x01\n" + + "\vWeeklyUsage\x129\n" + + "\x06models\x18\x01 \x03(\v2!.usage.v1.WeeklyUsage.ModelsEntryR\x06models\x12#\n" + + "\rsession_count\x18\x02 \x01(\x03R\fsessionCount\x1aP\n" + + "\vModelsEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12+\n" + + "\x05value\x18\x02 \x01(\v2\x15.usage.v1.ModelTokensR\x05value:\x028\x01\"\x18\n" + "\x16GetSessionUsageRequest\"K\n" + "\x17GetSessionUsageResponse\x120\n" + "\asession\x18\x01 \x01(\v2\x16.usage.v1.SessionUsageR\asession\"\x17\n" + @@ -314,29 +404,36 @@ func file_usage_v1_usage_proto_rawDescGZIP() []byte { return file_usage_v1_usage_proto_rawDescData } -var file_usage_v1_usage_proto_msgTypes = make([]protoimpl.MessageInfo, 6) +var file_usage_v1_usage_proto_msgTypes = make([]protoimpl.MessageInfo, 9) var file_usage_v1_usage_proto_goTypes = []any{ - (*SessionUsage)(nil), // 0: usage.v1.SessionUsage - (*WeeklyUsage)(nil), // 1: usage.v1.WeeklyUsage - (*GetSessionUsageRequest)(nil), // 2: usage.v1.GetSessionUsageRequest - (*GetSessionUsageResponse)(nil), // 3: usage.v1.GetSessionUsageResponse - (*GetWeeklyUsageRequest)(nil), // 4: usage.v1.GetWeeklyUsageRequest - (*GetWeeklyUsageResponse)(nil), // 5: usage.v1.GetWeeklyUsageResponse - (*timestamppb.Timestamp)(nil), // 6: google.protobuf.Timestamp + (*ModelTokens)(nil), // 0: usage.v1.ModelTokens + (*SessionUsage)(nil), // 1: usage.v1.SessionUsage + (*WeeklyUsage)(nil), // 2: usage.v1.WeeklyUsage + (*GetSessionUsageRequest)(nil), // 3: usage.v1.GetSessionUsageRequest + (*GetSessionUsageResponse)(nil), // 4: usage.v1.GetSessionUsageResponse + (*GetWeeklyUsageRequest)(nil), // 5: usage.v1.GetWeeklyUsageRequest + (*GetWeeklyUsageResponse)(nil), // 6: usage.v1.GetWeeklyUsageResponse + nil, // 7: usage.v1.SessionUsage.ModelsEntry + nil, // 8: usage.v1.WeeklyUsage.ModelsEntry + (*timestamppb.Timestamp)(nil), // 9: google.protobuf.Timestamp } var file_usage_v1_usage_proto_depIdxs = []int32{ - 6, // 0: usage.v1.SessionUsage.session_expiry:type_name -> google.protobuf.Timestamp - 0, // 1: usage.v1.GetSessionUsageResponse.session:type_name -> usage.v1.SessionUsage - 1, // 2: usage.v1.GetWeeklyUsageResponse.usage:type_name -> usage.v1.WeeklyUsage - 2, // 3: usage.v1.UsageService.GetSessionUsage:input_type -> usage.v1.GetSessionUsageRequest - 4, // 4: usage.v1.UsageService.GetWeeklyUsage:input_type -> usage.v1.GetWeeklyUsageRequest - 3, // 5: usage.v1.UsageService.GetSessionUsage:output_type -> usage.v1.GetSessionUsageResponse - 5, // 6: usage.v1.UsageService.GetWeeklyUsage:output_type -> usage.v1.GetWeeklyUsageResponse - 5, // [5:7] is the sub-list for method output_type - 3, // [3:5] is the sub-list for method input_type - 3, // [3:3] is the sub-list for extension type_name - 3, // [3:3] is the sub-list for extension extendee - 0, // [0:3] is the sub-list for field type_name + 9, // 0: usage.v1.SessionUsage.session_expiry:type_name -> google.protobuf.Timestamp + 7, // 1: usage.v1.SessionUsage.models:type_name -> usage.v1.SessionUsage.ModelsEntry + 8, // 2: usage.v1.WeeklyUsage.models:type_name -> usage.v1.WeeklyUsage.ModelsEntry + 1, // 3: usage.v1.GetSessionUsageResponse.session:type_name -> usage.v1.SessionUsage + 2, // 4: usage.v1.GetWeeklyUsageResponse.usage:type_name -> usage.v1.WeeklyUsage + 0, // 5: usage.v1.SessionUsage.ModelsEntry.value:type_name -> usage.v1.ModelTokens + 0, // 6: usage.v1.WeeklyUsage.ModelsEntry.value:type_name -> usage.v1.ModelTokens + 3, // 7: usage.v1.UsageService.GetSessionUsage:input_type -> usage.v1.GetSessionUsageRequest + 5, // 8: usage.v1.UsageService.GetWeeklyUsage:input_type -> usage.v1.GetWeeklyUsageRequest + 4, // 9: usage.v1.UsageService.GetSessionUsage:output_type -> usage.v1.GetSessionUsageResponse + 6, // 10: usage.v1.UsageService.GetWeeklyUsage:output_type -> usage.v1.GetWeeklyUsageResponse + 9, // [9:11] is the sub-list for method output_type + 7, // [7:9] is the sub-list for method input_type + 7, // [7:7] is the sub-list for extension type_name + 7, // [7:7] is the sub-list for extension extendee + 0, // [0:7] is the sub-list for field type_name } func init() { file_usage_v1_usage_proto_init() } @@ -350,7 +447,7 @@ func file_usage_v1_usage_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_usage_v1_usage_proto_rawDesc), len(file_usage_v1_usage_proto_rawDesc)), NumEnums: 0, - NumMessages: 6, + NumMessages: 9, NumExtensions: 0, NumServices: 1, }, diff --git a/proto/usage/v1/usage.proto b/proto/usage/v1/usage.proto index d9141dd0..5696fa06 100644 --- a/proto/usage/v1/usage.proto +++ b/proto/usage/v1/usage.proto @@ -17,13 +17,23 @@ service UsageService { } } +message ModelTokens { + int64 prompt_tokens = 1; + int64 completion_tokens = 2; + int64 total_tokens = 3; + int64 request_count = 4; +} + message SessionUsage { google.protobuf.Timestamp session_expiry = 1; - int64 total_tokens = 2; + // Tokens per model (model_slug -> tokens) + map models = 2; } message WeeklyUsage { - int64 total_tokens = 1; + // Tokens per model (model_slug -> tokens) + map models = 1; + int64 session_count = 2; } message GetSessionUsageRequest {} diff --git a/webapp/_webapp/src/pkg/gen/apiclient/usage/v1/usage_pb.ts b/webapp/_webapp/src/pkg/gen/apiclient/usage/v1/usage_pb.ts index 35ec21ae..efba93f5 100644 --- a/webapp/_webapp/src/pkg/gen/apiclient/usage/v1/usage_pb.ts +++ b/webapp/_webapp/src/pkg/gen/apiclient/usage/v1/usage_pb.ts @@ -13,7 +13,39 @@ import type { Message } from "@bufbuild/protobuf"; * Describes the file usage/v1/usage.proto. */ export const file_usage_v1_usage: GenFile = /*@__PURE__*/ - fileDesc("ChR1c2FnZS92MS91c2FnZS5wcm90bxIIdXNhZ2UudjEiWAoMU2Vzc2lvblVzYWdlEjIKDnNlc3Npb25fZXhwaXJ5GAEgASgLMhouZ29vZ2xlLnByb3RvYnVmLlRpbWVzdGFtcBIUCgx0b3RhbF90b2tlbnMYAiABKAMiIwoLV2Vla2x5VXNhZ2USFAoMdG90YWxfdG9rZW5zGAEgASgDIhgKFkdldFNlc3Npb25Vc2FnZVJlcXVlc3QiQgoXR2V0U2Vzc2lvblVzYWdlUmVzcG9uc2USJwoHc2Vzc2lvbhgBIAEoCzIWLnVzYWdlLnYxLlNlc3Npb25Vc2FnZSIXChVHZXRXZWVrbHlVc2FnZVJlcXVlc3QiPgoWR2V0V2Vla2x5VXNhZ2VSZXNwb25zZRIkCgV1c2FnZRgBIAEoCzIVLnVzYWdlLnYxLldlZWtseVVzYWdlMpoCCgxVc2FnZVNlcnZpY2UShQEKD0dldFNlc3Npb25Vc2FnZRIgLnVzYWdlLnYxLkdldFNlc3Npb25Vc2FnZVJlcXVlc3QaIS51c2FnZS52MS5HZXRTZXNzaW9uVXNhZ2VSZXNwb25zZSItgtPkkwInEiUvX3BkL2FwaS92MS91c2Vycy9Ac2VsZi91c2FnZS9zZXNzaW9uEoEBCg5HZXRXZWVrbHlVc2FnZRIfLnVzYWdlLnYxLkdldFdlZWtseVVzYWdlUmVxdWVzdBogLnVzYWdlLnYxLkdldFdlZWtseVVzYWdlUmVzcG9uc2UiLILT5JMCJhIkL19wZC9hcGkvdjEvdXNlcnMvQHNlbGYvdXNhZ2Uvd2Vla2x5QocBCgxjb20udXNhZ2UudjFCClVzYWdlUHJvdG9QAVoqcGFwZXJkZWJ1Z2dlci9wa2cvZ2VuL2FwaS91c2FnZS92MTt1c2FnZXYxogIDVVhYqgIIVXNhZ2UuVjHKAghVc2FnZVxWMeICFFVzYWdlXFYxXEdQQk1ldGFkYXRh6gIJVXNhZ2U6OlYxYgZwcm90bzM", [file_google_api_annotations, file_google_protobuf_timestamp]); + fileDesc("ChR1c2FnZS92MS91c2FnZS5wcm90bxIIdXNhZ2UudjEibAoLTW9kZWxUb2tlbnMSFQoNcHJvbXB0X3Rva2VucxgBIAEoAxIZChFjb21wbGV0aW9uX3Rva2VucxgCIAEoAxIUCgx0b3RhbF90b2tlbnMYAyABKAMSFQoNcmVxdWVzdF9jb3VudBgEIAEoAyK8AQoMU2Vzc2lvblVzYWdlEjIKDnNlc3Npb25fZXhwaXJ5GAEgASgLMhouZ29vZ2xlLnByb3RvYnVmLlRpbWVzdGFtcBIyCgZtb2RlbHMYAiADKAsyIi51c2FnZS52MS5TZXNzaW9uVXNhZ2UuTW9kZWxzRW50cnkaRAoLTW9kZWxzRW50cnkSCwoDa2V5GAEgASgJEiQKBXZhbHVlGAIgASgLMhUudXNhZ2UudjEuTW9kZWxUb2tlbnM6AjgBIp0BCgtXZWVrbHlVc2FnZRIxCgZtb2RlbHMYASADKAsyIS51c2FnZS52MS5XZWVrbHlVc2FnZS5Nb2RlbHNFbnRyeRIVCg1zZXNzaW9uX2NvdW50GAIgASgDGkQKC01vZGVsc0VudHJ5EgsKA2tleRgBIAEoCRIkCgV2YWx1ZRgCIAEoCzIVLnVzYWdlLnYxLk1vZGVsVG9rZW5zOgI4ASIYChZHZXRTZXNzaW9uVXNhZ2VSZXF1ZXN0IkIKF0dldFNlc3Npb25Vc2FnZVJlc3BvbnNlEicKB3Nlc3Npb24YASABKAsyFi51c2FnZS52MS5TZXNzaW9uVXNhZ2UiFwoVR2V0V2Vla2x5VXNhZ2VSZXF1ZXN0Ij4KFkdldFdlZWtseVVzYWdlUmVzcG9uc2USJAoFdXNhZ2UYASABKAsyFS51c2FnZS52MS5XZWVrbHlVc2FnZTKaAgoMVXNhZ2VTZXJ2aWNlEoUBCg9HZXRTZXNzaW9uVXNhZ2USIC51c2FnZS52MS5HZXRTZXNzaW9uVXNhZ2VSZXF1ZXN0GiEudXNhZ2UudjEuR2V0U2Vzc2lvblVzYWdlUmVzcG9uc2UiLYLT5JMCJxIlL19wZC9hcGkvdjEvdXNlcnMvQHNlbGYvdXNhZ2Uvc2Vzc2lvbhKBAQoOR2V0V2Vla2x5VXNhZ2USHy51c2FnZS52MS5HZXRXZWVrbHlVc2FnZVJlcXVlc3QaIC51c2FnZS52MS5HZXRXZWVrbHlVc2FnZVJlc3BvbnNlIiyC0+STAiYSJC9fcGQvYXBpL3YxL3VzZXJzL0BzZWxmL3VzYWdlL3dlZWtseUKHAQoMY29tLnVzYWdlLnYxQgpVc2FnZVByb3RvUAFaKnBhcGVyZGVidWdnZXIvcGtnL2dlbi9hcGkvdXNhZ2UvdjE7dXNhZ2V2MaICA1VYWKoCCFVzYWdlLlYxygIIVXNhZ2VcVjHiAhRVc2FnZVxWMVxHUEJNZXRhZGF0YeoCCVVzYWdlOjpWMWIGcHJvdG8z", [file_google_api_annotations, file_google_protobuf_timestamp]); + +/** + * @generated from message usage.v1.ModelTokens + */ +export type ModelTokens = Message<"usage.v1.ModelTokens"> & { + /** + * @generated from field: int64 prompt_tokens = 1; + */ + promptTokens: bigint; + + /** + * @generated from field: int64 completion_tokens = 2; + */ + completionTokens: bigint; + + /** + * @generated from field: int64 total_tokens = 3; + */ + totalTokens: bigint; + + /** + * @generated from field: int64 request_count = 4; + */ + requestCount: bigint; +}; + +/** + * Describes the message usage.v1.ModelTokens. + * Use `create(ModelTokensSchema)` to create a new message. + */ +export const ModelTokensSchema: GenMessage = /*@__PURE__*/ + messageDesc(file_usage_v1_usage, 0); /** * @generated from message usage.v1.SessionUsage @@ -25,9 +57,11 @@ export type SessionUsage = Message<"usage.v1.SessionUsage"> & { sessionExpiry?: Timestamp; /** - * @generated from field: int64 total_tokens = 2; + * Tokens per model (model_slug -> tokens) + * + * @generated from field: map models = 2; */ - totalTokens: bigint; + models: { [key: string]: ModelTokens }; }; /** @@ -35,16 +69,23 @@ export type SessionUsage = Message<"usage.v1.SessionUsage"> & { * Use `create(SessionUsageSchema)` to create a new message. */ export const SessionUsageSchema: GenMessage = /*@__PURE__*/ - messageDesc(file_usage_v1_usage, 0); + messageDesc(file_usage_v1_usage, 1); /** * @generated from message usage.v1.WeeklyUsage */ export type WeeklyUsage = Message<"usage.v1.WeeklyUsage"> & { /** - * @generated from field: int64 total_tokens = 1; + * Tokens per model (model_slug -> tokens) + * + * @generated from field: map models = 1; */ - totalTokens: bigint; + models: { [key: string]: ModelTokens }; + + /** + * @generated from field: int64 session_count = 2; + */ + sessionCount: bigint; }; /** @@ -52,7 +93,7 @@ export type WeeklyUsage = Message<"usage.v1.WeeklyUsage"> & { * Use `create(WeeklyUsageSchema)` to create a new message. */ export const WeeklyUsageSchema: GenMessage = /*@__PURE__*/ - messageDesc(file_usage_v1_usage, 1); + messageDesc(file_usage_v1_usage, 2); /** * @generated from message usage.v1.GetSessionUsageRequest @@ -65,7 +106,7 @@ export type GetSessionUsageRequest = Message<"usage.v1.GetSessionUsageRequest"> * Use `create(GetSessionUsageRequestSchema)` to create a new message. */ export const GetSessionUsageRequestSchema: GenMessage = /*@__PURE__*/ - messageDesc(file_usage_v1_usage, 2); + messageDesc(file_usage_v1_usage, 3); /** * @generated from message usage.v1.GetSessionUsageResponse @@ -84,7 +125,7 @@ export type GetSessionUsageResponse = Message<"usage.v1.GetSessionUsageResponse" * Use `create(GetSessionUsageResponseSchema)` to create a new message. */ export const GetSessionUsageResponseSchema: GenMessage = /*@__PURE__*/ - messageDesc(file_usage_v1_usage, 3); + messageDesc(file_usage_v1_usage, 4); /** * @generated from message usage.v1.GetWeeklyUsageRequest @@ -97,7 +138,7 @@ export type GetWeeklyUsageRequest = Message<"usage.v1.GetWeeklyUsageRequest"> & * Use `create(GetWeeklyUsageRequestSchema)` to create a new message. */ export const GetWeeklyUsageRequestSchema: GenMessage = /*@__PURE__*/ - messageDesc(file_usage_v1_usage, 4); + messageDesc(file_usage_v1_usage, 5); /** * @generated from message usage.v1.GetWeeklyUsageResponse @@ -114,7 +155,7 @@ export type GetWeeklyUsageResponse = Message<"usage.v1.GetWeeklyUsageResponse"> * Use `create(GetWeeklyUsageResponseSchema)` to create a new message. */ export const GetWeeklyUsageResponseSchema: GenMessage = /*@__PURE__*/ - messageDesc(file_usage_v1_usage, 5); + messageDesc(file_usage_v1_usage, 6); /** * @generated from service usage.v1.UsageService diff --git a/webapp/_webapp/src/views/usage/index.tsx b/webapp/_webapp/src/views/usage/index.tsx index 36756be7..bebca148 100644 --- a/webapp/_webapp/src/views/usage/index.tsx +++ b/webapp/_webapp/src/views/usage/index.tsx @@ -4,6 +4,7 @@ import { useState, useEffect } from "react"; import { TabHeader } from "../../components/tab-header"; import { useGetSessionUsageQuery, useGetWeeklyUsageQuery } from "../../query"; import CellWrapper from "../../components/cell-wrapper"; +import type { ModelTokens } from "../../pkg/gen/apiclient/usage/v1/usage_pb"; const formatNumber = (n: bigint | number | undefined): string => { if (n === undefined) return "0"; @@ -50,11 +51,26 @@ const SectionTitle = ({ children }: { children: React.ReactNode }) => { return
{children}
; }; -const StatItem = ({ label, value }: { label: string; value: string }) => { +const ModelUsageItem = ({ model, tokens }: { model: string; tokens: ModelTokens }) => { return ( -
- {label} - {value} +
+
{model}
+
+ Total + {formatNumber(tokens.totalTokens)} +
+
+ Prompt + {formatNumber(tokens.promptTokens)} +
+
+ Completion + {formatNumber(tokens.completionTokens)} +
+
+ Requests + {formatNumber(tokens.requestCount)} +
); }; @@ -101,6 +117,9 @@ export const Usage = () => { const session = sessionData?.session; const weekly = weeklyData?.usage; + const sessionModels = session?.models ? Object.entries(session.models) : []; + const weeklyModels = weekly?.models ? Object.entries(weekly.models) : []; + return (
@@ -112,10 +131,12 @@ export const Usage = () => { ({formatTimeRemaining(session.sessionExpiry)}) )} - {session ? ( + {session && sessionModels.length > 0 ? ( -
- +
+ {sessionModels.map(([model, tokens]) => ( + + ))}
) : ( @@ -126,11 +147,13 @@ export const Usage = () => { - Weekly Limits - {weekly ? ( + Weekly Usage + {weekly && weeklyModels.length > 0 ? ( -
- +
+ {weeklyModels.map(([model, tokens]) => ( + + ))}
) : ( From 18402b2b92d2e8e7a8cd2b1555d2a8efe125742f Mon Sep 17 00:00:00 2001 From: wjiayis Date: Wed, 4 Mar 2026 20:35:28 +0800 Subject: [PATCH 11/13] feat: display USD token prices --- internal/api/server.go | 17 +- internal/api/usage/get_session_usage.go | 25 ++- internal/api/usage/get_weekly_usage.go | 25 ++- internal/api/usage/server.go | 11 +- internal/models/model_pricing.go | 23 ++ internal/services/pricing.go | 204 ++++++++++++++++++ internal/services/usage.go | 23 +- internal/wire.go | 1 + internal/wire_gen.go | 7 +- pkg/gen/api/usage/v1/usage.pb.go | 37 +++- proto/usage/v1/usage.proto | 3 + .../pkg/gen/apiclient/usage/v1/usage_pb.ts | 23 +- webapp/_webapp/src/views/usage/index.tsx | 42 ++-- 13 files changed, 396 insertions(+), 45 deletions(-) create mode 100644 internal/models/model_pricing.go create mode 100644 internal/services/pricing.go diff --git a/internal/api/server.go b/internal/api/server.go index d8e9b36a..c377d5a3 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -11,6 +11,7 @@ import ( "paperdebugger/internal/libs/logger" "paperdebugger/internal/libs/metadatautil" "paperdebugger/internal/libs/shared" + "paperdebugger/internal/services" authv1 "paperdebugger/pkg/gen/api/auth/v1" chatv1 "paperdebugger/pkg/gen/api/chat/v1" chatv2 "paperdebugger/pkg/gen/api/chat/v2" @@ -31,8 +32,9 @@ import ( ) type Server struct { - grpcServer *GrpcServer - ginServer *GinServer + grpcServer *GrpcServer + ginServer *GinServer + pricingService *services.PricingService logger *logger.Logger } @@ -40,16 +42,21 @@ type Server struct { func NewServer( grpcServer *GrpcServer, ginServer *GinServer, + pricingService *services.PricingService, logger *logger.Logger, ) *Server { return &Server{ - grpcServer: grpcServer, - ginServer: ginServer, - logger: logger, + grpcServer: grpcServer, + ginServer: ginServer, + pricingService: pricingService, + logger: logger, } } func (s *Server) Run(addr string) { + // Start the pricing updater in the background + s.pricingService.StartPriceUpdater(context.Background()) + listener, err := net.Listen("tcp", ":0") if err != nil { s.logger.Fatalf("failed to start grpc server listener: %v", err) diff --git a/internal/api/usage/get_session_usage.go b/internal/api/usage/get_session_usage.go index 5eb35cad..0b569943 100644 --- a/internal/api/usage/get_session_usage.go +++ b/internal/api/usage/get_session_usage.go @@ -4,6 +4,7 @@ import ( "context" "paperdebugger/internal/libs/contextutil" + "paperdebugger/internal/models" usagev1 "paperdebugger/pkg/gen/api/usage/v1" "google.golang.org/protobuf/types/known/timestamppb" @@ -29,21 +30,37 @@ func (s *UsageServer) GetSessionUsage( }, nil } - // Convert models map to proto format - models := make(map[string]*usagev1.ModelTokens) + // Get pricing map for cost calculation + pricingMap, err := s.pricingService.GetPricingMap(ctx) + if err != nil { + s.logger.Warn("Failed to get pricing map", "error", err) + pricingMap = make(map[string]*models.ModelPricing) + } + + // Convert models map to proto format and calculate costs + protoModels := make(map[string]*usagev1.ModelTokens) + var totalCostUSD float64 for modelName, tokens := range session.Models { - models[modelName] = &usagev1.ModelTokens{ + var costUSD float64 + if pricing, ok := pricingMap[modelName]; ok && pricing != nil { + costUSD = float64(tokens.PromptTokens)*pricing.PromptPrice + + float64(tokens.CompletionTokens)*pricing.CompletionPrice + totalCostUSD += costUSD + } + protoModels[modelName] = &usagev1.ModelTokens{ PromptTokens: tokens.PromptTokens, CompletionTokens: tokens.CompletionTokens, TotalTokens: tokens.TotalTokens, RequestCount: tokens.RequestCount, + CostUsd: costUSD, } } return &usagev1.GetSessionUsageResponse{ Session: &usagev1.SessionUsage{ SessionExpiry: timestamppb.New(session.SessionExpiry.Time()), - Models: models, + Models: protoModels, + TotalCostUsd: totalCostUSD, }, }, nil } diff --git a/internal/api/usage/get_weekly_usage.go b/internal/api/usage/get_weekly_usage.go index e244b3c6..c791858f 100644 --- a/internal/api/usage/get_weekly_usage.go +++ b/internal/api/usage/get_weekly_usage.go @@ -4,6 +4,7 @@ import ( "context" "paperdebugger/internal/libs/contextutil" + "paperdebugger/internal/models" usagev1 "paperdebugger/pkg/gen/api/usage/v1" ) @@ -21,21 +22,37 @@ func (s *UsageServer) GetWeeklyUsage( return nil, err } - // Convert models map to proto format - models := make(map[string]*usagev1.ModelTokens) + // Get pricing map for cost calculation + pricingMap, err := s.pricingService.GetPricingMap(ctx) + if err != nil { + s.logger.Warn("Failed to get pricing map", "error", err) + pricingMap = make(map[string]*models.ModelPricing) + } + + // Convert models map to proto format and calculate costs + protoModels := make(map[string]*usagev1.ModelTokens) + var totalCostUSD float64 for modelName, tokens := range stats.Models { - models[modelName] = &usagev1.ModelTokens{ + var costUSD float64 + if pricing, ok := pricingMap[modelName]; ok && pricing != nil { + costUSD = float64(tokens.PromptTokens)*pricing.PromptPrice + + float64(tokens.CompletionTokens)*pricing.CompletionPrice + totalCostUSD += costUSD + } + protoModels[modelName] = &usagev1.ModelTokens{ PromptTokens: tokens.PromptTokens, CompletionTokens: tokens.CompletionTokens, TotalTokens: tokens.TotalTokens, RequestCount: tokens.RequestCount, + CostUsd: costUSD, } } return &usagev1.GetWeeklyUsageResponse{ Usage: &usagev1.WeeklyUsage{ - Models: models, + Models: protoModels, SessionCount: stats.SessionCount, + TotalCostUsd: totalCostUSD, }, }, nil } diff --git a/internal/api/usage/server.go b/internal/api/usage/server.go index 5d64854e..69a4eb77 100644 --- a/internal/api/usage/server.go +++ b/internal/api/usage/server.go @@ -9,16 +9,19 @@ import ( type UsageServer struct { usagev1.UnimplementedUsageServiceServer - usageService *services.UsageService - logger *logger.Logger + usageService *services.UsageService + pricingService *services.PricingService + logger *logger.Logger } func NewUsageServer( usageService *services.UsageService, + pricingService *services.PricingService, logger *logger.Logger, ) usagev1.UsageServiceServer { return &UsageServer{ - usageService: usageService, - logger: logger, + usageService: usageService, + pricingService: pricingService, + logger: logger, } } diff --git a/internal/models/model_pricing.go b/internal/models/model_pricing.go new file mode 100644 index 00000000..adfbf114 --- /dev/null +++ b/internal/models/model_pricing.go @@ -0,0 +1,23 @@ +package models + +import ( + "time" + + "go.mongodb.org/mongo-driver/v2/bson" +) + +// ModelPricing stores the pricing information for an LLM model. +// Prices are in USD per token. +type ModelPricing struct { + ID bson.ObjectID `bson:"_id"` + ModelID string `bson:"model_id"` // e.g., "openai/gpt-4" + ModelSlug string `bson:"model_slug"` // e.g., "gpt-4" (short name used in our app) + Name string `bson:"name"` // e.g., "OpenAI: GPT-4" + PromptPrice float64 `bson:"prompt_price"` // USD per token + CompletionPrice float64 `bson:"completion_price"` // USD per token + UpdatedAt time.Time `bson:"updated_at"` +} + +func (m ModelPricing) CollectionName() string { + return "model_pricing" +} diff --git a/internal/services/pricing.go b/internal/services/pricing.go new file mode 100644 index 00000000..3862c689 --- /dev/null +++ b/internal/services/pricing.go @@ -0,0 +1,204 @@ +package services + +import ( + "context" + "encoding/json" + "net/http" + "strconv" + "strings" + "time" + + "paperdebugger/internal/libs/cfg" + "paperdebugger/internal/libs/db" + "paperdebugger/internal/libs/logger" + "paperdebugger/internal/models" + + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" + "go.mongodb.org/mongo-driver/v2/mongo/options" +) + +const ( + OpenRouterModelsURL = "https://openrouter.ai/api/v1/models" + PriceRefreshInterval = 24 * time.Hour +) + +type PricingService struct { + BaseService + collection *mongo.Collection + httpClient *http.Client +} + +// OpenRouterModel represents a model from the OpenRouter API. +type OpenRouterModel struct { + ID string `json:"id"` + Name string `json:"name"` + Pricing struct { + Prompt string `json:"prompt"` + Completion string `json:"completion"` + } `json:"pricing"` +} + +// OpenRouterResponse is the response from the OpenRouter models API. +type OpenRouterResponse struct { + Data []OpenRouterModel `json:"data"` +} + +func NewPricingService(db *db.DB, cfg *cfg.Cfg, logger *logger.Logger) *PricingService { + base := NewBaseService(db, cfg, logger) + return &PricingService{ + BaseService: base, + collection: base.db.Collection((models.ModelPricing{}).CollectionName()), + httpClient: &http.Client{ + Timeout: 30 * time.Second, + }, + } +} + +// FetchAndUpdatePrices fetches model prices from OpenRouter and updates the database. +func (s *PricingService) FetchAndUpdatePrices(ctx context.Context) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, OpenRouterModelsURL, nil) + if err != nil { + return err + } + + resp, err := s.httpClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + var openRouterResp OpenRouterResponse + if err := json.NewDecoder(resp.Body).Decode(&openRouterResp); err != nil { + return err + } + + now := time.Now() + for _, model := range openRouterResp.Data { + promptPrice, _ := strconv.ParseFloat(model.Pricing.Prompt, 64) + completionPrice, _ := strconv.ParseFloat(model.Pricing.Completion, 64) + + // Skip models with no pricing + if promptPrice == 0 && completionPrice == 0 { + continue + } + + // Extract model slug (short name) from the full model ID + // e.g., "openai/gpt-4" -> "gpt-4" + modelSlug := extractModelSlug(model.ID) + + filter := bson.M{"model_id": model.ID} + update := bson.M{ + "$set": bson.M{ + "model_id": model.ID, + "model_slug": modelSlug, + "name": model.Name, + "prompt_price": promptPrice, + "completion_price": completionPrice, + "updated_at": now, + }, + "$setOnInsert": bson.M{ + "_id": bson.NewObjectID(), + }, + } + opts := options.UpdateOne().SetUpsert(true) + _, err := s.collection.UpdateOne(ctx, filter, update, opts) + if err != nil { + s.logger.Warn("Failed to update model pricing", "modelID", model.ID, "error", err) + } + } + + s.logger.Info("Updated model pricing", "count", len(openRouterResp.Data)) + return nil +} + +// GetPricing returns the pricing for a model by its slug. +func (s *PricingService) GetPricing(ctx context.Context, modelSlug string) (*models.ModelPricing, error) { + // Try exact match first + filter := bson.M{"model_slug": modelSlug} + var pricing models.ModelPricing + err := s.collection.FindOne(ctx, filter).Decode(&pricing) + if err == nil { + return &pricing, nil + } + if err != mongo.ErrNoDocuments { + return nil, err + } + + // Try partial match (model slug might be a prefix) + filter = bson.M{"model_slug": bson.M{"$regex": "^" + modelSlug}} + err = s.collection.FindOne(ctx, filter).Decode(&pricing) + if err == mongo.ErrNoDocuments { + return nil, nil + } + if err != nil { + return nil, err + } + return &pricing, nil +} + +// GetAllPricing returns all model pricing. +func (s *PricingService) GetAllPricing(ctx context.Context) ([]models.ModelPricing, error) { + cursor, err := s.collection.Find(ctx, bson.M{}) + if err != nil { + return nil, err + } + defer cursor.Close(ctx) + + var pricings []models.ModelPricing + if err := cursor.All(ctx, &pricings); err != nil { + return nil, err + } + return pricings, nil +} + +// GetPricingMap returns a map of model slug to pricing for quick lookup. +func (s *PricingService) GetPricingMap(ctx context.Context) (map[string]*models.ModelPricing, error) { + pricings, err := s.GetAllPricing(ctx) + if err != nil { + return nil, err + } + + result := make(map[string]*models.ModelPricing) + for i := range pricings { + result[pricings[i].ModelSlug] = &pricings[i] + } + return result, nil +} + +// extractModelSlug extracts the short model name from a full model ID. +// e.g., "openai/gpt-4" -> "gpt-4", "anthropic/claude-3-opus" -> "claude-3-opus" +func extractModelSlug(modelID string) string { + parts := strings.Split(modelID, "/") + if len(parts) > 1 { + return parts[len(parts)-1] + } + return modelID +} + +// StartPriceUpdater starts a background goroutine that periodically updates prices. +func (s *PricingService) StartPriceUpdater(ctx context.Context) { + // Fetch immediately on startup + go func() { + if err := s.FetchAndUpdatePrices(ctx); err != nil { + s.logger.Error("Failed to fetch initial model pricing", "error", err) + } + }() + + // Then fetch periodically + go func() { + ticker := time.NewTicker(PriceRefreshInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := s.FetchAndUpdatePrices(context.Background()); err != nil { + s.logger.Error("Failed to update model pricing", "error", err) + } + } + } + }() +} diff --git a/internal/services/usage.go b/internal/services/usage.go index 332e243e..4c9d75f4 100644 --- a/internal/services/usage.go +++ b/internal/services/usage.go @@ -31,15 +31,30 @@ type UsageRecord struct { // ModelUsageStats stores aggregated usage statistics for a specific model. type ModelUsageStats struct { - PromptTokens int64 `bson:"prompt_tokens"` - CompletionTokens int64 `bson:"completion_tokens"` - TotalTokens int64 `bson:"total_tokens"` - RequestCount int64 `bson:"request_count"` + PromptTokens int64 `bson:"prompt_tokens"` + CompletionTokens int64 `bson:"completion_tokens"` + TotalTokens int64 `bson:"total_tokens"` + RequestCount int64 `bson:"request_count"` + CostUSD float64 `bson:"-"` // Calculated field, not stored } type UsageStats struct { Models map[string]*ModelUsageStats `bson:"models"` SessionCount int64 `bson:"session_count"` + TotalCostUSD float64 `bson:"-"` // Calculated field, not stored +} + +// CalculateCosts calculates the cost in USD for each model and total. +// pricingMap maps model slug to pricing info. +func (s *UsageStats) CalculateCosts(pricingMap map[string]*models.ModelPricing) { + s.TotalCostUSD = 0 + for modelSlug, stats := range s.Models { + if pricing, ok := pricingMap[modelSlug]; ok && pricing != nil { + stats.CostUSD = float64(stats.PromptTokens)*pricing.PromptPrice + + float64(stats.CompletionTokens)*pricing.CompletionPrice + s.TotalCostUSD += stats.CostUSD + } + } } func NewUsageService(db *db.DB, cfg *cfg.Cfg, logger *logger.Logger) *UsageService { diff --git a/internal/wire.go b/internal/wire.go index 52e6ff28..8c7a111e 100644 --- a/internal/wire.go +++ b/internal/wire.go @@ -46,6 +46,7 @@ var Set = wire.NewSet( services.NewPromptService, services.NewOAuthService, services.NewUsageService, + services.NewPricingService, cfg.GetCfg, logger.GetLogger, diff --git a/internal/wire_gen.go b/internal/wire_gen.go index a706db0f..a8d00490 100644 --- a/internal/wire_gen.go +++ b/internal/wire_gen.go @@ -47,15 +47,16 @@ func InitializeApp() (*api.Server, error) { userServiceServer := user.NewUserServer(userService, promptService, cfgCfg, loggerLogger) projectServiceServer := project.NewProjectServer(projectService, loggerLogger, cfgCfg) commentServiceServer := comment.NewCommentServer(projectService, chatService, reverseCommentService, loggerLogger, cfgCfg) - usageServiceServer := usage.NewUsageServer(usageService, loggerLogger) + pricingService := services.NewPricingService(dbDB, cfgCfg, loggerLogger) + usageServiceServer := usage.NewUsageServer(usageService, pricingService, loggerLogger) grpcServer := api.NewGrpcServer(userService, cfgCfg, authServiceServer, chatServiceServer, chatv2ChatServiceServer, userServiceServer, projectServiceServer, commentServiceServer, usageServiceServer) oAuthService := services.NewOAuthService(dbDB, cfgCfg, loggerLogger) oAuthHandler := auth.NewOAuthHandler(oAuthService) ginServer := api.NewGinServer(cfgCfg, oAuthHandler) - server := api.NewServer(grpcServer, ginServer, loggerLogger) + server := api.NewServer(grpcServer, ginServer, pricingService, loggerLogger) return server, nil } // wire.go: -var Set = wire.NewSet(api.NewServer, api.NewGrpcServer, api.NewGinServer, auth.NewOAuthHandler, auth.NewAuthServer, chat.NewChatServer, chat.NewChatServerV2, user.NewUserServer, project.NewProjectServer, comment.NewCommentServer, usage.NewUsageServer, client.NewAIClient, client.NewAIClientV2, services.NewReverseCommentService, services.NewChatService, services.NewChatServiceV2, services.NewTokenService, services.NewUserService, services.NewProjectService, services.NewPromptService, services.NewOAuthService, services.NewUsageService, cfg.GetCfg, logger.GetLogger, db.NewDB) +var Set = wire.NewSet(api.NewServer, api.NewGrpcServer, api.NewGinServer, auth.NewOAuthHandler, auth.NewAuthServer, chat.NewChatServer, chat.NewChatServerV2, user.NewUserServer, project.NewProjectServer, comment.NewCommentServer, usage.NewUsageServer, client.NewAIClient, client.NewAIClientV2, services.NewReverseCommentService, services.NewChatService, services.NewChatServiceV2, services.NewTokenService, services.NewUserService, services.NewProjectService, services.NewPromptService, services.NewOAuthService, services.NewUsageService, services.NewPricingService, cfg.GetCfg, logger.GetLogger, db.NewDB) diff --git a/pkg/gen/api/usage/v1/usage.pb.go b/pkg/gen/api/usage/v1/usage.pb.go index 38d4fd62..33530afd 100644 --- a/pkg/gen/api/usage/v1/usage.pb.go +++ b/pkg/gen/api/usage/v1/usage.pb.go @@ -29,6 +29,7 @@ type ModelTokens struct { CompletionTokens int64 `protobuf:"varint,2,opt,name=completion_tokens,json=completionTokens,proto3" json:"completion_tokens,omitempty"` TotalTokens int64 `protobuf:"varint,3,opt,name=total_tokens,json=totalTokens,proto3" json:"total_tokens,omitempty"` RequestCount int64 `protobuf:"varint,4,opt,name=request_count,json=requestCount,proto3" json:"request_count,omitempty"` + CostUsd float64 `protobuf:"fixed64,5,opt,name=cost_usd,json=costUsd,proto3" json:"cost_usd,omitempty"` // Cost in USD for this model unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -91,11 +92,19 @@ func (x *ModelTokens) GetRequestCount() int64 { return 0 } +func (x *ModelTokens) GetCostUsd() float64 { + if x != nil { + return x.CostUsd + } + return 0 +} + type SessionUsage struct { state protoimpl.MessageState `protogen:"open.v1"` SessionExpiry *timestamppb.Timestamp `protobuf:"bytes,1,opt,name=session_expiry,json=sessionExpiry,proto3" json:"session_expiry,omitempty"` // Tokens per model (model_slug -> tokens) Models map[string]*ModelTokens `protobuf:"bytes,2,rep,name=models,proto3" json:"models,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + TotalCostUsd float64 `protobuf:"fixed64,3,opt,name=total_cost_usd,json=totalCostUsd,proto3" json:"total_cost_usd,omitempty"` // Total cost in USD across all models unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -144,11 +153,19 @@ func (x *SessionUsage) GetModels() map[string]*ModelTokens { return nil } +func (x *SessionUsage) GetTotalCostUsd() float64 { + if x != nil { + return x.TotalCostUsd + } + return 0 +} + type WeeklyUsage struct { state protoimpl.MessageState `protogen:"open.v1"` // Tokens per model (model_slug -> tokens) Models map[string]*ModelTokens `protobuf:"bytes,1,rep,name=models,proto3" json:"models,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` SessionCount int64 `protobuf:"varint,2,opt,name=session_count,json=sessionCount,proto3" json:"session_count,omitempty"` + TotalCostUsd float64 `protobuf:"fixed64,3,opt,name=total_cost_usd,json=totalCostUsd,proto3" json:"total_cost_usd,omitempty"` // Total cost in USD across all models unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -197,6 +214,13 @@ func (x *WeeklyUsage) GetSessionCount() int64 { return 0 } +func (x *WeeklyUsage) GetTotalCostUsd() float64 { + if x != nil { + return x.TotalCostUsd + } + return 0 +} + type GetSessionUsageRequest struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -362,21 +386,24 @@ var File_usage_v1_usage_proto protoreflect.FileDescriptor const file_usage_v1_usage_proto_rawDesc = "" + "\n" + - "\x14usage/v1/usage.proto\x12\busage.v1\x1a\x1cgoogle/api/annotations.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\xa7\x01\n" + + "\x14usage/v1/usage.proto\x12\busage.v1\x1a\x1cgoogle/api/annotations.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\xc2\x01\n" + "\vModelTokens\x12#\n" + "\rprompt_tokens\x18\x01 \x01(\x03R\fpromptTokens\x12+\n" + "\x11completion_tokens\x18\x02 \x01(\x03R\x10completionTokens\x12!\n" + "\ftotal_tokens\x18\x03 \x01(\x03R\vtotalTokens\x12#\n" + - "\rrequest_count\x18\x04 \x01(\x03R\frequestCount\"\xdf\x01\n" + + "\rrequest_count\x18\x04 \x01(\x03R\frequestCount\x12\x19\n" + + "\bcost_usd\x18\x05 \x01(\x01R\acostUsd\"\x85\x02\n" + "\fSessionUsage\x12A\n" + "\x0esession_expiry\x18\x01 \x01(\v2\x1a.google.protobuf.TimestampR\rsessionExpiry\x12:\n" + - "\x06models\x18\x02 \x03(\v2\".usage.v1.SessionUsage.ModelsEntryR\x06models\x1aP\n" + + "\x06models\x18\x02 \x03(\v2\".usage.v1.SessionUsage.ModelsEntryR\x06models\x12$\n" + + "\x0etotal_cost_usd\x18\x03 \x01(\x01R\ftotalCostUsd\x1aP\n" + "\vModelsEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12+\n" + - "\x05value\x18\x02 \x01(\v2\x15.usage.v1.ModelTokensR\x05value:\x028\x01\"\xbf\x01\n" + + "\x05value\x18\x02 \x01(\v2\x15.usage.v1.ModelTokensR\x05value:\x028\x01\"\xe5\x01\n" + "\vWeeklyUsage\x129\n" + "\x06models\x18\x01 \x03(\v2!.usage.v1.WeeklyUsage.ModelsEntryR\x06models\x12#\n" + - "\rsession_count\x18\x02 \x01(\x03R\fsessionCount\x1aP\n" + + "\rsession_count\x18\x02 \x01(\x03R\fsessionCount\x12$\n" + + "\x0etotal_cost_usd\x18\x03 \x01(\x01R\ftotalCostUsd\x1aP\n" + "\vModelsEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12+\n" + "\x05value\x18\x02 \x01(\v2\x15.usage.v1.ModelTokensR\x05value:\x028\x01\"\x18\n" + diff --git a/proto/usage/v1/usage.proto b/proto/usage/v1/usage.proto index 5696fa06..f5f480ce 100644 --- a/proto/usage/v1/usage.proto +++ b/proto/usage/v1/usage.proto @@ -22,18 +22,21 @@ message ModelTokens { int64 completion_tokens = 2; int64 total_tokens = 3; int64 request_count = 4; + double cost_usd = 5; // Cost in USD for this model } message SessionUsage { google.protobuf.Timestamp session_expiry = 1; // Tokens per model (model_slug -> tokens) map models = 2; + double total_cost_usd = 3; // Total cost in USD across all models } message WeeklyUsage { // Tokens per model (model_slug -> tokens) map models = 1; int64 session_count = 2; + double total_cost_usd = 3; // Total cost in USD across all models } message GetSessionUsageRequest {} diff --git a/webapp/_webapp/src/pkg/gen/apiclient/usage/v1/usage_pb.ts b/webapp/_webapp/src/pkg/gen/apiclient/usage/v1/usage_pb.ts index efba93f5..e38175ee 100644 --- a/webapp/_webapp/src/pkg/gen/apiclient/usage/v1/usage_pb.ts +++ b/webapp/_webapp/src/pkg/gen/apiclient/usage/v1/usage_pb.ts @@ -13,7 +13,7 @@ import type { Message } from "@bufbuild/protobuf"; * Describes the file usage/v1/usage.proto. */ export const file_usage_v1_usage: GenFile = /*@__PURE__*/ - fileDesc("ChR1c2FnZS92MS91c2FnZS5wcm90bxIIdXNhZ2UudjEibAoLTW9kZWxUb2tlbnMSFQoNcHJvbXB0X3Rva2VucxgBIAEoAxIZChFjb21wbGV0aW9uX3Rva2VucxgCIAEoAxIUCgx0b3RhbF90b2tlbnMYAyABKAMSFQoNcmVxdWVzdF9jb3VudBgEIAEoAyK8AQoMU2Vzc2lvblVzYWdlEjIKDnNlc3Npb25fZXhwaXJ5GAEgASgLMhouZ29vZ2xlLnByb3RvYnVmLlRpbWVzdGFtcBIyCgZtb2RlbHMYAiADKAsyIi51c2FnZS52MS5TZXNzaW9uVXNhZ2UuTW9kZWxzRW50cnkaRAoLTW9kZWxzRW50cnkSCwoDa2V5GAEgASgJEiQKBXZhbHVlGAIgASgLMhUudXNhZ2UudjEuTW9kZWxUb2tlbnM6AjgBIp0BCgtXZWVrbHlVc2FnZRIxCgZtb2RlbHMYASADKAsyIS51c2FnZS52MS5XZWVrbHlVc2FnZS5Nb2RlbHNFbnRyeRIVCg1zZXNzaW9uX2NvdW50GAIgASgDGkQKC01vZGVsc0VudHJ5EgsKA2tleRgBIAEoCRIkCgV2YWx1ZRgCIAEoCzIVLnVzYWdlLnYxLk1vZGVsVG9rZW5zOgI4ASIYChZHZXRTZXNzaW9uVXNhZ2VSZXF1ZXN0IkIKF0dldFNlc3Npb25Vc2FnZVJlc3BvbnNlEicKB3Nlc3Npb24YASABKAsyFi51c2FnZS52MS5TZXNzaW9uVXNhZ2UiFwoVR2V0V2Vla2x5VXNhZ2VSZXF1ZXN0Ij4KFkdldFdlZWtseVVzYWdlUmVzcG9uc2USJAoFdXNhZ2UYASABKAsyFS51c2FnZS52MS5XZWVrbHlVc2FnZTKaAgoMVXNhZ2VTZXJ2aWNlEoUBCg9HZXRTZXNzaW9uVXNhZ2USIC51c2FnZS52MS5HZXRTZXNzaW9uVXNhZ2VSZXF1ZXN0GiEudXNhZ2UudjEuR2V0U2Vzc2lvblVzYWdlUmVzcG9uc2UiLYLT5JMCJxIlL19wZC9hcGkvdjEvdXNlcnMvQHNlbGYvdXNhZ2Uvc2Vzc2lvbhKBAQoOR2V0V2Vla2x5VXNhZ2USHy51c2FnZS52MS5HZXRXZWVrbHlVc2FnZVJlcXVlc3QaIC51c2FnZS52MS5HZXRXZWVrbHlVc2FnZVJlc3BvbnNlIiyC0+STAiYSJC9fcGQvYXBpL3YxL3VzZXJzL0BzZWxmL3VzYWdlL3dlZWtseUKHAQoMY29tLnVzYWdlLnYxQgpVc2FnZVByb3RvUAFaKnBhcGVyZGVidWdnZXIvcGtnL2dlbi9hcGkvdXNhZ2UvdjE7dXNhZ2V2MaICA1VYWKoCCFVzYWdlLlYxygIIVXNhZ2VcVjHiAhRVc2FnZVxWMVxHUEJNZXRhZGF0YeoCCVVzYWdlOjpWMWIGcHJvdG8z", [file_google_api_annotations, file_google_protobuf_timestamp]); + fileDesc("ChR1c2FnZS92MS91c2FnZS5wcm90bxIIdXNhZ2UudjEifgoLTW9kZWxUb2tlbnMSFQoNcHJvbXB0X3Rva2VucxgBIAEoAxIZChFjb21wbGV0aW9uX3Rva2VucxgCIAEoAxIUCgx0b3RhbF90b2tlbnMYAyABKAMSFQoNcmVxdWVzdF9jb3VudBgEIAEoAxIQCghjb3N0X3VzZBgFIAEoASLUAQoMU2Vzc2lvblVzYWdlEjIKDnNlc3Npb25fZXhwaXJ5GAEgASgLMhouZ29vZ2xlLnByb3RvYnVmLlRpbWVzdGFtcBIyCgZtb2RlbHMYAiADKAsyIi51c2FnZS52MS5TZXNzaW9uVXNhZ2UuTW9kZWxzRW50cnkSFgoOdG90YWxfY29zdF91c2QYAyABKAEaRAoLTW9kZWxzRW50cnkSCwoDa2V5GAEgASgJEiQKBXZhbHVlGAIgASgLMhUudXNhZ2UudjEuTW9kZWxUb2tlbnM6AjgBIrUBCgtXZWVrbHlVc2FnZRIxCgZtb2RlbHMYASADKAsyIS51c2FnZS52MS5XZWVrbHlVc2FnZS5Nb2RlbHNFbnRyeRIVCg1zZXNzaW9uX2NvdW50GAIgASgDEhYKDnRvdGFsX2Nvc3RfdXNkGAMgASgBGkQKC01vZGVsc0VudHJ5EgsKA2tleRgBIAEoCRIkCgV2YWx1ZRgCIAEoCzIVLnVzYWdlLnYxLk1vZGVsVG9rZW5zOgI4ASIYChZHZXRTZXNzaW9uVXNhZ2VSZXF1ZXN0IkIKF0dldFNlc3Npb25Vc2FnZVJlc3BvbnNlEicKB3Nlc3Npb24YASABKAsyFi51c2FnZS52MS5TZXNzaW9uVXNhZ2UiFwoVR2V0V2Vla2x5VXNhZ2VSZXF1ZXN0Ij4KFkdldFdlZWtseVVzYWdlUmVzcG9uc2USJAoFdXNhZ2UYASABKAsyFS51c2FnZS52MS5XZWVrbHlVc2FnZTKaAgoMVXNhZ2VTZXJ2aWNlEoUBCg9HZXRTZXNzaW9uVXNhZ2USIC51c2FnZS52MS5HZXRTZXNzaW9uVXNhZ2VSZXF1ZXN0GiEudXNhZ2UudjEuR2V0U2Vzc2lvblVzYWdlUmVzcG9uc2UiLYLT5JMCJxIlL19wZC9hcGkvdjEvdXNlcnMvQHNlbGYvdXNhZ2Uvc2Vzc2lvbhKBAQoOR2V0V2Vla2x5VXNhZ2USHy51c2FnZS52MS5HZXRXZWVrbHlVc2FnZVJlcXVlc3QaIC51c2FnZS52MS5HZXRXZWVrbHlVc2FnZVJlc3BvbnNlIiyC0+STAiYSJC9fcGQvYXBpL3YxL3VzZXJzL0BzZWxmL3VzYWdlL3dlZWtseUKHAQoMY29tLnVzYWdlLnYxQgpVc2FnZVByb3RvUAFaKnBhcGVyZGVidWdnZXIvcGtnL2dlbi9hcGkvdXNhZ2UvdjE7dXNhZ2V2MaICA1VYWKoCCFVzYWdlLlYxygIIVXNhZ2VcVjHiAhRVc2FnZVxWMVxHUEJNZXRhZGF0YeoCCVVzYWdlOjpWMWIGcHJvdG8z", [file_google_api_annotations, file_google_protobuf_timestamp]); /** * @generated from message usage.v1.ModelTokens @@ -38,6 +38,13 @@ export type ModelTokens = Message<"usage.v1.ModelTokens"> & { * @generated from field: int64 request_count = 4; */ requestCount: bigint; + + /** + * Cost in USD for this model + * + * @generated from field: double cost_usd = 5; + */ + costUsd: number; }; /** @@ -62,6 +69,13 @@ export type SessionUsage = Message<"usage.v1.SessionUsage"> & { * @generated from field: map models = 2; */ models: { [key: string]: ModelTokens }; + + /** + * Total cost in USD across all models + * + * @generated from field: double total_cost_usd = 3; + */ + totalCostUsd: number; }; /** @@ -86,6 +100,13 @@ export type WeeklyUsage = Message<"usage.v1.WeeklyUsage"> & { * @generated from field: int64 session_count = 2; */ sessionCount: bigint; + + /** + * Total cost in USD across all models + * + * @generated from field: double total_cost_usd = 3; + */ + totalCostUsd: number; }; /** diff --git a/webapp/_webapp/src/views/usage/index.tsx b/webapp/_webapp/src/views/usage/index.tsx index bebca148..f00bc285 100644 --- a/webapp/_webapp/src/views/usage/index.tsx +++ b/webapp/_webapp/src/views/usage/index.tsx @@ -11,6 +11,12 @@ const formatNumber = (n: bigint | number | undefined): string => { return Number(n).toLocaleString(); }; +const formatCost = (cost: number | undefined): string => { + if (cost === undefined || cost === 0) return "$0.00"; + if (cost < 0.01) return `$${cost.toFixed(4)}`; + return `$${cost.toFixed(2)}`; +}; + const formatTimeRemaining = (timestamp: { seconds?: bigint; nanos?: number } | undefined): string => { if (!timestamp || !timestamp.seconds) return ""; const expiryMs = Number(timestamp.seconds) * 1000; @@ -53,28 +59,32 @@ const SectionTitle = ({ children }: { children: React.ReactNode }) => { const ModelUsageItem = ({ model, tokens }: { model: string; tokens: ModelTokens }) => { return ( -
-
{model}
-
- Total - {formatNumber(tokens.totalTokens)} -
-
- Prompt - {formatNumber(tokens.promptTokens)} +
+
+ {model} + {formatCost(tokens.costUsd)}
- Completion - {formatNumber(tokens.completionTokens)} + Tokens + {formatNumber(tokens.totalTokens)}
- Requests - {formatNumber(tokens.requestCount)} + Requests + {formatNumber(tokens.requestCount)}
); }; +const TotalCostDisplay = ({ cost }: { cost: number | undefined }) => { + return ( +
+ Total Cost + {formatCost(cost)} +
+ ); +}; + export const Usage = () => { const { data: sessionData, @@ -133,10 +143,11 @@ export const Usage = () => { {session && sessionModels.length > 0 ? ( -
+
{sessionModels.map(([model, tokens]) => ( ))} +
) : ( @@ -150,10 +161,11 @@ export const Usage = () => { Weekly Usage {weekly && weeklyModels.length > 0 ? ( -
+
{weeklyModels.map(([model, tokens]) => ( ))} +
) : ( From 219396f3f862e28fd739316d0dfd839388c6f174 Mon Sep 17 00:00:00 2001 From: wjiayis Date: Wed, 4 Mar 2026 20:52:05 +0800 Subject: [PATCH 12/13] feat: improve display of usage costs --- webapp/_webapp/src/views/usage/index.tsx | 65 +++++------------------- 1 file changed, 14 insertions(+), 51 deletions(-) diff --git a/webapp/_webapp/src/views/usage/index.tsx b/webapp/_webapp/src/views/usage/index.tsx index f00bc285..526cd18e 100644 --- a/webapp/_webapp/src/views/usage/index.tsx +++ b/webapp/_webapp/src/views/usage/index.tsx @@ -4,17 +4,11 @@ import { useState, useEffect } from "react"; import { TabHeader } from "../../components/tab-header"; import { useGetSessionUsageQuery, useGetWeeklyUsageQuery } from "../../query"; import CellWrapper from "../../components/cell-wrapper"; -import type { ModelTokens } from "../../pkg/gen/apiclient/usage/v1/usage_pb"; - -const formatNumber = (n: bigint | number | undefined): string => { - if (n === undefined) return "0"; - return Number(n).toLocaleString(); -}; const formatCost = (cost: number | undefined): string => { - if (cost === undefined || cost === 0) return "$0.00"; - if (cost < 0.01) return `$${cost.toFixed(4)}`; - return `$${cost.toFixed(2)}`; + if (cost === undefined || cost === 0) return "USD $0.00"; + if (cost < 0.01) return `USD $${cost.toFixed(4)}`; + return `USD $${cost.toFixed(2)}`; }; const formatTimeRemaining = (timestamp: { seconds?: bigint; nanos?: number } | undefined): string => { @@ -57,30 +51,10 @@ const SectionTitle = ({ children }: { children: React.ReactNode }) => { return
{children}
; }; -const ModelUsageItem = ({ model, tokens }: { model: string; tokens: ModelTokens }) => { +const CostDisplay = ({ cost }: { cost: number | undefined }) => { return ( -
-
- {model} - {formatCost(tokens.costUsd)} -
-
- Tokens - {formatNumber(tokens.totalTokens)} -
-
- Requests - {formatNumber(tokens.requestCount)} -
-
- ); -}; - -const TotalCostDisplay = ({ cost }: { cost: number | undefined }) => { - return ( -
- Total Cost - {formatCost(cost)} +
+ {formatCost(cost)}
); }; @@ -127,28 +101,20 @@ export const Usage = () => { const session = sessionData?.session; const weekly = weeklyData?.usage; - const sessionModels = session?.models ? Object.entries(session.models) : []; - const weeklyModels = weekly?.models ? Object.entries(weekly.models) : []; - return (
- Current Session + Current Session Usage {session?.sessionExpiry && ( ({formatTimeRemaining(session.sessionExpiry)}) )} - {session && sessionModels.length > 0 ? ( + {session ? ( -
- {sessionModels.map(([model, tokens]) => ( - - ))} - -
+
) : ( @@ -159,14 +125,9 @@ export const Usage = () => { Weekly Usage - {weekly && weeklyModels.length > 0 ? ( + {weekly ? ( -
- {weeklyModels.map(([model, tokens]) => ( - - ))} - -
+
) : ( @@ -174,7 +135,9 @@ export const Usage = () => { )}
- +
+ All costs displayed are fully covered by the PaperDebugger Team. +
Last updated: {formatLastUpdated(sessionUpdatedAt)} From dc78d5b6ac93f210f9e782bb2c3e8dd47d6a7d07 Mon Sep 17 00:00:00 2001 From: wjiayis Date: Wed, 4 Mar 2026 20:58:53 +0800 Subject: [PATCH 13/13] fix: only track usage for non-BYOK --- .../services/toolkit/client/completion_v2.go | 4 ++-- webapp/_webapp/src/views/usage/index.tsx | 20 +++++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/internal/services/toolkit/client/completion_v2.go b/internal/services/toolkit/client/completion_v2.go index 316c617b..2b1a0d36 100644 --- a/internal/services/toolkit/client/completion_v2.go +++ b/internal/services/toolkit/client/completion_v2.go @@ -98,8 +98,8 @@ func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream chunk := stream.Current() if len(chunk.Choices) == 0 { - // Handle usage information - if chunk.Usage.TotalTokens > 0 { + // Handle usage information - only record for non-BYOK users + if chunk.Usage.TotalTokens > 0 && !llmProvider.IsCustom() { // Record usage asynchronously to avoid blocking the response go func(usage services.UsageRecord) { bgCtx := context.Background() diff --git a/webapp/_webapp/src/views/usage/index.tsx b/webapp/_webapp/src/views/usage/index.tsx index 526cd18e..3465d9ad 100644 --- a/webapp/_webapp/src/views/usage/index.tsx +++ b/webapp/_webapp/src/views/usage/index.tsx @@ -4,6 +4,7 @@ import { useState, useEffect } from "react"; import { TabHeader } from "../../components/tab-header"; import { useGetSessionUsageQuery, useGetWeeklyUsageQuery } from "../../query"; import CellWrapper from "../../components/cell-wrapper"; +import { useSettingStore } from "../../stores/setting-store"; const formatCost = (cost: number | undefined): string => { if (cost === undefined || cost === 0) return "USD $0.00"; @@ -60,6 +61,9 @@ const CostDisplay = ({ cost }: { cost: number | undefined }) => { }; export const Usage = () => { + const { settings } = useSettingStore(); + const isBYOK = Boolean(settings?.openaiApiKey); + const { data: sessionData, isLoading: sessionLoading, @@ -98,6 +102,22 @@ export const Usage = () => { ); } + // Show message for BYOK users + if (isBYOK) { + return ( +
+ +
+ +
+ Usage tracking is not available when using your own API key. +
+
+
+
+ ); + } + const session = sessionData?.session; const weekly = weeklyData?.usage;