feat(tools): pr_files_diff with caps

Returns per-file unified diff for a PR, capped at 20KB/file and 200KB
total response. Files exceeding per-file cap report truncated+omitted_lines;
files that would push the response over 200KB go to omitted_files.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Mathias Bergqvist
2026-05-04 22:57:11 +02:00
parent d3d0fed6b1
commit e95e87e8e3
5 changed files with 435 additions and 0 deletions

View File

@@ -0,0 +1,171 @@
package tools
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"strings"
"gitea.d-ma.be/mathias/gitea-mcp/internal/allowlist"
"gitea.d-ma.be/mathias/gitea-mcp/internal/gitea"
"gitea.d-ma.be/mathias/gitea-mcp/internal/registry"
)
const (
maxFileDiffBytes = 20 * 1024
maxResponseBytes = 200 * 1024
)
type PRFilesDiff struct {
c *gitea.Client
a *allowlist.Allowlist
}
func NewPRFilesDiff(c *gitea.Client, a *allowlist.Allowlist) *PRFilesDiff {
return &PRFilesDiff{c: c, a: a}
}
func (t *PRFilesDiff) Descriptor() registry.ToolDescriptor {
return registry.ToolDescriptor{
Name: "pr_files_diff",
Description: "Get a pull request's per-file diff with size caps (20KB/file, 200KB total).",
InputSchema: json.RawMessage(`{
"type":"object",
"properties":{
"owner":{"type":"string"},
"name":{"type":"string"},
"number":{"type":"integer","minimum":1}
},
"required":["owner","name","number"]
}`),
}
}
type prFilesDiffArgs struct {
Owner string `json:"owner"`
Name string `json:"name"`
Number int `json:"number"`
}
type prFileDiffEntry struct {
Path string `json:"path"`
Diff string `json:"diff"`
Truncated bool `json:"truncated"`
OmittedLines int `json:"omitted_lines,omitempty"`
Additions int `json:"additions"`
Deletions int `json:"deletions"`
}
func (t *PRFilesDiff) Call(ctx context.Context, raw json.RawMessage) (json.RawMessage, error) {
var args prFilesDiffArgs
if err := parseArgs(raw, &args); err != nil {
return nil, err
}
if err := t.a.Check(args.Owner); err != nil {
return nil, err
}
if args.Number < 1 {
return nil, fmt.Errorf("number must be >= 1: %w", gitea.ErrValidation)
}
files, err := t.c.GetPullRequestFiles(ctx, args.Owner, args.Name, args.Number)
if err != nil {
return nil, err
}
rawDiff, err := t.c.GetPullRequestDiff(ctx, args.Owner, args.Name, args.Number)
if err != nil {
return nil, err
}
// Split unified diff by per-file headers ("diff --git a/path b/path")
perFile := splitUnifiedDiff(rawDiff)
out := struct {
Files []prFileDiffEntry `json:"files"`
OmittedFiles []string `json:"omitted_files,omitempty"`
ResponseTruncated bool `json:"response_truncated"`
}{
Files: make([]prFileDiffEntry, 0, len(files)),
}
totalBytes := 0
for _, f := range files {
// look up the diff for this file (best-effort by path match)
diffBytes, ok := perFile[f.Filename]
if !ok {
diffBytes = []byte{}
}
entry := prFileDiffEntry{
Path: f.Filename,
Additions: f.Additions,
Deletions: f.Deletions,
}
// Per-file cap
if len(diffBytes) > maxFileDiffBytes {
truncated := diffBytes[:maxFileDiffBytes]
omittedLines := bytes.Count(diffBytes[maxFileDiffBytes:], []byte("\n"))
entry.Diff = string(truncated)
entry.Truncated = true
entry.OmittedLines = omittedLines
} else {
entry.Diff = string(diffBytes)
}
// Response cap — if adding this entry would exceed, push to omitted_files
entryEstimate := len(entry.Diff) + 200 // small overhead for path + counts
if totalBytes+entryEstimate > maxResponseBytes {
out.OmittedFiles = append(out.OmittedFiles, f.Filename)
out.ResponseTruncated = true
continue
}
totalBytes += entryEstimate
out.Files = append(out.Files, entry)
}
return textOK(out)
}
// splitUnifiedDiff parses a unified diff and returns a map from filename to that file's
// portion of the diff. The unified diff format starts each file with a line like
// "diff --git a/<path> b/<path>".
func splitUnifiedDiff(d []byte) map[string][]byte {
m := map[string][]byte{}
scanner := bufio.NewScanner(bytes.NewReader(d))
scanner.Buffer(make([]byte, 0, 64*1024), 16*1024*1024) // allow long diffs
var currentFile string
var current bytes.Buffer
flush := func() {
if currentFile != "" {
m[currentFile] = []byte(current.String())
current.Reset()
}
}
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "diff --git ") {
flush()
// Parse: "diff --git a/<path> b/<path>"
rest := strings.TrimPrefix(line, "diff --git a/")
parts := strings.SplitN(rest, " b/", 2)
if len(parts) == 2 {
currentFile = parts[0]
} else {
currentFile = ""
}
}
if currentFile != "" {
current.WriteString(line)
current.WriteByte('\n')
}
}
flush()
return m
}