diff --git a/internal/github/github.go b/internal/github/github.go index 9f350c7240..470e1c8d4d 100644 --- a/internal/github/github.go +++ b/internal/github/github.go @@ -135,6 +135,27 @@ func (c *Client) CreatePullRequest(ctx context.Context, repo *Repository, remote return pullRequestMetadata, nil } +// GetLabels fetches the labels for an issue. +func (c *Client) GetLabels(ctx context.Context, number int) ([]string, error) { + slog.Info("Getting labels", "number", number) + labels, _, err := c.Issues.ListLabelsByIssue(ctx, c.repo.Owner, c.repo.Name, number, nil) + if err != nil { + return nil, err + } + var labelNames []string + for _, label := range labels { + labelNames = append(labelNames, *label.Name) + } + return labelNames, nil +} + +// ReplaceLabels replaces all labels for an issue. +func (c *Client) ReplaceLabels(ctx context.Context, number int, labels []string) error { + slog.Info("Replacing labels", "number", number, "labels", labels) + _, _, err := c.Issues.ReplaceLabelsForIssue(ctx, c.repo.Owner, c.repo.Name, number, labels) + return err +} + // AddLabelsToIssue adds labels to an existing issue in a GitHub repository. func (c *Client) AddLabelsToIssue(ctx context.Context, repo *Repository, number int, labels []string) error { slog.Info("Labels added to issue", "number", number, "labels", labels) @@ -165,3 +186,56 @@ func FetchGitHubRepoFromRemote(repo gitrepo.Repository) (*Repository, error) { return nil, fmt.Errorf("could not find an 'origin' remote pointing to a GitHub https URL") } + +// SearchPullRequests searches for pull requests in the repository using the provided raw query. +func (c *Client) SearchPullRequests(ctx context.Context, query string) ([]*PullRequest, error) { + var prs []*PullRequest + opts := &github.SearchOptions{ + ListOptions: github.ListOptions{PerPage: 100}, + } + for { + result, resp, err := c.Search.Issues(ctx, query, opts) + if err != nil { + return nil, err + } + for _, issue := range result.Issues { + if issue.IsPullRequest() { + pr, _, err := c.PullRequests.Get(ctx, c.repo.Owner, c.repo.Name, issue.GetNumber()) + if err != nil { + return nil, err + } + prs = append(prs, pr) + } + } + if resp.NextPage == 0 { + break + } + opts.Page = resp.NextPage + } + return prs, nil +} + +// GetPullRequest gets a pull request by its number. +func (c *Client) GetPullRequest(ctx context.Context, number int) (*PullRequest, error) { + pr, _, err := c.PullRequests.Get(ctx, c.repo.Owner, c.repo.Name, number) + return pr, err +} + +// CreateRelease creates a tag and release in the repository at the given commitish. +func (c *Client) CreateRelease(ctx context.Context, tagName, name, body, commitish string) (*github.RepositoryRelease, error) { + r, _, err := c.Repositories.CreateRelease(ctx, c.repo.Owner, c.repo.Name, &github.RepositoryRelease{ + TagName: &tagName, + Name: &name, + Body: &body, + TargetCommitish: &commitish, + }) + return r, err +} + +// CreateIssueComment adds a comment to the issue number provided. +func (c *Client) CreateIssueComment(ctx context.Context, number int, comment string) error { + _, _, err := c.Issues.CreateComment(ctx, c.repo.Owner, c.repo.Name, number, &github.IssueComment{ + Body: &comment, + }) + return err +} diff --git a/internal/github/github_test.go b/internal/github/github_test.go index 603301466c..dca237ab76 100644 --- a/internal/github/github_test.go +++ b/internal/github/github_test.go @@ -471,3 +471,466 @@ func TestAddLabelsToIssue(t *testing.T) { }) } } + +func TestGetLabels(t *testing.T) { + t.Parallel() + for _, test := range []struct { + name string + handler http.HandlerFunc + issueNum int + wantLabels []string + wantErr bool + wantErrSubstr string + }{ + { + name: "get labels from an issue", + handler: func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("unexpected method: got %s, want %s", r.Method, http.MethodGet) + } + wantPath := "/repos/owner/repo/issues/7/labels" + if r.URL.Path != wantPath { + t.Errorf("unexpected path: got %s, want %s", r.URL.Path, wantPath) + } + fmt.Fprint(w, `[{"name": "label1"}, {"name": "label2"}]`) + }, + issueNum: 7, + wantLabels: []string{"label1", "label2"}, + }, + { + name: "GitHub API error", + handler: func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) }, + issueNum: 7, + wantErr: true, + wantErrSubstr: "500", + }, + } { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(test.handler) + defer server.Close() + + repo := &Repository{Owner: "owner", Name: "repo"} + client, err := newClientWithHTTP("fake-token", repo, server.Client()) + if err != nil { + t.Fatalf("newClientWithHTTP() error = %v", err) + } + client.BaseURL, _ = url.Parse(server.URL + "/") + + gotLabels, err := client.GetLabels(context.Background(), test.issueNum) + + if test.wantErr { + if err == nil { + t.Errorf("GetLabels() should return an error") + } + if !strings.Contains(err.Error(), test.wantErrSubstr) { + t.Errorf("GetLabels() err = %v, want error containing %q", err, test.wantErrSubstr) + } + return + } + + if err != nil { + t.Errorf("GetLabels() err = %v, want nil", err) + } + if diff := cmp.Diff(test.wantLabels, gotLabels); diff != "" { + t.Errorf("GetLabels() labels mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestReplaceLabels(t *testing.T) { + t.Parallel() + for _, test := range []struct { + name string + handler http.HandlerFunc + issueNum int + labels []string + wantErr bool + wantErrSubstr string + }{ + { + name: "replace labels for an issue", + handler: func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPut { + t.Errorf("unexpected method: got %s, want %s", r.Method, http.MethodPut) + } + wantPath := "/repos/owner/repo/issues/7/labels" + if r.URL.Path != wantPath { + t.Errorf("unexpected path: got %s, want %s", r.URL.Path, wantPath) + } + var labels []string + if err := json.NewDecoder(r.Body).Decode(&labels); err != nil { + t.Fatalf("failed to decode request body: %v", err) + } + expectedBody := []string{"new-label", "another-label"} + if diff := cmp.Diff(expectedBody, labels); diff != "" { + t.Errorf("ReplaceLabels() request body mismatch (-want +got):\n%s", diff) + } + fmt.Fprint(w, `[]`) + }, + issueNum: 7, + labels: []string{"new-label", "another-label"}, + }, + { + name: "GitHub API error", + handler: func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) }, + issueNum: 7, + labels: []string{"some-label"}, + wantErr: true, + wantErrSubstr: "500", + }, + } { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(test.handler) + defer server.Close() + + repo := &Repository{Owner: "owner", Name: "repo"} + client, err := newClientWithHTTP("fake-token", repo, server.Client()) + if err != nil { + t.Fatalf("newClientWithHTTP() error = %v", err) + } + client.BaseURL, _ = url.Parse(server.URL + "/") + + err = client.ReplaceLabels(context.Background(), test.issueNum, test.labels) + + if test.wantErr { + if err == nil { + t.Errorf("ReplaceLabels() should return an error") + } + if !strings.Contains(err.Error(), test.wantErrSubstr) { + t.Errorf("ReplaceLabels() err = %v, want error containing %q", err, test.wantErrSubstr) + } + return + } + + if err != nil { + t.Errorf("ReplaceLabels() err = %v, want nil", err) + } + }) + } +} + +func TestSearchPullRequests(t *testing.T) { + t.Parallel() + for _, test := range []struct { + name string + query string + handler http.HandlerFunc + wantPRs []*PullRequest + wantErr bool + wantErrSubstr string + }{ + { + name: "Success with single page", + query: "is:pr is:open author:app/dependabot", + handler: func(w http.ResponseWriter, r *http.Request) { + if strings.HasPrefix(r.URL.Path, "/search/issues") { + if r.URL.Query().Get("q") != "is:pr is:open author:app/dependabot" { + t.Errorf("unexpected query: got %q", r.URL.Query().Get("q")) + } + fmt.Fprint(w, `{"items": [{"number": 1, "pull_request": {}}]}`) + } else if strings.HasPrefix(r.URL.Path, "/repos/owner/repo/pulls/1") { + fmt.Fprint(w, `{"number": 1, "title": "PR 1"}`) + } else { + w.WriteHeader(http.StatusNotFound) + } + }, + wantPRs: []*PullRequest{ + {Number: github.Ptr(1), Title: github.Ptr("PR 1")}, + }, + }, + { + name: "Success with multiple pages", + query: "is:pr is:open", + handler: func(w http.ResponseWriter, r *http.Request) { + if strings.HasPrefix(r.URL.Path, "/search/issues") { + if r.URL.Query().Get("page") == "2" { + fmt.Fprint(w, `{"items": [{"number": 2, "pull_request": {}}]}`) + } else { + w.Header().Set("Link", `; rel="next"`) + fmt.Fprint(w, `{"items": [{"number": 1, "pull_request": {}}]}`) + } + } else if strings.HasPrefix(r.URL.Path, "/repos/owner/repo/pulls/1") { + fmt.Fprint(w, `{"number": 1, "title": "PR 1"}`) + } else if strings.HasPrefix(r.URL.Path, "/repos/owner/repo/pulls/2") { + fmt.Fprint(w, `{"number": 2, "title": "PR 2"}`) + } else { + w.WriteHeader(http.StatusNotFound) + } + }, + wantPRs: []*PullRequest{ + {Number: github.Ptr(1), Title: github.Ptr("PR 1")}, + {Number: github.Ptr(2), Title: github.Ptr("PR 2")}, + }, + }, + { + name: "Search API error", + query: "is:pr", + handler: func(w http.ResponseWriter, r *http.Request) { + if strings.HasPrefix(r.URL.Path, "/search/issues") { + w.WriteHeader(http.StatusInternalServerError) + } + }, + wantErr: true, + wantErrSubstr: "500", + }, + { + name: "Get PR API error", + query: "is:pr", + handler: func(w http.ResponseWriter, r *http.Request) { + if strings.HasPrefix(r.URL.Path, "/search/issues") { + fmt.Fprint(w, `{"items": [{"number": 1, "pull_request": {}}]}`) + } else if strings.HasPrefix(r.URL.Path, "/repos/owner/repo/pulls/1") { + w.WriteHeader(http.StatusInternalServerError) + } + }, + wantErr: true, + wantErrSubstr: "500", + }, + } { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(test.handler) + defer server.Close() + + repo := &Repository{Owner: "owner", Name: "repo"} + client, err := newClientWithHTTP("fake-token", repo, server.Client()) + if err != nil { + t.Fatalf("newClientWithHTTP() error = %v", err) + } + client.BaseURL, _ = url.Parse(server.URL + "/") + + prs, err := client.SearchPullRequests(context.Background(), test.query) + + if test.wantErr { + if err == nil { + t.Errorf("SearchPullRequests() err = nil, want error containing %q", test.wantErrSubstr) + } else if !strings.Contains(err.Error(), test.wantErrSubstr) { + t.Errorf("SearchPullRequests() err = %v, want error containing %q", err, test.wantErrSubstr) + } + } else if err != nil { + t.Errorf("SearchPullRequests() err = %v, want nil", err) + } + + if diff := cmp.Diff(test.wantPRs, prs); diff != "" { + t.Errorf("SearchPullRequests() prs mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestGetPullRequest(t *testing.T) { + t.Parallel() + for _, test := range []struct { + name string + number int + handler http.HandlerFunc + wantPR *PullRequest + wantErr bool + wantErrSubstr string + }{ + { + name: "Success", + number: 42, + handler: func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("unexpected method: got %s, want %s", r.Method, http.MethodGet) + } + wantPath := "/repos/owner/repo/pulls/42" + if r.URL.Path != wantPath { + t.Errorf("unexpected path: got %s, want %s", r.URL.Path, wantPath) + } + fmt.Fprint(w, `{"number": 42, "title": "The Answer"}`) + }, + wantPR: &PullRequest{Number: github.Ptr(42), Title: github.Ptr("The Answer")}, + }, + { + name: "Not Found", + number: 43, + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + }, + wantErr: true, + wantErrSubstr: "404", + }, + } { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(test.handler) + defer server.Close() + + repo := &Repository{Owner: "owner", Name: "repo"} + client, err := newClientWithHTTP("fake-token", repo, server.Client()) + if err != nil { + t.Fatalf("newClientWithHTTP() error = %v", err) + } + client.BaseURL, _ = url.Parse(server.URL + "/") + + pr, err := client.GetPullRequest(context.Background(), test.number) + + if test.wantErr { + if err == nil { + t.Errorf("GetPullRequest() err = nil, want error containing %q", test.wantErrSubstr) + } else if !strings.Contains(err.Error(), test.wantErrSubstr) { + t.Errorf("GetPullRequest() err = %v, want error containing %q", err, test.wantErrSubstr) + } + } else if err != nil { + t.Errorf("GetPullRequest() err = %v, want nil", err) + } + + if diff := cmp.Diff(test.wantPR, pr); diff != "" { + t.Errorf("GetPullRequest() pr mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestCreateRelease(t *testing.T) { + t.Parallel() + for _, test := range []struct { + name string + tagName string + releaseName string + body string + commitish string + handler http.HandlerFunc + wantRelease *github.RepositoryRelease + wantErr bool + wantErrSubstr string + }{ + { + name: "Success", + tagName: "v1.0.0", + releaseName: "Version 1.0.0", + body: "Initial release", + commitish: "main", + handler: func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("unexpected method: got %s, want %s", r.Method, http.MethodPost) + } + wantPath := "/repos/owner/repo/releases" + if r.URL.Path != wantPath { + t.Errorf("unexpected path: got %s, want %s", r.URL.Path, wantPath) + } + var newRelease github.RepositoryRelease + if err := json.NewDecoder(r.Body).Decode(&newRelease); err != nil { + t.Fatalf("failed to decode request body: %v", err) + } + if *newRelease.TagName != "v1.0.0" { + t.Errorf("unexpected tag name: got %q, want %q", *newRelease.TagName, "v1.0.0") + } + fmt.Fprint(w, `{"tag_name": "v1.0.0", "name": "Version 1.0.0"}`) + }, + wantRelease: &github.RepositoryRelease{TagName: github.Ptr("v1.0.0"), Name: github.Ptr("Version 1.0.0")}, + }, + { + name: "API Error", + tagName: "v1.0.0", + releaseName: "Version 1.0.0", + body: "Initial release", + commitish: "main", + handler: func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) }, + wantErr: true, + wantErrSubstr: "500", + }, + } { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(test.handler) + defer server.Close() + + repo := &Repository{Owner: "owner", Name: "repo"} + client, err := newClientWithHTTP("fake-token", repo, server.Client()) + if err != nil { + t.Fatalf("newClientWithHTTP() error = %v", err) + } + client.BaseURL, _ = url.Parse(server.URL + "/") + + release, err := client.CreateRelease(context.Background(), test.tagName, test.releaseName, test.body, test.commitish) + + if test.wantErr { + if err == nil { + t.Errorf("CreateRelease() err = nil, want error containing %q", test.wantErrSubstr) + } else if !strings.Contains(err.Error(), test.wantErrSubstr) { + t.Errorf("CreateRelease() err = %v, want error containing %q", err, test.wantErrSubstr) + } + } else if err != nil { + t.Errorf("CreateRelease() err = %v, want nil", err) + } + + if diff := cmp.Diff(test.wantRelease, release); diff != "" { + t.Errorf("CreateRelease() release mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestCreateIssueComment(t *testing.T) { + t.Parallel() + for _, test := range []struct { + name string + number int + body string + handler http.HandlerFunc + wantErr bool + wantErrSubstr string + }{ + { + name: "Success", + number: 123, + body: "This is a comment.", + handler: func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("unexpected method: got %s, want %s", r.Method, http.MethodPost) + } + wantPath := "/repos/owner/repo/issues/123/comments" + if r.URL.Path != wantPath { + t.Errorf("unexpected path: got %s, want %s", r.URL.Path, wantPath) + } + var comment github.IssueComment + if err := json.NewDecoder(r.Body).Decode(&comment); err != nil { + t.Fatalf("failed to decode request body: %v", err) + } + if *comment.Body != "This is a comment." { + t.Errorf("unexpected body: got %q, want %q", *comment.Body, "This is a comment.") + } + w.WriteHeader(http.StatusCreated) + }, + }, + { + name: "API Error", + number: 123, + body: "This is a comment.", + handler: func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) }, + wantErr: true, + wantErrSubstr: "500", + }, + } { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(test.handler) + defer server.Close() + + repo := &Repository{Owner: "owner", Name: "repo"} + client, err := newClientWithHTTP("fake-token", repo, server.Client()) + if err != nil { + t.Fatalf("newClientWithHTTP() error = %v", err) + } + client.BaseURL, _ = url.Parse(server.URL + "/") + + err = client.CreateIssueComment(context.Background(), test.number, test.body) + + if test.wantErr { + if err == nil { + t.Errorf("CreateComment() err = nil, want error containing %q", test.wantErrSubstr) + } else if !strings.Contains(err.Error(), test.wantErrSubstr) { + t.Errorf("CreateComment() err = %v, want error containing %q", err, test.wantErrSubstr) + } + } else if err != nil { + t.Errorf("CreateComment() err = %v, want nil", err) + } + }) + } +}