refactor: replace orchestrator/verifier chain with direct LiteLLM calls
Drop the three-layer Claude subprocess orchestration (local model →
Claude verifier → cloud escalation). Skills now call LiteLLM directly
and return plain text to Claude Code, which decides what to do with it.
- Delete executor, orchestrator, verifier, result, attempts packages
- Simplify LiteLLMExecutor: Run(Request)→Result becomes Complete(model,sys,user)→(string,int64,error)
- Replace ExecutorFn with CompleteFunc in all 6 skill configs
- Rewrite all skill handlers to call Complete and return {"text","model","duration_ms"}
- Simplify config/models: remove Verifier/LlamaSwapURL, add ModelFor
- Bump version to v0.5.0
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -8,7 +8,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/mathiasbq/supervisor/internal/brain"
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/session"
|
||||
)
|
||||
|
||||
@@ -52,38 +51,32 @@ func (s *Skill) Handle(ctx context.Context, tool string, args json.RawMessage) (
|
||||
task = brainCtx + "\n---\n\n" + task
|
||||
}
|
||||
|
||||
if s.cfg.ExecutorFn == nil {
|
||||
if s.cfg.CompleteFunc == nil {
|
||||
return nil, fmt.Errorf("no executor configured")
|
||||
}
|
||||
t0 := time.Now()
|
||||
result, err := s.cfg.ExecutorFn(ctx, iexec.Request{
|
||||
SkillPrompt: s.cfg.SkillPrompt,
|
||||
TaskPrompt: task,
|
||||
Model: model,
|
||||
Tools: "Read,Bash",
|
||||
})
|
||||
text, dur, err := s.cfg.CompleteFunc(ctx, model, s.cfg.SkillPrompt, task)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if a.SessionID != "" && s.cfg.SessionsDir != "" {
|
||||
msg := text
|
||||
if len(msg) > 200 {
|
||||
msg = msg[:200]
|
||||
}
|
||||
_ = session.Append(s.cfg.SessionsDir, a.SessionID, session.Entry{
|
||||
SessionID: a.SessionID,
|
||||
Timestamp: time.Now(),
|
||||
Skill: "debug",
|
||||
Phase: "debug",
|
||||
ProjectRoot: a.ProjectRoot,
|
||||
Attempts: session.AttemptsFrom(result.Attempts),
|
||||
FinalStatus: result.Status,
|
||||
ModelUsed: result.ModelUsed,
|
||||
FinalStatus: "ok",
|
||||
ModelUsed: model,
|
||||
DurationMs: time.Since(t0).Milliseconds(),
|
||||
Message: result.Message,
|
||||
Message: msg,
|
||||
})
|
||||
}
|
||||
|
||||
b, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal result: %w", err)
|
||||
}
|
||||
return b, nil
|
||||
return json.Marshal(map[string]any{"text": text, "model": model, "duration_ms": dur})
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/skills/debug"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -33,29 +32,22 @@ func TestDebugRequiresError(t *testing.T) {
|
||||
assert.ErrorContains(t, err, "error")
|
||||
}
|
||||
|
||||
func TestDebugCallsExecutor(t *testing.T) {
|
||||
called := false
|
||||
func TestDebugCallsCompleteFunc(t *testing.T) {
|
||||
var capturedTask string
|
||||
fakeFn := func(_ context.Context, req iexec.Request) (iexec.Result, error) {
|
||||
called = true
|
||||
capturedTask = req.TaskPrompt
|
||||
return iexec.Result{
|
||||
Status: "pass", Phase: "debug", Skill: "debug",
|
||||
RunnerOutput: "HYPOTHESIS 1 (likelihood: high): nil map access\nVERIFY: go test ./... → expected: panic line reference",
|
||||
Verified: false, ModelUsed: "self", Message: "3 hypotheses for: panic nil pointer at foo.go:42",
|
||||
}, nil
|
||||
fakeFn := func(_ context.Context, _, _, user string) (string, int64, error) {
|
||||
capturedTask = user
|
||||
return "HYPOTHESIS 1 (high): nil map access. Verify: go test ./...", 90, nil
|
||||
}
|
||||
|
||||
sk := debug.New(debug.Config{SkillPrompt: "debug rules", ExecutorFn: fakeFn, SessionsDir: t.TempDir()})
|
||||
sk := debug.New(debug.Config{SkillPrompt: "debug rules", CompleteFunc: fakeFn, SessionsDir: t.TempDir()})
|
||||
out, err := sk.Handle(context.Background(), "debug", json.RawMessage(
|
||||
`{"project_root":"/tmp/proj","error":"panic: nil pointer dereference at foo.go:42","context":"occurs on startup"}`,
|
||||
))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, called)
|
||||
assert.Contains(t, capturedTask, "panic: nil pointer dereference")
|
||||
assert.Contains(t, capturedTask, "occurs on startup")
|
||||
|
||||
var result iexec.Result
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &result))
|
||||
assert.Equal(t, "debug", result.Phase)
|
||||
assert.Contains(t, result["text"], "nil map access")
|
||||
}
|
||||
|
||||
@@ -5,20 +5,19 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/registry"
|
||||
)
|
||||
|
||||
// ExecutorFn is the function signature for running a worker subprocess.
|
||||
type ExecutorFn func(ctx context.Context, req iexec.Request) (iexec.Result, error)
|
||||
// CompleteFunc is the function used to call a local model.
|
||||
type CompleteFunc func(ctx context.Context, model, system, user string) (string, int64, error)
|
||||
|
||||
// Config holds dependencies for the debug skill.
|
||||
type Config struct {
|
||||
SkillPrompt string
|
||||
DefaultModel string
|
||||
ExecutorFn ExecutorFn
|
||||
CompleteFunc CompleteFunc
|
||||
SessionsDir string
|
||||
IngestBaseURL string // optional: base URL of ingestion server for brain context
|
||||
IngestBaseURL string
|
||||
}
|
||||
|
||||
// Skill implements the debug MCP tool.
|
||||
@@ -40,7 +39,7 @@ func (s *Skill) Tools() []registry.ToolDef {
|
||||
return []registry.ToolDef{
|
||||
{
|
||||
Name: "debug",
|
||||
Description: "Analyse an error and return 3-5 hypotheses ordered by likelihood, each with a concrete verification step.",
|
||||
Description: "Consult a local model to analyse an error and return hypotheses ordered by likelihood, each with a concrete verification step.",
|
||||
InputSchema: schema(
|
||||
[]string{"project_root", "error"},
|
||||
map[string]any{
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/session"
|
||||
)
|
||||
|
||||
@@ -34,7 +33,6 @@ func (s *Skill) Handle(ctx context.Context, tool string, args json.RawMessage) (
|
||||
model = s.cfg.DefaultModel
|
||||
}
|
||||
|
||||
// Read session log entries (empty slice if no log exists yet).
|
||||
entries, err := session.Read(s.cfg.SessionsDir, a.SessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read session log: %w", err)
|
||||
@@ -46,39 +44,33 @@ func (s *Skill) Handle(ctx context.Context, tool string, args json.RawMessage) (
|
||||
}
|
||||
|
||||
taskPrompt := fmt.Sprintf(
|
||||
"SESSION_ID: %s\n\nSESSION_LOG:\n%s\n\nReview this session log. Identify what is novel or worth preserving as organizational knowledge. Write structured entries to brain/raw/ via brain_write. Return JSON result when done.",
|
||||
"SESSION_ID: %s\n\nSESSION_LOG:\n%s\n\nReview this session log. Identify what is novel or worth preserving as organizational knowledge. Provide structured insights.",
|
||||
a.SessionID, string(logJSON),
|
||||
)
|
||||
|
||||
if s.cfg.ExecutorFn == nil {
|
||||
if s.cfg.CompleteFunc == nil {
|
||||
return nil, fmt.Errorf("no executor configured")
|
||||
}
|
||||
t0 := time.Now()
|
||||
result, err := s.cfg.ExecutorFn(ctx, iexec.Request{
|
||||
SkillPrompt: s.cfg.SkillPrompt,
|
||||
TaskPrompt: taskPrompt,
|
||||
Model: model,
|
||||
Tools: "Bash,Read,Write",
|
||||
})
|
||||
text, dur, err := s.cfg.CompleteFunc(ctx, model, s.cfg.SkillPrompt, taskPrompt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("retrospective worker: %w", err)
|
||||
return nil, fmt.Errorf("retrospective model: %w", err)
|
||||
}
|
||||
|
||||
msg := text
|
||||
if len(msg) > 200 {
|
||||
msg = msg[:200]
|
||||
}
|
||||
_ = session.Append(s.cfg.SessionsDir, a.SessionID, session.Entry{
|
||||
SessionID: a.SessionID,
|
||||
Timestamp: time.Now(),
|
||||
Skill: "retrospective",
|
||||
Phase: "retrospective",
|
||||
Attempts: session.AttemptsFrom(result.Attempts),
|
||||
FinalStatus: result.Status,
|
||||
ModelUsed: result.ModelUsed,
|
||||
FinalStatus: "ok",
|
||||
ModelUsed: model,
|
||||
DurationMs: time.Since(t0).Milliseconds(),
|
||||
Message: result.Message,
|
||||
Message: msg,
|
||||
})
|
||||
|
||||
b, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal result: %w", err)
|
||||
}
|
||||
return b, nil
|
||||
return json.Marshal(map[string]any{"text": text, "model": model, "duration_ms": dur})
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/skills/retrospective"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -20,20 +19,14 @@ func TestHandle_Retrospective_RequiresSessionID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHandle_Retrospective_BuildsPromptWithSessionLog(t *testing.T) {
|
||||
var capturedReq iexec.Request
|
||||
var capturedTask string
|
||||
s := retrospective.New(retrospective.Config{
|
||||
SkillPrompt: "retrospective discipline",
|
||||
DefaultModel: "ollama/test",
|
||||
SessionsDir: t.TempDir(), // empty dir, no session file — that's OK, session.Read returns nil
|
||||
ExecutorFn: func(_ context.Context, req iexec.Request) (iexec.Result, error) {
|
||||
capturedReq = req
|
||||
return iexec.Result{
|
||||
Status: "pass",
|
||||
Phase: "retrospective",
|
||||
Skill: "retrospective",
|
||||
Verified: true,
|
||||
Message: "wrote 2 entries to brain",
|
||||
}, nil
|
||||
SessionsDir: t.TempDir(),
|
||||
CompleteFunc: func(_ context.Context, _, _, user string) (string, int64, error) {
|
||||
capturedTask = user
|
||||
return "Key insight: the team resolved a tricky nil pointer issue via careful logging.", 75, nil
|
||||
},
|
||||
})
|
||||
|
||||
@@ -41,9 +34,8 @@ func TestHandle_Retrospective_BuildsPromptWithSessionLog(t *testing.T) {
|
||||
out, err := s.Handle(context.Background(), "retrospective", args)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result iexec.Result
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &result))
|
||||
assert.Equal(t, "pass", result.Status)
|
||||
assert.Contains(t, capturedReq.SkillPrompt, "retrospective discipline")
|
||||
assert.Contains(t, capturedReq.TaskPrompt, "empty-session")
|
||||
assert.Contains(t, result["text"], "nil pointer")
|
||||
assert.Contains(t, capturedTask, "empty-session")
|
||||
}
|
||||
|
||||
@@ -5,19 +5,18 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/registry"
|
||||
)
|
||||
|
||||
// ExecutorFn allows injecting a test double for the subprocess executor.
|
||||
type ExecutorFn func(ctx context.Context, req iexec.Request) (iexec.Result, error)
|
||||
// CompleteFunc is the function used to call a local model.
|
||||
type CompleteFunc func(ctx context.Context, model, system, user string) (string, int64, error)
|
||||
|
||||
// Config holds retrospective skill configuration.
|
||||
type Config struct {
|
||||
SkillPrompt string // content of retrospective.md
|
||||
DefaultModel string // model to use when not specified in args
|
||||
SessionsDir string // path to brain/sessions/
|
||||
ExecutorFn ExecutorFn // injected executor
|
||||
SkillPrompt string
|
||||
DefaultModel string
|
||||
SessionsDir string
|
||||
CompleteFunc CompleteFunc
|
||||
}
|
||||
|
||||
// Skill implements registry.Skill for the retrospective tool.
|
||||
@@ -36,7 +35,7 @@ func (s *Skill) Tools() []registry.ToolDef {
|
||||
return []registry.ToolDef{
|
||||
{
|
||||
Name: "retrospective",
|
||||
Description: "Run a retrospective on a completed session. Reads the session log, identifies novel learnings, and writes structured entries to the brain for ingestion. Call at the end of each coding session.",
|
||||
Description: "Consult a local model to analyse a completed session and identify what is novel or worth preserving as organizational knowledge.",
|
||||
InputSchema: json.RawMessage(`{
|
||||
"type": "object",
|
||||
"required": ["session_id"],
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/mathiasbq/supervisor/internal/brain"
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/session"
|
||||
)
|
||||
|
||||
@@ -53,39 +52,32 @@ func (s *Skill) Handle(ctx context.Context, tool string, args json.RawMessage) (
|
||||
task = brainCtx + "\n---\n\n" + task
|
||||
}
|
||||
|
||||
if s.cfg.ExecutorFn == nil {
|
||||
if s.cfg.CompleteFunc == nil {
|
||||
return nil, fmt.Errorf("no executor configured")
|
||||
}
|
||||
t0 := time.Now()
|
||||
result, err := s.cfg.ExecutorFn(ctx, iexec.Request{
|
||||
SkillPrompt: s.cfg.SkillPrompt,
|
||||
TaskPrompt: task,
|
||||
Model: model,
|
||||
Tools: "Read,Bash",
|
||||
})
|
||||
text, dur, err := s.cfg.CompleteFunc(ctx, model, s.cfg.SkillPrompt, task)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if a.SessionID != "" && s.cfg.SessionsDir != "" {
|
||||
msg := text
|
||||
if len(msg) > 200 {
|
||||
msg = msg[:200]
|
||||
}
|
||||
_ = session.Append(s.cfg.SessionsDir, a.SessionID, session.Entry{
|
||||
SessionID: a.SessionID,
|
||||
Timestamp: time.Now(),
|
||||
Skill: "review",
|
||||
Phase: "review",
|
||||
ProjectRoot: a.ProjectRoot,
|
||||
Attempts: session.AttemptsFrom(result.Attempts),
|
||||
FinalStatus: result.Status,
|
||||
FilePath: result.FilePath,
|
||||
ModelUsed: result.ModelUsed,
|
||||
FinalStatus: "ok",
|
||||
ModelUsed: model,
|
||||
DurationMs: time.Since(t0).Milliseconds(),
|
||||
Message: result.Message,
|
||||
Message: msg,
|
||||
})
|
||||
}
|
||||
|
||||
b, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal result: %w", err)
|
||||
}
|
||||
return b, nil
|
||||
return json.Marshal(map[string]any{"text": text, "model": model, "duration_ms": dur})
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/skills/review"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -33,29 +32,22 @@ func TestReviewRequiresFiles(t *testing.T) {
|
||||
assert.ErrorContains(t, err, "files")
|
||||
}
|
||||
|
||||
func TestReviewCallsExecutor(t *testing.T) {
|
||||
called := false
|
||||
func TestReviewCallsCompleteFunc(t *testing.T) {
|
||||
var capturedTask string
|
||||
fakeFn := func(_ context.Context, req iexec.Request) (iexec.Result, error) {
|
||||
called = true
|
||||
capturedTask = req.TaskPrompt
|
||||
return iexec.Result{
|
||||
Status: "pass", Phase: "review", Skill: "review",
|
||||
Verified: true, ModelUsed: "self", Message: "2 warnings found",
|
||||
}, nil
|
||||
fakeFn := func(_ context.Context, _, _, user string) (string, int64, error) {
|
||||
capturedTask = user
|
||||
return "2 warnings found: missing error handling at line 42", 80, nil
|
||||
}
|
||||
|
||||
sk := review.New(review.Config{SkillPrompt: "review rules", ExecutorFn: fakeFn, SessionsDir: t.TempDir()})
|
||||
sk := review.New(review.Config{SkillPrompt: "review rules", CompleteFunc: fakeFn, SessionsDir: t.TempDir()})
|
||||
out, err := sk.Handle(context.Background(), "review", json.RawMessage(
|
||||
`{"project_root":"/tmp/proj","files":["internal/foo/foo.go"],"context":"PR: add Foo helper"}`,
|
||||
))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, called)
|
||||
assert.Contains(t, capturedTask, "internal/foo/foo.go")
|
||||
assert.Contains(t, capturedTask, "PR: add Foo helper")
|
||||
|
||||
var result iexec.Result
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &result))
|
||||
assert.Equal(t, "pass", result.Status)
|
||||
assert.Equal(t, "review", result.Phase)
|
||||
assert.Contains(t, result["text"], "2 warnings found")
|
||||
}
|
||||
|
||||
@@ -5,20 +5,19 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/registry"
|
||||
)
|
||||
|
||||
// ExecutorFn is the function signature for running a worker subprocess.
|
||||
type ExecutorFn func(ctx context.Context, req iexec.Request) (iexec.Result, error)
|
||||
// CompleteFunc is the function used to call a local model.
|
||||
type CompleteFunc func(ctx context.Context, model, system, user string) (string, int64, error)
|
||||
|
||||
// Config holds dependencies for the review skill.
|
||||
type Config struct {
|
||||
SkillPrompt string
|
||||
DefaultModel string
|
||||
ExecutorFn ExecutorFn
|
||||
CompleteFunc CompleteFunc
|
||||
SessionsDir string
|
||||
IngestBaseURL string // optional: base URL of ingestion server for brain context
|
||||
IngestBaseURL string
|
||||
}
|
||||
|
||||
// Skill implements the review MCP tool.
|
||||
@@ -40,7 +39,7 @@ func (s *Skill) Tools() []registry.ToolDef {
|
||||
return []registry.ToolDef{
|
||||
{
|
||||
Name: "review",
|
||||
Description: "Perform a structured code review of the specified files. Returns findings with severity levels.",
|
||||
Description: "Consult a local model for a structured code review of the specified files. Returns findings with severity levels.",
|
||||
InputSchema: schema(
|
||||
[]string{"project_root", "files"},
|
||||
map[string]any{
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/mathiasbq/supervisor/internal/brain"
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/session"
|
||||
)
|
||||
|
||||
@@ -57,39 +56,32 @@ func (s *Skill) Handle(ctx context.Context, tool string, args json.RawMessage) (
|
||||
task = brainCtx + "\n---\n\n" + task
|
||||
}
|
||||
|
||||
if s.cfg.ExecutorFn == nil {
|
||||
if s.cfg.CompleteFunc == nil {
|
||||
return nil, fmt.Errorf("no executor configured")
|
||||
}
|
||||
t0 := time.Now()
|
||||
result, err := s.cfg.ExecutorFn(ctx, iexec.Request{
|
||||
SkillPrompt: s.cfg.SkillPrompt,
|
||||
TaskPrompt: task,
|
||||
Model: model,
|
||||
Tools: "Read,Write",
|
||||
})
|
||||
text, dur, err := s.cfg.CompleteFunc(ctx, model, s.cfg.SkillPrompt, task)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if a.SessionID != "" && s.cfg.SessionsDir != "" {
|
||||
msg := text
|
||||
if len(msg) > 200 {
|
||||
msg = msg[:200]
|
||||
}
|
||||
_ = session.Append(s.cfg.SessionsDir, a.SessionID, session.Entry{
|
||||
SessionID: a.SessionID,
|
||||
Timestamp: time.Now(),
|
||||
Skill: "spec",
|
||||
Phase: "spec",
|
||||
ProjectRoot: a.ProjectRoot,
|
||||
Attempts: session.AttemptsFrom(result.Attempts),
|
||||
FinalStatus: result.Status,
|
||||
FilePath: result.FilePath,
|
||||
ModelUsed: result.ModelUsed,
|
||||
FinalStatus: "ok",
|
||||
ModelUsed: model,
|
||||
DurationMs: time.Since(t0).Milliseconds(),
|
||||
Message: result.Message,
|
||||
Message: msg,
|
||||
})
|
||||
}
|
||||
|
||||
b, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal result: %w", err)
|
||||
}
|
||||
return b, nil
|
||||
return json.Marshal(map[string]any{"text": text, "model": model, "duration_ms": dur})
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/skills/spec"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -33,29 +32,22 @@ func TestSpecRequiresRequirements(t *testing.T) {
|
||||
assert.ErrorContains(t, err, "requirements")
|
||||
}
|
||||
|
||||
func TestSpecCallsExecutor(t *testing.T) {
|
||||
called := false
|
||||
func TestSpecCallsCompleteFunc(t *testing.T) {
|
||||
var capturedTask string
|
||||
fakeFn := func(_ context.Context, req iexec.Request) (iexec.Result, error) {
|
||||
called = true
|
||||
capturedTask = req.TaskPrompt
|
||||
return iexec.Result{
|
||||
Status: "pass", Phase: "spec", Skill: "spec",
|
||||
FilePath: "/tmp/proj/docs/login-spec.md",
|
||||
Verified: true, ModelUsed: "self", Message: "spec written: login feature",
|
||||
}, nil
|
||||
fakeFn := func(_ context.Context, _, _, user string) (string, int64, error) {
|
||||
capturedTask = user
|
||||
return "# OAuth2 Login Spec\n\n## Overview\nImplement OAuth2 login flow.", 110, nil
|
||||
}
|
||||
|
||||
sk := spec.New(spec.Config{SkillPrompt: "spec rules", ExecutorFn: fakeFn, SessionsDir: t.TempDir()})
|
||||
sk := spec.New(spec.Config{SkillPrompt: "spec rules", CompleteFunc: fakeFn, SessionsDir: t.TempDir()})
|
||||
out, err := sk.Handle(context.Background(), "spec", json.RawMessage(
|
||||
`{"project_root":"/tmp/proj","requirements":"add OAuth2 login","output_path":"docs/login-spec.md"}`,
|
||||
))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, called)
|
||||
assert.Contains(t, capturedTask, "OAuth2 login")
|
||||
assert.Contains(t, capturedTask, "docs/login-spec.md")
|
||||
|
||||
var result iexec.Result
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &result))
|
||||
assert.Equal(t, "spec", result.Phase)
|
||||
assert.Contains(t, result["text"], "OAuth2 Login Spec")
|
||||
}
|
||||
|
||||
@@ -5,20 +5,19 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/registry"
|
||||
)
|
||||
|
||||
// ExecutorFn is the function signature for running a worker subprocess.
|
||||
type ExecutorFn func(ctx context.Context, req iexec.Request) (iexec.Result, error)
|
||||
// CompleteFunc is the function used to call a local model.
|
||||
type CompleteFunc func(ctx context.Context, model, system, user string) (string, int64, error)
|
||||
|
||||
// Config holds dependencies for the spec skill.
|
||||
type Config struct {
|
||||
SkillPrompt string
|
||||
DefaultModel string
|
||||
ExecutorFn ExecutorFn
|
||||
CompleteFunc CompleteFunc
|
||||
SessionsDir string
|
||||
IngestBaseURL string // optional: base URL of ingestion server for brain context
|
||||
IngestBaseURL string
|
||||
}
|
||||
|
||||
// Skill implements the spec MCP tool.
|
||||
@@ -40,7 +39,7 @@ func (s *Skill) Tools() []registry.ToolDef {
|
||||
return []registry.ToolDef{
|
||||
{
|
||||
Name: "spec",
|
||||
Description: "Generate a structured implementation spec from requirements. Writes the spec to output_path in the project.",
|
||||
Description: "Consult a local model to draft a structured implementation spec from requirements. Returns the spec text.",
|
||||
InputSchema: schema(
|
||||
[]string{"project_root", "requirements"},
|
||||
map[string]any{
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/mathiasbq/supervisor/internal/brain"
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/session"
|
||||
)
|
||||
|
||||
@@ -51,7 +50,7 @@ func (s *Skill) handleRed(ctx context.Context, raw json.RawMessage) (json.RawMes
|
||||
if brainCtx != "" {
|
||||
task = brainCtx + "\n---\n\n" + task
|
||||
}
|
||||
return s.execute(ctx, task)
|
||||
return s.complete(ctx, s.resolveModel(args.Model), task)
|
||||
}
|
||||
|
||||
type greenArgs struct {
|
||||
@@ -80,11 +79,11 @@ func (s *Skill) handleGreen(ctx context.Context, raw json.RawMessage) (json.RawM
|
||||
task = session.PrependHistory(s.cfg.SessionsDir, args.SessionID, "green", task)
|
||||
|
||||
t0 := time.Now()
|
||||
result, err := s.execute(ctx, task)
|
||||
result, err := s.complete(ctx, s.resolveModel(args.Model), task)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.logAttempt(args.SessionID, args.ProjectRoot, "tdd", "green", t0, result)
|
||||
s.logEntry(args.SessionID, args.ProjectRoot, "tdd", "green", s.resolveModel(args.Model), t0, result)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -118,11 +117,11 @@ func (s *Skill) handleRefactor(ctx context.Context, raw json.RawMessage) (json.R
|
||||
task = session.PrependHistory(s.cfg.SessionsDir, args.SessionID, "refactor", task)
|
||||
|
||||
t0 := time.Now()
|
||||
result, err := s.execute(ctx, task)
|
||||
result, err := s.complete(ctx, s.resolveModel(args.Model), task)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.logAttempt(args.SessionID, args.ProjectRoot, "tdd", "refactor", t0, result)
|
||||
s.logEntry(args.SessionID, args.ProjectRoot, "tdd", "refactor", s.resolveModel(args.Model), t0, result)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -133,31 +132,32 @@ func (s *Skill) resolveModel(override string) string {
|
||||
return s.cfg.DefaultModel
|
||||
}
|
||||
|
||||
// execute calls ExecutorFn and returns the marshaled result.
|
||||
func (s *Skill) execute(ctx context.Context, task string) (json.RawMessage, error) {
|
||||
if s.cfg.ExecutorFn == nil {
|
||||
// complete calls CompleteFunc and returns the text as JSON.
|
||||
func (s *Skill) complete(ctx context.Context, model, task string) (json.RawMessage, error) {
|
||||
if s.cfg.CompleteFunc == nil {
|
||||
return nil, fmt.Errorf("no executor configured")
|
||||
}
|
||||
req := iexec.Request{
|
||||
SkillPrompt: s.cfg.SkillPrompt,
|
||||
TaskPrompt: task,
|
||||
}
|
||||
result, err := s.cfg.ExecutorFn(ctx, req)
|
||||
text, dur, err := s.cfg.CompleteFunc(ctx, model, s.cfg.SkillPrompt, task)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(result)
|
||||
return json.Marshal(map[string]any{"text": text, "model": model, "duration_ms": dur})
|
||||
}
|
||||
|
||||
// logAttempt writes a session.Entry for a completed phase if session_id is set.
|
||||
// raw is the marshaled Result returned by execute; we unmarshal to extract fields.
|
||||
func (s *Skill) logAttempt(sessionID, projectRoot, skill, phase string, t0 time.Time, raw json.RawMessage) {
|
||||
// logEntry writes a session.Entry for a completed phase if session_id is set.
|
||||
func (s *Skill) logEntry(sessionID, projectRoot, skill, phase, model string, t0 time.Time, raw json.RawMessage) {
|
||||
if sessionID == "" || s.cfg.SessionsDir == "" {
|
||||
return
|
||||
}
|
||||
var result iexec.Result
|
||||
if err := json.Unmarshal(raw, &result); err != nil {
|
||||
return
|
||||
var msg string
|
||||
var result struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &result); err == nil && len(result.Text) > 0 {
|
||||
msg = result.Text
|
||||
if len(msg) > 200 {
|
||||
msg = msg[:200]
|
||||
}
|
||||
}
|
||||
_ = session.Append(s.cfg.SessionsDir, sessionID, session.Entry{
|
||||
SessionID: sessionID,
|
||||
@@ -165,11 +165,9 @@ func (s *Skill) logAttempt(sessionID, projectRoot, skill, phase string, t0 time.
|
||||
Skill: skill,
|
||||
Phase: phase,
|
||||
ProjectRoot: projectRoot,
|
||||
Attempts: session.AttemptsFrom(result.Attempts),
|
||||
FinalStatus: result.Status,
|
||||
FilePath: result.FilePath,
|
||||
ModelUsed: result.ModelUsed,
|
||||
FinalStatus: "ok",
|
||||
ModelUsed: model,
|
||||
DurationMs: time.Since(t0).Milliseconds(),
|
||||
Message: result.Message,
|
||||
Message: msg,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/session"
|
||||
"github.com/mathiasbq/supervisor/internal/skills/tdd"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -14,8 +13,7 @@ import (
|
||||
|
||||
func TestTDDSkillTools(t *testing.T) {
|
||||
skill := tdd.New(tdd.Config{
|
||||
SystemPrompt: "supervisor rules",
|
||||
SkillPrompt: "tdd rules",
|
||||
SkillPrompt: "tdd rules",
|
||||
})
|
||||
tools := skill.Tools()
|
||||
names := make([]string, len(tools))
|
||||
@@ -26,19 +24,19 @@ func TestTDDSkillTools(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTDDSkillHandleUnknown(t *testing.T) {
|
||||
skill := tdd.New(tdd.Config{SystemPrompt: "s", SkillPrompt: "t"})
|
||||
skill := tdd.New(tdd.Config{SkillPrompt: "t"})
|
||||
_, err := skill.Handle(context.Background(), "tdd_unknown", json.RawMessage(`{}`))
|
||||
assert.ErrorContains(t, err, "unknown tool")
|
||||
}
|
||||
|
||||
func TestTDDRedRequiresProjectRoot(t *testing.T) {
|
||||
skill := tdd.New(tdd.Config{SystemPrompt: "s", SkillPrompt: "t"})
|
||||
skill := tdd.New(tdd.Config{SkillPrompt: "t"})
|
||||
_, err := skill.Handle(context.Background(), "tdd_red", json.RawMessage(`{"spec":"add two numbers"}`))
|
||||
assert.ErrorContains(t, err, "project_root")
|
||||
}
|
||||
|
||||
func TestTDDRedRequiresSpec(t *testing.T) {
|
||||
skill := tdd.New(tdd.Config{SystemPrompt: "s", SkillPrompt: "t"})
|
||||
skill := tdd.New(tdd.Config{SkillPrompt: "t"})
|
||||
_, err := skill.Handle(context.Background(), "tdd_red", json.RawMessage(`{"project_root":"/tmp/proj"}`))
|
||||
assert.ErrorContains(t, err, "spec")
|
||||
}
|
||||
@@ -51,35 +49,49 @@ func TestTDDGreenInjectsSessionHistory(t *testing.T) {
|
||||
Message: "wrote failing test for Foo",
|
||||
}))
|
||||
|
||||
var capturedPrompt string
|
||||
fakeFn := func(_ context.Context, req iexec.Request) (iexec.Result, error) {
|
||||
capturedPrompt = req.TaskPrompt
|
||||
return iexec.Result{Status: "pass", Phase: "green", Skill: "tdd", Verified: true, ModelUsed: "self", Message: "ok"}, nil
|
||||
var capturedTask string
|
||||
fakeFn := func(_ context.Context, _, _, user string) (string, int64, error) {
|
||||
capturedTask = user
|
||||
return "here is my suggestion", 100, nil
|
||||
}
|
||||
|
||||
sk := tdd.New(tdd.Config{SkillPrompt: "tdd", ExecutorFn: fakeFn, SessionsDir: sessDir})
|
||||
sk := tdd.New(tdd.Config{SkillPrompt: "tdd", CompleteFunc: fakeFn, SessionsDir: sessDir})
|
||||
_, err := sk.Handle(context.Background(), "tdd_green", json.RawMessage(
|
||||
`{"project_root":"/tmp","test_path":"internal/foo/foo_test.go","test_cmd":"go test ./...","session_id":"sess-1"}`,
|
||||
))
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, capturedPrompt, "## Session history")
|
||||
assert.Contains(t, capturedPrompt, "wrote failing test for Foo")
|
||||
assert.Contains(t, capturedTask, "## Session history")
|
||||
assert.Contains(t, capturedTask, "wrote failing test for Foo")
|
||||
}
|
||||
|
||||
func TestTDDGreenNoHistoryWhenSessionIDEmpty(t *testing.T) {
|
||||
var capturedPrompt string
|
||||
fakeFn := func(_ context.Context, req iexec.Request) (iexec.Result, error) {
|
||||
capturedPrompt = req.TaskPrompt
|
||||
return iexec.Result{Status: "pass", Phase: "green", Skill: "tdd", Verified: true, ModelUsed: "self", Message: "ok"}, nil
|
||||
var capturedTask string
|
||||
fakeFn := func(_ context.Context, _, _, user string) (string, int64, error) {
|
||||
capturedTask = user
|
||||
return "suggestion", 50, nil
|
||||
}
|
||||
|
||||
sk := tdd.New(tdd.Config{SkillPrompt: "tdd", ExecutorFn: fakeFn, SessionsDir: t.TempDir()})
|
||||
sk := tdd.New(tdd.Config{SkillPrompt: "tdd", CompleteFunc: fakeFn, SessionsDir: t.TempDir()})
|
||||
_, err := sk.Handle(context.Background(), "tdd_green", json.RawMessage(
|
||||
`{"project_root":"/tmp","test_path":"internal/foo/foo_test.go"}`,
|
||||
))
|
||||
require.NoError(t, err)
|
||||
assert.NotContains(t, capturedPrompt, "## Session history")
|
||||
assert.NotContains(t, capturedTask, "## Session history")
|
||||
}
|
||||
|
||||
// Ensure require is used (avoids import error).
|
||||
var _ = require.New
|
||||
func TestTDDGreenReturnsTextJSON(t *testing.T) {
|
||||
fakeFn := func(_ context.Context, _, _, _ string) (string, int64, error) {
|
||||
return "write a func that adds two ints", 42, nil
|
||||
}
|
||||
|
||||
sk := tdd.New(tdd.Config{SkillPrompt: "tdd", CompleteFunc: fakeFn})
|
||||
raw, err := sk.Handle(context.Background(), "tdd_green", json.RawMessage(
|
||||
`{"project_root":"/tmp","test_path":"foo_test.go"}`,
|
||||
))
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(raw, &result))
|
||||
assert.Equal(t, "write a func that adds two ints", result["text"])
|
||||
assert.Equal(t, float64(42), result["duration_ms"])
|
||||
}
|
||||
|
||||
@@ -4,17 +4,15 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/registry"
|
||||
)
|
||||
|
||||
// ExecutorFn allows injecting a test double for the executor.
|
||||
type ExecutorFn func(ctx context.Context, req iexec.Request) (iexec.Result, error)
|
||||
// CompleteFunc is the function used to call a local model.
|
||||
type CompleteFunc func(ctx context.Context, model, system, user string) (string, int64, error)
|
||||
|
||||
type Config struct {
|
||||
SystemPrompt string
|
||||
SkillPrompt string
|
||||
ExecutorFn ExecutorFn // nil = no executor (tests that don't reach execute())
|
||||
CompleteFunc CompleteFunc // nil = no executor (tests that don't reach execute())
|
||||
DefaultModel string
|
||||
SessionsDir string // optional: path to brain/sessions/ for history injection
|
||||
IngestBaseURL string // optional: base URL of ingestion server for brain context
|
||||
@@ -44,7 +42,7 @@ func (s *Skill) Tools() []registry.ToolDef {
|
||||
return []registry.ToolDef{
|
||||
{
|
||||
Name: "tdd_red",
|
||||
Description: "Write a failing test for the described behavior. Verifies the test fails before returning.",
|
||||
Description: "Consult a local model for help writing a failing test for the described behavior.",
|
||||
InputSchema: schema(
|
||||
[]string{"project_root", "spec"},
|
||||
map[string]any{
|
||||
@@ -57,7 +55,7 @@ func (s *Skill) Tools() []registry.ToolDef {
|
||||
},
|
||||
{
|
||||
Name: "tdd_green",
|
||||
Description: "Write minimal implementation to make the test at test_path pass.",
|
||||
Description: "Consult a local model for implementation ideas to make the test at test_path pass.",
|
||||
InputSchema: schema(
|
||||
[]string{"project_root", "test_path"},
|
||||
map[string]any{
|
||||
@@ -71,7 +69,7 @@ func (s *Skill) Tools() []registry.ToolDef {
|
||||
},
|
||||
{
|
||||
Name: "tdd_refactor",
|
||||
Description: "Refactor the implementation at impl_path while keeping tests green.",
|
||||
Description: "Consult a local model for refactoring suggestions for impl_path while keeping tests green.",
|
||||
InputSchema: schema(
|
||||
[]string{"project_root", "test_path", "impl_path"},
|
||||
map[string]any{
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/session"
|
||||
)
|
||||
|
||||
@@ -28,7 +27,7 @@ func (s *Skill) Handle(ctx context.Context, tool string, args json.RawMessage) (
|
||||
if a.SessionID == "" {
|
||||
return nil, fmt.Errorf("session_id is required")
|
||||
}
|
||||
if s.cfg.ExecutorFn == nil {
|
||||
if s.cfg.CompleteFunc == nil {
|
||||
return nil, fmt.Errorf("no executor configured")
|
||||
}
|
||||
|
||||
@@ -42,53 +41,47 @@ func (s *Skill) Handle(ctx context.Context, tool string, args json.RawMessage) (
|
||||
return nil, fmt.Errorf("read session log: %w", err)
|
||||
}
|
||||
|
||||
// ── Step 1: Reader agent ─────────────────────────────────────────────────
|
||||
// ── Step 1: Reader ────────────────────────────────────────────────────────
|
||||
history := session.FormatHistory(entries, "")
|
||||
readerTask := fmt.Sprintf(
|
||||
"role: reader\nsession_id: %s\nbrain_dir: %s\n\n%s",
|
||||
a.SessionID, s.cfg.BrainDir, history,
|
||||
)
|
||||
readerResult, err := s.cfg.ExecutorFn(ctx, iexec.Request{
|
||||
SkillPrompt: s.cfg.ReaderPrompt,
|
||||
TaskPrompt: readerTask,
|
||||
Model: model,
|
||||
Tools: "Read",
|
||||
})
|
||||
readerText, _, err := s.cfg.CompleteFunc(ctx, model, s.cfg.ReaderPrompt, readerTask)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reader agent: %w", err)
|
||||
return nil, fmt.Errorf("reader: %w", err)
|
||||
}
|
||||
|
||||
// ── Step 2: Writer agent (receives reader candidates) ────────────────────
|
||||
// ── Step 2: Writer (receives reader output) ───────────────────────────────
|
||||
t0 := time.Now()
|
||||
writerTask := fmt.Sprintf(
|
||||
"role: writer\nsession_id: %s\nbrain_dir: %s\n\nreader_summary: %s\nreader_candidates:\n%s",
|
||||
a.SessionID, s.cfg.BrainDir, readerResult.Message, readerResult.RunnerOutput,
|
||||
"role: writer\nsession_id: %s\nbrain_dir: %s\n\nreader_analysis:\n%s",
|
||||
a.SessionID, s.cfg.BrainDir, readerText,
|
||||
)
|
||||
writerResult, err := s.cfg.ExecutorFn(ctx, iexec.Request{
|
||||
SkillPrompt: s.cfg.WriterPrompt,
|
||||
TaskPrompt: writerTask,
|
||||
Model: model,
|
||||
Tools: "Read,Write",
|
||||
})
|
||||
writerText, dur, err := s.cfg.CompleteFunc(ctx, model, s.cfg.WriterPrompt, writerTask)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("writer agent: %w", err)
|
||||
return nil, fmt.Errorf("writer: %w", err)
|
||||
}
|
||||
|
||||
msg := writerText
|
||||
if len(msg) > 200 {
|
||||
msg = msg[:200]
|
||||
}
|
||||
_ = session.Append(s.cfg.SessionsDir, a.SessionID, session.Entry{
|
||||
SessionID: a.SessionID,
|
||||
Timestamp: time.Now(),
|
||||
Skill: "trainer",
|
||||
Phase: "trainer",
|
||||
Attempts: session.AttemptsFrom(writerResult.Attempts),
|
||||
FinalStatus: writerResult.Status,
|
||||
ModelUsed: writerResult.ModelUsed,
|
||||
FinalStatus: "ok",
|
||||
ModelUsed: model,
|
||||
DurationMs: time.Since(t0).Milliseconds(),
|
||||
Message: writerResult.Message,
|
||||
Message: msg,
|
||||
})
|
||||
|
||||
b, err := json.Marshal(writerResult)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal result: %w", err)
|
||||
}
|
||||
return b, nil
|
||||
return json.Marshal(map[string]any{
|
||||
"reader_analysis": readerText,
|
||||
"writer_output": writerText,
|
||||
"model": model,
|
||||
"duration_ms": dur,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/session"
|
||||
"github.com/mathiasbq/supervisor/internal/skills/trainer"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -31,52 +30,44 @@ func TestTrainerRequiresSessionID(t *testing.T) {
|
||||
func TestTrainerCallsReaderThenWriter(t *testing.T) {
|
||||
sessDir := t.TempDir()
|
||||
require.NoError(t, session.Append(sessDir, "sess-1", session.Entry{
|
||||
SessionID: "sess-1", Skill: "tdd", Phase: "red", FinalStatus: "pass",
|
||||
SessionID: "sess-1", Skill: "tdd", Phase: "red", FinalStatus: "ok",
|
||||
Message: "wrote failing test", FilePath: "internal/foo/foo_test.go",
|
||||
}))
|
||||
|
||||
callCount := 0
|
||||
var readerTask, writerTask string
|
||||
|
||||
fakeFn := func(_ context.Context, req iexec.Request) (iexec.Result, error) {
|
||||
fakeFn := func(_ context.Context, _, sys, user string) (string, int64, error) {
|
||||
callCount++
|
||||
if callCount == 1 {
|
||||
// reader call
|
||||
readerTask = req.TaskPrompt
|
||||
return iexec.Result{
|
||||
Status: "pass", Phase: "trainer", Skill: "trainer",
|
||||
RunnerOutput: `[{"type":"sft","moment":"first-pass clean TDD","score":4}]`,
|
||||
Verified: true, ModelUsed: "self", Message: "1 sft candidate found",
|
||||
}, nil
|
||||
readerTask = user
|
||||
return "1 sft candidate found: first-pass clean TDD", 60, nil
|
||||
}
|
||||
// writer call
|
||||
writerTask = req.TaskPrompt
|
||||
return iexec.Result{
|
||||
Status: "pass", Phase: "trainer", Skill: "trainer",
|
||||
FilePath: sessDir + "/training-data/sft/sess-1.jsonl",
|
||||
Verified: true, ModelUsed: "self", Message: "1 sft pair written",
|
||||
}, nil
|
||||
writerTask = user
|
||||
return "written 1 knowledge entry to brain/knowledge/tdd-patterns.md", 70, nil
|
||||
}
|
||||
|
||||
sk := trainer.New(trainer.Config{
|
||||
ReaderPrompt: "reader rules",
|
||||
WriterPrompt: "writer rules",
|
||||
ExecutorFn: fakeFn,
|
||||
CompleteFunc: fakeFn,
|
||||
SessionsDir: sessDir,
|
||||
BrainDir: t.TempDir(),
|
||||
})
|
||||
out, err := sk.Handle(context.Background(), "trainer", json.RawMessage(`{"session_id":"sess-1"}`))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 2, callCount, "executor must be called exactly twice: reader then writer")
|
||||
assert.Equal(t, 2, callCount, "complete must be called exactly twice: reader then writer")
|
||||
assert.Contains(t, readerTask, "role: reader")
|
||||
assert.Contains(t, readerTask, "sess-1")
|
||||
assert.Contains(t, readerTask, "wrote failing test") // session history in reader prompt
|
||||
assert.Contains(t, readerTask, "wrote failing test")
|
||||
assert.Contains(t, writerTask, "role: writer")
|
||||
assert.Contains(t, writerTask, "sft candidate") // reader output passed to writer
|
||||
assert.Contains(t, writerTask, "sft candidate")
|
||||
|
||||
var result iexec.Result
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &result))
|
||||
assert.Equal(t, "trainer", result.Phase)
|
||||
assert.Equal(t, "pass", result.Status)
|
||||
assert.Contains(t, result["reader_analysis"], "sft candidate")
|
||||
assert.Contains(t, result["writer_output"], "knowledge entry")
|
||||
}
|
||||
|
||||
@@ -5,21 +5,20 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/registry"
|
||||
)
|
||||
|
||||
// ExecutorFn is the function signature for running a worker subprocess.
|
||||
type ExecutorFn func(ctx context.Context, req iexec.Request) (iexec.Result, error)
|
||||
// CompleteFunc is the function used to call a local model.
|
||||
type CompleteFunc func(ctx context.Context, model, system, user string) (string, int64, error)
|
||||
|
||||
// Config holds dependencies for the trainer skill.
|
||||
type Config struct {
|
||||
ReaderPrompt string
|
||||
WriterPrompt string
|
||||
DefaultModel string
|
||||
ExecutorFn ExecutorFn
|
||||
CompleteFunc CompleteFunc
|
||||
SessionsDir string
|
||||
BrainDir string // root of brain/ directory; writer writes to BrainDir/training-data/
|
||||
BrainDir string // root of brain/ directory
|
||||
}
|
||||
|
||||
// Skill implements the trainer MCP tool.
|
||||
@@ -40,7 +39,7 @@ func (s *Skill) Tools() []registry.ToolDef {
|
||||
return []registry.ToolDef{
|
||||
{
|
||||
Name: "trainer",
|
||||
Description: "Extract SFT and DPO training pairs from a session log. Runs a reader→writer chain: reader identifies learning moments, writer formats and writes pairs to brain/training-data/.",
|
||||
Description: "Consult a local model to identify learning moments from a session log and suggest knowledge to preserve in the brain.",
|
||||
InputSchema: schema(
|
||||
[]string{"session_id"},
|
||||
map[string]any{
|
||||
|
||||
Reference in New Issue
Block a user