|
| 1 | +package cmd |
| 2 | + |
| 3 | +import ( |
| 4 | + "errors" |
| 5 | + "fmt" |
| 6 | + "strconv" |
| 7 | + |
| 8 | + "github.com/cli/go-gh/v2/pkg/api" |
| 9 | + "github.com/github/gh-stack/internal/config" |
| 10 | + "github.com/github/gh-stack/internal/github" |
| 11 | + "github.com/spf13/cobra" |
| 12 | +) |
| 13 | + |
| 14 | +func LinkCmd(cfg *config.Config) *cobra.Command { |
| 15 | + cmd := &cobra.Command{ |
| 16 | + Use: "link <pr-number> <pr-number> [<pr-number>...]", |
| 17 | + Short: "Link PRs into a stack on GitHub without local tracking", |
| 18 | + Long: `Create or update a stack on GitHub from a list of PR numbers. |
| 19 | +
|
| 20 | +This command works entirely via the GitHub API and does not modify |
| 21 | +any local state. It is designed for users who manage branches with |
| 22 | +external tools (e.g. jj) and want to use GitHub stacked PRs without |
| 23 | +adopting local stack tracking. |
| 24 | +
|
| 25 | +PR numbers must be provided in stack order (bottom to top). The first |
| 26 | +PR's base branch is the trunk of the stack, and each subsequent PR |
| 27 | +should target the previous PR's head branch. |
| 28 | +
|
| 29 | +If the PRs are not yet in a stack, a new stack is created. If some of |
| 30 | +the PRs are already in a stack, the existing stack is updated to include |
| 31 | +the new PRs (existing PRs are never removed).`, |
| 32 | + Args: cobra.MinimumNArgs(2), |
| 33 | + RunE: func(cmd *cobra.Command, args []string) error { |
| 34 | + return runLink(cfg, args) |
| 35 | + }, |
| 36 | + } |
| 37 | + |
| 38 | + return cmd |
| 39 | +} |
| 40 | + |
| 41 | +func runLink(cfg *config.Config, args []string) error { |
| 42 | + prNumbers, err := parsePRNumbers(args) |
| 43 | + if err != nil { |
| 44 | + cfg.Errorf("%s", err) |
| 45 | + return ErrInvalidArgs |
| 46 | + } |
| 47 | + |
| 48 | + client, err := cfg.GitHubClient() |
| 49 | + if err != nil { |
| 50 | + cfg.Errorf("failed to create GitHub client: %s", err) |
| 51 | + return ErrAPIFailure |
| 52 | + } |
| 53 | + |
| 54 | + stacks, err := client.ListStacks() |
| 55 | + if err != nil { |
| 56 | + var httpErr *api.HTTPError |
| 57 | + if errors.As(err, &httpErr) && httpErr.StatusCode == 404 { |
| 58 | + cfg.Warningf("Stacked PRs are not enabled for this repository") |
| 59 | + return ErrStacksUnavailable |
| 60 | + } |
| 61 | + cfg.Errorf("failed to list stacks: %v", err) |
| 62 | + return ErrAPIFailure |
| 63 | + } |
| 64 | + |
| 65 | + matchedStack, err := findMatchingStack(stacks, prNumbers) |
| 66 | + if err != nil { |
| 67 | + cfg.Errorf("%s", err) |
| 68 | + return ErrDisambiguate |
| 69 | + } |
| 70 | + |
| 71 | + if matchedStack == nil { |
| 72 | + return createLink(cfg, client, prNumbers) |
| 73 | + } |
| 74 | + |
| 75 | + return updateLink(cfg, client, matchedStack, prNumbers) |
| 76 | +} |
| 77 | + |
| 78 | +// parsePRNumbers converts string args to a validated list of PR numbers. |
| 79 | +// Returns an error if any arg is not a positive integer or if there are duplicates. |
| 80 | +func parsePRNumbers(args []string) ([]int, error) { |
| 81 | + prNumbers := make([]int, 0, len(args)) |
| 82 | + seen := make(map[int]bool, len(args)) |
| 83 | + |
| 84 | + for _, arg := range args { |
| 85 | + n, err := strconv.Atoi(arg) |
| 86 | + if err != nil || n <= 0 { |
| 87 | + return nil, fmt.Errorf("invalid PR number: %q", arg) |
| 88 | + } |
| 89 | + if seen[n] { |
| 90 | + return nil, fmt.Errorf("duplicate PR number: %d", n) |
| 91 | + } |
| 92 | + seen[n] = true |
| 93 | + prNumbers = append(prNumbers, n) |
| 94 | + } |
| 95 | + |
| 96 | + return prNumbers, nil |
| 97 | +} |
| 98 | + |
| 99 | +// findMatchingStack finds a single stack that contains any of the given PR numbers. |
| 100 | +// Returns nil if no stack matches. Returns an error if PRs span multiple stacks. |
| 101 | +func findMatchingStack(stacks []github.RemoteStack, prNumbers []int) (*github.RemoteStack, error) { |
| 102 | + prSet := make(map[int]bool, len(prNumbers)) |
| 103 | + for _, n := range prNumbers { |
| 104 | + prSet[n] = true |
| 105 | + } |
| 106 | + |
| 107 | + var matched *github.RemoteStack |
| 108 | + for i := range stacks { |
| 109 | + for _, n := range stacks[i].PullRequests { |
| 110 | + if prSet[n] { |
| 111 | + if matched != nil && matched.ID != stacks[i].ID { |
| 112 | + return nil, fmt.Errorf("PRs belong to multiple stacks — unstack them first, then re-link") |
| 113 | + } |
| 114 | + matched = &stacks[i] |
| 115 | + break |
| 116 | + } |
| 117 | + } |
| 118 | + } |
| 119 | + |
| 120 | + return matched, nil |
| 121 | +} |
| 122 | + |
| 123 | +// createLink creates a new stack with the given PR numbers. |
| 124 | +func createLink(cfg *config.Config, client github.ClientOps, prNumbers []int) error { |
| 125 | + _, err := client.CreateStack(prNumbers) |
| 126 | + if err != nil { |
| 127 | + var httpErr *api.HTTPError |
| 128 | + if errors.As(err, &httpErr) { |
| 129 | + switch httpErr.StatusCode { |
| 130 | + case 422: |
| 131 | + cfg.Errorf("Cannot create stack: %s", httpErr.Message) |
| 132 | + return ErrAPIFailure |
| 133 | + case 404: |
| 134 | + cfg.Warningf("Stacked PRs are not enabled for this repository") |
| 135 | + return ErrStacksUnavailable |
| 136 | + default: |
| 137 | + cfg.Errorf("Failed to create stack (HTTP %d): %s", httpErr.StatusCode, httpErr.Message) |
| 138 | + return ErrAPIFailure |
| 139 | + } |
| 140 | + } |
| 141 | + cfg.Errorf("Failed to create stack: %v", err) |
| 142 | + return ErrAPIFailure |
| 143 | + } |
| 144 | + |
| 145 | + cfg.Successf("Created stack with %d PRs", len(prNumbers)) |
| 146 | + return nil |
| 147 | +} |
| 148 | + |
| 149 | +// updateLink updates an existing stack with the given PR numbers. |
| 150 | +// The update is additive-only: it errors if any existing PRs would be removed. |
| 151 | +func updateLink(cfg *config.Config, client github.ClientOps, existing *github.RemoteStack, prNumbers []int) error { |
| 152 | + // Check if the input exactly matches the existing stack. |
| 153 | + if slicesEqual(existing.PullRequests, prNumbers) { |
| 154 | + cfg.Successf("Stack with %d PRs is already up to date", len(prNumbers)) |
| 155 | + return nil |
| 156 | + } |
| 157 | + |
| 158 | + // Check that no existing PRs would be removed (additive-only). |
| 159 | + newSet := make(map[int]bool, len(prNumbers)) |
| 160 | + for _, n := range prNumbers { |
| 161 | + newSet[n] = true |
| 162 | + } |
| 163 | + |
| 164 | + var dropped []int |
| 165 | + for _, n := range existing.PullRequests { |
| 166 | + if !newSet[n] { |
| 167 | + dropped = append(dropped, n) |
| 168 | + } |
| 169 | + } |
| 170 | + |
| 171 | + if len(dropped) > 0 { |
| 172 | + cfg.Errorf("Cannot update stack: this would remove %s from the stack", |
| 173 | + formatPRList(dropped)) |
| 174 | + cfg.Printf("Current stack: %s", formatPRList(existing.PullRequests)) |
| 175 | + cfg.Printf("Include all existing PRs in the command to update the stack") |
| 176 | + return ErrInvalidArgs |
| 177 | + } |
| 178 | + |
| 179 | + stackID := strconv.Itoa(existing.ID) |
| 180 | + if err := client.UpdateStack(stackID, prNumbers); err != nil { |
| 181 | + var httpErr *api.HTTPError |
| 182 | + if errors.As(err, &httpErr) { |
| 183 | + switch httpErr.StatusCode { |
| 184 | + case 404: |
| 185 | + // Stack was deleted between list and update — try creating instead. |
| 186 | + cfg.Warningf("Stack was deleted — creating a new one") |
| 187 | + return createLink(cfg, client, prNumbers) |
| 188 | + case 422: |
| 189 | + cfg.Errorf("Cannot update stack: %s", httpErr.Message) |
| 190 | + return ErrAPIFailure |
| 191 | + default: |
| 192 | + cfg.Errorf("Failed to update stack (HTTP %d): %s", httpErr.StatusCode, httpErr.Message) |
| 193 | + return ErrAPIFailure |
| 194 | + } |
| 195 | + } |
| 196 | + cfg.Errorf("Failed to update stack: %v", err) |
| 197 | + return ErrAPIFailure |
| 198 | + } |
| 199 | + |
| 200 | + cfg.Successf("Updated stack to %d PRs", len(prNumbers)) |
| 201 | + return nil |
| 202 | +} |
| 203 | + |
| 204 | +func slicesEqual(a, b []int) bool { |
| 205 | + if len(a) != len(b) { |
| 206 | + return false |
| 207 | + } |
| 208 | + for i := range a { |
| 209 | + if a[i] != b[i] { |
| 210 | + return false |
| 211 | + } |
| 212 | + } |
| 213 | + return true |
| 214 | +} |
| 215 | + |
| 216 | +func formatPRList(numbers []int) string { |
| 217 | + if len(numbers) == 0 { |
| 218 | + return "" |
| 219 | + } |
| 220 | + s := fmt.Sprintf("#%d", numbers[0]) |
| 221 | + for _, n := range numbers[1:] { |
| 222 | + s += fmt.Sprintf(", #%d", n) |
| 223 | + } |
| 224 | + return s |
| 225 | +} |
0 commit comments