From aba6937060b507514164b459cd95a68601f7d16e Mon Sep 17 00:00:00 2001 From: Shaw Date: Sat, 4 Apr 2026 19:33:15 -0700 Subject: [PATCH 01/11] cloud: add managed Discord install flow --- .../discord/eliza-app/messages/route.ts | 150 +++++++++ app/api/v1/discord/callback/route.ts | 112 ++++++- .../agents/[agentId]/discord/oauth/route.ts | 117 +++++++ .../milady/agents/[agentId]/discord/route.ts | 72 ++++ .../agents/[agentId]/discord/oauth/route.ts | 1 + .../milaidy/agents/[agentId]/discord/route.ts | 1 + packages/db/repositories/milady-sandboxes.ts | 17 + .../lib/services/discord-automation/index.ts | 309 ++++++++++++++++-- .../lib/services/discord-automation/types.ts | 14 + packages/lib/services/milady-agent-config.ts | 97 ++++++ .../lib/services/milady-managed-discord.ts | 241 ++++++++++++++ .../gateway-discord/src/gateway-manager.ts | 127 ++++++- .../managed-discord-eliza-app-route.test.ts | 175 ++++++++++ .../unit/milady-agent-discord-routes.test.ts | 264 +++++++++++++++ .../tests/unit/milady-managed-discord.test.ts | 65 ++++ 15 files changed, 1725 insertions(+), 37 deletions(-) create mode 100644 app/api/internal/discord/eliza-app/messages/route.ts create mode 100644 app/api/v1/milady/agents/[agentId]/discord/oauth/route.ts create mode 100644 app/api/v1/milady/agents/[agentId]/discord/route.ts create mode 100644 app/api/v1/milaidy/agents/[agentId]/discord/oauth/route.ts create mode 100644 app/api/v1/milaidy/agents/[agentId]/discord/route.ts create mode 100644 packages/lib/services/milady-managed-discord.ts create mode 100644 packages/tests/unit/managed-discord-eliza-app-route.test.ts create mode 100644 packages/tests/unit/milady-agent-discord-routes.test.ts create mode 100644 packages/tests/unit/milady-managed-discord.test.ts diff --git a/app/api/internal/discord/eliza-app/messages/route.ts b/app/api/internal/discord/eliza-app/messages/route.ts new file mode 100644 index 000000000..2de697cf3 --- /dev/null +++ b/app/api/internal/discord/eliza-app/messages/route.ts @@ -0,0 +1,150 @@ +import { randomUUID } from "crypto"; +import { NextRequest, NextResponse } from "next/server"; +import { z } from "zod"; +import { miladySandboxesRepository } from "@/db/repositories/milady-sandboxes"; +import { withInternalAuth } from "@/lib/auth/internal-api"; +import type { BridgeRequest } from "@/lib/services/milady-sandbox"; +import { miladySandboxService } from "@/lib/services/milady-sandbox"; +import { logger } from "@/lib/utils/logger"; + +export const dynamic = "force-dynamic"; + +const senderSchema = z.object({ + id: z.string().trim().min(1), + username: z.string().trim().min(1), + displayName: z.string().trim().optional(), + avatar: z.string().trim().nullable().optional(), +}); + +const requestSchema = z.object({ + guildId: z.string().trim().min(1), + channelId: z.string().trim().min(1), + messageId: z.string().trim().min(1), + content: z.string().trim().min(1), + sender: senderSchema, +}); + +type ManagedDiscordRouteResponse = { + handled: boolean; + replyText?: string | null; + reason?: string; + agentId?: string; +}; + +export const POST = withInternalAuth(async (request: NextRequest) => { + let body: unknown; + try { + body = await request.json(); + } catch { + return NextResponse.json({ error: "Invalid JSON" }, { status: 400 }); + } + + const parsed = requestSchema.safeParse(body); + if (!parsed.success) { + return NextResponse.json( + { error: "Invalid payload", details: parsed.error.issues }, + { status: 400 }, + ); + } + + const { guildId, channelId, messageId, content, sender } = parsed.data; + + const linkedSandboxes = await miladySandboxesRepository.findByManagedDiscordGuildId(guildId); + if (linkedSandboxes.length === 0) { + return NextResponse.json({ + handled: false, + reason: "not_linked", + }); + } + + if (linkedSandboxes.length > 1) { + logger.warn("[managed-discord] Multiple Milady agents linked to the same guild", { + guildId, + agentIds: linkedSandboxes.map((sandbox) => sandbox.id), + }); + return NextResponse.json({ + handled: false, + reason: "ambiguous_guild_link", + }); + } + + const sandbox = linkedSandboxes[0]; + if (!sandbox) { + return NextResponse.json({ + handled: false, + reason: "not_linked", + }); + } + + if (sandbox.status !== "running") { + return NextResponse.json({ + handled: false, + reason: "agent_not_running", + agentId: sandbox.id, + }); + } + + const rpcRequest: BridgeRequest = { + jsonrpc: "2.0", + id: randomUUID(), + method: "message.send", + params: { + text: content, + roomId: `discord-guild:${guildId}:channel:${channelId}`, + channelType: "GROUP", + source: "discord", + sender: { + id: sender.id, + username: sender.username, + ...(sender.displayName ? { displayName: sender.displayName } : {}), + metadata: { + discord: { + userId: sender.id, + username: sender.username, + ...(sender.displayName ? { globalName: sender.displayName } : {}), + ...(sender.avatar ? { avatar: sender.avatar } : {}), + }, + }, + }, + metadata: { + discord: { + guildId, + channelId, + messageId, + }, + }, + }, + }; + + const bridgeResponse = await miladySandboxService.bridge( + sandbox.id, + sandbox.organization_id, + rpcRequest, + ); + + if (bridgeResponse.error) { + logger.warn("[managed-discord] Sandbox bridge rejected Discord message", { + guildId, + agentId: sandbox.id, + error: bridgeResponse.error.message, + }); + return NextResponse.json({ + handled: false, + reason: "bridge_failed", + agentId: sandbox.id, + }); + } + + const replyText = + bridgeResponse.result && + typeof bridgeResponse.result === "object" && + typeof bridgeResponse.result.text === "string" + ? bridgeResponse.result.text + : null; + + return NextResponse.json({ + handled: true, + replyText, + agentId: sandbox.id, + }); +}); diff --git a/app/api/v1/discord/callback/route.ts b/app/api/v1/discord/callback/route.ts index 1034a8b72..5638c7324 100644 --- a/app/api/v1/discord/callback/route.ts +++ b/app/api/v1/discord/callback/route.ts @@ -6,12 +6,52 @@ */ import { NextRequest, NextResponse } from "next/server"; -import { resolveSafeRedirectTarget } from "@/lib/security/redirect-validation"; +import { + assertAllowedAbsoluteRedirectUrl, + getDefaultPlatformRedirectOrigins, + resolveSafeRedirectTarget, + sanitizeRelativeRedirectPath, +} from "@/lib/security/redirect-validation"; import { discordAutomationService } from "@/lib/services/discord-automation"; +import { managedMiladyDiscordService } from "@/lib/services/milady-managed-discord"; import { logger } from "@/lib/utils/logger"; export const maxDuration = 60; +const LOOPBACK_REDIRECT_ORIGINS = [ + "http://localhost:*", + "http://127.0.0.1:*", + "https://localhost:*", + "https://127.0.0.1:*", +] as const; + +function resolveOAuthReturnTarget( + baseUrl: string, + returnUrl: string | undefined, + managedFlow: boolean, +): URL { + const fallbackPath = managedFlow + ? "/dashboard/settings?tab=agents" + : "/dashboard/settings?tab=connections"; + + if (managedFlow && returnUrl) { + if (returnUrl.startsWith("/")) { + return new URL(sanitizeRelativeRedirectPath(returnUrl, fallbackPath), baseUrl); + } + + try { + return assertAllowedAbsoluteRedirectUrl(returnUrl, [ + ...getDefaultPlatformRedirectOrigins(), + ...LOOPBACK_REDIRECT_ORIGINS, + ]); + } catch { + // Fall through to standard same-origin fallback below. + } + } + + return resolveSafeRedirectTarget(returnUrl, baseUrl, fallbackPath); +} + export async function GET(request: NextRequest): Promise { const { searchParams } = new URL(request.url); const code = searchParams.get("code"); @@ -24,15 +64,24 @@ export async function GET(request: NextRequest): Promise { const baseUrl = process.env.NEXT_PUBLIC_APP_URL || "http://localhost:3000"; // Parse state for return URL (do this early for error redirects) - const defaultReturnPath = "/dashboard/settings?tab=connections"; - let returnTarget = resolveSafeRedirectTarget(undefined, baseUrl, defaultReturnPath); + let decodedState: + | { + returnUrl?: string; + flow?: "organization-install" | "milady-managed"; + agentId?: string; + organizationId?: string; + userId?: string; + botNickname?: string; + } + | null = null; + let returnTarget = resolveOAuthReturnTarget(baseUrl, undefined, false); if (state) { try { - const stateData = JSON.parse(Buffer.from(state, "base64").toString()); - returnTarget = resolveSafeRedirectTarget( - typeof stateData.returnUrl === "string" ? stateData.returnUrl : undefined, + decodedState = discordAutomationService.decodeOAuthState(state); + returnTarget = resolveOAuthReturnTarget( baseUrl, - defaultReturnPath, + typeof decodedState.returnUrl === "string" ? decodedState.returnUrl : undefined, + decodedState.flow === "milady-managed", ); } catch { // Use default return URL @@ -60,23 +109,62 @@ export async function GET(request: NextRequest): Promise { } // For bot OAuth, guild_id is returned directly in URL params - if (!guildId || !state) { + if (!guildId || !state || !code || !decodedState) { logger.warn("[Discord Callback] Missing params", { hasGuildId: !!guildId, hasState: !!state, hasCode: !!code, + hasDecodedState: !!decodedState, }); return redirectWithStatus("error", { message: "missing_params" }); } try { - const result = await discordAutomationService.handleBotOAuthCallback( + const result = await discordAutomationService.handleBotOAuthCallback({ + code, guildId, - state, - permissions || undefined, - ); + oauthState: decodedState, + permissions: permissions || undefined, + }); if (result.success) { + if ( + decodedState.flow === "milady-managed" && + decodedState.agentId && + decodedState.organizationId && + decodedState.userId && + result.discordUser + ) { + const connected = await managedMiladyDiscordService.connectAgent({ + agentId: decodedState.agentId, + organizationId: decodedState.organizationId, + binding: { + mode: "cloud-managed", + applicationId: discordAutomationService.getApplicationId() ?? undefined, + guildId: result.guildId ?? guildId, + guildName: result.guildName || "", + adminDiscordUserId: result.discordUser.id, + adminDiscordUsername: result.discordUser.username, + adminElizaUserId: decodedState.userId, + connectedAt: new Date().toISOString(), + ...(result.discordUser.globalName + ? { adminDiscordDisplayName: result.discordUser.globalName } + : {}), + ...(decodedState.botNickname?.trim() + ? { botNickname: decodedState.botNickname.trim() } + : {}), + }, + }); + + return redirectWithStatus("connected", { + managed: "1", + agentId: decodedState.agentId, + guildId: result.guildId ?? guildId, + guildName: result.guildName || "", + restarted: connected.restarted ? "1" : "0", + }); + } + return redirectWithStatus("connected", { guildId: result.guildId ?? guildId, guildName: result.guildName || "", diff --git a/app/api/v1/milady/agents/[agentId]/discord/oauth/route.ts b/app/api/v1/milady/agents/[agentId]/discord/oauth/route.ts new file mode 100644 index 000000000..fa54c2d8a --- /dev/null +++ b/app/api/v1/milady/agents/[agentId]/discord/oauth/route.ts @@ -0,0 +1,117 @@ +import { randomBytes } from "crypto"; +import { NextRequest, NextResponse } from "next/server"; +import { z } from "zod"; +import { errorToResponse } from "@/lib/api/errors"; +import { requireAuthOrApiKeyWithOrg } from "@/lib/auth"; +import { + assertAllowedAbsoluteRedirectUrl, + getDefaultPlatformRedirectOrigins, + sanitizeRelativeRedirectPath, +} from "@/lib/security/redirect-validation"; +import { discordAutomationService } from "@/lib/services/discord-automation"; +import { miladySandboxService } from "@/lib/services/milady-sandbox"; +import { applyCorsHeaders, handleCorsOptions } from "@/lib/services/proxy/cors"; + +export const dynamic = "force-dynamic"; + +const CORS_METHODS = "POST, OPTIONS"; +const LOOPBACK_REDIRECT_ORIGINS = [ + "http://localhost:*", + "http://127.0.0.1:*", + "https://localhost:*", + "https://127.0.0.1:*", +] as const; + +const oauthLinkSchema = z.object({ + returnUrl: z.string().trim().optional(), + botNickname: z.string().trim().max(32).optional(), +}); + +export function OPTIONS() { + return handleCorsOptions(CORS_METHODS); +} + +function resolveManagedReturnUrl(rawValue: string | undefined): string { + const baseUrl = process.env.NEXT_PUBLIC_APP_URL || "http://localhost:3000"; + const defaultPath = "/dashboard/settings?tab=agents"; + + if (!rawValue) { + return new URL(defaultPath, baseUrl).toString(); + } + + if (rawValue.startsWith("/")) { + return new URL(sanitizeRelativeRedirectPath(rawValue, defaultPath), baseUrl).toString(); + } + + return assertAllowedAbsoluteRedirectUrl(rawValue, [ + ...getDefaultPlatformRedirectOrigins(), + ...LOOPBACK_REDIRECT_ORIGINS, + ]).toString(); +} + +export async function POST( + request: NextRequest, + { params }: { params: Promise<{ agentId: string }> }, +) { + try { + const { user } = await requireAuthOrApiKeyWithOrg(request); + const { agentId } = await params; + + if (!discordAutomationService.isOAuthConfigured()) { + return applyCorsHeaders( + NextResponse.json( + { success: false, error: "Discord integration is not configured" }, + { status: 503 }, + ), + CORS_METHODS, + ); + } + + const sandbox = await miladySandboxService.getAgent(agentId, user.organization_id); + if (!sandbox) { + return applyCorsHeaders( + NextResponse.json({ success: false, error: "Agent not found" }, { status: 404 }), + CORS_METHODS, + ); + } + + const body = await request.json().catch(() => ({})); + const parsed = oauthLinkSchema.safeParse(body); + if (!parsed.success) { + return applyCorsHeaders( + NextResponse.json( + { success: false, error: "Invalid request", details: parsed.error.issues }, + { status: 400 }, + ), + CORS_METHODS, + ); + } + + const authorizeUrl = discordAutomationService.generateOAuthUrl({ + organizationId: user.organization_id, + userId: user.id, + agentId, + flow: "milady-managed", + nonce: randomBytes(16).toString("hex"), + returnUrl: resolveManagedReturnUrl(parsed.data.returnUrl), + ...(parsed.data.botNickname?.trim() + ? { botNickname: parsed.data.botNickname.trim() } + : sandbox.agent_name?.trim() + ? { botNickname: sandbox.agent_name.trim().slice(0, 32) } + : {}), + }); + + return applyCorsHeaders( + NextResponse.json({ + success: true, + data: { + authorizeUrl, + applicationId: discordAutomationService.getApplicationId(), + }, + }), + CORS_METHODS, + ); + } catch (error) { + return applyCorsHeaders(errorToResponse(error), CORS_METHODS); + } +} diff --git a/app/api/v1/milady/agents/[agentId]/discord/route.ts b/app/api/v1/milady/agents/[agentId]/discord/route.ts new file mode 100644 index 000000000..cc171a7ac --- /dev/null +++ b/app/api/v1/milady/agents/[agentId]/discord/route.ts @@ -0,0 +1,72 @@ +import { NextRequest, NextResponse } from "next/server"; +import { errorToResponse } from "@/lib/api/errors"; +import { requireAuthOrApiKeyWithOrg } from "@/lib/auth"; +import { discordAutomationService } from "@/lib/services/discord-automation"; +import { managedMiladyDiscordService } from "@/lib/services/milady-managed-discord"; +import { applyCorsHeaders, handleCorsOptions } from "@/lib/services/proxy/cors"; + +export const dynamic = "force-dynamic"; + +const CORS_METHODS = "GET, DELETE, OPTIONS"; + +export function OPTIONS() { + return handleCorsOptions(CORS_METHODS); +} + +export async function GET( + request: NextRequest, + { params }: { params: Promise<{ agentId: string }> }, +) { + try { + const { user } = await requireAuthOrApiKeyWithOrg(request); + const { agentId } = await params; + + const status = await managedMiladyDiscordService.getStatus({ + agentId, + organizationId: user.organization_id, + configured: discordAutomationService.isOAuthConfigured(), + applicationId: discordAutomationService.getApplicationId(), + }); + + if (!status) { + return applyCorsHeaders( + NextResponse.json({ success: false, error: "Agent not found" }, { status: 404 }), + CORS_METHODS, + ); + } + + return applyCorsHeaders(NextResponse.json({ success: true, data: status }), CORS_METHODS); + } catch (error) { + return applyCorsHeaders(errorToResponse(error), CORS_METHODS); + } +} + +export async function DELETE( + request: NextRequest, + { params }: { params: Promise<{ agentId: string }> }, +) { + try { + const { user } = await requireAuthOrApiKeyWithOrg(request); + const { agentId } = await params; + + const result = await managedMiladyDiscordService.disconnectAgent({ + agentId, + organizationId: user.organization_id, + configured: discordAutomationService.isOAuthConfigured(), + applicationId: discordAutomationService.getApplicationId(), + }); + + return applyCorsHeaders( + NextResponse.json({ + success: true, + data: { + ...result.status, + restarted: result.restarted, + }, + }), + CORS_METHODS, + ); + } catch (error) { + return applyCorsHeaders(errorToResponse(error), CORS_METHODS); + } +} diff --git a/app/api/v1/milaidy/agents/[agentId]/discord/oauth/route.ts b/app/api/v1/milaidy/agents/[agentId]/discord/oauth/route.ts new file mode 100644 index 000000000..968bb383f --- /dev/null +++ b/app/api/v1/milaidy/agents/[agentId]/discord/oauth/route.ts @@ -0,0 +1 @@ +export * from "@/app/api/v1/milady/agents/[agentId]/discord/oauth/route"; diff --git a/app/api/v1/milaidy/agents/[agentId]/discord/route.ts b/app/api/v1/milaidy/agents/[agentId]/discord/route.ts new file mode 100644 index 000000000..8bd0e5b0b --- /dev/null +++ b/app/api/v1/milaidy/agents/[agentId]/discord/route.ts @@ -0,0 +1 @@ +export * from "@/app/api/v1/milady/agents/[agentId]/discord/route"; diff --git a/packages/db/repositories/milady-sandboxes.ts b/packages/db/repositories/milady-sandboxes.ts index 544fdffcd..39ec6002a 100644 --- a/packages/db/repositories/milady-sandboxes.ts +++ b/packages/db/repositories/milady-sandboxes.ts @@ -1,5 +1,6 @@ import { and, desc, eq, inArray, notInArray, sql } from "drizzle-orm"; import { dbRead, dbWrite } from "@/db/helpers"; +import { MILADY_MANAGED_DISCORD_KEY } from "@/lib/services/milady-agent-config"; import { type MiladyBackupSnapshotType, type MiladySandbox, @@ -99,6 +100,22 @@ export class MiladySandboxesRepository { return r; } + async findByManagedDiscordGuildId(guildId: string): Promise { + const trimmedGuildId = guildId.trim(); + if (!trimmedGuildId) { + return []; + } + + const result = await dbWrite.execute(sql` + SELECT * + FROM ${miladySandboxes} + WHERE (${miladySandboxes.agent_config} -> ${MILADY_MANAGED_DISCORD_KEY} ->> 'guildId') = ${trimmedGuildId} + ORDER BY ${miladySandboxes.updated_at} DESC + `); + + return result.rows; + } + // Writes async create(data: NewMiladySandbox): Promise { diff --git a/packages/lib/services/discord-automation/index.ts b/packages/lib/services/discord-automation/index.ts index 1bf5392d3..2f2049b73 100644 --- a/packages/lib/services/discord-automation/index.ts +++ b/packages/lib/services/discord-automation/index.ts @@ -5,6 +5,7 @@ * Uses Discord REST API for all operations (serverless-compatible). */ +import { createHmac, timingSafeEqual } from "node:crypto"; import { discordChannelsRepository } from "@/db/repositories/discord-channels"; import { discordGuildsRepository } from "@/db/repositories/discord-guilds"; import { @@ -19,6 +20,7 @@ import type { DiscordChannelInfo, DiscordConnectionStatus, DiscordEmbed, + DiscordOAuthIdentity, OAuthState, SendMessageResult, } from "./types"; @@ -33,9 +35,41 @@ const DISCORD_BOT_TOKEN = process.env.DISCORD_BOT_TOKEN; const APP_URL = process.env.NEXT_PUBLIC_APP_URL || "https://www.elizacloud.ai"; // OAuth2 scopes and permissions -const OAUTH_SCOPES = "bot"; -// Permissions: Send Messages (2048) + Embed Links (16384) + Read Message History (65536) -const BOT_PERMISSIONS = "83968"; +const OAUTH_SCOPES = "identify guilds bot applications.commands"; +// Permissions: +// - View Channels (1024) +// - Send Messages (2048) +// - Embed Links (16384) +// - Read Message History (65536) +// - Change Nickname (67108864) +const BOT_PERMISSIONS = "67193856"; +const OAUTH_CALLBACK_PATH = "/api/v1/discord/callback"; + +interface DiscordTokenResponse { + access_token?: string; + refresh_token?: string; + token_type?: string; + expires_in?: number; + scope?: string; +} + +interface DiscordApiUser { + id: string; + username: string; + global_name: string | null; + avatar: string | null; + bot?: boolean; + system?: boolean; +} + +interface DiscordApiGuild { + id: string; + name: string; + icon: string | null; + owner: boolean; + permissions: string; + features: string[]; +} class DiscordAutomationService { /** @@ -54,6 +88,14 @@ class DiscordAutomationService { return Boolean(DISCORD_BOT_TOKEN); } + getApplicationId(): string | null { + return DISCORD_CLIENT_ID?.trim() || null; + } + + getOAuthRedirectUri(): string { + return `${APP_URL}${OAUTH_CALLBACK_PATH}`; + } + /** * Check if Discord is configured (alias for isOAuthConfigured for backwards compatibility) * @deprecated Use isOAuthConfigured() or canSendMessages() instead @@ -66,17 +108,21 @@ class DiscordAutomationService { * Generate OAuth2 URL for adding bot to a server */ generateOAuthUrl(state: OAuthState): string { - if (!DISCORD_CLIENT_ID) { - throw new Error("Discord client ID not configured"); + const clientId = this.getApplicationId(); + if (!clientId || !DISCORD_CLIENT_SECRET) { + throw new Error("Discord OAuth is not configured"); } - const stateEncoded = Buffer.from(JSON.stringify(state)).toString("base64"); + const stateEncoded = this.encodeOAuthState({ + ...state, + flow: state.flow ?? "organization-install", + }); const params = new URLSearchParams({ - client_id: DISCORD_CLIENT_ID, + client_id: clientId, permissions: BOT_PERMISSIONS, scope: OAUTH_SCOPES, - redirect_uri: `${APP_URL}/api/v1/discord/callback`, + redirect_uri: this.getOAuthRedirectUri(), response_type: "code", state: stateEncoded, }); @@ -84,18 +130,154 @@ class DiscordAutomationService { return `https://discord.com/oauth2/authorize?${params.toString()}`; } + decodeOAuthState(stateValue: string): OAuthState { + if (!DISCORD_CLIENT_SECRET) { + throw new Error("Discord OAuth is not configured"); + } + + const [payloadBase64, signature] = stateValue.split(".", 2); + if (!payloadBase64 || !signature) { + throw new Error("Invalid Discord OAuth state"); + } + + const expectedSignature = createHmac("sha256", DISCORD_CLIENT_SECRET) + .update(payloadBase64) + .digest("base64url"); + + const providedBytes = Buffer.from(signature); + const expectedBytes = Buffer.from(expectedSignature); + if ( + providedBytes.length !== expectedBytes.length || + !timingSafeEqual(providedBytes, expectedBytes) + ) { + throw new Error("Invalid Discord OAuth state signature"); + } + + const parsed = JSON.parse(Buffer.from(payloadBase64, "base64url").toString("utf8")); + if (!parsed || typeof parsed !== "object") { + throw new Error("Invalid Discord OAuth state payload"); + } + + return parsed as OAuthState; + } + + async resolveOAuthIdentity(code: string): Promise { + if (!DISCORD_CLIENT_ID || !DISCORD_CLIENT_SECRET) { + logger.error("[Discord] Discord OAuth is not configured"); + return null; + } + + let tokenData: DiscordTokenResponse; + try { + const tokenResponse = await fetch(`${DISCORD_API_BASE}/oauth2/token`, { + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + body: new URLSearchParams({ + client_id: DISCORD_CLIENT_ID, + client_secret: DISCORD_CLIENT_SECRET, + grant_type: "authorization_code", + code, + redirect_uri: this.getOAuthRedirectUri(), + }), + }); + + if (!tokenResponse.ok) { + const errorText = await tokenResponse.text(); + logger.warn("[Discord] Token exchange failed", { + status: tokenResponse.status, + error: errorText.slice(0, 200), + }); + return null; + } + + tokenData = (await tokenResponse.json()) as DiscordTokenResponse; + if (!tokenData.access_token) { + logger.warn("[Discord] Missing access token in OAuth response"); + return null; + } + } catch (error) { + logger.error("[Discord] Token exchange request failed", { + error: error instanceof Error ? error.message : String(error), + }); + return null; + } + + try { + const [userResponse, guildsResponse] = await Promise.all([ + fetch(`${DISCORD_API_BASE}/users/@me`, { + headers: { + Authorization: `Bearer ${tokenData.access_token}`, + }, + }), + fetch(`${DISCORD_API_BASE}/users/@me/guilds`, { + headers: { + Authorization: `Bearer ${tokenData.access_token}`, + }, + }), + ]); + + if (!userResponse.ok || !guildsResponse.ok) { + logger.warn("[Discord] Failed to fetch OAuth identity", { + userStatus: userResponse.status, + guildsStatus: guildsResponse.status, + }); + return null; + } + + const user = (await userResponse.json()) as DiscordApiUser; + if (!user.id || !user.username || user.bot || user.system) { + logger.warn("[Discord] Invalid OAuth user", { + hasId: !!user.id, + hasUsername: !!user.username, + bot: user.bot, + system: user.system, + }); + return null; + } + + const guilds = (await guildsResponse.json()) as DiscordApiGuild[]; + + return { + accessToken: tokenData.access_token, + guilds: guilds.map((guild) => ({ + id: guild.id, + name: guild.name, + icon: guild.icon, + owner: guild.owner, + permissions: guild.permissions, + features: guild.features, + })), + user: { + id: user.id, + username: user.username, + globalName: user.global_name, + avatar: user.avatar, + }, + }; + } catch (error) { + logger.error("[Discord] Failed to fetch OAuth identity", { + error: error instanceof Error ? error.message : String(error), + }); + return null; + } + } + /** * Handle Bot OAuth callback - uses guild_id directly from URL params * For bot OAuth (scope=bot), Discord returns guild_id in the callback URL */ - async handleBotOAuthCallback( - guildId: string, - stateBase64: string, - permissions?: string, - ): Promise<{ + async handleBotOAuthCallback(args: { + code: string; + guildId: string; + oauthState: OAuthState; + permissions?: string; + }): Promise<{ success: boolean; guildId?: string; guildName?: string; + discordUser?: DiscordOAuthIdentity["user"]; error?: string; }> { if (!DISCORD_BOT_TOKEN) { @@ -103,10 +285,28 @@ class DiscordAutomationService { } try { - const state: OAuthState = JSON.parse(Buffer.from(stateBase64, "base64").toString()); + const identity = await this.resolveOAuthIdentity(args.code); + if (!identity) { + return { success: false, error: "Failed to verify Discord account" }; + } + + const guildAccess = identity.guilds.find((guild) => guild.id === args.guildId); + if (!guildAccess) { + return { + success: false, + error: "Discord account does not have access to this server", + }; + } + + if (args.oauthState.flow === "milady-managed" && !guildAccess.owner) { + return { + success: false, + error: "Discord account must own the server", + }; + } // Fetch guild info using bot token - const guildResponse = await fetch(`${DISCORD_API_BASE}/guilds/${guildId}`, { + const guildResponse = await fetch(`${DISCORD_API_BASE}/guilds/${args.guildId}`, { headers: { Authorization: `Bot ${DISCORD_BOT_TOKEN}`, }, @@ -115,7 +315,7 @@ class DiscordAutomationService { if (!guildResponse.ok) { const errorText = await guildResponse.text(); logger.error("[Discord] Failed to fetch guild info:", { - guildId, + guildId: args.guildId, status: guildResponse.status, error: errorText, }); @@ -132,24 +332,35 @@ class DiscordAutomationService { // Store guild in database await discordGuildsRepository.upsert({ - organization_id: state.organizationId, + organization_id: args.oauthState.organizationId, guild_id: guild.id, guild_name: guild.name, icon_hash: guild.icon, - owner_id: state.userId, - bot_permissions: permissions || BOT_PERMISSIONS, + owner_id: identity.user.id, + bot_permissions: args.permissions || BOT_PERMISSIONS, }); // Fetch and cache channels - await this.refreshChannels(state.organizationId, guild.id); + await this.refreshChannels(args.oauthState.organizationId, guild.id); + + const requestedNickname = args.oauthState.botNickname?.trim(); + if (requestedNickname) { + await this.setGuildBotNickname(guild.id, requestedNickname); + } logger.info("[Discord] Bot added to guild", { - organizationId: state.organizationId, + organizationId: args.oauthState.organizationId, guildId: guild.id, guildName: guild.name, + oauthUserId: identity.user.id, }); - return { success: true, guildId: guild.id, guildName: guild.name }; + return { + success: true, + guildId: guild.id, + guildName: guild.name, + discordUser: identity.user, + }; } catch (error) { logger.error("[Discord] Bot OAuth callback error:", { error: error instanceof Error ? error.message : "Unknown error", @@ -158,6 +369,60 @@ class DiscordAutomationService { } } + async setGuildBotNickname(guildId: string, nickname: string): Promise { + if (!DISCORD_BOT_TOKEN) { + return false; + } + + const trimmed = nickname.trim(); + if (!trimmed) { + return true; + } + + try { + const response = await fetch(`${DISCORD_API_BASE}/guilds/${guildId}/members/@me`, { + method: "PATCH", + headers: { + Authorization: `Bot ${DISCORD_BOT_TOKEN}`, + "Content-Type": "application/json", + }, + body: JSON.stringify({ + nick: trimmed.slice(0, 32), + }), + }); + + if (!response.ok) { + const errorText = await response.text(); + logger.warn("[Discord] Failed to set bot nickname", { + guildId, + status: response.status, + error: errorText.slice(0, 200), + }); + return false; + } + + return true; + } catch (error) { + logger.warn("[Discord] Failed to set bot nickname", { + guildId, + error: error instanceof Error ? error.message : String(error), + }); + return false; + } + } + + private encodeOAuthState(state: OAuthState): string { + if (!DISCORD_CLIENT_SECRET) { + throw new Error("Discord OAuth is not configured"); + } + + const payloadBase64 = Buffer.from(JSON.stringify(state), "utf8").toString("base64url"); + const signature = createHmac("sha256", DISCORD_CLIENT_SECRET) + .update(payloadBase64) + .digest("base64url"); + return `${payloadBase64}.${signature}`; + } + /** * Get connection status for an organization * Uses canSendMessages() to check if bot can actually post (only needs bot token) diff --git a/packages/lib/services/discord-automation/types.ts b/packages/lib/services/discord-automation/types.ts index 6f58f4ad3..440905040 100644 --- a/packages/lib/services/discord-automation/types.ts +++ b/packages/lib/services/discord-automation/types.ts @@ -67,6 +67,20 @@ export interface OAuthState { userId: string; returnUrl: string; nonce: string; + flow?: "organization-install" | "milady-managed"; + agentId?: string; + botNickname?: string; +} + +export interface DiscordOAuthIdentity { + accessToken: string; + guilds: DiscordGuildInfo[]; + user: { + id: string; + username: string; + globalName: string | null; + avatar: string | null; + }; } export interface SendMessageResult { diff --git a/packages/lib/services/milady-agent-config.ts b/packages/lib/services/milady-agent-config.ts index 8b1ee8923..61a386d5b 100644 --- a/packages/lib/services/milady-agent-config.ts +++ b/packages/lib/services/milady-agent-config.ts @@ -1,6 +1,30 @@ export const MILADY_INTERNAL_CONFIG_PREFIX = "__milady"; export const MILADY_CHARACTER_OWNERSHIP_KEY = "__miladyCharacterOwnership"; export const MILADY_REUSE_EXISTING_CHARACTER = "reuse-existing"; +export const MILADY_MANAGED_DISCORD_KEY = "__miladyManagedDiscord"; + +export interface ManagedMiladyDiscordBinding { + mode: "cloud-managed"; + applicationId?: string; + guildId: string; + guildName: string; + adminDiscordUserId: string; + adminDiscordUsername: string; + adminDiscordDisplayName?: string; + adminElizaUserId: string; + botNickname?: string; + connectedAt: string; +} + +function asRecord(value: unknown): Record | null { + return value && typeof value === "object" && !Array.isArray(value) + ? (value as Record) + : null; +} + +function cloneAgentConfig(agentConfig?: Record | null): Record { + return asRecord(agentConfig) ? { ...agentConfig } : {}; +} export function stripReservedMiladyConfigKeys( agentConfig?: Record | null, @@ -30,3 +54,76 @@ export function reusesExistingMiladyCharacter( ): boolean { return agentConfig?.[MILADY_CHARACTER_OWNERSHIP_KEY] === MILADY_REUSE_EXISTING_CHARACTER; } + +export function readManagedMiladyDiscordBinding( + agentConfig?: Record | null, +): ManagedMiladyDiscordBinding | null { + const binding = asRecord(agentConfig?.[MILADY_MANAGED_DISCORD_KEY]); + if (!binding) { + return null; + } + + const guildId = typeof binding.guildId === "string" ? binding.guildId.trim() : ""; + const guildName = typeof binding.guildName === "string" ? binding.guildName.trim() : ""; + const adminDiscordUserId = + typeof binding.adminDiscordUserId === "string" ? binding.adminDiscordUserId.trim() : ""; + const adminDiscordUsername = + typeof binding.adminDiscordUsername === "string" ? binding.adminDiscordUsername.trim() : ""; + const adminElizaUserId = + typeof binding.adminElizaUserId === "string" ? binding.adminElizaUserId.trim() : ""; + const connectedAt = typeof binding.connectedAt === "string" ? binding.connectedAt.trim() : ""; + + if (!guildId || !guildName || !adminDiscordUserId || !adminDiscordUsername || !adminElizaUserId) { + return null; + } + + return { + mode: "cloud-managed", + guildId, + guildName, + adminDiscordUserId, + adminDiscordUsername, + adminElizaUserId, + connectedAt: connectedAt || new Date(0).toISOString(), + ...(typeof binding.applicationId === "string" && binding.applicationId.trim() + ? { applicationId: binding.applicationId.trim() } + : {}), + ...(typeof binding.adminDiscordDisplayName === "string" && + binding.adminDiscordDisplayName.trim() + ? { adminDiscordDisplayName: binding.adminDiscordDisplayName.trim() } + : {}), + ...(typeof binding.botNickname === "string" && binding.botNickname.trim() + ? { botNickname: binding.botNickname.trim() } + : {}), + }; +} + +export function withManagedMiladyDiscordBinding( + agentConfig: Record | null | undefined, + binding: ManagedMiladyDiscordBinding, +): Record { + const next = cloneAgentConfig(agentConfig); + next[MILADY_MANAGED_DISCORD_KEY] = { + mode: "cloud-managed", + guildId: binding.guildId, + guildName: binding.guildName, + adminDiscordUserId: binding.adminDiscordUserId, + adminDiscordUsername: binding.adminDiscordUsername, + adminElizaUserId: binding.adminElizaUserId, + connectedAt: binding.connectedAt, + ...(binding.applicationId ? { applicationId: binding.applicationId } : {}), + ...(binding.adminDiscordDisplayName + ? { adminDiscordDisplayName: binding.adminDiscordDisplayName } + : {}), + ...(binding.botNickname ? { botNickname: binding.botNickname } : {}), + }; + return next; +} + +export function withoutManagedMiladyDiscordBinding( + agentConfig: Record | null | undefined, +): Record { + const next = cloneAgentConfig(agentConfig); + delete next[MILADY_MANAGED_DISCORD_KEY]; + return next; +} diff --git a/packages/lib/services/milady-managed-discord.ts b/packages/lib/services/milady-managed-discord.ts new file mode 100644 index 000000000..480e1f781 --- /dev/null +++ b/packages/lib/services/milady-managed-discord.ts @@ -0,0 +1,241 @@ +import { miladySandboxesRepository } from "@/db/repositories/milady-sandboxes"; +import { logger } from "@/lib/utils/logger"; +import { + type ManagedMiladyDiscordBinding, + readManagedMiladyDiscordBinding, + withManagedMiladyDiscordBinding, + withoutManagedMiladyDiscordBinding, +} from "./milady-agent-config"; +import { miladySandboxService } from "@/lib/services/milady-sandbox"; + +const ROLES_PLUGIN_ID = "@miladyai/plugin-roles"; +export const DISCORD_DEVELOPER_PORTAL_URL = "https://discord.com/developers/applications"; + +function asRecord(value: unknown): Record | null { + return value && typeof value === "object" && !Array.isArray(value) + ? (value as Record) + : null; +} + +function ensureRecord(parent: Record, key: string): Record { + const existing = asRecord(parent[key]); + if (existing) { + return existing; + } + + const next: Record = {}; + parent[key] = next; + return next; +} + +function withDiscordConnectorAdmin( + agentConfig: Record | null | undefined, + adminDiscordUserId: string, +): Record { + const next = { ...(agentConfig ?? {}) }; + const plugins = ensureRecord(next, "plugins"); + const entries = ensureRecord(plugins, "entries"); + const rolesEntry = ensureRecord(entries, ROLES_PLUGIN_ID); + rolesEntry.enabled = true; + + const roleConfig = ensureRecord(rolesEntry, "config"); + const connectorAdmins = ensureRecord(roleConfig, "connectorAdmins"); + connectorAdmins.discord = [adminDiscordUserId]; + + return next; +} + +function withoutDiscordConnectorAdmin( + agentConfig: Record | null | undefined, +): Record { + const next = { ...(agentConfig ?? {}) }; + const plugins = asRecord(next.plugins); + const entries = asRecord(plugins?.entries); + const rolesEntry = asRecord(entries?.[ROLES_PLUGIN_ID]); + const roleConfig = asRecord(rolesEntry?.config); + const connectorAdmins = asRecord(roleConfig?.connectorAdmins); + + if (connectorAdmins) { + delete connectorAdmins.discord; + if (Object.keys(connectorAdmins).length === 0 && roleConfig) { + delete roleConfig.connectorAdmins; + } + } + + if (roleConfig && Object.keys(roleConfig).length === 0 && rolesEntry) { + delete rolesEntry.config; + } + + return next; +} + +export interface ManagedMiladyDiscordStatus { + applicationId: string | null; + configured: boolean; + connected: boolean; + developerPortalUrl: string; + guildId: string | null; + guildName: string | null; + adminDiscordUserId: string | null; + adminDiscordUsername: string | null; + adminDiscordDisplayName: string | null; + adminElizaUserId: string | null; + botNickname: string | null; + connectedAt: string | null; +} + +function toStatus( + agentConfig: Record | null | undefined, + configured: boolean, + applicationId: string | null, +): ManagedMiladyDiscordStatus { + const binding = readManagedMiladyDiscordBinding(agentConfig); + + return { + applicationId, + configured, + connected: Boolean(binding), + developerPortalUrl: DISCORD_DEVELOPER_PORTAL_URL, + guildId: binding?.guildId ?? null, + guildName: binding?.guildName ?? null, + adminDiscordUserId: binding?.adminDiscordUserId ?? null, + adminDiscordUsername: binding?.adminDiscordUsername ?? null, + adminDiscordDisplayName: binding?.adminDiscordDisplayName ?? null, + adminElizaUserId: binding?.adminElizaUserId ?? null, + botNickname: binding?.botNickname ?? null, + connectedAt: binding?.connectedAt ?? null, + }; +} + +export class ManagedMiladyDiscordService { + async getStatus(params: { + agentId: string; + organizationId: string; + configured: boolean; + applicationId: string | null; + }): Promise { + const sandbox = await miladySandboxesRepository.findByIdAndOrg( + params.agentId, + params.organizationId, + ); + if (!sandbox) { + return null; + } + + return toStatus( + (sandbox.agent_config as Record | null) ?? {}, + params.configured, + params.applicationId, + ); + } + + async connectAgent(params: { + agentId: string; + organizationId: string; + binding: ManagedMiladyDiscordBinding; + }): Promise<{ restarted: boolean; status: ManagedMiladyDiscordStatus }> { + const conflictingGuildLinks = await miladySandboxesRepository.findByManagedDiscordGuildId( + params.binding.guildId, + ); + const conflict = conflictingGuildLinks.find((sandbox) => sandbox.id !== params.agentId); + if (conflict) { + throw new Error("Discord server is already linked to another agent"); + } + + const sandbox = await miladySandboxesRepository.findByIdAndOrg( + params.agentId, + params.organizationId, + ); + if (!sandbox) { + throw new Error("Agent not found"); + } + + let nextConfig = withManagedMiladyDiscordBinding( + (sandbox.agent_config as Record | null) ?? {}, + params.binding, + ); + nextConfig = withDiscordConnectorAdmin(nextConfig, params.binding.adminDiscordUserId); + + await miladySandboxesRepository.update(sandbox.id, { + agent_config: nextConfig, + }); + + let restarted = false; + if (sandbox.status === "running") { + const shutdown = await miladySandboxService.shutdown(sandbox.id, params.organizationId); + if (!shutdown.success) { + throw new Error(shutdown.error || "Failed to restart agent"); + } + + const provision = await miladySandboxService.provision(sandbox.id, params.organizationId); + if (!provision.success) { + throw new Error(provision.error || "Failed to restart agent"); + } + restarted = true; + } + + logger.info("[managed-discord] Linked Discord to managed Milady agent", { + agentId: sandbox.id, + organizationId: params.organizationId, + guildId: params.binding.guildId, + adminDiscordUserId: params.binding.adminDiscordUserId, + restarted, + }); + + return { + restarted, + status: toStatus(nextConfig, true, params.binding.applicationId ?? null), + }; + } + + async disconnectAgent(params: { + agentId: string; + organizationId: string; + configured: boolean; + applicationId: string | null; + }): Promise<{ restarted: boolean; status: ManagedMiladyDiscordStatus }> { + const sandbox = await miladySandboxesRepository.findByIdAndOrg( + params.agentId, + params.organizationId, + ); + if (!sandbox) { + throw new Error("Agent not found"); + } + + let nextConfig = withoutManagedMiladyDiscordBinding( + (sandbox.agent_config as Record | null) ?? {}, + ); + nextConfig = withoutDiscordConnectorAdmin(nextConfig); + + await miladySandboxesRepository.update(sandbox.id, { + agent_config: nextConfig, + }); + + let restarted = false; + if (sandbox.status === "running") { + const shutdown = await miladySandboxService.shutdown(sandbox.id, params.organizationId); + if (!shutdown.success) { + throw new Error(shutdown.error || "Failed to restart agent"); + } + + const provision = await miladySandboxService.provision(sandbox.id, params.organizationId); + if (!provision.success) { + throw new Error(provision.error || "Failed to restart agent"); + } + restarted = true; + } + + logger.info("[managed-discord] Unlinked Discord from managed Milady agent", { + agentId: sandbox.id, + organizationId: params.organizationId, + restarted, + }); + + return { + restarted, + status: toStatus(nextConfig, params.configured, params.applicationId), + }; + } +} + +export const managedMiladyDiscordService = new ManagedMiladyDiscordService(); diff --git a/packages/services/gateway-discord/src/gateway-manager.ts b/packages/services/gateway-discord/src/gateway-manager.ts index 9ed1cc4be..06e64c9fe 100644 --- a/packages/services/gateway-discord/src/gateway-manager.ts +++ b/packages/services/gateway-discord/src/gateway-manager.ts @@ -1622,7 +1622,12 @@ export class GatewayManager { }); this.elizaAppClient = new Client({ - intents: [GatewayIntentBits.DirectMessages, GatewayIntentBits.MessageContent], + intents: [ + GatewayIntentBits.Guilds, + GatewayIntentBits.GuildMessages, + GatewayIntentBits.DirectMessages, + GatewayIntentBits.MessageContent, + ], // Partials required for DM support - DM channels are not cached by default partials: [Partials.Channel, Partials.Message], }); @@ -1682,11 +1687,14 @@ export class GatewayManager { /** * Handle a message received by the Eliza App bot. - * Filters to DM-only and forwards to the Eliza App webhook. + * Handles both DM identity routing and managed guild installs. */ private async handleElizaAppMessage(message: Message): Promise { if (message.author.bot) return; - if (message.guild) return; + if (message.guild) { + await this.handleManagedMiladyGuildMessage(message); + return; + } if (!message.content.trim()) return; if (!this.redis) { @@ -1772,6 +1780,119 @@ export class GatewayManager { } } + private async handleManagedMiladyGuildMessage(message: Message): Promise { + const botUserId = this.elizaAppClient?.user?.id; + if (!botUserId || !message.guildId) { + return; + } + + const trimmedContent = message.content.trim(); + if (!trimmedContent) { + return; + } + + const botMentionRegex = new RegExp(`<@!?${botUserId}>`, "g"); + const botMentioned = + message.mentions.users.has(botUserId) || botMentionRegex.test(trimmedContent); + if (!botMentioned) { + return; + } + + const mentionedOtherUser = message.mentions.users.some((user) => user.id !== botUserId); + const repliedUserId = message.mentions.repliedUser?.id; + const repliedToAnotherUser = Boolean(repliedUserId && repliedUserId !== botUserId); + if (mentionedOtherUser || message.mentions.everyone || repliedToAnotherUser) { + logger.debug("Ignoring managed guild message that targets someone else", { + guildId: message.guildId, + channelId: message.channelId, + messageId: message.id, + }); + return; + } + + const sanitizedContent = trimmedContent.replace(botMentionRegex, "").trim(); + if (!sanitizedContent) { + return; + } + + try { + if ("sendTyping" in message.channel) { + await message.channel.sendTyping(); + } + + const response = await fetchWithTimeout( + `${this.config.elizaCloudUrl}/api/internal/discord/eliza-app/messages`, + { + method: "POST", + headers: { + "Content-Type": "application/json", + ...this.getAuthHeader(), + }, + body: JSON.stringify({ + guildId: message.guildId, + channelId: message.channelId, + messageId: message.id, + content: sanitizedContent, + sender: { + id: message.author.id, + username: message.author.username, + displayName: + message.member?.displayName ?? message.author.globalName ?? undefined, + avatar: message.author.displayAvatarURL() || null, + }, + }), + timeout: EVENT_FORWARD_TIMEOUT_MS, + }, + ); + + if (!response.ok) { + const errorText = await response.text().catch(() => ""); + logger.warn("Managed Milady Discord routing request failed", { + guildId: message.guildId, + channelId: message.channelId, + status: response.status, + error: errorText.slice(0, 200), + }); + return; + } + + const routed = (await response.json()) as { + handled?: boolean; + replyText?: string | null; + reason?: string; + agentId?: string; + }; + + if (!routed.handled) { + logger.debug("Managed Milady Discord message was not handled", { + guildId: message.guildId, + channelId: message.channelId, + reason: routed.reason, + agentId: routed.agentId, + }); + return; + } + + if (!routed.replyText?.trim()) { + return; + } + + const replyText = routed.replyText.trim(); + const truncated = replyText.length > 2000 ? replyText.slice(0, 2000) : replyText; + await message.reply({ + content: truncated, + allowedMentions: { repliedUser: false }, + }); + } catch (error) { + logger.error("Failed to route managed Milady guild message", { + guildId: message.guildId, + channelId: message.channelId, + messageId: message.id, + error: sanitizeError(error), + }); + } + } + getHealth(): HealthStatus { const bots = [...this.connections.values()]; const connectedBots = bots.filter((c) => c.status === "connected").length; diff --git a/packages/tests/unit/managed-discord-eliza-app-route.test.ts b/packages/tests/unit/managed-discord-eliza-app-route.test.ts new file mode 100644 index 000000000..0b85a2de7 --- /dev/null +++ b/packages/tests/unit/managed-discord-eliza-app-route.test.ts @@ -0,0 +1,175 @@ +import { afterAll, beforeAll, beforeEach, describe, expect, mock, test } from "bun:test"; +import { NextRequest } from "next/server"; + +const mockFindByManagedDiscordGuildId = mock(); +const mockBridge = mock(); +const originalEnv = { ...process.env }; +let POST: typeof import("@/app/api/internal/discord/eliza-app/messages/route").POST; + +mock.module("@/db/repositories/milady-sandboxes", () => ({ + miladySandboxesRepository: { + findByManagedDiscordGuildId: mockFindByManagedDiscordGuildId, + }, +})); + +mock.module("@/lib/services/milady-sandbox", () => ({ + miladySandboxService: { + bridge: mockBridge, + }, +})); + +mock.module("@/lib/utils/logger", () => ({ + logger: { + info: mock(), + warn: mock(), + error: mock(), + debug: mock(), + }, +})); +const TEST_PRIVATE_KEY = `-----BEGIN PRIVATE KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgrQVTJ7WWtYbqub0Q +fLr2lzR+KLx0o6bljZjyK3+vmnehRANCAASqngGNae2HCVarjzxZ2mwfsM9Z8Us5 +tKQ751KrxuBykiNCX+Xo4twm4lFo2pNcJYVB7lRPNmFcjz8i2aDFOK/9 +-----END PRIVATE KEY-----`; +const TEST_PUBLIC_KEY = `-----BEGIN PUBLIC KEY----- +MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEqp4BjWnthwlWq488WdpsH7DPWfFL +ObSkO+dSq8bgcpIjQl/l6OLcJuJRaNqTXCWFQe5UTzZhXI8/ItmgxTiv/Q== +-----END PUBLIC KEY-----`; + +async function createInternalAuthHeader(): Promise { + const { signInternalToken } = await import("@/lib/auth/jwt-internal"); + const { access_token } = await signInternalToken({ + subject: "test-discord-gateway", + service: "discord-gateway", + }); + return `Bearer ${access_token}`; +} + +describe("managed Discord Eliza App routing route", () => { + beforeAll(async () => { + process.env.JWT_SIGNING_PRIVATE_KEY = Buffer.from(TEST_PRIVATE_KEY).toString("base64"); + process.env.JWT_SIGNING_PUBLIC_KEY = Buffer.from(TEST_PUBLIC_KEY).toString("base64"); + process.env.JWT_SIGNING_KEY_ID = "test-key-id"; + process.env.NODE_ENV = "test"; + ({ POST } = await import("@/app/api/internal/discord/eliza-app/messages/route")); + }); + + afterAll(() => { + process.env = { ...originalEnv }; + }); + + beforeEach(() => { + mockFindByManagedDiscordGuildId.mockReset(); + mockBridge.mockReset(); + }); + + test("routes a managed guild message into the linked Milady sandbox bridge", async () => { + const authHeader = await createInternalAuthHeader(); + mockFindByManagedDiscordGuildId.mockResolvedValue([ + { + id: "agent-1", + organization_id: "org-1", + status: "running", + }, + ]); + mockBridge.mockResolvedValue({ + jsonrpc: "2.0", + id: "rpc-1", + result: { text: "hello from agent" }, + }); + + const response = await POST( + new NextRequest("https://example.com/api/internal/discord/eliza-app/messages", { + method: "POST", + headers: { Authorization: authHeader, "Content-Type": "application/json" }, + body: JSON.stringify({ + guildId: "guild-1", + channelId: "channel-1", + messageId: "message-1", + content: "hello bot", + sender: { + id: "discord-user-1", + username: "owner", + displayName: "Owner Person", + avatar: "https://cdn.discordapp.com/avatar.png", + }, + }), + }), + { + params: Promise.resolve({}), + } as never, + ); + + expect(response.status).toBe(200); + expect(await response.json()).toEqual({ + handled: true, + replyText: "hello from agent", + agentId: "agent-1", + }); + expect(mockBridge).toHaveBeenCalledWith( + "agent-1", + "org-1", + expect.objectContaining({ + jsonrpc: "2.0", + method: "message.send", + params: expect.objectContaining({ + text: "hello bot", + roomId: "discord-guild:guild-1:channel:channel-1", + channelType: "GROUP", + source: "discord", + sender: { + id: "discord-user-1", + username: "owner", + displayName: "Owner Person", + metadata: { + discord: { + userId: "discord-user-1", + username: "owner", + globalName: "Owner Person", + avatar: "https://cdn.discordapp.com/avatar.png", + }, + }, + }, + metadata: { + discord: { + guildId: "guild-1", + channelId: "channel-1", + messageId: "message-1", + }, + }, + }), + }), + ); + }); + + test("returns handled=false when no agent is linked to the guild", async () => { + const authHeader = await createInternalAuthHeader(); + mockFindByManagedDiscordGuildId.mockResolvedValue([]); + + const response = await POST( + new NextRequest("https://example.com/api/internal/discord/eliza-app/messages", { + method: "POST", + headers: { Authorization: authHeader, "Content-Type": "application/json" }, + body: JSON.stringify({ + guildId: "guild-1", + channelId: "channel-1", + messageId: "message-1", + content: "hello bot", + sender: { + id: "discord-user-1", + username: "owner", + }, + }), + }), + { + params: Promise.resolve({}), + } as never, + ); + + expect(response.status).toBe(200); + expect(await response.json()).toEqual({ + handled: false, + reason: "not_linked", + }); + }); +}); diff --git a/packages/tests/unit/milady-agent-discord-routes.test.ts b/packages/tests/unit/milady-agent-discord-routes.test.ts new file mode 100644 index 000000000..56367f6af --- /dev/null +++ b/packages/tests/unit/milady-agent-discord-routes.test.ts @@ -0,0 +1,264 @@ +import { beforeEach, describe, expect, mock, test } from "bun:test"; +import { NextRequest } from "next/server"; +import { jsonRequest, routeParams } from "./api/route-test-helpers"; + +const mockRequireAuthOrApiKeyWithOrg = mock(); +const mockGetAgent = mock(); +const mockGetStatus = mock(); +const mockDisconnectAgent = mock(); +const mockConnectAgent = mock(); +const mockIsOAuthConfigured = mock(); +const mockGetApplicationId = mock(); +const mockGenerateOAuthUrl = mock(); +const mockDecodeOAuthState = mock(); +const mockHandleBotOAuthCallback = mock(); + +mock.module("@/lib/auth", () => ({ + requireAuthOrApiKeyWithOrg: mockRequireAuthOrApiKeyWithOrg, +})); + +mock.module("@/lib/services/milady-sandbox", () => ({ + miladySandboxService: { + getAgent: mockGetAgent, + }, +})); + +mock.module("@/lib/services/milady-managed-discord", () => ({ + managedMiladyDiscordService: { + getStatus: mockGetStatus, + disconnectAgent: mockDisconnectAgent, + connectAgent: mockConnectAgent, + }, +})); + +mock.module("@/lib/services/discord-automation", () => ({ + discordAutomationService: { + isOAuthConfigured: mockIsOAuthConfigured, + getApplicationId: mockGetApplicationId, + generateOAuthUrl: mockGenerateOAuthUrl, + decodeOAuthState: mockDecodeOAuthState, + handleBotOAuthCallback: mockHandleBotOAuthCallback, + }, +})); + +mock.module("@/lib/utils/logger", () => ({ + logger: { + info: mock(), + warn: mock(), + error: mock(), + debug: mock(), + }, +})); + +import { GET as discordCallbackGet } from "@/app/api/v1/discord/callback/route"; +import { + DELETE as deleteManagedDiscord, + GET as getManagedDiscord, +} from "@/app/api/v1/milady/agents/[agentId]/discord/route"; +import { POST as postManagedDiscordOauth } from "@/app/api/v1/milady/agents/[agentId]/discord/oauth/route"; + +describe("managed Milady Discord routes", () => { + beforeEach(() => { + mockRequireAuthOrApiKeyWithOrg.mockReset(); + mockGetAgent.mockReset(); + mockGetStatus.mockReset(); + mockDisconnectAgent.mockReset(); + mockConnectAgent.mockReset(); + mockIsOAuthConfigured.mockReset(); + mockGetApplicationId.mockReset(); + mockGenerateOAuthUrl.mockReset(); + mockDecodeOAuthState.mockReset(); + mockHandleBotOAuthCallback.mockReset(); + + mockRequireAuthOrApiKeyWithOrg.mockResolvedValue({ + user: { + id: "user-1", + organization_id: "org-1", + }, + }); + mockIsOAuthConfigured.mockReturnValue(true); + mockGetApplicationId.mockReturnValue("discord-app-1"); + }); + + test("POST /api/v1/milady/agents/[agentId]/discord/oauth returns an authorize URL for loopback Milady redirects", async () => { + mockGetAgent.mockResolvedValue({ + id: "agent-1", + agent_name: "Chen", + }); + mockGenerateOAuthUrl.mockReturnValue("https://discord.com/oauth2/authorize?mock=1"); + + const response = await postManagedDiscordOauth( + jsonRequest("https://example.com/api/v1/milady/agents/agent-1/discord/oauth", "POST", { + returnUrl: "http://127.0.0.1:31337/cloud?tab=agents", + botNickname: "Milady Chen", + }), + routeParams({ agentId: "agent-1" }), + ); + + expect(response.status).toBe(200); + expect(await response.json()).toEqual({ + success: true, + data: { + authorizeUrl: "https://discord.com/oauth2/authorize?mock=1", + applicationId: "discord-app-1", + }, + }); + expect(mockGenerateOAuthUrl).toHaveBeenCalledWith( + expect.objectContaining({ + agentId: "agent-1", + flow: "milady-managed", + botNickname: "Milady Chen", + organizationId: "org-1", + returnUrl: "http://127.0.0.1:31337/cloud?tab=agents", + userId: "user-1", + }), + ); + }); + + test("GET /api/v1/milady/agents/[agentId]/discord returns managed Discord status", async () => { + mockGetStatus.mockResolvedValue({ + applicationId: "discord-app-1", + configured: true, + connected: true, + developerPortalUrl: "https://discord.com/developers/applications", + guildId: "guild-1", + guildName: "Guild One", + adminDiscordUserId: "discord-user-1", + adminDiscordUsername: "owner", + adminDiscordDisplayName: "Owner", + adminElizaUserId: "user-1", + botNickname: "Milady", + connectedAt: "2026-04-04T16:00:00.000Z", + }); + + const response = await getManagedDiscord( + new NextRequest("https://example.com/api/v1/milady/agents/agent-1/discord"), + routeParams({ agentId: "agent-1" }), + ); + + expect(response.status).toBe(200); + expect(await response.json()).toEqual({ + success: true, + data: { + applicationId: "discord-app-1", + configured: true, + connected: true, + developerPortalUrl: "https://discord.com/developers/applications", + guildId: "guild-1", + guildName: "Guild One", + adminDiscordUserId: "discord-user-1", + adminDiscordUsername: "owner", + adminDiscordDisplayName: "Owner", + adminElizaUserId: "user-1", + botNickname: "Milady", + connectedAt: "2026-04-04T16:00:00.000Z", + }, + }); + }); + + test("GET /api/v1/discord/callback links managed Discord installs back to the Milady agent", async () => { + mockDecodeOAuthState.mockReturnValue({ + flow: "milady-managed", + agentId: "agent-1", + organizationId: "org-1", + userId: "user-1", + returnUrl: "http://127.0.0.1:31337/cloud?tab=agents", + botNickname: "Milady Chen", + }); + mockHandleBotOAuthCallback.mockResolvedValue({ + success: true, + guildId: "guild-1", + guildName: "Guild One", + discordUser: { + id: "discord-user-1", + username: "owner", + globalName: "Owner Person", + avatar: null, + }, + }); + mockConnectAgent.mockResolvedValue({ + restarted: true, + status: { + connected: true, + }, + }); + + const response = await discordCallbackGet( + new NextRequest( + "https://example.com/api/v1/discord/callback?code=oauth-code&state=signed-state&guild_id=guild-1", + ), + ); + + expect(response.status).toBe(307); + expect(response.headers.get("location")).toContain( + "http://127.0.0.1:31337/cloud?tab=agents", + ); + expect(response.headers.get("location")).toContain("managed=1"); + expect(response.headers.get("location")).toContain("agentId=agent-1"); + expect(response.headers.get("location")).toContain("guildId=guild-1"); + expect(response.headers.get("location")).toContain("restarted=1"); + expect(mockConnectAgent).toHaveBeenCalledWith({ + agentId: "agent-1", + organizationId: "org-1", + binding: { + mode: "cloud-managed", + applicationId: "discord-app-1", + guildId: "guild-1", + guildName: "Guild One", + adminDiscordUserId: "discord-user-1", + adminDiscordUsername: "owner", + adminDiscordDisplayName: "Owner Person", + adminElizaUserId: "user-1", + botNickname: "Milady Chen", + connectedAt: expect.any(String), + }, + }); + }); + + test("DELETE /api/v1/milady/agents/[agentId]/discord disconnects managed Discord", async () => { + mockDisconnectAgent.mockResolvedValue({ + restarted: false, + status: { + applicationId: "discord-app-1", + configured: true, + connected: false, + developerPortalUrl: "https://discord.com/developers/applications", + guildId: null, + guildName: null, + adminDiscordUserId: null, + adminDiscordUsername: null, + adminDiscordDisplayName: null, + adminElizaUserId: null, + botNickname: null, + connectedAt: null, + }, + }); + + const response = await deleteManagedDiscord( + new NextRequest("https://example.com/api/v1/milady/agents/agent-1/discord", { + method: "DELETE", + }), + routeParams({ agentId: "agent-1" }), + ); + + expect(response.status).toBe(200); + expect(await response.json()).toEqual({ + success: true, + data: { + applicationId: "discord-app-1", + configured: true, + connected: false, + developerPortalUrl: "https://discord.com/developers/applications", + guildId: null, + guildName: null, + adminDiscordUserId: null, + adminDiscordUsername: null, + adminDiscordDisplayName: null, + adminElizaUserId: null, + botNickname: null, + connectedAt: null, + restarted: false, + }, + }); + }); +}); diff --git a/packages/tests/unit/milady-managed-discord.test.ts b/packages/tests/unit/milady-managed-discord.test.ts new file mode 100644 index 000000000..871852cc5 --- /dev/null +++ b/packages/tests/unit/milady-managed-discord.test.ts @@ -0,0 +1,65 @@ +import { describe, expect, test } from "bun:test"; +import { + MILADY_CHARACTER_OWNERSHIP_KEY, + readManagedMiladyDiscordBinding, + withManagedMiladyDiscordBinding, + withoutManagedMiladyDiscordBinding, +} from "@/lib/services/milady-agent-config"; + +describe("managed Milady Discord config helpers", () => { + test("writes and reads the managed Discord binding payload", () => { + const config = withManagedMiladyDiscordBinding( + { + existing: true, + [MILADY_CHARACTER_OWNERSHIP_KEY]: "reuse-existing", + }, + { + mode: "cloud-managed", + applicationId: "discord-app-1", + guildId: "guild-1", + guildName: "Guild One", + adminDiscordUserId: "discord-user-1", + adminDiscordUsername: "owner", + adminDiscordDisplayName: "Owner Person", + adminElizaUserId: "user-1", + botNickname: "Milady", + connectedAt: "2026-04-04T16:00:00.000Z", + }, + ); + + expect(readManagedMiladyDiscordBinding(config)).toEqual({ + mode: "cloud-managed", + applicationId: "discord-app-1", + guildId: "guild-1", + guildName: "Guild One", + adminDiscordUserId: "discord-user-1", + adminDiscordUsername: "owner", + adminDiscordDisplayName: "Owner Person", + adminElizaUserId: "user-1", + botNickname: "Milady", + connectedAt: "2026-04-04T16:00:00.000Z", + }); + expect(config[MILADY_CHARACTER_OWNERSHIP_KEY]).toBe("reuse-existing"); + }); + + test("removes only the managed Discord binding", () => { + const config = withoutManagedMiladyDiscordBinding({ + existing: true, + [MILADY_CHARACTER_OWNERSHIP_KEY]: "reuse-existing", + __miladyManagedDiscord: { + guildId: "guild-1", + guildName: "Guild One", + adminDiscordUserId: "discord-user-1", + adminDiscordUsername: "owner", + adminElizaUserId: "user-1", + connectedAt: "2026-04-04T16:00:00.000Z", + }, + }); + + expect(readManagedMiladyDiscordBinding(config)).toBeNull(); + expect(config).toEqual({ + existing: true, + [MILADY_CHARACTER_OWNERSHIP_KEY]: "reuse-existing", + }); + }); +}); From 077659dda3d4fcb539557aefd71f0253f3bbd649 Mon Sep 17 00:00:00 2001 From: Shaw Date: Sat, 4 Apr 2026 21:32:08 -0700 Subject: [PATCH 02/11] cloud: format managed discord routes --- app/api/v1/discord/callback/route.ts | 18 ++++++++---------- packages/db/repositories/milady-sandboxes.ts | 2 +- .../lib/services/milady-managed-discord.ts | 2 +- .../gateway-discord/src/gateway-manager.ts | 3 +-- .../unit/milady-agent-discord-routes.test.ts | 6 ++---- 5 files changed, 13 insertions(+), 18 deletions(-) diff --git a/app/api/v1/discord/callback/route.ts b/app/api/v1/discord/callback/route.ts index 5638c7324..e6c462eda 100644 --- a/app/api/v1/discord/callback/route.ts +++ b/app/api/v1/discord/callback/route.ts @@ -64,16 +64,14 @@ export async function GET(request: NextRequest): Promise { const baseUrl = process.env.NEXT_PUBLIC_APP_URL || "http://localhost:3000"; // Parse state for return URL (do this early for error redirects) - let decodedState: - | { - returnUrl?: string; - flow?: "organization-install" | "milady-managed"; - agentId?: string; - organizationId?: string; - userId?: string; - botNickname?: string; - } - | null = null; + let decodedState: { + returnUrl?: string; + flow?: "organization-install" | "milady-managed"; + agentId?: string; + organizationId?: string; + userId?: string; + botNickname?: string; + } | null = null; let returnTarget = resolveOAuthReturnTarget(baseUrl, undefined, false); if (state) { try { diff --git a/packages/db/repositories/milady-sandboxes.ts b/packages/db/repositories/milady-sandboxes.ts index 39ec6002a..0d1f68b5d 100644 --- a/packages/db/repositories/milady-sandboxes.ts +++ b/packages/db/repositories/milady-sandboxes.ts @@ -1,6 +1,5 @@ import { and, desc, eq, inArray, notInArray, sql } from "drizzle-orm"; import { dbRead, dbWrite } from "@/db/helpers"; -import { MILADY_MANAGED_DISCORD_KEY } from "@/lib/services/milady-agent-config"; import { type MiladyBackupSnapshotType, type MiladySandbox, @@ -11,6 +10,7 @@ import { type NewMiladySandbox, type NewMiladySandboxBackup, } from "@/db/schemas/milady-sandboxes"; +import { MILADY_MANAGED_DISCORD_KEY } from "@/lib/services/milady-agent-config"; export type { MiladyBackupSnapshotType, diff --git a/packages/lib/services/milady-managed-discord.ts b/packages/lib/services/milady-managed-discord.ts index 480e1f781..0a561c9b3 100644 --- a/packages/lib/services/milady-managed-discord.ts +++ b/packages/lib/services/milady-managed-discord.ts @@ -1,4 +1,5 @@ import { miladySandboxesRepository } from "@/db/repositories/milady-sandboxes"; +import { miladySandboxService } from "@/lib/services/milady-sandbox"; import { logger } from "@/lib/utils/logger"; import { type ManagedMiladyDiscordBinding, @@ -6,7 +7,6 @@ import { withManagedMiladyDiscordBinding, withoutManagedMiladyDiscordBinding, } from "./milady-agent-config"; -import { miladySandboxService } from "@/lib/services/milady-sandbox"; const ROLES_PLUGIN_ID = "@miladyai/plugin-roles"; export const DISCORD_DEVELOPER_PORTAL_URL = "https://discord.com/developers/applications"; diff --git a/packages/services/gateway-discord/src/gateway-manager.ts b/packages/services/gateway-discord/src/gateway-manager.ts index 06e64c9fe..bd047722d 100644 --- a/packages/services/gateway-discord/src/gateway-manager.ts +++ b/packages/services/gateway-discord/src/gateway-manager.ts @@ -1836,8 +1836,7 @@ export class GatewayManager { sender: { id: message.author.id, username: message.author.username, - displayName: - message.member?.displayName ?? message.author.globalName ?? undefined, + displayName: message.member?.displayName ?? message.author.globalName ?? undefined, avatar: message.author.displayAvatarURL() || null, }, }), diff --git a/packages/tests/unit/milady-agent-discord-routes.test.ts b/packages/tests/unit/milady-agent-discord-routes.test.ts index 56367f6af..a615220de 100644 --- a/packages/tests/unit/milady-agent-discord-routes.test.ts +++ b/packages/tests/unit/milady-agent-discord-routes.test.ts @@ -51,11 +51,11 @@ mock.module("@/lib/utils/logger", () => ({ })); import { GET as discordCallbackGet } from "@/app/api/v1/discord/callback/route"; +import { POST as postManagedDiscordOauth } from "@/app/api/v1/milady/agents/[agentId]/discord/oauth/route"; import { DELETE as deleteManagedDiscord, GET as getManagedDiscord, } from "@/app/api/v1/milady/agents/[agentId]/discord/route"; -import { POST as postManagedDiscordOauth } from "@/app/api/v1/milady/agents/[agentId]/discord/oauth/route"; describe("managed Milady Discord routes", () => { beforeEach(() => { @@ -190,9 +190,7 @@ describe("managed Milady Discord routes", () => { ); expect(response.status).toBe(307); - expect(response.headers.get("location")).toContain( - "http://127.0.0.1:31337/cloud?tab=agents", - ); + expect(response.headers.get("location")).toContain("http://127.0.0.1:31337/cloud?tab=agents"); expect(response.headers.get("location")).toContain("managed=1"); expect(response.headers.get("location")).toContain("agentId=agent-1"); expect(response.headers.get("location")).toContain("guildId=guild-1"); From 99c88df16ed15b3553c01b2267e12bb218d0080d Mon Sep 17 00:00:00 2001 From: Shaw Date: Sat, 4 Apr 2026 21:49:11 -0700 Subject: [PATCH 03/11] cloud: fix managed discord callback typing --- app/api/v1/discord/callback/route.ts | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/app/api/v1/discord/callback/route.ts b/app/api/v1/discord/callback/route.ts index e6c462eda..bab39f521 100644 --- a/app/api/v1/discord/callback/route.ts +++ b/app/api/v1/discord/callback/route.ts @@ -13,6 +13,7 @@ import { sanitizeRelativeRedirectPath, } from "@/lib/security/redirect-validation"; import { discordAutomationService } from "@/lib/services/discord-automation"; +import type { OAuthState } from "@/lib/services/discord-automation/types"; import { managedMiladyDiscordService } from "@/lib/services/milady-managed-discord"; import { logger } from "@/lib/utils/logger"; @@ -64,14 +65,7 @@ export async function GET(request: NextRequest): Promise { const baseUrl = process.env.NEXT_PUBLIC_APP_URL || "http://localhost:3000"; // Parse state for return URL (do this early for error redirects) - let decodedState: { - returnUrl?: string; - flow?: "organization-install" | "milady-managed"; - agentId?: string; - organizationId?: string; - userId?: string; - botNickname?: string; - } | null = null; + let decodedState: OAuthState | null = null; let returnTarget = resolveOAuthReturnTarget(baseUrl, undefined, false); if (state) { try { From 4e6c6985bb74e9595b4bf161b1faaa3a5980ebda Mon Sep 17 00:00:00 2001 From: Shaw Date: Sat, 4 Apr 2026 22:11:23 -0700 Subject: [PATCH 04/11] cloud: fix managed discord route test env mutation --- .../tests/unit/managed-discord-eliza-app-route.test.ts | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/packages/tests/unit/managed-discord-eliza-app-route.test.ts b/packages/tests/unit/managed-discord-eliza-app-route.test.ts index 0b85a2de7..574141e52 100644 --- a/packages/tests/unit/managed-discord-eliza-app-route.test.ts +++ b/packages/tests/unit/managed-discord-eliza-app-route.test.ts @@ -4,6 +4,7 @@ import { NextRequest } from "next/server"; const mockFindByManagedDiscordGuildId = mock(); const mockBridge = mock(); const originalEnv = { ...process.env }; +const mutableEnv = process.env as Record; let POST: typeof import("@/app/api/internal/discord/eliza-app/messages/route").POST; mock.module("@/db/repositories/milady-sandboxes", () => ({ @@ -47,10 +48,10 @@ async function createInternalAuthHeader(): Promise { describe("managed Discord Eliza App routing route", () => { beforeAll(async () => { - process.env.JWT_SIGNING_PRIVATE_KEY = Buffer.from(TEST_PRIVATE_KEY).toString("base64"); - process.env.JWT_SIGNING_PUBLIC_KEY = Buffer.from(TEST_PUBLIC_KEY).toString("base64"); - process.env.JWT_SIGNING_KEY_ID = "test-key-id"; - process.env.NODE_ENV = "test"; + mutableEnv.JWT_SIGNING_PRIVATE_KEY = Buffer.from(TEST_PRIVATE_KEY).toString("base64"); + mutableEnv.JWT_SIGNING_PUBLIC_KEY = Buffer.from(TEST_PUBLIC_KEY).toString("base64"); + mutableEnv.JWT_SIGNING_KEY_ID = "test-key-id"; + mutableEnv.NODE_ENV = "test"; ({ POST } = await import("@/app/api/internal/discord/eliza-app/messages/route")); }); From 55bd2e14be9a9c09285fd17f2df5551359e6593c Mon Sep 17 00:00:00 2001 From: Shaw Date: Sat, 4 Apr 2026 23:55:53 -0700 Subject: [PATCH 05/11] cloud: add managed Google Milady routes --- .../v1/milady/google/calendar/events/route.ts | 65 ++ .../v1/milady/google/calendar/feed/route.ts | 44 + .../milady/google/connect/initiate/route.ts | 55 + app/api/v1/milady/google/disconnect/route.ts | 43 + .../milady/google/gmail/reply-send/route.ts | 53 + .../v1/milady/google/gmail/triage/route.ts | 41 + app/api/v1/milady/google/status/route.ts | 30 + .../lib/services/milady-google-connector.ts | 940 ++++++++++++++++++ .../unit/milady-google-connector.test.ts | 369 +++++++ .../tests/unit/milady-google-routes.test.ts | 226 +++++ 10 files changed, 1866 insertions(+) create mode 100644 app/api/v1/milady/google/calendar/events/route.ts create mode 100644 app/api/v1/milady/google/calendar/feed/route.ts create mode 100644 app/api/v1/milady/google/connect/initiate/route.ts create mode 100644 app/api/v1/milady/google/disconnect/route.ts create mode 100644 app/api/v1/milady/google/gmail/reply-send/route.ts create mode 100644 app/api/v1/milady/google/gmail/triage/route.ts create mode 100644 app/api/v1/milady/google/status/route.ts create mode 100644 packages/lib/services/milady-google-connector.ts create mode 100644 packages/tests/unit/milady-google-connector.test.ts create mode 100644 packages/tests/unit/milady-google-routes.test.ts diff --git a/app/api/v1/milady/google/calendar/events/route.ts b/app/api/v1/milady/google/calendar/events/route.ts new file mode 100644 index 000000000..fb5f71dc1 --- /dev/null +++ b/app/api/v1/milady/google/calendar/events/route.ts @@ -0,0 +1,65 @@ +import type { NextRequest } from "next/server"; +import { NextResponse } from "next/server"; +import { z } from "zod"; +import { requireAuthOrApiKeyWithOrg } from "@/lib/auth"; +import { + createManagedGoogleCalendarEvent, + MiladyGoogleConnectorError, +} from "@/lib/services/milady-google-connector"; + +export const dynamic = "force-dynamic"; +export const maxDuration = 30; + +const attendeeSchema = z.object({ + email: z.string().email(), + displayName: z.string().trim().min(1).optional(), + optional: z.boolean().optional(), +}); + +const requestSchema = z.object({ + calendarId: z.string().trim().min(1).optional(), + title: z.string().trim().min(1), + description: z.string().optional(), + location: z.string().optional(), + startAt: z.string().trim().min(1), + endAt: z.string().trim().min(1), + timeZone: z.string().trim().min(1), + attendees: z.array(attendeeSchema).optional(), +}); + +export async function POST(request: NextRequest) { + try { + const { user } = await requireAuthOrApiKeyWithOrg(request); + const parsed = requestSchema.safeParse(await request.json()); + if (!parsed.success) { + return NextResponse.json( + { error: "Invalid calendar event request.", details: parsed.error.issues }, + { status: 400 }, + ); + } + + return NextResponse.json( + await createManagedGoogleCalendarEvent({ + organizationId: user.organization_id, + userId: user.id, + calendarId: parsed.data.calendarId ?? "primary", + title: parsed.data.title, + description: parsed.data.description, + location: parsed.data.location, + startAt: parsed.data.startAt, + endAt: parsed.data.endAt, + timeZone: parsed.data.timeZone, + attendees: parsed.data.attendees, + }), + { status: 201 }, + ); + } catch (error) { + if (error instanceof MiladyGoogleConnectorError) { + return NextResponse.json({ error: error.message }, { status: error.status }); + } + return NextResponse.json( + { error: error instanceof Error ? error.message : "Failed to create Google Calendar event." }, + { status: 500 }, + ); + } +} diff --git a/app/api/v1/milady/google/calendar/feed/route.ts b/app/api/v1/milady/google/calendar/feed/route.ts new file mode 100644 index 000000000..e6e1ce051 --- /dev/null +++ b/app/api/v1/milady/google/calendar/feed/route.ts @@ -0,0 +1,44 @@ +import type { NextRequest } from "next/server"; +import { NextResponse } from "next/server"; +import { requireAuthOrApiKeyWithOrg } from "@/lib/auth"; +import { + fetchManagedGoogleCalendarFeed, + MiladyGoogleConnectorError, +} from "@/lib/services/milady-google-connector"; + +export const dynamic = "force-dynamic"; +export const maxDuration = 30; + +export async function GET(request: NextRequest) { + try { + const { user } = await requireAuthOrApiKeyWithOrg(request); + const searchParams = request.nextUrl.searchParams; + const calendarId = searchParams.get("calendarId")?.trim() || "primary"; + const timeMin = searchParams.get("timeMin")?.trim(); + const timeMax = searchParams.get("timeMax")?.trim(); + const timeZone = searchParams.get("timeZone")?.trim() || "UTC"; + + if (!timeMin || !timeMax) { + return NextResponse.json({ error: "timeMin and timeMax are required." }, { status: 400 }); + } + + return NextResponse.json( + await fetchManagedGoogleCalendarFeed({ + organizationId: user.organization_id, + userId: user.id, + calendarId, + timeMin, + timeMax, + timeZone, + }), + ); + } catch (error) { + if (error instanceof MiladyGoogleConnectorError) { + return NextResponse.json({ error: error.message }, { status: error.status }); + } + return NextResponse.json( + { error: error instanceof Error ? error.message : "Failed to fetch Google Calendar." }, + { status: 500 }, + ); + } +} diff --git a/app/api/v1/milady/google/connect/initiate/route.ts b/app/api/v1/milady/google/connect/initiate/route.ts new file mode 100644 index 000000000..3c7410d99 --- /dev/null +++ b/app/api/v1/milady/google/connect/initiate/route.ts @@ -0,0 +1,55 @@ +import type { NextRequest } from "next/server"; +import { NextResponse } from "next/server"; +import { z } from "zod"; +import { requireAuthOrApiKeyWithOrg } from "@/lib/auth"; +import { + initiateManagedGoogleConnection, + MiladyGoogleConnectorError, +} from "@/lib/services/milady-google-connector"; + +export const dynamic = "force-dynamic"; +export const maxDuration = 30; + +const requestSchema = z.object({ + redirectUrl: z.string().trim().min(1).optional(), + capabilities: z + .array( + z.enum([ + "google.basic_identity", + "google.calendar.read", + "google.calendar.write", + "google.gmail.triage", + "google.gmail.send", + ]), + ) + .optional(), +}); + +export async function POST(request: NextRequest) { + try { + const { user } = await requireAuthOrApiKeyWithOrg(request); + const parsed = requestSchema.safeParse(await request.json().catch(() => ({}))); + if (!parsed.success) { + return NextResponse.json( + { error: "Invalid Google connector request.", details: parsed.error.issues }, + { status: 400 }, + ); + } + return NextResponse.json( + await initiateManagedGoogleConnection({ + organizationId: user.organization_id, + userId: user.id, + redirectUrl: parsed.data.redirectUrl, + capabilities: parsed.data.capabilities, + }), + ); + } catch (error) { + if (error instanceof MiladyGoogleConnectorError) { + return NextResponse.json({ error: error.message }, { status: error.status }); + } + return NextResponse.json( + { error: error instanceof Error ? error.message : "Failed to initiate Google OAuth." }, + { status: 500 }, + ); + } +} diff --git a/app/api/v1/milady/google/disconnect/route.ts b/app/api/v1/milady/google/disconnect/route.ts new file mode 100644 index 000000000..b0c411e15 --- /dev/null +++ b/app/api/v1/milady/google/disconnect/route.ts @@ -0,0 +1,43 @@ +import type { NextRequest } from "next/server"; +import { NextResponse } from "next/server"; +import { z } from "zod"; +import { requireAuthOrApiKeyWithOrg } from "@/lib/auth"; +import { + disconnectManagedGoogleConnection, + MiladyGoogleConnectorError, +} from "@/lib/services/milady-google-connector"; + +export const dynamic = "force-dynamic"; +export const maxDuration = 30; + +const requestSchema = z.object({ + connectionId: z.string().uuid().nullable().optional(), +}); + +export async function POST(request: NextRequest) { + try { + const { user } = await requireAuthOrApiKeyWithOrg(request); + const parsed = requestSchema.safeParse(await request.json().catch(() => ({}))); + if (!parsed.success) { + return NextResponse.json( + { error: "Invalid disconnect request.", details: parsed.error.issues }, + { status: 400 }, + ); + } + + await disconnectManagedGoogleConnection({ + organizationId: user.organization_id, + userId: user.id, + connectionId: parsed.data.connectionId ?? null, + }); + return NextResponse.json({ ok: true }); + } catch (error) { + if (error instanceof MiladyGoogleConnectorError) { + return NextResponse.json({ error: error.message }, { status: error.status }); + } + return NextResponse.json( + { error: error instanceof Error ? error.message : "Failed to disconnect Google." }, + { status: 500 }, + ); + } +} diff --git a/app/api/v1/milady/google/gmail/reply-send/route.ts b/app/api/v1/milady/google/gmail/reply-send/route.ts new file mode 100644 index 000000000..81d1dfa2c --- /dev/null +++ b/app/api/v1/milady/google/gmail/reply-send/route.ts @@ -0,0 +1,53 @@ +import type { NextRequest } from "next/server"; +import { NextResponse } from "next/server"; +import { z } from "zod"; +import { requireAuthOrApiKeyWithOrg } from "@/lib/auth"; +import { + MiladyGoogleConnectorError, + sendManagedGoogleReply, +} from "@/lib/services/milady-google-connector"; + +export const dynamic = "force-dynamic"; +export const maxDuration = 30; + +const requestSchema = z.object({ + to: z.array(z.string().email()).min(1), + cc: z.array(z.string().email()).optional(), + subject: z.string().trim().min(1), + bodyText: z.string().min(1), + inReplyTo: z.string().trim().min(1).nullable().optional(), + references: z.string().trim().min(1).nullable().optional(), +}); + +export async function POST(request: NextRequest) { + try { + const { user } = await requireAuthOrApiKeyWithOrg(request); + const parsed = requestSchema.safeParse(await request.json()); + if (!parsed.success) { + return NextResponse.json( + { error: "Invalid Gmail send request.", details: parsed.error.issues }, + { status: 400 }, + ); + } + + await sendManagedGoogleReply({ + organizationId: user.organization_id, + userId: user.id, + to: parsed.data.to, + cc: parsed.data.cc, + subject: parsed.data.subject, + bodyText: parsed.data.bodyText, + inReplyTo: parsed.data.inReplyTo ?? null, + references: parsed.data.references ?? null, + }); + return NextResponse.json({ ok: true }); + } catch (error) { + if (error instanceof MiladyGoogleConnectorError) { + return NextResponse.json({ error: error.message }, { status: error.status }); + } + return NextResponse.json( + { error: error instanceof Error ? error.message : "Failed to send Gmail reply." }, + { status: 500 }, + ); + } +} diff --git a/app/api/v1/milady/google/gmail/triage/route.ts b/app/api/v1/milady/google/gmail/triage/route.ts new file mode 100644 index 000000000..f94c52e14 --- /dev/null +++ b/app/api/v1/milady/google/gmail/triage/route.ts @@ -0,0 +1,41 @@ +import type { NextRequest } from "next/server"; +import { NextResponse } from "next/server"; +import { requireAuthOrApiKeyWithOrg } from "@/lib/auth"; +import { + fetchManagedGoogleGmailTriage, + MiladyGoogleConnectorError, +} from "@/lib/services/milady-google-connector"; + +export const dynamic = "force-dynamic"; +export const maxDuration = 30; + +export async function GET(request: NextRequest) { + try { + const { user } = await requireAuthOrApiKeyWithOrg(request); + const rawMaxResults = request.nextUrl.searchParams.get("maxResults"); + const maxResults = + rawMaxResults && rawMaxResults.trim().length > 0 ? Number.parseInt(rawMaxResults, 10) : 12; + if (!Number.isFinite(maxResults) || maxResults <= 0) { + return NextResponse.json( + { error: "maxResults must be a positive integer." }, + { status: 400 }, + ); + } + + return NextResponse.json( + await fetchManagedGoogleGmailTriage({ + organizationId: user.organization_id, + userId: user.id, + maxResults, + }), + ); + } catch (error) { + if (error instanceof MiladyGoogleConnectorError) { + return NextResponse.json({ error: error.message }, { status: error.status }); + } + return NextResponse.json( + { error: error instanceof Error ? error.message : "Failed to fetch Gmail triage." }, + { status: 500 }, + ); + } +} diff --git a/app/api/v1/milady/google/status/route.ts b/app/api/v1/milady/google/status/route.ts new file mode 100644 index 000000000..7bf93c3a7 --- /dev/null +++ b/app/api/v1/milady/google/status/route.ts @@ -0,0 +1,30 @@ +import type { NextRequest } from "next/server"; +import { NextResponse } from "next/server"; +import { requireAuthOrApiKeyWithOrg } from "@/lib/auth"; +import { + getManagedGoogleConnectorStatus, + MiladyGoogleConnectorError, +} from "@/lib/services/milady-google-connector"; + +export const dynamic = "force-dynamic"; +export const maxDuration = 30; + +export async function GET(request: NextRequest) { + try { + const { user } = await requireAuthOrApiKeyWithOrg(request); + return NextResponse.json( + await getManagedGoogleConnectorStatus({ + organizationId: user.organization_id, + userId: user.id, + }), + ); + } catch (error) { + if (error instanceof MiladyGoogleConnectorError) { + return NextResponse.json({ error: error.message }, { status: error.status }); + } + return NextResponse.json( + { error: error instanceof Error ? error.message : "Failed to resolve Google status." }, + { status: 500 }, + ); + } +} diff --git a/packages/lib/services/milady-google-connector.ts b/packages/lib/services/milady-google-connector.ts new file mode 100644 index 000000000..d7ae4bff0 --- /dev/null +++ b/packages/lib/services/milady-google-connector.ts @@ -0,0 +1,940 @@ +import { and, eq } from "drizzle-orm"; +import { dbRead } from "@/db/client"; +import { platformCredentials } from "@/db/schemas/platform-credentials"; +import { oauthService } from "@/lib/services/oauth"; +import { getPreferredActiveConnection } from "@/lib/services/oauth/oauth-service"; +import { getProvider, isProviderConfigured } from "@/lib/services/oauth/provider-registry"; +import { + applyTimeZone, + googleFetchWithToken, + sanitizeHeaderValue, +} from "@/lib/utils/google-mcp-shared"; + +const GOOGLE_CALENDAR_EVENTS_ENDPOINT = "https://www.googleapis.com/calendar/v3/calendars"; +const GOOGLE_GMAIL_MESSAGES_ENDPOINT = "https://gmail.googleapis.com/gmail/v1/users/me/messages"; +const GOOGLE_GMAIL_SEND_ENDPOINT = `${GOOGLE_GMAIL_MESSAGES_ENDPOINT}/send`; +const DEFAULT_GOOGLE_CONNECTOR_CAPABILITIES = [ + "google.basic_identity", + "google.calendar.read", +] as const; +const GMAIL_METADATA_HEADERS = [ + "Subject", + "From", + "To", + "Cc", + "Date", + "Reply-To", + "Message-Id", + "References", + "List-Id", + "Precedence", + "Auto-Submitted", +] as const; + +export type MiladyGoogleCapability = + | "google.basic_identity" + | "google.calendar.read" + | "google.calendar.write" + | "google.gmail.triage" + | "google.gmail.send"; + +export interface ManagedGoogleConnectorStatus { + provider: "google"; + mode: "cloud_managed"; + configured: boolean; + connected: boolean; + reason: "connected" | "disconnected" | "config_missing" | "token_missing" | "needs_reauth"; + identity: Record | null; + grantedCapabilities: MiladyGoogleCapability[]; + grantedScopes: string[]; + expiresAt: string | null; + hasRefreshToken: boolean; + connectionId: string | null; + linkedAt: string | null; + lastUsedAt: string | null; +} + +export interface ManagedGoogleCalendarEvent { + externalId: string; + calendarId: string; + title: string; + description: string; + location: string; + status: string; + startAt: string; + endAt: string; + isAllDay: boolean; + timezone: string | null; + htmlLink: string | null; + conferenceLink: string | null; + organizer: Record | null; + attendees: Array<{ + email: string | null; + displayName: string | null; + responseStatus: string | null; + self: boolean; + organizer: boolean; + optional: boolean; + }>; + metadata: Record; +} + +export interface ManagedGoogleGmailMessage { + externalId: string; + threadId: string; + subject: string; + from: string; + fromEmail: string | null; + replyTo: string | null; + to: string[]; + cc: string[]; + snippet: string; + receivedAt: string; + isUnread: boolean; + isImportant: boolean; + likelyReplyNeeded: boolean; + triageScore: number; + triageReason: string; + labels: string[]; + htmlLink: string | null; + metadata: Record; +} + +export class MiladyGoogleConnectorError extends Error { + constructor( + public readonly status: number, + message: string, + ) { + super(message); + this.name = "MiladyGoogleConnectorError"; + } +} + +type GoogleConnectionRow = typeof platformCredentials.$inferSelect; + +type GoogleCalendarEventDate = { + date?: string; + dateTime?: string; + timeZone?: string; +}; + +type GoogleCalendarApiEvent = { + id?: string; + status?: string; + summary?: string; + description?: string; + location?: string; + htmlLink?: string; + hangoutLink?: string; + iCalUID?: string; + recurringEventId?: string; + created?: string; + start?: GoogleCalendarEventDate; + end?: GoogleCalendarEventDate; + organizer?: { + email?: string; + displayName?: string; + self?: boolean; + }; + attendees?: Array<{ + email?: string; + displayName?: string; + responseStatus?: string; + self?: boolean; + organizer?: boolean; + optional?: boolean; + }>; + conferenceData?: { + entryPoints?: Array<{ + uri?: string; + }>; + }; +}; + +type GoogleGmailMetadataHeader = { + name?: string; + value?: string; +}; + +type GoogleGmailMetadataResponse = { + id?: string; + threadId?: string; + labelIds?: string[]; + snippet?: string; + internalDate?: string; + historyId?: string; + sizeEstimate?: number; + payload?: { + headers?: GoogleGmailMetadataHeader[]; + }; +}; + +type GoogleGmailListResponse = { + messages?: Array<{ + id?: string; + threadId?: string; + }>; +}; + +function fail(status: number, message: string): never { + throw new MiladyGoogleConnectorError(status, message); +} + +function normalizeCapabilities( + requested?: readonly MiladyGoogleCapability[], +): MiladyGoogleCapability[] { + const source = requested ?? DEFAULT_GOOGLE_CONNECTOR_CAPABILITIES; + const normalized = [...new Set(source)]; + return normalized.includes("google.basic_identity") + ? normalized + : ["google.basic_identity", ...normalized]; +} + +function capabilitiesToScopes(capabilities: readonly MiladyGoogleCapability[]): string[] { + const scopes = new Set(["openid", "email", "profile"]); + + for (const capability of normalizeCapabilities(capabilities)) { + if (capability === "google.calendar.read") { + scopes.add("https://www.googleapis.com/auth/calendar.readonly"); + } + if (capability === "google.calendar.write") { + scopes.add("https://www.googleapis.com/auth/calendar.events"); + } + if (capability === "google.gmail.triage") { + scopes.add("https://www.googleapis.com/auth/gmail.metadata"); + } + if (capability === "google.gmail.send") { + scopes.add("https://www.googleapis.com/auth/gmail.send"); + } + } + + return [...scopes]; +} + +function scopesToCapabilities(scopes: readonly string[]): MiladyGoogleCapability[] { + const granted = new Set(scopes); + const capabilities: MiladyGoogleCapability[] = []; + const hasIdentity = + granted.has("openid") || + granted.has("email") || + granted.has("profile") || + granted.has("https://www.googleapis.com/auth/userinfo.email") || + granted.has("https://www.googleapis.com/auth/userinfo.profile"); + if (hasIdentity) { + capabilities.push("google.basic_identity"); + } + if ( + granted.has("https://www.googleapis.com/auth/calendar.readonly") || + granted.has("https://www.googleapis.com/auth/calendar.events") || + granted.has("https://www.googleapis.com/auth/calendar") + ) { + capabilities.push("google.calendar.read"); + } + if ( + granted.has("https://www.googleapis.com/auth/calendar.events") || + granted.has("https://www.googleapis.com/auth/calendar") + ) { + capabilities.push("google.calendar.write"); + } + if ( + granted.has("https://www.googleapis.com/auth/gmail.metadata") || + granted.has("https://www.googleapis.com/auth/gmail.readonly") || + granted.has("https://www.googleapis.com/auth/gmail.modify") || + granted.has("https://www.googleapis.com/auth/gmail.compose") + ) { + capabilities.push("google.gmail.triage"); + } + if (granted.has("https://www.googleapis.com/auth/gmail.send")) { + capabilities.push("google.gmail.send"); + } + return normalizeCapabilities(capabilities); +} + +async function getConnectionRow( + organizationId: string, + connectionId: string, +): Promise { + const [row] = await dbRead + .select() + .from(platformCredentials) + .where( + and( + eq(platformCredentials.organization_id, organizationId), + eq(platformCredentials.id, connectionId), + ), + ) + .limit(1); + return row ?? null; +} + +async function getScopedGoogleConnections(args: { organizationId: string; userId: string }) { + return oauthService.listConnections({ + organizationId: args.organizationId, + userId: args.userId, + platform: "google", + }); +} + +async function getActiveGoogleConnectionRecord(args: { organizationId: string; userId: string }) { + const connections = await getScopedGoogleConnections(args); + const activeConnection = getPreferredActiveConnection(connections, args.userId); + const latestConnection = connections[0] ?? null; + const activeRow = activeConnection + ? await getConnectionRow(args.organizationId, activeConnection.id) + : null; + const latestRow = + latestConnection && latestConnection.id !== activeConnection?.id + ? await getConnectionRow(args.organizationId, latestConnection.id) + : activeRow; + + return { + connections, + activeConnection, + latestConnection, + activeRow, + latestRow, + }; +} + +async function getGoogleAccessToken(args: { + organizationId: string; + userId: string; +}): Promise<{ accessToken: string; connectionId: string }> { + try { + return await oauthService + .getValidTokenByPlatformWithConnectionId({ + organizationId: args.organizationId, + userId: args.userId, + platform: "google", + }) + .then((result) => ({ + accessToken: result.token.accessToken, + connectionId: result.connectionId, + })); + } catch (error) { + const message = error instanceof Error ? error.message : String(error); + fail(409, message); + } +} + +async function googleFetch(args: { + organizationId: string; + userId: string; + url: string; + options?: RequestInit; +}): Promise { + const { accessToken } = await getGoogleAccessToken(args); + try { + return await googleFetchWithToken(accessToken, args.url, args.options); + } catch (error) { + fail(502, error instanceof Error ? error.message : String(error)); + } +} + +function readGoogleEventInstant( + value: GoogleCalendarEventDate | undefined, +): { iso: string; isAllDay: boolean; timeZone: string | null } | null { + if (!value) return null; + if (value.dateTime?.trim()) { + return { + iso: new Date(value.dateTime).toISOString(), + isAllDay: false, + timeZone: value.timeZone?.trim() || null, + }; + } + if (value.date?.trim()) { + return { + iso: new Date(`${value.date}T00:00:00.000Z`).toISOString(), + isAllDay: true, + timeZone: value.timeZone?.trim() || null, + }; + } + return null; +} + +function readConferenceLink(event: GoogleCalendarApiEvent): string | null { + if (event.hangoutLink?.trim()) { + return event.hangoutLink.trim(); + } + return event.conferenceData?.entryPoints?.find((entry) => entry.uri?.trim())?.uri?.trim() || null; +} + +function normalizeGoogleCalendarEvent( + calendarId: string, + event: GoogleCalendarApiEvent, +): ManagedGoogleCalendarEvent | null { + const externalId = event.id?.trim(); + const start = readGoogleEventInstant(event.start); + const end = readGoogleEventInstant(event.end); + if (!externalId || !start || !end) { + return null; + } + + return { + externalId, + calendarId, + title: event.summary?.trim() || "Untitled event", + description: event.description?.trim() || "", + location: event.location?.trim() || "", + status: event.status?.trim() || "confirmed", + startAt: start.iso, + endAt: end.iso, + isAllDay: start.isAllDay, + timezone: start.timeZone || end.timeZone, + htmlLink: event.htmlLink?.trim() || null, + conferenceLink: readConferenceLink(event), + organizer: event.organizer + ? { + email: event.organizer.email?.trim() || null, + displayName: event.organizer.displayName?.trim() || null, + self: Boolean(event.organizer.self), + } + : null, + attendees: (event.attendees ?? []).map((attendee) => ({ + email: attendee.email?.trim() || null, + displayName: attendee.displayName?.trim() || null, + responseStatus: attendee.responseStatus?.trim() || null, + self: Boolean(attendee.self), + organizer: Boolean(attendee.organizer), + optional: Boolean(attendee.optional), + })), + metadata: { + iCalUID: event.iCalUID?.trim() || null, + recurringEventId: event.recurringEventId?.trim() || null, + createdAt: event.created?.trim() || null, + }, + }; +} + +function splitMailboxHeader(value: string): string[] { + const parts: string[] = []; + let current = ""; + let inQuotes = false; + let angleDepth = 0; + + for (const char of value) { + if (char === '"') { + inQuotes = !inQuotes; + current += char; + continue; + } + if (!inQuotes && char === "<") { + angleDepth += 1; + current += char; + continue; + } + if (!inQuotes && char === ">") { + angleDepth = Math.max(0, angleDepth - 1); + current += char; + continue; + } + if (!inQuotes && angleDepth === 0 && char === ",") { + const trimmed = current.trim(); + if (trimmed.length > 0) { + parts.push(trimmed); + } + current = ""; + continue; + } + current += char; + } + + const trimmed = current.trim(); + if (trimmed.length > 0) { + parts.push(trimmed); + } + return parts; +} + +function stripQuotedDisplayName(value: string): string { + const trimmed = value.trim(); + if (trimmed.startsWith('"') && trimmed.endsWith('"') && trimmed.length >= 2) { + return trimmed.slice(1, -1).trim(); + } + return trimmed; +} + +function parseMailbox(value: string): { display: string; email: string | null } { + const trimmed = value.trim(); + const match = trimmed.match(/^(.*?)(?:<([^>]+)>)$/); + if (match) { + const display = stripQuotedDisplayName(match[1] ?? "").trim(); + const email = (match[2] ?? "").trim().toLowerCase(); + return { + display: display || email, + email: email.length > 0 ? email : null, + }; + } + const normalized = stripQuotedDisplayName(trimmed); + if (/^[^@\s]+@[^@\s]+\.[^@\s]+$/.test(normalized)) { + return { + display: normalized, + email: normalized.toLowerCase(), + }; + } + return { + display: normalized, + email: null, + }; +} + +function parseMailboxList(value: string | undefined) { + if (!value) return []; + return splitMailboxHeader(value) + .map((entry) => parseMailbox(entry)) + .filter((entry) => entry.display.length > 0 || entry.email !== null); +} + +function readHeaderValue( + headers: GoogleGmailMetadataHeader[] | undefined, + name: string, +): string | undefined { + const lowerName = name.toLowerCase(); + const header = headers?.find((candidate) => candidate.name?.trim().toLowerCase() === lowerName); + const value = header?.value?.trim(); + return value && value.length > 0 ? value : undefined; +} + +function normalizeSnippet(value: string | undefined): string { + return value?.replace(/\s+/g, " ").trim() || ""; +} + +function deriveHtmlLink(threadId: string): string { + return `https://mail.google.com/mail/u/0/#all/${encodeURIComponent(threadId)}`; +} + +function classifyReplyNeed(args: { + labels: string[]; + fromEmail: string | null; + to: string[]; + cc: string[]; + selfEmail: string | null; + precedence: string | undefined; + listId: string | undefined; + autoSubmitted: string | undefined; +}) { + const labels = new Set(args.labels.map((label) => label.trim().toUpperCase())); + const isUnread = labels.has("UNREAD"); + const explicitlyImportant = labels.has("IMPORTANT"); + const selfEmail = args.selfEmail?.trim().toLowerCase() || null; + const fromEmail = args.fromEmail?.trim().toLowerCase() || null; + const directRecipients = [...args.to, ...args.cc].map((entry) => entry.trim().toLowerCase()); + const directlyAddressed = selfEmail ? directRecipients.includes(selfEmail) : false; + const fromSelf = Boolean(selfEmail && fromEmail && selfEmail === fromEmail); + const precedence = args.precedence?.trim().toLowerCase(); + const autoSubmitted = args.autoSubmitted?.trim().toLowerCase(); + const automated = + Boolean( + fromEmail && + /(?:^|\b)(?:no-?reply|donotreply|notifications?|mailer-daemon)(?:\b|@)/i.test(fromEmail), + ) || + Boolean(args.listId) || + precedence === "bulk" || + precedence === "list" || + precedence === "junk" || + (autoSubmitted !== undefined && autoSubmitted !== "no"); + + let triageScore = 0; + const reasons: string[] = []; + + if (isUnread) { + triageScore += 30; + reasons.push("unread"); + } + if (explicitlyImportant) { + triageScore += 35; + reasons.push("important label"); + } + if (directlyAddressed) { + triageScore += 15; + reasons.push("directly addressed"); + } + if (!automated && !fromSelf && isUnread && directlyAddressed) { + triageScore += 30; + reasons.push("likely needs reply"); + } + if (automated) { + triageScore -= 25; + reasons.push("automated sender"); + } + if (fromSelf) { + triageScore -= 60; + reasons.push("sent by self"); + } + + return { + likelyReplyNeeded: !automated && !fromSelf && isUnread && directlyAddressed, + isImportant: explicitlyImportant || (!automated && !fromSelf && isUnread && directlyAddressed), + triageScore: Math.max(0, triageScore), + triageReason: reasons.join(", ") || "recent inbox message", + }; +} + +function normalizeGoogleGmailMessage( + message: GoogleGmailMetadataResponse, + selfEmail: string | null, +): ManagedGoogleGmailMessage | null { + const externalId = message.id?.trim(); + const threadId = message.threadId?.trim(); + if (!externalId || !threadId) { + return null; + } + + const headers = message.payload?.headers ?? []; + const subject = readHeaderValue(headers, "Subject") || "(no subject)"; + const fromHeader = readHeaderValue(headers, "From") || "Unknown sender"; + const fromMailbox = parseMailbox(fromHeader); + const replyToHeader = readHeaderValue(headers, "Reply-To"); + const replyToMailbox = replyToHeader ? parseMailbox(replyToHeader) : null; + const to = parseMailboxList(readHeaderValue(headers, "To")).map( + (entry) => entry.email || entry.display, + ); + const cc = parseMailboxList(readHeaderValue(headers, "Cc")).map( + (entry) => entry.email || entry.display, + ); + const labels = (message.labelIds ?? []).map((label) => label.trim()).filter(Boolean); + const receivedAtMs = Number(message.internalDate); + const receivedAt = Number.isFinite(receivedAtMs) + ? new Date(receivedAtMs).toISOString() + : new Date().toISOString(); + const precedence = readHeaderValue(headers, "Precedence"); + const listId = readHeaderValue(headers, "List-Id"); + const autoSubmitted = readHeaderValue(headers, "Auto-Submitted"); + const triage = classifyReplyNeed({ + labels, + fromEmail: fromMailbox.email, + to, + cc, + selfEmail, + precedence, + listId, + autoSubmitted, + }); + + return { + externalId, + threadId, + subject, + from: fromMailbox.display, + fromEmail: fromMailbox.email, + replyTo: replyToMailbox?.email || replyToMailbox?.display || null, + to, + cc, + snippet: normalizeSnippet(message.snippet), + receivedAt, + isUnread: labels.includes("UNREAD"), + isImportant: triage.isImportant, + likelyReplyNeeded: triage.likelyReplyNeeded, + triageScore: triage.triageScore, + triageReason: triage.triageReason, + labels, + htmlLink: deriveHtmlLink(threadId), + metadata: { + historyId: message.historyId?.trim() || null, + sizeEstimate: typeof message.sizeEstimate === "number" ? message.sizeEstimate : null, + dateHeader: readHeaderValue(headers, "Date") || null, + messageIdHeader: readHeaderValue(headers, "Message-Id") || null, + referencesHeader: readHeaderValue(headers, "References") || null, + listId: listId || null, + precedence: precedence || null, + autoSubmitted: autoSubmitted || null, + }, + }; +} + +function normalizeReplySubject(subject: string): string { + const trimmed = subject.trim(); + if (trimmed.length === 0) { + return "Re: your message"; + } + return /^re:/i.test(trimmed) ? trimmed : `Re: ${trimmed}`; +} + +export async function getManagedGoogleConnectorStatus(args: { + organizationId: string; + userId: string; +}): Promise { + const provider = getProvider("google"); + const configured = provider ? isProviderConfigured(provider) : false; + + if (!configured) { + return { + provider: "google", + mode: "cloud_managed", + configured: false, + connected: false, + reason: "config_missing", + identity: null, + grantedCapabilities: [], + grantedScopes: [], + expiresAt: null, + hasRefreshToken: false, + connectionId: null, + linkedAt: null, + lastUsedAt: null, + }; + } + + const { activeConnection, latestConnection, activeRow, latestRow } = + await getActiveGoogleConnectionRecord(args); + const currentConnection = activeConnection ?? latestConnection ?? null; + const currentRow = activeRow ?? latestRow ?? null; + + if (!currentConnection) { + return { + provider: "google", + mode: "cloud_managed", + configured: true, + connected: false, + reason: "disconnected", + identity: null, + grantedCapabilities: [], + grantedScopes: [], + expiresAt: null, + hasRefreshToken: false, + connectionId: null, + linkedAt: null, + lastUsedAt: null, + }; + } + + const connected = currentConnection.status === "active"; + const reason = connected + ? "connected" + : currentConnection.status === "expired" || currentConnection.status === "error" + ? "needs_reauth" + : "disconnected"; + + return { + provider: "google", + mode: "cloud_managed", + configured: true, + connected, + reason, + identity: { + id: currentConnection.platformUserId, + email: currentConnection.email ?? null, + name: currentConnection.displayName ?? currentConnection.username ?? null, + avatarUrl: currentConnection.avatarUrl ?? null, + }, + grantedCapabilities: scopesToCapabilities(currentConnection.scopes), + grantedScopes: [...currentConnection.scopes], + expiresAt: currentRow?.token_expires_at?.toISOString() ?? null, + hasRefreshToken: Boolean(currentRow?.refresh_token_secret_id), + connectionId: currentConnection.id, + linkedAt: currentConnection.linkedAt.toISOString(), + lastUsedAt: currentConnection.lastUsedAt?.toISOString() ?? null, + }; +} + +export async function initiateManagedGoogleConnection(args: { + organizationId: string; + userId: string; + redirectUrl?: string; + capabilities?: MiladyGoogleCapability[]; +}) { + const requestedCapabilities = normalizeCapabilities(args.capabilities); + const auth = await oauthService.initiateAuth({ + organizationId: args.organizationId, + userId: args.userId, + platform: "google", + redirectUrl: args.redirectUrl, + scopes: capabilitiesToScopes(requestedCapabilities), + }); + return { + provider: "google" as const, + mode: "cloud_managed" as const, + requestedCapabilities, + redirectUri: args.redirectUrl ?? "/auth/success?platform=google", + authUrl: auth.authUrl, + }; +} + +export async function disconnectManagedGoogleConnection(args: { + organizationId: string; + userId: string; + connectionId?: string | null; +}): Promise { + const connections = await getScopedGoogleConnections(args); + const activeConnection = + (args.connectionId + ? connections.find((connection) => connection.id === args.connectionId) + : getPreferredActiveConnection(connections, args.userId)) ?? + connections[0] ?? + null; + if (!activeConnection) { + return; + } + await oauthService.revokeConnection({ + organizationId: args.organizationId, + connectionId: activeConnection.id, + }); +} + +export async function fetchManagedGoogleCalendarFeed(args: { + organizationId: string; + userId: string; + calendarId: string; + timeMin: string; + timeMax: string; + timeZone: string; +}): Promise<{ calendarId: string; events: ManagedGoogleCalendarEvent[]; syncedAt: string }> { + const params = new URLSearchParams({ + singleEvents: "true", + orderBy: "startTime", + showDeleted: "false", + maxResults: "50", + timeMin: args.timeMin, + timeMax: args.timeMax, + fields: + "items(id,status,summary,description,location,htmlLink,hangoutLink,iCalUID,recurringEventId,created,start,end,organizer(email,displayName,self),attendees(email,displayName,responseStatus,self,organizer,optional),conferenceData(entryPoints(uri)))", + timeZone: args.timeZone, + }); + + const response = await googleFetch({ + organizationId: args.organizationId, + userId: args.userId, + url: `${GOOGLE_CALENDAR_EVENTS_ENDPOINT}/${encodeURIComponent(args.calendarId)}/events?${params.toString()}`, + }); + const parsed = (await response.json()) as { items?: GoogleCalendarApiEvent[] }; + return { + calendarId: args.calendarId, + events: (parsed.items ?? []) + .map((event) => normalizeGoogleCalendarEvent(args.calendarId, event)) + .filter((event): event is ManagedGoogleCalendarEvent => event !== null), + syncedAt: new Date().toISOString(), + }; +} + +export async function createManagedGoogleCalendarEvent(args: { + organizationId: string; + userId: string; + calendarId: string; + title: string; + description?: string; + location?: string; + startAt: string; + endAt: string; + timeZone: string; + attendees?: Array<{ + email: string; + displayName?: string; + optional?: boolean; + }>; +}): Promise<{ event: ManagedGoogleCalendarEvent }> { + const response = await googleFetch({ + organizationId: args.organizationId, + userId: args.userId, + url: `${GOOGLE_CALENDAR_EVENTS_ENDPOINT}/${encodeURIComponent(args.calendarId)}/events?conferenceDataVersion=1`, + options: { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + summary: args.title, + description: args.description ?? "", + location: args.location ?? "", + start: applyTimeZone(args.startAt, args.timeZone), + end: applyTimeZone(args.endAt, args.timeZone), + attendees: args.attendees ?? [], + }), + }, + }); + const parsed = (await response.json()) as GoogleCalendarApiEvent; + const event = normalizeGoogleCalendarEvent(args.calendarId, parsed); + if (!event) { + fail(502, "Google Calendar returned an incomplete event payload."); + } + return { event }; +} + +export async function fetchManagedGoogleGmailTriage(args: { + organizationId: string; + userId: string; + maxResults: number; +}): Promise<{ messages: ManagedGoogleGmailMessage[]; syncedAt: string }> { + const maxResults = Math.min(Math.max(args.maxResults, 1), 50); + const connectorStatus = await getManagedGoogleConnectorStatus({ + organizationId: args.organizationId, + userId: args.userId, + }); + const selfEmail = + connectorStatus.identity && typeof connectorStatus.identity.email === "string" + ? connectorStatus.identity.email + : null; + const listParams = new URLSearchParams({ + maxResults: String(maxResults), + includeSpamTrash: "false", + }); + listParams.append("labelIds", "INBOX"); + + const listResponse = await googleFetch({ + organizationId: args.organizationId, + userId: args.userId, + url: `${GOOGLE_GMAIL_MESSAGES_ENDPOINT}?${listParams.toString()}`, + }); + const listed = (await listResponse.json()) as GoogleGmailListResponse; + + const messages = await Promise.all( + (listed.messages ?? []).map(async (messageRef) => { + const messageId = messageRef.id?.trim(); + if (!messageId) return null; + const params = new URLSearchParams({ format: "metadata" }); + for (const header of GMAIL_METADATA_HEADERS) { + params.append("metadataHeaders", header); + } + const response = await googleFetch({ + organizationId: args.organizationId, + userId: args.userId, + url: `${GOOGLE_GMAIL_MESSAGES_ENDPOINT}/${encodeURIComponent(messageId)}?${params.toString()}`, + }); + const parsed = (await response.json()) as GoogleGmailMetadataResponse; + return normalizeGoogleGmailMessage(parsed, selfEmail); + }), + ); + + return { + messages: messages + .filter((message): message is ManagedGoogleGmailMessage => message !== null) + .sort((left, right) => { + const scoreDelta = right.triageScore - left.triageScore; + if (scoreDelta !== 0) return scoreDelta; + return Date.parse(right.receivedAt) - Date.parse(left.receivedAt); + }), + syncedAt: new Date().toISOString(), + }; +} + +export async function sendManagedGoogleReply(args: { + organizationId: string; + userId: string; + to: string[]; + cc?: string[]; + subject: string; + bodyText: string; + inReplyTo?: string | null; + references?: string | null; +}): Promise { + const lines = [ + `To: ${sanitizeHeaderValue(args.to.join(", "))}`, + ...(args.cc && args.cc.length > 0 ? [`Cc: ${sanitizeHeaderValue(args.cc.join(", "))}`] : []), + `Subject: ${sanitizeHeaderValue(normalizeReplySubject(args.subject))}`, + "MIME-Version: 1.0", + "Content-Type: text/plain; charset=UTF-8", + ...(args.inReplyTo ? [`In-Reply-To: ${sanitizeHeaderValue(args.inReplyTo)}`] : []), + ...(args.references ? [`References: ${sanitizeHeaderValue(args.references)}`] : []), + "", + args.bodyText.replace(/\r?\n/g, "\r\n"), + ]; + const raw = Buffer.from(lines.join("\r\n"), "utf-8").toString("base64url"); + + await googleFetch({ + organizationId: args.organizationId, + userId: args.userId, + url: GOOGLE_GMAIL_SEND_ENDPOINT, + options: { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ raw }), + }, + }); +} diff --git a/packages/tests/unit/milady-google-connector.test.ts b/packages/tests/unit/milady-google-connector.test.ts new file mode 100644 index 000000000..8515ffd1f --- /dev/null +++ b/packages/tests/unit/milady-google-connector.test.ts @@ -0,0 +1,369 @@ +import { beforeEach, describe, expect, mock, test } from "bun:test"; +import type { OAuthConnection } from "@/lib/services/oauth/types"; + +const mockListConnections = mock(); +const mockGetValidTokenByPlatformWithConnectionId = mock(); +const mockInitiateAuth = mock(); +const mockRevokeConnection = mock(); +const mockGoogleFetchWithToken = mock(); +const mockGetProvider = mock(); +const mockIsProviderConfigured = mock(); +const mockDbLimit = mock(); +const mockDbWhere = mock(() => ({ limit: mockDbLimit })); +const mockDbFrom = mock(() => ({ where: mockDbWhere })); +const mockDbSelect = mock(() => ({ from: mockDbFrom })); + +mock.module("drizzle-orm", () => ({ + and: (...args: unknown[]) => args, + eq: (left: unknown, right: unknown) => ({ left, right }), +})); + +mock.module("@/db/client", () => ({ + dbRead: { + select: mockDbSelect, + }, +})); + +mock.module("@/db/schemas/platform-credentials", () => ({ + platformCredentials: { + organization_id: "organization_id", + id: "id", + }, +})); + +mock.module("@/lib/services/oauth", () => ({ + oauthService: { + listConnections: mockListConnections, + getValidTokenByPlatformWithConnectionId: mockGetValidTokenByPlatformWithConnectionId, + initiateAuth: mockInitiateAuth, + revokeConnection: mockRevokeConnection, + }, +})); + +mock.module("@/lib/services/oauth/oauth-service", () => ({ + getPreferredActiveConnection: (connections: OAuthConnection[], userId?: string) => + connections.find( + (connection) => connection.status === "active" && (!userId || connection.userId === userId), + ) ?? + connections.find((connection) => connection.status === "active") ?? + null, +})); + +mock.module("@/lib/services/oauth/provider-registry", () => ({ + getProvider: mockGetProvider, + isProviderConfigured: mockIsProviderConfigured, +})); + +mock.module("@/lib/utils/google-mcp-shared", () => ({ + applyTimeZone: (dateTime: string, timeZone: string | undefined) => + timeZone ? { dateTime, timeZone } : { dateTime }, + googleFetchWithToken: mockGoogleFetchWithToken, + sanitizeHeaderValue: (value: string) => value.replace(/[\r\n]/g, ""), +})); + +import { + disconnectManagedGoogleConnection, + fetchManagedGoogleCalendarFeed, + fetchManagedGoogleGmailTriage, + getManagedGoogleConnectorStatus, + initiateManagedGoogleConnection, + sendManagedGoogleReply, +} from "@/lib/services/milady-google-connector"; + +function createConnection(overrides: Partial = {}): OAuthConnection { + return { + id: "conn-google-1", + userId: "user-1", + platform: "google", + platformUserId: "google-user-1", + email: "founder@example.com", + username: "founder", + displayName: "Founder Example", + avatarUrl: "https://example.com/avatar.png", + status: "active", + scopes: [ + "openid", + "email", + "profile", + "https://www.googleapis.com/auth/calendar.readonly", + "https://www.googleapis.com/auth/gmail.metadata", + "https://www.googleapis.com/auth/gmail.send", + ], + linkedAt: new Date("2026-04-04T15:00:00.000Z"), + lastUsedAt: new Date("2026-04-04T16:00:00.000Z"), + tokenExpired: false, + source: "platform_credentials", + ...overrides, + }; +} + +describe("milady Google connector service", () => { + beforeEach(() => { + mockListConnections.mockReset(); + mockGetValidTokenByPlatformWithConnectionId.mockReset(); + mockInitiateAuth.mockReset(); + mockRevokeConnection.mockReset(); + mockGoogleFetchWithToken.mockReset(); + mockGetProvider.mockReset(); + mockIsProviderConfigured.mockReset(); + mockDbLimit.mockReset(); + mockDbWhere.mockClear(); + mockDbFrom.mockClear(); + mockDbSelect.mockClear(); + + mockGetProvider.mockReturnValue({ id: "google" }); + mockIsProviderConfigured.mockReturnValue(true); + mockDbLimit.mockResolvedValue([ + { + token_expires_at: new Date("2026-04-05T00:00:00.000Z"), + refresh_token_secret_id: "refresh-secret-1", + }, + ]); + mockGetValidTokenByPlatformWithConnectionId.mockResolvedValue({ + token: { + accessToken: "google-access-token", + }, + connectionId: "conn-google-1", + }); + }); + + test("reports managed Google connector status from the active user-scoped connection", async () => { + mockListConnections.mockResolvedValue([createConnection()]); + + const status = await getManagedGoogleConnectorStatus({ + organizationId: "org-1", + userId: "user-1", + }); + + expect(status).toEqual({ + provider: "google", + mode: "cloud_managed", + configured: true, + connected: true, + reason: "connected", + identity: { + id: "google-user-1", + email: "founder@example.com", + name: "Founder Example", + avatarUrl: "https://example.com/avatar.png", + }, + grantedCapabilities: [ + "google.basic_identity", + "google.calendar.read", + "google.gmail.triage", + "google.gmail.send", + ], + grantedScopes: [ + "openid", + "email", + "profile", + "https://www.googleapis.com/auth/calendar.readonly", + "https://www.googleapis.com/auth/gmail.metadata", + "https://www.googleapis.com/auth/gmail.send", + ], + expiresAt: "2026-04-05T00:00:00.000Z", + hasRefreshToken: true, + connectionId: "conn-google-1", + linkedAt: "2026-04-04T15:00:00.000Z", + lastUsedAt: "2026-04-04T16:00:00.000Z", + }); + }); + + test("initiates managed Google auth with the requested Milady capability scopes", async () => { + mockInitiateAuth.mockResolvedValue({ + authUrl: "https://accounts.google.com/o/oauth2/v2/auth?state=managed-google", + }); + + const result = await initiateManagedGoogleConnection({ + organizationId: "org-1", + userId: "user-1", + redirectUrl: "https://www.elizacloud.ai/auth/success?platform=google", + capabilities: ["google.calendar.read", "google.gmail.triage", "google.gmail.send"], + }); + + expect(mockInitiateAuth).toHaveBeenCalledWith({ + organizationId: "org-1", + userId: "user-1", + platform: "google", + redirectUrl: "https://www.elizacloud.ai/auth/success?platform=google", + scopes: [ + "openid", + "email", + "profile", + "https://www.googleapis.com/auth/calendar.readonly", + "https://www.googleapis.com/auth/gmail.metadata", + "https://www.googleapis.com/auth/gmail.send", + ], + }); + expect(result.mode).toBe("cloud_managed"); + expect(result.requestedCapabilities).toEqual([ + "google.basic_identity", + "google.calendar.read", + "google.gmail.triage", + "google.gmail.send", + ]); + }); + + test("normalizes Google Calendar events into the Milady managed feed shape", async () => { + mockGoogleFetchWithToken.mockResolvedValueOnce( + new Response( + JSON.stringify({ + items: [ + { + id: "event-1", + summary: "Founder sync", + description: "Review the launch plan", + location: "HQ", + status: "confirmed", + htmlLink: "https://calendar.google.com/event?eid=event-1", + start: { + dateTime: "2026-04-04T10:00:00-07:00", + timeZone: "America/Los_Angeles", + }, + end: { + dateTime: "2026-04-04T10:30:00-07:00", + timeZone: "America/Los_Angeles", + }, + organizer: { + email: "founder@example.com", + displayName: "Founder Example", + }, + attendees: [ + { + email: "teammate@example.com", + displayName: "Teammate", + responseStatus: "accepted", + }, + ], + }, + ], + }), + ), + ); + + const feed = await fetchManagedGoogleCalendarFeed({ + organizationId: "org-1", + userId: "user-1", + calendarId: "primary", + timeMin: "2026-04-04T00:00:00.000Z", + timeMax: "2026-04-05T00:00:00.000Z", + timeZone: "America/Los_Angeles", + }); + + expect(feed.calendarId).toBe("primary"); + expect(feed.events).toHaveLength(1); + expect(feed.events[0]).toMatchObject({ + externalId: "event-1", + calendarId: "primary", + title: "Founder sync", + description: "Review the launch plan", + location: "HQ", + timezone: "America/Los_Angeles", + }); + }); + + test("classifies Gmail triage messages using the connected Google identity", async () => { + mockListConnections.mockResolvedValue([createConnection()]); + mockGoogleFetchWithToken + .mockResolvedValueOnce( + new Response( + JSON.stringify({ + messages: [{ id: "msg-1", threadId: "thread-1" }], + }), + ), + ) + .mockResolvedValueOnce( + new Response( + JSON.stringify({ + id: "msg-1", + threadId: "thread-1", + labelIds: ["INBOX", "UNREAD", "IMPORTANT"], + snippet: "Can you review the plan today?", + internalDate: "1775327400000", + historyId: "history-1", + sizeEstimate: 1234, + payload: { + headers: [ + { name: "Subject", value: "Project sync" }, + { name: "From", value: "CEO Example " }, + { name: "To", value: "founder@example.com" }, + { name: "Reply-To", value: "ceo@example.com" }, + { name: "Message-Id", value: "" }, + ], + }, + }), + ), + ); + + const triage = await fetchManagedGoogleGmailTriage({ + organizationId: "org-1", + userId: "user-1", + maxResults: 5, + }); + + expect(triage.messages).toHaveLength(1); + expect(triage.messages[0]).toMatchObject({ + externalId: "msg-1", + threadId: "thread-1", + subject: "Project sync", + fromEmail: "ceo@example.com", + replyTo: "ceo@example.com", + isUnread: true, + isImportant: true, + likelyReplyNeeded: true, + }); + expect(triage.messages[0]?.triageReason).toContain("unread"); + }); + + test("sends Gmail replies with sanitized RFC822 headers", async () => { + mockGoogleFetchWithToken.mockResolvedValueOnce(new Response(null, { status: 200 })); + + await sendManagedGoogleReply({ + organizationId: "org-1", + userId: "user-1", + to: ["founder@example.com"], + cc: ["ops@example.com"], + subject: "Project sync", + bodyText: "Reviewing it now.", + inReplyTo: "", + references: "", + }); + + expect(mockGoogleFetchWithToken).toHaveBeenCalledTimes(1); + const [, url, options] = mockGoogleFetchWithToken.mock.calls[0] as [ + string, + string, + { body?: string }, + ]; + expect(url).toBe("https://gmail.googleapis.com/gmail/v1/users/me/messages/send"); + const payload = JSON.parse(String(options.body)) as { raw: string }; + const decoded = Buffer.from(payload.raw, "base64url").toString("utf-8"); + expect(decoded).toContain("To: founder@example.com"); + expect(decoded).toContain("Cc: ops@example.com"); + expect(decoded).toContain("Subject: Re: Project sync"); + expect(decoded).toContain("In-Reply-To: "); + expect(decoded).toContain("References: "); + expect(decoded).toContain("Reviewing it now."); + }); + + test("disconnects the preferred active Google connection for the user", async () => { + mockListConnections.mockResolvedValue([ + createConnection({ id: "conn-google-1" }), + createConnection({ + id: "conn-google-2", + status: "revoked", + linkedAt: new Date("2026-04-03T15:00:00.000Z"), + }), + ]); + + await disconnectManagedGoogleConnection({ + organizationId: "org-1", + userId: "user-1", + }); + + expect(mockRevokeConnection).toHaveBeenCalledWith({ + organizationId: "org-1", + connectionId: "conn-google-1", + }); + }); +}); diff --git a/packages/tests/unit/milady-google-routes.test.ts b/packages/tests/unit/milady-google-routes.test.ts new file mode 100644 index 000000000..a64e34fe3 --- /dev/null +++ b/packages/tests/unit/milady-google-routes.test.ts @@ -0,0 +1,226 @@ +import { beforeEach, describe, expect, mock, test } from "bun:test"; +import { NextRequest } from "next/server"; +import { jsonRequest } from "./api/route-test-helpers"; + +const mockRequireAuthOrApiKeyWithOrg = mock(); +const mockGetStatus = mock(); +const mockInitiateConnection = mock(); +const mockDisconnectConnection = mock(); +const mockFetchCalendarFeed = mock(); +const mockCreateCalendarEvent = mock(); +const mockFetchGmailTriage = mock(); +const mockSendReply = mock(); + +mock.module("@/lib/auth", () => ({ + requireAuthOrApiKeyWithOrg: mockRequireAuthOrApiKeyWithOrg, +})); + +mock.module("@/lib/services/milady-google-connector", () => ({ + getManagedGoogleConnectorStatus: mockGetStatus, + initiateManagedGoogleConnection: mockInitiateConnection, + disconnectManagedGoogleConnection: mockDisconnectConnection, + fetchManagedGoogleCalendarFeed: mockFetchCalendarFeed, + createManagedGoogleCalendarEvent: mockCreateCalendarEvent, + fetchManagedGoogleGmailTriage: mockFetchGmailTriage, + sendManagedGoogleReply: mockSendReply, + MiladyGoogleConnectorError: class MiladyGoogleConnectorError extends Error { + constructor( + public readonly status: number, + message: string, + ) { + super(message); + this.name = "MiladyGoogleConnectorError"; + } + }, +})); + +import { POST as postCalendarEvent } from "@/app/api/v1/milady/google/calendar/events/route"; +import { GET as getCalendarFeed } from "@/app/api/v1/milady/google/calendar/feed/route"; +import { POST as postConnectInitiate } from "@/app/api/v1/milady/google/connect/initiate/route"; +import { POST as postDisconnect } from "@/app/api/v1/milady/google/disconnect/route"; +import { POST as postReplySend } from "@/app/api/v1/milady/google/gmail/reply-send/route"; +import { GET as getGmailTriage } from "@/app/api/v1/milady/google/gmail/triage/route"; +import { GET as getStatus } from "@/app/api/v1/milady/google/status/route"; + +describe("Milady managed Google routes", () => { + beforeEach(() => { + mockRequireAuthOrApiKeyWithOrg.mockReset(); + mockGetStatus.mockReset(); + mockInitiateConnection.mockReset(); + mockDisconnectConnection.mockReset(); + mockFetchCalendarFeed.mockReset(); + mockCreateCalendarEvent.mockReset(); + mockFetchGmailTriage.mockReset(); + mockSendReply.mockReset(); + + mockRequireAuthOrApiKeyWithOrg.mockResolvedValue({ + user: { + id: "user-1", + organization_id: "org-1", + }, + }); + }); + + test("GET /api/v1/milady/google/status returns the managed connector status", async () => { + mockGetStatus.mockResolvedValue({ + provider: "google", + mode: "cloud_managed", + configured: true, + connected: true, + reason: "connected", + identity: { email: "founder@example.com" }, + grantedCapabilities: ["google.basic_identity"], + grantedScopes: ["openid", "email", "profile"], + expiresAt: null, + hasRefreshToken: true, + connectionId: "conn-1", + linkedAt: "2026-04-04T15:00:00.000Z", + lastUsedAt: "2026-04-04T16:00:00.000Z", + }); + + const response = await getStatus( + new NextRequest("https://example.com/api/v1/milady/google/status"), + ); + + expect(response.status).toBe(200); + expect(await response.json()).toMatchObject({ + provider: "google", + mode: "cloud_managed", + connected: true, + connectionId: "conn-1", + }); + }); + + test("POST /api/v1/milady/google/connect/initiate validates capabilities and delegates to the service", async () => { + mockInitiateConnection.mockResolvedValue({ + provider: "google", + mode: "cloud_managed", + requestedCapabilities: ["google.basic_identity", "google.calendar.read"], + redirectUri: "https://www.elizacloud.ai/auth/success?platform=google", + authUrl: "https://accounts.google.com/o/oauth2/v2/auth?state=managed-google", + }); + + const response = await postConnectInitiate( + jsonRequest("https://example.com/api/v1/milady/google/connect/initiate", "POST", { + redirectUrl: "https://www.elizacloud.ai/auth/success?platform=google", + capabilities: ["google.calendar.read"], + }), + ); + + expect(response.status).toBe(200); + expect(mockInitiateConnection).toHaveBeenCalledWith({ + organizationId: "org-1", + userId: "user-1", + redirectUrl: "https://www.elizacloud.ai/auth/success?platform=google", + capabilities: ["google.calendar.read"], + }); + }); + + test("GET /api/v1/milady/google/calendar/feed requires an explicit time window", async () => { + const response = await getCalendarFeed( + new NextRequest("https://example.com/api/v1/milady/google/calendar/feed?calendarId=primary"), + ); + + expect(response.status).toBe(400); + expect(await response.json()).toEqual({ + error: "timeMin and timeMax are required.", + }); + }); + + test("POST /api/v1/milady/google/calendar/events creates calendar events through the service", async () => { + mockCreateCalendarEvent.mockResolvedValue({ + event: { + externalId: "event-1", + calendarId: "primary", + title: "Founder sync", + description: "", + location: "", + status: "confirmed", + startAt: "2026-04-04T19:00:00.000Z", + endAt: "2026-04-04T19:30:00.000Z", + isAllDay: false, + timezone: "UTC", + htmlLink: null, + conferenceLink: null, + organizer: null, + attendees: [], + metadata: {}, + }, + }); + + const response = await postCalendarEvent( + jsonRequest("https://example.com/api/v1/milady/google/calendar/events", "POST", { + title: "Founder sync", + startAt: "2026-04-04T19:00:00.000Z", + endAt: "2026-04-04T19:30:00.000Z", + timeZone: "UTC", + }), + ); + + expect(response.status).toBe(201); + expect(mockCreateCalendarEvent).toHaveBeenCalledWith({ + organizationId: "org-1", + userId: "user-1", + calendarId: "primary", + title: "Founder sync", + description: undefined, + location: undefined, + startAt: "2026-04-04T19:00:00.000Z", + endAt: "2026-04-04T19:30:00.000Z", + timeZone: "UTC", + attendees: undefined, + }); + }); + + test("GET /api/v1/milady/google/gmail/triage rejects non-positive maxResults", async () => { + const response = await getGmailTriage( + new NextRequest("https://example.com/api/v1/milady/google/gmail/triage?maxResults=0"), + ); + + expect(response.status).toBe(400); + expect(await response.json()).toEqual({ + error: "maxResults must be a positive integer.", + }); + }); + + test("POST /api/v1/milady/google/gmail/reply-send validates the payload and delegates to the service", async () => { + mockSendReply.mockResolvedValue(undefined); + + const response = await postReplySend( + jsonRequest("https://example.com/api/v1/milady/google/gmail/reply-send", "POST", { + to: ["founder@example.com"], + subject: "Project sync", + bodyText: "Reviewing it now.", + inReplyTo: "", + }), + ); + + expect(response.status).toBe(200); + expect(mockSendReply).toHaveBeenCalledWith({ + organizationId: "org-1", + userId: "user-1", + to: ["founder@example.com"], + cc: undefined, + subject: "Project sync", + bodyText: "Reviewing it now.", + inReplyTo: "", + references: null, + }); + }); + + test("POST /api/v1/milady/google/disconnect disconnects the current Google connection", async () => { + mockDisconnectConnection.mockResolvedValue(undefined); + + const response = await postDisconnect( + jsonRequest("https://example.com/api/v1/milady/google/disconnect", "POST", {}), + ); + + expect(response.status).toBe(200); + expect(await response.json()).toEqual({ ok: true }); + expect(mockDisconnectConnection).toHaveBeenCalledWith({ + organizationId: "org-1", + userId: "user-1", + connectionId: null, + }); + }); +}); From d122a3692c8268a047353958571c0b32a4cd93f1 Mon Sep 17 00:00:00 2001 From: Shaw Date: Sun, 5 Apr 2026 01:28:50 -0700 Subject: [PATCH 06/11] cloud: checkpoint local oauth and runtime changes --- .gitignore | 1 + .../auth/connection-success/route.ts | 118 +++++++++ .../connections/[platform]/initiate/route.ts | 85 +++++++ app/api/eliza-app/connections/route.ts | 70 ++++++ app/api/eliza-app/webhook/blooio/route.ts | 1 + app/api/eliza-app/webhook/discord/route.ts | 1 + app/api/eliza-app/webhook/telegram/route.ts | 2 + app/api/mcp/tools/airtable.ts | 2 + app/api/mcp/tools/asana.ts | 2 + app/api/mcp/tools/dropbox.ts | 2 + app/api/mcp/tools/github.ts | 2 + app/api/mcp/tools/google.ts | 2 + app/api/mcp/tools/jira.ts | 2 + app/api/mcp/tools/linear.ts | 2 + app/api/mcp/tools/linkedin.ts | 2 + app/api/mcp/tools/notion.ts | 2 + app/api/mcp/tools/salesforce.ts | 2 + app/api/mcp/tools/twitter.ts | 2 + app/api/mcp/tools/zoom.ts | 2 + app/api/mcps/airtable/[transport]/route.ts | 9 + app/api/mcps/asana/[transport]/route.ts | 9 + app/api/mcps/dropbox/[transport]/route.ts | 9 + app/api/mcps/github/[transport]/route.ts | 9 + app/api/mcps/google/[transport]/route.ts | 9 + app/api/mcps/jira/[transport]/route.ts | 14 +- app/api/mcps/linear/[transport]/route.ts | 9 + app/api/mcps/linkedin/[transport]/route.ts | 9 + app/api/mcps/microsoft/[transport]/route.ts | 9 + app/api/mcps/notion/[transport]/route.ts | 9 + app/api/mcps/salesforce/[transport]/route.ts | 9 + app/api/mcps/twitter/[transport]/route.ts | 9 + app/api/mcps/zoom/[transport]/route.ts | 14 +- app/api/v1/oauth/[platform]/callback/route.ts | 5 +- app/api/v1/oauth/[platform]/initiate/route.ts | 43 ++-- app/api/v1/oauth/connect/route.ts | 58 +++-- app/api/v1/oauth/connections/[id]/route.ts | 51 ++-- .../v1/oauth/connections/[id]/token/route.ts | 24 +- app/api/v1/oauth/connections/route.ts | 26 +- app/api/v1/oauth/status/route.ts | 20 +- app/api/v1/oauth/token/[platform]/route.ts | 33 ++- packages/lib/eliza/runtime-factory.ts | 4 +- .../eliza-app/connection-enforcement.ts | 44 ++-- .../connection-adapters/generic-adapter.ts | 1 + packages/lib/services/oauth/oauth-service.ts | 99 ++++++-- packages/lib/services/oauth/types.ts | 6 + .../tests/gateway-manager.test.ts | 172 ++++++++++++- packages/tests/e2e/setup-server.ts | 27 +- packages/tests/infrastructure/index.ts | 4 +- packages/tests/infrastructure/test-runtime.ts | 6 +- .../tests/integration/connection-apis.test.ts | 23 +- packages/tests/load-env.ts | 6 +- .../unit/discord-automation-oauth.test.ts | 235 ++++++++++++++++++ .../eliza-app/connection-enforcement.test.ts | 17 +- .../connection-success-route.test.ts | 16 ++ .../unit/eliza-app/connections-route.test.ts | 117 +++++++++ packages/tests/unit/mcp-google-tools.test.ts | 10 + .../tests/unit/oauth/oauth-service.test.ts | 49 ++++ tsconfig.json | 50 +++- 58 files changed, 1401 insertions(+), 174 deletions(-) create mode 100644 app/api/eliza-app/connections/[platform]/initiate/route.ts create mode 100644 app/api/eliza-app/connections/route.ts create mode 100644 packages/tests/unit/discord-automation-oauth.test.ts create mode 100644 packages/tests/unit/eliza-app/connections-route.test.ts diff --git a/.gitignore b/.gitignore index 12f650779..2e16ee365 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,7 @@ node_modules # next.js /.next/ /.next-build/ +/.next-test-*/ /out/ # production diff --git a/app/api/eliza-app/auth/connection-success/route.ts b/app/api/eliza-app/auth/connection-success/route.ts index d200899bb..88fd20308 100644 --- a/app/api/eliza-app/auth/connection-success/route.ts +++ b/app/api/eliza-app/auth/connection-success/route.ts @@ -9,6 +9,14 @@ const PLATFORM_MESSAGES: Record = { web: "close this tab. your chat is ready.", }; +const PROVIDER_LABELS: Record = { + google: "Google", + microsoft: "Microsoft", + twitter: "X", + github: "GitHub", + slack: "Slack", +}; + function buildHtml(platform: string): string { const instruction = PLATFORM_MESSAGES[platform] ?? PLATFORM_MESSAGES.web; @@ -71,7 +79,117 @@ function buildHtml(platform: string): string { `; } +function buildElizaAppHtml(provider: string, connectionId: string | null): string { + const providerLabel = PROVIDER_LABELS[provider] ?? "Your account"; + const payload = JSON.stringify({ + type: "eliza-app-oauth-complete", + provider, + connectionId, + connected: true, + }); + + return ` + + + + + ${providerLabel} connected + + + +
+
+

${providerLabel} connected.

+

You can return to Eliza App now. If this window does not close automatically, close it manually.

+ +
+ + +`; +} + export async function GET(request: NextRequest) { + const source = request.nextUrl.searchParams.get("source"); + if (source === "eliza-app") { + const provider = request.nextUrl.searchParams.get("platform") || "connection"; + const connectionId = request.nextUrl.searchParams.get("connection_id"); + + return new NextResponse(buildElizaAppHtml(provider, connectionId), { + status: 200, + headers: { "Content-Type": "text/html; charset=utf-8" }, + }); + } + const platform = request.nextUrl.searchParams.get("platform") || "web"; if (platform === "web") { diff --git a/app/api/eliza-app/connections/[platform]/initiate/route.ts b/app/api/eliza-app/connections/[platform]/initiate/route.ts new file mode 100644 index 000000000..d7b3d6f29 --- /dev/null +++ b/app/api/eliza-app/connections/[platform]/initiate/route.ts @@ -0,0 +1,85 @@ +import { NextRequest, NextResponse } from "next/server"; +import { elizaAppSessionService } from "@/lib/services/eliza-app"; +import { oauthService } from "@/lib/services/oauth"; +import { getProvider } from "@/lib/services/oauth/provider-registry"; + +interface InitiateBody { + returnPath?: string; + scopes?: string[]; +} + +function sanitizeReturnPath(path: string | undefined): string { + if (!path || !path.startsWith("/")) { + return "/connected"; + } + + return path; +} + +export async function POST( + request: NextRequest, + { params }: { params: Promise<{ platform: string }> }, +): Promise { + const authHeader = request.headers.get("Authorization"); + if (!authHeader) { + return NextResponse.json( + { error: "Authorization header required", code: "UNAUTHORIZED" }, + { status: 401 }, + ); + } + + const session = await elizaAppSessionService.validateAuthHeader(authHeader); + if (!session) { + return NextResponse.json( + { error: "Invalid or expired session", code: "INVALID_SESSION" }, + { status: 401 }, + ); + } + + const { platform } = await params; + const normalizedPlatform = platform.toLowerCase(); + const provider = getProvider(normalizedPlatform); + if (!provider) { + return NextResponse.json( + { error: "Unsupported platform", code: "PLATFORM_NOT_SUPPORTED" }, + { status: 400 }, + ); + } + + let body: InitiateBody = {}; + try { + body = (await request.json()) as InitiateBody; + } catch { + // Empty body is fine. + } + + const returnPath = sanitizeReturnPath(body.returnPath); + const redirectUrl = `/api/eliza-app/auth/connection-success?source=eliza-app&return_path=${encodeURIComponent(returnPath)}`; + + try { + const result = await oauthService.initiateAuth({ + organizationId: session.organizationId, + userId: session.userId, + platform: normalizedPlatform, + redirectUrl, + scopes: body.scopes, + }); + + return NextResponse.json({ + authUrl: result.authUrl, + state: result.state, + provider: { + id: provider.id, + name: provider.name, + }, + }); + } catch (error) { + return NextResponse.json( + { + error: error instanceof Error ? error.message : "Failed to initiate OAuth", + code: "INITIATE_FAILED", + }, + { status: 500 }, + ); + } +} diff --git a/app/api/eliza-app/connections/route.ts b/app/api/eliza-app/connections/route.ts new file mode 100644 index 000000000..e8b58ec31 --- /dev/null +++ b/app/api/eliza-app/connections/route.ts @@ -0,0 +1,70 @@ +import { NextRequest, NextResponse } from "next/server"; +import { elizaAppSessionService } from "@/lib/services/eliza-app"; +import { getProvider } from "@/lib/services/oauth/provider-registry"; + +function getRequestedPlatform(request: NextRequest): string | null { + const platform = request.nextUrl.searchParams.get("platform")?.toLowerCase() || "google"; + return getProvider(platform) ? platform : null; +} + +export async function GET(request: NextRequest): Promise { + const authHeader = request.headers.get("Authorization"); + if (!authHeader) { + return NextResponse.json( + { error: "Authorization header required", code: "UNAUTHORIZED" }, + { status: 401 }, + ); + } + + const session = await elizaAppSessionService.validateAuthHeader(authHeader); + if (!session) { + return NextResponse.json( + { error: "Invalid or expired session", code: "INVALID_SESSION" }, + { status: 401 }, + ); + } + + const platform = getRequestedPlatform(request); + if (!platform) { + return NextResponse.json( + { error: "Unsupported platform", code: "PLATFORM_NOT_SUPPORTED" }, + { status: 400 }, + ); + } + + try { + const { oauthService } = await import("@/lib/services/oauth"); + const connections = await oauthService.listConnections({ + organizationId: session.organizationId, + userId: session.userId, + platform, + }); + + const active = connections.find((connection) => connection.status === "active"); + const expired = connections.find((connection) => connection.status === "expired"); + const current = active ?? expired ?? null; + + return NextResponse.json({ + platform, + connected: Boolean(active), + status: active ? "active" : expired ? "expired" : "not_connected", + email: current?.email ?? null, + scopes: current?.scopes ?? [], + linkedAt: current?.linkedAt?.toISOString() ?? null, + connectionId: current?.id ?? null, + message: active + ? null + : expired + ? "Connection expired. Reconnect Google to keep Gmail and Calendar working." + : "Not connected yet.", + }); + } catch (error) { + return NextResponse.json( + { + error: error instanceof Error ? error.message : "Failed to load connection status", + code: "CONNECTION_STATUS_FAILED", + }, + { status: 500 }, + ); + } +} diff --git a/app/api/eliza-app/webhook/blooio/route.ts b/app/api/eliza-app/webhook/blooio/route.ts index cf7d73f6c..680722090 100644 --- a/app/api/eliza-app/webhook/blooio/route.ts +++ b/app/api/eliza-app/webhook/blooio/route.ts @@ -169,6 +169,7 @@ async function handleIncomingMessage(event: BlooioWebhookEvent): Promise const hasRequiredConnection = await connectionEnforcementService.hasRequiredConnection( organization.id, + userWithOrg.id, ); if (!hasRequiredConnection) { const nudgeText = await connectionEnforcementService.generateNudgeResponse({ diff --git a/app/api/eliza-app/webhook/telegram/route.ts b/app/api/eliza-app/webhook/telegram/route.ts index 36ac7df5d..01555cf69 100644 --- a/app/api/eliza-app/webhook/telegram/route.ts +++ b/app/api/eliza-app/webhook/telegram/route.ts @@ -173,6 +173,7 @@ async function handleMessage(message: Message): Promise { const hasRequiredConnection = await connectionEnforcementService.hasRequiredConnection( organization.id, + userWithOrg.id, ); if (!hasRequiredConnection) { const nudgeText = await connectionEnforcementService.generateNudgeResponse({ @@ -393,6 +394,7 @@ async function handleCommand(message: Message & { text: string }): Promise const creditBalance = user.organization.credit_balance || "0.00"; const hasRequiredConnection = await connectionEnforcementService.hasRequiredConnection( user.organization.id, + user.id, ); const connectionStatus = hasRequiredConnection ? "✅ Data integration connected" diff --git a/app/api/mcp/tools/airtable.ts b/app/api/mcp/tools/airtable.ts index 54580e79a..ddf5a62af 100644 --- a/app/api/mcp/tools/airtable.ts +++ b/app/api/mcp/tools/airtable.ts @@ -16,6 +16,7 @@ async function getAirtableToken(): Promise { try { const result = await oauthService.getValidTokenByPlatform({ organizationId: user.organization_id, + userId: user.id, platform: "airtable", }); return result.accessToken; @@ -68,6 +69,7 @@ export function registerAirtableTools(server: McpServer): void { const { user } = getAuthContext(); const connections = await oauthService.listConnections({ organizationId: user.organization_id, + userId: user.id, platform: "airtable", }); const active = connections.find((c) => c.status === "active"); diff --git a/app/api/mcp/tools/asana.ts b/app/api/mcp/tools/asana.ts index 676236707..b8c3ea3fb 100644 --- a/app/api/mcp/tools/asana.ts +++ b/app/api/mcp/tools/asana.ts @@ -23,6 +23,7 @@ async function getAsanaToken(): Promise { try { const result = await oauthService.getValidTokenByPlatform({ organizationId: user.organization_id, + userId: user.id, platform: "asana", }); return result.accessToken; @@ -77,6 +78,7 @@ export function registerAsanaTools(server: McpServer): void { const { user } = getAuthContext(); const connections = await oauthService.listConnections({ organizationId: user.organization_id, + userId: user.id, platform: "asana", }); const active = connections.find((c) => c.status === "active"); diff --git a/app/api/mcp/tools/dropbox.ts b/app/api/mcp/tools/dropbox.ts index abc4aa9a0..65e40890e 100644 --- a/app/api/mcp/tools/dropbox.ts +++ b/app/api/mcp/tools/dropbox.ts @@ -16,6 +16,7 @@ async function getDropboxToken(): Promise { try { const result = await oauthService.getValidTokenByPlatform({ organizationId: user.organization_id, + userId: user.id, platform: "dropbox", }); return result.accessToken; @@ -72,6 +73,7 @@ export function registerDropboxTools(server: McpServer): void { const { user } = getAuthContext(); const connections = await oauthService.listConnections({ organizationId: user.organization_id, + userId: user.id, platform: "dropbox", }); const active = connections.find((c) => c.status === "active"); diff --git a/app/api/mcp/tools/github.ts b/app/api/mcp/tools/github.ts index 0d48f53d4..48887ffb8 100644 --- a/app/api/mcp/tools/github.ts +++ b/app/api/mcp/tools/github.ts @@ -16,6 +16,7 @@ async function getGitHubToken(): Promise { try { const result = await oauthService.getValidTokenByPlatform({ organizationId: user.organization_id, + userId: user.id, platform: "github", }); return result.accessToken; @@ -77,6 +78,7 @@ export function registerGitHubTools(server: McpServer): void { const { user } = getAuthContext(); const connections = await oauthService.listConnections({ organizationId: user.organization_id, + userId: user.id, platform: "github", }); const active = connections.find((c) => c.status === "active"); diff --git a/app/api/mcp/tools/google.ts b/app/api/mcp/tools/google.ts index a5a270e6c..f9a0dadbd 100644 --- a/app/api/mcp/tools/google.ts +++ b/app/api/mcp/tools/google.ts @@ -20,6 +20,7 @@ async function getGoogleToken(): Promise { try { const result = await oauthService.getValidTokenByPlatform({ organizationId: user.organization_id, + userId: user.id, platform: "google", }); return result.accessToken; @@ -49,6 +50,7 @@ export function registerGoogleTools(server: McpServer): void { const { user } = getAuthContext(); const connections = await oauthService.listConnections({ organizationId: user.organization_id, + userId: user.id, platform: "google", }); diff --git a/app/api/mcp/tools/jira.ts b/app/api/mcp/tools/jira.ts index 3d0005092..aa8081932 100644 --- a/app/api/mcp/tools/jira.ts +++ b/app/api/mcp/tools/jira.ts @@ -20,6 +20,7 @@ async function getJiraToken(): Promise { try { const result = await oauthService.getValidTokenByPlatform({ organizationId: user.organization_id, + userId: user.id, platform: "jira", }); return result.accessToken; @@ -122,6 +123,7 @@ export function registerJiraTools(server: McpServer): void { const { user } = getAuthContext(); const connections = await oauthService.listConnections({ organizationId: user.organization_id, + userId: user.id, platform: "jira", }); const active = connections.find((c) => c.status === "active"); diff --git a/app/api/mcp/tools/linear.ts b/app/api/mcp/tools/linear.ts index 7c10dc6d7..9ea9f4342 100644 --- a/app/api/mcp/tools/linear.ts +++ b/app/api/mcp/tools/linear.ts @@ -16,6 +16,7 @@ async function getLinearToken(): Promise { try { const result = await oauthService.getValidTokenByPlatform({ organizationId: user.organization_id, + userId: user.id, platform: "linear", }); return result.accessToken; @@ -76,6 +77,7 @@ export function registerLinearTools(server: McpServer): void { const { user } = getAuthContext(); const connections = await oauthService.listConnections({ organizationId: user.organization_id, + userId: user.id, platform: "linear", }); const active = connections.find((c) => c.status === "active"); diff --git a/app/api/mcp/tools/linkedin.ts b/app/api/mcp/tools/linkedin.ts index 1c18b3655..0b5cc5568 100644 --- a/app/api/mcp/tools/linkedin.ts +++ b/app/api/mcp/tools/linkedin.ts @@ -22,6 +22,7 @@ async function getLinkedInToken(): Promise { try { const result = await oauthService.getValidTokenByPlatform({ organizationId: user.organization_id, + userId: user.id, platform: "linkedin", }); return result.accessToken; @@ -90,6 +91,7 @@ export function registerLinkedInTools(server: McpServer): void { const { user } = getAuthContext(); const connections = await oauthService.listConnections({ organizationId: user.organization_id, + userId: user.id, platform: "linkedin", }); const active = connections.find((c) => c.status === "active"); diff --git a/app/api/mcp/tools/notion.ts b/app/api/mcp/tools/notion.ts index f7b65e19f..c3d3dd9f0 100644 --- a/app/api/mcp/tools/notion.ts +++ b/app/api/mcp/tools/notion.ts @@ -16,6 +16,7 @@ async function getNotionToken(): Promise { try { const result = await oauthService.getValidTokenByPlatform({ organizationId: user.organization_id, + userId: user.id, platform: "notion", }); return result.accessToken; @@ -67,6 +68,7 @@ export function registerNotionTools(server: McpServer): void { const { user } = getAuthContext(); const connections = await oauthService.listConnections({ organizationId: user.organization_id, + userId: user.id, platform: "notion", }); const active = connections.find((c) => c.status === "active"); diff --git a/app/api/mcp/tools/salesforce.ts b/app/api/mcp/tools/salesforce.ts index db2dea076..998acab46 100644 --- a/app/api/mcp/tools/salesforce.ts +++ b/app/api/mcp/tools/salesforce.ts @@ -25,6 +25,7 @@ async function getSalesforceToken(): Promise { try { const result = await oauthService.getValidTokenByPlatform({ organizationId: user.organization_id, + userId: user.id, platform: "salesforce", }); return result.accessToken; @@ -127,6 +128,7 @@ export function registerSalesforceTools(server: McpServer): void { const { user } = getAuthContext(); const connections = await oauthService.listConnections({ organizationId: user.organization_id, + userId: user.id, platform: "salesforce", }); const active = connections.find((c) => c.status === "active"); diff --git a/app/api/mcp/tools/twitter.ts b/app/api/mcp/tools/twitter.ts index 8363d316c..1c50455fe 100644 --- a/app/api/mcp/tools/twitter.ts +++ b/app/api/mcp/tools/twitter.ts @@ -34,6 +34,7 @@ async function getTwitterClient(): Promise { try { result = await oauthService.getValidTokenByPlatform({ organizationId: user.organization_id, + userId: user.id, platform: "twitter", }); } catch (error) { @@ -303,6 +304,7 @@ export function registerTwitterTools(server: McpServer): void { const { user } = getAuthContext(); const connections = await oauthService.listConnections({ organizationId: user.organization_id, + userId: user.id, platform: "twitter", }); const active = connections.find((c) => c.status === "active"); diff --git a/app/api/mcp/tools/zoom.ts b/app/api/mcp/tools/zoom.ts index b87ba5421..75854e0c5 100644 --- a/app/api/mcp/tools/zoom.ts +++ b/app/api/mcp/tools/zoom.ts @@ -18,6 +18,7 @@ async function getZoomToken(): Promise { try { const result = await oauthService.getValidTokenByPlatform({ organizationId: user.organization_id, + userId: user.id, platform: "zoom", }); return result.accessToken; @@ -72,6 +73,7 @@ export function registerZoomTools(server: McpServer): void { const { user } = getAuthContext(); const connections = await oauthService.listConnections({ organizationId: user.organization_id, + userId: user.id, platform: "zoom", }); const active = connections.find((c) => c.status === "active"); diff --git a/app/api/mcps/airtable/[transport]/route.ts b/app/api/mcps/airtable/[transport]/route.ts index 237ccd181..dc36868ad 100644 --- a/app/api/mcps/airtable/[transport]/route.ts +++ b/app/api/mcps/airtable/[transport]/route.ts @@ -38,8 +38,10 @@ async function getAirtableMcpHandler() { const { z } = await import("zod3"); async function getAirtableToken(organizationId: string): Promise { + const user = getAuthUser(); const result = await oauthService.getValidTokenByPlatform({ organizationId, + userId: user.id, platform: "airtable", }); return result.accessToken; @@ -75,6 +77,12 @@ async function getAirtableMcpHandler() { return ctx.user.organization_id; } + function getAuthUser() { + const ctx = authContextStorage.getStore(); + if (!ctx) throw new Error("Not authenticated"); + return ctx.user; + } + function jsonResult(data: object) { return { content: [{ type: "text" as const, text: JSON.stringify(data) }] }; } @@ -94,6 +102,7 @@ async function getAirtableMcpHandler() { const orgId = getOrgId(); const connections = await oauthService.listConnections({ organizationId: orgId, + userId: getAuthUser().id, platform: "airtable", }); const active = connections.find((c) => c.status === "active"); diff --git a/app/api/mcps/asana/[transport]/route.ts b/app/api/mcps/asana/[transport]/route.ts index b4f8d6a3e..483ad376f 100644 --- a/app/api/mcps/asana/[transport]/route.ts +++ b/app/api/mcps/asana/[transport]/route.ts @@ -46,8 +46,10 @@ async function getAsanaMcpHandler() { const API_BASE = "https://app.asana.com/api/1.0"; async function getAsanaToken(organizationId: string): Promise { + const user = getAuthUser(); const result = await oauthService.getValidTokenByPlatform({ organizationId, + userId: user.id, platform: "asana", }); return result.accessToken; @@ -94,6 +96,12 @@ async function getAsanaMcpHandler() { return ctx.user.organization_id; } + function getAuthUser() { + const ctx = authContextStorage.getStore(); + if (!ctx) throw new Error("Not authenticated"); + return ctx.user; + } + function jsonResult(data: object) { return { content: [{ type: "text" as const, text: JSON.stringify(data) }] }; } @@ -113,6 +121,7 @@ async function getAsanaMcpHandler() { const orgId = getOrgId(); const connections = await oauthService.listConnections({ organizationId: orgId, + userId: getAuthUser().id, platform: "asana", }); const active = connections.find((c) => c.status === "active"); diff --git a/app/api/mcps/dropbox/[transport]/route.ts b/app/api/mcps/dropbox/[transport]/route.ts index e3fb287c3..2683e3f26 100644 --- a/app/api/mcps/dropbox/[transport]/route.ts +++ b/app/api/mcps/dropbox/[transport]/route.ts @@ -38,8 +38,10 @@ async function getDropboxMcpHandler() { const { z } = await import("zod3"); async function getDropboxToken(organizationId: string): Promise { + const user = getAuthUser(); const result = await oauthService.getValidTokenByPlatform({ organizationId, + userId: user.id, platform: "dropbox", }); return result.accessToken; @@ -78,6 +80,12 @@ async function getDropboxMcpHandler() { return ctx.user.organization_id; } + function getAuthUser() { + const ctx = authContextStorage.getStore(); + if (!ctx) throw new Error("Not authenticated"); + return ctx.user; + } + function jsonResult(data: object) { return { content: [{ type: "text" as const, text: JSON.stringify(data) }] }; } @@ -97,6 +105,7 @@ async function getDropboxMcpHandler() { const orgId = getOrgId(); const connections = await oauthService.listConnections({ organizationId: orgId, + userId: getAuthUser().id, platform: "dropbox", }); const active = connections.find((c) => c.status === "active"); diff --git a/app/api/mcps/github/[transport]/route.ts b/app/api/mcps/github/[transport]/route.ts index ea5f899f5..e4583ba9d 100644 --- a/app/api/mcps/github/[transport]/route.ts +++ b/app/api/mcps/github/[transport]/route.ts @@ -34,8 +34,10 @@ async function getGitHubMcpHandler() { const { z } = await import("zod"); async function getGitHubToken(organizationId: string): Promise { + const user = getAuthUser(); const result = await oauthService.getValidTokenByPlatform({ organizationId, + userId: user.id, platform: "github", }); return result.accessToken; @@ -71,6 +73,12 @@ async function getGitHubMcpHandler() { return ctx.user.organization_id; } + function getAuthUser() { + const ctx = authContextStorage.getStore(); + if (!ctx) throw new Error("Not authenticated"); + return ctx.user; + } + function jsonResult(data: object) { return { content: [{ type: "text" as const, text: JSON.stringify(data) }] }; } @@ -98,6 +106,7 @@ async function getGitHubMcpHandler() { const orgId = getOrgId(); const connections = await oauthService.listConnections({ organizationId: orgId, + userId: getAuthUser().id, platform: "github", }); const active = connections.find((c) => c.status === "active"); diff --git a/app/api/mcps/google/[transport]/route.ts b/app/api/mcps/google/[transport]/route.ts index 558a4e31c..d9819ea83 100644 --- a/app/api/mcps/google/[transport]/route.ts +++ b/app/api/mcps/google/[transport]/route.ts @@ -53,10 +53,18 @@ async function getGoogleMcpHandler() { return ctx.user.organization_id; } + function getAuthUser() { + const ctx = authContextStorage.getStore(); + if (!ctx) throw new Error("Not authenticated"); + return ctx.user; + } + async function getGoogleToken(organizationId: string): Promise { try { + const user = getAuthUser(); const result = await oauthService.getValidTokenByPlatform({ organizationId, + userId: user.id, platform: "google", }); return result.accessToken; @@ -106,6 +114,7 @@ async function getGoogleMcpHandler() { const orgId = getOrgId(); const connections = await oauthService.listConnections({ organizationId: orgId, + userId: getAuthUser().id, platform: "google", }); const active = connections.find((c) => c.status === "active"); diff --git a/app/api/mcps/jira/[transport]/route.ts b/app/api/mcps/jira/[transport]/route.ts index 6e260db0b..9935a6f86 100644 --- a/app/api/mcps/jira/[transport]/route.ts +++ b/app/api/mcps/jira/[transport]/route.ts @@ -43,7 +43,12 @@ async function getJiraMcpHandler() { const { z } = await import("zod3"); async function getJiraToken(organizationId: string): Promise { - const result = await oauthService.getValidTokenByPlatform({ organizationId, platform: "jira" }); + const user = getAuthUser(); + const result = await oauthService.getValidTokenByPlatform({ + organizationId, + userId: user.id, + platform: "jira", + }); return result.accessToken; } @@ -147,6 +152,12 @@ async function getJiraMcpHandler() { return ctx.user.organization_id; } + function getAuthUser() { + const ctx = authContextStorage.getStore(); + if (!ctx) throw new Error("Not authenticated"); + return ctx.user; + } + function jsonResult(data: object) { return { content: [{ type: "text" as const, text: JSON.stringify(data) }] }; } @@ -165,6 +176,7 @@ async function getJiraMcpHandler() { const orgId = getOrgId(); const connections = await oauthService.listConnections({ organizationId: orgId, + userId: getAuthUser().id, platform: "jira", }); const active = connections.find((c) => c.status === "active"); diff --git a/app/api/mcps/linear/[transport]/route.ts b/app/api/mcps/linear/[transport]/route.ts index 0d3caaa3f..d06ef5645 100644 --- a/app/api/mcps/linear/[transport]/route.ts +++ b/app/api/mcps/linear/[transport]/route.ts @@ -34,8 +34,10 @@ async function getLinearMcpHandler() { const { z } = await import("zod"); async function getLinearToken(organizationId: string): Promise { + const user = getAuthUser(); const result = await oauthService.getValidTokenByPlatform({ organizationId, + userId: user.id, platform: "linear", }); return result.accessToken; @@ -79,6 +81,12 @@ async function getLinearMcpHandler() { return ctx.user.organization_id; } + function getAuthUser() { + const ctx = authContextStorage.getStore(); + if (!ctx) throw new Error("Not authenticated"); + return ctx.user; + } + function jsonResult(data: object) { return { content: [{ type: "text" as const, text: JSON.stringify(data) }] }; } @@ -97,6 +105,7 @@ async function getLinearMcpHandler() { const orgId = getOrgId(); const connections = await oauthService.listConnections({ organizationId: orgId, + userId: getAuthUser().id, platform: "linear", }); const active = connections.find((c) => c.status === "active"); diff --git a/app/api/mcps/linkedin/[transport]/route.ts b/app/api/mcps/linkedin/[transport]/route.ts index cc344a45f..0e0072e23 100644 --- a/app/api/mcps/linkedin/[transport]/route.ts +++ b/app/api/mcps/linkedin/[transport]/route.ts @@ -44,8 +44,10 @@ async function getLinkedInMcpHandler() { const { z } = await import("zod3"); async function getLinkedInToken(organizationId: string): Promise { + const user = getAuthUser(); const result = await oauthService.getValidTokenByPlatform({ organizationId, + userId: user.id, platform: "linkedin", }); return result.accessToken; @@ -57,6 +59,12 @@ async function getLinkedInMcpHandler() { return ctx.user.organization_id; } + function getAuthUser() { + const ctx = authContextStorage.getStore(); + if (!ctx) throw new Error("Not authenticated"); + return ctx.user; + } + async function linkedinFetch(orgId: string, path: string, options: RequestInit = {}) { const token = await getLinkedInToken(orgId); const url = path.startsWith("http") ? path : `${LINKEDIN_REST_BASE}${path}`; @@ -115,6 +123,7 @@ async function getLinkedInMcpHandler() { const orgId = getOrgId(); const connections = await oauthService.listConnections({ organizationId: orgId, + userId: getAuthUser().id, platform: "linkedin", }); const active = connections.find((c) => c.status === "active"); diff --git a/app/api/mcps/microsoft/[transport]/route.ts b/app/api/mcps/microsoft/[transport]/route.ts index 8cb026066..ba23f29ec 100644 --- a/app/api/mcps/microsoft/[transport]/route.ts +++ b/app/api/mcps/microsoft/[transport]/route.ts @@ -37,8 +37,10 @@ async function getMicrosoftMcpHandler() { const { z } = await import("zod/v3"); async function getMicrosoftToken(organizationId: string): Promise { + const user = getAuthUser(); const result = await oauthService.getValidTokenByPlatform({ organizationId, + userId: user.id, platform: "microsoft", }); return result.accessToken; @@ -67,6 +69,12 @@ async function getMicrosoftMcpHandler() { return ctx.user.organization_id; } + function getAuthUser() { + const ctx = authContextStorage.getStore(); + if (!ctx) throw new Error("Not authenticated"); + return ctx.user; + } + function jsonResult(data: object) { return { content: [{ type: "text" as const, text: JSON.stringify(data) }] }; } @@ -85,6 +93,7 @@ async function getMicrosoftMcpHandler() { const orgId = getOrgId(); const connections = await oauthService.listConnections({ organizationId: orgId, + userId: getAuthUser().id, platform: "microsoft", }); const active = connections.find((c) => c.status === "active"); diff --git a/app/api/mcps/notion/[transport]/route.ts b/app/api/mcps/notion/[transport]/route.ts index 80823e751..6fae7535c 100644 --- a/app/api/mcps/notion/[transport]/route.ts +++ b/app/api/mcps/notion/[transport]/route.ts @@ -34,8 +34,10 @@ async function getNotionMcpHandler() { const { z } = await import("zod3"); async function getNotionToken(organizationId: string): Promise { + const user = getAuthUser(); const result = await oauthService.getValidTokenByPlatform({ organizationId, + userId: user.id, platform: "notion", }); return result.accessToken; @@ -70,6 +72,12 @@ async function getNotionMcpHandler() { return ctx.user.organization_id; } + function getAuthUser() { + const ctx = authContextStorage.getStore(); + if (!ctx) throw new Error("Not authenticated"); + return ctx.user; + } + function jsonResult(data: object) { return { content: [{ type: "text" as const, text: JSON.stringify(data) }] }; } @@ -88,6 +96,7 @@ async function getNotionMcpHandler() { const orgId = getOrgId(); const connections = await oauthService.listConnections({ organizationId: orgId, + userId: getAuthUser().id, platform: "notion", }); const active = connections.find((c) => c.status === "active"); diff --git a/app/api/mcps/salesforce/[transport]/route.ts b/app/api/mcps/salesforce/[transport]/route.ts index 034cc5902..ced6ab3c1 100644 --- a/app/api/mcps/salesforce/[transport]/route.ts +++ b/app/api/mcps/salesforce/[transport]/route.ts @@ -47,8 +47,10 @@ async function getSalesforceMcpHandler() { const { z } = await import("zod3"); async function getSalesforceToken(organizationId: string): Promise { + const user = getAuthUser(); const result = await oauthService.getValidTokenByPlatform({ organizationId, + userId: user.id, platform: "salesforce", }); return result.accessToken; @@ -127,6 +129,12 @@ async function getSalesforceMcpHandler() { return ctx.user.organization_id; } + function getAuthUser() { + const ctx = authContextStorage.getStore(); + if (!ctx) throw new Error("Not authenticated"); + return ctx.user; + } + function jsonResult(data: object) { return { content: [{ type: "text" as const, text: JSON.stringify(data) }] }; } @@ -146,6 +154,7 @@ async function getSalesforceMcpHandler() { const orgId = getOrgId(); const connections = await oauthService.listConnections({ organizationId: orgId, + userId: getAuthUser().id, platform: "salesforce", }); const active = connections.find((c) => c.status === "active"); diff --git a/app/api/mcps/twitter/[transport]/route.ts b/app/api/mcps/twitter/[transport]/route.ts index 4fbbf5e2e..bddf817c2 100644 --- a/app/api/mcps/twitter/[transport]/route.ts +++ b/app/api/mcps/twitter/[transport]/route.ts @@ -57,8 +57,10 @@ async function getTwitterMcpHandler() { ); } + const user = getAuthUser(); const result = await oauthService.getValidTokenByPlatform({ organizationId, + userId: user.id, platform: "twitter", }); @@ -76,6 +78,12 @@ async function getTwitterMcpHandler() { return ctx.user.organization_id; } + function getAuthUser() { + const ctx = authContextStorage.getStore(); + if (!ctx) throw new Error("Not authenticated"); + return ctx.user; + } + function jsonResult(data: object) { return { content: [{ type: "text" as const, text: JSON.stringify(data) }] }; } @@ -136,6 +144,7 @@ async function getTwitterMcpHandler() { const orgId = getOrgId(); const connections = await oauthService.listConnections({ organizationId: orgId, + userId: getAuthUser().id, platform: "twitter", }); const active = connections.find((c) => c.status === "active"); diff --git a/app/api/mcps/zoom/[transport]/route.ts b/app/api/mcps/zoom/[transport]/route.ts index f706d6765..71a07f55d 100644 --- a/app/api/mcps/zoom/[transport]/route.ts +++ b/app/api/mcps/zoom/[transport]/route.ts @@ -40,7 +40,12 @@ async function getZoomMcpHandler() { const { z } = await import("zod3"); async function getZoomToken(organizationId: string): Promise { - const result = await oauthService.getValidTokenByPlatform({ organizationId, platform: "zoom" }); + const user = getAuthUser(); + const result = await oauthService.getValidTokenByPlatform({ + organizationId, + userId: user.id, + platform: "zoom", + }); return result.accessToken; } @@ -50,6 +55,12 @@ async function getZoomMcpHandler() { return ctx.user.organization_id; } + function getAuthUser() { + const ctx = authContextStorage.getStore(); + if (!ctx) throw new Error("Not authenticated"); + return ctx.user; + } + async function zoomFetch(orgId: string, path: string, options: RequestInit = {}) { const token = await getZoomToken(orgId); const url = `${ZOOM_API_BASE}${path}`; @@ -94,6 +105,7 @@ async function getZoomMcpHandler() { const orgId = getOrgId(); const connections = await oauthService.listConnections({ organizationId: orgId, + userId: getAuthUser().id, platform: "zoom", }); const active = connections.find((c) => c.status === "active"); diff --git a/app/api/v1/oauth/[platform]/callback/route.ts b/app/api/v1/oauth/[platform]/callback/route.ts index c275baf6f..81ca9add0 100644 --- a/app/api/v1/oauth/[platform]/callback/route.ts +++ b/app/api/v1/oauth/[platform]/callback/route.ts @@ -184,7 +184,10 @@ async function handleCallback( entitySettingsCache.invalidateUser(result.userId), edgeRuntimeCache.bumpMcpVersion(result.organizationId), incrementOAuthVersion(result.organizationId, platformLower), - connectionEnforcementService.invalidateRequiredConnectionCache(result.organizationId), + connectionEnforcementService.invalidateRequiredConnectionCache( + result.organizationId, + result.userId, + ), ]); } catch (e) { logger.warn(`[OAuth ${platform}] Cache invalidation failed`, { diff --git a/app/api/v1/oauth/[platform]/initiate/route.ts b/app/api/v1/oauth/[platform]/initiate/route.ts index 5f97a6708..6e155f838 100644 --- a/app/api/v1/oauth/[platform]/initiate/route.ts +++ b/app/api/v1/oauth/[platform]/initiate/route.ts @@ -9,6 +9,7 @@ import type { NextRequest } from "next/server"; import { NextResponse } from "next/server"; +import { ApiError } from "@/lib/api/errors"; import { requireAuthOrApiKeyWithOrg } from "@/lib/auth"; import { withRateLimit } from "@/lib/middleware/rate-limit"; import { getProvider, isProviderConfigured } from "@/lib/services/oauth/provider-registry"; @@ -35,6 +36,7 @@ async function handleInitiate(request: NextRequest, context?: RouteParams): Prom } const { platform } = await context.params; const platformLower = platform.toLowerCase(); + let organizationId: string | undefined; // Get provider configuration const provider = getProvider(platformLower); @@ -85,29 +87,28 @@ async function handleInitiate(request: NextRequest, context?: RouteParams): Prom ); } - // Authenticate request - const { user } = await requireAuthOrApiKeyWithOrg(request); - - // Parse request body - let body: InitiateRequestBody = {}; try { - body = (await request.json()) as InitiateRequestBody; - } catch { - // Empty body is fine, use defaults - } + const { user } = await requireAuthOrApiKeyWithOrg(request); + organizationId = user.organization_id; - const redirectUrl = body.redirectUrl || "/dashboard/settings?tab=connections"; - const scopes = body.scopes || provider.defaultScopes || []; + let body: InitiateRequestBody = {}; + try { + body = (await request.json()) as InitiateRequestBody; + } catch { + // Empty body is fine, use defaults + } - logger.info(`[OAuth ${platform}] Initiating auth`, { - organizationId: user.organization_id, - userId: user.id, - scopeCount: scopes.length, - }); + const redirectUrl = body.redirectUrl || "/dashboard/settings?tab=connections"; + const scopes = body.scopes || provider.defaultScopes || []; + + logger.info(`[OAuth ${platform}] Initiating auth`, { + organizationId, + userId: user.id, + scopeCount: scopes.length, + }); - try { const result = await initiateOAuth2(provider, { - organizationId: user.organization_id, + organizationId, userId: user.id, redirectUrl, scopes, @@ -123,10 +124,14 @@ async function handleInitiate(request: NextRequest, context?: RouteParams): Prom }); } catch (error) { logger.error(`[OAuth ${platform}] Failed to initiate auth`, { - organizationId: user.organization_id, + organizationId, error: error instanceof Error ? error.message : String(error), }); + if (error instanceof ApiError) { + return NextResponse.json(error.toJSON(), { status: error.status }); + } + return NextResponse.json( { error: "INITIATE_FAILED", diff --git a/app/api/v1/oauth/connect/route.ts b/app/api/v1/oauth/connect/route.ts index 6108404a5..a97b55e66 100644 --- a/app/api/v1/oauth/connect/route.ts +++ b/app/api/v1/oauth/connect/route.ts @@ -6,6 +6,7 @@ */ import { NextRequest, NextResponse } from "next/server"; +import { ApiError } from "@/lib/api/errors"; import { requireAuthOrApiKeyWithOrg } from "@/lib/auth"; import { internalErrorResponse, @@ -29,36 +30,41 @@ function isValidString(value: unknown): value is string { } export async function POST(request: NextRequest) { - const { user } = await requireAuthOrApiKeyWithOrg(request); + let organizationId: string | undefined; + let platform: string | undefined; - let body: ConnectRequestBody; try { - body = (await request.json()) as ConnectRequestBody; - } catch { - return NextResponse.json(validationErrorResponse("Invalid JSON body"), { status: 400 }); - } + const { user } = await requireAuthOrApiKeyWithOrg(request); + organizationId = user.organization_id; - if (!isValidString(body.platform)) { - return NextResponse.json( - validationErrorResponse("platform is required and must be a non-empty string"), - { status: 400 }, - ); - } + let body: ConnectRequestBody; + try { + body = (await request.json()) as ConnectRequestBody; + } catch { + return NextResponse.json(validationErrorResponse("Invalid JSON body"), { status: 400 }); + } - // Sanitize platform - lowercase and max 50 chars - body.platform = body.platform.toLowerCase().slice(0, 50); + if (!isValidString(body.platform)) { + return NextResponse.json( + validationErrorResponse("platform is required and must be a non-empty string"), + { status: 400 }, + ); + } - logger.info("[API] POST /api/v1/oauth/connect", { - organizationId: user.organization_id, - platform: body.platform, - hasScopes: !!body.scopes, - }); + // Sanitize platform - lowercase and max 50 chars + body.platform = body.platform.toLowerCase().slice(0, 50); + platform = body.platform; + + logger.info("[API] POST /api/v1/oauth/connect", { + organizationId, + platform, + hasScopes: !!body.scopes, + }); - try { const result = await oauthService.initiateAuth({ - organizationId: user.organization_id, + organizationId, userId: user.id, - platform: body.platform, + platform, redirectUrl: body.redirectUrl, scopes: body.scopes, }); @@ -66,11 +72,15 @@ export async function POST(request: NextRequest) { return NextResponse.json(result); } catch (error) { logger.error("[API] POST /api/v1/oauth/connect error", { - organizationId: user.organization_id, - platform: body.platform, + organizationId, + platform, error: error instanceof Error ? error.message : String(error), }); + if (error instanceof ApiError) { + return NextResponse.json(error.toJSON(), { status: error.status }); + } + if (error instanceof OAuthError) { return NextResponse.json(error.toResponse(), { status: error.httpStatus }); } diff --git a/app/api/v1/oauth/connections/[id]/route.ts b/app/api/v1/oauth/connections/[id]/route.ts index b82a39ddc..0fe107ee4 100644 --- a/app/api/v1/oauth/connections/[id]/route.ts +++ b/app/api/v1/oauth/connections/[id]/route.ts @@ -4,6 +4,7 @@ */ import { NextRequest, NextResponse } from "next/server"; +import { ApiError } from "@/lib/api/errors"; import { requireAuthOrApiKeyWithOrg } from "@/lib/auth"; import { Errors, internalErrorResponse, OAuthError, oauthService } from "@/lib/services/oauth"; import { invalidateOAuthState } from "@/lib/services/oauth/invalidation"; @@ -13,17 +14,20 @@ export const dynamic = "force-dynamic"; export const maxDuration = 30; export async function GET(request: NextRequest, { params }: { params: Promise<{ id: string }> }) { - const { user } = await requireAuthOrApiKeyWithOrg(request); const { id: connectionId } = await params; - - logger.debug("[API] GET /api/v1/oauth/connections/:id", { - organizationId: user.organization_id, - connectionId, - }); + let organizationId: string | undefined; try { + const { user } = await requireAuthOrApiKeyWithOrg(request); + organizationId = user.organization_id; + + logger.debug("[API] GET /api/v1/oauth/connections/:id", { + organizationId, + connectionId, + }); + const connection = await oauthService.getConnection({ - organizationId: user.organization_id, + organizationId, connectionId, }); @@ -41,11 +45,15 @@ export async function GET(request: NextRequest, { params }: { params: Promise<{ }); } catch (error) { logger.error("[API] GET /api/v1/oauth/connections/:id error", { - organizationId: user.organization_id, + organizationId, connectionId, error: error instanceof Error ? error.message : String(error), }); + if (error instanceof ApiError) { + return NextResponse.json(error.toJSON(), { status: error.status }); + } + if (error instanceof OAuthError) { return NextResponse.json(error.toResponse(), { status: error.httpStatus }); } @@ -58,30 +66,39 @@ export async function DELETE( request: NextRequest, { params }: { params: Promise<{ id: string }> }, ) { - const { user } = await requireAuthOrApiKeyWithOrg(request); const { id: connectionId } = await params; - - logger.info("[API] DELETE /api/v1/oauth/connections/:id", { - organizationId: user.organization_id, - connectionId, - }); + let organizationId: string | undefined; + let userId: string | undefined; try { + const { user } = await requireAuthOrApiKeyWithOrg(request); + organizationId = user.organization_id; + userId = user.id; + + logger.info("[API] DELETE /api/v1/oauth/connections/:id", { + organizationId, + connectionId, + }); + await oauthService.revokeConnection({ - organizationId: user.organization_id, + organizationId, connectionId, }); - await invalidateOAuthState(user.organization_id, "oauth", user.id, { skipVersionBump: true }); + await invalidateOAuthState(organizationId, "oauth", userId, { skipVersionBump: true }); return NextResponse.json({ success: true }); } catch (error) { logger.error("[API] DELETE /api/v1/oauth/connections/:id error", { - organizationId: user.organization_id, + organizationId, connectionId, error: error instanceof Error ? error.message : String(error), }); + if (error instanceof ApiError) { + return NextResponse.json(error.toJSON(), { status: error.status }); + } + if (error instanceof OAuthError) { return NextResponse.json(error.toResponse(), { status: error.httpStatus }); } diff --git a/app/api/v1/oauth/connections/[id]/token/route.ts b/app/api/v1/oauth/connections/[id]/token/route.ts index e03034576..e85cc88fe 100644 --- a/app/api/v1/oauth/connections/[id]/token/route.ts +++ b/app/api/v1/oauth/connections/[id]/token/route.ts @@ -6,6 +6,7 @@ */ import { NextRequest, NextResponse } from "next/server"; +import { ApiError } from "@/lib/api/errors"; import { requireAuthOrApiKeyWithOrg } from "@/lib/auth"; import { internalErrorResponse, OAuthError, oauthService } from "@/lib/services/oauth"; import { logger } from "@/lib/utils/logger"; @@ -14,17 +15,20 @@ export const dynamic = "force-dynamic"; export const maxDuration = 30; export async function GET(request: NextRequest, { params }: { params: Promise<{ id: string }> }) { - const { user } = await requireAuthOrApiKeyWithOrg(request); const { id: connectionId } = await params; - - logger.debug("[API] GET /api/v1/oauth/connections/:id/token", { - organizationId: user.organization_id, - connectionId, - }); + let organizationId: string | undefined; try { + const { user } = await requireAuthOrApiKeyWithOrg(request); + organizationId = user.organization_id; + + logger.debug("[API] GET /api/v1/oauth/connections/:id/token", { + organizationId, + connectionId, + }); + const token = await oauthService.getValidToken({ - organizationId: user.organization_id, + organizationId, connectionId, }); @@ -38,11 +42,15 @@ export async function GET(request: NextRequest, { params }: { params: Promise<{ }); } catch (error) { logger.error("[API] GET /api/v1/oauth/connections/:id/token error", { - organizationId: user.organization_id, + organizationId, connectionId, error: error instanceof Error ? error.message : String(error), }); + if (error instanceof ApiError) { + return NextResponse.json(error.toJSON(), { status: error.status }); + } + if (error instanceof OAuthError) { return NextResponse.json(error.toResponse(), { status: error.httpStatus }); } diff --git a/app/api/v1/oauth/connections/route.ts b/app/api/v1/oauth/connections/route.ts index daaeec311..9978fa26a 100644 --- a/app/api/v1/oauth/connections/route.ts +++ b/app/api/v1/oauth/connections/route.ts @@ -5,6 +5,7 @@ */ import { NextRequest, NextResponse } from "next/server"; +import { ApiError } from "@/lib/api/errors"; import { requireAuthOrApiKeyWithOrg } from "@/lib/auth"; import { internalErrorResponse, OAuthError, oauthService } from "@/lib/services/oauth"; import { logger } from "@/lib/utils/logger"; @@ -13,19 +14,22 @@ export const dynamic = "force-dynamic"; export const maxDuration = 30; export async function GET(request: NextRequest) { - const { user } = await requireAuthOrApiKeyWithOrg(request); - const { searchParams } = new URL(request.url); const platform = searchParams.get("platform") || undefined; - - logger.debug("[API] GET /api/v1/oauth/connections", { - organizationId: user.organization_id, - platform, - }); + let organizationId: string | undefined; try { + const { user } = await requireAuthOrApiKeyWithOrg(request); + organizationId = user.organization_id; + + logger.debug("[API] GET /api/v1/oauth/connections", { + organizationId, + platform, + }); + const connections = await oauthService.listConnections({ - organizationId: user.organization_id, + organizationId, + userId: user.id, platform, }); @@ -38,11 +42,15 @@ export async function GET(request: NextRequest) { }); } catch (error) { logger.error("[API] GET /api/v1/oauth/connections error", { - organizationId: user.organization_id, + organizationId, platform, error: error instanceof Error ? error.message : String(error), }); + if (error instanceof ApiError) { + return NextResponse.json(error.toJSON(), { status: error.status }); + } + if (error instanceof OAuthError) { return NextResponse.json(error.toResponse(), { status: error.httpStatus }); } diff --git a/app/api/v1/oauth/status/route.ts b/app/api/v1/oauth/status/route.ts index 3764f6615..34366226e 100644 --- a/app/api/v1/oauth/status/route.ts +++ b/app/api/v1/oauth/status/route.ts @@ -1,5 +1,6 @@ import type { NextRequest } from "next/server"; import { NextResponse } from "next/server"; +import { ApiError } from "@/lib/api/errors"; import { requireAuthOrApiKeyWithOrg } from "@/lib/auth"; import { blooioAutomationService } from "@/lib/services/blooio-automation"; import { oauthService } from "@/lib/services/oauth"; @@ -16,10 +17,14 @@ type LegacyServiceStatus = { error?: string; }; -async function getGoogleStatus(organizationId: string): Promise { +async function getGoogleStatus( + organizationId: string, + userId: string, +): Promise { try { const connections = await oauthService.listConnections({ organizationId, + userId, platform: "google", }); @@ -77,11 +82,14 @@ async function getBlooioStatus(organizationId: string): Promise { - const { user } = await requireAuthOrApiKeyWithOrg(request); + let organizationId: string | undefined; try { + const { user } = await requireAuthOrApiKeyWithOrg(request); + organizationId = user.organization_id; + const services = await Promise.all([ - getGoogleStatus(user.organization_id), + getGoogleStatus(user.organization_id, user.id), getTwilioStatus(user.organization_id), getBlooioStatus(user.organization_id), ]); @@ -89,10 +97,14 @@ export async function GET(request: NextRequest): Promise { return NextResponse.json({ services }); } catch (error) { logger.error("[OAuth Status] Failed to build legacy status response", { - organizationId: user.organization_id, + organizationId, error: error instanceof Error ? error.message : String(error), }); + if (error instanceof ApiError) { + return NextResponse.json(error.toJSON(), { status: error.status }); + } + return NextResponse.json({ error: "Failed to fetch OAuth status" }, { status: 500 }); } } diff --git a/app/api/v1/oauth/token/[platform]/route.ts b/app/api/v1/oauth/token/[platform]/route.ts index ecf0ed0e1..7a066d315 100644 --- a/app/api/v1/oauth/token/[platform]/route.ts +++ b/app/api/v1/oauth/token/[platform]/route.ts @@ -6,6 +6,7 @@ */ import { NextRequest, NextResponse } from "next/server"; +import { ApiError } from "@/lib/api/errors"; import { requireAuthOrApiKeyWithOrg } from "@/lib/auth"; import { Errors, @@ -23,22 +24,26 @@ export async function GET( request: NextRequest, { params }: { params: Promise<{ platform: string }> }, ) { - const { user } = await requireAuthOrApiKeyWithOrg(request); const { platform } = await params; + let organizationId: string | undefined; - logger.debug("[API] GET /api/v1/oauth/token/:platform", { - organizationId: user.organization_id, - platform, - }); + try { + const { user } = await requireAuthOrApiKeyWithOrg(request); + organizationId = user.organization_id; - if (!isValidProvider(platform)) { - const error = Errors.platformNotSupported(platform); - return NextResponse.json(error.toResponse(), { status: error.httpStatus }); - } + logger.debug("[API] GET /api/v1/oauth/token/:platform", { + organizationId, + platform, + }); + + if (!isValidProvider(platform)) { + const error = Errors.platformNotSupported(platform); + return NextResponse.json(error.toResponse(), { status: error.httpStatus }); + } - try { const { token, connectionId } = await oauthService.getValidTokenByPlatformWithConnectionId({ - organizationId: user.organization_id, + organizationId, + userId: user.id, platform, }); @@ -53,11 +58,15 @@ export async function GET( }); } catch (error) { logger.error("[API] GET /api/v1/oauth/token/:platform error", { - organizationId: user.organization_id, + organizationId, platform, error: error instanceof Error ? error.message : String(error), }); + if (error instanceof ApiError) { + return NextResponse.json(error.toJSON(), { status: error.status }); + } + if (error instanceof OAuthError) { return NextResponse.json(error.toResponse(), { status: error.httpStatus }); } diff --git a/packages/lib/eliza/runtime-factory.ts b/packages/lib/eliza/runtime-factory.ts index 6e444361b..66316e05d 100644 --- a/packages/lib/eliza/runtime-factory.ts +++ b/packages/lib/eliza/runtime-factory.ts @@ -1006,14 +1006,14 @@ export class RuntimeFactory { } } +export const runtimeFactory = RuntimeFactory.getInstance(); + export function getRuntimeCacheStats(): { runtime: { size: number; maxSize: number }; } { return runtimeFactory.getCacheStats(); } -export const runtimeFactory = RuntimeFactory.getInstance(); - export async function invalidateRuntime(agentId: string): Promise { return runtimeFactory.invalidateRuntime(agentId); } diff --git a/packages/lib/services/eliza-app/connection-enforcement.ts b/packages/lib/services/eliza-app/connection-enforcement.ts index 69612d6b0..434064f9c 100644 --- a/packages/lib/services/eliza-app/connection-enforcement.ts +++ b/packages/lib/services/eliza-app/connection-enforcement.ts @@ -139,18 +139,21 @@ function formatConversationHistory(messages: ConversationMessage[]): string { .join("\n")}\n`; } -function getConversationKey(organizationId: string): string { - return `connection-enforcement:conversation:${organizationId}`; +function getConversationKey(organizationId: string, userId: string): string { + return `connection-enforcement:conversation:${organizationId}:${userId}`; } -function getConnectionStatusKey(organizationId: string): string { - return `connection-enforcement:required-connection:${organizationId}`; +function getConnectionStatusKey(organizationId: string, userId: string): string { + return `connection-enforcement:required-connection:${organizationId}:${userId}`; } -async function loadConversationState(organizationId: string): Promise { +async function loadConversationState( + organizationId: string, + userId: string, +): Promise { try { return ( - (await cache.get(getConversationKey(organizationId))) ?? { + (await cache.get(getConversationKey(organizationId, userId))) ?? { messageCount: 0, messages: [], } @@ -162,6 +165,7 @@ async function loadConversationState(organizationId: string): Promise { try { @@ -169,10 +173,15 @@ async function saveConversationState( ...state, messages: state.messages.slice(-MAX_HISTORY_MESSAGES), }; - await cache.set(getConversationKey(organizationId), normalizedState, CONVERSATION_TTL_SECONDS); + await cache.set( + getConversationKey(organizationId, userId), + normalizedState, + CONVERSATION_TTL_SECONDS, + ); } catch (error) { logger.warn("[ConnectionEnforcement] Failed to persist conversation state", { organizationId, + userId, error: error instanceof Error ? error.message : String(error), }); } @@ -335,15 +344,15 @@ function detectProviderFromMessage(message: string): RequiredPlatform | null { } class ConnectionEnforcementService { - async hasRequiredConnection(organizationId: string): Promise { + async hasRequiredConnection(organizationId: string, userId: string): Promise { try { - const cacheKey = getConnectionStatusKey(organizationId); + const cacheKey = getConnectionStatusKey(organizationId, userId); const cached = await cache.get(cacheKey); if (typeof cached === "boolean") { return cached; } - const connectedPlatforms = await oauthService.getConnectedPlatforms(organizationId); + const connectedPlatforms = await oauthService.getConnectedPlatforms(organizationId, userId); const hasRequired = connectedPlatforms.some((platform) => (REQUIRED_PLATFORMS as readonly string[]).includes(platform), ); @@ -353,18 +362,25 @@ class ConnectionEnforcementService { } catch (error) { logger.error("[ConnectionEnforcement] Failed to check connections", { organizationId, + userId, error: error instanceof Error ? error.message : String(error), }); return true; } } - async invalidateRequiredConnectionCache(organizationId: string): Promise { + async invalidateRequiredConnectionCache(organizationId: string, userId?: string): Promise { try { - await cache.del(getConnectionStatusKey(organizationId)); + if (userId) { + await cache.del(getConnectionStatusKey(organizationId, userId)); + return; + } + + await cache.delPattern(`connection-enforcement:required-connection:${organizationId}:*`); } catch (error) { logger.warn("[ConnectionEnforcement] Failed to invalidate connection cache", { organizationId, + userId, error: error instanceof Error ? error.message : String(error), }); } @@ -372,7 +388,7 @@ class ConnectionEnforcementService { async generateNudgeResponse(params: NudgeParams): Promise { const { userMessage, platform, organizationId, userId } = params; - const state = await loadConversationState(organizationId); + const state = await loadConversationState(organizationId, userId); const conversationHistory = formatConversationHistory(state.messages); const detectedProvider = detectProviderFromMessage(userMessage); const isFirstInteraction = state.messageCount === 0; @@ -409,7 +425,7 @@ class ConnectionEnforcementService { state.messages.push({ role: "user", content: userMessage }); state.messages.push({ role: "assistant", content: responseForHistory }); state.messageCount += 1; - await saveConversationState(organizationId, state); + await saveConversationState(organizationId, userId, state); return response; } diff --git a/packages/lib/services/oauth/connection-adapters/generic-adapter.ts b/packages/lib/services/oauth/connection-adapters/generic-adapter.ts index b45e535a5..c077b47e1 100644 --- a/packages/lib/services/oauth/connection-adapters/generic-adapter.ts +++ b/packages/lib/services/oauth/connection-adapters/generic-adapter.ts @@ -108,6 +108,7 @@ export function createGenericAdapter(platform: string): ConnectionAdapter { return credentials.map((cred) => ({ id: cred.id, + userId: cred.user_id || undefined, platform, platformUserId: cred.platform_user_id, email: cred.platform_email || undefined, diff --git a/packages/lib/services/oauth/oauth-service.ts b/packages/lib/services/oauth/oauth-service.ts index 2d2b8f732..f67c688f4 100644 --- a/packages/lib/services/oauth/oauth-service.ts +++ b/packages/lib/services/oauth/oauth-service.ts @@ -27,6 +27,68 @@ import type { const DEFAULT_REDIRECT = "/dashboard/settings?tab=connections"; const STATE_TTL = 600; // 10 minutes +export function sortConnectionsByRecency(connections: OAuthConnection[]): OAuthConnection[] { + return [...connections].sort((a, b) => { + const aTime = a.lastUsedAt?.getTime() || a.linkedAt.getTime(); + const bTime = b.lastUsedAt?.getTime() || b.linkedAt.getTime(); + return bTime - aTime; + }); +} + +export function getMostRecentActiveConnection( + connections: OAuthConnection[], +): OAuthConnection | null { + const active = connections.filter((c) => c.status === "active"); + if (active.length === 0) return null; + return active.reduce((most, conn) => { + const mostTime = most.lastUsedAt?.getTime() || most.linkedAt.getTime(); + const connTime = conn.lastUsedAt?.getTime() || conn.linkedAt.getTime(); + return connTime > mostTime ? conn : most; + }); +} + +export function getPreferredActiveConnection( + connections: OAuthConnection[], + userId?: string, +): OAuthConnection | null { + if (!userId) { + return getMostRecentActiveConnection(connections); + } + + const ownedConnection = getMostRecentActiveConnection( + connections.filter((connection) => connection.userId === userId), + ); + if (ownedConnection) { + return ownedConnection; + } + + return getMostRecentActiveConnection( + connections.filter((connection) => connection.userId === undefined), + ); +} + +export function scopeConnectionsForUser( + connections: OAuthConnection[], + userId?: string, +): OAuthConnection[] { + if (!userId) { + return sortConnectionsByRecency(connections); + } + + const ownedConnections = sortConnectionsByRecency( + connections.filter((connection) => connection.userId === userId), + ); + const sharedConnections = sortConnectionsByRecency( + connections.filter((connection) => connection.userId === undefined), + ); + + if (ownedConnections.length > 0) { + return [...ownedConnections, ...sharedConnections]; + } + + return sharedConnections; +} + class OAuthService { /** List all available OAuth providers with configuration status */ listProviders(): OAuthProviderInfo[] { @@ -101,7 +163,7 @@ class OAuthService { /** List all OAuth connections for an organization */ async listConnections(params: ListConnectionsParams): Promise { - const { organizationId, platform } = params; + const { organizationId, platform, userId } = params; const adapters = platform ? [getAdapter(platform)].filter(Boolean) : getAllAdapters(); const results = await Promise.allSettled( adapters.map((a) => a!.listConnections(organizationId)), @@ -115,7 +177,7 @@ class OAuthService { return []; }); - return this.sortConnectionsByRecency(connections); + return scopeConnectionsForUser(connections, userId); } /** Get a single connection by ID */ @@ -181,13 +243,13 @@ class OAuthService { async getValidTokenByPlatformWithConnectionId( params: GetTokenByPlatformParams, ): Promise<{ token: TokenResult; connectionId: string }> { - const { organizationId, platform } = params; + const { organizationId, platform, userId } = params; const adapter = getAdapter(platform); if (!adapter) throw Errors.platformNotSupported(platform); const connections = await adapter.listConnections(organizationId); - const activeConnection = this.getMostRecentActive(connections); + const activeConnection = getPreferredActiveConnection(connections, userId); if (!activeConnection) throw Errors.platformNotConnected(platform); const token = await this.getValidToken({ @@ -199,17 +261,21 @@ class OAuthService { } /** Check if a platform has an active connection */ - async isPlatformConnected(organizationId: string, platform: string): Promise { + async isPlatformConnected( + organizationId: string, + platform: string, + userId?: string, + ): Promise { const adapter = getAdapter(platform); if (!adapter) return false; const connections = await adapter.listConnections(organizationId); - return connections.some((c) => c.status === "active"); + return getPreferredActiveConnection(connections, userId) !== null; } /** Get all platforms with active connections */ - async getConnectedPlatforms(organizationId: string): Promise { - const connections = await this.listConnections({ organizationId }); + async getConnectedPlatforms(organizationId: string, userId?: string): Promise { + const connections = await this.listConnections({ organizationId, userId }); return [...new Set(connections.filter((c) => c.status === "active").map((c) => c.platform))]; } @@ -227,23 +293,8 @@ class OAuthService { } return null; } - - private getMostRecentActive(connections: OAuthConnection[]): OAuthConnection | null { - const active = connections.filter((c) => c.status === "active"); - if (active.length === 0) return null; - return active.reduce((most, conn) => { - const mostTime = most.lastUsedAt?.getTime() || most.linkedAt.getTime(); - const connTime = conn.lastUsedAt?.getTime() || conn.linkedAt.getTime(); - return connTime > mostTime ? conn : most; - }); - } - private sortConnectionsByRecency(connections: OAuthConnection[]): OAuthConnection[] { - return connections.sort((a, b) => { - const aTime = a.lastUsedAt?.getTime() || a.linkedAt.getTime(); - const bTime = b.lastUsedAt?.getTime() || b.linkedAt.getTime(); - return bTime - aTime; - }); + return sortConnectionsByRecency(connections); } } diff --git a/packages/lib/services/oauth/types.ts b/packages/lib/services/oauth/types.ts index 999319614..c7a347fb0 100644 --- a/packages/lib/services/oauth/types.ts +++ b/packages/lib/services/oauth/types.ts @@ -35,6 +35,8 @@ export interface OAuthProviderInfo { export interface OAuthConnection { /** Unique connection identifier */ id: string; + /** Cloud user that owns the connection when user-scoped */ + userId?: string; /** Platform identifier (e.g., 'google', 'twitter') */ platform: string; /** User ID on the platform */ @@ -113,6 +115,8 @@ export interface InitiateAuthResult { export interface ListConnectionsParams { /** Organization to list connections for */ organizationId: string; + /** Optional user scope within the organization */ + userId?: string; /** Optional platform filter */ platform?: string; } @@ -133,6 +137,8 @@ export interface GetTokenParams { export interface GetTokenByPlatformParams { /** Organization owning the connection */ organizationId: string; + /** Optional user scope within the organization */ + userId?: string; /** Platform identifier */ platform: string; } diff --git a/packages/services/gateway-discord/tests/gateway-manager.test.ts b/packages/services/gateway-discord/tests/gateway-manager.test.ts index 4dad73651..d0739a841 100644 --- a/packages/services/gateway-discord/tests/gateway-manager.test.ts +++ b/packages/services/gateway-discord/tests/gateway-manager.test.ts @@ -8,10 +8,50 @@ * Integration tests should cover the full flow. */ -import { afterEach, describe, expect, test } from "bun:test"; +import { afterEach, beforeEach, describe, expect, mock, test } from "bun:test"; + +mock.module("discord.js", () => ({ + Attachment: class MockAttachment {}, + Client: class MockDiscordClient {}, + Events: {}, + GatewayIntentBits: { + Guilds: 1, + GuildMessages: 2, + MessageContent: 4, + GuildMessageReactions: 8, + DirectMessages: 16, + }, + Interaction: class MockInteraction {}, + Message: class MockMessage {}, + MessageFlags: { + IsVoiceMessage: 1 << 13, + }, + MessageReaction: class MockMessageReaction {}, + Partials: {}, + Role: class MockRole {}, +})); + +mock.module("@upstash/redis", () => ({ + Redis: class MockRedis { + static fromEnv() { + return new MockRedis(); + } + }, +})); + +mock.module("hashring", () => ({ + default: class MockHashRing { + constructor(_nodes: string[]) {} + + range(_key: string, _count: number) { + return []; + } + }, +})); // Store original env const originalEnv = { ...process.env }; +const originalFetch = globalThis.fetch; describe("GatewayManager helper functions", () => { describe("parseIntEnv logic", () => { @@ -294,6 +334,136 @@ describe("GatewayManager failover logic", () => { }); }); +describe("GatewayManager managed Milady guild routing", () => { + beforeEach(() => { + globalThis.fetch = originalFetch; + }); + + afterEach(() => { + globalThis.fetch = originalFetch; + }); + + function createMentions(userIds: string[], repliedUserId?: string, everyone = false) { + const users = userIds.map((id) => ({ id })); + return { + users: { + has: (id: string) => users.some((user) => user.id === id), + some: (predicate: (user: { id: string }) => boolean) => users.some(predicate), + }, + repliedUser: repliedUserId ? { id: repliedUserId } : null, + everyone, + }; + } + + async function createGatewayManager() { + const { GatewayManager } = await import("../src/gateway-manager"); + const manager = new GatewayManager({ + podName: "pod-1", + elizaCloudUrl: "https://cloud.example", + gatewayBootstrapSecret: "gateway-secret", + project: "test-project", + }); + (manager as any).accessToken = "internal-jwt"; + (manager as any).elizaAppClient = { + user: { + id: "bot-1", + }, + }; + return manager; + } + + test("ignores managed guild messages that also target another mentioned user", async () => { + const fetchMock = mock(); + globalThis.fetch = fetchMock as typeof fetch; + const manager = await createGatewayManager(); + + await (manager as any).handleManagedMiladyGuildMessage({ + guildId: "guild-1", + channelId: "channel-1", + id: "message-1", + content: "<@bot-1> can you help <@user-2>?", + mentions: createMentions(["bot-1", "user-2"]), + channel: { + sendTyping: mock(), + }, + author: { + id: "discord-user-1", + username: "owner", + globalName: "Owner Person", + displayAvatarURL: () => "https://cdn.discordapp.com/avatar.png", + }, + member: { + displayName: "Owner Person", + }, + reply: mock(), + }); + + expect(fetchMock).not.toHaveBeenCalled(); + }); + + test("routes managed guild messages that mention only the bot", async () => { + const fetchMock = mock(async (_input: RequestInfo | URL, init?: RequestInit) => { + const body = JSON.parse(String(init?.body)); + expect(body).toEqual({ + guildId: "guild-1", + channelId: "channel-1", + messageId: "message-1", + content: "hello there", + sender: { + id: "discord-user-1", + username: "owner", + displayName: "Owner Person", + avatar: "https://cdn.discordapp.com/avatar.png", + }, + }); + return new Response( + JSON.stringify({ + handled: true, + replyText: "hello from agent", + }), + { + status: 200, + headers: { + "Content-Type": "application/json", + }, + }, + ); + }); + globalThis.fetch = fetchMock as typeof fetch; + const manager = await createGatewayManager(); + const sendTyping = mock(async () => {}); + const reply = mock(async () => {}); + + await (manager as any).handleManagedMiladyGuildMessage({ + guildId: "guild-1", + channelId: "channel-1", + id: "message-1", + content: "<@bot-1> hello there", + mentions: createMentions(["bot-1"]), + channel: { + sendTyping, + }, + author: { + id: "discord-user-1", + username: "owner", + globalName: "Owner Person", + displayAvatarURL: () => "https://cdn.discordapp.com/avatar.png", + }, + member: { + displayName: "Owner Person", + }, + reply, + }); + + expect(fetchMock).toHaveBeenCalledTimes(1); + expect(sendTyping).toHaveBeenCalledTimes(1); + expect(reply).toHaveBeenCalledWith({ + content: "hello from agent", + allowedMentions: { repliedUser: false }, + }); + }); +}); + describe("GatewayManager constants", () => { test("bot poll interval is 30 seconds", () => { const BOT_POLL_INTERVAL_MS = 30_000; diff --git a/packages/tests/e2e/setup-server.ts b/packages/tests/e2e/setup-server.ts index 7860a7349..7e72bb4b0 100644 --- a/packages/tests/e2e/setup-server.ts +++ b/packages/tests/e2e/setup-server.ts @@ -1,6 +1,8 @@ import type { Subprocess } from "bun"; -const SERVER_URL = process.env.TEST_BASE_URL || "http://localhost:3000"; +const TEST_SERVER_PORT = process.env.TEST_SERVER_PORT || "3000"; +const SERVER_URL = process.env.TEST_BASE_URL || `http://localhost:${TEST_SERVER_PORT}`; +const TEST_SERVER_DIST_DIR = process.env.TEST_SERVER_DIST_DIR || `.next-test-${TEST_SERVER_PORT}`; const HEALTH_ENDPOINT = `${SERVER_URL}/api/health`; // Cold Next.js webpack boots can take noticeably longer after large test suites // or when the first request has to compile the health route. @@ -20,6 +22,7 @@ let serverProcess: Subprocess | null = null; let startedServer = false; let serverStartupPromise: Promise | null = null; let serverExitError: Error | null = null; +let detectedPeerServerStartup = false; async function isServerRunning(): Promise { const controller = new AbortController(); @@ -73,6 +76,12 @@ function pipeServerLogs( text.includes("Local:") || text.includes("Error")) ) { + if ( + label === "stderr" && + (text.includes("Unable to acquire lock") || text.includes("EADDRINUSE")) + ) { + detectedPeerServerStartup = true; + } console.log(`[E2E Server:${label}] ${text}`); } } @@ -86,7 +95,7 @@ function pipeServerLogs( } function watchServerExit(process: Subprocess): void { - void process.exited.then((code) => { + void process.exited.then(async (code) => { if (serverProcess !== process) { return; } @@ -97,6 +106,11 @@ function watchServerExit(process: Subprocess): void { } if (code !== 0 && code !== 15) { + await Bun.sleep(250); + if (detectedPeerServerStartup || (await isServerRunning())) { + console.warn("[E2E Server] Detected another worker-owned dev server; waiting for health"); + return; + } serverExitError = new Error(`E2E server exited with code ${code}`); console.error(`[E2E Server] ${serverExitError.message}`); } @@ -122,6 +136,7 @@ async function stopServer(): Promise { startedServer = false; serverStartupPromise = null; serverExitError = null; + detectedPeerServerStartup = false; if (proc) { // Kill the entire process group so child processes (webpack, etc.) also die @@ -146,7 +161,7 @@ async function stopServer(): Promise { // Always wait for the port to be released, even without a process — // something else may still hold the port. - await waitForPortRelease(3000); + await waitForPortRelease(Number(TEST_SERVER_PORT)); } export async function ensureServer(): Promise { @@ -173,17 +188,19 @@ export async function ensureServer(): Promise { } // Ensure the port is free before spawning. - await waitForPortRelease(3000); + await waitForPortRelease(Number(TEST_SERVER_PORT)); startedServer = true; serverExitError = null; + detectedPeerServerStartup = false; serverProcess = Bun.spawn(["bun", "run", TEST_SERVER_SCRIPT], { cwd: process.cwd(), stdio: ["ignore", "pipe", "pipe"], env: { ...process.env, NODE_ENV: "development", - PORT: "3000", + NEXT_DIST_DIR: TEST_SERVER_DIST_DIR, + PORT: TEST_SERVER_PORT, }, }); diff --git a/packages/tests/infrastructure/index.ts b/packages/tests/infrastructure/index.ts index ba16dac9f..ae78dc68a 100644 --- a/packages/tests/infrastructure/index.ts +++ b/packages/tests/infrastructure/index.ts @@ -2,6 +2,8 @@ * Test Infrastructure Exports */ +// Test runtime - direct access to production RuntimeFactory +export { AgentMode } from "../../lib/eliza/agent-mode-types"; // HTTP/SSE test utilities export { assertStreamingOrder, @@ -38,11 +40,9 @@ export { type TestOrganization, type TestUser, } from "./test-data-factory"; -// Test runtime - direct access to production RuntimeFactory export { // Test internals for race condition testing _testing, - AgentMode, buildUserContext, // Test helpers createTestRuntime, diff --git a/packages/tests/infrastructure/test-runtime.ts b/packages/tests/infrastructure/test-runtime.ts index f02523897..85fc4b228 100644 --- a/packages/tests/infrastructure/test-runtime.ts +++ b/packages/tests/infrastructure/test-runtime.ts @@ -11,6 +11,7 @@ import { stringToUuid, type UUID } from "@elizaos/core"; import { v4 as uuidv4 } from "uuid"; import type { DebugRenderView, DebugTrace } from "../../lib/debug"; +import { AgentMode } from "../../lib/eliza/agent-mode-types"; // Import server check - this must complete before tests run import { serverReady } from "../e2e/setup-server"; @@ -31,7 +32,6 @@ export { listDebugTraces, renderDebugTrace, } from "../../lib/debug"; -export { AgentMode } from "../../lib/eliza/agent-mode-types"; // Re-export the production RuntimeFactory directly export { _testing, @@ -43,8 +43,8 @@ export { } from "../../lib/eliza/runtime-factory"; // Re-export types from the production code export type { UserContext } from "../../lib/eliza/user-context"; +export { AgentMode }; -import type { AgentMode } from "../../lib/eliza/agent-mode-types"; // Import for type inference import type { runtimeFactory as RuntimeFactoryType } from "../../lib/eliza/runtime-factory"; import type { UserContext } from "../../lib/eliza/user-context"; @@ -210,7 +210,7 @@ export function buildUserContext( ); } - const mode = options.agentMode || ("ASSISTANT" as AgentMode); + const mode = options.agentMode || AgentMode.ASSISTANT; return { userId: testData.user.id, diff --git a/packages/tests/integration/connection-apis.test.ts b/packages/tests/integration/connection-apis.test.ts index aa6c47c5c..71e6dba27 100644 --- a/packages/tests/integration/connection-apis.test.ts +++ b/packages/tests/integration/connection-apis.test.ts @@ -7,6 +7,7 @@ import { afterAll, beforeAll, describe, expect, it } from "bun:test"; import { Client } from "pg"; +import { apiKeysService } from "@/lib/services/api-keys"; import { blooioAutomationService } from "@/lib/services/blooio-automation"; import { twilioAutomationService } from "@/lib/services/twilio-automation"; import { @@ -361,17 +362,21 @@ describe.skipIf(!TEST_DB_URL)("Connection APIs E2E Tests", () => { }); it("should handle expired API key", async () => { - // Create an expired API key - const expiredKeyId = await client.query( - `INSERT INTO api_keys (id, name, key, key_hash, key_prefix, organization_id, user_id, is_active, expires_at) - VALUES ($1, 'Expired Key', 'ek_expired_123', 'hash123', 'ek_expired_', $2, $3, true, NOW() - INTERVAL '1 day') - RETURNING key`, - ["00000000-0000-0000-0000-000000000001", testData.organization.id, testData.user.id], + const { apiKey: expiredKey, plainKey: expiredKeyValue } = await apiKeysService.create({ + name: `Expired Key ${crypto.randomUUID()}`, + organization_id: testData.organization.id, + user_id: testData.user.id, + is_active: true, + }); + + await client.query( + `UPDATE api_keys SET expires_at = NOW() - INTERVAL '1 day' WHERE id = $1`, + [expiredKey.id], ); const response = await fetch(`${BASE_URL}/api/v1/oauth/connections?platform=google`, { headers: { - Authorization: `Bearer ${expiredKeyId.rows[0].key}`, + Authorization: `Bearer ${expiredKeyValue}`, }, }); @@ -379,9 +384,7 @@ describe.skipIf(!TEST_DB_URL)("Connection APIs E2E Tests", () => { expect([401, 403]).toContain(response.status); // Cleanup - await client.query(`DELETE FROM api_keys WHERE id = $1`, [ - "00000000-0000-0000-0000-000000000001", - ]); + await client.query(`DELETE FROM api_keys WHERE id = $1`, [expiredKey.id]); }); it("should handle empty request body for connect endpoints", async () => { diff --git a/packages/tests/load-env.ts b/packages/tests/load-env.ts index 2fd48ef6f..fe6220a91 100644 --- a/packages/tests/load-env.ts +++ b/packages/tests/load-env.ts @@ -22,10 +22,14 @@ if (process.env.SKIP_DB_DEPENDENT === "1") { } else { const shouldPreferLocalDockerDb = process.env.CI !== "true" && process.env.DISABLE_LOCAL_DOCKER_DB_FALLBACK !== "1"; + const localDockerDatabaseUrl = getLocalDockerDatabaseUrl({ + ...process.env, + LOCAL_DOCKER_DB_HOST: process.env.LOCAL_DOCKER_DB_HOST || "localhost", + }); const testDatabaseUrl = process.env.TEST_DATABASE_URL || - (shouldPreferLocalDockerDb ? getLocalDockerDatabaseUrl(process.env) : process.env.DATABASE_URL); + (shouldPreferLocalDockerDb ? localDockerDatabaseUrl : process.env.DATABASE_URL); if (testDatabaseUrl) { process.env.TEST_DATABASE_URL = testDatabaseUrl; diff --git a/packages/tests/unit/discord-automation-oauth.test.ts b/packages/tests/unit/discord-automation-oauth.test.ts new file mode 100644 index 000000000..bf9a59e2e --- /dev/null +++ b/packages/tests/unit/discord-automation-oauth.test.ts @@ -0,0 +1,235 @@ +import { afterAll, beforeAll, beforeEach, describe, expect, mock, test } from "bun:test"; + +const mockDiscordGuildUpsert = mock(); +const mockDiscordChannelUpsert = mock(); +const mockLogger = { + info: mock(), + warn: mock(), + error: mock(), + debug: mock(), +}; + +mock.module("@/db/repositories/discord-guilds", () => ({ + discordGuildsRepository: { + upsert: mockDiscordGuildUpsert, + }, +})); + +mock.module("@/db/repositories/discord-channels", () => ({ + discordChannelsRepository: { + upsert: mockDiscordChannelUpsert, + }, +})); + +mock.module("@/lib/utils/logger", () => ({ + logger: mockLogger, +})); + +const originalEnv = { ...process.env }; +const originalFetch = globalThis.fetch; + +let discordAutomationService: typeof import("@/lib/services/discord-automation").discordAutomationService; + +function jsonResponse(body: unknown, status = 200): Response { + return new Response(JSON.stringify(body), { + status, + headers: { + "Content-Type": "application/json", + }, + }); +} + +function textResponse(body: string, status = 200): Response { + return new Response(body, { status }); +} + +describe("discordAutomationService.handleBotOAuthCallback", () => { + beforeAll(async () => { + process.env.DISCORD_CLIENT_ID = "discord-app-1"; + process.env.DISCORD_CLIENT_SECRET = "discord-secret"; + process.env.DISCORD_BOT_TOKEN = "discord-bot-token"; + process.env.NEXT_PUBLIC_APP_URL = "https://cloud.example"; + + ({ discordAutomationService } = await import("@/lib/services/discord-automation")); + }); + + beforeEach(() => { + mockDiscordGuildUpsert.mockReset(); + mockDiscordChannelUpsert.mockReset(); + mockLogger.info.mockReset(); + mockLogger.warn.mockReset(); + mockLogger.error.mockReset(); + mockLogger.debug.mockReset(); + }); + + afterAll(() => { + process.env = { ...originalEnv }; + globalThis.fetch = originalFetch; + }); + + test("rejects managed installs when the Discord user does not own the target server", async () => { + const fetchMock = mock(async (input: RequestInfo | URL) => { + const url = String(input); + if (url.endsWith("/oauth2/token")) { + return jsonResponse({ access_token: "oauth-access-token" }); + } + if (url.endsWith("/users/@me")) { + return jsonResponse({ + id: "discord-user-1", + username: "owner", + global_name: "Owner Person", + avatar: null, + }); + } + if (url.endsWith("/users/@me/guilds")) { + return jsonResponse([ + { + id: "guild-1", + name: "Guild One", + icon: null, + owner: false, + permissions: "8", + features: [], + }, + ]); + } + throw new Error(`Unexpected fetch: ${url}`); + }); + globalThis.fetch = fetchMock as typeof fetch; + + const result = await discordAutomationService.handleBotOAuthCallback({ + code: "oauth-code", + guildId: "guild-1", + oauthState: { + organizationId: "org-1", + userId: "user-1", + returnUrl: "https://cloud.example/dashboard/settings?tab=agents", + nonce: "nonce-1", + flow: "milady-managed", + agentId: "agent-1", + }, + }); + + expect(result).toEqual({ + success: false, + error: "Discord account must own the server", + }); + expect(fetchMock).toHaveBeenCalledTimes(3); + expect(mockDiscordGuildUpsert).not.toHaveBeenCalled(); + expect(mockDiscordChannelUpsert).not.toHaveBeenCalled(); + }); + + test("stores the guild, refreshes channels, and applies the requested nickname for successful managed installs", async () => { + const fetchMock = mock(async (input: RequestInfo | URL, init?: RequestInit) => { + const url = String(input); + if (url.endsWith("/oauth2/token")) { + return jsonResponse({ access_token: "oauth-access-token" }); + } + if (url.endsWith("/users/@me")) { + return jsonResponse({ + id: "discord-user-1", + username: "owner", + global_name: "Owner Person", + avatar: "avatar-hash", + }); + } + if (url.endsWith("/users/@me/guilds")) { + return jsonResponse([ + { + id: "guild-1", + name: "Guild One", + icon: "guild-icon", + owner: true, + permissions: "8", + features: ["COMMUNITY"], + }, + ]); + } + if (url.endsWith("/guilds/guild-1")) { + return jsonResponse({ + id: "guild-1", + name: "Guild One", + icon: "guild-icon", + }); + } + if (url.endsWith("/guilds/guild-1/channels")) { + return jsonResponse([ + { + id: "text-1", + name: "general", + type: 0, + parent_id: null, + position: 1, + guild_id: "guild-1", + }, + { + id: "voice-1", + name: "Voice", + type: 2, + parent_id: null, + position: 2, + guild_id: "guild-1", + }, + ]); + } + if (url.endsWith("/guilds/guild-1/members/@me")) { + expect(init?.method).toBe("PATCH"); + expect(init?.body).toBe( + JSON.stringify({ + nick: "Milady Cloud Agent With A Long Na", + }), + ); + return textResponse("", 204); + } + throw new Error(`Unexpected fetch: ${url}`); + }); + globalThis.fetch = fetchMock as typeof fetch; + + const result = await discordAutomationService.handleBotOAuthCallback({ + code: "oauth-code", + guildId: "guild-1", + oauthState: { + organizationId: "org-1", + userId: "user-1", + returnUrl: "https://cloud.example/dashboard/settings?tab=agents", + nonce: "nonce-2", + flow: "milady-managed", + agentId: "agent-1", + botNickname: "Milady Cloud Agent With A Long Name", + }, + }); + + expect(result).toEqual({ + success: true, + guildId: "guild-1", + guildName: "Guild One", + discordUser: { + id: "discord-user-1", + username: "owner", + globalName: "Owner Person", + avatar: "avatar-hash", + }, + }); + expect(mockDiscordGuildUpsert).toHaveBeenCalledWith({ + organization_id: "org-1", + guild_id: "guild-1", + guild_name: "Guild One", + icon_hash: "guild-icon", + owner_id: "discord-user-1", + bot_permissions: "67193856", + }); + expect(mockDiscordChannelUpsert).toHaveBeenCalledTimes(1); + expect(mockDiscordChannelUpsert).toHaveBeenCalledWith({ + organization_id: "org-1", + guild_id: "guild-1", + channel_id: "text-1", + channel_name: "general", + channel_type: 0, + parent_id: null, + position: 1, + can_send_messages: true, + is_nsfw: false, + }); + expect(fetchMock).toHaveBeenCalledTimes(6); + }); +}); diff --git a/packages/tests/unit/eliza-app/connection-enforcement.test.ts b/packages/tests/unit/eliza-app/connection-enforcement.test.ts index df7d10da7..b1b99aa21 100644 --- a/packages/tests/unit/eliza-app/connection-enforcement.test.ts +++ b/packages/tests/unit/eliza-app/connection-enforcement.test.ts @@ -87,14 +87,21 @@ describe("connection enforcement", () => { test("caches required connection checks and invalidates them after OAuth", async () => { mockGetConnectedPlatforms.mockResolvedValueOnce(["google"]).mockResolvedValueOnce([]); - await expect(connectionEnforcementService.hasRequiredConnection("org-1")).resolves.toBe(true); - await expect(connectionEnforcementService.hasRequiredConnection("org-1")).resolves.toBe(true); + await expect( + connectionEnforcementService.hasRequiredConnection("org-1", "user-1"), + ).resolves.toBe(true); + await expect( + connectionEnforcementService.hasRequiredConnection("org-1", "user-1"), + ).resolves.toBe(true); expect(mockGetConnectedPlatforms).toHaveBeenCalledTimes(1); + expect(mockGetConnectedPlatforms).toHaveBeenCalledWith("org-1", "user-1"); - await connectionEnforcementService.invalidateRequiredConnectionCache("org-1"); + await connectionEnforcementService.invalidateRequiredConnectionCache("org-1", "user-1"); - await expect(connectionEnforcementService.hasRequiredConnection("org-1")).resolves.toBe(false); + await expect( + connectionEnforcementService.hasRequiredConnection("org-1", "user-1"), + ).resolves.toBe(false); expect(mockGetConnectedPlatforms).toHaveBeenCalledTimes(2); }); @@ -152,7 +159,7 @@ describe("connection enforcement", () => { }); const storedConversation = Array.from(cacheStore.entries()).find(([key]) => - key.includes("connection-enforcement:conversation:org-3"), + key.includes("connection-enforcement:conversation:org-3:user-3"), )?.[1] as | { messages: Array<{ role: "user" | "assistant"; content: string }>; diff --git a/packages/tests/unit/eliza-app/connection-success-route.test.ts b/packages/tests/unit/eliza-app/connection-success-route.test.ts index 604f5e438..f9e1d0561 100644 --- a/packages/tests/unit/eliza-app/connection-success-route.test.ts +++ b/packages/tests/unit/eliza-app/connection-success-route.test.ts @@ -27,4 +27,20 @@ describe("connection success route", () => { expect(body).toContain("you're connected."); expect(body).toContain("head back to Telegram and send me a message."); }); + + test("renders popup-safe success page for Eliza App OAuth completions", async () => { + const response = await connectionSuccessGet( + new NextRequest( + "https://elizacloud.ai/api/eliza-app/auth/connection-success?source=eliza-app&platform=google&connection_id=conn-123", + ), + ); + + expect(response.status).toBe(200); + expect(response.headers.get("content-type")).toContain("text/html"); + + const body = await response.text(); + expect(body).toContain("Google connected."); + expect(body).toContain("eliza-app-oauth-complete"); + expect(body).toContain("conn-123"); + }); }); diff --git a/packages/tests/unit/eliza-app/connections-route.test.ts b/packages/tests/unit/eliza-app/connections-route.test.ts new file mode 100644 index 000000000..423a84bc5 --- /dev/null +++ b/packages/tests/unit/eliza-app/connections-route.test.ts @@ -0,0 +1,117 @@ +import { beforeEach, describe, expect, mock, test } from "bun:test"; +import { NextRequest } from "next/server"; + +const mockValidateAuthHeader = mock(); +const mockListConnections = mock(); +const mockInitiateAuth = mock(); +const mockGetProvider = mock((platform: string) => + platform === "google" + ? { + id: "google", + name: "Google", + } + : null, +); + +mock.module("@/lib/services/eliza-app", () => ({ + elizaAppSessionService: { + validateAuthHeader: mockValidateAuthHeader, + }, +})); + +mock.module("@/lib/services/oauth", () => ({ + oauthService: { + listConnections: mockListConnections, + initiateAuth: mockInitiateAuth, + }, +})); + +mock.module("@/lib/services/oauth/provider-registry", () => ({ + getProvider: mockGetProvider, +})); + +describe("Eliza App connections routes", () => { + beforeEach(() => { + mockValidateAuthHeader.mockReset(); + mockListConnections.mockReset(); + mockInitiateAuth.mockReset(); + mockGetProvider.mockClear(); + + mockValidateAuthHeader.mockResolvedValue({ + userId: "user-1", + organizationId: "org-1", + }); + }); + + test("returns user-scoped Google connection status", async () => { + const { GET } = await import("@/app/api/eliza-app/connections/route"); + + mockListConnections.mockResolvedValue([ + { + id: "conn-1", + userId: "user-1", + platform: "google", + platformUserId: "google-user-1", + email: "user@example.com", + status: "active", + scopes: ["gmail.send", "calendar.events"], + linkedAt: new Date("2026-04-04T12:00:00Z"), + tokenExpired: false, + source: "platform_credentials", + }, + ]); + + const response = await GET( + new NextRequest("https://elizacloud.ai/api/eliza-app/connections?platform=google", { + headers: { Authorization: "Bearer session-token" }, + }), + ); + + expect(response.status).toBe(200); + expect(mockListConnections).toHaveBeenCalledWith({ + organizationId: "org-1", + userId: "user-1", + platform: "google", + }); + + const json = await response.json(); + expect(json.connected).toBe(true); + expect(json.status).toBe("active"); + expect(json.email).toBe("user@example.com"); + }); + + test("initiates Google OAuth with Eliza App callback bridge", async () => { + const { POST } = await import("@/app/api/eliza-app/connections/[platform]/initiate/route"); + + mockInitiateAuth.mockResolvedValue({ + authUrl: "https://accounts.google.com/o/oauth2/v2/auth?state=test-state", + state: "test-state", + }); + + const response = await POST( + new NextRequest("https://elizacloud.ai/api/eliza-app/connections/google/initiate", { + method: "POST", + headers: { + Authorization: "Bearer session-token", + "Content-Type": "application/json", + }, + body: JSON.stringify({ returnPath: "/connected" }), + }), + { params: Promise.resolve({ platform: "google" }) }, + ); + + expect(response.status).toBe(200); + expect(mockInitiateAuth).toHaveBeenCalledWith({ + organizationId: "org-1", + userId: "user-1", + platform: "google", + redirectUrl: + "/api/eliza-app/auth/connection-success?source=eliza-app&return_path=%2Fconnected", + scopes: undefined, + }); + + const json = await response.json(); + expect(json.authUrl).toContain("accounts.google.com"); + expect(json.provider.name).toBe("Google"); + }); +}); diff --git a/packages/tests/unit/mcp-google-tools.test.ts b/packages/tests/unit/mcp-google-tools.test.ts index 09f678b38..f3a82776c 100644 --- a/packages/tests/unit/mcp-google-tools.test.ts +++ b/packages/tests/unit/mcp-google-tools.test.ts @@ -182,6 +182,11 @@ describe("Google MCP Tools", () => { expect(p.email).toBe("user@test.com"); expect(p.scopes).toContain("gmail.send"); expect(p.linkedAt).toBe("2026-01-01T00:00:00.000Z"); + expect(mockOAuth.listConnections).toHaveBeenCalledWith({ + organizationId: "org-1", + userId: "u-org-1", + platform: "google", + }); }); test("returns connected=false when no active connection", async () => { @@ -229,6 +234,11 @@ describe("Google MCP Tools", () => { expect(p.success).toBe(true); expect(p.messageId).toBe("msg-123"); expect(p.threadId).toBe("thread-456"); + expect(mockOAuth.getValidTokenByPlatform).toHaveBeenCalledWith({ + organizationId: "org-1", + userId: "u-org-1", + platform: "google", + }); }); test("handles CC and BCC recipients", async () => { diff --git a/packages/tests/unit/oauth/oauth-service.test.ts b/packages/tests/unit/oauth/oauth-service.test.ts index 3d13385d9..77383e8ce 100644 --- a/packages/tests/unit/oauth/oauth-service.test.ts +++ b/packages/tests/unit/oauth/oauth-service.test.ts @@ -5,6 +5,10 @@ */ import { describe, expect, it } from "bun:test"; +import { + getPreferredActiveConnection, + scopeConnectionsForUser, +} from "@/lib/services/oauth/oauth-service"; import { OAUTH_PROVIDERS } from "@/lib/services/oauth/provider-registry"; import type { OAuthConnection, OAuthProviderInfo } from "@/lib/services/oauth/types"; @@ -281,6 +285,51 @@ describe("OAuth Service Logic", () => { expect(activePlatforms.length).toBe(2); }); }); + + describe("User-scoped connection selection", () => { + it("should prefer user-owned connections before shared org connections", () => { + const connections: OAuthConnection[] = [ + createMockConnection("active", "shared", { + linkedAt: new Date("2026-01-01T00:00:00Z"), + }), + createMockConnection("active", "owned", { + userId: "user-1", + linkedAt: new Date("2025-01-01T00:00:00Z"), + }), + createMockConnection("active", "other-user", { + userId: "user-2", + linkedAt: new Date("2027-01-01T00:00:00Z"), + }), + ]; + + const preferred = getPreferredActiveConnection(connections, "user-1"); + + expect(preferred?.id).toBe("owned"); + }); + + it("should exclude other users while still exposing shared connections", () => { + const connections: OAuthConnection[] = [ + createMockConnection("active", "shared", { + linkedAt: new Date("2026-01-01T00:00:00Z"), + }), + createMockConnection("active", "owned", { + userId: "user-1", + linkedAt: new Date("2025-01-01T00:00:00Z"), + }), + createMockConnection("active", "other-user", { + userId: "user-2", + linkedAt: new Date("2027-01-01T00:00:00Z"), + }), + ]; + + const scoped = scopeConnectionsForUser(connections, "user-1"); + + expect(scoped.map((connection: OAuthConnection) => connection.id)).toEqual([ + "owned", + "shared", + ]); + }); + }); }); // Helper function to create mock connections diff --git a/tsconfig.json b/tsconfig.json index e3135eb90..53eba3229 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -1,7 +1,11 @@ { "compilerOptions": { "target": "ES2020", - "lib": ["dom", "dom.iterable", "esnext"], + "lib": [ + "dom", + "dom.iterable", + "esnext" + ], "allowJs": true, "skipLibCheck": true, "strict": true, @@ -14,21 +18,39 @@ "jsx": "react-jsx", "incremental": true, "tsBuildInfoFile": ".tsbuildinfo", - "types": ["node"], + "types": [ + "node" + ], "plugins": [ { "name": "next" } ], "paths": { - "@/lib/*": ["./packages/lib/*"], - "@/db/*": ["./packages/db/*"], - "@/tests/*": ["./packages/tests/*"], - "@/types/*": ["./packages/types/*"], - "@/*": ["./*"], - "@/components/*": ["./packages/ui/src/components/*"], - "@elizaos/cloud-ui": ["./packages/ui/src/index.ts"], - "@elizaos/cloud-ui/*": ["./packages/ui/src/*"] + "@/lib/*": [ + "./packages/lib/*" + ], + "@/db/*": [ + "./packages/db/*" + ], + "@/tests/*": [ + "./packages/tests/*" + ], + "@/types/*": [ + "./packages/types/*" + ], + "@/*": [ + "./*" + ], + "@/components/*": [ + "./packages/ui/src/components/*" + ], + "@elizaos/cloud-ui": [ + "./packages/ui/src/index.ts" + ], + "@elizaos/cloud-ui/*": [ + "./packages/ui/src/*" + ] } }, "include": [ @@ -46,7 +68,13 @@ ".next-build/types/**/*.ts", ".next-build/dev/types/**/*.ts", ".next-dev/types/**/*.ts", - ".next-dev/dev/types/**/*.ts" + ".next-dev/dev/types/**/*.ts", + ".next-test-*/types/**/*.ts", + ".next-test-*/dev/types/**/*.ts", + ".next-test-3303/types/**/*.ts", + ".next-test-3303/dev/types/**/*.ts", + ".next-test-3300/types/**/*.ts", + ".next-test-3300/dev/types/**/*.ts" ], "exclude": [ "node_modules", From ebeec227fcac50991c78fbf48aad522d69ad6b47 Mon Sep 17 00:00:00 2001 From: Shaw Date: Sun, 5 Apr 2026 01:59:17 -0700 Subject: [PATCH 07/11] cloud: separate owner and agent Google connections --- .../v1/milady/google/calendar/events/route.ts | 2 + .../v1/milady/google/calendar/feed/route.ts | 5 ++ .../milady/google/connect/initiate/route.ts | 2 + app/api/v1/milady/google/disconnect/route.ts | 2 + .../milady/google/gmail/reply-send/route.ts | 2 + .../v1/milady/google/gmail/triage/route.ts | 5 ++ app/api/v1/milady/google/status/route.ts | 5 ++ .../lib/services/milady-google-connector.ts | 40 ++++++++- .../connection-adapters/generic-adapter.ts | 6 ++ packages/lib/services/oauth/oauth-service.ts | 35 +++++--- .../lib/services/oauth/providers/oauth2.ts | 35 ++++++-- packages/lib/services/oauth/types.ts | 10 +++ .../unit/milady-google-connector.test.ts | 85 +++++++++++++++++-- .../tests/unit/milady-google-routes.test.ts | 43 +++++++++- 14 files changed, 246 insertions(+), 31 deletions(-) diff --git a/app/api/v1/milady/google/calendar/events/route.ts b/app/api/v1/milady/google/calendar/events/route.ts index fb5f71dc1..63115cc07 100644 --- a/app/api/v1/milady/google/calendar/events/route.ts +++ b/app/api/v1/milady/google/calendar/events/route.ts @@ -17,6 +17,7 @@ const attendeeSchema = z.object({ }); const requestSchema = z.object({ + side: z.enum(["owner", "agent"]).optional(), calendarId: z.string().trim().min(1).optional(), title: z.string().trim().min(1), description: z.string().optional(), @@ -42,6 +43,7 @@ export async function POST(request: NextRequest) { await createManagedGoogleCalendarEvent({ organizationId: user.organization_id, userId: user.id, + side: parsed.data.side ?? "owner", calendarId: parsed.data.calendarId ?? "primary", title: parsed.data.title, description: parsed.data.description, diff --git a/app/api/v1/milady/google/calendar/feed/route.ts b/app/api/v1/milady/google/calendar/feed/route.ts index e6e1ce051..1bcce07d9 100644 --- a/app/api/v1/milady/google/calendar/feed/route.ts +++ b/app/api/v1/milady/google/calendar/feed/route.ts @@ -13,11 +13,15 @@ export async function GET(request: NextRequest) { try { const { user } = await requireAuthOrApiKeyWithOrg(request); const searchParams = request.nextUrl.searchParams; + const rawSide = searchParams.get("side"); const calendarId = searchParams.get("calendarId")?.trim() || "primary"; const timeMin = searchParams.get("timeMin")?.trim(); const timeMax = searchParams.get("timeMax")?.trim(); const timeZone = searchParams.get("timeZone")?.trim() || "UTC"; + if (rawSide !== null && rawSide !== "owner" && rawSide !== "agent") { + return NextResponse.json({ error: "side must be owner or agent." }, { status: 400 }); + } if (!timeMin || !timeMax) { return NextResponse.json({ error: "timeMin and timeMax are required." }, { status: 400 }); } @@ -26,6 +30,7 @@ export async function GET(request: NextRequest) { await fetchManagedGoogleCalendarFeed({ organizationId: user.organization_id, userId: user.id, + side: rawSide === "agent" ? "agent" : "owner", calendarId, timeMin, timeMax, diff --git a/app/api/v1/milady/google/connect/initiate/route.ts b/app/api/v1/milady/google/connect/initiate/route.ts index 3c7410d99..dc8b15d59 100644 --- a/app/api/v1/milady/google/connect/initiate/route.ts +++ b/app/api/v1/milady/google/connect/initiate/route.ts @@ -11,6 +11,7 @@ export const dynamic = "force-dynamic"; export const maxDuration = 30; const requestSchema = z.object({ + side: z.enum(["owner", "agent"]).optional(), redirectUrl: z.string().trim().min(1).optional(), capabilities: z .array( @@ -39,6 +40,7 @@ export async function POST(request: NextRequest) { await initiateManagedGoogleConnection({ organizationId: user.organization_id, userId: user.id, + side: parsed.data.side ?? "owner", redirectUrl: parsed.data.redirectUrl, capabilities: parsed.data.capabilities, }), diff --git a/app/api/v1/milady/google/disconnect/route.ts b/app/api/v1/milady/google/disconnect/route.ts index b0c411e15..4083aa1a3 100644 --- a/app/api/v1/milady/google/disconnect/route.ts +++ b/app/api/v1/milady/google/disconnect/route.ts @@ -11,6 +11,7 @@ export const dynamic = "force-dynamic"; export const maxDuration = 30; const requestSchema = z.object({ + side: z.enum(["owner", "agent"]).optional(), connectionId: z.string().uuid().nullable().optional(), }); @@ -28,6 +29,7 @@ export async function POST(request: NextRequest) { await disconnectManagedGoogleConnection({ organizationId: user.organization_id, userId: user.id, + side: parsed.data.side ?? "owner", connectionId: parsed.data.connectionId ?? null, }); return NextResponse.json({ ok: true }); diff --git a/app/api/v1/milady/google/gmail/reply-send/route.ts b/app/api/v1/milady/google/gmail/reply-send/route.ts index 81d1dfa2c..90a691bea 100644 --- a/app/api/v1/milady/google/gmail/reply-send/route.ts +++ b/app/api/v1/milady/google/gmail/reply-send/route.ts @@ -11,6 +11,7 @@ export const dynamic = "force-dynamic"; export const maxDuration = 30; const requestSchema = z.object({ + side: z.enum(["owner", "agent"]).optional(), to: z.array(z.string().email()).min(1), cc: z.array(z.string().email()).optional(), subject: z.string().trim().min(1), @@ -33,6 +34,7 @@ export async function POST(request: NextRequest) { await sendManagedGoogleReply({ organizationId: user.organization_id, userId: user.id, + side: parsed.data.side ?? "owner", to: parsed.data.to, cc: parsed.data.cc, subject: parsed.data.subject, diff --git a/app/api/v1/milady/google/gmail/triage/route.ts b/app/api/v1/milady/google/gmail/triage/route.ts index f94c52e14..763d89205 100644 --- a/app/api/v1/milady/google/gmail/triage/route.ts +++ b/app/api/v1/milady/google/gmail/triage/route.ts @@ -12,7 +12,11 @@ export const maxDuration = 30; export async function GET(request: NextRequest) { try { const { user } = await requireAuthOrApiKeyWithOrg(request); + const rawSide = request.nextUrl.searchParams.get("side"); const rawMaxResults = request.nextUrl.searchParams.get("maxResults"); + if (rawSide !== null && rawSide !== "owner" && rawSide !== "agent") { + return NextResponse.json({ error: "side must be owner or agent." }, { status: 400 }); + } const maxResults = rawMaxResults && rawMaxResults.trim().length > 0 ? Number.parseInt(rawMaxResults, 10) : 12; if (!Number.isFinite(maxResults) || maxResults <= 0) { @@ -26,6 +30,7 @@ export async function GET(request: NextRequest) { await fetchManagedGoogleGmailTriage({ organizationId: user.organization_id, userId: user.id, + side: rawSide === "agent" ? "agent" : "owner", maxResults, }), ); diff --git a/app/api/v1/milady/google/status/route.ts b/app/api/v1/milady/google/status/route.ts index 7bf93c3a7..96c219241 100644 --- a/app/api/v1/milady/google/status/route.ts +++ b/app/api/v1/milady/google/status/route.ts @@ -12,10 +12,15 @@ export const maxDuration = 30; export async function GET(request: NextRequest) { try { const { user } = await requireAuthOrApiKeyWithOrg(request); + const rawSide = request.nextUrl.searchParams.get("side"); + if (rawSide !== null && rawSide !== "owner" && rawSide !== "agent") { + return NextResponse.json({ error: "side must be owner or agent." }, { status: 400 }); + } return NextResponse.json( await getManagedGoogleConnectorStatus({ organizationId: user.organization_id, userId: user.id, + side: rawSide === "agent" ? "agent" : "owner", }), ); } catch (error) { diff --git a/packages/lib/services/milady-google-connector.ts b/packages/lib/services/milady-google-connector.ts index d7ae4bff0..e62eb8063 100644 --- a/packages/lib/services/milady-google-connector.ts +++ b/packages/lib/services/milady-google-connector.ts @@ -4,6 +4,7 @@ import { platformCredentials } from "@/db/schemas/platform-credentials"; import { oauthService } from "@/lib/services/oauth"; import { getPreferredActiveConnection } from "@/lib/services/oauth/oauth-service"; import { getProvider, isProviderConfigured } from "@/lib/services/oauth/provider-registry"; +import type { OAuthConnectionRole } from "@/lib/services/oauth/types"; import { applyTimeZone, googleFetchWithToken, @@ -40,6 +41,7 @@ export type MiladyGoogleCapability = export interface ManagedGoogleConnectorStatus { provider: "google"; + side: OAuthConnectionRole; mode: "cloud_managed"; configured: boolean; connected: boolean; @@ -267,17 +269,26 @@ async function getConnectionRow( return row ?? null; } -async function getScopedGoogleConnections(args: { organizationId: string; userId: string }) { +async function getScopedGoogleConnections(args: { + organizationId: string; + userId: string; + side: OAuthConnectionRole; +}) { return oauthService.listConnections({ organizationId: args.organizationId, userId: args.userId, platform: "google", + connectionRole: args.side, }); } -async function getActiveGoogleConnectionRecord(args: { organizationId: string; userId: string }) { +async function getActiveGoogleConnectionRecord(args: { + organizationId: string; + userId: string; + side: OAuthConnectionRole; +}) { const connections = await getScopedGoogleConnections(args); - const activeConnection = getPreferredActiveConnection(connections, args.userId); + const activeConnection = getPreferredActiveConnection(connections, args.userId, args.side); const latestConnection = connections[0] ?? null; const activeRow = activeConnection ? await getConnectionRow(args.organizationId, activeConnection.id) @@ -299,6 +310,7 @@ async function getActiveGoogleConnectionRecord(args: { organizationId: string; u async function getGoogleAccessToken(args: { organizationId: string; userId: string; + side: OAuthConnectionRole; }): Promise<{ accessToken: string; connectionId: string }> { try { return await oauthService @@ -306,6 +318,7 @@ async function getGoogleAccessToken(args: { organizationId: args.organizationId, userId: args.userId, platform: "google", + connectionRole: args.side, }) .then((result) => ({ accessToken: result.token.accessToken, @@ -320,6 +333,7 @@ async function getGoogleAccessToken(args: { async function googleFetch(args: { organizationId: string; userId: string; + side: OAuthConnectionRole; url: string; options?: RequestInit; }): Promise { @@ -653,6 +667,7 @@ function normalizeReplySubject(subject: string): string { export async function getManagedGoogleConnectorStatus(args: { organizationId: string; userId: string; + side: OAuthConnectionRole; }): Promise { const provider = getProvider("google"); const configured = provider ? isProviderConfigured(provider) : false; @@ -660,6 +675,7 @@ export async function getManagedGoogleConnectorStatus(args: { if (!configured) { return { provider: "google", + side: args.side, mode: "cloud_managed", configured: false, connected: false, @@ -683,6 +699,7 @@ export async function getManagedGoogleConnectorStatus(args: { if (!currentConnection) { return { provider: "google", + side: args.side, mode: "cloud_managed", configured: true, connected: false, @@ -707,6 +724,7 @@ export async function getManagedGoogleConnectorStatus(args: { return { provider: "google", + side: args.side, mode: "cloud_managed", configured: true, connected, @@ -730,6 +748,7 @@ export async function getManagedGoogleConnectorStatus(args: { export async function initiateManagedGoogleConnection(args: { organizationId: string; userId: string; + side: OAuthConnectionRole; redirectUrl?: string; capabilities?: MiladyGoogleCapability[]; }) { @@ -740,9 +759,11 @@ export async function initiateManagedGoogleConnection(args: { platform: "google", redirectUrl: args.redirectUrl, scopes: capabilitiesToScopes(requestedCapabilities), + connectionRole: args.side, }); return { provider: "google" as const, + side: args.side, mode: "cloud_managed" as const, requestedCapabilities, redirectUri: args.redirectUrl ?? "/auth/success?platform=google", @@ -753,13 +774,14 @@ export async function initiateManagedGoogleConnection(args: { export async function disconnectManagedGoogleConnection(args: { organizationId: string; userId: string; + side: OAuthConnectionRole; connectionId?: string | null; }): Promise { const connections = await getScopedGoogleConnections(args); const activeConnection = (args.connectionId ? connections.find((connection) => connection.id === args.connectionId) - : getPreferredActiveConnection(connections, args.userId)) ?? + : getPreferredActiveConnection(connections, args.userId, args.side)) ?? connections[0] ?? null; if (!activeConnection) { @@ -774,6 +796,7 @@ export async function disconnectManagedGoogleConnection(args: { export async function fetchManagedGoogleCalendarFeed(args: { organizationId: string; userId: string; + side: OAuthConnectionRole; calendarId: string; timeMin: string; timeMax: string; @@ -794,6 +817,7 @@ export async function fetchManagedGoogleCalendarFeed(args: { const response = await googleFetch({ organizationId: args.organizationId, userId: args.userId, + side: args.side, url: `${GOOGLE_CALENDAR_EVENTS_ENDPOINT}/${encodeURIComponent(args.calendarId)}/events?${params.toString()}`, }); const parsed = (await response.json()) as { items?: GoogleCalendarApiEvent[] }; @@ -809,6 +833,7 @@ export async function fetchManagedGoogleCalendarFeed(args: { export async function createManagedGoogleCalendarEvent(args: { organizationId: string; userId: string; + side: OAuthConnectionRole; calendarId: string; title: string; description?: string; @@ -825,6 +850,7 @@ export async function createManagedGoogleCalendarEvent(args: { const response = await googleFetch({ organizationId: args.organizationId, userId: args.userId, + side: args.side, url: `${GOOGLE_CALENDAR_EVENTS_ENDPOINT}/${encodeURIComponent(args.calendarId)}/events?conferenceDataVersion=1`, options: { method: "POST", @@ -850,12 +876,14 @@ export async function createManagedGoogleCalendarEvent(args: { export async function fetchManagedGoogleGmailTriage(args: { organizationId: string; userId: string; + side: OAuthConnectionRole; maxResults: number; }): Promise<{ messages: ManagedGoogleGmailMessage[]; syncedAt: string }> { const maxResults = Math.min(Math.max(args.maxResults, 1), 50); const connectorStatus = await getManagedGoogleConnectorStatus({ organizationId: args.organizationId, userId: args.userId, + side: args.side, }); const selfEmail = connectorStatus.identity && typeof connectorStatus.identity.email === "string" @@ -870,6 +898,7 @@ export async function fetchManagedGoogleGmailTriage(args: { const listResponse = await googleFetch({ organizationId: args.organizationId, userId: args.userId, + side: args.side, url: `${GOOGLE_GMAIL_MESSAGES_ENDPOINT}?${listParams.toString()}`, }); const listed = (await listResponse.json()) as GoogleGmailListResponse; @@ -885,6 +914,7 @@ export async function fetchManagedGoogleGmailTriage(args: { const response = await googleFetch({ organizationId: args.organizationId, userId: args.userId, + side: args.side, url: `${GOOGLE_GMAIL_MESSAGES_ENDPOINT}/${encodeURIComponent(messageId)}?${params.toString()}`, }); const parsed = (await response.json()) as GoogleGmailMetadataResponse; @@ -907,6 +937,7 @@ export async function fetchManagedGoogleGmailTriage(args: { export async function sendManagedGoogleReply(args: { organizationId: string; userId: string; + side: OAuthConnectionRole; to: string[]; cc?: string[]; subject: string; @@ -930,6 +961,7 @@ export async function sendManagedGoogleReply(args: { await googleFetch({ organizationId: args.organizationId, userId: args.userId, + side: args.side, url: GOOGLE_GMAIL_SEND_ENDPOINT, options: { method: "POST", diff --git a/packages/lib/services/oauth/connection-adapters/generic-adapter.ts b/packages/lib/services/oauth/connection-adapters/generic-adapter.ts index c077b47e1..54fe57ee0 100644 --- a/packages/lib/services/oauth/connection-adapters/generic-adapter.ts +++ b/packages/lib/services/oauth/connection-adapters/generic-adapter.ts @@ -109,6 +109,12 @@ export function createGenericAdapter(platform: string): ConnectionAdapter { return credentials.map((cred) => ({ id: cred.id, userId: cred.user_id || undefined, + connectionRole: + cred.source_context && + typeof cred.source_context === "object" && + (cred.source_context as Record).miladyGoogleSide === "agent" + ? "agent" + : "owner", platform, platformUserId: cred.platform_user_id, email: cred.platform_email || undefined, diff --git a/packages/lib/services/oauth/oauth-service.ts b/packages/lib/services/oauth/oauth-service.ts index f67c688f4..392166225 100644 --- a/packages/lib/services/oauth/oauth-service.ts +++ b/packages/lib/services/oauth/oauth-service.ts @@ -20,6 +20,7 @@ import type { InitiateAuthResult, ListConnectionsParams, OAuthConnection, + OAuthConnectionRole, OAuthProviderInfo, TokenResult, } from "./types"; @@ -50,36 +51,44 @@ export function getMostRecentActiveConnection( export function getPreferredActiveConnection( connections: OAuthConnection[], userId?: string, + connectionRole?: OAuthConnectionRole, ): OAuthConnection | null { + const scopedByRole = connectionRole + ? connections.filter((connection) => connection.connectionRole === connectionRole) + : connections; if (!userId) { - return getMostRecentActiveConnection(connections); + return getMostRecentActiveConnection(scopedByRole); } const ownedConnection = getMostRecentActiveConnection( - connections.filter((connection) => connection.userId === userId), + scopedByRole.filter((connection) => connection.userId === userId), ); if (ownedConnection) { return ownedConnection; } return getMostRecentActiveConnection( - connections.filter((connection) => connection.userId === undefined), + scopedByRole.filter((connection) => connection.userId === undefined), ); } export function scopeConnectionsForUser( connections: OAuthConnection[], userId?: string, + connectionRole?: OAuthConnectionRole, ): OAuthConnection[] { + const scopedByRole = connectionRole + ? connections.filter((connection) => connection.connectionRole === connectionRole) + : connections; if (!userId) { - return sortConnectionsByRecency(connections); + return sortConnectionsByRecency(scopedByRole); } const ownedConnections = sortConnectionsByRecency( - connections.filter((connection) => connection.userId === userId), + scopedByRole.filter((connection) => connection.userId === userId), ); const sharedConnections = sortConnectionsByRecency( - connections.filter((connection) => connection.userId === undefined), + scopedByRole.filter((connection) => connection.userId === undefined), ); if (ownedConnections.length > 0) { @@ -104,7 +113,7 @@ class OAuthService { /** Initiate OAuth flow for a platform */ async initiateAuth(params: InitiateAuthParams): Promise { - const { organizationId, userId, platform, redirectUrl, scopes } = params; + const { organizationId, userId, platform, redirectUrl, scopes, connectionRole } = params; const provider = getProvider(platform); if (!provider) throw Errors.platformNotSupported(platform); @@ -122,6 +131,7 @@ class OAuthService { userId, redirectUrl, scopes, + connectionRole, }); return { authUrl: result.authUrl, state: result.state }; } @@ -163,7 +173,7 @@ class OAuthService { /** List all OAuth connections for an organization */ async listConnections(params: ListConnectionsParams): Promise { - const { organizationId, platform, userId } = params; + const { organizationId, platform, userId, connectionRole } = params; const adapters = platform ? [getAdapter(platform)].filter(Boolean) : getAllAdapters(); const results = await Promise.allSettled( adapters.map((a) => a!.listConnections(organizationId)), @@ -177,7 +187,7 @@ class OAuthService { return []; }); - return scopeConnectionsForUser(connections, userId); + return scopeConnectionsForUser(connections, userId, connectionRole); } /** Get a single connection by ID */ @@ -243,13 +253,13 @@ class OAuthService { async getValidTokenByPlatformWithConnectionId( params: GetTokenByPlatformParams, ): Promise<{ token: TokenResult; connectionId: string }> { - const { organizationId, platform, userId } = params; + const { organizationId, platform, userId, connectionRole } = params; const adapter = getAdapter(platform); if (!adapter) throw Errors.platformNotSupported(platform); const connections = await adapter.listConnections(organizationId); - const activeConnection = getPreferredActiveConnection(connections, userId); + const activeConnection = getPreferredActiveConnection(connections, userId, connectionRole); if (!activeConnection) throw Errors.platformNotConnected(platform); const token = await this.getValidToken({ @@ -265,12 +275,13 @@ class OAuthService { organizationId: string, platform: string, userId?: string, + connectionRole?: OAuthConnectionRole, ): Promise { const adapter = getAdapter(platform); if (!adapter) return false; const connections = await adapter.listConnections(organizationId); - return getPreferredActiveConnection(connections, userId) !== null; + return getPreferredActiveConnection(connections, userId, connectionRole) !== null; } /** Get all platforms with active connections */ diff --git a/packages/lib/services/oauth/providers/oauth2.ts b/packages/lib/services/oauth/providers/oauth2.ts index 813284ba9..1a9611333 100644 --- a/packages/lib/services/oauth/providers/oauth2.ts +++ b/packages/lib/services/oauth/providers/oauth2.ts @@ -14,6 +14,7 @@ import { secretsService } from "@/lib/services/secrets"; import { logger } from "@/lib/utils/logger"; import type { OAuthProviderConfig, UserInfoMapping } from "../provider-registry"; import { getCallbackUrl, getClientId, getClientSecret, getNestedValue } from "../provider-registry"; +import type { OAuthConnectionRole } from "../types"; const STATE_TTL_SECONDS = 600; // 10 minutes @@ -26,6 +27,7 @@ interface OAuth2State { providerId: string; redirectUrl: string; scopes: string[]; + connectionRole?: OAuthConnectionRole; createdAt: number; codeVerifier?: string; } @@ -103,6 +105,7 @@ export async function initiateOAuth2( userId: string; redirectUrl?: string; scopes?: string[]; + connectionRole?: OAuthConnectionRole; }, ): Promise { const clientId = getClientId(provider); @@ -138,6 +141,7 @@ export async function initiateOAuth2( providerId: provider.id, redirectUrl, scopes, + connectionRole: params.connectionRole, createdAt: Date.now(), codeVerifier, }; @@ -207,6 +211,7 @@ export async function handleOAuth2Callback( await cache.del(stateKey); const { organizationId, userId, redirectUrl, scopes, codeVerifier } = stateData; + const connectionRole = stateData.connectionRole ?? "owner"; // Exchange code for tokens const tokens = await exchangeCodeForTokens(provider, code, codeVerifier); @@ -226,6 +231,7 @@ export async function handleOAuth2Callback( provider, organizationId, userId, + connectionRole, tokens, userInfo, scopes, @@ -562,10 +568,12 @@ async function storeConnection( provider: OAuthProviderConfig, organizationId: string, userId: string, + connectionRole: OAuthConnectionRole, tokens: TokenResponse, userInfo: ExtractedUserInfo, scopes: string[], ): Promise { + const connectionUserId = connectionRole === "agent" ? null : userId; const audit = { actorType: "user" as const, actorId: userId, @@ -752,11 +760,18 @@ async function storeConnection( }) .from(platformCredentials) .where( - and( - eq(platformCredentials.organization_id, organizationId), - eq(platformCredentials.user_id, userId), - eq(platformCredentials.platform, providerPlatform), - ), + connectionUserId + ? and( + eq(platformCredentials.organization_id, organizationId), + eq(platformCredentials.user_id, connectionUserId), + eq(platformCredentials.platform, providerPlatform), + ) + : and( + eq(platformCredentials.organization_id, organizationId), + sql`${platformCredentials.user_id} IS NULL`, + eq(platformCredentials.platform, providerPlatform), + sql`COALESCE(${platformCredentials.source_context}->>'miladyGoogleSide', 'owner') = ${connectionRole}`, + ), ) .orderBy( desc( @@ -947,7 +962,7 @@ async function storeConnection( .insert(platformCredentials) .values({ organization_id: organizationId, - user_id: userId, + user_id: connectionUserId, platform: providerPlatform, platform_user_id: userInfo.id, platform_username: userInfo.username || undefined, @@ -961,6 +976,9 @@ async function storeConnection( scopes, profile_data: userInfo.raw, source_type: "web", + source_context: { + miladyGoogleSide: connectionRole, + }, linked_at: new Date(), }) .onConflictDoUpdate({ @@ -971,7 +989,7 @@ async function storeConnection( ], setWhere: sql`${platformCredentials.user_id} IS NULL OR ${platformCredentials.user_id} = ${userId}`, set: { - user_id: userId, + user_id: connectionUserId, platform_username: userInfo.username || undefined, platform_display_name: userInfo.displayName || undefined, platform_avatar_url: userInfo.avatarUrl || undefined, @@ -982,6 +1000,9 @@ async function storeConnection( token_expires_at: tokenExpiresAt, scopes, profile_data: userInfo.raw, + source_context: { + miladyGoogleSide: connectionRole, + }, linked_at: new Date(), updated_at: new Date(), }, diff --git a/packages/lib/services/oauth/types.ts b/packages/lib/services/oauth/types.ts index c7a347fb0..7edd3afca 100644 --- a/packages/lib/services/oauth/types.ts +++ b/packages/lib/services/oauth/types.ts @@ -11,6 +11,8 @@ export type OAuthConnectionStatus = "pending" | "active" | "expired" | "revoked" export type OAuthConnectionSource = "platform_credentials" | "secrets"; +export type OAuthConnectionRole = "owner" | "agent"; + /** * Provider information returned by the list providers endpoint. */ @@ -37,6 +39,8 @@ export interface OAuthConnection { id: string; /** Cloud user that owns the connection when user-scoped */ userId?: string; + /** Logical role Milady uses for the connection */ + connectionRole?: OAuthConnectionRole; /** Platform identifier (e.g., 'google', 'twitter') */ platform: string; /** User ID on the platform */ @@ -95,6 +99,8 @@ export interface InitiateAuthParams { redirectUrl?: string; /** Specific scopes to request (overrides defaults) */ scopes?: string[]; + /** Logical Milady-side role for the connection */ + connectionRole?: OAuthConnectionRole; } /** @@ -119,6 +125,8 @@ export interface ListConnectionsParams { userId?: string; /** Optional platform filter */ platform?: string; + /** Optional logical role filter */ + connectionRole?: OAuthConnectionRole; } /** @@ -141,6 +149,8 @@ export interface GetTokenByPlatformParams { userId?: string; /** Platform identifier */ platform: string; + /** Optional logical role filter */ + connectionRole?: OAuthConnectionRole; } /** diff --git a/packages/tests/unit/milady-google-connector.test.ts b/packages/tests/unit/milady-google-connector.test.ts index 8515ffd1f..c9933dc3e 100644 --- a/packages/tests/unit/milady-google-connector.test.ts +++ b/packages/tests/unit/milady-google-connector.test.ts @@ -1,6 +1,10 @@ -import { beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterAll, beforeEach, describe, expect, mock, test } from "bun:test"; import type { OAuthConnection } from "@/lib/services/oauth/types"; +afterAll(() => { + mock.restore(); +}); + const mockListConnections = mock(); const mockGetValidTokenByPlatformWithConnectionId = mock(); const mockInitiateAuth = mock(); @@ -41,9 +45,21 @@ mock.module("@/lib/services/oauth", () => ({ })); mock.module("@/lib/services/oauth/oauth-service", () => ({ - getPreferredActiveConnection: (connections: OAuthConnection[], userId?: string) => + getPreferredActiveConnection: ( + connections: OAuthConnection[], + userId?: string, + connectionRole?: "owner" | "agent", + ) => + connections.find( + (connection) => + connection.status === "active" && + (!userId || connection.userId === userId) && + (!connectionRole || connection.connectionRole === connectionRole), + ) ?? connections.find( - (connection) => connection.status === "active" && (!userId || connection.userId === userId), + (connection) => + connection.status === "active" && + (!connectionRole || connection.connectionRole === connectionRole), ) ?? connections.find((connection) => connection.status === "active") ?? null, @@ -74,6 +90,7 @@ function createConnection(overrides: Partial = {}): OAuthConnec return { id: "conn-google-1", userId: "user-1", + connectionRole: "owner", platform: "google", platformUserId: "google-user-1", email: "founder@example.com", @@ -127,16 +144,18 @@ describe("milady Google connector service", () => { }); }); - test("reports managed Google connector status from the active user-scoped connection", async () => { + test("reports managed Google connector status from the active owner connection", async () => { mockListConnections.mockResolvedValue([createConnection()]); const status = await getManagedGoogleConnectorStatus({ organizationId: "org-1", userId: "user-1", + side: "owner", }); expect(status).toEqual({ provider: "google", + side: "owner", mode: "cloud_managed", configured: true, connected: true, @@ -167,6 +186,46 @@ describe("milady Google connector service", () => { linkedAt: "2026-04-04T15:00:00.000Z", lastUsedAt: "2026-04-04T16:00:00.000Z", }); + expect(mockListConnections).toHaveBeenCalledWith({ + organizationId: "org-1", + userId: "user-1", + platform: "google", + connectionRole: "owner", + }); + }); + + test("reports managed Google connector status from the agent-side connection", async () => { + mockListConnections.mockResolvedValue([ + createConnection({ + id: "conn-google-agent", + userId: null, + connectionRole: "agent", + email: "milady-agent@example.com", + username: "milady-agent", + displayName: "Milady Agent", + }), + ]); + + const status = await getManagedGoogleConnectorStatus({ + organizationId: "org-1", + userId: "user-1", + side: "agent", + }); + + expect(status.side).toBe("agent"); + expect(status.connectionId).toBe("conn-google-agent"); + expect(status.identity).toEqual({ + id: "google-user-1", + email: "milady-agent@example.com", + name: "Milady Agent", + avatarUrl: "https://example.com/avatar.png", + }); + expect(mockListConnections).toHaveBeenCalledWith({ + organizationId: "org-1", + userId: "user-1", + platform: "google", + connectionRole: "agent", + }); }); test("initiates managed Google auth with the requested Milady capability scopes", async () => { @@ -177,6 +236,7 @@ describe("milady Google connector service", () => { const result = await initiateManagedGoogleConnection({ organizationId: "org-1", userId: "user-1", + side: "agent", redirectUrl: "https://www.elizacloud.ai/auth/success?platform=google", capabilities: ["google.calendar.read", "google.gmail.triage", "google.gmail.send"], }); @@ -194,7 +254,9 @@ describe("milady Google connector service", () => { "https://www.googleapis.com/auth/gmail.metadata", "https://www.googleapis.com/auth/gmail.send", ], + connectionRole: "agent", }); + expect(result.side).toBe("agent"); expect(result.mode).toBe("cloud_managed"); expect(result.requestedCapabilities).toEqual([ "google.basic_identity", @@ -244,6 +306,7 @@ describe("milady Google connector service", () => { const feed = await fetchManagedGoogleCalendarFeed({ organizationId: "org-1", userId: "user-1", + side: "owner", calendarId: "primary", timeMin: "2026-04-04T00:00:00.000Z", timeMax: "2026-04-05T00:00:00.000Z", @@ -298,6 +361,7 @@ describe("milady Google connector service", () => { const triage = await fetchManagedGoogleGmailTriage({ organizationId: "org-1", userId: "user-1", + side: "owner", maxResults: 5, }); @@ -321,6 +385,7 @@ describe("milady Google connector service", () => { await sendManagedGoogleReply({ organizationId: "org-1", userId: "user-1", + side: "owner", to: ["founder@example.com"], cc: ["ops@example.com"], subject: "Project sync", @@ -346,11 +411,16 @@ describe("milady Google connector service", () => { expect(decoded).toContain("Reviewing it now."); }); - test("disconnects the preferred active Google connection for the user", async () => { + test("disconnects the preferred active Google connection for the requested side", async () => { mockListConnections.mockResolvedValue([ - createConnection({ id: "conn-google-1" }), + createConnection({ + id: "conn-google-agent", + userId: null, + connectionRole: "agent", + }), createConnection({ id: "conn-google-2", + connectionRole: "owner", status: "revoked", linkedAt: new Date("2026-04-03T15:00:00.000Z"), }), @@ -359,11 +429,12 @@ describe("milady Google connector service", () => { await disconnectManagedGoogleConnection({ organizationId: "org-1", userId: "user-1", + side: "agent", }); expect(mockRevokeConnection).toHaveBeenCalledWith({ organizationId: "org-1", - connectionId: "conn-google-1", + connectionId: "conn-google-agent", }); }); }); diff --git a/packages/tests/unit/milady-google-routes.test.ts b/packages/tests/unit/milady-google-routes.test.ts index a64e34fe3..bf0651dc8 100644 --- a/packages/tests/unit/milady-google-routes.test.ts +++ b/packages/tests/unit/milady-google-routes.test.ts @@ -1,4 +1,4 @@ -import { beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterAll, beforeEach, describe, expect, mock, test } from "bun:test"; import { NextRequest } from "next/server"; import { jsonRequest } from "./api/route-test-helpers"; @@ -38,6 +38,11 @@ import { POST as postCalendarEvent } from "@/app/api/v1/milady/google/calendar/e import { GET as getCalendarFeed } from "@/app/api/v1/milady/google/calendar/feed/route"; import { POST as postConnectInitiate } from "@/app/api/v1/milady/google/connect/initiate/route"; import { POST as postDisconnect } from "@/app/api/v1/milady/google/disconnect/route"; + +afterAll(() => { + mock.restore(); +}); + import { POST as postReplySend } from "@/app/api/v1/milady/google/gmail/reply-send/route"; import { GET as getGmailTriage } from "@/app/api/v1/milady/google/gmail/triage/route"; import { GET as getStatus } from "@/app/api/v1/milady/google/status/route"; @@ -64,6 +69,7 @@ describe("Milady managed Google routes", () => { test("GET /api/v1/milady/google/status returns the managed connector status", async () => { mockGetStatus.mockResolvedValue({ provider: "google", + side: "owner", mode: "cloud_managed", configured: true, connected: true, @@ -85,12 +91,43 @@ describe("Milady managed Google routes", () => { expect(response.status).toBe(200); expect(await response.json()).toMatchObject({ provider: "google", + side: "owner", mode: "cloud_managed", connected: true, connectionId: "conn-1", }); }); + test("GET /api/v1/milady/google/status forwards an explicit side", async () => { + mockGetStatus.mockResolvedValue({ + provider: "google", + side: "agent", + mode: "cloud_managed", + configured: true, + connected: false, + reason: "disconnected", + identity: null, + grantedCapabilities: [], + grantedScopes: [], + expiresAt: null, + hasRefreshToken: false, + connectionId: null, + linkedAt: null, + lastUsedAt: null, + }); + + const response = await getStatus( + new NextRequest("https://example.com/api/v1/milady/google/status?side=agent"), + ); + + expect(response.status).toBe(200); + expect(mockGetStatus).toHaveBeenCalledWith({ + organizationId: "org-1", + userId: "user-1", + side: "agent", + }); + }); + test("POST /api/v1/milady/google/connect/initiate validates capabilities and delegates to the service", async () => { mockInitiateConnection.mockResolvedValue({ provider: "google", @@ -111,6 +148,7 @@ describe("Milady managed Google routes", () => { expect(mockInitiateConnection).toHaveBeenCalledWith({ organizationId: "org-1", userId: "user-1", + side: "owner", redirectUrl: "https://www.elizacloud.ai/auth/success?platform=google", capabilities: ["google.calendar.read"], }); @@ -161,6 +199,7 @@ describe("Milady managed Google routes", () => { expect(mockCreateCalendarEvent).toHaveBeenCalledWith({ organizationId: "org-1", userId: "user-1", + side: "owner", calendarId: "primary", title: "Founder sync", description: undefined, @@ -199,6 +238,7 @@ describe("Milady managed Google routes", () => { expect(mockSendReply).toHaveBeenCalledWith({ organizationId: "org-1", userId: "user-1", + side: "owner", to: ["founder@example.com"], cc: undefined, subject: "Project sync", @@ -220,6 +260,7 @@ describe("Milady managed Google routes", () => { expect(mockDisconnectConnection).toHaveBeenCalledWith({ organizationId: "org-1", userId: "user-1", + side: "owner", connectionId: null, }); }); From 14f269ccdc1b65650ea01c00c59bb84aef1ca621 Mon Sep 17 00:00:00 2001 From: Shaw Date: Sun, 5 Apr 2026 02:27:47 -0700 Subject: [PATCH 08/11] cloud: checkpoint verified local changes --- packages/db/schemas/platform-credentials.ts | 1 + .../tests/e2e/gateway-service.e2e.test.ts | 6 +- .../tests/gateway-manager.test.ts | 6 +- .../gateway-discord/tests/logger.test.ts | 56 +-- .../unit/api/v1-agents-service-route.test.ts | 6 +- .../unit/api/v1-app-agents-route.test.ts | 6 +- .../unit/api/v1-generation-routes.test.ts | 11 +- ...v1-process-provisioning-jobs-route.test.ts | 6 +- .../tests/unit/api/v1-public-routes.test.ts | 6 +- packages/tests/unit/auth-pair-route.test.ts | 6 +- .../unit/blooio-automation-service.test.ts | 420 +++++++++--------- .../compat-auth-and-restart-route.test.ts | 6 +- .../unit/compat-availability-route.test.ts | 6 +- .../tests/unit/compat-error-handler.test.ts | 6 +- .../unit/compat-routes-error-handling.test.ts | 6 +- .../unit/discord-automation-oauth.test.ts | 8 +- .../unit/docker-ssh-cloud-deploy.test.ts | 6 +- .../unit/eliza-app/connections-route.test.ts | 6 +- .../admin-metrics-api.test.ts | 6 +- .../compute-metrics-cron.test.ts | 6 +- .../user-metrics-contract.test.ts | 6 +- packages/tests/unit/field-encryption.test.ts | 6 +- .../tests/unit/headscale-ip-route.test.ts | 6 +- .../managed-discord-eliza-app-route.test.ts | 1 + packages/tests/unit/mcp-google-tools.test.ts | 6 +- packages/tests/unit/mcp-hubspot-tools.test.ts | 6 +- packages/tests/unit/mcp-twitter-tools.test.ts | 6 +- .../unit/milady-agent-discord-routes.test.ts | 6 +- .../tests/unit/milady-create-routes.test.ts | 6 +- .../unit/milady-google-connector.test.ts | 195 ++++---- .../milaidy-agent-routes-followups.test.ts | 6 +- .../milaidy-sandbox-bridge-security.test.ts | 6 +- .../milaidy-sandbox-service-followups.test.ts | 6 +- .../unit/provisioning-jobs-followups.test.ts | 6 +- packages/tests/unit/proxy-pricing.test.ts | 6 +- packages/tests/unit/service-jwt.test.ts | 6 +- .../unit/v1-milaidy-provision-route.test.ts | 6 +- tsconfig.json | 42 +- 38 files changed, 508 insertions(+), 406 deletions(-) diff --git a/packages/db/schemas/platform-credentials.ts b/packages/db/schemas/platform-credentials.ts index b35ea0935..42dc15f16 100644 --- a/packages/db/schemas/platform-credentials.ts +++ b/packages/db/schemas/platform-credentials.ts @@ -126,6 +126,7 @@ export const platformCredentials = pgTable( channel_id?: string; message_id?: string; referrer?: string; + miladyGoogleSide?: "owner" | "agent"; }>(), // Raw profile data from OAuth provider diff --git a/packages/services/gateway-discord/tests/e2e/gateway-service.e2e.test.ts b/packages/services/gateway-discord/tests/e2e/gateway-service.e2e.test.ts index b1d7ae28a..d4f54edc6 100644 --- a/packages/services/gateway-discord/tests/e2e/gateway-service.e2e.test.ts +++ b/packages/services/gateway-discord/tests/e2e/gateway-service.e2e.test.ts @@ -5,9 +5,13 @@ * Uses mocks for external dependencies (Discord API, Redis, Eliza Cloud). */ -import { afterEach, beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterAll, afterEach, beforeEach, describe, expect, mock, test } from "bun:test"; import { Hono } from "hono"; +afterAll(() => { + mock.restore(); +}); + // ============================================ // Mock Setup // ============================================ diff --git a/packages/services/gateway-discord/tests/gateway-manager.test.ts b/packages/services/gateway-discord/tests/gateway-manager.test.ts index d0739a841..7253167ee 100644 --- a/packages/services/gateway-discord/tests/gateway-manager.test.ts +++ b/packages/services/gateway-discord/tests/gateway-manager.test.ts @@ -8,7 +8,11 @@ * Integration tests should cover the full flow. */ -import { afterEach, beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterAll, afterEach, beforeEach, describe, expect, mock, test } from "bun:test"; + +afterAll(() => { + mock.restore(); +}); mock.module("discord.js", () => ({ Attachment: class MockAttachment {}, diff --git a/packages/services/gateway-discord/tests/logger.test.ts b/packages/services/gateway-discord/tests/logger.test.ts index ca665c130..684b8f452 100644 --- a/packages/services/gateway-discord/tests/logger.test.ts +++ b/packages/services/gateway-discord/tests/logger.test.ts @@ -8,6 +8,13 @@ import { afterEach, beforeEach, describe, expect, spyOn, test } from "bun:test"; // Store original env const originalEnv = { ...process.env }; +let loggerImportVersion = 0; + +async function importLoggerFresh() { + loggerImportVersion += 1; + // @ts-expect-error Bun supports cache-busting query imports in tests. + return import(`../src/logger?test=${loggerImportVersion}`); +} describe("logger", () => { let consoleLogSpy: ReturnType; @@ -28,15 +35,12 @@ describe("logger", () => { consoleLogSpy.mockRestore(); consoleWarnSpy.mockRestore(); consoleErrorSpy.mockRestore(); - // Clear module cache to reload with new env - const modulePath = require.resolve("../src/logger"); - delete require.cache[modulePath]; }); describe("log level filtering", () => { test("debug level logs all messages", async () => { process.env.LOG_LEVEL = "debug"; - const { logger } = await import("../src/logger"); + const { logger } = await importLoggerFresh(); logger.debug("debug message"); logger.info("info message"); @@ -50,9 +54,7 @@ describe("logger", () => { test("info level filters debug messages", async () => { process.env.LOG_LEVEL = "info"; - const modulePath = require.resolve("../src/logger"); - delete require.cache[modulePath]; - const { logger } = await import("../src/logger"); + const { logger } = await importLoggerFresh(); logger.debug("debug message"); logger.info("info message"); @@ -66,9 +68,7 @@ describe("logger", () => { test("warn level filters debug and info messages", async () => { process.env.LOG_LEVEL = "warn"; - const modulePath = require.resolve("../src/logger"); - delete require.cache[modulePath]; - const { logger } = await import("../src/logger"); + const { logger } = await importLoggerFresh(); logger.debug("debug message"); logger.info("info message"); @@ -82,9 +82,7 @@ describe("logger", () => { test("error level only logs errors", async () => { process.env.LOG_LEVEL = "error"; - const modulePath = require.resolve("../src/logger"); - delete require.cache[modulePath]; - const { logger } = await import("../src/logger"); + const { logger } = await importLoggerFresh(); logger.debug("debug message"); logger.info("info message"); @@ -98,9 +96,7 @@ describe("logger", () => { test("defaults to info level when LOG_LEVEL not set", async () => { delete process.env.LOG_LEVEL; - const modulePath = require.resolve("../src/logger"); - delete require.cache[modulePath]; - const { logger } = await import("../src/logger"); + const { logger } = await importLoggerFresh(); logger.debug("debug message"); logger.info("info message"); @@ -112,9 +108,7 @@ describe("logger", () => { describe("message formatting", () => { test("formats message as JSON with timestamp, level, and message", async () => { process.env.LOG_LEVEL = "info"; - const modulePath = require.resolve("../src/logger"); - delete require.cache[modulePath]; - const { logger } = await import("../src/logger"); + const { logger } = await importLoggerFresh(); logger.info("test message"); @@ -130,9 +124,7 @@ describe("logger", () => { test("includes metadata in formatted message", async () => { process.env.LOG_LEVEL = "info"; - const modulePath = require.resolve("../src/logger"); - delete require.cache[modulePath]; - const { logger } = await import("../src/logger"); + const { logger } = await importLoggerFresh(); logger.info("test message", { connectionId: "conn-123", count: 42 }); @@ -145,9 +137,7 @@ describe("logger", () => { test("warn logs to console.warn", async () => { process.env.LOG_LEVEL = "warn"; - const modulePath = require.resolve("../src/logger"); - delete require.cache[modulePath]; - const { logger } = await import("../src/logger"); + const { logger } = await importLoggerFresh(); logger.warn("warning message", { reason: "test" }); @@ -161,9 +151,7 @@ describe("logger", () => { test("error logs to console.error", async () => { process.env.LOG_LEVEL = "error"; - const modulePath = require.resolve("../src/logger"); - delete require.cache[modulePath]; - const { logger } = await import("../src/logger"); + const { logger } = await importLoggerFresh(); logger.error("error message", { error: "something went wrong" }); @@ -179,9 +167,7 @@ describe("logger", () => { describe("edge cases", () => { test("handles undefined metadata", async () => { process.env.LOG_LEVEL = "info"; - const modulePath = require.resolve("../src/logger"); - delete require.cache[modulePath]; - const { logger } = await import("../src/logger"); + const { logger } = await importLoggerFresh(); logger.info("message without meta"); @@ -195,9 +181,7 @@ describe("logger", () => { test("handles empty metadata object", async () => { process.env.LOG_LEVEL = "info"; - const modulePath = require.resolve("../src/logger"); - delete require.cache[modulePath]; - const { logger } = await import("../src/logger"); + const { logger } = await importLoggerFresh(); logger.info("message with empty meta", {}); @@ -209,9 +193,7 @@ describe("logger", () => { test("handles complex nested metadata", async () => { process.env.LOG_LEVEL = "info"; - const modulePath = require.resolve("../src/logger"); - delete require.cache[modulePath]; - const { logger } = await import("../src/logger"); + const { logger } = await importLoggerFresh(); logger.info("complex meta", { nested: { deep: { value: true } }, diff --git a/packages/tests/unit/api/v1-agents-service-route.test.ts b/packages/tests/unit/api/v1-agents-service-route.test.ts index ea345d701..737ed4c10 100644 --- a/packages/tests/unit/api/v1-agents-service-route.test.ts +++ b/packages/tests/unit/api/v1-agents-service-route.test.ts @@ -1,6 +1,10 @@ -import { beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterAll, beforeEach, describe, expect, mock, test } from "bun:test"; import { jsonRequest } from "./route-test-helpers"; +afterAll(() => { + mock.restore(); +}); + const mockRequireServiceKey = mock(); const mockCreateAgent = mock(); const mockEnqueueMiladyProvision = mock(); diff --git a/packages/tests/unit/api/v1-app-agents-route.test.ts b/packages/tests/unit/api/v1-app-agents-route.test.ts index 67f1d4031..128ef18a6 100644 --- a/packages/tests/unit/api/v1-app-agents-route.test.ts +++ b/packages/tests/unit/api/v1-app-agents-route.test.ts @@ -1,6 +1,10 @@ -import { beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterAll, beforeEach, describe, expect, mock, test } from "bun:test"; import { jsonRequest } from "./route-test-helpers"; +afterAll(() => { + mock.restore(); +}); + const mockRequireAuthOrApiKeyWithOrg = mock(); const mockCharacterCreate = mock(); const mockFindByTokenAddress = mock(); diff --git a/packages/tests/unit/api/v1-generation-routes.test.ts b/packages/tests/unit/api/v1-generation-routes.test.ts index 5f494f2ea..e23898611 100644 --- a/packages/tests/unit/api/v1-generation-routes.test.ts +++ b/packages/tests/unit/api/v1-generation-routes.test.ts @@ -1,4 +1,4 @@ -import { afterEach, beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterEach, beforeAll, beforeEach, describe, expect, mock, test } from "bun:test"; import { creditsModuleRuntimeShim } from "@/tests/support/bun-partial-module-shims"; import { jsonRequest } from "./route-test-helpers"; @@ -125,6 +125,8 @@ mock.module("@/lib/pricing", () => ({ VIDEO_GENERATION_FALLBACK_COST: 1, calculateCost: mockCalculateCost, estimateTokens: mockEstimateTokens, + getProviderFromModel: (model: string) => + model.startsWith("openai") || model.startsWith("gpt") ? "openai" : "fal", })); mock.module("@/lib/models", () => ({ @@ -181,7 +183,8 @@ import { OPTIONS as generateImageOptions, } from "@/app/api/v1/generate-image/route"; import { POST as generatePrompts } from "@/app/api/v1/generate-prompts/route"; -import { POST as generateVideo } from "@/app/api/v1/generate-video/route"; + +let generateVideo: typeof import("@/app/api/v1/generate-video/route").POST; const authenticatedUser = { id: "user-1", @@ -194,6 +197,10 @@ const authenticatedUser = { }, }; +beforeAll(async () => { + ({ POST: generateVideo } = await import("@/app/api/v1/generate-video/route")); +}); + beforeEach(() => { process.env.FAL_KEY = "fal_test"; diff --git a/packages/tests/unit/api/v1-process-provisioning-jobs-route.test.ts b/packages/tests/unit/api/v1-process-provisioning-jobs-route.test.ts index fa17d19a5..11c4b5762 100644 --- a/packages/tests/unit/api/v1-process-provisioning-jobs-route.test.ts +++ b/packages/tests/unit/api/v1-process-provisioning-jobs-route.test.ts @@ -1,6 +1,10 @@ -import { afterEach, beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterAll, afterEach, beforeEach, describe, expect, mock, test } from "bun:test"; import { NextRequest } from "next/server"; +afterAll(() => { + mock.restore(); +}); + const mockProcessPendingJobs = mock(); mock.module("@/lib/services/provisioning-jobs", () => ({ diff --git a/packages/tests/unit/api/v1-public-routes.test.ts b/packages/tests/unit/api/v1-public-routes.test.ts index 271e4ff37..86bcb1e6d 100644 --- a/packages/tests/unit/api/v1-public-routes.test.ts +++ b/packages/tests/unit/api/v1-public-routes.test.ts @@ -1,9 +1,13 @@ -import { beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterAll, beforeEach, describe, expect, mock, test } from "bun:test"; import { NextRequest } from "next/server"; import { flushMicrotasks, jsonRequest, routeParams } from "./route-test-helpers"; +afterAll(() => { + mock.restore(); +}); + const mockRequireAuth = mock(); const mockRequireAuthWithOrg = mock(); const mockRequireAuthOrApiKey = mock(); diff --git a/packages/tests/unit/auth-pair-route.test.ts b/packages/tests/unit/auth-pair-route.test.ts index 2b6f6518d..8c2af395d 100644 --- a/packages/tests/unit/auth-pair-route.test.ts +++ b/packages/tests/unit/auth-pair-route.test.ts @@ -1,6 +1,10 @@ -import { beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterAll, beforeEach, describe, expect, mock, test } from "bun:test"; import { jsonRequest } from "./api/route-test-helpers"; +afterAll(() => { + mock.restore(); +}); + const mockValidateToken = mock(); const mockFindByIdAndOrg = mock(); diff --git a/packages/tests/unit/blooio-automation-service.test.ts b/packages/tests/unit/blooio-automation-service.test.ts index be377457d..1eb79c335 100644 --- a/packages/tests/unit/blooio-automation-service.test.ts +++ b/packages/tests/unit/blooio-automation-service.test.ts @@ -11,7 +11,6 @@ */ import { beforeEach, describe, expect, it, mock } from "bun:test"; -import { blooioAutomationService } from "@/lib/services/blooio-automation"; // Mock external dependencies const _mockSecretsService = { @@ -27,284 +26,285 @@ const _mockValidateBlooioChatId = mock(() => true); // These would typically be mocked at module level, but for now we test observable behavior -describe.skipIf(!process.env.DATABASE_URL || process.env.SKIP_DB_DEPENDENT === "1")( - "BlooioAutomationService", - () => { - const testOrgId = "11111111-1111-1111-1111-111111111111"; - const _testUserId = "22222222-2222-2222-2222-222222222222"; - const _testApiKey = "bloo_test_api_key_123"; +const shouldRunBlooioTests = + Boolean(process.env.DATABASE_URL) && process.env.SKIP_DB_DEPENDENT !== "1"; +const blooioAutomationService = shouldRunBlooioTests + ? (await import("@/lib/services/blooio-automation")).blooioAutomationService + : null; - beforeEach(() => { - // Clear the status cache before each test - blooioAutomationService.invalidateStatusCache(testOrgId); - }); - - describe("validateApiKey", () => { - it("returns invalid when API key is empty", async () => { - const result = await blooioAutomationService.validateApiKey(""); - expect(result.valid).toBe(false); - expect(result.error).toBe("API key is required"); - }); +describe.skipIf(!shouldRunBlooioTests)("BlooioAutomationService", () => { + const testOrgId = "11111111-1111-1111-1111-111111111111"; + const _testUserId = "22222222-2222-2222-2222-222222222222"; + const _testApiKey = "bloo_test_api_key_123"; - it("returns invalid when API key is whitespace", async () => { - const result = await blooioAutomationService.validateApiKey(" "); - expect(result.valid).toBe(false); - expect(result.error).toBe("API key is required"); - }); + beforeEach(() => { + // Clear the status cache before each test + blooioAutomationService.invalidateStatusCache(testOrgId); + }); - // Note: Full validation tests would require mocking blooioApiRequest + describe("validateApiKey", () => { + it("returns invalid when API key is empty", async () => { + const result = await blooioAutomationService.validateApiKey(""); + expect(result.valid).toBe(false); + expect(result.error).toBe("API key is required"); }); - describe("invalidateStatusCache", () => { - it("clears cache for organization", async () => { - // This is a simple method that clears the internal cache - // We test it doesn't throw - expect(() => { - blooioAutomationService.invalidateStatusCache(testOrgId); - }).not.toThrow(); - }); - - it("handles multiple invalidations", () => { - expect(() => { - blooioAutomationService.invalidateStatusCache(testOrgId); - blooioAutomationService.invalidateStatusCache(testOrgId); - blooioAutomationService.invalidateStatusCache("33333333-3000-3000-3000-333333333333"); - }).not.toThrow(); - }); + it("returns invalid when API key is whitespace", async () => { + const result = await blooioAutomationService.validateApiKey(" "); + expect(result.valid).toBe(false); + expect(result.error).toBe("API key is required"); }); - describe("getWebhookUrl", () => { - it("returns correct webhook URL format", () => { - const url = blooioAutomationService.getWebhookUrl(testOrgId); - expect(url).toContain("/api/webhooks/blooio/"); - expect(url).toContain(testOrgId); - }); + // Note: Full validation tests would require mocking blooioApiRequest + }); - it("includes organization ID in URL", () => { - const orgId = "44444444-4444-4444-4444-444444444444"; - const url = blooioAutomationService.getWebhookUrl(orgId); - expect(url).toContain(orgId); - }); + describe("invalidateStatusCache", () => { + it("clears cache for organization", async () => { + // This is a simple method that clears the internal cache + // We test it doesn't throw + expect(() => { + blooioAutomationService.invalidateStatusCache(testOrgId); + }).not.toThrow(); }); - describe("Chat ID Validation", () => { - // These tests validate the chat ID normalization and validation logic + it("handles multiple invalidations", () => { + expect(() => { + blooioAutomationService.invalidateStatusCache(testOrgId); + blooioAutomationService.invalidateStatusCache(testOrgId); + blooioAutomationService.invalidateStatusCache("33333333-3000-3000-3000-333333333333"); + }).not.toThrow(); + }); + }); - describe("normalizes chat IDs", () => { - it("handles phone numbers", async () => { - // sendMessage normalizes chat IDs before sending - // This tests the format validation behavior - const result = await blooioAutomationService.sendMessage(testOrgId, "+15551234567", { - text: "test", - }); + describe("getWebhookUrl", () => { + it("returns correct webhook URL format", () => { + const url = blooioAutomationService.getWebhookUrl(testOrgId); + expect(url).toContain("/api/webhooks/blooio/"); + expect(url).toContain(testOrgId); + }); - // Will fail because no API key, but tests normalization path - expect(result.success).toBe(false); - }); + it("includes organization ID in URL", () => { + const orgId = "44444444-4444-4444-4444-444444444444"; + const url = blooioAutomationService.getWebhookUrl(orgId); + expect(url).toContain(orgId); + }); + }); - it("handles email addresses", async () => { - const result = await blooioAutomationService.sendMessage(testOrgId, "user@example.com", { - text: "test", - }); + describe("Chat ID Validation", () => { + // These tests validate the chat ID normalization and validation logic - expect(result.success).toBe(false); - // Would fail at send, not validation + describe("normalizes chat IDs", () => { + it("handles phone numbers", async () => { + // sendMessage normalizes chat IDs before sending + // This tests the format validation behavior + const result = await blooioAutomationService.sendMessage(testOrgId, "+15551234567", { + text: "test", }); - it("handles comma-separated chat IDs", async () => { - // Test multiple recipients - const result = await blooioAutomationService.sendMessage( - testOrgId, - "+15551234567, +15559876543", - { text: "test" }, - ); - - expect(result.success).toBe(false); - }); + // Will fail because no API key, but tests normalization path + expect(result.success).toBe(false); }); - }); - describe("sendMessage", () => { - it("returns error when Blooio is not configured", async () => { - // getApiKey returns null when not configured - const result = await blooioAutomationService.sendMessage(testOrgId, "+15551234567", { - text: "Hello", + it("handles email addresses", async () => { + const result = await blooioAutomationService.sendMessage(testOrgId, "user@example.com", { + text: "test", }); expect(result.success).toBe(false); - expect(result.error).toContain("not configured"); + // Would fail at send, not validation }); - it("validates chat ID format", async () => { - // Invalid chat ID format - would fail validation if API key existed - const result = await blooioAutomationService.sendMessage(testOrgId, "invalid-chat-id", { - text: "Hello", - }); + it("handles comma-separated chat IDs", async () => { + // Test multiple recipients + const result = await blooioAutomationService.sendMessage( + testOrgId, + "+15551234567, +15559876543", + { text: "test" }, + ); - // May return "not configured" or validation error depending on order expect(result.success).toBe(false); }); + }); + }); - it("handles group chat IDs", async () => { - // Group IDs start with grp_ - const result = await blooioAutomationService.sendMessage(testOrgId, "grp_abc123", { - text: "Hello group", - }); + describe("sendMessage", () => { + it("returns error when Blooio is not configured", async () => { + // getApiKey returns null when not configured + const result = await blooioAutomationService.sendMessage(testOrgId, "+15551234567", { + text: "Hello", + }); - // Will fail due to no config, but validates format is accepted - expect(result.success).toBe(false); + expect(result.success).toBe(false); + expect(result.error).toContain("not configured"); + }); + + it("validates chat ID format", async () => { + // Invalid chat ID format - would fail validation if API key existed + const result = await blooioAutomationService.sendMessage(testOrgId, "invalid-chat-id", { + text: "Hello", }); - it("handles message with attachments", async () => { - const result = await blooioAutomationService.sendMessage(testOrgId, "+15551234567", { - text: "Check this out", - attachments: [{ url: "https://example.com/image.jpg", name: "image.jpg" }], - }); + // May return "not configured" or validation error depending on order + expect(result.success).toBe(false); + }); - expect(result.success).toBe(false); + it("handles group chat IDs", async () => { + // Group IDs start with grp_ + const result = await blooioAutomationService.sendMessage(testOrgId, "grp_abc123", { + text: "Hello group", }); - it("handles message with metadata", async () => { - const result = await blooioAutomationService.sendMessage(testOrgId, "+15551234567", { - text: "With metadata", - metadata: { source: "test", timestamp: Date.now() }, - }); + // Will fail due to no config, but validates format is accepted + expect(result.success).toBe(false); + }); - expect(result.success).toBe(false); + it("handles message with attachments", async () => { + const result = await blooioAutomationService.sendMessage(testOrgId, "+15551234567", { + text: "Check this out", + attachments: [{ url: "https://example.com/image.jpg", name: "image.jpg" }], }); + + expect(result.success).toBe(false); }); - describe("isConfigured", () => { - it("returns false when no API key is stored", async () => { - const result = await blooioAutomationService.isConfigured(testOrgId); - // Without mocking secrets service, this depends on actual secrets - expect(typeof result).toBe("boolean"); + it("handles message with metadata", async () => { + const result = await blooioAutomationService.sendMessage(testOrgId, "+15551234567", { + text: "With metadata", + metadata: { source: "test", timestamp: Date.now() }, }); + + expect(result.success).toBe(false); }); + }); - describe("getConnectionStatus", () => { - it("returns unconfigured status when no API key", async () => { - const status = await blooioAutomationService.getConnectionStatus(testOrgId); + describe("isConfigured", () => { + it("returns false when no API key is stored", async () => { + const result = await blooioAutomationService.isConfigured(testOrgId); + // Without mocking secrets service, this depends on actual secrets + expect(typeof result).toBe("boolean"); + }); + }); - // Without mock, depends on actual config - expect(status).toHaveProperty("connected"); - expect(status).toHaveProperty("configured"); - expect(typeof status.connected).toBe("boolean"); - expect(typeof status.configured).toBe("boolean"); - }); + describe("getConnectionStatus", () => { + it("returns unconfigured status when no API key", async () => { + const status = await blooioAutomationService.getConnectionStatus(testOrgId); - it("caches status for performance", async () => { - // First call - const status1 = await blooioAutomationService.getConnectionStatus(testOrgId); + // Without mock, depends on actual config + expect(status).toHaveProperty("connected"); + expect(status).toHaveProperty("configured"); + expect(typeof status.connected).toBe("boolean"); + expect(typeof status.configured).toBe("boolean"); + }); - // Second call should use cache - const status2 = await blooioAutomationService.getConnectionStatus(testOrgId); + it("caches status for performance", async () => { + // First call + const status1 = await blooioAutomationService.getConnectionStatus(testOrgId); - // Results should be identical - expect(status1.connected).toBe(status2.connected); - expect(status1.configured).toBe(status2.configured); - }); + // Second call should use cache + const status2 = await blooioAutomationService.getConnectionStatus(testOrgId); - it("respects skipCache option", async () => { - // First call - const _status1 = await blooioAutomationService.getConnectionStatus(testOrgId); + // Results should be identical + expect(status1.connected).toBe(status2.connected); + expect(status1.configured).toBe(status2.configured); + }); - // Second call with skipCache - const status2 = await blooioAutomationService.getConnectionStatus(testOrgId, { - skipCache: true, - }); + it("respects skipCache option", async () => { + // First call + const _status1 = await blooioAutomationService.getConnectionStatus(testOrgId); - // Both should work without error - expect(status2).toHaveProperty("connected"); + // Second call with skipCache + const status2 = await blooioAutomationService.getConnectionStatus(testOrgId, { + skipCache: true, }); - it("returns fromNumber when available", async () => { - const status = await blooioAutomationService.getConnectionStatus(testOrgId); + // Both should work without error + expect(status2).toHaveProperty("connected"); + }); - // fromNumber is optional - if (status.connected) { - expect(typeof status.fromNumber === "string" || status.fromNumber === undefined).toBe( - true, - ); - } - }); + it("returns fromNumber when available", async () => { + const status = await blooioAutomationService.getConnectionStatus(testOrgId); + + // fromNumber is optional + if (status.connected) { + expect(typeof status.fromNumber === "string" || status.fromNumber === undefined).toBe(true); + } }); + }); - describe("Error Handling", () => { - it("handles empty organization ID gracefully", async () => { - const status = await blooioAutomationService.getConnectionStatus( - "00000000-0000-0000-0000-000000000000", - ); - expect(status).toHaveProperty("connected"); - }); + describe("Error Handling", () => { + it("handles empty organization ID gracefully", async () => { + const status = await blooioAutomationService.getConnectionStatus( + "00000000-0000-0000-0000-000000000000", + ); + expect(status).toHaveProperty("connected"); + }); - it("handles special characters in organization ID", async () => { - const status = await blooioAutomationService.getConnectionStatus( - "00000000-0000-0000-0000-000000000001", - ); - expect(status).toHaveProperty("connected"); - }); + it("handles special characters in organization ID", async () => { + const status = await blooioAutomationService.getConnectionStatus( + "00000000-0000-0000-0000-000000000001", + ); + expect(status).toHaveProperty("connected"); }); + }); - describe("Credential Methods", () => { - describe("getApiKey", () => { - it("returns null when no API key stored", async () => { - // This tests the fallback behavior - const apiKey = await blooioAutomationService.getApiKey( - "55555555-5555-5555-5555-555555555555", - ); - // May return null or env var fallback - expect(apiKey === null || typeof apiKey === "string").toBe(true); - }); + describe("Credential Methods", () => { + describe("getApiKey", () => { + it("returns null when no API key stored", async () => { + // This tests the fallback behavior + const apiKey = await blooioAutomationService.getApiKey( + "55555555-5555-5555-5555-555555555555", + ); + // May return null or env var fallback + expect(apiKey === null || typeof apiKey === "string").toBe(true); }); + }); - describe("getWebhookSecret", () => { - it("returns null when no webhook secret stored", async () => { - const secret = await blooioAutomationService.getWebhookSecret( - "55555555-5555-5555-5555-555555555555", - ); - expect(secret === null || typeof secret === "string").toBe(true); - }); + describe("getWebhookSecret", () => { + it("returns null when no webhook secret stored", async () => { + const secret = await blooioAutomationService.getWebhookSecret( + "55555555-5555-5555-5555-555555555555", + ); + expect(secret === null || typeof secret === "string").toBe(true); }); + }); - describe("getFromNumber", () => { - it("returns null when no from number stored", async () => { - const fromNumber = await blooioAutomationService.getFromNumber( - "55555555-5555-5555-5555-555555555555", - ); - expect(fromNumber === null || typeof fromNumber === "string").toBe(true); - }); + describe("getFromNumber", () => { + it("returns null when no from number stored", async () => { + const fromNumber = await blooioAutomationService.getFromNumber( + "55555555-5555-5555-5555-555555555555", + ); + expect(fromNumber === null || typeof fromNumber === "string").toBe(true); }); }); + }); - describe("Message Request Handling", () => { - it("handles text-only message", async () => { - const result = await blooioAutomationService.sendMessage(testOrgId, "+15551234567", { - text: "Simple text message", - }); - expect(result).toHaveProperty("success"); + describe("Message Request Handling", () => { + it("handles text-only message", async () => { + const result = await blooioAutomationService.sendMessage(testOrgId, "+15551234567", { + text: "Simple text message", }); + expect(result).toHaveProperty("success"); + }); - it("handles message with typing indicator", async () => { - const result = await blooioAutomationService.sendMessage(testOrgId, "+15551234567", { - text: "Message with typing", - use_typing_indicator: true, - }); - expect(result).toHaveProperty("success"); + it("handles message with typing indicator", async () => { + const result = await blooioAutomationService.sendMessage(testOrgId, "+15551234567", { + text: "Message with typing", + use_typing_indicator: true, }); + expect(result).toHaveProperty("success"); + }); - it("handles message with idempotency key", async () => { - const result = await blooioAutomationService.sendMessage(testOrgId, "+15551234567", { - text: "Idempotent message", - idempotencyKey: "unique-key-123", - }); - expect(result).toHaveProperty("success"); + it("handles message with idempotency key", async () => { + const result = await blooioAutomationService.sendMessage(testOrgId, "+15551234567", { + text: "Idempotent message", + idempotencyKey: "unique-key-123", }); + expect(result).toHaveProperty("success"); }); - }, -); + }); +}); -describe("BlooioAutomationService Integration-style Tests", () => { +describe.skipIf(!shouldRunBlooioTests)("BlooioAutomationService Integration-style Tests", () => { // These tests would require database and secrets service to be available // They test the full flow without external Blooio API calls diff --git a/packages/tests/unit/compat-auth-and-restart-route.test.ts b/packages/tests/unit/compat-auth-and-restart-route.test.ts index debb15930..941821b9e 100644 --- a/packages/tests/unit/compat-auth-and-restart-route.test.ts +++ b/packages/tests/unit/compat-auth-and-restart-route.test.ts @@ -1,7 +1,11 @@ -import { beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterAll, beforeEach, describe, expect, mock, test } from "bun:test"; import { NextRequest } from "next/server"; import { routeParams } from "./api/route-test-helpers"; +afterAll(() => { + mock.restore(); +}); + const mockRequireServiceKey = mock(); const mockAuthenticateWaifuBridge = mock(); const mockRequireAuthOrApiKeyWithOrg = mock(); diff --git a/packages/tests/unit/compat-availability-route.test.ts b/packages/tests/unit/compat-availability-route.test.ts index d6a96db91..8c4bd0000 100644 --- a/packages/tests/unit/compat-availability-route.test.ts +++ b/packages/tests/unit/compat-availability-route.test.ts @@ -1,6 +1,10 @@ -import { beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterAll, beforeEach, describe, expect, mock, test } from "bun:test"; import { NextRequest } from "next/server"; +afterAll(() => { + mock.restore(); +}); + const mockFindAll = mock(); const mockValidateServiceKey = mock(); const mockAuthenticateWaifuBridge = mock(); diff --git a/packages/tests/unit/compat-error-handler.test.ts b/packages/tests/unit/compat-error-handler.test.ts index f5da36d78..eb36ce51b 100644 --- a/packages/tests/unit/compat-error-handler.test.ts +++ b/packages/tests/unit/compat-error-handler.test.ts @@ -9,7 +9,11 @@ * - Non-Error throws get generic 500 */ -import { describe, expect, mock, test } from "bun:test"; +import { afterAll, describe, expect, mock, test } from "bun:test"; + +afterAll(() => { + mock.restore(); +}); // Mock logger before importing error-handler mock.module("@/lib/utils/logger", () => ({ diff --git a/packages/tests/unit/compat-routes-error-handling.test.ts b/packages/tests/unit/compat-routes-error-handling.test.ts index de2acad9c..ff0ca894b 100644 --- a/packages/tests/unit/compat-routes-error-handling.test.ts +++ b/packages/tests/unit/compat-routes-error-handling.test.ts @@ -7,10 +7,14 @@ * - POST /agents auto-provision warning field (item 4) */ -import { beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterAll, beforeEach, describe, expect, mock, test } from "bun:test"; import { NextRequest } from "next/server"; import { routeParams } from "./api/route-test-helpers"; +afterAll(() => { + mock.restore(); +}); + // ── Mocks ──────────────────────────────────────────────────────────── const mockRequireServiceKey = mock(); const mockAuthenticateWaifuBridge = mock(); diff --git a/packages/tests/unit/discord-automation-oauth.test.ts b/packages/tests/unit/discord-automation-oauth.test.ts index bf9a59e2e..6cb3c9784 100644 --- a/packages/tests/unit/discord-automation-oauth.test.ts +++ b/packages/tests/unit/discord-automation-oauth.test.ts @@ -1,5 +1,9 @@ import { afterAll, beforeAll, beforeEach, describe, expect, mock, test } from "bun:test"; +afterAll(() => { + mock.restore(); +}); + const mockDiscordGuildUpsert = mock(); const mockDiscordChannelUpsert = mock(); const mockLogger = { @@ -95,7 +99,7 @@ describe("discordAutomationService.handleBotOAuthCallback", () => { } throw new Error(`Unexpected fetch: ${url}`); }); - globalThis.fetch = fetchMock as typeof fetch; + globalThis.fetch = fetchMock as unknown as typeof fetch; const result = await discordAutomationService.handleBotOAuthCallback({ code: "oauth-code", @@ -183,7 +187,7 @@ describe("discordAutomationService.handleBotOAuthCallback", () => { } throw new Error(`Unexpected fetch: ${url}`); }); - globalThis.fetch = fetchMock as typeof fetch; + globalThis.fetch = fetchMock as unknown as typeof fetch; const result = await discordAutomationService.handleBotOAuthCallback({ code: "oauth-code", diff --git a/packages/tests/unit/docker-ssh-cloud-deploy.test.ts b/packages/tests/unit/docker-ssh-cloud-deploy.test.ts index 324cc1073..3f79127a8 100644 --- a/packages/tests/unit/docker-ssh-cloud-deploy.test.ts +++ b/packages/tests/unit/docker-ssh-cloud-deploy.test.ts @@ -6,7 +6,11 @@ * never leaked in error messages or logs. */ -import { afterEach, beforeEach, describe, expect, test } from "bun:test"; +import { afterAll, afterEach, beforeEach, describe, expect, mock, test } from "bun:test"; + +afterAll(() => { + mock.restore(); +}); // Import the real `redact` object directly. Other test files in the batch // may call `mock.module("@/lib/utils/logger", ...)` without including `redact`, diff --git a/packages/tests/unit/eliza-app/connections-route.test.ts b/packages/tests/unit/eliza-app/connections-route.test.ts index 423a84bc5..dffc59040 100644 --- a/packages/tests/unit/eliza-app/connections-route.test.ts +++ b/packages/tests/unit/eliza-app/connections-route.test.ts @@ -1,4 +1,4 @@ -import { beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterAll, beforeEach, describe, expect, mock, test } from "bun:test"; import { NextRequest } from "next/server"; const mockValidateAuthHeader = mock(); @@ -31,6 +31,10 @@ mock.module("@/lib/services/oauth/provider-registry", () => ({ })); describe("Eliza App connections routes", () => { + afterAll(() => { + mock.restore(); + }); + beforeEach(() => { mockValidateAuthHeader.mockReset(); mockListConnections.mockReset(); diff --git a/packages/tests/unit/engagement-metrics/admin-metrics-api.test.ts b/packages/tests/unit/engagement-metrics/admin-metrics-api.test.ts index 7c8a1bc6a..d7bc79b17 100644 --- a/packages/tests/unit/engagement-metrics/admin-metrics-api.test.ts +++ b/packages/tests/unit/engagement-metrics/admin-metrics-api.test.ts @@ -9,9 +9,13 @@ * - Error handling for service failures */ -import { beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterAll, beforeEach, describe, expect, mock, test } from "bun:test"; import { NextRequest } from "next/server"; +afterAll(() => { + mock.restore(); +}); + // ─── Mock Setup ────────────────────────────────────────────────────────────── const mockOverview = { diff --git a/packages/tests/unit/engagement-metrics/compute-metrics-cron.test.ts b/packages/tests/unit/engagement-metrics/compute-metrics-cron.test.ts index ec5d22d66..d0edca21b 100644 --- a/packages/tests/unit/engagement-metrics/compute-metrics-cron.test.ts +++ b/packages/tests/unit/engagement-metrics/compute-metrics-cron.test.ts @@ -7,9 +7,13 @@ * - Error handling */ -import { afterEach, beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterAll, afterEach, beforeEach, describe, expect, mock, test } from "bun:test"; import { NextRequest } from "next/server"; +afterAll(() => { + mock.restore(); +}); + // ─── Mock Setup ────────────────────────────────────────────────────────────── const mockComputeDailyMetrics = mock((_date: Date) => Promise.resolve()); diff --git a/packages/tests/unit/engagement-metrics/user-metrics-contract.test.ts b/packages/tests/unit/engagement-metrics/user-metrics-contract.test.ts index 855327e35..5f7459bd5 100644 --- a/packages/tests/unit/engagement-metrics/user-metrics-contract.test.ts +++ b/packages/tests/unit/engagement-metrics/user-metrics-contract.test.ts @@ -10,7 +10,11 @@ * through the API route tests. */ -import { beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterAll, beforeEach, describe, expect, mock, test } from "bun:test"; + +afterAll(() => { + mock.restore(); +}); // ─── Mock Setup ────────────────────────────────────────────────────────────── diff --git a/packages/tests/unit/field-encryption.test.ts b/packages/tests/unit/field-encryption.test.ts index 77ad339fc..f44d291f9 100644 --- a/packages/tests/unit/field-encryption.test.ts +++ b/packages/tests/unit/field-encryption.test.ts @@ -1,6 +1,10 @@ -import { beforeEach, describe, expect, it, mock } from "bun:test"; +import { afterAll, beforeEach, describe, expect, it, mock } from "bun:test"; import crypto from "crypto"; +afterAll(() => { + mock.restore(); +}); + type OrgEncryptionKeyRow = { id: string; organization_id: string; diff --git a/packages/tests/unit/headscale-ip-route.test.ts b/packages/tests/unit/headscale-ip-route.test.ts index ab4931174..ae8bbda43 100644 --- a/packages/tests/unit/headscale-ip-route.test.ts +++ b/packages/tests/unit/headscale-ip-route.test.ts @@ -1,7 +1,11 @@ -import { afterEach, beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterAll, afterEach, beforeEach, describe, expect, mock, test } from "bun:test"; import { NextRequest } from "next/server"; import { routeParams } from "./api/route-test-helpers"; +afterAll(() => { + mock.restore(); +}); + const mockFindById = mock(); const savedHeadscaleInternalToken = process.env.HEADSCALE_INTERNAL_TOKEN; diff --git a/packages/tests/unit/managed-discord-eliza-app-route.test.ts b/packages/tests/unit/managed-discord-eliza-app-route.test.ts index 574141e52..649a8ba5d 100644 --- a/packages/tests/unit/managed-discord-eliza-app-route.test.ts +++ b/packages/tests/unit/managed-discord-eliza-app-route.test.ts @@ -56,6 +56,7 @@ describe("managed Discord Eliza App routing route", () => { }); afterAll(() => { + mock.restore(); process.env = { ...originalEnv }; }); diff --git a/packages/tests/unit/mcp-google-tools.test.ts b/packages/tests/unit/mcp-google-tools.test.ts index f3a82776c..4afede5eb 100644 --- a/packages/tests/unit/mcp-google-tools.test.ts +++ b/packages/tests/unit/mcp-google-tools.test.ts @@ -5,7 +5,7 @@ * Real: all handler logic, helpers, mappers, error formatting. */ -import { afterEach, beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterAll, afterEach, beforeEach, describe, expect, mock, test } from "bun:test"; import { authContextStorage } from "@/app/api/mcp/lib/context"; import type { ListConnectionsParams, OAuthConnection } from "@/lib/services/oauth/types"; @@ -117,6 +117,10 @@ function parse(result: GoogleToolHandlerResult) { // ══════════════════════════════════════════════════════════════════════════════ describe("Google MCP Tools", () => { + afterAll(() => { + mock.restore(); + }); + beforeEach(() => { setupMockFetch(); mockOAuth.getValidTokenByPlatform.mockReset(); diff --git a/packages/tests/unit/mcp-hubspot-tools.test.ts b/packages/tests/unit/mcp-hubspot-tools.test.ts index 5281284ab..989852726 100644 --- a/packages/tests/unit/mcp-hubspot-tools.test.ts +++ b/packages/tests/unit/mcp-hubspot-tools.test.ts @@ -9,7 +9,7 @@ * - Network error handling */ -import { afterEach, beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterAll, afterEach, beforeEach, describe, expect, mock, test } from "bun:test"; import { authContextStorage } from "@/app/api/mcp/lib/context"; import type { GetTokenByPlatformParams, @@ -107,6 +107,10 @@ function createMockAuth(orgId: string = "test-org-123") { } describe("HubSpot MCP Tools", () => { + afterAll(() => { + mock.restore(); + }); + beforeEach(() => { setupMockFetch(); mockOAuthService.getValidTokenByPlatform.mockReset(); diff --git a/packages/tests/unit/mcp-twitter-tools.test.ts b/packages/tests/unit/mcp-twitter-tools.test.ts index cc7131375..284ff49c0 100644 --- a/packages/tests/unit/mcp-twitter-tools.test.ts +++ b/packages/tests/unit/mcp-twitter-tools.test.ts @@ -5,10 +5,14 @@ * Real: all handler logic, helpers, mappers, error formatting. */ -import { beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterAll, beforeEach, describe, expect, mock, test } from "bun:test"; import { authContextStorage } from "@/app/api/mcp/lib/context"; import type { OAuthConnection } from "@/lib/services/oauth/types"; +afterAll(() => { + mock.restore(); +}); + function twitterOAuthFixture( o: Partial & Pick, ): OAuthConnection { diff --git a/packages/tests/unit/milady-agent-discord-routes.test.ts b/packages/tests/unit/milady-agent-discord-routes.test.ts index a615220de..c28ee7b7d 100644 --- a/packages/tests/unit/milady-agent-discord-routes.test.ts +++ b/packages/tests/unit/milady-agent-discord-routes.test.ts @@ -1,7 +1,11 @@ -import { beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterAll, beforeEach, describe, expect, mock, test } from "bun:test"; import { NextRequest } from "next/server"; import { jsonRequest, routeParams } from "./api/route-test-helpers"; +afterAll(() => { + mock.restore(); +}); + const mockRequireAuthOrApiKeyWithOrg = mock(); const mockGetAgent = mock(); const mockGetStatus = mock(); diff --git a/packages/tests/unit/milady-create-routes.test.ts b/packages/tests/unit/milady-create-routes.test.ts index ed8587888..821196f4f 100644 --- a/packages/tests/unit/milady-create-routes.test.ts +++ b/packages/tests/unit/milady-create-routes.test.ts @@ -1,8 +1,12 @@ -import { afterEach, beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterAll, afterEach, beforeEach, describe, expect, mock, test } from "bun:test"; import { mockMiladyPricingMinimumDepositForRouteTests } from "../helpers/mock-milady-pricing-for-route-tests"; import { jsonRequest } from "./api/route-test-helpers"; +afterAll(() => { + mock.restore(); +}); + const mockRequireAuthOrApiKeyWithOrg = mock(); const mockRequireServiceKey = mock(); const mockAuthenticateWaifuBridge = mock(); diff --git a/packages/tests/unit/milady-google-connector.test.ts b/packages/tests/unit/milady-google-connector.test.ts index c9933dc3e..e048ebe0f 100644 --- a/packages/tests/unit/milady-google-connector.test.ts +++ b/packages/tests/unit/milady-google-connector.test.ts @@ -1,21 +1,20 @@ -import { afterAll, beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterEach, beforeEach, describe, expect, mock, test } from "bun:test"; import type { OAuthConnection } from "@/lib/services/oauth/types"; -afterAll(() => { - mock.restore(); -}); - const mockListConnections = mock(); const mockGetValidTokenByPlatformWithConnectionId = mock(); const mockInitiateAuth = mock(); const mockRevokeConnection = mock(); -const mockGoogleFetchWithToken = mock(); -const mockGetProvider = mock(); -const mockIsProviderConfigured = mock(); const mockDbLimit = mock(); const mockDbWhere = mock(() => ({ limit: mockDbLimit })); const mockDbFrom = mock(() => ({ where: mockDbWhere })); const mockDbSelect = mock(() => ({ from: mockDbFrom })); +const originalFetch = globalThis.fetch; +const providerEnvKeys = ["GOOGLE_CLIENT_ID", "GOOGLE_CLIENT_SECRET"] as const; +let savedProviderEnv: Record<(typeof providerEnvKeys)[number], string | undefined> = { + GOOGLE_CLIENT_ID: undefined, + GOOGLE_CLIENT_SECRET: undefined, +}; mock.module("drizzle-orm", () => ({ and: (...args: unknown[]) => args, @@ -65,18 +64,6 @@ mock.module("@/lib/services/oauth/oauth-service", () => ({ null, })); -mock.module("@/lib/services/oauth/provider-registry", () => ({ - getProvider: mockGetProvider, - isProviderConfigured: mockIsProviderConfigured, -})); - -mock.module("@/lib/utils/google-mcp-shared", () => ({ - applyTimeZone: (dateTime: string, timeZone: string | undefined) => - timeZone ? { dateTime, timeZone } : { dateTime }, - googleFetchWithToken: mockGoogleFetchWithToken, - sanitizeHeaderValue: (value: string) => value.replace(/[\r\n]/g, ""), -})); - import { disconnectManagedGoogleConnection, fetchManagedGoogleCalendarFeed, @@ -86,6 +73,28 @@ import { sendManagedGoogleReply, } from "@/lib/services/milady-google-connector"; +// Drop the top-level module mocks after importing the service under test so +// they don't leak into later test files loaded by the same Bun process. +mock.restore(); + +function saveProviderEnv() { + savedProviderEnv = { + GOOGLE_CLIENT_ID: process.env.GOOGLE_CLIENT_ID, + GOOGLE_CLIENT_SECRET: process.env.GOOGLE_CLIENT_SECRET, + }; +} + +function restoreProviderEnv() { + for (const key of providerEnvKeys) { + const value = savedProviderEnv[key]; + if (value === undefined) { + delete process.env[key]; + } else { + process.env[key] = value; + } + } +} + function createConnection(overrides: Partial = {}): OAuthConnection { return { id: "conn-google-1", @@ -116,20 +125,17 @@ function createConnection(overrides: Partial = {}): OAuthConnec describe("milady Google connector service", () => { beforeEach(() => { + saveProviderEnv(); + process.env.GOOGLE_CLIENT_ID = "google-client-id"; + process.env.GOOGLE_CLIENT_SECRET = "google-client-secret"; mockListConnections.mockReset(); mockGetValidTokenByPlatformWithConnectionId.mockReset(); mockInitiateAuth.mockReset(); mockRevokeConnection.mockReset(); - mockGoogleFetchWithToken.mockReset(); - mockGetProvider.mockReset(); - mockIsProviderConfigured.mockReset(); mockDbLimit.mockReset(); mockDbWhere.mockClear(); mockDbFrom.mockClear(); mockDbSelect.mockClear(); - - mockGetProvider.mockReturnValue({ id: "google" }); - mockIsProviderConfigured.mockReturnValue(true); mockDbLimit.mockResolvedValue([ { token_expires_at: new Date("2026-04-05T00:00:00.000Z"), @@ -144,6 +150,11 @@ describe("milady Google connector service", () => { }); }); + afterEach(() => { + globalThis.fetch = originalFetch; + restoreProviderEnv(); + }); + test("reports managed Google connector status from the active owner connection", async () => { mockListConnections.mockResolvedValue([createConnection()]); @@ -267,41 +278,42 @@ describe("milady Google connector service", () => { }); test("normalizes Google Calendar events into the Milady managed feed shape", async () => { - mockGoogleFetchWithToken.mockResolvedValueOnce( - new Response( - JSON.stringify({ - items: [ - { - id: "event-1", - summary: "Founder sync", - description: "Review the launch plan", - location: "HQ", - status: "confirmed", - htmlLink: "https://calendar.google.com/event?eid=event-1", - start: { - dateTime: "2026-04-04T10:00:00-07:00", - timeZone: "America/Los_Angeles", - }, - end: { - dateTime: "2026-04-04T10:30:00-07:00", - timeZone: "America/Los_Angeles", - }, - organizer: { - email: "founder@example.com", - displayName: "Founder Example", - }, - attendees: [ - { - email: "teammate@example.com", - displayName: "Teammate", - responseStatus: "accepted", + globalThis.fetch = mock( + async () => + new Response( + JSON.stringify({ + items: [ + { + id: "event-1", + summary: "Founder sync", + description: "Review the launch plan", + location: "HQ", + status: "confirmed", + htmlLink: "https://calendar.google.com/event?eid=event-1", + start: { + dateTime: "2026-04-04T10:00:00-07:00", + timeZone: "America/Los_Angeles", }, - ], - }, - ], - }), - ), - ); + end: { + dateTime: "2026-04-04T10:30:00-07:00", + timeZone: "America/Los_Angeles", + }, + organizer: { + email: "founder@example.com", + displayName: "Founder Example", + }, + attendees: [ + { + email: "teammate@example.com", + displayName: "Teammate", + responseStatus: "accepted", + }, + ], + }, + ], + }), + ), + ) as unknown as typeof fetch; const feed = await fetchManagedGoogleCalendarFeed({ organizationId: "org-1", @@ -327,36 +339,36 @@ describe("milady Google connector service", () => { test("classifies Gmail triage messages using the connected Google identity", async () => { mockListConnections.mockResolvedValue([createConnection()]); - mockGoogleFetchWithToken - .mockResolvedValueOnce( - new Response( + globalThis.fetch = mock(async (url: string | URL | Request) => { + const urlString = url.toString(); + if (urlString.includes("/messages?")) { + return new Response( JSON.stringify({ messages: [{ id: "msg-1", threadId: "thread-1" }], }), - ), - ) - .mockResolvedValueOnce( - new Response( - JSON.stringify({ - id: "msg-1", - threadId: "thread-1", - labelIds: ["INBOX", "UNREAD", "IMPORTANT"], - snippet: "Can you review the plan today?", - internalDate: "1775327400000", - historyId: "history-1", - sizeEstimate: 1234, - payload: { - headers: [ - { name: "Subject", value: "Project sync" }, - { name: "From", value: "CEO Example " }, - { name: "To", value: "founder@example.com" }, - { name: "Reply-To", value: "ceo@example.com" }, - { name: "Message-Id", value: "" }, - ], - }, - }), - ), + ); + } + return new Response( + JSON.stringify({ + id: "msg-1", + threadId: "thread-1", + labelIds: ["INBOX", "UNREAD", "IMPORTANT"], + snippet: "Can you review the plan today?", + internalDate: "1775327400000", + historyId: "history-1", + sizeEstimate: 1234, + payload: { + headers: [ + { name: "Subject", value: "Project sync" }, + { name: "From", value: "CEO Example " }, + { name: "To", value: "founder@example.com" }, + { name: "Reply-To", value: "ceo@example.com" }, + { name: "Message-Id", value: "" }, + ], + }, + }), ); + }) as unknown as typeof fetch; const triage = await fetchManagedGoogleGmailTriage({ organizationId: "org-1", @@ -380,7 +392,8 @@ describe("milady Google connector service", () => { }); test("sends Gmail replies with sanitized RFC822 headers", async () => { - mockGoogleFetchWithToken.mockResolvedValueOnce(new Response(null, { status: 200 })); + const fetchMock = mock(async () => new Response(null, { status: 200 })); + globalThis.fetch = fetchMock as unknown as typeof fetch; await sendManagedGoogleReply({ organizationId: "org-1", @@ -394,12 +407,8 @@ describe("milady Google connector service", () => { references: "", }); - expect(mockGoogleFetchWithToken).toHaveBeenCalledTimes(1); - const [, url, options] = mockGoogleFetchWithToken.mock.calls[0] as [ - string, - string, - { body?: string }, - ]; + expect(fetchMock).toHaveBeenCalledTimes(1); + const [url, options] = fetchMock.mock.calls[0] as [string, { body?: string }]; expect(url).toBe("https://gmail.googleapis.com/gmail/v1/users/me/messages/send"); const payload = JSON.parse(String(options.body)) as { raw: string }; const decoded = Buffer.from(payload.raw, "base64url").toString("utf-8"); diff --git a/packages/tests/unit/milaidy-agent-routes-followups.test.ts b/packages/tests/unit/milaidy-agent-routes-followups.test.ts index 3b700170b..966e043b3 100644 --- a/packages/tests/unit/milaidy-agent-routes-followups.test.ts +++ b/packages/tests/unit/milaidy-agent-routes-followups.test.ts @@ -1,8 +1,12 @@ -import { beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterAll, beforeEach, describe, expect, mock, test } from "bun:test"; import { NextRequest } from "next/server"; import { mockMiladyPricingMinimumDepositForRouteTests } from "../helpers/mock-milady-pricing-for-route-tests"; import { jsonRequest, routeParams } from "./api/route-test-helpers"; +afterAll(() => { + mock.restore(); +}); + const mockRequireAuthOrApiKeyWithOrg = mock(); const mockRequireServiceKey = mock(); const mockAuthenticateWaifuBridge = mock(); diff --git a/packages/tests/unit/milaidy-sandbox-bridge-security.test.ts b/packages/tests/unit/milaidy-sandbox-bridge-security.test.ts index 52b7813a0..380b1e629 100644 --- a/packages/tests/unit/milaidy-sandbox-bridge-security.test.ts +++ b/packages/tests/unit/milaidy-sandbox-bridge-security.test.ts @@ -1,4 +1,8 @@ -import { afterEach, beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterAll, afterEach, beforeEach, describe, expect, mock, test } from "bun:test"; + +afterAll(() => { + mock.restore(); +}); const mockFindRunningSandbox = mock(); const mockFindByIdAndOrg = mock(); diff --git a/packages/tests/unit/milaidy-sandbox-service-followups.test.ts b/packages/tests/unit/milaidy-sandbox-service-followups.test.ts index d5921e2ec..98add4550 100644 --- a/packages/tests/unit/milaidy-sandbox-service-followups.test.ts +++ b/packages/tests/unit/milaidy-sandbox-service-followups.test.ts @@ -1,4 +1,8 @@ -import { beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterAll, beforeEach, describe, expect, mock, test } from "bun:test"; + +afterAll(() => { + mock.restore(); +}); const mockTransaction = mock(); diff --git a/packages/tests/unit/provisioning-jobs-followups.test.ts b/packages/tests/unit/provisioning-jobs-followups.test.ts index 393abeefe..0f53546c8 100644 --- a/packages/tests/unit/provisioning-jobs-followups.test.ts +++ b/packages/tests/unit/provisioning-jobs-followups.test.ts @@ -1,4 +1,8 @@ -import { beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterAll, beforeEach, describe, expect, mock, test } from "bun:test"; + +afterAll(() => { + mock.restore(); +}); const mockJobsRepository = { claimPendingJobs: mock(), diff --git a/packages/tests/unit/proxy-pricing.test.ts b/packages/tests/unit/proxy-pricing.test.ts index d271b5fbd..5c7f7f500 100644 --- a/packages/tests/unit/proxy-pricing.test.ts +++ b/packages/tests/unit/proxy-pricing.test.ts @@ -1,4 +1,8 @@ -import { beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterAll, beforeEach, describe, expect, mock, test } from "bun:test"; + +afterAll(() => { + mock.restore(); +}); const mockCacheGet = mock(); const mockCacheSet = mock(); diff --git a/packages/tests/unit/service-jwt.test.ts b/packages/tests/unit/service-jwt.test.ts index 9cd6d3fb8..562f4b336 100644 --- a/packages/tests/unit/service-jwt.test.ts +++ b/packages/tests/unit/service-jwt.test.ts @@ -2,7 +2,7 @@ * Unit tests for service JWT verification. */ -import { afterEach, beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterAll, afterEach, beforeEach, describe, expect, mock, test } from "bun:test"; import * as jose from "jose"; import { isServiceJwtEnabled, verifyServiceJwt } from "@/lib/auth/service-jwt"; @@ -20,6 +20,10 @@ const TEST_SECRET = "test-jwt-secret-for-waifu-core-bridge"; describe("Service JWT Auth", () => { const saved: Record = {}; + afterAll(() => { + mock.restore(); + }); + beforeEach(() => { saved.MILADY_SERVICE_JWT_SECRET = process.env.MILADY_SERVICE_JWT_SECRET; process.env.MILADY_SERVICE_JWT_SECRET = TEST_SECRET; diff --git a/packages/tests/unit/v1-milaidy-provision-route.test.ts b/packages/tests/unit/v1-milaidy-provision-route.test.ts index 868657a47..208b0f4b8 100644 --- a/packages/tests/unit/v1-milaidy-provision-route.test.ts +++ b/packages/tests/unit/v1-milaidy-provision-route.test.ts @@ -1,9 +1,13 @@ -import { beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterAll, beforeEach, describe, expect, mock, test } from "bun:test"; import { NextRequest } from "next/server"; import { mockMiladyPricingMinimumDepositForRouteTests } from "../helpers/mock-milady-pricing-for-route-tests"; import { routeParams } from "./api/route-test-helpers"; +afterAll(() => { + mock.restore(); +}); + const mockRequireAuthOrApiKeyWithOrg = mock(); const mockAssertSafeOutboundUrl = mock(); const mockGetAgentForWrite = mock(); diff --git a/tsconfig.json b/tsconfig.json index 53eba3229..7439fdf4f 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -1,11 +1,7 @@ { "compilerOptions": { "target": "ES2020", - "lib": [ - "dom", - "dom.iterable", - "esnext" - ], + "lib": ["dom", "dom.iterable", "esnext"], "allowJs": true, "skipLibCheck": true, "strict": true, @@ -18,39 +14,21 @@ "jsx": "react-jsx", "incremental": true, "tsBuildInfoFile": ".tsbuildinfo", - "types": [ - "node" - ], + "types": ["node"], "plugins": [ { "name": "next" } ], "paths": { - "@/lib/*": [ - "./packages/lib/*" - ], - "@/db/*": [ - "./packages/db/*" - ], - "@/tests/*": [ - "./packages/tests/*" - ], - "@/types/*": [ - "./packages/types/*" - ], - "@/*": [ - "./*" - ], - "@/components/*": [ - "./packages/ui/src/components/*" - ], - "@elizaos/cloud-ui": [ - "./packages/ui/src/index.ts" - ], - "@elizaos/cloud-ui/*": [ - "./packages/ui/src/*" - ] + "@/lib/*": ["./packages/lib/*"], + "@/db/*": ["./packages/db/*"], + "@/tests/*": ["./packages/tests/*"], + "@/types/*": ["./packages/types/*"], + "@/*": ["./*"], + "@/components/*": ["./packages/ui/src/components/*"], + "@elizaos/cloud-ui": ["./packages/ui/src/index.ts"], + "@elizaos/cloud-ui/*": ["./packages/ui/src/*"] } }, "include": [ From acaee1fe40f7d1f8aef80f9347988196b3583013 Mon Sep 17 00:00:00 2001 From: Shaw Date: Sun, 5 Apr 2026 03:53:24 -0700 Subject: [PATCH 09/11] cloud: isolate oauth test module mocks --- .../eliza-app/connection-enforcement.test.ts | 13 +++++++------ .../unit/eliza-app/connections-route.test.ts | 19 ++++++++++--------- packages/tests/unit/mcp-google-tools.test.ts | 17 +++++++++-------- packages/tests/unit/mcp-hubspot-tools.test.ts | 16 ++++++++-------- packages/tests/unit/mcp-twitter-tools.test.ts | 14 ++++++++------ 5 files changed, 42 insertions(+), 37 deletions(-) diff --git a/packages/tests/unit/eliza-app/connection-enforcement.test.ts b/packages/tests/unit/eliza-app/connection-enforcement.test.ts index b1b99aa21..38e11d696 100644 --- a/packages/tests/unit/eliza-app/connection-enforcement.test.ts +++ b/packages/tests/unit/eliza-app/connection-enforcement.test.ts @@ -1,4 +1,4 @@ -import { afterEach, beforeEach, describe, expect, mock, test } from "bun:test"; +import { beforeAll, beforeEach, describe, expect, mock, test } from "bun:test"; const mockGenerateText = mock(); const mockGetConnectedPlatforms = mock(); @@ -44,12 +44,13 @@ mock.module("@/lib/utils/logger", () => ({ }, })); -import { - connectionEnforcementService, - detectProviderFromMessage, -} from "@/lib/services/eliza-app/connection-enforcement"; +let connectionEnforcementService: typeof import("@/lib/services/eliza-app/connection-enforcement").connectionEnforcementService; +let detectProviderFromMessage: typeof import("@/lib/services/eliza-app/connection-enforcement").detectProviderFromMessage; -afterEach(() => { +beforeAll(async () => { + ({ connectionEnforcementService, detectProviderFromMessage } = await import( + "@/lib/services/eliza-app/connection-enforcement" + )); mock.restore(); }); diff --git a/packages/tests/unit/eliza-app/connections-route.test.ts b/packages/tests/unit/eliza-app/connections-route.test.ts index dffc59040..128f33ce6 100644 --- a/packages/tests/unit/eliza-app/connections-route.test.ts +++ b/packages/tests/unit/eliza-app/connections-route.test.ts @@ -1,4 +1,4 @@ -import { afterAll, beforeEach, describe, expect, mock, test } from "bun:test"; +import { beforeAll, beforeEach, describe, expect, mock, test } from "bun:test"; import { NextRequest } from "next/server"; const mockValidateAuthHeader = mock(); @@ -30,11 +30,16 @@ mock.module("@/lib/services/oauth/provider-registry", () => ({ getProvider: mockGetProvider, })); -describe("Eliza App connections routes", () => { - afterAll(() => { - mock.restore(); - }); +let GET: typeof import("@/app/api/eliza-app/connections/route").GET; +let POST: typeof import("@/app/api/eliza-app/connections/[platform]/initiate/route").POST; + +beforeAll(async () => { + ({ GET } = await import("@/app/api/eliza-app/connections/route")); + ({ POST } = await import("@/app/api/eliza-app/connections/[platform]/initiate/route")); + mock.restore(); +}); +describe("Eliza App connections routes", () => { beforeEach(() => { mockValidateAuthHeader.mockReset(); mockListConnections.mockReset(); @@ -48,8 +53,6 @@ describe("Eliza App connections routes", () => { }); test("returns user-scoped Google connection status", async () => { - const { GET } = await import("@/app/api/eliza-app/connections/route"); - mockListConnections.mockResolvedValue([ { id: "conn-1", @@ -85,8 +88,6 @@ describe("Eliza App connections routes", () => { }); test("initiates Google OAuth with Eliza App callback bridge", async () => { - const { POST } = await import("@/app/api/eliza-app/connections/[platform]/initiate/route"); - mockInitiateAuth.mockResolvedValue({ authUrl: "https://accounts.google.com/o/oauth2/v2/auth?state=test-state", state: "test-state", diff --git a/packages/tests/unit/mcp-google-tools.test.ts b/packages/tests/unit/mcp-google-tools.test.ts index 4afede5eb..069c28cae 100644 --- a/packages/tests/unit/mcp-google-tools.test.ts +++ b/packages/tests/unit/mcp-google-tools.test.ts @@ -5,7 +5,7 @@ * Real: all handler logic, helpers, mappers, error formatting. */ -import { afterAll, afterEach, beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterEach, beforeAll, beforeEach, describe, expect, mock, test } from "bun:test"; import { authContextStorage } from "@/app/api/mcp/lib/context"; import type { ListConnectionsParams, OAuthConnection } from "@/lib/services/oauth/types"; @@ -73,6 +73,8 @@ const mockOAuth = { mock.module("@/lib/services/oauth", () => ({ oauthService: mockOAuth })); +let registerGoogleTools: typeof import("@/app/api/mcp/tools/google").registerGoogleTools; + // ── Test helpers ──────────────────────────────────────────────────────────── type AnyFn = (...args: unknown[]) => unknown; @@ -97,7 +99,6 @@ async function callTool( args: Record = {}, orgId = "org-1", ): Promise { - const { registerGoogleTools } = await import("@/app/api/mcp/tools/google"); let handler: AnyFn | undefined; const mockServer = { registerTool: (n: string, _s: unknown, h: AnyFn) => { @@ -116,11 +117,12 @@ function parse(result: GoogleToolHandlerResult) { // ══════════════════════════════════════════════════════════════════════════════ -describe("Google MCP Tools", () => { - afterAll(() => { - mock.restore(); - }); +beforeAll(async () => { + ({ registerGoogleTools } = await import("@/app/api/mcp/tools/google")); + mock.restore(); +}); +describe("Google MCP Tools", () => { beforeEach(() => { setupMockFetch(); mockOAuth.getValidTokenByPlatform.mockReset(); @@ -146,8 +148,7 @@ describe("Google MCP Tools", () => { describe("Registration", () => { test("exports registerGoogleTools", async () => { - const mod = await import("@/app/api/mcp/tools/google"); - expect(typeof mod.registerGoogleTools).toBe("function"); + expect(typeof registerGoogleTools).toBe("function"); }); test("registers all expected tools", async () => { diff --git a/packages/tests/unit/mcp-hubspot-tools.test.ts b/packages/tests/unit/mcp-hubspot-tools.test.ts index 989852726..5b241efdf 100644 --- a/packages/tests/unit/mcp-hubspot-tools.test.ts +++ b/packages/tests/unit/mcp-hubspot-tools.test.ts @@ -9,7 +9,7 @@ * - Network error handling */ -import { afterAll, afterEach, beforeEach, describe, expect, mock, test } from "bun:test"; +import { afterEach, beforeAll, beforeEach, describe, expect, mock, test } from "bun:test"; import { authContextStorage } from "@/app/api/mcp/lib/context"; import type { GetTokenByPlatformParams, @@ -95,6 +95,8 @@ mock.module("@/lib/services/oauth", () => ({ oauthService: mockOAuthService, })); +let registerHubSpotTools: typeof import("@/app/api/mcp/tools/hubspot").registerHubSpotTools; + // Create mock auth context function createMockAuth(orgId: string = "test-org-123") { return { @@ -106,11 +108,12 @@ function createMockAuth(orgId: string = "test-org-123") { } as any; } -describe("HubSpot MCP Tools", () => { - afterAll(() => { - mock.restore(); - }); +beforeAll(async () => { + ({ registerHubSpotTools } = await import("@/app/api/mcp/tools/hubspot")); + mock.restore(); +}); +describe("HubSpot MCP Tools", () => { beforeEach(() => { setupMockFetch(); mockOAuthService.getValidTokenByPlatform.mockReset(); @@ -127,14 +130,11 @@ describe("HubSpot MCP Tools", () => { describe("Module Registration", () => { test("registerHubSpotTools is exported", async () => { - const { registerHubSpotTools } = await import("@/app/api/mcp/tools/hubspot"); expect(registerHubSpotTools).toBeDefined(); expect(typeof registerHubSpotTools).toBe("function"); }); test("registers all expected tools", async () => { - const { registerHubSpotTools } = await import("@/app/api/mcp/tools/hubspot"); - const registeredTools: string[] = []; const mockServer = { registerTool: (name: string, _schema: any, _handler: any) => { diff --git a/packages/tests/unit/mcp-twitter-tools.test.ts b/packages/tests/unit/mcp-twitter-tools.test.ts index 284ff49c0..7c1fec1be 100644 --- a/packages/tests/unit/mcp-twitter-tools.test.ts +++ b/packages/tests/unit/mcp-twitter-tools.test.ts @@ -5,14 +5,10 @@ * Real: all handler logic, helpers, mappers, error formatting. */ -import { afterAll, beforeEach, describe, expect, mock, test } from "bun:test"; +import { beforeAll, beforeEach, describe, expect, mock, test } from "bun:test"; import { authContextStorage } from "@/app/api/mcp/lib/context"; import type { OAuthConnection } from "@/lib/services/oauth/types"; -afterAll(() => { - mock.restore(); -}); - function twitterOAuthFixture( o: Partial & Pick, ): OAuthConnection { @@ -226,6 +222,8 @@ const mockOAuth = { mock.module("@/lib/services/oauth", () => ({ oauthService: mockOAuth })); +let registerTwitterTools: typeof import("@/app/api/mcp/tools/twitter").registerTwitterTools; + // ── Test helpers ───────────────────────────────────────────────────────────── type AnyFn = (...args: unknown[]) => unknown; @@ -250,7 +248,6 @@ async function callTool( args: Record = {}, orgId = "org-1", ): Promise { - const { registerTwitterTools } = await import("@/app/api/mcp/tools/twitter"); let handler: AnyFn | undefined; const mockServer = { registerTool: (n: string, _s: unknown, h: AnyFn) => { @@ -269,6 +266,11 @@ function parse(result: TwitterToolHandlerResult) { // ══════════════════════════════════════════════════════════════════════════════ +beforeAll(async () => { + ({ registerTwitterTools } = await import("@/app/api/mcp/tools/twitter")); + mock.restore(); +}); + describe("Twitter MCP Tools", () => { beforeEach(() => { resetTwitterMocks(); From 4dfdfe156f713918bf1375445deb0bbca10dd9fb Mon Sep 17 00:00:00 2001 From: Shaw Date: Sun, 5 Apr 2026 04:57:48 -0700 Subject: [PATCH 10/11] cloud: harden oauth test isolation --- .../unit/blooio-automation-service.test.ts | 12 ++++++-- .../unit/eliza-app/connections-route.test.ts | 13 --------- .../unit/milady-google-connector.test.ts | 17 +++++++---- .../tests/unit/milady-google-routes.test.ts | 28 ++++++++++++------- 4 files changed, 38 insertions(+), 32 deletions(-) diff --git a/packages/tests/unit/blooio-automation-service.test.ts b/packages/tests/unit/blooio-automation-service.test.ts index 1eb79c335..3e7e61efc 100644 --- a/packages/tests/unit/blooio-automation-service.test.ts +++ b/packages/tests/unit/blooio-automation-service.test.ts @@ -12,6 +12,9 @@ import { beforeEach, describe, expect, it, mock } from "bun:test"; +type BlooioAutomationService = + typeof import("@/lib/services/blooio-automation")["blooioAutomationService"]; + // Mock external dependencies const _mockSecretsService = { create: mock(() => Promise.resolve()), @@ -28,9 +31,12 @@ const _mockValidateBlooioChatId = mock(() => true); const shouldRunBlooioTests = Boolean(process.env.DATABASE_URL) && process.env.SKIP_DB_DEPENDENT !== "1"; -const blooioAutomationService = shouldRunBlooioTests - ? (await import("@/lib/services/blooio-automation")).blooioAutomationService - : null; +// `describe.skipIf` ensures these tests never execute when the service is unavailable. +const blooioAutomationService = ( + shouldRunBlooioTests + ? (await import("@/lib/services/blooio-automation")).blooioAutomationService + : undefined +) as BlooioAutomationService; describe.skipIf(!shouldRunBlooioTests)("BlooioAutomationService", () => { const testOrgId = "11111111-1111-1111-1111-111111111111"; diff --git a/packages/tests/unit/eliza-app/connections-route.test.ts b/packages/tests/unit/eliza-app/connections-route.test.ts index 128f33ce6..0348c1f67 100644 --- a/packages/tests/unit/eliza-app/connections-route.test.ts +++ b/packages/tests/unit/eliza-app/connections-route.test.ts @@ -4,14 +4,6 @@ import { NextRequest } from "next/server"; const mockValidateAuthHeader = mock(); const mockListConnections = mock(); const mockInitiateAuth = mock(); -const mockGetProvider = mock((platform: string) => - platform === "google" - ? { - id: "google", - name: "Google", - } - : null, -); mock.module("@/lib/services/eliza-app", () => ({ elizaAppSessionService: { @@ -26,10 +18,6 @@ mock.module("@/lib/services/oauth", () => ({ }, })); -mock.module("@/lib/services/oauth/provider-registry", () => ({ - getProvider: mockGetProvider, -})); - let GET: typeof import("@/app/api/eliza-app/connections/route").GET; let POST: typeof import("@/app/api/eliza-app/connections/[platform]/initiate/route").POST; @@ -44,7 +32,6 @@ describe("Eliza App connections routes", () => { mockValidateAuthHeader.mockReset(); mockListConnections.mockReset(); mockInitiateAuth.mockReset(); - mockGetProvider.mockClear(); mockValidateAuthHeader.mockResolvedValue({ userId: "user-1", diff --git a/packages/tests/unit/milady-google-connector.test.ts b/packages/tests/unit/milady-google-connector.test.ts index e048ebe0f..640fc614b 100644 --- a/packages/tests/unit/milady-google-connector.test.ts +++ b/packages/tests/unit/milady-google-connector.test.ts @@ -209,7 +209,7 @@ describe("milady Google connector service", () => { mockListConnections.mockResolvedValue([ createConnection({ id: "conn-google-agent", - userId: null, + userId: undefined, connectionRole: "agent", email: "milady-agent@example.com", username: "milady-agent", @@ -392,7 +392,13 @@ describe("milady Google connector service", () => { }); test("sends Gmail replies with sanitized RFC822 headers", async () => { - const fetchMock = mock(async () => new Response(null, { status: 200 })); + let sentUrl: string | undefined; + let sentBody: string | undefined; + const fetchMock = mock(async (url: string | URL | Request, options?: RequestInit) => { + sentUrl = url.toString(); + sentBody = typeof options?.body === "string" ? options.body : undefined; + return new Response(null, { status: 200 }); + }); globalThis.fetch = fetchMock as unknown as typeof fetch; await sendManagedGoogleReply({ @@ -408,9 +414,8 @@ describe("milady Google connector service", () => { }); expect(fetchMock).toHaveBeenCalledTimes(1); - const [url, options] = fetchMock.mock.calls[0] as [string, { body?: string }]; - expect(url).toBe("https://gmail.googleapis.com/gmail/v1/users/me/messages/send"); - const payload = JSON.parse(String(options.body)) as { raw: string }; + expect(sentUrl).toBe("https://gmail.googleapis.com/gmail/v1/users/me/messages/send"); + const payload = JSON.parse(String(sentBody)) as { raw: string }; const decoded = Buffer.from(payload.raw, "base64url").toString("utf-8"); expect(decoded).toContain("To: founder@example.com"); expect(decoded).toContain("Cc: ops@example.com"); @@ -424,7 +429,7 @@ describe("milady Google connector service", () => { mockListConnections.mockResolvedValue([ createConnection({ id: "conn-google-agent", - userId: null, + userId: undefined, connectionRole: "agent", }), createConnection({ diff --git a/packages/tests/unit/milady-google-routes.test.ts b/packages/tests/unit/milady-google-routes.test.ts index bf0651dc8..e8a43e80e 100644 --- a/packages/tests/unit/milady-google-routes.test.ts +++ b/packages/tests/unit/milady-google-routes.test.ts @@ -1,4 +1,4 @@ -import { afterAll, beforeEach, describe, expect, mock, test } from "bun:test"; +import { beforeAll, beforeEach, describe, expect, mock, test } from "bun:test"; import { NextRequest } from "next/server"; import { jsonRequest } from "./api/route-test-helpers"; @@ -34,19 +34,27 @@ mock.module("@/lib/services/milady-google-connector", () => ({ }, })); -import { POST as postCalendarEvent } from "@/app/api/v1/milady/google/calendar/events/route"; -import { GET as getCalendarFeed } from "@/app/api/v1/milady/google/calendar/feed/route"; -import { POST as postConnectInitiate } from "@/app/api/v1/milady/google/connect/initiate/route"; -import { POST as postDisconnect } from "@/app/api/v1/milady/google/disconnect/route"; +let postCalendarEvent: typeof import("@/app/api/v1/milady/google/calendar/events/route").POST; +let getCalendarFeed: typeof import("@/app/api/v1/milady/google/calendar/feed/route").GET; +let postConnectInitiate: typeof import("@/app/api/v1/milady/google/connect/initiate/route").POST; +let postDisconnect: typeof import("@/app/api/v1/milady/google/disconnect/route").POST; +let postReplySend: typeof import("@/app/api/v1/milady/google/gmail/reply-send/route").POST; +let getGmailTriage: typeof import("@/app/api/v1/milady/google/gmail/triage/route").GET; +let getStatus: typeof import("@/app/api/v1/milady/google/status/route").GET; -afterAll(() => { +beforeAll(async () => { + ({ POST: postCalendarEvent } = await import("@/app/api/v1/milady/google/calendar/events/route")); + ({ GET: getCalendarFeed } = await import("@/app/api/v1/milady/google/calendar/feed/route")); + ({ POST: postConnectInitiate } = await import( + "@/app/api/v1/milady/google/connect/initiate/route" + )); + ({ POST: postDisconnect } = await import("@/app/api/v1/milady/google/disconnect/route")); + ({ POST: postReplySend } = await import("@/app/api/v1/milady/google/gmail/reply-send/route")); + ({ GET: getGmailTriage } = await import("@/app/api/v1/milady/google/gmail/triage/route")); + ({ GET: getStatus } = await import("@/app/api/v1/milady/google/status/route")); mock.restore(); }); -import { POST as postReplySend } from "@/app/api/v1/milady/google/gmail/reply-send/route"; -import { GET as getGmailTriage } from "@/app/api/v1/milady/google/gmail/triage/route"; -import { GET as getStatus } from "@/app/api/v1/milady/google/status/route"; - describe("Milady managed Google routes", () => { beforeEach(() => { mockRequireAuthOrApiKeyWithOrg.mockReset(); From 31ebdcadb6ba721a815fc2ad938b3aa7847a2baf Mon Sep 17 00:00:00 2001 From: Shaw Date: Sun, 5 Apr 2026 05:02:16 -0700 Subject: [PATCH 11/11] cloud: clean up managed google connector test mocks --- .../unit/milady-google-connector.test.ts | 21 ------------------- 1 file changed, 21 deletions(-) diff --git a/packages/tests/unit/milady-google-connector.test.ts b/packages/tests/unit/milady-google-connector.test.ts index 640fc614b..36424796e 100644 --- a/packages/tests/unit/milady-google-connector.test.ts +++ b/packages/tests/unit/milady-google-connector.test.ts @@ -43,27 +43,6 @@ mock.module("@/lib/services/oauth", () => ({ }, })); -mock.module("@/lib/services/oauth/oauth-service", () => ({ - getPreferredActiveConnection: ( - connections: OAuthConnection[], - userId?: string, - connectionRole?: "owner" | "agent", - ) => - connections.find( - (connection) => - connection.status === "active" && - (!userId || connection.userId === userId) && - (!connectionRole || connection.connectionRole === connectionRole), - ) ?? - connections.find( - (connection) => - connection.status === "active" && - (!connectionRole || connection.connectionRole === connectionRole), - ) ?? - connections.find((connection) => connection.status === "active") ?? - null, -})); - import { disconnectManagedGoogleConnection, fetchManagedGoogleCalendarFeed,