From 9a258ca32a795319f11f3791e4309c827ad53fa1 Mon Sep 17 00:00:00 2001 From: Mathias Bergqvist Date: Mon, 4 May 2026 22:51:01 +0200 Subject: [PATCH] feat(routing): router dispatch wrapper Composes Fetcher + Policy + Logger + CompleteFunc into a single Run method. Falls open to Claude on local-model errors; defaults to local when brain is unreachable. Skill packages will receive Router.Run as their CompleteFunc. Co-Authored-By: Claude Sonnet 4.6 --- internal/routing/router.go | 80 +++++++++++++++++++ internal/routing/router_test.go | 136 ++++++++++++++++++++++++++++++++ 2 files changed, 216 insertions(+) create mode 100644 internal/routing/router.go create mode 100644 internal/routing/router_test.go diff --git a/internal/routing/router.go b/internal/routing/router.go new file mode 100644 index 0000000..c808f66 --- /dev/null +++ b/internal/routing/router.go @@ -0,0 +1,80 @@ +package routing + +import ( + "context" + "fmt" + "log/slog" +) + +// CompleteFunc matches the signature used by every skill package's Config. +type CompleteFunc func(ctx context.Context, model, system, user string) (string, int64, error) + +// RunInput captures the per-call inputs the dispatch wrapper needs. +type RunInput struct { + Skill string + System string + User string + SessionID string + ProjectRoot string +} + +// Router composes a pass-rate fetcher, a decision policy, a session logger, +// and a LiteLLM client. Skill packages receive Router.Run as their CompleteFunc. +type Router struct { + Fetcher *Fetcher + Logger *Logger + Policy Policy + LocalModel string + ClaudeModel string + Complete CompleteFunc +} + +// Run executes one skill call: decides local vs claude, calls LiteLLM, logs the +// decision. On local-side error, falls open by retrying once on the Claude model. +func (r *Router) Run(ctx context.Context, in RunInput) (string, int64, error) { + pr, ferr := r.Fetcher.Get(ctx, in.Skill) + if ferr != nil { + slog.Warn("router: pass-rate unreachable, defaulting to local", "skill", in.Skill, "err", ferr) + pr = nil + } + hash := CanonicalHash(in.System, in.User) + decision := r.Policy.Decide(pr, hash) + + model := r.ClaudeModel + if decision == DecideLocal { + model = r.LocalModel + } + + out, ms, err := r.Complete(ctx, model, in.System, in.User) + _ = r.Logger.LogDecision(ctx, LogEntry{ + SessionID: in.SessionID, + Skill: in.Skill, + Decision: decision.String(), + Message: fmt.Sprintf("model=%s, pass_rate=%s", model, formatPassRate(pr)), + ProjectRoot: in.ProjectRoot, + DurationMs: ms, + Failed: err != nil, + }) + + if err != nil && decision == DecideLocal { + slog.Warn("router: local failed, falling open to claude", "skill", in.Skill, "err", err) + out, ms, err = r.Complete(ctx, r.ClaudeModel, in.System, in.User) + _ = r.Logger.LogDecision(ctx, LogEntry{ + SessionID: in.SessionID, + Skill: in.Skill, + Decision: "claude_fallback", + Message: fmt.Sprintf("model=%s, after-local-error", r.ClaudeModel), + ProjectRoot: in.ProjectRoot, + DurationMs: ms, + Failed: err != nil, + }) + } + return out, ms, err +} + +func formatPassRate(pr *float64) string { + if pr == nil { + return "null" + } + return fmt.Sprintf("%.2f", *pr) +} diff --git a/internal/routing/router_test.go b/internal/routing/router_test.go new file mode 100644 index 0000000..269fbc0 --- /dev/null +++ b/internal/routing/router_test.go @@ -0,0 +1,136 @@ +package routing_test + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/mathiasbq/supervisor/internal/routing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type fakeLLM struct { + mu sync.Mutex + calls []struct{ Model, System, User string } + resp string + err error + errOn string // if non-empty, only the named model errors +} + +func (f *fakeLLM) Complete(_ context.Context, model, system, user string) (string, int64, error) { + f.mu.Lock() + defer f.mu.Unlock() + f.calls = append(f.calls, struct{ Model, System, User string }{model, system, user}) + if f.errOn == model { + return "", 0, f.err + } + if f.err != nil && f.errOn == "" { + return "", 0, f.err + } + return f.resp, 100, nil +} + +func newRouter(t *testing.T, llm *fakeLLM, passRate float64) (*routing.Router, *httptest.Server, *httptest.Server) { + t.Helper() + brain := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/pass-rate": + _ = json.NewEncoder(w).Encode(map[string]any{"pass_rate": passRate}) + case "/mcp": + _ = json.NewEncoder(w).Encode(map[string]any{"jsonrpc": "2.0", "id": 1, "result": map[string]any{}}) + } + })) + t.Cleanup(brain.Close) + + r := &routing.Router{ + Fetcher: routing.NewFetcher(brain.URL, "7d", time.Minute), + Logger: routing.NewLogger(brain.URL), + Policy: routing.Policy{Floor: 0.9, Ceil: 0.7}, + LocalModel: "qwen35", + ClaudeModel: "claude-sonnet-4-6", + Complete: llm.Complete, + } + return r, brain, brain +} + +func TestRouterRoutesLocalAtHighPassRate(t *testing.T) { + llm := &fakeLLM{resp: "ok"} + r, _, _ := newRouter(t, llm, 0.95) + + out, _, err := r.Run(context.Background(), routing.RunInput{ + Skill: "code_review", System: "sys", User: "user", SessionID: "s1", ProjectRoot: "/p", + }) + require.NoError(t, err) + assert.Equal(t, "ok", out) + + llm.mu.Lock() + defer llm.mu.Unlock() + require.Len(t, llm.calls, 1) + assert.Equal(t, "qwen35", llm.calls[0].Model) +} + +func TestRouterRoutesClaudeAtLowPassRate(t *testing.T) { + llm := &fakeLLM{resp: "ok"} + r, _, _ := newRouter(t, llm, 0.3) + + _, _, err := r.Run(context.Background(), routing.RunInput{ + Skill: "code_review", System: "sys", User: "user", SessionID: "s2", + }) + require.NoError(t, err) + + llm.mu.Lock() + defer llm.mu.Unlock() + require.Len(t, llm.calls, 1) + assert.Equal(t, "claude-sonnet-4-6", llm.calls[0].Model) +} + +func TestRouterFailsOpenLocalErrorToClaude(t *testing.T) { + llm := &fakeLLM{resp: "ok-after-fallback", err: errors.New("local boom"), errOn: "qwen35"} + r, _, _ := newRouter(t, llm, 0.95) // would route local + + out, _, err := r.Run(context.Background(), routing.RunInput{ + Skill: "code_review", System: "sys", User: "user", SessionID: "s3", + }) + require.NoError(t, err) + assert.Equal(t, "ok-after-fallback", out) + + llm.mu.Lock() + defer llm.mu.Unlock() + require.Len(t, llm.calls, 2) + assert.Equal(t, "qwen35", llm.calls[0].Model) + assert.Equal(t, "claude-sonnet-4-6", llm.calls[1].Model) +} + +func TestRouterDefaultsToLocalWhenBrainUnreachable(t *testing.T) { + // Brain returns 500 → fetcher errors → router treats pass rate as nil → local. + brain := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "down", http.StatusInternalServerError) + })) + defer brain.Close() + + llm := &fakeLLM{resp: "ok"} + r := &routing.Router{ + Fetcher: routing.NewFetcher(brain.URL, "7d", time.Minute), + Logger: routing.NewLogger(brain.URL), + Policy: routing.Policy{Floor: 0.9, Ceil: 0.7}, + LocalModel: "qwen35", + ClaudeModel: "claude-sonnet-4-6", + Complete: llm.Complete, + } + + _, _, err := r.Run(context.Background(), routing.RunInput{ + Skill: "code_review", System: "sys", User: "user", SessionID: "s4", + }) + require.NoError(t, err) + + llm.mu.Lock() + defer llm.mu.Unlock() + require.Len(t, llm.calls, 1) + assert.Equal(t, "qwen35", llm.calls[0].Model) +}