From 6ff541284e4fae2eaa3c752dd7d0807c81337f5e Mon Sep 17 00:00:00 2001 From: Mplan Date: Thu, 11 Jun 2026 00:39:20 +0800 Subject: [PATCH] feat(cli): add pull request creation with AI-generated messages (#2) Add a new `gai pr` subcommand that generates pull request titles and descriptions using AI, then creates the PR via GitHub CLI (`gh`) or Gitea CLI (`tea`). This extends the existing commit-generation system by reusing retry logic and prompt infrastructure, and introduces a `callAI` function that returns raw output (instead of pre-cleaned messages) to support structured PR responses. Reviewed-on: https://git.catpl.top/Mplan/gai/pulls/2 --- index.ts | 288 ++++++++++++++++++++++++++++++++++++++++++++++++++ src/ai.ts | 50 +++++++-- src/pr.ts | 184 ++++++++++++++++++++++++++++++++ src/prompt.ts | 64 ++++++++++- src/types.ts | 10 ++ 5 files changed, 585 insertions(+), 11 deletions(-) create mode 100644 src/pr.ts diff --git a/index.ts b/index.ts index e6ba1e6..0595acb 100644 --- a/index.ts +++ b/index.ts @@ -19,6 +19,20 @@ import { generateCommitMessage } from "./src/ai"; import { copyToClipboard } from "./src/clipboard"; import { BOLD, GREEN, YELLOW, CYAN, RED, DIM, RESET } from "./src/terminal"; import type { Config } from "./src/types"; +import { + getDefaultBranch, + getBranchName, + getBranchCommits, + getBranchDiff, + detectPlatform, + getRemoteHostname, + checkCLI, + checkAuth, + createPR, +} from "./src/pr"; +import type { Platform } from "./src/pr"; +import { PR_SYSTEM_PROMPT, buildPRPrompt } from "./src/prompt"; +import { generatePRMessage } from "./src/ai"; const args = process.argv.slice(2); @@ -31,6 +45,8 @@ ${BOLD}Usage:${RESET} gai commit Generate commit message for staged/changed files gai commit --auto Auto-stage all changed files gai commit -d Generate message without committing + gai pr Create a PR with AI-generated title and body + gai pr --draft Create a draft PR gai config Configure API settings gai --help Show this help message gai --version Show version @@ -277,6 +293,7 @@ interface MenuAction { const MENU_ACTIONS: MenuAction[] = [ { key: "commit", label: "commit", description: "Generate AI commit message" }, + { key: "pr", label: "pr", description: "Create a PR with AI-generated title" }, { key: "config", label: "config", description: "Configure API settings" }, ]; @@ -385,6 +402,8 @@ async function showMenu(): Promise { if (selected.key === "commit") { handleCommit(false, false).then(resolve); + } else if (selected.key === "pr") { + handlePR(false).then(resolve); } else if (selected.key === "config") { handleConfig().then(resolve); } else { @@ -396,6 +415,119 @@ async function showMenu(): Promise { }); } +async function selectPlatform(hostname: string): Promise { + const options = [ + { platform: "github" as Platform, label: "GitHub", desc: "gh CLI" }, + { platform: "gitea" as Platform, label: "Gitea", desc: "tea CLI" }, + ]; + let cursor = 0; + + const headerLines = 4; + + process.stdout.write(`\n Remote: ${CYAN}${hostname}${RESET} — could not auto-detect platform.\n`); + process.stdout.write(` ${DIM}↑/↓ navigate, space/enter select${RESET}\n\n`); + + const totalLines = headerLines + options.length; + + function render() { + for (let i = 0; i < options.length; i++) { + process.stdout.write("\x1b[2K\r"); + const opt = options[i]!; + const pointer = i === cursor ? `${CYAN}❯${RESET} ` : " "; + const dot = i === cursor ? `${GREEN}◉${RESET}` : `${DIM}○${RESET}`; + const name = i === cursor ? `${BOLD}${opt.label}${RESET}` : opt.label; + const desc = i === cursor ? opt.desc : `${DIM}${opt.desc}${RESET}`; + process.stdout.write(`${pointer} ${dot} ${name}${" ".repeat(Math.max(1, 10 - opt.label.length))}${desc}\n`); + } + process.stdout.write(`\x1b[${options.length}A`); + } + + function clearMenu() { + process.stdout.write(`\x1b[${headerLines}A`); + for (let i = 0; i < totalLines; i++) { + process.stdout.write("\r\x1b[2K\n"); + } + process.stdout.write(`\x1b[${totalLines}A`); + } + + if (!process.stdin.isTTY) { + console.error(` ${RED}Error: Platform selection requires a TTY.${RESET}`); + process.exit(1); + } + + const savedRaw = process.stdin.isRaw; + process.stdin.setRawMode(true); + process.stdin.resume(); + process.stdout.write("\x1b[?25l"); + + render(); + + return new Promise((resolve) => { + let escapeBuf = ""; + + function handleSeq(seq: string) { + if (seq === "\x1b[A" || seq === "\x1bOA") { + if (cursor > 0) { + cursor--; + render(); + } + } else if (seq === "\x1b[B" || seq === "\x1bOB") { + if (cursor < options.length - 1) { + cursor++; + render(); + } + } + } + + process.stdin.on("data", (data: Buffer) => { + const key = data.toString(); + + if (key === "\x03") { + process.stdin.setRawMode(savedRaw === true); + process.stdin.pause(); + process.stdin.removeAllListeners("data"); + clearMenu(); + process.stdout.write("\x1b[?25h"); + resolve(null); + return; + } + + if (key === "\x1b" || key.startsWith("\x1b[")) { + escapeBuf = key; + if (key.length >= 3) { + handleSeq(key); + escapeBuf = ""; + } + return; + } + + if (escapeBuf) { + escapeBuf += key; + if (/^[A-Za-z~]$/.test(key)) { + handleSeq(escapeBuf); + escapeBuf = ""; + } else if (escapeBuf.length > 8) { + escapeBuf = ""; + } + return; + } + + if (key === " " || key === "\r") { + const selected = options[cursor]!; + process.stdin.setRawMode(savedRaw === true); + process.stdin.pause(); + process.stdin.removeAllListeners("data"); + + clearMenu(); + process.stdout.write("\x1b[?25h"); + + resolve(selected.platform); + return; + } + }); + }); +} + async function handleCommit(autoMode: boolean, dryRun: boolean): Promise { const config = await loadConfig(); @@ -517,6 +649,156 @@ async function handleCommit(autoMode: boolean, dryRun: boolean): Promise { } } +async function handlePR(draft: boolean): Promise { + const config = await loadConfig(); + + if (!config.apiKey) { + console.error( + ` ${RED}Error: API key not set. Run ${BOLD}gai config${RESET}${RED} to configure.${RESET}`, + ); + process.exit(1); + } + + if (!(await isGitRepo())) { + console.error(` ${RED}Error: Not a git repository.${RESET}`); + process.exit(1); + } + + let platform = await detectPlatform(); + if (!platform) { + const hostname = (await getRemoteHostname()) || "unknown"; + const chosen = await selectPlatform(hostname); + if (!chosen) { + console.log(" Aborted."); + process.exit(0); + } + platform = chosen; + } + + const platformLabel = platform === "github" ? "GitHub" : "Gitea"; + console.log(` Using: ${CYAN}${platformLabel}${RESET}`); + + const cliError = checkCLI(platform); + if (cliError) { + console.error(` ${RED}Error: ${cliError}${RESET}`); + process.exit(1); + } + + const authError = await checkAuth(platform); + if (authError) { + console.error(` ${RED}Error: ${authError}${RESET}`); + process.exit(1); + } + + const baseBranch = await getDefaultBranch(); + const branchName = await getBranchName(); + + if (branchName === baseBranch) { + console.error( + ` ${RED}Error: You are on the default branch (${baseBranch}). Switch to a feature branch first.${RESET}`, + ); + process.exit(1); + } + + console.log( + ` Branch: ${CYAN}${branchName}${RESET} → base: ${CYAN}${baseBranch}${RESET}`, + ); + + const commits = await getBranchCommits(baseBranch); + + if (commits.length === 0) { + console.error( + ` ${RED}Error: No commits on ${branchName} compared to ${baseBranch}. Commit something first.${RESET}`, + ); + process.exit(1); + } + + console.log( + ` ${commits.length} commit${commits.length > 1 ? "s" : ""} on this branch`, + ); + + const diff = await getBranchDiff(baseBranch); + if (!diff) { + console.error(` ${RED}Error: No diff from base branch.${RESET}`); + process.exit(1); + } + + const MAX_DIFF_SIZE = 15000; + const truncatedDiff = + diff.length > MAX_DIFF_SIZE + ? diff.substring(0, MAX_DIFF_SIZE) + "\n... (truncated)" + : diff; + + const repoRoot = await getRepoRoot(); + const projectCtx = await collectProjectContext(repoRoot); + + const userPrompt = buildPRPrompt({ + readme: projectCtx.readme, + packageDescription: projectCtx.packageDescription, + structure: projectCtx.structure, + branchName, + baseBranch, + branchCommits: commits, + diff: truncatedDiff, + }); + + console.log("\n Generating PR title..."); + + let title: string; + let body: string; + try { + const result = await generatePRMessage(config, PR_SYSTEM_PROMPT, userPrompt); + title = result.title; + body = result.body; + } catch (err) { + console.error( + ` ${RED}AI request failed: ${err instanceof Error ? err.message : err}${RESET}`, + ); + process.exit(1); + } + + console.log(`\n ${BOLD}Generated PR:${RESET}`); + console.log(` Title: ${GREEN}${title}${RESET}`); + if (body) { + console.log( + ` Body: ${DIM}${body.replace(/\n/g, "\n ")}${RESET}`, + ); + } + console.log(""); + + const answer = await ask(` Create this PR? [${GREEN}Y${RESET}/n/e] `); + const lower = answer.toLowerCase(); + + if (lower === "n") { + console.log(" Aborted."); + return; + } + + if (lower === "e") { + const newTitle = await ask(" Title: "); + const newBody = await ask(" Body (optional): "); + if (!newTitle.trim()) { + console.log(" Aborted."); + return; + } + title = newTitle; + body = newBody; + } + + console.log(`\n Creating PR...`); + + try { + const url = await createPR(platform, title, body, baseBranch, draft); + console.log(` ${GREEN}${BOLD}✔ PR created!${RESET}`); + console.log(` ${CYAN}${url}${RESET}`); + } catch (err) { + console.error( + ` ${RED}PR creation failed: ${err instanceof Error ? err.message : err}${RESET}`, + ); + process.exit(1); + } +} + async function main() { if (args.includes("--help") || args.includes("-h")) { showHelp(); @@ -547,6 +829,12 @@ async function main() { return; } + if (subcommand === "pr") { + const draft = args.includes("--draft"); + await handlePR(draft); + return; + } + if (!subcommand) { await showMenu(); return; diff --git a/src/ai.ts b/src/ai.ts index 4fe4243..4c96662 100644 --- a/src/ai.ts +++ b/src/ai.ts @@ -39,11 +39,10 @@ async function sleep(ms: number) { return new Promise((resolve) => setTimeout(resolve, ms)); } -export async function generateCommitMessage( +export async function callAI( config: Config, systemPrompt: string, userPrompt: string, - retries = MAX_RETRIES, ): Promise { const url = `${config.apiBase.replace(/\/$/, "")}/chat/completions`; @@ -52,7 +51,7 @@ export async function generateCommitMessage( { role: "user", content: userPrompt }, ]; - for (let attempt = 1; attempt <= retries; attempt++) { + for (let attempt = 1; attempt <= MAX_RETRIES; attempt++) { try { const response = await fetch(url, { method: "POST", @@ -70,7 +69,7 @@ export async function generateCommitMessage( if (!response.ok) { const text = await response.text(); - if (response.status === 429 && attempt < retries) { + if (response.status === 429 && attempt < MAX_RETRIES) { await sleep(RETRY_DELAY * attempt); continue; } @@ -89,7 +88,7 @@ export async function generateCommitMessage( const finishReason = data.choices?.[0]?.finish_reason; if (raw && raw.trim()) { - return cleanMessage(raw); + return raw; } if (finishReason === "length") { @@ -102,22 +101,53 @@ export async function generateCommitMessage( throw new Error("Response blocked by content filter."); } - if (attempt < retries) { + if (attempt < MAX_RETRIES) { await sleep(RETRY_DELAY * attempt); continue; } throw new Error( - `Empty response from AI after ${retries} attempts. finish_reason: ${finishReason ?? "unknown"}`, + `Empty response from AI after ${MAX_RETRIES} attempts. finish_reason: ${finishReason ?? "unknown"}`, ); } catch (err) { - if (attempt >= retries) throw err; + if (attempt >= MAX_RETRIES) throw err; if (err instanceof Error && err.message.startsWith("API error")) throw err; if (err instanceof Error && err.message.includes("max_tokens")) throw err; - if (err instanceof Error && err.message.includes("content filter")) throw err; + if (err instanceof Error && err.message.includes("content filter")) + throw err; await sleep(RETRY_DELAY * attempt); } } - throw new Error("Failed to generate commit message"); + throw new Error("Failed to generate response"); +} + +export async function generateCommitMessage( + config: Config, + systemPrompt: string, + userPrompt: string, +): Promise { + const raw = await callAI(config, systemPrompt, userPrompt); + return cleanMessage(raw); +} + +export async function generatePRMessage( + config: Config, + systemPrompt: string, + userPrompt: string, +): Promise<{ title: string; body: string }> { + const raw = await callAI(config, systemPrompt, userPrompt); + const cleaned = cleanMessage(raw); + + const lines = cleaned.split("\n"); + const title = lines[0]?.trim() || "Update"; + let bodyStart = 1; + + while (bodyStart < lines.length && lines[bodyStart]?.trim() === "") { + bodyStart++; + } + + const body = lines.slice(bodyStart).join("\n").trim(); + + return { title, body }; } diff --git a/src/pr.ts b/src/pr.ts new file mode 100644 index 0000000..20e867f --- /dev/null +++ b/src/pr.ts @@ -0,0 +1,184 @@ +export type Platform = "github" | "gitea"; + +export async function getDefaultBranch(): Promise { + try { + const result = + await Bun.$`git symbolic-ref refs/remotes/origin/HEAD`.quiet().text(); + return result.trim().replace("refs/remotes/origin/", ""); + } catch { + try { + const branches = await Bun.$`git branch -r`.quiet().text(); + for (const line of branches.split("\n")) { + const trimmed = line.trim(); + if (trimmed === "origin/main" || trimmed === "origin/master") { + return trimmed.replace("origin/", ""); + } + } + } catch {} + return "main"; + } +} + +export async function getBranchName(): Promise { + const result = + await Bun.$`git rev-parse --abbrev-ref HEAD`.quiet().text(); + return result.trim(); +} + +export async function getBranchCommits(base: string): Promise { + try { + const result = + await Bun.$`git log --oneline origin/${base}..HEAD`.quiet().text(); + return result.trim().split("\n").filter(Boolean); + } catch { + try { + const result = + await Bun.$`git log --oneline ${base}..HEAD`.quiet().text(); + return result.trim().split("\n").filter(Boolean); + } catch { + return []; + } + } +} + +export async function getBranchDiff(base: string): Promise { + try { + const result = + await Bun.$`git diff ${base}...HEAD`.quiet().text(); + return result.trim(); + } catch { + return ""; + } +} + +function parseRemoteHostname(url: string): string | null { + const hostname = url + .trim() + .toLowerCase() + .replace(/^(https?:\/\/|ssh:\/\/|git:\/\/)/, "") + .replace(/^[^@]+@/, "") + .split(/[:/]/)[0]; + return hostname || null; +} + +export async function detectPlatform(): Promise { + try { + const url = await Bun.$`git remote get-url origin`.quiet().text(); + const hostname = parseRemoteHostname(url); + + if (!hostname) return null; + + if (hostname === "github.com") return "github"; + if (hostname.includes("gitea")) return "gitea"; + return null; + } catch { + return null; + } +} + +export async function getRemoteHostname(): Promise { + try { + const url = await Bun.$`git remote get-url origin`.quiet().text(); + return parseRemoteHostname(url); + } catch { + return null; + } +} + +export function checkCLI(platform: Platform): string | null { + const bin = platform === "github" ? "gh" : "tea"; + const path = Bun.which(bin); + if (!path) { + if (platform === "github") { + return "GitHub CLI (gh) not found. Install: brew install gh"; + } + return "Gitea CLI (tea) not found. Install from: https://gitea.com/gitea/tea"; + } + return null; +} + +export async function checkAuth(platform: Platform): Promise { + if (platform === "github") { + try { + await Bun.$`gh auth status`.quiet(); + return null; + } catch { + return "Not authenticated with GitHub CLI. Run: gh auth login"; + } + } + + try { + const result = await Bun.$`tea logins list`.quiet().text(); + if (result.trim()) return null; + return "Not authenticated with Gitea CLI. Run: tea login add"; + } catch { + return "Not authenticated with Gitea CLI. Run: tea login add"; + } +} + +export async function createPR( + platform: Platform, + title: string, + body: string, + base: string, + draft: boolean, +): Promise { + if (platform === "github") { + const args = [ + "pr", + "create", + "--title", + title, + "--body", + body, + "--base", + base, + ]; + if (draft) args.push("--draft"); + + const proc = Bun.spawn(["gh", ...args], { + stdout: "pipe", + stderr: "pipe", + }); + const exitCode = await proc.exited; + const stdout = await new Response(proc.stdout).text(); + const stderr = await new Response(proc.stderr).text(); + + if (exitCode !== 0) { + throw new Error( + stderr.trim() || `gh pr create failed (exit code ${exitCode})`, + ); + } + + const match = stdout.match(/(https?:\/\/[^\s]+)/); + return match ? match[1] : stdout.trim(); + } + + const args = [ + "pulls", + "create", + "--title", + title, + "--description", + body, + "--base", + base, + ]; + + const proc = Bun.spawn(["tea", ...args], { + stdout: "pipe", + stderr: "pipe", + }); + const exitCode = await proc.exited; + const stdout = await new Response(proc.stdout).text(); + const stderr = await new Response(proc.stderr).text(); + + if (exitCode !== 0) { + throw new Error( + stderr.trim() || `tea pulls create failed (exit code ${exitCode})`, + ); + } + + const match = stdout.match(/(https?:\/\/[^\s]+)/); + return match ? match[1] : stdout.trim(); +} diff --git a/src/prompt.ts b/src/prompt.ts index 806f3d8..4090f1a 100644 --- a/src/prompt.ts +++ b/src/prompt.ts @@ -1,4 +1,4 @@ -import type { ProjectContext } from "./types"; +import type { PRContext, ProjectContext } from "./types"; export const SYSTEM_PROMPT = `You are an expert at writing concise, meaningful git commit messages following the Conventional Commits specification. @@ -54,3 +54,65 @@ export function buildPrompt(context: ProjectContext): string { return parts.join("\n"); } + +export const PR_SYSTEM_PROMPT = `You are an expert at writing clear, concise pull request titles and descriptions. + +Format: + + + + +Rules: +1. Title must be under 72 characters, in imperative mood +2. Follow the Conventional Commits style for the title (e.g., "feat(api): add user authentication") +3. Body should be 2-3 sentences in plain text explaining WHAT was changed and WHY +4. Be specific — avoid vague messages +5. Match the language and style of recent commits if provided +6. If the branch name hints at the type (e.g., "feat/..." or "fix/..."), reflect that in the title +7. Output ONLY the PR text — no markdown, no code blocks, no prefixes`; + +export function buildPRPrompt(context: PRContext): string { + const parts: string[] = []; + + if ( + context.packageDescription || + context.readme || + context.structure + ) { + parts.push("## Project Context"); + if (context.packageDescription) { + parts.push(`Description: ${context.packageDescription}`); + } + if (context.structure) { + parts.push(`Structure: ${context.structure}`); + } + if (context.readme) { + parts.push(`README:\n${context.readme}`); + } + parts.push(""); + } + + parts.push("## Branch Info"); + parts.push(`Branch: ${context.branchName}`); + parts.push(`Target base: ${context.baseBranch}`); + parts.push(""); + + if (context.branchCommits.length > 0) { + parts.push("## Commits on This Branch"); + for (const c of context.branchCommits) { + parts.push(c); + } + parts.push(""); + } + + parts.push("## Changes (diff from base)"); + parts.push("```diff"); + parts.push(context.diff); + parts.push("```"); + parts.push(""); + parts.push( + "Generate a pull request title and brief body for the above changes.", + ); + + return parts.join("\n"); +} diff --git a/src/types.ts b/src/types.ts index cc8e5d1..b81f755 100644 --- a/src/types.ts +++ b/src/types.ts @@ -19,3 +19,13 @@ export interface ProjectContext { recentCommits: string[]; diff: string; } + +export interface PRContext { + readme: string | null; + packageDescription: string | null; + structure: string | null; + branchName: string; + baseBranch: string; + branchCommits: string[]; + diff: string; +}