feat(trainer): add trainer MCP skill with reader→writer sub-agent chain
Reader agent scans session logs for SFT/DPO candidates; writer receives reader output and formats+writes training pairs to brain/training-data/. Adds trainer-reader.md and trainer-writer.md discipline prompts. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -16,6 +16,7 @@ import (
|
|||||||
skilldebug "github.com/mathiasbq/supervisor/internal/skills/debug"
|
skilldebug "github.com/mathiasbq/supervisor/internal/skills/debug"
|
||||||
"github.com/mathiasbq/supervisor/internal/skills/review"
|
"github.com/mathiasbq/supervisor/internal/skills/review"
|
||||||
"github.com/mathiasbq/supervisor/internal/skills/spec"
|
"github.com/mathiasbq/supervisor/internal/skills/spec"
|
||||||
|
"github.com/mathiasbq/supervisor/internal/skills/trainer"
|
||||||
"github.com/mathiasbq/supervisor/internal/skills/sessionlog"
|
"github.com/mathiasbq/supervisor/internal/skills/sessionlog"
|
||||||
"github.com/mathiasbq/supervisor/internal/skills/tdd"
|
"github.com/mathiasbq/supervisor/internal/skills/tdd"
|
||||||
"github.com/mathiasbq/supervisor/internal/tier"
|
"github.com/mathiasbq/supervisor/internal/tier"
|
||||||
@@ -72,6 +73,17 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
trainerReaderPrompt, err := os.ReadFile(cfg.ConfigDir + "/trainer-reader.md")
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("read trainer-reader.md", "path", cfg.ConfigDir+"/trainer-reader.md", "err", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
trainerWriterPrompt, err := os.ReadFile(cfg.ConfigDir + "/trainer-writer.md")
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("read trainer-writer.md", "path", cfg.ConfigDir+"/trainer-writer.md", "err", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
executor := iexec.New(iexec.Config{
|
executor := iexec.New(iexec.Config{
|
||||||
SystemPrompt: string(systemPrompt),
|
SystemPrompt: string(systemPrompt),
|
||||||
LiteLLMBaseURL: cfg.LiteLLMBaseURL,
|
LiteLLMBaseURL: cfg.LiteLLMBaseURL,
|
||||||
@@ -123,6 +135,14 @@ func main() {
|
|||||||
ExecutorFn: executor.Run,
|
ExecutorFn: executor.Run,
|
||||||
SessionsDir: cfg.SessionsDir,
|
SessionsDir: cfg.SessionsDir,
|
||||||
}))
|
}))
|
||||||
|
reg.Register(trainer.New(trainer.Config{
|
||||||
|
ReaderPrompt: string(trainerReaderPrompt),
|
||||||
|
WriterPrompt: string(trainerWriterPrompt),
|
||||||
|
DefaultModel: models.Resolve("trainer", ""),
|
||||||
|
ExecutorFn: executor.Run,
|
||||||
|
SessionsDir: cfg.SessionsDir,
|
||||||
|
BrainDir: cfg.BrainDir,
|
||||||
|
}))
|
||||||
|
|
||||||
srv := mcp.NewServer(reg)
|
srv := mcp.NewServer(reg)
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
|
|||||||
@@ -9,3 +9,5 @@ skills:
|
|||||||
review: ollama/devstral-tuned
|
review: ollama/devstral-tuned
|
||||||
debug: ollama/deepseek-r1-tuned
|
debug: ollama/deepseek-r1-tuned
|
||||||
retrospective: ollama/qwen3-coder-30b-tuned
|
retrospective: ollama/qwen3-coder-30b-tuned
|
||||||
|
spec: ollama/qwen3-coder-30b-tuned
|
||||||
|
trainer: ollama/qwen3-coder-30b-tuned
|
||||||
|
|||||||
31
config/supervisor/trainer-reader.md
Normal file
31
config/supervisor/trainer-reader.md
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
# Trainer Reader Discipline
|
||||||
|
|
||||||
|
You scan session logs and identify candidate learning moments worth converting to training data.
|
||||||
|
|
||||||
|
## What to look for
|
||||||
|
- **SFT candidates**: the worker did exactly the right thing — a clean pattern worth reinforcing
|
||||||
|
- **DPO candidates**: the worker first produced a wrong or suboptimal response, then corrected — you have both rejected and chosen
|
||||||
|
|
||||||
|
## Scoring (1–5)
|
||||||
|
- 5: novel pattern, clearly correct, generalises across projects
|
||||||
|
- 4: good pattern, correct, somewhat project-specific but still useful
|
||||||
|
- 3: correct but obvious — include only if especially clean
|
||||||
|
- 2 or below: skip — too ambiguous or too context-specific
|
||||||
|
|
||||||
|
## Output contract
|
||||||
|
Return JSON result with:
|
||||||
|
- `status`: "pass" or "error"
|
||||||
|
- `phase`: "trainer"
|
||||||
|
- `skill`: "trainer"
|
||||||
|
- `file_path`: ""
|
||||||
|
- `runner_output`: JSON array of candidates (valid JSON, not markdown):
|
||||||
|
[{"type":"sft","moment":"<what happened>","prompt":"<what was asked>","completion":"<what was done right>","score":4},
|
||||||
|
{"type":"dpo","moment":"<what happened>","prompt":"<what was asked>","chosen":"<correct>","rejected":"<incorrect>","score":3}]
|
||||||
|
- `verified`: true
|
||||||
|
- `message`: "N sft candidates, M dpo candidates found"
|
||||||
|
|
||||||
|
## Rules
|
||||||
|
1. Read all session entries in the task prompt
|
||||||
|
2. Score each entry — only include entries scoring >= 3
|
||||||
|
3. Prompt/completion fields must be phrased to generalise: no project-specific paths or names
|
||||||
|
4. If no candidates score >= 3, return an empty array `[]` — never force low-quality candidates
|
||||||
35
config/supervisor/trainer-writer.md
Normal file
35
config/supervisor/trainer-writer.md
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
# Trainer Writer Discipline
|
||||||
|
|
||||||
|
You receive candidate learning moments from the reader and write clean SFT/DPO training pairs.
|
||||||
|
|
||||||
|
## Quality gate (apply before writing)
|
||||||
|
- SFT: prompt must be phrased so it could come from any project, not just this one
|
||||||
|
- DPO: chosen and rejected must be clearly distinguishable — skip if a reader can't tell which is better
|
||||||
|
- Never include project-specific paths, variable names, or identifiers in any pair
|
||||||
|
|
||||||
|
## Output contract
|
||||||
|
Return JSON result with:
|
||||||
|
- `status`: "pass" (pairs written or skipped due to quality) or "error" (candidates JSON was malformed)
|
||||||
|
- `phase`: "trainer"
|
||||||
|
- `skill`: "trainer"
|
||||||
|
- `file_path`: path of the last file written (empty if nothing passed quality gate)
|
||||||
|
- `runner_output`: "N SFT pairs written to brain/training-data/sft/, M DPO pairs to brain/training-data/dpo/" or "0 pairs passed quality gate"
|
||||||
|
- `verified`: true if files were written; false if nothing passed
|
||||||
|
- `message`: "N sft + M dpo pairs for session <id>" or "no pairs passed quality gate"
|
||||||
|
|
||||||
|
## File format
|
||||||
|
JSONL — one JSON object per line.
|
||||||
|
|
||||||
|
SFT: `{"prompt": "...", "completion": "..."}`
|
||||||
|
DPO: `{"prompt": "...", "chosen": "...", "rejected": "..."}`
|
||||||
|
|
||||||
|
Write SFT to: `<brain_dir>/training-data/sft/<session_id>.jsonl`
|
||||||
|
Write DPO to: `<brain_dir>/training-data/dpo/<session_id>.jsonl`
|
||||||
|
|
||||||
|
Append to existing files if they exist (don't overwrite).
|
||||||
|
|
||||||
|
## Rules
|
||||||
|
1. Parse the `reader_candidates` JSON from the task prompt
|
||||||
|
2. For each candidate: apply quality gate
|
||||||
|
3. Write passing SFT candidates to sft JSONL, DPO candidates to dpo JSONL
|
||||||
|
4. If nothing passes, return status "pass" with verified: false and message "no pairs passed quality gate"
|
||||||
80
internal/skills/trainer/handlers.go
Normal file
80
internal/skills/trainer/handlers.go
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
// internal/skills/trainer/handlers.go
|
||||||
|
package trainer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||||
|
"github.com/mathiasbq/supervisor/internal/session"
|
||||||
|
)
|
||||||
|
|
||||||
|
type trainArgs struct {
|
||||||
|
SessionID string `json:"session_id"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle dispatches the MCP tool call to the trainer handler.
|
||||||
|
func (s *Skill) Handle(ctx context.Context, tool string, args json.RawMessage) (json.RawMessage, error) {
|
||||||
|
if tool != "trainer" {
|
||||||
|
return nil, fmt.Errorf("unknown tool: %s", tool)
|
||||||
|
}
|
||||||
|
var a trainArgs
|
||||||
|
if err := json.Unmarshal(args, &a); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse args: %w", err)
|
||||||
|
}
|
||||||
|
if a.SessionID == "" {
|
||||||
|
return nil, fmt.Errorf("session_id is required")
|
||||||
|
}
|
||||||
|
if s.cfg.ExecutorFn == nil {
|
||||||
|
return nil, fmt.Errorf("no executor configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
model := a.Model
|
||||||
|
if model == "" {
|
||||||
|
model = s.cfg.DefaultModel
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, err := session.Read(s.cfg.SessionsDir, a.SessionID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read session log: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Step 1: Reader agent ─────────────────────────────────────────────────
|
||||||
|
history := session.FormatHistory(entries, "")
|
||||||
|
readerTask := fmt.Sprintf(
|
||||||
|
"role: reader\nsession_id: %s\nbrain_dir: %s\n\n%s",
|
||||||
|
a.SessionID, s.cfg.BrainDir, history,
|
||||||
|
)
|
||||||
|
readerResult, err := s.cfg.ExecutorFn(ctx, iexec.Request{
|
||||||
|
SkillPrompt: s.cfg.ReaderPrompt,
|
||||||
|
TaskPrompt: readerTask,
|
||||||
|
Model: model,
|
||||||
|
Tools: "Read",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("reader agent: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Step 2: Writer agent (receives reader candidates) ────────────────────
|
||||||
|
writerTask := fmt.Sprintf(
|
||||||
|
"role: writer\nsession_id: %s\nbrain_dir: %s\n\nreader_summary: %s\nreader_candidates:\n%s",
|
||||||
|
a.SessionID, s.cfg.BrainDir, readerResult.Message, readerResult.RunnerOutput,
|
||||||
|
)
|
||||||
|
writerResult, err := s.cfg.ExecutorFn(ctx, iexec.Request{
|
||||||
|
SkillPrompt: s.cfg.WriterPrompt,
|
||||||
|
TaskPrompt: writerTask,
|
||||||
|
Model: model,
|
||||||
|
Tools: "Read,Write",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("writer agent: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b, err := json.Marshal(writerResult)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("marshal result: %w", err)
|
||||||
|
}
|
||||||
|
return b, nil
|
||||||
|
}
|
||||||
82
internal/skills/trainer/handlers_test.go
Normal file
82
internal/skills/trainer/handlers_test.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
// internal/skills/trainer/handlers_test.go
|
||||||
|
package trainer_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||||
|
"github.com/mathiasbq/supervisor/internal/session"
|
||||||
|
"github.com/mathiasbq/supervisor/internal/skills/trainer"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTrainerToolRegistered(t *testing.T) {
|
||||||
|
sk := trainer.New(trainer.Config{ReaderPrompt: "r", WriterPrompt: "w"})
|
||||||
|
names := make([]string, 0)
|
||||||
|
for _, tool := range sk.Tools() {
|
||||||
|
names = append(names, tool.Name)
|
||||||
|
}
|
||||||
|
assert.Contains(t, names, "trainer")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTrainerRequiresSessionID(t *testing.T) {
|
||||||
|
sk := trainer.New(trainer.Config{ReaderPrompt: "r", WriterPrompt: "w"})
|
||||||
|
_, err := sk.Handle(context.Background(), "trainer", json.RawMessage(`{}`))
|
||||||
|
assert.ErrorContains(t, err, "session_id")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTrainerCallsReaderThenWriter(t *testing.T) {
|
||||||
|
sessDir := t.TempDir()
|
||||||
|
require.NoError(t, session.Append(sessDir, "sess-1", session.Entry{
|
||||||
|
SessionID: "sess-1", Skill: "tdd", Phase: "red", FinalStatus: "pass",
|
||||||
|
Message: "wrote failing test", FilePath: "internal/foo/foo_test.go",
|
||||||
|
}))
|
||||||
|
|
||||||
|
callCount := 0
|
||||||
|
var readerTask, writerTask string
|
||||||
|
|
||||||
|
fakeFn := func(_ context.Context, req iexec.Request) (iexec.Result, error) {
|
||||||
|
callCount++
|
||||||
|
if callCount == 1 {
|
||||||
|
// reader call
|
||||||
|
readerTask = req.TaskPrompt
|
||||||
|
return iexec.Result{
|
||||||
|
Status: "pass", Phase: "trainer", Skill: "trainer",
|
||||||
|
RunnerOutput: `[{"type":"sft","moment":"first-pass clean TDD","score":4}]`,
|
||||||
|
Verified: true, ModelUsed: "self", Message: "1 sft candidate found",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
// writer call
|
||||||
|
writerTask = req.TaskPrompt
|
||||||
|
return iexec.Result{
|
||||||
|
Status: "pass", Phase: "trainer", Skill: "trainer",
|
||||||
|
FilePath: sessDir + "/training-data/sft/sess-1.jsonl",
|
||||||
|
Verified: true, ModelUsed: "self", Message: "1 sft pair written",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
sk := trainer.New(trainer.Config{
|
||||||
|
ReaderPrompt: "reader rules",
|
||||||
|
WriterPrompt: "writer rules",
|
||||||
|
ExecutorFn: fakeFn,
|
||||||
|
SessionsDir: sessDir,
|
||||||
|
BrainDir: t.TempDir(),
|
||||||
|
})
|
||||||
|
out, err := sk.Handle(context.Background(), "trainer", json.RawMessage(`{"session_id":"sess-1"}`))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, 2, callCount, "executor must be called exactly twice: reader then writer")
|
||||||
|
assert.Contains(t, readerTask, "role: reader")
|
||||||
|
assert.Contains(t, readerTask, "sess-1")
|
||||||
|
assert.Contains(t, readerTask, "wrote failing test") // session history in reader prompt
|
||||||
|
assert.Contains(t, writerTask, "role: writer")
|
||||||
|
assert.Contains(t, writerTask, "sft candidate") // reader output passed to writer
|
||||||
|
|
||||||
|
var result iexec.Result
|
||||||
|
require.NoError(t, json.Unmarshal(out, &result))
|
||||||
|
assert.Equal(t, "trainer", result.Phase)
|
||||||
|
assert.Equal(t, "pass", result.Status)
|
||||||
|
}
|
||||||
53
internal/skills/trainer/skill.go
Normal file
53
internal/skills/trainer/skill.go
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
// internal/skills/trainer/skill.go
|
||||||
|
package trainer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||||
|
"github.com/mathiasbq/supervisor/internal/registry"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExecutorFn is the function signature for running a worker subprocess.
|
||||||
|
type ExecutorFn func(ctx context.Context, req iexec.Request) (iexec.Result, error)
|
||||||
|
|
||||||
|
// Config holds dependencies for the trainer skill.
|
||||||
|
type Config struct {
|
||||||
|
ReaderPrompt string
|
||||||
|
WriterPrompt string
|
||||||
|
DefaultModel string
|
||||||
|
ExecutorFn ExecutorFn
|
||||||
|
SessionsDir string
|
||||||
|
BrainDir string // root of brain/ directory; writer writes to BrainDir/training-data/
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skill implements the trainer MCP tool.
|
||||||
|
type Skill struct{ cfg Config }
|
||||||
|
|
||||||
|
// New creates a new trainer Skill.
|
||||||
|
func New(cfg Config) *Skill { return &Skill{cfg: cfg} }
|
||||||
|
|
||||||
|
// Name returns the skill identifier.
|
||||||
|
func (s *Skill) Name() string { return "trainer" }
|
||||||
|
|
||||||
|
// Tools returns the MCP tool definitions for this skill.
|
||||||
|
func (s *Skill) Tools() []registry.ToolDef {
|
||||||
|
schema := func(required []string, props map[string]any) json.RawMessage {
|
||||||
|
b, _ := json.Marshal(map[string]any{"type": "object", "required": required, "properties": props})
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
return []registry.ToolDef{
|
||||||
|
{
|
||||||
|
Name: "trainer",
|
||||||
|
Description: "Extract SFT and DPO training pairs from a session log. Runs a reader→writer chain: reader identifies learning moments, writer formats and writes pairs to brain/training-data/.",
|
||||||
|
InputSchema: schema(
|
||||||
|
[]string{"session_id"},
|
||||||
|
map[string]any{
|
||||||
|
"session_id": map[string]any{"type": "string"},
|
||||||
|
"model": map[string]any{"type": "string"},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user