Files
gitea-mcp/internal/tools/pr_files_diff_test.go
Mathias 4c87856aec fix(pr_files_diff): copy per-file diff bytes to break buffer aliasing
splitUnifiedDiff used bytes.Buffer to accumulate each file's diff,
then stored buf.Bytes() into the result map and called buf.Reset()
to start the next file. bytes.Buffer.Bytes() returns the buffer's
internal backing slice; Reset() resets length to 0 but reuses the
same backing array. As a result, every map entry aliased the same
storage, so all files ended up showing the LAST file's diff content.

Fix: copy the bytes into a fresh slice before storing in the map.

Adds TestPRFilesDiffPerFileIsolation as a regression test that
asserts each file entry contains its OWN diff --git header and
none of the other files' headers. Verified failing on the prior
code, passing after the fix.

Closes #25
2026-05-16 23:01:18 +02:00

225 lines
7.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package tools_test
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"gitea.d-ma.be/mathias/gitea-mcp/internal/allowlist"
"gitea.d-ma.be/mathias/gitea-mcp/internal/gitea"
"gitea.d-ma.be/mathias/gitea-mcp/internal/tools"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// buildDiff builds a synthetic unified diff for a set of files.
// Each file gets `linesPerFile` added lines.
func buildDiff(files []string, linesPerFile int) string {
var sb strings.Builder
for _, f := range files {
fmt.Fprintf(&sb, "diff --git a/%s b/%s\n", f, f)
fmt.Fprintf(&sb, "--- a/%s\n+++ b/%s\n", f, f)
fmt.Fprintf(&sb, "@@ -0,0 +1,%d @@\n", linesPerFile)
sb.WriteString(strings.Repeat("+abcdefghij\n", linesPerFile))
}
return sb.String()
}
// buildFilesJSON builds the JSON list of PullRequestFile objects.
func buildFilesJSON(files []string, additions int) string {
entries := make([]string, len(files))
for i, f := range files {
entries[i] = fmt.Sprintf(`{"filename":%q,"status":"modified","additions":%d,"deletions":0}`, f, additions)
}
return "[" + strings.Join(entries, ",") + "]"
}
// newPRFilesDiffServer creates a test server that serves both the /files and .diff endpoints.
func newPRFilesDiffServer(t *testing.T, filesJSON, rawDiff string) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v1/repos/o/r/pulls/1/files":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(filesJSON))
case "/api/v1/repos/o/r/pulls/1.diff":
w.Header().Set("Content-Type", "text/plain")
_, _ = w.Write([]byte(rawDiff))
default:
t.Errorf("unexpected request: %s", r.URL.Path)
w.WriteHeader(http.StatusNotFound)
}
}))
}
func TestPRFilesDiffSmall(t *testing.T) {
// Two files, each ~120 bytes of diff — well under per-file and total caps.
fileNames := []string{"main.go", "util.go"}
// ~10 lines each = ~120 bytes per file diff
rawDiff := buildDiff(fileNames, 10)
filesJSON := buildFilesJSON(fileNames, 10)
srv := newPRFilesDiffServer(t, filesJSON, rawDiff)
defer srv.Close()
tool := tools.NewPRFilesDiff(gitea.NewClient(srv.URL, "tok"), allowlist.New([]string{"o"}))
result, err := tool.Call(context.Background(), json.RawMessage(`{"owner":"o","name":"r","number":1}`))
require.NoError(t, err)
var out struct {
Files []struct {
Path string `json:"path"`
Diff string `json:"diff"`
Truncated bool `json:"truncated"`
Additions int `json:"additions"`
Deletions int `json:"deletions"`
} `json:"files"`
OmittedFiles []string `json:"omitted_files"`
ResponseTruncated bool `json:"response_truncated"`
}
require.NoError(t, json.Unmarshal(result, &out))
assert.Len(t, out.Files, 2)
assert.Empty(t, out.OmittedFiles)
assert.False(t, out.ResponseTruncated)
for _, f := range out.Files {
assert.False(t, f.Truncated, "file %s should not be truncated", f.Path)
assert.NotEmpty(t, f.Diff)
assert.Equal(t, 10, f.Additions)
assert.Equal(t, 0, f.Deletions)
}
paths := []string{out.Files[0].Path, out.Files[1].Path}
assert.ElementsMatch(t, fileNames, paths)
}
// Regression for issue #25: every file's diff entry must contain its OWN diff,
// not a shared buffer pointing at the last file. Prior bug: splitUnifiedDiff
// flushed bytes.Buffer.Bytes() into the map without copying, so every entry
// aliased the buffer's backing array and showed the last file's content.
func TestPRFilesDiffPerFileIsolation(t *testing.T) {
fileNames := []string{"alpha.go", "beta.go", "gamma.go", "delta.go"}
rawDiff := buildDiff(fileNames, 5)
filesJSON := buildFilesJSON(fileNames, 5)
srv := newPRFilesDiffServer(t, filesJSON, rawDiff)
defer srv.Close()
tool := tools.NewPRFilesDiff(gitea.NewClient(srv.URL, "tok"), allowlist.New([]string{"o"}))
result, err := tool.Call(context.Background(), json.RawMessage(`{"owner":"o","name":"r","number":1}`))
require.NoError(t, err)
var out struct {
Files []struct {
Path string `json:"path"`
Diff string `json:"diff"`
} `json:"files"`
}
require.NoError(t, json.Unmarshal(result, &out))
require.Len(t, out.Files, len(fileNames))
for _, f := range out.Files {
expected := fmt.Sprintf("diff --git a/%s b/%s", f.Path, f.Path)
assert.Contains(t, f.Diff, expected,
"file %s diff must contain its own header, got: %.80q", f.Path, f.Diff)
// No other file's header should leak in.
for _, other := range fileNames {
if other == f.Path {
continue
}
otherHeader := fmt.Sprintf("diff --git a/%s b/%s", other, other)
assert.NotContains(t, f.Diff, otherHeader,
"file %s diff must NOT contain %s's header", f.Path, other)
}
}
}
func TestPRFilesDiffPerFileTruncated(t *testing.T) {
// One file with a 30KB diff (each "+abcdefghij\n" = 12 bytes; 30KB / 12 ≈ 2560 lines).
fileNames := []string{"bigfile.go"}
linesPerFile := 2560 // ~30720 bytes > 20KB cap
rawDiff := buildDiff(fileNames, linesPerFile)
filesJSON := buildFilesJSON(fileNames, linesPerFile)
srv := newPRFilesDiffServer(t, filesJSON, rawDiff)
defer srv.Close()
tool := tools.NewPRFilesDiff(gitea.NewClient(srv.URL, "tok"), allowlist.New([]string{"o"}))
result, err := tool.Call(context.Background(), json.RawMessage(`{"owner":"o","name":"r","number":1}`))
require.NoError(t, err)
var out struct {
Files []struct {
Path string `json:"path"`
Diff string `json:"diff"`
Truncated bool `json:"truncated"`
OmittedLines int `json:"omitted_lines"`
Additions int `json:"additions"`
} `json:"files"`
ResponseTruncated bool `json:"response_truncated"`
}
require.NoError(t, json.Unmarshal(result, &out))
require.Len(t, out.Files, 1)
f := out.Files[0]
assert.Equal(t, "bigfile.go", f.Path)
assert.True(t, f.Truncated, "file should be truncated")
assert.Greater(t, f.OmittedLines, 0, "omitted_lines should be > 0")
assert.LessOrEqual(t, len(f.Diff), 20*1024+200, "diff should be capped near 20KB")
assert.False(t, out.ResponseTruncated)
}
func TestPRFilesDiffResponseCapped(t *testing.T) {
// 25 files × ~10KB diff each = ~250KB raw, well over the 200KB response cap.
// Each file: 850 lines × 12 bytes = 10200 bytes per file.
numFiles := 25
linesPerFile := 850
fileNames := make([]string, numFiles)
for i := range fileNames {
fileNames[i] = fmt.Sprintf("file%02d.go", i)
}
rawDiff := buildDiff(fileNames, linesPerFile)
filesJSON := buildFilesJSON(fileNames, linesPerFile)
srv := newPRFilesDiffServer(t, filesJSON, rawDiff)
defer srv.Close()
tool := tools.NewPRFilesDiff(gitea.NewClient(srv.URL, "tok"), allowlist.New([]string{"o"}))
result, err := tool.Call(context.Background(), json.RawMessage(`{"owner":"o","name":"r","number":1}`))
require.NoError(t, err)
var out struct {
Files []struct {
Path string `json:"path"`
} `json:"files"`
OmittedFiles []string `json:"omitted_files"`
ResponseTruncated bool `json:"response_truncated"`
}
require.NoError(t, json.Unmarshal(result, &out))
assert.True(t, out.ResponseTruncated, "response should be truncated")
assert.NotEmpty(t, out.OmittedFiles, "some files should be omitted")
assert.NotEmpty(t, out.Files, "some files should be included")
// Total files accounted for should equal numFiles.
totalAccountedFor := len(out.Files) + len(out.OmittedFiles)
assert.Equal(t, numFiles, totalAccountedFor)
}
func TestPRFilesDiffAllowlistRejects(t *testing.T) {
tool := tools.NewPRFilesDiff(gitea.NewClient("http://unused", ""), allowlist.New([]string{"allowed"}))
_, err := tool.Call(context.Background(), json.RawMessage(`{"owner":"evil","name":"r","number":1}`))
require.Error(t, err)
}
func TestPRFilesDiffRequiresValidNumber(t *testing.T) {
tool := tools.NewPRFilesDiff(gitea.NewClient("http://unused", ""), allowlist.New([]string{"o"}))
_, err := tool.Call(context.Background(), json.RawMessage(`{"owner":"o","name":"r","number":0}`))
require.Error(t, err)
assert.ErrorIs(t, err, gitea.ErrValidation)
}