diff --git a/internal/exec/orchestrator.go b/internal/exec/orchestrator.go new file mode 100644 index 0000000..bddda2b --- /dev/null +++ b/internal/exec/orchestrator.go @@ -0,0 +1,197 @@ +package exec + +import ( + "context" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +// ChainEntry is one tier in an escalation chain. +type ChainEntry struct { + Model string // e.g. "ollama/phi4", "claude-sonnet-4-6" + Tier string // "local" | "subagent" | "managed" + IsCloud bool // true for claude-* models; skips verifier call +} + +// EntryFor builds a ChainEntry from a model name string. +func EntryFor(model string) ChainEntry { + cloud := strings.HasPrefix(model, "claude-") + tier := "local" + if cloud { + tier = "subagent" + } + return ChainEntry{Model: model, Tier: tier, IsCloud: cloud} +} + +// AttemptRecord captures the outcome of one tier attempt for session logging. +type AttemptRecord struct { + Model string + Tier string + DurationMs int64 + WarmStart bool + Verdict string // "accept" | "escalate" | "error" + Feedback string +} + +// VerifierFn is the interface the orchestrator uses to verify local output. +type VerifierFn interface { + Verify(ctx context.Context, skillPrompt, taskPrompt string, output Result) (Verdict, error) +} + +// ExecutorRunFn is the signature of Executor.Run and LiteLLMExecutor.Run. +type ExecutorRunFn func(ctx context.Context, req Request) (Result, error) + +// Orchestrator walks an escalation chain, delegating generation and verification. +// It implements the ExecutorFn shape expected by skill handlers. +type Orchestrator struct { + chain []ChainEntry + localRun ExecutorRunFn // for local (non-cloud) tiers; may be nil + cloudRun ExecutorRunFn // for cloud tiers; may be nil + verifier VerifierFn + llamaSwapURL string + attempts *[]AttemptRecord +} + +// NewOrchestrator creates an Orchestrator. +// attempts is a pointer to a slice that will be appended to on each tier attempt. +// Pass nil for localRun or cloudRun if no tiers of that type exist in the chain. +func NewOrchestrator( + chain []ChainEntry, + localRun ExecutorRunFn, + cloudRun ExecutorRunFn, + verifier VerifierFn, + llamaSwapURL string, + attempts *[]AttemptRecord, +) *Orchestrator { + return &Orchestrator{ + chain: chain, + localRun: localRun, + cloudRun: cloudRun, + verifier: verifier, + llamaSwapURL: llamaSwapURL, + attempts: attempts, + } +} + +// Run walks the escalation chain and returns the first accepted result. +// Satisfies the ExecutorFn signature: func(context.Context, Request) (Result, error). +func (o *Orchestrator) Run(ctx context.Context, req Request) (Result, error) { + taskPrompt := req.TaskPrompt + + for _, entry := range o.chain { + warm := o.probeWarm(entry.Model) + start := time.Now() + + tierReq := req + tierReq.Model = entry.Model + tierReq.TaskPrompt = taskPrompt + + if entry.IsCloud { + result, genErr := o.cloudRun(ctx, tierReq) + dur := time.Since(start).Milliseconds() + verdict := "accept" + if genErr != nil { + verdict = "error" + } + o.appendAttempt(AttemptRecord{ + Model: entry.Model, + Tier: entry.Tier, + DurationMs: dur, + WarmStart: warm, + Verdict: verdict, + }) + if genErr == nil { + return result, nil + } + continue + } + + // Local tier. + result, genErr := o.localRun(ctx, tierReq) + dur := time.Since(start).Milliseconds() + + if genErr != nil { + o.appendAttempt(AttemptRecord{ + Model: entry.Model, + Tier: entry.Tier, + DurationMs: dur, + WarmStart: warm, + Verdict: "error", + Feedback: genErr.Error(), + }) + continue + } + + verdict, verErr := o.verifier.Verify(ctx, req.SkillPrompt, taskPrompt, result) + if verErr != nil { + // Treat verifier failure as escalate (safe default). + o.appendAttempt(AttemptRecord{ + Model: entry.Model, + Tier: entry.Tier, + DurationMs: dur, + WarmStart: warm, + Verdict: "escalate", + Feedback: "verifier error: " + verErr.Error(), + }) + continue + } + + if verdict.Accept { + o.appendAttempt(AttemptRecord{ + Model: entry.Model, + Tier: entry.Tier, + DurationMs: dur, + WarmStart: warm, + Verdict: "accept", + }) + return result, nil + } + + o.appendAttempt(AttemptRecord{ + Model: entry.Model, + Tier: entry.Tier, + DurationMs: dur, + WarmStart: warm, + Verdict: "escalate", + Feedback: verdict.Feedback, + }) + // Inject verifier feedback into the next tier's task prompt. + taskPrompt = taskPrompt + "\n\nPrior attempt feedback: " + verdict.Feedback + } + + return Result{}, fmt.Errorf("all tiers exhausted after %d attempt(s)", len(o.chain)) +} + +func (o *Orchestrator) appendAttempt(rec AttemptRecord) { + if o.attempts != nil { + *o.attempts = append(*o.attempts, rec) + } +} + +// probeWarm checks whether the model is currently loaded in llama-swap. +// Returns false on any error or if llamaSwapURL is empty. +func (o *Orchestrator) probeWarm(model string) bool { + if o.llamaSwapURL == "" { + return false + } + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, o.llamaSwapURL+"/v1/models", nil) + if err != nil { + return false + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return false + } + defer resp.Body.Close() //nolint:errcheck + body, err := io.ReadAll(resp.Body) + if err != nil { + return false + } + return strings.Contains(string(body), model) +} diff --git a/internal/exec/orchestrator_test.go b/internal/exec/orchestrator_test.go new file mode 100644 index 0000000..c0e4774 --- /dev/null +++ b/internal/exec/orchestrator_test.go @@ -0,0 +1,151 @@ +package exec_test + +import ( + "context" + "errors" + "testing" + + iexec "github.com/mathiasbq/supervisor/internal/exec" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// stubRunFn returns preset results sequentially. +type stubRunFn struct { + calls []stubCall + callIdx int +} + +type stubCall struct { + result iexec.Result + err error +} + +func (s *stubRunFn) Run(_ context.Context, _ iexec.Request) (iexec.Result, error) { + if s.callIdx >= len(s.calls) { + return iexec.Result{}, errors.New("unexpected call") + } + c := s.calls[s.callIdx] + s.callIdx++ + return c.result, c.err +} + +// stubVerifier returns preset verdicts sequentially. +type stubVerifier struct { + verdicts []iexec.Verdict + idx int +} + +func (s *stubVerifier) Verify(_ context.Context, _, _ string, _ iexec.Result) (iexec.Verdict, error) { + if s.idx >= len(s.verdicts) { + return iexec.Verdict{}, errors.New("unexpected verify call") + } + v := s.verdicts[s.idx] + s.idx++ + return v, nil +} + +func okResult(skill string) iexec.Result { + return iexec.Result{Status: "pass", Phase: "review", Skill: skill, Message: "ok", ModelUsed: "m"} +} + +func TestOrchestratorSingleLocalAccept(t *testing.T) { + local := &stubRunFn{calls: []stubCall{{result: okResult("review")}}} + verifier := &stubVerifier{verdicts: []iexec.Verdict{{Accept: true}}} + + var attempts []iexec.AttemptRecord + orch := iexec.NewOrchestrator( + []iexec.ChainEntry{{Model: "ollama/devstral", Tier: "local", IsCloud: false}}, + local.Run, nil, verifier, "", &attempts, + ) + + result, err := orch.Run(context.Background(), iexec.Request{TaskPrompt: "review"}) + require.NoError(t, err) + assert.Equal(t, "pass", result.Status) + require.Len(t, attempts, 1) + assert.Equal(t, "local", attempts[0].Tier) + assert.Equal(t, "accept", attempts[0].Verdict) +} + +func TestOrchestratorEscalatesOnVerifierReject(t *testing.T) { + local := &stubRunFn{calls: []stubCall{ + {result: iexec.Result{Status: "fail", Phase: "review", Skill: "review", Message: "weak"}}, + {result: okResult("review")}, + }} + verifier := &stubVerifier{verdicts: []iexec.Verdict{ + {Accept: false, Feedback: "missing line refs"}, + {Accept: true}, + }} + + var attempts []iexec.AttemptRecord + orch := iexec.NewOrchestrator( + []iexec.ChainEntry{ + {Model: "ollama/devstral", Tier: "local", IsCloud: false}, + {Model: "ollama/gemma4", Tier: "local", IsCloud: false}, + }, + local.Run, nil, verifier, "", &attempts, + ) + + result, err := orch.Run(context.Background(), iexec.Request{TaskPrompt: "review"}) + require.NoError(t, err) + assert.Equal(t, "pass", result.Status) + require.Len(t, attempts, 2) + assert.Equal(t, "escalate", attempts[0].Verdict) + assert.Equal(t, "missing line refs", attempts[0].Feedback) + assert.Equal(t, "accept", attempts[1].Verdict) +} + +func TestOrchestratorEscalatesOnLocalError(t *testing.T) { + local := &stubRunFn{calls: []stubCall{ + {err: errors.New("network failure")}, + {result: okResult("review")}, + }} + verifier := &stubVerifier{verdicts: []iexec.Verdict{{Accept: true}}} + + var attempts []iexec.AttemptRecord + orch := iexec.NewOrchestrator( + []iexec.ChainEntry{ + {Model: "ollama/devstral", Tier: "local", IsCloud: false}, + {Model: "ollama/gemma4", Tier: "local", IsCloud: false}, + }, + local.Run, nil, verifier, "", &attempts, + ) + + _, err := orch.Run(context.Background(), iexec.Request{TaskPrompt: "review"}) + require.NoError(t, err) + require.Len(t, attempts, 2) + assert.Equal(t, "error", attempts[0].Verdict) + assert.Equal(t, "accept", attempts[1].Verdict) +} + +func TestOrchestratorCloudTierSelfCertifies(t *testing.T) { + cloud := &stubRunFn{calls: []stubCall{{result: okResult("review")}}} + verifier := &stubVerifier{} // no verdicts — must not be called + + var attempts []iexec.AttemptRecord + orch := iexec.NewOrchestrator( + []iexec.ChainEntry{{Model: "claude-sonnet-4-6", Tier: "subagent", IsCloud: true}}, + nil, cloud.Run, verifier, "", &attempts, + ) + + result, err := orch.Run(context.Background(), iexec.Request{TaskPrompt: "review"}) + require.NoError(t, err) + assert.Equal(t, "pass", result.Status) + require.Len(t, attempts, 1) + assert.Equal(t, "subagent", attempts[0].Tier) + assert.Equal(t, "accept", attempts[0].Verdict) + assert.Equal(t, 0, verifier.idx) // verifier never called +} + +func TestOrchestratorAllTiersExhausted(t *testing.T) { + local := &stubRunFn{calls: []stubCall{{err: errors.New("unavailable")}}} + + var attempts []iexec.AttemptRecord + orch := iexec.NewOrchestrator( + []iexec.ChainEntry{{Model: "ollama/devstral", Tier: "local", IsCloud: false}}, + local.Run, nil, &stubVerifier{}, "", &attempts, + ) + + _, err := orch.Run(context.Background(), iexec.Request{TaskPrompt: "review"}) + assert.ErrorContains(t, err, "all tiers exhausted") +}