diff --git a/pkg/github/repositories.go b/pkg/github/repositories.go index 1b68b4222..c31bb7df2 100644 --- a/pkg/github/repositories.go +++ b/pkg/github/repositories.go @@ -671,6 +671,8 @@ func GetFileContents(t translations.TranslationHelperFunc) inventory.ServerTool if err != nil { return utils.NewToolResultError(err.Error()), nil, nil } + originalRef := ref + sha, err := OptionalParam[string](args, "sha") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -681,7 +683,7 @@ func GetFileContents(t translations.TranslationHelperFunc) inventory.ServerTool return utils.NewToolResultError("failed to get GitHub client"), nil, nil } - rawOpts, err := resolveGitReference(ctx, client, owner, repo, ref, sha) + rawOpts, fallbackUsed, err := resolveGitReference(ctx, client, owner, repo, ref, sha) if err != nil { return utils.NewToolResultError(fmt.Sprintf("failed to resolve git reference: %s", err)), nil, nil } @@ -747,6 +749,12 @@ func GetFileContents(t translations.TranslationHelperFunc) inventory.ServerTool } } + // main branch ref passed in ref parameter but it doesn't exist - default branch was used + var successNote string + if fallbackUsed { + successNote = fmt.Sprintf(" Note: the provided ref '%s' does not exist, default branch '%s' was used instead.", originalRef, rawOpts.Ref) + } + // Determine if content is text or binary isTextContent := strings.HasPrefix(contentType, "text/") || contentType == "application/json" || @@ -762,9 +770,9 @@ func GetFileContents(t translations.TranslationHelperFunc) inventory.ServerTool } // Include SHA in the result metadata if fileSHA != "" { - return utils.NewToolResultResource(fmt.Sprintf("successfully downloaded text file (SHA: %s)", fileSHA), result), nil, nil + return utils.NewToolResultResource(fmt.Sprintf("successfully downloaded text file (SHA: %s)", fileSHA)+successNote, result), nil, nil } - return utils.NewToolResultResource("successfully downloaded text file", result), nil, nil + return utils.NewToolResultResource("successfully downloaded text file"+successNote, result), nil, nil } result := &mcp.ResourceContents{ @@ -774,9 +782,9 @@ func GetFileContents(t translations.TranslationHelperFunc) inventory.ServerTool } // Include SHA in the result metadata if fileSHA != "" { - return utils.NewToolResultResource(fmt.Sprintf("successfully downloaded binary file (SHA: %s)", fileSHA), result), nil, nil + return utils.NewToolResultResource(fmt.Sprintf("successfully downloaded binary file (SHA: %s)", fileSHA)+successNote, result), nil, nil } - return utils.NewToolResultResource("successfully downloaded binary file", result), nil, nil + return utils.NewToolResultResource("successfully downloaded binary file"+successNote, result), nil, nil } // Raw API call failed @@ -1876,15 +1884,15 @@ func looksLikeSHA(s string) bool { // // Any unexpected (non-404) errors during the resolution process are returned // immediately. All API errors are logged with rich context to aid diagnostics. -func resolveGitReference(ctx context.Context, githubClient *github.Client, owner, repo, ref, sha string) (*raw.ContentOpts, error) { +func resolveGitReference(ctx context.Context, githubClient *github.Client, owner, repo, ref, sha string) (*raw.ContentOpts, bool, error) { // 1) If SHA explicitly provided, it's the highest priority. if sha != "" { - return &raw.ContentOpts{Ref: "", SHA: sha}, nil + return &raw.ContentOpts{Ref: "", SHA: sha}, false, nil } // 1a) If sha is empty but ref looks like a SHA, return it without changes if looksLikeSHA(ref) { - return &raw.ContentOpts{Ref: "", SHA: ref}, nil + return &raw.ContentOpts{Ref: "", SHA: ref}, false, nil } originalRef := ref // Keep original ref for clearer error messages down the line. @@ -1893,16 +1901,16 @@ func resolveGitReference(ctx context.Context, githubClient *github.Client, owner var reference *github.Reference var resp *github.Response var err error + var fallbackUsed bool switch { case originalRef == "": // 2a) If ref is empty, determine the default branch. - repoInfo, resp, err := githubClient.Repositories.Get(ctx, owner, repo) + reference, err = resolveDefaultBranch(ctx, githubClient, owner, repo) if err != nil { - _, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get repository info", resp, err) - return nil, fmt.Errorf("failed to get repository info: %w", err) + return nil, false, err // Error is already wrapped in resolveDefaultBranch. } - ref = fmt.Sprintf("refs/heads/%s", repoInfo.GetDefaultBranch()) + ref = reference.GetRef() case strings.HasPrefix(originalRef, "refs/"): // 2b) Already fully qualified. The reference will be fetched at the end. case strings.HasPrefix(originalRef, "heads/") || strings.HasPrefix(originalRef, "tags/"): @@ -1928,19 +1936,26 @@ func resolveGitReference(ctx context.Context, githubClient *github.Client, owner ghErr2, isGhErr2 := err.(*github.ErrorResponse) if isGhErr2 && ghErr2.Response.StatusCode == http.StatusNotFound { if originalRef == "main" { - return nil, fmt.Errorf("could not find branch or tag 'main'. Some repositories use 'master' as the default branch name") + reference, err = resolveDefaultBranch(ctx, githubClient, owner, repo) + if err != nil { + return nil, false, err // Error is already wrapped in resolveDefaultBranch. + } + // Update ref to the actual default branch ref so the note can be generated + ref = reference.GetRef() + fallbackUsed = true + break } - return nil, fmt.Errorf("could not resolve ref %q as a branch or a tag", originalRef) + return nil, false, fmt.Errorf("could not resolve ref %q as a branch or a tag", originalRef) } // The tag lookup failed for a different reason. _, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get reference (tag)", resp, err) - return nil, fmt.Errorf("failed to get reference for tag '%s': %w", originalRef, err) + return nil, false, fmt.Errorf("failed to get reference for tag '%s': %w", originalRef, err) } } else { // The branch lookup failed for a different reason. _, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get reference (branch)", resp, err) - return nil, fmt.Errorf("failed to get reference for branch '%s': %w", originalRef, err) + return nil, false, fmt.Errorf("failed to get reference for branch '%s': %w", originalRef, err) } } } @@ -1949,15 +1964,48 @@ func resolveGitReference(ctx context.Context, githubClient *github.Client, owner reference, resp, err = githubClient.Git.GetRef(ctx, owner, repo, ref) if err != nil { if ref == "refs/heads/main" { - return nil, fmt.Errorf("could not find branch 'main'. Some repositories use 'master' as the default branch name") + reference, err = resolveDefaultBranch(ctx, githubClient, owner, repo) + if err != nil { + return nil, false, err // Error is already wrapped in resolveDefaultBranch. + } + // Update ref to the actual default branch ref so the note can be generated + ref = reference.GetRef() + fallbackUsed = true + } else { + _, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get final reference", resp, err) + return nil, false, fmt.Errorf("failed to get final reference for %q: %w", ref, err) } - _, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get final reference", resp, err) - return nil, fmt.Errorf("failed to get final reference for %q: %w", ref, err) } } sha = reference.GetObject().GetSHA() - return &raw.ContentOpts{Ref: ref, SHA: sha}, nil + return &raw.ContentOpts{Ref: ref, SHA: sha}, fallbackUsed, nil +} + +func resolveDefaultBranch(ctx context.Context, githubClient *github.Client, owner, repo string) (*github.Reference, error) { + repoInfo, resp, err := githubClient.Repositories.Get(ctx, owner, repo) + if err != nil { + _, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get repository info", resp, err) + return nil, fmt.Errorf("failed to get repository info: %w", err) + } + + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + + defaultBranch := repoInfo.GetDefaultBranch() + + defaultRef, resp, err := githubClient.Git.GetRef(ctx, owner, repo, "heads/"+defaultBranch) + if err != nil { + _, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get default branch reference", resp, err) + return nil, fmt.Errorf("failed to get default branch reference: %w", err) + } + + if resp != nil && resp.Body != nil { + defer func() { _ = resp.Body.Close() }() + } + + return defaultRef, nil } // ListStarredRepositories creates a tool to list starred repositories for the authenticated user or a specified user. diff --git a/pkg/github/repositories_test.go b/pkg/github/repositories_test.go index 6c56d104e..8b5dab098 100644 --- a/pkg/github/repositories_test.go +++ b/pkg/github/repositories_test.go @@ -69,6 +69,7 @@ func Test_GetFileContents(t *testing.T) { expectedResult interface{} expectedErrMsg string expectStatus int + expectedMsg string // optional: expected message text to verify in result }{ { name: "successful text content fetch", @@ -290,6 +291,70 @@ func Test_GetFileContents(t *testing.T) { MIMEType: "text/markdown", }, }, + { + name: "successful text content fetch with note when ref falls back to default branch", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposByOwnerByRepo, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"name": "repo", "default_branch": "develop"}`)) + }), + ), + mock.WithRequestMatchHandler( + mock.GetReposGitRefByOwnerByRepoByRef, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Request for "refs/heads/main" -> 404 (doesn't exist) + // Request for "refs/heads/develop" (default branch) -> 200 + switch { + case strings.Contains(r.URL.Path, "heads/main"): + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + case strings.Contains(r.URL.Path, "heads/develop"): + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"ref": "refs/heads/develop", "object": {"sha": "abc123def456"}}`)) + default: + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + } + }), + ), + mock.WithRequestMatchHandler( + mock.GetReposContentsByOwnerByRepoByPath, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + fileContent := &github.RepositoryContent{ + Name: github.Ptr("README.md"), + Path: github.Ptr("README.md"), + SHA: github.Ptr("abc123"), + Type: github.Ptr("file"), + } + contentBytes, _ := json.Marshal(fileContent) + _, _ = w.Write(contentBytes) + }), + ), + mock.WithRequestMatchHandler( + raw.GetRawReposContentsByOwnerByRepoBySHAByPath, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/markdown") + _, _ = w.Write(mockRawContent) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "path": "README.md", + "ref": "main", + }, + expectError: false, + expectedResult: mcp.ResourceContents{ + URI: "repo://owner/repo/abc123def456/contents/README.md", + Text: "# Test Repository\n\nThis is a test repository.", + MIMEType: "text/markdown", + }, + expectedMsg: " Note: the provided ref 'main' does not exist, default branch 'refs/heads/develop' was used instead.", + }, { name: "content fetch fails", mockedClient: mock.NewMockedHTTPClient( @@ -358,6 +423,14 @@ func Test_GetFileContents(t *testing.T) { // Handle both text and blob resources resource := getResourceResult(t, result) assert.Equal(t, expected, *resource) + + // If expectedMsg is set, verify the message text + if tc.expectedMsg != "" { + require.Len(t, result.Content, 2) + textContent, ok := result.Content[0].(*mcp.TextContent) + require.True(t, ok, "expected Content[0] to be TextContent") + assert.Contains(t, textContent.Text, tc.expectedMsg) + } case []*github.RepositoryContent: // Directory content fetch returns a text result (JSON array) textContent := getTextResult(t, result) @@ -3288,7 +3361,7 @@ func Test_resolveGitReference(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockSetup()) - opts, err := resolveGitReference(ctx, client, owner, repo, tc.ref, tc.sha) + opts, _, err := resolveGitReference(ctx, client, owner, repo, tc.ref, tc.sha) if tc.expectError { require.Error(t, err)