feat(routing): router dispatch wrapper
Composes Fetcher + Policy + Logger + CompleteFunc into a single Run method. Falls open to Claude on local-model errors; defaults to local when brain is unreachable. Skill packages will receive Router.Run as their CompleteFunc. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
80
internal/routing/router.go
Normal file
80
internal/routing/router.go
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CompleteFunc matches the signature used by every skill package's Config.
|
||||||
|
type CompleteFunc func(ctx context.Context, model, system, user string) (string, int64, error)
|
||||||
|
|
||||||
|
// RunInput captures the per-call inputs the dispatch wrapper needs.
|
||||||
|
type RunInput struct {
|
||||||
|
Skill string
|
||||||
|
System string
|
||||||
|
User string
|
||||||
|
SessionID string
|
||||||
|
ProjectRoot string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Router composes a pass-rate fetcher, a decision policy, a session logger,
|
||||||
|
// and a LiteLLM client. Skill packages receive Router.Run as their CompleteFunc.
|
||||||
|
type Router struct {
|
||||||
|
Fetcher *Fetcher
|
||||||
|
Logger *Logger
|
||||||
|
Policy Policy
|
||||||
|
LocalModel string
|
||||||
|
ClaudeModel string
|
||||||
|
Complete CompleteFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run executes one skill call: decides local vs claude, calls LiteLLM, logs the
|
||||||
|
// decision. On local-side error, falls open by retrying once on the Claude model.
|
||||||
|
func (r *Router) Run(ctx context.Context, in RunInput) (string, int64, error) {
|
||||||
|
pr, ferr := r.Fetcher.Get(ctx, in.Skill)
|
||||||
|
if ferr != nil {
|
||||||
|
slog.Warn("router: pass-rate unreachable, defaulting to local", "skill", in.Skill, "err", ferr)
|
||||||
|
pr = nil
|
||||||
|
}
|
||||||
|
hash := CanonicalHash(in.System, in.User)
|
||||||
|
decision := r.Policy.Decide(pr, hash)
|
||||||
|
|
||||||
|
model := r.ClaudeModel
|
||||||
|
if decision == DecideLocal {
|
||||||
|
model = r.LocalModel
|
||||||
|
}
|
||||||
|
|
||||||
|
out, ms, err := r.Complete(ctx, model, in.System, in.User)
|
||||||
|
_ = r.Logger.LogDecision(ctx, LogEntry{
|
||||||
|
SessionID: in.SessionID,
|
||||||
|
Skill: in.Skill,
|
||||||
|
Decision: decision.String(),
|
||||||
|
Message: fmt.Sprintf("model=%s, pass_rate=%s", model, formatPassRate(pr)),
|
||||||
|
ProjectRoot: in.ProjectRoot,
|
||||||
|
DurationMs: ms,
|
||||||
|
Failed: err != nil,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil && decision == DecideLocal {
|
||||||
|
slog.Warn("router: local failed, falling open to claude", "skill", in.Skill, "err", err)
|
||||||
|
out, ms, err = r.Complete(ctx, r.ClaudeModel, in.System, in.User)
|
||||||
|
_ = r.Logger.LogDecision(ctx, LogEntry{
|
||||||
|
SessionID: in.SessionID,
|
||||||
|
Skill: in.Skill,
|
||||||
|
Decision: "claude_fallback",
|
||||||
|
Message: fmt.Sprintf("model=%s, after-local-error", r.ClaudeModel),
|
||||||
|
ProjectRoot: in.ProjectRoot,
|
||||||
|
DurationMs: ms,
|
||||||
|
Failed: err != nil,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return out, ms, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatPassRate(pr *float64) string {
|
||||||
|
if pr == nil {
|
||||||
|
return "null"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%.2f", *pr)
|
||||||
|
}
|
||||||
136
internal/routing/router_test.go
Normal file
136
internal/routing/router_test.go
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
package routing_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mathiasbq/supervisor/internal/routing"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fakeLLM struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
calls []struct{ Model, System, User string }
|
||||||
|
resp string
|
||||||
|
err error
|
||||||
|
errOn string // if non-empty, only the named model errors
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeLLM) Complete(_ context.Context, model, system, user string) (string, int64, error) {
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
f.calls = append(f.calls, struct{ Model, System, User string }{model, system, user})
|
||||||
|
if f.errOn == model {
|
||||||
|
return "", 0, f.err
|
||||||
|
}
|
||||||
|
if f.err != nil && f.errOn == "" {
|
||||||
|
return "", 0, f.err
|
||||||
|
}
|
||||||
|
return f.resp, 100, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRouter(t *testing.T, llm *fakeLLM, passRate float64) (*routing.Router, *httptest.Server, *httptest.Server) {
|
||||||
|
t.Helper()
|
||||||
|
brain := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/pass-rate":
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{"pass_rate": passRate})
|
||||||
|
case "/mcp":
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{"jsonrpc": "2.0", "id": 1, "result": map[string]any{}})
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
t.Cleanup(brain.Close)
|
||||||
|
|
||||||
|
r := &routing.Router{
|
||||||
|
Fetcher: routing.NewFetcher(brain.URL, "7d", time.Minute),
|
||||||
|
Logger: routing.NewLogger(brain.URL),
|
||||||
|
Policy: routing.Policy{Floor: 0.9, Ceil: 0.7},
|
||||||
|
LocalModel: "qwen35",
|
||||||
|
ClaudeModel: "claude-sonnet-4-6",
|
||||||
|
Complete: llm.Complete,
|
||||||
|
}
|
||||||
|
return r, brain, brain
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterRoutesLocalAtHighPassRate(t *testing.T) {
|
||||||
|
llm := &fakeLLM{resp: "ok"}
|
||||||
|
r, _, _ := newRouter(t, llm, 0.95)
|
||||||
|
|
||||||
|
out, _, err := r.Run(context.Background(), routing.RunInput{
|
||||||
|
Skill: "code_review", System: "sys", User: "user", SessionID: "s1", ProjectRoot: "/p",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "ok", out)
|
||||||
|
|
||||||
|
llm.mu.Lock()
|
||||||
|
defer llm.mu.Unlock()
|
||||||
|
require.Len(t, llm.calls, 1)
|
||||||
|
assert.Equal(t, "qwen35", llm.calls[0].Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterRoutesClaudeAtLowPassRate(t *testing.T) {
|
||||||
|
llm := &fakeLLM{resp: "ok"}
|
||||||
|
r, _, _ := newRouter(t, llm, 0.3)
|
||||||
|
|
||||||
|
_, _, err := r.Run(context.Background(), routing.RunInput{
|
||||||
|
Skill: "code_review", System: "sys", User: "user", SessionID: "s2",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
llm.mu.Lock()
|
||||||
|
defer llm.mu.Unlock()
|
||||||
|
require.Len(t, llm.calls, 1)
|
||||||
|
assert.Equal(t, "claude-sonnet-4-6", llm.calls[0].Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterFailsOpenLocalErrorToClaude(t *testing.T) {
|
||||||
|
llm := &fakeLLM{resp: "ok-after-fallback", err: errors.New("local boom"), errOn: "qwen35"}
|
||||||
|
r, _, _ := newRouter(t, llm, 0.95) // would route local
|
||||||
|
|
||||||
|
out, _, err := r.Run(context.Background(), routing.RunInput{
|
||||||
|
Skill: "code_review", System: "sys", User: "user", SessionID: "s3",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "ok-after-fallback", out)
|
||||||
|
|
||||||
|
llm.mu.Lock()
|
||||||
|
defer llm.mu.Unlock()
|
||||||
|
require.Len(t, llm.calls, 2)
|
||||||
|
assert.Equal(t, "qwen35", llm.calls[0].Model)
|
||||||
|
assert.Equal(t, "claude-sonnet-4-6", llm.calls[1].Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterDefaultsToLocalWhenBrainUnreachable(t *testing.T) {
|
||||||
|
// Brain returns 500 → fetcher errors → router treats pass rate as nil → local.
|
||||||
|
brain := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
http.Error(w, "down", http.StatusInternalServerError)
|
||||||
|
}))
|
||||||
|
defer brain.Close()
|
||||||
|
|
||||||
|
llm := &fakeLLM{resp: "ok"}
|
||||||
|
r := &routing.Router{
|
||||||
|
Fetcher: routing.NewFetcher(brain.URL, "7d", time.Minute),
|
||||||
|
Logger: routing.NewLogger(brain.URL),
|
||||||
|
Policy: routing.Policy{Floor: 0.9, Ceil: 0.7},
|
||||||
|
LocalModel: "qwen35",
|
||||||
|
ClaudeModel: "claude-sonnet-4-6",
|
||||||
|
Complete: llm.Complete,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, err := r.Run(context.Background(), routing.RunInput{
|
||||||
|
Skill: "code_review", System: "sys", User: "user", SessionID: "s4",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
llm.mu.Lock()
|
||||||
|
defer llm.mu.Unlock()
|
||||||
|
require.Len(t, llm.calls, 1)
|
||||||
|
assert.Equal(t, "qwen35", llm.calls[0].Model)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user