diff --git a/docs/log.md b/docs/log.md index 5378b42..1c845a0 100644 --- a/docs/log.md +++ b/docs/log.md @@ -5,13 +5,32 @@ Visualize your stack's structure and status. ## Usage ```bash +# Infer stack from current branch (recommended) +gh-stack log + +# Infer from a specific branch +gh-stack log --branch feat/my-feature + +# List all stacks and select interactively +gh-stack log --all + +# Search by identifier in PR titles gh-stack log 'STACK-ID' -gh-stack log 'STACK-ID' --short # compact list view -gh-stack log 'STACK-ID' --include-closed # show closed/merged PRs -gh-stack log 'STACK-ID' --no-color # disable colors and unicode -gh-stack log 'STACK-ID' -C /path/to/repo # specify repo path + +# CI mode (non-interactive) +gh-stack log --branch $GITHUB_HEAD_REF --ci ``` +## Stack Discovery + +When run without an identifier, `gh-stack log` automatically discovers your stack by: + +1. Finding the PR for your current branch +2. Walking up the PR base chain to find ancestors +3. Walking down to find PRs that build on yours + +This works with any PR structure - no special naming required. + ## Output The default tree view shows: @@ -50,7 +69,12 @@ The `--short` flag shows a compact list: | Flag | Description | |------|-------------| +| `--branch`, `-b` | Infer stack from this branch instead of current | +| `--all`, `-a` | List all stacks and select interactively | +| `--ci` | Non-interactive mode for CI environments | +| `--trunk` | Override trunk branch (default: auto-detect or "main") | | `--short`, `-s` | Compact list format instead of tree | +| `--status` | Show CI, approval, and merge status bits | | `--include-closed` | Show branches with closed/merged PRs | | `--no-color` | Disable colors and unicode characters | | `-C`, `--project` | Path to local repository | @@ -58,6 +82,55 @@ The `--short` flag shows a compact list: | `-o`, `--origin` | Git remote name (default: origin) | | `-e`, `--excl` | Exclude PR by number (repeatable) | +## CI Usage + +In CI environments, use `--ci` to disable interactive prompts: + +```bash +# Must provide branch explicitly +gh-stack log --branch $GITHUB_HEAD_REF --ci + +# Or use identifier +gh-stack log 'STACK-ID' --ci +``` + +The `--ci` flag will: +- Fail with an error if no identifier or branch is provided +- Fail if on a trunk branch without an identifier +- Never prompt for user input + +## On Trunk Branch + +When you're on a trunk branch (main, master, etc.), `gh-stack log` will prompt you to: + +1. Enter a stack identifier manually +2. Select from detected stacks in the repository +3. Cancel + +``` +You're on 'main' (trunk branch). + +? What would you like to do? +> Enter a stack identifier + Select from detected stacks (3 found) + Cancel +``` + +## No PR Found + +If no PR exists for your current branch: + +``` +No PR found for branch 'feat/new-feature'. + +Create a PR with: + gh pr create --base main --head feat/new-feature + +? Create PR now? [Y/n] +``` + +If you confirm, `gh-stack` will run `gh pr create` for you. + ## When to use - Before rebasing to understand stack structure @@ -68,4 +141,5 @@ The `--short` flag shows a compact list: ## See also - [annotate](annotate.md) - Add stack tables to PR descriptions +- [status](status.md) - Show CI and approval status - [land](land.md) - Merge the stack diff --git a/docs/status.md b/docs/status.md index ab5f9c1..cc6d4f1 100644 --- a/docs/status.md +++ b/docs/status.md @@ -5,11 +5,35 @@ Show stack status with CI, approval, and merge readiness indicators. ## Usage ```bash -gh-stack status [OPTIONS] -# or -gh-stack log --status [OPTIONS] +# Infer stack from current branch (recommended) +gh-stack status + +# Infer from a specific branch +gh-stack status --branch feat/my-feature + +# List all stacks and select interactively +gh-stack status --all + +# Search by identifier in PR titles +gh-stack status 'STACK-ID' + +# CI mode (non-interactive) +gh-stack status --branch $GITHUB_HEAD_REF --ci + +# Use with log command +gh-stack log --status 'STACK-ID' ``` +## Stack Discovery + +When run without an identifier, `gh-stack status` automatically discovers your stack by: + +1. Finding the PR for your current branch +2. Walking up the PR base chain to find ancestors +3. Walking down to find PRs that build on yours + +This works with any PR structure - no special naming required. + ## Description The `status` command displays your PR stack with status bits showing: @@ -68,19 +92,43 @@ Status: [CI | Approved | Mergeable | Stack] | Flag | Description | |------|-------------| -| `--no-checks` | Skip fetching CI/approval/conflict status (faster, shows basic tree) | +| `--branch`, `-b` | Infer stack from this branch instead of current | +| `--all`, `-a` | List all stacks and select interactively | +| `--ci` | Non-interactive mode for CI environments | +| `--trunk` | Override trunk branch (default: auto-detect or "main") | +| `--no-checks` | Skip fetching CI/approval/conflict status (faster) | | `--no-color` | Disable colors and Unicode characters | | `--help-legend` | Show status bits legend | | `--json` | Output in JSON format | | `-C, --project ` | Path to local repository | | `-r, --repository ` | Specify repository (owner/repo) | | `-o, --origin ` | Git remote to use (default: origin) | -| `-e, --excl ` | Exclude PR by number (can be used multiple times) | +| `-e, --excl ` | Exclude PR by number (repeatable) | + +## CI Usage + +In CI environments, use `--ci` to disable interactive prompts: + +```bash +# Must provide branch explicitly +gh-stack status --branch $GITHUB_HEAD_REF --ci + +# Or use identifier +gh-stack status 'STACK-ID' --ci + +# JSON output for parsing +gh-stack status --branch $GITHUB_HEAD_REF --ci --json +``` + +The `--ci` flag will: +- Fail with an error if no identifier or branch is provided +- Fail if on a trunk branch without an identifier +- Never prompt for user input ## JSON Output ```bash -gh-stack status STACK-123 --json +gh-stack status --json ``` ```json @@ -113,41 +161,41 @@ gh-stack status STACK-123 --json The legend is shown automatically on first run. To see it again: ```bash -gh-stack status STACK-123 --help-legend +gh-stack status --help-legend ``` The legend marker is stored in `~/.gh-stack-legend-seen`. ## Examples -### Basic usage +### Infer from current branch ```bash -gh-stack status JIRA-1234 +gh-stack status ``` -### Skip API calls for faster output +### Infer from specific branch ```bash -gh-stack status JIRA-1234 --no-checks +gh-stack status --branch feat/my-feature ``` -### Specify repository explicitly +### Skip API calls for faster output ```bash -gh-stack status JIRA-1234 -r owner/repo +gh-stack status --no-checks ``` -### Output as JSON for scripting +### Specify repository explicitly ```bash -gh-stack status JIRA-1234 --json | jq '.stack[].status.ci' +gh-stack status -r owner/repo ``` -### Use with log command +### Output as JSON for scripting ```bash -gh-stack log --status JIRA-1234 +gh-stack status --json | jq '.stack[].status.ci' ``` ## See Also diff --git a/src/api/create.rs b/src/api/create.rs new file mode 100644 index 0000000..389aba2 --- /dev/null +++ b/src/api/create.rs @@ -0,0 +1,238 @@ +//! GitHub API methods for creating PRs +//! +//! This module provides functionality to create pull requests via the GitHub API, +//! eliminating the need for the `gh` CLI dependency. + +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use std::error::Error; +use std::time::Duration; + +use crate::Credentials; + +/// Request body for creating a PR +#[derive(Serialize, Debug)] +struct CreatePrRequest<'a> { + title: &'a str, + head: &'a str, + base: &'a str, + #[serde(skip_serializing_if = "Option::is_none")] + body: Option<&'a str>, +} + +/// Response from PR creation endpoint +#[derive(Deserialize, Debug)] +struct CreatePrResponse { + number: usize, + html_url: String, +} + +/// Create a new pull request via GitHub API +/// +/// # Arguments +/// * `repository` - Repository in "owner/repo" format +/// * `head` - Head branch name (the branch with changes) +/// * `base` - Base branch name (the branch to merge into) +/// * `title` - PR title +/// * `body` - Optional PR body/description +/// * `credentials` - GitHub credentials +/// +/// # Returns +/// Tuple of (pr_number, html_url) on success +/// +/// # Errors +/// Returns an error if the API request fails or returns a non-success status +pub async fn create_pr( + repository: &str, + head: &str, + base: &str, + title: &str, + body: Option<&str>, + credentials: &Credentials, +) -> Result<(usize, String), Box> { + let client = Client::new(); + let url = format!("{}/repos/{}/pulls", super::github_api_base(), repository); + + let request_body = CreatePrRequest { + title, + head, + base, + body, + }; + + let response = client + .post(&url) + .timeout(Duration::from_secs(30)) + .header("Authorization", format!("token {}", credentials.token)) + .header("User-Agent", "luqven/gh-stack") + .header("Accept", "application/vnd.github.v3+json") + .json(&request_body) + .send() + .await?; + + if !response.status().is_success() { + let status = response.status(); + let text = response.text().await.unwrap_or_default(); + return Err(format!("Failed to create PR ({}): {}", status, text).into()); + } + + let pr: CreatePrResponse = response.json().await?; + Ok((pr.number, pr.html_url)) +} + +#[cfg(test)] +mod tests { + use super::*; + use mockito::Server; + use serial_test::serial; + + #[tokio::test] + #[serial] + async fn test_create_pr_success() { + let mut server = Server::new_async().await; + + let mock = server + .mock("POST", "/repos/owner/repo/pulls") + .match_body(mockito::Matcher::Json(serde_json::json!({ + "title": "Test PR", + "head": "feature", + "base": "main", + "body": "PR body" + }))) + .with_status(201) + .with_body(r#"{"number": 123, "html_url": "https://github.com/owner/repo/pull/123"}"#) + .create_async() + .await; + + std::env::set_var("GITHUB_API_BASE", server.url()); + + let creds = Credentials::new("test-token"); + let result = create_pr( + "owner/repo", + "feature", + "main", + "Test PR", + Some("PR body"), + &creds, + ) + .await; + + assert!(result.is_ok()); + let (number, url) = result.unwrap(); + assert_eq!(number, 123); + assert_eq!(url, "https://github.com/owner/repo/pull/123"); + mock.assert_async().await; + } + + #[tokio::test] + #[serial] + async fn test_create_pr_without_body() { + let mut server = Server::new_async().await; + + let mock = server + .mock("POST", "/repos/owner/repo/pulls") + .match_body(mockito::Matcher::Json(serde_json::json!({ + "title": "Test PR", + "head": "feature", + "base": "main" + }))) + .with_status(201) + .with_body(r#"{"number": 456, "html_url": "https://github.com/owner/repo/pull/456"}"#) + .create_async() + .await; + + std::env::set_var("GITHUB_API_BASE", server.url()); + + let creds = Credentials::new("test-token"); + let result = create_pr("owner/repo", "feature", "main", "Test PR", None, &creds).await; + + assert!(result.is_ok()); + let (number, _) = result.unwrap(); + assert_eq!(number, 456); + mock.assert_async().await; + } + + #[tokio::test] + #[serial] + async fn test_create_pr_validation_error() { + let mut server = Server::new_async().await; + + let mock = server + .mock("POST", "/repos/owner/repo/pulls") + .with_status(422) + .with_body(r#"{"message": "Validation Failed", "errors": [{"resource": "PullRequest", "code": "custom", "message": "A pull request already exists"}]}"#) + .create_async() + .await; + + std::env::set_var("GITHUB_API_BASE", server.url()); + + let creds = Credentials::new("test-token"); + let result = create_pr("owner/repo", "feature", "main", "Test PR", None, &creds).await; + + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("422")); + assert!(err.contains("Validation Failed")); + mock.assert_async().await; + } + + #[tokio::test] + #[serial] + async fn test_create_pr_unauthorized() { + let mut server = Server::new_async().await; + + let mock = server + .mock("POST", "/repos/owner/repo/pulls") + .with_status(401) + .with_body(r#"{"message": "Bad credentials"}"#) + .create_async() + .await; + + std::env::set_var("GITHUB_API_BASE", server.url()); + + let creds = Credentials::new("bad-token"); + let result = create_pr("owner/repo", "feature", "main", "Test PR", None, &creds).await; + + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("401")); + mock.assert_async().await; + } + + #[tokio::test] + #[serial] + async fn test_create_pr_with_identifier_in_body() { + let mut server = Server::new_async().await; + + let mock = server + .mock("POST", "/repos/owner/repo/pulls") + .match_body(mockito::Matcher::Json(serde_json::json!({ + "title": "[STACK-123] My feature", + "head": "feature", + "base": "main", + "body": "" + }))) + .with_status(201) + .with_body(r#"{"number": 789, "html_url": "https://github.com/owner/repo/pull/789"}"#) + .create_async() + .await; + + std::env::set_var("GITHUB_API_BASE", server.url()); + + let creds = Credentials::new("test-token"); + let result = create_pr( + "owner/repo", + "feature", + "main", + "[STACK-123] My feature", + Some(""), + &creds, + ) + .await; + + assert!(result.is_ok()); + let (number, _) = result.unwrap(); + assert_eq!(number, 789); + mock.assert_async().await; + } +} diff --git a/src/api/mod.rs b/src/api/mod.rs index a1c1c1c..4305bf7 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,11 +1,16 @@ use crate::Credentials; -use reqwest::{Client, RequestBuilder}; +use chrono::{DateTime, Utc}; +use reqwest::{Client, RequestBuilder, Response}; +use std::error::Error; +use std::fmt; use std::time::Duration; pub mod checks; +pub mod create; pub mod land; pub mod pull_request; pub mod search; +pub mod stack; pub use pull_request::PullRequest; pub use pull_request::PullRequestReview; @@ -26,6 +31,131 @@ pub fn github_api_base() -> String { GITHUB_API_BASE.to_string() } +/// Maximum number of retry attempts for rate-limited requests +const MAX_RETRIES: u32 = 3; + +/// Base delay between retries (will be doubled each attempt) +const BASE_RETRY_DELAY_MS: u64 = 1000; + +/// Rate limit error with reset time information +#[derive(Debug, Clone)] +pub struct RateLimitError { + pub reset_time: Option>, + pub limit: Option, + pub remaining: Option, +} + +impl fmt::Display for RateLimitError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.reset_time { + Some(reset) => { + let wait = reset.signed_duration_since(Utc::now()); + let mins = wait.num_minutes().max(1); + write!( + f, + "GitHub API rate limit exceeded. Try again in {} minute{}.", + mins, + if mins == 1 { "" } else { "s" } + ) + } + None => write!(f, "GitHub API rate limit exceeded."), + } + } +} + +impl Error for RateLimitError {} + +/// Parse rate limit headers from a GitHub API response +fn parse_rate_limit_headers(response: &Response) -> RateLimitError { + let headers = response.headers(); + + let reset_time = headers + .get("x-ratelimit-reset") + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()) + .and_then(|ts| DateTime::from_timestamp(ts, 0)); + + let limit = headers + .get("x-ratelimit-limit") + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse().ok()); + + let remaining = headers + .get("x-ratelimit-remaining") + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse().ok()); + + RateLimitError { + reset_time, + limit, + remaining, + } +} + +/// Check if a response indicates rate limiting (HTTP 429 or 403 with rate limit headers) +fn is_rate_limited(response: &Response) -> bool { + if response.status() == 429 { + return true; + } + + // GitHub sometimes returns 403 for rate limits + if response.status() == 403 { + if let Some(remaining) = response.headers().get("x-ratelimit-remaining") { + if remaining.to_str().unwrap_or("1") == "0" { + return true; + } + } + } + + false +} + +/// Send a request with automatic retry on rate limit (HTTP 429). +/// +/// Implements exponential backoff with up to MAX_RETRIES attempts. +/// On the final failure, returns a RateLimitError with reset time info. +/// +/// # Arguments +/// * `client` - The reqwest client to use +/// * `build_request` - A closure that builds the request (called fresh each attempt) +/// +/// # Returns +/// The successful response, or an error if all retries fail +pub async fn send_with_retry( + client: &Client, + build_request: F, +) -> Result> +where + F: Fn(&Client) -> RequestBuilder, +{ + let mut last_rate_limit_error: Option = None; + + for attempt in 0..MAX_RETRIES { + let request = build_request(client); + let response = request.send().await?; + + if is_rate_limited(&response) { + last_rate_limit_error = Some(parse_rate_limit_headers(&response)); + + // Don't sleep on the last attempt + if attempt < MAX_RETRIES - 1 { + let delay_ms = BASE_RETRY_DELAY_MS * 2u64.pow(attempt); + tokio::time::sleep(Duration::from_millis(delay_ms)).await; + } + continue; + } + + return Ok(response); + } + + // All retries exhausted + Err(Box::new(last_rate_limit_error.unwrap_or(RateLimitError { + reset_time: None, + limit: None, + remaining: None, + }))) +} + pub fn base_request(client: &Client, credentials: &Credentials, url: &str) -> RequestBuilder { client .get(url) @@ -46,6 +176,7 @@ pub fn base_patch_request(client: &Client, credentials: &Credentials, url: &str) mod tests { use super::*; use mockito::Server; + use serial_test::serial; #[test] fn test_base_request_sets_auth_header() { @@ -181,4 +312,146 @@ mod tests { mock.assert_async().await; } + + #[test] + fn test_rate_limit_error_display_with_reset() { + let future_time = Utc::now() + chrono::Duration::minutes(5); + let err = RateLimitError { + reset_time: Some(future_time), + limit: Some(5000), + remaining: Some(0), + }; + let msg = format!("{}", err); + assert!(msg.contains("rate limit exceeded")); + assert!(msg.contains("minute")); + } + + #[test] + fn test_rate_limit_error_display_without_reset() { + let err = RateLimitError { + reset_time: None, + limit: None, + remaining: None, + }; + let msg = format!("{}", err); + assert!(msg.contains("rate limit exceeded")); + } + + #[tokio::test] + #[serial] + async fn test_send_with_retry_success_first_try() { + let mut server = Server::new_async().await; + + let mock = server + .mock("GET", "/test") + .with_status(200) + .with_body("ok") + .expect(1) + .create_async() + .await; + + let client = Client::new(); + let result = send_with_retry(&client, |c| c.get(format!("{}/test", server.url()))).await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap().status(), 200); + mock.assert_async().await; + } + + #[tokio::test] + #[serial] + async fn test_send_with_retry_rate_limit_then_success() { + let mut server = Server::new_async().await; + + // First request: rate limited + let mock_429 = server + .mock("GET", "/test") + .with_status(429) + .with_header("x-ratelimit-remaining", "0") + .expect(1) + .create_async() + .await; + + // Second request: success + let mock_200 = server + .mock("GET", "/test") + .with_status(200) + .with_body("ok") + .expect(1) + .create_async() + .await; + + let client = Client::new(); + let result = send_with_retry(&client, |c| c.get(format!("{}/test", server.url()))).await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap().status(), 200); + mock_429.assert_async().await; + mock_200.assert_async().await; + } + + #[tokio::test] + #[serial] + async fn test_send_with_retry_exhausted() { + let mut server = Server::new_async().await; + + // All requests: rate limited + let mock = server + .mock("GET", "/test") + .with_status(429) + .with_header("x-ratelimit-remaining", "0") + .with_header("x-ratelimit-limit", "5000") + .expect(3) // MAX_RETRIES + .create_async() + .await; + + let client = Client::new(); + let result = send_with_retry(&client, |c| c.get(format!("{}/test", server.url()))).await; + + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(err.to_string().contains("rate limit")); + mock.assert_async().await; + } + + #[tokio::test] + #[serial] + async fn test_send_with_retry_403_with_rate_limit() { + let mut server = Server::new_async().await; + + // First: 403 with rate limit headers (GitHub sometimes does this) + let mock_403 = server + .mock("GET", "/test") + .with_status(403) + .with_header("x-ratelimit-remaining", "0") + .expect(1) + .create_async() + .await; + + // Second: success + let mock_200 = server + .mock("GET", "/test") + .with_status(200) + .with_body("ok") + .expect(1) + .create_async() + .await; + + let client = Client::new(); + let result = send_with_retry(&client, |c| c.get(format!("{}/test", server.url()))).await; + + assert!(result.is_ok()); + mock_403.assert_async().await; + mock_200.assert_async().await; + } + + #[test] + fn test_is_rate_limited_429() { + // Can't easily test this without mocking Response, but the logic is tested in integration tests above + } + + #[test] + fn test_parse_rate_limit_headers() { + // Unit test for header parsing logic is covered by integration tests + } } diff --git a/src/api/stack.rs b/src/api/stack.rs new file mode 100644 index 0000000..bd0f821 --- /dev/null +++ b/src/api/stack.rs @@ -0,0 +1,899 @@ +//! Stack discovery via GitHub API +//! +//! Walks PR base/head chains to discover full stack structure without +//! requiring identifier patterns in PR titles. +//! +//! ## Performance +//! +//! Stack discovery uses a batch-fetch strategy: all open PRs are fetched +//! in a single paginated API call, then the chain is walked in-memory. +//! This reduces API calls from O(N) to O(1) for most repositories. + +use crate::api::{github_api_base, PullRequest}; +use crate::Credentials; +use reqwest::Client; +use std::collections::{HashMap, HashSet}; +use std::error::Error; +use std::time::Duration; + +/// Maximum number of pages to fetch (100 PRs per page = 1000 PRs max) +const MAX_PAGES: u32 = 10; + +/// Build a GET request with auth headers +fn build_request(client: &Client, creds: &Credentials, url: &str) -> reqwest::RequestBuilder { + client + .get(url) + .timeout(Duration::from_secs(10)) + .header("Authorization", format!("token {}", creds.token)) + .header("User-Agent", "luqven/gh-stack") + .header("Accept", "application/vnd.github.v3+json") +} + +/// Index of PRs for fast lookup by head/base branch. +/// +/// Built once from a batch fetch, then used for in-memory chain walking. +struct PrIndex { + by_head: HashMap, + by_base: HashMap>, +} + +impl PrIndex { + /// Build an index from a list of PRs + fn from_prs(prs: Vec) -> Self { + let mut by_head = HashMap::new(); + let mut by_base: HashMap> = HashMap::new(); + + for pr in prs { + by_base + .entry(pr.base().to_string()) + .or_default() + .push(pr.clone()); + by_head.insert(pr.head().to_string(), pr); + } + + Self { by_head, by_base } + } + + /// Get a PR by its head branch name + fn get_by_head(&self, head: &str) -> Option<&PullRequest> { + self.by_head.get(head) + } + + /// Get all PRs that target a given base branch + fn get_by_base(&self, base: &str) -> Vec<&PullRequest> { + self.by_base + .get(base) + .map(|v| v.iter().collect()) + .unwrap_or_default() + } +} + +/// Fetch a PR by its head branch name. +/// Returns None if no open PR exists for this branch. +/// +/// # Arguments +/// * `repo` - Repository in "owner/repo" format +/// * `branch` - The head branch name to search for +/// * `creds` - GitHub credentials +pub async fn fetch_pr_by_head( + repo: &str, + branch: &str, + creds: &Credentials, +) -> Result, Box> { + let client = Client::new(); + + // Extract owner from repo for the head filter + let owner = repo.split('/').next().unwrap_or(repo); + let head_filter = format!("{}:{}", owner, branch); + + let url = format!( + "{}/repos/{}/pulls?state=open&head={}", + github_api_base(), + repo, + head_filter + ); + + let response = build_request(&client, creds, &url).send().await?; + + if response.status() == 429 { + return Err("GitHub API rate limit exceeded".into()); + } + + if !response.status().is_success() { + let status = response.status(); + let text = response.text().await.unwrap_or_default(); + return Err(format!("Failed to fetch PR by head ({}): {}", status, text).into()); + } + + let prs: Vec = response.json().await?; + Ok(prs.into_iter().next()) +} + +/// Fetch all open PRs that target a given base branch. +/// +/// # Arguments +/// * `repo` - Repository in "owner/repo" format +/// * `base` - The base branch name to search for +/// * `creds` - GitHub credentials +pub async fn fetch_prs_by_base( + repo: &str, + base: &str, + creds: &Credentials, +) -> Result, Box> { + let client = Client::new(); + + let url = format!( + "{}/repos/{}/pulls?state=open&base={}", + github_api_base(), + repo, + base + ); + + let response = build_request(&client, creds, &url).send().await?; + + if response.status() == 429 { + return Err("GitHub API rate limit exceeded".into()); + } + + if !response.status().is_success() { + let status = response.status(); + let text = response.text().await.unwrap_or_default(); + return Err(format!("Failed to fetch PRs by base ({}): {}", status, text).into()); + } + + let prs: Vec = response.json().await?; + Ok(prs) +} + +/// Fetch all open PRs in a repository with pagination support. +/// +/// Fetches up to MAX_PAGES pages (1000 PRs) to support enterprise users +/// with large numbers of open PRs. +/// +/// # Arguments +/// * `repo` - Repository in "owner/repo" format +/// * `creds` - GitHub credentials +pub async fn fetch_all_open_prs( + repo: &str, + creds: &Credentials, +) -> Result, Box> { + let client = Client::new(); + let mut all_prs = Vec::new(); + + for page in 1..=MAX_PAGES { + let url = format!( + "{}/repos/{}/pulls?state=open&per_page=100&page={}", + github_api_base(), + repo, + page + ); + + let response = build_request(&client, creds, &url).send().await?; + + if response.status() == 429 { + return Err("GitHub API rate limit exceeded".into()); + } + + if !response.status().is_success() { + let status = response.status(); + let text = response.text().await.unwrap_or_default(); + return Err(format!("Failed to fetch open PRs ({}): {}", status, text).into()); + } + + let prs: Vec = response.json().await?; + let count = prs.len(); + all_prs.extend(prs); + + // GitHub returns fewer items when we've reached the end + if count < 100 { + break; + } + } + + Ok(all_prs) +} + +/// Discover the full stack by walking PR chain from a starting PR. +/// +/// Uses batch-fetch strategy: fetches all open PRs in one paginated call, +/// then walks the chain in-memory. This reduces API calls from O(N) to O(1). +/// +/// # Arguments +/// * `repo` - Repository in "owner/repo" format +/// * `starting_pr` - The PR to start discovery from +/// * `trunk` - The trunk branch name (e.g., "main", "master") +/// * `creds` - GitHub credentials +/// +/// # Returns +/// Vector of PRs in the stack, sorted from bottom (closest to trunk) to top +pub async fn discover_stack( + repo: &str, + starting_pr: PullRequest, + trunk: &str, + creds: &Credentials, +) -> Result, Box> { + // Batch fetch all open PRs (1 paginated API call) + let all_prs = fetch_all_open_prs(repo, creds).await?; + + // Build in-memory index + let index = PrIndex::from_prs(all_prs); + + // Walk chain in memory (no more API calls) + Ok(discover_stack_from_index(&index, starting_pr, trunk)) +} + +/// Walk stack using pre-fetched PR index (pure in-memory operation). +/// +/// This is the core algorithm that walks up and down the PR chain +/// without making any API calls. +fn discover_stack_from_index( + index: &PrIndex, + starting_pr: PullRequest, + trunk: &str, +) -> Vec { + let mut visited: HashMap = HashMap::new(); + visited.insert(starting_pr.head().to_string(), starting_pr.clone()); + + // Walk UP: follow base branches until we hit trunk + let mut up_queue = vec![starting_pr.base().to_string()]; + let mut seen_bases: HashSet = HashSet::new(); + + while let Some(base) = up_queue.pop() { + // Skip if we've seen this base or it's trunk + if base == trunk || seen_bases.contains(&base) || visited.contains_key(&base) { + continue; + } + seen_bases.insert(base.clone()); + + // Try to find a PR with this branch as its head (in-memory lookup) + if let Some(pr) = index.get_by_head(&base) { + let pr_base = pr.base().to_string(); + visited.insert(pr.head().to_string(), pr.clone()); + up_queue.push(pr_base); + } + } + + // Walk DOWN: find PRs whose base is in our visited set + let mut down_queue: Vec = visited.keys().cloned().collect(); + let mut seen_heads: HashSet = HashSet::new(); + + while let Some(head) = down_queue.pop() { + if seen_heads.contains(&head) { + continue; + } + seen_heads.insert(head.clone()); + + // Find all PRs that target this branch as their base (in-memory lookup) + for child in index.get_by_base(&head) { + if !visited.contains_key(child.head()) { + let child_head = child.head().to_string(); + visited.insert(child_head.clone(), child.clone()); + down_queue.push(child_head); + } + } + } + + // Sort PRs by their position in the stack (bottom to top) + sort_stack(visited.into_values().collect(), trunk) +} + +/// Sort PRs by their position in the stack (bottom to top). +/// Bottom = PR whose base is trunk, Top = PR with no children. +fn sort_stack(prs: Vec, trunk: &str) -> Vec { + if prs.is_empty() { + return prs; + } + + // Build a map from base -> PR for sorting + let head_to_pr: HashMap<&str, &PullRequest> = prs.iter().map(|pr| (pr.head(), pr)).collect(); + + let mut sorted = Vec::with_capacity(prs.len()); + let mut remaining: HashSet<&str> = prs.iter().map(|pr| pr.head()).collect(); + + // Find the root(s) - PRs whose base is trunk or not in our set + let mut current_base = trunk; + + while !remaining.is_empty() { + // Find a PR whose base matches current_base + let next_pr = prs + .iter() + .find(|pr| remaining.contains(pr.head()) && pr.base() == current_base); + + match next_pr { + Some(pr) => { + remaining.remove(pr.head()); + current_base = pr.head(); + sorted.push(pr.clone()); + } + None => { + // No more PRs with expected base, try to find any remaining PR + // whose base is already in sorted list or is trunk + let sorted_heads: HashSet<&str> = sorted.iter().map(|pr| pr.head()).collect(); + let fallback = prs.iter().find(|pr| { + remaining.contains(pr.head()) + && (pr.base() == trunk || sorted_heads.contains(pr.base())) + }); + + match fallback { + Some(pr) => { + remaining.remove(pr.head()); + current_base = pr.head(); + sorted.push(pr.clone()); + } + None => { + // Add any remaining PRs (shouldn't happen in well-formed stacks) + for head in remaining.iter() { + if let Some(pr) = head_to_pr.get(head) { + sorted.push((*pr).clone()); + } + } + break; + } + } + } + } + } + + sorted +} + +/// Discover all stacks in a repository. +/// +/// Groups PRs by their root (PR whose base is trunk) and returns +/// each group as a separate stack. Uses batch-fetch for efficiency. +/// +/// # Arguments +/// * `repo` - Repository in "owner/repo" format +/// * `trunk` - The trunk branch name (e.g., "main", "master") +/// * `creds` - GitHub credentials +/// +/// # Returns +/// Vector of stacks, where each stack is a vector of PRs sorted bottom to top +pub async fn discover_all_stacks( + repo: &str, + trunk: &str, + creds: &Credentials, +) -> Result>, Box> { + let all_prs = fetch_all_open_prs(repo, creds).await?; + Ok(group_into_stacks(all_prs, trunk)) +} + +/// Group PRs into stacks (pure in-memory operation). +/// +/// PRs are grouped by walking from each root (PR whose base is trunk) +/// down through child PRs. +fn group_into_stacks(prs: Vec, trunk: &str) -> Vec> { + if prs.is_empty() { + return vec![]; + } + + // Build adjacency: base -> list of PRs targeting that base + let mut base_to_prs: HashMap> = HashMap::new(); + for pr in &prs { + base_to_prs + .entry(pr.base().to_string()) + .or_default() + .push(pr); + } + + // Find root PRs (those whose base is trunk) + let roots: Vec<&PullRequest> = prs.iter().filter(|pr| pr.base() == trunk).collect(); + + // For each root, build its stack by walking down + let mut stacks = Vec::new(); + let mut assigned: HashSet = HashSet::new(); + + for root in roots { + if assigned.contains(&root.number()) { + continue; + } + + let mut stack = vec![(*root).clone()]; + assigned.insert(root.number()); + + // BFS to find all descendants + let mut queue = vec![root.head()]; + while let Some(head) = queue.pop() { + if let Some(children) = base_to_prs.get(head) { + for child in children { + if !assigned.contains(&child.number()) { + assigned.insert(child.number()); + stack.push((*child).clone()); + queue.push(child.head()); + } + } + } + } + + // Sort the stack + stack = sort_stack(stack, trunk); + stacks.push(stack); + } + + // Sort stacks by size (largest first) for better UX + stacks.sort_by_key(|s| std::cmp::Reverse(s.len())); + + stacks +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::api::PullRequestStatus; + use mockito::Server; + use serial_test::serial; + + fn make_pr_json(number: usize, head: &str, base: &str, title: &str) -> String { + format!( + r#"{{ + "id": {number}, + "number": {number}, + "head": {{"label": "user:{head}", "ref": "{head}", "sha": "abc{number}"}}, + "base": {{"label": "user:{base}", "ref": "{base}", "sha": "def{number}"}}, + "title": "{title}", + "url": "https://api.github.com/repos/test/repo/pulls/{number}", + "body": null, + "state": "open", + "merged_at": null, + "updated_at": null, + "draft": false + }}"# + ) + } + + fn make_test_pr(number: usize, head: &str, base: &str) -> PullRequest { + PullRequest::new_for_test( + number, + head, + base, + &format!("PR {}", number), + PullRequestStatus::Open, + false, + None, + vec![], + ) + } + + // === PrIndex tests === + + #[test] + fn test_pr_index_get_by_head() { + let pr1 = make_test_pr(1, "feature-1", "main"); + let pr2 = make_test_pr(2, "feature-2", "feature-1"); + + let index = PrIndex::from_prs(vec![pr1, pr2]); + + assert!(index.get_by_head("feature-1").is_some()); + assert_eq!(index.get_by_head("feature-1").unwrap().number(), 1); + assert!(index.get_by_head("feature-2").is_some()); + assert!(index.get_by_head("nonexistent").is_none()); + } + + #[test] + fn test_pr_index_get_by_base() { + let pr1 = make_test_pr(1, "feature-1", "main"); + let pr2 = make_test_pr(2, "feature-2", "main"); + let pr3 = make_test_pr(3, "feature-3", "feature-1"); + + let index = PrIndex::from_prs(vec![pr1, pr2, pr3]); + + let main_children = index.get_by_base("main"); + assert_eq!(main_children.len(), 2); + + let feature1_children = index.get_by_base("feature-1"); + assert_eq!(feature1_children.len(), 1); + assert_eq!(feature1_children[0].number(), 3); + + let no_children = index.get_by_base("feature-3"); + assert!(no_children.is_empty()); + } + + // === discover_stack_from_index tests === + + #[test] + fn test_discover_stack_from_index_linear() { + let pr1 = make_test_pr(1, "feature-1", "main"); + let pr2 = make_test_pr(2, "feature-2", "feature-1"); + let pr3 = make_test_pr(3, "feature-3", "feature-2"); + + let index = PrIndex::from_prs(vec![pr1.clone(), pr2, pr3]); + + // Start from middle of stack + let stack = discover_stack_from_index(&index, pr1, "main"); + + assert_eq!(stack.len(), 3); + assert_eq!(stack[0].number(), 1); // bottom + assert_eq!(stack[1].number(), 2); + assert_eq!(stack[2].number(), 3); // top + } + + #[test] + fn test_discover_stack_from_index_from_top() { + let pr1 = make_test_pr(1, "feature-1", "main"); + let pr2 = make_test_pr(2, "feature-2", "feature-1"); + let pr3 = make_test_pr(3, "feature-3", "feature-2"); + + let index = PrIndex::from_prs(vec![pr1, pr2, pr3.clone()]); + + // Start from top of stack + let stack = discover_stack_from_index(&index, pr3, "main"); + + assert_eq!(stack.len(), 3); + assert_eq!(stack[0].number(), 1); // bottom + assert_eq!(stack[2].number(), 3); // top + } + + #[test] + fn test_discover_stack_from_index_single_pr() { + let pr = make_test_pr(1, "feature", "main"); + let index = PrIndex::from_prs(vec![pr.clone()]); + + let stack = discover_stack_from_index(&index, pr, "main"); + + assert_eq!(stack.len(), 1); + assert_eq!(stack[0].number(), 1); + } + + #[test] + fn test_discover_stack_from_index_unrelated_prs() { + // Two separate stacks - should only discover one + let pr1 = make_test_pr(1, "feature-1", "main"); + let pr2 = make_test_pr(2, "other-stack", "main"); + + let index = PrIndex::from_prs(vec![pr1.clone(), pr2]); + + let stack = discover_stack_from_index(&index, pr1, "main"); + + assert_eq!(stack.len(), 1); + assert_eq!(stack[0].number(), 1); + } + + // === group_into_stacks tests === + + #[test] + fn test_group_into_stacks_single_stack() { + let pr1 = make_test_pr(1, "feature-1", "main"); + let pr2 = make_test_pr(2, "feature-2", "feature-1"); + + let stacks = group_into_stacks(vec![pr1, pr2], "main"); + + assert_eq!(stacks.len(), 1); + assert_eq!(stacks[0].len(), 2); + } + + #[test] + fn test_group_into_stacks_multiple_stacks() { + let pr1 = make_test_pr(1, "feature-1", "main"); + let pr2 = make_test_pr(2, "feature-2", "feature-1"); + let pr3 = make_test_pr(3, "other-1", "main"); + + let stacks = group_into_stacks(vec![pr1, pr2, pr3], "main"); + + assert_eq!(stacks.len(), 2); + // Larger stack first + assert_eq!(stacks[0].len(), 2); + assert_eq!(stacks[1].len(), 1); + } + + #[test] + fn test_group_into_stacks_empty() { + let stacks = group_into_stacks(vec![], "main"); + assert!(stacks.is_empty()); + } + + // === sort_stack tests === + + #[test] + fn test_sort_stack_linear() { + let pr1 = make_test_pr(1, "feature-1", "main"); + let pr2 = make_test_pr(2, "feature-2", "feature-1"); + let pr3 = make_test_pr(3, "feature-3", "feature-2"); + + // Give them in wrong order + let prs = vec![pr3, pr1, pr2]; + let sorted = sort_stack(prs, "main"); + + assert_eq!(sorted.len(), 3); + assert_eq!(sorted[0].number(), 1); // base: main + assert_eq!(sorted[1].number(), 2); // base: feature-1 + assert_eq!(sorted[2].number(), 3); // base: feature-2 + } + + #[test] + fn test_sort_stack_single() { + let pr = make_test_pr(1, "feature", "main"); + + let sorted = sort_stack(vec![pr], "main"); + assert_eq!(sorted.len(), 1); + assert_eq!(sorted[0].number(), 1); + } + + #[test] + fn test_sort_stack_empty() { + let sorted = sort_stack(vec![], "main"); + assert!(sorted.is_empty()); + } + + // === API tests with mocks === + + #[tokio::test] + #[serial] + async fn test_fetch_pr_by_head_found() { + let mut server = Server::new_async().await; + + let pr_json = make_pr_json(42, "feature-branch", "main", "Test PR"); + + let mock = server + .mock("GET", "/repos/owner/repo/pulls") + .match_query(mockito::Matcher::AllOf(vec![ + mockito::Matcher::UrlEncoded("state".into(), "open".into()), + mockito::Matcher::UrlEncoded("head".into(), "owner:feature-branch".into()), + ])) + .with_status(200) + .with_header("content-type", "application/json") + .with_body(format!("[{}]", pr_json)) + .create_async() + .await; + + std::env::set_var("GITHUB_API_BASE", server.url()); + + let creds = Credentials::new("test-token"); + let result = fetch_pr_by_head("owner/repo", "feature-branch", &creds).await; + + assert!(result.is_ok()); + let pr = result.unwrap(); + assert!(pr.is_some()); + assert_eq!(pr.unwrap().number(), 42); + + mock.assert_async().await; + } + + #[tokio::test] + #[serial] + async fn test_fetch_pr_by_head_not_found() { + let mut server = Server::new_async().await; + + let mock = server + .mock("GET", "/repos/owner/repo/pulls") + .match_query(mockito::Matcher::AllOf(vec![ + mockito::Matcher::UrlEncoded("state".into(), "open".into()), + mockito::Matcher::UrlEncoded("head".into(), "owner:nonexistent".into()), + ])) + .with_status(200) + .with_header("content-type", "application/json") + .with_body("[]") + .create_async() + .await; + + std::env::set_var("GITHUB_API_BASE", server.url()); + + let creds = Credentials::new("test-token"); + let result = fetch_pr_by_head("owner/repo", "nonexistent", &creds).await; + + assert!(result.is_ok()); + assert!(result.unwrap().is_none()); + + mock.assert_async().await; + } + + #[tokio::test] + #[serial] + async fn test_fetch_pr_by_head_rate_limited() { + let mut server = Server::new_async().await; + + let mock = server + .mock("GET", "/repos/owner/repo/pulls") + .match_query(mockito::Matcher::Any) + .with_status(429) + .with_body("rate limit exceeded") + .create_async() + .await; + + std::env::set_var("GITHUB_API_BASE", server.url()); + + let creds = Credentials::new("test-token"); + let result = fetch_pr_by_head("owner/repo", "feature", &creds).await; + + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("rate limit")); + + mock.assert_async().await; + } + + #[tokio::test] + #[serial] + async fn test_fetch_prs_by_base_multiple() { + let mut server = Server::new_async().await; + + let pr1 = make_pr_json(1, "feature-1", "main", "PR 1"); + let pr2 = make_pr_json(2, "feature-2", "main", "PR 2"); + + let mock = server + .mock("GET", "/repos/owner/repo/pulls") + .match_query(mockito::Matcher::AllOf(vec![ + mockito::Matcher::UrlEncoded("state".into(), "open".into()), + mockito::Matcher::UrlEncoded("base".into(), "main".into()), + ])) + .with_status(200) + .with_header("content-type", "application/json") + .with_body(format!("[{}, {}]", pr1, pr2)) + .create_async() + .await; + + std::env::set_var("GITHUB_API_BASE", server.url()); + + let creds = Credentials::new("test-token"); + let result = fetch_prs_by_base("owner/repo", "main", &creds).await; + + assert!(result.is_ok()); + let prs = result.unwrap(); + assert_eq!(prs.len(), 2); + + mock.assert_async().await; + } + + #[tokio::test] + #[serial] + async fn test_fetch_prs_by_base_empty() { + let mut server = Server::new_async().await; + + let mock = server + .mock("GET", "/repos/owner/repo/pulls") + .match_query(mockito::Matcher::AllOf(vec![ + mockito::Matcher::UrlEncoded("state".into(), "open".into()), + mockito::Matcher::UrlEncoded("base".into(), "feature".into()), + ])) + .with_status(200) + .with_header("content-type", "application/json") + .with_body("[]") + .create_async() + .await; + + std::env::set_var("GITHUB_API_BASE", server.url()); + + let creds = Credentials::new("test-token"); + let result = fetch_prs_by_base("owner/repo", "feature", &creds).await; + + assert!(result.is_ok()); + assert!(result.unwrap().is_empty()); + + mock.assert_async().await; + } + + #[tokio::test] + #[serial] + async fn test_fetch_all_open_prs_single_page() { + let mut server = Server::new_async().await; + + let pr1 = make_pr_json(1, "feature-1", "main", "PR 1"); + let pr2 = make_pr_json(2, "feature-2", "main", "PR 2"); + + let mock = server + .mock("GET", "/repos/owner/repo/pulls") + .match_query(mockito::Matcher::AllOf(vec![ + mockito::Matcher::UrlEncoded("state".into(), "open".into()), + mockito::Matcher::UrlEncoded("per_page".into(), "100".into()), + mockito::Matcher::UrlEncoded("page".into(), "1".into()), + ])) + .with_status(200) + .with_header("content-type", "application/json") + .with_body(format!("[{}, {}]", pr1, pr2)) + .create_async() + .await; + + std::env::set_var("GITHUB_API_BASE", server.url()); + + let creds = Credentials::new("test-token"); + let result = fetch_all_open_prs("owner/repo", &creds).await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap().len(), 2); + + mock.assert_async().await; + } + + #[tokio::test] + #[serial] + async fn test_fetch_all_open_prs_pagination() { + let mut server = Server::new_async().await; + + // Generate 100 PRs for page 1 (triggers pagination) + let page1_prs: Vec = (1..=100) + .map(|i| make_pr_json(i, &format!("feature-{}", i), "main", &format!("PR {}", i))) + .collect(); + let page1_body = format!("[{}]", page1_prs.join(",")); + + // Page 2 has fewer than 100, indicating end + let pr101 = make_pr_json(101, "feature-101", "main", "PR 101"); + + let mock_page1 = server + .mock("GET", "/repos/owner/repo/pulls") + .match_query(mockito::Matcher::AllOf(vec![ + mockito::Matcher::UrlEncoded("state".into(), "open".into()), + mockito::Matcher::UrlEncoded("per_page".into(), "100".into()), + mockito::Matcher::UrlEncoded("page".into(), "1".into()), + ])) + .with_status(200) + .with_header("content-type", "application/json") + .with_body(page1_body) + .create_async() + .await; + + let mock_page2 = server + .mock("GET", "/repos/owner/repo/pulls") + .match_query(mockito::Matcher::AllOf(vec![ + mockito::Matcher::UrlEncoded("state".into(), "open".into()), + mockito::Matcher::UrlEncoded("per_page".into(), "100".into()), + mockito::Matcher::UrlEncoded("page".into(), "2".into()), + ])) + .with_status(200) + .with_header("content-type", "application/json") + .with_body(format!("[{}]", pr101)) + .create_async() + .await; + + std::env::set_var("GITHUB_API_BASE", server.url()); + + let creds = Credentials::new("test-token"); + let result = fetch_all_open_prs("owner/repo", &creds).await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap().len(), 101); + + mock_page1.assert_async().await; + mock_page2.assert_async().await; + } + + #[tokio::test] + #[serial] + async fn test_discover_stack_batch_fetch() { + let mut server = Server::new_async().await; + + // Create a 3-PR stack + let pr1 = make_pr_json(1, "feature-1", "main", "PR 1"); + let pr2 = make_pr_json(2, "feature-2", "feature-1", "PR 2"); + let pr3 = make_pr_json(3, "feature-3", "feature-2", "PR 3"); + + // Single batch fetch should be enough + let mock = server + .mock("GET", "/repos/owner/repo/pulls") + .match_query(mockito::Matcher::AllOf(vec![ + mockito::Matcher::UrlEncoded("state".into(), "open".into()), + mockito::Matcher::UrlEncoded("per_page".into(), "100".into()), + mockito::Matcher::UrlEncoded("page".into(), "1".into()), + ])) + .with_status(200) + .with_header("content-type", "application/json") + .with_body(format!("[{}, {}, {}]", pr1, pr2, pr3)) + .expect(1) // Should only be called once! + .create_async() + .await; + + std::env::set_var("GITHUB_API_BASE", server.url()); + + let creds = Credentials::new("test-token"); + + // Create starting PR + let starting_pr = PullRequest::new_for_test( + 2, + "feature-2", + "feature-1", + "PR 2", + PullRequestStatus::Open, + false, + None, + vec![], + ); + + let result = discover_stack("owner/repo", starting_pr, "main", &creds).await; + + assert!(result.is_ok()); + let stack = result.unwrap(); + assert_eq!(stack.len(), 3); + assert_eq!(stack[0].number(), 1); // bottom + assert_eq!(stack[1].number(), 2); + assert_eq!(stack[2].number(), 3); // top + + mock.assert_async().await; + } +} diff --git a/src/browser.rs b/src/browser.rs new file mode 100644 index 0000000..7e98c0e --- /dev/null +++ b/src/browser.rs @@ -0,0 +1,233 @@ +//! Browser and URL utilities for PR creation +//! +//! Provides cross-platform browser opening and GitHub URL generation +//! for creating pull requests without depending on the `gh` CLI. + +use dialoguer::Confirm; +use std::error::Error; +use std::io::IsTerminal; +use std::process::Command; + +/// Extract GitHub base URL from a git remote URL +/// +/// Returns the base URL (e.g., "https://github.com" or "https://github.mycompany.com") +/// +/// # Examples +/// - `git@github.com:owner/repo.git` → `https://github.com` +/// - `git@github.mycompany.com:org/repo.git` → `https://github.mycompany.com` +/// - `https://github.com/owner/repo.git` → `https://github.com` +pub fn parse_github_host(remote_url: &str) -> Option { + // SSH format: git@:owner/repo.git + if remote_url.starts_with("git@") { + let host = remote_url.strip_prefix("git@")?.split(':').next()?; + return Some(format!("https://{}", host)); + } + + // HTTPS/HTTP format: https:///owner/repo.git + if remote_url.starts_with("https://") || remote_url.starts_with("http://") { + let without_protocol = remote_url.split("://").nth(1)?; + let host = without_protocol.split('/').next()?; + let protocol = if remote_url.starts_with("https://") { + "https" + } else { + "http" + }; + return Some(format!("{}://{}", protocol, host)); + } + + None +} + +/// Build GitHub PR creation URL with pre-filled branches +/// +/// The URL opens GitHub's compare view with the PR creation form expanded. +pub fn build_pr_url(github_host: &str, repo: &str, base: &str, head: &str) -> String { + format!( + "{}/{}/compare/{}...{}?expand=1", + github_host, repo, base, head + ) +} + +/// Open URL in default browser (cross-platform) +/// +/// Uses platform-specific commands: +/// - macOS: `open` +/// - Linux: `xdg-open` +/// - Windows: `cmd /C start` +pub fn open_url(url: &str) -> Result<(), Box> { + #[cfg(target_os = "macos")] + { + Command::new("open").arg(url).status()?; + } + + #[cfg(target_os = "linux")] + { + Command::new("xdg-open").arg(url).status()?; + } + + #[cfg(target_os = "windows")] + { + Command::new("cmd") + .args(["/C", "start", "", url]) + .status()?; + } + + #[cfg(not(any(target_os = "macos", target_os = "linux", target_os = "windows")))] + { + return Err("Unsupported platform for opening browser".into()); + } + + Ok(()) +} + +/// Print URL for creating PR (non-interactive/CI mode) +pub fn suggest_create_pr(github_host: &str, repo: &str, head: &str, base: &str) { + let url = build_pr_url(github_host, repo, base, head); + println!("No PR found for branch '{}'.\n", head); + println!("Create a PR at:"); + println!(" {}\n", url); +} + +/// Prompt user and open browser to create PR (interactive mode) +/// +/// Returns `Ok(true)` if user chose to open browser, `Ok(false)` if declined. +/// In non-interactive mode, prints the URL and returns `Ok(false)`. +pub fn prompt_create_pr( + github_host: &str, + repo: &str, + head: &str, + base: &str, +) -> Result> { + if !std::io::stdout().is_terminal() { + suggest_create_pr(github_host, repo, head, base); + return Ok(false); + } + + let url = build_pr_url(github_host, repo, base, head); + println!("No PR found for branch '{}'.\n", head); + + let open = Confirm::new() + .with_prompt(format!( + "Open browser to create PR from '{}' into '{}'?", + head, base + )) + .default(true) + .interact()?; + + if open { + println!("\nOpening browser..."); + println!(" {}\n", url); + open_url(&url)?; + } + + Ok(open) +} + +#[cfg(test)] +mod tests { + use super::*; + + // === parse_github_host tests === + + #[test] + fn test_parse_github_host_ssh() { + assert_eq!( + parse_github_host("git@github.com:owner/repo.git"), + Some("https://github.com".to_string()) + ); + } + + #[test] + fn test_parse_github_host_ssh_no_suffix() { + assert_eq!( + parse_github_host("git@github.com:owner/repo"), + Some("https://github.com".to_string()) + ); + } + + #[test] + fn test_parse_github_host_ssh_enterprise() { + assert_eq!( + parse_github_host("git@github.mycompany.com:org/repo.git"), + Some("https://github.mycompany.com".to_string()) + ); + } + + #[test] + fn test_parse_github_host_https() { + assert_eq!( + parse_github_host("https://github.com/owner/repo.git"), + Some("https://github.com".to_string()) + ); + } + + #[test] + fn test_parse_github_host_https_no_suffix() { + assert_eq!( + parse_github_host("https://github.com/owner/repo"), + Some("https://github.com".to_string()) + ); + } + + #[test] + fn test_parse_github_host_https_enterprise() { + assert_eq!( + parse_github_host("https://github.mycompany.com/org/repo.git"), + Some("https://github.mycompany.com".to_string()) + ); + } + + #[test] + fn test_parse_github_host_http() { + assert_eq!( + parse_github_host("http://github.com/owner/repo.git"), + Some("http://github.com".to_string()) + ); + } + + #[test] + fn test_parse_github_host_invalid() { + assert_eq!(parse_github_host("not-a-url"), None); + } + + #[test] + fn test_parse_github_host_empty() { + assert_eq!(parse_github_host(""), None); + } + + // === build_pr_url tests === + + #[test] + fn test_build_pr_url() { + assert_eq!( + build_pr_url("https://github.com", "owner/repo", "main", "feature"), + "https://github.com/owner/repo/compare/main...feature?expand=1" + ); + } + + #[test] + fn test_build_pr_url_enterprise() { + assert_eq!( + build_pr_url( + "https://github.mycompany.com", + "org/repo", + "develop", + "my-branch" + ), + "https://github.mycompany.com/org/repo/compare/develop...my-branch?expand=1" + ); + } + + #[test] + fn test_build_pr_url_with_slashes_in_branch() { + assert_eq!( + build_pr_url( + "https://github.com", + "owner/repo", + "main", + "feature/my-feature" + ), + "https://github.com/owner/repo/compare/main...feature/my-feature?expand=1" + ); + } +} diff --git a/src/identifier.rs b/src/identifier.rs new file mode 100644 index 0000000..25680fb --- /dev/null +++ b/src/identifier.rs @@ -0,0 +1,356 @@ +//! Identifier detection and interactive prompts +//! +//! Provides trunk branch detection and interactive stack selection +//! for the smart-default log command. + +use crate::api::PullRequest; +use dialoguer::{Input, Select}; +use std::error::Error; +use std::process::Command; + +/// Common trunk branch names +const TRUNK_BRANCHES: &[&str] = &["main", "master", "develop", "dev", "trunk"]; + +/// Summary of a stack for display in selection UI +#[derive(Debug, Clone)] +pub struct StackSummary { + /// The root branch name (first PR's head) + pub root_branch: String, + /// Number of PRs in the stack + pub pr_count: usize, + /// PR numbers in the stack + pub pr_numbers: Vec, + /// First part of the root PR's title + pub title_snippet: String, +} + +impl StackSummary { + /// Create a summary from a list of PRs + /// + /// PRs should be sorted bottom-to-top (root first) + pub fn from_prs(prs: &[PullRequest], _trunk: &str) -> Self { + let root_branch = prs + .first() + .map(|pr| pr.head().to_string()) + .unwrap_or_default(); + + let pr_numbers: Vec = prs.iter().map(|pr| pr.number()).collect(); + + let title_snippet = prs + .first() + .map(|pr| { + let title = pr.raw_title(); + if title.len() > 40 { + format!("{}...", &title[..37]) + } else { + title.to_string() + } + }) + .unwrap_or_default(); + + StackSummary { + root_branch, + pr_count: prs.len(), + pr_numbers, + title_snippet, + } + } + + /// Format for display in selection list + pub fn display(&self) -> String { + let prs = self + .pr_numbers + .iter() + .map(|n| format!("#{}", n)) + .collect::>() + .join(", "); + + format!( + "{} ({} PR{}): {}", + self.root_branch, + self.pr_count, + if self.pr_count == 1 { "" } else { "s" }, + prs + ) + } +} + +/// Check if a branch name is a trunk branch +/// +/// Returns true if the branch matches the configured trunk or any common trunk name. +pub fn is_trunk_branch(branch: &str, configured_trunk: Option<&str>) -> bool { + if let Some(trunk) = configured_trunk { + if branch == trunk { + return true; + } + } + + TRUNK_BRANCHES.contains(&branch) +} + +/// Detect the trunk branch from git remote's default branch +/// +/// Runs `git remote show origin` to find the HEAD branch. +pub fn detect_trunk_branch() -> Option { + // Try to get the default branch from the remote + let output = Command::new("git") + .args(["remote", "show", "origin"]) + .output() + .ok()?; + + if !output.status.success() { + return None; + } + + let stdout = String::from_utf8_lossy(&output.stdout); + + // Look for "HEAD branch: " + for line in stdout.lines() { + let line = line.trim(); + if line.starts_with("HEAD branch:") { + return line.split(':').nth(1).map(|s| s.trim().to_string()); + } + } + + None +} + +/// Action to take when on trunk branch +#[derive(Debug, Clone, PartialEq)] +pub enum TrunkAction { + /// User entered an identifier manually + EnterIdentifier(String), + /// User selected a stack by index + SelectStack(usize), + /// User cancelled + Cancel, +} + +/// Prompt user for action when on trunk branch +/// +/// Shows options to enter an identifier or select from detected stacks. +pub fn prompt_trunk_action(stacks: &[StackSummary]) -> Result> { + let mut items = vec!["Enter a stack identifier".to_string()]; + + if !stacks.is_empty() { + items.push(format!( + "Select from detected stacks ({} found)", + stacks.len() + )); + } + + items.push("Cancel".to_string()); + + let selection = Select::new() + .with_prompt("You're on a trunk branch. What would you like to do?") + .items(&items) + .default(0) + .interact()?; + + match selection { + 0 => { + // Enter identifier + let identifier: String = Input::new() + .with_prompt("Enter stack identifier") + .interact_text()?; + + if identifier.is_empty() { + Ok(TrunkAction::Cancel) + } else { + Ok(TrunkAction::EnterIdentifier(identifier)) + } + } + idx if idx == items.len() - 1 => Ok(TrunkAction::Cancel), + 1 if !stacks.is_empty() => { + // Select from stacks + let stack_idx = prompt_select_stack(stacks)?; + Ok(TrunkAction::SelectStack(stack_idx)) + } + _ => Ok(TrunkAction::Cancel), + } +} + +/// Prompt user to select a stack from a list +/// +/// Returns the index of the selected stack. +pub fn prompt_select_stack(stacks: &[StackSummary]) -> Result> { + if stacks.is_empty() { + return Err("No stacks to select from".into()); + } + + let items: Vec = stacks.iter().map(|s| s.display()).collect(); + + let selection = Select::new() + .with_prompt("Select a stack") + .items(&items) + .default(0) + .interact()?; + + Ok(selection) +} + +/// Prompt user to enter an identifier manually +pub fn prompt_identifier() -> Result> { + let identifier: String = Input::new() + .with_prompt("Enter stack identifier") + .interact_text()?; + + Ok(identifier) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_is_trunk_branch_main() { + assert!(is_trunk_branch("main", None)); + } + + #[test] + fn test_is_trunk_branch_master() { + assert!(is_trunk_branch("master", None)); + } + + #[test] + fn test_is_trunk_branch_develop() { + assert!(is_trunk_branch("develop", None)); + } + + #[test] + fn test_is_trunk_branch_feature_returns_false() { + assert!(!is_trunk_branch("feat/my-feature", None)); + assert!(!is_trunk_branch("feature-branch", None)); + assert!(!is_trunk_branch("fix/bug", None)); + } + + #[test] + fn test_is_trunk_branch_configured() { + // Custom trunk takes precedence + assert!(is_trunk_branch("production", Some("production"))); + // But common trunks still work + assert!(is_trunk_branch("main", Some("production"))); + } + + #[test] + fn test_is_trunk_branch_configured_not_in_common() { + // A configured trunk that's not in TRUNK_BRANCHES should still match + assert!(is_trunk_branch("release", Some("release"))); + // But other branches shouldn't + assert!(!is_trunk_branch("feature", Some("release"))); + } + + #[test] + fn test_stack_summary_from_prs_single() { + use crate::api::PullRequestStatus; + + let pr = PullRequest::new_for_test( + 42, + "feat/my-feature", + "main", + "Add awesome feature", + PullRequestStatus::Open, + false, + None, + vec![], + ); + + let summary = StackSummary::from_prs(&[pr], "main"); + + assert_eq!(summary.root_branch, "feat/my-feature"); + assert_eq!(summary.pr_count, 1); + assert_eq!(summary.pr_numbers, vec![42]); + assert_eq!(summary.title_snippet, "Add awesome feature"); + } + + #[test] + fn test_stack_summary_from_prs_multiple() { + use crate::api::PullRequestStatus; + + let pr1 = PullRequest::new_for_test( + 1, + "feat/part-1", + "main", + "Part 1: Initial setup", + PullRequestStatus::Open, + false, + None, + vec![], + ); + let pr2 = PullRequest::new_for_test( + 2, + "feat/part-2", + "feat/part-1", + "Part 2: Implementation", + PullRequestStatus::Open, + false, + None, + vec![], + ); + + let summary = StackSummary::from_prs(&[pr1, pr2], "main"); + + assert_eq!(summary.root_branch, "feat/part-1"); + assert_eq!(summary.pr_count, 2); + assert_eq!(summary.pr_numbers, vec![1, 2]); + } + + #[test] + fn test_stack_summary_truncates_long_title() { + use crate::api::PullRequestStatus; + + let pr = PullRequest::new_for_test( + 1, + "feat/long", + "main", + "This is a very long title that should be truncated because it exceeds forty characters", + PullRequestStatus::Open, + false, + None, + vec![], + ); + + let summary = StackSummary::from_prs(&[pr], "main"); + + assert!(summary.title_snippet.len() <= 43); // 40 + "..." + assert!(summary.title_snippet.ends_with("...")); + } + + #[test] + fn test_stack_summary_display() { + let summary = StackSummary { + root_branch: "feat/my-feature".to_string(), + pr_count: 2, + pr_numbers: vec![42, 43], + title_snippet: "Add feature".to_string(), + }; + + let display = summary.display(); + assert!(display.contains("feat/my-feature")); + assert!(display.contains("2 PRs")); + assert!(display.contains("#42")); + assert!(display.contains("#43")); + } + + #[test] + fn test_stack_summary_display_single() { + let summary = StackSummary { + root_branch: "feat/single".to_string(), + pr_count: 1, + pr_numbers: vec![99], + title_snippet: "Single PR".to_string(), + }; + + let display = summary.display(); + assert!(display.contains("1 PR)")); // Note: no 's' + } + + #[test] + fn test_stack_summary_empty() { + let summary = StackSummary::from_prs(&[], "main"); + + assert!(summary.root_branch.is_empty()); + assert_eq!(summary.pr_count, 0); + assert!(summary.pr_numbers.is_empty()); + } +} diff --git a/src/lib.rs b/src/lib.rs index 7ad85ff..5156b2a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,8 @@ pub mod api; +pub mod browser; pub mod git; pub mod graph; +pub mod identifier; pub mod land; pub mod markdown; pub mod persist; diff --git a/src/main.rs b/src/main.rs index 2f3c392..7ea51d0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,15 +4,17 @@ use git2::Repository; use regex::Regex; use std::env; use std::error::Error; +use std::io::IsTerminal; use std::rc::Rc; use gh_stack::api::PullRequest; use gh_stack::graph::FlatDep; +use gh_stack::identifier::{self, StackSummary, TrunkAction}; use gh_stack::land::{self, LandError, LandOptions}; use gh_stack::status::{self, StatusConfig}; use gh_stack::util::loop_until_confirm; use gh_stack::Credentials; -use gh_stack::{api, git, graph, markdown, persist, tree}; +use gh_stack::{api, browser, git, graph, markdown, persist, tree}; fn clap<'a, 'b>() -> App<'a, 'b> { let identifier = Arg::with_name("identifier") @@ -71,13 +73,44 @@ fn clap<'a, 'b>() -> App<'a, 'b> { .value_name("FILE") .help("Prepend the annotation with the contents of this file")); + // For log command, identifier is optional (can infer from branch) + let log_identifier = Arg::with_name("identifier") + .index(1) + .required(false) + .help("Stack identifier in PR titles (optional - infers from current branch if omitted)"); + let log = SubCommand::with_name("log") .about("Print a visual tree of all pull requests in a stack") - .setting(AppSettings::ArgRequiredElseHelp) - .arg(identifier.clone()) + .arg(log_identifier) .arg(exclude.clone()) .arg(repository.clone()) .arg(origin.clone()) + .arg( + Arg::with_name("branch") + .long("branch") + .short("b") + .takes_value(true) + .help("Infer stack from this branch instead of current branch"), + ) + .arg( + Arg::with_name("all") + .long("all") + .short("a") + .takes_value(false) + .help("List all detected stacks and select interactively"), + ) + .arg( + Arg::with_name("ci") + .long("ci") + .takes_value(false) + .help("Non-interactive mode for CI (requires identifier or --branch)"), + ) + .arg( + Arg::with_name("trunk") + .long("trunk") + .takes_value(true) + .help("Trunk branch name (default: auto-detect or 'main')"), + ) .arg( Arg::with_name("short") .long("short") @@ -107,15 +140,52 @@ fn clap<'a, 'b>() -> App<'a, 'b> { .long("status") .takes_value(false) .help("Show status bits (CI, approval, conflicts, stack health)"), + ) + .arg( + Arg::with_name("create-pr") + .long("create-pr") + .takes_value(false) + .help("Create PR via GitHub API if branch has no PR"), ); + // For status command, identifier is optional (can infer from branch) + let status_identifier = Arg::with_name("identifier") + .index(1) + .required(false) + .help("Stack identifier in PR titles (optional - infers from current branch if omitted)"); + let status_cmd = SubCommand::with_name("status") .about("Show stack status with CI, approval, and merge readiness indicators") - .setting(AppSettings::ArgRequiredElseHelp) - .arg(identifier.clone()) + .arg(status_identifier) .arg(exclude.clone()) .arg(repository.clone()) .arg(origin.clone()) + .arg( + Arg::with_name("branch") + .long("branch") + .short("b") + .takes_value(true) + .help("Infer stack from this branch instead of current branch"), + ) + .arg( + Arg::with_name("all") + .long("all") + .short("a") + .takes_value(false) + .help("List all detected stacks and select interactively"), + ) + .arg( + Arg::with_name("ci") + .long("ci") + .takes_value(false) + .help("Non-interactive mode for CI (requires identifier or --branch)"), + ) + .arg( + Arg::with_name("trunk") + .long("trunk") + .takes_value(true) + .help("Trunk branch name (default: auto-detect or 'main')"), + ) .arg( Arg::with_name("project") .long("project") @@ -146,6 +216,12 @@ fn clap<'a, 'b>() -> App<'a, 'b> { .long("json") .takes_value(false) .help("Output in JSON format"), + ) + .arg( + Arg::with_name("create-pr") + .long("create-pr") + .takes_value(false) + .help("Create PR via GitHub API if branch has no PR"), ); let autorebase = SubCommand::with_name("autorebase") @@ -310,6 +386,102 @@ fn remove_title_prefixes(title: String, prefix: &str) -> String { regex.replace_all(&title, "").into_owned() } +/// Get GitHub host URL from repository's git remote +fn get_github_host(repo: Option<&Repository>, remote_name: &str) -> String { + repo.and_then(|r| tree::get_remote_url(r, remote_name)) + .and_then(|url| browser::parse_github_host(&url)) + .unwrap_or_else(|| "https://github.com".to_string()) +} + +/// Poll for PR existence with timeout +async fn wait_for_pr( + repository: &str, + branch: &str, + credentials: &Credentials, + timeout_secs: u64, +) -> Option { + use std::time::{Duration, Instant}; + + let start = Instant::now(); + let timeout = Duration::from_secs(timeout_secs); + let poll_interval = Duration::from_secs(2); + + println!("Waiting for PR to be created..."); + + while start.elapsed() < timeout { + if let Ok(Some(pr)) = api::stack::fetch_pr_by_head(repository, branch, credentials).await { + println!("Found PR #{}!", pr.number()); + return Some(pr); + } + print!("."); + use std::io::Write; + std::io::stdout().flush().ok(); + tokio::time::sleep(poll_interval).await; + } + println!(); + + None +} + +/// Create a PR via GitHub API with interactive prompts for title/body +async fn create_pr_interactive( + repository: &str, + head: &str, + base: &str, + repo: Option<&Repository>, + identifier: Option<&str>, + credentials: &Credentials, + ci_mode: bool, +) -> Result<(usize, String), Box> { + use dialoguer::Input; + + // Get default title from first commit message + let default_title = repo + .and_then(|r| tree::first_commit_message(r, head, base)) + .unwrap_or_else(|| head.to_string()); + + // Add identifier prefix if we have one + let default_title = if let Some(id) = identifier { + format!("[{}] {}", id, default_title) + } else { + default_title + }; + + // Get default body with stack marker + let default_body = format!("", identifier.unwrap_or("")); + + let (title, body) = if ci_mode || !std::io::stdout().is_terminal() { + // Non-interactive: use defaults + (default_title, default_body) + } else { + // Interactive: prompt for title and body + let title: String = Input::new() + .with_prompt("PR title") + .default(default_title) + .interact_text()?; + + let body: String = Input::new() + .with_prompt("PR body (optional)") + .default(default_body) + .allow_empty(true) + .interact_text()?; + + (title, body) + }; + + println!("\nCreating PR..."); + let body_opt = if body.is_empty() { + None + } else { + Some(body.as_str()) + }; + let (pr_num, url) = + api::create::create_pr(repository, head, base, &title, body_opt, credentials).await?; + println!("Created: {} (PR #{})\n", style(&url).cyan(), pr_num); + + Ok((pr_num, url)) +} + #[tokio::main] async fn main() -> Result<(), Box> { dotenvy::from_filename(".gh-stack.env").ok(); @@ -368,29 +540,291 @@ async fn main() -> Result<(), Box> { } ("log", Some(m)) => { - let identifier = m.value_of("identifier").unwrap(); - - // resolve repository with fallback chain + let explicit_identifier = m.value_of("identifier"); + let branch_override = m.value_of("branch"); + let show_all = m.is_present("all"); + let ci_mode = m.is_present("ci"); + + // Resolve trunk branch + let trunk = m + .value_of("trunk") + .map(String::from) + .or_else(identifier::detect_trunk_branch) + .unwrap_or_else(|| "main".to_string()); + + // Resolve repository with fallback chain let remote_name = m.value_of("origin").unwrap_or("origin"); let repository = resolve_repository(m.value_of("repository"), &repository, remote_name) .unwrap_or_else(|e| panic!("{}", e)); - println!( - "Searching for {} identifier in {} repo", - style(identifier).bold(), - style(&repository).bold() - ); - let stack = - build_pr_stack_for_repo(identifier, &repository, &credentials, get_excluded(m)) - .await?; + // Determine how to find the stack + let stack: FlatDep = if let Some(id) = explicit_identifier { + // === EXISTING BEHAVIOR: Search by identifier === + println!( + "Searching for {} identifier in {} repo", + style(id).bold(), + style(&repository).bold() + ); + build_pr_stack_for_repo(id, &repository, &credentials, get_excluded(m)).await? + } else if show_all { + // === NEW: --all flag - show all stacks === + if ci_mode { + eprintln!( + "{} --all requires interactive mode (incompatible with --ci)", + style("Error:").red().bold() + ); + std::process::exit(1); + } + + println!("Discovering stacks in {}...", style(&repository).bold()); + + let stacks = + api::stack::discover_all_stacks(&repository, &trunk, &credentials).await?; + + if stacks.is_empty() { + println!("No open stacks found."); + return Ok(()); + } + + let summaries: Vec = stacks + .iter() + .map(|s| StackSummary::from_prs(s, &trunk)) + .collect(); + + let selected = identifier::prompt_select_stack(&summaries)?; + + // Convert selected stack to FlatDep + let prs: Vec> = + stacks[selected].iter().cloned().map(Rc::new).collect(); + let g = graph::build(&prs); + graph::log(&g) + } else { + // === NEW: Infer from current/specified branch === + let repo_handle = m + .value_of("project") + .and_then(|p| Repository::open(p).ok()) + .or_else(tree::detect_repo); + + let branch = branch_override + .map(String::from) + .or_else(|| repo_handle.as_ref().and_then(tree::current_branch)); + + match branch { + Some(branch) => { + // Check if on trunk + if identifier::is_trunk_branch(&branch, Some(&trunk)) { + if ci_mode { + eprintln!( + "{} On trunk branch '{}'. Provide an identifier or use --branch.", + style("Error:").red().bold(), + branch + ); + std::process::exit(1); + } + + println!("You're on '{}' (trunk branch).\n", style(&branch).cyan()); + + // Discover all stacks for selection + let stacks = + api::stack::discover_all_stacks(&repository, &trunk, &credentials) + .await?; + + let summaries: Vec = stacks + .iter() + .map(|s| StackSummary::from_prs(s, &trunk)) + .collect(); + + match identifier::prompt_trunk_action(&summaries)? { + TrunkAction::EnterIdentifier(id) => { + println!( + "\nSearching for {} identifier in {} repo", + style(&id).bold(), + style(&repository).bold() + ); + build_pr_stack_for_repo( + &id, + &repository, + &credentials, + get_excluded(m), + ) + .await? + } + TrunkAction::SelectStack(idx) => { + let prs: Vec> = + stacks[idx].iter().cloned().map(Rc::new).collect(); + let g = graph::build(&prs); + graph::log(&g) + } + TrunkAction::Cancel => { + return Ok(()); + } + } + } else { + // Discover stack from branch + println!( + "Discovering stack for branch '{}'...", + style(&branch).cyan() + ); + + match api::stack::fetch_pr_by_head(&repository, &branch, &credentials) + .await? + { + Some(pr) => { + let prs = api::stack::discover_stack( + &repository, + pr, + &trunk, + &credentials, + ) + .await?; + let prs: Vec> = + prs.into_iter().map(Rc::new).collect(); + let g = graph::build(&prs); + graph::log(&g) + } + None => { + // No PR found for this branch + let create_pr_flag = m.is_present("create-pr"); + let github_host = + get_github_host(repo_handle.as_ref(), remote_name); + + if create_pr_flag { + // Create PR via API + let (_pr_num, _url) = create_pr_interactive( + &repository, + &branch, + &trunk, + repo_handle.as_ref(), + None, // No identifier known + &credentials, + ci_mode, + ) + .await?; + + // Retry discovery + println!("Retrying stack discovery..."); + if let Some(pr) = api::stack::fetch_pr_by_head( + &repository, + &branch, + &credentials, + ) + .await? + { + let prs = api::stack::discover_stack( + &repository, + pr, + &trunk, + &credentials, + ) + .await?; + let prs: Vec> = + prs.into_iter().map(Rc::new).collect(); + let g = graph::build(&prs); + graph::log(&g) + } else { + eprintln!("PR creation succeeded but could not find PR. Try again."); + return Ok(()); + } + } else if ci_mode { + // CI mode without --create-pr: print URL and exit + eprintln!( + "{} No PR found for branch '{}'", + style("Error:").red().bold(), + branch + ); + browser::suggest_create_pr( + &github_host, + &repository, + &branch, + &trunk, + ); + std::process::exit(1); + } else if std::io::stdout().is_terminal() { + // Interactive: prompt to open browser + if browser::prompt_create_pr( + &github_host, + &repository, + &branch, + &trunk, + )? { + // Poll for PR creation + if let Some(pr) = + wait_for_pr(&repository, &branch, &credentials, 30) + .await + { + println!("\nRetrying stack discovery..."); + let prs = api::stack::discover_stack( + &repository, + pr, + &trunk, + &credentials, + ) + .await?; + let prs: Vec> = + prs.into_iter().map(Rc::new).collect(); + let g = graph::build(&prs); + graph::log(&g) + } else { + eprintln!("\nTimed out waiting for PR. Run the command again after creating the PR."); + return Ok(()); + } + } else { + return Ok(()); + } + } else { + // Non-interactive, non-CI: just print URL + browser::suggest_create_pr( + &github_host, + &repository, + &branch, + &trunk, + ); + return Ok(()); + } + } + } + } + } + None => { + // Could not determine branch + if ci_mode { + eprintln!( + "{} Could not determine current branch. Provide identifier or --branch.", + style("Error:").red().bold() + ); + std::process::exit(1); + } + + eprintln!( + "{} Could not determine current branch.\n", + style("Note:").yellow() + ); + eprintln!("You can:"); + eprintln!(" - Run from inside a git repository"); + eprintln!( + " - Use {} to specify a branch", + style("--branch ").cyan() + ); + eprintln!( + " - Provide an identifier directly: {}", + style("gh-stack log 'STACK-ID'").cyan() + ); + return Ok(()); + } + } + }; // Check for empty stack if stack.is_empty() { - println!("No PRs found matching '{}'", identifier); + if let Some(id) = explicit_identifier { + println!("No PRs found matching '{}'", id); + } else { + println!("No PRs found in stack."); + } return Ok(()); } - // Check if --status flag is set (delegate to status handler) + // === DISPLAY LOGIC (unchanged) === if m.is_present("status") { let no_color = m.is_present("no-color"); let show_legend = status::should_show_legend(); @@ -440,7 +874,7 @@ async fn main() -> Result<(), Box> { } } } else { - // New tree view (default) + // Tree view (default) let no_color = m.is_present("no-color"); let mut config = tree::TreeConfig::detect(no_color); config.include_closed = m.is_present("include-closed"); @@ -613,33 +1047,326 @@ async fn main() -> Result<(), Box> { } ("status", Some(m)) => { - let identifier = m.value_of("identifier").unwrap(); + let explicit_identifier = m.value_of("identifier"); + let branch_override = m.value_of("branch"); + let show_all = m.is_present("all"); + let ci_mode = m.is_present("ci"); + let json_output = m.is_present("json"); - // resolve repository with fallback chain + // Resolve trunk branch + let trunk = m + .value_of("trunk") + .map(String::from) + .or_else(identifier::detect_trunk_branch) + .unwrap_or_else(|| "main".to_string()); + + // Resolve repository with fallback chain let remote_name = m.value_of("origin").unwrap_or("origin"); let repository = resolve_repository(m.value_of("repository"), &repository, remote_name) .unwrap_or_else(|e| panic!("{}", e)); - let json_output = m.is_present("json"); + // Determine how to find the stack + let stack: FlatDep = if let Some(id) = explicit_identifier { + // === EXISTING BEHAVIOR: Search by identifier === + if !json_output { + println!( + "Searching for {} identifier in {} repo", + style(id).bold(), + style(&repository).bold() + ); + } + build_pr_stack_for_repo(id, &repository, &credentials, get_excluded(m)).await? + } else if show_all { + // === NEW: --all flag - show all stacks === + if ci_mode { + eprintln!( + "{} --all requires interactive mode (incompatible with --ci)", + style("Error:").red().bold() + ); + std::process::exit(1); + } - if !json_output { - println!( - "Searching for {} identifier in {} repo", - style(identifier).bold(), - style(&repository).bold() - ); - } + if !json_output { + println!("Discovering stacks in {}...", style(&repository).bold()); + } - let stack = - build_pr_stack_for_repo(identifier, &repository, &credentials, get_excluded(m)) - .await?; + let stacks = + api::stack::discover_all_stacks(&repository, &trunk, &credentials).await?; + + if stacks.is_empty() { + if json_output { + println!(r#"{{"stack": [], "trunk": "{}"}}"#, trunk); + } else { + println!("No open stacks found."); + } + return Ok(()); + } + + let summaries: Vec = stacks + .iter() + .map(|s| StackSummary::from_prs(s, &trunk)) + .collect(); + + let selected = identifier::prompt_select_stack(&summaries)?; + + // Convert selected stack to FlatDep + let prs: Vec> = + stacks[selected].iter().cloned().map(Rc::new).collect(); + let g = graph::build(&prs); + graph::log(&g) + } else { + // === NEW: Infer from current/specified branch === + let repo_handle = m + .value_of("project") + .and_then(|p| Repository::open(p).ok()) + .or_else(tree::detect_repo); + + let branch = branch_override + .map(String::from) + .or_else(|| repo_handle.as_ref().and_then(tree::current_branch)); + + match branch { + Some(branch) => { + // Check if on trunk + if identifier::is_trunk_branch(&branch, Some(&trunk)) { + if ci_mode { + eprintln!( + "{} On trunk branch '{}'. Provide an identifier or use --branch.", + style("Error:").red().bold(), + branch + ); + std::process::exit(1); + } + + if !json_output { + println!("You're on '{}' (trunk branch).\n", style(&branch).cyan()); + } + + // Discover all stacks for selection + let stacks = + api::stack::discover_all_stacks(&repository, &trunk, &credentials) + .await?; + + let summaries: Vec = stacks + .iter() + .map(|s| StackSummary::from_prs(s, &trunk)) + .collect(); + + match identifier::prompt_trunk_action(&summaries)? { + TrunkAction::EnterIdentifier(id) => { + if !json_output { + println!( + "\nSearching for {} identifier in {} repo", + style(&id).bold(), + style(&repository).bold() + ); + } + build_pr_stack_for_repo( + &id, + &repository, + &credentials, + get_excluded(m), + ) + .await? + } + TrunkAction::SelectStack(idx) => { + let prs: Vec> = + stacks[idx].iter().cloned().map(Rc::new).collect(); + let g = graph::build(&prs); + graph::log(&g) + } + TrunkAction::Cancel => { + return Ok(()); + } + } + } else { + // Discover stack from branch + if !json_output { + println!( + "Discovering stack for branch '{}'...", + style(&branch).cyan() + ); + } + + match api::stack::fetch_pr_by_head(&repository, &branch, &credentials) + .await? + { + Some(pr) => { + let prs = api::stack::discover_stack( + &repository, + pr, + &trunk, + &credentials, + ) + .await?; + let prs: Vec> = + prs.into_iter().map(Rc::new).collect(); + let g = graph::build(&prs); + graph::log(&g) + } + None => { + // No PR found for this branch + let create_pr_flag = m.is_present("create-pr"); + let github_host = + get_github_host(repo_handle.as_ref(), remote_name); + + if create_pr_flag { + // Create PR via API + let (_pr_num, _url) = create_pr_interactive( + &repository, + &branch, + &trunk, + repo_handle.as_ref(), + None, // No identifier known + &credentials, + ci_mode, + ) + .await?; + + // Retry discovery + if !json_output { + println!("Retrying stack discovery..."); + } + if let Some(pr) = api::stack::fetch_pr_by_head( + &repository, + &branch, + &credentials, + ) + .await? + { + let prs = api::stack::discover_stack( + &repository, + pr, + &trunk, + &credentials, + ) + .await?; + let prs: Vec> = + prs.into_iter().map(Rc::new).collect(); + let g = graph::build(&prs); + graph::log(&g) + } else { + if json_output { + println!( + r#"{{"error": "PR created but not found", "stack": [], "trunk": "{}"}}"#, + trunk + ); + } else { + eprintln!("PR creation succeeded but could not find PR. Try again."); + } + return Ok(()); + } + } else if ci_mode || json_output { + // CI/JSON mode without --create-pr + if json_output { + println!( + r#"{{"error": "No PR found for branch '{}'", "stack": [], "trunk": "{}"}}"#, + branch, trunk + ); + } else { + eprintln!( + "{} No PR found for branch '{}'", + style("Error:").red().bold(), + branch + ); + browser::suggest_create_pr( + &github_host, + &repository, + &branch, + &trunk, + ); + } + std::process::exit(1); + } else if std::io::stdout().is_terminal() { + // Interactive: prompt to open browser + if browser::prompt_create_pr( + &github_host, + &repository, + &branch, + &trunk, + )? { + // Poll for PR creation + if let Some(pr) = + wait_for_pr(&repository, &branch, &credentials, 30) + .await + { + println!("\nRetrying stack discovery..."); + let prs = api::stack::discover_stack( + &repository, + pr, + &trunk, + &credentials, + ) + .await?; + let prs: Vec> = + prs.into_iter().map(Rc::new).collect(); + let g = graph::build(&prs); + graph::log(&g) + } else { + eprintln!("\nTimed out waiting for PR. Run the command again after creating the PR."); + return Ok(()); + } + } else { + return Ok(()); + } + } else { + // Non-interactive: just print URL + browser::suggest_create_pr( + &github_host, + &repository, + &branch, + &trunk, + ); + return Ok(()); + } + } + } + } + } + None => { + // Could not determine branch + if ci_mode || json_output { + if json_output { + println!( + r#"{{"error": "Could not determine current branch", "stack": [], "trunk": "{}"}}"#, + trunk + ); + } else { + eprintln!( + "{} Could not determine current branch. Provide identifier or --branch.", + style("Error:").red().bold() + ); + } + std::process::exit(1); + } + + eprintln!( + "{} Could not determine current branch.\n", + style("Note:").yellow() + ); + eprintln!("You can:"); + eprintln!(" - Run from inside a git repository"); + eprintln!( + " - Use {} to specify a branch", + style("--branch ").cyan() + ); + eprintln!( + " - Provide an identifier directly: {}", + style("gh-stack status 'STACK-ID'").cyan() + ); + return Ok(()); + } + } + }; // Check for empty stack if stack.is_empty() { if json_output { - println!(r#"{{"stack": [], "trunk": "main"}}"#); + println!(r#"{{"stack": [], "trunk": "{}"}}"#, trunk); + } else if let Some(id) = explicit_identifier { + println!("No PRs found matching '{}'", id); } else { - println!("No PRs found matching '{}'", identifier); + println!("No PRs found in stack."); } return Ok(()); } diff --git a/src/status.rs b/src/status.rs index 18f85b0..f9d8edc 100644 --- a/src/status.rs +++ b/src/status.rs @@ -2,9 +2,16 @@ //! //! This module provides functionality to display stack status with CI, approval, //! merge, and stack health indicators. +//! +//! ## Performance +//! +//! Status checks are fetched in parallel using `futures::join_all` to minimize +//! latency when checking multiple PRs. use std::path::PathBuf; +use std::rc::Rc; +use futures::future::join_all; use git2::Repository; use serde::Serialize; @@ -227,7 +234,42 @@ fn compute_stack_clear(entries: &[StatusEntry], index: usize) -> StatusBit { StatusBit::Passed } +/// Intermediate data for building status entries +struct PrCheckData { + pr: Rc, + is_current: bool, + commits: Vec, + extra_commits: usize, +} + +/// Fetch CI and mergeable status for a single PR +async fn fetch_pr_status( + pr: &PullRequest, + repository: &str, + credentials: &Credentials, +) -> (StatusBit, StatusBit) { + // Fetch CI status and mergeable status in parallel + let (ci_result, mergeable_result) = futures::join!( + fetch_check_status(pr.head_sha(), repository, credentials), + fetch_mergeable_status(pr.number(), repository, credentials) + ); + + let ci = match ci_result { + Ok(check) => check_status_to_bit(&check), + Err(_) => StatusBit::NotApplicable, + }; + + let mergeable = match mergeable_result { + Ok(m) => mergeable_to_bit(m), + Err(_) => StatusBit::NotApplicable, + }; + + (ci, mergeable) +} + /// Build status entries from a PR stack +/// +/// Fetches CI and mergeable status for all PRs in parallel for better performance. pub async fn build_status_entries( stack: &FlatDep, repo: Option<&Repository>, @@ -236,68 +278,78 @@ pub async fn build_status_entries( config: &StatusConfig, ) -> Vec { let current = repo.and_then(current_branch); - let mut entries = Vec::new(); // Get trunk branch from first PR's base let trunk_branch = stack.first().map(|(pr, _)| pr.base().to_string()); - // Process PRs in reverse order (top of stack first) - for (pr, _parent) in stack.iter().rev() { - // Skip closed/merged PRs - if pr.is_merged() || pr.state() == &crate::api::PullRequestStatus::Closed { - continue; - } - - let is_current = current.as_ref().is_some_and(|c| c == pr.head()); - let timestamp = pr.updated_at().and_then(parse_timestamp); - - // Get commits if we have a repo - let (commits, extra_commits) = if let Some(r) = repo { - if branch_exists_locally(r, pr.head()) { - commits_for_branch(r, pr.head(), pr.base()) + // Collect PR data (non-async operations) + let pr_data: Vec = stack + .iter() + .rev() + .filter(|(pr, _)| !pr.is_merged() && pr.state() != &crate::api::PullRequestStatus::Closed) + .map(|(pr, _)| { + let is_current = current.as_ref().is_some_and(|c| c == pr.head()); + + // Get commits if we have a repo + let (commits, extra_commits) = if let Some(r) = repo { + if branch_exists_locally(r, pr.head()) { + commits_for_branch(r, pr.head(), pr.base()) + } else { + (vec![], 0) + } } else { (vec![], 0) + }; + + PrCheckData { + pr: pr.clone(), + is_current, + commits, + extra_commits, } - } else { - (vec![], 0) - }; + }) + .collect(); - // Fetch status if enabled - let status = if config.include_checks { - let ci = match fetch_check_status(pr.head_sha(), repository, credentials).await { - Ok(check) => check_status_to_bit(&check), - Err(_) => StatusBit::NotApplicable, - }; + // Fetch status checks in parallel if enabled + let statuses: Vec> = if config.include_checks { + let futures: Vec<_> = pr_data + .iter() + .map(|data| fetch_pr_status(&data.pr, repository, credentials)) + .collect(); - let mergeable = match fetch_mergeable_status(pr.number(), repository, credentials).await - { - Ok(m) => mergeable_to_bit(m), - Err(_) => StatusBit::NotApplicable, - }; + join_all(futures).await.into_iter().map(Some).collect() + } else { + vec![None; pr_data.len()] + }; + + // Build entries from collected data + let mut entries: Vec = pr_data + .into_iter() + .zip(statuses) + .map(|(data, status_bits)| { + let timestamp = data.pr.updated_at().and_then(parse_timestamp); - Some(PrStatus { + let status = status_bits.map(|(ci, mergeable)| PrStatus { ci, - approved: approval_to_bit(pr), + approved: approval_to_bit(&data.pr), mergeable, stack_clear: StatusBit::Pending, // Will be computed after all entries are built - }) - } else { - None - }; + }); - entries.push(StatusEntry { - branch: pr.head().to_string(), - pr_number: Some(pr.number()), - title: Some(truncate_title(pr.raw_title(), MAX_TITLE_LEN)), - is_current, - is_draft: pr.is_draft(), - is_trunk: false, - status, - updated_at: timestamp.map(|t| t.to_rfc3339()), - commits, - extra_commits, - }); - } + StatusEntry { + branch: data.pr.head().to_string(), + pr_number: Some(data.pr.number()), + title: Some(truncate_title(data.pr.raw_title(), MAX_TITLE_LEN)), + is_current: data.is_current, + is_draft: data.pr.is_draft(), + is_trunk: false, + status, + updated_at: timestamp.map(|t| t.to_rfc3339()), + commits: data.commits, + extra_commits: data.extra_commits, + } + }) + .collect(); // Compute stack_clear for each entry (requires all entries to be built first) if config.include_checks { diff --git a/src/tree.rs b/src/tree.rs index b10fc03..feb859a 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -115,6 +115,21 @@ pub fn branch_exists_locally(repo: &Repository, branch: &str) -> bool { repo.find_branch(branch, git2::BranchType::Local).is_ok() } +/// Get git remote URL for a remote name +pub fn get_remote_url(repo: &Repository, remote_name: &str) -> Option { + repo.find_remote(remote_name).ok()?.url().map(String::from) +} + +/// Get the first commit message between head and base branches +/// +/// Returns the message of the oldest commit in the branch (first commit after +/// diverging from base). Used as default PR title when creating via API. +pub fn first_commit_message(repo: &Repository, head: &str, base: &str) -> Option { + let (commits, _) = commits_for_branch(repo, head, base); + // Last in list is oldest (topological sort puts newest first) + commits.last().map(|c| c.message.clone()) +} + /// Get commits between two branches (head..base exclusive) /// Returns up to MAX_COMMITS and count of extras pub fn commits_for_branch(repo: &Repository, head: &str, base: &str) -> (Vec, usize) { @@ -1268,4 +1283,47 @@ mod tests { fn test_parse_github_remote_url_empty() { assert_eq!(parse_github_remote_url(""), None); } + + // Tests for first_commit_message - requires repo so tested indirectly via commits_for_branch + #[test] + fn test_first_commit_message_empty_commits() { + // When commits_for_branch returns empty, first_commit_message returns None + let commits: Vec = vec![]; + assert_eq!(commits.last().map(|c| c.message.clone()), None); + } + + #[test] + fn test_first_commit_message_single_commit() { + let commits = vec![CommitInfo { + sha: "abc1234".to_string(), + message: "Initial commit".to_string(), + }]; + assert_eq!( + commits.last().map(|c| c.message.clone()), + Some("Initial commit".to_string()) + ); + } + + #[test] + fn test_first_commit_message_multiple_commits() { + // Topological sort: newest first, so last is oldest + let commits = vec![ + CommitInfo { + sha: "ccc3333".to_string(), + message: "Third commit".to_string(), + }, + CommitInfo { + sha: "bbb2222".to_string(), + message: "Second commit".to_string(), + }, + CommitInfo { + sha: "aaa1111".to_string(), + message: "First commit".to_string(), + }, + ]; + assert_eq!( + commits.last().map(|c| c.message.clone()), + Some("First commit".to_string()) + ); + } }