feat(exec): add Orchestrator chain walker with verification and warm-state logging
This commit is contained in:
197
internal/exec/orchestrator.go
Normal file
197
internal/exec/orchestrator.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package exec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ChainEntry is one tier in an escalation chain.
|
||||
type ChainEntry struct {
|
||||
Model string // e.g. "ollama/phi4", "claude-sonnet-4-6"
|
||||
Tier string // "local" | "subagent" | "managed"
|
||||
IsCloud bool // true for claude-* models; skips verifier call
|
||||
}
|
||||
|
||||
// EntryFor builds a ChainEntry from a model name string.
|
||||
func EntryFor(model string) ChainEntry {
|
||||
cloud := strings.HasPrefix(model, "claude-")
|
||||
tier := "local"
|
||||
if cloud {
|
||||
tier = "subagent"
|
||||
}
|
||||
return ChainEntry{Model: model, Tier: tier, IsCloud: cloud}
|
||||
}
|
||||
|
||||
// AttemptRecord captures the outcome of one tier attempt for session logging.
|
||||
type AttemptRecord struct {
|
||||
Model string
|
||||
Tier string
|
||||
DurationMs int64
|
||||
WarmStart bool
|
||||
Verdict string // "accept" | "escalate" | "error"
|
||||
Feedback string
|
||||
}
|
||||
|
||||
// VerifierFn is the interface the orchestrator uses to verify local output.
|
||||
type VerifierFn interface {
|
||||
Verify(ctx context.Context, skillPrompt, taskPrompt string, output Result) (Verdict, error)
|
||||
}
|
||||
|
||||
// ExecutorRunFn is the signature of Executor.Run and LiteLLMExecutor.Run.
|
||||
type ExecutorRunFn func(ctx context.Context, req Request) (Result, error)
|
||||
|
||||
// Orchestrator walks an escalation chain, delegating generation and verification.
|
||||
// It implements the ExecutorFn shape expected by skill handlers.
|
||||
type Orchestrator struct {
|
||||
chain []ChainEntry
|
||||
localRun ExecutorRunFn // for local (non-cloud) tiers; may be nil
|
||||
cloudRun ExecutorRunFn // for cloud tiers; may be nil
|
||||
verifier VerifierFn
|
||||
llamaSwapURL string
|
||||
attempts *[]AttemptRecord
|
||||
}
|
||||
|
||||
// NewOrchestrator creates an Orchestrator.
|
||||
// attempts is a pointer to a slice that will be appended to on each tier attempt.
|
||||
// Pass nil for localRun or cloudRun if no tiers of that type exist in the chain.
|
||||
func NewOrchestrator(
|
||||
chain []ChainEntry,
|
||||
localRun ExecutorRunFn,
|
||||
cloudRun ExecutorRunFn,
|
||||
verifier VerifierFn,
|
||||
llamaSwapURL string,
|
||||
attempts *[]AttemptRecord,
|
||||
) *Orchestrator {
|
||||
return &Orchestrator{
|
||||
chain: chain,
|
||||
localRun: localRun,
|
||||
cloudRun: cloudRun,
|
||||
verifier: verifier,
|
||||
llamaSwapURL: llamaSwapURL,
|
||||
attempts: attempts,
|
||||
}
|
||||
}
|
||||
|
||||
// Run walks the escalation chain and returns the first accepted result.
|
||||
// Satisfies the ExecutorFn signature: func(context.Context, Request) (Result, error).
|
||||
func (o *Orchestrator) Run(ctx context.Context, req Request) (Result, error) {
|
||||
taskPrompt := req.TaskPrompt
|
||||
|
||||
for _, entry := range o.chain {
|
||||
warm := o.probeWarm(entry.Model)
|
||||
start := time.Now()
|
||||
|
||||
tierReq := req
|
||||
tierReq.Model = entry.Model
|
||||
tierReq.TaskPrompt = taskPrompt
|
||||
|
||||
if entry.IsCloud {
|
||||
result, genErr := o.cloudRun(ctx, tierReq)
|
||||
dur := time.Since(start).Milliseconds()
|
||||
verdict := "accept"
|
||||
if genErr != nil {
|
||||
verdict = "error"
|
||||
}
|
||||
o.appendAttempt(AttemptRecord{
|
||||
Model: entry.Model,
|
||||
Tier: entry.Tier,
|
||||
DurationMs: dur,
|
||||
WarmStart: warm,
|
||||
Verdict: verdict,
|
||||
})
|
||||
if genErr == nil {
|
||||
return result, nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Local tier.
|
||||
result, genErr := o.localRun(ctx, tierReq)
|
||||
dur := time.Since(start).Milliseconds()
|
||||
|
||||
if genErr != nil {
|
||||
o.appendAttempt(AttemptRecord{
|
||||
Model: entry.Model,
|
||||
Tier: entry.Tier,
|
||||
DurationMs: dur,
|
||||
WarmStart: warm,
|
||||
Verdict: "error",
|
||||
Feedback: genErr.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
verdict, verErr := o.verifier.Verify(ctx, req.SkillPrompt, taskPrompt, result)
|
||||
if verErr != nil {
|
||||
// Treat verifier failure as escalate (safe default).
|
||||
o.appendAttempt(AttemptRecord{
|
||||
Model: entry.Model,
|
||||
Tier: entry.Tier,
|
||||
DurationMs: dur,
|
||||
WarmStart: warm,
|
||||
Verdict: "escalate",
|
||||
Feedback: "verifier error: " + verErr.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if verdict.Accept {
|
||||
o.appendAttempt(AttemptRecord{
|
||||
Model: entry.Model,
|
||||
Tier: entry.Tier,
|
||||
DurationMs: dur,
|
||||
WarmStart: warm,
|
||||
Verdict: "accept",
|
||||
})
|
||||
return result, nil
|
||||
}
|
||||
|
||||
o.appendAttempt(AttemptRecord{
|
||||
Model: entry.Model,
|
||||
Tier: entry.Tier,
|
||||
DurationMs: dur,
|
||||
WarmStart: warm,
|
||||
Verdict: "escalate",
|
||||
Feedback: verdict.Feedback,
|
||||
})
|
||||
// Inject verifier feedback into the next tier's task prompt.
|
||||
taskPrompt = taskPrompt + "\n\nPrior attempt feedback: " + verdict.Feedback
|
||||
}
|
||||
|
||||
return Result{}, fmt.Errorf("all tiers exhausted after %d attempt(s)", len(o.chain))
|
||||
}
|
||||
|
||||
func (o *Orchestrator) appendAttempt(rec AttemptRecord) {
|
||||
if o.attempts != nil {
|
||||
*o.attempts = append(*o.attempts, rec)
|
||||
}
|
||||
}
|
||||
|
||||
// probeWarm checks whether the model is currently loaded in llama-swap.
|
||||
// Returns false on any error or if llamaSwapURL is empty.
|
||||
func (o *Orchestrator) probeWarm(model string) bool {
|
||||
if o.llamaSwapURL == "" {
|
||||
return false
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, o.llamaSwapURL+"/v1/models", nil)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close() //nolint:errcheck
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(string(body), model)
|
||||
}
|
||||
Reference in New Issue
Block a user