feat(trainer): add trainer MCP skill with reader→writer sub-agent chain

Reader agent scans session logs for SFT/DPO candidates; writer receives
reader output and formats+writes training pairs to brain/training-data/.
Adds trainer-reader.md and trainer-writer.md discipline prompts.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Mathias Bergqvist
2026-04-19 14:06:00 +02:00
parent 7697e901d2
commit 38fcac4cba
7 changed files with 303 additions and 0 deletions

View File

@@ -0,0 +1,80 @@
// internal/skills/trainer/handlers.go
package trainer
import (
"context"
"encoding/json"
"fmt"
iexec "github.com/mathiasbq/supervisor/internal/exec"
"github.com/mathiasbq/supervisor/internal/session"
)
type trainArgs struct {
SessionID string `json:"session_id"`
Model string `json:"model"`
}
// Handle dispatches the MCP tool call to the trainer handler.
func (s *Skill) Handle(ctx context.Context, tool string, args json.RawMessage) (json.RawMessage, error) {
if tool != "trainer" {
return nil, fmt.Errorf("unknown tool: %s", tool)
}
var a trainArgs
if err := json.Unmarshal(args, &a); err != nil {
return nil, fmt.Errorf("parse args: %w", err)
}
if a.SessionID == "" {
return nil, fmt.Errorf("session_id is required")
}
if s.cfg.ExecutorFn == nil {
return nil, fmt.Errorf("no executor configured")
}
model := a.Model
if model == "" {
model = s.cfg.DefaultModel
}
entries, err := session.Read(s.cfg.SessionsDir, a.SessionID)
if err != nil {
return nil, fmt.Errorf("read session log: %w", err)
}
// ── Step 1: Reader agent ─────────────────────────────────────────────────
history := session.FormatHistory(entries, "")
readerTask := fmt.Sprintf(
"role: reader\nsession_id: %s\nbrain_dir: %s\n\n%s",
a.SessionID, s.cfg.BrainDir, history,
)
readerResult, err := s.cfg.ExecutorFn(ctx, iexec.Request{
SkillPrompt: s.cfg.ReaderPrompt,
TaskPrompt: readerTask,
Model: model,
Tools: "Read",
})
if err != nil {
return nil, fmt.Errorf("reader agent: %w", err)
}
// ── Step 2: Writer agent (receives reader candidates) ────────────────────
writerTask := fmt.Sprintf(
"role: writer\nsession_id: %s\nbrain_dir: %s\n\nreader_summary: %s\nreader_candidates:\n%s",
a.SessionID, s.cfg.BrainDir, readerResult.Message, readerResult.RunnerOutput,
)
writerResult, err := s.cfg.ExecutorFn(ctx, iexec.Request{
SkillPrompt: s.cfg.WriterPrompt,
TaskPrompt: writerTask,
Model: model,
Tools: "Read,Write",
})
if err != nil {
return nil, fmt.Errorf("writer agent: %w", err)
}
b, err := json.Marshal(writerResult)
if err != nil {
return nil, fmt.Errorf("marshal result: %w", err)
}
return b, nil
}

View File

@@ -0,0 +1,82 @@
// internal/skills/trainer/handlers_test.go
package trainer_test
import (
"context"
"encoding/json"
"testing"
iexec "github.com/mathiasbq/supervisor/internal/exec"
"github.com/mathiasbq/supervisor/internal/session"
"github.com/mathiasbq/supervisor/internal/skills/trainer"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestTrainerToolRegistered(t *testing.T) {
sk := trainer.New(trainer.Config{ReaderPrompt: "r", WriterPrompt: "w"})
names := make([]string, 0)
for _, tool := range sk.Tools() {
names = append(names, tool.Name)
}
assert.Contains(t, names, "trainer")
}
func TestTrainerRequiresSessionID(t *testing.T) {
sk := trainer.New(trainer.Config{ReaderPrompt: "r", WriterPrompt: "w"})
_, err := sk.Handle(context.Background(), "trainer", json.RawMessage(`{}`))
assert.ErrorContains(t, err, "session_id")
}
func TestTrainerCallsReaderThenWriter(t *testing.T) {
sessDir := t.TempDir()
require.NoError(t, session.Append(sessDir, "sess-1", session.Entry{
SessionID: "sess-1", Skill: "tdd", Phase: "red", FinalStatus: "pass",
Message: "wrote failing test", FilePath: "internal/foo/foo_test.go",
}))
callCount := 0
var readerTask, writerTask string
fakeFn := func(_ context.Context, req iexec.Request) (iexec.Result, error) {
callCount++
if callCount == 1 {
// reader call
readerTask = req.TaskPrompt
return iexec.Result{
Status: "pass", Phase: "trainer", Skill: "trainer",
RunnerOutput: `[{"type":"sft","moment":"first-pass clean TDD","score":4}]`,
Verified: true, ModelUsed: "self", Message: "1 sft candidate found",
}, nil
}
// writer call
writerTask = req.TaskPrompt
return iexec.Result{
Status: "pass", Phase: "trainer", Skill: "trainer",
FilePath: sessDir + "/training-data/sft/sess-1.jsonl",
Verified: true, ModelUsed: "self", Message: "1 sft pair written",
}, nil
}
sk := trainer.New(trainer.Config{
ReaderPrompt: "reader rules",
WriterPrompt: "writer rules",
ExecutorFn: fakeFn,
SessionsDir: sessDir,
BrainDir: t.TempDir(),
})
out, err := sk.Handle(context.Background(), "trainer", json.RawMessage(`{"session_id":"sess-1"}`))
require.NoError(t, err)
assert.Equal(t, 2, callCount, "executor must be called exactly twice: reader then writer")
assert.Contains(t, readerTask, "role: reader")
assert.Contains(t, readerTask, "sess-1")
assert.Contains(t, readerTask, "wrote failing test") // session history in reader prompt
assert.Contains(t, writerTask, "role: writer")
assert.Contains(t, writerTask, "sft candidate") // reader output passed to writer
var result iexec.Result
require.NoError(t, json.Unmarshal(out, &result))
assert.Equal(t, "trainer", result.Phase)
assert.Equal(t, "pass", result.Status)
}

View File

@@ -0,0 +1,53 @@
// internal/skills/trainer/skill.go
package trainer
import (
"context"
"encoding/json"
iexec "github.com/mathiasbq/supervisor/internal/exec"
"github.com/mathiasbq/supervisor/internal/registry"
)
// ExecutorFn is the function signature for running a worker subprocess.
type ExecutorFn func(ctx context.Context, req iexec.Request) (iexec.Result, error)
// Config holds dependencies for the trainer skill.
type Config struct {
ReaderPrompt string
WriterPrompt string
DefaultModel string
ExecutorFn ExecutorFn
SessionsDir string
BrainDir string // root of brain/ directory; writer writes to BrainDir/training-data/
}
// Skill implements the trainer MCP tool.
type Skill struct{ cfg Config }
// New creates a new trainer Skill.
func New(cfg Config) *Skill { return &Skill{cfg: cfg} }
// Name returns the skill identifier.
func (s *Skill) Name() string { return "trainer" }
// Tools returns the MCP tool definitions for this skill.
func (s *Skill) Tools() []registry.ToolDef {
schema := func(required []string, props map[string]any) json.RawMessage {
b, _ := json.Marshal(map[string]any{"type": "object", "required": required, "properties": props})
return b
}
return []registry.ToolDef{
{
Name: "trainer",
Description: "Extract SFT and DPO training pairs from a session log. Runs a reader→writer chain: reader identifies learning moments, writer formats and writes pairs to brain/training-data/.",
InputSchema: schema(
[]string{"session_id"},
map[string]any{
"session_id": map[string]any{"type": "string"},
"model": map[string]any{"type": "string"},
},
),
},
}
}