refactor: replace orchestrator/verifier chain with direct LiteLLM calls
Drop the three-layer Claude subprocess orchestration (local model →
Claude verifier → cloud escalation). Skills now call LiteLLM directly
and return plain text to Claude Code, which decides what to do with it.
- Delete executor, orchestrator, verifier, result, attempts packages
- Simplify LiteLLMExecutor: Run(Request)→Result becomes Complete(model,sys,user)→(string,int64,error)
- Replace ExecutorFn with CompleteFunc in all 6 skill configs
- Rewrite all skill handlers to call Complete and return {"text","model","duration_ms"}
- Simplify config/models: remove Verifier/LlamaSwapURL, add ModelFor
- Bump version to v0.5.0
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -37,12 +37,6 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
systemPrompt, err := os.ReadFile(cfg.ConfigDir + "/CLAUDE.md")
|
||||
if err != nil {
|
||||
logger.Error("read supervisor CLAUDE.md", "path", cfg.ConfigDir+"/CLAUDE.md", "err", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
protocolsPrompt, err := os.ReadFile(cfg.ConfigDir + "/protocols.md")
|
||||
if err != nil {
|
||||
logger.Error("read protocols.md", "path", cfg.ConfigDir+"/protocols.md", "err", err)
|
||||
@@ -95,40 +89,7 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
claudeExec := iexec.New(iexec.Config{
|
||||
SystemPrompt: string(systemPrompt),
|
||||
LiteLLMBaseURL: cfg.LiteLLMBaseURL,
|
||||
LiteLLMAPIKey: cfg.LiteLLMAPIKey,
|
||||
})
|
||||
litellmExec := iexec.NewLiteLLM(cfg.LiteLLMBaseURL, cfg.LiteLLMAPIKey, 0)
|
||||
verifier := iexec.NewVerifier("", models.Verifier(), 0)
|
||||
|
||||
buildOrch := func(skill string) func(ctx context.Context, req iexec.Request) (iexec.Result, error) {
|
||||
return func(ctx context.Context, req iexec.Request) (iexec.Result, error) {
|
||||
rawChain := models.ChainFor(skill, req.Model)
|
||||
chain := make([]iexec.ChainEntry, len(rawChain))
|
||||
for i, m := range rawChain {
|
||||
chain[i] = iexec.EntryFor(m)
|
||||
}
|
||||
attempts := make([]iexec.AttemptRecord, 0, len(chain))
|
||||
orch := iexec.NewOrchestrator(chain, litellmExec.Run, claudeExec.Run, verifier, models.LlamaSwapURL(), &attempts)
|
||||
result, err := orch.Run(ctx, req)
|
||||
result.Attempts = attempts // attach orchestration metadata before returning
|
||||
// Log per-attempt verdicts so pass rates are visible in pod logs.
|
||||
for i, a := range attempts {
|
||||
logger.Info("chain attempt",
|
||||
"skill", skill,
|
||||
"attempt", i+1,
|
||||
"model", a.Model,
|
||||
"tier", a.Tier,
|
||||
"verdict", a.Verdict,
|
||||
"duration_ms", a.DurationMs,
|
||||
"warm", a.WarmStart,
|
||||
)
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
}
|
||||
litellm := iexec.NewLiteLLM(cfg.LiteLLMBaseURL, cfg.LiteLLMAPIKey, 0)
|
||||
|
||||
tierFn := func(ctx context.Context) tier.Info {
|
||||
return tier.Detect(ctx, "https://api.anthropic.com", cfg.LiteLLMBaseURL)
|
||||
@@ -136,10 +97,9 @@ func main() {
|
||||
|
||||
reg := registry.New()
|
||||
reg.Register(tdd.New(tdd.Config{
|
||||
SystemPrompt: string(systemPrompt),
|
||||
SkillPrompt: prependProtocols(tddPrompt),
|
||||
DefaultModel: models.ChainFor("tdd", "")[0],
|
||||
ExecutorFn: buildOrch("tdd"),
|
||||
DefaultModel: models.ModelFor("tdd", ""),
|
||||
CompleteFunc: litellm.Complete,
|
||||
SessionsDir: cfg.SessionsDir,
|
||||
IngestBaseURL: cfg.IngestBaseURL,
|
||||
}))
|
||||
@@ -154,36 +114,36 @@ func main() {
|
||||
}))
|
||||
reg.Register(retrospective.New(retrospective.Config{
|
||||
SkillPrompt: prependProtocols(retroPrompt),
|
||||
DefaultModel: models.ChainFor("retrospective", "")[0],
|
||||
DefaultModel: models.ModelFor("retrospective", ""),
|
||||
SessionsDir: cfg.SessionsDir,
|
||||
ExecutorFn: buildOrch("retrospective"),
|
||||
CompleteFunc: litellm.Complete,
|
||||
}))
|
||||
reg.Register(review.New(review.Config{
|
||||
SkillPrompt: prependProtocols(reviewPrompt),
|
||||
DefaultModel: models.ChainFor("review", "")[0],
|
||||
ExecutorFn: buildOrch("review"),
|
||||
DefaultModel: models.ModelFor("review", ""),
|
||||
CompleteFunc: litellm.Complete,
|
||||
SessionsDir: cfg.SessionsDir,
|
||||
IngestBaseURL: cfg.IngestBaseURL,
|
||||
}))
|
||||
reg.Register(skilldebug.New(skilldebug.Config{
|
||||
SkillPrompt: prependProtocols(debugPrompt),
|
||||
DefaultModel: models.ChainFor("debug", "")[0],
|
||||
ExecutorFn: buildOrch("debug"),
|
||||
DefaultModel: models.ModelFor("debug", ""),
|
||||
CompleteFunc: litellm.Complete,
|
||||
SessionsDir: cfg.SessionsDir,
|
||||
IngestBaseURL: cfg.IngestBaseURL,
|
||||
}))
|
||||
reg.Register(spec.New(spec.Config{
|
||||
SkillPrompt: prependProtocols(specPrompt),
|
||||
DefaultModel: models.ChainFor("spec", "")[0],
|
||||
ExecutorFn: buildOrch("spec"),
|
||||
DefaultModel: models.ModelFor("spec", ""),
|
||||
CompleteFunc: litellm.Complete,
|
||||
SessionsDir: cfg.SessionsDir,
|
||||
IngestBaseURL: cfg.IngestBaseURL,
|
||||
}))
|
||||
reg.Register(trainer.New(trainer.Config{
|
||||
ReaderPrompt: prependProtocols(trainerReaderPrompt),
|
||||
WriterPrompt: prependProtocols(trainerWriterPrompt),
|
||||
DefaultModel: models.ChainFor("trainer", "")[0],
|
||||
ExecutorFn: buildOrch("trainer"),
|
||||
DefaultModel: models.ModelFor("trainer", ""),
|
||||
CompleteFunc: litellm.Complete,
|
||||
SessionsDir: cfg.SessionsDir,
|
||||
BrainDir: cfg.BrainDir,
|
||||
}))
|
||||
@@ -193,7 +153,7 @@ func main() {
|
||||
mux.Handle("/mcp", srv)
|
||||
|
||||
addr := ":" + cfg.Port
|
||||
logger.Info("supervisor starting", "addr", addr, "version", "v0.4.0")
|
||||
logger.Info("supervisor starting", "addr", addr, "version", "v0.5.0")
|
||||
if err := http.ListenAndServe(addr, mux); err != nil {
|
||||
logger.Error("server stopped", "err", err)
|
||||
os.Exit(1)
|
||||
|
||||
@@ -1,41 +1,25 @@
|
||||
# Model routing chains — three-layer priority:
|
||||
# 1. model param in MCP tool call (caller override — collapses to single entry, no escalation)
|
||||
# 2. per-skill chain here
|
||||
# 3. default_chain fallback
|
||||
|
||||
verifier: claude-sonnet-4-6 # fixed verifier for all local tiers
|
||||
|
||||
llama_swap_url: http://koala:8080 # for warm-state probing
|
||||
# Model selection — first entry per skill is used.
|
||||
# Override per-call by passing model in the MCP tool args.
|
||||
|
||||
default_chain:
|
||||
- ollama/qwen3-coder-30b-tuned
|
||||
- claude-sonnet-4-6
|
||||
|
||||
skills:
|
||||
tdd:
|
||||
chain:
|
||||
- ollama/qwen3-coder-30b-tuned
|
||||
- claude-sonnet-4-6
|
||||
review:
|
||||
chain:
|
||||
- ollama/devstral-tuned
|
||||
- ollama/gemma4
|
||||
- claude-sonnet-4-6
|
||||
debug:
|
||||
chain:
|
||||
- ollama/deepseek-r1-tuned
|
||||
- claude-sonnet-4-6
|
||||
spec:
|
||||
chain:
|
||||
- ollama/phi4
|
||||
- ollama/gemma4
|
||||
- claude-sonnet-4-6
|
||||
- claude-opus-4-6
|
||||
retrospective:
|
||||
chain:
|
||||
- ollama/qwen3-coder-30b-tuned
|
||||
- claude-sonnet-4-6
|
||||
trainer:
|
||||
chain:
|
||||
- ollama/qwen3-coder-30b-tuned
|
||||
- claude-sonnet-4-6
|
||||
|
||||
@@ -12,8 +12,6 @@ type skillChain struct {
|
||||
}
|
||||
|
||||
type modelsFile struct {
|
||||
Verifier string `yaml:"verifier"`
|
||||
LlamaSwapURL string `yaml:"llama_swap_url"`
|
||||
DefaultChain []string `yaml:"default_chain"`
|
||||
Skills map[string]skillChain `yaml:"skills"`
|
||||
}
|
||||
@@ -34,23 +32,18 @@ func LoadModels(path string) (Models, error) {
|
||||
return Models{data: f}, nil
|
||||
}
|
||||
|
||||
// Verifier returns the model name to use for all local-tier output verification.
|
||||
func (m Models) Verifier() string { return m.data.Verifier }
|
||||
|
||||
// LlamaSwapURL returns the llama-swap base URL for warm-state probing.
|
||||
func (m Models) LlamaSwapURL() string { return m.data.LlamaSwapURL }
|
||||
|
||||
// ChainFor returns the ordered list of model names for a skill.
|
||||
// If override is non-empty, returns a single-entry chain (no escalation).
|
||||
// Falls back to default_chain when the skill has no explicit entry.
|
||||
func (m Models) ChainFor(skill, override string) []string {
|
||||
// ModelFor returns the primary model to use for a skill.
|
||||
// If override is non-empty, it is returned directly.
|
||||
// Falls back to default_chain[0] when the skill has no explicit entry.
|
||||
func (m Models) ModelFor(skill, override string) string {
|
||||
if override != "" {
|
||||
return []string{override}
|
||||
return override
|
||||
}
|
||||
if sc, ok := m.data.Skills[skill]; ok && len(sc.Chain) > 0 {
|
||||
return sc.Chain
|
||||
return sc.Chain[0]
|
||||
}
|
||||
out := make([]string, len(m.data.DefaultChain))
|
||||
copy(out, m.data.DefaultChain)
|
||||
return out
|
||||
if len(m.data.DefaultChain) > 0 {
|
||||
return m.data.DefaultChain[0]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -11,9 +11,6 @@ import (
|
||||
)
|
||||
|
||||
const testYAML = `
|
||||
verifier: claude-sonnet-4-6
|
||||
llama_swap_url: http://koala:8080
|
||||
|
||||
default_chain:
|
||||
- ollama/qwen3-coder-30b-tuned
|
||||
- claude-sonnet-4-6
|
||||
@@ -37,44 +34,20 @@ func writeModels(t *testing.T, content string) string {
|
||||
return f
|
||||
}
|
||||
|
||||
func TestModelsVerifier(t *testing.T) {
|
||||
func TestModelsModelForSkillWithEntry(t *testing.T) {
|
||||
m, err := config.LoadModels(writeModels(t, testYAML))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "claude-sonnet-4-6", m.Verifier())
|
||||
assert.Equal(t, "ollama/devstral-tuned", m.ModelFor("review", ""))
|
||||
}
|
||||
|
||||
func TestModelsLlamaSwapURL(t *testing.T) {
|
||||
func TestModelsModelForDefaultFallback(t *testing.T) {
|
||||
m, err := config.LoadModels(writeModels(t, testYAML))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "http://koala:8080", m.LlamaSwapURL())
|
||||
assert.Equal(t, "ollama/qwen3-coder-30b-tuned", m.ModelFor("trainer", ""))
|
||||
}
|
||||
|
||||
func TestModelsChainForSkillOverride(t *testing.T) {
|
||||
func TestModelsModelForCallerOverride(t *testing.T) {
|
||||
m, err := config.LoadModels(writeModels(t, testYAML))
|
||||
require.NoError(t, err)
|
||||
|
||||
chain := m.ChainFor("review", "")
|
||||
require.Len(t, chain, 3)
|
||||
assert.Equal(t, "ollama/devstral-tuned", chain[0])
|
||||
assert.Equal(t, "ollama/gemma4", chain[1])
|
||||
assert.Equal(t, "claude-sonnet-4-6", chain[2])
|
||||
}
|
||||
|
||||
func TestModelsChainForDefaultFallback(t *testing.T) {
|
||||
m, err := config.LoadModels(writeModels(t, testYAML))
|
||||
require.NoError(t, err)
|
||||
|
||||
chain := m.ChainFor("trainer", "") // not in skills map
|
||||
require.Len(t, chain, 2)
|
||||
assert.Equal(t, "ollama/qwen3-coder-30b-tuned", chain[0])
|
||||
assert.Equal(t, "claude-sonnet-4-6", chain[1])
|
||||
}
|
||||
|
||||
func TestModelsChainForCallerOverride(t *testing.T) {
|
||||
m, err := config.LoadModels(writeModels(t, testYAML))
|
||||
require.NoError(t, err)
|
||||
|
||||
chain := m.ChainFor("review", "claude-opus-4-6")
|
||||
require.Len(t, chain, 1)
|
||||
assert.Equal(t, "claude-opus-4-6", chain[0])
|
||||
assert.Equal(t, "claude-opus-4-6", m.ModelFor("review", "claude-opus-4-6"))
|
||||
}
|
||||
|
||||
@@ -1,111 +0,0 @@
|
||||
package exec
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Config holds executor configuration.
|
||||
type Config struct {
|
||||
ClaudeBinary string // path to claude binary, defaults to "claude"
|
||||
SystemPrompt string // contents of supervisor CLAUDE.md
|
||||
Timeout time.Duration // per-invocation timeout, default 120s
|
||||
LiteLLMBaseURL string // passed to Claude so it can delegate to Ollama
|
||||
LiteLLMAPIKey string // passed to Claude for LiteLLM auth
|
||||
}
|
||||
|
||||
// Request is the input to a single supervisor invocation.
|
||||
type Request struct {
|
||||
SkillPrompt string // skill-specific discipline (e.g. tdd.md contents)
|
||||
TaskPrompt string // the specific task (phase, project_root, spec, model)
|
||||
Model string // resolved model name, passed in task prompt
|
||||
Tools string // comma-separated allowed tools, default "Bash,Read,Write"
|
||||
}
|
||||
|
||||
// Executor spawns a claude instance and captures its structured JSON output.
|
||||
type Executor struct {
|
||||
cfg Config
|
||||
}
|
||||
|
||||
func New(cfg Config) *Executor {
|
||||
if cfg.ClaudeBinary == "" {
|
||||
cfg.ClaudeBinary = "claude"
|
||||
}
|
||||
if cfg.Timeout == 0 {
|
||||
cfg.Timeout = 120 * time.Second
|
||||
}
|
||||
return &Executor{cfg: cfg}
|
||||
}
|
||||
|
||||
func (e *Executor) Run(ctx context.Context, req Request) (Result, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, e.cfg.Timeout)
|
||||
defer cancel()
|
||||
|
||||
tools := req.Tools
|
||||
if tools == "" {
|
||||
tools = "Bash,Read,Write"
|
||||
}
|
||||
|
||||
// Build the full prompt: system rules + skill rules + infra context + task.
|
||||
// LITELLM_API_KEY is injected as a subprocess env var, not in the prompt,
|
||||
// to prevent it appearing in error log output.
|
||||
litellmCtx := fmt.Sprintf("LITELLM_BASE_URL: %s", e.cfg.LiteLLMBaseURL)
|
||||
prompt := strings.Join([]string{
|
||||
e.cfg.SystemPrompt,
|
||||
"---",
|
||||
req.SkillPrompt,
|
||||
"---",
|
||||
litellmCtx,
|
||||
"---",
|
||||
req.TaskPrompt,
|
||||
}, "\n\n")
|
||||
|
||||
args := []string{
|
||||
"--print",
|
||||
"--permission-mode", "bypassPermissions",
|
||||
"--tools", tools,
|
||||
"--json-schema", Schema,
|
||||
"--output-format", "json",
|
||||
}
|
||||
if strings.HasPrefix(req.Model, "claude-") {
|
||||
args = append(args, "--model", req.Model)
|
||||
}
|
||||
args = append(args, prompt)
|
||||
|
||||
cmd := exec.CommandContext(ctx, e.cfg.ClaudeBinary, args...)
|
||||
cmd.Env = append(os.Environ(), "LITELLM_API_KEY="+e.cfg.LiteLLMAPIKey)
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return Result{}, fmt.Errorf("timeout after %s", e.cfg.Timeout)
|
||||
}
|
||||
return Result{}, fmt.Errorf("claude exited with error: %w — stderr: %s", err, stderr.String())
|
||||
}
|
||||
|
||||
// --output-format json wraps the response in an envelope; structured output
|
||||
// from --json-schema is in the "structured_output" field.
|
||||
var envelope struct {
|
||||
StructuredOutput *Result `json:"structured_output"`
|
||||
IsError bool `json:"is_error"`
|
||||
Result string `json:"result"` // fallback text result for error messages
|
||||
}
|
||||
if err := json.Unmarshal(stdout.Bytes(), &envelope); err != nil {
|
||||
return Result{}, fmt.Errorf("parse envelope JSON: %w — raw: %s — stderr: %s", err, stdout.String(), stderr.String())
|
||||
}
|
||||
if envelope.StructuredOutput == nil {
|
||||
return Result{}, fmt.Errorf("no structured_output in response — result: %s — stderr: %s", envelope.Result, stderr.String())
|
||||
}
|
||||
if err := envelope.StructuredOutput.Validate(); err != nil {
|
||||
return Result{}, fmt.Errorf("invalid result: %w", err)
|
||||
}
|
||||
return *envelope.StructuredOutput, nil
|
||||
}
|
||||
@@ -1,132 +0,0 @@
|
||||
package exec_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// fakeClaudePath writes a shell script that prints fixed output and returns its path.
|
||||
func fakeClaudePath(t *testing.T, output string, exitCode int) string {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
script := filepath.Join(dir, "claude")
|
||||
var content string
|
||||
if exitCode != 0 {
|
||||
content = "#!/bin/sh\necho 'error' >&2\nexit 1\n"
|
||||
} else {
|
||||
content = "#!/bin/sh\necho '" + output + "'\n"
|
||||
}
|
||||
require.NoError(t, os.WriteFile(script, []byte(content), 0755))
|
||||
return script
|
||||
}
|
||||
|
||||
func TestExecutorParsesValidResult(t *testing.T) {
|
||||
// Fake claude emits the --output-format json envelope that the real CLI produces.
|
||||
// The executor extracts the result from the "structured_output" field.
|
||||
envelope := `{"type":"result","subtype":"success","is_error":false,"structured_output":{"status":"pass","phase":"red","skill":"tdd","file_path":"/tmp/x_test.go","runner_output":"FAIL","verified":true,"model_used":"self","message":"ok"}}`
|
||||
claude := fakeClaudePath(t, envelope, 0)
|
||||
|
||||
ex := iexec.New(iexec.Config{
|
||||
ClaudeBinary: claude,
|
||||
SystemPrompt: "you are a supervisor",
|
||||
Timeout: 5 * time.Second,
|
||||
})
|
||||
|
||||
result, err := ex.Run(context.Background(), iexec.Request{
|
||||
SkillPrompt: "tdd rules",
|
||||
TaskPrompt: "run red phase",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pass", result.Status)
|
||||
assert.True(t, result.Verified)
|
||||
}
|
||||
|
||||
func TestExecutorReturnsErrorOnNonZeroExit(t *testing.T) {
|
||||
claude := fakeClaudePath(t, "", 1)
|
||||
|
||||
ex := iexec.New(iexec.Config{
|
||||
ClaudeBinary: claude,
|
||||
SystemPrompt: "you are a supervisor",
|
||||
Timeout: 5 * time.Second,
|
||||
})
|
||||
|
||||
_, err := ex.Run(context.Background(), iexec.Request{TaskPrompt: "fail"})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestExecutorTimesOut(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
script := filepath.Join(dir, "claude")
|
||||
require.NoError(t, os.WriteFile(script, []byte("#!/bin/sh\nsleep 60\n"), 0755))
|
||||
|
||||
ex := iexec.New(iexec.Config{
|
||||
ClaudeBinary: script,
|
||||
SystemPrompt: "you are a supervisor",
|
||||
Timeout: 100 * time.Millisecond,
|
||||
})
|
||||
|
||||
_, err := ex.Run(context.Background(), iexec.Request{TaskPrompt: "slow"})
|
||||
assert.ErrorContains(t, err, "timeout")
|
||||
}
|
||||
|
||||
func TestExecutorPassesModelFlagForCloudModel(t *testing.T) {
|
||||
// The script captures its args to a temp file so we can assert --model was passed.
|
||||
argsFile := filepath.Join(t.TempDir(), "args.txt")
|
||||
envelope := `{"type":"result","subtype":"success","is_error":false,"structured_output":{"status":"pass","phase":"review","skill":"review","file_path":"","runner_output":"","verified":true,"model_used":"claude-sonnet-4-6","message":"ok"}}`
|
||||
|
||||
dir := t.TempDir()
|
||||
script := filepath.Join(dir, "claude")
|
||||
content := "#!/bin/sh\necho \"$@\" > " + argsFile + "\necho '" + envelope + "'\n"
|
||||
require.NoError(t, os.WriteFile(script, []byte(content), 0755))
|
||||
|
||||
ex := iexec.New(iexec.Config{
|
||||
ClaudeBinary: script,
|
||||
SystemPrompt: "sys",
|
||||
Timeout: 5 * time.Second,
|
||||
})
|
||||
|
||||
_, err := ex.Run(context.Background(), iexec.Request{
|
||||
SkillPrompt: "review rules",
|
||||
TaskPrompt: "do review",
|
||||
Model: "claude-sonnet-4-6",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
argsData, err := os.ReadFile(argsFile)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, string(argsData), "--model claude-sonnet-4-6")
|
||||
}
|
||||
|
||||
func TestExecutorSkipsModelFlagForLocalModel(t *testing.T) {
|
||||
argsFile := filepath.Join(t.TempDir(), "args.txt")
|
||||
envelope := `{"type":"result","subtype":"success","is_error":false,"structured_output":{"status":"pass","phase":"review","skill":"review","file_path":"","runner_output":"","verified":true,"model_used":"ollama/devstral","message":"ok"}}`
|
||||
|
||||
dir := t.TempDir()
|
||||
script := filepath.Join(dir, "claude")
|
||||
content := "#!/bin/sh\necho \"$@\" > " + argsFile + "\necho '" + envelope + "'\n"
|
||||
require.NoError(t, os.WriteFile(script, []byte(content), 0755))
|
||||
|
||||
ex := iexec.New(iexec.Config{
|
||||
ClaudeBinary: script,
|
||||
SystemPrompt: "sys",
|
||||
Timeout: 5 * time.Second,
|
||||
})
|
||||
|
||||
_, err := ex.Run(context.Background(), iexec.Request{
|
||||
SkillPrompt: "review rules",
|
||||
TaskPrompt: "do review",
|
||||
Model: "ollama/devstral",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
argsData, err := os.ReadFile(argsFile)
|
||||
require.NoError(t, err)
|
||||
assert.NotContains(t, string(argsData), "--model")
|
||||
}
|
||||
@@ -9,9 +9,8 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// LiteLLMExecutor calls a LiteLLM-compatible /v1/chat/completions endpoint.
|
||||
// Local models are expected to return a JSON object matching the Result schema
|
||||
// as their response content — no envelope.
|
||||
// LiteLLMExecutor calls a LiteLLM-compatible /v1/chat/completions endpoint
|
||||
// and returns the raw assistant message text.
|
||||
type LiteLLMExecutor struct {
|
||||
baseURL string
|
||||
apiKey string
|
||||
@@ -21,6 +20,9 @@ type LiteLLMExecutor struct {
|
||||
// NewLiteLLM creates a LiteLLMExecutor.
|
||||
// timeout applies to the full HTTP round-trip per call.
|
||||
func NewLiteLLM(baseURL, apiKey string, timeout time.Duration) *LiteLLMExecutor {
|
||||
if timeout == 0 {
|
||||
timeout = 120 * time.Second
|
||||
}
|
||||
return &LiteLLMExecutor{
|
||||
baseURL: baseURL,
|
||||
apiKey: apiKey,
|
||||
@@ -46,58 +48,50 @@ type litellmResponse struct {
|
||||
Choices []litellmChoice `json:"choices"`
|
||||
}
|
||||
|
||||
// Run dispatches req to the LiteLLM server and parses the Result from the
|
||||
// assistant message content. Returns an error on network failure, non-200
|
||||
// status, or unparseable/invalid JSON — all of which the Orchestrator treats
|
||||
// as automatic escalation triggers.
|
||||
func (e *LiteLLMExecutor) Run(ctx context.Context, req Request) (Result, error) {
|
||||
// Complete sends system+user messages to the given model and returns the raw
|
||||
// assistant text along with the round-trip duration in milliseconds.
|
||||
func (e *LiteLLMExecutor) Complete(ctx context.Context, model, system, user string) (string, int64, error) {
|
||||
body := litellmRequest{
|
||||
Model: req.Model,
|
||||
Model: model,
|
||||
Messages: []litellmMessage{
|
||||
{Role: "system", Content: req.SkillPrompt},
|
||||
{Role: "user", Content: req.TaskPrompt},
|
||||
{Role: "system", Content: system},
|
||||
{Role: "user", Content: user},
|
||||
},
|
||||
}
|
||||
|
||||
bodyBytes, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return Result{}, fmt.Errorf("litellm: marshal request: %w", err)
|
||||
return "", 0, fmt.Errorf("litellm: marshal request: %w", err)
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, e.baseURL+"/v1/chat/completions", bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return Result{}, fmt.Errorf("litellm: create request: %w", err)
|
||||
return "", 0, fmt.Errorf("litellm: create request: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
if e.apiKey != "" {
|
||||
httpReq.Header.Set("Authorization", "Bearer "+e.apiKey)
|
||||
}
|
||||
|
||||
t0 := time.Now()
|
||||
resp, err := e.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return Result{}, fmt.Errorf("litellm: request failed: %w", err)
|
||||
return "", 0, fmt.Errorf("litellm: request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close() //nolint:errcheck
|
||||
durationMs := time.Since(t0).Milliseconds()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return Result{}, fmt.Errorf("litellm: server returned status %d", resp.StatusCode)
|
||||
return "", 0, fmt.Errorf("litellm: server returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var chatResp litellmResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
|
||||
return Result{}, fmt.Errorf("litellm: decode response: %w", err)
|
||||
return "", 0, fmt.Errorf("litellm: decode response: %w", err)
|
||||
}
|
||||
if len(chatResp.Choices) == 0 {
|
||||
return Result{}, fmt.Errorf("litellm: no choices in response")
|
||||
return "", 0, fmt.Errorf("litellm: no choices in response")
|
||||
}
|
||||
|
||||
content := chatResp.Choices[0].Message.Content
|
||||
var result Result
|
||||
if err := json.Unmarshal([]byte(content), &result); err != nil {
|
||||
return Result{}, fmt.Errorf("litellm: parse result JSON: %w — content: %s", err, content)
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return Result{}, fmt.Errorf("litellm: invalid result: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
return chatResp.Choices[0].Message.Content, durationMs, nil
|
||||
}
|
||||
|
||||
@@ -13,23 +13,11 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func validLiteLLMResult() iexec.Result {
|
||||
return iexec.Result{
|
||||
Status: "pass",
|
||||
Phase: "review",
|
||||
Skill: "review",
|
||||
ModelUsed: "ollama/devstral",
|
||||
Message: "looks good",
|
||||
}
|
||||
}
|
||||
|
||||
func chatResponseFor(t *testing.T, result iexec.Result) []byte {
|
||||
func chatResponse(t *testing.T, content string) []byte {
|
||||
t.Helper()
|
||||
content, err := json.Marshal(result)
|
||||
require.NoError(t, err)
|
||||
resp := map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{"message": map[string]any{"role": "assistant", "content": string(content)}},
|
||||
{"message": map[string]any{"role": "assistant", "content": content}},
|
||||
},
|
||||
}
|
||||
data, err := json.Marshal(resp)
|
||||
@@ -37,25 +25,21 @@ func chatResponseFor(t *testing.T, result iexec.Result) []byte {
|
||||
return data
|
||||
}
|
||||
|
||||
func TestLiteLLMParsesValidResult(t *testing.T) {
|
||||
func TestLiteLLMReturnsText(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/v1/chat/completions", r.URL.Path)
|
||||
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(chatResponseFor(t, validLiteLLMResult()))
|
||||
_, _ = w.Write(chatResponse(t, "here is my analysis"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
ex := iexec.NewLiteLLM(srv.URL, "", 5*time.Second)
|
||||
result, err := ex.Run(context.Background(), iexec.Request{
|
||||
SkillPrompt: "review rules",
|
||||
TaskPrompt: "review the code",
|
||||
Model: "ollama/devstral",
|
||||
})
|
||||
text, dur, err := ex.Complete(context.Background(), "ollama/devstral", "system prompt", "user prompt")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pass", result.Status)
|
||||
assert.Equal(t, "review", result.Skill)
|
||||
assert.Equal(t, "here is my analysis", text)
|
||||
assert.GreaterOrEqual(t, dur, int64(0))
|
||||
}
|
||||
|
||||
func TestLiteLLMSendsAuthHeader(t *testing.T) {
|
||||
@@ -63,12 +47,12 @@ func TestLiteLLMSendsAuthHeader(t *testing.T) {
|
||||
assert.Equal(t, "Bearer secret", r.Header.Get("Authorization"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(chatResponseFor(t, validLiteLLMResult()))
|
||||
_, _ = w.Write(chatResponse(t, "ok"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
ex := iexec.NewLiteLLM(srv.URL, "secret", 5*time.Second)
|
||||
_, err := ex.Run(context.Background(), iexec.Request{Model: "x", TaskPrompt: "t", SkillPrompt: "s"})
|
||||
_, _, err := ex.Complete(context.Background(), "model", "sys", "user")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -79,34 +63,28 @@ func TestLiteLLMErrorOnNonOKStatus(t *testing.T) {
|
||||
defer srv.Close()
|
||||
|
||||
ex := iexec.NewLiteLLM(srv.URL, "", 5*time.Second)
|
||||
_, err := ex.Run(context.Background(), iexec.Request{Model: "x", TaskPrompt: "t"})
|
||||
_, _, err := ex.Complete(context.Background(), "model", "sys", "user")
|
||||
assert.ErrorContains(t, err, "503")
|
||||
}
|
||||
|
||||
func TestLiteLLMErrorOnUnparsableJSON(t *testing.T) {
|
||||
func TestLiteLLMErrorOnEmptyChoices(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
resp := map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{"message": map[string]any{"role": "assistant", "content": "not json at all"}},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
_, _ = w.Write(data)
|
||||
_, _ = w.Write([]byte(`{"choices":[]}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
ex := iexec.NewLiteLLM(srv.URL, "", 5*time.Second)
|
||||
_, err := ex.Run(context.Background(), iexec.Request{Model: "x", TaskPrompt: "t"})
|
||||
assert.Error(t, err)
|
||||
_, _, err := ex.Complete(context.Background(), "model", "sys", "user")
|
||||
assert.ErrorContains(t, err, "no choices")
|
||||
}
|
||||
|
||||
func TestLiteLLMRespectsContextCancellation(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
cancel()
|
||||
|
||||
ex := iexec.NewLiteLLM("http://invalid.example.com", "", 1*time.Second)
|
||||
_, err := ex.Run(ctx, iexec.Request{Model: "x", TaskPrompt: "t"})
|
||||
_, _, err := ex.Complete(ctx, "model", "sys", "user")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -1,197 +0,0 @@
|
||||
package exec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ChainEntry is one tier in an escalation chain.
|
||||
type ChainEntry struct {
|
||||
Model string // e.g. "ollama/phi4", "claude-sonnet-4-6"
|
||||
Tier string // "local" | "subagent" | "managed"
|
||||
IsCloud bool // true for claude-* models; skips verifier call
|
||||
}
|
||||
|
||||
// EntryFor builds a ChainEntry from a model name string.
|
||||
func EntryFor(model string) ChainEntry {
|
||||
cloud := strings.HasPrefix(model, "claude-")
|
||||
tier := "local"
|
||||
if cloud {
|
||||
tier = "subagent"
|
||||
}
|
||||
return ChainEntry{Model: model, Tier: tier, IsCloud: cloud}
|
||||
}
|
||||
|
||||
// AttemptRecord captures the outcome of one tier attempt for session logging.
|
||||
type AttemptRecord struct {
|
||||
Model string
|
||||
Tier string
|
||||
DurationMs int64
|
||||
WarmStart bool
|
||||
Verdict string // "accept" | "escalate" | "error"
|
||||
Feedback string
|
||||
}
|
||||
|
||||
// VerifierFn is the interface the orchestrator uses to verify local output.
|
||||
type VerifierFn interface {
|
||||
Verify(ctx context.Context, skillPrompt, taskPrompt string, output Result) (Verdict, error)
|
||||
}
|
||||
|
||||
// ExecutorRunFn is the signature of Executor.Run and LiteLLMExecutor.Run.
|
||||
type ExecutorRunFn func(ctx context.Context, req Request) (Result, error)
|
||||
|
||||
// Orchestrator walks an escalation chain, delegating generation and verification.
|
||||
// It implements the ExecutorFn shape expected by skill handlers.
|
||||
type Orchestrator struct {
|
||||
chain []ChainEntry
|
||||
localRun ExecutorRunFn // for local (non-cloud) tiers; may be nil
|
||||
cloudRun ExecutorRunFn // for cloud tiers; may be nil
|
||||
verifier VerifierFn
|
||||
llamaSwapURL string
|
||||
attempts *[]AttemptRecord
|
||||
}
|
||||
|
||||
// NewOrchestrator creates an Orchestrator.
|
||||
// attempts is a pointer to a slice that will be appended to on each tier attempt.
|
||||
// Pass nil for localRun or cloudRun if no tiers of that type exist in the chain.
|
||||
func NewOrchestrator(
|
||||
chain []ChainEntry,
|
||||
localRun ExecutorRunFn,
|
||||
cloudRun ExecutorRunFn,
|
||||
verifier VerifierFn,
|
||||
llamaSwapURL string,
|
||||
attempts *[]AttemptRecord,
|
||||
) *Orchestrator {
|
||||
return &Orchestrator{
|
||||
chain: chain,
|
||||
localRun: localRun,
|
||||
cloudRun: cloudRun,
|
||||
verifier: verifier,
|
||||
llamaSwapURL: llamaSwapURL,
|
||||
attempts: attempts,
|
||||
}
|
||||
}
|
||||
|
||||
// Run walks the escalation chain and returns the first accepted result.
|
||||
// Satisfies the ExecutorFn signature: func(context.Context, Request) (Result, error).
|
||||
func (o *Orchestrator) Run(ctx context.Context, req Request) (Result, error) {
|
||||
taskPrompt := req.TaskPrompt
|
||||
|
||||
for _, entry := range o.chain {
|
||||
warm := o.probeWarm(entry.Model)
|
||||
start := time.Now()
|
||||
|
||||
tierReq := req
|
||||
tierReq.Model = entry.Model
|
||||
tierReq.TaskPrompt = taskPrompt
|
||||
|
||||
if entry.IsCloud {
|
||||
result, genErr := o.cloudRun(ctx, tierReq)
|
||||
dur := time.Since(start).Milliseconds()
|
||||
verdict := "accept"
|
||||
if genErr != nil {
|
||||
verdict = "error"
|
||||
}
|
||||
o.appendAttempt(AttemptRecord{
|
||||
Model: entry.Model,
|
||||
Tier: entry.Tier,
|
||||
DurationMs: dur,
|
||||
WarmStart: warm,
|
||||
Verdict: verdict,
|
||||
})
|
||||
if genErr == nil {
|
||||
return result, nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Local tier.
|
||||
result, genErr := o.localRun(ctx, tierReq)
|
||||
dur := time.Since(start).Milliseconds()
|
||||
|
||||
if genErr != nil {
|
||||
o.appendAttempt(AttemptRecord{
|
||||
Model: entry.Model,
|
||||
Tier: entry.Tier,
|
||||
DurationMs: dur,
|
||||
WarmStart: warm,
|
||||
Verdict: "error",
|
||||
Feedback: genErr.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
verdict, verErr := o.verifier.Verify(ctx, req.SkillPrompt, taskPrompt, result)
|
||||
if verErr != nil {
|
||||
// Treat verifier failure as escalate (safe default).
|
||||
o.appendAttempt(AttemptRecord{
|
||||
Model: entry.Model,
|
||||
Tier: entry.Tier,
|
||||
DurationMs: dur,
|
||||
WarmStart: warm,
|
||||
Verdict: "escalate",
|
||||
Feedback: "verifier error: " + verErr.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if verdict.Accept {
|
||||
o.appendAttempt(AttemptRecord{
|
||||
Model: entry.Model,
|
||||
Tier: entry.Tier,
|
||||
DurationMs: dur,
|
||||
WarmStart: warm,
|
||||
Verdict: "accept",
|
||||
})
|
||||
return result, nil
|
||||
}
|
||||
|
||||
o.appendAttempt(AttemptRecord{
|
||||
Model: entry.Model,
|
||||
Tier: entry.Tier,
|
||||
DurationMs: dur,
|
||||
WarmStart: warm,
|
||||
Verdict: "escalate",
|
||||
Feedback: verdict.Feedback,
|
||||
})
|
||||
// Inject verifier feedback into the next tier's task prompt.
|
||||
taskPrompt = taskPrompt + "\n\nPrior attempt feedback: " + verdict.Feedback
|
||||
}
|
||||
|
||||
return Result{}, fmt.Errorf("all tiers exhausted after %d attempt(s)", len(o.chain))
|
||||
}
|
||||
|
||||
func (o *Orchestrator) appendAttempt(rec AttemptRecord) {
|
||||
if o.attempts != nil {
|
||||
*o.attempts = append(*o.attempts, rec)
|
||||
}
|
||||
}
|
||||
|
||||
// probeWarm checks whether the model is currently loaded in llama-swap.
|
||||
// Returns false on any error or if llamaSwapURL is empty.
|
||||
func (o *Orchestrator) probeWarm(model string) bool {
|
||||
if o.llamaSwapURL == "" {
|
||||
return false
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, o.llamaSwapURL+"/v1/models", nil)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close() //nolint:errcheck
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(string(body), model)
|
||||
}
|
||||
@@ -1,151 +0,0 @@
|
||||
package exec_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// stubRunFn returns preset results sequentially.
|
||||
type stubRunFn struct {
|
||||
calls []stubCall
|
||||
callIdx int
|
||||
}
|
||||
|
||||
type stubCall struct {
|
||||
result iexec.Result
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *stubRunFn) Run(_ context.Context, _ iexec.Request) (iexec.Result, error) {
|
||||
if s.callIdx >= len(s.calls) {
|
||||
return iexec.Result{}, errors.New("unexpected call")
|
||||
}
|
||||
c := s.calls[s.callIdx]
|
||||
s.callIdx++
|
||||
return c.result, c.err
|
||||
}
|
||||
|
||||
// stubVerifier returns preset verdicts sequentially.
|
||||
type stubVerifier struct {
|
||||
verdicts []iexec.Verdict
|
||||
idx int
|
||||
}
|
||||
|
||||
func (s *stubVerifier) Verify(_ context.Context, _, _ string, _ iexec.Result) (iexec.Verdict, error) {
|
||||
if s.idx >= len(s.verdicts) {
|
||||
return iexec.Verdict{}, errors.New("unexpected verify call")
|
||||
}
|
||||
v := s.verdicts[s.idx]
|
||||
s.idx++
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func okResult(skill string) iexec.Result {
|
||||
return iexec.Result{Status: "pass", Phase: "review", Skill: skill, Message: "ok", ModelUsed: "m"}
|
||||
}
|
||||
|
||||
func TestOrchestratorSingleLocalAccept(t *testing.T) {
|
||||
local := &stubRunFn{calls: []stubCall{{result: okResult("review")}}}
|
||||
verifier := &stubVerifier{verdicts: []iexec.Verdict{{Accept: true}}}
|
||||
|
||||
var attempts []iexec.AttemptRecord
|
||||
orch := iexec.NewOrchestrator(
|
||||
[]iexec.ChainEntry{{Model: "ollama/devstral", Tier: "local", IsCloud: false}},
|
||||
local.Run, nil, verifier, "", &attempts,
|
||||
)
|
||||
|
||||
result, err := orch.Run(context.Background(), iexec.Request{TaskPrompt: "review"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pass", result.Status)
|
||||
require.Len(t, attempts, 1)
|
||||
assert.Equal(t, "local", attempts[0].Tier)
|
||||
assert.Equal(t, "accept", attempts[0].Verdict)
|
||||
}
|
||||
|
||||
func TestOrchestratorEscalatesOnVerifierReject(t *testing.T) {
|
||||
local := &stubRunFn{calls: []stubCall{
|
||||
{result: iexec.Result{Status: "fail", Phase: "review", Skill: "review", Message: "weak"}},
|
||||
{result: okResult("review")},
|
||||
}}
|
||||
verifier := &stubVerifier{verdicts: []iexec.Verdict{
|
||||
{Accept: false, Feedback: "missing line refs"},
|
||||
{Accept: true},
|
||||
}}
|
||||
|
||||
var attempts []iexec.AttemptRecord
|
||||
orch := iexec.NewOrchestrator(
|
||||
[]iexec.ChainEntry{
|
||||
{Model: "ollama/devstral", Tier: "local", IsCloud: false},
|
||||
{Model: "ollama/gemma4", Tier: "local", IsCloud: false},
|
||||
},
|
||||
local.Run, nil, verifier, "", &attempts,
|
||||
)
|
||||
|
||||
result, err := orch.Run(context.Background(), iexec.Request{TaskPrompt: "review"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pass", result.Status)
|
||||
require.Len(t, attempts, 2)
|
||||
assert.Equal(t, "escalate", attempts[0].Verdict)
|
||||
assert.Equal(t, "missing line refs", attempts[0].Feedback)
|
||||
assert.Equal(t, "accept", attempts[1].Verdict)
|
||||
}
|
||||
|
||||
func TestOrchestratorEscalatesOnLocalError(t *testing.T) {
|
||||
local := &stubRunFn{calls: []stubCall{
|
||||
{err: errors.New("network failure")},
|
||||
{result: okResult("review")},
|
||||
}}
|
||||
verifier := &stubVerifier{verdicts: []iexec.Verdict{{Accept: true}}}
|
||||
|
||||
var attempts []iexec.AttemptRecord
|
||||
orch := iexec.NewOrchestrator(
|
||||
[]iexec.ChainEntry{
|
||||
{Model: "ollama/devstral", Tier: "local", IsCloud: false},
|
||||
{Model: "ollama/gemma4", Tier: "local", IsCloud: false},
|
||||
},
|
||||
local.Run, nil, verifier, "", &attempts,
|
||||
)
|
||||
|
||||
_, err := orch.Run(context.Background(), iexec.Request{TaskPrompt: "review"})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, attempts, 2)
|
||||
assert.Equal(t, "error", attempts[0].Verdict)
|
||||
assert.Equal(t, "accept", attempts[1].Verdict)
|
||||
}
|
||||
|
||||
func TestOrchestratorCloudTierSelfCertifies(t *testing.T) {
|
||||
cloud := &stubRunFn{calls: []stubCall{{result: okResult("review")}}}
|
||||
verifier := &stubVerifier{} // no verdicts — must not be called
|
||||
|
||||
var attempts []iexec.AttemptRecord
|
||||
orch := iexec.NewOrchestrator(
|
||||
[]iexec.ChainEntry{{Model: "claude-sonnet-4-6", Tier: "subagent", IsCloud: true}},
|
||||
nil, cloud.Run, verifier, "", &attempts,
|
||||
)
|
||||
|
||||
result, err := orch.Run(context.Background(), iexec.Request{TaskPrompt: "review"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pass", result.Status)
|
||||
require.Len(t, attempts, 1)
|
||||
assert.Equal(t, "subagent", attempts[0].Tier)
|
||||
assert.Equal(t, "accept", attempts[0].Verdict)
|
||||
assert.Equal(t, 0, verifier.idx) // verifier never called
|
||||
}
|
||||
|
||||
func TestOrchestratorAllTiersExhausted(t *testing.T) {
|
||||
local := &stubRunFn{calls: []stubCall{{err: errors.New("unavailable")}}}
|
||||
|
||||
var attempts []iexec.AttemptRecord
|
||||
orch := iexec.NewOrchestrator(
|
||||
[]iexec.ChainEntry{{Model: "ollama/devstral", Tier: "local", IsCloud: false}},
|
||||
local.Run, nil, &stubVerifier{}, "", &attempts,
|
||||
)
|
||||
|
||||
_, err := orch.Run(context.Background(), iexec.Request{TaskPrompt: "review"})
|
||||
assert.ErrorContains(t, err, "all tiers exhausted")
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
package exec
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Result is the structured JSON output from every supervisor invocation.
|
||||
// The JSON schema constant is passed to claude via --json-schema so Claude
|
||||
// validates its own output before returning.
|
||||
type Result struct {
|
||||
Status string `json:"status"` // pass | fail | error
|
||||
Phase string `json:"phase"` // red | green | refactor | retrospective | review | debug | spec | trainer
|
||||
Skill string `json:"skill"` // tdd | review | ...
|
||||
FilePath string `json:"file_path"` // absolute path to generated file
|
||||
RunnerOutput string `json:"runner_output"` // raw stdout+stderr from test runner
|
||||
Verified bool `json:"verified"` // based on exit code, never self-report
|
||||
ModelUsed string `json:"model_used"` // model name or "self"
|
||||
Message string `json:"message"` // one sentence summary
|
||||
Attempts []AttemptRecord `json:"attempts,omitempty"` // populated by orchestrator, not Claude
|
||||
}
|
||||
|
||||
var validStatuses = map[string]bool{"pass": true, "fail": true, "error": true}
|
||||
var validPhases = map[string]bool{
|
||||
"red": true,
|
||||
"green": true,
|
||||
"refactor": true,
|
||||
"retrospective": true,
|
||||
"review": true,
|
||||
"debug": true,
|
||||
"spec": true,
|
||||
"trainer": true,
|
||||
}
|
||||
|
||||
func (r Result) Validate() error {
|
||||
var errs []string
|
||||
if !validStatuses[r.Status] {
|
||||
errs = append(errs, "status must be pass|fail|error, got: "+r.Status)
|
||||
}
|
||||
if !validPhases[r.Phase] {
|
||||
errs = append(errs, "phase must be one of red|green|refactor|retrospective|review|debug|spec|trainer, got: "+r.Phase)
|
||||
}
|
||||
if r.Skill == "" {
|
||||
errs = append(errs, "skill is required")
|
||||
}
|
||||
if len(errs) > 0 {
|
||||
return errors.New(strings.Join(errs, "; "))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Schema is passed to claude --json-schema to enforce structured output.
|
||||
const Schema = `{
|
||||
"type": "object",
|
||||
"required": ["status","phase","skill","file_path","runner_output","verified","model_used","message"],
|
||||
"properties": {
|
||||
"status": {"type": "string", "enum": ["pass","fail","error"]},
|
||||
"phase": {"type": "string"},
|
||||
"skill": {"type": "string"},
|
||||
"file_path": {"type": "string"},
|
||||
"runner_output": {"type": "string"},
|
||||
"verified": {"type": "boolean"},
|
||||
"model_used": {"type": "string"},
|
||||
"message": {"type": "string"}
|
||||
}
|
||||
}`
|
||||
@@ -1,79 +0,0 @@
|
||||
package exec_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestResultParsesValidJSON(t *testing.T) {
|
||||
raw := `{
|
||||
"status": "pass",
|
||||
"phase": "red",
|
||||
"skill": "tdd",
|
||||
"file_path": "/tmp/foo_test.go",
|
||||
"runner_output": "--- FAIL: TestFoo",
|
||||
"verified": true,
|
||||
"model_used": "self",
|
||||
"message": "test fails as expected"
|
||||
}`
|
||||
var r exec.Result
|
||||
require.NoError(t, json.Unmarshal([]byte(raw), &r))
|
||||
assert.Equal(t, "pass", r.Status)
|
||||
assert.Equal(t, "red", r.Phase)
|
||||
assert.True(t, r.Verified)
|
||||
}
|
||||
|
||||
func TestResultValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
result exec.Result
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid pass result",
|
||||
result: exec.Result{
|
||||
Status: "pass", Phase: "red", Skill: "tdd",
|
||||
FilePath: "/tmp/x_test.go", RunnerOutput: "FAIL",
|
||||
Verified: true, ModelUsed: "self", Message: "ok",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty status",
|
||||
result: exec.Result{Phase: "red", Skill: "tdd"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid status",
|
||||
result: exec.Result{Status: "unknown", Phase: "red", Skill: "tdd"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid phase",
|
||||
result: exec.Result{Status: "pass", Phase: "bad", Skill: "tdd"},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.result.Validate()
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAcceptsAllPhases(t *testing.T) {
|
||||
phases := []string{"red", "green", "refactor", "retrospective", "review", "debug", "spec", "trainer"}
|
||||
for _, phase := range phases {
|
||||
r := exec.Result{Status: "pass", Phase: phase, Skill: "test", ModelUsed: "self", Message: "ok"}
|
||||
assert.NoError(t, r.Validate(), "phase %q should be valid", phase)
|
||||
}
|
||||
}
|
||||
@@ -1,99 +0,0 @@
|
||||
package exec
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Verdict is the output of a Claude verification call.
|
||||
type Verdict struct {
|
||||
Accept bool `json:"accept"`
|
||||
Feedback string `json:"feedback"` // empty when Accept is true
|
||||
}
|
||||
|
||||
// Verifier runs a focused Claude call to judge local model output.
|
||||
type Verifier struct {
|
||||
claudeBinary string
|
||||
model string
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// NewVerifier creates a Verifier that calls claude with the given binary path and model.
|
||||
// Empty claudeBinary defaults to "claude". Zero timeout defaults to 30s.
|
||||
func NewVerifier(claudeBinary, model string, timeout time.Duration) *Verifier {
|
||||
if claudeBinary == "" {
|
||||
claudeBinary = "claude"
|
||||
}
|
||||
if timeout == 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
return &Verifier{
|
||||
claudeBinary: claudeBinary,
|
||||
model: model,
|
||||
timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
// Verify asks Claude whether output satisfies the skill discipline's iron laws.
|
||||
// Returns Verdict{Accept: true} to accept or Verdict{Accept: false, Feedback: "..."}
|
||||
// to escalate. Returns an error on subprocess failure or unparseable response.
|
||||
func (v *Verifier) Verify(ctx context.Context, skillPrompt, taskPrompt string, output Result) (Verdict, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, v.timeout)
|
||||
defer cancel()
|
||||
|
||||
outputJSON, err := json.Marshal(output)
|
||||
if err != nil {
|
||||
return Verdict{}, fmt.Errorf("verifier: marshal output: %w", err)
|
||||
}
|
||||
|
||||
prompt := fmt.Sprintf(`You are a quality verifier for an AI supervisor system.
|
||||
|
||||
Given the skill discipline, the original task, and the generated output, decide whether the output satisfies the discipline's iron laws and output contract.
|
||||
|
||||
Reply with JSON only — no other text:
|
||||
{"accept": true, "feedback": ""}
|
||||
or
|
||||
{"accept": false, "feedback": "<one sentence reason>"}
|
||||
|
||||
## Skill discipline
|
||||
%s
|
||||
|
||||
## Original task
|
||||
%s
|
||||
|
||||
## Generated output
|
||||
%s`, skillPrompt, taskPrompt, string(outputJSON))
|
||||
|
||||
args := []string{
|
||||
"--print",
|
||||
"--permission-mode", "bypassPermissions",
|
||||
}
|
||||
if v.model != "" {
|
||||
args = append(args, "--model", v.model)
|
||||
}
|
||||
args = append(args, prompt)
|
||||
|
||||
cmd := exec.CommandContext(ctx, v.claudeBinary, args...)
|
||||
cmd.Env = os.Environ()
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return Verdict{}, fmt.Errorf("verifier: timeout after %s", v.timeout)
|
||||
}
|
||||
return Verdict{}, fmt.Errorf("verifier: claude exited with error: %w — stderr: %s", err, stderr.String())
|
||||
}
|
||||
|
||||
var verdict Verdict
|
||||
if err := json.Unmarshal(bytes.TrimSpace(stdout.Bytes()), &verdict); err != nil {
|
||||
return Verdict{}, fmt.Errorf("verifier: parse verdict JSON: %w — raw: %s", err, stdout.String())
|
||||
}
|
||||
return verdict, nil
|
||||
}
|
||||
@@ -1,74 +0,0 @@
|
||||
package exec_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func fakeVerifierClaude(t *testing.T, verdict iexec.Verdict) string {
|
||||
t.Helper()
|
||||
data, err := json.Marshal(verdict)
|
||||
require.NoError(t, err)
|
||||
dir := t.TempDir()
|
||||
script := filepath.Join(dir, "claude")
|
||||
content := fmt.Sprintf("#!/bin/sh\necho '%s'\n", string(data))
|
||||
require.NoError(t, os.WriteFile(script, []byte(content), 0755))
|
||||
return script
|
||||
}
|
||||
|
||||
func TestVerifierAccepts(t *testing.T) {
|
||||
claude := fakeVerifierClaude(t, iexec.Verdict{Accept: true, Feedback: ""})
|
||||
v := iexec.NewVerifier(claude, "claude-sonnet-4-6", 5*time.Second)
|
||||
|
||||
verdict, err := v.Verify(context.Background(), "skill rules", "do the task", iexec.Result{
|
||||
Status: "pass", Phase: "review", Skill: "review", Message: "ok",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, verdict.Accept)
|
||||
assert.Empty(t, verdict.Feedback)
|
||||
}
|
||||
|
||||
func TestVerifierEscalates(t *testing.T) {
|
||||
claude := fakeVerifierClaude(t, iexec.Verdict{Accept: false, Feedback: "missing line references"})
|
||||
v := iexec.NewVerifier(claude, "claude-sonnet-4-6", 5*time.Second)
|
||||
|
||||
verdict, err := v.Verify(context.Background(), "skill rules", "do the task", iexec.Result{
|
||||
Status: "pass", Phase: "review", Skill: "review", Message: "incomplete",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, verdict.Accept)
|
||||
assert.Equal(t, "missing line references", verdict.Feedback)
|
||||
}
|
||||
|
||||
func TestVerifierErrorOnUnparsableOutput(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
script := filepath.Join(dir, "claude")
|
||||
require.NoError(t, os.WriteFile(script, []byte("#!/bin/sh\necho 'not json'\n"), 0755))
|
||||
|
||||
v := iexec.NewVerifier(script, "claude-sonnet-4-6", 5*time.Second)
|
||||
_, err := v.Verify(context.Background(), "rules", "task", iexec.Result{
|
||||
Status: "pass", Phase: "review", Skill: "review", Message: "ok",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestVerifierErrorOnNonZeroExit(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
script := filepath.Join(dir, "claude")
|
||||
require.NoError(t, os.WriteFile(script, []byte("#!/bin/sh\nexit 1\n"), 0755))
|
||||
|
||||
v := iexec.NewVerifier(script, "claude-sonnet-4-6", 5*time.Second)
|
||||
_, err := v.Verify(context.Background(), "rules", "task", iexec.Result{
|
||||
Status: "pass", Phase: "review", Skill: "review", Message: "ok",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
// internal/session/attempts.go
|
||||
package session
|
||||
|
||||
import iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
|
||||
// AttemptsFrom converts exec.AttemptRecord slice to session.Attempt slice
|
||||
// for writing into a session JSONL entry.
|
||||
func AttemptsFrom(records []iexec.AttemptRecord) []Attempt {
|
||||
if len(records) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]Attempt, len(records))
|
||||
for i, r := range records {
|
||||
out[i] = Attempt{
|
||||
Attempt: i + 1,
|
||||
Model: r.Model,
|
||||
Tier: r.Tier,
|
||||
DurationMs: r.DurationMs,
|
||||
WarmStart: r.WarmStart,
|
||||
Verdict: r.Verdict,
|
||||
Feedback: r.Feedback,
|
||||
Verified: r.Verdict == "accept",
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
package session_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/session"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAttemptsFromEmpty(t *testing.T) {
|
||||
result := session.AttemptsFrom(nil)
|
||||
assert.Empty(t, result)
|
||||
}
|
||||
|
||||
func TestAttemptsFromSetsIndex(t *testing.T) {
|
||||
records := []exec.AttemptRecord{
|
||||
{Model: "ollama/phi4", Tier: "local", DurationMs: 1200, WarmStart: true, Verdict: "escalate", Feedback: "too vague"},
|
||||
{Model: "claude-sonnet-4-6", Tier: "subagent", DurationMs: 3400, WarmStart: false, Verdict: "accept"},
|
||||
}
|
||||
result := session.AttemptsFrom(records)
|
||||
require.Len(t, result, 2)
|
||||
|
||||
assert.Equal(t, 1, result[0].Attempt)
|
||||
assert.Equal(t, "ollama/phi4", result[0].Model)
|
||||
assert.Equal(t, "local", result[0].Tier)
|
||||
assert.Equal(t, int64(1200), result[0].DurationMs)
|
||||
assert.True(t, result[0].WarmStart)
|
||||
assert.Equal(t, "escalate", result[0].Verdict)
|
||||
assert.Equal(t, "too vague", result[0].Feedback)
|
||||
assert.False(t, result[0].Verified)
|
||||
|
||||
assert.Equal(t, 2, result[1].Attempt)
|
||||
assert.Equal(t, "claude-sonnet-4-6", result[1].Model)
|
||||
assert.True(t, result[1].Verified)
|
||||
}
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/mathiasbq/supervisor/internal/brain"
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/session"
|
||||
)
|
||||
|
||||
@@ -52,38 +51,32 @@ func (s *Skill) Handle(ctx context.Context, tool string, args json.RawMessage) (
|
||||
task = brainCtx + "\n---\n\n" + task
|
||||
}
|
||||
|
||||
if s.cfg.ExecutorFn == nil {
|
||||
if s.cfg.CompleteFunc == nil {
|
||||
return nil, fmt.Errorf("no executor configured")
|
||||
}
|
||||
t0 := time.Now()
|
||||
result, err := s.cfg.ExecutorFn(ctx, iexec.Request{
|
||||
SkillPrompt: s.cfg.SkillPrompt,
|
||||
TaskPrompt: task,
|
||||
Model: model,
|
||||
Tools: "Read,Bash",
|
||||
})
|
||||
text, dur, err := s.cfg.CompleteFunc(ctx, model, s.cfg.SkillPrompt, task)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if a.SessionID != "" && s.cfg.SessionsDir != "" {
|
||||
msg := text
|
||||
if len(msg) > 200 {
|
||||
msg = msg[:200]
|
||||
}
|
||||
_ = session.Append(s.cfg.SessionsDir, a.SessionID, session.Entry{
|
||||
SessionID: a.SessionID,
|
||||
Timestamp: time.Now(),
|
||||
Skill: "debug",
|
||||
Phase: "debug",
|
||||
ProjectRoot: a.ProjectRoot,
|
||||
Attempts: session.AttemptsFrom(result.Attempts),
|
||||
FinalStatus: result.Status,
|
||||
ModelUsed: result.ModelUsed,
|
||||
FinalStatus: "ok",
|
||||
ModelUsed: model,
|
||||
DurationMs: time.Since(t0).Milliseconds(),
|
||||
Message: result.Message,
|
||||
Message: msg,
|
||||
})
|
||||
}
|
||||
|
||||
b, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal result: %w", err)
|
||||
}
|
||||
return b, nil
|
||||
return json.Marshal(map[string]any{"text": text, "model": model, "duration_ms": dur})
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/skills/debug"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -33,29 +32,22 @@ func TestDebugRequiresError(t *testing.T) {
|
||||
assert.ErrorContains(t, err, "error")
|
||||
}
|
||||
|
||||
func TestDebugCallsExecutor(t *testing.T) {
|
||||
called := false
|
||||
func TestDebugCallsCompleteFunc(t *testing.T) {
|
||||
var capturedTask string
|
||||
fakeFn := func(_ context.Context, req iexec.Request) (iexec.Result, error) {
|
||||
called = true
|
||||
capturedTask = req.TaskPrompt
|
||||
return iexec.Result{
|
||||
Status: "pass", Phase: "debug", Skill: "debug",
|
||||
RunnerOutput: "HYPOTHESIS 1 (likelihood: high): nil map access\nVERIFY: go test ./... → expected: panic line reference",
|
||||
Verified: false, ModelUsed: "self", Message: "3 hypotheses for: panic nil pointer at foo.go:42",
|
||||
}, nil
|
||||
fakeFn := func(_ context.Context, _, _, user string) (string, int64, error) {
|
||||
capturedTask = user
|
||||
return "HYPOTHESIS 1 (high): nil map access. Verify: go test ./...", 90, nil
|
||||
}
|
||||
|
||||
sk := debug.New(debug.Config{SkillPrompt: "debug rules", ExecutorFn: fakeFn, SessionsDir: t.TempDir()})
|
||||
sk := debug.New(debug.Config{SkillPrompt: "debug rules", CompleteFunc: fakeFn, SessionsDir: t.TempDir()})
|
||||
out, err := sk.Handle(context.Background(), "debug", json.RawMessage(
|
||||
`{"project_root":"/tmp/proj","error":"panic: nil pointer dereference at foo.go:42","context":"occurs on startup"}`,
|
||||
))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, called)
|
||||
assert.Contains(t, capturedTask, "panic: nil pointer dereference")
|
||||
assert.Contains(t, capturedTask, "occurs on startup")
|
||||
|
||||
var result iexec.Result
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &result))
|
||||
assert.Equal(t, "debug", result.Phase)
|
||||
assert.Contains(t, result["text"], "nil map access")
|
||||
}
|
||||
|
||||
@@ -5,20 +5,19 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/registry"
|
||||
)
|
||||
|
||||
// ExecutorFn is the function signature for running a worker subprocess.
|
||||
type ExecutorFn func(ctx context.Context, req iexec.Request) (iexec.Result, error)
|
||||
// CompleteFunc is the function used to call a local model.
|
||||
type CompleteFunc func(ctx context.Context, model, system, user string) (string, int64, error)
|
||||
|
||||
// Config holds dependencies for the debug skill.
|
||||
type Config struct {
|
||||
SkillPrompt string
|
||||
DefaultModel string
|
||||
ExecutorFn ExecutorFn
|
||||
CompleteFunc CompleteFunc
|
||||
SessionsDir string
|
||||
IngestBaseURL string // optional: base URL of ingestion server for brain context
|
||||
IngestBaseURL string
|
||||
}
|
||||
|
||||
// Skill implements the debug MCP tool.
|
||||
@@ -40,7 +39,7 @@ func (s *Skill) Tools() []registry.ToolDef {
|
||||
return []registry.ToolDef{
|
||||
{
|
||||
Name: "debug",
|
||||
Description: "Analyse an error and return 3-5 hypotheses ordered by likelihood, each with a concrete verification step.",
|
||||
Description: "Consult a local model to analyse an error and return hypotheses ordered by likelihood, each with a concrete verification step.",
|
||||
InputSchema: schema(
|
||||
[]string{"project_root", "error"},
|
||||
map[string]any{
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/session"
|
||||
)
|
||||
|
||||
@@ -34,7 +33,6 @@ func (s *Skill) Handle(ctx context.Context, tool string, args json.RawMessage) (
|
||||
model = s.cfg.DefaultModel
|
||||
}
|
||||
|
||||
// Read session log entries (empty slice if no log exists yet).
|
||||
entries, err := session.Read(s.cfg.SessionsDir, a.SessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read session log: %w", err)
|
||||
@@ -46,39 +44,33 @@ func (s *Skill) Handle(ctx context.Context, tool string, args json.RawMessage) (
|
||||
}
|
||||
|
||||
taskPrompt := fmt.Sprintf(
|
||||
"SESSION_ID: %s\n\nSESSION_LOG:\n%s\n\nReview this session log. Identify what is novel or worth preserving as organizational knowledge. Write structured entries to brain/raw/ via brain_write. Return JSON result when done.",
|
||||
"SESSION_ID: %s\n\nSESSION_LOG:\n%s\n\nReview this session log. Identify what is novel or worth preserving as organizational knowledge. Provide structured insights.",
|
||||
a.SessionID, string(logJSON),
|
||||
)
|
||||
|
||||
if s.cfg.ExecutorFn == nil {
|
||||
if s.cfg.CompleteFunc == nil {
|
||||
return nil, fmt.Errorf("no executor configured")
|
||||
}
|
||||
t0 := time.Now()
|
||||
result, err := s.cfg.ExecutorFn(ctx, iexec.Request{
|
||||
SkillPrompt: s.cfg.SkillPrompt,
|
||||
TaskPrompt: taskPrompt,
|
||||
Model: model,
|
||||
Tools: "Bash,Read,Write",
|
||||
})
|
||||
text, dur, err := s.cfg.CompleteFunc(ctx, model, s.cfg.SkillPrompt, taskPrompt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("retrospective worker: %w", err)
|
||||
return nil, fmt.Errorf("retrospective model: %w", err)
|
||||
}
|
||||
|
||||
msg := text
|
||||
if len(msg) > 200 {
|
||||
msg = msg[:200]
|
||||
}
|
||||
_ = session.Append(s.cfg.SessionsDir, a.SessionID, session.Entry{
|
||||
SessionID: a.SessionID,
|
||||
Timestamp: time.Now(),
|
||||
Skill: "retrospective",
|
||||
Phase: "retrospective",
|
||||
Attempts: session.AttemptsFrom(result.Attempts),
|
||||
FinalStatus: result.Status,
|
||||
ModelUsed: result.ModelUsed,
|
||||
FinalStatus: "ok",
|
||||
ModelUsed: model,
|
||||
DurationMs: time.Since(t0).Milliseconds(),
|
||||
Message: result.Message,
|
||||
Message: msg,
|
||||
})
|
||||
|
||||
b, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal result: %w", err)
|
||||
}
|
||||
return b, nil
|
||||
return json.Marshal(map[string]any{"text": text, "model": model, "duration_ms": dur})
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/skills/retrospective"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -20,20 +19,14 @@ func TestHandle_Retrospective_RequiresSessionID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHandle_Retrospective_BuildsPromptWithSessionLog(t *testing.T) {
|
||||
var capturedReq iexec.Request
|
||||
var capturedTask string
|
||||
s := retrospective.New(retrospective.Config{
|
||||
SkillPrompt: "retrospective discipline",
|
||||
DefaultModel: "ollama/test",
|
||||
SessionsDir: t.TempDir(), // empty dir, no session file — that's OK, session.Read returns nil
|
||||
ExecutorFn: func(_ context.Context, req iexec.Request) (iexec.Result, error) {
|
||||
capturedReq = req
|
||||
return iexec.Result{
|
||||
Status: "pass",
|
||||
Phase: "retrospective",
|
||||
Skill: "retrospective",
|
||||
Verified: true,
|
||||
Message: "wrote 2 entries to brain",
|
||||
}, nil
|
||||
SessionsDir: t.TempDir(),
|
||||
CompleteFunc: func(_ context.Context, _, _, user string) (string, int64, error) {
|
||||
capturedTask = user
|
||||
return "Key insight: the team resolved a tricky nil pointer issue via careful logging.", 75, nil
|
||||
},
|
||||
})
|
||||
|
||||
@@ -41,9 +34,8 @@ func TestHandle_Retrospective_BuildsPromptWithSessionLog(t *testing.T) {
|
||||
out, err := s.Handle(context.Background(), "retrospective", args)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result iexec.Result
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &result))
|
||||
assert.Equal(t, "pass", result.Status)
|
||||
assert.Contains(t, capturedReq.SkillPrompt, "retrospective discipline")
|
||||
assert.Contains(t, capturedReq.TaskPrompt, "empty-session")
|
||||
assert.Contains(t, result["text"], "nil pointer")
|
||||
assert.Contains(t, capturedTask, "empty-session")
|
||||
}
|
||||
|
||||
@@ -5,19 +5,18 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/registry"
|
||||
)
|
||||
|
||||
// ExecutorFn allows injecting a test double for the subprocess executor.
|
||||
type ExecutorFn func(ctx context.Context, req iexec.Request) (iexec.Result, error)
|
||||
// CompleteFunc is the function used to call a local model.
|
||||
type CompleteFunc func(ctx context.Context, model, system, user string) (string, int64, error)
|
||||
|
||||
// Config holds retrospective skill configuration.
|
||||
type Config struct {
|
||||
SkillPrompt string // content of retrospective.md
|
||||
DefaultModel string // model to use when not specified in args
|
||||
SessionsDir string // path to brain/sessions/
|
||||
ExecutorFn ExecutorFn // injected executor
|
||||
SkillPrompt string
|
||||
DefaultModel string
|
||||
SessionsDir string
|
||||
CompleteFunc CompleteFunc
|
||||
}
|
||||
|
||||
// Skill implements registry.Skill for the retrospective tool.
|
||||
@@ -36,7 +35,7 @@ func (s *Skill) Tools() []registry.ToolDef {
|
||||
return []registry.ToolDef{
|
||||
{
|
||||
Name: "retrospective",
|
||||
Description: "Run a retrospective on a completed session. Reads the session log, identifies novel learnings, and writes structured entries to the brain for ingestion. Call at the end of each coding session.",
|
||||
Description: "Consult a local model to analyse a completed session and identify what is novel or worth preserving as organizational knowledge.",
|
||||
InputSchema: json.RawMessage(`{
|
||||
"type": "object",
|
||||
"required": ["session_id"],
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/mathiasbq/supervisor/internal/brain"
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/session"
|
||||
)
|
||||
|
||||
@@ -53,39 +52,32 @@ func (s *Skill) Handle(ctx context.Context, tool string, args json.RawMessage) (
|
||||
task = brainCtx + "\n---\n\n" + task
|
||||
}
|
||||
|
||||
if s.cfg.ExecutorFn == nil {
|
||||
if s.cfg.CompleteFunc == nil {
|
||||
return nil, fmt.Errorf("no executor configured")
|
||||
}
|
||||
t0 := time.Now()
|
||||
result, err := s.cfg.ExecutorFn(ctx, iexec.Request{
|
||||
SkillPrompt: s.cfg.SkillPrompt,
|
||||
TaskPrompt: task,
|
||||
Model: model,
|
||||
Tools: "Read,Bash",
|
||||
})
|
||||
text, dur, err := s.cfg.CompleteFunc(ctx, model, s.cfg.SkillPrompt, task)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if a.SessionID != "" && s.cfg.SessionsDir != "" {
|
||||
msg := text
|
||||
if len(msg) > 200 {
|
||||
msg = msg[:200]
|
||||
}
|
||||
_ = session.Append(s.cfg.SessionsDir, a.SessionID, session.Entry{
|
||||
SessionID: a.SessionID,
|
||||
Timestamp: time.Now(),
|
||||
Skill: "review",
|
||||
Phase: "review",
|
||||
ProjectRoot: a.ProjectRoot,
|
||||
Attempts: session.AttemptsFrom(result.Attempts),
|
||||
FinalStatus: result.Status,
|
||||
FilePath: result.FilePath,
|
||||
ModelUsed: result.ModelUsed,
|
||||
FinalStatus: "ok",
|
||||
ModelUsed: model,
|
||||
DurationMs: time.Since(t0).Milliseconds(),
|
||||
Message: result.Message,
|
||||
Message: msg,
|
||||
})
|
||||
}
|
||||
|
||||
b, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal result: %w", err)
|
||||
}
|
||||
return b, nil
|
||||
return json.Marshal(map[string]any{"text": text, "model": model, "duration_ms": dur})
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/skills/review"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -33,29 +32,22 @@ func TestReviewRequiresFiles(t *testing.T) {
|
||||
assert.ErrorContains(t, err, "files")
|
||||
}
|
||||
|
||||
func TestReviewCallsExecutor(t *testing.T) {
|
||||
called := false
|
||||
func TestReviewCallsCompleteFunc(t *testing.T) {
|
||||
var capturedTask string
|
||||
fakeFn := func(_ context.Context, req iexec.Request) (iexec.Result, error) {
|
||||
called = true
|
||||
capturedTask = req.TaskPrompt
|
||||
return iexec.Result{
|
||||
Status: "pass", Phase: "review", Skill: "review",
|
||||
Verified: true, ModelUsed: "self", Message: "2 warnings found",
|
||||
}, nil
|
||||
fakeFn := func(_ context.Context, _, _, user string) (string, int64, error) {
|
||||
capturedTask = user
|
||||
return "2 warnings found: missing error handling at line 42", 80, nil
|
||||
}
|
||||
|
||||
sk := review.New(review.Config{SkillPrompt: "review rules", ExecutorFn: fakeFn, SessionsDir: t.TempDir()})
|
||||
sk := review.New(review.Config{SkillPrompt: "review rules", CompleteFunc: fakeFn, SessionsDir: t.TempDir()})
|
||||
out, err := sk.Handle(context.Background(), "review", json.RawMessage(
|
||||
`{"project_root":"/tmp/proj","files":["internal/foo/foo.go"],"context":"PR: add Foo helper"}`,
|
||||
))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, called)
|
||||
assert.Contains(t, capturedTask, "internal/foo/foo.go")
|
||||
assert.Contains(t, capturedTask, "PR: add Foo helper")
|
||||
|
||||
var result iexec.Result
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &result))
|
||||
assert.Equal(t, "pass", result.Status)
|
||||
assert.Equal(t, "review", result.Phase)
|
||||
assert.Contains(t, result["text"], "2 warnings found")
|
||||
}
|
||||
|
||||
@@ -5,20 +5,19 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/registry"
|
||||
)
|
||||
|
||||
// ExecutorFn is the function signature for running a worker subprocess.
|
||||
type ExecutorFn func(ctx context.Context, req iexec.Request) (iexec.Result, error)
|
||||
// CompleteFunc is the function used to call a local model.
|
||||
type CompleteFunc func(ctx context.Context, model, system, user string) (string, int64, error)
|
||||
|
||||
// Config holds dependencies for the review skill.
|
||||
type Config struct {
|
||||
SkillPrompt string
|
||||
DefaultModel string
|
||||
ExecutorFn ExecutorFn
|
||||
CompleteFunc CompleteFunc
|
||||
SessionsDir string
|
||||
IngestBaseURL string // optional: base URL of ingestion server for brain context
|
||||
IngestBaseURL string
|
||||
}
|
||||
|
||||
// Skill implements the review MCP tool.
|
||||
@@ -40,7 +39,7 @@ func (s *Skill) Tools() []registry.ToolDef {
|
||||
return []registry.ToolDef{
|
||||
{
|
||||
Name: "review",
|
||||
Description: "Perform a structured code review of the specified files. Returns findings with severity levels.",
|
||||
Description: "Consult a local model for a structured code review of the specified files. Returns findings with severity levels.",
|
||||
InputSchema: schema(
|
||||
[]string{"project_root", "files"},
|
||||
map[string]any{
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/mathiasbq/supervisor/internal/brain"
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/session"
|
||||
)
|
||||
|
||||
@@ -57,39 +56,32 @@ func (s *Skill) Handle(ctx context.Context, tool string, args json.RawMessage) (
|
||||
task = brainCtx + "\n---\n\n" + task
|
||||
}
|
||||
|
||||
if s.cfg.ExecutorFn == nil {
|
||||
if s.cfg.CompleteFunc == nil {
|
||||
return nil, fmt.Errorf("no executor configured")
|
||||
}
|
||||
t0 := time.Now()
|
||||
result, err := s.cfg.ExecutorFn(ctx, iexec.Request{
|
||||
SkillPrompt: s.cfg.SkillPrompt,
|
||||
TaskPrompt: task,
|
||||
Model: model,
|
||||
Tools: "Read,Write",
|
||||
})
|
||||
text, dur, err := s.cfg.CompleteFunc(ctx, model, s.cfg.SkillPrompt, task)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if a.SessionID != "" && s.cfg.SessionsDir != "" {
|
||||
msg := text
|
||||
if len(msg) > 200 {
|
||||
msg = msg[:200]
|
||||
}
|
||||
_ = session.Append(s.cfg.SessionsDir, a.SessionID, session.Entry{
|
||||
SessionID: a.SessionID,
|
||||
Timestamp: time.Now(),
|
||||
Skill: "spec",
|
||||
Phase: "spec",
|
||||
ProjectRoot: a.ProjectRoot,
|
||||
Attempts: session.AttemptsFrom(result.Attempts),
|
||||
FinalStatus: result.Status,
|
||||
FilePath: result.FilePath,
|
||||
ModelUsed: result.ModelUsed,
|
||||
FinalStatus: "ok",
|
||||
ModelUsed: model,
|
||||
DurationMs: time.Since(t0).Milliseconds(),
|
||||
Message: result.Message,
|
||||
Message: msg,
|
||||
})
|
||||
}
|
||||
|
||||
b, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal result: %w", err)
|
||||
}
|
||||
return b, nil
|
||||
return json.Marshal(map[string]any{"text": text, "model": model, "duration_ms": dur})
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/skills/spec"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -33,29 +32,22 @@ func TestSpecRequiresRequirements(t *testing.T) {
|
||||
assert.ErrorContains(t, err, "requirements")
|
||||
}
|
||||
|
||||
func TestSpecCallsExecutor(t *testing.T) {
|
||||
called := false
|
||||
func TestSpecCallsCompleteFunc(t *testing.T) {
|
||||
var capturedTask string
|
||||
fakeFn := func(_ context.Context, req iexec.Request) (iexec.Result, error) {
|
||||
called = true
|
||||
capturedTask = req.TaskPrompt
|
||||
return iexec.Result{
|
||||
Status: "pass", Phase: "spec", Skill: "spec",
|
||||
FilePath: "/tmp/proj/docs/login-spec.md",
|
||||
Verified: true, ModelUsed: "self", Message: "spec written: login feature",
|
||||
}, nil
|
||||
fakeFn := func(_ context.Context, _, _, user string) (string, int64, error) {
|
||||
capturedTask = user
|
||||
return "# OAuth2 Login Spec\n\n## Overview\nImplement OAuth2 login flow.", 110, nil
|
||||
}
|
||||
|
||||
sk := spec.New(spec.Config{SkillPrompt: "spec rules", ExecutorFn: fakeFn, SessionsDir: t.TempDir()})
|
||||
sk := spec.New(spec.Config{SkillPrompt: "spec rules", CompleteFunc: fakeFn, SessionsDir: t.TempDir()})
|
||||
out, err := sk.Handle(context.Background(), "spec", json.RawMessage(
|
||||
`{"project_root":"/tmp/proj","requirements":"add OAuth2 login","output_path":"docs/login-spec.md"}`,
|
||||
))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, called)
|
||||
assert.Contains(t, capturedTask, "OAuth2 login")
|
||||
assert.Contains(t, capturedTask, "docs/login-spec.md")
|
||||
|
||||
var result iexec.Result
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &result))
|
||||
assert.Equal(t, "spec", result.Phase)
|
||||
assert.Contains(t, result["text"], "OAuth2 Login Spec")
|
||||
}
|
||||
|
||||
@@ -5,20 +5,19 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/registry"
|
||||
)
|
||||
|
||||
// ExecutorFn is the function signature for running a worker subprocess.
|
||||
type ExecutorFn func(ctx context.Context, req iexec.Request) (iexec.Result, error)
|
||||
// CompleteFunc is the function used to call a local model.
|
||||
type CompleteFunc func(ctx context.Context, model, system, user string) (string, int64, error)
|
||||
|
||||
// Config holds dependencies for the spec skill.
|
||||
type Config struct {
|
||||
SkillPrompt string
|
||||
DefaultModel string
|
||||
ExecutorFn ExecutorFn
|
||||
CompleteFunc CompleteFunc
|
||||
SessionsDir string
|
||||
IngestBaseURL string // optional: base URL of ingestion server for brain context
|
||||
IngestBaseURL string
|
||||
}
|
||||
|
||||
// Skill implements the spec MCP tool.
|
||||
@@ -40,7 +39,7 @@ func (s *Skill) Tools() []registry.ToolDef {
|
||||
return []registry.ToolDef{
|
||||
{
|
||||
Name: "spec",
|
||||
Description: "Generate a structured implementation spec from requirements. Writes the spec to output_path in the project.",
|
||||
Description: "Consult a local model to draft a structured implementation spec from requirements. Returns the spec text.",
|
||||
InputSchema: schema(
|
||||
[]string{"project_root", "requirements"},
|
||||
map[string]any{
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/mathiasbq/supervisor/internal/brain"
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/session"
|
||||
)
|
||||
|
||||
@@ -51,7 +50,7 @@ func (s *Skill) handleRed(ctx context.Context, raw json.RawMessage) (json.RawMes
|
||||
if brainCtx != "" {
|
||||
task = brainCtx + "\n---\n\n" + task
|
||||
}
|
||||
return s.execute(ctx, task)
|
||||
return s.complete(ctx, s.resolveModel(args.Model), task)
|
||||
}
|
||||
|
||||
type greenArgs struct {
|
||||
@@ -80,11 +79,11 @@ func (s *Skill) handleGreen(ctx context.Context, raw json.RawMessage) (json.RawM
|
||||
task = session.PrependHistory(s.cfg.SessionsDir, args.SessionID, "green", task)
|
||||
|
||||
t0 := time.Now()
|
||||
result, err := s.execute(ctx, task)
|
||||
result, err := s.complete(ctx, s.resolveModel(args.Model), task)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.logAttempt(args.SessionID, args.ProjectRoot, "tdd", "green", t0, result)
|
||||
s.logEntry(args.SessionID, args.ProjectRoot, "tdd", "green", s.resolveModel(args.Model), t0, result)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -118,11 +117,11 @@ func (s *Skill) handleRefactor(ctx context.Context, raw json.RawMessage) (json.R
|
||||
task = session.PrependHistory(s.cfg.SessionsDir, args.SessionID, "refactor", task)
|
||||
|
||||
t0 := time.Now()
|
||||
result, err := s.execute(ctx, task)
|
||||
result, err := s.complete(ctx, s.resolveModel(args.Model), task)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.logAttempt(args.SessionID, args.ProjectRoot, "tdd", "refactor", t0, result)
|
||||
s.logEntry(args.SessionID, args.ProjectRoot, "tdd", "refactor", s.resolveModel(args.Model), t0, result)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -133,31 +132,32 @@ func (s *Skill) resolveModel(override string) string {
|
||||
return s.cfg.DefaultModel
|
||||
}
|
||||
|
||||
// execute calls ExecutorFn and returns the marshaled result.
|
||||
func (s *Skill) execute(ctx context.Context, task string) (json.RawMessage, error) {
|
||||
if s.cfg.ExecutorFn == nil {
|
||||
// complete calls CompleteFunc and returns the text as JSON.
|
||||
func (s *Skill) complete(ctx context.Context, model, task string) (json.RawMessage, error) {
|
||||
if s.cfg.CompleteFunc == nil {
|
||||
return nil, fmt.Errorf("no executor configured")
|
||||
}
|
||||
req := iexec.Request{
|
||||
SkillPrompt: s.cfg.SkillPrompt,
|
||||
TaskPrompt: task,
|
||||
}
|
||||
result, err := s.cfg.ExecutorFn(ctx, req)
|
||||
text, dur, err := s.cfg.CompleteFunc(ctx, model, s.cfg.SkillPrompt, task)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(result)
|
||||
return json.Marshal(map[string]any{"text": text, "model": model, "duration_ms": dur})
|
||||
}
|
||||
|
||||
// logAttempt writes a session.Entry for a completed phase if session_id is set.
|
||||
// raw is the marshaled Result returned by execute; we unmarshal to extract fields.
|
||||
func (s *Skill) logAttempt(sessionID, projectRoot, skill, phase string, t0 time.Time, raw json.RawMessage) {
|
||||
// logEntry writes a session.Entry for a completed phase if session_id is set.
|
||||
func (s *Skill) logEntry(sessionID, projectRoot, skill, phase, model string, t0 time.Time, raw json.RawMessage) {
|
||||
if sessionID == "" || s.cfg.SessionsDir == "" {
|
||||
return
|
||||
}
|
||||
var result iexec.Result
|
||||
if err := json.Unmarshal(raw, &result); err != nil {
|
||||
return
|
||||
var msg string
|
||||
var result struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &result); err == nil && len(result.Text) > 0 {
|
||||
msg = result.Text
|
||||
if len(msg) > 200 {
|
||||
msg = msg[:200]
|
||||
}
|
||||
}
|
||||
_ = session.Append(s.cfg.SessionsDir, sessionID, session.Entry{
|
||||
SessionID: sessionID,
|
||||
@@ -165,11 +165,9 @@ func (s *Skill) logAttempt(sessionID, projectRoot, skill, phase string, t0 time.
|
||||
Skill: skill,
|
||||
Phase: phase,
|
||||
ProjectRoot: projectRoot,
|
||||
Attempts: session.AttemptsFrom(result.Attempts),
|
||||
FinalStatus: result.Status,
|
||||
FilePath: result.FilePath,
|
||||
ModelUsed: result.ModelUsed,
|
||||
FinalStatus: "ok",
|
||||
ModelUsed: model,
|
||||
DurationMs: time.Since(t0).Milliseconds(),
|
||||
Message: result.Message,
|
||||
Message: msg,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/session"
|
||||
"github.com/mathiasbq/supervisor/internal/skills/tdd"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -14,7 +13,6 @@ import (
|
||||
|
||||
func TestTDDSkillTools(t *testing.T) {
|
||||
skill := tdd.New(tdd.Config{
|
||||
SystemPrompt: "supervisor rules",
|
||||
SkillPrompt: "tdd rules",
|
||||
})
|
||||
tools := skill.Tools()
|
||||
@@ -26,19 +24,19 @@ func TestTDDSkillTools(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTDDSkillHandleUnknown(t *testing.T) {
|
||||
skill := tdd.New(tdd.Config{SystemPrompt: "s", SkillPrompt: "t"})
|
||||
skill := tdd.New(tdd.Config{SkillPrompt: "t"})
|
||||
_, err := skill.Handle(context.Background(), "tdd_unknown", json.RawMessage(`{}`))
|
||||
assert.ErrorContains(t, err, "unknown tool")
|
||||
}
|
||||
|
||||
func TestTDDRedRequiresProjectRoot(t *testing.T) {
|
||||
skill := tdd.New(tdd.Config{SystemPrompt: "s", SkillPrompt: "t"})
|
||||
skill := tdd.New(tdd.Config{SkillPrompt: "t"})
|
||||
_, err := skill.Handle(context.Background(), "tdd_red", json.RawMessage(`{"spec":"add two numbers"}`))
|
||||
assert.ErrorContains(t, err, "project_root")
|
||||
}
|
||||
|
||||
func TestTDDRedRequiresSpec(t *testing.T) {
|
||||
skill := tdd.New(tdd.Config{SystemPrompt: "s", SkillPrompt: "t"})
|
||||
skill := tdd.New(tdd.Config{SkillPrompt: "t"})
|
||||
_, err := skill.Handle(context.Background(), "tdd_red", json.RawMessage(`{"project_root":"/tmp/proj"}`))
|
||||
assert.ErrorContains(t, err, "spec")
|
||||
}
|
||||
@@ -51,35 +49,49 @@ func TestTDDGreenInjectsSessionHistory(t *testing.T) {
|
||||
Message: "wrote failing test for Foo",
|
||||
}))
|
||||
|
||||
var capturedPrompt string
|
||||
fakeFn := func(_ context.Context, req iexec.Request) (iexec.Result, error) {
|
||||
capturedPrompt = req.TaskPrompt
|
||||
return iexec.Result{Status: "pass", Phase: "green", Skill: "tdd", Verified: true, ModelUsed: "self", Message: "ok"}, nil
|
||||
var capturedTask string
|
||||
fakeFn := func(_ context.Context, _, _, user string) (string, int64, error) {
|
||||
capturedTask = user
|
||||
return "here is my suggestion", 100, nil
|
||||
}
|
||||
|
||||
sk := tdd.New(tdd.Config{SkillPrompt: "tdd", ExecutorFn: fakeFn, SessionsDir: sessDir})
|
||||
sk := tdd.New(tdd.Config{SkillPrompt: "tdd", CompleteFunc: fakeFn, SessionsDir: sessDir})
|
||||
_, err := sk.Handle(context.Background(), "tdd_green", json.RawMessage(
|
||||
`{"project_root":"/tmp","test_path":"internal/foo/foo_test.go","test_cmd":"go test ./...","session_id":"sess-1"}`,
|
||||
))
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, capturedPrompt, "## Session history")
|
||||
assert.Contains(t, capturedPrompt, "wrote failing test for Foo")
|
||||
assert.Contains(t, capturedTask, "## Session history")
|
||||
assert.Contains(t, capturedTask, "wrote failing test for Foo")
|
||||
}
|
||||
|
||||
func TestTDDGreenNoHistoryWhenSessionIDEmpty(t *testing.T) {
|
||||
var capturedPrompt string
|
||||
fakeFn := func(_ context.Context, req iexec.Request) (iexec.Result, error) {
|
||||
capturedPrompt = req.TaskPrompt
|
||||
return iexec.Result{Status: "pass", Phase: "green", Skill: "tdd", Verified: true, ModelUsed: "self", Message: "ok"}, nil
|
||||
var capturedTask string
|
||||
fakeFn := func(_ context.Context, _, _, user string) (string, int64, error) {
|
||||
capturedTask = user
|
||||
return "suggestion", 50, nil
|
||||
}
|
||||
|
||||
sk := tdd.New(tdd.Config{SkillPrompt: "tdd", ExecutorFn: fakeFn, SessionsDir: t.TempDir()})
|
||||
sk := tdd.New(tdd.Config{SkillPrompt: "tdd", CompleteFunc: fakeFn, SessionsDir: t.TempDir()})
|
||||
_, err := sk.Handle(context.Background(), "tdd_green", json.RawMessage(
|
||||
`{"project_root":"/tmp","test_path":"internal/foo/foo_test.go"}`,
|
||||
))
|
||||
require.NoError(t, err)
|
||||
assert.NotContains(t, capturedPrompt, "## Session history")
|
||||
assert.NotContains(t, capturedTask, "## Session history")
|
||||
}
|
||||
|
||||
// Ensure require is used (avoids import error).
|
||||
var _ = require.New
|
||||
func TestTDDGreenReturnsTextJSON(t *testing.T) {
|
||||
fakeFn := func(_ context.Context, _, _, _ string) (string, int64, error) {
|
||||
return "write a func that adds two ints", 42, nil
|
||||
}
|
||||
|
||||
sk := tdd.New(tdd.Config{SkillPrompt: "tdd", CompleteFunc: fakeFn})
|
||||
raw, err := sk.Handle(context.Background(), "tdd_green", json.RawMessage(
|
||||
`{"project_root":"/tmp","test_path":"foo_test.go"}`,
|
||||
))
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(raw, &result))
|
||||
assert.Equal(t, "write a func that adds two ints", result["text"])
|
||||
assert.Equal(t, float64(42), result["duration_ms"])
|
||||
}
|
||||
|
||||
@@ -4,17 +4,15 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/registry"
|
||||
)
|
||||
|
||||
// ExecutorFn allows injecting a test double for the executor.
|
||||
type ExecutorFn func(ctx context.Context, req iexec.Request) (iexec.Result, error)
|
||||
// CompleteFunc is the function used to call a local model.
|
||||
type CompleteFunc func(ctx context.Context, model, system, user string) (string, int64, error)
|
||||
|
||||
type Config struct {
|
||||
SystemPrompt string
|
||||
SkillPrompt string
|
||||
ExecutorFn ExecutorFn // nil = no executor (tests that don't reach execute())
|
||||
CompleteFunc CompleteFunc // nil = no executor (tests that don't reach execute())
|
||||
DefaultModel string
|
||||
SessionsDir string // optional: path to brain/sessions/ for history injection
|
||||
IngestBaseURL string // optional: base URL of ingestion server for brain context
|
||||
@@ -44,7 +42,7 @@ func (s *Skill) Tools() []registry.ToolDef {
|
||||
return []registry.ToolDef{
|
||||
{
|
||||
Name: "tdd_red",
|
||||
Description: "Write a failing test for the described behavior. Verifies the test fails before returning.",
|
||||
Description: "Consult a local model for help writing a failing test for the described behavior.",
|
||||
InputSchema: schema(
|
||||
[]string{"project_root", "spec"},
|
||||
map[string]any{
|
||||
@@ -57,7 +55,7 @@ func (s *Skill) Tools() []registry.ToolDef {
|
||||
},
|
||||
{
|
||||
Name: "tdd_green",
|
||||
Description: "Write minimal implementation to make the test at test_path pass.",
|
||||
Description: "Consult a local model for implementation ideas to make the test at test_path pass.",
|
||||
InputSchema: schema(
|
||||
[]string{"project_root", "test_path"},
|
||||
map[string]any{
|
||||
@@ -71,7 +69,7 @@ func (s *Skill) Tools() []registry.ToolDef {
|
||||
},
|
||||
{
|
||||
Name: "tdd_refactor",
|
||||
Description: "Refactor the implementation at impl_path while keeping tests green.",
|
||||
Description: "Consult a local model for refactoring suggestions for impl_path while keeping tests green.",
|
||||
InputSchema: schema(
|
||||
[]string{"project_root", "test_path", "impl_path"},
|
||||
map[string]any{
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/session"
|
||||
)
|
||||
|
||||
@@ -28,7 +27,7 @@ func (s *Skill) Handle(ctx context.Context, tool string, args json.RawMessage) (
|
||||
if a.SessionID == "" {
|
||||
return nil, fmt.Errorf("session_id is required")
|
||||
}
|
||||
if s.cfg.ExecutorFn == nil {
|
||||
if s.cfg.CompleteFunc == nil {
|
||||
return nil, fmt.Errorf("no executor configured")
|
||||
}
|
||||
|
||||
@@ -42,53 +41,47 @@ func (s *Skill) Handle(ctx context.Context, tool string, args json.RawMessage) (
|
||||
return nil, fmt.Errorf("read session log: %w", err)
|
||||
}
|
||||
|
||||
// ── Step 1: Reader agent ─────────────────────────────────────────────────
|
||||
// ── Step 1: Reader ────────────────────────────────────────────────────────
|
||||
history := session.FormatHistory(entries, "")
|
||||
readerTask := fmt.Sprintf(
|
||||
"role: reader\nsession_id: %s\nbrain_dir: %s\n\n%s",
|
||||
a.SessionID, s.cfg.BrainDir, history,
|
||||
)
|
||||
readerResult, err := s.cfg.ExecutorFn(ctx, iexec.Request{
|
||||
SkillPrompt: s.cfg.ReaderPrompt,
|
||||
TaskPrompt: readerTask,
|
||||
Model: model,
|
||||
Tools: "Read",
|
||||
})
|
||||
readerText, _, err := s.cfg.CompleteFunc(ctx, model, s.cfg.ReaderPrompt, readerTask)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reader agent: %w", err)
|
||||
return nil, fmt.Errorf("reader: %w", err)
|
||||
}
|
||||
|
||||
// ── Step 2: Writer agent (receives reader candidates) ────────────────────
|
||||
// ── Step 2: Writer (receives reader output) ───────────────────────────────
|
||||
t0 := time.Now()
|
||||
writerTask := fmt.Sprintf(
|
||||
"role: writer\nsession_id: %s\nbrain_dir: %s\n\nreader_summary: %s\nreader_candidates:\n%s",
|
||||
a.SessionID, s.cfg.BrainDir, readerResult.Message, readerResult.RunnerOutput,
|
||||
"role: writer\nsession_id: %s\nbrain_dir: %s\n\nreader_analysis:\n%s",
|
||||
a.SessionID, s.cfg.BrainDir, readerText,
|
||||
)
|
||||
writerResult, err := s.cfg.ExecutorFn(ctx, iexec.Request{
|
||||
SkillPrompt: s.cfg.WriterPrompt,
|
||||
TaskPrompt: writerTask,
|
||||
Model: model,
|
||||
Tools: "Read,Write",
|
||||
})
|
||||
writerText, dur, err := s.cfg.CompleteFunc(ctx, model, s.cfg.WriterPrompt, writerTask)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("writer agent: %w", err)
|
||||
return nil, fmt.Errorf("writer: %w", err)
|
||||
}
|
||||
|
||||
msg := writerText
|
||||
if len(msg) > 200 {
|
||||
msg = msg[:200]
|
||||
}
|
||||
_ = session.Append(s.cfg.SessionsDir, a.SessionID, session.Entry{
|
||||
SessionID: a.SessionID,
|
||||
Timestamp: time.Now(),
|
||||
Skill: "trainer",
|
||||
Phase: "trainer",
|
||||
Attempts: session.AttemptsFrom(writerResult.Attempts),
|
||||
FinalStatus: writerResult.Status,
|
||||
ModelUsed: writerResult.ModelUsed,
|
||||
FinalStatus: "ok",
|
||||
ModelUsed: model,
|
||||
DurationMs: time.Since(t0).Milliseconds(),
|
||||
Message: writerResult.Message,
|
||||
Message: msg,
|
||||
})
|
||||
|
||||
b, err := json.Marshal(writerResult)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal result: %w", err)
|
||||
}
|
||||
return b, nil
|
||||
return json.Marshal(map[string]any{
|
||||
"reader_analysis": readerText,
|
||||
"writer_output": writerText,
|
||||
"model": model,
|
||||
"duration_ms": dur,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/session"
|
||||
"github.com/mathiasbq/supervisor/internal/skills/trainer"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -31,52 +30,44 @@ func TestTrainerRequiresSessionID(t *testing.T) {
|
||||
func TestTrainerCallsReaderThenWriter(t *testing.T) {
|
||||
sessDir := t.TempDir()
|
||||
require.NoError(t, session.Append(sessDir, "sess-1", session.Entry{
|
||||
SessionID: "sess-1", Skill: "tdd", Phase: "red", FinalStatus: "pass",
|
||||
SessionID: "sess-1", Skill: "tdd", Phase: "red", FinalStatus: "ok",
|
||||
Message: "wrote failing test", FilePath: "internal/foo/foo_test.go",
|
||||
}))
|
||||
|
||||
callCount := 0
|
||||
var readerTask, writerTask string
|
||||
|
||||
fakeFn := func(_ context.Context, req iexec.Request) (iexec.Result, error) {
|
||||
fakeFn := func(_ context.Context, _, sys, user string) (string, int64, error) {
|
||||
callCount++
|
||||
if callCount == 1 {
|
||||
// reader call
|
||||
readerTask = req.TaskPrompt
|
||||
return iexec.Result{
|
||||
Status: "pass", Phase: "trainer", Skill: "trainer",
|
||||
RunnerOutput: `[{"type":"sft","moment":"first-pass clean TDD","score":4}]`,
|
||||
Verified: true, ModelUsed: "self", Message: "1 sft candidate found",
|
||||
}, nil
|
||||
readerTask = user
|
||||
return "1 sft candidate found: first-pass clean TDD", 60, nil
|
||||
}
|
||||
// writer call
|
||||
writerTask = req.TaskPrompt
|
||||
return iexec.Result{
|
||||
Status: "pass", Phase: "trainer", Skill: "trainer",
|
||||
FilePath: sessDir + "/training-data/sft/sess-1.jsonl",
|
||||
Verified: true, ModelUsed: "self", Message: "1 sft pair written",
|
||||
}, nil
|
||||
writerTask = user
|
||||
return "written 1 knowledge entry to brain/knowledge/tdd-patterns.md", 70, nil
|
||||
}
|
||||
|
||||
sk := trainer.New(trainer.Config{
|
||||
ReaderPrompt: "reader rules",
|
||||
WriterPrompt: "writer rules",
|
||||
ExecutorFn: fakeFn,
|
||||
CompleteFunc: fakeFn,
|
||||
SessionsDir: sessDir,
|
||||
BrainDir: t.TempDir(),
|
||||
})
|
||||
out, err := sk.Handle(context.Background(), "trainer", json.RawMessage(`{"session_id":"sess-1"}`))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 2, callCount, "executor must be called exactly twice: reader then writer")
|
||||
assert.Equal(t, 2, callCount, "complete must be called exactly twice: reader then writer")
|
||||
assert.Contains(t, readerTask, "role: reader")
|
||||
assert.Contains(t, readerTask, "sess-1")
|
||||
assert.Contains(t, readerTask, "wrote failing test") // session history in reader prompt
|
||||
assert.Contains(t, readerTask, "wrote failing test")
|
||||
assert.Contains(t, writerTask, "role: writer")
|
||||
assert.Contains(t, writerTask, "sft candidate") // reader output passed to writer
|
||||
assert.Contains(t, writerTask, "sft candidate")
|
||||
|
||||
var result iexec.Result
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &result))
|
||||
assert.Equal(t, "trainer", result.Phase)
|
||||
assert.Equal(t, "pass", result.Status)
|
||||
assert.Contains(t, result["reader_analysis"], "sft candidate")
|
||||
assert.Contains(t, result["writer_output"], "knowledge entry")
|
||||
}
|
||||
|
||||
@@ -5,21 +5,20 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/registry"
|
||||
)
|
||||
|
||||
// ExecutorFn is the function signature for running a worker subprocess.
|
||||
type ExecutorFn func(ctx context.Context, req iexec.Request) (iexec.Result, error)
|
||||
// CompleteFunc is the function used to call a local model.
|
||||
type CompleteFunc func(ctx context.Context, model, system, user string) (string, int64, error)
|
||||
|
||||
// Config holds dependencies for the trainer skill.
|
||||
type Config struct {
|
||||
ReaderPrompt string
|
||||
WriterPrompt string
|
||||
DefaultModel string
|
||||
ExecutorFn ExecutorFn
|
||||
CompleteFunc CompleteFunc
|
||||
SessionsDir string
|
||||
BrainDir string // root of brain/ directory; writer writes to BrainDir/training-data/
|
||||
BrainDir string // root of brain/ directory
|
||||
}
|
||||
|
||||
// Skill implements the trainer MCP tool.
|
||||
@@ -40,7 +39,7 @@ func (s *Skill) Tools() []registry.ToolDef {
|
||||
return []registry.ToolDef{
|
||||
{
|
||||
Name: "trainer",
|
||||
Description: "Extract SFT and DPO training pairs from a session log. Runs a reader→writer chain: reader identifies learning moments, writer formats and writes pairs to brain/training-data/.",
|
||||
Description: "Consult a local model to identify learning moments from a session log and suggest knowledge to preserve in the brain.",
|
||||
InputSchema: schema(
|
||||
[]string{"session_id"},
|
||||
map[string]any{
|
||||
|
||||
Reference in New Issue
Block a user