Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 68 additions & 20 deletions pkg/github/repositories.go
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,8 @@ func GetFileContents(t translations.TranslationHelperFunc) inventory.ServerTool
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}
originalRef := ref

sha, err := OptionalParam[string](args, "sha")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
Expand All @@ -681,7 +683,7 @@ func GetFileContents(t translations.TranslationHelperFunc) inventory.ServerTool
return utils.NewToolResultError("failed to get GitHub client"), nil, nil
}

rawOpts, err := resolveGitReference(ctx, client, owner, repo, ref, sha)
rawOpts, fallbackUsed, err := resolveGitReference(ctx, client, owner, repo, ref, sha)
if err != nil {
return utils.NewToolResultError(fmt.Sprintf("failed to resolve git reference: %s", err)), nil, nil
}
Expand Down Expand Up @@ -747,6 +749,12 @@ func GetFileContents(t translations.TranslationHelperFunc) inventory.ServerTool
}
}

// main branch ref passed in ref parameter but it doesn't exist - default branch was used
var successNote string
if fallbackUsed {
successNote = fmt.Sprintf(" Note: the provided ref '%s' does not exist, default branch '%s' was used instead.", originalRef, rawOpts.Ref)
}

// Determine if content is text or binary
isTextContent := strings.HasPrefix(contentType, "text/") ||
contentType == "application/json" ||
Expand All @@ -762,9 +770,9 @@ func GetFileContents(t translations.TranslationHelperFunc) inventory.ServerTool
}
// Include SHA in the result metadata
if fileSHA != "" {
return utils.NewToolResultResource(fmt.Sprintf("successfully downloaded text file (SHA: %s)", fileSHA), result), nil, nil
return utils.NewToolResultResource(fmt.Sprintf("successfully downloaded text file (SHA: %s)", fileSHA)+successNote, result), nil, nil
}
return utils.NewToolResultResource("successfully downloaded text file", result), nil, nil
return utils.NewToolResultResource("successfully downloaded text file"+successNote, result), nil, nil
}

result := &mcp.ResourceContents{
Expand All @@ -774,9 +782,9 @@ func GetFileContents(t translations.TranslationHelperFunc) inventory.ServerTool
}
// Include SHA in the result metadata
if fileSHA != "" {
return utils.NewToolResultResource(fmt.Sprintf("successfully downloaded binary file (SHA: %s)", fileSHA), result), nil, nil
return utils.NewToolResultResource(fmt.Sprintf("successfully downloaded binary file (SHA: %s)", fileSHA)+successNote, result), nil, nil
}
return utils.NewToolResultResource("successfully downloaded binary file", result), nil, nil
return utils.NewToolResultResource("successfully downloaded binary file"+successNote, result), nil, nil
}

// Raw API call failed
Expand Down Expand Up @@ -1876,15 +1884,15 @@ func looksLikeSHA(s string) bool {
//
// Any unexpected (non-404) errors during the resolution process are returned
// immediately. All API errors are logged with rich context to aid diagnostics.
func resolveGitReference(ctx context.Context, githubClient *github.Client, owner, repo, ref, sha string) (*raw.ContentOpts, error) {
func resolveGitReference(ctx context.Context, githubClient *github.Client, owner, repo, ref, sha string) (*raw.ContentOpts, bool, error) {
// 1) If SHA explicitly provided, it's the highest priority.
if sha != "" {
return &raw.ContentOpts{Ref: "", SHA: sha}, nil
return &raw.ContentOpts{Ref: "", SHA: sha}, false, nil
}

// 1a) If sha is empty but ref looks like a SHA, return it without changes
if looksLikeSHA(ref) {
return &raw.ContentOpts{Ref: "", SHA: ref}, nil
return &raw.ContentOpts{Ref: "", SHA: ref}, false, nil
}

originalRef := ref // Keep original ref for clearer error messages down the line.
Expand All @@ -1893,16 +1901,16 @@ func resolveGitReference(ctx context.Context, githubClient *github.Client, owner
var reference *github.Reference
var resp *github.Response
var err error
var fallbackUsed bool

switch {
case originalRef == "":
// 2a) If ref is empty, determine the default branch.
repoInfo, resp, err := githubClient.Repositories.Get(ctx, owner, repo)
reference, err = resolveDefaultBranch(ctx, githubClient, owner, repo)
if err != nil {
_, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get repository info", resp, err)
return nil, fmt.Errorf("failed to get repository info: %w", err)
return nil, false, err // Error is already wrapped in resolveDefaultBranch.
}
ref = fmt.Sprintf("refs/heads/%s", repoInfo.GetDefaultBranch())
ref = reference.GetRef()
case strings.HasPrefix(originalRef, "refs/"):
// 2b) Already fully qualified. The reference will be fetched at the end.
case strings.HasPrefix(originalRef, "heads/") || strings.HasPrefix(originalRef, "tags/"):
Expand All @@ -1928,19 +1936,26 @@ func resolveGitReference(ctx context.Context, githubClient *github.Client, owner
ghErr2, isGhErr2 := err.(*github.ErrorResponse)
if isGhErr2 && ghErr2.Response.StatusCode == http.StatusNotFound {
if originalRef == "main" {
return nil, fmt.Errorf("could not find branch or tag 'main'. Some repositories use 'master' as the default branch name")
reference, err = resolveDefaultBranch(ctx, githubClient, owner, repo)
if err != nil {
return nil, false, err // Error is already wrapped in resolveDefaultBranch.
}
// Update ref to the actual default branch ref so the note can be generated
ref = reference.GetRef()
fallbackUsed = true
break
}
return nil, fmt.Errorf("could not resolve ref %q as a branch or a tag", originalRef)
return nil, false, fmt.Errorf("could not resolve ref %q as a branch or a tag", originalRef)
}

// The tag lookup failed for a different reason.
_, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get reference (tag)", resp, err)
return nil, fmt.Errorf("failed to get reference for tag '%s': %w", originalRef, err)
return nil, false, fmt.Errorf("failed to get reference for tag '%s': %w", originalRef, err)
}
} else {
// The branch lookup failed for a different reason.
_, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get reference (branch)", resp, err)
return nil, fmt.Errorf("failed to get reference for branch '%s': %w", originalRef, err)
return nil, false, fmt.Errorf("failed to get reference for branch '%s': %w", originalRef, err)
}
}
}
Expand All @@ -1949,15 +1964,48 @@ func resolveGitReference(ctx context.Context, githubClient *github.Client, owner
reference, resp, err = githubClient.Git.GetRef(ctx, owner, repo, ref)
if err != nil {
if ref == "refs/heads/main" {
return nil, fmt.Errorf("could not find branch 'main'. Some repositories use 'master' as the default branch name")
reference, err = resolveDefaultBranch(ctx, githubClient, owner, repo)
if err != nil {
return nil, false, err // Error is already wrapped in resolveDefaultBranch.
}
// Update ref to the actual default branch ref so the note can be generated
ref = reference.GetRef()
fallbackUsed = true
} else {
_, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get final reference", resp, err)
return nil, false, fmt.Errorf("failed to get final reference for %q: %w", ref, err)
}
_, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get final reference", resp, err)
return nil, fmt.Errorf("failed to get final reference for %q: %w", ref, err)
}
}

sha = reference.GetObject().GetSHA()
return &raw.ContentOpts{Ref: ref, SHA: sha}, nil
return &raw.ContentOpts{Ref: ref, SHA: sha}, fallbackUsed, nil
}

func resolveDefaultBranch(ctx context.Context, githubClient *github.Client, owner, repo string) (*github.Reference, error) {
repoInfo, resp, err := githubClient.Repositories.Get(ctx, owner, repo)
if err != nil {
_, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get repository info", resp, err)
return nil, fmt.Errorf("failed to get repository info: %w", err)
}

if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}

defaultBranch := repoInfo.GetDefaultBranch()

defaultRef, resp, err := githubClient.Git.GetRef(ctx, owner, repo, "heads/"+defaultBranch)
if err != nil {
_, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get default branch reference", resp, err)
return nil, fmt.Errorf("failed to get default branch reference: %w", err)
}

if resp != nil && resp.Body != nil {
defer func() { _ = resp.Body.Close() }()
}

return defaultRef, nil
}

// ListStarredRepositories creates a tool to list starred repositories for the authenticated user or a specified user.
Expand Down
75 changes: 74 additions & 1 deletion pkg/github/repositories_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ func Test_GetFileContents(t *testing.T) {
expectedResult interface{}
expectedErrMsg string
expectStatus int
expectedMsg string // optional: expected message text to verify in result
}{
{
name: "successful text content fetch",
Expand Down Expand Up @@ -290,6 +291,70 @@ func Test_GetFileContents(t *testing.T) {
MIMEType: "text/markdown",
},
},
{
name: "successful text content fetch with note when ref falls back to default branch",
mockedClient: mock.NewMockedHTTPClient(
mock.WithRequestMatchHandler(
mock.GetReposByOwnerByRepo,
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"name": "repo", "default_branch": "develop"}`))
}),
),
mock.WithRequestMatchHandler(
mock.GetReposGitRefByOwnerByRepoByRef,
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Request for "refs/heads/main" -> 404 (doesn't exist)
// Request for "refs/heads/develop" (default branch) -> 200
switch {
case strings.Contains(r.URL.Path, "heads/main"):
w.WriteHeader(http.StatusNotFound)
_, _ = w.Write([]byte(`{"message": "Not Found"}`))
case strings.Contains(r.URL.Path, "heads/develop"):
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"ref": "refs/heads/develop", "object": {"sha": "abc123def456"}}`))
default:
w.WriteHeader(http.StatusNotFound)
_, _ = w.Write([]byte(`{"message": "Not Found"}`))
}
}),
),
mock.WithRequestMatchHandler(
mock.GetReposContentsByOwnerByRepoByPath,
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
fileContent := &github.RepositoryContent{
Name: github.Ptr("README.md"),
Path: github.Ptr("README.md"),
SHA: github.Ptr("abc123"),
Type: github.Ptr("file"),
}
contentBytes, _ := json.Marshal(fileContent)
_, _ = w.Write(contentBytes)
}),
),
mock.WithRequestMatchHandler(
raw.GetRawReposContentsByOwnerByRepoBySHAByPath,
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/markdown")
_, _ = w.Write(mockRawContent)
}),
),
),
requestArgs: map[string]interface{}{
"owner": "owner",
"repo": "repo",
"path": "README.md",
"ref": "main",
},
expectError: false,
expectedResult: mcp.ResourceContents{
URI: "repo://owner/repo/abc123def456/contents/README.md",
Text: "# Test Repository\n\nThis is a test repository.",
MIMEType: "text/markdown",
},
expectedMsg: " Note: the provided ref 'main' does not exist, default branch 'refs/heads/develop' was used instead.",
},
{
name: "content fetch fails",
mockedClient: mock.NewMockedHTTPClient(
Expand Down Expand Up @@ -358,6 +423,14 @@ func Test_GetFileContents(t *testing.T) {
// Handle both text and blob resources
resource := getResourceResult(t, result)
assert.Equal(t, expected, *resource)

// If expectedMsg is set, verify the message text
if tc.expectedMsg != "" {
require.Len(t, result.Content, 2)
textContent, ok := result.Content[0].(*mcp.TextContent)
require.True(t, ok, "expected Content[0] to be TextContent")
assert.Contains(t, textContent.Text, tc.expectedMsg)
}
case []*github.RepositoryContent:
// Directory content fetch returns a text result (JSON array)
textContent := getTextResult(t, result)
Expand Down Expand Up @@ -3288,7 +3361,7 @@ func Test_resolveGitReference(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
// Setup client with mock
client := github.NewClient(tc.mockSetup())
opts, err := resolveGitReference(ctx, client, owner, repo, tc.ref, tc.sha)
opts, _, err := resolveGitReference(ctx, client, owner, repo, tc.ref, tc.sha)

if tc.expectError {
require.Error(t, err)
Expand Down