diff --git a/packages/control-plane/src/session/durable-object.ts b/packages/control-plane/src/session/durable-object.ts index 8e47a52..017e645 100644 --- a/packages/control-plane/src/session/durable-object.ts +++ b/packages/control-plane/src/session/durable-object.ts @@ -33,13 +33,10 @@ import { import { createSourceControlProvider as createSourceControlProviderImpl, resolveScmProviderFromEnv, - SourceControlProviderError, type SourceControlProvider, type SourceControlAuthContext, type GitPushSpec, } from "../source-control"; -import { resolveHeadBranchForPr } from "../source-control/branch-resolution"; -import { generateBranchName, type ManualPullRequestArtifactMetadata } from "@open-inspect/shared"; import { DEFAULT_MODEL, isValidModel, @@ -62,6 +59,7 @@ import type { import type { SessionRow, ParticipantRow, ArtifactRow, SandboxRow, SandboxCommand } from "./types"; import { SessionRepository } from "./repository"; import { SessionWebSocketManagerImpl, type SessionWebSocketManager } from "./websocket-manager"; +import { SessionPullRequestService } from "./pull-request-service"; import { RepoSecretsStore } from "../db/repo-secrets"; import { GlobalSecretsStore } from "../db/global-secrets"; import { mergeSecrets } from "../db/secrets-validation"; @@ -1985,13 +1983,43 @@ export class SessionDO extends DurableObject { this.log.error("Failed to notify slack-bot after retries", { message_id: messageId }); } + /** + * Get the prompting participant for PR creation. + * Returns the participant who triggered the currently processing message. + */ + private async getPromptingParticipantForPR(): Promise< + | { participant: ParticipantRow; error?: never; status?: never } + | { participant?: never; error: string; status: number } + > { + const processingMessage = this.repository.getProcessingMessageAuthor(); + + if (!processingMessage) { + this.log.warn("PR creation failed: no processing message found"); + return { + error: "No active prompt found. PR creation must be triggered by a user prompt.", + status: 400, + }; + } + + const participant = this.repository.getParticipantById(processingMessage.author_id); + + if (!participant) { + this.log.warn("PR creation failed: participant not found", { + participantId: processingMessage.author_id, + }); + return { error: "User not found. Please re-authenticate.", status: 401 }; + } + + return { participant }; + } + /** * Check if a participant's GitHub token is expired. * Returns true if expired or will expire within buffer time. */ private isGitHubTokenExpired(participant: ParticipantRow, bufferMs = 60000): boolean { if (!participant.github_token_expires_at) { - return false; // No expiration set, assume valid + return false; } return Date.now() + bufferMs >= participant.github_token_expires_at; } @@ -2057,50 +2085,15 @@ export class SessionDO extends DurableObject { } } - /** - * Get the prompting participant for PR creation. - * Returns the participant who triggered the currently processing message. - */ - private async getPromptingParticipantForPR(): Promise< - | { participant: ParticipantRow; error?: never; status?: never } - | { participant?: never; error: string; status: number } - > { - // Find the currently processing message - const processingMessage = this.repository.getProcessingMessageAuthor(); - - if (!processingMessage) { - this.log.warn("PR creation failed: no processing message found"); - return { - error: "No active prompt found. PR creation must be triggered by a user prompt.", - status: 400, - }; - } - - const participantId = processingMessage.author_id; - - // Get the participant record - const participant = this.repository.getParticipantById(participantId); - - if (!participant) { - this.log.warn("PR creation failed: participant not found", { participantId }); - return { error: "User not found. Please re-authenticate.", status: 401 }; - } - - return { participant }; - } - /** * Resolve the prompting participant's OAuth credentials for API-based PR creation. - * Returns `auth: null` when no user OAuth token is available (manual PR fallback). + * Returns auth: null only when user OAuth is not configured; returns an HTTP error for token failures. */ - private async resolvePromptingUserAuthForPR(participant: ParticipantRow): Promise< - | { - participant: ParticipantRow; - auth: SourceControlAuthContext | null; - error?: never; - status?: never; - } - | { participant?: never; auth?: never; error: string; status: number } + private async resolvePromptingUserAuthForPR( + participant: ParticipantRow + ): Promise< + | { auth: SourceControlAuthContext | null; error?: never; status?: never } + | { auth?: never; error: string; status: number } > { let resolvedParticipant = participant; @@ -2108,7 +2101,7 @@ export class SessionDO extends DurableObject { this.log.info("PR creation: prompting user has no OAuth token, using manual fallback", { user_id: resolvedParticipant.user_id, }); - return { participant: resolvedParticipant, auth: null }; + return { auth: null }; } if (this.isGitHubTokenExpired(resolvedParticipant)) { @@ -2120,6 +2113,9 @@ export class SessionDO extends DurableObject { if (refreshed) { resolvedParticipant = refreshed; } else { + this.log.warn("GitHub token refresh failed, returning auth error", { + user_id: resolvedParticipant.user_id, + }); return { error: "Your GitHub token has expired and could not be refreshed. Please re-authenticate.", @@ -2129,7 +2125,7 @@ export class SessionDO extends DurableObject { } if (!resolvedParticipant.github_access_token_encrypted) { - return { participant: resolvedParticipant, auth: null }; + return { auth: null }; } try { @@ -2139,7 +2135,6 @@ export class SessionDO extends DurableObject { ); return { - participant: resolvedParticipant, auth: { authType: "oauth", token: accessToken, @@ -2504,9 +2499,7 @@ export class SessionDO extends DurableObject { /** * Handle PR creation request. - * 1. Resolve prompting participant and branch metadata - * 2. Push branch to remote via provider push auth - * 3. Create PR via OAuth token, or return manual PR URL fallback + * Resolves prompting participant and auth in DO, then delegates PR orchestration. */ private async handleCreatePR(request: Request): Promise { const body = (await request.json()) as { @@ -2530,168 +2523,54 @@ export class SessionDO extends DurableObject { } const promptingParticipant = promptingParticipantResult.participant; - this.log.info("Creating PR", { user_id: promptingParticipant.user_id }); + const authResolution = await this.resolvePromptingUserAuthForPR(promptingParticipant); + if ("error" in authResolution) { + return Response.json({ error: authResolution.error }, { status: authResolution.status }); + } - try { - const sessionId = session.session_name || session.id; - const generatedHeadBranch = generateBranchName(sessionId); - - const initialArtifacts = this.repository.listArtifacts(); - const existingPrArtifact = initialArtifacts.find((artifact) => artifact.type === "pr"); - if (existingPrArtifact) { - return Response.json( - { error: "A pull request has already been created for this session." }, - { status: 409 } - ); - } + const sessionId = session.session_name || session.id; + const webAppUrl = this.env.WEB_APP_URL || this.env.WORKER_URL || ""; + const sessionUrl = webAppUrl + "/session/" + sessionId; - // Generate push auth via provider app credentials (not user token) - // User token (if available) is only used for PR API call below - let pushAuth; - try { - pushAuth = await this.sourceControlProvider.generatePushAuth(); - this.log.info("Generated fresh push auth token"); - } catch (err) { - this.log.error("Failed to generate push auth", { - error: err instanceof Error ? err : String(err), + const pullRequestService = new SessionPullRequestService({ + repository: this.repository, + sourceControlProvider: this.sourceControlProvider, + log: this.log, + generateId: () => generateId(), + pushBranchToRemote: (headBranch, pushSpec) => this.pushBranchToRemote(headBranch, pushSpec), + broadcastArtifactCreated: (artifact) => { + this.broadcast({ + type: "artifact_created", + artifact, }); - const errorMessage = - err instanceof SourceControlProviderError - ? err.message - : "Failed to generate push authentication"; - return Response.json({ error: errorMessage }, { status: 500 }); - } - - // Resolve repository metadata with app auth so this still works for Slack sessions - const appAuth: SourceControlAuthContext = { - authType: "app", - token: pushAuth.token, - }; - const repoInfo = await this.sourceControlProvider.getRepository(appAuth, { - owner: session.repo_owner, - name: session.repo_name, - }); - const baseBranch = body.baseBranch || repoInfo.defaultBranch; - const branchResolution = resolveHeadBranchForPr({ - requestedHeadBranch: body.headBranch, - sessionBranchName: session.branch_name, - generatedBranchName: generatedHeadBranch, - baseBranch, - }); - const headBranch = branchResolution.headBranch; - this.log.info("Resolved PR head branch", { - requested_head_branch: body.headBranch ?? null, - session_branch_name: session.branch_name, - generated_head_branch: generatedHeadBranch, - resolved_head_branch: headBranch, - resolution_source: branchResolution.source, - base_branch: baseBranch, - }); - const pushSpec = this.sourceControlProvider.buildGitPushSpec({ - owner: session.repo_owner, - name: session.repo_name, - sourceRef: "HEAD", - targetBranch: headBranch, - auth: pushAuth, - force: true, - }); - - // Push branch to remote via sandbox (session-layer coordination) - const pushResult = await this.pushBranchToRemote(headBranch, pushSpec); - - if (!pushResult.success) { - return Response.json({ error: pushResult.error }, { status: 500 }); - } - - // Update session with branch name after push succeeds - this.repository.updateSessionBranch(session.id, headBranch); - - // Re-check artifacts after async work to avoid stale reads on retries/interleaving. - const latestArtifacts = this.repository.listArtifacts(); - const latestPrArtifact = latestArtifacts.find((artifact) => artifact.type === "pr"); - if (latestPrArtifact) { - return Response.json( - { error: "A pull request has already been created for this session." }, - { status: 409 } - ); - } - - const authResolution = await this.resolvePromptingUserAuthForPR(promptingParticipant); - if ("error" in authResolution) { - return this.buildManualPrFallbackResponse( - session, - headBranch, - baseBranch, - latestArtifacts, - authResolution.error - ); - } - - if (!authResolution.auth) { - return this.buildManualPrFallbackResponse(session, headBranch, baseBranch, latestArtifacts); - } - - // Append session link footer to agent's PR body - const webAppUrl = this.env.WEB_APP_URL || this.env.WORKER_URL || ""; - const sessionUrl = `${webAppUrl}/session/${sessionId}`; - const fullBody = body.body + `\n\n---\n*Created with [Open-Inspect](${sessionUrl})*`; - - // Create the PR via provider (using the prompting user's OAuth token) - const prResult = await this.sourceControlProvider.createPullRequest(authResolution.auth, { - repository: repoInfo, - title: body.title, - body: fullBody, - sourceBranch: headBranch, - targetBranch: baseBranch, - }); + }, + }); - // Store the PR as an artifact - const artifactId = generateId(); - const now = Date.now(); - this.repository.createArtifact({ - id: artifactId, - type: "pr", - url: prResult.webUrl, - metadata: JSON.stringify({ - number: prResult.id, - state: prResult.state, - head: headBranch, - base: baseBranch, - }), - createdAt: now, - }); + const result = await pullRequestService.createPullRequest({ + ...body, + promptingUserId: promptingParticipant.user_id, + promptingAuth: authResolution.auth, + sessionUrl, + }); - // Broadcast PR creation to all clients - this.broadcast({ - type: "artifact_created", - artifact: { - id: artifactId, - type: "pr", - url: prResult.webUrl, - prNumber: prResult.id, - }, - }); + if (result.kind === "error") { + return Response.json({ error: result.error }, { status: result.status }); + } + if (result.kind === "manual") { return Response.json({ - prNumber: prResult.id, - prUrl: prResult.webUrl, - state: prResult.state, - }); - } catch (error) { - this.log.error("PR creation failed", { - error: error instanceof Error ? error : String(error), + status: "manual", + createPrUrl: result.createPrUrl, + headBranch: result.headBranch, + baseBranch: result.baseBranch, }); - - // Handle SourceControlProviderError with HTTP status - if (error instanceof SourceControlProviderError) { - return Response.json({ error: error.message }, { status: error.httpStatus || 500 }); - } - - return Response.json( - { error: error instanceof Error ? error.message : "Failed to create PR" }, - { status: 500 } - ); } + + return Response.json({ + prNumber: result.prNumber, + prUrl: result.prUrl, + state: result.state, + }); } private parseArtifactMetadata( @@ -2712,121 +2591,6 @@ export class SessionDO extends DurableObject { } } - private getExistingManualBranchArtifact( - artifacts: ArtifactRow[], - headBranch: string - ): { artifact: ArtifactRow; metadata: Record } | null { - for (const artifact of artifacts) { - if (artifact.type !== "branch") { - continue; - } - - const metadata = this.parseArtifactMetadata(artifact); - if (!metadata) { - continue; - } - - if (metadata.mode === "manual_pr" && metadata.head === headBranch) { - return { artifact, metadata }; - } - } - - return null; - } - - private getCreatePrUrlFromManualArtifact( - existing: { artifact: ArtifactRow; metadata: Record }, - fallbackUrl: string - ): string { - const metadataUrl = existing.metadata.createPrUrl; - if (typeof metadataUrl === "string" && metadataUrl.length > 0) { - return metadataUrl; - } - - if (existing.artifact.url && existing.artifact.url.length > 0) { - return existing.artifact.url; - } - - return fallbackUrl; - } - - private buildManualPrFallbackResponse( - session: SessionRow, - headBranch: string, - baseBranch: string, - artifacts: ArtifactRow[], - reason?: string - ): Response { - const manualCreatePrUrl = this.sourceControlProvider.buildManualPullRequestUrl({ - owner: session.repo_owner, - name: session.repo_name, - sourceBranch: headBranch, - targetBranch: baseBranch, - }); - - const existingManualArtifact = this.getExistingManualBranchArtifact(artifacts, headBranch); - if (existingManualArtifact) { - const createPrUrl = this.getCreatePrUrlFromManualArtifact( - existingManualArtifact, - manualCreatePrUrl - ); - this.log.info("Using manual PR fallback", { - head_branch: headBranch, - base_branch: baseBranch, - session_id: session.session_name || session.id, - existing_artifact_id: existingManualArtifact.artifact.id, - reason: reason ?? "missing_oauth_token", - }); - return Response.json({ - status: "manual", - createPrUrl, - headBranch, - baseBranch, - }); - } - - const artifactId = generateId(); - const now = Date.now(); - const metadata: ManualPullRequestArtifactMetadata = { - head: headBranch, - base: baseBranch, - mode: "manual_pr", - createPrUrl: manualCreatePrUrl, - provider: this.sourceControlProvider.name, - }; - this.repository.createArtifact({ - id: artifactId, - type: "branch", - url: manualCreatePrUrl, - metadata: JSON.stringify(metadata), - createdAt: now, - }); - - this.broadcast({ - type: "artifact_created", - artifact: { - id: artifactId, - type: "branch", - url: manualCreatePrUrl, - }, - }); - - this.log.info("Using manual PR fallback", { - head_branch: headBranch, - base_branch: baseBranch, - session_id: session.session_name || session.id, - artifact_id: artifactId, - reason: reason ?? "missing_oauth_token", - }); - - return Response.json({ - status: "manual", - createPrUrl: manualCreatePrUrl, - headBranch, - baseBranch, - }); - } - /** * Generate a WebSocket authentication token for a participant. * diff --git a/packages/control-plane/src/session/pull-request-service.test.ts b/packages/control-plane/src/session/pull-request-service.test.ts new file mode 100644 index 0000000..0e54b81 --- /dev/null +++ b/packages/control-plane/src/session/pull-request-service.test.ts @@ -0,0 +1,250 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import type { Logger } from "../logger"; +import type { SourceControlProvider } from "../source-control"; +import type { ArtifactRow, SessionRow } from "./types"; +import { + SessionPullRequestService, + type CreatePullRequestInput, + type PullRequestRepository, + type PullRequestServiceDeps, +} from "./pull-request-service"; + +function createMockLogger(): Logger { + return { + debug: vi.fn(), + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + child: vi.fn(() => createMockLogger()), + }; +} + +function createSession(overrides: Partial = {}): SessionRow { + return { + id: "session-1", + session_name: "session-name-1", + title: null, + repo_owner: "acme", + repo_name: "web", + repo_id: 123, + repo_default_branch: "main", + branch_name: null, + base_sha: null, + current_sha: null, + opencode_session_id: null, + model: "anthropic/claude-sonnet-4-5", + reasoning_effort: null, + status: "active", + created_at: 1, + updated_at: 1, + ...overrides, + }; +} + +function createMockProvider() { + return { + name: "github", + generatePushAuth: vi.fn(async () => ({ authType: "app", token: "app-token" as const })), + getRepository: vi.fn(async () => ({ + owner: "acme", + name: "web", + fullName: "acme/web", + defaultBranch: "main", + isPrivate: true, + providerRepoId: 123, + })), + createPullRequest: vi.fn(async () => ({ + id: 42, + webUrl: "https://github.com/acme/web/pull/42", + apiUrl: "https://api.github.com/repos/acme/web/pulls/42", + state: "open" as const, + sourceBranch: "open-inspect/session-name-1", + targetBranch: "main", + })), + buildManualPullRequestUrl: vi.fn( + (config: { sourceBranch: string; targetBranch: string }) => + `https://github.com/acme/web/pull/new/${config.targetBranch}...${config.sourceBranch}` + ), + buildGitPushSpec: vi.fn((config: { targetBranch: string }) => ({ + remoteUrl: "https://example.invalid/repo.git", + redactedRemoteUrl: "https://example.invalid/.git", + refspec: `HEAD:refs/heads/${config.targetBranch}`, + targetBranch: config.targetBranch, + force: true, + })), + } as unknown as SourceControlProvider; +} + +function createInput(overrides: Partial = {}): CreatePullRequestInput { + return { + title: "Test PR", + body: "Body text", + promptingUserId: "user-1", + promptingAuth: null, + sessionUrl: "https://app.example.com/session/session-name-1", + ...overrides, + }; +} + +function createTestHarness() { + const log = createMockLogger(); + const provider = createMockProvider(); + const artifacts: ArtifactRow[] = []; + let session: SessionRow | null = createSession(); + + const repository: PullRequestRepository = { + getSession: () => session, + updateSessionBranch: (sessionId, branchName) => { + if (session && session.id === sessionId) { + session = { ...session, branch_name: branchName }; + } + }, + listArtifacts: () => [...artifacts], + createArtifact: (data) => { + artifacts.unshift({ + id: data.id, + type: data.type, + url: data.url, + metadata: data.metadata, + created_at: data.createdAt, + } as ArtifactRow); + }, + }; + + let idCounter = 0; + const deps: PullRequestServiceDeps = { + repository, + sourceControlProvider: provider, + log, + generateId: () => `id-${++idCounter}`, + pushBranchToRemote: vi.fn(async () => ({ success: true as const })), + broadcastArtifactCreated: vi.fn(), + }; + + const service = new SessionPullRequestService(deps); + + return { + service, + deps, + provider, + artifacts, + setSession: (next: SessionRow | null) => { + session = next; + }, + }; +} + +describe("SessionPullRequestService", () => { + let harness: ReturnType; + + beforeEach(() => { + harness = createTestHarness(); + }); + + it("returns 404 when session is missing", async () => { + harness.setSession(null); + + const result = await harness.service.createPullRequest(createInput()); + + expect(result).toEqual({ kind: "error", status: 404, error: "Session not found" }); + }); + + it("returns 409 when PR artifact already exists", async () => { + harness.artifacts.push({ + id: "artifact-pr-existing", + type: "pr", + url: "https://github.com/acme/web/pull/1", + metadata: null, + created_at: Date.now(), + }); + + const result = await harness.service.createPullRequest(createInput()); + + expect(result).toEqual({ + kind: "error", + status: 409, + error: "A pull request has already been created for this session.", + }); + expect(harness.provider.generatePushAuth).not.toHaveBeenCalled(); + }); + + it("returns 500 when push to remote fails", async () => { + harness.deps.pushBranchToRemote = vi.fn(async () => ({ + success: false as const, + error: "Failed to push branch: timeout", + })); + harness.service = new SessionPullRequestService(harness.deps); + + const result = await harness.service.createPullRequest( + createInput({ promptingAuth: { authType: "oauth", token: "user-token" } }) + ); + + expect(result).toEqual({ + kind: "error", + status: 500, + error: "Failed to push branch: timeout", + }); + }); + + it("returns manual fallback when prompting auth is unavailable", async () => { + const result = await harness.service.createPullRequest(createInput({ promptingAuth: null })); + + expect(result.kind).toBe("manual"); + if (result.kind === "manual") { + expect(result.createPrUrl).toContain("/pull/new/"); + expect(result.headBranch).toBe("open-inspect/session-name-1"); + expect(result.baseBranch).toBe("main"); + } + expect(harness.deps.broadcastArtifactCreated).toHaveBeenCalledTimes(1); + }); + + it("creates PR with OAuth token and stores PR artifact", async () => { + const result = await harness.service.createPullRequest( + createInput({ promptingAuth: { authType: "oauth", token: "user-token" } }) + ); + + expect(result).toEqual({ + kind: "created", + prNumber: 42, + prUrl: "https://github.com/acme/web/pull/42", + state: "open", + }); + expect(harness.provider.createPullRequest).toHaveBeenCalledTimes(1); + const createPrCall = (harness.provider.createPullRequest as ReturnType).mock + .calls[0]; + expect(createPrCall[0]).toEqual({ authType: "oauth", token: "user-token" }); + expect(createPrCall[1].body).toContain( + "*Created with [Open-Inspect](https://app.example.com/session/session-name-1)*" + ); + expect(harness.deps.broadcastArtifactCreated).toHaveBeenCalledWith({ + id: "id-1", + type: "pr", + url: "https://github.com/acme/web/pull/42", + prNumber: 42, + }); + }); + + it("reuses existing manual artifact URL for same branch", async () => { + harness.artifacts.push({ + id: "branch-artifact-1", + type: "branch", + url: "https://github.com/acme/web/pull/new/main...open-inspect/session-name-1", + metadata: JSON.stringify({ + mode: "manual_pr", + head: "open-inspect/session-name-1", + createPrUrl: "https://existing.example.com/manual-pr", + }), + created_at: Date.now(), + }); + + const result = await harness.service.createPullRequest(createInput({ promptingAuth: null })); + + expect(result).toEqual({ + kind: "manual", + createPrUrl: "https://existing.example.com/manual-pr", + headBranch: "open-inspect/session-name-1", + baseBranch: "main", + }); + expect(harness.deps.broadcastArtifactCreated).not.toHaveBeenCalled(); + }); +}); diff --git a/packages/control-plane/src/session/pull-request-service.ts b/packages/control-plane/src/session/pull-request-service.ts new file mode 100644 index 0000000..6af5642 --- /dev/null +++ b/packages/control-plane/src/session/pull-request-service.ts @@ -0,0 +1,351 @@ +import { generateBranchName, type ManualPullRequestArtifactMetadata } from "@open-inspect/shared"; +import type { Logger } from "../logger"; +import { resolveHeadBranchForPr } from "../source-control/branch-resolution"; +import { + SourceControlProviderError, + type SourceControlProvider, + type SourceControlAuthContext, + type GitPushAuthContext, + type GitPushSpec, +} from "../source-control"; +import type { ArtifactRow, SessionRow } from "./types"; + +/** + * Inputs required to create a PR once caller identity/auth are already resolved. + */ +export interface CreatePullRequestInput { + title: string; + body: string; + baseBranch?: string; + headBranch?: string; + promptingUserId: string; + promptingAuth: SourceControlAuthContext | null; + sessionUrl: string; +} + +export type CreatePullRequestResult = + | { + kind: "created"; + prNumber: number; + prUrl: string; + state: "open" | "closed" | "merged" | "draft"; + } + | { kind: "manual"; createPrUrl: string; headBranch: string; baseBranch: string } + | { kind: "error"; status: number; error: string }; + +export type PushBranchResult = { success: true } | { success: false; error: string }; + +/** + * Session persistence operations required by pull request orchestration. + */ +export interface PullRequestRepository { + getSession(): SessionRow | null; + updateSessionBranch(sessionId: string, branchName: string): void; + listArtifacts(): ArtifactRow[]; + createArtifact(data: { + id: string; + type: "pr" | "branch"; + url: string | null; + metadata: string | null; + createdAt: number; + }): void; +} + +/** + * Durable-object adapters that bridge runtime concerns into the service. + */ +export interface PullRequestServiceDeps { + repository: PullRequestRepository; + sourceControlProvider: SourceControlProvider; + log: Logger; + generateId: () => string; + pushBranchToRemote: (headBranch: string, pushSpec: GitPushSpec) => Promise; + broadcastArtifactCreated: (artifact: { + id: string; + type: "pr" | "branch"; + url: string; + prNumber?: number; + }) => void; +} + +/** + * Orchestrates branch push and PR creation for a session. + * Participant lookup and token resolution are handled by SessionDO. + */ +export class SessionPullRequestService { + constructor(private readonly deps: PullRequestServiceDeps) {} + + /** + * Creates a pull request when OAuth auth is available, or falls back + * to a manual PR URL artifact when user OAuth cannot be used. + */ + async createPullRequest(input: CreatePullRequestInput): Promise { + const session = this.deps.repository.getSession(); + if (!session) { + return { kind: "error", status: 404, error: "Session not found" }; + } + + this.deps.log.info("Creating PR", { user_id: input.promptingUserId }); + + try { + const sessionId = session.session_name || session.id; + const generatedHeadBranch = generateBranchName(sessionId); + + const initialArtifacts = this.deps.repository.listArtifacts(); + const existingPrArtifact = initialArtifacts.find((artifact) => artifact.type === "pr"); + if (existingPrArtifact) { + return { + kind: "error", + status: 409, + error: "A pull request has already been created for this session.", + }; + } + + let pushAuth: GitPushAuthContext; + try { + pushAuth = await this.deps.sourceControlProvider.generatePushAuth(); + this.deps.log.info("Generated fresh push auth token"); + } catch (error) { + this.deps.log.error("Failed to generate push auth", { + error: error instanceof Error ? error : String(error), + }); + return { + kind: "error", + status: 500, + error: + error instanceof SourceControlProviderError + ? error.message + : "Failed to generate push authentication", + }; + } + + const appAuth: SourceControlAuthContext = { + authType: "app", + token: pushAuth.token, + }; + + const repoInfo = await this.deps.sourceControlProvider.getRepository(appAuth, { + owner: session.repo_owner, + name: session.repo_name, + }); + const baseBranch = input.baseBranch || repoInfo.defaultBranch; + const branchResolution = resolveHeadBranchForPr({ + requestedHeadBranch: input.headBranch, + sessionBranchName: session.branch_name, + generatedBranchName: generatedHeadBranch, + baseBranch, + }); + const headBranch = branchResolution.headBranch; + this.deps.log.info("Resolved PR head branch", { + requested_head_branch: input.headBranch ?? null, + session_branch_name: session.branch_name, + generated_head_branch: generatedHeadBranch, + resolved_head_branch: headBranch, + resolution_source: branchResolution.source, + base_branch: baseBranch, + }); + const pushSpec = this.deps.sourceControlProvider.buildGitPushSpec({ + owner: session.repo_owner, + name: session.repo_name, + sourceRef: "HEAD", + targetBranch: headBranch, + auth: pushAuth, + force: true, + }); + + const pushResult = await this.deps.pushBranchToRemote(headBranch, pushSpec); + if (!pushResult.success) { + return { kind: "error", status: 500, error: pushResult.error }; + } + + this.deps.repository.updateSessionBranch(session.id, headBranch); + + const latestArtifacts = this.deps.repository.listArtifacts(); + const latestPrArtifact = latestArtifacts.find((artifact) => artifact.type === "pr"); + if (latestPrArtifact) { + return { + kind: "error", + status: 409, + error: "A pull request has already been created for this session.", + }; + } + + if (!input.promptingAuth) { + return this.buildManualPrFallbackResult(session, headBranch, baseBranch, latestArtifacts); + } + + const fullBody = input.body + `\n\n---\n*Created with [Open-Inspect](${input.sessionUrl})*`; + + const prResult = await this.deps.sourceControlProvider.createPullRequest( + input.promptingAuth, + { + repository: repoInfo, + title: input.title, + body: fullBody, + sourceBranch: headBranch, + targetBranch: baseBranch, + } + ); + + const artifactId = this.deps.generateId(); + const now = Date.now(); + this.deps.repository.createArtifact({ + id: artifactId, + type: "pr", + url: prResult.webUrl, + metadata: JSON.stringify({ + number: prResult.id, + state: prResult.state, + head: headBranch, + base: baseBranch, + }), + createdAt: now, + }); + + this.deps.broadcastArtifactCreated({ + id: artifactId, + type: "pr", + url: prResult.webUrl, + prNumber: prResult.id, + }); + + return { + kind: "created", + prNumber: prResult.id, + prUrl: prResult.webUrl, + state: prResult.state, + }; + } catch (error) { + this.deps.log.error("PR creation failed", { + error: error instanceof Error ? error : String(error), + }); + + if (error instanceof SourceControlProviderError) { + return { + kind: "error", + status: error.httpStatus || 500, + error: error.message, + }; + } + + return { + kind: "error", + status: 500, + error: error instanceof Error ? error.message : "Failed to create PR", + }; + } + } + + /** + * Reuses an existing manual PR artifact URL for the same branch when present. + */ + private findExistingManualPrUrl( + artifacts: ArtifactRow[], + headBranch: string, + fallbackUrl: string + ): { artifactId: string; createPrUrl: string } | null { + for (const artifact of artifacts) { + if (artifact.type !== "branch" || !artifact.metadata) { + continue; + } + + try { + const metadata = JSON.parse(artifact.metadata) as Record; + if (metadata.mode !== "manual_pr" || metadata.head !== headBranch) { + continue; + } + + const metadataUrl = metadata.createPrUrl; + let createPrUrl = fallbackUrl; + + if (typeof metadataUrl === "string" && metadataUrl.length > 0) { + createPrUrl = metadataUrl; + } else if (artifact.url && artifact.url.length > 0) { + createPrUrl = artifact.url; + } + + return { + artifactId: artifact.id, + createPrUrl, + }; + } catch (error) { + this.deps.log.warn("Invalid artifact metadata JSON", { + artifact_id: artifact.id, + error: error instanceof Error ? error.message : String(error), + }); + } + } + + return null; + } + + /** + * Creates or reuses a manual PR fallback artifact and returns manual response payload. + */ + private buildManualPrFallbackResult( + session: SessionRow, + headBranch: string, + baseBranch: string, + artifacts: ArtifactRow[] + ): CreatePullRequestResult { + const manualCreatePrUrl = this.deps.sourceControlProvider.buildManualPullRequestUrl({ + owner: session.repo_owner, + name: session.repo_name, + sourceBranch: headBranch, + targetBranch: baseBranch, + }); + + const existing = this.findExistingManualPrUrl(artifacts, headBranch, manualCreatePrUrl); + if (existing) { + this.deps.log.info("Using manual PR fallback", { + head_branch: headBranch, + base_branch: baseBranch, + session_id: session.session_name || session.id, + existing_artifact_id: existing.artifactId, + }); + return { + kind: "manual", + createPrUrl: existing.createPrUrl, + headBranch, + baseBranch, + }; + } + + const artifactId = this.deps.generateId(); + const now = Date.now(); + const metadata: ManualPullRequestArtifactMetadata = { + head: headBranch, + base: baseBranch, + mode: "manual_pr", + createPrUrl: manualCreatePrUrl, + provider: this.deps.sourceControlProvider.name, + }; + this.deps.repository.createArtifact({ + id: artifactId, + type: "branch", + url: manualCreatePrUrl, + metadata: JSON.stringify(metadata), + createdAt: now, + }); + + this.deps.broadcastArtifactCreated({ + id: artifactId, + type: "branch", + url: manualCreatePrUrl, + }); + + this.deps.log.info("Using manual PR fallback", { + head_branch: headBranch, + base_branch: baseBranch, + session_id: session.session_name || session.id, + artifact_id: artifactId, + }); + + return { + kind: "manual", + createPrUrl: manualCreatePrUrl, + headBranch, + baseBranch, + }; + } +} diff --git a/packages/control-plane/test/integration/create-pr.test.ts b/packages/control-plane/test/integration/create-pr.test.ts new file mode 100644 index 0000000..8953a37 --- /dev/null +++ b/packages/control-plane/test/integration/create-pr.test.ts @@ -0,0 +1,275 @@ +import { describe, expect, it } from "vitest"; +import { env, runInDurableObject } from "cloudflare:test"; +import type { SourceControlProvider } from "../../src/source-control"; +import type { SessionDO } from "../../src/session/durable-object"; +import { initSession, queryDO, seedMessage } from "./helpers"; + +describe("POST /internal/create-pr", () => { + it("returns 404 when session is not initialized", async () => { + const id = env.SESSION.newUniqueId(); + const stub = env.SESSION.get(id); + + const res = await stub.fetch("http://internal/internal/create-pr", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + title: "Test PR", + body: "Body from integration test", + }), + }); + + expect(res.status).toBe(404); + const body = await res.json<{ error: string }>(); + expect(body.error).toBe("Session not found"); + }); + + it("returns 400 when no processing message exists", async () => { + const { stub } = await initSession({ userId: "user-1" }); + + const res = await stub.fetch("http://internal/internal/create-pr", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + title: "Test PR", + body: "Body from integration test", + }), + }); + + expect(res.status).toBe(400); + const body = await res.json<{ error: string }>(); + expect(body.error).toBe( + "No active prompt found. PR creation must be triggered by a user prompt." + ); + }); + + it("returns 401 when processing message author cannot be resolved", async () => { + const { stub } = await initSession({ userId: "user-1" }); + + const participants = await queryDO<{ id: string }>( + stub, + "SELECT id FROM participants WHERE user_id = ?", + "user-1" + ); + const ownerParticipantId = participants[0]?.id; + if (!ownerParticipantId) { + throw new Error("Expected owner participant"); + } + + await seedMessage(stub, { + id: "msg-processing-missing-author", + authorId: ownerParticipantId, + content: "Create a PR", + source: "web", + status: "processing", + createdAt: Date.now() - 1000, + startedAt: Date.now() - 500, + }); + + await runInDurableObject(stub, (instance: SessionDO) => { + instance.ctx.storage.sql.exec("PRAGMA foreign_keys = OFF"); + instance.ctx.storage.sql.exec( + "UPDATE messages SET author_id = ? WHERE id = ?", + "participant-does-not-exist", + "msg-processing-missing-author" + ); + instance.ctx.storage.sql.exec("PRAGMA foreign_keys = ON"); + }); + + const res = await stub.fetch("http://internal/internal/create-pr", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + title: "Test PR", + body: "Body from integration test", + }), + }); + + expect(res.status).toBe(401); + const body = await res.json<{ error: string }>(); + expect(body.error).toBe("User not found. Please re-authenticate."); + }); + it("returns 401 when expired OAuth token cannot be refreshed", async () => { + const { stub } = await initSession({ userId: "user-1" }); + + const participants = await queryDO<{ id: string }>( + stub, + "SELECT id FROM participants WHERE user_id = ?", + "user-1" + ); + const ownerParticipantId = participants[0]?.id; + if (!ownerParticipantId) { + throw new Error("Expected owner participant"); + } + + await seedMessage(stub, { + id: "msg-processing-expired-token", + authorId: ownerParticipantId, + content: "Create a PR", + source: "web", + status: "processing", + createdAt: Date.now() - 1000, + startedAt: Date.now() - 500, + }); + + await runInDurableObject(stub, (instance: SessionDO) => { + instance.ctx.storage.sql.exec( + "UPDATE participants SET github_access_token_encrypted = ?, github_refresh_token_encrypted = ?, github_token_expires_at = ? WHERE id = ?", + "invalid-access-token", + "invalid-refresh-token", + Date.now() - 60_000, + ownerParticipantId + ); + }); + + const res = await stub.fetch("http://internal/internal/create-pr", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + title: "Test PR", + body: "Body from integration test", + }), + }); + + expect(res.status).toBe(401); + const body = await res.json<{ error: string }>(); + expect(body.error).toBe( + "Your GitHub token has expired and could not be refreshed. Please re-authenticate." + ); + }); + + it("returns manual fallback and stores branch artifact when prompting user has no OAuth token", async () => { + const { stub } = await initSession({ userId: "user-1" }); + + const participants = await queryDO<{ id: string }>( + stub, + "SELECT id FROM participants WHERE user_id = ?", + "user-1" + ); + const ownerParticipantId = participants[0]?.id; + if (!ownerParticipantId) { + throw new Error("Expected owner participant"); + } + + await seedMessage(stub, { + id: "msg-processing-1", + authorId: ownerParticipantId, + content: "Create a PR", + source: "web", + status: "processing", + createdAt: Date.now() - 1000, + startedAt: Date.now() - 500, + }); + + await runInDurableObject(stub, (instance: SessionDO) => { + const mockProvider = { + name: "github", + generatePushAuth: async () => ({ authType: "app", token: "push-token" as const }), + getRepository: async () => ({ + owner: "acme", + name: "web-app", + fullName: "acme/web-app", + defaultBranch: "main", + isPrivate: true, + providerRepoId: 12345, + }), + createPullRequest: async () => { + throw new Error("createPullRequest should not be called for manual fallback"); + }, + buildManualPullRequestUrl: (config: { + owner: string; + name: string; + sourceBranch: string; + targetBranch: string; + }) => + `https://github.com/${config.owner}/${config.name}/pull/new/${config.targetBranch}...${config.sourceBranch}`, + buildGitPushSpec: (config: { targetBranch: string }) => ({ + remoteUrl: "https://example.invalid/repo.git", + redactedRemoteUrl: "https://example.invalid/.git", + refspec: `HEAD:refs/heads/${config.targetBranch}`, + targetBranch: config.targetBranch, + force: true, + }), + } as unknown as SourceControlProvider; + + ( + instance as unknown as { _sourceControlProvider: SourceControlProvider | null } + )._sourceControlProvider = mockProvider; + }); + + const res = await stub.fetch("http://internal/internal/create-pr", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + title: "Test PR", + body: "Body from integration test", + }), + }); + + expect(res.status).toBe(200); + const body = await res.json<{ + status: string; + createPrUrl: string; + headBranch: string; + baseBranch: string; + }>(); + expect(body.status).toBe("manual"); + expect(body.createPrUrl).toContain("/pull/new/"); + expect(body.headBranch.length).toBeGreaterThan(0); + expect(body.baseBranch).toBe("main"); + + const artifacts = await queryDO<{ type: string; metadata: string | null }>( + stub, + "SELECT type, metadata FROM artifacts ORDER BY created_at DESC LIMIT 1" + ); + expect(artifacts[0]?.type).toBe("branch"); + expect(artifacts[0]?.metadata).toContain('"mode":"manual_pr"'); + }); + + it("returns 409 when a PR artifact already exists", async () => { + const { stub } = await initSession({ userId: "user-1" }); + + const participants = await queryDO<{ id: string }>( + stub, + "SELECT id FROM participants WHERE user_id = ?", + "user-1" + ); + const ownerParticipantId = participants[0]?.id; + if (!ownerParticipantId) { + throw new Error("Expected owner participant"); + } + + await seedMessage(stub, { + id: "msg-processing-2", + authorId: ownerParticipantId, + content: "Create a PR", + source: "web", + status: "processing", + createdAt: Date.now() - 1000, + startedAt: Date.now() - 500, + }); + + await runInDurableObject(stub, (instance: SessionDO) => { + instance.ctx.storage.sql.exec( + "INSERT INTO artifacts (id, type, url, metadata, created_at) VALUES (?, ?, ?, ?, ?)", + "artifact-pr-existing", + "pr", + "https://github.com/acme/web-app/pull/1", + JSON.stringify({ number: 1 }), + Date.now() + ); + }); + + const res = await stub.fetch("http://internal/internal/create-pr", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + title: "Test PR", + body: "Body from integration test", + }), + }); + + expect(res.status).toBe(409); + const body = await res.json<{ error: string }>(); + expect(body.error).toBe("A pull request has already been created for this session."); + }); +});