diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d0a6fd4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,46 @@ +# Binary (anchor to repo root so cmd/codecrucible/ is not ignored) +/codecrucible +/bin/ + +# Go +coverage.out +coverage.html + +# Environment +.env +.env.bak +.env.* + +# IDE / Editor +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# OS +.DS_Store +Thumbs.db + +# Python (legacy) +.venv/ +__pycache__/ +*.pyc +.pytest_cache/ + +# Build artifacts +dist/ + +# GEPA / experiment artifacts +scripts/ +output/ +.desloppify/ + +# Misc +~/ + +# agent +.claude/ +.amp/ +.perles* +.beads* diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..df7a4af --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,40 @@ +# Agent Instructions + +This project uses **bd** (beads) for issue tracking. Run `bd onboard` to get started. + +## Quick Reference + +```bash +bd ready # Find available work +bd show # View issue details +bd update --status in_progress # Claim work +bd close # Complete work +bd sync # Sync with git +``` + +## Landing the Plane (Session Completion) + +**When ending a work session**, you MUST complete ALL steps below. Work is NOT complete until `git push` succeeds. + +**MANDATORY WORKFLOW:** + +1. **File issues for remaining work** - Create issues for anything that needs follow-up +2. **Run quality gates** (if code changed) - Tests, linters, builds +3. **Update issue status** - Close finished work, update in-progress items +4. **PUSH TO REMOTE** - This is MANDATORY: + ```bash + git pull --rebase + bd sync + git push + git status # MUST show "up to date with origin" + ``` +5. **Clean up** - Clear stashes, prune remote branches +6. **Verify** - All changes committed AND pushed +7. **Hand off** - Provide context for next session + +**CRITICAL RULES:** +- Work is NOT complete until `git push` succeeds +- NEVER stop before pushing - that leaves work stranded locally +- NEVER say "ready to push when you are" - YOU must push +- If push fails, resolve and retry until it succeeds + diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..0278d36 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,32 @@ +# Stage 1: Build +FROM golang:1.23-alpine AS builder + +RUN apk add --no-cache git + +WORKDIR /build + +# Cache module downloads. +COPY go.mod go.sum ./ +RUN go mod download + +# Copy source and build. +COPY . . + +ARG VERSION=dev +ARG COMMIT=none +ARG DATE=unknown + +RUN CGO_ENABLED=0 GOOS=linux go build \ + -ldflags "-X 'github.com/block/codecrucible/internal/cli.version=${VERSION}' \ + -X 'github.com/block/codecrucible/internal/cli.commit=${COMMIT}' \ + -X 'github.com/block/codecrucible/internal/cli.date=${DATE}'" \ + -o codecrucible ./cmd/codecrucible + +# Stage 2: Runtime +FROM gcr.io/distroless/static-debian12:nonroot + +COPY --from=builder /build/codecrucible /usr/local/bin/codecrucible +COPY --from=builder /build/prompts /prompts + +ENTRYPOINT ["codecrucible"] +CMD ["scan", "--prompts-dir", "/prompts/default"] diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..17a5f3e --- /dev/null +++ b/Makefile @@ -0,0 +1,53 @@ +BINARY := codecrucible +MODULE := github.com/block/codecrucible +CMD := ./cmd/codecrucible + +VERSION := $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev") +COMMIT := $(shell git rev-parse --short HEAD 2>/dev/null || echo "none") +DATE := $(shell date -u +"%Y-%m-%dT%H:%M:%SZ") + +LDFLAGS := -X '$(MODULE)/internal/cli.version=$(VERSION)' \ + -X '$(MODULE)/internal/cli.commit=$(COMMIT)' \ + -X '$(MODULE)/internal/cli.date=$(DATE)' + +.PHONY: build test lint clean coverage docker-build docker-test fmt vet + +build: + go build -ldflags "$(LDFLAGS)" -o $(BINARY) $(CMD) + +test: + go test -race -count=1 ./... + +lint: + @if command -v golangci-lint >/dev/null 2>&1; then \ + golangci-lint run ./...; \ + else \ + echo "golangci-lint not installed, running go vet"; \ + go vet ./...; \ + fi + +coverage: + go test -race -coverprofile=coverage.out -covermode=atomic ./... + go tool cover -func=coverage.out + @echo "---" + @echo "HTML report: go tool cover -html=coverage.out -o coverage.html" + +fmt: + gofmt -w . + +vet: + go vet ./... + +docker-build: + docker build \ + --build-arg VERSION=$(VERSION) \ + --build-arg COMMIT=$(COMMIT) \ + --build-arg DATE=$(DATE) \ + -t $(BINARY):$(VERSION) \ + -t $(BINARY):latest . + +docker-test: + docker run --rm $(BINARY):latest --version + +clean: + rm -f $(BINARY) coverage.out coverage.html diff --git a/README.md b/README.md index 96ed047..fef99fc 100644 --- a/README.md +++ b/README.md @@ -1,36 +1,649 @@ -# codecrucible README +# codecrucible -Congrats, project leads! You got a new project to grow! +A purpose-built Go CLI tool that analyzes Git repositories for security vulnerabilities using LLM-based analysis and produces [SARIF v2.1.0](https://docs.oasis-open.org/sarif/sarif/v2.1.0/sarif-v2.1.0.html) output for GitHub Code Scanning integration. -This stub is meant to help you form a strong community around your work. It's yours to adapt, and may -diverge from this initial structure. Just keep the files seeded in this repo, and the rest is yours to evolve! +## Overview -## Introduction +codecrucible replaces the original Python/repomix pipeline with a single static Go binary. Key improvements: -Orient users to the project here. This is a good place to start with an assumption -that the user knows very little - so start with the Big Picture and show how this -project fits into it. +- **No Node.js/Python runtime** — single binary, distroless Docker image (<50MB) +- **Structured output enforcement** — JSON Schema (`response_format`) for GPT/Gemini, `tool_use` for Claude +- **Per-phase LLM configuration** — run feature detection, analysis, and audit on different models, providers, keys, and params +- **Streaming responses** — SSE for Anthropic keeps long generations alive past edge idle timeouts +- **Token-aware chunking** — large repos are split into budget-safe chunks with cross-file manifests +- **Retry with backoff** — exponential backoff on 429/5xx, `Retry-After` header respect +- **Aggressive filtering** — test, vendor, binary, and doc file exclusion saves 20–40% token budget +- **Valid SARIF every time** — even on LLM failure, partial results produce schema-valid SARIF -Then maybe a dive into what this project does. +## Quick Start -Diagrams and other visuals are helpful here. Perhaps code snippets showing usage. +```bash +# Build +make build -Project leads should complete, alongside this `README`: +# Databricks-backed scan +export DATABRICKS_HOST=https://your-workspace.databricks.com +export DATABRICKS_TOKEN=your-token -* [CODEOWNERS](./CODEOWNERS) - set project lead(s) -* [CONTRIBUTING.md](./CONTRIBUTING.md) - Fill out how to: install prereqs, build, test, run, access CI, chat, discuss, file issues -* [Bug-report.md](.github/ISSUE_TEMPLATE/bug-report.md) - Fill out `Assignees` add codeowners @names -* [config.yml](.github/ISSUE_TEMPLATE/config.yml) - remove "(/add your discord channel..)" and replace the url with your Discord channel if applicable +./codecrucible scan /path/to/repo --output results.sarif -The other files in this template repo may be used as-is: +# Direct Anthropic API scan +export ANTHROPIC_API_KEY=your-anthropic-key +./codecrucible scan /path/to/repo --provider anthropic --model claude-sonnet-4-6 -* [GOVERNANCE.md](./GOVERNANCE.md) -* [LICENSE](./LICENSE) +# Anthropic API scan with adaptive thinking always enabled +./codecrucible scan /path/to/repo \ + --provider anthropic \ + --model claude-sonnet-4-6 \ + --model-params '{"thinking":{"type":"enabled","budget_tokens":4096}}' -## Project Resources +# Or use Claude Code CLI auth (SSO/login) with no API key +claude auth status +./codecrucible scan /path/to/repo --provider anthropic --model claude-sonnet-4-6 -| Resource | Description | -| ------------------------------------------ | ------------------------------------------------------------------------------ | -| [CODEOWNERS](./CODEOWNERS) | Outlines the project lead(s) | -| [GOVERNANCE.md](./GOVERNANCE.md) | Project governance | -| [LICENSE](./LICENSE) | Apache License, Version 2.0 | +# In Claude CLI auth mode, Anthropic beta headers are forwarded via `claude --betas` +# (for example: --custom-headers "anthropic-beta: context-1m-2025-08-07"). + +# Direct OpenAI API scan +export OPENAI_API_KEY=your-openai-key +./codecrucible scan /path/to/repo --provider openai --model gpt-5.2 + +# Direct Google Gemini scan (OpenAI-compat endpoint) +export GOOGLE_API_KEY=your-google-key +./codecrucible scan /path/to/repo --provider google --model gemini-3-pro + +# Mix providers per phase: opus for analysis, gemini for audit +export ANTHROPIC_API_KEY=your-anthropic-key +export GOOGLE_API_KEY=your-google-key +./codecrucible scan /path/to/repo \ + --model claude-opus-4-6 \ + --audit-provider google --audit-model gemini-3-pro \ + --fd-model gemini-3-flash --fd-provider google + +# Preview scope without making API calls +./codecrucible scan /path/to/repo --dry-run + +# Scan specific paths in a monorepo +./codecrucible scan /path/to/repo --paths src/ --paths lib/ +``` + +When `--output results.sarif` writes to a file, CodeCrucible also writes +phase artifacts beside it: `results.feature-detection.json`, +`results.analysis.sarif`, and `results.audit.sarif` (when audit runs). Use +`--phase-output-dir DIR` to choose an explicit artifact directory, including +for stdout workflows. + +## Installation + +### From Source + +```bash +git clone && cd codecrucible +make build +``` + +### Docker + +```bash +make docker-build +docker run --rm \ + -e DATABRICKS_HOST=https://your-workspace.databricks.com \ + -e DATABRICKS_TOKEN=your-token \ + -v /path/to/repo:/repo:ro \ + codecrucible:latest scan /repo --output /dev/stdout +``` + +## Usage + +``` +codecrucible scan [repo-path] [flags] +``` + +### Per-phase LLM selection + +Four symmetric flag families. The unprefixed flags configure the analysis +phase; feature-detection, audit, and context-compress inherit any knob +they don't set. + +| analysis | feature-detection *(alias: `--fd-*`)* | audit | context-compress *(alias: `--cc-*`)* | +|-----------------------|------------------------------------------|-------------------------|-----------------------------------------| +| `--model` | `--feature-detection-model` | `--audit-model` | `--context-compress-model` | +| `--provider` | `--feature-detection-provider` | `--audit-provider` | `--context-compress-provider` | +| `--api-key` | `--feature-detection-api-key` | `--audit-api-key` | `--context-compress-api-key` | +| `--base-url` | `--feature-detection-base-url` | `--audit-base-url` | `--context-compress-base-url` | +| `--model-params` | `--feature-detection-model-params` | `--audit-model-params` | `--context-compress-model-params` | + +Providers: `anthropic`, `openai`, `google`, `ollama`, `openai-compat`, `databricks`. Auto-detected from env vars when unset. + +### Custom / Local LLMs + +Use `--provider ollama` for Ollama (no API key needed, defaults to `localhost:11434`): + +```bash +codecrucible scan --provider ollama --model llama3.1:70b --context-limit 131072 . +``` + +Use `--provider openai-compat` for any OpenAI-compatible API (vLLM, LM Studio, text-generation-inference): + +```bash +codecrucible scan --provider openai-compat --model my-model --base-url http://localhost:8000 . +``` + +### Per-Phase Overrides + +Each pipeline phase (analysis, feature-detection, audit, context-compress) can use a different provider/model. Per-phase flags are hidden from `--help` for clarity but work as documented: + +```bash +# Cheap model for feature detection, expensive model for analysis +codecrucible scan --model claude-opus-4-6 --feature-detection-model claude-sonnet-4-6 . +``` + +Per-phase flags follow the pattern `--{phase}-{flag}` (e.g. `--audit-model`, `--audit-provider`, `--audit-api-key`, `--audit-base-url`). Short aliases: `--fd-*` for feature-detection, `--cc-*` for context-compress. + +### Everything else + +``` + --audit-batch-size int split audit into N-finding batches (default 25) + --audit-confidence-threshold float reject findings below this confidence (default 0.3) + --base-url string override default provider URL + --compress compress whitespace in source files to save tokens + --concurrency int max parallel chunks (default 3) + --context-budget-pct int % of context window for supplementary context (default 15, max 40) + --context-limit int override model context window in tokens (0 = model default) + --context-source strings supplementary context: name=X,type=,location=Y + --custom-headers strings extra HTTP headers, format 'Name: Value' + --custom-requirements string additional requirements appended to the prompt + --dry-run preview scope and cost without API calls + --exclude strings glob patterns to exclude + --fail-on-severity float exit code 2 if any finding >= this severity (0-10) + --include strings glob patterns to force-include + --include-docs include documentation files in analysis + --include-tests include test files in analysis + --max-cost float maximum cost budget in dollars (default 25) + --max-file-size int exclude files larger than this (default 102400) + --max-output-tokens int override model max output tokens (0 = model default) + -o, --output string write SARIF to file (default: stdout) + --phase-output-dir string write per-phase artifacts to this directory + --paths strings paths within the repo to analyze + --prompts-dir string prompt set directory (default: prompts/default) + --request-timeout int HTTP timeout in seconds (0 = default 600s) + --skip-audit skip CWE-specific audit phase + --skip-feature-detection skip feature detection pre-pass + +Global Flags: + --config string config file (default: .codecrucible.yaml) + --verbose enable debug logging +``` + +## Prompt Sets + +The `prompts/` directory contains multiple prompt sets, each a complete set of YAML templates that control how the LLM analyzes code. The default set is `prompts/default/`. + +To use a different prompt set: + +```bash +codecrucible scan --prompts-dir prompts/carlini . +``` + +Available sets: + +| Set | Description | +|-----|-------------| +| `default` | General-purpose security analysis (used when no `--prompts-dir` is specified) | +| `carlini` | Slim Carlini-style adversarial CTF-researcher prompts | +| `carlini-curated` | Carlini v2 with targeted suppression rules and line-precision guidance | +| `exploit-proof` | Language-agnostic set that requires a concrete exploit per finding | +| `exploit-proof-c-kernel` | Kernel / driver / systems C — syscalls, copy{in,out}, locking, refcounts | +| `exploit-proof-c-userland` | C/C++ userland daemons, setuid binaries, parsers | +| `exploit-proof-rust` | Rust (`unsafe`, FFI, serde, integer casts, web frameworks) | +| `exploit-proof-solidity` | Solidity / Vyper — reentrancy, oracle manipulation, access control | +| `exploit-proof-web-go` | Go web services (net/http, gin, chi, echo, fiber, gRPC) | +| `exploit-proof-web-java` | JVM web apps (Spring, Jakarta EE, Quarkus, Micronaut, Ktor) | +| `exploit-proof-web-js` | Node.js backends and JS/TS frontends (Express, Next.js, React, Vue) | +| `exploit-proof-web-python` | Python web apps (Django, Flask, FastAPI, Starlette, Tornado, aiohttp) | +| `nano-analyzer` | Terse attacker-first voice adapted from weareaisle/nano-analyzer | + +See [SKILLS.md](SKILLS.md) for a fuller walkthrough of when to reach for each set. + +Each prompt set directory must contain: `security_analysis_base.yaml`, `analysis_sections.yaml`, `feature_detection.yaml`, `audit.yaml`, `cwe_deep_analysis.yaml`, and optionally `context_compress.yaml`. + +## Configuration + +Configuration follows a priority chain: **CLI flags > environment variables > config file > defaults**. + +### Environment Variables + +**Ambient credentials** — cascade to any phase that doesn't set its own key: + +| Variable | Description | +|----------|-------------| +| `DATABRICKS_HOST` | Databricks workspace URL | +| `DATABRICKS_TOKEN` | Bearer token for API authentication | +| `DATABRICKS_ENDPOINT` | Model serving endpoint (overrides `--model`) | +| `ANTHROPIC_API_KEY` | Anthropic API key (optional if Claude Code CLI is installed and logged in) | +| `OPENAI_API_KEY` | OpenAI API key | +| `GOOGLE_API_KEY` / `GEMINI_API_KEY` | Google AI Studio API key | +| `CODECRUCIBLE_PROVIDER` | Provider override (`databricks`, `anthropic`, `openai`, `google`) | +| `CODECRUCIBLE_MODEL_PARAMS` | JSON object merged into model request body | + +**Per-phase overrides** — `PHASES__` where `` is `ANALYSIS`, +`FEATURE_DETECTION`, `AUDIT`, or `CONTEXT_COMPRESS`: + +| Variable | Maps to | +|----------|---------| +| `PHASES_AUDIT_PROVIDER` | `--audit-provider` | +| `PHASES_AUDIT_MODEL` | `--audit-model` | +| `PHASES_AUDIT_API_KEY` | `--audit-api-key` | +| `PHASES_AUDIT_MODEL_PARAMS_JSON` | `--audit-model-params` | +| `PHASES_AUDIT_BASE_URL` | override the provider's default base URL (proxies, Azure, Vertex) | +| `PHASES_AUDIT_ENDPOINT` | Databricks serving-endpoint override for this phase | +| `PHASES_AUDIT_REQUEST_TIMEOUT` | per-phase HTTP timeout in seconds | +| `PHASES_AUDIT_CONTEXT_LIMIT` | per-phase context window override | +| `PHASES_AUDIT_MAX_OUTPUT_TOKENS` | per-phase max output override | + +Same keys with `PHASES_ANALYSIS_*`, `PHASES_FEATURE_DETECTION_*`, and +`PHASES_CONTEXT_COMPRESS_*`. Handy for wrapper scripts that inject +per-phase config via env. + +### Config File + +Create `.codecrucible.yaml` in your repo root or home directory: + +```yaml +model: claude-sonnet-4-6 +provider: databricks +include-tests: false +include-docs: false +max-cost: 25 +fail-on-severity: 7.0 +concurrency: 3 +model-params: + thinking: + type: enabled + budget_tokens: 4096 +skip-audit: false +audit-confidence-threshold: 0.3 +exclude: + - "*.generated.go" + - "**/generated/**" +``` + +### Per-Phase Configuration + +The pipeline has three LLM phases: **feature detection** (gating, skipped for +small repos), **analysis** (the main loop), and **audit** (validation). Each +can run on a different provider, model, API key, and params. + +The flat keys above (`model`, `provider`, `model-params`) configure the +analysis phase and are **inherited** by the other two. A `phases:` block +overrides selectively: + +```yaml +# Analysis: claude-opus with extended thinking. Slow, thorough, expensive. +model: claude-opus-4-6 +provider: anthropic +model-params: + thinking: + type: enabled + budget_tokens: 8192 + +phases: + # Feature detection is a cheap gating pass on a file manifest. A small + # fast model is plenty. Skipped entirely when the repo fits in one chunk. + feature-detection: + provider: google + model: gemini-3-flash + api-key: ${GOOGLE_API_KEY} + # NOT setting model-params inherits thinking-mode from analysis, which + # gemini would reject. Setting any params replaces the inherited set + # wholesale (see inheritance rules below) — so put something gemini + # actually wants: + model-params: + max_tokens: 2048 + + # Audit is a validation pass — short, structured output. Dropping + # thinking-mode and capping max_tokens keeps it fast without hurting + # quality. + audit: + model: claude-sonnet-4-6 + model-params: + max_tokens: 8192 + # provider, api-key: inherited from analysis (anthropic + its key) +``` + +**Inheritance rules** + +- Any per-phase field left at its zero value inherits from the analysis phase. +- `model-params` inherits on empty; a phase that sets its own params gets + **exactly** those params (replace, not merge) — so you can drop inherited + keys. +- `--context-limit` / `--max-output-tokens` inherit per-phase too. Previously + they only applied to the main model; now `--audit-model gemini-3-pro` with + `--context-limit 500000` gives audit the override as well. + +**Provider resolution**, per phase: explicit `---provider`, else the +model registry's hint for that phase's model, else Databricks ambient env +(Databricks proxies all providers), else whichever direct-provider key is +set, else `databricks`. + +### Supplementary Context + +Security review in isolation is pattern-matching. Security review with the +API spec, the threat model, and the sibling repo that implements the other +side of an RPC contract is *understanding*. Supplementary context feeds that +material to the analysis and audit prompts so the model can distinguish +"unvalidated input" from "input validated by the gateway three hops upstream". + +**Source types:** + +| type | `location` is… | notes | +|----------|--------------------------------------|----------------------------------------------| +| `path` | filesystem path (file or directory) | directories go through the ingest walker — `.gitignore`, binary-skip, and `include`/`exclude` globs all apply | +| `repo` | git clone URL | shallow-cloned to a temp dir, then treated as `path` | +| `url` | HTTP(S) URL | 4 MiB cap, HTML stripped to text, non-HTTP schemes and private-IP redirects refused | +| `inline` | the content itself | for short notes — "admin routes are mTLS-gated" | + +**Budget discipline.** Supplementary context shares the context window with +the scan target, so it's capped at `--context-budget-pct` (default 15%, hard +ceiling 40%). When sources exceed the cap: + +1. **Priority packing** (always): sources are sorted by `priority` descending + and packed greedily. The last source that doesn't fully fit is truncated + with a `[... N tokens truncated ...]` marker; anything after is dropped. +2. **LLM compression** (opt-in per source): sources with `compress: true` + that exceed their fair share of the budget go through a one-shot + `context-compress` pre-pass that summarises them down. Runs once per scan + on the `phases.context-compress` model — typically a cheap flash/haiku. + +Why this approach: relevance-filtering and embedding retrieval both assume +you can pick different context per chunk, but supplementary context is shared +across all chunks — the API spec is relevant everywhere. Priority packing is +deterministic and free; LLM compression handles the "200-page API reference" +case without adding an embedding-model dependency. + +```yaml +context-sources: + - name: "Payments API Spec" + type: path + location: ../api-contracts/payments/openapi.yaml + priority: 100 + phases: [analysis, audit] # empty/omitted = both phases + + - name: "Auth SDK" + type: repo + location: git@github.com:org/auth-sdk.git + priority: 80 + include: ["**/*.go"] + compress: true # squeeze via LLM if over budget + + - name: "Threat model" + type: url + location: https://wiki.internal/threat-model/payments + priority: 90 + phases: [audit] # only the auditor needs this + + - name: "Review notes" + type: inline + location: | + The /admin endpoints sit behind mTLS at the gateway. Findings + there must demonstrate gateway bypass, not just handler weakness. + priority: 70 + +context-budget-pct: 15 + +phases: + context-compress: + model: claude-haiku-4-5 # compression is a writing task, not analysis +``` + +On the CLI (scalar fields only — use the config file for globs and phase lists): + +```bash +./codecrucible scan ./target \ + --context-source 'name=spec,type=path,location=../contracts/api.yaml,priority=100' \ + --context-source 'name=notes,type=inline,location=admin is mTLS-gated,priority=50' \ + --context-budget-pct 20 \ + --cc-model claude-haiku-4-5 +``` + +**Guardrails.** Load failures (404, clone error, missing file) log a warning +and skip that source — the scan continues. If context consumes so much of the +window that less than 5000 tokens remain for actual source code, the scan +aborts before any LLM call with a clear error. + +### Model Params + +`model-params` is merged into the top level of the request body — use it for +provider-specific knobs (thinking budgets, reasoning effort, custom safety +settings). + +- YAML map form in config files; JSON string form on the CLI and in env vars. +- When both are present, the JSON string deep-merges onto the map; JSON wins + on conflict. +- Not forwarded when Anthropic falls back to Claude CLI auth. +- Unknown keys are passed through; the provider rejects what it doesn't + understand. + +## Supported Models + +| Model | Provider | Context Limit | Max Output | Structured Output | +|-------|----------|--------------|------------|-------------------| +| claude-sonnet-4-6 | anthropic | 200K | 16K | tool_use | +| claude-opus-4-6 | anthropic | 200K | 32K | tool_use | +| claude-opus-4-7 | anthropic | 1M | 128K | tool_use | +| gpt-5.2 | openai | 400K | 16K | response_format JSON Schema | +| gpt-5.4 | openai | 1M | 128K | response_format JSON Schema | +| gpt-5.5 | openai | 1M | 128K | response_format JSON Schema | +| gpt-5.4-mini | openai | 400K | 128K | response_format JSON Schema | +| gpt-5.4-nano | openai | 400K | 128K | response_format JSON Schema | +| gemini-3-pro | google | 1M | 64K | response_format JSON Schema | +| gemini-3-flash | google | 1M | 64K | response_format JSON Schema | + +Gemini goes through Google's OpenAI-compat endpoint +(`generativelanguage.googleapis.com/v1beta/openai`) — Bearer auth, OpenAI +request/response shapes. Gemini-specific features (grounding, code execution) +are not available through this path; set `phases..base-url` to point +at Vertex or a proxy if you need them. + +All providers are also reachable via Databricks model serving when +`DATABRICKS_HOST`/`DATABRICKS_TOKEN` are set. + +Unknown models get conservative defaults (128K context, 8192 max output, +unstructured). Override with `--context-limit` / `--max-output-tokens`. + +### Adding models via config + +The table above is the *built-in* registry compiled into the binary. You can +extend or override it from the config file under a `models:` key — useful when +a new model ships, when you run behind a proxy with a custom endpoint, or when +you want to retune pricing / context limits without recompiling. + +```yaml +models: + # Extend: a model the binary doesn't know about yet. + - name: claude-sonnet-4-8 + provider: anthropic + input_price_per_million: 3.0 + output_price_per_million: 15.0 + context_limit: 1000000 + max_output_tokens: 64000 + tokenizer_encoding: claude + supports_structured_output: true + + # Override: change a built-in's pricing without forking. + - name: claude-sonnet-4-6 + provider: anthropic + input_price_per_million: 1.5 # negotiated rate + output_price_per_million: 7.5 + context_limit: 200000 + max_output_tokens: 16384 + tokenizer_encoding: claude + supports_structured_output: true + + # Azure / self-hosted: point at a non-standard endpoint. + - name: azure-gpt-5 + provider: openai-compat + endpoint: deployments/my-azure-deploy/chat/completions + context_limit: 400000 + max_output_tokens: 16384 + tokenizer_encoding: o200k_base +``` + +Entries are keyed by `name`: a user entry sharing a name with a built-in +replaces it wholesale (case-insensitive), a new name extends the registry. +Empty `endpoint` defaults to `/invocations` to match the built-in +convention (Databricks serving path; other providers ignore it). `name` is +required; other fields follow the same YAML schema as the built-in registry. + + +## Architecture + +``` +Repository on disk + │ + ▼ +┌──────────────┐ +│ Walker │ filepath.WalkDir + .gitignore +├──────────────┤ +│ Filter │ Test/vendor/binary/doc exclusion +├──────────────┤ +│ Flattener │ Repomix-compatible XML with line numbers +├──────────────┤ +│ Chunker │ Token-budget-aware splitting by directory +└──────┬───────┘ + │ + │ PhaseConfig: each box has its own (provider, model, + │ api-key, base-url, model-params, timeout). Unset + │ fields inherit from analysis. Resolved once at + │ startup by config.ResolvePhases. + │ + │ repo > one chunk? + ▼ +┌─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┐ ┌──────────────────────────┐ + │ phases.feature-detection │ +│ FEATURE DETECTION (optional) │◀───│ --fd-provider │ + │ --fd-model │ +│ file manifest → features │ │ --fd-api-key │ + │ --fd-model-params │ +└─ ─ ─ ─ ─ ─ ─ ─ ┬ ─ ─ ─ ─ ─ ─ ─ ─┘ └──────────────────────────┘ + │ enabled features → prunes analysis_sections + ▼ +┌─────────────────────────────────┐ ┌──────────────────────────┐ +│ │ │ phases.analysis │ +│ ANALYSIS (main loop) │◀───│ --provider │ +│ │ │ --model │ +│ chunk 1 ──▶ findings │ │ --api-key │ +│ chunk 2 ──▶ findings } merge │ │ --model-params │ +│ chunk N ──▶ findings │ └──────────────────────────┘ +│ (concurrent, --concurrency) │ +│ │ +└────────────────┬────────────────┘ + │ initial findings + CWE categories + ▼ +┌─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┐ ┌──────────────────────────┐ + │ phases.audit │ +│ AUDIT (optional, --skip-audit)│◀───│ --audit-provider │ + │ --audit-model │ +│ findings + code + CWE prompts │ │ --audit-api-key │ +│ → confirm / reject / refine │ │ --audit-model-params │ + └──────────────────────────┘ +└─ ─ ─ ─ ─ ─ ─ ─ ┬ ─ ─ ─ ─ ─ ─ ─ ─┘ + │ audited findings + ▼ +┌─────────────────────────────────┐ +│ SARIF build + merge + dedup │ +└────────────────┬────────────────┘ + ▼ + SARIF v2.1.0 output + + phase artifacts +``` + +Each phase builds its own `llm.Client` from its `PhaseConfig` — there is no +shared client object. A fresh client per phase means the learned endpoint +constraints (`noForcedToolChoice`, `dropTemperature`) reset at phase +boundaries; at most one extra 400→retry per phase when the model rejects +a feature. + +### Project Structure + +``` +cmd/codecrucible/ CLI entry point +internal/ + cli/ Cobra commands (scan, list-models, init), pipeline orchestration + config/ Viper config, model registry, per-phase resolution (phase.go) + ingest/ File walker, filter, XML flattener, import graph + chunk/ Token counting (tiktoken), budget-aware chunking + supctx/ Supplementary-context loaders, priority packing, LLM compression + llm/ HTTP client, prompt templates, JSON schema + sarif/ SARIF types, builder, merger, post-processor, contract tests + logging/ slog-based structured logging +prompts/ Prompt sets (each subdirectory is a complete set of YAML templates) + default/ Default prompt set + carlini/ Carlini-style adversarial prompts + carlini-curated/ Carlini v2 with targeted suppression + exploit-proof/ Language-agnostic concrete-exploit gate + exploit-proof-c-kernel/ Kernel / systems C + exploit-proof-c-userland/ C/C++ userland, setuid, parsers + exploit-proof-rust/ Rust (unsafe, FFI, web frameworks) + exploit-proof-solidity/ Solidity / Vyper smart contracts + exploit-proof-web-go/ Go web services + exploit-proof-web-java/ JVM web applications + exploit-proof-web-js/ JS/TS backends and frontends + exploit-proof-web-python/ Python web applications + nano-analyzer/ Terse attacker-first voice (nano-analyzer port) +testdata/fixtures/ LLM response fixtures for contract tests +``` + +## Exit Codes + +| Code | Meaning | +|------|---------| +| 0 | Success (no findings above threshold) | +| 1 | Error (pipeline failure) | +| 2 | Findings exceed `--fail-on-severity` threshold | + +## CI Integration + +```yaml +# GitHub Actions example +- name: Security scan + run: | + ./codecrucible scan . \ + --output results.sarif \ + --fail-on-severity 7.0 + +- name: Upload SARIF + uses: github/codeql-action/upload-sarif@v3 + with: + sarif_file: results.sarif +``` + +## Utility Commands + +```bash +# List available models for a provider (alias: list-endpoints). +# Provider auto-detected from env; override with --provider. +./codecrucible list-models +./codecrucible list-models --provider anthropic +./codecrucible list-models --provider openai-compat --base-url http://localhost:8000 + +# Write a commented .codecrucible.yaml to the current directory (or given path). +# A worked example, not a schema dump — shows the per-phase override shape. +./codecrucible init +./codecrucible init --force path/to/config.yaml +``` + +## Development + +```bash +make build # Build binary +make test # Run tests with race detector +make lint # golangci-lint (or go vet fallback) +make coverage # Coverage report +make fmt # Format all Go files +make vet # go vet +``` + +## License + +See [LICENSE](LICENSE) for the full license text. diff --git a/SKILLS.md b/SKILLS.md new file mode 100644 index 0000000..56cd7b5 --- /dev/null +++ b/SKILLS.md @@ -0,0 +1,95 @@ +# Prompt Sets + +A "prompt set" is a complete bundle of YAML templates that drives every +LLM phase (feature detection, analysis, audit, CWE deep-dive, optional +context compression) with a consistent persona, threat model, and output +contract. The active set is chosen with `--prompts-dir`; the shipped sets +live under `prompts/` and are swappable without touching Go code. + +```bash +codecrucible scan /path/to/repo --prompts-dir prompts/exploit-proof-web-go +``` + +Each directory contains: + +- `security_analysis_base.yaml` — system prompt and framing (persona, threat + model, reporting rules, schema pointer) +- `analysis_sections.yaml` — the enumerated vulnerability classes the + analyzer walks through; pruned by feature detection on multi-chunk scans +- `feature_detection.yaml` — file-manifest pre-pass that decides which + sections are worth keeping +- `audit.yaml` — second-pass validator that challenges each finding +- `cwe_deep_analysis.yaml` — CWE-specific deep dives used by the audit +- `context_compress.yaml` *(optional)* — one-shot compressor for large + supplementary-context sources + +If a prompt set is missing `context_compress.yaml`, `--cc-*` compression +falls back to whichever template is available at `prompts/default/`. + +## Choosing a set + +| Set | When to reach for it | +|-----|----------------------| +| `default` | First-pass, language-agnostic, two-stage (recall-tuned analyzer + adversarial audit). The right starting point for most repos. | +| `carlini` | Slim CTF-researcher persona. Maximizes code-to-instruction ratio — use when the context budget is tight and you want the model thinking like an attacker, not a linter. | +| `carlini-curated` | Carlini with added suppression rules and tighter line-precision guidance. Lower false-positive rate at modest token cost. | +| `exploit-proof` | Language-agnostic with a *concrete exploit required* per finding. The model's own exploitation reasoning replaces the deny-list — good precision without hand-curated rules. | +| `exploit-proof-c-kernel` | Kernel / driver / protocol parser C. Persona knows syscalls, `copy_{from,to}_user`, mbuf/skbuff chains, locking, RCU, refcount lifecycles. | +| `exploit-proof-c-userland` | C/C++ userland — setuid binaries, privileged daemons, IPC endpoints, file-format parsers. Memory safety + privilege boundaries. | +| `exploit-proof-rust` | Rust. Focuses the model on `unsafe`, FFI boundaries, serde deserialization, integer casts, `Send`/`Sync` mistakes, panics as DoS. | +| `exploit-proof-solidity` | Solidity / Vyper. Reentrancy, flash-loan oracle manipulation, access control across upgrades, silent casts. | +| `exploit-proof-web-go` | Go web services — net/http, gin, chi, echo, fiber, gRPC. Go-specific footguns (loop-var capture, nil-interface, concurrent map access). | +| `exploit-proof-web-java` | JVM web — Spring (MVC/WebFlux), Jakarta EE, Quarkus, Micronaut, Ktor. JPA raw queries, readObject, Jackson default-typing, JNDI, XXE. | +| `exploit-proof-web-js` | Node.js backends and JS/TS frontends (Express, Next.js, React, Vue, Angular). Template sinks, DOM sinks, dynamic `require`, shell out. | +| `exploit-proof-web-python` | Django / Flask / FastAPI / Starlette / Tornado / aiohttp. ORM query builders, template engines, `pickle.loads`, subprocess, missing authn/authz. | +| `nano-analyzer` | Terse attacker-first walkthrough adapted from [weareaisle/nano-analyzer](https://github.com/weareaisle/nano-analyzer). Five questions per function, show-your-work style. Pairs well with larger-context reasoning models. | + +The `exploit-proof-*` language variants share the "concrete exploit or +it doesn't ship" gate from the base `exploit-proof` set but swap the +persona and the trace-to-sink list for language-specific hotspots. Pick +the variant that matches the repo's primary language; fall back to the +language-agnostic `exploit-proof` for polyglot codebases. + +## Authoring a new set + +Copy an existing directory and edit in place: + +```bash +cp -r prompts/default prompts/my-set +# tweak system_message, analysis_sections, audit criteria, ... +codecrucible scan ./target --prompts-dir prompts/my-set +``` + +Required files: `security_analysis_base.yaml`, `analysis_sections.yaml`, +`feature_detection.yaml`, `audit.yaml`, `cwe_deep_analysis.yaml`. The loader +errors loudly if any of these are missing. + +Things to keep intact when authoring: + +- **Schema conformance.** The analyzer returns JSON matching + `llm.SecurityAnalysisSchema`; the prompt must keep asking for it. Changing + the shape means changing `internal/llm` as well. +- **`{repo_name}` / `{xml_content}` placeholders.** Templated in by + `llm.PromptParams` — drop them and the scan target never reaches the model. +- **Audit cross-reference.** The audit expects analyzer findings to name a + *source*, a *sink*, and the *absent control*. Prompts that produce + free-form prose break the audit's ability to challenge individual + findings. +- **Recall bias in the analyzer, precision bias in the audit.** This is the + two-pass contract. A prompt set that pre-filters aggressively in the + analyzer gives the audit nothing to remove and hurts overall recall. + +## Prompt-set-aware flags + +Nothing about per-phase flags, supplementary context, or model selection is +tied to a specific prompt set — they compose. A common pattern is a cheap +flash model for feature detection, the main thinking model for analysis +against a specialist set, and a smaller model for audit: + +```bash +codecrucible scan ./target \ + --prompts-dir prompts/exploit-proof-web-go \ + --fd-provider google --fd-model gemini-3-flash \ + --model claude-opus-4-7 \ + --audit-model claude-sonnet-4-6 +``` diff --git a/cmd/codecrucible/main.go b/cmd/codecrucible/main.go new file mode 100644 index 0000000..17e9804 --- /dev/null +++ b/cmd/codecrucible/main.go @@ -0,0 +1,113 @@ +package main + +import ( + "fmt" + "log" + "net/http" + _ "net/http/pprof" + "os" + "runtime" + "runtime/pprof" + "runtime/trace" + + "github.com/block/codecrucible/internal/cli" +) + +type listenAndServeFunc func(addr string, handler http.Handler) error + +func main() { + if err := run(cli.Execute, http.ListenAndServe); err != nil { + log.Fatalf("codecrucible failed: %v", err) + } +} + +func run(executeCLI func(), listenAndServe listenAndServeFunc) error { + maybeStartPprofServer(os.Getenv("PPROF_ADDR"), listenAndServe) + + stopCPUProfile, err := maybeStartCPUProfile(os.Getenv("CPUPROFILE")) + if err != nil { + return err + } + defer stopCPUProfile() + + stopTrace, err := maybeStartTrace(os.Getenv("TRACEFILE")) + if err != nil { + return err + } + defer stopTrace() + + executeCLI() + + return maybeWriteMemProfile(os.Getenv("MEMPROFILE")) +} + +func maybeStartPprofServer(addr string, listenAndServe listenAndServeFunc) { + if addr == "" || listenAndServe == nil { + return + } + + go func() { + log.Printf("pprof listening on %s", addr) + if err := listenAndServe(addr, nil); err != nil { + log.Printf("pprof server stopped: %v", err) + } + }() +} + +func maybeStartCPUProfile(cpuFile string) (func(), error) { + if cpuFile == "" { + return func() {}, nil + } + + f, err := os.Create(cpuFile) + if err != nil { + return nil, fmt.Errorf("creating CPU profile: %w", err) + } + if err := pprof.StartCPUProfile(f); err != nil { + _ = f.Close() + return nil, fmt.Errorf("starting CPU profile: %w", err) + } + + return func() { + pprof.StopCPUProfile() + _ = f.Close() + }, nil +} + +func maybeStartTrace(traceFile string) (func(), error) { + if traceFile == "" { + return func() {}, nil + } + + f, err := os.Create(traceFile) + if err != nil { + return nil, fmt.Errorf("creating trace file: %w", err) + } + if err := trace.Start(f); err != nil { + _ = f.Close() + return nil, fmt.Errorf("starting trace: %w", err) + } + + return func() { + trace.Stop() + _ = f.Close() + }, nil +} + +func maybeWriteMemProfile(memFile string) error { + if memFile == "" { + return nil + } + + f, err := os.Create(memFile) + if err != nil { + return fmt.Errorf("creating memory profile: %w", err) + } + defer f.Close() + + runtime.GC() + if err := pprof.WriteHeapProfile(f); err != nil { + return fmt.Errorf("writing memory profile: %w", err) + } + return nil +} diff --git a/cmd/codecrucible/main_test.go b/cmd/codecrucible/main_test.go new file mode 100644 index 0000000..4aad24f --- /dev/null +++ b/cmd/codecrucible/main_test.go @@ -0,0 +1,114 @@ +package main + +import ( + "errors" + "net/http" + "os" + "path/filepath" + "testing" + "time" +) + +func TestRunExecutesCLIWithoutProfiles(t *testing.T) { + t.Setenv("PPROF_ADDR", "") + t.Setenv("CPUPROFILE", "") + t.Setenv("TRACEFILE", "") + t.Setenv("MEMPROFILE", "") + + executed := false + listenCalled := false + + err := run(func() { + executed = true + }, func(addr string, handler http.Handler) error { + listenCalled = true + return nil + }) + if err != nil { + t.Fatalf("run() error = %v", err) + } + if !executed { + t.Fatal("expected CLI callback to execute") + } + if listenCalled { + t.Fatal("pprof listener should not be called when PPROF_ADDR is unset") + } +} + +func TestMaybeStartCPUProfile_CreatesFile(t *testing.T) { + profilePath := filepath.Join(t.TempDir(), "cpu.prof") + + stop, err := maybeStartCPUProfile(profilePath) + if err != nil { + t.Fatalf("maybeStartCPUProfile() error = %v", err) + } + + // Generate a little CPU activity before stopping the profile. + var sink int + for i := 0; i < 100_000; i++ { + sink += i % 7 + } + _ = sink + + stop() + + info, err := os.Stat(profilePath) + if err != nil { + t.Fatalf("stat profile file: %v", err) + } + if info.Size() == 0 { + t.Fatal("expected CPU profile file to be non-empty") + } +} + +func TestMaybeStartTrace_CreatesFile(t *testing.T) { + tracePath := filepath.Join(t.TempDir(), "trace.out") + + stop, err := maybeStartTrace(tracePath) + if err != nil { + t.Fatalf("maybeStartTrace() error = %v", err) + } + stop() + + info, err := os.Stat(tracePath) + if err != nil { + t.Fatalf("stat trace file: %v", err) + } + if info.Size() == 0 { + t.Fatal("expected trace file to be non-empty") + } +} + +func TestMaybeWriteMemProfile_CreatesFile(t *testing.T) { + memPath := filepath.Join(t.TempDir(), "mem.prof") + + if err := maybeWriteMemProfile(memPath); err != nil { + t.Fatalf("maybeWriteMemProfile() error = %v", err) + } + + info, err := os.Stat(memPath) + if err != nil { + t.Fatalf("stat memory profile file: %v", err) + } + if info.Size() == 0 { + t.Fatal("expected memory profile file to be non-empty") + } +} + +func TestMaybeStartPprofServer_StartsListener(t *testing.T) { + started := make(chan struct{}, 1) + + maybeStartPprofServer("127.0.0.1:9999", func(addr string, handler http.Handler) error { + if addr != "127.0.0.1:9999" { + t.Errorf("listener addr = %q, want %q", addr, "127.0.0.1:9999") + } + started <- struct{}{} + return errors.New("stop") + }) + + select { + case <-started: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for pprof listener start") + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..8099ea2 --- /dev/null +++ b/go.mod @@ -0,0 +1,25 @@ +module github.com/block/codecrucible + +go 1.23.0 + +require ( + github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06 + github.com/spf13/cobra v1.10.2 + github.com/spf13/viper v1.21.0 + go.yaml.in/yaml/v3 v3.0.4 +) + +require ( + github.com/fsnotify/fsnotify v1.9.0 // indirect + github.com/go-viper/mapstructure/v2 v2.4.0 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/sagikazarmark/locafero v0.11.0 // indirect + github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect + github.com/spf13/afero v1.15.0 // indirect + github.com/spf13/cast v1.10.0 // indirect + github.com/spf13/pflag v1.0.10 // indirect + github.com/subosito/gotenv v1.6.0 // indirect + golang.org/x/sys v0.29.0 // indirect + golang.org/x/text v0.28.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..88f2088 --- /dev/null +++ b/go.sum @@ -0,0 +1,60 @@ +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= +github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= +github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06 h1:OkMGxebDjyw0ULyrTYWeN0UNCCkmCWfjPnIA2W6oviI= +github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06/go.mod h1:+ePHsJ1keEjQtpvf9HHw0f4ZeJ0TLRsxhunSI2hYJSs= +github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc= +github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik= +github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw= +github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U= +github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= +github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg= +github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= +github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= +github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= +github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= +github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU= +github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/chunk/chunker.go b/internal/chunk/chunker.go new file mode 100644 index 0000000..03c1aac --- /dev/null +++ b/internal/chunk/chunker.go @@ -0,0 +1,579 @@ +package chunk + +import ( + "fmt" + "log/slog" + "path/filepath" + "sort" + "strings" + + "github.com/block/codecrucible/internal/ingest" +) + +// Chunk represents a self-contained XML document with a subset of repository files. +type Chunk struct { + XML string // Self-contained XML document for this chunk. + Index int // 0-based index of this chunk. + Total int // Total number of chunks. + Paths []string // File paths included in this chunk. + Manifest []string // ALL repository file paths (for cross-file context). + Tokens int // Estimated token count for this chunk's XML. + RelatedSummaries []string // One-line summaries of files related to this chunk but not included. +} + +// ChunkOptions provides optional import-aware grouping hints. +type ChunkOptions struct { + // ImportGraph maps file path → list of local imports (from ingest.ResolveImports). + ImportGraph map[string][]string + // ExportSummaries maps file path → short one-line summary of exports. + ExportSummaries map[string]string +} + +// Chunker splits a FlattenResult into token-budget-safe chunks. +type Chunker interface { + Chunk(input ingest.FlattenResult, budget int, opts *ChunkOptions) ([]Chunk, error) +} + +// defaultChunker implements Chunker with directory-proximity grouping. +type defaultChunker struct { + counter *TokenCounter + logger *slog.Logger +} + +// NewChunker creates a Chunker that groups files by directory proximity +// under the given token budget, preserving file boundaries. +func NewChunker(counter *TokenCounter, logger *slog.Logger) Chunker { + if logger == nil { + logger = slog.Default() + } + return &defaultChunker{ + counter: counter, + logger: logger, + } +} + +// fileEntry holds precomputed data for a single file during chunking. +type fileEntry struct { + path string + content string // prebuilt XML + tokens int // cached token count + imports []string // from ImportGraph + priority int // computed score +} + +// Chunk splits the FlattenResult into chunks that each fit within the token budget. +// Files are grouped by import graph (if provided) or directory proximity. +// File boundaries are never broken. +// Files that individually exceed the budget are skipped with a warning. +// Every chunk's Manifest field contains ALL repository file paths. +func (c *defaultChunker) Chunk(input ingest.FlattenResult, budget int, opts *ChunkOptions) ([]Chunk, error) { + if budget <= 0 { + return nil, fmt.Errorf("chunk: budget must be positive, got %d", budget) + } + + // Collect all file paths for the manifest. + allPaths := sortedPaths(input.FileMap) + + // Handle empty input — return a single empty chunk. + if len(allPaths) == 0 { + tokens := input.Tokens + if tokens == 0 { + tokens = c.counter.Count(input.XML) + } + return []Chunk{{ + XML: input.XML, + Index: 0, + Total: 1, + Paths: nil, + Manifest: nil, + Tokens: tokens, + }}, nil + } + + // Check if the entire flattened XML fits within budget. + // Reuse precomputed count from FlattenResult if available. + // Skip this fast-path when input.XML is empty but FileMap is populated + // (e.g. overflow-recovery re-chunking passes only FileMap) — otherwise + // Count("") returns 0, the fast-path fires, and we emit an empty chunk. + if input.XML != "" { + totalTokens := input.Tokens + if totalTokens == 0 { + totalTokens = c.counter.Count(input.XML) + } + if totalTokens <= budget { + return []Chunk{{ + XML: input.XML, + Index: 0, + Total: 1, + Paths: allPaths, + Manifest: allPaths, + Tokens: totalTokens, + }}, nil + } + } + + // Normalise options. + if opts == nil { + opts = &ChunkOptions{} + } + + // Build fileEntry structs with cached token counts and priority scores. + entryMap := make(map[string]*fileEntry, len(allPaths)) + entries := make([]*fileEntry, 0, len(allPaths)) + for _, p := range allPaths { + content, ok := input.FileMap[p] + if !ok { + continue + } + xml := BuildFileXML(p, content) + e := &fileEntry{ + path: p, + content: xml, + tokens: c.counter.Count(xml), + imports: opts.ImportGraph[p], + priority: computePriority(p), + } + entryMap[p] = e + entries = append(entries, e) + } + + // Sort by priority descending, then path alphabetically for determinism. + sort.Slice(entries, func(i, j int) bool { + if entries[i].priority != entries[j].priority { + return entries[i].priority > entries[j].priority + } + return entries[i].path < entries[j].path + }) + + // Estimate overhead for the chunk wrapper XML (metadata, tags, paths listing). + // The manifest is handled by the prompt template, not embedded in chunk XML. + // For the section, estimate a representative per-chunk subset rather + // than the full repo — each chunk holds at most budget/avgFileTokens files. + avgFileTokens := sumTokens(entries) / max(len(entries), 1) + filesPerChunk := budget / max(avgFileTokens, 1) + if filesPerChunk > len(entries) { + filesPerChunk = len(entries) + } + representativePaths := allPaths[:min(filesPerChunk, len(allPaths))] + overheadTokens := heuristicCount(buildChunkWrapper(0, 1, representativePaths, nil)) + + effectiveBudget := budget - overheadTokens + if effectiveBudget <= 0 { + return nil, fmt.Errorf("chunk: budget %d too small after overhead %d", budget, overheadTokens) + } + + // Identify shared files: high priority + small enough to duplicate. + sharedBudget := effectiveBudget / 10 // 10% of chunk budget + const maxSharedPerChunk = 3 + const sharedTokenCap = 500 + const sharedPriorityMin = 80 + + var sharedFiles []*fileEntry + sharedTokenTotal := 0 + for _, e := range entries { + if e.priority >= sharedPriorityMin && e.tokens <= sharedTokenCap { + if sharedTokenTotal+e.tokens <= sharedBudget && len(sharedFiles) < maxSharedPerChunk { + sharedFiles = append(sharedFiles, e) + sharedTokenTotal += e.tokens + } + } + } + + sharedSet := make(map[string]bool, len(sharedFiles)) + for _, sf := range sharedFiles { + sharedSet[sf.path] = true + } + + hasImportGraph := len(opts.ImportGraph) > 0 + + // Group files into chunks. + assigned := make(map[string]bool) + var groups [][]*fileEntry + + if hasImportGraph { + // Seed-and-grow: BFS through imports. + for _, seed := range entries { + if assigned[seed.path] || sharedSet[seed.path] { + continue + } + if seed.tokens > effectiveBudget-sharedTokenTotal { + c.logger.Warn("file exceeds chunk budget, skipping", + "path", seed.path, + "file_tokens", seed.tokens, + "budget", effectiveBudget, + ) + assigned[seed.path] = true + continue + } + + group := []*fileEntry{seed} + assigned[seed.path] = true + groupTokens := seed.tokens + sharedTokenTotal + + // BFS through imports. + queue := make([]string, len(seed.imports)) + copy(queue, seed.imports) + for len(queue) > 0 { + imp := queue[0] + queue = queue[1:] + if assigned[imp] || sharedSet[imp] { + continue + } + ie, ok := entryMap[imp] + if !ok { + continue + } + if groupTokens+ie.tokens > effectiveBudget { + continue + } + group = append(group, ie) + assigned[imp] = true + groupTokens += ie.tokens + queue = append(queue, ie.imports...) + } + + // Fill remaining budget with directory-proximity files. + seedDir := filepath.Dir(seed.path) + for _, e := range entries { + if assigned[e.path] || sharedSet[e.path] { + continue + } + if filepath.Dir(e.path) != seedDir { + continue + } + if groupTokens+e.tokens > effectiveBudget { + continue + } + group = append(group, e) + assigned[e.path] = true + groupTokens += e.tokens + } + + groups = append(groups, group) + } + + // Pick up any remaining unassigned, non-shared files. + var remaining []*fileEntry + for _, e := range entries { + if !assigned[e.path] && !sharedSet[e.path] { + remaining = append(remaining, e) + } + } + if len(remaining) > 0 { + groups = append(groups, c.packGreedy(remaining, effectiveBudget-sharedTokenTotal)...) + } + } else { + // No import graph: directory-proximity sort + greedy packing (original behaviour). + sort.Slice(entries, func(i, j int) bool { + di := filepath.Dir(entries[i].path) + dj := filepath.Dir(entries[j].path) + if di != dj { + return di < dj + } + return entries[i].path < entries[j].path + }) + + var packable []*fileEntry + for _, e := range entries { + if sharedSet[e.path] { + continue + } + if e.tokens > effectiveBudget-sharedTokenTotal { + c.logger.Warn("file exceeds chunk budget, skipping", + "path", e.path, + "file_tokens", e.tokens, + "budget", effectiveBudget, + ) + continue + } + packable = append(packable, e) + } + groups = c.packGreedy(packable, effectiveBudget-sharedTokenTotal) + } + + // Merge small groups to maximize context window utilization. + // After import-graph BFS or greedy grouping, some groups may be well under + // the budget. Combining them reduces the number of LLM calls and avoids + // duplicating prompt overhead per call. + if len(groups) > 1 { + preCount := len(groups) + mergeBudget := effectiveBudget - sharedTokenTotal + groups = c.mergeGroups(groups, mergeBudget) + if len(groups) < preCount { + c.logger.Info("merged chunks to maximize context utilization", + "before", preCount, + "after", len(groups), + ) + } + } + + // If all files were skipped, return a single empty chunk. + if len(groups) == 0 { + return []Chunk{{ + XML: buildChunkXML(0, 1, nil, nil), + Index: 0, + Total: 1, + Paths: nil, + Manifest: allPaths, + Tokens: overheadTokens, + }}, nil + } + + // Build a map of which chunk each file ended up in. + fileToChunk := make(map[string]int) + for i, group := range groups { + for _, e := range group { + fileToChunk[e.path] = i + } + } + + // Build chunk XML documents. + total := len(groups) + chunks := make([]Chunk, total) + for i, group := range groups { + // Prepend shared files to each group (if they fit). + fullGroup := make([]*fileEntry, 0, len(sharedFiles)+len(group)) + for _, sf := range sharedFiles { + // Don't duplicate if already in group. + already := false + for _, g := range group { + if g.path == sf.path { + already = true + break + } + } + if !already { + fullGroup = append(fullGroup, sf) + } + } + fullGroup = append(fullGroup, group...) + + // Sort paths within chunk for deterministic output. + sort.Slice(fullGroup, func(a, b int) bool { + return fullGroup[a].path < fullGroup[b].path + }) + + paths := make([]string, len(fullGroup)) + var filesXML strings.Builder + for j, e := range fullGroup { + paths[j] = e.path + filesXML.WriteString(e.content) + } + + // Compute cross-chunk summaries. + var relatedSummaries []string + if len(opts.ExportSummaries) > 0 { + chunkPaths := make(map[string]bool, len(paths)) + for _, p := range paths { + chunkPaths[p] = true + } + seen := make(map[string]bool) + for _, e := range fullGroup { + for _, imp := range e.imports { + if !chunkPaths[imp] && !seen[imp] { + seen[imp] = true + if summary, ok := opts.ExportSummaries[imp]; ok { + relatedSummaries = append(relatedSummaries, imp+": "+summary) + } + } + } + } + if len(relatedSummaries) > 50 { + relatedSummaries = relatedSummaries[:50] + } + } + + xml := buildChunkXMLWithFiles(i, total, paths, nil, filesXML.String()) + chunks[i] = Chunk{ + XML: xml, + Index: i, + Total: total, + Paths: paths, + Manifest: allPaths, + Tokens: c.counter.Count(xml), + RelatedSummaries: relatedSummaries, + } + } + + return chunks, nil +} + +// mergeGroups combines small file-entry groups to maximize context window utilization. +// Groups are merged greedily: adjacent groups are combined as long as their total +// token count fits within the budget. +func (c *defaultChunker) mergeGroups(groups [][]*fileEntry, budget int) [][]*fileEntry { + if len(groups) <= 1 { + return groups + } + + var merged [][]*fileEntry + current := groups[0] + currentTokens := sumTokens(current) + + for i := 1; i < len(groups); i++ { + nextTokens := sumTokens(groups[i]) + if currentTokens+nextTokens <= budget { + current = append(current, groups[i]...) + currentTokens += nextTokens + } else { + merged = append(merged, current) + current = groups[i] + currentTokens = nextTokens + } + } + merged = append(merged, current) + + return merged +} + +// sumTokens returns the total token count across all file entries in a group. +func sumTokens(entries []*fileEntry) int { + total := 0 + for _, e := range entries { + total += e.tokens + } + return total +} + +// packGreedy groups file entries using greedy bin packing. +func (c *defaultChunker) packGreedy(entries []*fileEntry, effectiveBudget int) [][]*fileEntry { + var groups [][]*fileEntry + var current []*fileEntry + currentTokens := 0 + + for _, e := range entries { + if currentTokens+e.tokens > effectiveBudget && len(current) > 0 { + groups = append(groups, current) + current = nil + currentTokens = 0 + } + current = append(current, e) + currentTokens += e.tokens + } + if len(current) > 0 { + groups = append(groups, current) + } + + return groups +} + +// computePriority assigns a priority score to a file based on its path. +func computePriority(p string) int { + dir := filepath.Dir(p) + base := filepath.Base(p) + name := strings.TrimSuffix(base, filepath.Ext(base)) + + // Check directory-based priorities. + parts := strings.Split(filepath.ToSlash(dir), "/") + for _, part := range parts { + switch part { + case "routes", "handlers", "controllers", "middleware", "api": + return 100 + } + } + + // Check filename-based priorities. + switch name { + case "main", "server", "app", "index": + return 90 + } + + // Check lib/utils/auth/security directories. + for _, part := range parts { + switch part { + case "lib", "utils", "auth", "security": + return 80 + } + } + + // Config/data files. + switch base { + case "package.json", "tsconfig.json", "go.mod", "go.sum", + "Cargo.toml", "requirements.txt", "Makefile", ".gitignore", + "Dockerfile", "docker-compose.yml": + return 10 + } + + return 50 +} + +// sortedPaths returns the sorted keys of a FileMap. +func sortedPaths(fm ingest.FileMap) []string { + if len(fm) == 0 { + return nil + } + paths := make([]string, 0, len(fm)) + for p := range fm { + paths = append(paths, p) + } + sort.Strings(paths) + return paths +} + +// BuildFileXML produces the XML element for a single file with numbered lines, +// matching the format used by the flattener. +func BuildFileXML(path, content string) string { + var b strings.Builder + escapedPath := ingest.EscapeXMLAttr(path) + b.WriteString(fmt.Sprintf("\n", escapedPath)) + + if content != "" { + lines := ingest.SplitLines(content) + totalLines := len(lines) + padding := len(fmt.Sprintf("%d", totalLines)) + + for i, line := range lines { + num := i + 1 + b.WriteString(fmt.Sprintf("%*d | %s\n", padding, num, ingest.EscapeXMLContent(line))) + } + } + + b.WriteString("\n") + return b.String() +} + +// buildChunkWrapper produces the wrapper XML without file content, for overhead estimation. +func buildChunkWrapper(index, total int, paths, manifest []string) string { + return buildChunkXMLWithFiles(index, total, paths, manifest, "") +} + +// buildChunkXML produces a complete chunk XML document with file content generated from paths. +func buildChunkXML(index, total int, paths, manifest []string) string { + return buildChunkXMLWithFiles(index, total, paths, manifest, "") +} + +// buildChunkXMLWithFiles produces a complete self-contained chunk XML document. +func buildChunkXMLWithFiles(index, total int, paths, manifest []string, filesContent string) string { + var b strings.Builder + b.Grow(len(filesContent) + 512) + + b.WriteString("\n") + b.WriteString(fmt.Sprintf("\n", index, total)) + + if len(paths) > 0 { + b.WriteString("\n") + for _, p := range paths { + b.WriteString(ingest.EscapeXMLContent(p)) + b.WriteByte('\n') + } + b.WriteString("\n") + } + + if len(manifest) > 0 { + b.WriteString("\n") + for _, p := range manifest { + b.WriteString(ingest.EscapeXMLContent(p)) + b.WriteByte('\n') + } + b.WriteString("\n") + } + + if filesContent != "" { + b.WriteString("\n") + b.WriteString(filesContent) + b.WriteString("\n") + } + + b.WriteString("\n") + return b.String() +} diff --git a/internal/chunk/chunker_test.go b/internal/chunk/chunker_test.go new file mode 100644 index 0000000..2ac1f37 --- /dev/null +++ b/internal/chunk/chunker_test.go @@ -0,0 +1,727 @@ +package chunk + +import ( + "fmt" + "log/slog" + "strings" + "testing" + + "github.com/block/codecrucible/internal/ingest" +) + +// helper to create a FlattenResult from files. +func makeFlattenResult(files []ingest.SourceFile) ingest.FlattenResult { + return ingest.Flatten(files, ingest.FlattenConfig{}) +} + +func newTestChunker(encoding string) Chunker { + tc := NewTokenCounter(encoding, slog.Default()) + return NewChunker(tc, slog.Default()) +} + +func TestChunk_EmptyInput(t *testing.T) { + ch := newTestChunker("cl100k_base") + input := makeFlattenResult(nil) + + chunks, err := ch.Chunk(input, 10000, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(chunks) != 1 { + t.Fatalf("expected 1 chunk, got %d", len(chunks)) + } + + c := chunks[0] + if c.Index != 0 { + t.Errorf("Index = %d, want 0", c.Index) + } + if c.Total != 1 { + t.Errorf("Total = %d, want 1", c.Total) + } + if c.Paths != nil { + t.Errorf("Paths should be nil for empty input, got %v", c.Paths) + } + if c.Manifest != nil { + t.Errorf("Manifest should be nil for empty input, got %v", c.Manifest) + } + if c.Tokens <= 0 { + t.Errorf("Tokens should be > 0 (empty XML still has structure), got %d", c.Tokens) + } +} + +func TestChunk_SingleFileWithinBudget(t *testing.T) { + ch := newTestChunker("cl100k_base") + files := []ingest.SourceFile{ + {Path: "main.go", Content: "package main\n\nfunc main() {}\n", LineCount: 3, Language: "go"}, + } + input := makeFlattenResult(files) + + chunks, err := ch.Chunk(input, 100000, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(chunks) != 1 { + t.Fatalf("expected 1 chunk for small input, got %d", len(chunks)) + } + + c := chunks[0] + if c.Index != 0 { + t.Errorf("Index = %d, want 0", c.Index) + } + if c.Total != 1 { + t.Errorf("Total = %d, want 1", c.Total) + } + if len(c.Paths) != 1 || c.Paths[0] != "main.go" { + t.Errorf("Paths = %v, want [main.go]", c.Paths) + } + if len(c.Manifest) != 1 || c.Manifest[0] != "main.go" { + t.Errorf("Manifest = %v, want [main.go]", c.Manifest) + } + // Single chunk should use the original flattened XML. + if c.XML != input.XML { + t.Error("single chunk should return original flattened XML") + } +} + +func TestChunk_MultipleFilesWithinBudget(t *testing.T) { + ch := newTestChunker("cl100k_base") + files := []ingest.SourceFile{ + {Path: "a.go", Content: "package a\n", LineCount: 1, Language: "go"}, + {Path: "b.go", Content: "package b\n", LineCount: 1, Language: "go"}, + {Path: "c.go", Content: "package c\n", LineCount: 1, Language: "go"}, + } + input := makeFlattenResult(files) + + chunks, err := ch.Chunk(input, 100000, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(chunks) != 1 { + t.Fatalf("expected 1 chunk, got %d", len(chunks)) + } + + c := chunks[0] + if len(c.Paths) != 3 { + t.Errorf("Paths should have 3 entries, got %d", len(c.Paths)) + } + if len(c.Manifest) != 3 { + t.Errorf("Manifest should have 3 entries, got %d", len(c.Manifest)) + } +} + +func TestChunk_OversizedRepoProducesMultipleChunks(t *testing.T) { + ch := newTestChunker("cl100k_base") + + // Create many files that will exceed a small budget. + var files []ingest.SourceFile + for i := 0; i < 50; i++ { + content := fmt.Sprintf("package pkg%d\n\n// This is file %d with some content to consume tokens.\nfunc Handler%d() string {\n\treturn \"handler %d result\"\n}\n", i, i, i, i) + files = append(files, ingest.SourceFile{ + Path: fmt.Sprintf("pkg%d/handler.go", i), + Content: content, + LineCount: 6, + Language: "go", + }) + } + input := makeFlattenResult(files) + + // Use a small budget that forces splitting. + totalTokens := NewTokenCounter("cl100k_base", slog.Default()).Count(input.XML) + budget := totalTokens / 3 + + chunks, err := ch.Chunk(input, budget, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(chunks) < 2 { + t.Fatalf("expected multiple chunks, got %d (total tokens: %d, budget: %d)", len(chunks), totalTokens, budget) + } + + // Verify each chunk is under budget. + tc := NewTokenCounter("cl100k_base", slog.Default()) + for i, c := range chunks { + tokens := tc.Count(c.XML) + if tokens > budget { + t.Errorf("chunk %d has %d tokens, exceeds budget %d", i, tokens, budget) + } + } +} + +func TestChunk_FileBoundariesPreserved(t *testing.T) { + ch := newTestChunker("cl100k_base") + + // Use a budget that forces 2 chunks (very small). + // We need to make files big enough that they get split. + bigFiles := []ingest.SourceFile{ + {Path: "alpha.go", Content: strings.Repeat("// alpha line\n", 500), LineCount: 500, Language: "go"}, + {Path: "beta.go", Content: strings.Repeat("// beta line\n", 500), LineCount: 500, Language: "go"}, + } + input := makeFlattenResult(bigFiles) + + tc := NewTokenCounter("cl100k_base", slog.Default()) + totalTokens := tc.Count(input.XML) + budget := totalTokens / 2 + + chunks, err := ch.Chunk(input, budget, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify no file appears in two chunks (file boundaries preserved). + seen := make(map[string]int) + for i, c := range chunks { + for _, p := range c.Paths { + if prev, ok := seen[p]; ok { + t.Errorf("file %q appears in chunk %d and chunk %d (file boundary broken)", p, prev, i) + } + seen[p] = i + } + } +} + +func TestChunk_DirectoryProximityGrouping(t *testing.T) { + ch := newTestChunker("cl100k_base") + + // Create files in different directories with enough content to force splitting. + files := []ingest.SourceFile{ + {Path: "api/handler_a.go", Content: strings.Repeat("// api handler a\n", 200), LineCount: 200, Language: "go"}, + {Path: "api/handler_b.go", Content: strings.Repeat("// api handler b\n", 200), LineCount: 200, Language: "go"}, + {Path: "db/query_a.go", Content: strings.Repeat("// db query a\n", 200), LineCount: 200, Language: "go"}, + {Path: "db/query_b.go", Content: strings.Repeat("// db query b\n", 200), LineCount: 200, Language: "go"}, + } + input := makeFlattenResult(files) + + tc := NewTokenCounter("cl100k_base", slog.Default()) + totalTokens := tc.Count(input.XML) + // Budget that forces ~2 chunks. + budget := totalTokens / 2 + + chunks, err := ch.Chunk(input, budget, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(chunks) < 2 { + t.Skipf("budget %d didn't force split (total: %d), skipping proximity test", budget, totalTokens) + } + + // Find which chunk each file is in. + fileChunk := make(map[string]int) + for i, c := range chunks { + for _, p := range c.Paths { + fileChunk[p] = i + } + } + + // Files in the same directory should be in the same chunk when possible. + if chA, ok1 := fileChunk["api/handler_a.go"]; ok1 { + if chB, ok2 := fileChunk["api/handler_b.go"]; ok2 { + if chA != chB { + t.Logf("api/ files split across chunks %d and %d (acceptable if budget-constrained)", chA, chB) + } + } + } + + // At minimum, files should be sorted by directory in their chunk paths. + for _, c := range chunks { + for i := 1; i < len(c.Paths); i++ { + // Paths should be in directory order. + if c.Paths[i] < c.Paths[i-1] { + t.Errorf("paths not in directory order: %q before %q", c.Paths[i-1], c.Paths[i]) + } + } + } +} + +func TestChunk_SingleFileExceedsBudget_SkippedWithWarning(t *testing.T) { + ch := newTestChunker("cl100k_base") + + // Create one large file and one small file. + files := []ingest.SourceFile{ + {Path: "huge.go", Content: strings.Repeat("// very long line of code\n", 5000), LineCount: 5000, Language: "go"}, + {Path: "small.go", Content: "package small\n", LineCount: 1, Language: "go"}, + } + input := makeFlattenResult(files) + + // Budget that fits small.go but not huge.go. + budget := 500 + + chunks, err := ch.Chunk(input, budget, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // huge.go should be skipped; small.go should be in a chunk. + for _, c := range chunks { + for _, p := range c.Paths { + if p == "huge.go" { + t.Error("huge.go should be skipped (exceeds budget)") + } + } + } + + // small.go should appear in some chunk. + found := false + for _, c := range chunks { + for _, p := range c.Paths { + if p == "small.go" { + found = true + } + } + } + if !found { + t.Error("small.go should appear in a chunk") + } +} + +func TestChunk_ManifestContainsAllPaths(t *testing.T) { + ch := newTestChunker("cl100k_base") + + // Create files that will be split across multiple chunks. + var files []ingest.SourceFile + for i := 0; i < 30; i++ { + files = append(files, ingest.SourceFile{ + Path: fmt.Sprintf("dir%d/file.go", i), + Content: strings.Repeat(fmt.Sprintf("// content for file %d\n", i), 100), + LineCount: 100, + Language: "go", + }) + } + input := makeFlattenResult(files) + + tc := NewTokenCounter("cl100k_base", slog.Default()) + totalTokens := tc.Count(input.XML) + budget := totalTokens / 3 + + chunks, err := ch.Chunk(input, budget, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(chunks) < 2 { + t.Skipf("didn't produce multiple chunks, skipping manifest test") + } + + // Every chunk's Manifest should contain ALL file paths. + allPaths := make(map[string]bool) + for _, f := range files { + allPaths[f.Path] = true + } + + for i, c := range chunks { + manifestSet := make(map[string]bool) + for _, p := range c.Manifest { + manifestSet[p] = true + } + + for p := range allPaths { + if !manifestSet[p] { + t.Errorf("chunk %d Manifest missing path %q", i, p) + } + } + } +} + +func TestChunk_MetadataCorrect(t *testing.T) { + ch := newTestChunker("cl100k_base") + + // Create enough files to produce multiple chunks. + var files []ingest.SourceFile + for i := 0; i < 40; i++ { + files = append(files, ingest.SourceFile{ + Path: fmt.Sprintf("pkg%d/main.go", i), + Content: strings.Repeat(fmt.Sprintf("// line %d content padding\n", i), 100), + LineCount: 100, + Language: "go", + }) + } + input := makeFlattenResult(files) + + tc := NewTokenCounter("cl100k_base", slog.Default()) + totalTokens := tc.Count(input.XML) + budget := totalTokens / 4 + + chunks, err := ch.Chunk(input, budget, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + for i, c := range chunks { + if c.Index != i { + t.Errorf("chunk %d has Index=%d", i, c.Index) + } + if c.Total != len(chunks) { + t.Errorf("chunk %d has Total=%d, want %d", i, c.Total, len(chunks)) + } + } +} + +func TestChunk_XMLContainsMetadata(t *testing.T) { + ch := newTestChunker("cl100k_base") + + var files []ingest.SourceFile + for i := 0; i < 30; i++ { + files = append(files, ingest.SourceFile{ + Path: fmt.Sprintf("pkg%d/file.go", i), + Content: strings.Repeat("// padding content\n", 200), + LineCount: 200, + Language: "go", + }) + } + input := makeFlattenResult(files) + + tc := NewTokenCounter("cl100k_base", slog.Default()) + totalTokens := tc.Count(input.XML) + budget := totalTokens / 3 + + chunks, err := ch.Chunk(input, budget, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(chunks) < 2 { + t.Skip("didn't produce multiple chunks") + } + + for i, c := range chunks { + // Check for metadata in XML. + expectedMeta := fmt.Sprintf(``, c.Index, c.Total) + if !strings.Contains(c.XML, expectedMeta) { + t.Errorf("chunk %d XML missing metadata tag %q", i, expectedMeta) + } + + // Check for wrapper. + if !strings.Contains(c.XML, "") { + t.Errorf("chunk %d XML missing wrapper", i) + } + if !strings.Contains(c.XML, "") { + t.Errorf("chunk %d XML missing closing tag", i) + } + + // Manifest is no longer embedded in chunk XML (handled by prompt template). + if strings.Contains(c.XML, "") { + t.Errorf("chunk %d XML should not contain (handled by prompt)", i) + } + + // Check for section. + if len(c.Paths) > 0 && !strings.Contains(c.XML, "") { + t.Errorf("chunk %d XML missing section", i) + } + + // Check for section. + if len(c.Paths) > 0 && !strings.Contains(c.XML, "") { + t.Errorf("chunk %d XML missing section", i) + } + } +} + +func TestChunk_InvalidBudget(t *testing.T) { + ch := newTestChunker("cl100k_base") + input := makeFlattenResult(nil) + + _, err := ch.Chunk(input, 0, nil) + if err == nil { + t.Error("expected error for zero budget") + } + + _, err = ch.Chunk(input, -100, nil) + if err == nil { + t.Error("expected error for negative budget") + } +} + +func TestChunk_AllFilesExceedBudget(t *testing.T) { + ch := newTestChunker("cl100k_base") + + // Create files that all exceed a very small budget. + files := []ingest.SourceFile{ + {Path: "big1.go", Content: strings.Repeat("// big content\n", 500), LineCount: 500, Language: "go"}, + {Path: "big2.go", Content: strings.Repeat("// big content\n", 500), LineCount: 500, Language: "go"}, + } + input := makeFlattenResult(files) + + // Very small budget. + chunks, err := ch.Chunk(input, 100, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Should return a chunk (possibly empty of files) rather than error. + if len(chunks) == 0 { + t.Error("should return at least one chunk even when all files skipped") + } + + // The chunk should have no file paths (all skipped). + for _, c := range chunks { + if len(c.Paths) > 0 { + t.Errorf("all files should be skipped, but chunk has paths: %v", c.Paths) + } + } +} + +func TestChunk_EachChunkUnderBudget(t *testing.T) { + ch := newTestChunker("cl100k_base") + + // Use 20 files with significant content each so that budget/3 still leaves + // room for manifest overhead per chunk. + var files []ingest.SourceFile + for i := 0; i < 20; i++ { + files = append(files, ingest.SourceFile{ + Path: fmt.Sprintf("src/pkg%d/impl.go", i), + Content: strings.Repeat(fmt.Sprintf("// implementation %d content\n", i), 50), + LineCount: 50, + Language: "go", + }) + } + input := makeFlattenResult(files) + + tc := NewTokenCounter("cl100k_base", slog.Default()) + totalTokens := tc.Count(input.XML) + budget := totalTokens / 3 + + chunks, err := ch.Chunk(input, budget, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(chunks) < 2 { + t.Fatalf("expected multiple chunks, got %d", len(chunks)) + } + + for i, c := range chunks { + chunkTokens := tc.Count(c.XML) + if chunkTokens > budget { + t.Errorf("chunk %d has %d tokens, exceeds budget %d", i, chunkTokens, budget) + } + } +} + +func TestChunk_HeuristicFallback(t *testing.T) { + // Test with heuristic (empty encoding, simulating Gemini). + ch := newTestChunker("") + files := []ingest.SourceFile{ + {Path: "main.go", Content: "package main\n\nfunc main() {}\n", LineCount: 3, Language: "go"}, + } + input := makeFlattenResult(files) + + chunks, err := ch.Chunk(input, 100000, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(chunks) != 1 { + t.Errorf("expected 1 chunk, got %d", len(chunks)) + } + if chunks[0].Tokens <= 0 { + t.Error("tokens should be positive even with heuristic counting") + } +} + +func TestChunk_ChunkXMLSelfContained(t *testing.T) { + ch := newTestChunker("cl100k_base") + + // Create enough content to force multiple chunks. + var files []ingest.SourceFile + for i := 0; i < 50; i++ { + files = append(files, ingest.SourceFile{ + Path: fmt.Sprintf("dir%d/file.go", i), + Content: strings.Repeat(fmt.Sprintf("// content %d\n", i), 100), + LineCount: 100, + Language: "go", + }) + } + input := makeFlattenResult(files) + + tc := NewTokenCounter("cl100k_base", slog.Default()) + totalTokens := tc.Count(input.XML) + budget := totalTokens / 3 + + chunks, err := ch.Chunk(input, budget, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(chunks) < 2 { + t.Skip("didn't produce multiple chunks") + } + + for i, c := range chunks { + // Each chunk XML should be self-contained — start with and end with . + if !strings.HasPrefix(c.XML, "\n") { + t.Errorf("chunk %d XML doesn't start with ", i) + } + if !strings.HasSuffix(c.XML, "\n") { + t.Errorf("chunk %d XML doesn't end with ", i) + } + + // Should contain file content. + for _, p := range c.Paths { + if !strings.Contains(c.XML, p) { + t.Errorf("chunk %d XML doesn't contain its file path %q", i, p) + } + } + } +} + +func TestChunk_SingleChunkUsesOriginalXML(t *testing.T) { + ch := newTestChunker("cl100k_base") + files := []ingest.SourceFile{ + {Path: "main.go", Content: "package main\n", LineCount: 1, Language: "go"}, + } + input := makeFlattenResult(files) + + chunks, err := ch.Chunk(input, 100000, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(chunks) != 1 { + t.Fatalf("expected 1 chunk, got %d", len(chunks)) + } + + // The single-chunk path returns the original flattened XML (no wrapping). + if chunks[0].XML != input.XML { + t.Error("single chunk should use original flattened XML verbatim") + } +} + +// Regression: overflow-recovery re-chunking passes a FlattenResult with only +// FileMap populated (XML=""). The fast-path used to fire on Count("")==0 and +// emit a single chunk with empty XML and Tokens=0, sending no code to the LLM. +func TestChunk_FileMapOnlyInput_SkipsFastPath(t *testing.T) { + ch := newTestChunker("cl100k_base") + fm := ingest.FileMap{ + "a.go": "package a\nfunc A() {}\n", + "b.go": "package b\nfunc B() {}\n", + "c.go": "package c\nfunc C() {}\n", + } + input := ingest.FlattenResult{FileMap: fm} // XML="", Tokens=0 + + chunks, err := ch.Chunk(input, 100000, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(chunks) == 0 { + t.Fatal("expected at least one chunk") + } + for i, c := range chunks { + if c.XML == "" { + t.Errorf("chunk %d: XML is empty", i) + } + if c.Tokens == 0 { + t.Errorf("chunk %d: Tokens is 0", i) + } + for _, p := range c.Paths { + if !strings.Contains(c.XML, p) { + t.Errorf("chunk %d: XML missing file %s", i, p) + } + } + } +} + +func TestChunk_MergesSmallImportGroups(t *testing.T) { + ch := newTestChunker("cl100k_base") + + // Create files each in their own directory with no import connections. + // BFS will create many 1-file groups since there are no directory neighbors + // or import links. Merging should combine them into fewer, larger chunks. + var files []ingest.SourceFile + importGraph := make(map[string][]string) + for i := 0; i < 20; i++ { + path := fmt.Sprintf("service%d/main.go", i) + files = append(files, ingest.SourceFile{ + Path: path, + Content: strings.Repeat(fmt.Sprintf("// service %d code\n", i), 50), + LineCount: 50, + Language: "go", + }) + importGraph[path] = nil + } + input := makeFlattenResult(files) + + tc := NewTokenCounter("cl100k_base", slog.Default()) + totalTokens := tc.Count(input.XML) + // Budget ~1/3 of total forces chunking but allows merging groups together. + budget := totalTokens / 3 + + opts := &ChunkOptions{ImportGraph: importGraph} + chunks, err := ch.Chunk(input, budget, opts) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Without merging: 20 groups (one per isolated file). + // With merging: should be ~3-4 chunks. + if len(chunks) > 6 { + t.Errorf("expected merging to reduce chunk count, got %d (total=%d, budget=%d)", len(chunks), totalTokens, budget) + } + + // Verify all files are present. + allPaths := make(map[string]bool) + for _, c := range chunks { + for _, p := range c.Paths { + allPaths[p] = true + } + } + if len(allPaths) != 20 { + t.Errorf("expected 20 files in merged chunks, got %d", len(allPaths)) + } + + // Verify chunks are under budget. + for i, c := range chunks { + chunkTokens := tc.Count(c.XML) + if chunkTokens > budget { + t.Errorf("merged chunk %d has %d tokens, exceeds budget %d", i, chunkTokens, budget) + } + } +} + +func TestEscapeXMLContent(t *testing.T) { + tests := []struct { + name string + in string + want string + }{ + {"no special chars", "hello world", "hello world"}, + {"empty", "", ""}, + {"ampersand", "a & b", "a & b"}, + {"less than", "a < b", "a < b"}, + {"greater than", "a > b", "a > b"}, + {"all three", " & ", "<a> & <b>"}, + {"adjacent", "&<>", "&<>"}, + {"quotes untouched", `"hello" 'world'`, `"hello" 'world'`}, + {"unicode preserved", "héllo & wörld", "héllo & wörld"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ingest.EscapeXMLContent(tt.in); got != tt.want { + t.Errorf("EscapeXMLContent(%q) = %q, want %q", tt.in, got, tt.want) + } + }) + } +} + +func TestChunk_PathsAreSorted(t *testing.T) { + ch := newTestChunker("cl100k_base") + files := []ingest.SourceFile{ + {Path: "z.go", Content: "package z\n", LineCount: 1, Language: "go"}, + {Path: "a.go", Content: "package a\n", LineCount: 1, Language: "go"}, + {Path: "m.go", Content: "package m\n", LineCount: 1, Language: "go"}, + } + input := makeFlattenResult(files) + + chunks, err := ch.Chunk(input, 100000, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(chunks) != 1 { + t.Fatalf("expected 1 chunk, got %d", len(chunks)) + } + + paths := chunks[0].Paths + for i := 1; i < len(paths); i++ { + if paths[i] < paths[i-1] { + t.Errorf("paths not sorted: %q before %q", paths[i-1], paths[i]) + } + } +} diff --git a/internal/chunk/token.go b/internal/chunk/token.go new file mode 100644 index 0000000..1557a5f --- /dev/null +++ b/internal/chunk/token.go @@ -0,0 +1,78 @@ +package chunk + +import ( + "log/slog" + "math" +) + +// TokenCounter counts tokens using a fast content-aware heuristic. +// The pipeline's 20% tokenizer safety margin on the context limit absorbs +// estimation error, making expensive BPE encoding unnecessary. +type TokenCounter struct { + logger *slog.Logger +} + +// NewTokenCounter creates a TokenCounter. The encoding parameter is accepted +// for API compatibility but is not used; all counting uses the fast heuristic. +func NewTokenCounter(encoding string, logger *slog.Logger) *TokenCounter { + if logger == nil { + logger = slog.Default() + } + + return &TokenCounter{ + logger: logger, + } +} + +// Count returns the estimated number of tokens in the given text. +// Uses a fast heuristic (3-4 chars per token + 10% safety margin) that avoids +// the extreme cost of BPE encoding through tiktoken-go's regexp2 engine. +func (tc *TokenCounter) Count(text string) int { + if text == "" { + return 0 + } + return heuristicCount(text) +} + +// heuristicCount estimates token count using a content-aware chars/token ratio +// plus a 10% safety margin. Code tokenizes much more densely than prose — every +// brace, paren, operator, and separator becomes its own token — so a fixed 4.0 +// undercounts C/Rust/Go by 25-35%. +func heuristicCount(text string) int { + chars := len(text) // byte length ≈ char length for ASCII-heavy source code + ratio := charsPerTokenRatio(text) + raw := float64(chars) / ratio + return int(math.Ceil(raw * 1.1)) +} + +// charsPerTokenRatio samples the first 4KB to classify the text and pick a +// chars/token ratio. Punctuation density is a strong signal: prose runs ~2%, +// source code runs 12-20%. Returns a ratio between 3.0 (dense code) and 4.0 +// (prose). The sample is small to keep counting cheap even on large files. +func charsPerTokenRatio(text string) float64 { + sample := text + if len(sample) > 4096 { + sample = sample[:4096] + } + if len(sample) == 0 { + return 4.0 + } + var punct int + for i := 0; i < len(sample); i++ { + switch sample[i] { + case '{', '}', '(', ')', '[', ']', ';', ',', '.', ':', + '<', '>', '=', '-', '+', '*', '/', '&', '|', '!', + '"', '\'', '#', '%', '^', '~', '?', '@', '\\': + punct++ + } + } + density := float64(punct) / float64(len(sample)) + switch { + case density > 0.12: + return 3.0 + case density > 0.06: + return 3.5 + default: + return 4.0 + } +} diff --git a/internal/chunk/token_test.go b/internal/chunk/token_test.go new file mode 100644 index 0000000..8953f37 --- /dev/null +++ b/internal/chunk/token_test.go @@ -0,0 +1,111 @@ +package chunk + +import ( + "log/slog" + "strings" + "testing" +) + +func TestNewTokenCounter_NilLogger(t *testing.T) { + tc := NewTokenCounter("cl100k_base", nil) + if tc.logger == nil { + t.Error("nil logger should be replaced with default") + } +} + +func TestCount_EmptyString(t *testing.T) { + tc := NewTokenCounter("cl100k_base", slog.Default()) + if got := tc.Count(""); got != 0 { + t.Errorf("Count(\"\") = %d, want 0", got) + } +} + +func TestCount_NonEmpty(t *testing.T) { + tc := NewTokenCounter("cl100k_base", slog.Default()) + got := tc.Count("hello world") + if got <= 0 { + t.Errorf("Count(\"hello world\") should be > 0, got %d", got) + } +} + +func TestHeuristicCount_Prose(t *testing.T) { + // Alphabetic text has zero punctuation → 4.0 chars/token ratio. + tests := []struct { + name string + text string + want int // ceil(len(bytes)/4.0 * 1.1) + }{ + {"4 chars", "abcd", 2}, // 4/4=1, *1.1=1.1, ceil=2 + {"8 chars", "abcdefgh", 3}, // 8/4=2, *1.1=2.2, ceil=3 + {"12 chars", "abcdefghijkl", 4}, // 12/4=3, *1.1=3.3, ceil=4 + {"1 char", "a", 1}, // 1/4=0.25, *1.1=0.275, ceil=1 + {"empty", "", 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := heuristicCount(tt.text) + if got != tt.want { + t.Errorf("heuristicCount(%q) = %d, want %d", tt.text, got, tt.want) + } + }) + } +} + +func TestHeuristicCount_CodeDenserThanProse(t *testing.T) { + // Same byte length, different punctuation density → code should + // produce a higher token estimate. + prose := "the quick brown fox jumps over the lazy dog again and again today" + code := "if(a->b){x[i]=f(y,z);r=*p&&q;}else{g();h(m|n);}/*t*/return(v!=w);" + if len(prose) != len(code) { + t.Fatalf("test inputs must be equal length: prose=%d code=%d", len(prose), len(code)) + } + proseTokens := heuristicCount(prose) + codeTokens := heuristicCount(code) + if codeTokens <= proseTokens { + t.Errorf("code should estimate higher than prose at equal byte length: code=%d prose=%d", codeTokens, proseTokens) + } +} + +func TestCharsPerTokenRatio(t *testing.T) { + tests := []struct { + name string + text string + want float64 + }{ + {"empty", "", 4.0}, + {"prose", "the quick brown fox jumps over the lazy dog and runs through the field", 4.0}, + {"light punct", "one. two. three. four. five six seven eight nine.", 3.5}, + {"c code", "void f(int *p) { if (p != NULL) { *p = (a + b) * c; } }", 3.0}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := charsPerTokenRatio(tt.text) + if got != tt.want { + t.Errorf("charsPerTokenRatio(%q) = %v, want %v", tt.text, got, tt.want) + } + }) + } +} + +func TestCount_LargeInput(t *testing.T) { + tc := NewTokenCounter("cl100k_base", slog.Default()) + + // Generate a large input. + var b strings.Builder + for i := 0; i < 1000; i++ { + b.WriteString("func handler(w http.ResponseWriter, r *http.Request) {\n") + b.WriteString("\tw.Write([]byte(\"hello\"))\n") + b.WriteString("}\n\n") + } + + tokens := tc.Count(b.String()) + if tokens <= 0 { + t.Errorf("large input should have positive token count, got %d", tokens) + } + + // Each handler is ~84 bytes → ~23 heuristic tokens. 1000 handlers → ~23000. + if tokens < 5000 || tokens > 50000 { + t.Errorf("token count %d seems unreasonable for 1000 handler functions", tokens) + } +} diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go new file mode 100644 index 0000000..eca63bf --- /dev/null +++ b/internal/cli/cli_test.go @@ -0,0 +1,202 @@ +package cli + +import ( + "bytes" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestRootCommand_Version(t *testing.T) { + cmd := NewRootCommand() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{"--version"}) + + if err := cmd.Execute(); err != nil { + t.Fatalf("Execute failed: %v", err) + } + + output := buf.String() + if !strings.Contains(output, "codecrucible version") { + t.Errorf("expected version output, got: %s", output) + } +} + +func TestRootCommand_Help(t *testing.T) { + cmd := NewRootCommand() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{"--help"}) + + if err := cmd.Execute(); err != nil { + t.Fatalf("Execute failed: %v", err) + } + + output := buf.String() + expectedFlags := []string{"--verbose", "--config"} + for _, f := range expectedFlags { + if !strings.Contains(output, f) { + t.Errorf("expected %q in help output, got: %s", f, output) + } + } +} + +func TestScanCommand_Help(t *testing.T) { + cmd := NewRootCommand() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{"scan", "--help"}) + + if err := cmd.Execute(); err != nil { + t.Fatalf("Execute failed: %v", err) + } + + output := buf.String() + expectedFlags := []string{ + "--paths", + "--model", + "--fail-on-severity", + "--max-cost", + "--dry-run", + "--include-tests", + "--include-docs", + "--compress", + "--custom-requirements", + "--custom-headers", + "--output", + "--phase-output-dir", + "--prompts-dir", + "--include", + "--exclude", + } + for _, f := range expectedFlags { + if !strings.Contains(output, f) { + t.Errorf("expected %q in scan --help output, got: %s", f, output) + } + } +} + +func TestScanCommand_DryRun(t *testing.T) { + cmd := NewRootCommand() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetErr(buf) + cmd.SetArgs([]string{"scan", "--dry-run"}) + + if err := cmd.Execute(); err != nil { + t.Fatalf("Execute failed: %v", err) + } +} + +func TestRootCommand_UnknownFlag(t *testing.T) { + cmd := NewRootCommand() + cmd.SetArgs([]string{"--nonexistent-flag"}) + + err := cmd.Execute() + if err == nil { + t.Fatal("expected error for unknown flag, got nil") + } + + if !strings.Contains(err.Error(), "unknown flag") { + t.Errorf("expected 'unknown flag' in error, got: %v", err) + } +} + +func TestScanCommand_UnknownFlag(t *testing.T) { + cmd := NewRootCommand() + cmd.SetArgs([]string{"scan", "--bad-flag"}) + + err := cmd.Execute() + if err == nil { + t.Fatal("expected error for unknown scan flag, got nil") + } + + if !strings.Contains(err.Error(), "unknown flag") { + t.Errorf("expected 'unknown flag' in error, got: %v", err) + } +} + +func TestRootCommand_UnknownSubcommand(t *testing.T) { + cmd := NewRootCommand() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetErr(buf) + cmd.SetArgs([]string{"nonexistent"}) + + err := cmd.Execute() + if err == nil { + t.Fatal("expected error for unknown subcommand, got nil") + } +} + +func TestScanCommand_GlobalFlagsInherited(t *testing.T) { + cmd := NewRootCommand() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{"scan", "--help"}) + + if err := cmd.Execute(); err != nil { + t.Fatalf("Execute failed: %v", err) + } + + output := buf.String() + // Global flags should be listed under scan --help as well. + globalFlags := []string{"--verbose", "--config"} + for _, f := range globalFlags { + if !strings.Contains(output, f) { + t.Errorf("expected global flag %q in scan help output", f) + } + } +} + +func TestScanCommand_OutputShortFlag(t *testing.T) { + cmd := NewRootCommand() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{"scan", "--help"}) + + if err := cmd.Execute(); err != nil { + t.Fatalf("Execute failed: %v", err) + } + + output := buf.String() + if !strings.Contains(output, "-o") { + t.Errorf("expected -o shorthand for --output in scan help") + } +} + +func TestScanCommand_MalformedConfigReturnsError(t *testing.T) { + dir := t.TempDir() + cfgPath := filepath.Join(dir, "bad.yaml") + if err := os.WriteFile(cfgPath, []byte(":\n bad: [yaml\n unclosed"), 0644); err != nil { + t.Fatalf("writing config file: %v", err) + } + + cmd := NewRootCommand() + cmd.SetArgs([]string{"--config", cfgPath, "scan", "--dry-run"}) + + err := cmd.Execute() + if err == nil { + t.Fatal("expected error for malformed config file, got nil") + } +} + +func TestVersionInfo_ContainsExpectedFields(t *testing.T) { + cmd := NewRootCommand() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{"--version"}) + + if err := cmd.Execute(); err != nil { + t.Fatalf("Execute failed: %v", err) + } + + output := buf.String() + for _, field := range []string{"commit:", "built:"} { + if !strings.Contains(output, field) { + t.Errorf("expected %q in version output, got: %s", field, output) + } + } +} diff --git a/internal/cli/endpoints.go b/internal/cli/endpoints.go new file mode 100644 index 0000000..65f89b0 --- /dev/null +++ b/internal/cli/endpoints.go @@ -0,0 +1,443 @@ +package cli + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "text/tabwriter" + "time" + + "github.com/block/codecrucible/internal/config" + "github.com/spf13/cobra" +) + +func newListEndpointsCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "list-models", + Aliases: []string{"list-endpoints"}, + Short: "List available models/endpoints for a provider", + Long: `Queries a model provider for its available models and prints a table. + +Provider is chosen by --provider if set. Otherwise, Databricks is used when +DATABRICKS_HOST + DATABRICKS_TOKEN are set; failing that, the first of +ANTHROPIC_API_KEY / OPENAI_API_KEY / GOOGLE_API_KEY present is used. + +Supported providers: databricks, anthropic, openai, google, ollama.`, + RunE: runListEndpoints, + } + cmd.Flags().String("provider", "", "provider to query (databricks, anthropic, openai, google, ollama)") + cmd.Flags().String("base-url", "", "override provider base URL (useful for OpenAI-compat endpoints)") + return cmd +} + +// modelEntry is the unified row format rendered by listEndpointsTable. +type modelEntry struct { + Name string // identifier a user would pass to --model / DATABRICKS_ENDPOINT + State string // READY / ACTIVE / "-" (not all providers expose a state) + Model string // underlying model name; often equal to Name for direct providers + Usage string // hint for how to target this entry on subsequent scans +} + +// ===== Command dispatcher ============================================ + +func runListEndpoints(cmd *cobra.Command, args []string) error { + cfg, err := config.Load(v) + if err != nil { + return fmt.Errorf("loading config: %w", err) + } + + provider, err := resolveListProvider(cmd, cfg) + if err != nil { + return err + } + + var entries []modelEntry + switch provider { + case "databricks": + entries, err = listDatabricks(cmd.Context(), cfg) + case "anthropic": + entries, err = listAnthropic(cmd.Context(), cfg, cmdStringFlag(cmd, "base-url")) + case "openai": + entries, err = listOpenAI(cmd.Context(), cfg, cmdStringFlag(cmd, "base-url")) + case "google": + entries, err = listGoogle(cmd.Context(), cfg, cmdStringFlag(cmd, "base-url")) + case "ollama": + entries, err = listOllama(cmd.Context(), cmdStringFlag(cmd, "base-url")) + default: + return fmt.Errorf("unsupported provider %q (supported: databricks, anthropic, openai, google, ollama)", provider) + } + if err != nil { + return err + } + + if len(entries) == 0 { + fmt.Fprintf(cmd.OutOrStdout(), "No models returned for provider %q.\n", provider) + return nil + } + + w := tabwriter.NewWriter(cmd.OutOrStdout(), 0, 4, 2, ' ', 0) + fmt.Fprintln(w, "ENDPOINT\tSTATE\tMODEL\tUSAGE") + for _, e := range entries { + fmt.Fprintf(w, "%s\t%s\t%s\t%s\n", e.Name, emptyDash(e.State), emptyDash(e.Model), emptyDash(e.Usage)) + } + return w.Flush() +} + +// resolveListProvider picks a provider: explicit flag > databricks (if creds +// present, preserving existing behavior) > first provider with an API key. +func resolveListProvider(cmd *cobra.Command, cfg *config.Config) (string, error) { + if p := strings.TrimSpace(cmdStringFlag(cmd, "provider")); p != "" { + return strings.ToLower(p), nil + } + if p := strings.TrimSpace(cfg.Provider); p != "" { + return strings.ToLower(p), nil + } + // If either Databricks env var is set, route to the Databricks lister so + // its specific "DATABRICKS_HOST/TOKEN is not set" error surfaces rather + // than the generic no-provider message. + if cfg.DatabricksHost != "" || cfg.DatabricksToken != "" { + return "databricks", nil + } + if cfg.AnthropicAPIKey != "" { + return "anthropic", nil + } + if cfg.OpenAIAPIKey != "" { + return "openai", nil + } + if cfg.GoogleAPIKey != "" { + return "google", nil + } + return "", fmt.Errorf("no provider specified and no credentials detected; pass --provider or set one of ANTHROPIC_API_KEY / OPENAI_API_KEY / GOOGLE_API_KEY / (DATABRICKS_HOST + DATABRICKS_TOKEN)") +} + +func cmdStringFlag(cmd *cobra.Command, name string) string { + if cmd == nil || cmd.Flags() == nil { + return "" + } + s, _ := cmd.Flags().GetString(name) + return s +} + +func emptyDash(s string) string { + if s == "" { + return "-" + } + return s +} + +// ===== Databricks =================================================== + +// databricksEndpointList is the response from the Databricks serving-endpoints API. +type databricksEndpointList struct { + Endpoints []databricksEndpoint `json:"endpoints"` +} + +type databricksEndpoint struct { + Name string `json:"name"` + State databricksState `json:"state"` + Config databricksEndpointCfg `json:"config"` + Creator string `json:"creator"` +} + +type databricksState struct { + Ready string `json:"ready"` + ConfigUpdate string `json:"config_update"` +} + +type databricksEndpointCfg struct { + ServedEntities []databricksServedEntity `json:"served_entities"` + ServedModels []databricksServedModel `json:"served_models"` +} + +type databricksServedEntity struct { + Name string `json:"name"` + ExternalModel *databricksExternalModel `json:"external_model"` + FoundationModel *databricksFoundationModel `json:"foundation_model"` +} + +type databricksExternalModel struct { + Name string `json:"name"` + Provider string `json:"provider"` +} + +type databricksFoundationModel struct { + Name string `json:"name"` +} + +type databricksServedModel struct { + Name string `json:"name"` + ModelName string `json:"model_name"` +} + +func listDatabricks(ctx context.Context, cfg *config.Config) ([]modelEntry, error) { + if cfg.DatabricksHost == "" { + return nil, fmt.Errorf("DATABRICKS_HOST is not set") + } + if cfg.DatabricksToken == "" { + return nil, fmt.Errorf("DATABRICKS_TOKEN is not set") + } + + host := strings.TrimRight(cfg.DatabricksHost, "/") + url := host + "/api/2.0/serving-endpoints" + + body, err := httpGetJSON(ctx, url, http.Header{ + "Authorization": []string{"Bearer " + cfg.DatabricksToken}, + }) + if err != nil { + return nil, fmt.Errorf("querying Databricks: %w", err) + } + + var result databricksEndpointList + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("parsing response: %w", err) + } + + entries := make([]modelEntry, 0, len(result.Endpoints)) + for _, ep := range result.Endpoints { + entries = append(entries, modelEntry{ + Name: ep.Name, + State: ep.State.Ready, + Model: describeModel(ep), + Usage: usageHint(ep.Name), + }) + } + return entries, nil +} + +// describeModel extracts a human-readable model name from the endpoint config. +func describeModel(ep databricksEndpoint) string { + for _, e := range ep.Config.ServedEntities { + if e.ExternalModel != nil { + return e.ExternalModel.Provider + "/" + e.ExternalModel.Name + } + if e.FoundationModel != nil { + return e.FoundationModel.Name + } + if e.Name != "" { + return e.Name + } + } + for _, m := range ep.Config.ServedModels { + if m.ModelName != "" { + return m.ModelName + } + } + return "-" +} + +// usageHint shows how to use this endpoint with codecrucible (Databricks flavor). +func usageHint(name string) string { + if _, ok := config.LookupModel(name); ok { + return "auto-detected" + } + return "DATABRICKS_ENDPOINT=" + name +} + +// ===== Anthropic ==================================================== + +// anthropicModelList is the response from GET /v1/models. +type anthropicModelList struct { + Data []struct { + ID string `json:"id"` + DisplayName string `json:"display_name"` + Type string `json:"type"` + } `json:"data"` +} + +func listAnthropic(ctx context.Context, cfg *config.Config, baseURL string) ([]modelEntry, error) { + if cfg.AnthropicAPIKey == "" { + return nil, fmt.Errorf("ANTHROPIC_API_KEY is not set") + } + base := firstNonEmpty(baseURL, cfg.BaseURL, "https://api.anthropic.com") + url := strings.TrimRight(base, "/") + "/v1/models" + + body, err := httpGetJSON(ctx, url, http.Header{ + "x-api-key": []string{cfg.AnthropicAPIKey}, + "anthropic-version": []string{"2023-06-01"}, + }) + if err != nil { + return nil, fmt.Errorf("querying Anthropic: %w", err) + } + + var list anthropicModelList + if err := json.Unmarshal(body, &list); err != nil { + return nil, fmt.Errorf("parsing response: %w", err) + } + + entries := make([]modelEntry, 0, len(list.Data)) + for _, m := range list.Data { + entries = append(entries, modelEntry{ + Name: m.ID, + State: "available", + Model: firstNonEmpty(m.DisplayName, m.ID), + Usage: "--provider anthropic --model " + m.ID, + }) + } + return entries, nil +} + +// ===== OpenAI ======================================================= + +// openaiModelList is the response from GET /v1/models. +type openaiModelList struct { + Data []struct { + ID string `json:"id"` + OwnedBy string `json:"owned_by"` + } `json:"data"` +} + +func listOpenAI(ctx context.Context, cfg *config.Config, baseURL string) ([]modelEntry, error) { + if cfg.OpenAIAPIKey == "" { + return nil, fmt.Errorf("OPENAI_API_KEY is not set") + } + base := firstNonEmpty(baseURL, cfg.BaseURL, "https://api.openai.com") + url := strings.TrimRight(base, "/") + "/v1/models" + + body, err := httpGetJSON(ctx, url, http.Header{ + "Authorization": []string{"Bearer " + cfg.OpenAIAPIKey}, + }) + if err != nil { + return nil, fmt.Errorf("querying OpenAI: %w", err) + } + + var list openaiModelList + if err := json.Unmarshal(body, &list); err != nil { + return nil, fmt.Errorf("parsing response: %w", err) + } + + entries := make([]modelEntry, 0, len(list.Data)) + for _, m := range list.Data { + entries = append(entries, modelEntry{ + Name: m.ID, + State: "available", + Model: firstNonEmpty(m.OwnedBy, m.ID), + Usage: "--provider openai --model " + m.ID, + }) + } + return entries, nil +} + +// ===== Google (OpenAI-compat layer) ================================= + +func listGoogle(ctx context.Context, cfg *config.Config, baseURL string) ([]modelEntry, error) { + if cfg.GoogleAPIKey == "" { + return nil, fmt.Errorf("GOOGLE_API_KEY is not set") + } + // Google's OpenAI-compat layer mirrors OpenAI's /models endpoint, mounted + // under /v1beta/openai/. + base := firstNonEmpty(baseURL, cfg.BaseURL, "https://generativelanguage.googleapis.com/v1beta/openai") + url := strings.TrimRight(base, "/") + "/models" + + body, err := httpGetJSON(ctx, url, http.Header{ + "Authorization": []string{"Bearer " + cfg.GoogleAPIKey}, + }) + if err != nil { + return nil, fmt.Errorf("querying Google: %w", err) + } + + var list openaiModelList + if err := json.Unmarshal(body, &list); err != nil { + return nil, fmt.Errorf("parsing response: %w", err) + } + + entries := make([]modelEntry, 0, len(list.Data)) + for _, m := range list.Data { + entries = append(entries, modelEntry{ + Name: m.ID, + State: "available", + Model: firstNonEmpty(m.OwnedBy, m.ID), + Usage: "--provider google --model " + m.ID, + }) + } + return entries, nil +} + +// ===== Ollama ======================================================= + +// ollamaTags is the response from GET /api/tags on an Ollama server. +type ollamaTags struct { + Models []struct { + Name string `json:"name"` + Model string `json:"model"` + } `json:"models"` +} + +func listOllama(ctx context.Context, baseURL string) ([]modelEntry, error) { + base := firstNonEmpty(baseURL, "http://localhost:11434") + url := strings.TrimRight(base, "/") + "/api/tags" + + body, err := httpGetJSON(ctx, url, nil) + if err != nil { + return nil, fmt.Errorf("querying Ollama: %w", err) + } + + var tags ollamaTags + if err := json.Unmarshal(body, &tags); err != nil { + return nil, fmt.Errorf("parsing response: %w", err) + } + + entries := make([]modelEntry, 0, len(tags.Models)) + for _, m := range tags.Models { + name := firstNonEmpty(m.Name, m.Model) + entries = append(entries, modelEntry{ + Name: name, + State: "local", + Model: name, + Usage: "--provider ollama --model " + name, + }) + } + return entries, nil +} + +// ===== Shared helpers =============================================== + +// httpGetJSON performs an authenticated GET and returns the raw body, +// raising a formatted error on any non-2xx response. +func httpGetJSON(ctx context.Context, url string, headers http.Header) ([]byte, error) { + client := &http.Client{Timeout: 30 * time.Second} + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + for k, vs := range headers { + for _, val := range vs { + req.Header.Add(k, val) + } + } + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading response: %w", err) + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, truncateStr(string(body), 300)) + } + return body, nil +} + +func firstNonEmpty(values ...string) string { + for _, v := range values { + if v != "" { + return v + } + } + return "" +} + +// truncateStr shortens a string (avoids collision with scan.go's truncate via llm package). +func truncateStr(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] + "..." +} diff --git a/internal/cli/endpoints_test.go b/internal/cli/endpoints_test.go new file mode 100644 index 0000000..fd200a6 --- /dev/null +++ b/internal/cli/endpoints_test.go @@ -0,0 +1,122 @@ +package cli + +import ( + "bytes" + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +func TestRunListEndpoints_Success(t *testing.T) { + var gotAuthHeader string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuthHeader = r.Header.Get("Authorization") + if r.URL.Path != "/api/2.0/serving-endpoints" { + http.Error(w, "unexpected path", http.StatusNotFound) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "endpoints": [ + { + "name": "gpt-5.2", + "state": {"ready": "READY"}, + "config": {"served_models": [{"model_name": "gpt-5.2"}]} + } + ] + }`)) + })) + defer server.Close() + + oldV := v + t.Cleanup(func() { v = oldV }) + + v = viper.New() + v.Set("databricks-host", server.URL) + v.Set("databricks-token", "test-token") + + cmd := &cobra.Command{} + out := new(bytes.Buffer) + cmd.SetOut(out) + cmd.SetContext(context.Background()) + + if err := runListEndpoints(cmd, nil); err != nil { + t.Fatalf("runListEndpoints() error = %v", err) + } + if gotAuthHeader != "Bearer test-token" { + t.Fatalf("Authorization header = %q, want %q", gotAuthHeader, "Bearer test-token") + } + + output := out.String() + for _, want := range []string{"ENDPOINT", "gpt-5.2", "auto-detected"} { + if !strings.Contains(output, want) { + t.Fatalf("output missing %q:\n%s", want, output) + } + } +} + +func TestRunListEndpoints_MissingHost(t *testing.T) { + oldV := v + t.Cleanup(func() { v = oldV }) + + v = viper.New() + v.Set("databricks-token", "token") + + err := runListEndpoints(&cobra.Command{}, nil) + if err == nil || !strings.Contains(err.Error(), "DATABRICKS_HOST is not set") { + t.Fatalf("expected missing host error, got %v", err) + } +} + +func TestRunListEndpoints_MissingToken(t *testing.T) { + oldV := v + t.Cleanup(func() { v = oldV }) + + v = viper.New() + v.Set("databricks-host", "https://example.com") + + err := runListEndpoints(&cobra.Command{}, nil) + if err == nil || !strings.Contains(err.Error(), "DATABRICKS_TOKEN is not set") { + t.Fatalf("expected missing token error, got %v", err) + } +} + +func TestDescribeModel_PrefersExternalModel(t *testing.T) { + endpoint := databricksEndpoint{ + Config: databricksEndpointCfg{ + ServedEntities: []databricksServedEntity{ + { + ExternalModel: &databricksExternalModel{Provider: "openai", Name: "gpt-5.2"}, + }, + }, + }, + } + + got := describeModel(endpoint) + if got != "openai/gpt-5.2" { + t.Fatalf("describeModel() = %q, want %q", got, "openai/gpt-5.2") + } +} + +func TestUsageHint(t *testing.T) { + if got := usageHint("gpt-5.2"); got != "auto-detected" { + t.Fatalf("usageHint() for known model = %q, want %q", got, "auto-detected") + } + if got := usageHint("custom-endpoint"); got != "DATABRICKS_ENDPOINT=custom-endpoint" { + t.Fatalf("usageHint() for unknown model = %q", got) + } +} + +func TestTruncateStr(t *testing.T) { + if got := truncateStr("short", 10); got != "short" { + t.Fatalf("truncateStr short = %q, want %q", got, "short") + } + if got := truncateStr("abcdefghijklmnopqrstuvwxyz", 5); got != "abcde..." { + t.Fatalf("truncateStr long = %q, want %q", got, "abcde...") + } +} diff --git a/internal/cli/init.go b/internal/cli/init.go new file mode 100644 index 0000000..b3185c0 --- /dev/null +++ b/internal/cli/init.go @@ -0,0 +1,155 @@ +package cli + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" +) + +func newInitCommand() *cobra.Command { + var force bool + cmd := &cobra.Command{ + Use: "init [path]", + Short: "Write a commented .codecrucible.yaml config file", + Long: `Writes a commented example config to .codecrucible.yaml (or the given path). + +The file is a worked example, not a schema dump — it shows the per-phase +provider/model override shape and the flags that actually matter when running +against large repos with thinking models. Delete what you don't need.`, + Args: cobra.MaximumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + path := ".codecrucible.yaml" + if len(args) == 1 { + path = args[0] + } + if !force { + if _, err := os.Stat(path); err == nil { + return fmt.Errorf("%s already exists (use --force to overwrite)", path) + } + } + if err := os.WriteFile(path, []byte(configTemplate), 0o644); err != nil { + return fmt.Errorf("writing %s: %w", path, err) + } + fmt.Fprintf(cmd.OutOrStdout(), "wrote %s\n", path) + return nil + }, + } + cmd.Flags().BoolVarP(&force, "force", "f", false, "overwrite an existing file") + return cmd +} + +// configTemplate is a worked example, not a schema dump. Shows the per-phase +// override shape and the flags that bite when running large-context scans. +// Everything is optional — flat top-level fields are inherited by all phases. +const configTemplate = `# codecrucible config — loaded from ./.codecrucible.yaml by default, +# or pass --config path/to/file.yaml. Flags > env vars > this file > defaults. + +# --- Scope --------------------------------------------------------------- + +paths: + - src/ +exclude: + - "**/third-party/**" + - "**/vendor/**" +# include-tests: false +# include-docs: false +# max-file-size: 102400 # bytes; files larger than this are skipped + +# --- Cost / concurrency -------------------------------------------------- + +max-cost: 100.0 # dollar ceiling; scan aborts if estimated cost exceeds +concurrency: 4 # parallel chunk requests +# dry-run: true # print plan without LLM calls + +# --- Model sizing -------------------------------------------------------- +# These matter for models not in the built-in registry (unknown models get +# weak defaults: 100k context, 4096 output). Set both. + +context-limit: 200000 +max-output-tokens: 16384 +request-timeout: 900 # seconds; bump for large-context + thinking models + +# --- Provider (flat form — inherited by all phases) ---------------------- +# Providers: anthropic, openai, google, ollama, openai-compat, databricks +# API keys come from env: ANTHROPIC_API_KEY, OPENAI_API_KEY, GOOGLE_API_KEY, +# DATABRICKS_HOST + DATABRICKS_TOKEN. Provider auto-detects from whichever +# is set unless pinned here. ollama and openai-compat don't require API keys. + +provider: anthropic +model: claude-sonnet-4-6 +# base-url: "" # override default provider URL (required for openai-compat) +# custom-headers: +# - "anthropic-beta: context-1m-2025-08-07" + +# model-params is raw JSON merged into every request body. tool_choice and +# response_format are REPLACED; everything else deep-merges. Prefer +# --max-output-tokens over "max_tokens" here — the chunker reads the former. +# model-params: +# thinking: +# type: enabled +# budget_tokens: 10000 + +# --- Per-phase overrides ------------------------------------------------- +# Any field left unset inherits from the flat fields above. Analysis is the +# main scan pass; feature-detection is a cheap pre-pass (also calibrates +# the tokenizer — don't skip it on large repos); audit re-checks findings. +# +# Env form: PHASES_AUDIT_PROVIDER, PHASES_AUDIT_MODEL, PHASES_AUDIT_API_KEY, ... + +phases: + # analysis: + # model: claude-opus-4-6 + # context-limit: 500000 + # max-output-tokens: 32000 + # headers: + # - "anthropic-beta: context-1m-2025-08-07" + + # feature-detection: + # # Cheap model is fine here — output is a small feature list. But if + # # you use a different model than analysis, its tokenizer calibration + # # is discarded (different tokenizer = invalid measurement). + # provider: openai + # model: gpt-5.2 + # api-key: ${OPENAI_API_KEY} + + # audit: + # # Example: run audit on a different provider entirely. + # provider: google + # model: gemini-3-pro + # api-key: ${GOOGLE_API_KEY} + # # Drop analysis-phase model-params that don't apply here: + # model-params: {} + +# --- Audit tuning -------------------------------------------------------- + +# skip-audit: false +audit-confidence-threshold: 0.3 # reject findings below this (0.0-1.0) + +# --- Model registry ------------------------------------------------------ +# Extend or override the built-in model registry without recompiling. Keyed +# by name (case-insensitive): matching a built-in replaces it wholesale, a +# new name extends. Empty endpoint defaults to /invocations. +# +# models: +# - name: claude-sonnet-4-8 # a model the binary doesn't know yet +# provider: anthropic +# input_price_per_million: 3.0 +# output_price_per_million: 15.0 +# context_limit: 1000000 +# max_output_tokens: 64000 +# tokenizer_encoding: claude +# supports_structured_output: true +# - name: azure-gpt-5 # self-hosted / Azure deployment +# provider: openai-compat +# endpoint: deployments/my-azure-deploy/chat/completions +# context_limit: 400000 +# max_output_tokens: 16384 +# tokenizer_encoding: o200k_base + +# --- Output -------------------------------------------------------------- + +# output: results.sarif +# phase-output-dir: ./codecrucible-artifacts # optional; default is sidecars next to output +# prompts-dir: ./prompts/default # prompt set directory (see prompts/ for available sets) +` diff --git a/internal/cli/phase_artifacts.go b/internal/cli/phase_artifacts.go new file mode 100644 index 0000000..69ce98f --- /dev/null +++ b/internal/cli/phase_artifacts.go @@ -0,0 +1,113 @@ +package cli + +import ( + "encoding/json" + "fmt" + "log/slog" + "os" + "path/filepath" + "strings" + + "github.com/block/codecrucible/internal/config" + "github.com/block/codecrucible/internal/sarif" +) + +type phaseArtifactWriter struct { + paths map[string]string +} + +type featureDetectionArtifact struct { + Phase string `json:"phase"` + Status string `json:"status"` + Repo string `json:"repo,omitempty"` + Provider string `json:"provider,omitempty"` + Model string `json:"model,omitempty"` + DetectedFeatures []string `json:"detected_features"` + TokenCorrection float64 `json:"token_correction,omitempty"` + Reason string `json:"reason,omitempty"` + Error string `json:"error,omitempty"` + Fallback string `json:"fallback,omitempty"` +} + +func newPhaseArtifactWriter(cfg *config.Config) phaseArtifactWriter { + if strings.TrimSpace(cfg.PhaseOutputDir) != "" { + dir := cfg.PhaseOutputDir + return phaseArtifactWriter{paths: map[string]string{ + "feature-detection": filepath.Join(dir, "feature-detection.json"), + "analysis": filepath.Join(dir, "analysis.sarif"), + "audit": filepath.Join(dir, "audit.sarif"), + }} + } + + if !canDerivePhaseArtifactSidecars(cfg.Output) { + return phaseArtifactWriter{} + } + + dir := filepath.Dir(cfg.Output) + base := filepath.Base(cfg.Output) + ext := filepath.Ext(base) + stem := strings.TrimSuffix(base, ext) + if stem == "" { + stem = base + } + + return phaseArtifactWriter{paths: map[string]string{ + "feature-detection": filepath.Join(dir, stem+".feature-detection.json"), + "analysis": filepath.Join(dir, stem+".analysis.sarif"), + "audit": filepath.Join(dir, stem+".audit.sarif"), + }} +} + +func canDerivePhaseArtifactSidecars(output string) bool { + output = strings.TrimSpace(output) + if output == "" || output == "-" { + return false + } + clean := filepath.Clean(output) + return clean != "/dev/stdout" && clean != "/dev/stderr" +} + +func (w phaseArtifactWriter) Enabled() bool { + return len(w.paths) > 0 +} + +func (w phaseArtifactWriter) Path(phase string) string { + return w.paths[phase] +} + +func (w phaseArtifactWriter) WriteFeatureDetection(artifact featureDetectionArtifact) error { + return w.writeJSON("feature-detection", artifact) +} + +func (w phaseArtifactWriter) WriteSARIF(phase string, doc sarif.SARIFDocument) error { + return w.writeJSON(phase, doc) +} + +func (w phaseArtifactWriter) writeJSON(phase string, value any) error { + path := w.Path(phase) + if path == "" { + return nil + } + + data, err := json.MarshalIndent(value, "", " ") + if err != nil { + return fmt.Errorf("marshal %s artifact: %w", phase, err) + } + data = append(data, '\n') + if err := writeArtifactFile(path, data); err != nil { + return err + } + slog.Info("phase artifact written", "phase", phase, "path", path) + return nil +} + +func writeArtifactFile(path string, data []byte) error { + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0700); err != nil { + return fmt.Errorf("creating artifact directory %q: %w", dir, err) + } + if err := os.WriteFile(path, data, 0600); err != nil { + return fmt.Errorf("writing artifact file %q: %w", path, err) + } + return nil +} diff --git a/internal/cli/phase_artifacts_test.go b/internal/cli/phase_artifacts_test.go new file mode 100644 index 0000000..fa89c98 --- /dev/null +++ b/internal/cli/phase_artifacts_test.go @@ -0,0 +1,106 @@ +package cli + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + + "github.com/block/codecrucible/internal/config" + "github.com/block/codecrucible/internal/sarif" +) + +func TestPhaseArtifactWriter_DerivesSidecarPathsFromOutput(t *testing.T) { + dir := t.TempDir() + w := newPhaseArtifactWriter(&config.Config{ + Output: filepath.Join(dir, "results.sarif"), + }) + + if !w.Enabled() { + t.Fatal("writer should be enabled when --output is a file") + } + + tests := map[string]string{ + "feature-detection": filepath.Join(dir, "results.feature-detection.json"), + "analysis": filepath.Join(dir, "results.analysis.sarif"), + "audit": filepath.Join(dir, "results.audit.sarif"), + } + for phase, want := range tests { + if got := w.Path(phase); got != want { + t.Errorf("Path(%q) = %q, want %q", phase, got, want) + } + } +} + +func TestPhaseArtifactWriter_UsesExplicitDirectoryForStdoutOutput(t *testing.T) { + dir := t.TempDir() + w := newPhaseArtifactWriter(&config.Config{ + Output: "/dev/stdout", + PhaseOutputDir: dir, + }) + + if !w.Enabled() { + t.Fatal("writer should be enabled when --phase-output-dir is set") + } + if got, want := w.Path("analysis"), filepath.Join(dir, "analysis.sarif"); got != want { + t.Fatalf("analysis path = %q, want %q", got, want) + } +} + +func TestPhaseArtifactWriter_DisabledForStdoutWithoutDirectory(t *testing.T) { + for _, output := range []string{"", "-", "/dev/stdout", "/dev/stderr"} { + t.Run(output, func(t *testing.T) { + w := newPhaseArtifactWriter(&config.Config{Output: output}) + if w.Enabled() { + t.Fatalf("writer should be disabled for output %q", output) + } + }) + } +} + +func TestPhaseArtifactWriter_WritesArtifacts(t *testing.T) { + dir := t.TempDir() + w := newPhaseArtifactWriter(&config.Config{PhaseOutputDir: dir}) + + if err := w.WriteFeatureDetection(featureDetectionArtifact{ + Phase: "feature-detection", + Status: "completed", + Repo: "repo", + Provider: "openai", + Model: "gpt-5.5", + DetectedFeatures: []string{"web"}, + TokenCorrection: 1.25, + }); err != nil { + t.Fatalf("WriteFeatureDetection: %v", err) + } + if err := w.WriteSARIF("analysis", sarif.SARIFDocument{ + Version: "2.1.0", + Runs: []sarif.SARIFRun{{Results: []sarif.SARIFResult{}}}, + }); err != nil { + t.Fatalf("WriteSARIF: %v", err) + } + + fdData, err := os.ReadFile(filepath.Join(dir, "feature-detection.json")) + if err != nil { + t.Fatalf("read feature artifact: %v", err) + } + var fd map[string]any + if err := json.Unmarshal(fdData, &fd); err != nil { + t.Fatalf("feature artifact is not JSON: %v", err) + } + if fd["status"] != "completed" { + t.Fatalf("feature status = %v, want completed", fd["status"]) + } + + sarifData, err := os.ReadFile(filepath.Join(dir, "analysis.sarif")) + if err != nil { + t.Fatalf("read analysis artifact: %v", err) + } + var doc sarif.SARIFDocument + if err := json.Unmarshal(sarifData, &doc); err != nil { + t.Fatalf("analysis artifact is not SARIF JSON: %v", err) + } + if doc.Version != "2.1.0" { + t.Fatalf("analysis SARIF version = %q, want 2.1.0", doc.Version) + } +} diff --git a/internal/cli/root.go b/internal/cli/root.go new file mode 100644 index 0000000..ed0639e --- /dev/null +++ b/internal/cli/root.go @@ -0,0 +1,86 @@ +package cli + +import ( + "fmt" + "io" + "log/slog" + "os" + + "github.com/block/codecrucible/internal/config" + "github.com/block/codecrucible/internal/logging" + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +// Build-time variables set via ldflags. +var ( + version = "dev" + commit = "none" + date = "unknown" +) + +var ( + cfgFile string + verbose bool + v *viper.Viper + logger *slog.Logger + + exitFunc = os.Exit + stderrWriter io.Writer = os.Stderr +) + +// NewRootCommand creates the root cobra command with all subcommands attached. +func NewRootCommand() *cobra.Command { + rootCmd := &cobra.Command{ + Use: "codecrucible", + Short: "Security analysis tool for Git repositories", + Long: `codecrucible is a purpose-built CLI tool that analyzes Git repositories +for security vulnerabilities using LLM-based analysis and produces SARIF output.`, + Version: fmt.Sprintf("%s (commit: %s, built: %s)", version, commit, date), + SilenceUsage: true, + SilenceErrors: true, + PersistentPreRunE: func(cmd *cobra.Command, args []string) error { + var err error + v, err = config.SetupViper(cfgFile) + if err != nil { + return err + } + // Bind the persistent flags to viper so flag values override config/env. + _ = v.BindPFlag("verbose", cmd.Root().PersistentFlags().Lookup("verbose")) + + logger = logging.NewLogger(v.GetBool("verbose")) + slog.SetDefault(logger) + return nil + }, + } + + rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default: .codecrucible.yaml)") + rootCmd.PersistentFlags().BoolVar(&verbose, "verbose", false, "enable verbose/debug logging") + + // Cobra auto-attaches `completion` (shell tab-completion generator). + // Hide it — the name collides confusingly with "LLM completion" in + // this domain. Still invocable; just not in --help. + rootCmd.CompletionOptions.HiddenDefaultCmd = true + + rootCmd.AddCommand(newScanCommand()) + rootCmd.AddCommand(newListEndpointsCommand()) + rootCmd.AddCommand(newInitCommand()) + + return rootCmd +} + +// Execute runs the root command. +func Execute() { + exitCode := executeCommand(NewRootCommand(), stderrWriter) + if exitCode != 0 { + exitFunc(exitCode) + } +} + +func executeCommand(rootCmd *cobra.Command, stderr io.Writer) int { + if err := rootCmd.Execute(); err != nil { + fmt.Fprintf(stderr, "Error: %v\n", err) + return 1 + } + return 0 +} diff --git a/internal/cli/root_execute_test.go b/internal/cli/root_execute_test.go new file mode 100644 index 0000000..9472471 --- /dev/null +++ b/internal/cli/root_execute_test.go @@ -0,0 +1,99 @@ +package cli + +import ( + "bytes" + "errors" + "os" + "strings" + "testing" + + "github.com/spf13/cobra" +) + +func TestExecuteCommand_Success(t *testing.T) { + cmd := &cobra.Command{ + Use: "test", + RunE: func(cmd *cobra.Command, args []string) error { + return nil + }, + } + + stderr := new(bytes.Buffer) + code := executeCommand(cmd, stderr) + if code != 0 { + t.Fatalf("executeCommand() code = %d, want 0", code) + } + if stderr.Len() != 0 { + t.Fatalf("unexpected stderr output: %q", stderr.String()) + } +} + +func TestExecuteCommand_Failure(t *testing.T) { + cmd := &cobra.Command{ + Use: "test", + RunE: func(cmd *cobra.Command, args []string) error { + return errors.New("boom") + }, + } + + stderr := new(bytes.Buffer) + code := executeCommand(cmd, stderr) + if code != 1 { + t.Fatalf("executeCommand() code = %d, want 1", code) + } + if !strings.Contains(stderr.String(), "Error: boom") { + t.Fatalf("expected error output, got %q", stderr.String()) + } +} + +func TestExecute_DoesNotExitOnSuccess(t *testing.T) { + oldArgs := os.Args + oldExit := exitFunc + oldStderr := stderrWriter + t.Cleanup(func() { + os.Args = oldArgs + exitFunc = oldExit + stderrWriter = oldStderr + }) + + os.Args = []string{"codecrucible", "--help"} + stderrWriter = new(bytes.Buffer) + + exited := false + exitFunc = func(code int) { + exited = true + } + + Execute() + if exited { + t.Fatal("Execute() should not call exitFunc on success") + } +} + +func TestExecute_ExitsOnFailure(t *testing.T) { + oldArgs := os.Args + oldExit := exitFunc + oldStderr := stderrWriter + t.Cleanup(func() { + os.Args = oldArgs + exitFunc = oldExit + stderrWriter = oldStderr + }) + + os.Args = []string{"codecrucible", "--bad-flag"} + stderr := new(bytes.Buffer) + stderrWriter = stderr + + exitCode := 0 + exitFunc = func(code int) { + exitCode = code + } + + Execute() + if exitCode != 1 { + t.Fatalf("exit code = %d, want 1", exitCode) + } + if !strings.Contains(stderr.String(), "Error:") { + t.Fatalf("expected stderr to contain error prefix, got %q", stderr.String()) + } +} diff --git a/internal/cli/root_test.go b/internal/cli/root_test.go new file mode 100644 index 0000000..d0caa41 --- /dev/null +++ b/internal/cli/root_test.go @@ -0,0 +1,37 @@ +package cli + +import "testing" + +func TestNewRootCommand_HasSubcommands(t *testing.T) { + cmd := NewRootCommand() + + got := map[string]bool{} + aliases := map[string]bool{} + for _, sub := range cmd.Commands() { + got[sub.Name()] = true + for _, a := range sub.Aliases { + aliases[a] = true + } + } + + for _, want := range []string{"scan", "list-models"} { + if !got[want] { + t.Fatalf("expected subcommand %q to be registered", want) + } + } + // list-endpoints is preserved as a backwards-compatible alias of list-models. + if !aliases["list-endpoints"] { + t.Fatalf("expected list-endpoints to remain registered as an alias") + } +} + +func TestNewRootCommand_HasPersistentFlags(t *testing.T) { + cmd := NewRootCommand() + flags := cmd.PersistentFlags() + + for _, name := range []string{"config", "verbose"} { + if flags.Lookup(name) == nil { + t.Fatalf("expected persistent flag %q", name) + } + } +} diff --git a/internal/cli/scan.go b/internal/cli/scan.go new file mode 100644 index 0000000..bf581ff --- /dev/null +++ b/internal/cli/scan.go @@ -0,0 +1,772 @@ +package cli + +import ( + "encoding/json" + "errors" + "fmt" + "log/slog" + "os" + "path/filepath" + "strings" + "sync" + + "github.com/block/codecrucible/internal/chunk" + "github.com/block/codecrucible/internal/config" + "github.com/block/codecrucible/internal/ingest" + "github.com/block/codecrucible/internal/llm" + "github.com/block/codecrucible/internal/sarif" + "github.com/spf13/cobra" +) + +func newScanCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "scan [repo-path]", + Short: "Scan a repository for security vulnerabilities", + Long: `Scan analyzes one or more repository paths using an LLM-based security analysis +pipeline and produces SARIF output suitable for GitHub Code Scanning integration.`, + Args: cobra.MaximumNArgs(1), + RunE: runScan, + } + + // Repository targeting + cmd.Flags().StringSlice("paths", nil, "paths within the repository to analyze (can be specified multiple times)") + cmd.Flags().StringSlice("include", nil, "glob patterns for files to include") + cmd.Flags().StringSlice("exclude", nil, "glob patterns for files to exclude") + + // Per-phase LLM selection: three symmetric families, defined once. + // Analysis-phase flags carry no prefix (they're the ones you reach + // for first). Unset --audit-* / --feature-detection-* inherit from + // the analysis values. + for _, p := range phaseFlagSets { + cmd.Flags().String(p.prefix+"model", "", p.what+" model"+p.inherit) + cmd.Flags().String(p.prefix+"provider", "", p.what+" provider: anthropic, openai, google, ollama, openai-compat, databricks"+p.inherit) + cmd.Flags().String(p.prefix+"api-key", "", p.what+" API key (optional for ollama/openai-compat)"+p.inherit) + cmd.Flags().String(p.prefix+"base-url", "", p.what+" base URL override"+p.inherit) + cmd.Flags().String(p.prefix+"model-params", "", p.what+" model request params as JSON (merged into request body)"+p.inherit) + // Hide per-phase flags from default --help to reduce noise. + // They still work; use --help-all or the README for full docs. + if p.prefix != "" { + for leaf := range phaseLeaves { + _ = cmd.Flags().MarkHidden(p.prefix + leaf) + } + } + // Short-form alias: distinct hidden flags binding to the same + // viper key in bindScanFlags. Cobra has command aliases but not + // flag aliases, so this is the idiom. + if p.alias != "" { + for leaf := range phaseLeaves { + cmd.Flags().String(p.alias+leaf, "", "") + _ = cmd.Flags().MarkHidden(p.alias + leaf) + } + } + } + + // Analysis behaviour + cmd.Flags().Float64("fail-on-severity", 0, "exit non-zero if any finding meets or exceeds this severity (0-10)") + cmd.Flags().Float64("max-cost", 25, "maximum cost budget in dollars (0 = unlimited)") + cmd.Flags().Bool("dry-run", false, "show what would be analyzed without making LLM calls") + cmd.Flags().String("custom-requirements", "", "additional analysis requirements to include in the prompt") + cmd.Flags().StringSlice("context-source", nil, "supplementary context as key=value pairs: name=X,type=,location=Y,priority=N,compress=true (repeatable)") + cmd.Flags().Int("context-budget-pct", 15, "percentage of context window reserved for supplementary context (max 40)") + cmd.Flags().String("prompts-dir", "", "prompt set directory containing YAML templates (default: prompts/default)") + cmd.Flags().StringSlice("custom-headers", nil, "additional HTTP headers for LLM requests, format 'Name: Value' (repeatable)") + + // Content options + cmd.Flags().Bool("include-tests", false, "include test files in the analysis") + cmd.Flags().Bool("include-docs", false, "include documentation files in the analysis") + cmd.Flags().Bool("compress", false, "compress whitespace in source files to save tokens") + cmd.Flags().Int("max-file-size", 102400, "exclude files larger than this size in bytes (0 = no limit)") + + // Model limits + cmd.Flags().Int("context-limit", 0, "override the model's context window size in tokens (0 = use model registry default)") + cmd.Flags().Int("max-output-tokens", 0, "override the model's max output tokens (0 = use model registry default)") + cmd.Flags().Int("request-timeout", 0, "per-request HTTP timeout in seconds (0 = default 600s)") + + // Phase gates + cmd.Flags().Bool("skip-feature-detection", false, "skip the feature detection pre-pass (faster for small repos)") + cmd.Flags().Bool("skip-audit", false, "skip the CWE-specific audit phase (faster but less accurate)") + cmd.Flags().Float64("audit-confidence-threshold", 0.3, "reject findings below this confidence score (0.0-1.0)") + cmd.Flags().Int("audit-batch-size", 25, "split audit into batches of N findings (0 = single call). Default keeps each call under typical server connection-age limits (~10-12min)") + cmd.Flags().Int("concurrency", 3, "max number of chunks to analyze in parallel") + + // Output + cmd.Flags().StringP("output", "o", "", "write output to file (default: stdout)") + cmd.Flags().String("phase-output-dir", "", "write per-phase artifacts to this directory (default: sidecars next to --output)") + + // Bind scan flags to viper + cmd.PreRun = func(cmd *cobra.Command, args []string) { + bindScanFlags(cmd) + } + + return cmd +} + +// phaseFlagSets keeps the three per-phase flag families symmetric. Add a +// knob once, get it on all three. The analysis phase has an empty prefix: +// --model, --provider, --api-key, --model-params are its flags. The other +// two get their name as prefix. +// +// viperKey maps the CLI form to the config.Phases.. path so a +// config-file phases: block, a PHASES_AUDIT_MODEL env var, and an +// --audit-model flag all land in the same place. +// +// The analysis phase is odd: its model/provider also bind to the legacy +// flat cfg.Model/cfg.Provider keys so existing config files and env vars +// (CODECRUCIBLE_PROVIDER, etc.) keep working. ResolvePhases reads the +// flat keys when seeding the analysis PhaseConfig, so either path ends up +// in the same slot. +var phaseFlagSets = []struct { + prefix string // CLI flag prefix + alias string // short prefix; flags registered hidden, bound to the same viper key + viperKey string // phases.. + what string // help text subject + inherit string // help text suffix +}{ + {"", "", "analysis", "analysis-phase", " (inherited by other phases unless overridden)"}, + {"feature-detection-", "fd-", "feature-detection", "feature-detection-phase", " (default: inherit from analysis; alias: --fd-*)"}, + {"audit-", "", "audit", "audit-phase", " (default: inherit from analysis)"}, + {"context-compress-", "cc-", "context-compress", "context-compression-phase", " (default: inherit from analysis; alias: --cc-*)"}, +} + +// phaseLeaves are the per-phase knobs. CLI name → PhaseConfig mapstructure +// leaf. model-params gets the -json suffix because the CLI form is a JSON +// string while the config-file form is a native YAML map — same split as +// the legacy global model-params / model-params-json pair. +var phaseLeaves = map[string]string{ + "model": "model", + "provider": "provider", + "api-key": "api-key", + "base-url": "base-url", + "model-params": "model-params-json", +} + +func bindScanFlags(cmd *cobra.Command) { + // Flat flags map to same-named viper keys. + flags := []string{ + "paths", "fail-on-severity", "max-cost", "dry-run", + "include-tests", "include-docs", "compress", "custom-requirements", + "output", "phase-output-dir", "prompts-dir", "include", "exclude", "custom-headers", + "skip-feature-detection", "concurrency", "max-file-size", + "context-limit", "max-output-tokens", "request-timeout", + "skip-audit", "audit-confidence-threshold", "audit-batch-size", + "context-budget-pct", + } + for _, f := range flags { + _ = v.BindPFlag(f, cmd.Flags().Lookup(f)) + } + // context-source flag populates the raw-string slice; config.Load parses + // each into a ContextSource and appends to any declared in the config file. + _ = v.BindPFlag("context-sources-raw", cmd.Flags().Lookup("context-source")) + + // Per-phase families: one loop, three phases, four knobs each. + // BindPFlag stores key → *pflag.Flag in a map; a second bind to the + // same key overwrites the first. So bind the alias only when the + // user actually set it — otherwise the long form stays wired. + // bindScanFlags runs from PreRun, after argv parsing, so .Changed + // is accurate here. + for _, p := range phaseFlagSets { + for cli, leaf := range phaseLeaves { + key := "phases." + p.viperKey + "." + leaf + _ = v.BindPFlag(key, cmd.Flags().Lookup(p.prefix+cli)) + if p.alias != "" { + if af := cmd.Flags().Lookup(p.alias + cli); af != nil && af.Changed { + _ = v.BindPFlag(key, af) + } + } + } + } + + // Legacy flat-key aliases for the analysis phase so existing + // config files (model:, provider:) and env vars keep working. + // ResolvePhases reads these into Phases.Analysis. + _ = v.BindPFlag("model", cmd.Flags().Lookup("model")) + _ = v.BindPFlag("provider", cmd.Flags().Lookup("provider")) + _ = v.BindPFlag("model-params-json", cmd.Flags().Lookup("model-params")) + _ = v.BindPFlag("base-url", cmd.Flags().Lookup("base-url")) + // Same for the legacy per-phase model-only flags. + _ = v.BindPFlag("feature-detection-model", cmd.Flags().Lookup("feature-detection-model")) + _ = v.BindPFlag("audit-model", cmd.Flags().Lookup("audit-model")) +} + +// exitCodeFindings is the exit code when findings exceed --fail-on-severity. +const exitCodeFindings = 2 + +func runScan(cmd *cobra.Command, args []string) error { + cfg, err := config.Load(v) + if err != nil { + return err + } + if err := config.ResolvePhases(cfg); err != nil { + return fmt.Errorf("loading config: %w", err) + } + + // Determine repo root. + repoRoot := "." + if len(args) > 0 { + repoRoot = args[0] + } + repoRoot, err = filepath.Abs(repoRoot) + if err != nil { + return fmt.Errorf("resolving repo path: %w", err) + } + + slog.Info("scan starting", + "repo", repoRoot, + "paths", cfg.Paths, + "model", cfg.Model, + "dry_run", cfg.DryRun, + ) + + // Phase configs are fully resolved by this point — registry lookup, + // per-phase overrides, provider detection, key cascade all done. + // Alias the analysis phase's model config to keep downstream token + // math and call sites reading the familiar name. + analysis := &cfg.Phases.Analysis + modelCfg := analysis.ModelCfg + + // --- Stage 1: Ingest --- + files, err := ingestFiles(repoRoot, cfg) + if err != nil { + return err + } + + // --- Stage 2: Filter --- + filtered, stats := ingest.FilterFiles(files, ingest.FilterConfig{ + IncludeTests: cfg.IncludeTests, + IncludeDocs: cfg.IncludeDocs, + Include: cfg.Include, + Exclude: cfg.Exclude, + MaxFileSize: cfg.MaxFileSize, + }) + + slog.Info("ingestion complete", + "files_found", stats.Total, + "files_kept", stats.Kept, + ) + + // --- Stage 3: Build FileMap (defer full XML generation) --- + flattenCfg := ingest.FlattenConfig{Compress: cfg.Compress} + flatResult := ingest.FlattenFileMapOnly(filtered) + + // --- Stage 3.5: Build import graph and export summaries --- + importGraph := ingest.ResolveImports(filtered) + exportSummaries := buildExportSummaries(filtered) + + slog.Info("import graph built", + "files_with_imports", len(importGraph), + "files_with_summaries", len(exportSummaries), + ) + + // --- Stage 4: Count tokens and chunk --- + counter := chunk.NewTokenCounter(modelCfg.Encoding, slog.Default()) + // Streaming count: iterate FileMap one file at a time so peak memory is + // max(single file XML) rather than the full concatenated document. + totalTokens := streamingTokenCount(flatResult.FileMap, counter, flattenCfg) + flatResult.Tokens = totalTokens + + // Only build the full XML when the repo likely fits in a single chunk. + // The chunker returns this verbatim for single-chunk, but for multi-chunk + // it rebuilds per-file XML from FileMap anyway — so building ~2x N of + // XML upfront would be pure waste. + if totalTokens <= modelCfg.ContextLimit { + flatResult.BuildFullXML(filtered, flattenCfg) + } + + // Estimate cost. + analysisCost := float64(totalTokens) * modelCfg.InputPricePerM / 1_000_000 + + // Estimate audit phase cost (assumes all repo tokens as context in the worst case). + var auditCostEstimate float64 + if !cfg.SkipAudit { + auditCostEstimate = float64(totalTokens) * cfg.Phases.Audit.ModelCfg.InputPricePerM / 1_000_000 + } + + estimatedCost := analysisCost + auditCostEstimate + + slog.Info("analysis scope", + "files", len(filtered), + "tokens", totalTokens, + "model", modelCfg.Name, + "context_limit", modelCfg.ContextLimit, + "estimated_analysis_cost", fmt.Sprintf("$%.4f", analysisCost), + "estimated_audit_cost", fmt.Sprintf("$%.4f", auditCostEstimate), + "estimated_total_input_cost", fmt.Sprintf("$%.4f", estimatedCost), + ) + + // Handle empty repo. + if len(filtered) == 0 { + slog.Info("no source files after filtering, producing empty SARIF") + return outputEmptySARIF(cfg) + } + + if cfg.DryRun { + fmt.Printf("Dry run — analysis scope:\n") + fmt.Printf(" Files: %d (of %d total)\n", stats.Kept, stats.Total) + fmt.Printf(" Tokens: %d\n", totalTokens) + fmt.Printf(" Model: %s (context limit: %d)\n", modelCfg.Name, modelCfg.ContextLimit) + fmt.Printf(" Estimated analysis input cost: $%.4f\n", analysisCost) + if !cfg.SkipAudit { + fmt.Printf(" Estimated audit input cost: $%.4f (model: %s)\n", auditCostEstimate, cfg.Phases.Audit.ModelCfg.Name) + } + fmt.Printf(" Estimated total input cost: $%.4f\n", estimatedCost) + if totalTokens > modelCfg.ContextLimit { + chunks := (totalTokens / modelCfg.ContextLimit) + 1 + fmt.Printf(" Will require ~%d chunks\n", chunks) + } + return nil + } + + // Check max cost. + if cfg.MaxCost > 0 && estimatedCost > cfg.MaxCost { + return fmt.Errorf("estimated cost $%.4f exceeds --max-cost $%.2f; aborting (use --dry-run to preview)", estimatedCost, cfg.MaxCost) + } + + // --- Stage 5: Prepare LLM --- + client, endpoint, err := buildPhaseClient(*analysis, cfg) + if err != nil { + return err + } + + slog.Info("LLM provider configured", "provider", analysis.Provider, "model", modelCfg.Name) + + // Load prompt templates. + promptLoader, err := resolvePromptLoader(cfg.PromptsDir) + if err != nil { + return fmt.Errorf("loading prompts: %w", err) + } + + schema := llm.SecurityAnalysisSchema() + outputMode := llm.OutputModeForModel(modelCfg.Name) + repoName := filepath.Base(repoRoot) + artifacts := newPhaseArtifactWriter(cfg) + if artifacts.Enabled() { + slog.Info("phase artifacts enabled", + "feature_detection", artifacts.Path("feature-detection"), + "analysis", artifacts.Path("analysis"), + "audit", artifacts.Path("audit"), + ) + } + + // --- Stage 5.25: Load & pack supplementary context --- + analysisCtx, auditCtx, err := loadSupplementaryContext(cmd.Context(), cfg, counter, promptLoader, modelCfg.ContextLimit) + if err != nil { + return err + } + + // --- Stage 5.5: Estimate whether we need chunking --- + // Measure prompt overhead assuming all sections (worst-case) to decide if + // feature detection is worth the extra LLM round-trip. + worstCaseMsgs, err := promptLoader.AssembleMessages(llm.PromptParams{ + RepoName: repoName, + XML: "", + Schema: string(*schema), + ChunkTotal: 1, + CustomRequirements: cfg.CustomRequirements, + EnabledFeatures: nil, // nil = all sections included + SupplementaryContext: analysisCtx.Rendered, + }) + if err != nil { + return fmt.Errorf("measuring prompt overhead: %w", err) + } + worstCaseOverhead := 0 + for _, msg := range worstCaseMsgs { + worstCaseOverhead += counter.Count(msg.Content) + } + if outputMode == llm.OutputModeToolUse && schema != nil { + worstCaseOverhead += counter.Count(string(*schema)) + } + + // If the repo fits in a single chunk even with all sections, skip feature + // detection entirely — it's a full LLM round-trip for no benefit. + const tokenizerSafetyMargin = 0.20 + outputReserve := modelCfg.MaxOutputTokens + worstCaseEffective := int(float64(modelCfg.ContextLimit) * (1 - tokenizerSafetyMargin)) + worstCaseBudget := worstCaseEffective - worstCaseOverhead - outputReserve + + var detectedFeatures []string + var tokenCorrection float64 + if !cfg.SkipFeatureDetection && totalTokens > worstCaseBudget { + // Multi-chunk scenario: feature detection trims sections and saves tokens. + fd := &cfg.Phases.FeatureDetection + fdClient, fdEndpoint, fdErr := buildPhaseClient(*fd, cfg) + if fdErr != nil { + slog.Warn("failed to build feature-detection client; falling back to analysis client", + "error", fdErr, "provider", fd.Provider) + fdClient, fdEndpoint = client, endpoint + fd = analysis + } else if fd.ModelCfg.Name != modelCfg.Name || fd.Provider != analysis.Provider { + slog.Info("feature detection uses separate configuration", + "provider", fd.Provider, "model", fd.ModelCfg.Name) + } + fdOutputMode := llm.OutputModeForModel(fd.ModelCfg.Name) + var fdCorrection float64 + detectedFeatures, fdCorrection, err = runFeatureDetection(cmd.Context(), filtered, repoName, fdClient, fdEndpoint, fd.ModelCfg, promptLoader, fdOutputMode, fd.ModelParams, counter) + if err != nil { + if wErr := artifacts.WriteFeatureDetection(featureDetectionArtifact{ + Phase: "feature-detection", + Status: "failed", + Repo: repoName, + Provider: fd.Provider, + Model: fd.ModelCfg.Name, + Error: err.Error(), + Fallback: "all_sections", + }); wErr != nil { + return fmt.Errorf("writing feature-detection artifact: %w", wErr) + } + slog.Warn("feature detection failed, using all sections", "error", err) + } else { + if wErr := artifacts.WriteFeatureDetection(featureDetectionArtifact{ + Phase: "feature-detection", + Status: "completed", + Repo: repoName, + Provider: fd.Provider, + Model: fd.ModelCfg.Name, + DetectedFeatures: detectedFeatures, + TokenCorrection: fdCorrection, + }); wErr != nil { + return fmt.Errorf("writing feature-detection artifact: %w", wErr) + } + slog.Info("feature detection complete", "features", detectedFeatures) + } + // Calibration is only valid when feature detection hit the same model + // as chunk analysis — a different model means a different tokenizer. + if fd.ModelCfg.Name == modelCfg.Name { + tokenCorrection = fdCorrection + } + } else if cfg.SkipFeatureDetection { + if wErr := artifacts.WriteFeatureDetection(featureDetectionArtifact{ + Phase: "feature-detection", + Status: "skipped", + Repo: repoName, + Reason: "--skip-feature-detection", + }); wErr != nil { + return fmt.Errorf("writing feature-detection artifact: %w", wErr) + } + slog.Info("feature detection skipped (--skip-feature-detection)") + } else { + if wErr := artifacts.WriteFeatureDetection(featureDetectionArtifact{ + Phase: "feature-detection", + Status: "skipped", + Repo: repoName, + Reason: "repo fits in single chunk", + }); wErr != nil { + return fmt.Errorf("writing feature-detection artifact: %w", wErr) + } + slog.Info("feature detection skipped (repo fits in single chunk)") + } + + // Release SourceFile slice — FileMap retains the content strings via + // shared Go string backing bytes; the slice of structs is no longer needed. + filtered = nil + + // Measure actual prompt token overhead with the resolved features. + // If features are nil (same as worst-case), reuse the already-computed value + // to avoid a redundant BPE encoding pass. + promptOverhead := worstCaseOverhead + if detectedFeatures != nil { + measureMsgs, mErr := promptLoader.AssembleMessages(llm.PromptParams{ + RepoName: repoName, + XML: "", + Schema: string(*schema), + ChunkTotal: 1, + CustomRequirements: cfg.CustomRequirements, + EnabledFeatures: detectedFeatures, + SupplementaryContext: analysisCtx.Rendered, + }) + if mErr != nil { + return fmt.Errorf("measuring prompt overhead: %w", mErr) + } + promptOverhead = 0 + for _, msg := range measureMsgs { + promptOverhead += counter.Count(msg.Content) + } + toolOverhead := 0 + if outputMode == llm.OutputModeToolUse && schema != nil { + toolOverhead = counter.Count(string(*schema)) + } + promptOverhead += toolOverhead + } + + // For multi-chunk scenarios, the prompt includes a manifest of all other + // file paths (injected by AssembleMessages when ChunkTotal > 1). Account + // for this in the overhead so the chunk budget isn't over-allocated. + // Cap the manifest to 10% of the context limit so it doesn't crowd out + // the actual file content in large repos. + manifestBudget := modelCfg.ContextLimit / 10 + if totalTokens > worstCaseBudget { + allPaths := make([]string, 0, len(flatResult.FileMap)) + for p := range flatResult.FileMap { + allPaths = append(allPaths, p) + } + manifestTokens := counter.Count(strings.Join(allPaths, "\n")) + if manifestTokens > manifestBudget { + manifestTokens = manifestBudget + } + promptOverhead += manifestTokens + } + + slog.Info("prompt overhead measured", + "total_overhead", promptOverhead, + ) + + // Chunk if needed. + // Reserve budget for output tokens, measured prompt overhead, and a safety + // margin for tokenizer variance. The anthropic-tokenizer-go library uses + // Claude 2-era BPE vocabulary which undercounts ~17% vs Claude 4's actual + // tokenizer. Also accounts for API-side overhead (tool definitions, message + // framing, internal prompt formatting). + effectiveLimit := int(float64(modelCfg.ContextLimit) * (1 - tokenizerSafetyMargin)) + chunkBudget := effectiveLimit - promptOverhead - outputReserve + // When supplementary context is configured, a too-small chunk budget + // means the user has starved the scan target. Fail loudly rather than + // producing tiny chunks with poor coverage. + const minChunkBudget = 5000 + if analysisCtx.Tokens > 0 && chunkBudget < minChunkBudget { + return fmt.Errorf("supplementary context (%d tokens) leaves only %d tokens for source code; reduce --context-budget-pct or enable compress:true on large sources", + analysisCtx.Tokens, chunkBudget) + } + if chunkBudget <= 0 { + chunkBudget = modelCfg.ContextLimit / 2 + } + + // Apply the measured correction from feature detection. Only shrink — + // if the heuristic overcounted (correction < 1), we're already safe. + // Clamp to 2× to guard against noisy measurements where framing overhead + // dominates a small feature-detection payload. + if tokenCorrection > 1.0 { + clamp := tokenCorrection + if clamp > 2.0 { + clamp = 2.0 + } + adjusted := int(float64(chunkBudget) / clamp) + slog.Info("shrinking chunk budget to match measured tokenizer density", + "before", chunkBudget, + "after", adjusted, + "correction_factor", fmt.Sprintf("%.3f", tokenCorrection), + ) + chunkBudget = adjusted + } + + chunker := chunk.NewChunker(counter, slog.Default()) + chunkOpts := &chunk.ChunkOptions{ + ImportGraph: importGraph, + ExportSummaries: exportSummaries, + } + chunks, err := chunker.Chunk(flatResult, chunkBudget, chunkOpts) + if err != nil { + return fmt.Errorf("chunking: %w", err) + } + + slog.Info("chunking complete", + "chunks", len(chunks), + "chunk_budget", chunkBudget, + "effective_limit", effectiveLimit, + ) + + // Release the full XML string — the chunker has either returned it + // verbatim (single-chunk) or built per-file XML from FileMap + // (multi-chunk). Holding it further wastes ~2x source-size bytes. + flatResult.XML = "" + + // --- Stage 5.75: Analyze chunks in parallel --- + maxConcurrency := cfg.Concurrency + if maxConcurrency <= 0 { + maxConcurrency = 3 + } + + type chunkResult struct { + index int + doc sarif.SARIFDocument + usage llm.TokenUsage + cost float64 + err error + } + + results := make([]chunkResult, len(chunks)) + var wg sync.WaitGroup + sem := make(chan struct{}, maxConcurrency) + + for i, c := range chunks { + wg.Add(1) + go func(idx int, ch chunk.Chunk) { + defer wg.Done() + sem <- struct{}{} // acquire + defer func() { <-sem }() // release + + doc, usage, cost, aErr := analyzeChunk( + cmd.Context(), ch, repoName, client, endpoint, + modelCfg, promptLoader, schema, outputMode, + cfg.CustomRequirements, analysisCtx.Rendered, detectedFeatures, flatResult.FileMap, analysis.ModelParams, + ) + results[idx] = chunkResult{index: idx, doc: doc, usage: usage, cost: cost, err: aErr} + }(i, c) + } + wg.Wait() + + // Collect results in order. + var sarifDocs []sarif.SARIFDocument + var totalUsage llm.TokenUsage + var totalCost float64 + var overflowPaths []string + for i, r := range results { + if errors.Is(r.err, llm.ErrContextLengthExceeded) { + overflowPaths = append(overflowPaths, chunks[i].Paths...) + continue + } + sarifDocs = append(sarifDocs, r.doc) + totalUsage.PromptTokens += r.usage.PromptTokens + totalUsage.CompletionTokens += r.usage.CompletionTokens + totalCost += r.cost + } + + // Recovery pass: any chunk that hit the server-side context limit gets + // re-chunked at 60% budget and retried once. Overflow 400s fail fast + // (no generation), so the wasted cost is just one cheap round-trip. + // One round is enough when calibration is working — if 60% still + // overflows, the heuristic is off by >40% and we'd rather fail loudly. + if len(overflowPaths) > 0 { + retryBudget := chunkBudget * 6 / 10 + slog.Warn("re-chunking files from overflowed chunks at reduced budget", + "files", len(overflowPaths), + "original_budget", chunkBudget, + "retry_budget", retryBudget, + ) + + retryMap := make(ingest.FileMap, len(overflowPaths)) + for _, p := range overflowPaths { + if content, ok := flatResult.FileMap[p]; ok { + retryMap[p] = content + } + } + retryChunks, rErr := chunker.Chunk(ingest.FlattenResult{FileMap: retryMap}, retryBudget, chunkOpts) + if rErr != nil { + slog.Error("overflow recovery re-chunking failed; files skipped", "error", rErr, "files", len(overflowPaths)) + } else { + slog.Info("overflow recovery pass starting", "chunks", len(retryChunks)) + for _, rc := range retryChunks { + doc, usage, cost, aErr := analyzeChunk( + cmd.Context(), rc, repoName, client, endpoint, + modelCfg, promptLoader, schema, outputMode, + cfg.CustomRequirements, analysisCtx.Rendered, detectedFeatures, flatResult.FileMap, analysis.ModelParams, + ) + if errors.Is(aErr, llm.ErrContextLengthExceeded) { + slog.Error("overflow recovery chunk still exceeds context; files skipped", + "files", len(rc.Paths), + "estimated_tokens", rc.Tokens, + "hint", "lower --context-limit or raise the tokenizer safety margin", + ) + continue + } + sarifDocs = append(sarifDocs, doc) + totalUsage.PromptTokens += usage.PromptTokens + totalUsage.CompletionTokens += usage.CompletionTokens + totalCost += cost + } + } + } + + // --- Stage 6: Merge --- + merged := sarif.Merge(sarifDocs) + + // --- Stage 6.5: Post-process (dedup + deprioritize non-source) --- + merged = sarif.PostProcess(merged) + + slog.Info("initial analysis complete", + "total_findings", len(merged.Runs[0].Results), + "total_rules", len(merged.Runs[0].Tool.Driver.Rules), + "prompt_tokens", totalUsage.PromptTokens, + "completion_tokens", totalUsage.CompletionTokens, + "total_cost", fmt.Sprintf("$%.4f", totalCost), + ) + if err := artifacts.WriteSARIF("analysis", merged); err != nil { + return fmt.Errorf("writing analysis artifact: %w", err) + } + + // --- Stage 6.75: Audit phase (CWE-specific scrutiny) --- + if !cfg.SkipAudit && len(merged.Runs[0].Results) > 0 { + audit := &cfg.Phases.Audit + auditClient, auditEndpoint, auditErr := buildPhaseClient(*audit, cfg) + if auditErr != nil { + slog.Warn("failed to build audit client; falling back to analysis client", + "error", auditErr, "provider", audit.Provider) + auditClient, auditEndpoint = client, endpoint + audit = analysis + } else if audit.ModelCfg.Name != modelCfg.Name || audit.Provider != analysis.Provider { + slog.Info("audit phase uses separate configuration", + "provider", audit.Provider, "model", audit.ModelCfg.Name) + } + auditOutputMode := llm.OutputModeForModel(audit.ModelCfg.Name) + + auditedDoc, auditUsage, auditCost := runAuditPhase( + cmd.Context(), merged, repoName, + auditClient, auditEndpoint, audit.ModelCfg, promptLoader, + auditOutputMode, flatResult.FileMap, cfg.AuditConfidenceThreshold, audit.ModelParams, auditCtx.Rendered, + cfg.AuditBatchSize, !cfg.IncludeTests, counter, + ) + if auditedDoc != nil { + merged = *auditedDoc + totalUsage.PromptTokens += auditUsage.PromptTokens + totalUsage.CompletionTokens += auditUsage.CompletionTokens + totalCost += auditCost + if err := artifacts.WriteSARIF("audit", merged); err != nil { + return fmt.Errorf("writing audit artifact: %w", err) + } + } + } else if cfg.SkipAudit { + slog.Info("audit phase skipped (--skip-audit)") + } else { + slog.Info("audit phase skipped (no findings to audit)") + } + + slog.Info("analysis complete", + "total_findings", len(merged.Runs[0].Results), + "total_rules", len(merged.Runs[0].Tool.Driver.Rules), + "prompt_tokens", totalUsage.PromptTokens, + "completion_tokens", totalUsage.CompletionTokens, + "total_cost", fmt.Sprintf("$%.4f", totalCost), + ) + + // --- Stage 7: Output --- + output, err := json.MarshalIndent(merged, "", " ") + if err != nil { + return fmt.Errorf("marshaling SARIF: %w", err) + } + + if cfg.Output != "" { + if err := os.WriteFile(cfg.Output, output, 0600); err != nil { + return fmt.Errorf("writing output file: %w", err) + } + slog.Info("SARIF written", "path", cfg.Output) + } else { + fmt.Println(string(output)) + } + + // Check --fail-on-severity threshold. + if cfg.FailOnSeverity > 0 { + ruleSeverity := make(map[string]float64, len(merged.Runs[0].Tool.Driver.Rules)) + for _, rule := range merged.Runs[0].Tool.Driver.Rules { + if sevStr, ok := rule.Properties["security-severity"].(string); ok { + var sev float64 + if _, err := fmt.Sscanf(sevStr, "%f", &sev); err == nil { + ruleSeverity[rule.ID] = sev + } + } + } + var maxSev float64 + for _, result := range merged.Runs[0].Results { + if sev, ok := ruleSeverity[result.RuleID]; ok && sev > maxSev { + maxSev = sev + } + } + if maxSev >= cfg.FailOnSeverity { + slog.Info("findings exceed severity threshold", + "threshold", cfg.FailOnSeverity, + "max_severity", maxSev, + ) + exitFunc(exitCodeFindings) + return nil + } + } + + return nil +} diff --git a/internal/cli/scan_analyze.go b/internal/cli/scan_analyze.go new file mode 100644 index 0000000..8dac2c7 --- /dev/null +++ b/internal/cli/scan_analyze.go @@ -0,0 +1,375 @@ +package cli + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "strings" + "time" + + "github.com/block/codecrucible/internal/chunk" + "github.com/block/codecrucible/internal/config" + "github.com/block/codecrucible/internal/ingest" + "github.com/block/codecrucible/internal/llm" + "github.com/block/codecrucible/internal/sarif" +) + +// analyzeChunk processes a single chunk: assembles the prompt, calls the LLM, parses +// the response, and builds a per-chunk SARIF document. Safe for concurrent use. +func analyzeChunk( + ctx context.Context, + c chunk.Chunk, + repoName string, + client llm.Client, + endpoint string, + modelCfg config.ModelConfig, + promptLoader *llm.PromptLoader, + schema *json.RawMessage, + outputMode llm.OutputMode, + customRequirements string, + supContext string, + enabledFeatures []string, + fileMap ingest.FileMap, + modelParams map[string]any, +) (sarif.SARIFDocument, llm.TokenUsage, float64, error) { + start := time.Now() + slog.Info("analyzing chunk", + "chunk", fmt.Sprintf("%d/%d", c.Index+1, c.Total), + "files", len(c.Paths), + "tokens", c.Tokens, + ) + + // Build manifest of files not in this chunk. + var otherPaths []string + if c.Total > 1 { + pathSet := make(map[string]bool, len(c.Paths)) + for _, p := range c.Paths { + pathSet[p] = true + } + for _, p := range c.Manifest { + if !pathSet[p] { + otherPaths = append(otherPaths, p) + } + } + } + + // Build manifest with summaries for cross-chunk context. + // Cap manifest size to ~10% of context limit (in characters, ~4 chars/token) + // to avoid crowding out file content in large repos. + manifestCharBudget := modelCfg.ContextLimit / 10 * 4 + manifest := capManifest(otherPaths, manifestCharBudget) + if len(c.RelatedSummaries) > 0 { + // Prepend related summaries before the raw path list. + var enriched []string + enriched = append(enriched, "Related files (summaries):") + enriched = append(enriched, c.RelatedSummaries...) + enriched = append(enriched, "") + enriched = append(enriched, "Other files:") + enriched = append(enriched, manifest...) + manifest = enriched + } + + // Expand per-chunk placeholders in custom requirements so that prompt + // templates can reference the files in this specific chunk, the chunk + // index/total, and the detected feature set. + expandedRequirements := customRequirements + if strings.Contains(expandedRequirements, "{") { + expandedRequirements = strings.ReplaceAll(expandedRequirements, "{chunk_files}", strings.Join(c.Paths, "\n")) + expandedRequirements = strings.ReplaceAll(expandedRequirements, "{chunk_index}", fmt.Sprintf("%d", c.Index+1)) + expandedRequirements = strings.ReplaceAll(expandedRequirements, "{chunk_total}", fmt.Sprintf("%d", c.Total)) + expandedRequirements = strings.ReplaceAll(expandedRequirements, "{detected_features}", strings.Join(enabledFeatures, ", ")) + } + + messages, err := promptLoader.AssembleMessages(llm.PromptParams{ + RepoName: repoName, + XML: c.XML, + Schema: string(*schema), + Manifest: manifest, + ChunkIndex: c.Index, + ChunkTotal: c.Total, + CustomRequirements: expandedRequirements, + EnabledFeatures: enabledFeatures, + SupplementaryContext: supContext, + }) + if err != nil { + slog.Error("failed to assemble prompt", "chunk", c.Index, "error", err) + doc := sarif.Build(sarif.AnalysisResult{RepoName: repoName}, nil, sarif.BuilderConfig{}) + doc.Runs[0].Invocations = []sarif.SARIFInvocation{{ + ExecutionSuccessful: false, + ToolExecutionNotifications: []sarif.SARIFNotification{{ + Level: "error", + Message: sarif.SARIFMessage{Text: fmt.Sprintf("chunk %d/%d: failed to assemble prompt: %v", c.Index+1, c.Total, err)}, + }}, + }} + return doc, llm.TokenUsage{}, 0, nil + } + + chunkLabel := fmt.Sprintf("analysis chunk %d/%d", c.Index+1, c.Total) + resp, err := client.ChatCompletion(ctx, llm.ChatRequest{ + Label: chunkLabel, + Endpoint: endpoint, + Model: modelCfg.Name, + Messages: messages, + Temperature: modelCfg.Temperature, + MaxTokens: modelCfg.MaxOutputTokens, + ResponseSchema: schema, + OutputMode: outputMode, + ModelParams: modelParams, + }) + if err != nil { + // Context overflow is recoverable by the caller (split and retry). + // Signal it distinctly instead of burying it in a SARIF notification. + if errors.Is(err, llm.ErrContextLengthExceeded) { + slog.Warn("chunk exceeded context window; caller will split and retry", + "chunk", c.Index, + "estimated_tokens", c.Tokens, + ) + return sarif.SARIFDocument{}, llm.TokenUsage{}, 0, err + } + slog.Error("chunk analysis failed", "chunk", c.Index, "error", err) + doc := sarif.Build(sarif.AnalysisResult{RepoName: repoName}, nil, sarif.BuilderConfig{}) + doc.Runs[0].Invocations = []sarif.SARIFInvocation{{ + ExecutionSuccessful: false, + ToolExecutionNotifications: []sarif.SARIFNotification{{ + Level: "error", + Message: sarif.SARIFMessage{Text: fmt.Sprintf("chunk %d/%d failed: %v", c.Index+1, c.Total, err)}, + }}, + }} + return doc, llm.TokenUsage{}, 0, nil + } + + usage := resp.Usage + chunkCost := float64(usage.PromptTokens)*modelCfg.InputPricePerM/1_000_000 + + float64(usage.CompletionTokens)*modelCfg.OutputPricePerM/1_000_000 + + elapsed := time.Since(start) + attrs := []any{ + "chunk", fmt.Sprintf("%d/%d", c.Index+1, c.Total), + "elapsed", elapsed.Round(time.Millisecond), + "prompt_tokens", usage.PromptTokens, + "completion_tokens", usage.CompletionTokens, + "max_output_tokens", modelCfg.MaxOutputTokens, + "finish_reason", resp.FinishReason, + "cost", fmt.Sprintf("$%.4f", chunkCost), + } + if secs := elapsed.Seconds(); secs > 0 { + attrs = append(attrs, "tokens_per_sec", fmt.Sprintf("%.1f", float64(usage.CompletionTokens)/secs)) + } + // Streaming-only: ttft is prompt-processing + thinking (everything before + // the first visible byte), gen_time is the pure output phase. + if resp.TimeToFirstToken > 0 { + attrs = append(attrs, + "ttft", resp.TimeToFirstToken.Round(time.Millisecond), + "gen_time", resp.GenerationTime.Round(time.Millisecond), + ) + } + if usage.ThinkingChars > 0 { + attrs = append(attrs, "thinking_chars", usage.ThinkingChars) + } + if usage.CacheReadTokens > 0 || usage.CacheCreationTokens > 0 { + attrs = append(attrs, + "cache_read_tokens", usage.CacheReadTokens, + "cache_creation_tokens", usage.CacheCreationTokens, + ) + } + if resp.Model != "" && resp.Model != modelCfg.Name { + // API sometimes resolves aliases to dated snapshot names. + attrs = append(attrs, "resolved_model", resp.Model) + } + slog.Info("chunk analysis complete", attrs...) + + // Warn if output was truncated — findings may be incomplete. + if resp.FinishReason == "length" { + slog.Warn("LLM output was truncated (finish_reason=length), findings may be incomplete", + "chunk", fmt.Sprintf("%d/%d", c.Index+1, c.Total), + "completion_tokens", usage.CompletionTokens, + "max_output_tokens", modelCfg.MaxOutputTokens, + ) + } + + // Parse LLM response. Three escalating attempts: + // 1. Direct unmarshal. + // 2. Local repair (strip fences, extract JSON, coerce string→[]). + // 3. Ask the model to reformat its own output against the schema. + // Only the third costs money, and only runs when 1+2 both fail. + var result sarif.AnalysisResult + parseErr := json.Unmarshal([]byte(resp.Content), &result) + if parseErr != nil { + if repaired, changed := llm.RepairJSON(resp.Content); changed { + if err := json.Unmarshal([]byte(repaired), &result); err == nil { + slog.Info("recovered malformed LLM response via local repair", "chunk", c.Index) + parseErr = nil + } + } + } + if parseErr != nil && resp.FinishReason == "length" { + // Truncation, not drift. The JSON was cut mid-object when it hit + // max_tokens. Model-repair gets the same cap — it would re-truncate + // or emit {"security_issues":[]} to fit, which is worse than honest + // failure. Point at the flag that actually fixes this. + slog.Error("output truncated at max_tokens; repair would hit the same cap", + "chunk", c.Index, + "completion_tokens", usage.CompletionTokens, + "max_output_tokens", modelCfg.MaxOutputTokens, + "fix", "increase --max-output-tokens (thinking tokens count against this limit)", + ) + doc := sarif.Build(sarif.AnalysisResult{RepoName: repoName}, nil, sarif.BuilderConfig{}) + doc.Runs[0].Invocations = []sarif.SARIFInvocation{{ + ExecutionSuccessful: false, + ToolExecutionNotifications: []sarif.SARIFNotification{{ + Level: "error", + Message: sarif.SARIFMessage{Text: fmt.Sprintf( + "chunk %d/%d: output truncated at max_tokens=%d (finish_reason=length). "+ + "Increase --max-output-tokens; thinking tokens count against this limit.", + c.Index+1, c.Total, modelCfg.MaxOutputTokens)}, + }}, + }} + return doc, usage, chunkCost, nil + } + if parseErr != nil { + slog.Warn("local JSON repair failed; asking model to reformat", + "chunk", c.Index, + "parse_error", parseErr, + ) + repairResp, repairErr := client.ChatCompletion(ctx, llm.ChatRequest{ + Label: chunkLabel + " repair", + Endpoint: endpoint, + Model: modelCfg.Name, + Messages: []llm.Message{{ + Role: "user", + Content: "The following output failed to parse against the required schema " + + "(error: " + parseErr.Error() + "). " + + "Return ONLY the corrected JSON object, nothing else:\n\n" + resp.Content, + }}, + Temperature: modelCfg.Temperature, + MaxTokens: modelCfg.MaxOutputTokens, + ResponseSchema: schema, + OutputMode: outputMode, + ModelParams: modelParams, + }) + if repairErr == nil { + usage.PromptTokens += repairResp.Usage.PromptTokens + usage.CompletionTokens += repairResp.Usage.CompletionTokens + chunkCost += float64(repairResp.Usage.PromptTokens)*modelCfg.InputPricePerM/1_000_000 + + float64(repairResp.Usage.CompletionTokens)*modelCfg.OutputPricePerM/1_000_000 + repaired, _ := llm.RepairJSON(repairResp.Content) + if err := json.Unmarshal([]byte(repaired), &result); err == nil { + slog.Info("recovered malformed LLM response via model reformat", "chunk", c.Index) + parseErr = nil + } + } + } + if parseErr != nil { + slog.Error("failed to parse LLM response after repair attempts", "chunk", c.Index, "error", parseErr) + doc := sarif.Build(sarif.AnalysisResult{RepoName: repoName}, nil, sarif.BuilderConfig{}) + doc.Runs[0].Invocations = []sarif.SARIFInvocation{{ + ExecutionSuccessful: false, + ToolExecutionNotifications: []sarif.SARIFNotification{{ + Level: "error", + Message: sarif.SARIFMessage{Text: fmt.Sprintf("chunk %d/%d: failed to parse LLM response: %v", c.Index+1, c.Total, parseErr)}, + }}, + }} + return doc, usage, chunkCost, nil + } + + doc := sarif.Build(result, sarif.FileMap(fileMap), sarif.BuilderConfig{ + ToolVersion: version, + }) + return doc, usage, chunkCost, nil +} + +// runFeatureDetection performs a lightweight LLM call to detect which security-relevant +// features the codebase uses. Returns nil (not an error) if detection fails, causing +// the analysis to fall back to including all sections. +func runFeatureDetection( + ctx context.Context, + files []ingest.SourceFile, + repoName string, + client llm.Client, + endpoint string, + modelCfg config.ModelConfig, + promptLoader *llm.PromptLoader, + outputMode llm.OutputMode, + modelParams map[string]any, + counter *chunk.TokenCounter, +) ([]string, float64, error) { + // Build file manifest, capped to fit within the model's context. + // Reserve ~50% of context for the manifest, rest for samples + prompt overhead. + manifestCharBudget := modelCfg.ContextLimit / 2 * 4 // tokens → chars + allPaths := make([]string, len(files)) + for i, f := range files { + allPaths[i] = f.Path + } + manifest := capManifest(allPaths, manifestCharBudget) + + // Build representative code samples (capped at ~2000 tokens). + fileEntries := make([]llm.FileEntry, len(files)) + for i, f := range files { + fileEntries[i] = llm.FileEntry{Path: f.Path, Content: f.Content} + } + samples := llm.BuildFeatureSamples(fileEntries, 2000) + + messages, err := promptLoader.AssembleFeatureDetectionMessages(llm.FeaturePromptParams{ + RepoName: repoName, + Manifest: manifest, + Samples: samples, + }) + if err != nil { + return nil, 0, fmt.Errorf("assembling feature detection prompt: %w", err) + } + + featureSchema := llm.FeatureDetectionSchema() + + // Local estimate of everything we're about to send: messages + tool schema. + // The API's PromptTokens in the response is ground truth for the same + // payload — the ratio between them is a measured correction factor for + // this model and this repo's content, which we can apply to chunk sizing. + localEstimate := 0 + for _, m := range messages { + localEstimate += counter.Count(m.Content) + } + if outputMode == llm.OutputModeToolUse && featureSchema != nil { + localEstimate += counter.Count(string(*featureSchema)) + } + + slog.Info("running feature detection pre-pass") + + resp, err := client.ChatCompletion(ctx, llm.ChatRequest{ + Label: "feature-detection", + Endpoint: endpoint, + Model: modelCfg.Name, + Messages: messages, + Temperature: modelCfg.Temperature, + MaxTokens: modelCfg.MaxOutputTokens, + ResponseSchema: featureSchema, + OutputMode: outputMode, + ModelParams: modelParams, + }) + if err != nil { + return nil, 0, fmt.Errorf("feature detection LLM call: %w", err) + } + + // Correction factor: actual/estimated. 1.0 means the heuristic is exact; + // >1.0 means we undercount and need to shrink chunks; <1.0 means we're + // conservative already. 0 signals "no calibration available". + var correction float64 + if localEstimate > 0 && resp.Usage.PromptTokens > 0 { + correction = float64(resp.Usage.PromptTokens) / float64(localEstimate) + slog.Info("tokenizer calibration measured", + "local_estimate", localEstimate, + "api_actual", resp.Usage.PromptTokens, + "correction_factor", fmt.Sprintf("%.3f", correction), + ) + } + + var result struct { + DetectedFeatures []string `json:"detected_features"` + } + if err := json.Unmarshal([]byte(resp.Content), &result); err != nil { + return nil, correction, fmt.Errorf("parsing feature detection response: %w", err) + } + + return result.DetectedFeatures, correction, nil +} diff --git a/internal/cli/scan_audit.go b/internal/cli/scan_audit.go new file mode 100644 index 0000000..db779c3 --- /dev/null +++ b/internal/cli/scan_audit.go @@ -0,0 +1,460 @@ +package cli + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "strings" + + "github.com/block/codecrucible/internal/chunk" + "github.com/block/codecrucible/internal/config" + "github.com/block/codecrucible/internal/ingest" + "github.com/block/codecrucible/internal/llm" + "github.com/block/codecrucible/internal/sarif" +) + +// AuditResult represents the structured output from the audit phase LLM call. +type AuditResult struct { + AuditedFindings []AuditedFinding `json:"audited_findings"` + NewFindings []NewFinding `json:"new_findings"` + AuditSummary string `json:"audit_summary"` +} + +// AuditedFinding is the audit verdict for a single initial finding. +type AuditedFinding struct { + OriginalIssue string `json:"original_issue"` + FilePath string `json:"file_path"` + StartLine int `json:"start_line"` + EndLine int `json:"end_line"` + Verdict string `json:"verdict"` + Confidence float64 `json:"confidence"` + RefinedSeverity float64 `json:"refined_severity"` + RefinedTechnicalDetails string `json:"refined_technical_details"` + RefinedCWEID string `json:"refined_cwe_id"` + Justification string `json:"justification"` +} + +// NewFinding is an additional finding discovered during the audit phase. +type NewFinding struct { + Issue string `json:"issue"` + FilePath string `json:"file_path"` + StartLine int `json:"start_line"` + EndLine int `json:"end_line"` + TechnicalDetails string `json:"technical_details"` + Severity float64 `json:"severity"` + CWEID string `json:"cwe_id"` + Confidence float64 `json:"confidence"` +} + +// runAuditPhase performs a CWE-specific scrutiny pass on the initial findings. +// It sends the findings + relevant code + CWE prompts to the LLM for validation. +// Returns the audited SARIF document, token usage, and cost. +// Returns nil doc (not error) if the audit fails, allowing fallback to unaudited results. +func runAuditPhase( + ctx context.Context, + doc sarif.SARIFDocument, + repoName string, + client llm.Client, + endpoint string, + modelCfg config.ModelConfig, + promptLoader *llm.PromptLoader, + outputMode llm.OutputMode, + fileMap ingest.FileMap, + confidenceThreshold float64, + modelParams map[string]any, + supContext string, + batchSize int, + productionOnly bool, + counter *chunk.TokenCounter, +) (*sarif.SARIFDocument, llm.TokenUsage, float64) { + slog.Info("starting audit phase", + "findings_to_audit", len(doc.Runs[0].Results), + "batch_size", batchSize, + ) + + // Extract initial findings as AnalysisResult for JSON serialization. + run := doc.Runs[0] + ruleByID := make(map[string]sarif.SARIFRule, len(run.Tool.Driver.Rules)) + for _, rule := range run.Tool.Driver.Rules { + ruleByID[rule.ID] = rule + } + + // claimToVerify wraps an analysis-phase finding as an unverified hypothesis + // for the audit phase. The field names are deliberate: the audit prompt + // must treat `unverified_exploit_sketch` as a claim to test against the + // source code, not as a conclusion to accept. `issue` is kept to allow the + // audit schema's `original_issue` output field to round-trip back into the + // (file, line, issue) match key that applyAuditVerdicts uses. + type claimToVerify struct { + Issue string `json:"issue"` + FilePath string `json:"file_path"` + StartLine int `json:"start_line"` + EndLine int `json:"end_line"` + UnverifiedExploitSketch string `json:"unverified_exploit_sketch"` + Severity float64 `json:"severity"` + CWEID string `json:"cwe_id"` + } + + var findings []claimToVerify + for _, result := range run.Results { + rule := ruleByID[result.RuleID] + var filePath string + var startLine, endLine int + if len(result.Locations) > 0 { + filePath = result.Locations[0].PhysicalLocation.ArtifactLocation.URI + if result.Locations[0].PhysicalLocation.Region != nil { + startLine = result.Locations[0].PhysicalLocation.Region.StartLine + endLine = result.Locations[0].PhysicalLocation.Region.EndLine + } + } + + var severity float64 + if sevStr, ok := rule.Properties["security-severity"].(string); ok { + if _, err := fmt.Sscanf(sevStr, "%f", &severity); err != nil { + slog.Debug("failed to parse security severity", "value", sevStr, "error", err) + } + } + + findings = append(findings, claimToVerify{ + Issue: rule.ShortDescription.Text, + FilePath: filePath, + StartLine: startLine, + EndLine: endLine, + UnverifiedExploitSketch: result.Message.Text, + Severity: severity, + CWEID: sarif.CWEForRule(rule), + }) + } + + // Batch boundaries. 0 (or oversized) = one call, same as before. + if batchSize <= 0 || batchSize >= len(findings) { + batchSize = len(findings) + } + numBatches := (len(findings) + batchSize - 1) / batchSize + + auditSchema := llm.AuditSchema() + + // auditBatch runs one self-contained audit call: just this batch's + // findings, just their CWE IDs, just the files they reference. Each call + // is independently valid so one batch failing doesn't poison the rest. + // Files shared across batches get sent more than once — a deliberate + // tradeoff: more tokens, but each call is stateless and can be retried + // without coordination. + auditBatch := func(batch []claimToVerify, label string) (AuditResult, llm.TokenUsage, float64, error) { + var cweIDs []string + filesNeeded := make(map[string]bool) + for _, f := range batch { + if f.CWEID != "" { + cweIDs = append(cweIDs, f.CWEID) + } + if f.FilePath != "" { + filesNeeded[f.FilePath] = true + } + } + + // Wrap the batch as { "claims_to_verify": [...] } so the prompt can + // refer to the input as claims (unverified hypotheses) rather than + // findings (settled facts). The wrapper framing is anti-anchoring: + // it discourages the audit model from accepting conclusory phrases + // in `unverified_exploit_sketch` without checking the source code. + findingsJSON, err := json.Marshal(map[string]any{"claims_to_verify": batch}) + if err != nil { + return AuditResult{}, llm.TokenUsage{}, 0, fmt.Errorf("marshal findings: %w", err) + } + + var codeCtx strings.Builder + for path := range filesNeeded { + if content, ok := fileMap[path]; ok { + fmt.Fprintf(&codeCtx, "\n%s\n\n\n", path, content) + } + } + + messages, err := promptLoader.AssembleAuditMessages(llm.AuditParams{ + RepoName: repoName, + FindingsJSON: string(findingsJSON), + CodeContext: codeCtx.String(), + CWEIDs: cweIDs, + Schema: string(*auditSchema), + ProductionOnly: productionOnly, + SupplementaryContext: supContext, + }) + if err != nil { + return AuditResult{}, llm.TokenUsage{}, 0, fmt.Errorf("assemble audit prompt: %w", err) + } + + var estTokens int + for _, m := range messages { + estTokens += counter.Count(m.Content) + } + slog.Info("running audit batch", + "label", label, + "findings", len(batch), + "cwe_categories", len(cweIDs), + "files_in_context", len(filesNeeded), + "estimated_tokens", estTokens, + "estimated_input_cost", fmt.Sprintf("$%.4f", float64(estTokens)*modelCfg.InputPricePerM/1_000_000), + ) + + resp, err := client.ChatCompletion(ctx, llm.ChatRequest{ + Label: label, + Endpoint: endpoint, + Model: modelCfg.Name, + Messages: messages, + Temperature: modelCfg.Temperature, + MaxTokens: modelCfg.MaxOutputTokens, + ResponseSchema: auditSchema, + OutputMode: outputMode, + ModelParams: modelParams, + }) + if err != nil { + return AuditResult{}, llm.TokenUsage{}, 0, fmt.Errorf("LLM call: %w", err) + } + + u := resp.Usage + c := float64(u.PromptTokens)*modelCfg.InputPricePerM/1_000_000 + + float64(u.CompletionTokens)*modelCfg.OutputPricePerM/1_000_000 + slog.Info("audit batch complete", + "label", label, + "prompt_tokens", u.PromptTokens, + "completion_tokens", u.CompletionTokens, + "cost", fmt.Sprintf("$%.4f", c), + ) + + var r AuditResult + if err := json.Unmarshal([]byte(resp.Content), &r); err != nil { + // Try local repair (strip markdown fences, extract JSON object) + // before giving up — same pattern as the analysis phase. + if repaired, changed := llm.RepairJSON(resp.Content); changed { + if err2 := json.Unmarshal([]byte(repaired), &r); err2 == nil { + slog.Info("recovered malformed audit response via local repair", "label", label) + return r, u, c, nil + } + } + return AuditResult{}, u, c, fmt.Errorf("parse audit response: %w", err) + } + return r, u, c, nil + } + + // Sequential — the point is to keep each request under the server's + // connection-age limit, not to go faster. + var auditResult AuditResult + var usage llm.TokenUsage + var cost float64 + for i := 0; i < len(findings); i += batchSize { + end := i + batchSize + if end > len(findings) { + end = len(findings) + } + label := "audit" + if numBatches > 1 { + label = fmt.Sprintf("audit %d/%d", i/batchSize+1, numBatches) + } + r, u, c, err := auditBatch(findings[i:end], label) + usage.PromptTokens += u.PromptTokens + usage.CompletionTokens += u.CompletionTokens + cost += c + if err != nil { + slog.Error("audit batch failed; findings in this batch will not be audited", + "label", label, "error", err, "findings", end-i) + continue + } + auditResult.AuditedFindings = append(auditResult.AuditedFindings, r.AuditedFindings...) + auditResult.NewFindings = append(auditResult.NewFindings, r.NewFindings...) + if r.AuditSummary != "" { + if auditResult.AuditSummary != "" { + auditResult.AuditSummary += "\n\n" + } + auditResult.AuditSummary += r.AuditSummary + } + } + + if len(auditResult.AuditedFindings) == 0 && len(auditResult.NewFindings) == 0 { + slog.Error("all audit batches failed") + return nil, usage, cost + } + + // Apply audit verdicts to produce the final SARIF document. + auditedDoc := applyAuditVerdicts(doc, auditResult, fileMap, confidenceThreshold) + + slog.Info("audit phase complete", + "audited", len(auditResult.AuditedFindings), + "new_findings", len(auditResult.NewFindings), + "summary", auditResult.AuditSummary, + ) + + return &auditedDoc, usage, cost +} + +// applyAuditVerdicts takes the original SARIF document and the audit results, +// and produces a new SARIF document with findings filtered, refined, and enriched. +func applyAuditVerdicts( + doc sarif.SARIFDocument, + audit AuditResult, + fileMap ingest.FileMap, + confidenceThreshold float64, +) sarif.SARIFDocument { + run := doc.Runs[0] + + // Build audit lookup by (file_path, start_line, original_issue). + // Using original_issue prevents collisions when multiple findings + // share the same location (e.g. SQL injection + log forging on one line). + type auditKey struct { + filePath string + startLine int + originalIssue string + } + auditByKey := make(map[auditKey]AuditedFinding) + for _, af := range audit.AuditedFindings { + key := auditKey{filePath: af.FilePath, startLine: af.StartLine, originalIssue: af.OriginalIssue} + auditByKey[key] = af + } + + // Process existing results: apply verdicts. + var keptResults []sarif.SARIFResult + ruleByID := make(map[string]sarif.SARIFRule, len(run.Tool.Driver.Rules)) + for _, rule := range run.Tool.Driver.Rules { + ruleByID[rule.ID] = rule + } + + rejected := 0 + refined := 0 + escalated := 0 + confirmed := 0 + + for _, result := range run.Results { + var filePath string + var startLine int + if len(result.Locations) > 0 { + filePath = result.Locations[0].PhysicalLocation.ArtifactLocation.URI + if result.Locations[0].PhysicalLocation.Region != nil { + startLine = result.Locations[0].PhysicalLocation.Region.StartLine + } + } + + rule := ruleByID[result.RuleID] + key := auditKey{filePath: filePath, startLine: startLine, originalIssue: rule.ShortDescription.Text} + af, found := auditByKey[key] + + if !found { + // No audit verdict for this finding — keep as-is. + keptResults = append(keptResults, result) + continue + } + + // Reject findings below confidence threshold. + if af.Verdict == "rejected" || af.Confidence < confidenceThreshold { + rejected++ + slog.Debug("audit: rejected finding", + "issue", af.OriginalIssue, + "file", af.FilePath, + "confidence", af.Confidence, + "reason", af.Justification, + ) + continue + } + + // Apply refinements. + switch af.Verdict { + case "refined": + refined++ + case "escalated": + escalated++ + default: + confirmed++ + } + + // Update the result with refined details. + if af.RefinedTechnicalDetails != "" { + result.Message = sarif.SARIFMessage{ + Text: fmt.Sprintf("%s\n\n[Audit confidence: %.0f%%] %s", + af.RefinedTechnicalDetails, af.Confidence*100, af.Justification), + } + } + + // Update rule severity if refined. + if rule, ok := ruleByID[result.RuleID]; ok && af.RefinedSeverity > 0 { + rule.Properties["security-severity"] = fmt.Sprintf("%.1f", af.RefinedSeverity) + result.Level = severityLevelScan(af.RefinedSeverity) + ruleByID[result.RuleID] = rule + } + + keptResults = append(keptResults, result) + } + + // Add new findings from the audit phase. + newCount := 0 + for _, nf := range audit.NewFindings { + if nf.Confidence < confidenceThreshold { + continue + } + + issue := sarif.SecurityIssue{ + Issue: nf.Issue, + FilePath: nf.FilePath, + StartLine: nf.StartLine, + EndLine: nf.EndLine, + TechnicalDetails: fmt.Sprintf("%s\n\n[Audit confidence: %.0f%%]", nf.TechnicalDetails, nf.Confidence*100), + Severity: nf.Severity, + CWEID: nf.CWEID, + } + + newDoc := sarif.Build(sarif.AnalysisResult{ + SecurityIssues: []sarif.SecurityIssue{issue}, + }, sarif.FileMap(fileMap), sarif.BuilderConfig{ToolVersion: version}) + + if len(newDoc.Runs) > 0 && len(newDoc.Runs[0].Results) > 0 { + keptResults = append(keptResults, newDoc.Runs[0].Results...) + for _, rule := range newDoc.Runs[0].Tool.Driver.Rules { + ruleByID[rule.ID] = rule + } + newCount++ + } + } + + slog.Info("audit verdicts applied", + "confirmed", confirmed, + "refined", refined, + "escalated", escalated, + "rejected", rejected, + "new", newCount, + ) + + // Rebuild rules slice from the map. + var rules []sarif.SARIFRule + usedRuleIDs := make(map[string]bool) + for _, r := range keptResults { + usedRuleIDs[r.RuleID] = true + } + for id, rule := range ruleByID { + if usedRuleIDs[id] { + rules = append(rules, rule) + } + } + if rules == nil { + rules = []sarif.SARIFRule{} + } + if keptResults == nil { + keptResults = []sarif.SARIFResult{} + } + + run.Results = keptResults + run.Tool.Driver.Rules = rules + doc.Runs[0] = run + + return doc +} + +// severityLevelScan maps a numeric severity to a SARIF level (scan package version). +func severityLevelScan(sev float64) string { + switch { + case sev <= 0: + return "none" + case sev < 4.0: + return "note" + case sev < 7.0: + return "warning" + default: + return "error" + } +} diff --git a/internal/cli/scan_helpers.go b/internal/cli/scan_helpers.go new file mode 100644 index 0000000..67c98e4 --- /dev/null +++ b/internal/cli/scan_helpers.go @@ -0,0 +1,383 @@ +package cli + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "net/textproto" + "os" + "strings" + "time" + + "github.com/block/codecrucible/internal/chunk" + "github.com/block/codecrucible/internal/config" + "github.com/block/codecrucible/internal/ingest" + "github.com/block/codecrucible/internal/llm" + "github.com/block/codecrucible/internal/sarif" + "github.com/block/codecrucible/internal/supctx" +) + +// providerPreset describes how to build a client for a given provider. +// The preset controls defaults; any value set explicitly on PhaseConfig wins. +type providerPreset struct { + baseURL string // default base URL (empty = must be provided) + keyEnv string // env var name for the error message (empty = auth not required) + authRequired bool // whether an API key is mandatory + // wireProvider is the provider string passed to the LLM HTTP client, + // which controls URL path and request body format. Providers that use + // OpenAI-compatible APIs set this to "openai". + wireProvider string +} + +// providerPresets maps --provider values to their defaults. Databricks is +// handled separately because it uses host+token env vars rather than a +// single API key. +var providerPresets = map[string]providerPreset{ + "anthropic": { + baseURL: "https://api.anthropic.com", + keyEnv: "ANTHROPIC_API_KEY", + authRequired: true, + wireProvider: "anthropic", + }, + "openai": { + baseURL: "https://api.openai.com", + keyEnv: "OPENAI_API_KEY", + authRequired: true, + wireProvider: "openai", + }, + "google": { + baseURL: "https://generativelanguage.googleapis.com/v1beta/openai", + keyEnv: "GOOGLE_API_KEY", + authRequired: true, + wireProvider: "google", + }, + "ollama": { + baseURL: "http://localhost:11434", + authRequired: false, + wireProvider: "openai", + }, + "openai-compat": { + authRequired: false, + wireProvider: "openai", + }, +} + +// buildPhaseClient constructs an LLM client from a resolved PhaseConfig. +// Everything that varies per-phase (provider, key, timeout, headers, +// base URL, model params) comes from pc. Only Databricks host/token stay +// on cfg — Databricks is an all-provider proxy so per-phase Databricks +// credentials don't really make sense; if you want a different workspace +// per phase, set pc.BaseURL. +// +// Returns (client, endpoint, err). endpoint is empty for direct providers +// where the model name goes in the request body; for Databricks it is the +// serving-endpoint path segment. +func buildPhaseClient(pc config.PhaseConfig, cfg *config.Config) (llm.Client, string, error) { + headers, err := parseCustomHeaders(pc.Headers) + if err != nil { + return nil, "", err + } + + // 0 → 0s → llm.NewClient falls through to its own default (600s). + timeout := time.Duration(pc.RequestTimeout) * time.Second + + // Anthropic has a no-key fallback: the locally-installed claude CLI + // can proxy requests using the user's desktop session. Only Anthropic + // offers this, so it stays a special case rather than table-driven. + if pc.Provider == "anthropic" && pc.APIKey == "" { + if len(pc.ModelParams) > 0 { + slog.Warn("model params are ignored with Claude CLI auth; set an API key to send model params to the Anthropic API") + } + client, err := llm.NewClaudeCLIClient(llm.ClientConfig{ + Provider: "anthropic", + Headers: headers, + Timeout: timeout, + Logger: slog.Default(), + }) + if err != nil { + return nil, "", fmt.Errorf("no API key and Claude CLI auth unavailable (provider=anthropic): %w", err) + } + slog.Info("using Claude CLI authentication for Anthropic requests") + return client, "", nil + } + + // Known providers — use preset defaults, allow overrides. + if preset, ok := providerPresets[pc.Provider]; ok { + if preset.authRequired && pc.APIKey == "" { + return nil, "", fmt.Errorf("no API key for provider=%s (set %s, or phases..api-key)", pc.Provider, preset.keyEnv) + } + baseURL := pc.BaseURL + if baseURL == "" { + if preset.baseURL == "" { + return nil, "", fmt.Errorf("provider=%s requires --base-url (no default URL)", pc.Provider) + } + baseURL = preset.baseURL + } + client := llm.NewClient(llm.ClientConfig{ + BaseURL: baseURL, + Token: pc.APIKey, // empty string is fine for no-auth providers + Provider: preset.wireProvider, + Headers: headers, + MaxRetries: 3, + Timeout: timeout, + Logger: slog.Default(), + }) + return client, "", nil + } + + // Databricks (and any unrecognised provider string, as before). + if cfg.DatabricksHost == "" { + return nil, "", fmt.Errorf("DATABRICKS_HOST is not set (provider=%s)", pc.Provider) + } + if cfg.DatabricksToken == "" { + return nil, "", fmt.Errorf("DATABRICKS_TOKEN is not set (provider=%s)", pc.Provider) + } + baseURL := pc.BaseURL + if baseURL == "" { + baseURL = cfg.DatabricksHost + "/serving-endpoints" + } + client := llm.NewClient(llm.ClientConfig{ + BaseURL: baseURL, + Token: cfg.DatabricksToken, + Provider: "databricks", + Headers: headers, + MaxRetries: 3, + Timeout: timeout, + Logger: slog.Default(), + }) + endpoint := pc.Endpoint + if endpoint == "" { + endpoint = pc.ModelCfg.Endpoint + } + // The LLM client's buildURL appends "/invocations", so strip it if + // present to avoid double-appending (the registry stores + // "model/invocations"). + return client, strings.TrimSuffix(endpoint, "/invocations"), nil +} + +// parseCustomHeaders parses header entries in "Name: Value" format. +func parseCustomHeaders(raw []string) (http.Header, error) { + headers := make(http.Header) + for _, entry := range raw { + trimmed := strings.TrimSpace(entry) + if trimmed == "" { + continue + } + + parts := strings.SplitN(trimmed, ":", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid custom header %q: expected format 'Name: Value'", entry) + } + + name := textproto.CanonicalMIMEHeaderKey(strings.TrimSpace(parts[0])) + value := strings.TrimSpace(parts[1]) + if name == "" { + return nil, fmt.Errorf("invalid custom header %q: header name is empty", entry) + } + if value == "" { + return nil, fmt.Errorf("invalid custom header %q: header value is empty", entry) + } + + headers.Add(name, value) + } + + return headers, nil +} + +// resolveModel looks up the model in the registry or returns defaults. +func resolveModel(name string) config.ModelConfig { + if name == "" { + return config.DefaultModel() + } + if m, ok := config.LookupModel(name); ok { + return m + } + slog.Warn("model not in registry, using defaults", "model", name) + return config.UnknownModelDefaults(name) +} + +// resolvePromptLoader creates a PromptLoader from --prompts-dir or the default location. +func resolvePromptLoader(promptsDir string) (*llm.PromptLoader, error) { + if promptsDir != "" { + return llm.NewPromptLoader(os.DirFS(promptsDir)), nil + } + // Default: look for prompts/default/ directory relative to CWD. + if info, err := os.Stat("prompts/default"); err == nil && info.IsDir() { + return llm.NewPromptLoader(os.DirFS("prompts/default")), nil + } + return nil, fmt.Errorf("prompts directory not found; use --prompts-dir to specify a prompt set (e.g. prompts/default)") +} + +// capManifest truncates a list of file paths to fit within a character budget. +// When truncated, appends a note indicating how many paths were omitted. +func capManifest(paths []string, charBudget int) []string { + if charBudget <= 0 { + return nil + } + total := 0 + for i, p := range paths { + total += len(p) + 1 // +1 for newline + if total > charBudget { + omitted := len(paths) - i + return append(paths[:i], fmt.Sprintf("... and %d more files", omitted)) + } + } + return paths +} + +// outputEmptySARIF produces a valid SARIF document with zero findings. +func outputEmptySARIF(cfg *config.Config) error { + doc := sarif.Build(sarif.AnalysisResult{}, nil, sarif.BuilderConfig{ + ToolVersion: version, + }) + doc.Runs[0].Invocations = []sarif.SARIFInvocation{{ + ExecutionSuccessful: true, + ToolExecutionNotifications: []sarif.SARIFNotification{{ + Level: "note", + Message: sarif.SARIFMessage{Text: "no source files found after filtering"}, + }}, + }} + + output, err := json.MarshalIndent(doc, "", " ") + if err != nil { + return fmt.Errorf("marshaling SARIF: %w", err) + } + + if cfg.Output != "" { + return os.WriteFile(cfg.Output, output, 0644) + } + fmt.Println(string(output)) + return nil +} + +// maxContextBudgetPct is the hard ceiling on how much of the context window +// supplementary sources may claim. Above this the actual scan target starves. +const maxContextBudgetPct = 40 + +// loadSupplementaryContext fetches, optionally compresses, and packs the +// configured context sources into per-phase rendered blocks. Returned +// PackResults carry the token count so the caller can fold it into +// promptOverhead — that's what keeps chunk-budget math honest. +func loadSupplementaryContext( + ctx context.Context, + cfg *config.Config, + counter *chunk.TokenCounter, + promptLoader *llm.PromptLoader, + contextLimit int, +) (analysisCtx, auditCtx supctx.PackResult, err error) { + if len(cfg.ContextSources) == 0 { + return + } + + pct := cfg.ContextBudgetPct + if pct <= 0 { + pct = 15 + } + if pct > maxContextBudgetPct { + return analysisCtx, auditCtx, + fmt.Errorf("--context-budget-pct %d exceeds maximum %d; supplementary context would starve the scan target", pct, maxContextBudgetPct) + } + budget := contextLimit * pct / 100 + + slog.Info("loading supplementary context", + "sources", len(cfg.ContextSources), "budget_tokens", budget, "budget_pct", pct) + + srcs := make([]supctx.Source, len(cfg.ContextSources)) + for i, cs := range cfg.ContextSources { + srcs[i] = supctx.Source{ + Name: cs.Name, + Type: cs.Type, + Location: cs.Location, + Priority: cs.Priority, + Compress: cs.Compress, + Phases: cs.Phases, + Include: cs.Include, + Exclude: cs.Exclude, + } + } + + loaded := supctx.LoadAll(ctx, srcs, counter) + if len(loaded) == 0 { + slog.Warn("no supplementary context loaded (all sources empty or failed)") + return + } + + // Run the optional compression pre-pass when any source opted in. + if anyCompress(loaded) { + cc := &cfg.Phases.ContextCompress + ccClient, _, ccErr := buildPhaseClient(*cc, cfg) + if ccErr != nil { + slog.Warn("context-compress client build failed; skipping compression", "error", ccErr) + } else { + cp, cpErr := promptLoader.LoadContextCompressPrompt() + if cpErr != nil { + slog.Warn("failed to load context compress prompt; skipping compression", "error", cpErr) + } else { + compressor := supctx.Compressor{ + Client: ccClient, + Prompt: *cp, + Counter: counter, + Model: cc.ModelCfg.Name, + } + loaded = compressor.Compress(ctx, loaded, budget) + } + } + } + + analysisCtx = supctx.Pack(supctx.FilterPhase(loaded, "analysis"), budget, counter) + auditCtx = supctx.Pack(supctx.FilterPhase(loaded, "audit"), budget, counter) + + logPack("analysis", analysisCtx) + logPack("audit", auditCtx) + + return analysisCtx, auditCtx, nil +} + +// streamingTokenCount estimates the total token count of the flattened XML by +// iterating FileMap entries one at a time. Each per-file XML string is built, +// counted, and discarded, so peak memory is max(single file XML) rather than +// sum(all file XML). The result closely matches counter.Count(fullXML) because +// the heuristic token counter is linear and additive. +func streamingTokenCount(fm ingest.FileMap, counter *chunk.TokenCounter, cfg ingest.FlattenConfig) int { + paths := make([]string, 0, len(fm)) + for p := range fm { + paths = append(paths, p) + } + + // Envelope: header + directory structure + wrapper. + envelope := ingest.EnvelopeXML(paths, cfg) + total := counter.Count(envelope) + + // Per-file: build XML one at a time, count, discard. + for _, p := range paths { + fileXML := chunk.BuildFileXML(p, fm[p]) + total += counter.Count(fileXML) + } + + return total +} + +func anyCompress(loaded []supctx.Loaded) bool { + for _, l := range loaded { + if l.Compress { + return true + } + } + return false +} + +func logPack(phase string, r supctx.PackResult) { + if r.Tokens == 0 { + return + } + slog.Info("supplementary context packed", "phase", phase, "tokens", r.Tokens, + "dropped", r.Dropped, "truncated", r.Truncated) + for _, d := range r.Dropped { + slog.Warn("context source dropped (over budget)", "phase", phase, "source", d) + } + if r.Truncated != "" { + slog.Warn("context source truncated", "phase", phase, "source", r.Truncated) + } +} diff --git a/internal/cli/scan_ingest.go b/internal/cli/scan_ingest.go new file mode 100644 index 0000000..94bf8af --- /dev/null +++ b/internal/cli/scan_ingest.go @@ -0,0 +1,227 @@ +package cli + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/block/codecrucible/internal/config" + "github.com/block/codecrucible/internal/ingest" +) + +// ingestFiles walks the repo and optionally restricts to --paths. +func ingestFiles(repoRoot string, cfg *config.Config) ([]ingest.SourceFile, error) { + // Always walk from repo root so .gitignore behavior is consistent whether + // or not --paths is provided. + files, err := ingest.WalkDir(repoRoot) + if err != nil { + return nil, err + } + + if len(cfg.Paths) == 0 { + return files, nil + } + + normalizedPaths, err := normalizeScanPaths(repoRoot, cfg.Paths) + if err != nil { + return nil, err + } + + filtered := make([]ingest.SourceFile, 0, len(files)) + for _, f := range files { + if pathMatchesAnyPrefix(f.Path, normalizedPaths) { + filtered = append(filtered, f) + } + } + + return filtered, nil +} + +// normalizeScanPaths validates and normalizes --paths values to repo-relative, +// slash-separated prefixes. +func normalizeScanPaths(repoRoot string, paths []string) ([]string, error) { + normalized := make([]string, 0, len(paths)) + seen := make(map[string]struct{}, len(paths)) + + for _, p := range paths { + absPath := filepath.Clean(filepath.Join(repoRoot, p)) + + relPath, err := filepath.Rel(repoRoot, absPath) + if err != nil { + return nil, fmt.Errorf("walking path %s: %w", p, err) + } + if relPath == ".." || strings.HasPrefix(relPath, ".."+string(filepath.Separator)) { + return nil, fmt.Errorf("walking path %s: path escapes repository root", p) + } + + info, err := os.Stat(absPath) + if err != nil { + return nil, fmt.Errorf("walking path %s: %w", p, err) + } + if !info.IsDir() { + return nil, fmt.Errorf("walking path %s: root path is not a directory: %s", p, absPath) + } + + normalizedPath := filepath.ToSlash(filepath.Clean(relPath)) + if normalizedPath == "." { + return []string{"."}, nil + } + + if _, ok := seen[normalizedPath]; ok { + continue + } + seen[normalizedPath] = struct{}{} + normalized = append(normalized, normalizedPath) + } + + return normalized, nil +} + +// pathMatchesAnyPrefix reports whether path is equal to, or nested under, one +// of the provided repo-relative path prefixes. +func pathMatchesAnyPrefix(path string, prefixes []string) bool { + normalizedPath := filepath.ToSlash(filepath.Clean(path)) + + for _, prefix := range prefixes { + if prefix == "." { + return true + } + if normalizedPath == prefix || strings.HasPrefix(normalizedPath, prefix+"/") { + return true + } + } + + return false +} + +// buildExportSummaries generates one-line summaries of each file's exports/purpose +// using heuristic parsing. Used for cross-chunk context. +func buildExportSummaries(files []ingest.SourceFile) map[string]string { + summaries := make(map[string]string) + for _, f := range files { + summary := summarizeFile(f) + if summary != "" { + summaries[f.Path] = summary + } + } + return summaries +} + +// summarizeFile produces a one-line heuristic summary of a file's exports. +func summarizeFile(f ingest.SourceFile) string { + lines := strings.Split(f.Content, "\n") + var exports []string + + for _, line := range lines { + trimmed := strings.TrimSpace(line) + + // JS/TS exports + if strings.HasPrefix(trimmed, "export ") { + if strings.Contains(trimmed, "function ") || strings.Contains(trimmed, "const ") || + strings.Contains(trimmed, "class ") || strings.Contains(trimmed, "interface ") || + strings.Contains(trimmed, "type ") || strings.Contains(trimmed, "default ") { + // Extract the name + name := extractExportName(trimmed) + if name != "" { + exports = append(exports, name) + } + } + } + + // Go exported functions/types + if f.Language == "go" { + if strings.HasPrefix(trimmed, "func ") { + name := extractGoFuncName(trimmed) + if name != "" && name[0] >= 'A' && name[0] <= 'Z' { + exports = append(exports, name+"()") + } + } else if strings.HasPrefix(trimmed, "type ") && (strings.Contains(trimmed, " struct") || strings.Contains(trimmed, " interface")) { + parts := strings.Fields(trimmed) + if len(parts) >= 2 && parts[1][0] >= 'A' && parts[1][0] <= 'Z' { + exports = append(exports, parts[1]) + } + } + } + + // Python: def and class at module level (no indentation) + if f.Language == "python" && !strings.HasPrefix(line, " ") && !strings.HasPrefix(line, "\t") { + if strings.HasPrefix(trimmed, "def ") || strings.HasPrefix(trimmed, "class ") { + name := extractPyName(trimmed) + if name != "" && !strings.HasPrefix(name, "_") { + exports = append(exports, name) + } + } + } + + if len(exports) >= 8 { + break + } + } + + if len(exports) == 0 { + return "" + } + return strings.Join(exports, ", ") +} + +func extractExportName(line string) string { + // "export function foo(" → "foo()" + // "export const bar =" → "bar" + // "export class Baz" → "Baz" + for _, keyword := range []string{"function ", "const ", "let ", "var ", "class ", "interface ", "type ", "default "} { + idx := strings.Index(line, keyword) + if idx < 0 { + continue + } + rest := line[idx+len(keyword):] + rest = strings.TrimSpace(rest) + // Take until space, paren, equals, or brace. + end := strings.IndexAny(rest, " (={<:") + if end > 0 { + name := rest[:end] + if keyword == "function " { + return name + "()" + } + return name + } + if len(rest) > 0 { + return rest + } + } + return "" +} + +func extractGoFuncName(line string) string { + // "func FooBar(" → "FooBar" + // "func (s *Server) Handle(" → "Handle" + rest := strings.TrimPrefix(line, "func ") + if strings.HasPrefix(rest, "(") { + // Method: skip receiver. + closeIdx := strings.Index(rest, ")") + if closeIdx < 0 { + return "" + } + rest = strings.TrimSpace(rest[closeIdx+1:]) + } + end := strings.IndexByte(rest, '(') + if end > 0 { + return rest[:end] + } + return "" +} + +func extractPyName(line string) string { + // "def foo(..." → "foo" + // "class Bar:" → "Bar" + for _, prefix := range []string{"def ", "class "} { + if strings.HasPrefix(line, prefix) { + rest := strings.TrimPrefix(line, prefix) + end := strings.IndexAny(rest, "(: ") + if end > 0 { + return rest[:end] + } + } + } + return "" +} diff --git a/internal/cli/scan_test.go b/internal/cli/scan_test.go new file mode 100644 index 0000000..5129ae7 --- /dev/null +++ b/internal/cli/scan_test.go @@ -0,0 +1,934 @@ +package cli + +import ( + "bytes" + "encoding/json" + "os" + "path/filepath" + "reflect" + "sort" + "strings" + "testing" + + "github.com/block/codecrucible/internal/chunk" + "github.com/block/codecrucible/internal/config" + "github.com/block/codecrucible/internal/ingest" + "github.com/block/codecrucible/internal/sarif" +) + +// createTestRepo creates a temp directory with a few Go source files for testing. +func createTestRepo(t *testing.T) string { + t.Helper() + dir := t.TempDir() + + // Create a source file. + srcDir := filepath.Join(dir, "src") + if err := os.MkdirAll(srcDir, 0755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(srcDir, "main.go"), []byte("package main\n\nfunc main() {}\n"), 0644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(srcDir, "handler.go"), []byte("package main\n\nfunc handler() {}\n"), 0644); err != nil { + t.Fatal(err) + } + + // Create a prompts/default directory so the scan command can find templates. + promptsDir := filepath.Join(dir, "prompts", "default") + if err := os.MkdirAll(promptsDir, 0755); err != nil { + t.Fatal(err) + } + basePrompt := `system_message: "You are a security analyst." +analysis_intro: "Analyze the code." +infrastructure_note: "" +analysis_requirements_header: "" +custom_requirements_placeholder: "" +repo_info: "Repo: {repo_name}\n{xml_content}" +critical_instructions: "Be thorough." +json_formatting_rules: "Return JSON: {schema}" +` + if err := os.WriteFile(filepath.Join(promptsDir, "security_analysis_base.yaml"), []byte(basePrompt), 0644); err != nil { + t.Fatal(err) + } + + return dir +} + +func TestScanCommand_DryRunWithRepo(t *testing.T) { + dir := createTestRepo(t) + + // Change to the test repo dir so prompts are found. + origDir, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd: %v", err) + } + defer func() { + if err := os.Chdir(origDir); err != nil { + t.Fatalf("restore working directory: %v", err) + } + }() + if err := os.Chdir(dir); err != nil { + t.Fatalf("chdir to test repo: %v", err) + } + + cmd := NewRootCommand() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetErr(buf) + cmd.SetArgs([]string{"scan", "--dry-run", dir}) + + if err := cmd.Execute(); err != nil { + t.Fatalf("Execute failed: %v", err) + } + + // Dry run should print scope info to stdout. + // Note: output goes to os.Stdout from fmt.Printf, not cmd's buffer. + // Just verify no error occurred. +} + +func TestScanCommand_EmptyRepoProducesValidSARIF(t *testing.T) { + // Create an empty directory (no source files). + dir := t.TempDir() + outFile := filepath.Join(dir, "results.sarif") + + cmd := NewRootCommand() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetErr(buf) + cmd.SetArgs([]string{"scan", "--output", outFile, dir}) + + err := cmd.Execute() + if err != nil { + t.Fatalf("Execute failed: %v", err) + } + + // Verify output file is valid SARIF. + data, err := os.ReadFile(outFile) + if err != nil { + t.Fatalf("reading output: %v", err) + } + + var doc sarif.SARIFDocument + if err := json.Unmarshal(data, &doc); err != nil { + t.Fatalf("invalid SARIF JSON: %v", err) + } + + if doc.Version != "2.1.0" { + t.Errorf("version: got %q, want %q", doc.Version, "2.1.0") + } + if len(doc.Runs) != 1 { + t.Fatalf("expected 1 run, got %d", len(doc.Runs)) + } + if len(doc.Runs[0].Results) != 0 { + t.Errorf("expected 0 results for empty repo, got %d", len(doc.Runs[0].Results)) + } + // Should have a notification about no source files. + if len(doc.Runs[0].Invocations) > 0 && len(doc.Runs[0].Invocations[0].ToolExecutionNotifications) > 0 { + msg := doc.Runs[0].Invocations[0].ToolExecutionNotifications[0].Message.Text + if !strings.Contains(msg, "no source files") { + t.Errorf("expected notification about no source files, got: %s", msg) + } + } +} + +func TestScanCommand_MissingCredentialsReturnsError(t *testing.T) { + dir := createTestRepo(t) + origDir, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd: %v", err) + } + defer func() { + if err := os.Chdir(origDir); err != nil { + t.Fatalf("restore working directory: %v", err) + } + }() + if err := os.Chdir(dir); err != nil { + t.Fatalf("chdir to test repo: %v", err) + } + + // Clear all provider env vars so no provider can be configured. + t.Setenv("DATABRICKS_HOST", "") + t.Setenv("DATABRICKS_TOKEN", "") + t.Setenv("ANTHROPIC_API_KEY", "") + t.Setenv("OPENAI_API_KEY", "") + t.Setenv("PATH", t.TempDir()) + + cmd := NewRootCommand() + cmd.SetArgs([]string{"scan", dir}) + + err = cmd.Execute() + if err == nil { + t.Fatal("expected error for missing credentials") + } + // With no API keys and no Claude CLI on PATH, auth setup should fail. + // Exact wording varies by which provider the resolver picked; match + // on the universal bit. + if !strings.Contains(strings.ToLower(err.Error()), "api key") && !strings.Contains(err.Error(), "is not set") { + t.Errorf("expected missing-credentials error, got: %v", err) + } +} + +func TestBuildLLMClient_AnthropicFallsBackToClaudeCLI(t *testing.T) { + dir := t.TempDir() + claudePath := filepath.Join(dir, "claude") + script := "#!/bin/sh\nexit 0\n" + if err := os.WriteFile(claudePath, []byte(script), 0755); err != nil { + t.Fatalf("write fake claude: %v", err) + } + t.Setenv("PATH", dir) + + client, endpoint, err := buildPhaseClient(config.PhaseConfig{ + Provider: "anthropic", + ModelCfg: config.ModelConfig{Name: "claude-sonnet-4-6"}, + // APIKey deliberately empty — exercises the CLI fallback path. + }, &config.Config{}) + if err != nil { + t.Fatalf("buildPhaseClient returned error: %v", err) + } + if client == nil { + t.Fatal("expected non-nil client") + } + if endpoint != "" { + t.Fatalf("endpoint = %q, want empty", endpoint) + } +} + +func TestScanCommand_MaxCostAbort(t *testing.T) { + // The default models have $0 pricing, so max-cost only triggers with + // non-zero pricing. We test the check in resolveModel + cost estimation. + // Instead, we verify the max-cost flag is accepted and dry-run shows cost. + dir := createTestRepo(t) + origDir, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd: %v", err) + } + defer func() { + if err := os.Chdir(origDir); err != nil { + t.Fatalf("restore working directory: %v", err) + } + }() + if err := os.Chdir(dir); err != nil { + t.Fatalf("chdir to test repo: %v", err) + } + + cmd := NewRootCommand() + cmd.SetArgs([]string{"scan", "--max-cost", "100", "--dry-run", dir}) + + err = cmd.Execute() + if err != nil { + t.Fatalf("Execute failed: %v", err) + } +} + +func TestResolveModel_Default(t *testing.T) { + m := resolveModel("") + if m.Name != "claude-sonnet-4-6" { + t.Errorf("expected default model claude-sonnet-4-6, got %s", m.Name) + } +} + +func TestResolveModel_Known(t *testing.T) { + m := resolveModel("gpt-5.2") + if m.Name != "gpt-5.2" { + t.Errorf("expected gpt-5.2, got %s", m.Name) + } + if m.ContextLimit != 400000 { + t.Errorf("expected context limit 400000, got %d", m.ContextLimit) + } +} + +func TestResolveModel_Unknown(t *testing.T) { + m := resolveModel("custom-model-v2") + if m.Name != "custom-model-v2" { + t.Errorf("expected custom-model-v2, got %s", m.Name) + } + if m.ContextLimit != 128000 { + t.Errorf("expected default context limit 128000, got %d", m.ContextLimit) + } +} + +func TestIngestFiles_PathsUseRepoRootGitignore(t *testing.T) { + repo := t.TempDir() + + if err := os.WriteFile(filepath.Join(repo, ".gitignore"), []byte("firmware/generated/\n"), 0644); err != nil { + t.Fatalf("write .gitignore: %v", err) + } + if err := os.MkdirAll(filepath.Join(repo, "firmware", "generated"), 0755); err != nil { + t.Fatalf("mkdir generated: %v", err) + } + if err := os.MkdirAll(filepath.Join(repo, "firmware", "src"), 0755); err != nil { + t.Fatalf("mkdir src: %v", err) + } + if err := os.WriteFile(filepath.Join(repo, "firmware", "generated", "ignored.go"), []byte("package generated\n"), 0644); err != nil { + t.Fatalf("write ignored.go: %v", err) + } + if err := os.WriteFile(filepath.Join(repo, "firmware", "src", "kept.go"), []byte("package main\n"), 0644); err != nil { + t.Fatalf("write kept.go: %v", err) + } + + files, err := ingestFiles(repo, &config.Config{Paths: []string{"firmware/"}}) + if err != nil { + t.Fatalf("ingestFiles: %v", err) + } + + if len(files) != 1 { + got := make([]string, len(files)) + for i := range files { + got[i] = files[i].Path + } + t.Fatalf("expected 1 file after gitignore + paths filtering, got %d: %v", len(files), got) + } + if files[0].Path != "firmware/src/kept.go" { + t.Fatalf("expected firmware/src/kept.go, got %s", files[0].Path) + } +} + +func TestIngestFiles_PathsDoNotDuplicateOverlappingEntries(t *testing.T) { + repo := t.TempDir() + + if err := os.MkdirAll(filepath.Join(repo, "firmware", "src"), 0755); err != nil { + t.Fatalf("mkdir src: %v", err) + } + if err := os.WriteFile(filepath.Join(repo, "firmware", "src", "main.go"), []byte("package main\n"), 0644); err != nil { + t.Fatalf("write main.go: %v", err) + } + + files, err := ingestFiles(repo, &config.Config{Paths: []string{"firmware", "firmware/src"}}) + if err != nil { + t.Fatalf("ingestFiles: %v", err) + } + + if len(files) != 1 { + got := make([]string, len(files)) + for i := range files { + got[i] = files[i].Path + } + t.Fatalf("expected 1 deduplicated file, got %d: %v", len(files), got) + } + if files[0].Path != "firmware/src/main.go" { + t.Fatalf("expected firmware/src/main.go, got %s", files[0].Path) + } +} + +func TestIngestAndFilter_PathsAndFiltersMatrix(t *testing.T) { + repo := t.TempDir() + + writeRepoFile := func(relPath, content string) { + t.Helper() + fullPath := filepath.Join(repo, relPath) + if err := os.MkdirAll(filepath.Dir(fullPath), 0755); err != nil { + t.Fatalf("mkdir %s: %v", filepath.Dir(relPath), err) + } + if err := os.WriteFile(fullPath, []byte(content), 0644); err != nil { + t.Fatalf("write %s: %v", relPath, err) + } + } + + writeRepoFile(".gitignore", "firmware/generated/\n") + writeRepoFile("firmware/main.c", "int main() { return 0; }\n") + writeRepoFile("firmware/src/boot.c", "int boot() { return 1; }\n") + writeRepoFile("firmware/third-party/direct.c", "int dep() { return 2; }\n") + writeRepoFile("firmware/third-party/nested/deep.c", "int deep() { return 3; }\n") + writeRepoFile("firmware/generated/should_be_ignored.c", "int ignored() { return 4; }\n") + writeRepoFile("app/main.go", "package main\n") + + testCases := []struct { + name string + paths []string + include []string + exclude []string + want []string + }{ + { + name: "no_paths_exclude_recursive_directory", + exclude: []string{"firmware/third-party/**"}, + want: []string{"app/main.go", "firmware/main.c", "firmware/src/boot.c"}, + }, + { + name: "firmware_path_with_recursive_exclude", + paths: []string{"firmware"}, + exclude: []string{"firmware/third-party/**"}, + want: []string{"firmware/main.c", "firmware/src/boot.c"}, + }, + { + name: "include_overrides_recursive_exclude_within_path_scope", + paths: []string{"firmware"}, + include: []string{"firmware/third-party/nested/deep.c"}, + exclude: []string{"firmware/third-party/**"}, + want: []string{"firmware/main.c", "firmware/src/boot.c", "firmware/third-party/nested/deep.c"}, + }, + { + name: "subpath_scope_respected_with_exclude", + paths: []string{"firmware/src"}, + exclude: []string{"firmware/third-party/**"}, + want: []string{"firmware/src/boot.c"}, + }, + { + name: "exclude_can_eliminate_all_files_in_selected_path", + paths: []string{"firmware/third-party"}, + exclude: []string{"firmware/third-party/**"}, + want: []string{}, + }, + { + name: "overlapping_paths_with_broad_exclude_and_specific_include", + paths: []string{"firmware", "firmware/src"}, + include: []string{"firmware/src/boot.c"}, + exclude: []string{"firmware/**"}, + want: []string{"firmware/src/boot.c"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + files, err := ingestFiles(repo, &config.Config{Paths: tc.paths}) + if err != nil { + t.Fatalf("ingestFiles: %v", err) + } + + kept, _ := ingest.FilterFiles(files, ingest.FilterConfig{ + IncludeTests: true, + IncludeDocs: true, + Include: tc.include, + Exclude: tc.exclude, + MaxFileSize: 0, + }) + + got := make([]string, len(kept)) + for i, f := range kept { + got[i] = f.Path + } + sort.Strings(got) + + want := make([]string, len(tc.want)) + copy(want, tc.want) + sort.Strings(want) + + if !reflect.DeepEqual(got, want) { + t.Fatalf("kept files mismatch\n got: %v\n want: %v", got, want) + } + }) + } +} + +func TestParseCustomHeaders_Valid(t *testing.T) { + headers, err := parseCustomHeaders([]string{ + "anthropic-beta: context-1m-2025-08-07", + "x-feature-flag: enabled", + "anthropic-beta: another-beta", + }) + if err != nil { + t.Fatalf("parseCustomHeaders returned error: %v", err) + } + + betaValues := headers.Values("Anthropic-Beta") + if len(betaValues) != 2 { + t.Fatalf("expected 2 Anthropic-Beta values, got %d: %v", len(betaValues), betaValues) + } + if betaValues[0] != "context-1m-2025-08-07" || betaValues[1] != "another-beta" { + t.Fatalf("unexpected Anthropic-Beta values: %v", betaValues) + } + + if got := headers.Get("X-Feature-Flag"); got != "enabled" { + t.Fatalf("expected X-Feature-Flag=enabled, got %q", got) + } +} + +func TestParseCustomHeaders_Invalid(t *testing.T) { + testCases := []struct { + name string + headers []string + }{ + {name: "missing_separator", headers: []string{"invalid"}}, + {name: "empty_name", headers: []string{": value"}}, + {name: "empty_value", headers: []string{"x-test: "}}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, err := parseCustomHeaders(tc.headers) + if err == nil { + t.Fatalf("expected parseCustomHeaders error for %v", tc.headers) + } + }) + } +} + +func TestSeverityLevelScan(t *testing.T) { + testCases := []struct { + name string + sev float64 + want string + }{ + {name: "zero", sev: 0.0, want: "none"}, + {name: "negative", sev: -1.5, want: "none"}, + {name: "low", sev: 2.0, want: "note"}, + {name: "boundary_note_warning", sev: 3.9, want: "note"}, + {name: "medium_low", sev: 4.0, want: "warning"}, + {name: "medium", sev: 5.5, want: "warning"}, + {name: "boundary_warning_error", sev: 6.9, want: "warning"}, + {name: "high", sev: 7.0, want: "error"}, + {name: "critical", sev: 9.8, want: "error"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := severityLevelScan(tc.sev) + if got != tc.want { + t.Errorf("severityLevelScan(%v) = %q, want %q", tc.sev, got, tc.want) + } + }) + } +} + +func TestCWEForRuleScan(t *testing.T) { + testCases := []struct { + name string + rule sarif.SARIFRule + want string + }{ + { + name: "from_relationship", + rule: sarif.SARIFRule{ + Relationships: []sarif.SARIFRelationship{{ + Target: sarif.SARIFRelationshipTarget{ID: "CWE-89"}, + }}, + }, + want: "CWE-89", + }, + { + name: "relationship_not_cwe_prefix", + rule: sarif.SARIFRule{ + Relationships: []sarif.SARIFRelationship{{ + Target: sarif.SARIFRelationshipTarget{ID: "OWASP-A01"}, + }}, + }, + want: "", + }, + { + name: "from_tags", + rule: sarif.SARIFRule{ + Properties: map[string]any{ + "tags": []string{"security", "external/cwe/cwe-79"}, + }, + }, + want: "CWE-79", + }, + { + name: "tags_wrong_type", + rule: sarif.SARIFRule{ + Properties: map[string]any{ + "tags": "not-a-slice", + }, + }, + want: "", + }, + { + name: "relationship_wins_over_tags", + rule: sarif.SARIFRule{ + Relationships: []sarif.SARIFRelationship{{ + Target: sarif.SARIFRelationshipTarget{ID: "CWE-22"}, + }}, + Properties: map[string]any{ + "tags": []string{"external/cwe/cwe-79"}, + }, + }, + want: "CWE-22", + }, + { + name: "empty_rule", + rule: sarif.SARIFRule{}, + want: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := sarif.CWEForRule(tc.rule) + if got != tc.want { + t.Errorf("sarif.CWEForRule() = %q, want %q", got, tc.want) + } + }) + } +} + +func TestExtractExportName(t *testing.T) { + testCases := []struct { + name string + line string + want string + }{ + {name: "function", line: "export function foo() {", want: "foo()"}, + {name: "async_function", line: "export async function fetchData(url) {", want: "fetchData()"}, + {name: "const", line: "export const bar = 42;", want: "bar"}, + {name: "let", line: "export let mutable = true;", want: "mutable"}, + {name: "class", line: "export class Baz {", want: "Baz"}, + {name: "interface", line: "export interface Props {", want: "Props"}, + {name: "type_alias", line: "export type ID = string;", want: "ID"}, + {name: "generic_type", line: "export type Result = T | Error;", want: "Result"}, + {name: "default_class", line: "export default class App {", want: "App"}, + {name: "no_keyword_match", line: "export { foo, bar };", want: ""}, + {name: "empty_after_keyword", line: "export const ", want: ""}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := extractExportName(tc.line) + if got != tc.want { + t.Errorf("extractExportName(%q) = %q, want %q", tc.line, got, tc.want) + } + }) + } +} + +func TestExtractPyName(t *testing.T) { + testCases := []struct { + name string + line string + want string + }{ + {name: "def", line: "def foo():", want: "foo"}, + {name: "def_with_args", line: "def process_items(items, *, key=None):", want: "process_items"}, + {name: "class", line: "class Bar:", want: "Bar"}, + {name: "class_with_base", line: "class Child(Parent):", want: "Child"}, + {name: "class_space_before_colon", line: "class Thing :", want: "Thing"}, + {name: "no_match", line: "import os", want: ""}, + {name: "def_no_paren", line: "def ", want: ""}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := extractPyName(tc.line) + if got != tc.want { + t.Errorf("extractPyName(%q) = %q, want %q", tc.line, got, tc.want) + } + }) + } +} + +func TestCapManifest(t *testing.T) { + t.Run("fits_within_budget", func(t *testing.T) { + paths := []string{"a.go", "b.go", "c.go"} + got := capManifest(paths, 100) + if !reflect.DeepEqual(got, paths) { + t.Errorf("got %v, want %v", got, paths) + } + }) + + t.Run("zero_budget", func(t *testing.T) { + got := capManifest([]string{"a.go"}, 0) + if got != nil { + t.Errorf("got %v, want nil", got) + } + }) + + t.Run("negative_budget", func(t *testing.T) { + got := capManifest([]string{"a.go"}, -10) + if got != nil { + t.Errorf("got %v, want nil", got) + } + }) + + t.Run("truncates_and_appends_note", func(t *testing.T) { + // capManifest mutates the input via append; use a fresh slice. + paths := []string{"aaaa", "bbbb", "cccc", "dddd"} + // Each path costs len+1; "aaaa" = 5, "bbbb" = 10. Budget 9 admits + // only the first path before the second blows past it. + got := capManifest(paths, 9) + if len(got) != 2 { + t.Fatalf("expected 2 entries (1 kept + note), got %d: %v", len(got), got) + } + if got[0] != "aaaa" { + t.Errorf("first kept = %q, want %q", got[0], "aaaa") + } + if !strings.Contains(got[1], "3 more files") { + t.Errorf("omitted note = %q, want to contain %q", got[1], "3 more files") + } + }) + + t.Run("empty_input", func(t *testing.T) { + got := capManifest(nil, 100) + if len(got) != 0 { + t.Errorf("got %v, want empty", got) + } + }) +} + +func TestResolvePromptLoader(t *testing.T) { + t.Run("explicit_dir", func(t *testing.T) { + dir := t.TempDir() + loader, err := resolvePromptLoader(dir) + if err != nil { + t.Fatalf("resolvePromptLoader(%q) returned error: %v", dir, err) + } + if loader == nil { + t.Fatal("expected non-nil loader") + } + }) + + t.Run("default_cwd_prompts", func(t *testing.T) { + dir := t.TempDir() + if err := os.MkdirAll(filepath.Join(dir, "prompts", "default"), 0755); err != nil { + t.Fatal(err) + } + origDir, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd: %v", err) + } + defer func() { + if err := os.Chdir(origDir); err != nil { + t.Fatalf("restore working directory: %v", err) + } + }() + if err := os.Chdir(dir); err != nil { + t.Fatalf("chdir: %v", err) + } + + loader, err := resolvePromptLoader("") + if err != nil { + t.Fatalf("resolvePromptLoader(\"\") returned error: %v", err) + } + if loader == nil { + t.Fatal("expected non-nil loader") + } + }) + + t.Run("not_found", func(t *testing.T) { + dir := t.TempDir() + origDir, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd: %v", err) + } + defer func() { + if err := os.Chdir(origDir); err != nil { + t.Fatalf("restore working directory: %v", err) + } + }() + if err := os.Chdir(dir); err != nil { + t.Fatalf("chdir: %v", err) + } + + _, err = resolvePromptLoader("") + if err == nil { + t.Fatal("expected error when prompts directory not found") + } + if !strings.Contains(err.Error(), "prompts directory not found") { + t.Errorf("unexpected error message: %v", err) + } + }) +} + +// mkSARIFResult builds a SARIFResult with a single physical location. Used by +// applyAuditVerdicts tests so fixtures stay readable. +func mkSARIFResult(ruleID, file string, line int, msg string) sarif.SARIFResult { + return sarif.SARIFResult{ + RuleID: ruleID, + Level: "warning", + Message: sarif.SARIFMessage{Text: msg}, + Locations: []sarif.SARIFLocation{{ + PhysicalLocation: sarif.SARIFPhysicalLocation{ + ArtifactLocation: sarif.SARIFArtifactLocation{URI: file}, + Region: &sarif.SARIFRegion{StartLine: line}, + }, + }}, + } +} + +func TestApplyAuditVerdicts_RejectsAndConfirms(t *testing.T) { + doc := sarif.SARIFDocument{ + Runs: []sarif.SARIFRun{{ + Tool: sarif.SARIFTool{Driver: sarif.SARIFDriver{ + Rules: []sarif.SARIFRule{ + {ID: "R1", Properties: map[string]any{}}, + {ID: "R2", Properties: map[string]any{}}, + }, + }}, + Results: []sarif.SARIFResult{ + mkSARIFResult("R1", "src/a.go", 10, "finding A"), + mkSARIFResult("R2", "src/b.go", 20, "finding B"), + }, + }}, + } + audit := AuditResult{ + AuditedFindings: []AuditedFinding{ + {FilePath: "src/a.go", StartLine: 10, Verdict: "rejected", Confidence: 0.9}, + {FilePath: "src/b.go", StartLine: 20, Verdict: "confirmed", Confidence: 0.95}, + }, + } + + out := applyAuditVerdicts(doc, audit, ingest.FileMap{}, 0.5) + + if len(out.Runs[0].Results) != 1 { + t.Fatalf("expected 1 kept result, got %d", len(out.Runs[0].Results)) + } + if out.Runs[0].Results[0].RuleID != "R2" { + t.Errorf("kept RuleID = %q, want %q", out.Runs[0].Results[0].RuleID, "R2") + } + // Only rules for kept results should survive. + if len(out.Runs[0].Tool.Driver.Rules) != 1 { + t.Fatalf("expected 1 kept rule, got %d", len(out.Runs[0].Tool.Driver.Rules)) + } + if out.Runs[0].Tool.Driver.Rules[0].ID != "R2" { + t.Errorf("kept rule ID = %q, want %q", out.Runs[0].Tool.Driver.Rules[0].ID, "R2") + } +} + +func TestApplyAuditVerdicts_ConfidenceThreshold(t *testing.T) { + doc := sarif.SARIFDocument{ + Runs: []sarif.SARIFRun{{ + Tool: sarif.SARIFTool{Driver: sarif.SARIFDriver{ + Rules: []sarif.SARIFRule{{ID: "R1", Properties: map[string]any{}}}, + }}, + Results: []sarif.SARIFResult{ + mkSARIFResult("R1", "src/a.go", 10, "finding A"), + }, + }}, + } + audit := AuditResult{ + AuditedFindings: []AuditedFinding{ + // Confirmed but below threshold — should be dropped. + {FilePath: "src/a.go", StartLine: 10, Verdict: "confirmed", Confidence: 0.3}, + }, + } + + out := applyAuditVerdicts(doc, audit, ingest.FileMap{}, 0.7) + + if len(out.Runs[0].Results) != 0 { + t.Fatalf("expected 0 results after confidence cutoff, got %d", len(out.Runs[0].Results)) + } +} + +func TestApplyAuditVerdicts_RefinesSeverityAndMessage(t *testing.T) { + doc := sarif.SARIFDocument{ + Runs: []sarif.SARIFRun{{ + Tool: sarif.SARIFTool{Driver: sarif.SARIFDriver{ + Rules: []sarif.SARIFRule{{ID: "R1", Properties: map[string]any{}}}, + }}, + Results: []sarif.SARIFResult{ + mkSARIFResult("R1", "src/a.go", 10, "original"), + }, + }}, + } + audit := AuditResult{ + AuditedFindings: []AuditedFinding{{ + FilePath: "src/a.go", + StartLine: 10, + Verdict: "refined", + Confidence: 0.9, + RefinedSeverity: 8.5, + RefinedTechnicalDetails: "new details", + Justification: "because", + }}, + } + + out := applyAuditVerdicts(doc, audit, ingest.FileMap{}, 0.5) + + if len(out.Runs[0].Results) != 1 { + t.Fatalf("expected 1 result, got %d", len(out.Runs[0].Results)) + } + result := out.Runs[0].Results[0] + if result.Level != "error" { + t.Errorf("Level = %q, want %q (severity 8.5)", result.Level, "error") + } + if !strings.Contains(result.Message.Text, "new details") { + t.Errorf("message %q missing refined details", result.Message.Text) + } + if !strings.Contains(result.Message.Text, "90%") { + t.Errorf("message %q missing confidence percentage", result.Message.Text) + } + rule := out.Runs[0].Tool.Driver.Rules[0] + if rule.Properties["security-severity"] != "8.5" { + t.Errorf("rule security-severity = %v, want %q", rule.Properties["security-severity"], "8.5") + } +} + +func TestApplyAuditVerdicts_UnauditedKeptAsIs(t *testing.T) { + doc := sarif.SARIFDocument{ + Runs: []sarif.SARIFRun{{ + Tool: sarif.SARIFTool{Driver: sarif.SARIFDriver{ + Rules: []sarif.SARIFRule{{ID: "R1", Properties: map[string]any{}}}, + }}, + Results: []sarif.SARIFResult{ + mkSARIFResult("R1", "src/a.go", 10, "untouched"), + }, + }}, + } + + out := applyAuditVerdicts(doc, AuditResult{}, ingest.FileMap{}, 0.5) + + if len(out.Runs[0].Results) != 1 { + t.Fatalf("expected 1 kept result, got %d", len(out.Runs[0].Results)) + } + if out.Runs[0].Results[0].Message.Text != "untouched" { + t.Errorf("message = %q, want unchanged", out.Runs[0].Results[0].Message.Text) + } +} + +func TestApplyAuditVerdicts_NewFindings(t *testing.T) { + doc := sarif.SARIFDocument{ + Runs: []sarif.SARIFRun{{ + Tool: sarif.SARIFTool{Driver: sarif.SARIFDriver{Rules: []sarif.SARIFRule{}}}, + Results: []sarif.SARIFResult{}, + }}, + } + audit := AuditResult{ + NewFindings: []NewFinding{ + { + Issue: "SQL injection", + FilePath: "src/db.go", + StartLine: 42, + EndLine: 44, + TechnicalDetails: "raw query with user input", + Severity: 8.0, + CWEID: "CWE-89", + Confidence: 0.95, + }, + { + // Below threshold — should be filtered out. + Issue: "maybe issue", + FilePath: "src/x.go", + StartLine: 1, + Confidence: 0.2, + }, + }, + } + + out := applyAuditVerdicts(doc, audit, ingest.FileMap{}, 0.5) + + if len(out.Runs[0].Results) != 1 { + t.Fatalf("expected 1 new result, got %d", len(out.Runs[0].Results)) + } + if len(out.Runs[0].Tool.Driver.Rules) != 1 { + t.Fatalf("expected 1 new rule, got %d", len(out.Runs[0].Tool.Driver.Rules)) + } +} + +func TestStreamingTokenCount_MatchesFullXMLCount(t *testing.T) { + files := []ingest.SourceFile{ + {Path: "cmd/main.go", Content: "package main\n\nimport \"fmt\"\n\nfunc main() {\n\tfmt.Println(\"hello\")\n}\n", LineCount: 7, Language: "go"}, + {Path: "internal/handler.go", Content: "package internal\n\ntype Handler struct {\n\tDB *sql.DB\n}\n\nfunc (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {\n\tw.Write([]byte(\"ok\"))\n}\n", LineCount: 9, Language: "go"}, + {Path: "config.yaml", Content: "server:\n port: 8080\n host: localhost\n", LineCount: 3, Language: "yaml"}, + } + + cfg := ingest.FlattenConfig{Compress: false} + full := ingest.Flatten(files, cfg) + counter := chunk.NewTokenCounter("cl100k_base", nil) + + fullCount := counter.Count(full.XML) + streamCount := streamingTokenCount(full.FileMap, counter, cfg) + + // The streaming count sums envelope + per-file independently. The full + // count includes inter-file blank lines and wrapper that + // the streaming count approximates. Allow 5% divergence. + diff := fullCount - streamCount + if diff < 0 { + diff = -diff + } + maxDrift := fullCount / 20 // 5% + if maxDrift < 5 { + maxDrift = 5 + } + if diff > maxDrift { + t.Errorf("streaming count %d diverges from full XML count %d by %d (max allowed %d)", + streamCount, fullCount, diff, maxDrift) + } +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..de7956e --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,302 @@ +package config + +import ( + "encoding/json" + "fmt" + "strconv" + "strings" + + "github.com/spf13/viper" +) + +// Config holds the application configuration loaded from Viper. +type Config struct { + // Global options + Verbose bool `mapstructure:"verbose"` + + // Scan options + Paths []string `mapstructure:"paths"` + Model string `mapstructure:"model"` + FailOnSeverity float64 `mapstructure:"fail-on-severity"` + MaxCost float64 `mapstructure:"max-cost"` + DryRun bool `mapstructure:"dry-run"` + IncludeTests bool `mapstructure:"include-tests"` + IncludeDocs bool `mapstructure:"include-docs"` + Compress bool `mapstructure:"compress"` + CustomRequirements string `mapstructure:"custom-requirements"` + SkipFeatureDetection bool `mapstructure:"skip-feature-detection"` + FeatureDetectionModel string `mapstructure:"feature-detection-model"` + Concurrency int `mapstructure:"concurrency"` + Output string `mapstructure:"output"` + PhaseOutputDir string `mapstructure:"phase-output-dir"` + PromptsDir string `mapstructure:"prompts-dir"` + Include []string `mapstructure:"include"` + Exclude []string `mapstructure:"exclude"` + CustomHeaders []string `mapstructure:"custom-headers"` + BaseURL string `mapstructure:"base-url"` + ModelParams map[string]any `mapstructure:"model-params"` + ModelParamsJSON string `mapstructure:"model-params-json"` + MaxFileSize int `mapstructure:"max-file-size"` + ContextLimit int `mapstructure:"context-limit"` + MaxOutputTokens int `mapstructure:"max-output-tokens"` + RequestTimeout int `mapstructure:"request-timeout"` // seconds; 0 = client default + + // Audit phase + SkipAudit bool `mapstructure:"skip-audit"` + AuditModel string `mapstructure:"audit-model"` + AuditConfidenceThreshold float64 `mapstructure:"audit-confidence-threshold"` + AuditBatchSize int `mapstructure:"audit-batch-size"` + + // Supplementary context: reference material injected into analysis and + // audit prompts. See internal/supctx. + ContextSources []ContextSource `mapstructure:"context-sources"` + ContextBudgetPct int `mapstructure:"context-budget-pct"` // % of context window reserved for sources (default 15, max 40) + ContextSourcesRaw []string `mapstructure:"context-sources-raw"` // CLI form: "name=X,type=Y,location=Z,priority=N" + + // Provider selection + Provider string `mapstructure:"provider"` // "databricks", "anthropic", "openai", "google" (auto-detected if empty) + + // Databricks options (from env vars) + DatabricksHost string `mapstructure:"databricks-host"` + DatabricksToken string `mapstructure:"databricks-token"` + DatabricksEndpoint string `mapstructure:"databricks-endpoint"` + + // Anthropic options (from env vars) + AnthropicAPIKey string `mapstructure:"anthropic-api-key"` + + // OpenAI options (from env vars) + OpenAIAPIKey string `mapstructure:"openai-api-key"` + + // Google options (from env vars) + GoogleAPIKey string `mapstructure:"google-api-key"` + + // Phases carries per-phase (provider, model, api_key, model_params, + // ...) tuples. The flat fields above act as defaults for the analysis + // phase; feature-detection and audit inherit from analysis. See + // ResolvePhases. Populated by config file (phases.audit.provider: ...) + // or env (PHASES_AUDIT_PROVIDER=...). + Phases Phases `mapstructure:"phases"` + + // Models extends (or overrides) the built-in model registry. Each entry + // is passed to RegisterModel at Load time, keyed by Name — so user + // entries sharing a built-in name replace the built-in wholesale. Lets + // operators add a new provider endpoint or retune pricing / context + // limits without recompiling. See RegisterUserModels. + Models []ModelConfig `mapstructure:"models"` +} + +// Load reads configuration from the given Viper instance. +// Priority chain: flags > env > config file > defaults. +// The caller is responsible for setting up flag bindings and calling +// viper.ReadInConfig() before calling Load. +func Load(v *viper.Viper) (*Config, error) { + var cfg Config + if err := v.Unmarshal(&cfg); err != nil { + return nil, fmt.Errorf("unmarshaling config: %w", err) + } + + if cfg.ModelParams == nil { + cfg.ModelParams = map[string]any{} + } + if strings.TrimSpace(cfg.ModelParamsJSON) != "" { + var override map[string]any + if err := json.Unmarshal([]byte(cfg.ModelParamsJSON), &override); err != nil { + return nil, fmt.Errorf("parsing model-params-json: %w", err) + } + if override == nil { + override = map[string]any{} + } + mergeMaps(cfg.ModelParams, override) + } + + // Parse CLI-form context sources and append to any from the config file. + for _, raw := range cfg.ContextSourcesRaw { + cs, err := parseContextSourceKV(raw) + if err != nil { + return nil, fmt.Errorf("parsing --context-source %q: %w", raw, err) + } + cfg.ContextSources = append(cfg.ContextSources, cs) + } + + // Merge user-declared models into the global registry. Entries sharing a + // name with a built-in replace it; new names extend. Running this on + // every Load is a no-op when cfg.Models is empty, and idempotent (same + // entry twice) when it isn't. + if err := RegisterUserModels(cfg.Models); err != nil { + return nil, fmt.Errorf("registering models from config: %w", err) + } + + return &cfg, nil +} + +// SetDefaults configures the default values for all config keys. +func SetDefaults(v *viper.Viper) { + v.SetDefault("verbose", false) + v.SetDefault("model", "") + v.SetDefault("fail-on-severity", float64(0)) + v.SetDefault("max-cost", float64(25)) //Set to $25 as a default + v.SetDefault("dry-run", false) + v.SetDefault("include-tests", false) + v.SetDefault("include-docs", false) + v.SetDefault("compress", false) + v.SetDefault("custom-requirements", "") + v.SetDefault("skip-feature-detection", false) + v.SetDefault("feature-detection-model", "") + v.SetDefault("concurrency", 3) + v.SetDefault("output", "") + v.SetDefault("phase-output-dir", "") + v.SetDefault("prompts-dir", "") + v.SetDefault("custom-headers", []string{}) + v.SetDefault("model-params", map[string]any{}) + v.SetDefault("model-params-json", "") + v.SetDefault("provider", "") + v.SetDefault("base-url", "") + v.SetDefault("max-file-size", 102400) // 100KB + v.SetDefault("context-limit", 0) // 0 = use model registry default + v.SetDefault("max-output-tokens", 0) // 0 = use model registry default + v.SetDefault("request-timeout", 0) // 0 = use client default (600s) + v.SetDefault("skip-audit", false) + v.SetDefault("audit-model", "") + v.SetDefault("audit-confidence-threshold", 0.3) + v.SetDefault("audit-batch-size", 25) + v.SetDefault("context-budget-pct", 15) + v.SetDefault("context-sources-raw", []string{}) +} + +// BindEnvVars binds environment variables to Viper keys. +func BindEnvVars(v *viper.Viper) { + // "-" → "_" lets flat keys like context-limit map to CONTEXT_LIMIT. + // "." → "_" lets nested keys like phases.audit.provider map to + // PHASES_AUDIT_PROVIDER. No existing key contains a dot, so this is + // additive. + v.SetEnvKeyReplacer(strings.NewReplacer("-", "_", ".", "_")) + v.AutomaticEnv() + + // Explicit bindings for Databricks env vars + _ = v.BindEnv("databricks-host", "DATABRICKS_HOST") + _ = v.BindEnv("databricks-token", "DATABRICKS_TOKEN") + _ = v.BindEnv("databricks-endpoint", "DATABRICKS_ENDPOINT") + + // Direct-provider key bindings — these populate the legacy flat fields + // which ResolvePhases then cascades to any phase that didn't set its own. + _ = v.BindEnv("anthropic-api-key", "ANTHROPIC_API_KEY") + _ = v.BindEnv("openai-api-key", "OPENAI_API_KEY") + _ = v.BindEnv("google-api-key", "GOOGLE_API_KEY", "GEMINI_API_KEY") + _ = v.BindEnv("provider", "CODECRUCIBLE_PROVIDER") + _ = v.BindEnv("model-params-json", "CODECRUCIBLE_MODEL_PARAMS") + + // Per-phase nested keys: AutomaticEnv only picks up env vars for keys + // viper already knows about. Unmarshal into a nested struct doesn't + // pre-register the leaves, so bind the env-shaped ones explicitly. + // Config-file users don't need this — only the env path does. + for _, phase := range []string{"analysis", "feature-detection", "audit", "context-compress"} { + for _, leaf := range []string{ + "provider", "model", "api-key", "base-url", "endpoint", + "model-params-json", "request-timeout", + "context-limit", "max-output-tokens", + } { + _ = v.BindEnv("phases." + phase + "." + leaf) + } + } +} + +func mergeMaps(dst, src map[string]any) { + for k, v := range src { + existing, ok := dst[k] + if !ok { + dst[k] = v + continue + } + + existingMap, existingIsMap := existing.(map[string]any) + srcMap, srcIsMap := v.(map[string]any) + if existingIsMap && srcIsMap { + mergeMaps(existingMap, srcMap) + dst[k] = existingMap + continue + } + + dst[k] = v + } +} + +// SetupViper initializes a Viper instance with defaults, env bindings, +// and optional config file. This does NOT bind CLI flags — that is +// done in the CLI layer. +// Returns an error if a config file exists but cannot be parsed. +func SetupViper(configFile string) (*viper.Viper, error) { + v := viper.New() + SetDefaults(v) + BindEnvVars(v) + + if configFile != "" { + v.SetConfigFile(configFile) + } else { + v.SetConfigName(".codecrucible") + v.SetConfigType("yaml") + v.AddConfigPath(".") + v.AddConfigPath("$HOME") + } + + if err := v.ReadInConfig(); err != nil { + if _, ok := err.(viper.ConfigFileNotFoundError); !ok { + return nil, fmt.Errorf("reading config file: %w", err) + } + } + + return v, nil +} + +// ContextSource mirrors supctx.Source for config unmarshal. Defined here so +// the config package stays free of internal dependencies. +type ContextSource struct { + Name string `mapstructure:"name"` + Type string `mapstructure:"type"` // "path" | "repo" | "url" | "inline" + Location string `mapstructure:"location"` // path, git URL, HTTP URL, or literal text + Priority int `mapstructure:"priority"` + Compress bool `mapstructure:"compress"` + Phases []string `mapstructure:"phases"` // empty = all phases + Include []string `mapstructure:"include"` + Exclude []string `mapstructure:"exclude"` +} + +// parseContextSourceKV parses the CLI form "name=X,type=Y,location=Z,priority=N". +// Only scalar fields are supported on the CLI — use the config file for +// include/exclude globs and phase lists. +func parseContextSourceKV(raw string) (ContextSource, error) { + var cs ContextSource + for _, pair := range strings.Split(raw, ",") { + k, v, ok := strings.Cut(pair, "=") + if !ok { + return cs, fmt.Errorf("expected key=value, got %q", pair) + } + k = strings.TrimSpace(k) + v = strings.TrimSpace(v) + switch k { + case "name": + cs.Name = v + case "type": + cs.Type = v + case "location": + cs.Location = v + case "priority": + p, err := strconv.Atoi(v) + if err != nil { + return cs, fmt.Errorf("priority must be an integer: %w", err) + } + cs.Priority = p + case "compress": + b, err := strconv.ParseBool(v) + if err != nil { + return cs, fmt.Errorf("compress must be a boolean: %w", err) + } + cs.Compress = b + default: + return cs, fmt.Errorf("unknown key %q", k) + } + } + if cs.Type == "" || cs.Location == "" { + return cs, fmt.Errorf("type and location are required") + } + return cs, nil +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..894f0e2 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,443 @@ +package config + +import ( + "os" + "path/filepath" + "testing" + + "github.com/spf13/viper" +) + +func TestSetDefaults(t *testing.T) { + v := viper.New() + SetDefaults(v) + + tests := []struct { + key string + expected any + }{ + {"verbose", false}, + {"dry-run", false}, + {"include-tests", false}, + {"include-docs", false}, + {"compress", false}, + {"fail-on-severity", float64(0)}, + {"max-cost", float64(25)}, + } + + for _, tt := range tests { + got := v.Get(tt.key) + if got != tt.expected { + t.Errorf("default for %q: got %v (%T), want %v (%T)", tt.key, got, got, tt.expected, tt.expected) + } + } + + if got := v.GetStringSlice("custom-headers"); len(got) != 0 { + t.Errorf("default for %q: got %v, want empty slice", "custom-headers", got) + } + if got := v.GetString("model-params-json"); got != "" { + t.Errorf("default for %q: got %q, want empty", "model-params-json", got) + } + if got := v.GetString("phase-output-dir"); got != "" { + t.Errorf("default for %q: got %q, want empty", "phase-output-dir", got) + } +} + +func TestLoad_FromDefaults(t *testing.T) { + v := viper.New() + SetDefaults(v) + + cfg, err := Load(v) + if err != nil { + t.Fatalf("Load failed: %v", err) + } + + if cfg.Verbose { + t.Error("Verbose should default to false") + } + if cfg.DryRun { + t.Error("DryRun should default to false") + } +} + +func TestBindEnvVars_DatabricksHost(t *testing.T) { + v := viper.New() + SetDefaults(v) + BindEnvVars(v) + + t.Setenv("DATABRICKS_HOST", "https://test.databricks.com") + + cfg, err := Load(v) + if err != nil { + t.Fatalf("Load failed: %v", err) + } + + if cfg.DatabricksHost != "https://test.databricks.com" { + t.Errorf("DatabricksHost: got %q, want %q", cfg.DatabricksHost, "https://test.databricks.com") + } +} + +func TestBindEnvVars_DatabricksToken(t *testing.T) { + v := viper.New() + SetDefaults(v) + BindEnvVars(v) + + t.Setenv("DATABRICKS_TOKEN", "test-token-123") + + cfg, err := Load(v) + if err != nil { + t.Fatalf("Load failed: %v", err) + } + + if cfg.DatabricksToken != "test-token-123" { + t.Errorf("DatabricksToken: got %q, want %q", cfg.DatabricksToken, "test-token-123") + } +} + +func TestBindEnvVars_DatabricksEndpoint(t *testing.T) { + v := viper.New() + SetDefaults(v) + BindEnvVars(v) + + t.Setenv("DATABRICKS_ENDPOINT", "my-endpoint/invocations") + + cfg, err := Load(v) + if err != nil { + t.Fatalf("Load failed: %v", err) + } + + if cfg.DatabricksEndpoint != "my-endpoint/invocations" { + t.Errorf("DatabricksEndpoint: got %q, want %q", cfg.DatabricksEndpoint, "my-endpoint/invocations") + } +} + +func TestBindEnvVars_ModelParamsJSON(t *testing.T) { + v := viper.New() + SetDefaults(v) + BindEnvVars(v) + + t.Setenv("CODECRUCIBLE_MODEL_PARAMS", `{"thinking":{"type":"enabled","budget_tokens":2048}}`) + + cfg, err := Load(v) + if err != nil { + t.Fatalf("Load failed: %v", err) + } + + thinkingRaw, ok := cfg.ModelParams["thinking"] + if !ok { + t.Fatalf("expected thinking in model params, got %+v", cfg.ModelParams) + } + thinking, ok := thinkingRaw.(map[string]any) + if !ok { + t.Fatalf("expected thinking object, got %T", thinkingRaw) + } + if thinking["type"] != "enabled" { + t.Errorf("thinking.type: got %v, want enabled", thinking["type"]) + } +} + +func TestLoad_ModelParamsJSON_MergesWithConfigObject(t *testing.T) { + dir := t.TempDir() + cfgPath := filepath.Join(dir, "config.yaml") + cfgBody := `model-params: + thinking: + type: enabled + budget_tokens: 1024 + extra: + keep: true +` + if err := os.WriteFile(cfgPath, []byte(cfgBody), 0644); err != nil { + t.Fatalf("writing config file: %v", err) + } + + v, err := SetupViper(cfgPath) + if err != nil { + t.Fatalf("SetupViper failed: %v", err) + } + + v.Set("model-params-json", `{"thinking":{"budget_tokens":4096,"mode":"adaptive"}}`) + + cfg, err := Load(v) + if err != nil { + t.Fatalf("Load failed: %v", err) + } + + thinkingRaw := cfg.ModelParams["thinking"] + thinking, ok := thinkingRaw.(map[string]any) + if !ok { + t.Fatalf("expected thinking object, got %T", thinkingRaw) + } + if thinking["type"] != "enabled" { + t.Errorf("thinking.type: got %v, want enabled", thinking["type"]) + } + if thinking["budget_tokens"] != float64(4096) { + t.Errorf("thinking.budget_tokens: got %v, want 4096", thinking["budget_tokens"]) + } + if thinking["mode"] != "adaptive" { + t.Errorf("thinking.mode: got %v, want adaptive", thinking["mode"]) + } + + extraRaw := cfg.ModelParams["extra"] + extra, ok := extraRaw.(map[string]any) + if !ok { + t.Fatalf("expected extra object, got %T", extraRaw) + } + if extra["keep"] != true { + t.Errorf("extra.keep: got %v, want true", extra["keep"]) + } +} + +func TestPriorityChain_FlagOverridesEnvOverridesFile(t *testing.T) { + // Create a temp config file with concurrency=10. + dir := t.TempDir() + cfgPath := filepath.Join(dir, "config.yaml") + if err := os.WriteFile(cfgPath, []byte("concurrency: 10\nverbose: true\n"), 0644); err != nil { + t.Fatalf("writing config file: %v", err) + } + + // Step 1: Config file provides concurrency=10. + v := viper.New() + SetDefaults(v) + BindEnvVars(v) + v.SetConfigFile(cfgPath) + if err := v.ReadInConfig(); err != nil { + t.Fatalf("ReadInConfig: %v", err) + } + + cfg, err := Load(v) + if err != nil { + t.Fatalf("Load failed: %v", err) + } + if cfg.Concurrency != 10 { + t.Errorf("Step 1 (file): Concurrency got %d, want %d", cfg.Concurrency, 10) + } + + // Step 2: Env var overrides config file. + t.Setenv("CONCURRENCY", "5") + cfg, err = Load(v) + if err != nil { + t.Fatalf("Load failed: %v", err) + } + if cfg.Concurrency != 5 { + t.Errorf("Step 2 (env): Concurrency got %d, want %d", cfg.Concurrency, 5) + } + + // Step 3: Flag (via Set) overrides env. + v.Set("concurrency", 7) + cfg, err = Load(v) + if err != nil { + t.Fatalf("Load failed: %v", err) + } + if cfg.Concurrency != 7 { + t.Errorf("Step 3 (flag): Concurrency got %d, want %d", cfg.Concurrency, 7) + } +} + +func TestSetupViper_NoConfigFile(t *testing.T) { + // SetupViper should not error when no config file exists. + v, err := SetupViper("") + if err != nil { + t.Fatalf("SetupViper failed: %v", err) + } + if v == nil { + t.Fatal("expected non-nil viper instance") + } + + // Defaults should be set. + if v.GetBool("verbose") != false { + t.Error("default verbose should be false") + } +} + +func TestSetupViper_WithConfigFile(t *testing.T) { + dir := t.TempDir() + cfgPath := filepath.Join(dir, "test-config.yaml") + if err := os.WriteFile(cfgPath, []byte("verbose: true\nmodel: test-model\n"), 0644); err != nil { + t.Fatalf("writing config file: %v", err) + } + + v, err := SetupViper(cfgPath) + if err != nil { + t.Fatalf("SetupViper failed: %v", err) + } + + if !v.GetBool("verbose") { + t.Error("expected verbose=true from config file") + } + if v.GetString("model") != "test-model" { + t.Errorf("model: got %q, want %q", v.GetString("model"), "test-model") + } +} + +func TestPriorityChain_ConfigFileOverridesDefaults(t *testing.T) { + dir := t.TempDir() + cfgPath := filepath.Join(dir, "config.yaml") + if err := os.WriteFile(cfgPath, []byte("dry-run: true\n"), 0644); err != nil { + t.Fatalf("writing config file: %v", err) + } + + v, err := SetupViper(cfgPath) + if err != nil { + t.Fatalf("SetupViper failed: %v", err) + } + + cfg, err := Load(v) + if err != nil { + t.Fatalf("Load failed: %v", err) + } + if !cfg.DryRun { + t.Error("expected DryRun=true from config file (overriding default false)") + } +} + +func TestSetupViper_MalformedConfigFileReturnsError(t *testing.T) { + dir := t.TempDir() + cfgPath := filepath.Join(dir, "bad-config.yaml") + // Write invalid YAML to trigger a parse error. + if err := os.WriteFile(cfgPath, []byte(":\n bad: [yaml\n unclosed"), 0644); err != nil { + t.Fatalf("writing config file: %v", err) + } + + _, err := SetupViper(cfgPath) + if err == nil { + t.Fatal("expected error for malformed config file, got nil") + } +} + +func TestParseContextSourceKV(t *testing.T) { + cs, err := parseContextSourceKV("name=spec,type=path,location=/tmp/api.yaml,priority=100,compress=true") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cs.Name != "spec" || cs.Type != "path" || cs.Location != "/tmp/api.yaml" || cs.Priority != 100 || !cs.Compress { + t.Errorf("parsed struct wrong: %+v", cs) + } + + if _, err := parseContextSourceKV("name=x,type=path"); err == nil { + t.Error("expected error for missing location") + } + if _, err := parseContextSourceKV("name=x,location=y"); err == nil { + t.Error("expected error for missing type") + } + if _, err := parseContextSourceKV("name=x,type=path,location=y,priority=abc"); err == nil { + t.Error("expected error for non-integer priority") + } + if _, err := parseContextSourceKV("name=x,type=path,location=y,unknown=z"); err == nil { + t.Error("expected error for unknown key") + } +} + +func TestLoad_RegistersModelsFromConfigFile(t *testing.T) { + // Isolate the global registry from other tests. + saved := make([]ModelConfig, len(defaultModels)) + copy(saved, defaultModels) + defer func() { defaultModels = saved }() + + dir := t.TempDir() + cfgPath := filepath.Join(dir, "config.yaml") + cfgBody := ` +models: + # Extend: a local Ollama model. + - name: llama-4-405b + provider: ollama + input_price_per_million: 0.0 + output_price_per_million: 0.0 + context_limit: 131072 + max_output_tokens: 8192 + tokenizer_encoding: cl100k_base + supports_structured_output: false + + # Override: retune a built-in. + - name: claude-sonnet-4-6 + provider: anthropic + endpoint: claude-sonnet-4-6/invocations + input_price_per_million: 2.0 + output_price_per_million: 10.0 + context_limit: 200000 + max_output_tokens: 16384 + tokenizer_encoding: claude + supports_structured_output: true +` + if err := os.WriteFile(cfgPath, []byte(cfgBody), 0644); err != nil { + t.Fatalf("writing config: %v", err) + } + + v, err := SetupViper(cfgPath) + if err != nil { + t.Fatalf("SetupViper: %v", err) + } + cfg, err := Load(v) + if err != nil { + t.Fatalf("Load: %v", err) + } + + if len(cfg.Models) != 2 { + t.Fatalf("cfg.Models: got %d entries, want 2", len(cfg.Models)) + } + + // Extension round-trip: the new model is in the registry, and empty + // Endpoint was defaulted from Name. + llama, ok := LookupModel("llama-4-405b") + if !ok { + t.Fatal("llama-4-405b not found in registry after Load") + } + if llama.Endpoint != "llama-4-405b/invocations" { + t.Errorf("llama Endpoint: got %q, want default", llama.Endpoint) + } + if llama.Provider != "ollama" { + t.Errorf("llama Provider: got %q, want ollama", llama.Provider) + } + + // Override round-trip: built-in pricing was replaced. + sonnet, _ := LookupModel("claude-sonnet-4-6") + if sonnet.InputPricePerM != 2.0 { + t.Errorf("claude-sonnet-4-6 override InputPricePerM: got %v, want 2.0", sonnet.InputPricePerM) + } + if sonnet.OutputPricePerM != 10.0 { + t.Errorf("claude-sonnet-4-6 override OutputPricePerM: got %v, want 10.0", sonnet.OutputPricePerM) + } +} + +func TestLoad_EmptyModelNameErrors(t *testing.T) { + saved := make([]ModelConfig, len(defaultModels)) + copy(saved, defaultModels) + defer func() { defaultModels = saved }() + + dir := t.TempDir() + cfgPath := filepath.Join(dir, "config.yaml") + cfgBody := ` +models: + - provider: openai-compat + context_limit: 128000 +` + if err := os.WriteFile(cfgPath, []byte(cfgBody), 0644); err != nil { + t.Fatalf("writing config: %v", err) + } + + v, err := SetupViper(cfgPath) + if err != nil { + t.Fatalf("SetupViper: %v", err) + } + if _, err := Load(v); err == nil { + t.Fatal("expected Load to error for models entry without name") + } +} + +func TestLoad_ContextSourcesRaw(t *testing.T) { + v := viper.New() + SetDefaults(v) + v.Set("context-sources-raw", []string{ + "name=a,type=inline,location=hello", + "name=b,type=path,location=/tmp,priority=50", + }) + cfg, err := Load(v) + if err != nil { + t.Fatalf("Load error: %v", err) + } + if len(cfg.ContextSources) != 2 { + t.Fatalf("expected 2 sources, got %d", len(cfg.ContextSources)) + } + if cfg.ContextSources[0].Name != "a" || cfg.ContextSources[1].Priority != 50 { + t.Errorf("sources parsed wrong: %+v", cfg.ContextSources) + } +} diff --git a/internal/config/models.go b/internal/config/models.go new file mode 100644 index 0000000..d5086c9 --- /dev/null +++ b/internal/config/models.go @@ -0,0 +1,264 @@ +package config + +import ( + "fmt" + "strings" + "sync" +) + +// ModelConfig describes a model's capabilities, pricing, and endpoint. +// +// Both yaml and mapstructure tags are set so the same struct literal in +// defaultModels can also be populated by Viper from a user's config file +// (Viper uses mapstructure for Unmarshal). +type ModelConfig struct { + Name string `yaml:"name" mapstructure:"name"` + Provider string `yaml:"provider" mapstructure:"provider"` // "databricks", "anthropic", "openai", "google" + Endpoint string `yaml:"endpoint" mapstructure:"endpoint"` + InputPricePerM float64 `yaml:"input_price_per_million" mapstructure:"input_price_per_million"` + OutputPricePerM float64 `yaml:"output_price_per_million" mapstructure:"output_price_per_million"` + ContextLimit int `yaml:"context_limit" mapstructure:"context_limit"` + MaxOutputTokens int `yaml:"max_output_tokens" mapstructure:"max_output_tokens"` + Temperature float64 `yaml:"temperature" mapstructure:"temperature"` + Encoding string `yaml:"tokenizer_encoding" mapstructure:"tokenizer_encoding"` + SupportsStructuredOutput bool `yaml:"supports_structured_output" mapstructure:"supports_structured_output"` +} + +// modelsMu protects defaultModels for concurrent access. +var modelsMu sync.RWMutex + +// defaultModels is the built-in model registry. +var defaultModels = []ModelConfig{ + { + Name: "claude-sonnet-4-6", + Provider: "anthropic", + Endpoint: "claude-sonnet-4-6/invocations", + InputPricePerM: 3.0, + OutputPricePerM: 15.0, + ContextLimit: 200000, + MaxOutputTokens: 16384, + Temperature: 0.0, + Encoding: "claude", + SupportsStructuredOutput: true, + }, + { + Name: "claude-opus-4-6", + Provider: "anthropic", + Endpoint: "claude-opus-4-6/invocations", + InputPricePerM: 5.0, + OutputPricePerM: 25.0, + ContextLimit: 200000, + MaxOutputTokens: 32768, + Temperature: 0.0, + Encoding: "claude", + SupportsStructuredOutput: true, + }, + { + Name: "claude-opus-4-7", + Provider: "anthropic", + Endpoint: "claude-opus-4-7/invocations", + InputPricePerM: 5.0, + OutputPricePerM: 25.0, + ContextLimit: 1000000, + MaxOutputTokens: 128000, + Temperature: 0.0, + Encoding: "claude", + SupportsStructuredOutput: true, + }, + { + Name: "gpt-5.2", + Provider: "openai", + Endpoint: "gpt-5.2/invocations", + InputPricePerM: 1.75, + OutputPricePerM: 14.0, + ContextLimit: 400000, + MaxOutputTokens: 16384, + Temperature: 0.0, + Encoding: "o200k_base", + SupportsStructuredOutput: true, + }, + { + Name: "gpt-5.4", + Provider: "openai", + Endpoint: "gpt-5.4/invocations", + InputPricePerM: 2.50, + OutputPricePerM: 15.0, + ContextLimit: 1000000, + MaxOutputTokens: 128000, + Temperature: 0.0, + Encoding: "o200k_base", + SupportsStructuredOutput: true, + }, + { + Name: "gpt-5.5", + Provider: "openai", + Endpoint: "gpt-5.5/invocations", + InputPricePerM: 5.00, + OutputPricePerM: 30.0, + ContextLimit: 1000000, + MaxOutputTokens: 128000, + Temperature: 1.0, + Encoding: "o200k_base", + SupportsStructuredOutput: true, + }, + { + Name: "gpt-5.4-mini", + Provider: "openai", + Endpoint: "gpt-5.4-mini/invocations", + InputPricePerM: 0.75, + OutputPricePerM: 4.50, + ContextLimit: 400000, + MaxOutputTokens: 128000, + Temperature: 0.0, + Encoding: "o200k_base", + SupportsStructuredOutput: true, + }, + { + Name: "gpt-5.4-nano", + Provider: "openai", + Endpoint: "gpt-5.4-nano/invocations", + InputPricePerM: 0.20, + OutputPricePerM: 1.25, + ContextLimit: 400000, + MaxOutputTokens: 128000, + Temperature: 0.0, + Encoding: "o200k_base", + SupportsStructuredOutput: true, + }, + { + Name: "gemini-3-pro", + Provider: "google", + Endpoint: "gemini-3-pro/invocations", + InputPricePerM: 2.0, + OutputPricePerM: 12.0, + ContextLimit: 1048576, + MaxOutputTokens: 65536, + Temperature: 0.0, + Encoding: "cl100k_base", + // Google's OpenAI-compat endpoint accepts response_format json_schema. + SupportsStructuredOutput: true, + }, + { + Name: "gemini-3-flash", + Provider: "google", + Endpoint: "gemini-3-flash/invocations", + InputPricePerM: 0.15, + OutputPricePerM: 0.60, + ContextLimit: 1048576, + MaxOutputTokens: 65536, + Temperature: 0.0, + Encoding: "cl100k_base", + SupportsStructuredOutput: true, + }, +} + +// RegisterModel adds or replaces a model in the registry. If a model with the +// same name already exists (case-insensitive), it is replaced. This allows +// callers to extend the built-in registry at startup without forking the code. +func RegisterModel(m ModelConfig) { + modelsMu.Lock() + defer modelsMu.Unlock() + lower := strings.ToLower(m.Name) + for i, existing := range defaultModels { + if strings.ToLower(existing.Name) == lower { + defaultModels[i] = m + return + } + } + defaultModels = append(defaultModels, m) +} + +// RegisterUserModels registers user-defined model configs into the registry. +// Each entry must have a non-empty Name. Entries with a name that matches a +// built-in are replaced (case-insensitive), so this is also how users override +// built-in pricing / context limits without forking. Empty Endpoint defaults +// to "/invocations" to match the built-in convention (Databricks serving +// path; other providers ignore it). +func RegisterUserModels(models []ModelConfig) error { + for i, m := range models { + if strings.TrimSpace(m.Name) == "" { + return fmt.Errorf("models[%d]: name is required", i) + } + if m.Endpoint == "" { + m.Endpoint = m.Name + "/invocations" + } + RegisterModel(m) + } + return nil +} + +// DefaultModelRegistry returns a copy of the built-in model configs. +func DefaultModelRegistry() []ModelConfig { + modelsMu.RLock() + defer modelsMu.RUnlock() + out := make([]ModelConfig, len(defaultModels)) + copy(out, defaultModels) + return out +} + +// LookupModel finds a model by name using case-insensitive partial matching. +// It matches if either the query contains a known model name, or a known model +// name contains the query. This handles Databricks-prefixed names like +// "databricks-claude-opus-4-5" matching "claude-opus-4". +func LookupModel(name string) (ModelConfig, bool) { + modelsMu.RLock() + defer modelsMu.RUnlock() + lower := strings.ToLower(name) + // Exact match first. + for _, m := range defaultModels { + if strings.ToLower(m.Name) == lower { + return m, true + } + } + // Partial match: query contains model name, or model name contains query. + // Try longest match first to avoid "claude-opus-4" matching before a more specific entry. + var best ModelConfig + bestLen := 0 + found := false + for _, m := range defaultModels { + mLower := strings.ToLower(m.Name) + if strings.Contains(lower, mLower) || strings.Contains(mLower, lower) { + if len(m.Name) > bestLen { + best = m + bestLen = len(m.Name) + found = true + } + } + } + return best, found +} + +// LookupModelByEndpoint finds a model by its exact endpoint. +func LookupModelByEndpoint(endpoint string) (ModelConfig, bool) { + modelsMu.RLock() + defer modelsMu.RUnlock() + for _, m := range defaultModels { + if m.Endpoint == endpoint { + return m, true + } + } + return ModelConfig{}, false +} + +// DefaultModel returns the default model (claude-sonnet-4). +func DefaultModel() ModelConfig { + modelsMu.RLock() + defer modelsMu.RUnlock() + return defaultModels[0] +} + +// UnknownModelDefaults returns conservative defaults for a model not in the +// registry. Uses 128K context and 8192 output tokens — reasonable for most +// frontier models released since 2024. Structured output is disabled since +// we can't know if the model supports it. +func UnknownModelDefaults(name string) ModelConfig { + return ModelConfig{ + Name: name, + Endpoint: name + "/invocations", + ContextLimit: 128000, + MaxOutputTokens: 8192, + Temperature: 0.0, + Encoding: "cl100k_base", + SupportsStructuredOutput: false, + } +} diff --git a/internal/config/models_test.go b/internal/config/models_test.go new file mode 100644 index 0000000..4a3e56d --- /dev/null +++ b/internal/config/models_test.go @@ -0,0 +1,410 @@ +package config + +import "testing" + +func TestDefaultModelRegistry_ContainsExpectedModels(t *testing.T) { + registry := DefaultModelRegistry() + + expected := []struct { + name string + contextLimit int + }{ + {"claude-sonnet-4-6", 200000}, + {"claude-opus-4-6", 200000}, + {"claude-opus-4-7", 1000000}, + {"gpt-5.2", 400000}, + {"gpt-5.4", 1000000}, + {"gpt-5.5", 1000000}, + {"gpt-5.4-mini", 400000}, + {"gpt-5.4-nano", 400000}, + {"gemini-3-pro", 1048576}, + {"gemini-3-flash", 1048576}, + } + + if len(registry) != len(expected) { + t.Fatalf("registry length: got %d, want %d", len(registry), len(expected)) + } + + for i, exp := range expected { + if registry[i].Name != exp.name { + t.Errorf("registry[%d].Name: got %q, want %q", i, registry[i].Name, exp.name) + } + if registry[i].ContextLimit != exp.contextLimit { + t.Errorf("registry[%d].ContextLimit: got %d, want %d", i, registry[i].ContextLimit, exp.contextLimit) + } + } +} + +func TestDefaultModelRegistry_ReturnsCopy(t *testing.T) { + r1 := DefaultModelRegistry() + r1[0].Name = "mutated" + + r2 := DefaultModelRegistry() + if r2[0].Name == "mutated" { + t.Error("DefaultModelRegistry should return a copy, but mutation leaked") + } +} + +func TestLookupModel_ExactName(t *testing.T) { + tests := []struct { + query string + wantName string + wantEndpoint string + wantFound bool + }{ + {"claude-sonnet-4-6", "claude-sonnet-4-6", "claude-sonnet-4-6/invocations", true}, + {"claude-opus-4-6", "claude-opus-4-6", "claude-opus-4-6/invocations", true}, + {"gpt-5.2", "gpt-5.2", "gpt-5.2/invocations", true}, + {"gpt-5.5", "gpt-5.5", "gpt-5.5/invocations", true}, + {"gemini-3-pro", "gemini-3-pro", "gemini-3-pro/invocations", true}, + {"nonexistent-model", "", "", false}, + } + + for _, tt := range tests { + t.Run(tt.query, func(t *testing.T) { + m, found := LookupModel(tt.query) + if found != tt.wantFound { + t.Fatalf("found: got %v, want %v", found, tt.wantFound) + } + if !found { + return + } + if m.Name != tt.wantName { + t.Errorf("Name: got %q, want %q", m.Name, tt.wantName) + } + if m.Endpoint != tt.wantEndpoint { + t.Errorf("Endpoint: got %q, want %q", m.Endpoint, tt.wantEndpoint) + } + }) + } +} + +func TestLookupModel_CaseInsensitive(t *testing.T) { + tests := []struct { + query string + wantName string + }{ + {"Claude-Sonnet-4-6", "claude-sonnet-4-6"}, + {"CLAUDE-SONNET-4-6", "claude-sonnet-4-6"}, + {"GPT-5.2", "gpt-5.2"}, + {"GPT-5.5", "gpt-5.5"}, + {"Gemini-3-Pro", "gemini-3-pro"}, + } + + for _, tt := range tests { + t.Run(tt.query, func(t *testing.T) { + m, found := LookupModel(tt.query) + if !found { + t.Fatalf("expected to find model for query %q", tt.query) + } + if m.Name != tt.wantName { + t.Errorf("Name: got %q, want %q", m.Name, tt.wantName) + } + }) + } +} + +func TestLookupModel_PartialMatch(t *testing.T) { + tests := []struct { + query string + wantName string + }{ + {"sonnet", "claude-sonnet-4-6"}, + // Two opus entries: longest-match prefers the first declared since + // names are the same length; claude-opus-4-6 comes first in the registry. + {"opus", "claude-opus-4-6"}, + // Multiple gpt entries: longest-match picks gpt-5.4-mini (first of the + // 12-char entries in declaration order). Exact GPT queries still hit + // the exact-match fast path. + {"gpt", "gpt-5.4-mini"}, + // Two gemini entries: longest-match picks flash (14 chars vs 12). + // Exact queries (gemini-3-pro) still hit the exact-match fast path. + {"gemini", "gemini-3-flash"}, + } + + for _, tt := range tests { + t.Run(tt.query, func(t *testing.T) { + m, found := LookupModel(tt.query) + if !found { + t.Fatalf("expected to find model for query %q", tt.query) + } + if m.Name != tt.wantName { + t.Errorf("Name: got %q, want %q", m.Name, tt.wantName) + } + }) + } +} + +func TestLookupModelByEndpoint(t *testing.T) { + tests := []struct { + endpoint string + wantName string + wantFound bool + }{ + {"claude-sonnet-4-6/invocations", "claude-sonnet-4-6", true}, + {"claude-opus-4-6/invocations", "claude-opus-4-6", true}, + {"gpt-5.2/invocations", "gpt-5.2", true}, + {"gpt-5.5/invocations", "gpt-5.5", true}, + {"gemini-3-pro/invocations", "gemini-3-pro", true}, + {"nonexistent/invocations", "", false}, + } + + for _, tt := range tests { + t.Run(tt.endpoint, func(t *testing.T) { + m, found := LookupModelByEndpoint(tt.endpoint) + if found != tt.wantFound { + t.Fatalf("found: got %v, want %v", found, tt.wantFound) + } + if !found { + return + } + if m.Name != tt.wantName { + t.Errorf("Name: got %q, want %q", m.Name, tt.wantName) + } + }) + } +} + +func TestLookupModel_GPT55Capabilities(t *testing.T) { + m, found := LookupModel("gpt-5.5") + if !found { + t.Fatal("gpt-5.5 not found") + } + + if m.Provider != "openai" { + t.Errorf("Provider: got %q, want openai", m.Provider) + } + if m.InputPricePerM != 5.0 { + t.Errorf("InputPricePerM: got %f, want 5.0", m.InputPricePerM) + } + if m.OutputPricePerM != 30.0 { + t.Errorf("OutputPricePerM: got %f, want 30.0", m.OutputPricePerM) + } + if m.ContextLimit != 1000000 { + t.Errorf("ContextLimit: got %d, want 1000000", m.ContextLimit) + } + if m.MaxOutputTokens != 128000 { + t.Errorf("MaxOutputTokens: got %d, want 128000", m.MaxOutputTokens) + } + if m.Temperature != 1.0 { + t.Errorf("Temperature: got %f, want 1.0", m.Temperature) + } + if m.Encoding != "o200k_base" { + t.Errorf("Encoding: got %q, want o200k_base", m.Encoding) + } + if !m.SupportsStructuredOutput { + t.Error("SupportsStructuredOutput: got false, want true") + } +} + +func TestDefaultModel_IsClaudeSonnet46(t *testing.T) { + m := DefaultModel() + if m.Name != "claude-sonnet-4-6" { + t.Errorf("Name: got %q, want %q", m.Name, "claude-sonnet-4-6") + } + if m.Endpoint != "claude-sonnet-4-6/invocations" { + t.Errorf("Endpoint: got %q, want %q", m.Endpoint, "claude-sonnet-4-6/invocations") + } +} + +func TestUnknownModelDefaults(t *testing.T) { + m := UnknownModelDefaults("my-custom-model") + + if m.Name != "my-custom-model" { + t.Errorf("Name: got %q, want %q", m.Name, "my-custom-model") + } + if m.Endpoint != "my-custom-model/invocations" { + t.Errorf("Endpoint: got %q, want %q", m.Endpoint, "my-custom-model/invocations") + } + if m.ContextLimit != 128000 { + t.Errorf("ContextLimit: got %d, want %d", m.ContextLimit, 128000) + } + if m.MaxOutputTokens != 8192 { + t.Errorf("MaxOutputTokens: got %d, want %d", m.MaxOutputTokens, 8192) + } + if m.Temperature != 0.0 { + t.Errorf("Temperature: got %f, want %f", m.Temperature, 0.0) + } + if m.Encoding != "cl100k_base" { + t.Errorf("Encoding: got %q, want %q", m.Encoding, "cl100k_base") + } + if m.SupportsStructuredOutput { + t.Error("SupportsStructuredOutput: got true, want false") + } +} + +func TestRegisterModel_AddsNew(t *testing.T) { + // Save and restore the registry to avoid polluting other tests. + saved := make([]ModelConfig, len(defaultModels)) + copy(saved, defaultModels) + defer func() { defaultModels = saved }() + + RegisterModel(ModelConfig{ + Name: "my-custom-model", + ContextLimit: 256000, + MaxOutputTokens: 16384, + }) + + m, found := LookupModel("my-custom-model") + if !found { + t.Fatal("registered model not found") + } + if m.ContextLimit != 256000 { + t.Errorf("ContextLimit: got %d, want %d", m.ContextLimit, 256000) + } +} + +func TestRegisterModel_ReplacesExisting(t *testing.T) { + saved := make([]ModelConfig, len(defaultModels)) + copy(saved, defaultModels) + defer func() { defaultModels = saved }() + + original, _ := LookupModel("claude-sonnet-4-6") + if original.MaxOutputTokens != 16384 { + t.Fatalf("precondition: expected 16384, got %d", original.MaxOutputTokens) + } + + RegisterModel(ModelConfig{ + Name: "claude-sonnet-4-6", + ContextLimit: 200000, + MaxOutputTokens: 32768, + }) + + updated, found := LookupModel("claude-sonnet-4-6") + if !found { + t.Fatal("replaced model not found") + } + if updated.MaxOutputTokens != 32768 { + t.Errorf("MaxOutputTokens: got %d, want %d", updated.MaxOutputTokens, 32768) + } + + // Registry size should not have grown. + if len(defaultModels) != len(saved) { + t.Errorf("registry grew from %d to %d after replace", len(saved), len(defaultModels)) + } +} + +func TestRegisterUserModels_AddsAndOverrides(t *testing.T) { + saved := make([]ModelConfig, len(defaultModels)) + copy(saved, defaultModels) + defer func() { defaultModels = saved }() + + err := RegisterUserModels([]ModelConfig{ + // New entry — extends the registry. + { + Name: "acme-llama-70b", + Provider: "openai-compat", + InputPricePerM: 0.5, + OutputPricePerM: 1.5, + ContextLimit: 131072, + MaxOutputTokens: 8192, + Encoding: "cl100k_base", + }, + // Override of a built-in — same Name, different pricing. + { + Name: "claude-sonnet-4-6", + Provider: "anthropic", + InputPricePerM: 1.5, // halved vs built-in + OutputPricePerM: 7.5, + ContextLimit: 200000, + MaxOutputTokens: 16384, + Encoding: "claude", + }, + }) + if err != nil { + t.Fatalf("RegisterUserModels: %v", err) + } + + added, ok := LookupModel("acme-llama-70b") + if !ok { + t.Fatal("user-defined model not registered") + } + if added.Endpoint != "acme-llama-70b/invocations" { + t.Errorf("empty Endpoint should default to /invocations, got %q", added.Endpoint) + } + if added.Provider != "openai-compat" { + t.Errorf("Provider: got %q, want %q", added.Provider, "openai-compat") + } + + overridden, _ := LookupModel("claude-sonnet-4-6") + if overridden.InputPricePerM != 1.5 { + t.Errorf("user override did not take effect: InputPricePerM got %v, want 1.5", overridden.InputPricePerM) + } + + // Size grew by exactly one (the new one). The override replaced in place. + if len(defaultModels) != len(saved)+1 { + t.Errorf("registry grew by %d, want 1", len(defaultModels)-len(saved)) + } +} + +func TestRegisterUserModels_PreservesExplicitEndpoint(t *testing.T) { + saved := make([]ModelConfig, len(defaultModels)) + copy(saved, defaultModels) + defer func() { defaultModels = saved }() + + err := RegisterUserModels([]ModelConfig{{ + Name: "azure-gpt-5", + Provider: "openai-compat", + Endpoint: "deployments/my-azure-deploy/chat/completions", + }}) + if err != nil { + t.Fatalf("RegisterUserModels: %v", err) + } + + m, _ := LookupModel("azure-gpt-5") + if m.Endpoint != "deployments/my-azure-deploy/chat/completions" { + t.Errorf("Endpoint: got %q, want the explicit value", m.Endpoint) + } +} + +func TestRegisterUserModels_RejectsEmptyName(t *testing.T) { + saved := make([]ModelConfig, len(defaultModels)) + copy(saved, defaultModels) + defer func() { defaultModels = saved }() + + err := RegisterUserModels([]ModelConfig{ + {Name: "ok-model", ContextLimit: 100000}, + {Name: " ", ContextLimit: 100000}, // whitespace-only name + }) + if err == nil { + t.Fatal("expected error for empty name, got nil") + } +} + +func TestDefaultModelRegistry_FieldValues(t *testing.T) { + tests := []struct { + name string + maxOutput int + encoding string + structured bool + temperature float64 + }{ + {"claude-sonnet-4-6", 16384, "claude", true, 0.0}, + {"claude-opus-4-6", 32768, "claude", true, 0.0}, + {"gpt-5.2", 16384, "o200k_base", true, 0.0}, + {"gpt-5.5", 128000, "o200k_base", true, 1.0}, + {"gemini-3-pro", 65536, "cl100k_base", true, 0.0}, + {"gemini-3-flash", 65536, "cl100k_base", true, 0.0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m, found := LookupModel(tt.name) + if !found { + t.Fatalf("model %q not found", tt.name) + } + if m.MaxOutputTokens != tt.maxOutput { + t.Errorf("MaxOutputTokens: got %d, want %d", m.MaxOutputTokens, tt.maxOutput) + } + if m.Encoding != tt.encoding { + t.Errorf("Encoding: got %q, want %q", m.Encoding, tt.encoding) + } + if m.SupportsStructuredOutput != tt.structured { + t.Errorf("SupportsStructuredOutput: got %v, want %v", m.SupportsStructuredOutput, tt.structured) + } + if m.Temperature != tt.temperature { + t.Errorf("Temperature: got %f, want %f", m.Temperature, tt.temperature) + } + }) + } +} diff --git a/internal/config/phase.go b/internal/config/phase.go new file mode 100644 index 0000000..2cb6070 --- /dev/null +++ b/internal/config/phase.go @@ -0,0 +1,357 @@ +package config + +import ( + "encoding/json" + "fmt" + "log/slog" + "strings" +) + +// PhaseConfig carries everything one pipeline phase needs to construct its +// own LLM client and ChatRequest, independent of the other phases. +// +// The zero value inherits: any field left at its zero value on the +// feature-detection or audit phase is filled from the analysis phase by +// ResolvePhases. For ModelParams, "zero" means len == 0 — a phase that +// genuinely needs to clear inherited params can set {"_":""} or similar, +// but in practice nobody wants that. +// +// ModelCfg is an output: ResolvePhases populates it from the registry with +// ContextLimit / MaxOutputTokens overrides applied. Callers read ModelCfg +// for pricing, tokenizer encoding, and request limits; they do not set it. +type PhaseConfig struct { + // Provider selects the HTTP client flavour: anthropic, openai, google, + // databricks. Empty inherits (analysis phase: auto-detects from env). + Provider string `mapstructure:"provider"` + + // Model is the model name sent in the request body. Drives registry + // lookup for ModelCfg. + Model string `mapstructure:"model"` + + // APIKey authenticates this phase's client. Per-phase so you can hit + // Anthropic for analysis and Google for audit without either key + // leaking into the wrong Authorization header. + APIKey string `mapstructure:"api-key"` + + // BaseURL overrides the hardcoded per-provider default. Use for proxies, + // Azure OpenAI, Vertex-vs-AI-Studio, local mock servers. Empty = default. + BaseURL string `mapstructure:"base-url"` + + // Endpoint is the Databricks serving-endpoint path segment. Ignored by + // other providers. + Endpoint string `mapstructure:"endpoint"` + + // ModelParams is merged into the top level of the request body (see + // llm.marshalWithModelParams). Per-phase so e.g. the audit phase can + // drop thinking-mode params that only the analysis phase needs. + ModelParams map[string]any `mapstructure:"model-params"` + + // ModelParamsJSON is the CLI/env string form — parsed and merged into + // ModelParams by ResolvePhases. Config-file users set ModelParams as a + // YAML map directly and leave this empty. + ModelParamsJSON string `mapstructure:"model-params-json"` + + // RequestTimeout is the per-request HTTP timeout in seconds. Covers the + // full body read, so for streaming responses it bounds total generation + // time. 0 = inherit (analysis: client default 600s). + RequestTimeout int `mapstructure:"request-timeout"` + + // Headers are extra HTTP headers in "Name: Value" form, parsed by the + // CLI layer. + Headers []string `mapstructure:"headers"` + + // ContextLimit overrides the registry's context window for chunk-budget + // math. Only the analysis phase's value affects chunking, but we carry + // it per-phase so a fresh registry lookup for an audit model doesn't + // silently drop the override (the prior code had that bug). + ContextLimit int `mapstructure:"context-limit"` + + // MaxOutputTokens overrides the registry's max output. Becomes + // ChatRequest.MaxTokens. + MaxOutputTokens int `mapstructure:"max-output-tokens"` + + // ModelCfg is the resolved registry entry with overrides applied. + // Populated by ResolvePhases; never set directly. + ModelCfg ModelConfig `mapstructure:"-"` +} + +// Phases groups the three pipeline phase configs. Lives on Config under +// mapstructure:"phases" so config-file users write: +// +// phases: +// analysis: +// model: claude-opus-4-6 +// audit: +// provider: google +// model: gemini-3-pro +// api-key: ${GOOGLE_API_KEY} +// +// Env var form (see BindEnvVars replacer): PHASES_AUDIT_PROVIDER, +// PHASES_AUDIT_MODEL, PHASES_AUDIT_API_KEY, PHASES_AUDIT_MODEL_PARAMS_JSON. +type Phases struct { + Analysis PhaseConfig `mapstructure:"analysis"` + FeatureDetection PhaseConfig `mapstructure:"feature-detection"` + Audit PhaseConfig `mapstructure:"audit"` + ContextCompress PhaseConfig `mapstructure:"context-compress"` +} + +// ResolvePhases fills Config.Phases from legacy globals and applies the +// inheritance cascade. Call once after Load. +// +// Cascade, in order: +// 1. Analysis phase is seeded from legacy flat fields (cfg.Model, +// cfg.Provider, cfg.ModelParams, ...). Anything already set on +// cfg.Phases.Analysis (config file, per-phase env) wins. +// 2. FeatureDetection and Audit inherit any zero-valued field from the +// resolved Analysis phase. Legacy per-phase flags (cfg.AuditModel, +// cfg.FeatureDetectionModel) act as overrides. +// 3. Each phase's ModelCfg is looked up from the registry and patched +// with that phase's ContextLimit / MaxOutputTokens overrides. This +// fixes the prior bug where --context-limit only patched the main +// model and a separate --audit-model got fresh registry values. +// 4. Provider auto-detect runs per-phase: registry hint, then which key +// is set on this phase, then ambient Databricks env. --provider (the +// legacy global) acts as a default, not a per-phase override — so +// setting phases.audit.provider beats --provider for the audit phase +// but --provider still applies to phases that don't set their own. +func ResolvePhases(cfg *Config) error { + // ── 1. Analysis: legacy globals as the base layer ────────────────── + base := PhaseConfig{ + Provider: cfg.Provider, + Model: cfg.Model, + BaseURL: cfg.BaseURL, + ModelParams: cfg.ModelParams, + RequestTimeout: cfg.RequestTimeout, + Headers: cfg.CustomHeaders, + ContextLimit: cfg.ContextLimit, + MaxOutputTokens: cfg.MaxOutputTokens, + Endpoint: cfg.DatabricksEndpoint, + // APIKey: left empty — picked per-provider in step 4. + } + overlay(&base, &cfg.Phases.Analysis) + if err := parsePhaseParams(&base, "analysis"); err != nil { + return err + } + cfg.Phases.Analysis = base + + // ── 2. Secondary phases: inherit from analysis, then overlay ─────── + secondaries := []struct { + name string + legacyModel string // --audit-model / --feature-detection-model + dst *PhaseConfig + }{ + {"feature-detection", cfg.FeatureDetectionModel, &cfg.Phases.FeatureDetection}, + {"audit", cfg.AuditModel, &cfg.Phases.Audit}, + {"context-compress", "", &cfg.Phases.ContextCompress}, + } + for _, s := range secondaries { + pc := base // value copy — inherit everything + // ModelParams is a map: the value copy aliases the same backing + // store. Break the alias so a phase that sets its own params + // doesn't mutate analysis's, and so downstream mergeMaps has a + // distinct target. + pc.ModelParams = cloneParams(base.ModelParams) + + // Legacy per-phase model flag is an override only when set — + // otherwise the inherited analysis model stands. + if s.legacyModel != "" { + pc.Model = s.legacyModel + } + + overlay(&pc, s.dst) + if err := parsePhaseParams(&pc, s.name); err != nil { + return err + } + *s.dst = pc + } + + // ── 3. Resolve registry + apply overrides, per phase ─────────────── + for name, pc := range allPhases(cfg) { + pc.ModelCfg = lookupOrDefault(pc.Model, name) + if pc.ContextLimit > 0 { + slog.Info("overriding model context limit", + "phase", name, "registry", pc.ModelCfg.ContextLimit, "override", pc.ContextLimit) + pc.ModelCfg.ContextLimit = pc.ContextLimit + } + if pc.MaxOutputTokens > 0 { + slog.Info("overriding model max output tokens", + "phase", name, "registry", pc.ModelCfg.MaxOutputTokens, "override", pc.MaxOutputTokens) + pc.ModelCfg.MaxOutputTokens = pc.MaxOutputTokens + } + } + + // ── 4. Provider + API key resolution, per phase ──────────────────── + for name, pc := range allPhases(cfg) { + if pc.Provider == "" { + pc.Provider = detectProvider(pc, cfg) + } + if pc.APIKey == "" { + pc.APIKey = ambientKey(pc.Provider, cfg) + } + slog.Debug("phase resolved", + "phase", name, "provider", pc.Provider, "model", pc.ModelCfg.Name, + "has_api_key", pc.APIKey != "", "base_url", pc.BaseURL) + } + + return nil +} + +// overlay copies every non-zero field from src onto dst. This is how a +// phase-specific config (from config file or PHASES_* env) beats the +// inherited/legacy value. Field-by-field because reflect would be overkill +// for a dozen fields and hides what "non-zero" means for each type. +func overlay(dst, src *PhaseConfig) { + if src.Provider != "" { + dst.Provider = src.Provider + } + if src.Model != "" { + dst.Model = src.Model + } + if src.APIKey != "" { + dst.APIKey = src.APIKey + } + if src.BaseURL != "" { + dst.BaseURL = src.BaseURL + } + if src.Endpoint != "" { + dst.Endpoint = src.Endpoint + } + if len(src.ModelParams) > 0 { + // Replace, don't merge — a phase that sets its own params gets + // exactly those params. Inheritance already gave dst the analysis + // params as a starting point; overlaying means "I want these + // instead." Merge-on-overlay would make it impossible to drop an + // inherited key. + dst.ModelParams = cloneParams(src.ModelParams) + } + if src.ModelParamsJSON != "" { + dst.ModelParamsJSON = src.ModelParamsJSON + } + if src.RequestTimeout > 0 { + dst.RequestTimeout = src.RequestTimeout + } + if len(src.Headers) > 0 { + dst.Headers = src.Headers + } + if src.ContextLimit > 0 { + dst.ContextLimit = src.ContextLimit + } + if src.MaxOutputTokens > 0 { + dst.MaxOutputTokens = src.MaxOutputTokens + } +} + +// parsePhaseParams decodes the JSON-string form of model params and merges +// it into the map form. The string form wins on conflict — it's the +// CLI/env override path. +func parsePhaseParams(pc *PhaseConfig, phase string) error { + if pc.ModelParams == nil { + pc.ModelParams = map[string]any{} + } + s := strings.TrimSpace(pc.ModelParamsJSON) + if s == "" { + return nil + } + var override map[string]any + if err := json.Unmarshal([]byte(s), &override); err != nil { + return fmt.Errorf("parsing %s model-params-json: %w", phase, err) + } + mergeMaps(pc.ModelParams, override) + return nil +} + +// detectProvider picks a provider when the phase didn't set one explicitly. +// Precedence: registry hint for this phase's model, then Databricks ambient +// env (because Databricks proxies all providers — if it's configured, route +// through it), then whichever direct-provider key is set on cfg, then +// databricks as the historic default. +func detectProvider(pc *PhaseConfig, cfg *Config) string { + if pc.ModelCfg.Provider != "" { + // Registry knows — but Databricks proxies everything, so if both + // Databricks env and a direct key are set, prefer Databricks. + // This preserves the prior resolveProvider behaviour. + if cfg.DatabricksHost != "" && cfg.DatabricksToken != "" { + return "databricks" + } + return pc.ModelCfg.Provider + } + if cfg.DatabricksHost != "" && cfg.DatabricksToken != "" { + return "databricks" + } + if cfg.AnthropicAPIKey != "" { + return "anthropic" + } + if cfg.OpenAIAPIKey != "" { + return "openai" + } + if cfg.GoogleAPIKey != "" { + return "google" + } + // No credentials detected — fall back to anthropic (Claude CLI can + // authenticate without an API key) rather than databricks which + // requires host+token env vars. + return "anthropic" +} + +// ambientKey returns the global provider-specific key from cfg when the +// phase didn't set its own. Databricks auth is host+token, not a single +// key, so it's handled in the client builder instead. +func ambientKey(provider string, cfg *Config) string { + switch provider { + case "anthropic": + return cfg.AnthropicAPIKey + case "openai": + return cfg.OpenAIAPIKey + case "google": + return cfg.GoogleAPIKey + } + return "" +} + +func lookupOrDefault(name, phase string) ModelConfig { + if name == "" { + return DefaultModel() + } + if m, ok := LookupModel(name); ok { + // Preserve the user-supplied name: the registry may have matched by + // substring (databricks-claude-opus → claude-opus-4), and the API + // wants the name the user typed. + m.Name = name + return m + } + slog.Warn("model not in registry, using defaults", "phase", phase, "model", name) + return UnknownModelDefaults(name) +} + +func cloneParams(m map[string]any) map[string]any { + if m == nil { + return nil + } + out := make(map[string]any, len(m)) + for k, v := range m { + // Deep-copy nested maps to prevent aliasing between phases. + if nested, ok := v.(map[string]any); ok { + out[k] = cloneParams(nested) + } else { + out[k] = v + } + } + return out +} + +// allPhases iterates every phase by name and pointer, for the resolve +// passes that treat them uniformly. +func allPhases(cfg *Config) func(yield func(string, *PhaseConfig) bool) { + return func(yield func(string, *PhaseConfig) bool) { + if !yield("analysis", &cfg.Phases.Analysis) { + return + } + if !yield("feature-detection", &cfg.Phases.FeatureDetection) { + return + } + if !yield("audit", &cfg.Phases.Audit) { + return + } + yield("context-compress", &cfg.Phases.ContextCompress) + } +} diff --git a/internal/config/phase_test.go b/internal/config/phase_test.go new file mode 100644 index 0000000..0458011 --- /dev/null +++ b/internal/config/phase_test.go @@ -0,0 +1,331 @@ +package config + +import ( + "reflect" + "testing" +) + +func TestResolvePhases_InheritFromLegacy(t *testing.T) { + // Nothing per-phase set: all three phases should end up identical, + // seeded from the legacy flat fields. + cfg := &Config{ + Model: "claude-opus-4-6", + Provider: "anthropic", + AnthropicAPIKey: "sk-ant-legacy", + ModelParams: map[string]any{"temperature": 0.0}, + RequestTimeout: 1800, + ContextLimit: 500000, + MaxOutputTokens: 128000, + } + if err := ResolvePhases(cfg); err != nil { + t.Fatalf("ResolvePhases: %v", err) + } + + for name, pc := range allPhases(cfg) { + if pc.Provider != "anthropic" { + t.Errorf("%s.Provider = %q, want anthropic", name, pc.Provider) + } + if pc.APIKey != "sk-ant-legacy" { + t.Errorf("%s.APIKey = %q, want sk-ant-legacy", name, pc.APIKey) + } + if pc.RequestTimeout != 1800 { + t.Errorf("%s.RequestTimeout = %d, want 1800", name, pc.RequestTimeout) + } + if pc.ModelCfg.ContextLimit != 500000 { + t.Errorf("%s.ModelCfg.ContextLimit = %d, want 500000 (override applied)", name, pc.ModelCfg.ContextLimit) + } + if pc.ModelCfg.MaxOutputTokens != 128000 { + t.Errorf("%s.ModelCfg.MaxOutputTokens = %d, want 128000 (override applied)", name, pc.ModelCfg.MaxOutputTokens) + } + } +} + +func TestResolvePhases_PerPhaseOverride(t *testing.T) { + // Audit phase sets its own provider + key + model; analysis and + // feature-detection stay on the legacy values. This is the headline + // use case: audit on a different provider entirely. + cfg := &Config{ + Model: "claude-opus-4-6", + AnthropicAPIKey: "sk-ant-analysis", + Phases: Phases{ + Audit: PhaseConfig{ + Provider: "google", + Model: "gemini-3-pro", + APIKey: "goog-audit", + }, + }, + } + if err := ResolvePhases(cfg); err != nil { + t.Fatalf("ResolvePhases: %v", err) + } + + if cfg.Phases.Analysis.Provider != "anthropic" { + t.Errorf("analysis.Provider = %q, want anthropic (from registry)", cfg.Phases.Analysis.Provider) + } + if cfg.Phases.Analysis.APIKey != "sk-ant-analysis" { + t.Errorf("analysis.APIKey = %q, want sk-ant-analysis", cfg.Phases.Analysis.APIKey) + } + + if cfg.Phases.Audit.Provider != "google" { + t.Errorf("audit.Provider = %q, want google", cfg.Phases.Audit.Provider) + } + if cfg.Phases.Audit.APIKey != "goog-audit" { + t.Errorf("audit.APIKey = %q, want goog-audit (per-phase, not inherited)", cfg.Phases.Audit.APIKey) + } + if cfg.Phases.Audit.ModelCfg.Name != "gemini-3-pro" { + t.Errorf("audit.ModelCfg.Name = %q, want gemini-3-pro", cfg.Phases.Audit.ModelCfg.Name) + } + // ContextLimit should come from the gemini registry entry, not + // inherited from claude. + if cfg.Phases.Audit.ModelCfg.ContextLimit != 1048576 { + t.Errorf("audit.ModelCfg.ContextLimit = %d, want 1048576 (gemini registry)", cfg.Phases.Audit.ModelCfg.ContextLimit) + } + + // Feature-detection inherits everything from analysis. + if cfg.Phases.FeatureDetection.Provider != "anthropic" { + t.Errorf("fd.Provider = %q, want anthropic (inherited)", cfg.Phases.FeatureDetection.Provider) + } +} + +func TestResolvePhases_LegacyAuditModelStillWorks(t *testing.T) { + // --audit-model (the legacy flag) should still override the model for + // the audit phase without touching provider/key. + cfg := &Config{ + Model: "claude-opus-4-6", + AnthropicAPIKey: "sk-ant", + AuditModel: "claude-sonnet-4-6", + } + if err := ResolvePhases(cfg); err != nil { + t.Fatalf("ResolvePhases: %v", err) + } + + if cfg.Phases.Audit.ModelCfg.Name != "claude-sonnet-4-6" { + t.Errorf("audit model = %q, want claude-sonnet-4-6", cfg.Phases.Audit.ModelCfg.Name) + } + if cfg.Phases.Audit.Provider != "anthropic" { + t.Errorf("audit provider = %q, want anthropic (inherited)", cfg.Phases.Audit.Provider) + } + if cfg.Phases.Audit.APIKey != "sk-ant" { + t.Errorf("audit key = %q, want sk-ant (inherited)", cfg.Phases.Audit.APIKey) + } +} + +func TestResolvePhases_ModelParamsNotAliased(t *testing.T) { + // Mutating the audit phase's ModelParams after resolve must not + // mutate the analysis phase's — the inheritance must have broken + // the map alias. + cfg := &Config{ + Model: "claude-opus-4-6", + AnthropicAPIKey: "k", + ModelParams: map[string]any{"max_tokens": 1000}, + } + if err := ResolvePhases(cfg); err != nil { + t.Fatalf("ResolvePhases: %v", err) + } + + cfg.Phases.Audit.ModelParams["max_tokens"] = 9999 + + if cfg.Phases.Analysis.ModelParams["max_tokens"] != 1000 { + t.Errorf("analysis.ModelParams mutated by audit write: %v", cfg.Phases.Analysis.ModelParams) + } + if cfg.Phases.FeatureDetection.ModelParams["max_tokens"] != 1000 { + t.Errorf("fd.ModelParams mutated by audit write: %v", cfg.Phases.FeatureDetection.ModelParams) + } +} + +func TestResolvePhases_PerPhaseModelParamsReplace(t *testing.T) { + // A phase that sets its own model-params gets exactly those params, + // not a merge with the inherited ones. The whole point of per-phase + // params is being able to DROP an inherited key (e.g. thinking-mode + // params that only analysis needs). + cfg := &Config{ + Model: "claude-opus-4-6", + AnthropicAPIKey: "k", + ModelParams: map[string]any{"thinking": map[string]any{"type": "enabled"}, "max_tokens": 32000}, + Phases: Phases{ + Audit: PhaseConfig{ + ModelParams: map[string]any{"max_tokens": 8000}, + }, + }, + } + if err := ResolvePhases(cfg); err != nil { + t.Fatalf("ResolvePhases: %v", err) + } + + want := map[string]any{"max_tokens": 8000} + if !reflect.DeepEqual(cfg.Phases.Audit.ModelParams, want) { + t.Errorf("audit.ModelParams = %v, want %v (replace, not merge — thinking key should be gone)", + cfg.Phases.Audit.ModelParams, want) + } + // Analysis keeps its own. + if _, ok := cfg.Phases.Analysis.ModelParams["thinking"]; !ok { + t.Error("analysis.ModelParams lost thinking key") + } +} + +func TestResolvePhases_ModelParamsJSON(t *testing.T) { + cfg := &Config{ + Model: "claude-opus-4-6", + AnthropicAPIKey: "k", + Phases: Phases{ + Audit: PhaseConfig{ + ModelParamsJSON: `{"max_tokens": 4096, "tool_choice": {"type": "auto"}}`, + }, + }, + } + if err := ResolvePhases(cfg); err != nil { + t.Fatalf("ResolvePhases: %v", err) + } + + if cfg.Phases.Audit.ModelParams["max_tokens"] != float64(4096) { + t.Errorf("audit max_tokens = %v, want 4096", cfg.Phases.Audit.ModelParams["max_tokens"]) + } + tc, _ := cfg.Phases.Audit.ModelParams["tool_choice"].(map[string]any) + if tc["type"] != "auto" { + t.Errorf("audit tool_choice = %v, want {type:auto}", cfg.Phases.Audit.ModelParams["tool_choice"]) + } +} + +func TestResolvePhases_ModelParamsJSON_Invalid(t *testing.T) { + cfg := &Config{ + Phases: Phases{ + Audit: PhaseConfig{ModelParamsJSON: `{not json`}, + }, + } + err := ResolvePhases(cfg) + if err == nil { + t.Fatal("expected error for invalid JSON, got nil") + } +} + +func TestResolvePhases_GoogleAmbientKey(t *testing.T) { + // No per-phase key, no Anthropic key, but GOOGLE_API_KEY is set and + // the model is gemini → provider should auto-detect to google and + // the key should cascade. + cfg := &Config{ + Model: "gemini-3-flash", + GoogleAPIKey: "goog-ambient", + } + if err := ResolvePhases(cfg); err != nil { + t.Fatalf("ResolvePhases: %v", err) + } + + if cfg.Phases.Analysis.Provider != "google" { + t.Errorf("provider = %q, want google (from registry)", cfg.Phases.Analysis.Provider) + } + if cfg.Phases.Analysis.APIKey != "goog-ambient" { + t.Errorf("key = %q, want goog-ambient", cfg.Phases.Analysis.APIKey) + } +} + +func TestResolvePhases_ContextLimitFix(t *testing.T) { + // Prior bug: --context-limit only patched the main modelCfg; a + // separate --audit-model got fresh registry values without the + // override. Now the override inherits per-phase unless the phase + // sets its own. + cfg := &Config{ + Model: "claude-opus-4-6", + AnthropicAPIKey: "k", + ContextLimit: 777777, + AuditModel: "claude-sonnet-4-6", // different model, same override + } + if err := ResolvePhases(cfg); err != nil { + t.Fatalf("ResolvePhases: %v", err) + } + + if cfg.Phases.Analysis.ModelCfg.ContextLimit != 777777 { + t.Errorf("analysis context = %d, want 777777", cfg.Phases.Analysis.ModelCfg.ContextLimit) + } + if cfg.Phases.Audit.ModelCfg.ContextLimit != 777777 { + t.Errorf("audit context = %d, want 777777 (override inherited despite different model)", cfg.Phases.Audit.ModelCfg.ContextLimit) + } +} + +func TestResolvePhases_DatabricksProxiesAll(t *testing.T) { + // When Databricks env is set, it wins over the registry's provider + // hint — Databricks proxies all models. This is the prior behaviour + // of resolveProvider and must be preserved. + cfg := &Config{ + Model: "claude-opus-4-6", // registry says anthropic + DatabricksHost: "https://dbx.example.com", + DatabricksToken: "dbx-tok", + } + if err := ResolvePhases(cfg); err != nil { + t.Fatalf("ResolvePhases: %v", err) + } + + if cfg.Phases.Analysis.Provider != "databricks" { + t.Errorf("provider = %q, want databricks (proxies anthropic)", cfg.Phases.Analysis.Provider) + } +} + +func TestDetectProvider(t *testing.T) { + tests := []struct { + name string + pc PhaseConfig + cfg Config + want string + }{ + { + name: "registry hint wins when no databricks env", + pc: PhaseConfig{ModelCfg: ModelConfig{Provider: "anthropic"}}, + cfg: Config{}, + want: "anthropic", + }, + { + name: "databricks env overrides registry hint", + pc: PhaseConfig{ModelCfg: ModelConfig{Provider: "anthropic"}}, + cfg: Config{DatabricksHost: "https://dbx", DatabricksToken: "tok"}, + want: "databricks", + }, + { + name: "databricks env without registry hint", + pc: PhaseConfig{}, + cfg: Config{DatabricksHost: "https://dbx", DatabricksToken: "tok"}, + want: "databricks", + }, + { + name: "databricks host alone is not enough", + pc: PhaseConfig{}, + cfg: Config{DatabricksHost: "https://dbx", AnthropicAPIKey: "sk-ant"}, + want: "anthropic", + }, + { + name: "anthropic key", + pc: PhaseConfig{}, + cfg: Config{AnthropicAPIKey: "sk-ant"}, + want: "anthropic", + }, + { + name: "openai key", + pc: PhaseConfig{}, + cfg: Config{OpenAIAPIKey: "sk-oai"}, + want: "openai", + }, + { + name: "google key", + pc: PhaseConfig{}, + cfg: Config{GoogleAPIKey: "goog"}, + want: "google", + }, + { + name: "anthropic beats openai when both set", + pc: PhaseConfig{}, + cfg: Config{AnthropicAPIKey: "sk-ant", OpenAIAPIKey: "sk-oai"}, + want: "anthropic", + }, + { + name: "fallback to anthropic when nothing set", + pc: PhaseConfig{}, + cfg: Config{}, + want: "anthropic", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := detectProvider(&tt.pc, &tt.cfg); got != tt.want { + t.Errorf("detectProvider() = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/internal/ingest/filter.go b/internal/ingest/filter.go new file mode 100644 index 0000000..133a5a5 --- /dev/null +++ b/internal/ingest/filter.go @@ -0,0 +1,457 @@ +package ingest + +import ( + "log/slog" + "path/filepath" + "strings" +) + +// DefaultMaxFileSize is the default maximum file size in bytes (100KB). +// Most application source files are well under this; files exceeding it are +// typically generated code, fixtures, or data files that waste token budget. +const DefaultMaxFileSize = 100 * 1024 + +// FilterConfig controls which files are kept or excluded by FilterFiles. +type FilterConfig struct { + IncludeTests bool // When true, test files are not excluded. + IncludeDocs bool // When true, .md files are not excluded. + Include []string // Custom glob patterns to force-include. + Exclude []string // Custom glob patterns to additionally exclude. + MaxFileSize int // Maximum file size in bytes (0 = no limit). +} + +// FilterStats reports how many files were excluded per category. +type FilterStats struct { + Tests int + Vendor int + Binary int + Docs int + Custom int + LowValue int + Oversized int + Kept int + Total int +} + +// testFilePatterns matches individual test file naming conventions. +var testFilePatterns = []string{ + "*_test.go", + "test_*.py", + "*_spec.ts", + "*_spec.js", + "*_spec.tsx", + "*_spec.jsx", + "*_test.py", + "*_test.ts", + "*_test.js", + "test_*.go", + "test_*.ts", + "test_*.js", +} + +// testDirNames are directory names that indicate test content. +var testDirNames = map[string]bool{ + "test": true, + "tests": true, + "__tests__": true, + "spec": true, + "__mocks__": true, +} + +// vendorDirNames are directory names that indicate vendored dependencies. +var vendorDirNames = map[string]bool{ + "vendor": true, + "node_modules": true, + "__pycache__": true, + ".venv": true, + "venv": true, +} + +// lowValueDirNames are directories containing CI/CD, tooling, or data files +// that produce noisy, low-value security findings and dilute LLM analysis. +var lowValueDirNames = map[string]bool{ + ".github": true, + ".circleci": true, + ".dependabot": true, + ".gitlab": true, + ".husky": true, + ".vscode": true, + ".idea": true, + ".zap": true, + "encryptionkeys": true, +} + +// lowValueExtensions are file extensions unlikely to contain exploitable vulnerabilities. +var lowValueExtensions = map[string]bool{ + ".key": true, + ".pem": true, + ".crt": true, + ".pub": true, + ".cert": true, + ".tsv": true, + ".csv": true, + ".ipynb": true, +} + +// lowValueFilenames are specific filenames that are build/config artifacts. +var lowValueFilenames = map[string]bool{ + "Gruntfile.js": true, + "Gulpfile.js": true, + "Cakefile": true, + "Rakefile": true, + "ctf.key": true, + ".npmrc": true, + ".eslintrc.js": true, + ".eslintrc.json": true, + ".prettierrc": true, + ".prettierrc.json": true, + ".babelrc": true, + "jest.config.js": true, + "jest.config.ts": true, + "karma.conf.js": true, + "protractor.conf.js": true, + "package-lock.json": true, + "yarn.lock": true, + "pnpm-lock.yaml": true, + "composer.lock": true, + "Gemfile.lock": true, + "Pipfile.lock": true, + "poetry.lock": true, + "Cargo.lock": true, +} + +// binaryExtensions are file extensions that indicate binary/generated content. +var binaryExtensions = map[string]bool{ + ".exe": true, + ".dll": true, + ".so": true, + ".dylib": true, + ".a": true, + ".o": true, + ".obj": true, + ".lib": true, + ".png": true, + ".jpg": true, + ".jpeg": true, + ".gif": true, + ".bmp": true, + ".ico": true, + ".svg": true, + ".webp": true, + ".zip": true, + ".tar": true, + ".gz": true, + ".bz2": true, + ".xz": true, + ".7z": true, + ".rar": true, + ".jar": true, + ".war": true, + ".class": true, + ".pyc": true, + ".pyo": true, + ".wasm": true, + ".pdf": true, + ".doc": true, + ".docx": true, + ".xls": true, + ".xlsx": true, + ".ppt": true, + ".pptx": true, + ".ttf": true, + ".otf": true, + ".woff": true, + ".woff2": true, + ".eot": true, + ".mp3": true, + ".mp4": true, + ".wav": true, + ".avi": true, + ".mov": true, + ".db": true, + ".sqlite": true, +} + +// alwaysIncludeExtensions are extensions that should never be filtered out, +// regardless of their location (e.g., even inside test directories). +var alwaysIncludeExtensions = map[string]bool{ + ".proto": true, + ".sql": true, + ".graphql": true, + ".gql": true, +} + +// FilterFiles applies heuristic filters to a list of source files and returns +// the files that pass all filters. It logs summary statistics about excluded files. +func FilterFiles(files []SourceFile, cfg FilterConfig) ([]SourceFile, FilterStats) { + var stats FilterStats + stats.Total = len(files) + + if len(files) == 0 { + return []SourceFile{}, stats + } + + var kept []SourceFile + + for _, f := range files { + category := classifyFile(f, cfg) + switch category { + case categoryKept: + kept = append(kept, f) + stats.Kept++ + case categoryTest: + stats.Tests++ + case categoryVendor: + stats.Vendor++ + case categoryBinary: + stats.Binary++ + case categoryDocs: + stats.Docs++ + case categoryCustom: + stats.Custom++ + case categoryLowValue: + stats.LowValue++ + case categoryOversized: + stats.Oversized++ + } + } + + if kept == nil { + kept = []SourceFile{} + } + + logFilterStats(stats) + return kept, stats +} + +// fileCategory represents why a file was kept or excluded. +type fileCategory int + +const ( + categoryKept fileCategory = iota + categoryTest + categoryVendor + categoryBinary + categoryDocs + categoryCustom + categoryLowValue + categoryOversized +) + +// classifyFile determines whether a file should be kept or excluded, and why. +func classifyFile(f SourceFile, cfg FilterConfig) fileCategory { + ext := strings.ToLower(filepath.Ext(f.Path)) + + // Always-include extensions bypass all filters. + if alwaysIncludeExtensions[ext] { + return categoryKept + } + + // Custom include patterns: if any match, the file is force-included. + if matchesAnyGlob(f.Path, cfg.Include) { + return categoryKept + } + + // Custom exclude patterns: checked before default heuristics. + if matchesAnyGlob(f.Path, cfg.Exclude) { + return categoryCustom + } + + // Oversized file check. + if cfg.MaxFileSize > 0 && len(f.Content) > cfg.MaxFileSize { + return categoryOversized + } + + // Binary extension check. + if binaryExtensions[ext] { + return categoryBinary + } + + // Binary content check (null byte in first 512 bytes). + if isBinaryFile(f.Content) { + return categoryBinary + } + + // Vendor directory check. + if isInVendorDir(f.Path) { + return categoryVendor + } + + // Low-value file check (CI/CD, tooling, key files). + if isLowValueFile(f.Path) { + return categoryLowValue + } + + // Test file/directory check (skipped if IncludeTests is set). + if !cfg.IncludeTests && isTestFile(f.Path) { + return categoryTest + } + + // Documentation check (skipped if IncludeDocs is set). + if !cfg.IncludeDocs && ext == ".md" { + return categoryDocs + } + + return categoryKept +} + +// isTestFile checks whether a file path looks like a test file, +// either by filename pattern or by being inside a test directory. +func isTestFile(path string) bool { + base := filepath.Base(path) + + // Check filename patterns. + for _, pattern := range testFilePatterns { + if matched, _ := filepath.Match(pattern, base); matched { + return true + } + } + + // Check if any path component is a test directory. + parts := strings.Split(filepath.ToSlash(path), "/") + for _, part := range parts { + if testDirNames[part] { + return true + } + } + + return false +} + +// isInVendorDir checks whether a file is inside a vendor directory. +func isInVendorDir(path string) bool { + parts := strings.Split(filepath.ToSlash(path), "/") + for _, part := range parts { + if vendorDirNames[part] { + return true + } + } + return false +} + +// isBinaryFile checks file content for null bytes in the first 512 bytes. +func isBinaryFile(content string) bool { + limit := 512 + if len(content) < limit { + limit = len(content) + } + return strings.Contains(content[:limit], "\x00") +} + +// matchesAnyGlob returns true if path matches any of the given glob patterns. +// Patterns are matched against the full relative path and also against just +// the filename component. +func matchesAnyGlob(path string, patterns []string) bool { + if len(patterns) == 0 { + return false + } + + // Normalize to forward slashes for consistent matching. + normalized := filepath.ToSlash(path) + base := filepath.Base(path) + + for _, pattern := range patterns { + pattern = filepath.ToSlash(pattern) + + if pattern == "**" { + return true + } + + // Match against full path. + if matched, _ := filepath.Match(pattern, normalized); matched { + return true + } + + // Match against base filename. + if matched, _ := filepath.Match(pattern, base); matched { + return true + } + + // Support ** prefix patterns by checking if path ends with pattern suffix. + if strings.HasPrefix(pattern, "**/") { + suffix := pattern[3:] + // Check if any path suffix matches. + if matched, _ := filepath.Match(suffix, base); matched { + return true + } + // Check path segments. + parts := strings.Split(normalized, "/") + for i := range parts { + subpath := strings.Join(parts[i:], "/") + if matched, _ := filepath.Match(suffix, subpath); matched { + return true + } + } + } + + // Support trailing /** patterns as recursive directory matches, + // e.g. firmware/third-party/** should match all descendants. + if strings.HasSuffix(pattern, "/**") { + dirPattern := strings.TrimSuffix(pattern, "/**") + + // Check the full path and every ancestor directory against + // the directory pattern, so wildcard segments still work. + if matched, _ := filepath.Match(dirPattern, normalized); matched { + return true + } + parts := strings.Split(normalized, "/") + for i := 1; i < len(parts); i++ { + ancestor := strings.Join(parts[:i], "/") + if matched, _ := filepath.Match(dirPattern, ancestor); matched { + return true + } + } + } + } + + return false +} + +// isLowValueFile checks whether a file is in a low-value directory, has a +// low-value extension, or has a low-value filename (CI/CD, tooling, key files). +func isLowValueFile(path string) bool { + base := filepath.Base(path) + ext := strings.ToLower(filepath.Ext(path)) + + // Check filename + if lowValueFilenames[base] { + return true + } + + // Check extension + if lowValueExtensions[ext] { + return true + } + + // Check if any path component is a low-value directory + parts := strings.Split(filepath.ToSlash(path), "/") + for _, part := range parts { + if lowValueDirNames[part] { + return true + } + } + + return false +} + +// logFilterStats logs a summary of filter results. +func logFilterStats(stats FilterStats) { + excluded := stats.Total - stats.Kept + if excluded == 0 { + slog.Info("file filter: all files kept", "total", stats.Total) + return + } + + slog.Info("file filter summary", + "total", stats.Total, + "kept", stats.Kept, + "excluded", excluded, + "tests", stats.Tests, + "vendor", stats.Vendor, + "binary", stats.Binary, + "docs", stats.Docs, + "custom", stats.Custom, + "low_value", stats.LowValue, + "oversized", stats.Oversized, + ) +} diff --git a/internal/ingest/filter_test.go b/internal/ingest/filter_test.go new file mode 100644 index 0000000..35f40d7 --- /dev/null +++ b/internal/ingest/filter_test.go @@ -0,0 +1,962 @@ +package ingest + +import ( + "strings" + "testing" +) + +// makeFile creates a SourceFile with the given path and dummy content. +func makeFile(path string) SourceFile { + return SourceFile{ + Path: path, + Content: "package main\n", + LineCount: 1, + Language: InferLanguage(path), + } +} + +// makeFileWithContent creates a SourceFile with specific content. +func makeFileWithContent(path, content string) SourceFile { + return SourceFile{ + Path: path, + Content: content, + LineCount: countLines(content), + Language: InferLanguage(path), + } +} + +func TestFilterFiles_EmptyInput(t *testing.T) { + kept, stats := FilterFiles(nil, FilterConfig{}) + if len(kept) != 0 { + t.Errorf("expected empty slice, got %d files", len(kept)) + } + if stats.Total != 0 { + t.Errorf("expected total 0, got %d", stats.Total) + } + + kept, stats = FilterFiles([]SourceFile{}, FilterConfig{}) + if len(kept) != 0 { + t.Errorf("expected empty slice, got %d files", len(kept)) + } + if stats.Total != 0 { + t.Errorf("expected total 0, got %d", stats.Total) + } +} + +func TestFilterFiles_DefaultExclusions_TestFiles(t *testing.T) { + files := []SourceFile{ + makeFile("main.go"), + makeFile("handler_test.go"), + makeFile("test_helper.py"), + makeFile("component_spec.ts"), + makeFile("widget_spec.js"), + makeFile("component_spec.tsx"), + makeFile("widget_spec.jsx"), + makeFile("app_test.py"), + makeFile("app_test.ts"), + makeFile("app_test.js"), + makeFile("test_util.go"), + makeFile("test_util.ts"), + makeFile("test_util.js"), + } + + kept, stats := FilterFiles(files, FilterConfig{}) + + if len(kept) != 1 { + paths := make([]string, len(kept)) + for i, f := range kept { + paths[i] = f.Path + } + t.Fatalf("expected 1 file kept, got %d: %v", len(kept), paths) + } + if kept[0].Path != "main.go" { + t.Errorf("expected main.go, got %q", kept[0].Path) + } + if stats.Tests != 12 { + t.Errorf("expected 12 test files excluded, got %d", stats.Tests) + } +} + +func TestFilterFiles_DefaultExclusions_TestDirectories(t *testing.T) { + files := []SourceFile{ + makeFile("src/main.go"), + makeFile("test/helper.go"), + makeFile("tests/integration.py"), + makeFile("__tests__/component.test.js"), + makeFile("spec/models/user_spec.rb"), + makeFile("src/__tests__/deep/nested.js"), + } + + kept, stats := FilterFiles(files, FilterConfig{}) + + if len(kept) != 1 { + paths := make([]string, len(kept)) + for i, f := range kept { + paths[i] = f.Path + } + t.Fatalf("expected 1 file kept, got %d: %v", len(kept), paths) + } + if kept[0].Path != "src/main.go" { + t.Errorf("expected src/main.go, got %q", kept[0].Path) + } + if stats.Tests != 5 { + t.Errorf("expected 5 test directory files excluded, got %d", stats.Tests) + } +} + +func TestFilterFiles_DefaultExclusions_VendorDirs(t *testing.T) { + files := []SourceFile{ + makeFile("main.go"), + makeFile("vendor/github.com/pkg/dep.go"), + makeFile("node_modules/express/index.js"), + makeFile("__pycache__/module.cpython-39.py"), + makeFile(".venv/lib/python3.9/site-packages/pip.py"), + makeFile("venv/lib/python3.9/site-packages/pip.py"), + } + + kept, stats := FilterFiles(files, FilterConfig{}) + + if len(kept) != 1 { + t.Fatalf("expected 1 file kept, got %d", len(kept)) + } + if kept[0].Path != "main.go" { + t.Errorf("expected main.go, got %q", kept[0].Path) + } + if stats.Vendor != 5 { + t.Errorf("expected 5 vendor files excluded, got %d", stats.Vendor) + } +} + +func TestFilterFiles_DefaultExclusions_BinaryExtensions(t *testing.T) { + files := []SourceFile{ + makeFile("main.go"), + makeFile("app.exe"), + makeFile("lib.dll"), + makeFile("lib.so"), + makeFile("lib.dylib"), + makeFile("image.png"), + makeFile("photo.jpg"), + makeFile("icon.gif"), + makeFile("archive.zip"), + makeFile("package.tar"), + makeFile("compressed.gz"), + } + + kept, stats := FilterFiles(files, FilterConfig{}) + + if len(kept) != 1 { + t.Fatalf("expected 1 file kept, got %d", len(kept)) + } + if kept[0].Path != "main.go" { + t.Errorf("expected main.go, got %q", kept[0].Path) + } + if stats.Binary != 10 { + t.Errorf("expected 10 binary files excluded, got %d", stats.Binary) + } +} + +func TestFilterFiles_DefaultExclusions_Docs(t *testing.T) { + files := []SourceFile{ + makeFile("main.go"), + makeFile("README.md"), + makeFile("docs/guide.md"), + makeFile("CHANGELOG.md"), + } + + kept, stats := FilterFiles(files, FilterConfig{}) + + if len(kept) != 1 { + t.Fatalf("expected 1 file kept, got %d", len(kept)) + } + if kept[0].Path != "main.go" { + t.Errorf("expected main.go, got %q", kept[0].Path) + } + if stats.Docs != 3 { + t.Errorf("expected 3 docs excluded, got %d", stats.Docs) + } +} + +func TestFilterFiles_AlwaysInclude_Proto(t *testing.T) { + files := []SourceFile{ + makeFile("api/service.proto"), + makeFile("test/api.proto"), // in test dir, but .proto + makeFile("vendor/google/protobuf/descriptor.proto"), // in vendor dir, but .proto + } + + kept, _ := FilterFiles(files, FilterConfig{}) + + if len(kept) != 3 { + paths := make([]string, len(kept)) + for i, f := range kept { + paths[i] = f.Path + } + t.Fatalf("expected all 3 .proto files kept, got %d: %v", len(kept), paths) + } +} + +func TestFilterFiles_AlwaysInclude_SQL(t *testing.T) { + files := []SourceFile{ + makeFile("db/migrations/001.sql"), + makeFile("test/fixtures.sql"), // in test dir, but .sql + } + + kept, _ := FilterFiles(files, FilterConfig{}) + + if len(kept) != 2 { + t.Fatalf("expected 2 .sql files kept, got %d", len(kept)) + } +} + +func TestFilterFiles_AlwaysInclude_GraphQL(t *testing.T) { + files := []SourceFile{ + makeFile("schema/types.graphql"), + makeFile("test/query.graphql"), // in test dir, but .graphql + makeFile("api/schema.gql"), + } + + kept, _ := FilterFiles(files, FilterConfig{}) + + if len(kept) != 3 { + t.Fatalf("expected 3 graphql files kept, got %d", len(kept)) + } +} + +func TestFilterFiles_IncludeTestsFlag(t *testing.T) { + files := []SourceFile{ + makeFile("main.go"), + makeFile("main_test.go"), + makeFile("test/helper.go"), + makeFile("__tests__/component.js"), + } + + kept, stats := FilterFiles(files, FilterConfig{IncludeTests: true}) + + if len(kept) != 4 { + paths := make([]string, len(kept)) + for i, f := range kept { + paths[i] = f.Path + } + t.Fatalf("expected 4 files kept with --include-tests, got %d: %v", len(kept), paths) + } + if stats.Tests != 0 { + t.Errorf("expected 0 test exclusions with --include-tests, got %d", stats.Tests) + } +} + +func TestFilterFiles_IncludeDocsFlag(t *testing.T) { + files := []SourceFile{ + makeFile("main.go"), + makeFile("README.md"), + makeFile("docs/guide.md"), + } + + kept, stats := FilterFiles(files, FilterConfig{IncludeDocs: true}) + + if len(kept) != 3 { + t.Fatalf("expected 3 files kept with --include-docs, got %d", len(kept)) + } + if stats.Docs != 0 { + t.Errorf("expected 0 doc exclusions with --include-docs, got %d", stats.Docs) + } +} + +func TestFilterFiles_CustomIncludeOverridesExclusion(t *testing.T) { + files := []SourceFile{ + makeFile("main.go"), + makeFile("main_test.go"), + makeFile("vendor/important.go"), + } + + // Custom include pattern should override default exclusions. + kept, _ := FilterFiles(files, FilterConfig{ + Include: []string{"*_test.go", "vendor/important.go"}, + }) + + if len(kept) != 3 { + paths := make([]string, len(kept)) + for i, f := range kept { + paths[i] = f.Path + } + t.Fatalf("expected 3 files kept with custom include, got %d: %v", len(kept), paths) + } +} + +func TestFilterFiles_CustomExcludeAddsExclusions(t *testing.T) { + files := []SourceFile{ + makeFile("main.go"), + makeFile("generated.pb.go"), + makeFile("config.yaml"), + } + + kept, stats := FilterFiles(files, FilterConfig{ + Exclude: []string{"*.pb.go", "*.yaml"}, + }) + + if len(kept) != 1 { + paths := make([]string, len(kept)) + for i, f := range kept { + paths[i] = f.Path + } + t.Fatalf("expected 1 file kept, got %d: %v", len(kept), paths) + } + if kept[0].Path != "main.go" { + t.Errorf("expected main.go, got %q", kept[0].Path) + } + if stats.Custom != 2 { + t.Errorf("expected 2 custom exclusions, got %d", stats.Custom) + } +} + +func TestFilterFiles_CustomExcludeDoesNotOverrideAlwaysInclude(t *testing.T) { + files := []SourceFile{ + makeFile("schema.proto"), + makeFile("migration.sql"), + makeFile("types.graphql"), + } + + // Custom exclude should NOT override always-include extensions. + kept, _ := FilterFiles(files, FilterConfig{ + Exclude: []string{"*.proto", "*.sql", "*.graphql"}, + }) + + if len(kept) != 3 { + t.Fatalf("expected 3 always-include files kept despite custom exclude, got %d", len(kept)) + } +} + +func TestFilterFiles_BinaryContentDetection(t *testing.T) { + // File with null byte in first 512 bytes. + binaryContent := strings.Repeat("A", 100) + "\x00" + strings.Repeat("B", 100) + files := []SourceFile{ + makeFile("main.go"), + makeFileWithContent("mystery.dat", binaryContent), + } + + kept, stats := FilterFiles(files, FilterConfig{}) + + if len(kept) != 1 { + t.Fatalf("expected 1 file kept, got %d", len(kept)) + } + if kept[0].Path != "main.go" { + t.Errorf("expected main.go, got %q", kept[0].Path) + } + if stats.Binary != 1 { + t.Errorf("expected 1 binary file excluded, got %d", stats.Binary) + } +} + +func TestFilterFiles_BinaryContentBoundary(t *testing.T) { + // Null byte at exactly position 511 (last checked byte) — binary. + data511 := strings.Repeat("A", 511) + "\x00" + strings.Repeat("B", 100) + // Null byte at position 512 (outside check range) — not binary. + data512 := strings.Repeat("A", 512) + "\x00" + strings.Repeat("B", 100) + + files := []SourceFile{ + makeFileWithContent("binary.dat", data511), + makeFileWithContent("text.dat", data512), + } + + kept, stats := FilterFiles(files, FilterConfig{}) + + if len(kept) != 1 { + paths := make([]string, len(kept)) + for i, f := range kept { + paths[i] = f.Path + } + t.Fatalf("expected 1 file kept, got %d: %v", len(kept), paths) + } + if kept[0].Path != "text.dat" { + t.Errorf("expected text.dat, got %q", kept[0].Path) + } + if stats.Binary != 1 { + t.Errorf("expected 1 binary exclusion, got %d", stats.Binary) + } +} + +func TestFilterFiles_EdgeCase_NoExtension(t *testing.T) { + files := []SourceFile{ + makeFile("Makefile"), + makeFile("Dockerfile"), + makeFile("Procfile"), + makeFile(".env"), + makeFile(".gitignore"), + makeFile("LICENSE"), + } + + kept, _ := FilterFiles(files, FilterConfig{}) + + // All of these should pass through — none match binary/test/vendor/docs patterns. + if len(kept) != 6 { + paths := make([]string, len(kept)) + for i, f := range kept { + paths[i] = f.Path + } + t.Fatalf("expected 6 files kept, got %d: %v", len(kept), paths) + } +} + +func TestFilterFiles_EdgeCase_DotFiles(t *testing.T) { + files := []SourceFile{ + makeFile(".env"), + makeFile(".dockerignore"), + makeFile(".eslintrc.js"), // low-value: build tooling config + makeFile(".prettierrc"), // low-value: build tooling config + } + + kept, stats := FilterFiles(files, FilterConfig{}) + + if len(kept) != 2 { + paths := make([]string, len(kept)) + for i, f := range kept { + paths[i] = f.Path + } + t.Fatalf("expected 2 dot-files kept (.env, .dockerignore), got %d: %v", len(kept), paths) + } + if stats.LowValue != 2 { + t.Errorf("expected 2 low-value files, got %d", stats.LowValue) + } +} + +func TestFilterFiles_EdgeCase_Makefile(t *testing.T) { + files := []SourceFile{ + makeFile("Makefile"), + makeFile("src/Makefile"), + makeFile("GNUmakefile"), + } + + kept, _ := FilterFiles(files, FilterConfig{}) + + if len(kept) != 3 { + t.Fatalf("expected 3 Makefiles kept, got %d", len(kept)) + } +} + +func TestFilterFiles_EdgeCase_Dockerfile(t *testing.T) { + files := []SourceFile{ + makeFile("Dockerfile"), + makeFile("Dockerfile.prod"), + makeFile("docker/Dockerfile.dev"), + } + + kept, _ := FilterFiles(files, FilterConfig{}) + + if len(kept) != 3 { + t.Fatalf("expected 3 Dockerfiles kept, got %d", len(kept)) + } +} + +func TestFilterFiles_Stats_Correct(t *testing.T) { + files := []SourceFile{ + makeFile("main.go"), // kept + makeFile("main_test.go"), // test + makeFile("vendor/dep.go"), // vendor + makeFile("image.png"), // binary + makeFile("README.md"), // docs + makeFile("schema.proto"), // always-include + makeFile("test/fixtures.sql"), // always-include (despite test dir) + } + + kept, stats := FilterFiles(files, FilterConfig{}) + + if stats.Total != 7 { + t.Errorf("total: got %d, want 7", stats.Total) + } + if stats.Kept != 3 { + t.Errorf("kept: got %d, want 3 (main.go, schema.proto, test/fixtures.sql)", stats.Kept) + } + if stats.Tests != 1 { + t.Errorf("tests: got %d, want 1", stats.Tests) + } + if stats.Vendor != 1 { + t.Errorf("vendor: got %d, want 1", stats.Vendor) + } + if stats.Binary != 1 { + t.Errorf("binary: got %d, want 1", stats.Binary) + } + if stats.Docs != 1 { + t.Errorf("docs: got %d, want 1", stats.Docs) + } + if len(kept) != 3 { + t.Errorf("kept files: got %d, want 3", len(kept)) + } +} + +func TestFilterFiles_StatsTotal(t *testing.T) { + files := []SourceFile{ + makeFile("main.go"), + makeFile("main_test.go"), + makeFile("vendor/dep.go"), + makeFile("image.png"), + makeFile("README.md"), + } + + _, stats := FilterFiles(files, FilterConfig{ + Exclude: []string{"nonexistent"}, + }) + + // Verify that total = kept + tests + vendor + binary + docs + custom + oversized. + sum := stats.Kept + stats.Tests + stats.Vendor + stats.Binary + stats.Docs + stats.Custom + stats.Oversized + if sum != stats.Total { + t.Errorf("stats don't add up: kept(%d) + tests(%d) + vendor(%d) + binary(%d) + docs(%d) + custom(%d) + oversized(%d) = %d, but total = %d", + stats.Kept, stats.Tests, stats.Vendor, stats.Binary, stats.Docs, stats.Custom, stats.Oversized, sum, stats.Total) + } +} + +func TestFilterFiles_CustomGlobDoubleStarPattern(t *testing.T) { + files := []SourceFile{ + makeFile("src/main.go"), + makeFile("src/generated/model.go"), + makeFile("lib/generated/types.go"), + } + + kept, stats := FilterFiles(files, FilterConfig{ + Exclude: []string{"**/generated/*.go"}, + }) + + if len(kept) != 1 { + paths := make([]string, len(kept)) + for i, f := range kept { + paths[i] = f.Path + } + t.Fatalf("expected 1 file kept, got %d: %v", len(kept), paths) + } + if kept[0].Path != "src/main.go" { + t.Errorf("expected src/main.go, got %q", kept[0].Path) + } + if stats.Custom != 2 { + t.Errorf("expected 2 custom exclusions, got %d", stats.Custom) + } +} + +func TestFilterFiles_CustomGlobTrailingDoubleStarPattern(t *testing.T) { + files := []SourceFile{ + makeFile("firmware/main.c"), + makeFile("firmware/third-party/direct.c"), + makeFile("firmware/third-party/nested/deep.c"), + } + + kept, stats := FilterFiles(files, FilterConfig{ + Exclude: []string{"firmware/third-party/**"}, + }) + + if len(kept) != 1 { + paths := make([]string, len(kept)) + for i, f := range kept { + paths[i] = f.Path + } + t.Fatalf("expected 1 file kept, got %d: %v", len(kept), paths) + } + if kept[0].Path != "firmware/main.c" { + t.Errorf("expected firmware/main.c, got %q", kept[0].Path) + } + if stats.Custom != 2 { + t.Errorf("expected 2 custom exclusions, got %d", stats.Custom) + } +} + +func TestFilterFiles_AllFlagsEnabled(t *testing.T) { + files := []SourceFile{ + makeFile("main.go"), + makeFile("main_test.go"), + makeFile("README.md"), + makeFile("vendor/dep.go"), // still excluded + makeFile("image.png"), // still excluded (binary extension) + } + + kept, stats := FilterFiles(files, FilterConfig{ + IncludeTests: true, + IncludeDocs: true, + }) + + if len(kept) != 3 { + paths := make([]string, len(kept)) + for i, f := range kept { + paths[i] = f.Path + } + t.Fatalf("expected 3 files kept, got %d: %v", len(kept), paths) + } + if stats.Tests != 0 { + t.Errorf("expected 0 test exclusions, got %d", stats.Tests) + } + if stats.Docs != 0 { + t.Errorf("expected 0 doc exclusions, got %d", stats.Docs) + } + if stats.Vendor != 1 { + t.Errorf("expected 1 vendor exclusion, got %d", stats.Vendor) + } + if stats.Binary != 1 { + t.Errorf("expected 1 binary exclusion, got %d", stats.Binary) + } +} + +func TestFilterFiles_MixedCategories(t *testing.T) { + // A comprehensive test with files from every category. + files := []SourceFile{ + // Kept + makeFile("cmd/main.go"), + makeFile("internal/handler.go"), + makeFile("config.yaml"), + makeFile("Dockerfile"), + makeFile("Makefile"), + makeFile(".env"), + // Always-include + makeFile("api/service.proto"), + makeFile("db/schema.sql"), + makeFile("api/types.graphql"), + // Test files + makeFile("internal/handler_test.go"), + makeFile("test/integration.py"), + makeFile("__tests__/component.js"), + // Vendor + makeFile("vendor/lib/dep.go"), + makeFile("node_modules/pkg/index.js"), + // Binary + makeFile("assets/logo.png"), + makeFile("bin/app.exe"), + // Docs + makeFile("README.md"), + makeFile("docs/guide.md"), + } + + kept, stats := FilterFiles(files, FilterConfig{}) + + expectedKept := 9 // 6 regular + 3 always-include + if len(kept) != expectedKept { + paths := make([]string, len(kept)) + for i, f := range kept { + paths[i] = f.Path + } + t.Fatalf("expected %d files kept, got %d: %v", expectedKept, len(kept), paths) + } + if stats.Tests != 3 { + t.Errorf("tests: got %d, want 3", stats.Tests) + } + if stats.Vendor != 2 { + t.Errorf("vendor: got %d, want 2", stats.Vendor) + } + if stats.Binary != 2 { + t.Errorf("binary: got %d, want 2", stats.Binary) + } + if stats.Docs != 2 { + t.Errorf("docs: got %d, want 2", stats.Docs) + } +} + +func TestFilterFiles_PreservesFileOrder(t *testing.T) { + files := []SourceFile{ + makeFile("z.go"), + makeFile("a.go"), + makeFile("m.go"), + } + + kept, _ := FilterFiles(files, FilterConfig{}) + + if len(kept) != 3 { + t.Fatalf("expected 3 files, got %d", len(kept)) + } + if kept[0].Path != "z.go" || kept[1].Path != "a.go" || kept[2].Path != "m.go" { + t.Errorf("filter should preserve order: got %v", []string{kept[0].Path, kept[1].Path, kept[2].Path}) + } +} + +func TestFilterFiles_PreservesContent(t *testing.T) { + content := "package main\n\nfunc main() {\n\tprintln(\"hello\")\n}\n" + files := []SourceFile{ + makeFileWithContent("main.go", content), + } + + kept, _ := FilterFiles(files, FilterConfig{}) + + if len(kept) != 1 { + t.Fatalf("expected 1 file, got %d", len(kept)) + } + if kept[0].Content != content { + t.Errorf("content not preserved: got %q, want %q", kept[0].Content, content) + } + if kept[0].LineCount != 5 { + t.Errorf("line count not preserved: got %d, want 5", kept[0].LineCount) + } +} + +func TestIsTestFile(t *testing.T) { + tests := []struct { + path string + want bool + }{ + {"main_test.go", true}, + {"test_helper.py", true}, + {"component_spec.ts", true}, + {"widget_spec.js", true}, + {"test/helper.go", true}, + {"tests/integration.py", true}, + {"__tests__/component.js", true}, + {"spec/models/user.rb", true}, + {"src/__tests__/nested.js", true}, + {"main.go", false}, + {"testing.go", false}, + {"testdata/fixture.go", false}, + {"contest/entry.go", false}, + } + + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + got := isTestFile(tt.path) + if got != tt.want { + t.Errorf("isTestFile(%q) = %v, want %v", tt.path, got, tt.want) + } + }) + } +} + +func TestIsInVendorDir(t *testing.T) { + tests := []struct { + path string + want bool + }{ + {"vendor/dep.go", true}, + {"node_modules/pkg/index.js", true}, + {"__pycache__/module.pyc", true}, + {".venv/lib/site.py", true}, + {"venv/lib/site.py", true}, + {"src/main.go", false}, + {"vendoring/tool.go", false}, + } + + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + got := isInVendorDir(tt.path) + if got != tt.want { + t.Errorf("isInVendorDir(%q) = %v, want %v", tt.path, got, tt.want) + } + }) + } +} + +func TestIsBinaryFile(t *testing.T) { + tests := []struct { + name string + content string + want bool + }{ + {"empty", "", false}, + {"text", "hello world\n", false}, + {"null at start", "\x00hello", true}, + {"null in middle", "hello\x00world", true}, + {"null at byte 511", strings.Repeat("A", 511) + "\x00" + strings.Repeat("B", 100), true}, + {"null at byte 512", strings.Repeat("A", 512) + "\x00" + strings.Repeat("B", 100), false}, + {"large text no null", strings.Repeat("hello\n", 1000), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isBinaryFile(tt.content) + if got != tt.want { + t.Errorf("isBinaryFile(%q...) = %v, want %v", tt.name, got, tt.want) + } + }) + } +} + +func TestMatchesAnyGlob(t *testing.T) { + tests := []struct { + name string + path string + patterns []string + want bool + }{ + {"empty patterns", "main.go", nil, false}, + {"exact match", "main.go", []string{"main.go"}, true}, + {"extension glob", "main.go", []string{"*.go"}, true}, + {"no match", "main.go", []string{"*.py"}, false}, + {"path match", "src/main.go", []string{"src/main.go"}, true}, + {"double star prefix", "src/gen/model.go", []string{"**/gen/*.go"}, true}, + {"base match", "deep/nested/file.pb.go", []string{"*.pb.go"}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := matchesAnyGlob(tt.path, tt.patterns) + if got != tt.want { + t.Errorf("matchesAnyGlob(%q, %v) = %v, want %v", tt.path, tt.patterns, got, tt.want) + } + }) + } +} + +func TestFilterFiles_VendorProtoAlwaysIncluded(t *testing.T) { + // .proto inside vendor should still be included. + files := []SourceFile{ + makeFile("vendor/google/api/annotations.proto"), + makeFile("vendor/regular.go"), + } + + kept, stats := FilterFiles(files, FilterConfig{}) + + if len(kept) != 1 { + t.Fatalf("expected 1 file kept, got %d", len(kept)) + } + if kept[0].Path != "vendor/google/api/annotations.proto" { + t.Errorf("expected proto file kept, got %q", kept[0].Path) + } + if stats.Vendor != 1 { + t.Errorf("expected 1 vendor exclusion, got %d", stats.Vendor) + } +} + +func TestFilterFiles_NilConfigDefaults(t *testing.T) { + // Zero-value FilterConfig should apply all default exclusions. + files := []SourceFile{ + makeFile("main.go"), + makeFile("main_test.go"), + makeFile("vendor/dep.go"), + makeFile("README.md"), + } + + kept, _ := FilterFiles(files, FilterConfig{}) + + if len(kept) != 1 { + paths := make([]string, len(kept)) + for i, f := range kept { + paths[i] = f.Path + } + t.Fatalf("expected 1 file with zero-value config, got %d: %v", len(kept), paths) + } + if kept[0].Path != "main.go" { + t.Errorf("expected main.go, got %q", kept[0].Path) + } +} + +func TestFilterFiles_CustomIncludeAndExcludeCombined(t *testing.T) { + files := []SourceFile{ + makeFile("main.go"), + makeFile("main_test.go"), // test excluded by default + makeFile("integration_test.go"), // force-included via custom include + makeFile("config.yaml"), // excluded via custom exclude + } + + kept, stats := FilterFiles(files, FilterConfig{ + Include: []string{"integration_test.go"}, + Exclude: []string{"*.yaml"}, + }) + + if len(kept) != 2 { + paths := make([]string, len(kept)) + for i, f := range kept { + paths[i] = f.Path + } + t.Fatalf("expected 2 files kept, got %d: %v", len(kept), paths) + } + + keptPaths := make(map[string]bool) + for _, f := range kept { + keptPaths[f.Path] = true + } + if !keptPaths["main.go"] { + t.Error("expected main.go to be kept") + } + if !keptPaths["integration_test.go"] { + t.Error("expected integration_test.go to be force-included") + } + if stats.Tests != 1 { + t.Errorf("expected 1 test exclusion (main_test.go), got %d", stats.Tests) + } + if stats.Custom != 1 { + t.Errorf("expected 1 custom exclusion (config.yaml), got %d", stats.Custom) + } +} + +func TestFilterFiles_MaxFileSize_ExcludesLargeFiles(t *testing.T) { + small := strings.Repeat("x", 1000) + large := strings.Repeat("x", 200000) // 200KB + + files := []SourceFile{ + makeFileWithContent("small.go", small), + makeFileWithContent("huge.tsx", large), + makeFileWithContent("big_fixture.json", large), + makeFileWithContent("normal.ts", small), + } + + kept, stats := FilterFiles(files, FilterConfig{MaxFileSize: 102400}) + + if len(kept) != 2 { + paths := make([]string, len(kept)) + for i, f := range kept { + paths[i] = f.Path + } + t.Fatalf("expected 2 files kept, got %d: %v", len(kept), paths) + } + if stats.Oversized != 2 { + t.Errorf("expected 2 oversized exclusions, got %d", stats.Oversized) + } +} + +func TestFilterFiles_MaxFileSize_ZeroDisablesLimit(t *testing.T) { + large := strings.Repeat("x", 500000) + + files := []SourceFile{ + makeFileWithContent("big.go", large), + } + + kept, stats := FilterFiles(files, FilterConfig{MaxFileSize: 0}) + + if len(kept) != 1 { + t.Fatalf("expected 1 file kept with no size limit, got %d", len(kept)) + } + if stats.Oversized != 0 { + t.Errorf("expected 0 oversized with no limit, got %d", stats.Oversized) + } +} + +func TestFilterFiles_MaxFileSize_CustomIncludeBypassesLimit(t *testing.T) { + large := strings.Repeat("x", 200000) + + files := []SourceFile{ + makeFileWithContent("important.go", large), + } + + kept, _ := FilterFiles(files, FilterConfig{ + MaxFileSize: 102400, + Include: []string{"important.go"}, + }) + + if len(kept) != 1 { + t.Fatal("expected force-included file to bypass size limit") + } +} + +func TestFilterFiles_MaxFileSize_AlwaysIncludeBypassesLimit(t *testing.T) { + large := strings.Repeat("x", 200000) + + files := []SourceFile{ + makeFileWithContent("schema.proto", large), + makeFileWithContent("migrations.sql", large), + } + + kept, _ := FilterFiles(files, FilterConfig{MaxFileSize: 102400}) + + if len(kept) != 2 { + t.Fatalf("expected 2 always-include files kept despite size, got %d", len(kept)) + } +} + +func TestFilterFiles_MaxFileSize_BoundaryValues(t *testing.T) { + exactly := strings.Repeat("x", 1024) + overBy1 := strings.Repeat("x", 1025) + + files := []SourceFile{ + makeFileWithContent("exact.go", exactly), + makeFileWithContent("over.go", overBy1), + } + + kept, stats := FilterFiles(files, FilterConfig{MaxFileSize: 1024}) + + if len(kept) != 1 { + t.Fatalf("expected 1 file kept at boundary, got %d", len(kept)) + } + if kept[0].Path != "exact.go" { + t.Errorf("expected exact.go kept, got %q", kept[0].Path) + } + if stats.Oversized != 1 { + t.Errorf("expected 1 oversized, got %d", stats.Oversized) + } +} diff --git a/internal/ingest/flattener.go b/internal/ingest/flattener.go new file mode 100644 index 0000000..14a20bc --- /dev/null +++ b/internal/ingest/flattener.go @@ -0,0 +1,313 @@ +package ingest + +import ( + "sort" + "strconv" + "strings" +) + +// FlattenResult holds the output of the Flatten operation. +type FlattenResult struct { + XML string // Repomix-compatible XML string. + FileMap FileMap // In-memory file contents for snippet extraction. + Tokens int // Token count (set to 0 here; counted by chunker later). +} + +// FlattenConfig controls XML output formatting. +type FlattenConfig struct { + Compress bool // When true, strips unnecessary whitespace from XML output. +} + +// Flatten converts a slice of SourceFiles into a repomix-compatible XML document +// with numbered lines, and builds an in-memory FileMap for snippet extraction. +func Flatten(files []SourceFile, cfg FlattenConfig) FlattenResult { + fm := make(FileMap, len(files)) + for _, f := range files { + fm[f.Path] = f.Content + } + + xml := buildXML(files, cfg) + + return FlattenResult{ + XML: xml, + FileMap: fm, + Tokens: 0, + } +} + +// FlattenFileMapOnly builds only the FileMap without generating the full XML +// document. Use this when the caller will use streaming token counting and +// may not need the full XML (e.g. multi-chunk repos where the chunker +// rebuilds per-file XML from FileMap anyway). +func FlattenFileMapOnly(files []SourceFile) FlattenResult { + fm := make(FileMap, len(files)) + for _, f := range files { + fm[f.Path] = f.Content + } + return FlattenResult{FileMap: fm} +} + +// BuildFullXML generates the complete repomix-format XML and stores it in the +// FlattenResult. Call this after FlattenFileMapOnly when the streaming token +// count confirms the repo fits in a single chunk. +func (fr *FlattenResult) BuildFullXML(files []SourceFile, cfg FlattenConfig) { + fr.XML = buildXML(files, cfg) +} + +// buildXML produces the full repomix-compatible XML document. +func buildXML(files []SourceFile, cfg FlattenConfig) string { + if len(files) == 0 { + return emptyXML(cfg) + } + + var b strings.Builder + // Pre-allocate a rough estimate to avoid repeated allocations. + b.Grow(estimateXMLSize(files)) + + writeHeader(&b, cfg) + writeDirectoryStructure(&b, files, cfg) + writeFiles(&b, files, cfg) + + return b.String() +} + +// emptyXML produces a valid document for zero files. +func emptyXML(cfg FlattenConfig) string { + var b strings.Builder + writeHeader(&b, cfg) + writeDirectoryStructure(&b, nil, cfg) + + if cfg.Compress { + b.WriteString("\n") + } else { + b.WriteString("\n\n") + } + return b.String() +} + +// writeHeader writes the repomix preamble and file_summary block. +func writeHeader(b *strings.Builder, cfg FlattenConfig) { + // gap returns "\n" in non-compressed mode to add blank lines between sections. + gap := "\n" + if cfg.Compress { + gap = "" + } + + b.WriteString("This file is a merged representation of the entire codebase, combined into a single document by Repomix.\n") + b.WriteString(gap) + b.WriteString("\n") + b.WriteString("\n") + b.WriteString("This file contains a packed representation of the entire repository's contents.\n") + b.WriteString("It is designed to be easily consumable by AI systems for analysis, code review,\n") + b.WriteString("or other automated processes.\n") + b.WriteString("\n") + b.WriteString(gap) + b.WriteString("\n") + b.WriteString("The content is organized as follows:\n") + b.WriteString("1. This summary section\n") + b.WriteString("2. Repository structure\n") + b.WriteString("3. Repository files, each preceded by its file path as an XML tag\n") + b.WriteString(gap) + b.WriteString("Each file's content is preceded by a line number prefix for reference.\n") + b.WriteString("\n") + b.WriteString(gap) + b.WriteString("\n") + b.WriteString("- Use the file path to understand the repository structure\n") + b.WriteString("- Use line numbers for precise code references\n") + b.WriteString("- Cross-reference files to understand dependencies and relationships\n") + b.WriteString("\n") + b.WriteString(gap) + b.WriteString("\n") + b.WriteString("- Some binary files may have been excluded\n") + b.WriteString("- Files are sorted by path for consistent ordering\n") + b.WriteString("\n") + b.WriteString("\n") + b.WriteString(gap) +} + +// writeDirectoryStructure writes the directory_structure block. +func writeDirectoryStructure(b *strings.Builder, files []SourceFile, cfg FlattenConfig) { + b.WriteString("\n") + if len(files) > 0 { + // Sort paths for deterministic output. + paths := make([]string, len(files)) + for i, f := range files { + paths[i] = f.Path + } + sort.Strings(paths) + for _, p := range paths { + b.WriteString(p) + b.WriteByte('\n') + } + } + if cfg.Compress { + b.WriteString("\n") + } else { + b.WriteString("\n\n") + } +} + +// writeFiles writes the block with numbered lines for each file. +func writeFiles(b *strings.Builder, files []SourceFile, cfg FlattenConfig) { + // Sort files by path for deterministic output. + sorted := make([]SourceFile, len(files)) + copy(sorted, files) + sort.Slice(sorted, func(i, j int) bool { + return sorted[i].Path < sorted[j].Path + }) + + b.WriteString("\n") + + for i, f := range sorted { + writeFileElement(b, f) + // Add blank line between files in non-compressed mode, but not after the last. + if !cfg.Compress && i < len(sorted)-1 { + b.WriteByte('\n') + } + } + + b.WriteString("\n") +} + +// writeFileElement writes a single element with numbered lines. +func writeFileElement(b *strings.Builder, f SourceFile) { + escapedPath := EscapeXMLAttr(f.Path) + b.WriteString("\n") + + writeNumberedLines(b, f.Content) + + b.WriteString("\n") +} + +// writeNumberedLines writes the file content with line numbers in repomix format. +// Format: "{padded_number} | {line}\n" where padding is based on total line count. +func writeNumberedLines(b *strings.Builder, content string) { + if content == "" { + return + } + + lines := SplitLines(content) + totalLines := len(lines) + padding := len(strconv.Itoa(totalLines)) + + for i, line := range lines { + num := i + 1 + b.WriteString(padLeft(strconv.Itoa(num), padding)) + b.WriteString(" | ") + b.WriteString(EscapeXMLContent(line)) + b.WriteByte('\n') + } +} + +// SplitLines splits content into lines, handling the trailing newline edge case. +// A file with content "a\nb\n" has 2 lines: ["a", "b"]. +// A file with content "a\nb" (no trailing newline) has 2 lines: ["a", "b"]. +// An empty string returns nil. +func SplitLines(content string) []string { + if content == "" { + return nil + } + // Remove a single trailing newline to avoid an empty phantom line. + trimmed := strings.TrimSuffix(content, "\n") + if trimmed == "" { + // Content was just "\n" — treat as a single empty line. + return []string{""} + } + return strings.Split(trimmed, "\n") +} + +// EnvelopeXML returns the header and directory-structure XML without any file +// content. Used by the streaming token counter to measure envelope overhead +// without building the full document. +func EnvelopeXML(paths []string, cfg FlattenConfig) string { + var b strings.Builder + writeHeader(&b, cfg) + + // Write directory structure from paths directly (no SourceFile needed). + b.WriteString("\n") + sorted := make([]string, len(paths)) + copy(sorted, paths) + sort.Strings(sorted) + for _, p := range sorted { + b.WriteString(p) + b.WriteByte('\n') + } + if cfg.Compress { + b.WriteString("\n") + } else { + b.WriteString("\n\n") + } + + b.WriteString("\n\n") + return b.String() +} + +// padLeft pads s with spaces on the left to reach the given width. +func padLeft(s string, width int) string { + if len(s) >= width { + return s + } + return strings.Repeat(" ", width-len(s)) + s +} + +// EscapeXMLContent escapes characters that are special in XML text content. +func EscapeXMLContent(s string) string { + // Fast path: if no special chars, return as-is. + if !strings.ContainsAny(s, "&<>") { + return s + } + var b strings.Builder + b.Grow(len(s) + 10) // slight overallocation for escapes + for _, r := range s { + switch r { + case '&': + b.WriteString("&") + case '<': + b.WriteString("<") + case '>': + b.WriteString(">") + default: + b.WriteRune(r) + } + } + return b.String() +} + +// EscapeXMLAttr escapes characters that are special in XML attribute values. +func EscapeXMLAttr(s string) string { + if !strings.ContainsAny(s, "&<>\"'") { + return s + } + var b strings.Builder + b.Grow(len(s) + 10) + for _, r := range s { + switch r { + case '&': + b.WriteString("&") + case '<': + b.WriteString("<") + case '>': + b.WriteString(">") + case '"': + b.WriteString(""") + case '\'': + b.WriteString("'") + default: + b.WriteRune(r) + } + } + return b.String() +} + +// estimateXMLSize estimates the total output size for pre-allocation. +func estimateXMLSize(files []SourceFile) int { + // Header + structure is roughly 1KB. + size := 1024 + for _, f := range files { + // path tag overhead (~50 bytes) + content + line numbers (~8 bytes per line). + size += 50 + len(f.Content) + f.LineCount*8 + } + return size +} diff --git a/internal/ingest/flattener_test.go b/internal/ingest/flattener_test.go new file mode 100644 index 0000000..9037b2f --- /dev/null +++ b/internal/ingest/flattener_test.go @@ -0,0 +1,881 @@ +package ingest + +import ( + "fmt" + "strings" + "testing" +) + +func TestFlatten_EmptyInput(t *testing.T) { + result := Flatten(nil, FlattenConfig{}) + + if len(result.FileMap) != 0 { + t.Errorf("FileMap should be empty, got %d entries", len(result.FileMap)) + } + if result.Tokens != 0 { + t.Errorf("Tokens should be 0, got %d", result.Tokens) + } + // Should still produce valid XML structure. + if !strings.Contains(result.XML, "") { + t.Error("empty XML should contain element") + } + if !strings.Contains(result.XML, "") { + t.Error("empty XML should contain closing element") + } + if !strings.Contains(result.XML, "") { + t.Error("empty XML should contain directory_structure element") + } +} + +func TestFlatten_EmptySlice(t *testing.T) { + result := Flatten([]SourceFile{}, FlattenConfig{}) + + if len(result.FileMap) != 0 { + t.Errorf("FileMap should be empty, got %d entries", len(result.FileMap)) + } + if !strings.Contains(result.XML, "") { + t.Error("empty XML should contain element") + } +} + +func TestFlatten_SingleFile(t *testing.T) { + files := []SourceFile{ + { + Path: "main.go", + Content: "package main\n\nfunc main() {}\n", + LineCount: 3, + Language: "go", + }, + } + + result := Flatten(files, FlattenConfig{}) + + // Check FileMap. + if content, ok := result.FileMap["main.go"]; !ok { + t.Error("FileMap should contain main.go") + } else if content != "package main\n\nfunc main() {}\n" { + t.Errorf("FileMap content mismatch: got %q", content) + } + + // Check XML structure. + if !strings.Contains(result.XML, ``) { + t.Error("XML should contain file element with path attribute") + } + if !strings.Contains(result.XML, "") { + t.Error("XML should contain closing file tag") + } + if !strings.Contains(result.XML, "1 | package main") { + t.Error("XML should contain numbered line 1") + } + if !strings.Contains(result.XML, "2 | ") { + t.Error("XML should contain numbered line 2 (empty line)") + } + if !strings.Contains(result.XML, "3 | func main() {}") { + t.Error("XML should contain numbered line 3") + } +} + +func TestFlatten_LineNumberPadding(t *testing.T) { + // Create a file with 12 lines to test padding (2-digit numbers). + var content strings.Builder + for i := 1; i <= 12; i++ { + fmt.Fprintf(&content, "line %d\n", i) + } + + files := []SourceFile{ + { + Path: "big.txt", + Content: content.String(), + LineCount: 12, + Language: "", + }, + } + + result := Flatten(files, FlattenConfig{}) + + // Single-digit lines should be padded. + if !strings.Contains(result.XML, " 1 | line 1") { + t.Error("line 1 should be padded: ' 1 | line 1'") + } + if !strings.Contains(result.XML, " 9 | line 9") { + t.Error("line 9 should be padded: ' 9 | line 9'") + } + // Two-digit lines should not be padded. + if !strings.Contains(result.XML, "10 | line 10") { + t.Error("line 10 should not be padded: '10 | line 10'") + } + if !strings.Contains(result.XML, "12 | line 12") { + t.Error("line 12 should not be padded: '12 | line 12'") + } +} + +func TestFlatten_LineNumberPadding_ThreeDigits(t *testing.T) { + // Create a file with 100 lines. + var content strings.Builder + for i := 1; i <= 100; i++ { + fmt.Fprintf(&content, "L%d\n", i) + } + + files := []SourceFile{ + { + Path: "hundred.txt", + Content: content.String(), + LineCount: 100, + Language: "", + }, + } + + result := Flatten(files, FlattenConfig{}) + + // Check 3-digit padding. + if !strings.Contains(result.XML, " 1 | L1") { + t.Errorf("line 1 should have 2 spaces of padding for 3-digit width") + } + if !strings.Contains(result.XML, " 10 | L10") { + t.Errorf("line 10 should have 1 space of padding for 3-digit width") + } + if !strings.Contains(result.XML, "100 | L100") { + t.Errorf("line 100 should have no padding") + } +} + +func TestFlatten_FileMap_AllFilesPresent(t *testing.T) { + files := []SourceFile{ + {Path: "a.go", Content: "package a\n", LineCount: 1, Language: "go"}, + {Path: "b/c.go", Content: "package c\n", LineCount: 1, Language: "go"}, + {Path: "d/e/f.py", Content: "print('hi')\n", LineCount: 1, Language: "python"}, + } + + result := Flatten(files, FlattenConfig{}) + + if len(result.FileMap) != 3 { + t.Fatalf("FileMap should have 3 entries, got %d", len(result.FileMap)) + } + + expected := map[string]string{ + "a.go": "package a\n", + "b/c.go": "package c\n", + "d/e/f.py": "print('hi')\n", + } + + for path, want := range expected { + got, ok := result.FileMap[path] + if !ok { + t.Errorf("FileMap missing key %q", path) + continue + } + if got != want { + t.Errorf("FileMap[%q] = %q, want %q", path, got, want) + } + } +} + +func TestFlatten_NoTrailingNewline(t *testing.T) { + files := []SourceFile{ + { + Path: "no_newline.txt", + Content: "line one\nline two", + LineCount: 2, + Language: "", + }, + } + + result := Flatten(files, FlattenConfig{}) + + // Should have exactly 2 lines. + if !strings.Contains(result.XML, "1 | line one") { + t.Error("should contain line 1") + } + if !strings.Contains(result.XML, "2 | line two") { + t.Error("should contain line 2") + } + + // Should NOT have a line 3. + if strings.Contains(result.XML, "3 |") { + t.Error("should not have a phantom line 3 from missing trailing newline") + } +} + +func TestFlatten_TrailingNewline(t *testing.T) { + files := []SourceFile{ + { + Path: "with_newline.txt", + Content: "line one\nline two\n", + LineCount: 2, + Language: "", + }, + } + + result := Flatten(files, FlattenConfig{}) + + // Should have exactly 2 lines. + if !strings.Contains(result.XML, "1 | line one") { + t.Error("should contain line 1") + } + if !strings.Contains(result.XML, "2 | line two") { + t.Error("should contain line 2") + } + if strings.Contains(result.XML, "3 |") { + t.Error("should not have phantom line 3") + } +} + +func TestFlatten_OnlyNewlineContent(t *testing.T) { + files := []SourceFile{ + { + Path: "just_newline.txt", + Content: "\n", + LineCount: 1, + Language: "", + }, + } + + result := Flatten(files, FlattenConfig{}) + + // Content "\n" should produce one empty line. + if !strings.Contains(result.XML, "1 | \n") { + t.Error("content '\\n' should produce a single empty numbered line") + } +} + +func TestFlatten_XMLSpecialCharacterEscaping(t *testing.T) { + files := []SourceFile{ + { + Path: "special.html", + Content: "
& foo > bar
\n", + LineCount: 1, + Language: "html", + }, + } + + result := Flatten(files, FlattenConfig{}) + + // Content should have XML special chars escaped. + if !strings.Contains(result.XML, "<div class=\"test\">&amp; foo > bar</div>") { + // Find the actual line for debugging. + for _, line := range strings.Split(result.XML, "\n") { + if strings.Contains(line, "1 |") && strings.Contains(line, "div") { + t.Errorf("XML content escaping incorrect, got line: %q", line) + return + } + } + t.Error("could not find the file content line in XML output") + } +} + +func TestFlatten_XMLEscaping_Ampersand(t *testing.T) { + files := []SourceFile{ + {Path: "amp.txt", Content: "a & b\n", LineCount: 1}, + } + result := Flatten(files, FlattenConfig{}) + if !strings.Contains(result.XML, "1 | a & b") { + t.Errorf("& should be escaped to & in XML content") + } +} + +func TestFlatten_XMLEscaping_LessThan(t *testing.T) { + files := []SourceFile{ + {Path: "lt.txt", Content: "a < b\n", LineCount: 1}, + } + result := Flatten(files, FlattenConfig{}) + if !strings.Contains(result.XML, "1 | a < b") { + t.Errorf("< should be escaped to < in XML content") + } +} + +func TestFlatten_XMLEscaping_GreaterThan(t *testing.T) { + files := []SourceFile{ + {Path: "gt.txt", Content: "a > b\n", LineCount: 1}, + } + result := Flatten(files, FlattenConfig{}) + if !strings.Contains(result.XML, "1 | a > b") { + t.Errorf("> should be escaped to > in XML content") + } +} + +func TestFlatten_XMLAttrEscaping_Path(t *testing.T) { + files := []SourceFile{ + {Path: "dir/file\"name.txt", Content: "content\n", LineCount: 1}, + } + result := Flatten(files, FlattenConfig{}) + if !strings.Contains(result.XML, ``) { + t.Errorf("double quotes in path should be escaped in attribute") + } +} + +func TestFlatten_UnicodePreservation(t *testing.T) { + files := []SourceFile{ + { + Path: "unicode.txt", + Content: "Hello 世界\nこんにちは\n🎉 emoji\n", + LineCount: 3, + Language: "", + }, + } + + result := Flatten(files, FlattenConfig{}) + + if !strings.Contains(result.XML, "1 | Hello 世界") { + t.Error("Chinese characters should be preserved") + } + if !strings.Contains(result.XML, "2 | こんにちは") { + t.Error("Japanese characters should be preserved") + } + if !strings.Contains(result.XML, "3 | 🎉 emoji") { + t.Error("emoji should be preserved") + } + + // FileMap should also preserve Unicode. + if content, ok := result.FileMap["unicode.txt"]; ok { + if !strings.Contains(content, "世界") { + t.Error("FileMap should preserve Unicode content") + } + } else { + t.Error("FileMap should contain unicode.txt") + } +} + +func TestFlatten_LargeFile(t *testing.T) { + // Create a file with 10,000+ lines. + var content strings.Builder + lineCount := 10500 + for i := 1; i <= lineCount; i++ { + fmt.Fprintf(&content, "line number %d with some content\n", i) + } + + files := []SourceFile{ + { + Path: "large.txt", + Content: content.String(), + LineCount: lineCount, + Language: "", + }, + } + + result := Flatten(files, FlattenConfig{}) + + // Check that all lines are numbered. + // Line 1 should be padded to 5 digits. + if !strings.Contains(result.XML, " 1 | line number 1 with some content") { + t.Error("first line should have 4 spaces padding (5-digit width)") + } + if !strings.Contains(result.XML, "10500 | line number 10500 with some content") { + t.Error("last line should be present with no padding") + } + + // Verify FileMap. + if _, ok := result.FileMap["large.txt"]; !ok { + t.Error("FileMap should contain large.txt") + } + + // Verify tokens is 0 (set by chunker later). + if result.Tokens != 0 { + t.Errorf("Tokens should be 0, got %d", result.Tokens) + } +} + +func TestFlatten_Compress(t *testing.T) { + files := []SourceFile{ + {Path: "a.go", Content: "package a\n", LineCount: 1, Language: "go"}, + } + + normal := Flatten(files, FlattenConfig{Compress: false}) + compressed := Flatten(files, FlattenConfig{Compress: true}) + + // Compressed output should be shorter (no blank lines between sections). + if len(compressed.XML) >= len(normal.XML) { + t.Errorf("compressed XML (%d bytes) should be shorter than normal (%d bytes)", + len(compressed.XML), len(normal.XML)) + } + + // Both should contain the file content. + if !strings.Contains(compressed.XML, "1 | package a") { + t.Error("compressed XML should still contain numbered lines") + } + if !strings.Contains(compressed.XML, ``) { + t.Error("compressed XML should still contain file element") + } + + // Normal should have blank lines between sections. + if !strings.Contains(normal.XML, "\n\n") { + t.Error("normal XML should have blank line after ") + } + + // Compressed should NOT have blank lines between sections. + if strings.Contains(compressed.XML, "\n\n") { + t.Error("compressed XML should not have blank line after ") + } +} + +func TestFlatten_Compress_ReducesSize(t *testing.T) { + // Create multiple files. + files := []SourceFile{ + {Path: "a.go", Content: "package a\n\nfunc A() {}\n", LineCount: 3, Language: "go"}, + {Path: "b.go", Content: "package b\n\nfunc B() {}\n", LineCount: 3, Language: "go"}, + {Path: "c.go", Content: "package c\n\nfunc C() {}\n", LineCount: 3, Language: "go"}, + } + + normal := Flatten(files, FlattenConfig{Compress: false}) + compressed := Flatten(files, FlattenConfig{Compress: true}) + + if len(compressed.XML) >= len(normal.XML) { + t.Errorf("compressed (%d bytes) should be smaller than normal (%d bytes)", + len(compressed.XML), len(normal.XML)) + } +} + +func TestFlatten_DirectoryStructure(t *testing.T) { + files := []SourceFile{ + {Path: "b/handler.go", Content: "package b\n", LineCount: 1, Language: "go"}, + {Path: "a/main.go", Content: "package a\n", LineCount: 1, Language: "go"}, + {Path: "c.txt", Content: "hello\n", LineCount: 1, Language: ""}, + } + + result := Flatten(files, FlattenConfig{}) + + // Directory structure should list paths sorted. + dsStart := strings.Index(result.XML, "") + dsEnd := strings.Index(result.XML, "") + if dsStart == -1 || dsEnd == -1 { + t.Fatal("missing directory_structure element") + } + ds := result.XML[dsStart:dsEnd] + + aPos := strings.Index(ds, "a/main.go") + bPos := strings.Index(ds, "b/handler.go") + cPos := strings.Index(ds, "c.txt") + + if aPos == -1 || bPos == -1 || cPos == -1 { + t.Fatalf("directory structure missing paths, got: %s", ds) + } + if !(aPos < bPos && bPos < cPos) { + t.Error("directory structure paths should be sorted alphabetically") + } +} + +func TestFlatten_FilesSortedByPath(t *testing.T) { + files := []SourceFile{ + {Path: "z.go", Content: "package z\n", LineCount: 1, Language: "go"}, + {Path: "a.go", Content: "package a\n", LineCount: 1, Language: "go"}, + {Path: "m.go", Content: "package m\n", LineCount: 1, Language: "go"}, + } + + result := Flatten(files, FlattenConfig{}) + + aPos := strings.Index(result.XML, ``) + mPos := strings.Index(result.XML, ``) + zPos := strings.Index(result.XML, ``) + + if aPos == -1 || mPos == -1 || zPos == -1 { + t.Fatal("not all file elements found in XML") + } + if !(aPos < mPos && mPos < zPos) { + t.Error("files should be sorted by path in XML output") + } +} + +func TestFlatten_HeaderPresent(t *testing.T) { + files := []SourceFile{ + {Path: "test.go", Content: "package test\n", LineCount: 1, Language: "go"}, + } + + result := Flatten(files, FlattenConfig{}) + + if !strings.HasPrefix(result.XML, "This file is a merged representation") { + t.Error("XML should start with repomix header text") + } + if !strings.Contains(result.XML, "") { + t.Error("XML should contain file_summary element") + } + if !strings.Contains(result.XML, "") { + t.Error("XML should contain purpose element") + } + if !strings.Contains(result.XML, "") { + t.Error("XML should contain file_format element") + } + if !strings.Contains(result.XML, "") { + t.Error("XML should contain usage_guidelines element") + } +} + +func TestFlatten_TokensAlwaysZero(t *testing.T) { + files := []SourceFile{ + {Path: "a.go", Content: "package a\n", LineCount: 1, Language: "go"}, + } + + result := Flatten(files, FlattenConfig{}) + if result.Tokens != 0 { + t.Errorf("Tokens should always be 0 (set by chunker), got %d", result.Tokens) + } +} + +func TestFlatten_MultipleFiles(t *testing.T) { + files := []SourceFile{ + {Path: "cmd/main.go", Content: "package main\n\nimport \"fmt\"\n\nfunc main() {\n\tfmt.Println(\"hello\")\n}\n", LineCount: 7, Language: "go"}, + {Path: "lib/util.go", Content: "package lib\n\nfunc Add(a, b int) int {\n\treturn a + b\n}\n", LineCount: 5, Language: "go"}, + } + + result := Flatten(files, FlattenConfig{}) + + // Both files should be in FileMap. + if len(result.FileMap) != 2 { + t.Fatalf("FileMap should have 2 entries, got %d", len(result.FileMap)) + } + + // Both files should appear in XML. + if !strings.Contains(result.XML, ``) { + t.Error("XML should contain cmd/main.go") + } + if !strings.Contains(result.XML, ``) { + t.Error("XML should contain lib/util.go") + } +} + +func TestFlatten_EmptyFileContent(t *testing.T) { + files := []SourceFile{ + {Path: "empty.txt", Content: "", LineCount: 0, Language: ""}, + } + + result := Flatten(files, FlattenConfig{}) + + // File element should exist but have no numbered lines. + if !strings.Contains(result.XML, ``) { + t.Error("XML should contain empty file element") + } + + // Should go directly from opening tag to closing tag. + idx := strings.Index(result.XML, ``) + if idx == -1 { + t.Fatal("could not find file element") + } + // After the opening tag and newline, the next line should be . + afterTag := result.XML[idx+len(``)+1:] + if !strings.HasPrefix(afterTag, "") { + t.Errorf("empty file should have no content between tags, got: %q", afterTag[:min(50, len(afterTag))]) + } +} + +func TestSplitLines(t *testing.T) { + tests := []struct { + name string + content string + want []string + }{ + {"empty", "", nil}, + {"single line no newline", "hello", []string{"hello"}}, + {"single line with newline", "hello\n", []string{"hello"}}, + {"two lines with newline", "a\nb\n", []string{"a", "b"}}, + {"two lines no trailing", "a\nb", []string{"a", "b"}}, + {"blank lines", "a\n\nb\n", []string{"a", "", "b"}}, + {"only newline", "\n", []string{""}}, + {"multiple blank lines", "\n\n\n", []string{"", "", ""}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := SplitLines(tt.content) + if tt.want == nil { + // SplitLines returns nil for empty content via early return. + if len(got) != 0 { + t.Errorf("SplitLines(%q) = %v, want nil/empty", tt.content, got) + } + return + } + if len(got) != len(tt.want) { + t.Errorf("SplitLines(%q) has %d lines, want %d: got %v", tt.content, len(got), len(tt.want), got) + return + } + for i, w := range tt.want { + if got[i] != w { + t.Errorf("SplitLines(%q)[%d] = %q, want %q", tt.content, i, got[i], w) + } + } + }) + } +} + +func TestPadLeft(t *testing.T) { + tests := []struct { + s string + width int + want string + }{ + {"1", 1, "1"}, + {"1", 3, " 1"}, + {"10", 3, " 10"}, + {"100", 3, "100"}, + {"1000", 3, "1000"}, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("%s_w%d", tt.s, tt.width), func(t *testing.T) { + got := padLeft(tt.s, tt.width) + if got != tt.want { + t.Errorf("padLeft(%q, %d) = %q, want %q", tt.s, tt.width, got, tt.want) + } + }) + } +} + +func TestEscapeXMLContent(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + {"no special chars", "hello world", "hello world"}, + {"ampersand", "a & b", "a & b"}, + {"less than", "a < b", "a < b"}, + {"greater than", "a > b", "a > b"}, + {"all special", "&value", "<tag>&value</tag>"}, + {"empty string", "", ""}, + {"unicode", "Hello 世界", "Hello 世界"}, + {"multiple amps", "a && b && c", "a && b && c"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := EscapeXMLContent(tt.input) + if got != tt.want { + t.Errorf("EscapeXMLContent(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestEscapeXMLAttr(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + {"no special chars", "path/to/file.go", "path/to/file.go"}, + {"double quote", `path/"file".go`, `path/"file".go`}, + {"single quote", "it's", "it's"}, + {"ampersand", "a&b", "a&b"}, + {"mixed", `<"path">`, `<"path">`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := EscapeXMLAttr(tt.input) + if got != tt.want { + t.Errorf("EscapeXMLAttr(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestFlatten_LineNumberAccuracy(t *testing.T) { + // Verify exact line number format for specific files. + tests := []struct { + name string + content string + wantLines []string + }{ + { + name: "single line", + content: "hello\n", + wantLines: []string{ + "1 | hello", + }, + }, + { + name: "three lines", + content: "a\nb\nc\n", + wantLines: []string{ + "1 | a", + "2 | b", + "3 | c", + }, + }, + { + name: "ten lines - padding kicks in", + content: "1\n2\n3\n4\n5\n6\n7\n8\n9\n10\n", + wantLines: []string{ + " 1 | 1", + " 2 | 2", + " 9 | 9", + "10 | 10", + }, + }, + { + name: "no trailing newline", + content: "first\nsecond", + wantLines: []string{ + "1 | first", + "2 | second", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + files := []SourceFile{ + {Path: "test.txt", Content: tt.content, LineCount: countLines(tt.content)}, + } + result := Flatten(files, FlattenConfig{}) + + for _, want := range tt.wantLines { + if !strings.Contains(result.XML, want) { + t.Errorf("expected line %q not found in XML output", want) + } + } + }) + } +} + +func TestFlatten_XMLWellFormedness(t *testing.T) { + files := []SourceFile{ + {Path: "a.go", Content: "package a\n", LineCount: 1, Language: "go"}, + {Path: "b.go", Content: "package b\n", LineCount: 1, Language: "go"}, + } + + result := Flatten(files, FlattenConfig{}) + + // Count opening and closing tags. + pairs := []struct { + open string + close string + }{ + {"", ""}, + {"", ""}, + {"", ""}, + {"", ""}, + {"", ""}, + {"", ""}, + {"", ""}, + } + + for _, p := range pairs { + openCount := strings.Count(result.XML, p.open) + closeCount := strings.Count(result.XML, p.close) + if openCount != closeCount { + t.Errorf("mismatched tags: %s (%d) vs %s (%d)", p.open, openCount, p.close, closeCount) + } + if openCount == 0 { + t.Errorf("missing element: %s", p.open) + } + } + + // Check file tags match. + fileOpens := strings.Count(result.XML, "") + if fileOpens != fileCloses { + t.Errorf("mismatched file tags: %d opens vs %d closes", fileOpens, fileCloses) + } + if fileOpens != 2 { + t.Errorf("expected 2 file elements, got %d", fileOpens) + } +} + +func TestFlatten_Compress_StillValidStructure(t *testing.T) { + files := []SourceFile{ + {Path: "a.go", Content: "package a\n", LineCount: 1, Language: "go"}, + } + + result := Flatten(files, FlattenConfig{Compress: true}) + + // All required elements should still be present. + required := []string{ + "", "", + "", "", + "", "", + ``, "", + "1 | package a", + } + + for _, r := range required { + if !strings.Contains(result.XML, r) { + t.Errorf("compressed XML missing required element: %q", r) + } + } +} + +func TestFlatten_Compress_NoBlanks(t *testing.T) { + files := []SourceFile{ + {Path: "a.go", Content: "package a\n", LineCount: 1, Language: "go"}, + {Path: "b.go", Content: "package b\n", LineCount: 1, Language: "go"}, + } + + result := Flatten(files, FlattenConfig{Compress: true}) + + // Should not have double newlines. + if strings.Contains(result.XML, "\n\n") { + // Find where the double newline is for debugging. + idx := strings.Index(result.XML, "\n\n") + start := idx - 30 + if start < 0 { + start = 0 + } + end := idx + 30 + if end > len(result.XML) { + end = len(result.XML) + } + t.Errorf("compressed XML should not have double newlines, found at position %d: %q", idx, result.XML[start:end]) + } +} + +func TestFlatten_LargeFile_LineCount(t *testing.T) { + // Verify all 10K lines are present. + lineCount := 10000 + var content strings.Builder + for i := 1; i <= lineCount; i++ { + fmt.Fprintf(&content, "L%d\n", i) + } + + files := []SourceFile{ + {Path: "big.txt", Content: content.String(), LineCount: lineCount}, + } + + result := Flatten(files, FlattenConfig{}) + + // Count the numbered lines in the file element. + fileStart := strings.Index(result.XML, ``) + fileEnd := strings.Index(result.XML, "") + if fileStart == -1 || fileEnd == -1 { + t.Fatal("could not find file element boundaries") + } + + fileContent := result.XML[fileStart:fileEnd] + // Count lines that match the number pattern. + lines := strings.Split(fileContent, "\n") + numberedCount := 0 + for _, line := range lines { + // Skip the opening tag line. + trimmed := strings.TrimSpace(line) + if trimmed == "" || strings.HasPrefix(trimmed, "<") { + continue + } + numberedCount++ + } + + if numberedCount != lineCount { + t.Errorf("expected %d numbered lines, got %d", lineCount, numberedCount) + } +} + +func TestFlatten_MixedContent(t *testing.T) { + // Realistic mix of files with different characteristics. + files := []SourceFile{ + {Path: "main.go", Content: "package main\n\nfunc main() {}\n", LineCount: 3, Language: "go"}, + {Path: "config.yaml", Content: "key: value\nlist:\n - item1\n - item2\n", LineCount: 4, Language: "yaml"}, + {Path: "README.md", Content: "# Project\n\nDescription with tags & special chars.\n", LineCount: 3, Language: "markdown"}, + {Path: "empty.txt", Content: "", LineCount: 0, Language: ""}, + } + + result := Flatten(files, FlattenConfig{}) + + // Verify all files are in FileMap. + if len(result.FileMap) != 4 { + t.Errorf("FileMap should have 4 entries, got %d", len(result.FileMap)) + } + + // Verify HTML tags in README are escaped. + if !strings.Contains(result.XML, "<html>") { + t.Error("HTML tags in content should be escaped") + } + if !strings.Contains(result.XML, "& special") { + t.Error("& in content should be escaped") + } +} diff --git a/internal/ingest/imports.go b/internal/ingest/imports.go new file mode 100644 index 0000000..d807bae --- /dev/null +++ b/internal/ingest/imports.go @@ -0,0 +1,171 @@ +package ingest + +import ( + "path" + "path/filepath" + "regexp" + "strings" +) + +// jsImportFrom matches: import ... from '...' or "..." +var jsImportFrom = regexp.MustCompile(`(?m)(?:import|export)\s+.*?\s+from\s+['"](\.\.?/[^'"]+)['"]`) + +// jsRequire matches: require('...') or require("...") +var jsRequire = regexp.MustCompile(`(?:require|import)\s*\(\s*['"](\.\.?/[^'"]+)['"]\s*\)`) + +// jsSideEffectImport matches: import './foo' or import "../bar" +var jsSideEffectImport = regexp.MustCompile(`(?m)^import\s+['"](\.\.?/[^'"]+)['"]`) + +// pyRelativeImport matches: from . import x, from .foo import x, from ..foo import x +var pyRelativeImport = regexp.MustCompile(`(?m)^from\s+(\.+\w*(?:\.\w+)*)\s+import`) + +// jsExtensions are tried when a JS/TS import has no extension. +var jsExtensions = []string{".ts", ".tsx", ".js", ".jsx"} + +// jsIndexFiles are tried when a JS/TS import might refer to a directory. +var jsIndexFiles = []string{"/index.ts", "/index.tsx", "/index.js", "/index.jsx"} + +// ParseImports extracts local/relative import paths from a source file. +// path is the repo-root-relative file path; content is the file's text. +// Returns repo-root-relative paths that the file imports. +// This is best-effort; it never panics on malformed input. +func ParseImports(filePath string, content string) []string { + ext := strings.ToLower(filepath.Ext(filePath)) + switch ext { + case ".ts", ".tsx", ".js", ".jsx": + return parseJSImports(filePath, content) + case ".py": + return parsePyImports(filePath, content) + default: + // Go uses module paths, not relative imports — no parser needed. + return nil + } +} + +func parseJSImports(filePath string, content string) []string { + dir := path.Dir(filePath) + seen := make(map[string]bool) + var result []string + + addImport := func(raw string) { + resolved := resolveJSImport(dir, raw) + for _, p := range resolved { + if !seen[p] { + seen[p] = true + result = append(result, p) + } + } + } + + for _, m := range jsImportFrom.FindAllStringSubmatch(content, -1) { + addImport(m[1]) + } + for _, m := range jsRequire.FindAllStringSubmatch(content, -1) { + addImport(m[1]) + } + for _, m := range jsSideEffectImport.FindAllStringSubmatch(content, -1) { + addImport(m[1]) + } + + return result +} + +func resolveJSImport(dir, importPath string) []string { + resolved := path.Join(dir, importPath) + + ext := path.Ext(resolved) + if ext != "" { + return []string{resolved} + } + + var candidates []string + for _, e := range jsExtensions { + candidates = append(candidates, resolved+e) + } + for _, idx := range jsIndexFiles { + candidates = append(candidates, resolved+idx) + } + return candidates +} + +func parsePyImports(filePath string, content string) []string { + dir := path.Dir(filePath) + seen := make(map[string]bool) + var result []string + + for _, m := range pyRelativeImport.FindAllStringSubmatch(content, -1) { + modulePart := m[1] // e.g. ".", ".foo", "..foo.bar" + + // Count leading dots. + dots := 0 + for _, c := range modulePart { + if c == '.' { + dots++ + } else { + break + } + } + + // Start from the importing file's directory, go up (dots-1) levels. + base := dir + for i := 1; i < dots; i++ { + base = path.Dir(base) + } + + // The rest after the dots is the module path. + rest := modulePart[dots:] + if rest == "" { + // "from . import x" — refers to __init__.py in current package. + p := path.Join(base, "__init__.py") + if !seen[p] { + seen[p] = true + result = append(result, p) + } + continue + } + + // Convert dotted module name to path: foo.bar -> foo/bar + parts := strings.Split(rest, ".") + modPath := path.Join(append([]string{base}, parts...)...) + + // Could be a module file or a package directory. + candidates := []string{ + modPath + ".py", + path.Join(modPath, "__init__.py"), + } + for _, p := range candidates { + if !seen[p] { + seen[p] = true + result = append(result, p) + } + } + } + + return result +} + +// ResolveImports takes all source files and returns a map from each file path +// to its local import paths, filtered to only include paths that actually +// exist in the file set. +func ResolveImports(files []SourceFile) map[string][]string { + known := make(map[string]bool, len(files)) + for _, f := range files { + known[f.Path] = true + } + + result := make(map[string][]string) + for _, f := range files { + candidates := ParseImports(f.Path, f.Content) + var resolved []string + for _, c := range candidates { + if known[c] { + resolved = append(resolved, c) + } + } + if len(resolved) > 0 { + result[f.Path] = resolved + } + } + + return result +} diff --git a/internal/ingest/imports_test.go b/internal/ingest/imports_test.go new file mode 100644 index 0000000..155dd9e --- /dev/null +++ b/internal/ingest/imports_test.go @@ -0,0 +1,296 @@ +package ingest + +import ( + "sort" + "testing" +) + +func TestParseImports_JSImportFrom(t *testing.T) { + content := ` +import { foo } from './utils' +import bar from '../lib/bar' +import * as baz from './components/baz.ts' +` + got := ParseImports("src/app.ts", content) + + want := map[string]bool{ + "src/utils.ts": true, + "src/utils.tsx": true, + "src/utils.js": true, + "src/utils.jsx": true, + "src/utils/index.ts": true, + "src/utils/index.tsx": true, + "src/utils/index.js": true, + "src/utils/index.jsx": true, + "lib/bar.ts": true, + "lib/bar.tsx": true, + "lib/bar.js": true, + "lib/bar.jsx": true, + "lib/bar/index.ts": true, + "lib/bar/index.tsx": true, + "lib/bar/index.js": true, + "lib/bar/index.jsx": true, + "src/components/baz.ts": true, + } + + for _, p := range got { + if !want[p] { + t.Errorf("unexpected import path: %q", p) + } + } + + // Check that baz.ts (with explicit extension) resolves exactly. + found := false + for _, p := range got { + if p == "src/components/baz.ts" { + found = true + } + } + if !found { + t.Error("expected src/components/baz.ts in results") + } +} + +func TestParseImports_JSRequire(t *testing.T) { + content := ` +const x = require('./config') +const y = require("../shared/types") +` + got := ParseImports("src/index.js", content) + + hasConfig := false + hasTypes := false + for _, p := range got { + if p == "src/config.js" || p == "src/config.ts" { + hasConfig = true + } + if p == "shared/types.js" || p == "shared/types.ts" { + hasTypes = true + } + } + if !hasConfig { + t.Error("expected config import candidates") + } + if !hasTypes { + t.Error("expected shared/types import candidates") + } +} + +func TestParseImports_JSDynamicImport(t *testing.T) { + content := `const mod = import('./lazy-module')` + got := ParseImports("src/app.ts", content) + + found := false + for _, p := range got { + if p == "src/lazy-module.ts" { + found = true + } + } + if !found { + t.Error("expected src/lazy-module.ts in dynamic import results") + } +} + +func TestParseImports_JSIgnoresNpmPackages(t *testing.T) { + content := ` +import React from 'react' +import { useState } from 'react' +const express = require('express') +import lodash from 'lodash/fp' +import { foo } from './local' +` + got := ParseImports("src/app.tsx", content) + + for _, p := range got { + if p == "react" || p == "express" || p == "lodash/fp" { + t.Errorf("should not include npm package: %q", p) + } + } + + found := false + for _, p := range got { + if p == "src/local.ts" || p == "src/local.tsx" { + found = true + } + } + if !found { + t.Error("expected local import candidates") + } +} + +func TestParseImports_GoReturnsEmpty(t *testing.T) { + content := ` +package main + +import ( + "fmt" + "net/http" + "github.com/pkg/errors" +) +` + got := ParseImports("cmd/main.go", content) + + if len(got) != 0 { + t.Errorf("expected empty imports for Go, got %v", got) + } +} + +func TestParseImports_PythonRelativeImports(t *testing.T) { + content := ` +from . import utils +from .models import User +from ..shared import helpers +` + got := ParseImports("src/app/views.py", content) + + want := map[string]bool{ + "src/app/__init__.py": true, + "src/app/models.py": true, + "src/app/models/__init__.py": true, + "src/shared.py": true, + "src/shared/__init__.py": true, + } + + for _, p := range got { + if !want[p] { + t.Errorf("unexpected python import path: %q", p) + } + } + + // Verify the "from . import" yields __init__.py. + found := false + for _, p := range got { + if p == "src/app/__init__.py" { + found = true + } + } + if !found { + t.Error("expected src/app/__init__.py for 'from . import utils'") + } +} + +func TestParseImports_PythonIgnoresAbsoluteImports(t *testing.T) { + content := ` +import os +import sys +from collections import defaultdict +from .local import thing +` + got := ParseImports("pkg/module.py", content) + + for _, p := range got { + if p == "os" || p == "sys" || p == "collections" { + t.Errorf("should not include stdlib import: %q", p) + } + } + + found := false + for _, p := range got { + if p == "pkg/local.py" || p == "pkg/local/__init__.py" { + found = true + } + } + if !found { + t.Error("expected pkg/local.py or pkg/local/__init__.py") + } +} + +func TestResolveImports_FiltersToExistingPaths(t *testing.T) { + files := []SourceFile{ + { + Path: "src/app.ts", + Content: `import { foo } from './utils'`, + Language: "typescript", + }, + { + Path: "src/utils.ts", + Content: `export const foo = 1`, + Language: "typescript", + }, + { + Path: "src/other.ts", + Content: `import { bar } from './nonexistent'`, + Language: "typescript", + }, + } + + result := ResolveImports(files) + + // src/app.ts imports ./utils -> src/utils.ts exists. + appImports, ok := result["src/app.ts"] + if !ok { + t.Fatal("expected src/app.ts in result map") + } + found := false + for _, p := range appImports { + if p == "src/utils.ts" { + found = true + } + } + if !found { + t.Errorf("expected src/utils.ts in resolved imports, got %v", appImports) + } + + // src/other.ts imports ./nonexistent -> nothing matches. + if otherImports, ok := result["src/other.ts"]; ok { + t.Errorf("expected no resolved imports for src/other.ts, got %v", otherImports) + } +} + +func TestResolveImports_PythonCrossFile(t *testing.T) { + files := []SourceFile{ + { + Path: "pkg/views.py", + Content: `from .models import User`, + Language: "python", + }, + { + Path: "pkg/models.py", + Content: `class User: pass`, + Language: "python", + }, + } + + result := ResolveImports(files) + + imports, ok := result["pkg/views.py"] + if !ok { + t.Fatal("expected pkg/views.py in result map") + } + + sort.Strings(imports) + if len(imports) != 1 || imports[0] != "pkg/models.py" { + t.Errorf("expected [pkg/models.py], got %v", imports) + } +} + +func TestParseImports_JSSideEffectImport(t *testing.T) { + content := `import './polyfills' +import "../setup" +` + got := ParseImports("src/main.ts", content) + + hasPolyfills := false + hasSetup := false + for _, p := range got { + if p == "src/polyfills.ts" || p == "src/polyfills.js" { + hasPolyfills = true + } + if p == "setup.ts" || p == "setup.js" { + hasSetup = true + } + } + if !hasPolyfills { + t.Error("expected polyfills import candidates") + } + if !hasSetup { + t.Error("expected setup import candidates") + } +} + +func TestParseImports_UnknownExtensionReturnsNil(t *testing.T) { + got := ParseImports("readme.md", "# Hello\nSome text") + if got != nil { + t.Errorf("expected nil for unknown extension, got %v", got) + } +} diff --git a/internal/ingest/walker.go b/internal/ingest/walker.go new file mode 100644 index 0000000..6043e6b --- /dev/null +++ b/internal/ingest/walker.go @@ -0,0 +1,308 @@ +package ingest + +import ( + "bytes" + "fmt" + "io/fs" + "log/slog" + "os" + "path/filepath" + "strings" + "unicode/utf8" + + ignore "github.com/sabhiram/go-gitignore" +) + +// SourceFile represents a single source file with metadata. +type SourceFile struct { + Path string // Relative path from repo root (e.g., "src/api/handlers.go") + Content string // Full file content (UTF-8 validated) + LineCount int // Number of lines + Language string // Inferred language (e.g., "go", "python") +} + +// FileMap maps file paths to their content for snippet extraction. +type FileMap map[string]string + +// langExtensions maps file extensions to language names. +var langExtensions = map[string]string{ + ".go": "go", + ".py": "python", + ".ts": "typescript", + ".tsx": "typescript", + ".js": "javascript", + ".jsx": "javascript", + ".java": "java", + ".rb": "ruby", + ".rs": "rust", + ".c": "c", + ".h": "c", + ".cpp": "cpp", + ".cc": "cpp", + ".hpp": "cpp", + ".cs": "csharp", + ".swift": "swift", + ".kt": "kotlin", + ".kts": "kotlin", + ".scala": "scala", + ".php": "php", + ".sh": "shell", + ".bash": "shell", + ".zsh": "shell", + ".yaml": "yaml", + ".yml": "yaml", + ".json": "json", + ".xml": "xml", + ".html": "html", + ".htm": "html", + ".css": "css", + ".scss": "scss", + ".sql": "sql", + ".proto": "protobuf", + ".r": "r", + ".R": "r", + ".pl": "perl", + ".pm": "perl", + ".lua": "lua", + ".dart": "dart", + ".tf": "terraform", + ".hcl": "hcl", + ".md": "markdown", + ".toml": "toml", + ".ini": "ini", + ".cfg": "ini", + ".dockerfile": "dockerfile", + ".ex": "elixir", + ".exs": "elixir", + ".erl": "erlang", + ".hs": "haskell", + ".ml": "ocaml", + ".vue": "vue", + ".svelte": "svelte", +} + +// InferLanguage returns the language name for a given file path based on extension. +// Returns empty string for unrecognized extensions. +func InferLanguage(path string) string { + // Handle Dockerfile specially (no extension). + base := filepath.Base(path) + lower := strings.ToLower(base) + if lower == "dockerfile" || strings.HasPrefix(lower, "dockerfile.") { + return "dockerfile" + } + if lower == "makefile" || lower == "gnumakefile" { + return "makefile" + } + + ext := strings.ToLower(filepath.Ext(path)) + if lang, ok := langExtensions[ext]; ok { + return lang + } + return "" +} + +// isBinaryContent checks whether data looks like binary content by searching +// for a null byte in the first 512 bytes. +func isBinaryContent(data []byte) bool { + limit := 512 + if len(data) < limit { + limit = len(data) + } + return bytes.Contains(data[:limit], []byte{0}) +} + +// sanitizeUTF8 replaces invalid UTF-8 bytes with the Unicode replacement character. +func sanitizeUTF8(data []byte) string { + if utf8.Valid(data) { + return string(data) + } + var buf strings.Builder + buf.Grow(len(data)) + for len(data) > 0 { + r, size := utf8.DecodeRune(data) + if r == utf8.RuneError && size <= 1 { + buf.WriteRune(utf8.RuneError) + data = data[1:] + } else { + buf.WriteRune(r) + data = data[size:] + } + } + return buf.String() +} + +// countLines returns the number of lines in s. +// An empty string has 0 lines, a string with no newline has 1 line. +func countLines(s string) int { + if s == "" { + return 0 + } + n := strings.Count(s, "\n") + if !strings.HasSuffix(s, "\n") { + n++ + } + return n +} + +// gitignoreStack tracks a stack of compiled .gitignore matchers for nested +// .gitignore support. Each entry is associated with the directory depth at which +// the .gitignore was found. +type gitignoreStack struct { + entries []gitignoreEntry +} + +type gitignoreEntry struct { + matcher *ignore.GitIgnore + depth int +} + +func (s *gitignoreStack) push(matcher *ignore.GitIgnore, depth int) { + s.entries = append(s.entries, gitignoreEntry{matcher: matcher, depth: depth}) +} + +// popAbove removes matchers that were added at a depth greater than the given depth. +func (s *gitignoreStack) popAbove(depth int) { + for len(s.entries) > 0 && s.entries[len(s.entries)-1].depth > depth { + s.entries = s.entries[:len(s.entries)-1] + } +} + +// isIgnored checks whether the given relative path is matched by any active +// .gitignore in the stack. Paths are checked against all active matchers. +func (s *gitignoreStack) isIgnored(relPath string, isDir bool) bool { + pathToCheck := relPath + if isDir { + pathToCheck = relPath + "/" + } + for _, entry := range s.entries { + if entry.matcher.MatchesPath(pathToCheck) { + return true + } + } + return false +} + +// WalkDir recursively walks the directory tree rooted at root, returning all +// source files that are not ignored by .gitignore, not symlinks, and not binary. +// Permission errors are logged as warnings but do not halt the walk. +func WalkDir(root string) ([]SourceFile, error) { + root, err := filepath.Abs(root) + if err != nil { + return nil, err + } + + // Verify the root directory exists. + info, err := os.Stat(root) + if err != nil { + return nil, fmt.Errorf("cannot access root directory: %w", err) + } + if !info.IsDir() { + return nil, fmt.Errorf("root path is not a directory: %s", root) + } + + var files []SourceFile + stack := &gitignoreStack{} + + // Load root .gitignore if present. + rootGitignore := filepath.Join(root, ".gitignore") + if gi, err := ignore.CompileIgnoreFile(rootGitignore); err == nil { + stack.push(gi, 0) + } + + err = filepath.WalkDir(root, func(path string, d fs.DirEntry, err error) error { + if err != nil { + slog.Warn("permission error, skipping", "path", path, "error", err) + return nil + } + + // filepath.Rel cannot fail here since both root and path are absolute, + // but propagate the error to avoid silent corruption if assumptions change. + relPath, relErr := filepath.Rel(root, path) + if relErr != nil { + return relErr + } + + // Skip the root directory itself. + if relPath == "." { + return nil + } + + depth := strings.Count(relPath, string(filepath.Separator)) + + // Pop gitignore entries from deeper directories we've left. + stack.popAbove(depth) + + // Skip .git directory. + if d.IsDir() && d.Name() == ".git" { + return fs.SkipDir + } + + // Check symlinks via Lstat. + info, lstatErr := os.Lstat(path) + if lstatErr != nil { + slog.Warn("cannot stat file, skipping", "path", relPath, "error", lstatErr) + return nil + } + if info.Mode()&os.ModeSymlink != 0 { + slog.Debug("skipping symlink", "path", relPath) + if info.IsDir() { + return fs.SkipDir + } + return nil + } + + // Check gitignore patterns. + if stack.isIgnored(relPath, d.IsDir()) { + slog.Debug("ignored by .gitignore", "path", relPath) + if d.IsDir() { + return fs.SkipDir + } + return nil + } + + // For directories, check for nested .gitignore. + if d.IsDir() { + nestedGitignore := filepath.Join(path, ".gitignore") + if gi, loadErr := ignore.CompileIgnoreFile(nestedGitignore); loadErr == nil { + stack.push(gi, depth+1) + } + return nil + } + + // Skip .gitignore files — they are metadata, not source. + if d.Name() == ".gitignore" { + return nil + } + + // Read file content. + data, readErr := os.ReadFile(path) + if readErr != nil { + slog.Warn("cannot read file, skipping", "path", relPath, "error", readErr) + return nil + } + + // Skip binary files. + if isBinaryContent(data) { + slog.Debug("skipping binary file", "path", relPath) + return nil + } + + // Sanitize UTF-8. + content := sanitizeUTF8(data) + + files = append(files, SourceFile{ + Path: relPath, + Content: content, + LineCount: countLines(content), + Language: InferLanguage(relPath), + }) + + return nil + }) + + if err != nil { + return nil, err + } + + return files, nil +} diff --git a/internal/ingest/walker_test.go b/internal/ingest/walker_test.go new file mode 100644 index 0000000..bb35c9b --- /dev/null +++ b/internal/ingest/walker_test.go @@ -0,0 +1,544 @@ +package ingest + +import ( + "os" + "path/filepath" + "runtime" + "sort" + "strings" + "testing" +) + +// writeFile is a test helper that creates a file with the given content in dir. +func writeFile(t *testing.T, dir, name, content string) { + t.Helper() + full := filepath.Join(dir, name) + if err := os.MkdirAll(filepath.Dir(full), 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(full, []byte(content), 0o644); err != nil { + t.Fatal(err) + } +} + +func filePaths(files []SourceFile) []string { + out := make([]string, len(files)) + for i, f := range files { + out[i] = f.Path + } + sort.Strings(out) + return out +} + +func TestWalkDir_SimpleDirectory(t *testing.T) { + root := t.TempDir() + writeFile(t, root, "main.go", "package main\n") + writeFile(t, root, "lib/utils.go", "package lib\n") + writeFile(t, root, "lib/helpers.py", "def helper(): pass\n") + + files, err := WalkDir(root) + if err != nil { + t.Fatalf("WalkDir: %v", err) + } + + paths := filePaths(files) + expected := []string{"lib/helpers.py", "lib/utils.go", "main.go"} + if len(paths) != len(expected) { + t.Fatalf("got %d files %v, want %d files %v", len(paths), paths, len(expected), expected) + } + for i, p := range paths { + if p != expected[i] { + t.Errorf("file %d: got %q, want %q", i, p, expected[i]) + } + } +} + +func TestWalkDir_GitignoreRespected(t *testing.T) { + root := t.TempDir() + writeFile(t, root, ".gitignore", "*.log\nbuild/\n") + writeFile(t, root, "main.go", "package main\n") + writeFile(t, root, "app.log", "some log\n") + writeFile(t, root, "build/output.go", "package build\n") + + files, err := WalkDir(root) + if err != nil { + t.Fatalf("WalkDir: %v", err) + } + + paths := filePaths(files) + if len(paths) != 1 { + t.Fatalf("got %v, want only [main.go]", paths) + } + if paths[0] != "main.go" { + t.Errorf("got %q, want %q", paths[0], "main.go") + } +} + +func TestWalkDir_NestedGitignoreOverrides(t *testing.T) { + root := t.TempDir() + writeFile(t, root, ".gitignore", "*.tmp\n") + writeFile(t, root, "main.go", "package main\n") + writeFile(t, root, "cache.tmp", "cached\n") + // Nested .gitignore adds its own pattern. + writeFile(t, root, "sub/.gitignore", "*.dat\n") + writeFile(t, root, "sub/code.go", "package sub\n") + writeFile(t, root, "sub/data.dat", "binary data\n") + writeFile(t, root, "sub/notes.tmp", "temp notes\n") + + files, err := WalkDir(root) + if err != nil { + t.Fatalf("WalkDir: %v", err) + } + + paths := filePaths(files) + // cache.tmp ignored by root .gitignore + // sub/data.dat ignored by sub/.gitignore + // sub/notes.tmp ignored by root .gitignore + expected := []string{"main.go", "sub/code.go"} + if len(paths) != len(expected) { + t.Fatalf("got %v, want %v", paths, expected) + } + for i, p := range paths { + if p != expected[i] { + t.Errorf("file %d: got %q, want %q", i, p, expected[i]) + } + } +} + +func TestWalkDir_SymlinksSkipped(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("symlink test not reliable on windows") + } + + root := t.TempDir() + writeFile(t, root, "real.go", "package main\n") + + // Create a file symlink. + target := filepath.Join(root, "real.go") + link := filepath.Join(root, "link.go") + if err := os.Symlink(target, link); err != nil { + t.Fatal(err) + } + + // Create a directory symlink. + subdir := filepath.Join(root, "subdir") + if err := os.Mkdir(subdir, 0o755); err != nil { + t.Fatal(err) + } + writeFile(t, root, "subdir/inner.go", "package subdir\n") + dirLink := filepath.Join(root, "linkeddir") + if err := os.Symlink(subdir, dirLink); err != nil { + t.Fatal(err) + } + + files, err := WalkDir(root) + if err != nil { + t.Fatalf("WalkDir: %v", err) + } + + paths := filePaths(files) + // Only real.go and subdir/inner.go should appear. + // link.go and linkeddir/* should be skipped. + expected := []string{"real.go", "subdir/inner.go"} + if len(paths) != len(expected) { + t.Fatalf("got %v, want %v", paths, expected) + } + for i, p := range paths { + if p != expected[i] { + t.Errorf("file %d: got %q, want %q", i, p, expected[i]) + } + } +} + +func TestWalkDir_BinaryFileDetection(t *testing.T) { + root := t.TempDir() + writeFile(t, root, "source.go", "package main\n") + + // Write a binary file with null byte in first 512 bytes. + binaryContent := make([]byte, 256) + binaryContent[0] = 0x89 + binaryContent[1] = 0x50 // PNG-like header + binaryContent[10] = 0x00 // null byte + if err := os.WriteFile(filepath.Join(root, "image.png"), binaryContent, 0o644); err != nil { + t.Fatal(err) + } + + // Write a text file that happens to be large but has no null bytes. + bigText := strings.Repeat("hello world\n", 1000) + writeFile(t, root, "large.txt", bigText) + + files, err := WalkDir(root) + if err != nil { + t.Fatalf("WalkDir: %v", err) + } + + paths := filePaths(files) + // image.png should be excluded (binary), source.go and large.txt included. + expected := []string{"large.txt", "source.go"} + if len(paths) != len(expected) { + t.Fatalf("got %v, want %v", paths, expected) + } +} + +func TestWalkDir_EmptyDirectories(t *testing.T) { + root := t.TempDir() + // Create empty subdirectories. + if err := os.MkdirAll(filepath.Join(root, "empty1"), 0o755); err != nil { + t.Fatal(err) + } + if err := os.MkdirAll(filepath.Join(root, "empty2", "nested"), 0o755); err != nil { + t.Fatal(err) + } + writeFile(t, root, "file.go", "package main\n") + + files, err := WalkDir(root) + if err != nil { + t.Fatalf("WalkDir: %v", err) + } + + paths := filePaths(files) + if len(paths) != 1 || paths[0] != "file.go" { + t.Fatalf("got %v, want [file.go]", paths) + } +} + +func TestWalkDir_DeeplyNested(t *testing.T) { + root := t.TempDir() + + // Create a directory structure 12 levels deep. + deepPath := "a/b/c/d/e/f/g/h/i/j/k/l" + writeFile(t, root, filepath.Join(deepPath, "deep.go"), "package deep\n") + + files, err := WalkDir(root) + if err != nil { + t.Fatalf("WalkDir: %v", err) + } + + if len(files) != 1 { + t.Fatalf("got %d files, want 1", len(files)) + } + if files[0].Path != filepath.Join(deepPath, "deep.go") { + t.Errorf("got path %q, want %q", files[0].Path, filepath.Join(deepPath, "deep.go")) + } +} + +func TestWalkDir_PermissionDenied(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("permission test not reliable on windows") + } + if os.Getuid() == 0 { + t.Skip("running as root, permission test not meaningful") + } + + root := t.TempDir() + writeFile(t, root, "readable.go", "package main\n") + + // Create an unreadable directory. + unreadable := filepath.Join(root, "secret") + if err := os.MkdirAll(unreadable, 0o755); err != nil { + t.Fatal(err) + } + writeFile(t, root, "secret/hidden.go", "package secret\n") + if err := os.Chmod(unreadable, 0o000); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + if err := os.Chmod(unreadable, 0o755); err != nil { + t.Fatalf("restoring unreadable dir permissions: %v", err) + } + }) + + files, err := WalkDir(root) + if err != nil { + t.Fatalf("WalkDir should not fail on permission errors: %v", err) + } + + paths := filePaths(files) + // readable.go should still be returned; the walk should continue. + if len(paths) != 1 || paths[0] != "readable.go" { + t.Fatalf("got %v, want [readable.go]", paths) + } +} + +func TestWalkDir_MixedEncodings(t *testing.T) { + root := t.TempDir() + + // Valid UTF-8. + writeFile(t, root, "utf8.go", "// café résumé\npackage main\n") + + // Latin-1 bytes (not valid UTF-8): 0xE9 is 'é' in Latin-1 but invalid standalone UTF-8. + latin1Content := []byte("// caf\xe9\npackage main\n") + if err := os.WriteFile(filepath.Join(root, "latin1.go"), latin1Content, 0o644); err != nil { + t.Fatal(err) + } + + files, err := WalkDir(root) + if err != nil { + t.Fatalf("WalkDir: %v", err) + } + + if len(files) != 2 { + t.Fatalf("got %d files, want 2", len(files)) + } + + // Both files should have content (no panic, no error). + for _, f := range files { + if f.Content == "" { + t.Errorf("file %q has empty content", f.Path) + } + if f.LineCount == 0 { + t.Errorf("file %q has 0 lines", f.Path) + } + } + + // The Latin-1 file should have replacement characters for invalid bytes. + for _, f := range files { + if f.Path == "latin1.go" { + if !strings.Contains(f.Content, "\uFFFD") { + t.Error("expected replacement character in latin1.go content") + } + } + } +} + +func TestInferLanguage(t *testing.T) { + tests := []struct { + path string + want string + }{ + {"main.go", "go"}, + {"script.py", "python"}, + {"app.ts", "typescript"}, + {"component.tsx", "typescript"}, + {"index.js", "javascript"}, + {"App.jsx", "javascript"}, + {"Service.java", "java"}, + {"helper.rb", "ruby"}, + {"lib.rs", "rust"}, + {"code.c", "c"}, + {"code.cpp", "cpp"}, + {"code.cs", "csharp"}, + {"app.swift", "swift"}, + {"app.kt", "kotlin"}, + {"test.scala", "scala"}, + {"page.php", "php"}, + {"run.sh", "shell"}, + {"config.yaml", "yaml"}, + {"config.yml", "yaml"}, + {"data.json", "json"}, + {"layout.html", "html"}, + {"style.css", "css"}, + {"query.sql", "sql"}, + {"schema.proto", "protobuf"}, + {"Dockerfile", "dockerfile"}, + {"Dockerfile.prod", "dockerfile"}, + {"Makefile", "makefile"}, + {"unknown.xyz", ""}, + {"noext", ""}, + } + + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + got := InferLanguage(tt.path) + if got != tt.want { + t.Errorf("InferLanguage(%q) = %q, want %q", tt.path, got, tt.want) + } + }) + } +} + +func TestWalkDir_SourceFileMetadata(t *testing.T) { + root := t.TempDir() + writeFile(t, root, "main.go", "package main\n\nfunc main() {}\n") + + files, err := WalkDir(root) + if err != nil { + t.Fatalf("WalkDir: %v", err) + } + + if len(files) != 1 { + t.Fatalf("got %d files, want 1", len(files)) + } + + f := files[0] + if f.Path != "main.go" { + t.Errorf("Path = %q, want %q", f.Path, "main.go") + } + if f.Content != "package main\n\nfunc main() {}\n" { + t.Errorf("Content = %q, want %q", f.Content, "package main\n\nfunc main() {}\n") + } + if f.LineCount != 3 { + t.Errorf("LineCount = %d, want 3", f.LineCount) + } + if f.Language != "go" { + t.Errorf("Language = %q, want %q", f.Language, "go") + } +} + +func TestWalkDir_GitDirectorySkipped(t *testing.T) { + root := t.TempDir() + writeFile(t, root, "main.go", "package main\n") + writeFile(t, root, ".git/objects/pack", "binary pack\n") + writeFile(t, root, ".git/HEAD", "ref: refs/heads/main\n") + + files, err := WalkDir(root) + if err != nil { + t.Fatalf("WalkDir: %v", err) + } + + paths := filePaths(files) + if len(paths) != 1 || paths[0] != "main.go" { + t.Fatalf("got %v, want [main.go]", paths) + } +} + +func TestIsBinaryContent(t *testing.T) { + tests := []struct { + name string + data []byte + binary bool + }{ + {"empty", []byte{}, false}, + {"text", []byte("hello world\n"), false}, + {"null at start", []byte{0x00, 0x41, 0x42}, true}, + {"null in middle", append([]byte("hello"), append([]byte{0x00}, []byte("world")...)...), true}, + {"null at byte 511", func() []byte { + b := make([]byte, 512) + for i := range b { + b[i] = 'A' + } + b[511] = 0x00 + return b + }(), true}, + {"null at byte 512", func() []byte { + b := make([]byte, 513) + for i := range b { + b[i] = 'A' + } + b[512] = 0x00 + return b + }(), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isBinaryContent(tt.data) + if got != tt.binary { + t.Errorf("isBinaryContent = %v, want %v", got, tt.binary) + } + }) + } +} + +func TestCountLines(t *testing.T) { + tests := []struct { + input string + want int + }{ + {"", 0}, + {"a", 1}, + {"a\n", 1}, + {"a\nb\n", 2}, + {"a\nb", 2}, + {"a\nb\nc\n", 3}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := countLines(tt.input) + if got != tt.want { + t.Errorf("countLines(%q) = %d, want %d", tt.input, got, tt.want) + } + }) + } +} + +func TestSanitizeUTF8(t *testing.T) { + // Valid UTF-8 passes through unchanged. + valid := "hello café" + if got := sanitizeUTF8([]byte(valid)); got != valid { + t.Errorf("sanitizeUTF8 changed valid UTF-8: got %q", got) + } + + // Invalid bytes get replaced with U+FFFD. + invalid := []byte("caf\xe9 world") + got := sanitizeUTF8(invalid) + if !strings.Contains(got, "\uFFFD") { + t.Errorf("expected replacement character, got %q", got) + } + if !strings.Contains(got, "caf") || !strings.Contains(got, " world") { + t.Errorf("valid parts should be preserved, got %q", got) + } +} + +func TestWalkDir_GitignoreDirectoryPattern(t *testing.T) { + root := t.TempDir() + writeFile(t, root, ".gitignore", "vendor/\nnode_modules/\n") + writeFile(t, root, "main.go", "package main\n") + writeFile(t, root, "vendor/dep.go", "package vendor\n") + writeFile(t, root, "node_modules/pkg/index.js", "module.exports = {}\n") + + files, err := WalkDir(root) + if err != nil { + t.Fatalf("WalkDir: %v", err) + } + + paths := filePaths(files) + if len(paths) != 1 || paths[0] != "main.go" { + t.Fatalf("got %v, want [main.go]", paths) + } +} + +func TestWalkDir_NonexistentRoot(t *testing.T) { + _, err := WalkDir("/nonexistent/path/that/does/not/exist") + if err == nil { + t.Fatal("expected error for nonexistent root, got nil") + } +} + +func TestWalkDir_EmptyRoot(t *testing.T) { + root := t.TempDir() + + files, err := WalkDir(root) + if err != nil { + t.Fatalf("WalkDir: %v", err) + } + + if len(files) != 0 { + t.Fatalf("got %d files, want 0", len(files)) + } +} + +func TestWalkDir_BinaryAtExactly512Boundary(t *testing.T) { + root := t.TempDir() + + // Null byte at position 511 (last checked byte) — should be detected as binary. + data511 := make([]byte, 600) + for i := range data511 { + data511[i] = 'A' + } + data511[511] = 0x00 + if err := os.WriteFile(filepath.Join(root, "boundary.bin"), data511, 0o644); err != nil { + t.Fatal(err) + } + + // Null byte at position 512 (just outside checked range) — should NOT be binary. + data512 := make([]byte, 600) + for i := range data512 { + data512[i] = 'A' + } + data512[512] = 0x00 + if err := os.WriteFile(filepath.Join(root, "not_binary.txt"), data512, 0o644); err != nil { + t.Fatal(err) + } + + files, err := WalkDir(root) + if err != nil { + t.Fatalf("WalkDir: %v", err) + } + + paths := filePaths(files) + if len(paths) != 1 || paths[0] != "not_binary.txt" { + t.Fatalf("got %v, want [not_binary.txt]", paths) + } +} diff --git a/internal/llm/claudecli.go b/internal/llm/claudecli.go new file mode 100644 index 0000000..4ac06c9 --- /dev/null +++ b/internal/llm/claudecli.go @@ -0,0 +1,223 @@ +package llm + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "net/textproto" + "os/exec" + "strings" + "time" +) + +var lookPathClaude = exec.LookPath + +// NewClaudeCLIClient creates an LLM client backed by the local Claude Code CLI. +// This allows Anthropic models to run using the user's Claude login instead of +// requiring a separate ANTHROPIC_API_KEY. +func NewClaudeCLIClient(cfg ClientConfig) (Client, error) { + if cfg.Timeout <= 0 { + cfg.Timeout = 600 * time.Second + } + if cfg.Logger == nil { + cfg.Logger = slog.Default() + } + + betas, err := extractClaudeCLIBetas(cfg.Headers) + if err != nil { + return nil, err + } + + claudePath, err := lookPathClaude("claude") + if err != nil { + return nil, fmt.Errorf("claude command not found in PATH: %w", err) + } + + return &claudeCLIClient{ + claudePath: claudePath, + timeout: cfg.Timeout, + logger: cfg.Logger, + betas: betas, + }, nil +} + +type claudeCLIClient struct { + claudePath string + timeout time.Duration + logger *slog.Logger + betas []string +} + +type claudeCLIResult struct { + Type string `json:"type"` + Subtype string `json:"subtype"` + IsError bool `json:"is_error"` + Result string `json:"result"` + StructuredOutput json.RawMessage `json:"structured_output"` + StopReason string `json:"stop_reason"` + Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + } `json:"usage"` + ModelUsage map[string]json.RawMessage `json:"modelUsage"` +} + +func (c *claudeCLIClient) ChatCompletion(ctx context.Context, req ChatRequest) (*ChatResponse, error) { + systemPrompt, userPrompt := splitClaudeCLIMessages(req.Messages) + + args := []string{ + "-p", + "--output-format", "json", + "--tools", "", + "--permission-mode", "bypassPermissions", + } + if req.Model != "" { + args = append(args, "--model", req.Model) + } + if systemPrompt != "" { + args = append(args, "--system-prompt", systemPrompt) + } + if req.ResponseSchema != nil { + args = append(args, "--json-schema", string(*req.ResponseSchema)) + } + if len(c.betas) > 0 { + args = append(args, "--betas") + args = append(args, c.betas...) + } + + stdout, stderr, err := c.run(ctx, args, userPrompt) + if err != nil { + combined := strings.TrimSpace(strings.Join([]string{stdout, stderr}, "\n")) + if isClaudeCLIContextLengthError(combined) { + return nil, ErrContextLengthExceeded + } + if combined == "" { + combined = err.Error() + } + return nil, fmt.Errorf("llm: claude CLI request failed: %s", combined) + } + + var result claudeCLIResult + if err := json.Unmarshal([]byte(stdout), &result); err != nil { + return nil, fmt.Errorf("llm: parsing claude CLI response: %w", err) + } + if result.IsError { + msg := strings.TrimSpace(result.Result) + if msg == "" { + msg = "unknown Claude CLI error" + } + if isClaudeCLIContextLengthError(msg) { + return nil, ErrContextLengthExceeded + } + return nil, fmt.Errorf("llm: claude CLI error: %s", msg) + } + + content := result.Result + if len(result.StructuredOutput) > 0 && string(result.StructuredOutput) != "null" { + content = string(result.StructuredOutput) + } + + return &ChatResponse{ + Content: content, + Usage: TokenUsage{PromptTokens: result.Usage.InputTokens, CompletionTokens: result.Usage.OutputTokens}, + FinishReason: mapClaudeCLIStopReason(result.StopReason), + Model: firstClaudeCLIModel(result.ModelUsage, req.Model), + }, nil +} + +func (c *claudeCLIClient) run(ctx context.Context, args []string, input string) (stdout string, stderr string, err error) { + if _, hasDeadline := ctx.Deadline(); !hasDeadline && c.timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, c.timeout) + defer cancel() + } + + cmd := exec.CommandContext(ctx, c.claudePath, args...) + cmd.Stdin = strings.NewReader(input) + + var outBuf bytes.Buffer + var errBuf bytes.Buffer + cmd.Stdout = &outBuf + cmd.Stderr = &errBuf + + err = cmd.Run() + return outBuf.String(), errBuf.String(), err +} + +func extractClaudeCLIBetas(headers http.Header) ([]string, error) { + if len(headers) == 0 { + return nil, nil + } + + var betas []string + for name, values := range headers { + canonical := textproto.CanonicalMIMEHeaderKey(name) + switch canonical { + case "Anthropic-Beta": + for _, value := range values { + trimmed := strings.TrimSpace(value) + if trimmed != "" { + betas = append(betas, trimmed) + } + } + default: + if len(values) > 0 { + return nil, fmt.Errorf("claude CLI auth only supports Anthropic-Beta custom headers; unsupported header %q", canonical) + } + } + } + + return betas, nil +} + +func splitClaudeCLIMessages(messages []Message) (systemPrompt string, userPrompt string) { + var systemParts []string + var promptParts []string + for _, msg := range messages { + content := strings.TrimSpace(msg.Content) + if content == "" { + continue + } + if strings.EqualFold(msg.Role, "system") { + systemParts = append(systemParts, content) + continue + } + role := strings.TrimSpace(msg.Role) + if len(promptParts) == 0 && strings.EqualFold(role, "user") { + promptParts = append(promptParts, content) + continue + } + if role == "" { + role = "user" + } + promptParts = append(promptParts, fmt.Sprintf("%s:\n%s", strings.ToUpper(role[:1])+strings.ToLower(role[1:]), content)) + } + return strings.Join(systemParts, "\n\n"), strings.Join(promptParts, "\n\n") +} + +func mapClaudeCLIStopReason(stopReason string) string { + switch stopReason { + case "max_tokens": + return "length" + default: + return stopReason + } +} + +func firstClaudeCLIModel(modelUsage map[string]json.RawMessage, fallback string) string { + for model := range modelUsage { + return model + } + return fallback +} + +func isClaudeCLIContextLengthError(msg string) bool { + lower := strings.ToLower(msg) + return strings.Contains(lower, "context length") || + strings.Contains(lower, "too many tokens") || + strings.Contains(lower, "prompt is too long") || + strings.Contains(lower, "input is too long") +} diff --git a/internal/llm/claudecli_test.go b/internal/llm/claudecli_test.go new file mode 100644 index 0000000..0821f8b --- /dev/null +++ b/internal/llm/claudecli_test.go @@ -0,0 +1,140 @@ +package llm + +import ( + "context" + "encoding/json" + "net/http" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestNewClaudeCLIClient_UnsupportedHeaders(t *testing.T) { + _, err := NewClaudeCLIClient(ClientConfig{ + Headers: http.Header{ + "X-Feature-Flag": []string{"enabled"}, + }, + }) + if err == nil { + t.Fatal("expected unsupported header error") + } + if !strings.Contains(err.Error(), "unsupported header") { + t.Fatalf("error = %v, want unsupported header", err) + } +} + +func TestClaudeCLIClient_ChatCompletionStructuredOutput(t *testing.T) { + dir := t.TempDir() + argsFile := filepath.Join(dir, "args.txt") + stdinFile := filepath.Join(dir, "stdin.txt") + claudePath := filepath.Join(dir, "claude") + + script := `#!/bin/sh +printf '%s\n' "$@" >"$TEST_ARGS_FILE" +cat >"$TEST_STDIN_FILE" +cat <<'JSON' +{"type":"result","subtype":"success","is_error":false,"result":"","structured_output":{"ok":true},"stop_reason":"max_tokens","usage":{"input_tokens":123,"output_tokens":45},"modelUsage":{"claude-sonnet-4-6":{"inputTokens":123,"outputTokens":45}}} +JSON +` + if err := os.WriteFile(claudePath, []byte(script), 0755); err != nil { + t.Fatalf("write fake claude: %v", err) + } + + t.Setenv("PATH", dir+string(os.PathListSeparator)+os.Getenv("PATH")) + t.Setenv("TEST_ARGS_FILE", argsFile) + t.Setenv("TEST_STDIN_FILE", stdinFile) + + client, err := NewClaudeCLIClient(ClientConfig{ + Headers: http.Header{ + "Anthropic-Beta": []string{"context-1m-2025-08-07", "other-beta"}, + }, + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatalf("NewClaudeCLIClient returned error: %v", err) + } + + schema := json.RawMessage(`{"type":"object","properties":{"ok":{"type":"boolean"}},"required":["ok"]}`) + resp, err := client.ChatCompletion(context.Background(), ChatRequest{ + Model: "claude-sonnet-4-6", + Messages: []Message{{Role: "system", Content: "system prompt"}, {Role: "user", Content: "user prompt"}}, + ResponseSchema: &schema, + }) + if err != nil { + t.Fatalf("ChatCompletion returned error: %v", err) + } + + if resp.Content != `{"ok":true}` { + t.Fatalf("Content = %q, want structured output JSON", resp.Content) + } + if resp.FinishReason != "length" { + t.Fatalf("FinishReason = %q, want %q", resp.FinishReason, "length") + } + if resp.Model != "claude-sonnet-4-6" { + t.Fatalf("Model = %q, want %q", resp.Model, "claude-sonnet-4-6") + } + if resp.Usage.PromptTokens != 123 || resp.Usage.CompletionTokens != 45 { + t.Fatalf("Usage = %+v, want prompt=123 completion=45", resp.Usage) + } + + argsData, err := os.ReadFile(argsFile) + if err != nil { + t.Fatalf("read args file: %v", err) + } + args := string(argsData) + for _, want := range []string{ + "-p", + "--output-format", + "json", + "--tools", + "--permission-mode", + "bypassPermissions", + "--model", + "claude-sonnet-4-6", + "--system-prompt", + "system prompt", + "--json-schema", + string(schema), + "--betas", + "context-1m-2025-08-07", + "other-beta", + } { + if !strings.Contains(args, want+"\n") { + t.Fatalf("expected args to contain %q, got:\n%s", want, args) + } + } + + stdinData, err := os.ReadFile(stdinFile) + if err != nil { + t.Fatalf("read stdin file: %v", err) + } + if got := string(stdinData); got != "user prompt" { + t.Fatalf("stdin = %q, want %q", got, "user prompt") + } +} + +func TestClaudeCLIClient_ContextLengthError(t *testing.T) { + dir := t.TempDir() + claudePath := filepath.Join(dir, "claude") + script := "#!/bin/sh\necho 'prompt is too long for this model' >&2\nexit 1\n" + if err := os.WriteFile(claudePath, []byte(script), 0755); err != nil { + t.Fatalf("write fake claude: %v", err) + } + + t.Setenv("PATH", dir+string(os.PathListSeparator)+os.Getenv("PATH")) + + client, err := NewClaudeCLIClient(ClientConfig{Timeout: 5 * time.Second}) + if err != nil { + t.Fatalf("NewClaudeCLIClient returned error: %v", err) + } + + _, err = client.ChatCompletion(context.Background(), ChatRequest{ + Model: "claude-sonnet-4-6", + Messages: []Message{{Role: "user", Content: "hi"}}, + }) + if err != ErrContextLengthExceeded { + t.Fatalf("error = %v, want %v", err, ErrContextLengthExceeded) + } +} diff --git a/internal/llm/client.go b/internal/llm/client.go new file mode 100644 index 0000000..1a31551 --- /dev/null +++ b/internal/llm/client.go @@ -0,0 +1,1159 @@ +package llm + +import ( + "bufio" + "bytes" + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "math" + "math/rand/v2" + "net" + "net/http" + "strconv" + "strings" + "sync/atomic" + "time" +) + +// Message represents a single chat message. +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// ChatRequest holds the parameters for a chat completion call. +type ChatRequest struct { + // Label is a caller-supplied tag for log correlation (e.g. "chunk 3/9", + // "feature-detection"). Carried into every log line inside the retry + // loop so concurrent requests can be distinguished. + Label string `json:"-"` + Endpoint string `json:"-"` // Model serving endpoint name. + Model string `json:"-"` // Model name (used by Anthropic/OpenAI direct APIs). + Messages []Message `json:"messages"` // System + user messages. + Temperature float64 `json:"temperature"` // Sampling temperature. + MaxTokens int `json:"max_tokens"` // Max output tokens. + ResponseSchema *json.RawMessage `json:"-"` // JSON Schema for response_format enforcement. + OutputMode OutputMode `json:"-"` // How to enforce structured output. + ModelParams map[string]any `json:"-"` // Provider-specific request body params (merged at top level). +} + +// ChatResponse holds the result of a chat completion call. +type ChatResponse struct { + Content string `json:"content"` // Raw response content. + Usage TokenUsage `json:"usage"` // Input/output token counts. + FinishReason string `json:"finish_reason"` // "stop", "length", "content_filter", etc. + Model string `json:"model"` // Model that generated the response. + + // Streaming-only timing. Zero for non-streaming responses. + // TimeToFirstToken is headers-received → first content_block_delta: + // that's server-side prompt processing + thinking (the model does all + // its thinking before the first visible output token). + // GenerationTime is first delta → message_stop: the pure output phase. + // CompletionTokens / GenerationTime is the real emit rate, free of + // prompt-processing and thinking overhead. + TimeToFirstToken time.Duration `json:"-"` + GenerationTime time.Duration `json:"-"` +} + +// TokenUsage tracks input and output token consumption. +type TokenUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + // Anthropic prompt-caching fields. Creation is billed at ~1.25× input + // rate; reads at ~0.1×. Zero for providers that don't report them. + CacheCreationTokens int `json:"cache_creation_tokens,omitempty"` + CacheReadTokens int `json:"cache_read_tokens,omitempty"` + // Characters (not tokens — Anthropic doesn't break the count down) in + // thinking blocks. Shows how much of CompletionTokens was reasoning vs + // the output you actually see. Streaming only. + ThinkingChars int `json:"thinking_chars,omitempty"` +} + +// OutputMode specifies how to enforce structured output. +type OutputMode int + +const ( + // OutputModeNone sends no structured output enforcement. + OutputModeNone OutputMode = iota + // OutputModeJSONSchema uses response_format with JSON Schema (OpenAI/GPT endpoints). + OutputModeJSONSchema + // OutputModeToolUse uses tool_use for structured output (Claude endpoints). + OutputModeToolUse +) + +// ErrContextLengthExceeded is returned when the model reports context length exceeded. +// Callers should re-chunk and retry rather than retrying the same request. +var ErrContextLengthExceeded = fmt.Errorf("llm: context length exceeded") + +// errNonRetryable wraps an error to indicate it should not be retried. +type errNonRetryable struct { + err error +} + +func (e *errNonRetryable) Error() string { return e.err.Error() } +func (e *errNonRetryable) Unwrap() error { return e.err } + +// errToolChoiceIncompatible indicates the endpoint rejected forced tool_choice +// for the selected model. Callers may retry without forcing tool choice. +type errToolChoiceIncompatible struct { + err error +} + +func (e *errToolChoiceIncompatible) Error() string { return e.err.Error() } +func (e *errToolChoiceIncompatible) Unwrap() error { return e.err } + +// errTemperatureIncompatible indicates the endpoint rejected the requested +// temperature value (e.g. adaptive-thinking models that require temperature=1). +// Callers may retry with temperature omitted so the API picks its own default. +type errTemperatureIncompatible struct { + err error +} + +func (e *errTemperatureIncompatible) Error() string { return e.err.Error() } +func (e *errTemperatureIncompatible) Unwrap() error { return e.err } + +// errMaxTokensParamIncompatible indicates the endpoint rejected `max_tokens` +// and requires `max_completion_tokens` instead (OpenAI GPT-5 family and +// reasoning-series models). Callers may retry using the new field name. +type errMaxTokensParamIncompatible struct { + err error +} + +func (e *errMaxTokensParamIncompatible) Error() string { return e.err.Error() } +func (e *errMaxTokensParamIncompatible) Unwrap() error { return e.err } + +// securityAnalysisToolDescription is sent alongside the JSON schema. Models +// that can't be forced to use a tool (tool_choice:auto only) read this at the +// exact moment they're populating fields — it's the cheapest place to prevent +// per-field drift like "none found" where [] belongs. +const securityAnalysisToolDescription = "Submit the security analysis results. " + + "You MUST call this tool — do not respond with plain text. " + + "security_issues and public_api_routes MUST be arrays: use [] when empty, never a string. " + + "Every issue requires file_path, numeric start_line/end_line, and a severity between 0 and 10." + +// Client abstracts LLM interaction for testability. +type Client interface { + // ChatCompletion sends a structured chat request and returns typed output. + ChatCompletion(ctx context.Context, req ChatRequest) (*ChatResponse, error) +} + +// ClientConfig holds configuration for the HTTP LLM client. +type ClientConfig struct { + BaseURL string // Base URL for the API (e.g., "https://host.databricks.com/serving-endpoints"). + Token string // Bearer token for authentication. + Provider string // "databricks", "anthropic", "openai" — controls URL construction and request format. + Headers http.Header // Additional headers to include in every request. + MaxRetries int // Maximum number of retries for transient errors (default: 3). + Timeout time.Duration // Per-request timeout (default: 600s). + Logger *slog.Logger // Structured logger. +} + +// httpClient implements Client using HTTP against an OpenAI-compatible API. +type httpClient struct { + baseURL string + token string + provider string + maxRetries int + timeout time.Duration + logger *slog.Logger + extraHeaders http.Header + client *http.Client + backoffFunc func(ctx context.Context, attempt int, retryAfter time.Duration) // override for testing + + // Learned constraints: set once when the API rejects a request feature, + // then applied to all subsequent requests to avoid wasted round-trips. + noForcedToolChoice atomic.Bool + dropTemperature atomic.Bool + useMaxCompletionTokens atomic.Bool +} + +// NewClient creates a new LLM HTTP client. +func NewClient(cfg ClientConfig) Client { + if cfg.MaxRetries <= 0 { + cfg.MaxRetries = 3 + } + if cfg.Timeout <= 0 { + cfg.Timeout = 600 * time.Second + } + if cfg.Logger == nil { + cfg.Logger = slog.Default() + } + + transport := &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12}, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + } + + return &httpClient{ + baseURL: strings.TrimRight(cfg.BaseURL, "/"), + token: cfg.Token, + provider: cfg.Provider, + maxRetries: cfg.MaxRetries, + timeout: cfg.Timeout, + logger: cfg.Logger, + extraHeaders: cfg.Headers.Clone(), + client: &http.Client{Timeout: cfg.Timeout, Transport: transport}, + } +} + +// ChatCompletion sends a chat completion request with retry logic for transient errors. +func (c *httpClient) ChatCompletion(ctx context.Context, req ChatRequest) (*ChatResponse, error) { + url := c.buildURL(req.Endpoint) + + var lastErr error + disableForcedToolChoice := c.noForcedToolChoice.Load() + dropTemperature := c.dropTemperature.Load() + useMaxCompletionTokens := c.useMaxCompletionTokens.Load() + + for attempt := 0; attempt <= c.maxRetries; attempt++ { + if attempt > 0 { + c.logger.Info("retrying LLM request", + "label", req.Label, + "attempt", attempt, + "max_retries", c.maxRetries, + ) + } + + body, err := c.buildRequestBody(req, !disableForcedToolChoice, dropTemperature, useMaxCompletionTokens) + if err != nil { + return nil, fmt.Errorf("llm: building request body: %w", err) + } + + resp, err := c.doRequest(ctx, url, req.Label, body) + if err != nil { + // Context cancellation is not retryable. + if ctx.Err() != nil { + return nil, fmt.Errorf("llm: context cancelled: %w", ctx.Err()) + } + lastErr = err + c.logger.Warn("LLM request failed", + "label", req.Label, + "attempt", attempt, + "error", err, + ) + if attempt < c.maxRetries { + c.backoff(ctx, attempt, 0) + } + continue + } + + chatResp, retryAfter, err := c.handleResponse(resp) + if err != nil { + // Some endpoints reject forced tool_choice for specific models. + // Retry once without forcing tool use while still providing tools. + var toolChoiceErr *errToolChoiceIncompatible + if errors.As(err, &toolChoiceErr) && req.OutputMode == OutputModeToolUse && !disableForcedToolChoice { + disableForcedToolChoice = true + c.noForcedToolChoice.Store(true) + lastErr = err + c.logger.Warn("endpoint rejected forced tool_choice; retrying without force", + "label", req.Label, + "provider", c.provider, + "model", req.Model, + ) + continue + } + + // Same adaptive pattern for temperature constraints: drop the field + // and let the API use its model-specific default. + var tempErr *errTemperatureIncompatible + if errors.As(err, &tempErr) && !dropTemperature { + dropTemperature = true + c.dropTemperature.Store(true) + lastErr = err + c.logger.Warn("endpoint rejected temperature; retrying without it", + "label", req.Label, + "provider", c.provider, + "model", req.Model, + ) + continue + } + + // OpenAI GPT-5 family and reasoning models require + // `max_completion_tokens` in place of `max_tokens`. + var maxTokErr *errMaxTokensParamIncompatible + if errors.As(err, &maxTokErr) && !useMaxCompletionTokens { + useMaxCompletionTokens = true + c.useMaxCompletionTokens.Store(true) + lastErr = err + c.logger.Warn("endpoint rejected max_tokens; retrying with max_completion_tokens", + "label", req.Label, + "provider", c.provider, + "model", req.Model, + ) + continue + } + + // Context length exceeded is not retryable — signal re-chunking. + if errors.Is(err, ErrContextLengthExceeded) { + return nil, err + } + + // Non-retryable errors (4xx client errors) should fail immediately. + var nonRetryable *errNonRetryable + if errors.As(err, &nonRetryable) { + return nil, nonRetryable.err + } + + lastErr = err + statusCode := 0 + if resp != nil { + statusCode = resp.StatusCode + } + c.logger.Warn("LLM response error", + "label", req.Label, + "attempt", attempt, + "status", statusCode, + "error", err, + ) + + if attempt < c.maxRetries { + c.backoff(ctx, attempt, retryAfter) + } + continue + } + + return chatResp, nil + } + + return nil, fmt.Errorf("llm: exhausted %d retries: %w", c.maxRetries, lastErr) +} + +func (c *httpClient) buildURL(endpoint string) string { + switch c.provider { + case "anthropic": + return c.baseURL + "/v1/messages" + case "openai": + return c.baseURL + "/v1/chat/completions" + case "google": + // Google's OpenAI-compat layer: request body, response body, and + // Bearer auth all ride the existing non-Anthropic paths. The only + // divergence is the URL — /chat/completions directly under the + // compat base, no /v1/ segment. + return c.baseURL + "/chat/completions" + default: // "databricks" or empty + if endpoint == "" { + return c.baseURL + "/chat/completions" + } + return c.baseURL + "/" + endpoint + "/invocations" + } +} + +// apiRequest is the OpenAI-compatible request body. +type apiRequest struct { + Model string `json:"model,omitempty"` + Messages []Message `json:"messages"` + Temperature *float64 `json:"temperature,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` + ResponseFormat *responseFormat `json:"response_format,omitempty"` + Tools []toolDefinition `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` +} + +type responseFormat struct { + Type string `json:"type"` + JSONSchema *json.RawMessage `json:"json_schema,omitempty"` +} + +// toolDefinition supports both OpenAI function-calling format and +// Anthropic/Databricks tool_use format. Only one of Function or +// the Anthropic fields (Name, Description, InputSchema) should be set. +type toolDefinition struct { + // OpenAI format + Type string `json:"type"` + Function *toolFunction `json:"function,omitempty"` + // Anthropic/Databricks format + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + InputSchema *json.RawMessage `json:"input_schema,omitempty"` +} + +type toolFunction struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters *json.RawMessage `json:"parameters"` +} + +func (c *httpClient) buildRequestBody(req ChatRequest, forceToolChoice, dropTemperature, useMaxCompletionTokens bool) ([]byte, error) { + if c.provider == "anthropic" { + return c.buildAnthropicRequestBody(req, forceToolChoice, dropTemperature) + } + + apiReq := apiRequest{ + Model: req.Model, + Messages: req.Messages, + } + if !dropTemperature { + t := req.Temperature + apiReq.Temperature = &t + } + if useMaxCompletionTokens { + apiReq.MaxCompletionTokens = req.MaxTokens + } else { + apiReq.MaxTokens = req.MaxTokens + } + + switch req.OutputMode { + case OutputModeJSONSchema: + if req.ResponseSchema != nil { + apiReq.ResponseFormat = &responseFormat{ + Type: "json_schema", + JSONSchema: req.ResponseSchema, + } + } + case OutputModeToolUse: + if req.ResponseSchema != nil { + // Extract the inner "schema" object from the response_format envelope + // (which has {"name": ..., "strict": ..., "schema": {...}}) to get + // the raw JSON Schema with a top-level "type" field. + inputSchema := extractInnerSchema(req.ResponseSchema) + apiReq.Tools = []toolDefinition{ + { + // OpenAI function-calling format (required by Databricks proxy). + Type: "function", + Function: &toolFunction{ + Name: "security_analysis", + Description: securityAnalysisToolDescription, + Parameters: inputSchema, + }, + }, + } + // Force the model to use the tool when supported. + if forceToolChoice { + apiReq.ToolChoice = map[string]any{"type": "function", "function": map[string]string{"name": "security_analysis"}} + } + } + } + + return marshalWithModelParams(apiReq, requestParamsWithout(req.ModelParams, dropTemperature, "temperature")) +} + +// extractInnerSchema extracts the inner "schema" object from the response_format +// envelope used by OpenAI's JSON Schema mode. The envelope has the shape: +// +// {"name": "...", "strict": true, "schema": { "type": "object", ... }} +// +// For Anthropic tool_use, we need just the inner schema (the part with "type": "object"). +// If extraction fails, returns the original raw message as-is. +func extractInnerSchema(raw *json.RawMessage) *json.RawMessage { + if raw == nil { + return nil + } + var envelope struct { + Schema json.RawMessage `json:"schema"` + } + if err := json.Unmarshal(*raw, &envelope); err != nil || len(envelope.Schema) == 0 { + return raw + } + return &envelope.Schema +} + +func (c *httpClient) doRequest(ctx context.Context, url, label string, body []byte) (*http.Response, error) { + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + switch c.provider { + case "anthropic": + httpReq.Header.Set("x-api-key", c.token) + httpReq.Header.Set("anthropic-version", "2023-06-01") + default: + if c.token != "" { + httpReq.Header.Set("Authorization", "Bearer "+c.token) + } + } + + for name, values := range c.extraHeaders { + if len(values) == 0 { + continue + } + httpReq.Header.Set(name, values[0]) + for _, value := range values[1:] { + httpReq.Header.Add(name, value) + } + } + + start := time.Now() + resp, err := c.client.Do(httpReq) + elapsed := time.Since(start) + if err != nil { + c.logger.Debug("HTTP request failed", "label", label, "url", url, "elapsed", elapsed, "error", err) + return nil, err + } + // "headers received", not "completed" — for streaming responses the body + // is still draining after this returns. elapsed here is time-to-first-byte; + // the caller's ttft/gen_time/elapsed give the full picture. + c.logger.Info("HTTP response headers received", "label", label, "status", resp.StatusCode, "ttfb", elapsed) + return resp, nil +} + +// apiResponse is the OpenAI-compatible response body. +type apiResponse struct { + Choices []apiChoice `json:"choices"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + } `json:"usage"` + Model string `json:"model"` + Error *apiError `json:"error,omitempty"` +} + +type apiChoice struct { + Message apiMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} + +type apiMessage struct { + Content flexString `json:"content"` + ToolCalls []apiToolCall `json:"tool_calls,omitempty"` +} + +// flexString unmarshals both a plain JSON string and an array of content +// parts (as returned by Gemini via the Databricks proxy). The array +// format looks like: [{"type":"text","text":"..."}]. In that case, all +// "text" fields are concatenated. +type flexString string + +func (f *flexString) UnmarshalJSON(data []byte) error { + // Fast path: plain string (OpenAI / Claude). + if len(data) > 0 && data[0] == '"' { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + *f = flexString(s) + return nil + } + + // Slow path: array of content parts (Gemini). + if len(data) > 0 && data[0] == '[' { + var parts []struct { + Type string `json:"type"` + Text string `json:"text"` + } + if err := json.Unmarshal(data, &parts); err != nil { + return err + } + var b strings.Builder + for _, p := range parts { + if p.Text != "" { + b.WriteString(p.Text) + } + } + *f = flexString(b.String()) + return nil + } + + // null → empty string + if string(data) == "null" { + *f = "" + return nil + } + + return fmt.Errorf("flexString: unexpected JSON token %q", data[0]) +} + +type apiToolCall struct { + Function struct { + Arguments string `json:"arguments"` + } `json:"function"` +} + +type apiError struct { + Message string `json:"message"` + Type string `json:"type"` + Code string `json:"code"` +} + +// handleResponse processes the HTTP response, returning the chat response, retry-after duration, +// and any error. A non-nil error with retryAfter > 0 indicates a retryable error. +func (c *httpClient) handleResponse(resp *http.Response) (*ChatResponse, time.Duration, error) { + defer resp.Body.Close() + + // Streaming responses (Anthropic only for now). Dispatch on Content-Type so + // a model_params override of stream:false falls through to the blocking + // path below without any further plumbing. Only successful responses + // stream; 4xx/5xx still return a JSON error body. + if resp.StatusCode == http.StatusOK && + strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") { + return c.readAnthropicStream(resp.Body) + } + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, 0, fmt.Errorf("reading response body: %w", err) + } + + // Parse Retry-After header for rate-limited responses. + retryAfter := parseRetryAfter(resp.Header.Get("Retry-After")) + + // Retryable: 429 (rate limit) and 5xx (server errors). + if resp.StatusCode == http.StatusTooManyRequests || (resp.StatusCode >= 500 && resp.StatusCode < 600) { + return nil, retryAfter, fmt.Errorf("llm: retryable error (status %d): %s", resp.StatusCode, truncate(string(respBody), 200)) + } + + // Non-retryable client errors (400, 401, 403, 404, etc.). + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + if isToolChoiceIncompatibleError(string(respBody)) { + return nil, 0, &errToolChoiceIncompatible{err: fmt.Errorf("llm: request failed (status %d): %s", resp.StatusCode, truncate(string(respBody), 200))} + } + if isTemperatureIncompatibleError(string(respBody)) { + return nil, 0, &errTemperatureIncompatible{err: fmt.Errorf("llm: request failed (status %d): %s", resp.StatusCode, truncate(string(respBody), 200))} + } + if isMaxTokensParamIncompatibleError(string(respBody)) { + return nil, 0, &errMaxTokensParamIncompatible{err: fmt.Errorf("llm: request failed (status %d): %s", resp.StatusCode, truncate(string(respBody), 200))} + } + + // Check for context_length_exceeded in OpenAI error response. + var apiResp apiResponse + if json.Unmarshal(respBody, &apiResp) == nil && apiResp.Error != nil { + if strings.Contains(apiResp.Error.Code, "context_length_exceeded") || + strings.Contains(apiResp.Error.Message, "context_length_exceeded") || + strings.Contains(apiResp.Error.Message, "maximum context length") { + return nil, 0, ErrContextLengthExceeded + } + } + // Check for Anthropic error format. + var anthErr struct { + Error *anthropicError `json:"error"` + } + if json.Unmarshal(respBody, &anthErr) == nil && anthErr.Error != nil { + if strings.Contains(anthErr.Error.Message, "too long") || strings.Contains(anthErr.Error.Message, "too many tokens") { + return nil, 0, ErrContextLengthExceeded + } + } + return nil, 0, &errNonRetryable{err: fmt.Errorf("llm: request failed (status %d): %s", resp.StatusCode, truncate(string(respBody), 200))} + } + + // Anthropic response parsing. + if c.provider == "anthropic" { + return c.parseAnthropicResponse(respBody) + } + + // Parse successful OpenAI-compatible response. + // Parse errors on a 200 are non-retryable — the server won't return a + // different format on retry. + var apiResp apiResponse + if err := json.Unmarshal(respBody, &apiResp); err != nil { + return nil, 0, &errNonRetryable{err: fmt.Errorf("llm: parsing response: %w", err)} + } + + if len(apiResp.Choices) == 0 { + return nil, 0, &errNonRetryable{err: fmt.Errorf("llm: empty choices in response")} + } + + choice := apiResp.Choices[0] + content := string(choice.Message.Content) + + // If tool_use mode, extract content from tool call arguments. + if content == "" && len(choice.Message.ToolCalls) > 0 { + content = choice.Message.ToolCalls[0].Function.Arguments + } + + return &ChatResponse{ + Content: content, + Usage: TokenUsage{PromptTokens: apiResp.Usage.PromptTokens, CompletionTokens: apiResp.Usage.CompletionTokens}, + FinishReason: choice.FinishReason, + Model: apiResp.Model, + }, 0, nil +} + +// anthropicRequest is the Anthropic Messages API request body. +type anthropicRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + System string `json:"system,omitempty"` + MaxTokens int `json:"max_tokens"` + Temperature *float64 `json:"temperature,omitempty"` + Tools []anthropicTool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + Stream bool `json:"stream,omitempty"` +} + +type anthropicTool struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema *json.RawMessage `json:"input_schema"` +} + +// anthropicResponse is the Anthropic Messages API response body. +type anthropicResponse struct { + ID string `json:"id"` + Content []anthropicBlock `json:"content"` + Model string `json:"model"` + StopReason string `json:"stop_reason"` + Usage anthropicUsage `json:"usage"` + Error *anthropicError `json:"error,omitempty"` +} + +type anthropicBlock struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input json.RawMessage `json:"input,omitempty"` +} + +type anthropicUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + CacheCreationInputTokens int `json:"cache_creation_input_tokens"` + CacheReadInputTokens int `json:"cache_read_input_tokens"` +} + +type anthropicError struct { + Type string `json:"type"` + Message string `json:"message"` +} + +func (c *httpClient) buildAnthropicRequestBody(req ChatRequest, forceToolChoice, dropTemperature bool) ([]byte, error) { + // Extract system message - Anthropic requires it as a separate field. + var system string + var messages []Message + for _, m := range req.Messages { + if m.Role == "system" { + system = m.Content + } else { + messages = append(messages, m) + } + } + + apiReq := anthropicRequest{ + Model: req.Model, + Messages: messages, + System: system, + MaxTokens: req.MaxTokens, + // Stream even though we don't need incremental output: SSE keepalive + // frames hold the TCP connection open through long generations. Without + // this, slow models trip the server-side idle timeout (~10–15min) and + // the edge sends RST before the full response is ready. + Stream: true, + } + if !dropTemperature { + t := req.Temperature + apiReq.Temperature = &t + } + + if req.OutputMode == OutputModeToolUse && req.ResponseSchema != nil { + inputSchema := extractInnerSchema(req.ResponseSchema) + apiReq.Tools = []anthropicTool{ + { + Name: "security_analysis", + Description: securityAnalysisToolDescription, + InputSchema: inputSchema, + }, + } + if forceToolChoice { + // {"type":"any"} forces tool use without naming a specific tool. + // With only one tool defined, it's functionally equivalent to + // {"type":"tool","name":"..."}. + // + // Thinking-enabled models reject any forced form of tool_choice + // (any, tool, or named) with + // "Thinking may not be enabled when tool_choice forces tool use". + // In that case the request loop (see processAnthropicResponse → + // isToolChoiceIncompatibleError) disables forceToolChoice and + // retries, which drops this field entirely — Anthropic's default + // with tools defined is tool_choice=auto, which is accepted + // alongside thinking. + apiReq.ToolChoice = map[string]string{"type": "any"} + } + } + + return marshalWithModelParams(apiReq, requestParamsWithout(req.ModelParams, dropTemperature, "temperature")) +} + +func marshalWithModelParams(base any, modelParams map[string]any) ([]byte, error) { + raw, err := json.Marshal(base) + if err != nil { + return nil, err + } + if len(modelParams) == 0 { + return raw, nil + } + + var payload map[string]any + if err := json.Unmarshal(raw, &payload); err != nil { + return nil, err + } + mergeRequestParams(payload, modelParams) + return json.Marshal(payload) +} + +func requestParamsWithout(modelParams map[string]any, drop bool, keys ...string) map[string]any { + if !drop || len(modelParams) == 0 || len(keys) == 0 { + return modelParams + } + + dropKeys := make(map[string]bool, len(keys)) + needsCopy := false + for _, key := range keys { + dropKeys[key] = true + if _, ok := modelParams[key]; ok { + needsCopy = true + } + } + if !needsCopy { + return modelParams + } + + out := make(map[string]any, len(modelParams)) + for key, value := range modelParams { + if !dropKeys[key] { + out[key] = value + } + } + return out +} + +// atomicRequestKeys are request-body keys whose value the user's model-params +// must fully replace, not deep-merge. Merging these produces payloads the API +// rejects (e.g. Anthropic tool_choice {"type":"auto","name":"x"} is invalid). +var atomicRequestKeys = map[string]bool{ + "tool_choice": true, + "response_format": true, +} + +func mergeRequestParams(dst, src map[string]any) { + for k, v := range src { + if atomicRequestKeys[k] { + dst[k] = v + continue + } + + existing, ok := dst[k] + if !ok { + dst[k] = v + continue + } + + existingMap, existingIsMap := existing.(map[string]any) + srcMap, srcIsMap := v.(map[string]any) + if existingIsMap && srcIsMap { + mergeRequestParams(existingMap, srcMap) + dst[k] = existingMap + continue + } + + dst[k] = v + } +} + +func (c *httpClient) parseAnthropicResponse(body []byte) (*ChatResponse, time.Duration, error) { + var resp anthropicResponse + if err := json.Unmarshal(body, &resp); err != nil { + return nil, 0, fmt.Errorf("llm: parsing Anthropic response: %w", err) + } + return mapAnthropicResponse(&resp), 0, nil +} + +func mapAnthropicResponse(resp *anthropicResponse) *ChatResponse { + // Extract content. Prefer tool_use blocks: when the request defined tools, + // the tool_use input is the authoritative structured output. Thinking-enabled + // models (and models run with tool_choice=auto) commonly emit a leading text + // block alongside the tool_use — that text is explanatory prose, not the + // answer. Picking text first would make the parser try to JSON-decode + // "Looking at this code..." and fail. + var content string + for _, block := range resp.Content { + if block.Type == "tool_use" { + content = string(block.Input) + break + } + } + // Fall back to text blocks when no tool_use was emitted (e.g. non-tool mode, + // or the model declined to call the tool under auto). + if content == "" { + for _, block := range resp.Content { + if block.Type == "text" { + content = block.Text + break + } + } + } + + // Map Anthropic stop_reason to OpenAI finish_reason. + finishReason := resp.StopReason + switch finishReason { + case "end_turn": + finishReason = "stop" + case "max_tokens": + finishReason = "length" + case "tool_use": + finishReason = "stop" + } + + return &ChatResponse{ + Content: content, + Usage: TokenUsage{ + PromptTokens: resp.Usage.InputTokens, + CompletionTokens: resp.Usage.OutputTokens, + CacheCreationTokens: resp.Usage.CacheCreationInputTokens, + CacheReadTokens: resp.Usage.CacheReadInputTokens, + }, + FinishReason: finishReason, + Model: resp.Model, + } +} + +// readAnthropicStream consumes an SSE response from /v1/messages and +// reconstructs the same anthropicResponse the non-streaming path would have +// received, so callers see no difference. +// +// Anthropic's stream is a sequence of typed events. The ones we act on: +// +// message_start carries model name and usage.input_tokens +// content_block_start opens a text / tool_use / thinking block +// content_block_delta appends text_delta.text or input_json_delta.partial_json +// content_block_stop finalises the current block (commit tool_use input JSON) +// message_delta carries stop_reason and usage.output_tokens +// message_stop end of stream +// ping keepalive — ignore (these are why we stream at all) +// error mid-stream failure — propagate as retryable +// +// thinking blocks are dropped: mapAnthropicResponse only reads text and +// tool_use, so there's no point accumulating them. +func (c *httpClient) readAnthropicStream(body io.Reader) (*ChatResponse, time.Duration, error) { + streamStart := time.Now() + var firstDeltaAt time.Time // zero until the first content_block_delta + var thinkingChars int + + sc := bufio.NewScanner(body) + sc.Buffer(make([]byte, 0, 64*1024), 1024*1024) + + var resp anthropicResponse + var cur *anthropicBlock // nil between content_block_start/stop + var toolInput strings.Builder // accumulates input_json_delta fragments + var ev struct { // reused per data: line; Decode zeroes it + Type string `json:"type"` + Message anthropicResponse `json:"message"` // message_start + Index int `json:"index"` // content_block_* + ContentBlock anthropicBlock `json:"content_block"` // content_block_start + Delta struct { + Type string `json:"type"` + Text string `json:"text"` // text_delta + PartialJSON string `json:"partial_json"` // input_json_delta + Thinking string `json:"thinking"` // thinking_delta + StopReason string `json:"stop_reason"` // message_delta.delta + } `json:"delta"` + Usage anthropicUsage `json:"usage"` // message_delta (top-level, not in delta) + Error anthropicError `json:"error"` + } + + flush := func() { + if cur == nil { + return + } + if cur.Type == "tool_use" { + // partial_json fragments concatenate to a complete JSON document. + // Leave validation to the caller that actually parses it. + cur.Input = json.RawMessage(toolInput.String()) + } + resp.Content = append(resp.Content, *cur) + cur = nil + toolInput.Reset() + } + + for sc.Scan() { + line := sc.Bytes() + // SSE framing: "event:" lines name the event, "data:" lines carry the + // payload. The payload's own "type" field duplicates the event name, + // so we key off that and ignore the event: line entirely. + if !bytes.HasPrefix(line, []byte("data:")) { + continue + } + payload := bytes.TrimSpace(line[len("data:"):]) + if len(payload) == 0 { + continue + } + + if err := json.Unmarshal(payload, &ev); err != nil { + return nil, 0, fmt.Errorf("llm: parsing stream event: %w", err) + } + + switch ev.Type { + case "message_start": + resp.ID = ev.Message.ID + resp.Model = ev.Message.Model + resp.Usage.InputTokens = ev.Message.Usage.InputTokens + resp.Usage.CacheCreationInputTokens = ev.Message.Usage.CacheCreationInputTokens + resp.Usage.CacheReadInputTokens = ev.Message.Usage.CacheReadInputTokens + + case "content_block_start": + flush() + if ev.ContentBlock.Type == "thinking" { + break // don't track; next start/stop will flush nil harmlessly + } + b := ev.ContentBlock + cur = &b + + case "content_block_delta": + if firstDeltaAt.IsZero() { + firstDeltaAt = time.Now() + } + if ev.Delta.Type == "thinking_delta" { + thinkingChars += len(ev.Delta.Thinking) + break // block itself isn't tracked, but count the volume + } + if cur == nil { + break // delta for a block we chose not to track + } + switch ev.Delta.Type { + case "text_delta": + cur.Text += ev.Delta.Text + case "input_json_delta": + toolInput.WriteString(ev.Delta.PartialJSON) + } + + case "content_block_stop": + flush() + + case "message_delta": + resp.StopReason = ev.Delta.StopReason + resp.Usage.OutputTokens = ev.Usage.OutputTokens + + case "message_stop": + flush() + out := mapAnthropicResponse(&resp) + out.Usage.ThinkingChars = thinkingChars + if !firstDeltaAt.IsZero() { + out.TimeToFirstToken = firstDeltaAt.Sub(streamStart) + out.GenerationTime = time.Since(firstDeltaAt) + } + return out, 0, nil + + case "error": + // overloaded_error mid-stream is the common case; plain error + // keeps it retryable via the ChatCompletion loop. + return nil, 0, fmt.Errorf("llm: stream error (%s): %s", ev.Error.Type, ev.Error.Message) + } + } + + if err := sc.Err(); err != nil { + return nil, 0, fmt.Errorf("llm: reading stream: %w", err) + } + // Scanner hit EOF without message_stop — the connection was cut. Surface + // what we have so the retry log is informative, but don't return a partial + // result: tool_use input is likely truncated mid-JSON. + return nil, 0, fmt.Errorf("llm: stream ended without message_stop (got %d blocks, stop_reason=%q)", len(resp.Content), resp.StopReason) +} + +func isToolChoiceIncompatibleError(body string) bool { + lower := strings.ToLower(body) + if !strings.Contains(lower, "tool_choice") { + return false + } + if strings.Contains(lower, "not compatible with this model") { + return true + } + // Anthropic rejects thinking combined with tool_choice forcing tool use + // (type=any, type=tool, or a named tool). The retry path disables forced + // tool_choice and re-issues the request with tool_choice=auto, which is + // accepted alongside thinking. + if strings.Contains(lower, "thinking") && strings.Contains(lower, "forces tool use") { + return true + } + return strings.Contains(lower, "does not support") && strings.Contains(lower, "tool") +} + +func isTemperatureIncompatibleError(body string) bool { + lower := strings.ToLower(body) + if !strings.Contains(lower, "temperature") { + return false + } + return strings.Contains(lower, "may only be set") || + strings.Contains(lower, "must be") || + strings.Contains(lower, "unsupported value") || + strings.Contains(lower, "does not support") || + strings.Contains(lower, "not supported") || + strings.Contains(lower, "deprecated") +} + +// isMaxTokensParamIncompatibleError detects OpenAI GPT-5 / reasoning-model +// errors that require `max_completion_tokens` in place of `max_tokens`. +// Example message: "Unsupported parameter: 'max_tokens' is not supported +// with this model. Use 'max_completion_tokens' instead." +func isMaxTokensParamIncompatibleError(body string) bool { + lower := strings.ToLower(body) + return strings.Contains(lower, "max_tokens") && + strings.Contains(lower, "max_completion_tokens") +} + +// backoff sleeps for an exponentially increasing duration with jitter. +// If retryAfter > 0, it respects the server-specified delay instead. +func (c *httpClient) backoff(ctx context.Context, attempt int, retryAfter time.Duration) { + if c.backoffFunc != nil { + c.backoffFunc(ctx, attempt, retryAfter) + return + } + + var delay time.Duration + if retryAfter > 0 { + delay = retryAfter + } else { + // Exponential backoff: 1s, 2s, 4s, 8s, capped at 30s. + base := time.Duration(math.Pow(2, float64(attempt))) * time.Second + if base > 30*time.Second { + base = 30 * time.Second + } + // Add jitter: 0-25% of the base delay. + jitter := time.Duration(rand.Float64() * 0.25 * float64(base)) + delay = base + jitter + } + + c.logger.Debug("backing off before retry", + "delay", delay, + "attempt", attempt, + ) + + timer := time.NewTimer(delay) + defer timer.Stop() + select { + case <-timer.C: + case <-ctx.Done(): + } +} + +// parseRetryAfter parses the Retry-After HTTP header value. +// Supports both seconds-based values (e.g., "5") and HTTP-date values. +func parseRetryAfter(header string) time.Duration { + if header == "" { + return 0 + } + // Try parsing as integer seconds. + if secs, err := strconv.Atoi(header); err == nil && secs > 0 { + return time.Duration(secs) * time.Second + } + // Try parsing as HTTP-date. + if t, err := http.ParseTime(header); err == nil { + delay := time.Until(t) + if delay > 0 { + return delay + } + } + return 0 +} + +// truncate shortens a string to the given maximum length. +func truncate(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] + "..." +} diff --git a/internal/llm/client_test.go b/internal/llm/client_test.go new file mode 100644 index 0000000..763e30a --- /dev/null +++ b/internal/llm/client_test.go @@ -0,0 +1,1229 @@ +package llm + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" +) + +// successResponse builds a valid OpenAI-compatible JSON response body. +func successResponse(content, model, finishReason string, promptTok, completionTok int) string { + return fmt.Sprintf(`{ + "choices": [{"message": {"content": %q}, "finish_reason": %q}], + "usage": {"prompt_tokens": %d, "completion_tokens": %d}, + "model": %q + }`, content, finishReason, promptTok, completionTok, model) +} + +func TestChatCompletion_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, successResponse("hello world", "gpt-4", "stop", 10, 5)) + })) + defer srv.Close() + + c := NewClient(ClientConfig{ + BaseURL: srv.URL, + Token: "test-token", + MaxRetries: 3, + Timeout: 5 * time.Second, + }) + + resp, err := c.ChatCompletion(context.Background(), ChatRequest{ + Messages: []Message{{Role: "user", Content: "hi"}}, + Temperature: 0.5, + MaxTokens: 100, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Content != "hello world" { + t.Errorf("Content = %q, want %q", resp.Content, "hello world") + } + if resp.Model != "gpt-4" { + t.Errorf("Model = %q, want %q", resp.Model, "gpt-4") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } + if resp.Usage.PromptTokens != 10 { + t.Errorf("PromptTokens = %d, want 10", resp.Usage.PromptTokens) + } + if resp.Usage.CompletionTokens != 5 { + t.Errorf("CompletionTokens = %d, want 5", resp.Usage.CompletionTokens) + } +} + +func TestChatCompletion_BearerTokenAuth(t *testing.T) { + var gotAuth string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, successResponse("ok", "gpt-4", "stop", 1, 1)) + })) + defer srv.Close() + + c := NewClient(ClientConfig{ + BaseURL: srv.URL, + Token: "my-secret-token", + Timeout: 5 * time.Second, + }) + + _, err := c.ChatCompletion(context.Background(), ChatRequest{ + Messages: []Message{{Role: "user", Content: "hi"}}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + want := "Bearer my-secret-token" + if gotAuth != want { + t.Errorf("Authorization = %q, want %q", gotAuth, want) + } +} + +func TestChatCompletion_AnthropicCustomHeaders(t *testing.T) { + var gotBeta []string + var gotVersion string + var gotFeature string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotBeta = r.Header.Values("Anthropic-Beta") + gotVersion = r.Header.Get("Anthropic-Version") + gotFeature = r.Header.Get("X-Feature-Flag") + + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{ + "id":"msg_1", + "content":[{"type":"text","text":"ok"}], + "model":"claude-opus-4-6", + "stop_reason":"end_turn", + "usage":{"input_tokens":10,"output_tokens":5} + }`) + })) + defer srv.Close() + + c := NewClient(ClientConfig{ + BaseURL: srv.URL, + Provider: "anthropic", + Token: "anthropic-key", + Headers: http.Header{ + "Anthropic-Beta": []string{"context-1m-2025-08-07", "other-beta"}, + "Anthropic-Version": []string{"2099-01-01"}, + "X-Feature-Flag": []string{"enabled"}, + }, + Timeout: 5 * time.Second, + }) + + resp, err := c.ChatCompletion(context.Background(), ChatRequest{ + Model: "claude-opus-4-6", + Messages: []Message{{Role: "user", Content: "hi"}}, + Temperature: 0, + MaxTokens: 64, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Content != "ok" { + t.Fatalf("Content = %q, want %q", resp.Content, "ok") + } + + if len(gotBeta) != 2 { + t.Fatalf("Anthropic-Beta values = %v, want 2 values", gotBeta) + } + if gotBeta[0] != "context-1m-2025-08-07" || gotBeta[1] != "other-beta" { + t.Fatalf("Anthropic-Beta values = %v, want [context-1m-2025-08-07 other-beta]", gotBeta) + } + if gotVersion != "2099-01-01" { + t.Fatalf("Anthropic-Version = %q, want %q", gotVersion, "2099-01-01") + } + if gotFeature != "enabled" { + t.Fatalf("X-Feature-Flag = %q, want %q", gotFeature, "enabled") + } +} + +func TestChatCompletion_AnthropicStream(t *testing.T) { + var gotStream bool + + sse := func(w http.ResponseWriter, event, data string) { + fmt.Fprintf(w, "event: %s\ndata: %s\n\n", event, data) + w.(http.Flusher).Flush() + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body map[string]any + json.NewDecoder(r.Body).Decode(&body) + gotStream, _ = body["stream"].(bool) + + w.Header().Set("Content-Type", "text/event-stream; charset=utf-8") + w.WriteHeader(http.StatusOK) + + sse(w, "message_start", `{"type":"message_start","message":{"id":"msg_1","model":"claude-opus-4-6","stop_reason":null,"usage":{"input_tokens":42,"output_tokens":0}}}`) + sse(w, "ping", `{"type":"ping"}`) + // Thinking block: must be skipped, not treated as content. + sse(w, "content_block_start", `{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`) + sse(w, "content_block_delta", `{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":"hmm"}}`) + sse(w, "content_block_stop", `{"type":"content_block_stop","index":0}`) + // Text block split across deltas. + sse(w, "content_block_start", `{"type":"content_block_start","index":1,"content_block":{"type":"text","text":""}}`) + sse(w, "content_block_delta", `{"type":"content_block_delta","index":1,"delta":{"type":"text_delta","text":"hel"}}`) + sse(w, "ping", `{"type":"ping"}`) + sse(w, "content_block_delta", `{"type":"content_block_delta","index":1,"delta":{"type":"text_delta","text":"lo"}}`) + sse(w, "content_block_stop", `{"type":"content_block_stop","index":1}`) + sse(w, "message_delta", `{"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":7}}`) + sse(w, "message_stop", `{"type":"message_stop"}`) + })) + defer srv.Close() + + c := NewClient(ClientConfig{ + BaseURL: srv.URL, + Provider: "anthropic", + Token: "k", + Timeout: 5 * time.Second, + }) + + resp, err := c.ChatCompletion(context.Background(), ChatRequest{ + Model: "claude-opus-4-6", + Messages: []Message{{Role: "user", Content: "hi"}}, + MaxTokens: 64, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !gotStream { + t.Fatal("request body did not include stream:true") + } + if resp.Content != "hello" { + t.Fatalf("Content = %q, want %q", resp.Content, "hello") + } + if resp.FinishReason != "stop" { + t.Fatalf("FinishReason = %q, want %q (end_turn should map to stop)", resp.FinishReason, "stop") + } + if resp.Usage.PromptTokens != 42 || resp.Usage.CompletionTokens != 7 { + t.Fatalf("Usage = %+v, want {42 7}", resp.Usage) + } +} + +func TestChatCompletion_AnthropicStream_ToolUse(t *testing.T) { + sse := func(w http.ResponseWriter, event, data string) { + fmt.Fprintf(w, "event: %s\ndata: %s\n\n", event, data) + w.(http.Flusher).Flush() + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + + sse(w, "message_start", `{"type":"message_start","message":{"id":"msg_1","model":"m","usage":{"input_tokens":10,"output_tokens":0}}}`) + // tool_use input arrives as partial_json fragments that must concatenate + // to valid JSON. Split mid-token to prove we don't parse incrementally. + sse(w, "content_block_start", `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"tu_1","name":"security_analysis","input":{}}}`) + sse(w, "content_block_delta", `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"{\"find"}}`) + sse(w, "content_block_delta", `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"ings\":["}}`) + sse(w, "content_block_delta", `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"1,2]}"}}`) + sse(w, "content_block_stop", `{"type":"content_block_stop","index":0}`) + sse(w, "message_delta", `{"type":"message_delta","delta":{"stop_reason":"tool_use"},"usage":{"output_tokens":3}}`) + sse(w, "message_stop", `{"type":"message_stop"}`) + })) + defer srv.Close() + + c := NewClient(ClientConfig{BaseURL: srv.URL, Provider: "anthropic", Token: "k", Timeout: 5 * time.Second}) + schema := json.RawMessage(`{"type":"object"}`) + resp, err := c.ChatCompletion(context.Background(), ChatRequest{ + Model: "m", + Messages: []Message{{Role: "user", Content: "go"}}, + MaxTokens: 64, + OutputMode: OutputModeToolUse, + ResponseSchema: &schema, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Content != `{"findings":[1,2]}` { + t.Fatalf("Content = %q, want reassembled tool input", resp.Content) + } + if resp.FinishReason != "stop" { + t.Fatalf("FinishReason = %q, want %q (tool_use should map to stop)", resp.FinishReason, "stop") + } +} + +func TestChatCompletion_AnthropicStream_Truncated(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, "event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"m\",\"model\":\"m\",\"usage\":{\"input_tokens\":1,\"output_tokens\":0}}}\n\n") + fmt.Fprint(w, "event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n") + // hang up mid-stream — no message_stop + })) + defer srv.Close() + + c := NewClient(ClientConfig{BaseURL: srv.URL, Provider: "anthropic", Token: "k", Timeout: 5 * time.Second, MaxRetries: 0}) + _, err := c.ChatCompletion(context.Background(), ChatRequest{ + Model: "m", + Messages: []Message{{Role: "user", Content: "x"}}, + MaxTokens: 64, + }) + if err == nil { + t.Fatal("expected error on truncated stream, got nil") + } + if !strings.Contains(err.Error(), "message_stop") { + t.Fatalf("error = %q, want mention of message_stop", err) + } +} + +func TestChatCompletion_FinishReasonLength(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, successResponse("partial", "gpt-4", "length", 100, 50)) + })) + defer srv.Close() + + c := NewClient(ClientConfig{BaseURL: srv.URL, Token: "t", Timeout: 5 * time.Second}) + + resp, err := c.ChatCompletion(context.Background(), ChatRequest{ + Messages: []Message{{Role: "user", Content: "hi"}}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.FinishReason != "length" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "length") + } +} + +func TestChatCompletion_TokenUsageTracked(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, successResponse("ok", "gpt-4", "stop", 250, 75)) + })) + defer srv.Close() + + c := NewClient(ClientConfig{BaseURL: srv.URL, Token: "t", Timeout: 5 * time.Second}) + + resp, err := c.ChatCompletion(context.Background(), ChatRequest{ + Messages: []Message{{Role: "user", Content: "hi"}}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Usage.PromptTokens != 250 { + t.Errorf("PromptTokens = %d, want 250", resp.Usage.PromptTokens) + } + if resp.Usage.CompletionTokens != 75 { + t.Errorf("CompletionTokens = %d, want 75", resp.Usage.CompletionTokens) + } +} + +func TestChatCompletion_Retry429(t *testing.T) { + var attempts atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := attempts.Add(1) + if n <= 2 { + w.WriteHeader(http.StatusTooManyRequests) + fmt.Fprint(w, `{"error": "rate limited"}`) + return + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, successResponse("ok", "gpt-4", "stop", 1, 1)) + })) + defer srv.Close() + + // Use a custom client with very short backoff by replacing the httpClient's backoff. + // Since we can't easily replace backoff, we rely on the real implementation but with + // a short enough test that the exponential backoff (1s, 2s) completes. + // Instead, we'll directly access the httpClient and override its backoff timing + // by creating a wrapper. The simplest approach: just verify it eventually succeeds. + c := NewClient(ClientConfig{ + BaseURL: srv.URL, + Token: "t", + MaxRetries: 3, + Timeout: 30 * time.Second, + }) + + // Override the internal client to have a very short timeout so backoff doesn't block. + hc := c.(*httpClient) + hc.backoffFunc = func(_ context.Context, _ int, _ time.Duration) {} // no-op backoff for test speed + + resp, err := hc.ChatCompletion(context.Background(), ChatRequest{ + Messages: []Message{{Role: "user", Content: "hi"}}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Content != "ok" { + t.Errorf("Content = %q, want %q", resp.Content, "ok") + } + if int(attempts.Load()) != 3 { + t.Errorf("attempts = %d, want 3", attempts.Load()) + } +} + +func TestChatCompletion_Retry5xx(t *testing.T) { + statusCodes := []int{500, 502, 503} + for _, code := range statusCodes { + t.Run(fmt.Sprintf("status_%d", code), func(t *testing.T) { + var attempts atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := attempts.Add(1) + if n == 1 { + w.WriteHeader(code) + fmt.Fprint(w, `{"error": "server error"}`) + return + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, successResponse("recovered", "gpt-4", "stop", 1, 1)) + })) + defer srv.Close() + + c := NewClient(ClientConfig{ + BaseURL: srv.URL, + Token: "t", + MaxRetries: 3, + Timeout: 30 * time.Second, + }) + hc := c.(*httpClient) + hc.backoffFunc = func(_ context.Context, _ int, _ time.Duration) {} + + resp, err := hc.ChatCompletion(context.Background(), ChatRequest{ + Messages: []Message{{Role: "user", Content: "hi"}}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Content != "recovered" { + t.Errorf("Content = %q, want %q", resp.Content, "recovered") + } + }) + } +} + +func TestChatCompletion_RetryAfterHeader(t *testing.T) { + var attempts atomic.Int32 + var observedDelay time.Duration + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := attempts.Add(1) + if n == 1 { + w.Header().Set("Retry-After", "3") + w.WriteHeader(http.StatusTooManyRequests) + fmt.Fprint(w, `{"error": "rate limited"}`) + return + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, successResponse("ok", "gpt-4", "stop", 1, 1)) + })) + defer srv.Close() + + c := NewClient(ClientConfig{ + BaseURL: srv.URL, + Token: "t", + MaxRetries: 3, + Timeout: 30 * time.Second, + }) + hc := c.(*httpClient) + hc.backoffFunc = func(_ context.Context, _ int, retryAfter time.Duration) { + observedDelay = retryAfter + } + + resp, err := hc.ChatCompletion(context.Background(), ChatRequest{ + Messages: []Message{{Role: "user", Content: "hi"}}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Content != "ok" { + t.Errorf("Content = %q, want %q", resp.Content, "ok") + } + if observedDelay != 3*time.Second { + t.Errorf("retry-after delay = %v, want 3s", observedDelay) + } +} + +func TestChatCompletion_MaxRetriesExhausted(t *testing.T) { + var attempts atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts.Add(1) + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprint(w, `{"error": "always failing"}`) + })) + defer srv.Close() + + c := NewClient(ClientConfig{ + BaseURL: srv.URL, + Token: "t", + MaxRetries: 3, + Timeout: 30 * time.Second, + }) + hc := c.(*httpClient) + hc.backoffFunc = func(_ context.Context, _ int, _ time.Duration) {} + + _, err := hc.ChatCompletion(context.Background(), ChatRequest{ + Messages: []Message{{Role: "user", Content: "hi"}}, + }) + if err == nil { + t.Fatal("expected error after max retries, got nil") + } + if !strings.Contains(err.Error(), "exhausted 3 retries") { + t.Errorf("error = %q, want it to contain 'exhausted 3 retries'", err.Error()) + } + // 1 initial + 3 retries = 4 total attempts + if int(attempts.Load()) != 4 { + t.Errorf("attempts = %d, want 4", attempts.Load()) + } +} + +func TestChatCompletion_ContextLengthExceeded(t *testing.T) { + tests := []struct { + name string + respBody string + }{ + { + name: "code_field", + respBody: `{"error": {"message": "some error", "type": "invalid_request_error", "code": "context_length_exceeded"}}`, + }, + { + name: "message_field", + respBody: `{"error": {"message": "context_length_exceeded: too many tokens", "type": "invalid_request_error", "code": ""}}`, + }, + { + name: "maximum_context_length", + respBody: `{"error": {"message": "This model's maximum context length is 8192 tokens", "type": "invalid_request_error", "code": ""}}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprint(w, tt.respBody) + })) + defer srv.Close() + + c := NewClient(ClientConfig{ + BaseURL: srv.URL, + Token: "t", + MaxRetries: 3, + Timeout: 5 * time.Second, + }) + + _, err := c.ChatCompletion(context.Background(), ChatRequest{ + Messages: []Message{{Role: "user", Content: "hi"}}, + }) + if err != ErrContextLengthExceeded { + t.Errorf("err = %v, want ErrContextLengthExceeded", err) + } + }) + } +} + +func TestChatCompletion_Timeout(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(3 * time.Second) + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, successResponse("late", "gpt-4", "stop", 1, 1)) + })) + defer srv.Close() + + c := NewClient(ClientConfig{ + BaseURL: srv.URL, + Token: "t", + MaxRetries: 1, + Timeout: 200 * time.Millisecond, + }) + hc := c.(*httpClient) + hc.maxRetries = 0 // no retries — just fail on first timeout + + _, err := hc.ChatCompletion(context.Background(), ChatRequest{ + Messages: []Message{{Role: "user", Content: "hi"}}, + }) + if err == nil { + t.Fatal("expected timeout error, got nil") + } +} + +func TestChatCompletion_ContextCancellation(t *testing.T) { + var attempts atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts.Add(1) + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprint(w, `{"error": "fail"}`) + })) + defer srv.Close() + + ctx, cancel := context.WithCancel(context.Background()) + + c := NewClient(ClientConfig{ + BaseURL: srv.URL, + Token: "t", + MaxRetries: 10, + Timeout: 30 * time.Second, + }) + hc := c.(*httpClient) + hc.backoffFunc = func(ctx context.Context, _ int, _ time.Duration) { + // Cancel context during backoff to simulate cancellation stopping retry loop + cancel() + // Wait for context to be done + <-ctx.Done() + } + + _, err := hc.ChatCompletion(ctx, ChatRequest{ + Messages: []Message{{Role: "user", Content: "hi"}}, + }) + if err == nil { + t.Fatal("expected error from cancelled context, got nil") + } + if !strings.Contains(err.Error(), "context") { + t.Errorf("error = %q, want it to mention 'context'", err.Error()) + } +} + +func TestChatCompletion_ResponseFormatJSONSchema(t *testing.T) { + var receivedBody []byte + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedBody, _ = io.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, successResponse("ok", "gpt-4", "stop", 1, 1)) + })) + defer srv.Close() + + schema := json.RawMessage(`{"type":"object","properties":{"name":{"type":"string"}}}`) + c := NewClient(ClientConfig{BaseURL: srv.URL, Token: "t", Timeout: 5 * time.Second}) + + _, err := c.ChatCompletion(context.Background(), ChatRequest{ + Messages: []Message{{Role: "user", Content: "hi"}}, + OutputMode: OutputModeJSONSchema, + ResponseSchema: &schema, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var reqBody map[string]json.RawMessage + if err := json.Unmarshal(receivedBody, &reqBody); err != nil { + t.Fatalf("failed to parse request body: %v", err) + } + + rfRaw, ok := reqBody["response_format"] + if !ok { + t.Fatal("request body missing 'response_format' field") + } + + var rf map[string]interface{} + if err := json.Unmarshal(rfRaw, &rf); err != nil { + t.Fatalf("failed to parse response_format: %v", err) + } + if rf["type"] != "json_schema" { + t.Errorf("response_format.type = %v, want 'json_schema'", rf["type"]) + } + if rf["json_schema"] == nil { + t.Error("response_format.json_schema is nil, want schema") + } +} + +func TestChatCompletion_ModelParamsMergedIntoRequest(t *testing.T) { + var receivedBody []byte + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedBody, _ = io.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, successResponse("ok", "gpt-4", "stop", 1, 1)) + })) + defer srv.Close() + + c := NewClient(ClientConfig{BaseURL: srv.URL, Token: "t", Timeout: 5 * time.Second}) + + _, err := c.ChatCompletion(context.Background(), ChatRequest{ + Messages: []Message{{Role: "user", Content: "hi"}}, + ModelParams: map[string]any{ + "thinking": map[string]any{"type": "enabled", "budget_tokens": 2048}, + }, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var reqBody map[string]any + if err := json.Unmarshal(receivedBody, &reqBody); err != nil { + t.Fatalf("failed to parse request body: %v", err) + } + thinkingRaw, ok := reqBody["thinking"] + if !ok { + t.Fatalf("request missing thinking model param: %v", reqBody) + } + thinking, ok := thinkingRaw.(map[string]any) + if !ok { + t.Fatalf("thinking param is %T, want object", thinkingRaw) + } + if thinking["type"] != "enabled" { + t.Errorf("thinking.type = %v, want enabled", thinking["type"]) + } +} + +func TestChatCompletion_RetriesOpenAIWithoutTemperature(t *testing.T) { + var attempts atomic.Int32 + var retriedBody []byte + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if attempts.Add(1) == 1 { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + fmt.Fprint(w, `{"error":{"message":"Unsupported value: 'temperature' does not support 0 with this model. Only the default (1) value is supported.","type":"invalid_request_error","param":"temperature"}}`) + return + } + + retriedBody, _ = io.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, successResponse("ok", "gpt-5.5", "stop", 1, 1)) + })) + defer srv.Close() + + modelParams := map[string]any{"temperature": 0.0} + c := NewClient(ClientConfig{ + BaseURL: srv.URL, + Provider: "openai", + Token: "t", + Timeout: 5 * time.Second, + }) + + resp, err := c.ChatCompletion(context.Background(), ChatRequest{ + Model: "gpt-5.5", + Messages: []Message{{Role: "user", Content: "hi"}}, + Temperature: 0.0, + MaxTokens: 100, + ModelParams: modelParams, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Content != "ok" { + t.Fatalf("Content = %q, want ok", resp.Content) + } + if got := attempts.Load(); got != 2 { + t.Fatalf("attempts = %d, want 2", got) + } + + var reqBody map[string]any + if err := json.Unmarshal(retriedBody, &reqBody); err != nil { + t.Fatalf("failed to parse retried request body: %v", err) + } + if _, ok := reqBody["temperature"]; ok { + t.Fatalf("retried request still contains temperature: %v", reqBody["temperature"]) + } + if _, ok := modelParams["temperature"]; !ok { + t.Fatal("request retry mutated caller-owned model params") + } +} + +func TestChatCompletion_ModelParamsDeepMerge(t *testing.T) { + var receivedBody []byte + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedBody, _ = io.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, successResponse("ok", "gpt-4", "stop", 1, 1)) + })) + defer srv.Close() + + schema := json.RawMessage(`{"name":"security_analysis","strict":true,"schema":{"type":"object","properties":{"result":{"type":"string"}}}}`) + c := NewClient(ClientConfig{BaseURL: srv.URL, Token: "t", Timeout: 5 * time.Second}) + + _, err := c.ChatCompletion(context.Background(), ChatRequest{ + Messages: []Message{{Role: "user", Content: "hi"}}, + OutputMode: OutputModeToolUse, + ResponseSchema: &schema, + ModelParams: map[string]any{ + "tool_choice": map[string]any{"type": "auto"}, + }, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var reqBody map[string]any + if err := json.Unmarshal(receivedBody, &reqBody); err != nil { + t.Fatalf("failed to parse request body: %v", err) + } + tcRaw, ok := reqBody["tool_choice"] + if !ok { + t.Fatalf("request missing tool_choice: %v", reqBody) + } + toolChoice, ok := tcRaw.(map[string]any) + if !ok { + t.Fatalf("tool_choice is %T, want object", tcRaw) + } + if toolChoice["type"] != "auto" { + t.Errorf("tool_choice.type = %v, want auto", toolChoice["type"]) + } + // tool_choice is an atomic key: model-params must fully replace the base + // value, not deep-merge into it (merged values are rejected by the API). + if _, hasFn := toolChoice["function"]; hasFn { + t.Errorf("tool_choice.function should not survive replacement: %v", toolChoice) + } + if _, hasName := toolChoice["name"]; hasName { + t.Errorf("tool_choice.name should not survive replacement: %v", toolChoice) + } +} + +func TestChatCompletion_ToolUseForClaude(t *testing.T) { + var receivedBody []byte + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedBody, _ = io.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/json") + // Respond with tool_call content + fmt.Fprint(w, `{ + "choices": [{"message": {"content": "", "tool_calls": [{"function": {"arguments": "{\"result\":\"ok\"}"}}]}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 1, "completion_tokens": 1}, + "model": "claude-3" + }`) + })) + defer srv.Close() + + schema := json.RawMessage(`{"type":"object","properties":{"result":{"type":"string"}}}`) + c := NewClient(ClientConfig{BaseURL: srv.URL, Token: "t", Timeout: 5 * time.Second}) + + resp, err := c.ChatCompletion(context.Background(), ChatRequest{ + Messages: []Message{{Role: "user", Content: "hi"}}, + OutputMode: OutputModeToolUse, + ResponseSchema: &schema, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify tool_calls arguments were extracted as content + if resp.Content != `{"result":"ok"}` { + t.Errorf("Content = %q, want %q", resp.Content, `{"result":"ok"}`) + } + + var reqBody map[string]json.RawMessage + if err := json.Unmarshal(receivedBody, &reqBody); err != nil { + t.Fatalf("failed to parse request body: %v", err) + } + + toolsRaw, ok := reqBody["tools"] + if !ok { + t.Fatal("request body missing 'tools' field") + } + + var tools []map[string]interface{} + if err := json.Unmarshal(toolsRaw, &tools); err != nil { + t.Fatalf("failed to parse tools: %v", err) + } + if len(tools) != 1 { + t.Fatalf("expected 1 tool, got %d", len(tools)) + } + if tools[0]["type"] != "function" { + t.Errorf("tool type = %v, want 'function'", tools[0]["type"]) + } + fn, hasFn := tools[0]["function"].(map[string]interface{}) + if !hasFn { + t.Fatal("tool missing 'function' field (OpenAI format)") + } + if fn["name"] != "security_analysis" { + t.Errorf("function name = %v, want 'security_analysis'", fn["name"]) + } + if _, hasParams := fn["parameters"]; !hasParams { + t.Error("function missing 'parameters' field") + } + + // Verify tool_choice is set + if _, hasTC := reqBody["tool_choice"]; !hasTC { + t.Error("request body missing 'tool_choice' field") + } + + // Verify no response_format is set + if _, hasRF := reqBody["response_format"]; hasRF { + t.Error("request body should not have 'response_format' for tool_use mode") + } +} + +func TestChatCompletion_ToolUseFallbackWhenToolChoiceRejected(t *testing.T) { + var attempts int32 + var requestBodies [][]byte + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + requestBodies = append(requestBodies, append([]byte(nil), body...)) + + n := atomic.AddInt32(&attempts, 1) + if n == 1 { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprint(w, `{"error":{"message":"tool_choice forces tool use is not compatible with this model."}}`) + return + } + + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{ + "choices": [{"message": {"content": "", "tool_calls": [{"function": {"arguments": "{\"result\":\"ok\"}"}}]}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 1, "completion_tokens": 1}, + "model": "claude-sonnet-4-6" + }`) + })) + defer srv.Close() + + schema := json.RawMessage(`{"type":"object","properties":{"result":{"type":"string"}}}`) + c := NewClient(ClientConfig{BaseURL: srv.URL, Token: "t", Timeout: 5 * time.Second}) + + resp, err := c.ChatCompletion(context.Background(), ChatRequest{ + Messages: []Message{{Role: "user", Content: "hi"}}, + OutputMode: OutputModeToolUse, + ResponseSchema: &schema, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Content != `{"result":"ok"}` { + t.Errorf("Content = %q, want %q", resp.Content, `{"result":"ok"}`) + } + if attempts != 2 { + t.Fatalf("attempts = %d, want 2", attempts) + } + if len(requestBodies) != 2 { + t.Fatalf("expected 2 request bodies, got %d", len(requestBodies)) + } + + var first map[string]json.RawMessage + if err := json.Unmarshal(requestBodies[0], &first); err != nil { + t.Fatalf("failed to parse first request body: %v", err) + } + if _, hasToolChoice := first["tool_choice"]; !hasToolChoice { + t.Fatal("first request missing tool_choice") + } + + var second map[string]json.RawMessage + if err := json.Unmarshal(requestBodies[1], &second); err != nil { + t.Fatalf("failed to parse second request body: %v", err) + } + if _, hasToolChoice := second["tool_choice"]; hasToolChoice { + t.Fatal("second request should not include tool_choice fallback") + } + if _, hasTools := second["tools"]; !hasTools { + t.Fatal("second request should keep tools in fallback") + } +} + +func TestChatCompletion_UnstructuredFallback(t *testing.T) { + var receivedBody []byte + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedBody, _ = io.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, successResponse("unstructured output", "gemini-pro", "stop", 1, 1)) + })) + defer srv.Close() + + schema := json.RawMessage(`{"type":"object"}`) + c := NewClient(ClientConfig{BaseURL: srv.URL, Token: "t", Timeout: 5 * time.Second}) + + resp, err := c.ChatCompletion(context.Background(), ChatRequest{ + Messages: []Message{{Role: "user", Content: "hi"}}, + OutputMode: OutputModeNone, + ResponseSchema: &schema, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Content != "unstructured output" { + t.Errorf("Content = %q, want %q", resp.Content, "unstructured output") + } + + var reqBody map[string]json.RawMessage + if err := json.Unmarshal(receivedBody, &reqBody); err != nil { + t.Fatalf("failed to parse request body: %v", err) + } + + if _, hasRF := reqBody["response_format"]; hasRF { + t.Error("request body should not have 'response_format' for OutputModeNone") + } + if _, hasTools := reqBody["tools"]; hasTools { + t.Error("request body should not have 'tools' for OutputModeNone") + } +} + +func TestChatCompletion_EndpointURL(t *testing.T) { + tests := []struct { + name string + endpoint string + wantPath string + }{ + {"empty_endpoint", "", "/chat/completions"}, + {"named_endpoint", "my-model", "/my-model/invocations"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var gotPath string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, successResponse("ok", "gpt-4", "stop", 1, 1)) + })) + defer srv.Close() + + c := NewClient(ClientConfig{BaseURL: srv.URL, Token: "t", Timeout: 5 * time.Second}) + + _, err := c.ChatCompletion(context.Background(), ChatRequest{ + Endpoint: tt.endpoint, + Messages: []Message{{Role: "user", Content: "hi"}}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if gotPath != tt.wantPath { + t.Errorf("path = %q, want %q", gotPath, tt.wantPath) + } + }) + } +} + +func TestParseRetryAfter(t *testing.T) { + tests := []struct { + name string + header string + want time.Duration + }{ + {"empty", "", 0}, + {"seconds", "5", 5 * time.Second}, + {"zero", "0", 0}, + {"negative", "-1", 0}, + {"non_numeric", "abc", 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := parseRetryAfter(tt.header) + if got != tt.want { + t.Errorf("parseRetryAfter(%q) = %v, want %v", tt.header, got, tt.want) + } + }) + } +} + +// --- schema.go tests --- + +func TestOutputModeForModel(t *testing.T) { + tests := []struct { + name string + modelName string + want OutputMode + }{ + {"gpt4", "gpt-4", OutputModeJSONSchema}, + {"gpt4_turbo", "gpt-4-turbo", OutputModeJSONSchema}, + {"gpt35", "gpt-3.5-turbo", OutputModeJSONSchema}, + {"GPT_uppercase", "GPT-4", OutputModeJSONSchema}, + {"o1_model", "o1-preview", OutputModeJSONSchema}, + {"o3_model", "o3-mini", OutputModeJSONSchema}, + {"claude3", "claude-3-sonnet", OutputModeToolUse}, + {"claude_uppercase", "Claude-3-Opus", OutputModeToolUse}, + {"anthropic", "anthropic.claude-v2", OutputModeToolUse}, + {"gemini_pro", "gemini-pro", OutputModeJSONSchema}, + {"gemini_15", "gemini-1.5-pro", OutputModeJSONSchema}, + {"unknown_model", "llama-3-70b", OutputModeNone}, + {"empty_string", "", OutputModeNone}, + {"mixtral", "mixtral-8x7b", OutputModeNone}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := OutputModeForModel(tt.modelName) + if got != tt.want { + t.Errorf("OutputModeForModel(%q) = %v, want %v", tt.modelName, got, tt.want) + } + }) + } +} + +func TestSecurityAnalysisSchema_ValidJSON(t *testing.T) { + schema := SecurityAnalysisSchema() + if schema == nil { + t.Fatal("SecurityAnalysisSchema() returned nil") + } + + var parsed map[string]interface{} + if err := json.Unmarshal(*schema, &parsed); err != nil { + t.Fatalf("SecurityAnalysisSchema() returned invalid JSON: %v", err) + } + + if parsed["name"] != "security_analysis" { + t.Errorf("schema name = %v, want 'security_analysis'", parsed["name"]) + } + if parsed["strict"] != true { + t.Errorf("schema strict = %v, want true", parsed["strict"]) + } + + schemaObj, ok := parsed["schema"].(map[string]interface{}) + if !ok { + t.Fatal("schema.schema is not an object") + } + if schemaObj["type"] != "object" { + t.Errorf("schema.schema.type = %v, want 'object'", schemaObj["type"]) + } + + required, ok := schemaObj["required"].([]interface{}) + if !ok { + t.Fatal("schema.schema.required is not an array") + } + requiredFields := make(map[string]bool) + for _, f := range required { + requiredFields[f.(string)] = true + } + + expectedFields := []string{"repo_name", "description", "public_api_routes", "security_issues", "security_risk", "risk_justification"} + for _, f := range expectedFields { + if !requiredFields[f] { + t.Errorf("missing required field: %s", f) + } + } +} + +func TestExtractInnerSchema(t *testing.T) { + // Envelope format (used by response_format JSON Schema mode). + envelope := json.RawMessage(`{"name":"test","strict":true,"schema":{"type":"object","properties":{"x":{"type":"string"}}}}`) + inner := extractInnerSchema(&envelope) + var parsed map[string]interface{} + if err := json.Unmarshal(*inner, &parsed); err != nil { + t.Fatalf("failed to parse inner schema: %v", err) + } + if parsed["type"] != "object" { + t.Errorf("inner schema type = %v, want 'object'", parsed["type"]) + } + + // Already a raw schema (no envelope). + raw := json.RawMessage(`{"type":"object","properties":{}}`) + result := extractInnerSchema(&raw) + var parsed2 map[string]interface{} + if err := json.Unmarshal(*result, &parsed2); err != nil { + t.Fatalf("failed to parse raw schema: %v", err) + } + if parsed2["type"] != "object" { + t.Errorf("raw schema type = %v, want 'object'", parsed2["type"]) + } + + // Nil input. + if extractInnerSchema(nil) != nil { + t.Error("expected nil for nil input") + } +} + +func TestFlexString_PlainString(t *testing.T) { + body := `{"choices":[{"message":{"content":"hello"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1},"model":"gpt-4"}` + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, body) + })) + defer srv.Close() + + c := NewClient(ClientConfig{BaseURL: srv.URL, Token: "t", Timeout: 5 * time.Second}) + resp, err := c.ChatCompletion(context.Background(), ChatRequest{ + Messages: []Message{{Role: "user", Content: "hi"}}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Content != "hello" { + t.Errorf("Content = %q, want %q", resp.Content, "hello") + } +} + +func TestFlexString_ArrayOfParts(t *testing.T) { + // Gemini-style: content is an array of parts. + body := `{"choices":[{"message":{"content":[{"type":"text","text":"hello from gemini"}]},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1},"model":"gemini-3-pro"}` + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, body) + })) + defer srv.Close() + + c := NewClient(ClientConfig{BaseURL: srv.URL, Token: "t", Timeout: 5 * time.Second}) + resp, err := c.ChatCompletion(context.Background(), ChatRequest{ + Messages: []Message{{Role: "user", Content: "hi"}}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Content != "hello from gemini" { + t.Errorf("Content = %q, want %q", resp.Content, "hello from gemini") + } +} + +func TestFlexString_NullContent(t *testing.T) { + body := `{"choices":[{"message":{"content":null,"tool_calls":[{"function":{"arguments":"{\"ok\":true}"}}]},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1},"model":"gpt-4"}` + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, body) + })) + defer srv.Close() + + c := NewClient(ClientConfig{BaseURL: srv.URL, Token: "t", Timeout: 5 * time.Second}) + resp, err := c.ChatCompletion(context.Background(), ChatRequest{ + Messages: []Message{{Role: "user", Content: "hi"}}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Content != `{"ok":true}` { + t.Errorf("Content = %q, want tool_calls fallback", resp.Content) + } +} + +func TestChatCompletion_ParseErrorNotRetried(t *testing.T) { + var attempts atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts.Add(1) + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `not valid json at all`) + })) + defer srv.Close() + + c := NewClient(ClientConfig{ + BaseURL: srv.URL, + Token: "t", + MaxRetries: 3, + Timeout: 5 * time.Second, + }) + hc := c.(*httpClient) + hc.backoffFunc = func(_ context.Context, _ int, _ time.Duration) {} + + _, err := hc.ChatCompletion(context.Background(), ChatRequest{ + Messages: []Message{{Role: "user", Content: "hi"}}, + }) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "parsing response") { + t.Errorf("error = %q, want it to contain 'parsing response'", err.Error()) + } + // Should NOT retry — only 1 attempt. + if int(attempts.Load()) != 1 { + t.Errorf("attempts = %d, want 1 (no retries for parse errors)", attempts.Load()) + } +} + +func TestErrorTypes_ErrorAndUnwrap(t *testing.T) { + base := errors.New("underlying cause") + tests := []struct { + name string + err error + }{ + {"errNonRetryable", &errNonRetryable{err: base}}, + {"errToolChoiceIncompatible", &errToolChoiceIncompatible{err: base}}, + {"errTemperatureIncompatible", &errTemperatureIncompatible{err: base}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.err.Error(); got != "underlying cause" { + t.Errorf("Error() = %q, want %q", got, "underlying cause") + } + if !errors.Is(tt.err, base) { + t.Errorf("errors.Is() = false, want true (Unwrap should return wrapped error)") + } + if unwrapped := errors.Unwrap(tt.err); unwrapped != base { + t.Errorf("Unwrap() = %v, want %v", unwrapped, base) + } + }) + } +} diff --git a/internal/llm/prompt.go b/internal/llm/prompt.go new file mode 100644 index 0000000..b4c2217 --- /dev/null +++ b/internal/llm/prompt.go @@ -0,0 +1,613 @@ +package llm + +import ( + "fmt" + "io/fs" + "strings" + "sync" + + "go.yaml.in/yaml/v3" +) + +// BasePrompt represents the parsed security_analysis_base.yaml template. +type BasePrompt struct { + SystemMessage string `yaml:"system_message"` + AnalysisIntro string `yaml:"analysis_intro"` + InfrastructureNote string `yaml:"infrastructure_note"` + AnalysisRequirementsHeader string `yaml:"analysis_requirements_header"` + CustomRequirementsPlaceholder string `yaml:"custom_requirements_placeholder"` + RepoInfo string `yaml:"repo_info"` + CriticalInstructions string `yaml:"critical_instructions"` + JSONFormattingRules string `yaml:"json_formatting_rules"` +} + +// AnalysisSectionsFile represents the parsed analysis_sections.yaml. +type AnalysisSectionsFile struct { + Sections map[string]AnalysisSection `yaml:"sections"` +} + +// AnalysisSection is a single dynamic analysis section. +type AnalysisSection struct { + Title string `yaml:"title"` + Features []string `yaml:"features"` + Content string `yaml:"content"` +} + +// PromptParams holds the dynamic values used when assembling prompt messages. +type PromptParams struct { + RepoName string + XML string // Flattened repo XML content. + Schema string // JSON schema string. + Manifest []string // All file paths in the full repo (for chunk context). + ChunkIndex int // 0-based chunk index. + ChunkTotal int // Total number of chunks. + CustomRequirements string // User-provided requirements from --custom-requirements. + EnabledFeatures []string // If non-empty, only include sections whose features overlap. + // SupplementaryContext is the packed output of internal/supctx — docs, + // API specs, sibling-repo snippets — shown to the model as reference + // material. Findings must not be reported against it. + SupplementaryContext string +} + +// FeaturePromptParams holds the dynamic values for feature detection prompt assembly. +type FeaturePromptParams struct { + RepoName string + Manifest []string // All file paths in the repo. + Samples string // Representative code samples (small subset of files). +} + +// FeatureDetectionPrompt represents the parsed feature_detection.yaml template. +type FeatureDetectionPrompt struct { + SystemMessage string `yaml:"system_message"` + UserPromptTemplate string `yaml:"user_prompt_template"` +} + +// AuditPrompt represents the parsed audit.yaml template. +type AuditPrompt struct { + SystemMessage string `yaml:"system_message"` + UserPromptTemplate string `yaml:"user_prompt_template"` + JSONFormattingRules string `yaml:"json_formatting_rules"` + // ProductionOnlyGate is injected at {production_only_gate} when the scan + // excluded test files (--include-tests not set). Adversarial pre-filter + // that forces the auditor to prove production reachability. + ProductionOnlyGate string `yaml:"production_only_gate"` +} + +// CWEPromptsFile represents the parsed cwe_deep_analysis.yaml. +type CWEPromptsFile struct { + CWEPrompts map[string]CWEPrompt `yaml:"cwe_prompts"` +} + +// CWEPrompt is a single CWE-specific deep analysis prompt. +type CWEPrompt struct { + Title string `yaml:"title"` + AnalysisPrompt string `yaml:"analysis_prompt"` + ValidationChecks []string `yaml:"validation_checks"` + FalsePositiveIndicators []string `yaml:"false_positive_indicators"` +} + +// AuditParams holds the dynamic values used when assembling audit prompt messages. +type AuditParams struct { + RepoName string + FindingsJSON string // Initial findings as JSON. + CodeContext string // Relevant source code for the findings. + CWEIDs []string // CWE IDs to include analysis prompts for. + Schema string // JSON schema string for the audit response. + ProductionOnly bool // When true, inject the production_only_gate section. + // SupplementaryContext is the packed output of internal/supctx for the + // audit phase. + SupplementaryContext string +} + +// PromptLoader loads and assembles prompt templates from an fs.FS. +// Templates are cached after first load to avoid repeated disk I/O and YAML parsing. +type PromptLoader struct { + fsys fs.FS + mu sync.Mutex // protects cached* fields for concurrent access + + // Cached templates (populated on first load). + cachedBase *BasePrompt + cachedSections *AnalysisSectionsFile + cachedFeature *FeatureDetectionPrompt + cachedAudit *AuditPrompt + cachedCWE *CWEPromptsFile + cachedCompress *ContextCompressPrompt +} + +// ContextCompressPrompt is the parsed context_compress.yaml template used by +// the supctx compression pre-pass. +type ContextCompressPrompt struct { + SystemMessage string `yaml:"system_message"` + UserPromptTemplate string `yaml:"user_prompt_template"` +} + +// NewPromptLoader creates a PromptLoader that reads templates from fsys. +// The caller provides either os.DirFS("prompts/default") or an embed.FS sub-directory. +func NewPromptLoader(fsys fs.FS) *PromptLoader { + return &PromptLoader{fsys: fsys} +} + +// LoadBasePrompt reads and parses the security_analysis_base.yaml template. +// The result is cached after the first successful load. +func (l *PromptLoader) LoadBasePrompt() (*BasePrompt, error) { + l.mu.Lock() + defer l.mu.Unlock() + if l.cachedBase != nil { + return l.cachedBase, nil + } + + data, err := fs.ReadFile(l.fsys, "security_analysis_base.yaml") + if err != nil { + return nil, fmt.Errorf("loading base prompt: %w", err) + } + + var bp BasePrompt + if err := yaml.Unmarshal(data, &bp); err != nil { + return nil, fmt.Errorf("parsing base prompt YAML: %w", err) + } + l.cachedBase = &bp + return &bp, nil +} + +// LoadAnalysisSections reads and parses the analysis_sections.yaml file. +// The result is cached after the first successful load. +func (l *PromptLoader) LoadAnalysisSections() (*AnalysisSectionsFile, error) { + l.mu.Lock() + defer l.mu.Unlock() + if l.cachedSections != nil { + return l.cachedSections, nil + } + + data, err := fs.ReadFile(l.fsys, "analysis_sections.yaml") + if err != nil { + return nil, fmt.Errorf("loading analysis sections: %w", err) + } + + var af AnalysisSectionsFile + if err := yaml.Unmarshal(data, &af); err != nil { + return nil, fmt.Errorf("parsing analysis sections YAML: %w", err) + } + l.cachedSections = &af + return &af, nil +} + +// AssembleMessages builds the system and user messages for a security analysis +// LLM call using the base prompt template and the given parameters. +func (l *PromptLoader) AssembleMessages(params PromptParams) ([]Message, error) { + bp, err := l.LoadBasePrompt() + if err != nil { + return nil, err + } + + var user strings.Builder + + // 1. Analysis intro. + user.WriteString(bp.AnalysisIntro) + user.WriteString("\n") + + // 2. Chunk manifest note (only for multi-chunk). + if params.ChunkTotal > 1 { + fmt.Fprintf(&user, "NOTE: This is chunk %d of %d. Files in this chunk are shown below. Other files in the repository:\n%s\nFocus your analysis on the files shown in this chunk.\n\n", + params.ChunkIndex+1, + params.ChunkTotal, + strings.Join(params.Manifest, "\n"), + ) + } + + // 3. Analysis requirements header and sections. + sections, err := l.LoadAnalysisSections() + if err == nil && len(sections.Sections) > 0 { + user.WriteString(bp.AnalysisRequirementsHeader) + user.WriteString("\n") + for _, section := range sections.Sections { + if !sectionEnabled(section, params.EnabledFeatures) { + continue + } + fmt.Fprintf(&user, "### %s\n%s\n", section.Title, section.Content) + } + user.WriteString("\n") + } + + // 4. Custom requirements. + if params.CustomRequirements != "" { + fmt.Fprintf(&user, "ADDITIONAL REQUIREMENTS:\n%s\n\n", params.CustomRequirements) + } + + // 4.5. Supplementary context. Placed before the repo so the model reads + // reference material first, then the actual scan target — matching how a + // human reviewer would skim the API spec before diving into handlers. + if params.SupplementaryContext != "" { + fmt.Fprintf(&user, "SUPPLEMENTARY CONTEXT — reference material, do NOT report findings against it:\n%s\n\n", params.SupplementaryContext) + } + + // 5. Repo info with placeholders replaced. + repoInfo := strings.ReplaceAll(bp.RepoInfo, "{repo_name}", params.RepoName) + repoInfo = strings.ReplaceAll(repoInfo, "{xml_content}", params.XML) + user.WriteString(repoInfo) + user.WriteString("\n") + + // 6. Critical instructions. + user.WriteString(bp.CriticalInstructions) + user.WriteString("\n") + + // 7. JSON formatting rules with schema replaced. + jsonRules := strings.ReplaceAll(bp.JSONFormattingRules, "{schema}", params.Schema) + user.WriteString(jsonRules) + + return []Message{ + {Role: "system", Content: bp.SystemMessage}, + {Role: "user", Content: user.String()}, + }, nil +} + +// LoadFeatureDetectionPrompt reads and parses the feature_detection.yaml template. +// The result is cached after the first successful load. +func (l *PromptLoader) LoadFeatureDetectionPrompt() (*FeatureDetectionPrompt, error) { + l.mu.Lock() + defer l.mu.Unlock() + if l.cachedFeature != nil { + return l.cachedFeature, nil + } + + data, err := fs.ReadFile(l.fsys, "feature_detection.yaml") + if err != nil { + return nil, fmt.Errorf("loading feature detection prompt: %w", err) + } + + var fd FeatureDetectionPrompt + if err := yaml.Unmarshal(data, &fd); err != nil { + return nil, fmt.Errorf("parsing feature detection prompt YAML: %w", err) + } + l.cachedFeature = &fd + return &fd, nil +} + +// AssembleFeatureDetectionMessages builds the system and user messages for a +// feature detection LLM call using the feature_detection.yaml template. +func (l *PromptLoader) AssembleFeatureDetectionMessages(params FeaturePromptParams) ([]Message, error) { + fd, err := l.LoadFeatureDetectionPrompt() + if err != nil { + return nil, err + } + + // Build the XML content from manifest + samples. + var xmlContent strings.Builder + xmlContent.WriteString("\n") + for _, path := range params.Manifest { + fmt.Fprintf(&xmlContent, " %s\n", path) + } + xmlContent.WriteString("\n\n") + if params.Samples != "" { + xmlContent.WriteString(params.Samples) + } + + schemaRaw := FeatureDetectionSchema() + schemaStr := string(*schemaRaw) + + userContent := fd.UserPromptTemplate + userContent = strings.ReplaceAll(userContent, "{repo_name}", params.RepoName) + userContent = strings.ReplaceAll(userContent, "{xml_content}", xmlContent.String()) + userContent = strings.ReplaceAll(userContent, "{schema}", schemaStr) + + return []Message{ + {Role: "system", Content: fd.SystemMessage}, + {Role: "user", Content: userContent}, + }, nil +} + +// sectionEnabled returns true if a section should be included given the enabled features. +// A section is included if: +// - enabledFeatures is empty (no filtering), +// - the section has no required features (always included), or +// - at least one of the section's features is in enabledFeatures. +func sectionEnabled(section AnalysisSection, enabledFeatures []string) bool { + if len(enabledFeatures) == 0 { + return true + } + if len(section.Features) == 0 { + return true + } + enabled := make(map[string]bool, len(enabledFeatures)) + for _, f := range enabledFeatures { + enabled[f] = true + } + for _, f := range section.Features { + if enabled[f] { + return true + } + } + return false +} + +// FileEntry represents a file with its path and content for feature sample building. +type FileEntry struct { + Path string + Content string +} + +// BuildFeatureSamples selects a representative subset of files to include as samples +// for feature detection. It prioritizes dependency manifests and entrypoint files, +// and caps total output at roughly maxTokens * 4 characters. +func BuildFeatureSamples(files []FileEntry, maxTokens int) string { + maxChars := maxTokens * 4 + + // Dependency manifest filenames to always include. + manifests := map[string]bool{ + "package.json": true, + "go.mod": true, + "go.sum": false, // skip, too large and not informative + "requirements.txt": true, + "Cargo.toml": true, + "Gemfile": true, + "pom.xml": true, + "build.gradle": true, + "composer.json": true, + "pyproject.toml": true, + "setup.py": true, + "Pipfile": true, + "pubspec.yaml": true, + "Package.swift": true, + "mix.exs": true, + "project.clj": true, + } + + var out strings.Builder + written := 0 + + addFile := func(path, content string, maxLines int) bool { + if written >= maxChars { + return false + } + var trimmed string + if maxLines > 0 { + lines := strings.SplitN(content, "\n", maxLines+1) + if len(lines) > maxLines { + lines = lines[:maxLines] + } + trimmed = strings.Join(lines, "\n") + } else { + trimmed = content + } + entry := fmt.Sprintf("\n%s\n\n", path, trimmed) + if written+len(entry) > maxChars { + remaining := maxChars - written + if remaining < 100 { + return false + } + entry = entry[:remaining] + } + out.WriteString(entry) + written += len(entry) + return written < maxChars + } + + // Pass 1: dependency manifests (full content). + for _, f := range files { + base := baseName(f.Path) + if include, ok := manifests[base]; ok && include { + if !addFile(f.Path, f.Content, 0) { + return out.String() + } + } + } + + // Pass 2: entrypoints and routers (first ~50 lines). + for _, f := range files { + if isEntrypoint(f.Path) { + if !addFile(f.Path, f.Content, 50) { + return out.String() + } + } + } + + return out.String() +} + +// baseName returns the last component of a slash-separated path. +func baseName(path string) string { + if i := strings.LastIndex(path, "/"); i >= 0 { + return path[i+1:] + } + return path +} + +// LoadAuditPrompt reads and parses the audit.yaml template. +// The result is cached after the first successful load. +func (l *PromptLoader) LoadAuditPrompt() (*AuditPrompt, error) { + l.mu.Lock() + defer l.mu.Unlock() + if l.cachedAudit != nil { + return l.cachedAudit, nil + } + + data, err := fs.ReadFile(l.fsys, "audit.yaml") + if err != nil { + return nil, fmt.Errorf("loading audit prompt: %w", err) + } + + var ap AuditPrompt + if err := yaml.Unmarshal(data, &ap); err != nil { + return nil, fmt.Errorf("parsing audit prompt YAML: %w", err) + } + l.cachedAudit = &ap + return &ap, nil +} + +// LoadCWEPrompts reads and parses the cwe_deep_analysis.yaml file. +// The result is cached after the first successful load. +func (l *PromptLoader) LoadCWEPrompts() (*CWEPromptsFile, error) { + l.mu.Lock() + defer l.mu.Unlock() + if l.cachedCWE != nil { + return l.cachedCWE, nil + } + + data, err := fs.ReadFile(l.fsys, "cwe_deep_analysis.yaml") + if err != nil { + return nil, fmt.Errorf("loading CWE prompts: %w", err) + } + + var cf CWEPromptsFile + if err := yaml.Unmarshal(data, &cf); err != nil { + return nil, fmt.Errorf("parsing CWE prompts YAML: %w", err) + } + l.cachedCWE = &cf + return &cf, nil +} + +// LoadContextCompressPrompt reads and parses context_compress.yaml. +// Missing file returns a usable fallback so the feature degrades gracefully +// when users run with a custom --prompts-dir that predates this template. +func (l *PromptLoader) LoadContextCompressPrompt() (*ContextCompressPrompt, error) { + l.mu.Lock() + defer l.mu.Unlock() + if l.cachedCompress != nil { + return l.cachedCompress, nil + } + data, err := fs.ReadFile(l.fsys, "context_compress.yaml") + if err != nil { + l.cachedCompress = &ContextCompressPrompt{ + SystemMessage: "You are a technical writer specialising in security documentation. Compress reference material while preserving security-relevant detail.", + UserPromptTemplate: "Summarise the following reference material named \"{source_name}\" into approximately {target_tokens} tokens. Preserve authentication flows, trust boundaries, input validation rules, and known issues. Omit boilerplate.\n\n{content}", + } + return l.cachedCompress, nil + } + var cp ContextCompressPrompt + if err := yaml.Unmarshal(data, &cp); err != nil { + return nil, fmt.Errorf("parsing context compress prompt YAML: %w", err) + } + l.cachedCompress = &cp + return &cp, nil +} + +// AssembleAuditMessages builds the system and user messages for the audit phase. +func (l *PromptLoader) AssembleAuditMessages(params AuditParams) ([]Message, error) { + ap, err := l.LoadAuditPrompt() + if err != nil { + return nil, err + } + + // Build CWE-specific analysis section. + cweSection, err := l.buildCWESection(params.CWEIDs) + if err != nil { + // Non-fatal: proceed without CWE prompts. + cweSection = "(CWE-specific prompts not available)" + } + + prodGate := "" + if params.ProductionOnly { + prodGate = ap.ProductionOnlyGate + } + + userContent := ap.UserPromptTemplate + userContent = strings.ReplaceAll(userContent, "{repo_name}", params.RepoName) + userContent = strings.ReplaceAll(userContent, "{findings_json}", params.FindingsJSON) + userContent = strings.ReplaceAll(userContent, "{cwe_analysis_prompts}", cweSection) + userContent = strings.ReplaceAll(userContent, "{code_context}", params.CodeContext) + userContent = strings.ReplaceAll(userContent, "{production_only_gate}", prodGate) + + supSection := "" + if params.SupplementaryContext != "" { + supSection = "=== SUPPLEMENTARY CONTEXT (reference material — do NOT report findings against it) ===\n\n" + + params.SupplementaryContext + } + userContent = strings.ReplaceAll(userContent, "{supplementary_context}", supSection) + + system := strings.ReplaceAll(ap.SystemMessage, "{production_only_gate}", prodGate) + + // Append JSON formatting rules with schema. + jsonRules := strings.ReplaceAll(ap.JSONFormattingRules, "{schema}", params.Schema) + userContent += "\n" + jsonRules + + // Append schema to user content. + userContent = strings.ReplaceAll(userContent, "{schema}", params.Schema) + + return []Message{ + {Role: "system", Content: system}, + {Role: "user", Content: userContent}, + }, nil +} + +// buildCWESection assembles the CWE-specific analysis prompts for the given CWE IDs. +func (l *PromptLoader) buildCWESection(cweIDs []string) (string, error) { + cf, err := l.LoadCWEPrompts() + if err != nil { + return "", err + } + + // Normalize CWE IDs: extract "CWE-89" from "CWE-89: SQL Injection". + seen := make(map[string]bool) + var normalized []string + for _, id := range cweIDs { + // Extract just the CWE-NNN part. + key := id + if idx := strings.Index(id, ":"); idx > 0 { + key = strings.TrimSpace(id[:idx]) + } + key = strings.ToUpper(strings.TrimSpace(key)) + if !seen[key] { + seen[key] = true + normalized = append(normalized, key) + } + } + + var out strings.Builder + for _, key := range normalized { + prompt, ok := cf.CWEPrompts[key] + if !ok { + continue + } + + fmt.Fprintf(&out, "### %s: %s\n\n", key, prompt.Title) + fmt.Fprintf(&out, "**Deep Analysis Guidance:**\n%s\n\n", prompt.AnalysisPrompt) + + if len(prompt.ValidationChecks) > 0 { + out.WriteString("**Validation Checks:**\n") + for _, check := range prompt.ValidationChecks { + fmt.Fprintf(&out, " - %s\n", check) + } + out.WriteString("\n") + } + + if len(prompt.FalsePositiveIndicators) > 0 { + out.WriteString("**False Positive Indicators:**\n") + for _, indicator := range prompt.FalsePositiveIndicators { + fmt.Fprintf(&out, " - %s\n", indicator) + } + out.WriteString("\n") + } + } + + if out.Len() == 0 { + return "(No CWE-specific prompts matched the findings)", nil + } + + return out.String(), nil +} + +// isEntrypoint returns true if the file path looks like an entrypoint or router file. +func isEntrypoint(path string) bool { + base := strings.ToLower(baseName(path)) + lower := strings.ToLower(path) + + entrypointNames := []string{ + "main.go", "main.py", "main.ts", "main.js", "main.rs", "main.rb", + "server.go", "server.py", "server.ts", "server.js", + "app.go", "app.py", "app.ts", "app.js", "app.rb", + "index.ts", "index.js", + } + for _, name := range entrypointNames { + if base == name { + return true + } + } + + routerDirs := []string{"/routes/", "/handlers/", "/controllers/", "/api/", "/middleware/"} + for _, dir := range routerDirs { + if strings.Contains(lower, dir) { + return true + } + } + + return false +} diff --git a/internal/llm/prompt_test.go b/internal/llm/prompt_test.go new file mode 100644 index 0000000..d51a8e9 --- /dev/null +++ b/internal/llm/prompt_test.go @@ -0,0 +1,1086 @@ +package llm + +import ( + "strings" + "testing" + "testing/fstest" +) + +const testBaseYAML = `system_message: > + You are a senior security engineer. + +analysis_intro: > + Analyze the following repository. + +infrastructure_note: > + IMPORTANT - INFRASTRUCTURE ANALYSIS. + +analysis_requirements_header: > + ANALYSIS REQUIREMENTS: + +custom_requirements_placeholder: "{custom_requirements_section}" + +repo_info: > + The repository name is: {repo_name}. + Repomix XML Content: + --- + {xml_content} + --- + +critical_instructions: > + CRITICAL INSTRUCTIONS: + Report all vulnerabilities. + +json_formatting_rules: > + Respond with ONLY valid JSON. + JSON Schema: + --- + {schema} + --- +` + +const testSectionsYAML = `sections: + public_api: + title: "PUBLIC API SECURITY" + features: ["public_api", "websockets"] + content: > + Identify all public-facing endpoints. + code_quality: + title: "CODE QUALITY" + features: [] + content: > + Report risky code patterns. +` + +func testFS() fstest.MapFS { + return fstest.MapFS{ + "security_analysis_base.yaml": &fstest.MapFile{Data: []byte(testBaseYAML)}, + "analysis_sections.yaml": &fstest.MapFile{Data: []byte(testSectionsYAML)}, + } +} + +func TestLoadBasePrompt(t *testing.T) { + loader := NewPromptLoader(testFS()) + + bp, err := loader.LoadBasePrompt() + if err != nil { + t.Fatalf("LoadBasePrompt() error: %v", err) + } + + if !strings.Contains(bp.SystemMessage, "senior security engineer") { + t.Errorf("SystemMessage = %q, want to contain 'senior security engineer'", bp.SystemMessage) + } + if !strings.Contains(bp.AnalysisIntro, "Analyze the following") { + t.Errorf("AnalysisIntro = %q, want to contain 'Analyze the following'", bp.AnalysisIntro) + } + if !strings.Contains(bp.RepoInfo, "{repo_name}") { + t.Errorf("RepoInfo = %q, want to contain '{repo_name}'", bp.RepoInfo) + } + if !strings.Contains(bp.JSONFormattingRules, "{schema}") { + t.Errorf("JSONFormattingRules = %q, want to contain '{schema}'", bp.JSONFormattingRules) + } +} + +func TestLoadAnalysisSections(t *testing.T) { + loader := NewPromptLoader(testFS()) + + af, err := loader.LoadAnalysisSections() + if err != nil { + t.Fatalf("LoadAnalysisSections() error: %v", err) + } + + if len(af.Sections) != 2 { + t.Fatalf("got %d sections, want 2", len(af.Sections)) + } + + api, ok := af.Sections["public_api"] + if !ok { + t.Fatal("missing 'public_api' section") + } + if api.Title != "PUBLIC API SECURITY" { + t.Errorf("public_api.Title = %q, want %q", api.Title, "PUBLIC API SECURITY") + } + if len(api.Features) != 2 { + t.Errorf("public_api.Features length = %d, want 2", len(api.Features)) + } + + cq, ok := af.Sections["code_quality"] + if !ok { + t.Fatal("missing 'code_quality' section") + } + if len(cq.Features) != 0 { + t.Errorf("code_quality.Features length = %d, want 0", len(cq.Features)) + } +} + +func TestAssembleMessages_BasicSubstitution(t *testing.T) { + loader := NewPromptLoader(testFS()) + + msgs, err := loader.AssembleMessages(PromptParams{ + RepoName: "my-repo", + XML: "contents", + Schema: `{"type":"object"}`, + ChunkTotal: 1, + }) + if err != nil { + t.Fatalf("AssembleMessages() error: %v", err) + } + + if len(msgs) != 2 { + t.Fatalf("got %d messages, want 2", len(msgs)) + } + + if msgs[0].Role != "system" { + t.Errorf("msgs[0].Role = %q, want %q", msgs[0].Role, "system") + } + if !strings.Contains(msgs[0].Content, "senior security engineer") { + t.Errorf("system message missing expected content") + } + + if msgs[1].Role != "user" { + t.Errorf("msgs[1].Role = %q, want %q", msgs[1].Role, "user") + } + + user := msgs[1].Content + if !strings.Contains(user, "my-repo") { + t.Error("user message missing repo name substitution") + } + if !strings.Contains(user, "contents") { + t.Error("user message missing XML content substitution") + } + if !strings.Contains(user, `{"type":"object"}`) { + t.Error("user message missing schema substitution") + } +} + +func TestAssembleMessages_IncludesAnalysisSections(t *testing.T) { + loader := NewPromptLoader(testFS()) + + msgs, err := loader.AssembleMessages(PromptParams{ + RepoName: "test-repo", + XML: "", + Schema: "{}", + ChunkTotal: 1, + }) + if err != nil { + t.Fatalf("AssembleMessages() error: %v", err) + } + + user := msgs[1].Content + if !strings.Contains(user, "ANALYSIS REQUIREMENTS:") { + t.Error("user message missing ANALYSIS REQUIREMENTS header") + } + if !strings.Contains(user, "PUBLIC API SECURITY") { + t.Error("user message missing PUBLIC API SECURITY section title") + } + if !strings.Contains(user, "Identify all public-facing endpoints") { + t.Error("user message missing public_api section content") + } + if !strings.Contains(user, "CODE QUALITY") { + t.Error("user message missing CODE QUALITY section title") + } + if !strings.Contains(user, "Report risky code patterns") { + t.Error("user message missing code_quality section content") + } +} + +func TestAssembleMessages_SectionsBeforeRepoInfo(t *testing.T) { + loader := NewPromptLoader(testFS()) + + msgs, err := loader.AssembleMessages(PromptParams{ + RepoName: "order-test", + XML: "", + Schema: "{}", + ChunkTotal: 1, + }) + if err != nil { + t.Fatalf("AssembleMessages() error: %v", err) + } + + user := msgs[1].Content + sectionsIdx := strings.Index(user, "ANALYSIS REQUIREMENTS:") + repoInfoIdx := strings.Index(user, "order-test") + if sectionsIdx < 0 || repoInfoIdx < 0 { + t.Fatal("missing expected content in user message") + } + if sectionsIdx > repoInfoIdx { + t.Error("analysis sections should appear before repo info in the prompt") + } +} + +func TestAssembleMessages_GracefulWithoutSectionsFile(t *testing.T) { + noSectionsFS := fstest.MapFS{ + "security_analysis_base.yaml": &fstest.MapFile{Data: []byte(testBaseYAML)}, + } + loader := NewPromptLoader(noSectionsFS) + + msgs, err := loader.AssembleMessages(PromptParams{ + RepoName: "no-sections", + XML: "", + Schema: "{}", + ChunkTotal: 1, + }) + if err != nil { + t.Fatalf("AssembleMessages() should not fail without sections file: %v", err) + } + + user := msgs[1].Content + if !strings.Contains(user, "no-sections") { + t.Error("user message missing repo name") + } +} + +func TestAssembleMessages_ChunkManifest(t *testing.T) { + loader := NewPromptLoader(testFS()) + + msgs, err := loader.AssembleMessages(PromptParams{ + RepoName: "chunked-repo", + XML: "", + Schema: "{}", + Manifest: []string{"file1.go", "file2.go", "file3.go"}, + ChunkIndex: 1, + ChunkTotal: 3, + }) + if err != nil { + t.Fatalf("AssembleMessages() error: %v", err) + } + + user := msgs[1].Content + if !strings.Contains(user, "chunk 2 of 3") { + t.Error("user message missing chunk indicator") + } + if !strings.Contains(user, "file1.go") { + t.Error("user message missing manifest file1.go") + } + if !strings.Contains(user, "file2.go") { + t.Error("user message missing manifest file2.go") + } + if !strings.Contains(user, "Focus your analysis on the files shown in this chunk") { + t.Error("user message missing chunk focus instruction") + } +} + +func TestAssembleMessages_SingleChunk_NoManifest(t *testing.T) { + loader := NewPromptLoader(testFS()) + + msgs, err := loader.AssembleMessages(PromptParams{ + RepoName: "single-repo", + XML: "", + Schema: "{}", + ChunkTotal: 1, + }) + if err != nil { + t.Fatalf("AssembleMessages() error: %v", err) + } + + user := msgs[1].Content + if strings.Contains(user, "chunk") { + t.Error("single-chunk message should not contain chunk info") + } + if strings.Contains(user, "Focus your analysis") { + t.Error("single-chunk message should not contain chunk focus instruction") + } +} + +func TestAssembleMessages_CustomRequirements(t *testing.T) { + loader := NewPromptLoader(testFS()) + + msgs, err := loader.AssembleMessages(PromptParams{ + RepoName: "custom-repo", + XML: "", + Schema: "{}", + ChunkTotal: 1, + CustomRequirements: "Focus on auth bypass vulnerabilities.", + }) + if err != nil { + t.Fatalf("AssembleMessages() error: %v", err) + } + + user := msgs[1].Content + if !strings.Contains(user, "ADDITIONAL REQUIREMENTS:") { + t.Error("user message missing ADDITIONAL REQUIREMENTS header") + } + if !strings.Contains(user, "Focus on auth bypass vulnerabilities.") { + t.Error("user message missing custom requirements content") + } +} + +func TestAssembleMessages_SupplementaryContext(t *testing.T) { + loader := NewPromptLoader(testFS()) + + msgs, err := loader.AssembleMessages(PromptParams{ + RepoName: "repo", + XML: "code", + Schema: "{}", + ChunkTotal: 1, + CustomRequirements: "custom reqs here", + SupplementaryContext: "openapi: 3.0", + }) + if err != nil { + t.Fatalf("AssembleMessages() error: %v", err) + } + + user := msgs[1].Content + if !strings.Contains(user, "SUPPLEMENTARY CONTEXT") { + t.Error("user message missing SUPPLEMENTARY CONTEXT header") + } + if !strings.Contains(user, "openapi: 3.0") { + t.Error("user message missing supplementary content") + } + + // Order: custom-requirements < supplementary-context < repo-info. + reqIdx := strings.Index(user, "custom reqs here") + supIdx := strings.Index(user, "openapi: 3.0") + repoIdx := strings.Index(user, "code") + if !(reqIdx < supIdx && supIdx < repoIdx) { + t.Errorf("section order wrong: reqs=%d sup=%d repo=%d", reqIdx, supIdx, repoIdx) + } +} + +func TestAssembleMessages_NoSupplementaryContext(t *testing.T) { + loader := NewPromptLoader(testFS()) + msgs, _ := loader.AssembleMessages(PromptParams{ + RepoName: "r", XML: "x", Schema: "{}", ChunkTotal: 1, + }) + if strings.Contains(msgs[1].Content, "SUPPLEMENTARY CONTEXT") { + t.Error("header should be omitted when SupplementaryContext is empty") + } +} + +func TestAssembleMessages_NoCustomRequirements(t *testing.T) { + loader := NewPromptLoader(testFS()) + + msgs, err := loader.AssembleMessages(PromptParams{ + RepoName: "no-custom-repo", + XML: "", + Schema: "{}", + ChunkTotal: 1, + }) + if err != nil { + t.Fatalf("AssembleMessages() error: %v", err) + } + + user := msgs[1].Content + if strings.Contains(user, "ADDITIONAL REQUIREMENTS:") { + t.Error("user message should not contain ADDITIONAL REQUIREMENTS when none provided") + } +} + +func TestAssembleMessages_MissingTemplateFile(t *testing.T) { + emptyFS := fstest.MapFS{} + loader := NewPromptLoader(emptyFS) + + _, err := loader.AssembleMessages(PromptParams{ + RepoName: "missing", + XML: "", + Schema: "{}", + ChunkTotal: 1, + }) + if err == nil { + t.Fatal("AssembleMessages() with missing template should return error") + } + if !strings.Contains(err.Error(), "loading base prompt") { + t.Errorf("error = %q, want to contain 'loading base prompt'", err.Error()) + } +} + +func TestAssembleMessages_AllParamsEmpty(t *testing.T) { + loader := NewPromptLoader(testFS()) + + msgs, err := loader.AssembleMessages(PromptParams{}) + if err != nil { + t.Fatalf("AssembleMessages() with empty params should not panic/error: %v", err) + } + + if len(msgs) != 2 { + t.Fatalf("got %d messages, want 2", len(msgs)) + } + if msgs[0].Role != "system" { + t.Errorf("msgs[0].Role = %q, want %q", msgs[0].Role, "system") + } + if msgs[1].Role != "user" { + t.Errorf("msgs[1].Role = %q, want %q", msgs[1].Role, "user") + } +} + +func TestLoadBasePrompt_MissingFile(t *testing.T) { + emptyFS := fstest.MapFS{} + loader := NewPromptLoader(emptyFS) + + _, err := loader.LoadBasePrompt() + if err == nil { + t.Fatal("LoadBasePrompt() with missing file should return error") + } + if !strings.Contains(err.Error(), "loading base prompt") { + t.Errorf("error = %q, want to contain 'loading base prompt'", err.Error()) + } +} + +func TestLoadAnalysisSections_MissingFile(t *testing.T) { + emptyFS := fstest.MapFS{} + loader := NewPromptLoader(emptyFS) + + _, err := loader.LoadAnalysisSections() + if err == nil { + t.Fatal("LoadAnalysisSections() with missing file should return error") + } + if !strings.Contains(err.Error(), "loading analysis sections") { + t.Errorf("error = %q, want to contain 'loading analysis sections'", err.Error()) + } +} + +// --- Feature detection prompt tests --- + +const testFeatureYAML = `system_message: > + You are a feature detector. + +user_prompt_template: > + Repo: {repo_name} + Content: + {xml_content} + Schema: + {schema} +` + +func testFeatureFS() fstest.MapFS { + return fstest.MapFS{ + "feature_detection.yaml": &fstest.MapFile{Data: []byte(testFeatureYAML)}, + } +} + +func TestLoadFeatureDetectionPrompt(t *testing.T) { + loader := NewPromptLoader(testFeatureFS()) + + fd, err := loader.LoadFeatureDetectionPrompt() + if err != nil { + t.Fatalf("LoadFeatureDetectionPrompt() error: %v", err) + } + + if !strings.Contains(fd.SystemMessage, "feature detector") { + t.Errorf("SystemMessage = %q, want to contain 'feature detector'", fd.SystemMessage) + } + if !strings.Contains(fd.UserPromptTemplate, "{repo_name}") { + t.Errorf("UserPromptTemplate = %q, want to contain '{repo_name}'", fd.UserPromptTemplate) + } + + // Second call should return the cached instance. + fd2, err := loader.LoadFeatureDetectionPrompt() + if err != nil { + t.Fatalf("cached LoadFeatureDetectionPrompt() error: %v", err) + } + if fd != fd2 { + t.Error("expected cached pointer to be returned on second call") + } +} + +func TestLoadFeatureDetectionPrompt_MissingFile(t *testing.T) { + loader := NewPromptLoader(fstest.MapFS{}) + + _, err := loader.LoadFeatureDetectionPrompt() + if err == nil { + t.Fatal("LoadFeatureDetectionPrompt() with missing file should return error") + } + if !strings.Contains(err.Error(), "loading feature detection prompt") { + t.Errorf("error = %q, want to contain 'loading feature detection prompt'", err.Error()) + } +} + +func TestAssembleFeatureDetectionMessages(t *testing.T) { + loader := NewPromptLoader(testFeatureFS()) + + msgs, err := loader.AssembleFeatureDetectionMessages(FeaturePromptParams{ + RepoName: "feat-repo", + Manifest: []string{"src/main.go", "src/db.go"}, + Samples: `package main`, + }) + if err != nil { + t.Fatalf("AssembleFeatureDetectionMessages() error: %v", err) + } + + if len(msgs) != 2 { + t.Fatalf("got %d messages, want 2", len(msgs)) + } + if msgs[0].Role != "system" { + t.Errorf("msgs[0].Role = %q, want %q", msgs[0].Role, "system") + } + if !strings.Contains(msgs[0].Content, "feature detector") { + t.Error("system message missing expected content") + } + + user := msgs[1].Content + if !strings.Contains(user, "feat-repo") { + t.Error("user message missing repo name substitution") + } + if !strings.Contains(user, "") { + t.Error("user message missing wrapper") + } + if !strings.Contains(user, "src/main.go") { + t.Error("user message missing manifest path src/main.go") + } + if !strings.Contains(user, "src/db.go") { + t.Error("user message missing manifest path src/db.go") + } + if !strings.Contains(user, "package main") { + t.Error("user message missing sample content") + } + if strings.Contains(user, "{schema}") { + t.Error("user message has unsubstituted {schema} placeholder") + } + if strings.Contains(user, "{xml_content}") { + t.Error("user message has unsubstituted {xml_content} placeholder") + } +} + +func TestAssembleFeatureDetectionMessages_MissingTemplate(t *testing.T) { + loader := NewPromptLoader(fstest.MapFS{}) + + _, err := loader.AssembleFeatureDetectionMessages(FeaturePromptParams{ + RepoName: "x", + }) + if err == nil { + t.Fatal("AssembleFeatureDetectionMessages() with missing template should return error") + } +} + +// --- sectionEnabled table-driven tests --- + +func TestSectionEnabled(t *testing.T) { + tests := []struct { + name string + section AnalysisSection + enabledFeatures []string + want bool + }{ + { + name: "no filter includes all", + section: AnalysisSection{Features: []string{"db"}}, + enabledFeatures: nil, + want: true, + }, + { + name: "empty filter includes all", + section: AnalysisSection{Features: []string{"db"}}, + enabledFeatures: []string{}, + want: true, + }, + { + name: "section with no features always included", + section: AnalysisSection{Features: nil}, + enabledFeatures: []string{"web"}, + want: true, + }, + { + name: "overlapping feature included", + section: AnalysisSection{Features: []string{"db", "sql"}}, + enabledFeatures: []string{"web", "sql"}, + want: true, + }, + { + name: "no overlap excluded", + section: AnalysisSection{Features: []string{"db", "sql"}}, + enabledFeatures: []string{"web", "api"}, + want: false, + }, + { + name: "single match", + section: AnalysisSection{Features: []string{"auth"}}, + enabledFeatures: []string{"auth"}, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := sectionEnabled(tt.section, tt.enabledFeatures) + if got != tt.want { + t.Errorf("sectionEnabled() = %v, want %v", got, tt.want) + } + }) + } +} + +// --- baseName table-driven tests --- + +func TestBaseName(t *testing.T) { + tests := []struct { + path string + want string + }{ + {"foo/bar/baz.go", "baz.go"}, + {"baz.go", "baz.go"}, + {"foo/bar/", ""}, + {"", ""}, + {"/leading/slash.txt", "slash.txt"}, + {"single/", ""}, + {"a/b/c/d/e.js", "e.js"}, + } + + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + got := baseName(tt.path) + if got != tt.want { + t.Errorf("baseName(%q) = %q, want %q", tt.path, got, tt.want) + } + }) + } +} + +// --- isEntrypoint table-driven tests --- + +func TestIsEntrypoint(t *testing.T) { + tests := []struct { + path string + want bool + }{ + {"main.go", true}, + {"cmd/server/main.go", true}, + {"src/Main.GO", true}, // case-insensitive on base name + {"app.py", true}, + {"index.js", true}, + {"server.ts", true}, + {"src/routes/users.go", true}, + {"internal/handlers/auth.go", true}, + {"pkg/controllers/admin.rb", true}, + {"src/api/v1.go", true}, + {"lib/middleware/cors.js", true}, + {"src/ROUTES/upper.go", true}, // case-insensitive on path + {"util.go", false}, + {"src/helpers/format.go", false}, + {"README.md", false}, + {"main_test.go", false}, + {"application.go", false}, + } + + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + got := isEntrypoint(tt.path) + if got != tt.want { + t.Errorf("isEntrypoint(%q) = %v, want %v", tt.path, got, tt.want) + } + }) + } +} + +// --- BuildFeatureSamples tests --- + +func TestBuildFeatureSamples_PrioritizesManifests(t *testing.T) { + files := []FileEntry{ + {Path: "src/util.go", Content: "package util"}, + {Path: "package.json", Content: `{"name":"app","dependencies":{}}`}, + {Path: "go.mod", Content: "module example.com/app\ngo 1.22"}, + } + + out := BuildFeatureSamples(files, 10000) + + if !strings.Contains(out, ``) { + t.Error("output missing package.json manifest") + } + if !strings.Contains(out, ``) { + t.Error("output missing go.mod manifest") + } + if !strings.Contains(out, "module example.com/app") { + t.Error("output missing go.mod content") + } + if strings.Contains(out, "src/util.go") { + t.Error("non-entrypoint non-manifest file should not be included") + } +} + +func TestBuildFeatureSamples_IncludesEntrypoints(t *testing.T) { + files := []FileEntry{ + {Path: "cmd/server/main.go", Content: "package main\nfunc main(){}"}, + {Path: "src/routes/users.go", Content: "package routes\nfunc Users(){}"}, + {Path: "docs/README.md", Content: "# readme"}, + } + + out := BuildFeatureSamples(files, 10000) + + if !strings.Contains(out, ``) { + t.Error("output missing main.go entrypoint") + } + if !strings.Contains(out, ``) { + t.Error("output missing routes/ entrypoint") + } + if strings.Contains(out, "README.md") { + t.Error("non-entrypoint file should not be included") + } +} + +func TestBuildFeatureSamples_SkipsGoSum(t *testing.T) { + files := []FileEntry{ + {Path: "go.sum", Content: "example.com/pkg v1.0.0 h1:abc"}, + {Path: "go.mod", Content: "module x"}, + } + + out := BuildFeatureSamples(files, 10000) + + if strings.Contains(out, "go.sum") { + t.Error("go.sum should be skipped") + } + if !strings.Contains(out, "go.mod") { + t.Error("go.mod should be included") + } +} + +func TestBuildFeatureSamples_RespectsCharBudget(t *testing.T) { + big := strings.Repeat("x", 5000) + files := []FileEntry{ + {Path: "package.json", Content: big}, + {Path: "go.mod", Content: big}, + {Path: "Cargo.toml", Content: big}, + } + + // maxTokens=100 → maxChars=400, so only part of the first file fits. + out := BuildFeatureSamples(files, 100) + + if len(out) > 400 { + t.Errorf("output length = %d, want <= 400", len(out)) + } +} + +func TestBuildFeatureSamples_TruncatesEntrypointLines(t *testing.T) { + var lines []string + for i := 0; i < 100; i++ { + lines = append(lines, "line") + } + content := strings.Join(lines, "\n") + + files := []FileEntry{ + {Path: "main.go", Content: content}, + } + + out := BuildFeatureSamples(files, 10000) + + // Entrypoints are capped at 50 lines. + got := strings.Count(out, "line") + if got > 50 { + t.Errorf("entrypoint content has %d 'line' occurrences, want <= 50", got) + } +} + +func TestBuildFeatureSamples_Empty(t *testing.T) { + out := BuildFeatureSamples(nil, 1000) + if out != "" { + t.Errorf("BuildFeatureSamples(nil) = %q, want empty", out) + } +} + +// --- Audit prompt tests --- + +const testAuditYAML = `system_message: > + You are an auditor. {production_only_gate} + +user_prompt_template: > + Repo: {repo_name} + Findings: + {findings_json} + CWE Prompts: + {cwe_analysis_prompts} + Code: + {code_context} + Gate: {production_only_gate} + +json_formatting_rules: > + Respond with JSON matching: + {schema} + +production_only_gate: > + PRODUCTION-ONLY: prove reachability from production entrypoints. +` + +const testCWEYAML = `cwe_prompts: + CWE-89: + title: "SQL Injection" + analysis_prompt: "Trace user input to SQL queries." + validation_checks: + - "Input reaches query unsanitized" + - "No parameterized queries" + false_positive_indicators: + - "Input is constant" + CWE-79: + title: "XSS" + analysis_prompt: "Check output encoding." + validation_checks: [] + false_positive_indicators: [] +` + +func testAuditFS() fstest.MapFS { + return fstest.MapFS{ + "audit.yaml": &fstest.MapFile{Data: []byte(testAuditYAML)}, + "cwe_deep_analysis.yaml": &fstest.MapFile{Data: []byte(testCWEYAML)}, + } +} + +func TestLoadAuditPrompt(t *testing.T) { + loader := NewPromptLoader(testAuditFS()) + + ap, err := loader.LoadAuditPrompt() + if err != nil { + t.Fatalf("LoadAuditPrompt() error: %v", err) + } + + if !strings.Contains(ap.SystemMessage, "auditor") { + t.Errorf("SystemMessage = %q, want to contain 'auditor'", ap.SystemMessage) + } + if !strings.Contains(ap.UserPromptTemplate, "{findings_json}") { + t.Errorf("UserPromptTemplate missing {findings_json} placeholder") + } + if !strings.Contains(ap.ProductionOnlyGate, "PRODUCTION-ONLY") { + t.Errorf("ProductionOnlyGate = %q, want to contain 'PRODUCTION-ONLY'", ap.ProductionOnlyGate) + } + + // Second call should return cached instance. + ap2, err := loader.LoadAuditPrompt() + if err != nil { + t.Fatalf("cached LoadAuditPrompt() error: %v", err) + } + if ap != ap2 { + t.Error("expected cached pointer on second call") + } +} + +func TestLoadAuditPrompt_MissingFile(t *testing.T) { + loader := NewPromptLoader(fstest.MapFS{}) + + _, err := loader.LoadAuditPrompt() + if err == nil { + t.Fatal("LoadAuditPrompt() with missing file should return error") + } + if !strings.Contains(err.Error(), "loading audit prompt") { + t.Errorf("error = %q, want to contain 'loading audit prompt'", err.Error()) + } +} + +func TestLoadCWEPrompts(t *testing.T) { + loader := NewPromptLoader(testAuditFS()) + + cf, err := loader.LoadCWEPrompts() + if err != nil { + t.Fatalf("LoadCWEPrompts() error: %v", err) + } + + if len(cf.CWEPrompts) != 2 { + t.Fatalf("got %d CWE prompts, want 2", len(cf.CWEPrompts)) + } + + sqli, ok := cf.CWEPrompts["CWE-89"] + if !ok { + t.Fatal("missing CWE-89 entry") + } + if sqli.Title != "SQL Injection" { + t.Errorf("CWE-89.Title = %q, want %q", sqli.Title, "SQL Injection") + } + if len(sqli.ValidationChecks) != 2 { + t.Errorf("CWE-89.ValidationChecks length = %d, want 2", len(sqli.ValidationChecks)) + } + if len(sqli.FalsePositiveIndicators) != 1 { + t.Errorf("CWE-89.FalsePositiveIndicators length = %d, want 1", len(sqli.FalsePositiveIndicators)) + } + + // Second call should return cached instance. + cf2, err := loader.LoadCWEPrompts() + if err != nil { + t.Fatalf("cached LoadCWEPrompts() error: %v", err) + } + if cf != cf2 { + t.Error("expected cached pointer on second call") + } +} + +func TestLoadCWEPrompts_MissingFile(t *testing.T) { + loader := NewPromptLoader(fstest.MapFS{}) + + _, err := loader.LoadCWEPrompts() + if err == nil { + t.Fatal("LoadCWEPrompts() with missing file should return error") + } + if !strings.Contains(err.Error(), "loading CWE prompts") { + t.Errorf("error = %q, want to contain 'loading CWE prompts'", err.Error()) + } +} + +func TestAssembleAuditMessages_ProductionOnlyTrue(t *testing.T) { + loader := NewPromptLoader(testAuditFS()) + + msgs, err := loader.AssembleAuditMessages(AuditParams{ + RepoName: "audit-repo", + FindingsJSON: `[{"id":1}]`, + CodeContext: "func vuln(){}", + CWEIDs: []string{"CWE-89"}, + Schema: `{"type":"array"}`, + ProductionOnly: true, + }) + if err != nil { + t.Fatalf("AssembleAuditMessages() error: %v", err) + } + + if len(msgs) != 2 { + t.Fatalf("got %d messages, want 2", len(msgs)) + } + + system := msgs[0].Content + user := msgs[1].Content + + if !strings.Contains(system, "PRODUCTION-ONLY") { + t.Error("system message missing production gate when ProductionOnly=true") + } + if !strings.Contains(user, "PRODUCTION-ONLY") { + t.Error("user message missing production gate when ProductionOnly=true") + } + if strings.Contains(user, "{production_only_gate}") { + t.Error("user message has unsubstituted {production_only_gate} placeholder") + } + if !strings.Contains(user, "audit-repo") { + t.Error("user message missing repo name substitution") + } + if !strings.Contains(user, `[{"id":1}]`) { + t.Error("user message missing findings JSON") + } + if !strings.Contains(user, "func vuln(){}") { + t.Error("user message missing code context") + } + if !strings.Contains(user, `{"type":"array"}`) { + t.Error("user message missing schema substitution") + } + if !strings.Contains(user, "SQL Injection") { + t.Error("user message missing CWE-89 prompt content") + } +} + +func TestAssembleAuditMessages_ProductionOnlyFalse(t *testing.T) { + loader := NewPromptLoader(testAuditFS()) + + msgs, err := loader.AssembleAuditMessages(AuditParams{ + RepoName: "audit-repo", + FindingsJSON: `[]`, + CodeContext: "code", + CWEIDs: nil, + Schema: "{}", + ProductionOnly: false, + }) + if err != nil { + t.Fatalf("AssembleAuditMessages() error: %v", err) + } + + system := msgs[0].Content + user := msgs[1].Content + + if strings.Contains(system, "PRODUCTION-ONLY") { + t.Error("system message should not contain production gate when ProductionOnly=false") + } + if strings.Contains(user, "PRODUCTION-ONLY") { + t.Error("user message should not contain production gate when ProductionOnly=false") + } + if strings.Contains(system, "{production_only_gate}") { + t.Error("system message has unsubstituted {production_only_gate} placeholder") + } + if strings.Contains(user, "{production_only_gate}") { + t.Error("user message has unsubstituted {production_only_gate} placeholder") + } +} + +func TestAssembleAuditMessages_MissingAuditTemplate(t *testing.T) { + loader := NewPromptLoader(fstest.MapFS{}) + + _, err := loader.AssembleAuditMessages(AuditParams{RepoName: "x"}) + if err == nil { + t.Fatal("AssembleAuditMessages() with missing template should return error") + } +} + +func TestAssembleAuditMessages_MissingCWEFile(t *testing.T) { + // Audit template present but CWE file missing: should not fail. + fsys := fstest.MapFS{ + "audit.yaml": &fstest.MapFile{Data: []byte(testAuditYAML)}, + } + loader := NewPromptLoader(fsys) + + msgs, err := loader.AssembleAuditMessages(AuditParams{ + RepoName: "x", + CWEIDs: []string{"CWE-89"}, + }) + if err != nil { + t.Fatalf("AssembleAuditMessages() should not fail without CWE file: %v", err) + } + + user := msgs[1].Content + if !strings.Contains(user, "CWE-specific prompts not available") { + t.Error("user message missing fallback text for unavailable CWE prompts") + } +} + +// --- buildCWESection tests --- + +func TestBuildCWESection_MatchesAndFormats(t *testing.T) { + loader := NewPromptLoader(testAuditFS()) + + out, err := loader.buildCWESection([]string{"CWE-89"}) + if err != nil { + t.Fatalf("buildCWESection() error: %v", err) + } + + if !strings.Contains(out, "### CWE-89: SQL Injection") { + t.Error("output missing CWE-89 header") + } + if !strings.Contains(out, "Trace user input to SQL queries") { + t.Error("output missing analysis prompt") + } + if !strings.Contains(out, "Validation Checks:") { + t.Error("output missing validation checks header") + } + if !strings.Contains(out, "Input reaches query unsanitized") { + t.Error("output missing validation check item") + } + if !strings.Contains(out, "False Positive Indicators:") { + t.Error("output missing false positive header") + } + if !strings.Contains(out, "Input is constant") { + t.Error("output missing false positive indicator") + } +} + +func TestBuildCWESection_NormalizesIDs(t *testing.T) { + loader := NewPromptLoader(testAuditFS()) + + out, err := loader.buildCWESection([]string{ + "cwe-89: SQL Injection", + "CWE-89", // duplicate after normalization + " CWE-79 ", // whitespace + }) + if err != nil { + t.Fatalf("buildCWESection() error: %v", err) + } + + if strings.Count(out, "### CWE-89") != 1 { + t.Error("CWE-89 should appear exactly once after dedup") + } + if !strings.Contains(out, "### CWE-79: XSS") { + t.Error("output missing normalized CWE-79 entry") + } +} + +func TestBuildCWESection_NoMatches(t *testing.T) { + loader := NewPromptLoader(testAuditFS()) + + out, err := loader.buildCWESection([]string{"CWE-999"}) + if err != nil { + t.Fatalf("buildCWESection() error: %v", err) + } + + if !strings.Contains(out, "No CWE-specific prompts matched") { + t.Errorf("output = %q, want fallback message", out) + } +} + +func TestBuildCWESection_EmptyInput(t *testing.T) { + loader := NewPromptLoader(testAuditFS()) + + out, err := loader.buildCWESection(nil) + if err != nil { + t.Fatalf("buildCWESection() error: %v", err) + } + + if !strings.Contains(out, "No CWE-specific prompts matched") { + t.Errorf("output = %q, want fallback message for empty input", out) + } +} diff --git a/internal/llm/repair.go b/internal/llm/repair.go new file mode 100644 index 0000000..d0bfc24 --- /dev/null +++ b/internal/llm/repair.go @@ -0,0 +1,125 @@ +package llm + +import ( + "encoding/json" + "regexp" + "strings" +) + +// RepairJSON attempts to extract a valid JSON object from LLM output that +// didn't quite follow the schema. It handles the common failure modes: +// markdown fences, prose preamble, and string-where-array-expected. +// Returns the repaired JSON and true if any repair was applied. +// Callers should try the original parse first and only call this on failure. +func RepairJSON(raw string) (string, bool) { + repaired := raw + changed := false + + // Strip markdown code fences: ```json ... ``` or ``` ... ```. + if s, ok := stripCodeFence(repaired); ok { + repaired, changed = s, true + } + + // If the model wrote prose around the JSON, extract the outermost object. + if s, ok := extractJSONObject(repaired); ok { + repaired, changed = s, true + } + + // Coerce string values into empty arrays for fields the schema declares + // as arrays — models sometimes write "none found" instead of []. + if s, ok := coerceStringsToEmptyArrays(repaired); ok { + repaired, changed = s, true + } + + return repaired, changed +} + +var fenceRE = regexp.MustCompile("(?s)```(?:json)?\\s*\\n(.*?)\\n```") + +func stripCodeFence(s string) (string, bool) { + trimmed := strings.TrimSpace(s) + if !strings.HasPrefix(trimmed, "```") { + return s, false + } + if m := fenceRE.FindStringSubmatch(trimmed); len(m) == 2 { + return m[1], true + } + return s, false +} + +// extractJSONObject finds the first { and its matching } using a depth counter. +// It ignores braces inside string literals so a path like "a/{b}/c" doesn't +// confuse the balance. Returns the substring and true if it's not already the +// entire input. +func extractJSONObject(s string) (string, bool) { + start := strings.IndexByte(s, '{') + if start < 0 { + return s, false + } + depth := 0 + inStr := false + esc := false + for i := start; i < len(s); i++ { + c := s[i] + if esc { + esc = false + continue + } + if c == '\\' { + esc = true + continue + } + if c == '"' { + inStr = !inStr + continue + } + if inStr { + continue + } + switch c { + case '{': + depth++ + case '}': + depth-- + if depth == 0 { + extracted := s[start : i+1] + if extracted == strings.TrimSpace(s) { + return s, false + } + return extracted, true + } + } + } + return s, false +} + +// arrayFields lists the AnalysisResult fields that must be arrays. Models +// occasionally emit a string like "no issues" for these when the answer is +// empty; we coerce those to []. +var arrayFields = []string{"security_issues", "public_api_routes"} + +func coerceStringsToEmptyArrays(s string) (string, bool) { + var payload map[string]any + if err := json.Unmarshal([]byte(s), &payload); err != nil { + return s, false + } + changed := false + for _, f := range arrayFields { + v, ok := payload[f] + if !ok { + continue + } + if _, isString := v.(string); isString { + payload[f] = []any{} + changed = true + } + } + if !changed { + return s, false + } + out, err := json.Marshal(payload) + if err != nil { + return s, false + } + return string(out), true +} diff --git a/internal/llm/repair_test.go b/internal/llm/repair_test.go new file mode 100644 index 0000000..c85d116 --- /dev/null +++ b/internal/llm/repair_test.go @@ -0,0 +1,118 @@ +package llm + +import ( + "encoding/json" + "testing" +) + +func TestRepairJSON_CleanInputUnchanged(t *testing.T) { + in := `{"security_issues":[],"repo_name":"x"}` + out, changed := RepairJSON(in) + if changed { + t.Errorf("clean input should not be changed: out=%q", out) + } +} + +func TestRepairJSON_StripsMarkdownFence(t *testing.T) { + in := "```json\n{\"security_issues\":[]}\n```" + out, changed := RepairJSON(in) + if !changed { + t.Fatal("expected repair to trigger") + } + var m map[string]any + if err := json.Unmarshal([]byte(out), &m); err != nil { + t.Fatalf("repaired output should parse: %v\nout=%q", err, out) + } +} + +func TestRepairJSON_StripsPlainFence(t *testing.T) { + in := "```\n{\"a\":1}\n```" + out, changed := RepairJSON(in) + if !changed || out != `{"a":1}` { + t.Fatalf("got changed=%v out=%q", changed, out) + } +} + +func TestRepairJSON_ExtractsFromProse(t *testing.T) { + in := `Here is my analysis: + +{"security_issues":[],"repo_name":"x"} + +Let me know if you need anything else.` + out, changed := RepairJSON(in) + if !changed { + t.Fatal("expected repair to trigger") + } + var m map[string]any + if err := json.Unmarshal([]byte(out), &m); err != nil { + t.Fatalf("repaired output should parse: %v\nout=%q", err, out) + } + if m["repo_name"] != "x" { + t.Errorf("repo_name = %v, want x", m["repo_name"]) + } +} + +func TestRepairJSON_CoercesStringToEmptyArray(t *testing.T) { + // The exact failure from the Java scan: security_issues as a string. + in := `{"repo_name":"x","security_issues":"No security issues were found in this chunk.","public_api_routes":[]}` + out, changed := RepairJSON(in) + if !changed { + t.Fatal("expected repair to trigger") + } + var m struct { + SecurityIssues []any `json:"security_issues"` + } + if err := json.Unmarshal([]byte(out), &m); err != nil { + t.Fatalf("repaired output should parse into []any: %v\nout=%q", err, out) + } + if len(m.SecurityIssues) != 0 { + t.Errorf("SecurityIssues = %v, want empty slice", m.SecurityIssues) + } +} + +func TestRepairJSON_CoercesBothArrayFields(t *testing.T) { + in := `{"security_issues":"none","public_api_routes":"not applicable"}` + out, changed := RepairJSON(in) + if !changed { + t.Fatal("expected repair to trigger") + } + var m struct { + SecurityIssues []any `json:"security_issues"` + PublicAPIRoutes []any `json:"public_api_routes"` + } + if err := json.Unmarshal([]byte(out), &m); err != nil { + t.Fatalf("repaired output should parse: %v\nout=%q", err, out) + } +} + +func TestRepairJSON_LeavesRealArraysAlone(t *testing.T) { + in := `{"security_issues":[{"issue":"x","file_path":"a.go","start_line":1,"end_line":2,"technical_details":"d","severity":5.0,"cwe_id":"CWE-79"}]}` + out, changed := RepairJSON(in) + if changed { + t.Errorf("valid array should not be coerced: out=%q", out) + } +} + +func TestExtractJSONObject_IgnoresBracesInStrings(t *testing.T) { + // A path with braces inside a string literal must not confuse depth tracking. + in := `prose {"k":"a/{b}/c","n":{"m":1}} trailing` + out, changed := extractJSONObject(in) + if !changed { + t.Fatal("expected extraction") + } + var m map[string]any + if err := json.Unmarshal([]byte(out), &m); err != nil { + t.Fatalf("extracted object should parse: %v\nout=%q", err, out) + } + if m["k"] != "a/{b}/c" { + t.Errorf("k = %v, want a/{b}/c", m["k"]) + } +} + +func TestExtractJSONObject_NoObject(t *testing.T) { + in := "just some prose with no braces" + _, changed := extractJSONObject(in) + if changed { + t.Error("no-brace input should not change") + } +} diff --git a/internal/llm/schema.go b/internal/llm/schema.go new file mode 100644 index 0000000..320d64c --- /dev/null +++ b/internal/llm/schema.go @@ -0,0 +1,278 @@ +package llm + +import ( + "encoding/json" + "strings" +) + +// SecurityAnalysisSchema returns the JSON Schema for the security analysis response format. +// This matches the SECURITY_ANALYSIS_SCHEMA from the current Python tool. +func SecurityAnalysisSchema() *json.RawMessage { + schema := json.RawMessage(securityAnalysisSchemaJSON) + return &schema +} + +const securityAnalysisSchemaJSON = `{ + "name": "security_analysis", + "strict": true, + "schema": { + "type": "object", + "required": ["repo_name", "description", "public_api_routes", "security_issues", "security_risk", "risk_justification"], + "additionalProperties": false, + "properties": { + "repo_name": { + "type": "string", + "description": "Name of the repository or project being analyzed" + }, + "description": { + "type": "string", + "description": "Brief description of what the codebase does" + }, + "public_api_routes": { + "type": "array", + "description": "List of public API routes found in the codebase", + "items": { + "type": "object", + "required": ["route", "citation"], + "additionalProperties": false, + "properties": { + "route": { + "type": "string", + "description": "The API route path" + }, + "citation": { + "type": "string", + "description": "File and line number where the route is defined" + } + } + } + }, + "security_issues": { + "type": "array", + "description": "List of security issues found in the codebase", + "items": { + "type": "object", + "required": ["issue", "file_path", "start_line", "end_line", "technical_details", "severity", "cwe_id"], + "additionalProperties": false, + "properties": { + "issue": { + "type": "string", + "description": "Short title describing the security issue" + }, + "file_path": { + "type": "string", + "description": "Relative path to the file containing the issue" + }, + "start_line": { + "type": "integer", + "description": "Starting line number of the vulnerable code" + }, + "end_line": { + "type": "integer", + "description": "Ending line number of the vulnerable code" + }, + "technical_details": { + "type": "string", + "description": "Detailed technical explanation of the vulnerability and remediation" + }, + "severity": { + "type": "number", + "description": "Severity score from 0 (info) to 10 (critical)" + }, + "cwe_id": { + "type": "string", + "description": "CWE identifier (e.g., CWE-89 for SQL injection)" + } + } + } + }, + "security_risk": { + "type": "number", + "description": "Overall security risk score from 0 (no risk) to 10 (critical risk)" + }, + "risk_justification": { + "type": "string", + "description": "Justification for the overall security risk score" + } + } + } +}` + +// FeatureDetectionSchema returns the JSON Schema for the feature detection response format. +// This is used in the first pass of a two-pass analysis to detect which security features +// the codebase uses. +func FeatureDetectionSchema() *json.RawMessage { + schema := json.RawMessage(featureDetectionSchemaJSON) + return &schema +} + +const featureDetectionSchemaJSON = `{ + "name": "feature_detection", + "strict": true, + "schema": { + "type": "object", + "required": ["detected_features"], + "additionalProperties": false, + "properties": { + "detected_features": { + "type": "array", + "description": "Security-relevant features detected in the codebase", + "items": { + "type": "string", + "enum": [ + "public_api", "authentication", "authorization", "database_operations", + "user_input_handling", "file_operations", "cryptography", "session_management", + "external_api_calls", "third_party_dependencies", "infrastructure_as_code", + "blockchain_crypto_finance", "websockets", "graphql", "grpc", + "xml_processing", "deserialization", "template_rendering", + "shell_command_execution", "sensitive_data_handling" + ] + } + } + } + } +}` + +// AuditSchema returns the JSON Schema for the audit phase response format. +func AuditSchema() *json.RawMessage { + schema := json.RawMessage(auditSchemaJSON) + return &schema +} + +const auditSchemaJSON = `{ + "name": "security_audit", + "strict": true, + "schema": { + "type": "object", + "required": ["audited_findings", "new_findings", "audit_summary"], + "additionalProperties": false, + "properties": { + "audited_findings": { + "type": "array", + "description": "Audit verdicts for each initial finding", + "items": { + "type": "object", + "required": ["original_issue", "file_path", "start_line", "end_line", "verdict", "confidence", "refined_severity", "refined_technical_details", "refined_cwe_id", "justification"], + "additionalProperties": false, + "properties": { + "original_issue": { + "type": "string", + "description": "The original issue title from the initial finding" + }, + "file_path": { + "type": "string", + "description": "File path of the finding" + }, + "start_line": { + "type": "integer", + "description": "Starting line number" + }, + "end_line": { + "type": "integer", + "description": "Ending line number" + }, + "verdict": { + "type": "string", + "enum": ["confirmed", "refined", "rejected", "escalated"], + "description": "Audit verdict: confirmed, refined, rejected, or escalated" + }, + "confidence": { + "type": "number", + "description": "Confidence score from 0.0 (unlikely) to 1.0 (certain)" + }, + "refined_severity": { + "type": "number", + "description": "Refined severity score (0-10), adjusted based on deeper analysis" + }, + "refined_technical_details": { + "type": "string", + "description": "Refined technical details incorporating deep CWE analysis" + }, + "refined_cwe_id": { + "type": "string", + "description": "Refined CWE identifier, may differ from initial if more specific CWE applies" + }, + "justification": { + "type": "string", + "description": "Justification for the verdict, explaining why the finding was confirmed, refined, rejected, or escalated" + } + } + } + }, + "new_findings": { + "type": "array", + "description": "Additional findings discovered during the deep CWE analysis", + "items": { + "type": "object", + "required": ["issue", "file_path", "start_line", "end_line", "technical_details", "severity", "cwe_id", "confidence"], + "additionalProperties": false, + "properties": { + "issue": { + "type": "string", + "description": "Short title describing the new security issue" + }, + "file_path": { + "type": "string", + "description": "Relative path to the file containing the issue" + }, + "start_line": { + "type": "integer", + "description": "Starting line number" + }, + "end_line": { + "type": "integer", + "description": "Ending line number" + }, + "technical_details": { + "type": "string", + "description": "Detailed technical explanation" + }, + "severity": { + "type": "number", + "description": "Severity score from 0 to 10" + }, + "cwe_id": { + "type": "string", + "description": "CWE identifier" + }, + "confidence": { + "type": "number", + "description": "Confidence score from 0.0 to 1.0" + } + } + } + }, + "audit_summary": { + "type": "string", + "description": "Brief summary of the audit results including counts of confirmed, refined, rejected, escalated, and new findings" + } + } + } +}` + +// OutputModeForModel returns the appropriate OutputMode for a model name. +// OpenAI/GPT models use JSON Schema response_format, Claude uses tool_use, +// and other models fall back to unstructured output. +func OutputModeForModel(modelName string) OutputMode { + lower := strings.ToLower(modelName) + + // OpenAI / GPT models support response_format JSON Schema. + if strings.Contains(lower, "gpt") || strings.Contains(lower, "o1") || strings.Contains(lower, "o3") { + return OutputModeJSONSchema + } + + // Claude models use tool_use for structured output. + if strings.Contains(lower, "claude") || strings.Contains(lower, "anthropic") { + return OutputModeToolUse + } + + // Gemini via the OpenAI-compat endpoint accepts response_format with + // json_schema. The native generateContent API uses a different + // schema dialect, but we don't speak that — the compat layer handles + // translation. + if strings.Contains(lower, "gemini") { + return OutputModeJSONSchema + } + + return OutputModeNone +} diff --git a/internal/llm/schema_test.go b/internal/llm/schema_test.go new file mode 100644 index 0000000..cd78d26 --- /dev/null +++ b/internal/llm/schema_test.go @@ -0,0 +1,96 @@ +package llm + +import ( + "encoding/json" + "testing" +) + +func TestSecurityAnalysisSchema_BasicShape(t *testing.T) { + raw := SecurityAnalysisSchema() + payload := decodeSchemaPayload(t, raw) + + if payload["name"] != "security_analysis" { + t.Fatalf("schema name = %v, want %q", payload["name"], "security_analysis") + } + if payload["strict"] != true { + t.Fatalf("strict = %v, want true", payload["strict"]) + } + + schemaObj, ok := payload["schema"].(map[string]any) + if !ok { + t.Fatalf("schema payload missing object body") + } + required, ok := schemaObj["required"].([]any) + if !ok || len(required) == 0 { + t.Fatalf("required fields missing from schema") + } +} + +func TestFeatureDetectionSchema_HasDetectedFeatures(t *testing.T) { + raw := FeatureDetectionSchema() + payload := decodeSchemaPayload(t, raw) + + if payload["name"] != "feature_detection" { + t.Fatalf("schema name = %v, want %q", payload["name"], "feature_detection") + } + + schemaObj := payload["schema"].(map[string]any) + properties := schemaObj["properties"].(map[string]any) + detected := properties["detected_features"].(map[string]any) + items := detected["items"].(map[string]any) + enumValues := items["enum"].([]any) + if len(enumValues) == 0 { + t.Fatal("detected_features enum must not be empty") + } +} + +func TestAuditSchema_BasicShape(t *testing.T) { + raw := AuditSchema() + payload := decodeSchemaPayload(t, raw) + + if payload["name"] != "security_audit" { + t.Fatalf("schema name = %v, want %q", payload["name"], "security_audit") + } + + schemaObj := payload["schema"].(map[string]any) + required := schemaObj["required"].([]any) + if len(required) != 3 { + t.Fatalf("required field count = %d, want 3", len(required)) + } +} + +func TestOutputModeForModel_Classification(t *testing.T) { + cases := []struct { + name string + model string + want OutputMode + }{ + {name: "gpt", model: "gpt-5.2", want: OutputModeJSONSchema}, + {name: "o-series", model: "o3-mini", want: OutputModeJSONSchema}, + {name: "claude", model: "claude-sonnet-4-6", want: OutputModeToolUse}, + {name: "gemini", model: "gemini-3-pro", want: OutputModeJSONSchema}, + {name: "other", model: "llama-3-405b", want: OutputModeNone}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := OutputModeForModel(tc.model) + if got != tc.want { + t.Fatalf("OutputModeForModel(%q) = %q, want %q", tc.model, got, tc.want) + } + }) + } +} + +func decodeSchemaPayload(t *testing.T, raw *json.RawMessage) map[string]any { + t.Helper() + if raw == nil { + t.Fatal("schema raw message is nil") + } + + var payload map[string]any + if err := json.Unmarshal(*raw, &payload); err != nil { + t.Fatalf("unmarshal schema: %v", err) + } + return payload +} diff --git a/internal/logging/logger.go b/internal/logging/logger.go new file mode 100644 index 0000000..0b5c6b9 --- /dev/null +++ b/internal/logging/logger.go @@ -0,0 +1,26 @@ +package logging + +import ( + "io" + "log/slog" + "os" +) + +// NewLogger creates a structured slog.Logger. +// When verbose is true, the level is set to DEBUG and output is JSON. +// When verbose is false, the level is set to INFO and output is text. +func NewLogger(verbose bool) *slog.Logger { + return NewLoggerWithWriter(verbose, os.Stderr) +} + +// NewLoggerWithWriter creates a structured slog.Logger writing to the given writer. +func NewLoggerWithWriter(verbose bool, w io.Writer) *slog.Logger { + if verbose { + return slog.New(slog.NewJSONHandler(w, &slog.HandlerOptions{ + Level: slog.LevelDebug, + })) + } + return slog.New(slog.NewTextHandler(w, &slog.HandlerOptions{ + Level: slog.LevelInfo, + })) +} diff --git a/internal/logging/logger_test.go b/internal/logging/logger_test.go new file mode 100644 index 0000000..fe430df --- /dev/null +++ b/internal/logging/logger_test.go @@ -0,0 +1,91 @@ +package logging + +import ( + "bytes" + "encoding/json" + "log/slog" + "strings" + "testing" +) + +func TestNewLoggerWithWriter_VerboseOutputsJSON(t *testing.T) { + var buf bytes.Buffer + logger := NewLoggerWithWriter(true, &buf) + + logger.Info("test message", "key", "value") + + output := buf.String() + if output == "" { + t.Fatal("expected log output, got empty string") + } + + // Verbose mode should produce valid JSON. + var parsed map[string]any + if err := json.Unmarshal([]byte(output), &parsed); err != nil { + t.Fatalf("verbose output is not valid JSON: %v\noutput: %s", err, output) + } + + if parsed["msg"] != "test message" { + t.Errorf("expected msg 'test message', got %v", parsed["msg"]) + } +} + +func TestNewLoggerWithWriter_VerboseIncludesDebug(t *testing.T) { + var buf bytes.Buffer + logger := NewLoggerWithWriter(true, &buf) + + logger.Debug("debug message") + + output := buf.String() + if !strings.Contains(output, "debug message") { + t.Errorf("expected debug message in verbose output, got: %s", output) + } +} + +func TestNewLoggerWithWriter_NonVerboseExcludesDebug(t *testing.T) { + var buf bytes.Buffer + logger := NewLoggerWithWriter(false, &buf) + + logger.Debug("debug message") + + output := buf.String() + if strings.Contains(output, "debug message") { + t.Errorf("expected no debug message in non-verbose output, got: %s", output) + } +} + +func TestNewLoggerWithWriter_NonVerboseIncludesInfo(t *testing.T) { + var buf bytes.Buffer + logger := NewLoggerWithWriter(false, &buf) + + logger.Info("info message") + + output := buf.String() + if !strings.Contains(output, "info message") { + t.Errorf("expected info message in non-verbose output, got: %s", output) + } +} + +func TestNewLoggerWithWriter_NonVerboseIsText(t *testing.T) { + var buf bytes.Buffer + logger := NewLoggerWithWriter(false, &buf) + + logger.Info("test message") + + output := buf.String() + // Text handler output should not be valid JSON. + var parsed map[string]any + if err := json.Unmarshal([]byte(output), &parsed); err == nil { + t.Errorf("non-verbose output should not be JSON, got: %s", output) + } +} + +func TestNewLogger_ReturnsSlogLogger(t *testing.T) { + logger := NewLogger(false) + if logger == nil { + t.Fatal("expected non-nil logger") + } + + // Verify it's a valid *slog.Logger by checking it implements the interface. + var _ *slog.Logger = logger +} diff --git a/internal/sarif/builder.go b/internal/sarif/builder.go new file mode 100644 index 0000000..8f9932d --- /dev/null +++ b/internal/sarif/builder.go @@ -0,0 +1,242 @@ +package sarif + +import ( + "fmt" + "log/slog" + "regexp" + "strings" +) + +const ( + sarifSchema = "https://raw.githubusercontent.com/oasis-tcs/sarif-spec/main/sarif-2.1/schema/sarif-schema-2.1.0.json" + sarifVersion = "2.1.0" + rulePrefix = "codecrucible." +) + +// FileMap maps relative file paths to their content for snippet extraction. +type FileMap map[string]string + +// BuilderConfig holds configuration for the SARIF builder. +type BuilderConfig struct { + ToolName string // default: "codecrucible" + ToolVersion string // default: "dev" + Logger *slog.Logger +} + +func (c BuilderConfig) toolName() string { + if c.ToolName != "" { + return c.ToolName + } + return "codecrucible" +} + +func (c BuilderConfig) toolVersion() string { + if c.ToolVersion != "" { + return c.ToolVersion + } + return "dev" +} + +func (c BuilderConfig) logger() *slog.Logger { + if c.Logger != nil { + return c.Logger + } + return slog.Default() +} + +// Build converts an AnalysisResult into a SARIF v2.1.0 document. +func Build(result AnalysisResult, fileMap FileMap, cfg BuilderConfig) SARIFDocument { + if fileMap == nil { + fileMap = FileMap{} + } + + log := cfg.logger() + + // Deduplicate rules: issue text → rule ID. + type ruleEntry struct { + rule SARIFRule + index int + } + rulesByID := make(map[string]*ruleEntry) + var rules []SARIFRule + var results []SARIFResult + + // Track CWE taxa for the taxonomy section. + cweTaxa := make(map[string]SARIFTaxon) + + for _, issue := range result.SecurityIssues { + ruleID := slugify(issue.Issue) + + if _, exists := rulesByID[ruleID]; !exists { + rule := SARIFRule{ + ID: ruleID, + ShortDescription: SARIFMessage{Text: issue.Issue}, + Properties: map[string]any{}, + } + rule.Properties["security-severity"] = fmt.Sprintf("%.1f", issue.Severity) + + // Always include "security" tag; add CWE tag in GitHub's expected format. + tags := []string{"security"} + cweTag := extractCWETag(issue.CWEID) + if cweTag != "" { + tags = append(tags, "external/cwe/"+cweTag) + + // Add a relationship to the CWE taxonomy. + cweID := strings.ToUpper(cweTag) // "CWE-89" + rule.Relationships = []SARIFRelationship{{ + Target: SARIFRelationshipTarget{ + ID: cweID, + ToolComponent: SARIFToolComponentRef{Name: "CWE"}, + }, + Kinds: []string{"superset"}, + }} + + // Record the taxon for the taxonomy section. + if _, seen := cweTaxa[cweID]; !seen { + cweTaxa[cweID] = SARIFTaxon{ + ID: cweID, + ShortDescription: SARIFMessage{Text: issue.CWEID}, + } + } + } + rule.Properties["tags"] = tags + + rulesByID[ruleID] = &ruleEntry{rule: rule, index: len(rules)} + rules = append(rules, rule) + } + + r := SARIFResult{ + RuleID: ruleID, + Level: severityLevel(issue.Severity), + Message: SARIFMessage{Text: issue.TechnicalDetails}, + } + + if issue.FilePath != "" { + loc := SARIFLocation{ + PhysicalLocation: SARIFPhysicalLocation{ + ArtifactLocation: SARIFArtifactLocation{URI: issue.FilePath}, + }, + } + + if issue.StartLine > 0 { + region := &SARIFRegion{StartLine: issue.StartLine} + if issue.EndLine > 0 { + region.EndLine = issue.EndLine + } + snippet := extractSnippet(fileMap, issue.FilePath, issue.StartLine, issue.EndLine, log) + if snippet != "" { + region.Snippet = &SARIFSnippet{Text: snippet} + } + loc.PhysicalLocation.Region = region + } + + r.Locations = []SARIFLocation{loc} + } + + results = append(results, r) + } + + // Guarantee non-nil slices for clean JSON output. + if rules == nil { + rules = []SARIFRule{} + } + if results == nil { + results = []SARIFResult{} + } + + run := SARIFRun{ + Tool: SARIFTool{ + Driver: SARIFDriver{ + Name: cfg.toolName(), + Version: cfg.toolVersion(), + Rules: rules, + }, + }, + Results: results, + Invocations: []SARIFInvocation{ + {ExecutionSuccessful: true}, + }, + } + + // Add the CWE taxonomy if any CWE references were found. + if len(cweTaxa) > 0 { + taxa := make([]SARIFTaxon, 0, len(cweTaxa)) + for _, t := range cweTaxa { + taxa = append(taxa, t) + } + run.Taxonomies = []SARIFTaxonomy{{ + Name: "CWE", + Organization: "MITRE", + ShortDescription: SARIFMessage{Text: "Common Weakness Enumeration"}, + Taxa: taxa, + }} + } + + return SARIFDocument{ + Schema: sarifSchema, + Version: sarifVersion, + Runs: []SARIFRun{run}, + } +} + +// severityLevel maps a numeric severity (0–10) to a SARIF level string. +func severityLevel(sev float64) string { + switch { + case sev <= 0: + return "none" + case sev < 4.0: + return "note" + case sev < 7.0: + return "warning" + default: + return "error" + } +} + +var slugRe = regexp.MustCompile(`[^a-z0-9]+`) + +// slugify converts an issue title to a rule ID like "codecrucible.sql-injection". +func slugify(text string) string { + s := strings.ToLower(text) + s = slugRe.ReplaceAllString(s, "-") + s = strings.Trim(s, "-") + return rulePrefix + s +} + +var cweRe = regexp.MustCompile(`(?i)CWE-(\d+)`) + +// extractCWETag pulls a tag like "cwe-89" from a string such as "CWE-89: SQL Injection". +func extractCWETag(id string) string { + m := cweRe.FindStringSubmatch(id) + if len(m) < 2 { + return "" + } + return "cwe-" + m[1] +} + +// extractSnippet returns the lines [startLine, endLine] from the file content +// stored in fileMap. Returns "" if the file is missing or lines are out of range. +func extractSnippet(fm FileMap, path string, startLine, endLine int, log *slog.Logger) string { + content, ok := fm[path] + if !ok { + log.Warn("file not found in FileMap, skipping snippet", "path", path) + return "" + } + + lines := strings.Split(content, "\n") + if startLine < 1 { + startLine = 1 + } + if endLine < startLine { + endLine = startLine + } + if startLine > len(lines) { + return "" + } + if endLine > len(lines) { + endLine = len(lines) + } + + selected := lines[startLine-1 : endLine] + return strings.Join(selected, "\n") +} diff --git a/internal/sarif/builder_test.go b/internal/sarif/builder_test.go new file mode 100644 index 0000000..46cadec --- /dev/null +++ b/internal/sarif/builder_test.go @@ -0,0 +1,402 @@ +package sarif + +import ( + "bytes" + "encoding/json" + "log/slog" + "testing" +) + +// --------------------------------------------------------------------------- +// severityLevel +// --------------------------------------------------------------------------- + +func TestSeverityLevel(t *testing.T) { + tests := []struct { + severity float64 + want string + }{ + {0, "none"}, + {0.1, "note"}, + {1.0, "note"}, + {3.9, "note"}, + {4.0, "warning"}, + {5.0, "warning"}, + {6.9, "warning"}, + {7.0, "error"}, + {9.0, "error"}, + {10.0, "error"}, + } + for _, tt := range tests { + got := severityLevel(tt.severity) + if got != tt.want { + t.Errorf("severityLevel(%v) = %q, want %q", tt.severity, got, tt.want) + } + } +} + +// --------------------------------------------------------------------------- +// slugify +// --------------------------------------------------------------------------- + +func TestSlugify(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"SQL Injection Vulnerability", "codecrucible.sql-injection-vulnerability"}, + {"XSS (Reflected)", "codecrucible.xss-reflected"}, + {" leading/trailing ", "codecrucible.leading-trailing"}, + {"multiple---hyphens", "codecrucible.multiple-hyphens"}, + {"UPPER CASE", "codecrucible.upper-case"}, + {"special!@#chars$%^&*()", "codecrucible.special-chars"}, + {"a", "codecrucible.a"}, + } + for _, tt := range tests { + got := slugify(tt.input) + if got != tt.want { + t.Errorf("slugify(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +// --------------------------------------------------------------------------- +// extractCWETag +// --------------------------------------------------------------------------- + +func TestExtractCWETag(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"CWE-89: SQL Injection", "cwe-89"}, + {"CWE-79", "cwe-79"}, + {"cwe-22: Path Traversal", "cwe-22"}, + {"no cwe here", ""}, + {"", ""}, + } + for _, tt := range tests { + got := extractCWETag(tt.input) + if got != tt.want { + t.Errorf("extractCWETag(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +// --------------------------------------------------------------------------- +// Rule deduplication +// --------------------------------------------------------------------------- + +func TestBuild_RuleDeduplication(t *testing.T) { + result := AnalysisResult{ + SecurityIssues: []SecurityIssue{ + {Issue: "SQL Injection", FilePath: "a.go", StartLine: 1, Severity: 9, CWEID: "CWE-89"}, + {Issue: "SQL Injection", FilePath: "b.go", StartLine: 5, Severity: 9, CWEID: "CWE-89"}, + }, + } + doc := Build(result, nil, BuilderConfig{}) + + rules := doc.Runs[0].Tool.Driver.Rules + if len(rules) != 1 { + t.Fatalf("expected 1 deduplicated rule, got %d", len(rules)) + } + + results := doc.Runs[0].Results + if len(results) != 2 { + t.Fatalf("expected 2 results, got %d", len(results)) + } + for _, r := range results { + if r.RuleID != rules[0].ID { + t.Errorf("result ruleId %q != rule id %q", r.RuleID, rules[0].ID) + } + } +} + +// --------------------------------------------------------------------------- +// Snippet population +// --------------------------------------------------------------------------- + +func TestBuild_SnippetPopulation(t *testing.T) { + fm := FileMap{ + "src/main.go": "line1\nline2\nline3\nline4\nline5\n", + } + result := AnalysisResult{ + SecurityIssues: []SecurityIssue{ + {Issue: "Bug", FilePath: "src/main.go", StartLine: 2, EndLine: 4, Severity: 5}, + }, + } + doc := Build(result, fm, BuilderConfig{}) + + loc := doc.Runs[0].Results[0].Locations[0].PhysicalLocation + if loc.Region == nil { + t.Fatal("expected region") + } + if loc.Region.Snippet == nil { + t.Fatal("expected snippet") + } + want := "line2\nline3\nline4" + if loc.Region.Snippet.Text != want { + t.Errorf("snippet = %q, want %q", loc.Region.Snippet.Text, want) + } +} + +// --------------------------------------------------------------------------- +// Missing file in FileMap — snippet is nil, no panic +// --------------------------------------------------------------------------- + +func TestBuild_MissingFileNoSnippet(t *testing.T) { + var buf bytes.Buffer + logger := slog.New(slog.NewTextHandler(&buf, nil)) + + result := AnalysisResult{ + SecurityIssues: []SecurityIssue{ + {Issue: "Bug", FilePath: "missing.go", StartLine: 1, EndLine: 2, Severity: 3}, + }, + } + doc := Build(result, FileMap{}, BuilderConfig{Logger: logger}) + + loc := doc.Runs[0].Results[0].Locations[0].PhysicalLocation + if loc.Region.Snippet != nil { + t.Errorf("expected nil snippet for missing file, got %+v", loc.Region.Snippet) + } + + if !bytes.Contains(buf.Bytes(), []byte("file not found in FileMap")) { + t.Error("expected warning log about missing file") + } +} + +// --------------------------------------------------------------------------- +// Empty SecurityIssues → valid SARIF with zero findings +// --------------------------------------------------------------------------- + +func TestBuild_EmptyIssues(t *testing.T) { + doc := Build(AnalysisResult{}, nil, BuilderConfig{}) + + if len(doc.Runs) != 1 { + t.Fatalf("expected 1 run, got %d", len(doc.Runs)) + } + if len(doc.Runs[0].Results) != 0 { + t.Errorf("expected 0 results, got %d", len(doc.Runs[0].Results)) + } + if len(doc.Runs[0].Tool.Driver.Rules) != 0 { + t.Errorf("expected 0 rules, got %d", len(doc.Runs[0].Tool.Driver.Rules)) + } + + // Must marshal cleanly (non-nil slices). + data, err := json.Marshal(doc) + if err != nil { + t.Fatalf("marshal failed: %v", err) + } + if !json.Valid(data) { + t.Error("output is not valid JSON") + } +} + +// --------------------------------------------------------------------------- +// CWE tags in rule properties +// --------------------------------------------------------------------------- + +func TestBuild_CWETags(t *testing.T) { + result := AnalysisResult{ + SecurityIssues: []SecurityIssue{ + {Issue: "Injection", Severity: 8, CWEID: "CWE-89: SQL Injection", FilePath: "a.go", StartLine: 1}, + }, + } + doc := Build(result, nil, BuilderConfig{}) + + rule := doc.Runs[0].Tool.Driver.Rules[0] + + // Verify tags include "security" and CWE in GitHub's expected format. + tags, ok := rule.Properties["tags"] + if !ok { + t.Fatal("expected tags property") + } + tagSlice, ok := tags.([]string) + if !ok { + t.Fatalf("tags is %T, want []string", tags) + } + if len(tagSlice) != 2 { + t.Fatalf("tags length = %d, want 2, got %v", len(tagSlice), tagSlice) + } + if tagSlice[0] != "security" { + t.Errorf("tags[0] = %q, want %q", tagSlice[0], "security") + } + if tagSlice[1] != "external/cwe/cwe-89" { + t.Errorf("tags[1] = %q, want %q", tagSlice[1], "external/cwe/cwe-89") + } + + // Verify relationship to CWE taxonomy. + if len(rule.Relationships) != 1 { + t.Fatalf("expected 1 relationship, got %d", len(rule.Relationships)) + } + rel := rule.Relationships[0] + if rel.Target.ID != "CWE-89" { + t.Errorf("relationship target ID = %q, want %q", rel.Target.ID, "CWE-89") + } + if rel.Target.ToolComponent.Name != "CWE" { + t.Errorf("relationship toolComponent = %q, want %q", rel.Target.ToolComponent.Name, "CWE") + } + + // Verify CWE taxonomy is included. + if len(doc.Runs[0].Taxonomies) != 1 { + t.Fatalf("expected 1 taxonomy, got %d", len(doc.Runs[0].Taxonomies)) + } + tax := doc.Runs[0].Taxonomies[0] + if tax.Name != "CWE" { + t.Errorf("taxonomy name = %q, want %q", tax.Name, "CWE") + } + if len(tax.Taxa) != 1 || tax.Taxa[0].ID != "CWE-89" { + t.Errorf("taxonomy taxa = %v, want [{ID: CWE-89}]", tax.Taxa) + } +} + +// --------------------------------------------------------------------------- +// security-severity property is set as string +// --------------------------------------------------------------------------- + +func TestBuild_SecuritySeverityProperty(t *testing.T) { + result := AnalysisResult{ + SecurityIssues: []SecurityIssue{ + {Issue: "Bug", Severity: 7.5, FilePath: "a.go", StartLine: 1}, + }, + } + doc := Build(result, nil, BuilderConfig{}) + + rule := doc.Runs[0].Tool.Driver.Rules[0] + val, ok := rule.Properties["security-severity"] + if !ok { + t.Fatal("expected security-severity property") + } + s, ok := val.(string) + if !ok { + t.Fatalf("security-severity is %T, want string", val) + } + if s != "7.5" { + t.Errorf("security-severity = %q, want %q", s, "7.5") + } +} + +// --------------------------------------------------------------------------- +// Full Build produces valid JSON with correct schema/version +// --------------------------------------------------------------------------- + +func TestBuild_ValidJSON(t *testing.T) { + fm := FileMap{ + "app/handler.go": "package app\n\nfunc Handle() {}\n", + } + result := AnalysisResult{ + RepoName: "test-repo", + Description: "A test repository", + SecurityIssues: []SecurityIssue{ + { + Issue: "Hardcoded Secret", + FilePath: "app/handler.go", + StartLine: 3, + EndLine: 3, + TechnicalDetails: "Secret found in source", + Severity: 8.0, + CWEID: "CWE-798", + }, + }, + } + doc := Build(result, fm, BuilderConfig{ToolName: "myTool", ToolVersion: "1.0.0"}) + + data, err := json.MarshalIndent(doc, "", " ") + if err != nil { + t.Fatalf("marshal failed: %v", err) + } + + // Re-parse to verify structure. + var parsed SARIFDocument + if err := json.Unmarshal(data, &parsed); err != nil { + t.Fatalf("unmarshal failed: %v", err) + } + + if parsed.Schema != sarifSchema { + t.Errorf("schema = %q, want %q", parsed.Schema, sarifSchema) + } + if parsed.Version != sarifVersion { + t.Errorf("version = %q, want %q", parsed.Version, sarifVersion) + } + if parsed.Runs[0].Tool.Driver.Name != "myTool" { + t.Errorf("tool name = %q, want %q", parsed.Runs[0].Tool.Driver.Name, "myTool") + } + if parsed.Runs[0].Tool.Driver.Version != "1.0.0" { + t.Errorf("tool version = %q, want %q", parsed.Runs[0].Tool.Driver.Version, "1.0.0") + } +} + +// --------------------------------------------------------------------------- +// Schema and version constants +// --------------------------------------------------------------------------- + +func TestBuild_SchemaAndVersion(t *testing.T) { + doc := Build(AnalysisResult{}, nil, BuilderConfig{}) + + if doc.Schema != "https://raw.githubusercontent.com/oasis-tcs/sarif-spec/main/sarif-2.1/schema/sarif-schema-2.1.0.json" { + t.Errorf("unexpected schema: %s", doc.Schema) + } + if doc.Version != "2.1.0" { + t.Errorf("unexpected version: %s", doc.Version) + } +} + +// --------------------------------------------------------------------------- +// Nil/empty AnalysisResult fields — no panics +// --------------------------------------------------------------------------- + +func TestBuild_NilFields(t *testing.T) { + cases := []struct { + name string + result AnalysisResult + }{ + {"zero value", AnalysisResult{}}, + {"nil issues", AnalysisResult{SecurityIssues: nil}}, + {"nil routes", AnalysisResult{PublicAPIRoutes: nil}}, + {"empty issues", AnalysisResult{SecurityIssues: []SecurityIssue{}}}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + doc := Build(tc.result, nil, BuilderConfig{}) + data, err := json.Marshal(doc) + if err != nil { + t.Fatalf("marshal failed: %v", err) + } + if !json.Valid(data) { + t.Error("output is not valid JSON") + } + }) + } +} + +// --------------------------------------------------------------------------- +// Invocation metadata +// --------------------------------------------------------------------------- + +func TestBuild_Invocation(t *testing.T) { + doc := Build(AnalysisResult{}, nil, BuilderConfig{}) + + invocations := doc.Runs[0].Invocations + if len(invocations) != 1 { + t.Fatalf("expected 1 invocation, got %d", len(invocations)) + } + if !invocations[0].ExecutionSuccessful { + t.Error("expected executionSuccessful = true") + } +} + +// --------------------------------------------------------------------------- +// Default config values +// --------------------------------------------------------------------------- + +func TestBuild_DefaultConfig(t *testing.T) { + doc := Build(AnalysisResult{}, nil, BuilderConfig{}) + + driver := doc.Runs[0].Tool.Driver + if driver.Name != "codecrucible" { + t.Errorf("default tool name = %q, want %q", driver.Name, "codecrucible") + } + if driver.Version != "dev" { + t.Errorf("default tool version = %q, want %q", driver.Version, "dev") + } +} diff --git a/internal/sarif/contract_test.go b/internal/sarif/contract_test.go new file mode 100644 index 0000000..b3047a4 --- /dev/null +++ b/internal/sarif/contract_test.go @@ -0,0 +1,196 @@ +package sarif + +import ( + "encoding/json" + "os" + "path/filepath" + "regexp" + "strings" + "testing" +) + +const fixtureDir = "../../testdata/fixtures/llm-responses" + +// stripMarkdownFences removes ```json ... ``` wrapping from LLM output. +func stripMarkdownFences(data []byte) []byte { + s := strings.TrimSpace(string(data)) + re := regexp.MustCompile("(?s)^```(?:json)?\\s*\n?(.*?)\\s*```$") + if m := re.FindStringSubmatch(s); len(m) == 2 { + return []byte(m[1]) + } + return data +} + +func TestContract_LLMResponseFixtures(t *testing.T) { + tests := []struct { + file string + wantUnmarshalOK bool + wantIssues int // expected issue count (-1 = don't check) + wantRoutes int // expected route count (-1 = don't check) + stripMarkdown bool + }{ + // Valid fixtures + { + file: "valid_claude_single_finding.json", + wantUnmarshalOK: true, + wantIssues: 1, + wantRoutes: 0, + }, + { + file: "valid_gpt_multiple_findings.json", + wantUnmarshalOK: true, + wantIssues: 3, + wantRoutes: 0, + }, + { + file: "valid_gemini_zero_findings.json", + wantUnmarshalOK: true, + wantIssues: 0, + wantRoutes: 0, + }, + { + file: "valid_with_api_routes.json", + wantUnmarshalOK: true, + wantIssues: 2, + wantRoutes: 3, + }, + + // Malformed fixtures — still parseable by Go's json package + { + file: "malformed_missing_fields.json", + wantUnmarshalOK: true, + wantIssues: 1, + wantRoutes: 0, + }, + { + file: "malformed_extra_fields.json", + wantUnmarshalOK: true, + wantIssues: 1, + wantRoutes: 0, + }, + { + file: "malformed_empty_arrays.json", + wantUnmarshalOK: true, + wantIssues: 0, + wantRoutes: 0, + }, + { + file: "malformed_null_values.json", + wantUnmarshalOK: true, + wantIssues: 1, + wantRoutes: 0, + }, + + // Type mismatch — json.Unmarshal will fail + { + file: "malformed_wrong_types.json", + wantUnmarshalOK: false, + wantIssues: -1, + wantRoutes: -1, + }, + + // Markdown-wrapped — needs fence stripping first + { + file: "malformed_markdown_wrapped.json", + wantUnmarshalOK: true, + wantIssues: 1, + wantRoutes: 0, + stripMarkdown: true, + }, + } + + for _, tt := range tests { + t.Run(tt.file, func(t *testing.T) { + path := filepath.Join(fixtureDir, tt.file) + raw, err := os.ReadFile(path) + if err != nil { + t.Fatalf("failed to read fixture %s: %v", tt.file, err) + } + + data := raw + if tt.stripMarkdown { + data = stripMarkdownFences(raw) + } + + var result AnalysisResult + unmarshalErr := json.Unmarshal(data, &result) + + if tt.wantUnmarshalOK && unmarshalErr != nil { + t.Fatalf("expected unmarshal to succeed, got: %v", unmarshalErr) + } + if !tt.wantUnmarshalOK && unmarshalErr == nil { + t.Fatal("expected unmarshal to fail, but it succeeded") + } + + // Even on unmarshal failure, Build() must not panic and must + // produce valid SARIF from the zero/partial result. + doc := Build(result, nil, BuilderConfig{}) + + sarifJSON, err := json.Marshal(doc) + if err != nil { + t.Fatalf("json.Marshal of SARIF document failed: %v", err) + } + if !json.Valid(sarifJSON) { + t.Fatal("Build() produced invalid JSON") + } + + // Verify SARIF schema and version. + var parsed SARIFDocument + if err := json.Unmarshal(sarifJSON, &parsed); err != nil { + t.Fatalf("failed to re-parse SARIF output: %v", err) + } + if parsed.Schema != sarifSchema { + t.Errorf("schema = %q, want %q", parsed.Schema, sarifSchema) + } + if parsed.Version != sarifVersion { + t.Errorf("version = %q, want %q", parsed.Version, sarifVersion) + } + if len(parsed.Runs) != 1 { + t.Fatalf("expected 1 run, got %d", len(parsed.Runs)) + } + + // Verify counts when expected. + if tt.wantIssues >= 0 { + got := len(parsed.Runs[0].Results) + if got != tt.wantIssues { + t.Errorf("result count = %d, want %d", got, tt.wantIssues) + } + } + if tt.wantRoutes >= 0 && tt.wantUnmarshalOK { + got := len(result.PublicAPIRoutes) + if got != tt.wantRoutes { + t.Errorf("route count = %d, want %d", got, tt.wantRoutes) + } + } + }) + } +} + +// TestContract_MarkdownStripping verifies that the fence-stripping helper +// correctly unwraps markdown-wrapped JSON so it can be parsed. +func TestContract_MarkdownStripping(t *testing.T) { + path := filepath.Join(fixtureDir, "malformed_markdown_wrapped.json") + raw, err := os.ReadFile(path) + if err != nil { + t.Fatalf("failed to read fixture: %v", err) + } + + // Raw content should NOT parse as valid JSON. + if json.Valid(raw) { + t.Fatal("expected raw markdown-wrapped content to be invalid JSON") + } + + // After stripping, it should parse. + stripped := stripMarkdownFences(raw) + if !json.Valid(stripped) { + t.Fatal("expected stripped content to be valid JSON") + } + + var result AnalysisResult + if err := json.Unmarshal(stripped, &result); err != nil { + t.Fatalf("unmarshal of stripped content failed: %v", err) + } + if len(result.SecurityIssues) != 1 { + t.Errorf("expected 1 security issue, got %d", len(result.SecurityIssues)) + } +} diff --git a/internal/sarif/merge.go b/internal/sarif/merge.go new file mode 100644 index 0000000..b96ae98 --- /dev/null +++ b/internal/sarif/merge.go @@ -0,0 +1,167 @@ +package sarif + +import "strings" + +// Merge combines multiple SARIF documents (one per chunk) into a single document. +// It deduplicates rules by normalized ID and results by (ruleID, filePath, startLine). +// Partial failure notifications from all chunks are preserved. +func Merge(docs []SARIFDocument) SARIFDocument { + if len(docs) == 0 { + return SARIFDocument{ + Schema: sarifSchema, + Version: sarifVersion, + Runs: []SARIFRun{{ + Tool: SARIFTool{Driver: SARIFDriver{ + Name: "codecrucible", + Version: "dev", + Rules: []SARIFRule{}, + }}, + Results: []SARIFResult{}, + Invocations: []SARIFInvocation{{ExecutionSuccessful: true}}, + }}, + } + } + + if len(docs) == 1 { + return docs[0] + } + + // Use tool metadata from the first document. + first := docs[0].Runs[0] + toolName := first.Tool.Driver.Name + toolVersion := first.Tool.Driver.Version + infoURI := first.Tool.Driver.InformationURI + + // Collect and deduplicate rules. + ruleIndex := make(map[string]int) // normalized rule ID → index in rules slice + var rules []SARIFRule + + // Collect and deduplicate results. + type resultKey struct { + ruleID string + filePath string + startLine int + } + seenResults := make(map[resultKey]bool) + var results []SARIFResult + + // Collect notifications from all chunks. + var notifications []SARIFNotification + allSuccessful := true + + // Collect and deduplicate CWE taxa across chunks. + taxaIndex := make(map[string]bool) + var taxa []SARIFTaxon + + for _, doc := range docs { + if len(doc.Runs) == 0 { + continue + } + run := doc.Runs[0] + + // Merge taxonomies. + for _, taxonomy := range run.Taxonomies { + for _, taxon := range taxonomy.Taxa { + if !taxaIndex[taxon.ID] { + taxaIndex[taxon.ID] = true + taxa = append(taxa, taxon) + } + } + } + + // Merge rules. + for _, rule := range run.Tool.Driver.Rules { + normID := normalizeRuleID(rule.ID) + if _, exists := ruleIndex[normID]; !exists { + // Store with normalized ID. + ruleCopy := rule + ruleCopy.ID = normID + ruleIndex[normID] = len(rules) + rules = append(rules, ruleCopy) + } + } + + // Merge results. + for _, r := range run.Results { + normID := normalizeRuleID(r.RuleID) + + var fp string + var sl int + if len(r.Locations) > 0 { + fp = r.Locations[0].PhysicalLocation.ArtifactLocation.URI + if r.Locations[0].PhysicalLocation.Region != nil { + sl = r.Locations[0].PhysicalLocation.Region.StartLine + } + } + + key := resultKey{ruleID: normID, filePath: fp, startLine: sl} + if seenResults[key] { + continue + } + seenResults[key] = true + + rCopy := r + rCopy.RuleID = normID + results = append(results, rCopy) + } + + // Merge invocations. + for _, inv := range run.Invocations { + if !inv.ExecutionSuccessful { + allSuccessful = false + } + notifications = append(notifications, inv.ToolExecutionNotifications...) + } + } + + if rules == nil { + rules = []SARIFRule{} + } + if results == nil { + results = []SARIFResult{} + } + + inv := SARIFInvocation{ + ExecutionSuccessful: allSuccessful, + ToolExecutionNotifications: notifications, + } + + run := SARIFRun{ + Tool: SARIFTool{Driver: SARIFDriver{ + Name: toolName, + Version: toolVersion, + InformationURI: infoURI, + Rules: rules, + }}, + Results: results, + Invocations: []SARIFInvocation{inv}, + } + + if len(taxa) > 0 { + run.Taxonomies = []SARIFTaxonomy{{ + Name: "CWE", + Organization: "MITRE", + ShortDescription: SARIFMessage{Text: "Common Weakness Enumeration"}, + Taxa: taxa, + }} + } + + return SARIFDocument{ + Schema: sarifSchema, + Version: sarifVersion, + Runs: []SARIFRun{run}, + } +} + +// normalizeRuleID normalizes a rule ID for deduplication: +// strips trailing punctuation, lowercases, then re-slugifies. +func normalizeRuleID(id string) string { + // If it already has the prefix, strip it to get the raw text. + text := strings.TrimPrefix(id, rulePrefix) + // Convert slug back to words for normalization. + text = strings.ReplaceAll(text, "-", " ") + // Strip trailing punctuation. + text = strings.TrimRight(text, "., ;:!?") + // Re-slugify. + return slugify(text) +} diff --git a/internal/sarif/merge_test.go b/internal/sarif/merge_test.go new file mode 100644 index 0000000..4f241ef --- /dev/null +++ b/internal/sarif/merge_test.go @@ -0,0 +1,278 @@ +package sarif + +import ( + "encoding/json" + "testing" +) + +func makeDoc(rules []SARIFRule, results []SARIFResult, invocations []SARIFInvocation) SARIFDocument { + if rules == nil { + rules = []SARIFRule{} + } + if results == nil { + results = []SARIFResult{} + } + if invocations == nil { + invocations = []SARIFInvocation{{ExecutionSuccessful: true}} + } + return SARIFDocument{ + Schema: sarifSchema, + Version: sarifVersion, + Runs: []SARIFRun{{ + Tool: SARIFTool{Driver: SARIFDriver{ + Name: "codecrucible", + Version: "1.0", + Rules: rules, + }}, + Results: results, + Invocations: invocations, + }}, + } +} + +func TestMerge_Empty(t *testing.T) { + doc := Merge(nil) + if len(doc.Runs) != 1 { + t.Fatalf("expected 1 run, got %d", len(doc.Runs)) + } + if len(doc.Runs[0].Results) != 0 { + t.Errorf("expected 0 results, got %d", len(doc.Runs[0].Results)) + } + if len(doc.Runs[0].Tool.Driver.Rules) != 0 { + t.Errorf("expected 0 rules, got %d", len(doc.Runs[0].Tool.Driver.Rules)) + } +} + +func TestMerge_SingleDocument(t *testing.T) { + doc := makeDoc( + []SARIFRule{{ID: "codecrucible.xss", ShortDescription: SARIFMessage{Text: "XSS"}}}, + []SARIFResult{{RuleID: "codecrucible.xss", Level: "error", Message: SARIFMessage{Text: "found xss"}}}, + nil, + ) + merged := Merge([]SARIFDocument{doc}) + + if len(merged.Runs[0].Tool.Driver.Rules) != 1 { + t.Errorf("expected 1 rule, got %d", len(merged.Runs[0].Tool.Driver.Rules)) + } + if len(merged.Runs[0].Results) != 1 { + t.Errorf("expected 1 result, got %d", len(merged.Runs[0].Results)) + } +} + +func TestMerge_NoOverlappingRules(t *testing.T) { + doc1 := makeDoc( + []SARIFRule{{ID: "codecrucible.xss", ShortDescription: SARIFMessage{Text: "XSS"}}}, + []SARIFResult{{ + RuleID: "codecrucible.xss", Level: "error", + Message: SARIFMessage{Text: "xss"}, + Locations: []SARIFLocation{{PhysicalLocation: SARIFPhysicalLocation{ArtifactLocation: SARIFArtifactLocation{URI: "a.go"}, Region: &SARIFRegion{StartLine: 1}}}}, + }}, + nil, + ) + doc2 := makeDoc( + []SARIFRule{{ID: "codecrucible.sqli", ShortDescription: SARIFMessage{Text: "SQLi"}}}, + []SARIFResult{{ + RuleID: "codecrucible.sqli", Level: "warning", + Message: SARIFMessage{Text: "sqli"}, + Locations: []SARIFLocation{{PhysicalLocation: SARIFPhysicalLocation{ArtifactLocation: SARIFArtifactLocation{URI: "b.go"}, Region: &SARIFRegion{StartLine: 5}}}}, + }}, + nil, + ) + + merged := Merge([]SARIFDocument{doc1, doc2}) + if len(merged.Runs[0].Tool.Driver.Rules) != 2 { + t.Errorf("expected 2 rules, got %d", len(merged.Runs[0].Tool.Driver.Rules)) + } + if len(merged.Runs[0].Results) != 2 { + t.Errorf("expected 2 results, got %d", len(merged.Runs[0].Results)) + } +} + +func TestMerge_OverlappingRulesDedup(t *testing.T) { + doc1 := makeDoc( + []SARIFRule{{ID: "codecrucible.sql-injection", ShortDescription: SARIFMessage{Text: "SQL Injection"}}}, + []SARIFResult{{ + RuleID: "codecrucible.sql-injection", Level: "error", + Message: SARIFMessage{Text: "injection in a"}, + Locations: []SARIFLocation{{PhysicalLocation: SARIFPhysicalLocation{ArtifactLocation: SARIFArtifactLocation{URI: "a.go"}, Region: &SARIFRegion{StartLine: 10}}}}, + }}, + nil, + ) + // Same rule but slightly different slug (trailing period). + doc2 := makeDoc( + []SARIFRule{{ID: "codecrucible.sql-injection-", ShortDescription: SARIFMessage{Text: "SQL Injection."}}}, + []SARIFResult{{ + RuleID: "codecrucible.sql-injection-", Level: "error", + Message: SARIFMessage{Text: "injection in b"}, + Locations: []SARIFLocation{{PhysicalLocation: SARIFPhysicalLocation{ArtifactLocation: SARIFArtifactLocation{URI: "b.go"}, Region: &SARIFRegion{StartLine: 20}}}}, + }}, + nil, + ) + + merged := Merge([]SARIFDocument{doc1, doc2}) + if len(merged.Runs[0].Tool.Driver.Rules) != 1 { + t.Errorf("expected 1 deduplicated rule, got %d", len(merged.Runs[0].Tool.Driver.Rules)) + } + if len(merged.Runs[0].Results) != 2 { + t.Errorf("expected 2 results (different locations), got %d", len(merged.Runs[0].Results)) + } +} + +func TestMerge_ResultDedup(t *testing.T) { + result := SARIFResult{ + RuleID: "codecrucible.xss", Level: "error", + Message: SARIFMessage{Text: "found xss"}, + Locations: []SARIFLocation{{PhysicalLocation: SARIFPhysicalLocation{ArtifactLocation: SARIFArtifactLocation{URI: "handler.go"}, Region: &SARIFRegion{StartLine: 42}}}}, + } + doc1 := makeDoc( + []SARIFRule{{ID: "codecrucible.xss", ShortDescription: SARIFMessage{Text: "XSS"}}}, + []SARIFResult{result}, + nil, + ) + doc2 := makeDoc( + []SARIFRule{{ID: "codecrucible.xss", ShortDescription: SARIFMessage{Text: "XSS"}}}, + []SARIFResult{result}, + nil, + ) + + merged := Merge([]SARIFDocument{doc1, doc2}) + if len(merged.Runs[0].Results) != 1 { + t.Errorf("expected 1 deduplicated result, got %d", len(merged.Runs[0].Results)) + } +} + +func TestMerge_EmptyChunk(t *testing.T) { + doc1 := makeDoc( + []SARIFRule{{ID: "codecrucible.xss", ShortDescription: SARIFMessage{Text: "XSS"}}}, + []SARIFResult{{RuleID: "codecrucible.xss", Level: "error", Message: SARIFMessage{Text: "xss"}}}, + nil, + ) + doc2 := makeDoc(nil, nil, nil) // empty chunk + + merged := Merge([]SARIFDocument{doc1, doc2}) + if len(merged.Runs[0].Tool.Driver.Rules) != 1 { + t.Errorf("expected 1 rule, got %d", len(merged.Runs[0].Tool.Driver.Rules)) + } + if len(merged.Runs[0].Results) != 1 { + t.Errorf("expected 1 result, got %d", len(merged.Runs[0].Results)) + } +} + +func TestMerge_PartialFailureNotifications(t *testing.T) { + doc1 := makeDoc(nil, nil, []SARIFInvocation{{ + ExecutionSuccessful: true, + ToolExecutionNotifications: nil, + }}) + doc2 := makeDoc(nil, nil, []SARIFInvocation{{ + ExecutionSuccessful: false, + ToolExecutionNotifications: []SARIFNotification{{ + Level: "error", + Message: SARIFMessage{Text: "chunk 2 failed after retries"}, + }}, + }}) + + merged := Merge([]SARIFDocument{doc1, doc2}) + inv := merged.Runs[0].Invocations[0] + if inv.ExecutionSuccessful { + t.Error("expected ExecutionSuccessful=false when any chunk failed") + } + if len(inv.ToolExecutionNotifications) != 1 { + t.Fatalf("expected 1 notification, got %d", len(inv.ToolExecutionNotifications)) + } + if inv.ToolExecutionNotifications[0].Message.Text != "chunk 2 failed after retries" { + t.Errorf("unexpected notification: %s", inv.ToolExecutionNotifications[0].Message.Text) + } +} + +func TestMerge_RuleNormalization(t *testing.T) { + tests := []struct { + name string + ids []string + want int // expected number of merged rules + }{ + { + name: "trailing punctuation", + ids: []string{"codecrucible.sql-injection", "codecrucible.sql-injection-"}, + want: 1, + }, + { + name: "different rules stay separate", + ids: []string{"codecrucible.xss", "codecrucible.sqli"}, + want: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var docs []SARIFDocument + for i, id := range tt.ids { + docs = append(docs, makeDoc( + []SARIFRule{{ID: id, ShortDescription: SARIFMessage{Text: id}}}, + []SARIFResult{{ + RuleID: id, Level: "warning", + Message: SARIFMessage{Text: "test"}, + Locations: []SARIFLocation{{PhysicalLocation: SARIFPhysicalLocation{ArtifactLocation: SARIFArtifactLocation{URI: "file.go"}, Region: &SARIFRegion{StartLine: i + 1}}}}, + }}, + nil, + )) + } + merged := Merge(docs) + if got := len(merged.Runs[0].Tool.Driver.Rules); got != tt.want { + t.Errorf("rules: got %d, want %d", got, tt.want) + } + }) + } +} + +func TestMerge_ProducesValidJSON(t *testing.T) { + doc1 := makeDoc( + []SARIFRule{{ID: "codecrucible.test", ShortDescription: SARIFMessage{Text: "Test"}, Properties: map[string]any{"security-severity": "5.0"}}}, + []SARIFResult{{RuleID: "codecrucible.test", Level: "warning", Message: SARIFMessage{Text: "details"}}}, + nil, + ) + merged := Merge([]SARIFDocument{doc1}) + + b, err := json.MarshalIndent(merged, "", " ") + if err != nil { + t.Fatalf("failed to marshal merged SARIF: %v", err) + } + + var check SARIFDocument + if err := json.Unmarshal(b, &check); err != nil { + t.Fatalf("failed to unmarshal merged SARIF: %v", err) + } + if check.Version != sarifVersion { + t.Errorf("version: got %q, want %q", check.Version, sarifVersion) + } +} + +func TestMerge_InvocationAllSuccessful(t *testing.T) { + doc1 := makeDoc(nil, nil, []SARIFInvocation{{ExecutionSuccessful: true}}) + doc2 := makeDoc(nil, nil, []SARIFInvocation{{ExecutionSuccessful: true}}) + + merged := Merge([]SARIFDocument{doc1, doc2}) + if !merged.Runs[0].Invocations[0].ExecutionSuccessful { + t.Error("expected ExecutionSuccessful=true when all chunks succeed") + } +} + +func TestNormalizeRuleID(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"codecrucible.sql-injection", "codecrucible.sql-injection"}, + {"codecrucible.sql-injection-", "codecrucible.sql-injection"}, + {"codecrucible.xss", "codecrucible.xss"}, + {"codecrucible.some-issue-with-trailing-period-", "codecrucible.some-issue-with-trailing-period"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := normalizeRuleID(tt.input) + if got != tt.want { + t.Errorf("normalizeRuleID(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} diff --git a/internal/sarif/postprocess.go b/internal/sarif/postprocess.go new file mode 100644 index 0000000..e5ec061 --- /dev/null +++ b/internal/sarif/postprocess.go @@ -0,0 +1,241 @@ +package sarif + +import ( + "log/slog" + "path" + "strconv" + "strings" +) + +// nonSourceExtensions lists file extensions that are not considered source code. +var nonSourceExtensions = map[string]bool{ + ".json": true, + ".md": true, + ".txt": true, + ".yaml": true, + ".yml": true, + ".toml": true, + ".cfg": true, + ".ini": true, + ".key": true, + ".pem": true, + ".crt": true, + ".lock": true, +} + +// nonSourceFilenames lists specific filenames that are not considered source code. +var nonSourceFilenames = map[string]bool{ + "package.json": true, + "package-lock.json": true, + "go.sum": true, + "yarn.lock": true, + "Gemfile.lock": true, +} + +// PostProcess deduplicates findings, deprioritizes non-source file findings, +// and removes orphaned rules/taxa from the SARIF document. +func PostProcess(doc SARIFDocument) SARIFDocument { + if len(doc.Runs) == 0 { + return doc + } + + run := doc.Runs[0] + + // Build a rule lookup: ruleID → SARIFRule. + ruleByID := make(map[string]SARIFRule, len(run.Tool.Driver.Rules)) + for _, rule := range run.Tool.Driver.Rules { + ruleByID[rule.ID] = rule + } + + // --- Step 1: Deduplicate results by (file URI, startLine, CWE) --- + run.Results = deduplicateResults(run.Results, ruleByID) + + // --- Step 2: Remove low-severity non-source file findings --- + run.Results = deprioritizeNonSource(run.Results) + + // --- Step 3: Clean up orphaned rules and taxa --- + run.Tool.Driver.Rules, run.Taxonomies = cleanOrphans(run.Results, ruleByID, run.Taxonomies) + + doc.Runs[0] = run + return doc +} + +// dedupKey uniquely identifies a finding for deduplication. +type dedupKey struct { + fileURI string + startLine int + cwe string +} + +// deduplicateResults keeps only the highest-severity result per (file, startLine, CWE). +func deduplicateResults(results []SARIFResult, ruleByID map[string]SARIFRule) []SARIFResult { + bestIndex := make(map[dedupKey]int) // key → index into deduped slice + var deduped []SARIFResult + + for _, r := range results { + var fileURI string + var startLine int + if len(r.Locations) > 0 { + fileURI = r.Locations[0].PhysicalLocation.ArtifactLocation.URI + if r.Locations[0].PhysicalLocation.Region != nil { + startLine = r.Locations[0].PhysicalLocation.Region.StartLine + } + } + + cwe := CWEForRule(ruleByID[r.RuleID]) + key := dedupKey{fileURI: fileURI, startLine: startLine, cwe: cwe} + + if idx, exists := bestIndex[key]; exists { + existingSev := ruleSeverity(ruleByID[deduped[idx].RuleID]) + newSev := ruleSeverity(ruleByID[r.RuleID]) + if newSev > existingSev { + deduped[idx] = r + } + } else { + bestIndex[key] = len(deduped) + deduped = append(deduped, r) + } + } + + removed := len(results) - len(deduped) + if removed > 0 { + slog.Info("postprocess: deduplicated results", "removed", removed, "remaining", len(deduped)) + } + + if deduped == nil { + deduped = []SARIFResult{} + } + return deduped +} + +// deprioritizeNonSource removes findings in non-source files that have +// severity "note" or "warning", keeping only "error"-level findings. +// This prevents low-value config/data findings from drowning out real +// vulnerabilities in route handlers and source code. +func deprioritizeNonSource(results []SARIFResult) []SARIFResult { + var kept []SARIFResult + removed := 0 + for _, r := range results { + if len(r.Locations) == 0 { + kept = append(kept, r) + continue + } + uri := r.Locations[0].PhysicalLocation.ArtifactLocation.URI + if isNonSourceFile(uri) && r.Level != "error" { + removed++ + continue + } + kept = append(kept, r) + } + if removed > 0 { + slog.Info("postprocess: removed low-severity non-source findings", "removed", removed, "remaining", len(kept)) + } + if kept == nil { + kept = []SARIFResult{} + } + return kept +} + +// cleanOrphans removes rules not referenced by any result and taxa not referenced +// by any remaining rule. +func cleanOrphans(results []SARIFResult, ruleByID map[string]SARIFRule, taxonomies []SARIFTaxonomy) ([]SARIFRule, []SARIFTaxonomy) { + // Determine which rule IDs are still referenced. + usedRuleIDs := make(map[string]bool, len(results)) + for _, r := range results { + usedRuleIDs[r.RuleID] = true + } + + // Keep only referenced rules and collect their CWE references. + var rules []SARIFRule + usedCWEs := make(map[string]bool) + for id := range usedRuleIDs { + rule, ok := ruleByID[id] + if !ok { + continue + } + rules = append(rules, rule) + for _, rel := range rule.Relationships { + usedCWEs[rel.Target.ID] = true + } + } + if rules == nil { + rules = []SARIFRule{} + } + + // Filter taxa in taxonomies. + var filteredTaxonomies []SARIFTaxonomy + for _, taxonomy := range taxonomies { + var filteredTaxa []SARIFTaxon + for _, taxon := range taxonomy.Taxa { + if usedCWEs[taxon.ID] { + filteredTaxa = append(filteredTaxa, taxon) + } + } + if len(filteredTaxa) > 0 { + taxonomyCopy := taxonomy + taxonomyCopy.Taxa = filteredTaxa + filteredTaxonomies = append(filteredTaxonomies, taxonomyCopy) + } + } + + return rules, filteredTaxonomies +} + +// CWEForRule extracts the CWE identifier from a rule's relationships or tags. +// It returns a string like "CWE-89" or "" if no CWE is found. +func CWEForRule(rule SARIFRule) string { + // Try relationships first. + if len(rule.Relationships) > 0 { + id := rule.Relationships[0].Target.ID + if strings.HasPrefix(id, "CWE-") { + return id + } + } + + // Fall back to tags in properties. + if tags, ok := rule.Properties["tags"]; ok { + if tagSlice, ok := tags.([]string); ok { + for _, tag := range tagSlice { + if strings.HasPrefix(tag, "external/cwe/cwe-") { + num := strings.TrimPrefix(tag, "external/cwe/cwe-") + return "CWE-" + num + } + } + } + } + + return "" +} + +// ruleSeverity extracts the numeric security-severity from a rule's properties. +func ruleSeverity(rule SARIFRule) float64 { + if rule.Properties == nil { + return 0 + } + raw, ok := rule.Properties["security-severity"] + if !ok { + return 0 + } + switch v := raw.(type) { + case string: + f, err := strconv.ParseFloat(v, 64) + if err != nil { + return 0 + } + return f + case float64: + return v + default: + return 0 + } +} + +// isNonSourceFile returns true if the URI points to a non-source file. +func isNonSourceFile(uri string) bool { + base := path.Base(uri) + if nonSourceFilenames[base] { + return true + } + ext := path.Ext(uri) + return nonSourceExtensions[ext] +} diff --git a/internal/sarif/postprocess_test.go b/internal/sarif/postprocess_test.go new file mode 100644 index 0000000..ae2c935 --- /dev/null +++ b/internal/sarif/postprocess_test.go @@ -0,0 +1,270 @@ +package sarif + +import "testing" + +// helper to build a minimal SARIF document for testing. +func testDoc(rules []SARIFRule, results []SARIFResult, taxa []SARIFTaxon) SARIFDocument { + var taxonomies []SARIFTaxonomy + if len(taxa) > 0 { + taxonomies = []SARIFTaxonomy{{ + Name: "CWE", + Organization: "MITRE", + ShortDescription: SARIFMessage{Text: "Common Weakness Enumeration"}, + Taxa: taxa, + }} + } + return SARIFDocument{ + Schema: sarifSchema, + Version: sarifVersion, + Runs: []SARIFRun{{ + Tool: SARIFTool{Driver: SARIFDriver{ + Name: "codecrucible", + Version: "dev", + Rules: rules, + }}, + Results: results, + Taxonomies: taxonomies, + }}, + } +} + +func makeRule(id string, severity string, cweID string) SARIFRule { + rule := SARIFRule{ + ID: id, + ShortDescription: SARIFMessage{Text: id}, + Properties: map[string]any{ + "security-severity": severity, + "tags": []string{"security", "external/cwe/" + cweID}, + }, + } + if cweID != "" { + upper := "CWE-" + cweID[len("cwe-"):] + rule.Relationships = []SARIFRelationship{{ + Target: SARIFRelationshipTarget{ + ID: upper, + ToolComponent: SARIFToolComponentRef{Name: "CWE"}, + }, + Kinds: []string{"superset"}, + }} + } + return rule +} + +func makeResult(ruleID, fileURI string, startLine int, level string) SARIFResult { + return SARIFResult{ + RuleID: ruleID, + Level: level, + Message: SARIFMessage{Text: "test finding"}, + Locations: []SARIFLocation{{ + PhysicalLocation: SARIFPhysicalLocation{ + ArtifactLocation: SARIFArtifactLocation{URI: fileURI}, + Region: &SARIFRegion{StartLine: startLine}, + }, + }}, + } +} + +func TestDedup_SameFileSameLineSameCWE_KeepsHighestSeverity(t *testing.T) { + ruleA := makeRule("rule-low", "3.0", "cwe-89") + ruleB := makeRule("rule-high", "9.0", "cwe-89") + + results := []SARIFResult{ + makeResult("rule-low", "src/app.go", 10, "note"), + makeResult("rule-high", "src/app.go", 10, "error"), + } + + doc := testDoc([]SARIFRule{ruleA, ruleB}, results, []SARIFTaxon{ + {ID: "CWE-89", ShortDescription: SARIFMessage{Text: "SQL Injection"}}, + }) + + got := PostProcess(doc) + run := got.Runs[0] + + if len(run.Results) != 1 { + t.Fatalf("expected 1 result after dedup, got %d", len(run.Results)) + } + if run.Results[0].RuleID != "rule-high" { + t.Errorf("expected highest severity rule kept (rule-high), got %s", run.Results[0].RuleID) + } +} + +func TestDedup_SameFileSameLineDifferentCWE_KeepsBoth(t *testing.T) { + ruleA := makeRule("rule-sqli", "7.0", "cwe-89") + ruleB := makeRule("rule-xss", "5.0", "cwe-79") + + results := []SARIFResult{ + makeResult("rule-sqli", "src/app.go", 10, "error"), + makeResult("rule-xss", "src/app.go", 10, "warning"), + } + + doc := testDoc([]SARIFRule{ruleA, ruleB}, results, []SARIFTaxon{ + {ID: "CWE-89", ShortDescription: SARIFMessage{Text: "SQL Injection"}}, + {ID: "CWE-79", ShortDescription: SARIFMessage{Text: "XSS"}}, + }) + + got := PostProcess(doc) + run := got.Runs[0] + + if len(run.Results) != 2 { + t.Fatalf("expected 2 results (different CWEs), got %d", len(run.Results)) + } +} + +func TestDeprioritize_NonSourceFile_LowSeverityRemoved(t *testing.T) { + rule := makeRule("rule-secret", "3.0", "cwe-798") + results := []SARIFResult{ + makeResult("rule-secret", "package.json", 5, "warning"), + } + + doc := testDoc([]SARIFRule{rule}, results, []SARIFTaxon{ + {ID: "CWE-798", ShortDescription: SARIFMessage{Text: "Hardcoded Credentials"}}, + }) + + got := PostProcess(doc) + run := got.Runs[0] + + if len(run.Results) != 0 { + t.Fatalf("expected 0 results (low-severity non-source removed), got %d", len(run.Results)) + } +} + +func TestDeprioritize_NonSourceFile_ErrorLevelKept(t *testing.T) { + rule := makeRule("rule-secret", "8.0", "cwe-798") + results := []SARIFResult{ + makeResult("rule-secret", "package.json", 5, "error"), + } + + doc := testDoc([]SARIFRule{rule}, results, []SARIFTaxon{ + {ID: "CWE-798", ShortDescription: SARIFMessage{Text: "Hardcoded Credentials"}}, + }) + + got := PostProcess(doc) + run := got.Runs[0] + + if len(run.Results) != 1 { + t.Fatalf("expected 1 result (error-level non-source kept), got %d", len(run.Results)) + } + if run.Results[0].Level != "error" { + t.Errorf("expected level 'error' preserved, got %q", run.Results[0].Level) + } +} + +func TestDeprioritize_SourceFileKeepsLevel(t *testing.T) { + rule := makeRule("rule-sqli", "9.0", "cwe-89") + results := []SARIFResult{ + makeResult("rule-sqli", "routes/login.ts", 42, "error"), + } + + doc := testDoc([]SARIFRule{rule}, results, []SARIFTaxon{ + {ID: "CWE-89", ShortDescription: SARIFMessage{Text: "SQL Injection"}}, + }) + + got := PostProcess(doc) + run := got.Runs[0] + + if len(run.Results) != 1 { + t.Fatalf("expected 1 result, got %d", len(run.Results)) + } + if run.Results[0].Level != "error" { + t.Errorf("expected level 'error' for source file, got %q", run.Results[0].Level) + } +} + +func TestOrphanedRuleCleanup(t *testing.T) { + ruleA := makeRule("rule-used", "7.0", "cwe-89") + ruleB := makeRule("rule-orphan", "5.0", "cwe-79") + + results := []SARIFResult{ + makeResult("rule-used", "src/app.go", 10, "error"), + } + + doc := testDoc([]SARIFRule{ruleA, ruleB}, results, []SARIFTaxon{ + {ID: "CWE-89", ShortDescription: SARIFMessage{Text: "SQL Injection"}}, + {ID: "CWE-79", ShortDescription: SARIFMessage{Text: "XSS"}}, + }) + + got := PostProcess(doc) + run := got.Runs[0] + + if len(run.Tool.Driver.Rules) != 1 { + t.Fatalf("expected 1 rule after orphan cleanup, got %d", len(run.Tool.Driver.Rules)) + } + if run.Tool.Driver.Rules[0].ID != "rule-used" { + t.Errorf("expected rule-used to survive, got %s", run.Tool.Driver.Rules[0].ID) + } + + // CWE-79 taxon should also be removed. + if len(run.Taxonomies) != 1 { + t.Fatalf("expected 1 taxonomy, got %d", len(run.Taxonomies)) + } + if len(run.Taxonomies[0].Taxa) != 1 { + t.Fatalf("expected 1 taxon after cleanup, got %d", len(run.Taxonomies[0].Taxa)) + } + if run.Taxonomies[0].Taxa[0].ID != "CWE-89" { + t.Errorf("expected CWE-89 taxon to survive, got %s", run.Taxonomies[0].Taxa[0].ID) + } +} + +func TestCweForRule(t *testing.T) { + tests := []struct { + name string + rule SARIFRule + want string + }{ + { + name: "from relationship", + rule: SARIFRule{ + Relationships: []SARIFRelationship{{ + Target: SARIFRelationshipTarget{ID: "CWE-89"}, + }}, + }, + want: "CWE-89", + }, + { + name: "relationship without CWE prefix falls through", + rule: SARIFRule{ + Relationships: []SARIFRelationship{{ + Target: SARIFRelationshipTarget{ID: "OWASP-A01"}, + }}, + }, + want: "", + }, + { + name: "from tag when no relationship", + rule: SARIFRule{ + Properties: map[string]any{ + "tags": []string{"security", "external/cwe/cwe-79"}, + }, + }, + want: "CWE-79", + }, + { + name: "tags wrong type ignored", + rule: SARIFRule{ + Properties: map[string]any{"tags": "not-a-slice"}, + }, + want: "", + }, + { + name: "no tags no relationships", + rule: SARIFRule{}, + want: "", + }, + { + name: "tags present but no cwe tag", + rule: SARIFRule{ + Properties: map[string]any{ + "tags": []string{"security", "performance"}, + }, + }, + want: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := CWEForRule(tt.rule); got != tt.want { + t.Errorf("CWEForRule() = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/internal/sarif/types.go b/internal/sarif/types.go new file mode 100644 index 0000000..2c0f816 --- /dev/null +++ b/internal/sarif/types.go @@ -0,0 +1,156 @@ +package sarif + +// AnalysisResult represents the structured output from the LLM security analysis. +type AnalysisResult struct { + RepoName string `json:"repo_name"` + Description string `json:"description"` + PublicAPIRoutes []APIRoute `json:"public_api_routes"` + SecurityIssues []SecurityIssue `json:"security_issues"` + SecurityRisk float64 `json:"security_risk"` + RiskJustification string `json:"risk_justification"` +} + +// APIRoute describes a discovered public API endpoint. +type APIRoute struct { + Route string `json:"route"` + Citation string `json:"citation"` +} + +// SecurityIssue describes a single finding from the LLM analysis. +type SecurityIssue struct { + Issue string `json:"issue"` + FilePath string `json:"file_path"` + StartLine int `json:"start_line"` + EndLine int `json:"end_line"` + TechnicalDetails string `json:"technical_details"` + Severity float64 `json:"severity"` + CWEID string `json:"cwe_id"` +} + +// --------------------------------------------------------------------------- +// SARIF v2.1.0 output types +// --------------------------------------------------------------------------- + +// SARIFDocument is the top-level SARIF v2.1.0 envelope. +type SARIFDocument struct { + Schema string `json:"$schema"` + Version string `json:"version"` + Runs []SARIFRun `json:"runs"` +} + +// SARIFRun groups tool information, results, and invocations. +type SARIFRun struct { + Tool SARIFTool `json:"tool"` + Results []SARIFResult `json:"results"` + Invocations []SARIFInvocation `json:"invocations,omitempty"` + Taxonomies []SARIFTaxonomy `json:"taxonomies,omitempty"` +} + +// SARIFTool describes the analysis tool. +type SARIFTool struct { + Driver SARIFDriver `json:"driver"` +} + +// SARIFDriver holds tool metadata and the set of rules. +type SARIFDriver struct { + Name string `json:"name"` + Version string `json:"version"` + InformationURI string `json:"informationUri,omitempty"` + Rules []SARIFRule `json:"rules"` +} + +// SARIFRule defines a unique finding category. +type SARIFRule struct { + ID string `json:"id"` + Name string `json:"name,omitempty"` + ShortDescription SARIFMessage `json:"shortDescription"` + FullDescription *SARIFMessage `json:"fullDescription,omitempty"` + Help *SARIFMessage `json:"help,omitempty"` + Properties map[string]any `json:"properties,omitempty"` + Relationships []SARIFRelationship `json:"relationships,omitempty"` +} + +// SARIFRelationship links a rule to an external taxonomy (e.g., CWE). +type SARIFRelationship struct { + Target SARIFRelationshipTarget `json:"target"` + Kinds []string `json:"kinds"` +} + +// SARIFRelationshipTarget identifies a taxonomy entry. +type SARIFRelationshipTarget struct { + ID string `json:"id"` + GUID string `json:"guid,omitempty"` + ToolComponent SARIFToolComponentRef `json:"toolComponent"` +} + +// SARIFToolComponentRef references a tool component (taxonomy) by name. +type SARIFToolComponentRef struct { + Name string `json:"name"` +} + +// SARIFTaxonomy describes an external taxonomy like CWE. +type SARIFTaxonomy struct { + Name string `json:"name"` + Organization string `json:"organization"` + ShortDescription SARIFMessage `json:"shortDescription"` + Taxa []SARIFTaxon `json:"taxa"` +} + +// SARIFTaxon is a single entry in a taxonomy. +type SARIFTaxon struct { + ID string `json:"id"` + ShortDescription SARIFMessage `json:"shortDescription"` +} + +// SARIFMessage is a simple text message wrapper used throughout SARIF. +type SARIFMessage struct { + Text string `json:"text"` +} + +// SARIFResult is a single finding referencing a rule. +type SARIFResult struct { + RuleID string `json:"ruleId"` + Level string `json:"level"` + Message SARIFMessage `json:"message"` + Locations []SARIFLocation `json:"locations,omitempty"` +} + +// SARIFLocation wraps a physical location. +type SARIFLocation struct { + PhysicalLocation SARIFPhysicalLocation `json:"physicalLocation"` +} + +// SARIFPhysicalLocation points to a file and optional region. +type SARIFPhysicalLocation struct { + ArtifactLocation SARIFArtifactLocation `json:"artifactLocation"` + Region *SARIFRegion `json:"region,omitempty"` +} + +// SARIFArtifactLocation identifies a file by URI. +type SARIFArtifactLocation struct { + URI string `json:"uri"` +} + +// SARIFRegion identifies a range of lines within a file. +type SARIFRegion struct { + StartLine int `json:"startLine"` + EndLine int `json:"endLine,omitempty"` + Snippet *SARIFSnippet `json:"snippet,omitempty"` +} + +// SARIFSnippet holds extracted source text. +type SARIFSnippet struct { + Text string `json:"text"` +} + +// SARIFInvocation records metadata about a tool execution. +type SARIFInvocation struct { + ExecutionSuccessful bool `json:"executionSuccessful"` + ToolExecutionNotifications []SARIFNotification `json:"toolExecutionNotifications,omitempty"` +} + +// SARIFNotification represents a runtime message from the tool. +type SARIFNotification struct { + Level string `json:"level"` + Message SARIFMessage `json:"message"` +} diff --git a/internal/supctx/compress.go b/internal/supctx/compress.go new file mode 100644 index 0000000..f541f0d --- /dev/null +++ b/internal/supctx/compress.go @@ -0,0 +1,84 @@ +package supctx + +import ( + "context" + "fmt" + "log/slog" + "strings" + + "github.com/block/codecrucible/internal/llm" +) + +// Compressor holds everything the compress pre-pass needs: an LLM client +// wired to the context-compress phase config, the prompt template, and a +// counter for re-measuring the result. +type Compressor struct { + Client llm.Client + Prompt llm.ContextCompressPrompt + Counter TokenCounter + Model string // sent in ChatRequest.Model +} + +// Compress squeezes sources that exceed their fair share of the budget. Only +// sources with Compress==true are touched; the rest pass through unchanged. +// Failures log and fall back to the original content — packing will then +// truncate instead. +// +// "Fair share" is budget / len(sources): crude, but it means a single huge +// doc doesn't starve a set of small ones just because it was listed first. +func (c *Compressor) Compress(ctx context.Context, loaded []Loaded, budget int) []Loaded { + if budget <= 0 || len(loaded) == 0 { + return loaded + } + share := budget / len(loaded) + if share < 200 { + share = 200 // floor — below this the summary is useless + } + + out := make([]Loaded, len(loaded)) + copy(out, loaded) + + for i := range out { + l := &out[i] + if !l.Compress || l.Tokens <= share { + continue + } + slog.Info("compressing context source", "name", l.Name, "tokens", l.Tokens, "target", share) + + compressed, err := c.compressOne(ctx, l.Name, l.Content, share) + if err != nil { + slog.Warn("context compression failed, will truncate instead", + "name", l.Name, "error", err) + continue + } + l.Content = compressed + l.Tokens = c.Counter.Count(compressed) + slog.Info("context source compressed", "name", l.Name, "tokens", l.Tokens) + } + return out +} + +func (c *Compressor) compressOne(ctx context.Context, name, content string, target int) (string, error) { + user := c.Prompt.UserPromptTemplate + user = strings.ReplaceAll(user, "{source_name}", name) + user = strings.ReplaceAll(user, "{target_tokens}", fmt.Sprintf("%d", target)) + user = strings.ReplaceAll(user, "{content}", content) + + resp, err := c.Client.ChatCompletion(ctx, llm.ChatRequest{ + Label: fmt.Sprintf("context-compress %s", name), + Model: c.Model, + Messages: []llm.Message{ + {Role: "system", Content: c.Prompt.SystemMessage}, + {Role: "user", Content: user}, + }, + // Give the model headroom above target — it can't count its own + // tokens precisely, and truncating a summary mid-sentence is worse + // than going slightly over. Pack will enforce the hard budget. + MaxTokens: target * 2, + OutputMode: llm.OutputModeNone, + }) + if err != nil { + return "", err + } + return strings.TrimSpace(resp.Content), nil +} diff --git a/internal/supctx/pack.go b/internal/supctx/pack.go new file mode 100644 index 0000000..1c9a304 --- /dev/null +++ b/internal/supctx/pack.go @@ -0,0 +1,111 @@ +package supctx + +import ( + "fmt" + "sort" + "strings" +) + +// PackResult is the rendered prompt block plus accounting. +type PackResult struct { + // Rendered is the final string ready to splice into the prompt. Empty + // when no sources applied or budget was zero. + Rendered string + + // Tokens is the counted size of Rendered. This is what the caller adds + // to promptOverhead so chunk-budget math stays honest. + Tokens int + + // Dropped lists sources that didn't fit at all. + Dropped []string + + // Truncated names the one source (if any) that was partially included. + Truncated string +} + +// Pack greedily fills the budget highest-priority-first. The last source that +// doesn't fully fit is truncated with a marker; anything after is dropped. +// +// Budget ≤ 0 returns an empty result — callers use this to cleanly disable +// context injection without special-casing. +func Pack(loaded []Loaded, budget int, counter TokenCounter) PackResult { + var res PackResult + if budget <= 0 || len(loaded) == 0 { + for _, l := range loaded { + res.Dropped = append(res.Dropped, l.Name) + } + return res + } + + // Stable sort so equal-priority sources keep declaration order. + sorted := make([]Loaded, len(loaded)) + copy(sorted, loaded) + sort.SliceStable(sorted, func(i, j int) bool { + return sorted[i].Priority > sorted[j].Priority + }) + + var b strings.Builder + remaining := budget + + for i, l := range sorted { + block := wrapSource(l.Name, l.Content) + tokens := counter.Count(block) + + if tokens <= remaining { + b.WriteString(block) + remaining -= tokens + continue + } + + // Doesn't fit whole. If there's meaningful room left, truncate; + // otherwise drop this and everything after. + if remaining > 50 { + truncated := truncateToTokens(l.Content, remaining, counter) + marker := fmt.Sprintf("\n[... ~%d tokens truncated ...]", l.Tokens-counter.Count(truncated)) + b.WriteString(wrapSource(l.Name, truncated+marker)) + res.Truncated = l.Name + remaining = 0 + } else { + res.Dropped = append(res.Dropped, l.Name) + } + + // Everything after is dropped. + for _, rest := range sorted[i+1:] { + res.Dropped = append(res.Dropped, rest.Name) + } + break + } + + res.Rendered = b.String() + res.Tokens = counter.Count(res.Rendered) + return res +} + +// wrapSource puts a named envelope around content so the model can attribute +// what it reads to the right source. +func wrapSource(name, content string) string { + return fmt.Sprintf("\n%s\n\n", name, content) +} + +// truncateToTokens cuts content to roughly target tokens by binary-searching +// on byte length. The heuristic counter is monotone in length so this +// converges in log(len) iterations without repeatedly counting the full text. +func truncateToTokens(content string, target int, counter TokenCounter) string { + if counter.Count(content) <= target { + return content + } + lo, hi := 0, len(content) + for lo < hi { + mid := (lo + hi + 1) / 2 + if counter.Count(content[:mid]) <= target { + lo = mid + } else { + hi = mid - 1 + } + } + // Back off to the last newline so we don't cut mid-line. + if i := strings.LastIndexByte(content[:lo], '\n'); i > 0 { + lo = i + } + return content[:lo] +} diff --git a/internal/supctx/pack_test.go b/internal/supctx/pack_test.go new file mode 100644 index 0000000..8756cbb --- /dev/null +++ b/internal/supctx/pack_test.go @@ -0,0 +1,137 @@ +package supctx + +import ( + "strings" + "testing" +) + +// charCounter counts bytes — deterministic and trivially predictable for tests. +type charCounter struct{} + +func (charCounter) Count(s string) int { return len(s) } + +func TestPack_EmptyInputs(t *testing.T) { + r := Pack(nil, 1000, charCounter{}) + if r.Rendered != "" || r.Tokens != 0 || len(r.Dropped) != 0 { + t.Fatalf("expected empty result, got %+v", r) + } + + r = Pack([]Loaded{{Name: "a", Content: "hello", Tokens: 5}}, 0, charCounter{}) + if r.Rendered != "" || len(r.Dropped) != 1 || r.Dropped[0] != "a" { + t.Fatalf("zero budget should drop all sources, got %+v", r) + } +} + +func TestPack_AllFit(t *testing.T) { + loaded := []Loaded{ + {Name: "spec", Content: "api spec here", Priority: 100}, + {Name: "notes", Content: "review notes", Priority: 50}, + } + r := Pack(loaded, 10_000, charCounter{}) + if len(r.Dropped) != 0 || r.Truncated != "" { + t.Fatalf("expected no drops/truncation, got %+v", r) + } + if !strings.Contains(r.Rendered, "api spec here") || !strings.Contains(r.Rendered, "review notes") { + t.Fatalf("rendered missing content: %q", r.Rendered) + } + // Higher priority should appear first. + specIdx := strings.Index(r.Rendered, "spec") + notesIdx := strings.Index(r.Rendered, "notes") + if specIdx > notesIdx { + t.Fatalf("priority order wrong: spec@%d notes@%d", specIdx, notesIdx) + } +} + +func TestPack_PriorityOrder(t *testing.T) { + loaded := []Loaded{ + {Name: "low", Content: strings.Repeat("x", 200), Priority: 10}, + {Name: "high", Content: strings.Repeat("y", 200), Priority: 100}, + } + // Budget fits one wrapped source (~230 chars with envelope) but not both. + r := Pack(loaded, 250, charCounter{}) + if !strings.Contains(r.Rendered, "high") { + t.Fatalf("high-priority source should be packed: %+v", r) + } + // low should be dropped or truncated — not fully included + if strings.Count(r.Rendered, "x") >= 200 { + t.Fatalf("low-priority source should not be fully included") + } +} + +func TestPack_Truncation(t *testing.T) { + big := strings.Repeat("line of content here\n", 100) // ~2100 chars + loaded := []Loaded{ + {Name: "big", Content: big, Tokens: len(big), Priority: 100}, + } + r := Pack(loaded, 500, charCounter{}) + if r.Truncated != "big" { + t.Fatalf("expected truncation of 'big', got %+v", r) + } + if !strings.Contains(r.Rendered, "truncated") { + t.Fatalf("truncation marker missing: %q", r.Rendered) + } + if r.Tokens > 600 { // some slop for the envelope+marker + t.Fatalf("rendered exceeds budget tolerance: %d tokens", r.Tokens) + } +} + +func TestPack_DropsRemainder(t *testing.T) { + loaded := []Loaded{ + {Name: "a", Content: strings.Repeat("a", 300), Priority: 100}, + {Name: "b", Content: strings.Repeat("b", 300), Priority: 50}, + {Name: "c", Content: strings.Repeat("c", 300), Priority: 10}, + } + // Budget fits 'a' wrapped (~330) and partially 'b', not 'c'. + r := Pack(loaded, 400, charCounter{}) + + found := map[string]bool{} + for _, d := range r.Dropped { + found[d] = true + } + if !found["c"] { + t.Fatalf("expected 'c' to be dropped, got dropped=%v", r.Dropped) + } + if found["a"] { + t.Fatalf("'a' (highest priority) should not be dropped") + } +} + +func TestPack_StablePriorityTies(t *testing.T) { + loaded := []Loaded{ + {Name: "first", Content: "aaa", Priority: 50}, + {Name: "second", Content: "bbb", Priority: 50}, + } + r := Pack(loaded, 1000, charCounter{}) + if strings.Index(r.Rendered, "first") > strings.Index(r.Rendered, "second") { + t.Fatalf("equal-priority sources should keep declaration order") + } +} + +func TestFilterPhase(t *testing.T) { + loaded := []Loaded{ + {Name: "all", Phases: nil}, + {Name: "analysis-only", Phases: []string{"analysis"}}, + {Name: "audit-only", Phases: []string{"audit"}}, + } + + a := FilterPhase(loaded, "analysis") + if len(a) != 2 || a[0].Name != "all" || a[1].Name != "analysis-only" { + t.Fatalf("analysis filter wrong: %+v", a) + } + + au := FilterPhase(loaded, "audit") + if len(au) != 2 || au[0].Name != "all" || au[1].Name != "audit-only" { + t.Fatalf("audit filter wrong: %+v", au) + } +} + +func TestTruncateToTokens_NewlineBackoff(t *testing.T) { + content := "line one\nline two\nline three\nline four\n" + out := truncateToTokens(content, 20, charCounter{}) + if strings.HasSuffix(out, "line") || strings.Contains(out, "line t\n") { + t.Fatalf("should back off to newline boundary, got %q", out) + } + if !strings.HasSuffix(out, "\n") && out != "line one" { + t.Logf("truncated to: %q", out) // soft check — depends on exact budget + } +} diff --git a/internal/supctx/source.go b/internal/supctx/source.go new file mode 100644 index 0000000..03d7e83 --- /dev/null +++ b/internal/supctx/source.go @@ -0,0 +1,310 @@ +// Package supctx loads supplementary context (docs, API specs, sibling repos, +// internal notes) and packs it into a token-budget-safe block for injection +// into analysis and audit prompts. +// +// The package is deliberately decoupled from the prompt assembler: it produces +// a rendered string plus accounting metadata, and the caller decides where in +// the prompt that string lands. +package supctx + +import ( + "context" + "fmt" + "io" + "log/slog" + "net" + "net/http" + "net/url" + "os" + "os/exec" + "path/filepath" + "regexp" + "sort" + "strings" + "time" + + "github.com/block/codecrucible/internal/ingest" +) + +// Source describes one piece of supplementary context as configured by the +// user. The four types share a single struct so config unmarshal stays flat. +type Source struct { + // Name labels the source in the rendered prompt block and in logs. + Name string `mapstructure:"name"` + + // Type is one of "path", "repo", "url", "inline". + Type string `mapstructure:"type"` + + // Location is interpreted per-type: filesystem path, git clone URL, + // HTTP(S) URL, or the literal content for inline. + Location string `mapstructure:"location"` + + // Priority orders packing when the combined sources exceed the budget. + // Higher wins; ties broken by declaration order. + Priority int `mapstructure:"priority"` + + // Compress marks the source as eligible for LLM pre-compression when it + // alone would exceed its fair share of the budget. + Compress bool `mapstructure:"compress"` + + // Phases limits which scan phases receive this source. Empty means all + // LLM phases (currently "analysis" and "audit" — feature-detection is + // never fed supplementary context). + Phases []string `mapstructure:"phases"` + + // Include/Exclude are glob filters for "path" and "repo" types, reusing + // the ingest package's FilterConfig semantics. + Include []string `mapstructure:"include"` + Exclude []string `mapstructure:"exclude"` +} + +// Loaded is a source whose content has been fetched and token-counted but not +// yet packed. Priority is carried through so Pack can sort. +type Loaded struct { + Name string + Content string + Tokens int + Priority int + Compress bool + Phases []string +} + +// TokenCounter is the subset of chunk.TokenCounter this package needs. +// Defined locally to avoid an import cycle if chunk ever wants supctx. +type TokenCounter interface { + Count(text string) int +} + +// fetchTimeout bounds git clone and HTTP fetch. Failures skip the source +// rather than aborting the scan. +const fetchTimeout = 60 * time.Second + +// LoadAll fetches every source concurrently-ish (sequential for now — the +// dominant cost is the LLM calls downstream, not loading a few files) and +// returns them token-counted. Load errors are logged and the source dropped; +// the scan proceeds with whatever loaded successfully. +func LoadAll(ctx context.Context, srcs []Source, counter TokenCounter) []Loaded { + var out []Loaded + for i, s := range srcs { + if s.Name == "" { + s.Name = fmt.Sprintf("context-%d", i+1) + } + content, err := loadOne(ctx, s) + if err != nil { + slog.Warn("skipping context source", "name", s.Name, "type", s.Type, "error", err) + continue + } + if strings.TrimSpace(content) == "" { + slog.Warn("context source is empty, skipping", "name", s.Name) + continue + } + out = append(out, Loaded{ + Name: s.Name, + Content: content, + Tokens: counter.Count(content), + Priority: s.Priority, + Compress: s.Compress, + Phases: s.Phases, + }) + } + return out +} + +func loadOne(ctx context.Context, s Source) (string, error) { + switch strings.ToLower(s.Type) { + case "inline": + return s.Location, nil + case "path": + return loadPath(s.Location, s.Include, s.Exclude) + case "repo": + return loadRepo(ctx, s.Location, s.Include, s.Exclude) + case "url": + return loadURL(ctx, s.Location) + default: + return "", fmt.Errorf("unknown context source type %q", s.Type) + } +} + +// loadPath handles both single files and directories. Directories go through +// the ingest walker/filter so .gitignore, binary-skip, and glob filters all +// apply exactly as they do for the scan target. +func loadPath(location string, include, exclude []string) (string, error) { + info, err := os.Stat(location) + if err != nil { + return "", err + } + if !info.IsDir() { + data, err := os.ReadFile(location) + if err != nil { + return "", err + } + return renderFiles([]ingest.SourceFile{{ + Path: filepath.Base(location), + Content: string(data), + }}), nil + } + + files, err := ingest.WalkDir(location) + if err != nil { + return "", err + } + filtered, _ := ingest.FilterFiles(files, ingest.FilterConfig{ + IncludeTests: true, // user explicitly pointed at this dir — don't second-guess + IncludeDocs: true, + Include: include, + Exclude: exclude, + }) + return renderFiles(filtered), nil +} + +// loadRepo shallow-clones into a temp dir, delegates to loadPath, then cleans +// up. The temp dir lives for the lifetime of the process (not just this call) +// because the returned string references no files — content is already read. +func loadRepo(ctx context.Context, gitURL string, include, exclude []string) (string, error) { + tmp, err := os.MkdirTemp("", "ri-ctx-*") + if err != nil { + return "", err + } + defer os.RemoveAll(tmp) + + cctx, cancel := context.WithTimeout(ctx, fetchTimeout) + defer cancel() + + cmd := exec.CommandContext(cctx, "git", "clone", "--depth", "1", "--quiet", gitURL, tmp) + cmd.Stderr = nil + if out, err := cmd.CombinedOutput(); err != nil { + return "", fmt.Errorf("git clone: %w: %s", err, strings.TrimSpace(string(out))) + } + return loadPath(tmp, include, exclude) +} + +// allowPrivateHosts disables the SSRF check on initial URLs. Set to true in +// tests that use httptest.Server (which binds to 127.0.0.1). +var allowPrivateHosts = false + +// loadURL fetches over HTTP(S), refusing other schemes and requests/redirects +// to private IP ranges. HTML responses get a crude tag-strip so the prompt +// isn't half angle brackets. +func loadURL(ctx context.Context, rawURL string) (string, error) { + u, err := url.Parse(rawURL) + if err != nil { + return "", err + } + if u.Scheme != "http" && u.Scheme != "https" { + return "", fmt.Errorf("unsupported URL scheme %q", u.Scheme) + } + if !allowPrivateHosts && isPrivateHost(u.Hostname()) { + return "", fmt.Errorf("request to private address %s refused", u.Host) + } + + client := &http.Client{ + Timeout: fetchTimeout, + CheckRedirect: func(req *http.Request, _ []*http.Request) error { + // Always check redirects — even in test mode, a redirect to a + // private IP from a public URL is suspicious. + if isPrivateHost(req.URL.Hostname()) { + return fmt.Errorf("redirect to private address %s refused", req.URL.Host) + } + return nil + }, + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil) + if err != nil { + return "", err + } + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("HTTP %d", resp.StatusCode) + } + + body, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20)) // 4 MiB cap + if err != nil { + return "", err + } + + content := string(body) + if strings.Contains(resp.Header.Get("Content-Type"), "text/html") { + content = stripHTML(content) + } + return content, nil +} + +// isPrivateHost returns true for loopback, link-local, and RFC1918 ranges. +// Best-effort SSRF guard — a determined attacker with control of the config +// file already has shell via the repo loader's git-clone anyway. +func isPrivateHost(host string) bool { + ips, err := net.LookupIP(host) + if err != nil { + return false // let the request fail naturally + } + for _, ip := range ips { + if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsPrivate() { + return true + } + } + return false +} + +var ( + scriptRe = regexp.MustCompile(`(?is)]*>.*?`) + styleRe = regexp.MustCompile(`(?is)]*>.*?`) + tagRe = regexp.MustCompile(`<[^>]+>`) + wsRe = regexp.MustCompile(`[ \t]+`) + nlRe = regexp.MustCompile(`\n{3,}`) +) + +// stripHTML removes script/style blocks, then all remaining tags, then +// collapses whitespace. Good enough for wiki pages and rendered markdown; +// not a real HTML parser but avoids a dependency. +func stripHTML(s string) string { + s = scriptRe.ReplaceAllString(s, "") + s = styleRe.ReplaceAllString(s, "") + s = tagRe.ReplaceAllString(s, " ") + s = wsRe.ReplaceAllString(s, " ") + s = nlRe.ReplaceAllString(s, "\n\n") + return strings.TrimSpace(s) +} + +// renderFiles wraps a set of files in a light XML-ish envelope so the model +// can tell where one file ends and the next begins. Intentionally simpler than +// the repomix format used for the scan target — no line numbers, no directory +// tree — because supplementary context is reference material, not the thing +// being audited. +func renderFiles(files []ingest.SourceFile) string { + sort.Slice(files, func(i, j int) bool { return files[i].Path < files[j].Path }) + var b strings.Builder + for _, f := range files { + fmt.Fprintf(&b, "\n%s\n\n", f.Path, f.Content) + } + return b.String() +} + +// AppliesTo reports whether a loaded source should be injected into the named +// phase. Empty Phases means "all". +func (l Loaded) AppliesTo(phase string) bool { + if len(l.Phases) == 0 { + return true + } + for _, p := range l.Phases { + if p == phase { + return true + } + } + return false +} + +// FilterPhase returns the subset of loaded sources that apply to phase. +func FilterPhase(loaded []Loaded, phase string) []Loaded { + var out []Loaded + for _, l := range loaded { + if l.AppliesTo(phase) { + out = append(out, l) + } + } + return out +} diff --git a/internal/supctx/source_test.go b/internal/supctx/source_test.go new file mode 100644 index 0000000..b094fa3 --- /dev/null +++ b/internal/supctx/source_test.go @@ -0,0 +1,135 @@ +package supctx + +import ( + "context" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestMain(m *testing.M) { + // httptest.Server binds to 127.0.0.1 which the SSRF check blocks. + allowPrivateHosts = true + os.Exit(m.Run()) +} + +func TestLoadAll_Inline(t *testing.T) { + srcs := []Source{ + {Name: "notes", Type: "inline", Location: "admin endpoints are behind mTLS"}, + } + loaded := LoadAll(context.Background(), srcs, charCounter{}) + if len(loaded) != 1 || loaded[0].Content != "admin endpoints are behind mTLS" { + t.Fatalf("inline load failed: %+v", loaded) + } + if loaded[0].Tokens != len(loaded[0].Content) { + t.Fatalf("tokens not counted") + } +} + +func TestLoadAll_Path_SingleFile(t *testing.T) { + tmp := t.TempDir() + f := filepath.Join(tmp, "spec.yaml") + if err := os.WriteFile(f, []byte("openapi: 3.0.0"), 0o644); err != nil { + t.Fatal(err) + } + loaded := LoadAll(context.Background(), []Source{ + {Name: "spec", Type: "path", Location: f}, + }, charCounter{}) + if len(loaded) != 1 || !strings.Contains(loaded[0].Content, "openapi: 3.0.0") { + t.Fatalf("path load failed: %+v", loaded) + } + if !strings.Contains(loaded[0].Content, ``) { + t.Fatalf("file envelope missing: %q", loaded[0].Content) + } +} + +func TestLoadAll_Path_Directory(t *testing.T) { + tmp := t.TempDir() + if err := os.WriteFile(filepath.Join(tmp, "a.go"), []byte("package a"), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(tmp, "b.go"), []byte("package b"), 0o644); err != nil { + t.Fatal(err) + } + loaded := LoadAll(context.Background(), []Source{ + {Name: "pkg", Type: "path", Location: tmp}, + }, charCounter{}) + if len(loaded) != 1 { + t.Fatalf("expected 1 loaded source, got %d", len(loaded)) + } + if !strings.Contains(loaded[0].Content, "package a") || !strings.Contains(loaded[0].Content, "package b") { + t.Fatalf("directory load missing files: %q", loaded[0].Content) + } +} + +func TestLoadAll_URL(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + _, _ = w.Write([]byte("threat model content")) + })) + defer srv.Close() + + loaded := LoadAll(context.Background(), []Source{ + {Name: "tm", Type: "url", Location: srv.URL}, + }, charCounter{}) + if len(loaded) != 1 || loaded[0].Content != "threat model content" { + t.Fatalf("url load failed: %+v", loaded) + } +} + +func TestLoadAll_URL_HTMLStrip(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + _, _ = w.Write([]byte(`

Title

Content here

`)) + })) + defer srv.Close() + + loaded := LoadAll(context.Background(), []Source{ + {Name: "wiki", Type: "url", Location: srv.URL}, + }, charCounter{}) + if len(loaded) != 1 { + t.Fatalf("expected 1 source, got %d", len(loaded)) + } + c := loaded[0].Content + if strings.Contains(c, "` — insufficient + if user contains ``). + validation_checks: + - "Verify triple-mustache / unescaped template tags are not applied to user data." + - "Confirm React / Vue / Angular frameworks' safe-HTML escapes are not bypassed with user data." + - "Check inline script JSON embedding uses a serializer that escapes `` (serialize-javascript with isJSON=true)." + - "Inspect client-side innerHTML / v-html / dangerouslySetInnerHTML for user-reachable sources." + false_positive_indicators: + - "Value is sanitized via DOMPurify (with a restrictive config)." + - "Value is a number / boolean / enum that cannot contain script." + - "Framework's auto-escape is active and the flagged site is a plain interpolation." + + CWE-89: + title: "SQL Injection" + analysis_prompt: > + Template literals used to build SQL (`\`SELECT * FROM u WHERE id = + ${id}\``). `knex.raw(user)`, `sequelize.query(user)` without + replacements, TypeORM `.query(\`... ${user} ...\`)`, raw `pg.query` + / `mysql2.query` with string concat. Prisma is generally safe but + flag `$queryRaw\`... ${user}\`` (note: Prisma's tagged template with + backticks auto-parameterizes, but `$queryRawUnsafe` does not — + distinguish these). Column / table / ORDER BY / LIMIT positions + never parameterize — they need an allowlist. + validation_checks: + - "Verify SQL construction uses placeholder binding ($1, ?, :name) with a params array, not string interpolation." + - "Confirm Prisma usage: tagged-template $queryRaw is safe; $queryRawUnsafe is not." + - "Check identifier positions (column, table, ORDER BY) come from an allowlist." + false_positive_indicators: + - "Interpolated value is a constant or a numeric path parameter validated by an int parser." + - "Query builder exclusively uses parameter binding; interpolation is only on non-value clauses." + + CWE-943: + title: "NoSQL Injection (Mongo / Redis Operator Injection)" + analysis_prompt: > + MongoDB: `User.find(req.body)` / `User.findOne(req.query)` passes the + entire object including operator fields (`$ne`, `$gt`, `$where`, + `$regex`). Attacker sends `{username: {$ne: null}, password: {$ne: + null}}` to bypass auth. Also `$where` with a user-controlled JS + string executes on the server. Redis: `EVAL` with user-concatenated + Lua scripts. DynamoDB: user-controlled filter expressions. + validation_checks: + - "Verify query objects are constructed field-by-field from validated primitive values, not spread from req.body / req.query." + - "Confirm `$where` / `$function` / Lua EVAL never include user data." + - "Inspect sanitization middleware (express-mongo-sanitize) is installed and applied." + false_positive_indicators: + - "Route uses a DTO (NestJS ValidationPipe, Zod, Joi) that coerces each field to a primitive before query construction." + + CWE-78: + title: "OS Command Injection" + analysis_prompt: > + `child_process.exec(user)` and `execSync(user)` always spawn a shell. + `spawn` / `execFile` with `shell: true` and user input is also + injection. Template-string commands: `exec(\`convert ${user} out.png\`)`. + Even with shell:false, passing user-controlled arg that starts with + `-` (argument injection, e.g. `ssh -oProxyCommand=...`). Piping user + into a shell via stdin of a shell process. + validation_checks: + - "Verify `exec` / `execSync` are not used with non-literal arguments; prefer `execFile` / `spawn` with shell:false." + - "Confirm args arrays use a validated allowlist; user values cannot introduce flags." + - "Flag every `shell: true` with non-literal arg." + false_positive_indicators: + - "Command and all args are hardcoded literals." + - "User input is escaped with shell-quote and shell:true is intentional and documented." + + CWE-1321: + title: "Prototype Pollution" + analysis_prompt: > + Recursive deep-merge / deep-clone / deep-assign operations that copy + attacker-JSON keys `__proto__`, `constructor`, `prototype` into a + base object. Libraries historically vulnerable: lodash `_.merge`, + `_.defaultsDeep`, `_.set` (pre-patched versions); `deepmerge`, + `hoek`, custom reduce-based merges. Express body-parser / qs default + parses `?a.b.c=x` into a nested object — combined with a merge, + this is end-to-end RCE-enabling. Impact ranges from DoS to gadget- + based RCE depending on downstream consumers (e.g. template engines + reading prototype properties, CI / build tools accepting JSON). + validation_checks: + - "Verify deep-merge operations reject keys __proto__, constructor, prototype, or use Object.create(null) bases." + - "Confirm lodash version is post-4.17.21 (merges are patched) AND that the code doesn't bypass the patch with _.setWith / _.set." + - "Inspect `qs`/`body-parser` config — disable nested object parsing where unused." + false_positive_indicators: + - "Merge source is a validated schema output (Zod, Joi, Yup); arbitrary keys are stripped upstream." + - "Merge target uses Object.create(null) or Map, not plain {}." + + CWE-502: + title: "Deserialization of Untrusted Data" + analysis_prompt: > + `node-serialize`, `serialize-javascript` deserializing user data, or + any package deserializing function bodies. JSON.parse with a reviver + that constructs classes. `js-yaml` load (not `safeLoad`) used with + tag-based type reconstruction. `vm.runInNewContext` over user JSON + that's effectively code. Kafka / Rabbit / Redis consumers + deserializing arbitrary payloads. + validation_checks: + - "Verify no JS deserializer reconstructs function / class instances from attacker data." + - "Confirm js-yaml uses safeLoad (or yaml 4.x which defaults safe)." + false_positive_indicators: + - "Data is plain JSON.parse to primitives / simple objects with no reviver." + + CWE-22: + title: "Path Traversal" + analysis_prompt: > + `fs.readFile(userPath)`, `res.sendFile(userPath)`, `path.join(base, + user)` without post-resolve prefix check (join does not stop `../`). + Multer destination functions using user data. Archive extraction + (unzipper, tar, adm-zip) without filtering entry paths (zip-slip). + Express `res.sendFile` has `root` option — verify it's used and the + resolved path stays inside. + validation_checks: + - "Verify paths are resolved (path.resolve) and the result checked to start with an allowed base." + - "Confirm archive extractors filter `..` / absolute paths from entries." + - "Check res.sendFile uses the `root` option and rejects resolution failures." + false_positive_indicators: + - "Filename is a server-generated UUID / hash." + - "path is inside a chrooted / containerized path-restricted environment." + + CWE-918: + title: "Server-Side Request Forgery (SSRF)" + analysis_prompt: > + `fetch(userUrl)`, `axios.get(userUrl)`, `got(userUrl)`, `http.request` + with user-controlled host. Common pitfalls: URL.parse + hostname + check allows DNS rebinding (resolve once, use IP for the request); + redirect-follow defaults to on (got/axios), so a public URL can + 301→169.254.169.254. Custom image proxies, URL preview cards, + webhook delivery, avatar fetchers. + validation_checks: + - "Verify URL scheme is restricted to http(s) and host is allowlisted or blocklisted against private ranges." + - "Confirm DNS is resolved once and the IP is used (no TOCTOU via rebinding)." + - "Check redirects are disabled or each hop is re-validated." + false_positive_indicators: + - "URL is a compile-time constant or from a trusted config value." + - "Request goes through a dedicated egress proxy that enforces destination policy." + + CWE-601: + title: "Open Redirect" + analysis_prompt: > + `res.redirect(req.query.next)`, `window.location = params.get('url')`, + Next.js `router.push(userUrl)`. Attacker crafts `/login?next=https:// + evil.com`, victim clicks, gets phished after auth. Safe patterns use + an allowlist of relative paths or hostnames. + validation_checks: + - "Verify redirect targets are validated against an allowlist OR constrained to same-origin paths (starts with `/` and no `//`)." + - "Confirm `//evil.com` protocol-relative form is rejected." + false_positive_indicators: + - "Target is a constant / from an allowlist." + + CWE-287: + title: "Improper Authentication" + analysis_prompt: > + `jwt.verify(token, secret, { algorithms })` missing `algorithms` — + jsonwebtoken pre-fix defaulted to permissive. `jwt.verify` with + `algorithms: ['HS256', 'RS256']` allows an RSA public key to be + used as an HMAC secret (classic RS256→HS256 confusion). Password + comparison with `===`; bcrypt comparison ignoring the async + callback. Session secrets weak / hardcoded. OAuth state parameter + missing or not verified. + validation_checks: + - "Verify jwt.verify pins algorithms to a single expected value." + - "Confirm password comparison uses bcrypt.compare / argon2.verify (constant-time)." + - "Check session cookies set httpOnly / secure / sameSite." + - "Inspect OAuth state / nonce validation." + false_positive_indicators: + - "JWT verification uses `jose` library's jwtVerify which pins algorithm by key type." + + CWE-862: + title: "Missing Authorization" + analysis_prompt: > + Express routes without auth middleware. NestJS controllers without + `@UseGuards(AuthGuard)`. Next.js API routes without auth check. DRF- + style ViewSets loading resources by id without ownership filter. + Mongoose `findById(req.params.id)` without user filter. + validation_checks: + - "Verify every route reading or mutating data has an auth middleware applied." + - "Confirm per-resource ownership check is present on id-based lookups." + - "Inspect NestJS guards apply at controller AND method levels where needed." + false_positive_indicators: + - "Resource is intentionally public." + - "Auth is enforced globally via a middleware attached at the router level." + + CWE-639: + title: "Authorization Bypass Through User-Controlled Key (IDOR)" + analysis_prompt: > + Routes accepting an id in path/query/body and returning/mutating the + resource without ownership validation. Batch endpoints processing + arrays of ids without per-item check. Mongoose / Prisma `findUnique` + / `findFirst` with user input and no tenant filter. + validation_checks: + - "Confirm every id-based lookup scopes to the authenticated user / tenant." + - "Verify batch endpoints check per-item authorization." + false_positive_indicators: + - "Row-level security / policies enforce tenant isolation." + + CWE-352: + title: "Cross-Site Request Forgery (CSRF)" + analysis_prompt: > + Cookie-based auth without CSRF protection (no csurf / no SameSite / + no origin check). GET endpoints performing state changes. CORS with + `origin: true, credentials: true` reflects origin. Double-cookie + implementations where the token is set via a readable cookie. + validation_checks: + - "Verify cookie-authenticated mutating routes require a CSRF token or SameSite cookies." + - "Confirm CORS does not reflect arbitrary origins when credentials are allowed." + false_positive_indicators: + - "API exclusively uses Bearer-token auth (Authorization header)." + - "SameSite=Strict cookies + no GET mutations." + + CWE-915: + title: "Mass Assignment" + analysis_prompt: > + `User.update(req.body)`, `new Model(req.body).save()`, Mongoose + schemas with `strict: false`, Sequelize `model.update(req.body)` with + no `fields` allowlist, Prisma `prisma.user.update({ data: req.body })`, + NestJS ValidationPipe without `whitelist: true`. + validation_checks: + - "Verify update / create calls receive an allowlist of fields, not req.body directly." + - "Confirm DTOs reject unknown properties (class-validator + ValidationPipe whitelist)." + - "Inspect Mongoose schemas for strict: false." + false_positive_indicators: + - "Inputs pass through a strict DTO with known fields only." + + CWE-1336: + title: "Server-Side Template Injection (SSTI)" + analysis_prompt: > + `pug.compile(userTemplate)`, `Handlebars.compile(userTemplate)`, + `nunjucks.renderString(userTemplate)`, `ejs.render(userTemplate)`, + any template engine fed a user-controlled TEMPLATE source (not + just user-controlled data placed into a constant template). Impact: + arbitrary JS execution via `process.mainModule.require` / similar + sandbox-escape payloads. + validation_checks: + - "Verify template sources are compile-time constants; user data is always context values." + false_positive_indicators: + - "User data is only passed as template context; templates are static files or hardcoded strings." + + CWE-327: + title: "Use of Broken or Risky Cryptographic Algorithm" + analysis_prompt: > + `Math.random()` for security tokens (use `crypto.randomBytes`). + `crypto.createHash('md5' / 'sha1')` for passwords. `crypto. + createCipher` (legacy, derives IV from password). ECB mode via + algorithm 'aes-256-ecb'. Static IVs, time-based IVs. RSA + signing without hash binding. + validation_checks: + - "Verify random tokens use crypto.randomBytes or crypto.randomUUID, not Math.random." + - "Confirm password hashing uses bcrypt / argon2 / scrypt; not md5 / sha1 / sha256-alone." + - "Inspect cipher use: AEAD (GCM / ChaCha20-Poly1305) with random IVs." + false_positive_indicators: + - "MD5/SHA1 used for non-security purposes (cache key, ETag)." + + CWE-798: + title: "Use of Hard-coded Credentials" + analysis_prompt: > + JWT secrets, API keys, database URLs with passwords, encryption keys + committed to source. .env files committed; config files with + plaintext secrets. + validation_checks: + - "Verify all secrets come from environment or a secret manager at runtime." + - "Confirm .env / secrets files are in .gitignore." + false_positive_indicators: + - "Value is a documented placeholder (`changeme`, `your-secret`) or a public identifier (publishable Stripe key, OAuth client id)." + + CWE-400: + title: "Uncontrolled Resource Consumption / ReDoS" + analysis_prompt: > + Regular expressions applied to user input with catastrophic + backtracking patterns (nested quantifiers over overlapping alternatives, + e.g. `^(a+)+$`, `(.*a)+$`). Common culprits: email validators, + custom password-policy regexes, URL parsing. Tooling hint: regexes + with `.*` / `.+` inside `(...)+` or `(...)*` are suspicious. Also + look for unbounded loops over user-controlled arrays without size + caps. + validation_checks: + - "Verify regexes applied to untrusted input are linear-time (safe-regex / RE2 usage)." + - "Confirm input size is capped before regex / JSON.parse / loops." + false_positive_indicators: + - "Regex is simple (no nested quantifiers) or input is length-bounded." + + CWE-20: + title: "Improper Input Validation" + analysis_prompt: > + Catch-all. Prefer specific CWEs when they fit. + validation_checks: + - "Verify all request values are validated with a schema library (zod, joi, class-validator) before use." + false_positive_indicators: + - "Values are validated by a DTO pipeline." diff --git a/prompts/exploit-proof-web-js/feature_detection.yaml b/prompts/exploit-proof-web-js/feature_detection.yaml new file mode 100644 index 0000000..8a2b114 --- /dev/null +++ b/prompts/exploit-proof-web-js/feature_detection.yaml @@ -0,0 +1,85 @@ +system_message: > + You are a JavaScript / TypeScript web application analysis expert. Quickly + scan the code and identify which security-relevant features are present. + Return findings in strict JSON format. + +user_prompt_template: > + Analyze the following JS/TS application to identify security-relevant + features. + + + Repository: {repo_name} + + + Repomix XML Content: + --- + {xml_content} + --- + + + FEATURE CATEGORIES: + 1. **express_routes** - Express / Koa / Fastify / Hapi route handlers + (`app.get`, `router.post`, `fastify.register`). + 2. **nestjs_controllers** - NestJS `@Controller`, `@Get`, `@Post`, DTOs, + guards, pipes. + 3. **nextjs_handlers** - Next.js `pages/api/*`, `app/**/route.ts`, + `getServerSideProps`, `getStaticProps`, middleware.ts. + 4. **websocket_handlers** - `socket.io`, `ws`, `uWebSockets.js`, NestJS + gateways. + 5. **graphql_resolvers** - Apollo Server, Mercurius, GraphQL Yoga, + schema-first / code-first resolvers. + 6. **authentication** - passport.js strategies, next-auth, custom JWT + verify, bcrypt / argon2 usage, session middleware (express-session, + cookie-session). + 7. **authorization** - casbin, CASL, NestJS guards, middleware checking + roles / permissions, row-level ownership checks. + 8. **sql_orm** - Sequelize, TypeORM, Prisma, knex, mikro-orm, raw `pg` + / `mysql2` queries. + 9. **nosql** - Mongoose / raw MongoDB driver, Redis, DynamoDB, Cosmos + DB operations. + 10. **template_rendering** - pug, handlebars, ejs, nunjucks, mustache, + Marko, React SSR. + 11. **serialization** - JSON.parse with reviver, custom deserializers, + yaml parsing, `serialize-javascript`, `node-serialize` (RCE). + 12. **command_execution** - `child_process.exec / execSync / spawn / + execFile / fork`, `node:vm`. + 13. **dynamic_require** - `require(variable)`, dynamic `import(...)`, + loading plugins by name from config / user input. + 14. **file_operations** - `fs.readFile / writeFile / createReadStream`, + `res.sendFile`, `path.join` with user input, multer uploads, archive + extraction (unzipper, tar). + 15. **external_http** - `fetch`, `axios`, `got`, `node-fetch`, `http.request` + with user-controlled URLs (SSRF surface). + 16. **cryptography** - `crypto` stdlib, `bcrypt` / `argon2`, `jsonwebtoken`, + `jose`, custom signing / encryption. + 17. **csrf_protection** - `csurf`, `@nestjs/csrf`, custom double-submit + cookie, origin / referer checks. + 18. **cors** - `cors` middleware, custom `Access-Control-*` setting. + 19. **dom_sinks** - Client-side code using `innerHTML`, `outerHTML`, + `document.write`, React `dangerouslySetInnerHTML`, Vue `v-html`, + Angular bypass sanitizer. + 20. **prototype_pollution_sinks** - `_.merge`, `_.set`, `_.defaultsDeep`, + custom deep-merge / deep-clone, `Object.assign` with recursive + descent over user JSON. + + + INSTRUCTIONS: + - Only include features that are ACTUALLY PRESENT in the code. + - Look for concrete function calls, not just imports. + - Be conservative. + + + Return ONLY a valid JSON object: + + + JSON Schema: + --- + {schema} + --- + + + IMPORTANT: + - Return ONLY the raw JSON object + - Do NOT wrap in markdown code blocks + - Do NOT include explanatory text + - The entire response must be valid, parseable JSON diff --git a/prompts/exploit-proof-web-js/security_analysis_base.yaml b/prompts/exploit-proof-web-js/security_analysis_base.yaml new file mode 100644 index 0000000..ba46006 --- /dev/null +++ b/prompts/exploit-proof-web-js/security_analysis_base.yaml @@ -0,0 +1,52 @@ +# exploit-proof: JavaScript / TypeScript web applications (Node.js backends +# and browser / React / Vue / Angular frontends). + +system_message: > + You are an elite offensive security researcher reviewing a JavaScript / + TypeScript web application. The code ships Express, Koa, Fastify, NestJS, + Next.js, React, Vue, Angular, or similar. Your job is to find REAL, + EXPLOITABLE vulnerabilities — not theoretical risks or best-practice gaps. + Trace request data (query, body, params, headers, cookies, files, WS + messages) to sinks: `eval`, `Function(...)`, `child_process.exec*`, + dynamic `require`, raw SQL / NoSQL queries, template rendering with + `|safe`-style escapes, DOM sinks (`innerHTML`, `dangerouslySetInnerHTML`, + `v-html`, bypass of Angular's DomSanitizer). Check every route for + missing authn / authz. Return findings in strict JSON format. + +analysis_intro: > + Find all exploitable security vulnerabilities in the JS/TS code below. + +analysis_requirements_header: "" + +repo_info: > + Repository: {repo_name} + + {xml_content} + +critical_instructions: > + For each vulnerability: + - file_path: relative path (e.g. src/routes/user.ts) + - start_line / end_line: the EXACT LINE performing the vulnerable + operation (the unescaped innerHTML, the eval, the prototype merge), + not the route definition above it. + - severity: 0.0-10.0 by exploitability and impact. + - cwe_id: most specific CWE as 'CWE-NNN: Name'. Prefer specific CWEs + (CWE-79, CWE-89, CWE-1321 prototype pollution, CWE-78, CWE-502, + CWE-918 SSRF, CWE-862, CWE-79 DOM, CWE-1336 SSTI, CWE-20 for NoSQL + injection) over plain CWE-20. + - technical_details: a concrete exploit proof — the exact HTTP request + (method, path, headers, body) or crafted input, step-by-step trace + through the code, and the primitive obtained. If you cannot write a + specific curl invocation or XHR request, drop the finding. + + QUALITY GATE: Could a pentester reproduce this in five minutes with + curl / browser devtools using only the technical_details? If no, drop it. + + Report each vulnerability ONCE at its most specific location. + +json_formatting_rules: > + Return ONLY a valid JSON object. No markdown fences, no commentary. + Escape backslashes in strings (\ becomes \\). + + JSON Schema: + {schema} diff --git a/prompts/exploit-proof-web-python/analysis_sections.yaml b/prompts/exploit-proof-web-python/analysis_sections.yaml new file mode 100644 index 0000000..f21b67d --- /dev/null +++ b/prompts/exploit-proof-web-python/analysis_sections.yaml @@ -0,0 +1,84 @@ +sections: + offensive_audit: + title: "PYTHON WEB ATTACK METHODOLOGY" + features: [] + content: > + For every route / handler / task in this chunk: + + 1. MAP ENTRY POINTS. Identify every request source: path params + (Django `path(''`, Flask ``, FastAPI `Path(...)`), + query (`request.GET`, `request.args`, `Query(...)`), body (JSON, + form, multipart), headers, cookies, uploaded files, websockets, + background job args, Celery task args. + + 2. TRACE TO SINKS. + - SQL: `.raw()`, `.extra()`, `.execute()`, `cursor.execute(f"...")`, + SQLAlchemy `text(f"...")`, `connection.execute(sql_string)`, + string formatting / f-strings / `%` into any query. + - Template injection: `render_template_string(user_input)`, + Jinja `Template(user).render(...)`, Django `Template(user)`. + - Code execution: `eval`, `exec`, `compile`, `__import__(user)`, + `importlib.import_module(user)`. + - Deserialization: `pickle.loads`, `pickle.load`, `yaml.load` + (without SafeLoader), `marshal.loads`, `shelve.open` on + untrusted data. + - Command: `subprocess.*` with `shell=True`, `os.system`, + `os.popen`, `commands.*`, `eval`ing a string constructed + from input. + - SSRF: `requests.get(url)`, `urllib.request.urlopen(url)`, + `httpx.get(url)`, `aiohttp.ClientSession().get(url)`, + internal host / metadata-endpoint reachability. + - XSS: `|safe`, `{% autoescape off %}`, `mark_safe`, `Markup`, + `HttpResponse(user)`, `{{ x|safe }}`, `innerHTML` in inline + script templates. + - Path traversal: `open(user)`, `send_file(user)`, + `FileResponse(user)`, `os.path.join(base, user)` without + prefix check after `realpath`. + + 3. AUTHN / AUTHZ ON EVERY VIEW. For Django: is `@login_required` + applied, or `LoginRequiredMixin`, or DRF `IsAuthenticated`? For + Flask: is `@login_required` from flask-login on every sensitive + view? For FastAPI: is the auth dependency present, and does it + actually validate? For every view that accepts a resource ID, + check ownership — `Order.objects.get(pk=pk)` without + `user=request.user` is IDOR. + + 4. LOGIC FLAWS. + - Mass assignment: `User.objects.create(**request.POST.dict())`, + `form.save()` without specifying fields, Pydantic models + without `extra='forbid'`. + - Algorithm confusion: `jwt.decode(token, key, algorithms=None)` + or `verify=False`, HS256/RS256 mixing. + - Race conditions on state: `if balance >= amount: balance -= + amount` without select_for_update / transaction. + - Type juggling: comparing `request.POST['role'] == 'admin'` + where the attacker sends a list. + + 5. CLIENT-SIDE (if templates render JS). Look for inline + `` — user data in a + JS context needs `|json_script` / explicit JSON encoding, not + the default HTML escape. `|safe` on anything touching user data + is a red flag. + + 6. CSRF. Django CSRF middleware off / `@csrf_exempt` on state- + changing views. Flask-WTF CSRF disabled. FastAPI / DRF view + accepting cookies without CSRF token validation. GET endpoints + performing state changes. + + 7. CRYPTO / SECRETS. `SECRET_KEY` hardcoded or committed. `random` + module used for tokens (use `secrets` instead). `hashlib.md5` / + `.sha1` for passwords (use `argon2` / `bcrypt` via passlib). + `check_password` replaced with string equality. + + 8. FRAMEWORK-SPECIFIC FOOTGUNS. + - Django: `RawSQL`, `cursor.execute(f"...")`, template tags + using `format_html` without escaping, `allow_unsafe_eval`, + `DEBUG=True` in prod (infoleak), middleware order (CSRF + before auth). + - Flask: `send_from_directory` with user-supplied filename + without `secure_filename` (path traversal), Jinja + autoescape on `.html` only by default. + - FastAPI / Pydantic: `model_dump(exclude_unset=False)` mass- + assigning admin fields; `response_model` missing. + - Celery / RQ: tasks accepting pickled args from untrusted + brokers. diff --git a/prompts/exploit-proof-web-python/audit.yaml b/prompts/exploit-proof-web-python/audit.yaml new file mode 100644 index 0000000..a23cd03 --- /dev/null +++ b/prompts/exploit-proof-web-python/audit.yaml @@ -0,0 +1,85 @@ +system_message: > + You are verifying inbound Python-web vulnerability reports. For each finding, + determine whether it is actually exploitable by examining the source code. + Be skeptical — reject anything you cannot construct a concrete HTTP request + for. Assign confidence 0.0-1.0. + {production_only_gate} + +user_prompt_template: > + Repository: {repo_name} + + === CLAIMS TO VERIFY (unverified leads from the analysis phase) === + {findings_json} + + INDEPENDENT VERIFICATION (READ BEFORE FORMING ANY VERDICT): + + The JSON above is a `{"claims_to_verify": [...]}` wrapper. Each claim + contains an `unverified_exploit_sketch` — a HYPOTHESIS from the analysis + phase, NOT a verified fact. The analysis phase is tuned for recall; its + narrative often hedges, self-rejects, or cites framework defaults that + may not actually apply at the claimed line. + + Verify every claim against the SOURCE CODE below. + + - Do NOT accept phrases like "Django ORM parameterizes", "CSRF + middleware covers this", "the DRF serializer filters this", "only + an admin can reach this" without checking the source lines yourself + (find the actual middleware, the actual serializer fields, the + actual permission class). + + - A REJECTED verdict MUST quote the specific source line(s) that block + the exploit. Paraphrasing the unverified_exploit_sketch is NOT a + valid justification. "The finding itself acknowledges it is not + exploitable" is NEVER a valid reason to reject — quote the code. + + - When unverified_exploit_sketch names an HTTP request, walk that + request through the code yourself (URL router → middleware → view → + query). If blocked, name the line. If not blocked, the finding + stands. + + === CWE GUIDANCE === + {cwe_analysis_prompts} + + === SOURCE CODE === + {code_context} + + {supplementary_context} + + For each finding, return a verdict: + - CONFIRMED: real, exploitable. The technical_details include the exact + HTTP request (method, path, headers, body) and the resulting primitive. + - REFINED: real but severity / CWE / details need adjustment. + - REJECTED: false positive — one-sentence reason (e.g. "the ORM method + used is .filter() not .raw(), which auto-parameterizes"). + - ESCALATED: worse than initially reported. + + Also report NEW vulnerabilities found while reviewing. Assign confidence + (0.0-1.0) to every finding. + + {production_only_gate} + + AUDIT SUMMARY DISCIPLINE: in `audit_summary`, describe each REJECTED / + REFINED finding in terms of the code guard that blocks or downgrades it + (e.g. "bounded by check at line N", "reachability gated by permission + check at line M"). Never use phrases like "self-retracted", "submitter + acknowledged", "by design", or "informational" without naming the + specific source line that supports the characterization. + + Return results as JSON per the schema below. + + {schema} + +json_formatting_rules: > + Return ONLY valid JSON. No markdown fences, no commentary. + Escape backslashes (\ becomes \\). + + JSON Schema: + {schema} + +production_only_gate: > + IMPORTANT: Findings must be reachable from production code. REJECT if: + - only present in tests / fixtures / management commands not exposed to + untrusted input; + - gated by `if DEBUG:` / `if settings.DEBUG:` / `if app.debug:` which is + disabled in production; + - in unused views that are not routed in urls.py / blueprints. diff --git a/prompts/exploit-proof-web-python/cwe_deep_analysis.yaml b/prompts/exploit-proof-web-python/cwe_deep_analysis.yaml new file mode 100644 index 0000000..c096402 --- /dev/null +++ b/prompts/exploit-proof-web-python/cwe_deep_analysis.yaml @@ -0,0 +1,263 @@ +# CWE-Specific Deep Analysis — Python web applications. + +cwe_prompts: + CWE-89: + title: "SQL Injection" + analysis_prompt: > + Check every location where SQL is constructed: `cursor.execute(f"...")`, + `.execute("...", format_args)`, Django `.raw()` / `.extra()`, + SQLAlchemy `text(f"...")`, Peewee `RawQuery`, `connection.execute` + with an f-string. Every `.raw()` or `text()` with string interpolation + is a bug unless every interpolated value is a compile-time constant. + Check for interpolation into ORDER BY / column-name / table-name + positions — parameterization can't help there; an allowlist must. + Inspect for second-order SQLi: values stored via ORM and later + interpolated into raw queries. Django `.extra()` is legacy and + dangerous; flag any use. + validation_checks: + - "Verify every query with user input uses parameter binding (%s / ? / :name), not f-strings or format()." + - "Confirm ORDER BY / column / table identifiers come from an allowlist, not user input." + - "Inspect .raw() and text() for string interpolation of any value derived from request data." + - "Check ORM-returned values that are later used in raw queries." + false_positive_indicators: + - "All interpolated values are compile-time constants." + - "The query uses parameter binding exclusively; the flagged % / format is on a non-value clause." + - "The value is an int from path('') / validated Pydantic int and cannot contain quotes." + + CWE-79: + title: "Cross-Site Scripting (XSS)" + analysis_prompt: > + Django: `|safe`, `mark_safe(user)`, `format_html` with user input in a + position not auto-escaped, `{% autoescape off %}` blocks. Flask / + Jinja: `{{ x|safe }}`, `Markup(user)`, rendering `.txt` / `.json` + templates (Jinja only autoescapes HTML-like extensions). FastAPI: + `HTMLResponse(user)` without escape. Inline script contexts: + `` — HTML-escape doesn't make JS- + safe; the correct pattern is `|json_script` (Django) or explicit + `json.dumps` then JS-escape. DOM XSS in served JS: check inline + scripts reading `location.hash` / `location.search` / `postMessage` + into `innerHTML`. + validation_checks: + - "Verify `|safe`, `mark_safe`, `Markup` are not applied to any user-derived value." + - "Confirm JSON data embedded in inline scripts uses |json_script or json.dumps." + - "Check that non-HTML template extensions explicitly enable autoescape." + - "Inspect DOM sinks in inline JS for user-controlled sources." + false_positive_indicators: + - "Value is a sanitizer output (bleach.clean with restrictive tags/attrs)." + - "Value is a server-generated int / bool / UUID with no user string component." + + CWE-1336: + title: "Server-Side Template Injection (SSTI)" + analysis_prompt: > + `render_template_string(user_template)`, `Template(user).render(ctx)`, + `Environment().from_string(user)`, Django `Template(user).render`. Any + path where the TEMPLATE itself (not the context) is user-controlled + grants arbitrary Python execution via `{{ ''.__class__.__mro__[...] }}` + or similar sandbox escapes. Modern Jinja sandboxes are defeatable. + validation_checks: + - "Verify template strings are always compile-time constants; user input goes into the context dict, not the template source." + - "Confirm no code path passes user data to .from_string() / Template(...) / render_template_string()." + false_positive_indicators: + - "The template string is a constant; user data is passed as a named context variable." + + CWE-502: + title: "Deserialization of Untrusted Data" + analysis_prompt: > + `pickle.loads` / `pickle.load` on any untrusted source is RCE. `yaml.load` + without `Loader=SafeLoader` is RCE (use `yaml.safe_load`). `marshal.loads`, + `shelve.open`, `jsonpickle.decode`. Celery tasks configured with + `task_serializer='pickle'` or `accept_content=['pickle']` trust the + broker fully — flag if the broker is not pre-auth-only. Custom + `__reduce__` methods on classes loadable from user data. + validation_checks: + - "Verify pickle / marshal / shelve are never used with data from HTTP requests, Celery brokers, Redis, or files authored by users." + - "Confirm yaml.load uses SafeLoader or is replaced by yaml.safe_load." + - "Inspect Celery config: JSON serializer preferred; pickle only with a fully trusted broker." + false_positive_indicators: + - "The pickled data is signed with itsdangerous / HMAC using a server-only key and verified before load." + - "The data source is a trusted internal service on a mutual-TLS channel." + + CWE-78: + title: "OS Command Injection" + analysis_prompt: > + `os.system(user)`, `os.popen(user)`, `subprocess.run(user, shell=True)`, + `subprocess.Popen(user, shell=True)`, `commands.getoutput` (legacy). Any + `shell=True` with user input is injection unless the input is escaped + with `shlex.quote` AND that's intentional. `shell=False` with + `subprocess.run([cmd, user_arg])` is safer but still vulnerable to + argument injection (user_arg starting with `-`). `eval(user)` and + `exec(user)` are also in this neighborhood. + validation_checks: + - "Verify subprocess calls use shell=False and a list of args, not a single string." + - "Confirm user values used as args cannot introduce new flags (prefixed with `--` separator or validated)." + - "Flag every `shell=True` / os.system / os.popen with non-literal arg." + false_positive_indicators: + - "Args are validated against a strict allowlist (e.g. only 'pdf' or 'png')." + - "User input is passed through shlex.quote and shell=True is documented as intentional." + + CWE-22: + title: "Path Traversal" + analysis_prompt: > + `open(user)` without base-prefix validation. `send_file(user)`, + `FileResponse(user)`. `os.path.join(base, user)` — note that join + does NOT prevent `../` and treats an absolute second arg as replacing + the base. `send_from_directory` in Flask is safe by default (uses + safe_join) but custom helpers may not be. Archive extraction (zipfile, + tarfile, shutil.unpack_archive) without checking entry paths is + zip-slip / tar-slip. + validation_checks: + - "Verify paths resolve under the intended base via realpath + startswith check." + - "Confirm archive extraction validates entry names (no `..` or absolute paths) before extraction." + - "Inspect uses of `os.path.join` with user input — ensure the result is re-validated." + false_positive_indicators: + - "werkzeug.utils.secure_filename is applied and the result is appended to a fixed base." + - "Filename is a UUID or database-row ID, not user-supplied." + + CWE-918: + title: "Server-Side Request Forgery (SSRF)" + analysis_prompt: > + Any `requests.get(user_url)`, `httpx.get(user_url)`, `urllib.urlopen`, + `aiohttp.ClientSession().get(user_url)`, or `urllib3.PoolManager` call + taking a URL from user input. Internal attackers can target: + 169.254.169.254 (cloud metadata), localhost services, internal + corporate networks. Flag naive protocol checks (startswith('http')) + that don't prevent redirect to `file://` / `gopher://`. Also flag + custom image proxies, URL previewers, webhook deliverers. + validation_checks: + - "Verify URL validation: scheme allowlist (http / https only) AND host allowlist OR blocklist of private ranges (127/8, 10/8, 172.16/12, 192.168/16, 169.254/16, fc00::/7, ::1)." + - "Confirm DNS re-resolution between check and use is prevented (resolve once, pass IP)." + - "Check redirects are bounded and re-validated after each hop." + false_positive_indicators: + - "The URL is constructed from a trusted internal config value." + - "The request goes to a single fixed external service with no user-controlled host." + + CWE-287: + title: "Improper Authentication" + analysis_prompt: > + JWT: `jwt.decode(token, key, algorithms=['HS256'])` with a weak secret, + `algorithms=None` or missing (alg:none attack), `verify=False`. Password + comparison with `==` (timing attack — use `hmac.compare_digest` or + `check_password`). MD5/SHA1 password hashing. Login endpoints without + rate limiting. Session cookies without Secure / HttpOnly / SameSite. + Custom auth middleware that skips validation on certain paths. + validation_checks: + - "Verify JWT decode pins algorithms to an expected list and rejects alg:none." + - "Confirm password check uses a constant-time comparison and a modern hash (argon2, bcrypt)." + - "Check session cookies set Secure + HttpOnly + SameSite." + - "Inspect login endpoints for rate limiting / lockout." + false_positive_indicators: + - "JWT verification is handled by a library (pyjwt with explicit algorithms, django-rest-framework-simplejwt default settings)." + - "Password comparison uses check_password / argon2.verify." + + CWE-862: + title: "Missing Authorization" + analysis_prompt: > + Views returning resources without ownership checks. `Order.objects.get( + pk=pk)` without filtering by `user=request.user`. DRF ViewSets with + no `get_queryset` restriction. FastAPI endpoints with auth dependency + but no per-object check. GraphQL resolvers that load objects by ID + globally. Background tasks with user IDs in args that don't re-verify. + validation_checks: + - "Verify every resource lookup filters by the authenticated user / tenant." + - "Confirm DRF ViewSet get_queryset restricts to request.user." + - "Inspect GraphQL node resolvers for per-object authorization." + false_positive_indicators: + - "Resources are intentionally public (catalog items, blog posts)." + - "Row-level security at the DB layer enforces tenant isolation." + + CWE-639: + title: "Authorization Bypass Through User-Controlled Key (IDOR)" + analysis_prompt: > + Endpoints accepting an ID from the client (path, query, body) and + returning / mutating the resource without verifying ownership. Classic + signs: `get_object_or_404(Order, pk=pk)` without user filter; + `Model.objects.get(id=user_input)`. Batch endpoints where auth is + checked on the wrapper but not per-item. + validation_checks: + - "Confirm every resource access validates ownership or membership." + - "Verify batch endpoints check per-item." + false_positive_indicators: + - "Row-level security / policies enforce per-tenant access at the DB layer." + - "Resource is intentionally public." + + CWE-352: + title: "Cross-Site Request Forgery (CSRF)" + analysis_prompt: > + Django: `@csrf_exempt` on state-changing views, CSRF middleware removed + from settings, DRF using SessionAuthentication without requiring CSRF. + Flask: Flask-WTF CSRFProtect not installed, forms without `{{ csrf_token + }}`. FastAPI: cookie-based auth without CSRF token validation. GET + endpoints performing mutations. + validation_checks: + - "Verify CSRF middleware is global and no @csrf_exempt is applied to mutating views." + - "Confirm cookie-authenticated APIs validate a CSRF token on mutating methods." + - "Check that no GET endpoint performs state changes." + false_positive_indicators: + - "Endpoint uses Bearer token authentication only (no cookies)." + - "SameSite=Strict cookies + no GET-based mutations." + + CWE-915: + title: "Mass Assignment / Improperly Controlled Modification of Attributes" + analysis_prompt: > + `User.objects.create(**request.POST.dict())`, `form.save()` on a + ModelForm with no `fields` specified, Pydantic models without + `extra='forbid'` used to bulk-update a DB row. DRF serializers with + `fields = '__all__'` exposing admin fields (is_staff, is_superuser, + balance). Setting `model.__dict__.update(data)`. + validation_checks: + - "Verify serializers / forms specify an explicit field allowlist." + - "Confirm Pydantic models use extra='forbid' or equivalent; sensitive fields (is_admin, role) are not in the public schema." + - "Inspect `update(**dict)` calls for attacker-controlled keys." + false_positive_indicators: + - "Serializer explicitly excludes sensitive fields." + - "Input is validated against a Pydantic model that cannot include privileged fields." + + CWE-327: + title: "Use of Broken or Risky Cryptographic Algorithm" + analysis_prompt: > + `hashlib.md5` / `.sha1` for passwords or security tokens. `random` + module for security-sensitive values (use `secrets`). PyCryptodome / + cryptography using DES, 3DES, RC4, AES-ECB, or static IVs. RSA + PKCS1v15 signing without hash binding. + validation_checks: + - "Verify password hashing uses argon2 / bcrypt / scrypt — not md5 / sha1 / sha256-alone." + - "Confirm random tokens use `secrets`, not `random`." + - "Inspect cipher construction: AES-GCM / ChaCha20-Poly1305, not ECB / static IV." + false_positive_indicators: + - "MD5/SHA1 used for non-security purposes (cache keys, content-addressed dedup)." + + CWE-798: + title: "Use of Hard-coded Credentials" + analysis_prompt: > + SECRET_KEY committed in settings.py, API keys / tokens in source files, + database URLs with embedded passwords, test credentials pointing at + real services. + validation_checks: + - "Verify SECRET_KEY, database URLs, and API keys come from environment or a secret manager." + - "Confirm no private keys / TLS client certs are committed." + false_positive_indicators: + - "Value is a well-known placeholder ('changeme', 'your-secret-here')." + - "Value is a public identifier (publishable Stripe key)." + + CWE-434: + title: "Unrestricted Upload of File with Dangerous Type" + analysis_prompt: > + Upload handlers trusting Content-Type / extension without magic-byte + validation. Saving files into a web-served directory where they + execute. Preserving user-supplied filenames (path traversal, + encoded execution). Archive extraction without zip-slip protection. + validation_checks: + - "Verify uploads validate magic bytes (python-magic) against an allowlist." + - "Confirm uploaded files are stored outside the web root or in a directory configured to not execute." + - "Check filenames are regenerated (UUID) and not derived from user input." + false_positive_indicators: + - "File type is allowlisted AND the upload path is not executed by the web server." + + CWE-20: + title: "Improper Input Validation" + analysis_prompt: > + Catch-all. Prefer specific CWEs above when they fit. + validation_checks: + - "Verify user values are validated against a type / range / allowlist before use." + false_positive_indicators: + - "Value is validated by a Pydantic model, Django form, or DRF serializer." diff --git a/prompts/exploit-proof-web-python/feature_detection.yaml b/prompts/exploit-proof-web-python/feature_detection.yaml new file mode 100644 index 0000000..39105b8 --- /dev/null +++ b/prompts/exploit-proof-web-python/feature_detection.yaml @@ -0,0 +1,82 @@ +system_message: > + You are a Python web application analysis expert. Quickly scan the code and + identify which security-relevant features and technologies are present. + Return findings in strict JSON format. + +user_prompt_template: > + Analyze the following Python web application to identify security-relevant + features. + + + Repository: {repo_name} + + + Repomix XML Content: + --- + {xml_content} + --- + + + FEATURE CATEGORIES: + 1. **django_views** - Django view functions / classes, URL routes, DRF + viewsets / serializers. + 2. **flask_routes** - Flask `@app.route` / Blueprint handlers. + 3. **fastapi_routes** - FastAPI / Starlette path operations, Pydantic + request models. + 4. **websocket_handlers** - Django Channels, FastAPI / Starlette WebSocket + endpoints, aiohttp WebSocket. + 5. **authentication** - Login flow, password hashing (passlib, bcrypt, + argon2), `@login_required`, custom auth backends, JWT encoding / + decoding. + 6. **authorization** - Permission classes, `user_passes_test`, role + checks, Django Guardian, DRF `IsAdminUser`, row-level ownership + checks. + 7. **orm_queries** - Django ORM, SQLAlchemy, Peewee, Tortoise ORM, + `.raw()` / `.execute()` / `text()` usage. + 8. **raw_sql** - `cursor.execute`, `connection.execute` with formatted + SQL strings. + 9. **template_rendering** - Jinja2, Django templates, Mako, `render_template`, + `render_template_string`, `Template(...).render`. + 10. **serialization** - `pickle.loads`, `yaml.load` (unsafe), `marshal`, + `shelve`, `json` (safe — only flag if custom hooks invoked). + 11. **command_execution** - `subprocess.*`, `os.system`, `os.popen`, + `commands.*`, `eval`, `exec`, `compile`. + 12. **file_operations** - `open(user)`, `send_file`, `FileResponse`, + path construction with user input, upload handling. + 13. **external_http** - `requests`, `httpx`, `urllib`, `aiohttp.ClientSession` + making outbound calls (SSRF surface). + 14. **cryptography** - `cryptography` library, `hashlib`, `hmac`, + `secrets`, `random` (if used for security tokens). + 15. **session_management** - Django sessions, Flask-Session, itsdangerous, + custom cookie signing. + 16. **csrf_protection** - Django CSRF middleware presence, Flask-WTF + CSRFProtect, `@csrf_exempt` usage, DRF SessionAuthentication. + 17. **cors** - `django-cors-headers`, `flask-cors`, FastAPI CORSMiddleware, + custom Access-Control-* header setting. + 18. **background_tasks** - Celery tasks, RQ jobs, Dramatiq, APScheduler — + especially if args come from user input. + 19. **graphql** - Graphene, Strawberry, Ariadne schemas and resolvers. + 20. **third_party_integrations** - OAuth / SSO libraries, payment SDKs, + email / SMS sending, cloud storage uploads. + + + INSTRUCTIONS: + - Only include features ACTUALLY PRESENT in the code. + - Look for imports used in function calls, not merely imported. + - Be conservative. + + + Return ONLY a valid JSON object: + + + JSON Schema: + --- + {schema} + --- + + + IMPORTANT: + - Return ONLY the raw JSON object + - Do NOT wrap in markdown code blocks + - Do NOT include explanatory text + - The entire response must be valid, parseable JSON diff --git a/prompts/exploit-proof-web-python/security_analysis_base.yaml b/prompts/exploit-proof-web-python/security_analysis_base.yaml new file mode 100644 index 0000000..b15e778 --- /dev/null +++ b/prompts/exploit-proof-web-python/security_analysis_base.yaml @@ -0,0 +1,53 @@ +# exploit-proof: Python web applications (Django / Flask / FastAPI / Starlette) + +system_message: > + You are an elite offensive security researcher reviewing a Python web + application. The code ships Django, Flask, FastAPI, Starlette, Tornado, + aiohttp, or a similar framework. Your job is to find REAL, EXPLOITABLE + vulnerabilities — not theoretical risks or best-practice gaps. Think like + an attacker: trace request data (query, body, path, headers, cookies, + files) into every sink (ORM query, template render, subprocess, file + open, pickle.loads, eval, external HTTP). Check every view for missing + authentication / authorization. Chain low-severity issues into high-impact + exploits. Return findings in strict JSON format. + +analysis_intro: > + Find all exploitable security vulnerabilities in the Python code below. + +analysis_requirements_header: "" + +repo_info: > + Repository: {repo_name} + + {xml_content} + +critical_instructions: > + For each vulnerability: + - file_path: relative path (e.g. app/views/auth.py) + - start_line / end_line: integer line numbers pointing at the EXACT LINE + performing the vulnerable operation (the raw SQL string, the + pickle.loads call, the render_template_string with user data), not + the decorator / route registration above. + - severity: 0.0-10.0 by real-world exploitability and impact. + - cwe_id: most specific CWE as 'CWE-NNN: Name'. Prefer specific CWEs + (CWE-89 SQLi, CWE-79 XSS, CWE-1336 SSTI, CWE-502 deserialization, + CWE-78 command injection, CWE-22 path traversal, CWE-918 SSRF, + CWE-862 missing authz, CWE-639 IDOR) over CWE-20. + - technical_details: you MUST include a concrete exploit proof — the + exact HTTP request (method, path, headers, body) or input an attacker + would send, what the app does with it step by step, and what the + attacker gains. If you cannot write a specific request that triggers + the vulnerability, do not report it. + + QUALITY GATE: Could a pentester reproduce this with curl and five minutes + reading the finding? If no, drop it. + + Report each vulnerability ONCE at its most specific location (the sink, + not every middleware / decorator above it). + +json_formatting_rules: > + Return ONLY a valid JSON object. No markdown fences, no commentary. + Escape backslashes in strings (\ becomes \\). + + JSON Schema: + {schema} diff --git a/prompts/exploit-proof/analysis_sections.yaml b/prompts/exploit-proof/analysis_sections.yaml new file mode 100644 index 0000000..895259f --- /dev/null +++ b/prompts/exploit-proof/analysis_sections.yaml @@ -0,0 +1,15 @@ +# Slim Carlini-style analysis sections +# Single always-on section. No feature gating — the model decides what matters. + +sections: + offensive_audit: + title: "ATTACK METHODOLOGY" + features: [] # Always included + content: > + For every file in this chunk: + 1. Map entry points: HTTP params, headers, cookies, file content, URL paths, WebSocket messages, GraphQL variables. + 2. Trace each input to sinks: SQL queries, shell commands, file paths, HTML output, crypto operations, auth decisions. Flag unvalidated flows. + 3. Check authorization on EVERY endpoint that takes a resource ID: does it verify the authenticated user owns that resource? Missing ownership checks = IDOR. + 4. Look for logic flaws: race conditions, mass assignment, type confusion, algorithm confusion (JWT alg=none), prototype pollution. + 5. Check client-side JS for DOM XSS: innerHTML/outerHTML/document.write with user-controlled data. + 6. Check state-changing endpoints (POST/PUT/DELETE) for CSRF protection. diff --git a/prompts/exploit-proof/audit.yaml b/prompts/exploit-proof/audit.yaml new file mode 100644 index 0000000..42da517 --- /dev/null +++ b/prompts/exploit-proof/audit.yaml @@ -0,0 +1,80 @@ +# Slim Carlini-style audit prompt +# Verification pass: "Is this actually exploitable?" + +system_message: > + You are verifying inbound vulnerability reports. For each finding, determine whether + it is actually exploitable by examining the source code. Be skeptical — reject anything + you cannot construct a concrete exploit for. Assign confidence 0.0-1.0. + {production_only_gate} + +user_prompt_template: > + Repository: {repo_name} + + === CLAIMS TO VERIFY (unverified leads from the analysis phase) === + {findings_json} + + INDEPENDENT VERIFICATION (READ BEFORE FORMING ANY VERDICT): + + The JSON above is a `{"claims_to_verify": [...]}` wrapper. Each claim + contains an `unverified_exploit_sketch` — a HYPOTHESIS from the analysis + phase, NOT a verified fact. The analysis phase is tuned for recall; its + narrative often hedges, self-rejects, or cites defensive guards that + may not actually exist in the code at the claimed line. + + Verify every claim against the SOURCE CODE below. + + - Do NOT accept phrases like "requires X privilege", "defended by Y", + "the caller validates at line Z", "only reachable if W", or any + statement that the issue "is not exploitable" without checking the + source lines yourself. + + - A REJECTED verdict MUST quote the specific source line(s) that + block the exploit. Paraphrasing the unverified_exploit_sketch is + NOT a valid justification. "The finding itself acknowledges it is + not exploitable" is NEVER a valid reason to reject — quote the code + that makes the finding false. + + - When unverified_exploit_sketch names an attack flow, walk that + flow in the source yourself. If blocked, name the line. If not + blocked, the finding stands. + + === CWE GUIDANCE === + {cwe_analysis_prompts} + + === SOURCE CODE === + {code_context} + + {supplementary_context} + + For each finding, return a verdict: + - CONFIRMED: real, exploitable vulnerability + - REFINED: real but severity/CWE/details need adjustment + - REJECTED: false positive — explain why in one sentence + - ESCALATED: worse than initially reported + + Also report any NEW vulnerabilities you discover while reviewing the code. + Assign confidence (0.0-1.0) to every finding. + + {production_only_gate} + + AUDIT SUMMARY DISCIPLINE: in `audit_summary`, describe each REJECTED / + REFINED finding in terms of the code guard that blocks or downgrades it + (e.g. "bounded by check at line N", "reachability gated by permission + check at line M"). Never use phrases like "self-retracted", "submitter + acknowledged", "by design", or "informational" without naming the + specific source line that supports the characterization. + + Return results as JSON per the schema below. + + {schema} + +json_formatting_rules: > + Return ONLY valid JSON. No markdown fences, no commentary. + Escape backslashes (\ becomes \\). + + JSON Schema: + {schema} + +production_only_gate: > + IMPORTANT: Findings must be reachable from production code paths. If a vulnerability + only exists in test code and cannot affect production, REJECT it. diff --git a/prompts/exploit-proof/cwe_deep_analysis.yaml b/prompts/exploit-proof/cwe_deep_analysis.yaml new file mode 100644 index 0000000..ac4aa70 --- /dev/null +++ b/prompts/exploit-proof/cwe_deep_analysis.yaml @@ -0,0 +1,576 @@ +# CWE-Specific Deep Analysis Prompts +# Each entry provides focused analysis guidance for a specific CWE category. +# Used during the audit phase to scrutinize initial findings. + +cwe_prompts: + CWE-89: + title: "SQL Injection" + analysis_prompt: > + Examine every location where SQL statements are constructed or executed. Look beyond obvious string + concatenation (e.g., "SELECT * FROM users WHERE id=" + userId) and inspect ORM-level raw query methods + such as Django's .raw(), .extra(), SQLAlchemy's text(), or ActiveRecord's .find_by_sql(). Check whether + parameterized queries or prepared statements are used consistently—a single missed binding is sufficient + for exploitation. Investigate stored procedures for internal dynamic SQL (EXECUTE IMMEDIATE, sp_executesql + with concatenation). Identify second-order injection where user input is stored safely but later + retrieved and interpolated into a new query without re-sanitization. Examine batch and bulk insert + operations where developers sometimes fall back to string formatting for performance. Check for blind + injection vectors where boolean conditions or time-based delays could leak data even when query results + are not directly returned. Inspect any query-builder patterns where column names, table names, or ORDER BY + clauses are derived from user input, as these cannot be parameterized and require strict allowlisting. + validation_checks: + - "Confirm that all user-supplied values pass through parameterized query bindings and are never interpolated via string concatenation, f-strings, or format()." + - "Verify that ORM raw query interfaces (.raw(), text(), $queryRaw) use placeholder syntax (?, :name, $1) rather than inline values." + - "Check that dynamic identifiers (column names, table names, sort directions) are validated against a strict allowlist, not merely escaped." + - "Inspect stored procedures for internal use of EXECUTE IMMEDIATE or EXEC with concatenated parameters." + - "Trace data flow for second-order injection: values stored in the database that are later used in query construction without re-parameterization." + - "Examine LIKE clauses for proper escaping of wildcard characters (%, _) in addition to parameterization." + false_positive_indicators: + - "Query uses parameterized bindings exclusively and the flagged concatenation only constructs non-value clauses (e.g., static table aliases)." + - "The interpolated value is an integer derived from an internal enum or constant, never from user input." + - "An ORM query builder is used end-to-end with no raw SQL fragments, and the framework auto-parameterizes all values." + + CWE-78: + title: "OS Command Injection" + analysis_prompt: > + Scrutinize every invocation of system-level command execution functions including subprocess.Popen, + subprocess.run, os.system, os.popen, child_process.exec, Runtime.exec, and backtick operators. + Determine whether shell=True (or its equivalent) is set, which enables shell metacharacter interpretation + and dramatically increases injection surface. When shell=False, verify that arguments are passed as a list + rather than a single string, and confirm no element of that list is derived from unsanitized user input. + Look for argument injection where a user-controlled value is passed as a flag (e.g., --output=/etc/passwd) + even when the command itself is static. Check for template strings or f-strings used to build command + strings. Examine whether shlex.quote(), escapeshellarg(), or equivalent escaping is applied, and verify + it matches the target shell (bash vs cmd.exe vs PowerShell). Investigate indirect command injection via + environment variables (PATH manipulation, LD_PRELOAD) or config files that influence command behavior. + Look for commands that accept filenames which could contain shell metacharacters (semicolons, pipes, + backticks) unless the filename is properly quoted. + validation_checks: + - "Verify that subprocess calls use shell=False and pass arguments as a list, not a single string." + - "Confirm that user-supplied values are never interpolated into command strings, even when shell=False." + - "Check for argument injection by verifying user input cannot introduce new flags (e.g., values starting with -- are rejected or preceded by --)." + - "Inspect all uses of os.system(), os.popen(), and backtick operators—these always invoke a shell and are inherently dangerous with user input." + - "Verify that any escaping function used (shlex.quote, escapeshellarg) is appropriate for the target platform and shell." + - "Check whether environment variables (PATH, LD_PRELOAD, IFS) can be influenced by user input before command execution." + false_positive_indicators: + - "The command string is entirely composed of hardcoded literals with no user-controlled components." + - "Arguments are passed as a list with shell=False, and all list elements are validated against a strict allowlist of known-safe values." + - "The subprocess call is wrapped in a sandboxed environment (container, seccomp) that restricts executable paths." + + CWE-79: + title: "Cross-Site Scripting (XSS)" + analysis_prompt: > + Analyze every location where user-supplied data is rendered into HTML responses, JavaScript contexts, + or DOM manipulation operations. Determine whether the template engine in use (Jinja2, Handlebars, ERB, + Razor, Thymeleaf) applies auto-escaping by default and whether it has been explicitly disabled via + filters like |safe, {{{triple-braces}}}, raw, or <%== %>. Check for reflected XSS where query parameters + or form data flow directly into response bodies. Identify stored XSS where user input persisted in a + database is later rendered without encoding. Examine DOM-based XSS through client-side code using + innerHTML, outerHTML, document.write, jQuery's .html(), or React's dangerouslySetInnerHTML—verify + whether the value passed to these sinks originates from user-controllable sources (location.hash, + location.search, postMessage, document.referrer). Check for context-dependent encoding failures: + data placed inside JavaScript strings requires JS escaping, data in HTML attributes requires attribute + encoding, and data in URLs requires URL encoding—HTML entity encoding alone is insufficient in these + contexts. Inspect Content Security Policy headers for unsafe-inline or unsafe-eval directives that + weaken XSS mitigations. Look for mutation XSS (mXSS) vectors in sanitizer bypass scenarios. + validation_checks: + - "Confirm the template engine has auto-escaping enabled globally and that |safe, raw, or equivalent bypass filters are justified and applied only to trusted content." + - "Verify that innerHTML, outerHTML, document.write, and jQuery .html() are never called with user-controlled data; prefer textContent or .text()." + - "Check that React components do not use dangerouslySetInnerHTML with unsanitized input, and that any sanitization uses a battle-tested library (DOMPurify)." + - "Inspect context-dependent encoding: JavaScript string contexts use JS-specific escaping, URL contexts use encodeURIComponent, and HTML attribute contexts use attribute encoding." + - "Verify Content-Security-Policy headers are present and do not include unsafe-inline or unsafe-eval in script-src." + - "Check for DOM sources (location.hash, location.search, postMessage handlers) flowing into DOM sinks without sanitization." + false_positive_indicators: + - "The template engine auto-escapes by default and the flagged output does not use any bypass filter." + - "The value rendered is an integer, boolean, or enum derived from server-side logic with no user-controlled string component." + - "A mature HTML sanitizer (DOMPurify, Bleach) is applied to the value before rendering, and the sanitizer configuration does not allow script or event handler attributes." + - "The output context is plain text (Content-Type: text/plain) or the response includes a restrictive CSP that blocks inline script execution." + + CWE-22: + title: "Path Traversal" + analysis_prompt: > + Examine every file system operation where a file path is constructed using user-supplied input—including + file reads, writes, deletes, directory listings, and archive extraction. Check whether path.join(), + os.path.join(), or equivalent functions are used and whether the developer understands that these + functions do not prevent traversal (e.g., os.path.join("/safe/dir", "../../../etc/passwd") resolves + upward). Verify that the resolved absolute path is validated to fall within the intended base directory + using os.path.realpath() or equivalent, and that this check happens after symlink resolution. Look for + null byte injection (%00) in languages/runtimes where null bytes can truncate path strings (older PHP, + some C libraries). Check for URL-encoded traversal sequences (%2e%2e%2f, %252e%252e%252f for double + encoding) that may bypass naive input filters. Inspect archive extraction (zip, tar) for entries + containing ../ paths (zip slip vulnerability). Examine whether the application follows symlinks that + could point outside the allowed directory. Check for path traversal in cloud storage key construction + (S3, GCS) where the "directory" abstraction is purely conventional. + validation_checks: + - "Verify that the resolved absolute path (after realpath/canonical resolution) is checked to start with the intended base directory prefix." + - "Confirm that path validation occurs after symlink resolution to prevent symlink-based escapes." + - "Check that archive extraction validates each entry's path does not contain ../ sequences or resolve outside the target directory (zip slip)." + - "Inspect for double-encoding and null-byte sequences that could bypass traversal filters." + - "Verify that user-supplied filenames are sanitized to remove or reject path separator characters (/, \\) and traversal sequences." + - "Check that cloud storage key construction validates user input does not include ../ or absolute paths." + false_positive_indicators: + - "The path is constructed from a server-generated identifier (UUID, database ID) with no user-controlled path components." + - "The application uses a chroot, container mount, or OS-level sandboxing that restricts file system access regardless of path." + - "A framework-level file-serving function (e.g., Django's static file handler, Express static middleware) handles path resolution securely by design." + + CWE-287: + title: "Improper Authentication" + analysis_prompt: > + Review the authentication architecture end-to-end, from credential submission through session + establishment. Check for endpoints or routes that lack authentication middleware entirely—look for + route registration patterns where auth middleware is applied selectively and verify no critical endpoint + is accidentally excluded. Examine token validation logic for JWT implementations: verify signature + validation is enforced (alg:none attack), the correct algorithm is pinned (RS256 vs HS256 confusion), + expiration (exp) and not-before (nbf) claims are checked, and the issuer/audience are validated. + Look for timing side-channels in credential comparison—string equality (== or ===) leaks information + via response time; use constant-time comparison (hmac.compare_digest, crypto.timingSafeEqual). Check + for credential stuffing vulnerabilities: are there rate limits, account lockout policies, or CAPTCHA + on login endpoints? Examine password reset flows for token predictability, expiration enforcement, and + token invalidation after use. Verify that session tokens have sufficient entropy (at least 128 bits) + and are transmitted securely (Secure, HttpOnly, SameSite cookie flags). Check multi-factor authentication + bypass vectors such as backup code brute-forcing, TOTP window tolerance, and fallback to weaker methods. + validation_checks: + - "Verify that every route serving sensitive data or performing state changes has authentication middleware applied, with no accidental exclusions." + - "Confirm JWT validation enforces algorithm pinning (rejects alg:none), checks exp/nbf/iss/aud claims, and uses the correct key for the algorithm." + - "Check that credential comparison uses constant-time functions (hmac.compare_digest, crypto.timingSafeEqual) to prevent timing attacks." + - "Verify login endpoints implement rate limiting, account lockout, or progressive delays to mitigate credential stuffing." + - "Confirm password reset tokens are cryptographically random, expire within a short window, and are single-use." + - "Check that session cookies set Secure, HttpOnly, and SameSite attributes appropriately." + false_positive_indicators: + - "The endpoint is intentionally public (login page, health check, public API documentation) and is documented as such." + - "Authentication is handled by an upstream reverse proxy, API gateway, or service mesh that is verified to enforce it before requests reach the application." + - "The flagged timing difference is in a non-security-sensitive comparison (e.g., comparing a session token to check cache hit, not for authentication)." + + CWE-862: + title: "Missing Authorization" + analysis_prompt: > + Examine every endpoint and function that accesses or modifies resources to determine whether + authorization checks are performed after authentication. Distinguish between authentication (who + you are) and authorization (what you're allowed to do)—a valid session token does not imply + permission to access any resource. Look for horizontal privilege escalation where User A can + access User B's data by changing a resource identifier in the request. Verify that resource + ownership is validated: when a user requests /api/orders/123, confirm the application checks + that order 123 belongs to the authenticated user, not just that the user is authenticated. + Check middleware ordering—authorization middleware must run after authentication middleware, and + both must execute before the handler. Look for inconsistencies between UI restrictions and API + enforcement: if a button is hidden in the frontend, verify the corresponding API endpoint still + enforces the check server-side. Examine role-based access control for role hierarchy bypass (can + a manager access admin endpoints by manipulating role claims?). Inspect batch operations and + GraphQL resolvers where authorization may be checked on the top-level query but not on nested + resource access. Check for parameter pollution where duplicate parameters override authorization + context. + validation_checks: + - "Verify that every endpoint performing data access checks the authenticated user's permission to access the specific resource, not just that they are logged in." + - "Confirm that authorization middleware executes after authentication and before the route handler, with no ordering gaps." + - "Check for horizontal privilege escalation by verifying resource ownership validation (e.g., WHERE user_id = :currentUser) on all data queries." + - "Inspect batch and bulk endpoints to confirm authorization is checked per-item, not just on the batch request." + - "Verify that GraphQL nested resolvers enforce authorization independently and do not rely solely on parent resolver checks." + - "Confirm that role-based checks prevent role hierarchy bypass (e.g., a user cannot escalate by modifying a role claim in a token or request)." + false_positive_indicators: + - "Authorization is enforced at a data layer (row-level security, database policies) that the application code does not need to duplicate." + - "The endpoint is intentionally accessible to all authenticated users (e.g., a shared dashboard, public directory lookup)." + - "An API gateway or service mesh enforces authorization rules before the request reaches the application." + + CWE-639: + title: "Authorization Bypass Through User-Controlled Key (IDOR)" + analysis_prompt: > + Identify every endpoint where a resource identifier is supplied by the client—path parameters + (/users/:id/profile), query parameters (?invoice_id=456), or request body fields. Determine + whether the application verifies that the authenticated user is authorized to access the + resource identified by that key. Check whether resource identifiers are predictable (sequential + integers, auto-increment IDs) versus opaque (UUIDv4, cryptographic hashes), recognizing that + opaque identifiers are defense-in-depth but not a substitute for authorization checks. Examine + batch and list endpoints where an array of IDs is submitted—verify authorization is enforced on + each individual ID, not just the first or the request as a whole. Look for indirect reference + patterns where the application maps a user-facing identifier to an internal one; verify the + mapping itself is scoped to the authenticated user. Check for IDOR in file download endpoints, + export functions, and webhook configurations where changing an ID could expose another tenant's + data. Inspect GraphQL queries and mutations for node/relay-style ID parameters that resolve + resources across tenant boundaries without permission checks. + validation_checks: + - "Confirm that every endpoint accepting a resource ID from the client validates the authenticated user's ownership or permission for that specific resource." + - "Verify that batch endpoints check authorization per-item, not just on the request level." + - "Check whether sequential/predictable identifiers are used and whether enumeration is feasible—recommend UUIDs as defense-in-depth." + - "Inspect file download and export endpoints for IDOR by verifying the resource belongs to the requesting user/tenant." + - "Verify that indirect reference maps (user-facing ID → internal ID) are scoped to the authenticated user's data." + false_positive_indicators: + - "The resource is intentionally public (e.g., a published blog post, product catalog item) and contains no sensitive data." + - "Row-level security or tenant isolation at the database layer ensures cross-tenant access is impossible regardless of application logic." + - "The identifier is a cryptographically signed token (HMAC'd ID) that the user cannot forge for other resources." + + CWE-352: + title: "Cross-Site Request Forgery (CSRF)" + analysis_prompt: > + Identify every state-changing endpoint (POST, PUT, PATCH, DELETE) and verify that CSRF + protections are in place. Check whether CSRF tokens are included in forms and validated + server-side—inspect the token generation for sufficient entropy and verify tokens are bound + to the user session. Examine SameSite cookie attributes: SameSite=Lax prevents CSRF for + top-level navigations with unsafe methods but does not protect GET endpoints with side effects; + SameSite=Strict provides stronger protection but may break legitimate cross-site navigation. + Check for GET endpoints that perform state changes (logout, delete, settings modification), + as these are vulnerable regardless of CSRF token protections. Verify that custom header + requirements (X-Requested-With, X-CSRF-Token) are enforced and cannot be set cross-origin + without CORS misconfiguration. Examine CORS configuration for overly permissive + Access-Control-Allow-Origin (wildcard or reflecting the Origin header) combined with + Access-Control-Allow-Credentials: true, which undermines CSRF protections. Check login + endpoints for login CSRF where an attacker forces the victim to authenticate as the attacker. + Inspect APIs that accept both cookie-based and token-based authentication—cookie-based auth + requires CSRF protection while Bearer token auth does not. + validation_checks: + - "Verify that all state-changing endpoints (POST, PUT, PATCH, DELETE) require and validate a CSRF token or rely on SameSite cookie attributes." + - "Confirm that no GET endpoint performs state-changing operations (data modification, logout, account changes)." + - "Check that CORS configuration does not reflect arbitrary Origin values when credentials are allowed." + - "Verify that CSRF tokens are cryptographically random, bound to the user session, and validated on every state-changing request." + - "Inspect login endpoints for login CSRF protections." + - "Confirm that APIs using cookie-based authentication enforce CSRF protections, even if token-based auth endpoints do not." + false_positive_indicators: + - "The endpoint exclusively uses Bearer token authentication (Authorization header) and does not accept cookies for authentication." + - "SameSite=Strict or SameSite=Lax cookies are used and the endpoint does not perform state changes via GET." + - "The API is accessed only by non-browser clients (mobile apps, server-to-server) and does not accept cookie-based sessions." + + CWE-502: + title: "Deserialization of Untrusted Data" + analysis_prompt: > + Search for all deserialization operations that process data from untrusted sources. In Python, + flag any use of pickle.loads(), pickle.load(), shelve, or yaml.load() (without Loader=SafeLoader) + on user-controlled input—these allow arbitrary code execution via crafted payloads. In Java, + identify ObjectInputStream.readObject() usage and check whether the input stream originates from + untrusted data (network, file upload, message queue); verify whether look-ahead deserialization + filters (ObjectInputFilter) or allowlist-based approaches are used. In PHP, check for + unserialize() on user input and whether allowed_classes is restricted. In .NET, examine + BinaryFormatter, NetDataContractSerializer, and SoapFormatter usage—all are dangerous with + untrusted data. For JSON deserialization, check for type-discriminated deserialization + (Jackson's enableDefaultTyping, Newtonsoft.Json's TypeNameHandling) that can instantiate + arbitrary classes. Examine message queue consumers and RPC frameworks that deserialize incoming + messages. Check for XML deserialization (XmlSerializer with user-controlled type, XStream + without allowlisting) and YAML deserialization (yaml.load vs yaml.safe_load in Python, + YAML.load in Ruby). Verify that any custom deserialization logic validates types before + instantiation. + validation_checks: + - "Confirm that pickle, shelve, and marshal are never used to deserialize data from untrusted sources; recommend JSON or Protocol Buffers instead." + - "Verify that yaml.load() is called with Loader=SafeLoader or replaced with yaml.safe_load() for all untrusted input." + - "Check that Java ObjectInputStream usage includes ObjectInputFilter allowlisting or is replaced with a safe alternative." + - "Verify that Jackson does not use enableDefaultTyping() or @JsonTypeInfo with untrusted input unless a PolymorphicTypeValidator allowlist is configured." + - "Inspect .NET code for BinaryFormatter or NetDataContractSerializer usage and confirm they are not used with untrusted data." + - "Check that PHP unserialize() calls specify allowed_classes as an array of safe types or false." + false_positive_indicators: + - "The deserialized data originates from a trusted internal source (signed message, encrypted cache) with integrity verification." + - "The application uses safe deserialization formats (JSON.parse without reviver, Protocol Buffers, FlatBuffers) that do not support arbitrary object instantiation." + - "An allowlist-based deserialization filter is in place that restricts instantiable classes to a known-safe set." + + CWE-327: + title: "Use of a Broken or Risky Cryptographic Algorithm" + analysis_prompt: > + Identify all cryptographic operations and evaluate the algorithms and parameters used. Flag + MD5 and SHA-1 when used for security-sensitive purposes (password hashing, digital signatures, + HMAC key derivation, certificate verification)—note that MD5/SHA-1 for non-security checksums + like cache keys or deduplication are generally acceptable. Check for symmetric encryption using + DES, 3DES, RC4, or Blowfish, which are considered weak; AES with 128-bit or 256-bit keys is + the minimum standard. Examine the block cipher mode: ECB mode is always inappropriate for + multi-block data as it leaks patterns; verify CBC is used with proper IV randomization (not + a fixed or zero IV) and consider whether authenticated encryption (GCM, CCM) is more + appropriate. Check RSA key sizes—keys below 2048 bits are considered weak, and 4096 bits is + recommended for long-term security. Look for custom or hand-rolled cryptographic + implementations instead of vetted libraries (OpenSSL, libsodium, BouncyCastle). Check for + static or hardcoded IVs/nonces in encryption operations, which destroy the security guarantees + of CBC and CTR modes. Verify that random number generation uses cryptographically secure + sources (os.urandom, crypto.randomBytes, SecureRandom) and not math.random or rand(). + validation_checks: + - "Confirm that MD5 and SHA-1 are not used for password hashing, digital signatures, or any security-sensitive integrity verification." + - "Verify that symmetric encryption uses AES-128 or AES-256 with an authenticated mode (GCM preferred) and not DES, 3DES, RC4, or ECB mode." + - "Check that RSA keys are at least 2048 bits and that elliptic curve parameters use standard curves (P-256, P-384, Curve25519)." + - "Verify that IVs and nonces are generated from a CSPRNG and are unique per encryption operation—never static, zero, or derived from predictable values." + - "Confirm that all random number generation for security purposes (tokens, keys, nonces) uses a CSPRNG, not Math.random() or similar." + false_positive_indicators: + - "MD5/SHA-1 is used for non-security purposes such as cache key generation, content deduplication, or ETags where collision resistance is not a security requirement." + - "The weak algorithm is used only for compatibility with a legacy system and is documented as a known risk with a migration plan." + - "The cryptographic operation uses a well-maintained library with secure defaults and the flagged concern is about an internal implementation detail handled by the library." + + CWE-798: + title: "Use of Hard-coded Credentials" + analysis_prompt: > + Scan the entire codebase for hardcoded secrets including API keys, passwords, authentication + tokens, database connection strings with embedded credentials, private keys, and certificates. + Look for common patterns: variables named password, secret, api_key, token, or credential + assigned to string literals; base64-encoded strings that decode to credentials; configuration + files (application.yml, settings.py, .env committed to version control) containing plaintext + secrets. Check for credentials in test files that might point to real services rather than + mocks. Examine Dockerfiles and docker-compose files for ENV or ARG directives with embedded + secrets. Look for credentials in CI/CD configuration files (.github/workflows, .gitlab-ci.yml, + Jenkinsfile). Verify that the application retrieves secrets from environment variables, a + secret manager (Vault, AWS Secrets Manager, GCP Secret Manager), or an encrypted configuration + system rather than source code. Check for private keys (RSA, SSH, TLS) committed to the + repository. Inspect comment blocks and documentation for credentials left as examples that + match production patterns. Check git history for secrets that were committed and later removed + but remain in the repository history. + validation_checks: + - "Verify that no string literals matching credential patterns (API keys, passwords, tokens) appear in source code, configuration files, or environment setup scripts." + - "Confirm that .env files, private keys, and certificate files are listed in .gitignore and are not committed to the repository." + - "Check that Dockerfiles and CI/CD configurations do not embed secrets in ENV, ARG, or RUN commands—use build-time secret mounts or runtime injection." + - "Verify that test fixtures use mock credentials or test-specific service accounts, not production secrets." + - "Confirm the application loads secrets from environment variables or a secrets manager at runtime, not from hardcoded values." + false_positive_indicators: + - "The string is a well-known placeholder or example value (e.g., 'changeme', 'your-api-key-here', 'sk_test_...') used in documentation or templates." + - "The credential is for a local development-only service (localhost database, test container) with no access to production data." + - "The value is a public identifier (e.g., a publishable Stripe key, OAuth client ID) that is intended to be embedded in client-side code." + - "The hardcoded value is a cryptographic constant (salt prefix, algorithm identifier) that is not itself a secret." + + CWE-434: + title: "Unrestricted Upload of File with Dangerous Type" + analysis_prompt: > + Examine every file upload endpoint and the full lifecycle of uploaded files—from receipt through + storage to retrieval. Check whether the application validates file type by extension, MIME type + (Content-Type header), or actual file content (magic bytes)—Content-Type headers are trivially + spoofed and extensions can be bypassed with double extensions (.php.jpg) or null bytes. Verify + that validation uses an allowlist of permitted types rather than a denylist of dangerous ones. + Inspect whether uploaded files are stored within the web root where they could be directly + accessed and executed by the web server (e.g., a .php file uploaded to a publicly served + directory). Check whether the application re-generates filenames using random identifiers or + preserves the original filename, which could contain path traversal sequences or special + characters. Examine image processing pipelines for ImageTragick-style vulnerabilities where + crafted image files trigger command execution. Verify that the storage location has restrictive + permissions and that the web server is configured to not execute files in the upload directory + (disable PHP, CGI handlers). Check for polyglot files that are valid in multiple formats + (GIFAR—a file that is both a GIF and a JAR). Inspect file size limits to prevent denial of + service via large uploads. + validation_checks: + - "Verify that file type validation uses an allowlist of permitted MIME types and extensions, not a denylist." + - "Confirm that file type is validated by inspecting magic bytes (file content), not just the Content-Type header or extension." + - "Check that uploaded files are stored outside the web root or in a location where the web server will not execute them." + - "Verify that uploaded filenames are regenerated (e.g., UUID-based) and the original filename is not used for storage paths." + - "Confirm file size limits are enforced to prevent denial-of-service via oversized uploads." + - "Check that image processing libraries are up to date and not vulnerable to known exploits (ImageTragick, libpng overflow)." + false_positive_indicators: + - "Files are stored in a blob store (S3, GCS) served via signed URLs with Content-Disposition: attachment, preventing browser execution." + - "The web server configuration explicitly disables script execution in the upload directory (e.g., RemoveHandler, php_flag engine off)." + - "The application only accepts uploads from authenticated internal services, not end users." + + CWE-918: + title: "Server-Side Request Forgery (SSRF)" + analysis_prompt: > + Identify every location where the server makes HTTP requests, DNS lookups, or network connections + using URLs or hostnames that are partially or fully user-controlled. This includes URL parameters + for webhooks, image fetching, URL previews, PDF generation from URLs, and import-from-URL + features. Verify that the application validates the destination to prevent requests to internal + network addresses (10.x.x.x, 172.16-31.x.x, 192.168.x.x, 127.0.0.1, ::1, 169.254.169.254) + and cloud metadata endpoints (169.254.169.254 for AWS/GCP, 169.254.169.254/metadata for Azure). + Check for DNS rebinding attacks where a hostname resolves to a public IP during validation but + to an internal IP during the actual request—the mitigation is to resolve DNS once and use the + IP directly. Examine whether the application follows HTTP redirects, as a redirect can bounce + a request from an allowed external host to an internal target. Look for SSRF via protocol + schemes: file://, gopher://, dict://, and other non-HTTP schemes can be used to interact with + internal services. Check for partial SSRF where the attacker controls only part of the URL + (path, query string) but the host is fixed—this can still reach internal endpoints if the + host is an internal proxy. Verify that URL parsing is consistent (the parser used for + validation must behave identically to the HTTP client). + validation_checks: + - "Verify that user-supplied URLs are validated against an allowlist of permitted hosts or domains, not just a denylist of internal ranges." + - "Confirm that requests to RFC 1918 addresses, loopback, link-local, and cloud metadata endpoints (169.254.169.254) are blocked." + - "Check that DNS resolution is performed once and the resolved IP is validated before the request is made, preventing DNS rebinding." + - "Verify that HTTP redirect following is disabled or that each redirect target is re-validated against the same rules." + - "Confirm that only http:// and https:// schemes are permitted—block file://, gopher://, dict://, and other schemes." + - "Check that URL parsing for validation and request execution uses the same library to prevent parser differential attacks." + false_positive_indicators: + - "The URL is constructed entirely server-side from trusted configuration with no user-controlled components." + - "The request is made through an egress proxy that enforces network-level restrictions on internal access." + - "The user-controlled portion is limited to a path or query parameter appended to a hardcoded base URL, and the base URL points to an external service." + + CWE-117: + title: "Improper Output Neutralization for Logs" + analysis_prompt: > + Examine all logging statements that include user-supplied data. Check whether newline characters + (\n, \r, \r\n) in user input could be used to inject forged log entries—an attacker could craft + input that, when logged, appears as a legitimate log line, potentially masking an attack or + creating false audit trails. Verify that structured logging frameworks (logfmt, JSON-structured + logging) are used instead of unstructured text logs, as structured formats naturally contain + user data within defined fields. Inspect log format strings for injection: if the logging + framework uses format specifiers (e.g., Python's % formatting with logging), check that user + input cannot introduce additional format specifiers that cause crashes or information leakage. + Look for CRLF injection in HTTP response headers that are subsequently logged, which could + corrupt log parsers. Check whether sensitive data (passwords, tokens, PII) is inadvertently + included in log messages. Examine whether log aggregation systems (Splunk, ELK, CloudWatch) + could interpret injected content as commands or structured query syntax. Verify that exception + logging does not capture and log full request bodies or headers containing sensitive data. + validation_checks: + - "Verify that user-supplied data included in log messages is sanitized to remove or escape newline and carriage return characters." + - "Confirm that structured logging (JSON, logfmt) is used, ensuring user data is contained within designated fields." + - "Check that log format strings do not allow user input to introduce additional format specifiers (%s, %d, {})." + - "Verify that sensitive data (passwords, tokens, session IDs, PII) is redacted or masked before logging." + - "Inspect exception handlers to confirm they do not log full request bodies or headers containing credentials." + false_positive_indicators: + - "The application uses a structured logging framework (e.g., zap, logrus, structlog) that automatically serializes values into defined fields." + - "The logged value is an internal identifier (numeric ID, enum) that cannot contain user-controlled characters." + - "Log output is consumed only by a system that safely handles embedded newlines (e.g., JSON log parser)." + + CWE-362: + title: "Concurrent Execution Using Shared Resource with Improper Synchronization (TOCTOU)" + analysis_prompt: > + Identify check-then-act patterns where a condition is verified and then a dependent action is + taken without atomicity guarantees. In file system operations, look for patterns where a file's + existence, permissions, or content is checked (stat, access, exists) and then the file is + opened, read, or modified—an attacker may alter the file between the check and the use. In + database operations, look for read-modify-write sequences without proper locking or atomic + operations (SELECT followed by UPDATE without SELECT FOR UPDATE, optimistic locking, or + serializable isolation). Examine concurrent request handling for race conditions in balance + checks, inventory decrements, rate limiting, and coupon redemption—these often require atomic + operations (compare-and-swap, database constraints, Redis WATCH/MULTI). Check for TOCTOU in + authentication flows where a permission check and the protected action are separate operations. + Look for shared mutable state accessed from multiple threads or goroutines without proper + synchronization (mutexes, channels, atomic operations). Examine lazy initialization patterns + for double-checked locking correctness. Check for race conditions in cache operations where + multiple requests might compute and store the same value with inconsistent results. + validation_checks: + - "Verify that file system check-then-act operations (exists/stat → open/write) are replaced with atomic operations (open with O_CREAT|O_EXCL, rename for atomic writes)." + - "Confirm that database read-modify-write sequences use SELECT FOR UPDATE, optimistic locking (version columns), or serializable transactions." + - "Check that financial operations (balance checks, transfers, inventory decrements) use atomic database operations or constraints, not application-level check-then-act." + - "Verify that shared mutable state is protected by appropriate synchronization primitives (mutexes, channels, atomic variables)." + - "Inspect rate limiting implementations for race conditions that allow burst bypass (use atomic counters or Redis INCR, not get-then-set)." + false_positive_indicators: + - "The check-then-act pattern operates on immutable data or data that only the current user/session can modify." + - "The operation is idempotent, making the race condition inconsequential (e.g., cache warm-up with identical results)." + - "Database-level constraints (UNIQUE, CHECK, serializable isolation) provide atomicity guarantees regardless of application logic." + + CWE-611: + title: "Improper Restriction of XML External Entity Reference (XXE)" + analysis_prompt: > + Identify all XML parsing operations and determine whether external entity processing is disabled. + In Java, check SAXParserFactory, DocumentBuilderFactory, XMLInputFactory, TransformerFactory, and + SchemaFactory for proper feature configuration—specifically that FEATURE_EXTERNAL_ENTITIES, + FEATURE_EXTERNAL_PARAMETER_ENTITIES, and FEATURE_DISALLOW_DOCTYPE_DECL are set correctly. In + Python, check that defusedxml is used instead of xml.etree.ElementTree, xml.dom.minidom, or lxml + with default settings. In PHP, verify that libxml_disable_entity_loader(true) is called before + parsing or that LIBXML_NOENT is not passed to simplexml_load_string/DOMDocument::loadXML. In + .NET, check that XmlReaderSettings.DtdProcessing is set to Prohibit or Ignore and that + XmlResolver is set to null. Look for XXE in less obvious locations: SOAP endpoints, SVG file + processing, DOCX/XLSX parsing (these are ZIP files containing XML), RSS/Atom feed parsing, and + SAML assertion processing. Check for billion laughs (XML bomb) attacks via recursive entity + expansion even when external entities are disabled—entity expansion limits should be configured. + Verify that XSLT processing does not allow document() function or external imports. + validation_checks: + - "Verify that XML parsers are configured to disable external entity processing (DTD, external general entities, external parameter entities)." + - "Confirm that DOCTYPE declarations are disallowed or that entity expansion is limited to prevent billion laughs attacks." + - "Check that XML parsing in file upload handlers (SVG, DOCX, XLSX, RSS) applies the same XXE protections as direct XML endpoints." + - "Verify that SAML and SOAP processing libraries are configured to reject external entities." + - "Inspect XSLT processors to confirm that document() and external imports are disabled." + false_positive_indicators: + - "The XML parser is explicitly configured to disable DTD processing and external entities (verified in the parser initialization code)." + - "The application uses a defused XML library (defusedxml in Python, safe defaults in modern .NET) that blocks XXE by default." + - "The XML input is generated internally by the application and never contains user-supplied content." + + CWE-94: + title: "Improper Control of Generation of Code (Code Injection)" + analysis_prompt: > + Search for all dynamic code execution functions: eval(), exec(), Function() constructor, + setTimeout/setInterval with string arguments, vm.runInNewContext(), and + importlib.import_module() with user-controlled module names. Determine whether any user input + flows into these functions—even indirect flows through database values, configuration files, + or environment variables that users can influence. Examine template engines for Server-Side + Template Injection (SSTI): test whether user input rendered through Jinja2, Mako, Freemarker, + Twig, Velocity, or Pebble can break out of the template context and execute arbitrary code + (e.g., {{config.__class__.__init__.__globals__}} in Jinja2). Check for expression language + injection in frameworks like Spring (SpEL), Struts (OGNL), or JSP (EL). Look for dynamic + import or require statements where the module path is user-controlled, enabling arbitrary + module loading. Examine code generation patterns where user input is interpolated into source + code that is then compiled and executed (e.g., dynamic SQL generation that is then eval'd, + or Groovy/Python script generation). Check for reflection-based invocation where user input + determines which method or class is instantiated. Verify that sandboxing attempts (Python's + restricted exec, JavaScript vm module) are recognized as bypassable and not relied upon as + the sole mitigation. + validation_checks: + - "Verify that eval(), exec(), and Function() constructor are never called with user-controlled input; if used, confirm input is entirely server-generated." + - "Check that template engines render user input as data, not as template syntax—user input should never be part of the template string itself." + - "Inspect for SSTI by confirming that user input is passed as template variables, not concatenated into template source code." + - "Verify that dynamic import/require paths are validated against an allowlist of permitted modules." + - "Check that expression language contexts (SpEL, OGNL, EL) do not evaluate user-supplied expressions." + - "Confirm that setTimeout and setInterval are not called with string arguments containing user input." + false_positive_indicators: + - "The eval/exec call processes only server-generated code with no user input in the data flow." + - "The template engine auto-escapes and the user input is passed as a context variable, never interpolated into the template source." + - "A strict Content Security Policy prevents execution of injected scripts on the client side (for client-side eval concerns)." + + CWE-200: + title: "Exposure of Sensitive Information" + analysis_prompt: > + Review all error handling paths, response bodies, and response headers for unintentional + information disclosure. Check whether uncaught exceptions return full stack traces, framework + version numbers, file paths, or database connection strings to the client—verify that + production error handlers return generic messages while logging details server-side. Examine + API responses for over-fetching: are full database records returned when only specific fields + are needed (e.g., returning a user object with password_hash, email, SSN alongside the + requested display_name)? Check for debug endpoints or development tools left enabled in + production (Django DEBUG=True, Express error handler with stack traces, Spring Boot Actuator + endpoints without authentication, PHP phpinfo()). Inspect HTTP response headers for server + software versions (Server, X-Powered-By, X-AspNet-Version) and remove or genericize them. + Look for sensitive data in URLs (tokens, passwords, API keys in query strings) that may be + logged by proxies, browsers, and analytics tools. Examine client-side source code and source + maps for embedded secrets, internal API endpoints, or sensitive business logic. Check for + directory listing enabled on web servers and verbose 404 pages that enumerate valid paths. + validation_checks: + - "Verify that production error handlers return generic error messages and do not expose stack traces, file paths, or internal details to clients." + - "Confirm that API responses include only the fields necessary for the consumer—no password hashes, internal IDs, or PII leakage." + - "Check that debug endpoints (phpinfo, Actuator, debug toolbar) are disabled or protected by authentication in production." + - "Verify that response headers do not reveal server software versions (Server, X-Powered-By, X-AspNet-Version)." + - "Confirm that sensitive data (tokens, credentials) is not included in URL query strings where it may be logged." + - "Check that client-side source maps are not deployed to production or are access-controlled." + false_positive_indicators: + - "The detailed error response is returned only in development/staging environments, gated by an environment variable or configuration flag." + - "The information exposed is already public (e.g., open-source framework version, public API documentation)." + - "The endpoint is restricted to authenticated internal users (admin panel) with appropriate access controls." + + CWE-306: + title: "Missing Authentication for Critical Function" + analysis_prompt: > + Identify all endpoints and functions that perform critical operations—user management (create, + delete, promote), financial transactions, configuration changes, data export, system + administration, and deployment triggers—and verify that authentication is strictly enforced. + Distinguish this from CWE-862 (Missing Authorization): this CWE concerns operations accessible + with no identity verification at all, not insufficient permission checks. Look for admin + panels, management APIs, and internal tools that assume network-level isolation provides + sufficient protection (e.g., "only accessible from the internal network" without authentication). + Check for API endpoints that rely on API key authentication where the key is a shared secret + rather than per-user credentials. Examine health check, metrics, and monitoring endpoints + (Prometheus /metrics, Spring Boot Actuator) for exposure of sensitive operational data without + authentication. Look for deployment hooks, webhook receivers, and callback URLs that accept + unauthenticated requests to trigger critical operations. Verify that password reset, account + recovery, and email verification endpoints cannot be abused to perform privileged actions + without proper identity verification. Check for debug or maintenance endpoints that bypass + normal authentication flows. + validation_checks: + - "Verify that all administrative and management endpoints require authentication, not just network-level access restrictions." + - "Confirm that webhook receivers and callback URLs validate the sender's identity (signature verification, shared secret, mutual TLS)." + - "Check that monitoring and metrics endpoints (/metrics, /health with details, Actuator) do not expose sensitive data without authentication." + - "Verify that deployment triggers and CI/CD webhook endpoints require authentication or cryptographic signature verification." + - "Confirm that internal-only endpoints are protected by authentication in addition to network controls, following defense-in-depth." + false_positive_indicators: + - "The endpoint is a public health check returning only 'ok' status with no sensitive operational details." + - "Network-level isolation is enforced by infrastructure (service mesh, VPC, firewall rules) AND the endpoint serves non-sensitive data." + - "The endpoint is part of the authentication flow itself (login, token exchange) and is intentionally unauthenticated." + + CWE-522: + title: "Insufficiently Protected Credentials" + analysis_prompt: > + Examine how the application stores, transmits, and handles user credentials throughout their + lifecycle. For password storage, verify that a modern adaptive hashing algorithm is used—bcrypt, + scrypt, or Argon2id with appropriate cost parameters (bcrypt work factor ≥ 10, Argon2id with + recommended memory/time/parallelism). Flag any use of MD5, SHA-1, SHA-256, or SHA-512 for + password hashing, even with a salt, as these fast hash functions are vulnerable to GPU-based + brute-force attacks. Verify that each password is hashed with a unique, randomly generated salt + (not a global or predictable salt). Check for plaintext password storage in databases, + configuration files, logs, or temporary files. Examine credential transmission to confirm + passwords are only sent over TLS—check for HTTP login forms, unencrypted API endpoints, and + FTP/SMTP credentials transmitted in cleartext. Look for password logging in authentication + handlers, audit trails, or error messages. Verify that API tokens and session credentials are + stored with equivalent protections—tokens in databases should be hashed (SHA-256 is acceptable + for tokens, unlike passwords). Check for credential exposure in memory dumps, core files, or + swap space—verify that sensitive strings are zeroed after use where the language permits. + Examine password change and reset flows to ensure old password hashes are not retained. + validation_checks: + - "Verify that passwords are hashed using bcrypt (work factor ≥ 10), scrypt, or Argon2id—not MD5, SHA-1, SHA-256, or any fast hash." + - "Confirm that each password hash uses a unique, cryptographically random salt, not a shared or predictable value." + - "Check that passwords are never stored in plaintext in any storage layer (database, config files, logs, temporary files)." + - "Verify that credentials are transmitted only over TLS and that login forms/API endpoints enforce HTTPS." + - "Confirm that API tokens and session identifiers stored in the database are hashed (SHA-256 is acceptable for tokens)." + - "Check that authentication handlers and error logging do not include plaintext passwords in log output." + false_positive_indicators: + - "SHA-256 is used to hash API tokens or session identifiers (not passwords)—this is appropriate as tokens have high entropy and do not need adaptive hashing." + - "The application delegates credential storage to an identity provider (OAuth, SAML, LDAP) and does not store passwords locally." + - "MD5/SHA-1 is used in a non-credential context (file checksums, cache keys) that does not involve password or secret storage." diff --git a/prompts/exploit-proof/feature_detection.yaml b/prompts/exploit-proof/feature_detection.yaml new file mode 100644 index 0000000..77c47af --- /dev/null +++ b/prompts/exploit-proof/feature_detection.yaml @@ -0,0 +1,74 @@ +# Feature Detection Prompt for Multi-Shot Analysis +# This prompt is used in the first phase to detect which features are present in the codebase + +system_message: > + You are a code analysis expert. Your job is to quickly scan a codebase and identify + which security-relevant features and technologies are present. Return your findings + in strict JSON format. + +user_prompt_template: > + Analyze the following repository structure and code to identify which security-relevant + features and technologies are present. + + + Your goal is to detect what types of functionality exist in this codebase, so that we + can perform a targeted security analysis. + + + Repository: {repo_name} + + + Repomix XML Content: + --- + {xml_content} + --- + + + Please identify which of the following features are present in this codebase: + + + FEATURE CATEGORIES: + 1. **public_api** - HTTP/REST endpoints, API routes, web handlers, or any publicly accessible interfaces + 2. **authentication** - Login systems, password handling, token generation, session creation, auth middleware + 3. **authorization** - Access control, permission checks, role-based access, authorization middleware + 4. **database_operations** - SQL queries, database connections, ORM usage, data persistence + 5. **user_input_handling** - Form processing, query parameters, request body parsing, user-provided data + 6. **file_operations** - File uploads, downloads, file system access, path operations + 7. **cryptography** - Encryption, decryption, hashing, signing, key management, crypto operations + 8. **session_management** - Session handling, cookies, session stores, session tokens + 9. **external_api_calls** - HTTP clients, API requests to external services, webhook calls + 10. **third_party_dependencies** - External libraries, packages, dependencies (check imports and package files) + 11. **infrastructure_as_code** - Terraform files (.tf), CloudFormation, Kubernetes configs, deployment configs + 12. **blockchain_crypto_finance** - Blockchain contracts, cryptocurrency handling, wallet operations, financial transactions + 13. **websockets** - WebSocket connections, real-time communication, socket handlers + 14. **graphql** - GraphQL schemas, resolvers, mutations, queries + 15. **grpc** - gRPC service definitions, protobuf files, RPC handlers + 16. **xml_processing** - XML parsing, SOAP services, XML document handling + 17. **deserialization** - JSON/XML/pickle deserialization, object unmarshaling + 18. **template_rendering** - Template engines (Jinja2, EJS, etc.), server-side rendering, HTML generation + 19. **shell_command_execution** - System command execution, shell operations, subprocess calls + 20. **sensitive_data_handling** - PII, passwords, tokens, API keys, secrets, credit card data + + + INSTRUCTIONS: + - Only include features that are ACTUALLY PRESENT in the code (not just imported but unused) + - Look for concrete evidence: function definitions, route handlers, database queries, API calls, etc. + - If you see imports but no actual usage, don't include that feature + - Be conservative: only report features with clear evidence + - Look at file names, imports, function definitions, and actual code implementation + + + Return ONLY a valid JSON object according to the following schema: + + + JSON Schema: + --- + {schema} + --- + + + IMPORTANT: + - Return ONLY the raw JSON object + - Do NOT wrap in markdown code blocks + - Do NOT include explanatory text + - The entire response must be valid, parseable JSON diff --git a/prompts/exploit-proof/security_analysis_base.yaml b/prompts/exploit-proof/security_analysis_base.yaml new file mode 100644 index 0000000..02b5330 --- /dev/null +++ b/prompts/exploit-proof/security_analysis_base.yaml @@ -0,0 +1,46 @@ +# Slim Carlini v3: exploit-proof gate +# Instead of listing what NOT to report, require a concrete exploit for each finding. +# The model's own exploitation reasoning serves as the precision filter. + +system_message: > + You are an elite offensive security researcher competing in a CTF. Your job is + to find REAL, EXPLOITABLE vulnerabilities — not theoretical risks or best-practice + gaps. Think like an attacker: trace user input to dangerous sinks, check every + endpoint for missing authorization, and chain low-severity issues into high-impact + exploits. Return findings in strict JSON format. + +analysis_intro: > + Find all exploitable security vulnerabilities in the code below. + +analysis_requirements_header: "" + +repo_info: > + Repository: {repo_name} + + {xml_content} + +critical_instructions: > + For each vulnerability you find: + - file_path: relative path (e.g. routes/auth.js) + - start_line / end_line: integer line numbers from the XML content. Point at the + EXACT LINE containing the vulnerable operation (the SQL query, the innerHTML + write, the exec call), not the function signature or route registration above it. + - severity: 0.0-10.0 based on real-world exploitability and impact + - cwe_id: most specific CWE formatted as 'CWE-NNN: Name' + - technical_details: you MUST include a concrete exploit proof — the exact HTTP + request (method, path, headers, body) or input an attacker would send, what + the server does with it step by step, and what the attacker gains. If you + cannot write a specific request that triggers the vulnerability, do not report it. + + QUALITY GATE: Before including any finding, ask yourself: "Could I hand this report + to a pentester and have them reproduce it in under 5 minutes with just curl/browser + devtools?" If no, drop it. + + Report each vulnerability ONCE at its most specific location. + +json_formatting_rules: > + Return ONLY a valid JSON object. No markdown fences, no commentary. + Escape backslashes in strings (\ becomes \\). + + JSON Schema: + {schema} diff --git a/prompts/nano-analyzer/analysis_sections.yaml b/prompts/nano-analyzer/analysis_sections.yaml new file mode 100644 index 0000000..2ea22dd --- /dev/null +++ b/prompts/nano-analyzer/analysis_sections.yaml @@ -0,0 +1,45 @@ +# nano-analyzer keeps methodology minimal: the 5-question loop over every +# function is the entire methodology. One always-on section. + +sections: + zero_day_hunt: + title: "ZERO-DAY HUNT METHODOLOGY" + features: [] + content: > + For every function in this chunk, walk the same loop: + + 1. IDENTIFY UNTRUSTED INPUT. Which parameter(s) come from outside the + trust boundary — network, file, IPC, user-supplied arg? Name them + explicitly. If no parameter is untrusted, skip this function + unless it operates on data another untrusted function stored. + + 2. ASK THE 5 QUESTIONS: + a. Can any parameter be NULL, too large, negative, or otherwise + invalid when called with malformed input? + b. Are there copies into fixed-size buffers without size + validation? + c. Can integer arithmetic overflow, wrap, or produce negative + values that are then used as sizes or indices? + d. Are tagged unions / variant types accessed without verifying + the type discriminator first? + e. Are return values from fallible operations checked before + use? + + 3. RESOLVE THE CONSTANTS. If a size, limit, or bound is a named + constant (`MAX_BUF_SIZE`, `EVP_MAX_MD_SIZE`, `PATH_MAX`), note + the resolved numeric value, not just the name. "A bound exists" + is not a verified bound. + + 4. VERIFY DEFENSES, DON'T INVENT THEM. If you believe a defense + prevents a finding, point at the specific line / function that + implements it and show the arithmetic. Don't rely on + "something upstream probably checks this." + + 5. DEPRIORITIZE NON-BUGS. Skip: static helpers whose callers are + all safe; allocation wrappers; platform-specific dead code; + theoretical UB without a concrete attack; defense-in-depth + observations with no attack path. + + 6. REPORT OR MOVE ON. If a concrete bug survives all checks, report + it at its sink line. Otherwise, move to the next function. Do + not pad findings with weak signals. diff --git a/prompts/nano-analyzer/audit.yaml b/prompts/nano-analyzer/audit.yaml new file mode 100644 index 0000000..5336c66 --- /dev/null +++ b/prompts/nano-analyzer/audit.yaml @@ -0,0 +1,121 @@ +# nano-analyzer-style triage / audit. Skeptical by default; demands that +# claimed defenses be verified rather than assumed. + +system_message: > + A first-pass scanner flagged these findings. Your job is to triage each + one: is it real, is it attacker-reachable, does a concrete defense prevent + it? Be skeptical — most scanner findings are false positives. + + {production_only_gate} + +user_prompt_template: > + Audit target: {repo_name} + + === CLAIMS TO VERIFY (unverified leads from the analysis phase) === + {findings_json} + + The JSON above is a `{"claims_to_verify": [...]}` wrapper. Each claim + contains an `unverified_exploit_sketch` — a HYPOTHESIS from the analysis + phase, NOT a verified fact. Verify against the SOURCE CODE below. + + === CWE GUIDANCE === + {cwe_analysis_prompts} + + === SOURCE CODE === + {code_context} + + {supplementary_context} + + + RULES: + - VALID (CONFIRMED): the bug is real AND an external attacker can + trigger it to cause meaningful harm (crash, code exec, data + corruption, auth bypass). The attacker must control the input that + triggers the bug. Confidence 0.7-1.0. + - INVALID (REJECTED): the bug pattern does not exist, OR it is not + attacker-reachable (only trusted internal callers), OR a concrete + defense prevents it, OR it is a code-quality issue not a security + vulnerability (e.g. data race on diagnostic state, missing NULL + check on internal-only API, undefined behavior only in debug + builds). Confidence 0.0-0.3. + - REFINED: the bug is real but severity / CWE / location / mechanism + needs correction. Confidence 0.5-0.8 after refinement. + - ESCALATED: the bug is worse than reported — name the amplifier / + chain. + + + SKEPTICISM DISCIPLINE (apply to every finding): + + CRUX. State the single key fact the verdict depends on. E.g. "whether + `len` is bounded before the memcpy at line 42", or "whether callers + of parse_attr() can pass untrusted type tags". + + VERIFY DEFENSES. When you cite any defense — a size limit, a NULL + check, a type-tag validation — verify it actually works: + - Look up the actual numeric values of named constants. + - Do the arithmetic. "There exists a bound" is NOT the same as + "the bound is sufficient." + - Show your work. Never skip the verification step. + + FOLLOW CONSTANTS. A constant name is not a verified bound. If the + guard is `if (len > MAX_BUF)`, find MAX_BUF's definition. If it's + defined via another macro, resolve that too. Only resolved numeric + values count. + + ABSENCE OF DEFENSE. If the bug pattern exists, the input is + untrusted, and you searched for a defense but did not find one, + lean VALID — not UNCERTAIN. Not having verified every upstream + caller is not a reason to mark UNCERTAIN; only cite a defense if + you can name the specific function / line. + + DON'T INVENT DEFENSES. Vague references to "assumptions in this + codebase" or "other code probably handles this" are not valid + defenses. If you cannot point to it in the provided code, it does + not exist. + + DON'T SELF-CONTRADICT. If your analysis reaches a conclusion, don't + then find reasons to reverse it in the same response. If you verify + a defense and it's insufficient, that is your answer. + + DON'T TRUST THE SKETCH. The unverified_exploit_sketch may self-reject + ("requires privilege X", "the caller validates") — those are + hypotheses to check against code, not conclusions. A REJECTED verdict + must quote the specific line(s) that block the exploit; paraphrasing + the sketch is NOT a valid justification. + + + For each finding, return: + - verdict: confirmed | refined | rejected | escalated + - confidence: 0.0-1.0 + - refined_severity, refined_cwe_id, refined_technical_details: your + corrected version; for REJECTED findings, refined_technical_details + should quote the specific code that closes the gate. + - justification: 2-3 sentences explaining the verdict. Must include + the crux and (for REJECTED) the code evidence. + + AUDIT SUMMARY DISCIPLINE: in `audit_summary`, describe each REJECTED / + REFINED finding in terms of the code guard that blocks or downgrades it + (e.g. "bounded by check at line N", "reachability gated by permission + check at line M"). Never use phrases like "self-retracted", "submitter + acknowledged", "by design", or "informational" without naming the + specific source line that supports the characterization. + + Return results as JSON per the schema below. + + {schema} + +json_formatting_rules: > + Return ONLY valid JSON. No markdown fences, no commentary. + Escape backslashes (\ becomes \\). + + JSON Schema: + {schema} + +production_only_gate: > + IMPORTANT: Findings must be reachable from production code paths + triggerable by an untrusted party. REJECT if: + - the bug is in test / fixture / example / sample code not shipped; + - the path is gated behind compile-time debug flags not enabled in + release; + - the only callers are trusted internal code with no untrusted input + reaching them. diff --git a/prompts/nano-analyzer/cwe_deep_analysis.yaml b/prompts/nano-analyzer/cwe_deep_analysis.yaml new file mode 100644 index 0000000..c266e40 --- /dev/null +++ b/prompts/nano-analyzer/cwe_deep_analysis.yaml @@ -0,0 +1,128 @@ +# nano-analyzer-style CWE guidance. Terse — nano focuses on a small set of +# bug classes and relies on skepticism discipline rather than exhaustive +# per-CWE prose. Each entry emphasizes: verify defenses with real numbers, +# don't invent upstream assumptions, deprioritize theoretical UB. + +cwe_prompts: + CWE-120: + title: "Classic Buffer Overflow" + analysis_prompt: > + Any copy into a fixed-size buffer where the source length is attacker- + influenced. Check the destination's declared size (resolve named + constants — `PATH_MAX`, `EVP_MAX_MD_SIZE`, macro definitions) against + the maximum source length the code enforces at the copy site, NOT at + some claimed "upstream" validation. + validation_checks: + - "Name the destination buffer and its resolved numeric size." + - "Name the source length and the bound enforced at the copy site." + - "If a bound is cited via a caller / wrapper, point at the specific function and verify its arithmetic." + false_positive_indicators: + - "Copy length is bounded at the call site by a constant ≤ destination size, verified numerically." + - "Destination is heap-allocated to exactly the required size after validating the length." + + CWE-787: + title: "Out-of-Bounds Write" + analysis_prompt: > + Writes at offsets or lengths derived from tainted input without a + bound that is actually sufficient. Includes memcpy / ptr::write / + indexed writes / struct casts to a type larger than the allocation. + validation_checks: + - "Verify every bounded write has a check whose constant resolves to a value ≤ destination size." + - "For indexed writes, verify the index is range-checked against the allocation, not just non-negative." + false_positive_indicators: + - "Destination is allocated after length validation to exactly the required size." + - "Index is structurally constrained by a typed enum with a bounded range." + + CWE-125: + title: "Out-of-Bounds Read" + analysis_prompt: > + Reads past the end of buffers derived from tainted length / offset + fields. Also copy-to-user / write-to-socket of partially-initialized + structs (info leak — check whether source was zeroed before partial + fill). + validation_checks: + - "Verify every read offset / length is bounded by the source buffer's allocated size." + - "For struct-to-socket writes, confirm the struct is fully initialized or explicitly zeroed." + false_positive_indicators: + - "Reads are bounded by prior length validation in the same function." + - "Struct uses kzalloc / memset / explicit field init before write." + + CWE-190: + title: "Integer Overflow / Underflow" + analysis_prompt: > + Arithmetic on tainted values (`len + hdr`, `count * elem_size`, + `remaining - consumed`) computed before the result is validated or + passed to malloc / memcpy / array indexing. Signed → unsigned + promotion flipping a negative value to huge size_t. Casts that + silently truncate (`x as u16`, `(uint32_t)len`). + validation_checks: + - "Identify every arithmetic site whose inputs are tainted." + - "Verify overflow-safe wrappers (checked_add, __builtin_*_overflow, reallocarray) are used OR the inputs are bounded to values that provably cannot overflow." + - "For underflows: verify `remaining = total - consumed` is guarded by `total >= consumed`." + false_positive_indicators: + - "Overflow-safe helpers are used with checked error handling." + - "Input is bounded to a small constant that structurally cannot overflow." + + CWE-476: + title: "NULL Pointer Dereference" + analysis_prompt: > + Fallible operations (allocator, lookup, parse) whose return value is + dereferenced without a NULL check. In languages that distinguish, + check for ERR_PTR / Option::None handling too. + validation_checks: + - "Identify every allocator / lookup / parse / alloc-like call that can return NULL or ERR_PTR." + - "Verify each is checked before the next field access or call through it." + false_positive_indicators: + - "Function is documented / enforced to not return NULL on the reached path." + - "Caller sets a NULL-tolerant wrapper or uses an infallible allocator (__GFP_NOFAIL / abort-on-null helper)." + + CWE-843: + title: "Type Confusion / Tagged Union Misuse" + analysis_prompt: > + Union or variant types accessed without checking the discriminator. + Classic pattern: `av->value.str_val->length` where `av->type` could + be NUMBER / BOOL / STR and the code always reads `str_val`. + validation_checks: + - "For every union member access, verify the discriminator field is checked first." + - "If the discriminator is set by tainted input, verify the access path is guarded per-tag." + false_positive_indicators: + - "Discriminator is immediately checked via switch/if before the member access." + - "Type is statically guaranteed by the call context (internal helper called with a known type)." + + CWE-252: + title: "Unchecked Return Value" + analysis_prompt: > + Fallible function whose failure is ignored, leading to downstream use + of invalid state. Most commonly paired with CWE-476 (deref the NULL + that came from an ignored failure) but also applies to partially- + successful operations. + validation_checks: + - "Identify fallible return-value sites not checked by the caller." + - "Trace the effect: does the ignored failure lead to a concrete bug (NULL deref, uninitialized use, resource leak with attacker-reachable trigger)?" + false_positive_indicators: + - "The return is intentionally ignored and the next use is NULL/failure-safe." + + CWE-416: + title: "Use After Free" + analysis_prompt: > + Pointer dereferenced after the owning object was freed. Common + patterns: error paths that free then continue using, callbacks / + timers / RCU readers holding pointers that other paths free, C++ + iterator invalidation, moved-from object use in Rust FFI. + validation_checks: + - "For every free/drop, identify any remaining pointer copies in scope and verify they are not used after the free." + - "For async callbacks holding a pointer, verify the callback is synchronized with the free (cancel_work_sync / del_timer_sync / synchronize_rcu)." + false_positive_indicators: + - "Pointer is NULL'd after free and subsequent paths are NULL-safe." + - "Free happens in a destructor reached only when refcount = 0." + + CWE-20: + title: "Improper Input Validation" + analysis_prompt: > + Catch-all for tainted values reaching sensitive operations without + validation when no more-specific CWE fits. Use sparingly — prefer the + specific CWEs above when they apply. + validation_checks: + - "Verify user-supplied values are range-checked / type-checked / structurally constrained before use." + false_positive_indicators: + - "Value is constrained by a prior typed parse (enum, int with range check)." diff --git a/prompts/nano-analyzer/feature_detection.yaml b/prompts/nano-analyzer/feature_detection.yaml new file mode 100644 index 0000000..1d340f5 --- /dev/null +++ b/prompts/nano-analyzer/feature_detection.yaml @@ -0,0 +1,83 @@ +# nano-analyzer's CONTEXT_GEN_PROMPT adapted to codecrucible's feature- +# detection slot. Instead of categorical feature buckets, it asks the model +# to identify the ATTACK SURFACE — entry points, tainted variables, fixed- +# size buffers, tagged unions, public vs static APIs — that the scanner +# should prioritize. + +system_message: > + You are preparing a security briefing for a vulnerability researcher. + Identify which attack-surface categories exist in this code so the + scanner can focus its analysis. Return findings in strict JSON format. + +user_prompt_template: > + Analyze the following code to identify the security-relevant attack + surface. Name actual variables and constants. Do not find vulnerabilities + — provide surface. + + + Repository: {repo_name} + + + Repomix XML Content: + --- + {xml_content} + --- + + + ATTACK-SURFACE CATEGORIES (nano-analyzer style — bug-class oriented, not + framework oriented): + + 1. **untrusted_input_entry** - Functions that receive data from outside + the trust boundary: network parsers, file readers, IPC handlers, CLI + arg parsers, syscall arguments, ioctl handlers. Name the entry + functions and the parameters that carry tainted data. + 2. **fixed_size_buffers** - Stack arrays and structs with fixed-size + char/byte buffers. Name them with their declared sizes resolved to + numeric values where named constants are involved + (e.g. `header[EVP_MAX_MD_SIZE=64]`). + 3. **size_arithmetic** - Code computing sizes / offsets / counts from + attacker-influenced inputs before passing them to memcpy / malloc / + array indexing. Locate the arithmetic sites. + 4. **tagged_unions_or_variants** - Union types or variant structs where a + type discriminator field selects which member is valid. Note whether + members are accessed with or without checking the tag. + 5. **fallible_returns** - Functions returning NULL / -1 / ERR_PTR on + failure whose callers may not check the return. Allocators, lookups, + parse helpers, syscalls. + 6. **public_vs_static_api** - Which functions are externally reachable + (public API, syscall handlers, router-exposed) vs internal helpers. + For static helpers, are all call sites safe? + 7. **refcounted_or_lifecycle** - Objects with explicit acquire/release, + init/cleanup pairs, or asynchronous callbacks holding pointers across + free boundaries. Note lifecycle-sensitive fields. + 8. **concurrency_state** - Fields shared across threads / interrupt + contexts / event handlers without obvious synchronization. + 9. **dangerous_sinks** - memcpy / strcpy / sprintf / exec* / eval / + system / raw SQL / template rendering / deserialization / dlopen + that take tainted or potentially-tainted data. + 10. **external_dependencies** - Third-party library calls whose failure + modes or input validation semantics affect the surface. + + + INSTRUCTIONS: + - Only include categories with CONCRETE evidence in the code. + - Name actual variables / structs / constants from the source. + - When a size is defined by a named constant, resolve it to the + numeric value if visible (use provided code, don't invent). + - Do not find vulnerabilities — just surface. + + + Return ONLY a valid JSON object according to the following schema: + + + JSON Schema: + --- + {schema} + --- + + + IMPORTANT: + - Return ONLY the raw JSON object + - Do NOT wrap in markdown code blocks + - Do NOT include explanatory text + - The entire response must be valid, parseable JSON diff --git a/prompts/nano-analyzer/security_analysis_base.yaml b/prompts/nano-analyzer/security_analysis_base.yaml new file mode 100644 index 0000000..c353e3a --- /dev/null +++ b/prompts/nano-analyzer/security_analysis_base.yaml @@ -0,0 +1,85 @@ +# nano-analyzer style prompt set. +# Adapted from weareaisle/nano-analyzer (github.com/weareaisle/nano-analyzer) +# to codecrucible's 5-file prompt schema. Preserves the terse, attacker-first +# voice: trace untrusted input, ask 5 questions per function, deprioritize +# obvious non-issues, show your work. + +system_message: > + You are a security researcher hunting for zero-day vulnerabilities. Analyze + the code step by step, tracing how untrusted data flows into each function. + For every function, ask yourself: + + 1. Can any parameter be NULL, too large, negative, or otherwise invalid when + this function is called with malformed input? + 2. Are there copies into fixed-size buffers without size validation? + 3. Can integer arithmetic overflow, wrap, or produce negative values that + are then used as sizes or indices? + 4. Are tagged unions / variant types accessed without verifying the type + discriminator first? + 5. Are return values from fallible operations checked before use? + + Focus on bugs that an external attacker can trigger through untrusted input. + Deprioritize static helpers with safe call sites, allocation wrappers, + platform-specific dead code, and theoretical issues. + + Reason first. Then output findings in strict JSON per the schema. + +analysis_intro: > + Find zero-day vulnerabilities in the code below. Report each bug ONCE at + its most specific location. Reasoning goes in your scratch thinking; the + JSON output is the contract. + +analysis_requirements_header: "" + +repo_info: > + Repository: {repo_name} + + {xml_content} + +critical_instructions: > + For every finding: + - file_path: the relative path of the file containing the bug + (e.g. net/parser.c). + - start_line / end_line: integer line numbers from the XML content, + pointing at the EXACT LINE performing the vulnerable operation — not + the function signature above it. + - severity: 0.0-10.0. Map from nano severities: + critical=9-10, high=7-8.9, medium=4-6.9, low=1-3.9, + informational=0.1-0.9. Calibrate by real-world attacker reach, not + pattern severity in the abstract. + - cwe_id: most specific CWE as 'CWE-NNN: Name'. Typical nano-hunt CWEs: + CWE-120/787 (buffer overflow / OOB write), CWE-125 (OOB read), + CWE-190 (integer overflow), CWE-476 (NULL deref), CWE-843 (type + confusion), CWE-252 (unchecked return), CWE-416 (UAF), CWE-20 + (generic input validation). + - technical_details: name the function, the untrusted input, the exact + data flow from entry point to sink, and the concrete consequence. For + numeric bounds, include resolved values — "buf[EVP_MAX_MD_SIZE] + where EVP_MAX_MD_SIZE=64" — not just the constant name. + + SKEPTICISM RULES (apply before reporting): + - A defense only counts if you can name the specific function/line that + implements it AND the numeric values make it sufficient. "There exists + a bound" is NOT the same as "the bound is sufficient." Do the + arithmetic. + - A constant name is not a verified bound — only its resolved numeric + value is. + - Vague references to "assumptions in this codebase" or "other code + probably handles this" are not valid defenses. If you cannot point to + it in the provided code, it does not exist. + - If your analysis reaches a conclusion, do not contradict yourself in + the same response. Verify once, decide. + + DEPRIORITIZE (do not report unless you can articulate a concrete attack): + - Static helpers whose call sites are all safe; + - Allocation wrappers that just forward arguments; + - Platform-specific dead code (#ifdef'd out on the build path); + - Theoretical UB that compilers have defined in practice; + - Defense-in-depth gaps with no concrete exploit path. + +json_formatting_rules: > + Return ONLY a valid JSON object conforming to the schema. No markdown + fences, no commentary. Escape backslashes in strings (\ becomes \\). + + JSON Schema: + {schema} diff --git a/testdata/fixtures/llm-responses/malformed_empty_arrays.json b/testdata/fixtures/llm-responses/malformed_empty_arrays.json new file mode 100644 index 0000000..af59412 --- /dev/null +++ b/testdata/fixtures/llm-responses/malformed_empty_arrays.json @@ -0,0 +1,8 @@ +{ + "repo_name": "example/nulls", + "description": "Response with null arrays", + "public_api_routes": null, + "security_issues": null, + "security_risk": 0.0, + "risk_justification": "No issues found." +} diff --git a/testdata/fixtures/llm-responses/malformed_extra_fields.json b/testdata/fixtures/llm-responses/malformed_extra_fields.json new file mode 100644 index 0000000..351a973 --- /dev/null +++ b/testdata/fixtures/llm-responses/malformed_extra_fields.json @@ -0,0 +1,23 @@ +{ + "repo_name": "example/extra-fields", + "description": "Response with extra unknown fields", + "public_api_routes": [], + "security_issues": [ + { + "issue": "Path Traversal", + "file_path": "internal/fileserver/serve.go", + "start_line": 20, + "end_line": 28, + "technical_details": "User-supplied filename is used in os.Open without sanitization.", + "severity": 7.0, + "cwe_id": "CWE-22", + "confidence": "high", + "model_name": "gpt-4o" + } + ], + "security_risk": 7.0, + "risk_justification": "Path traversal allows reading arbitrary files from the server.", + "confidence": 0.95, + "model_name": "gpt-4o", + "analysis_duration_ms": 1234 +} diff --git a/testdata/fixtures/llm-responses/malformed_markdown_wrapped.json b/testdata/fixtures/llm-responses/malformed_markdown_wrapped.json new file mode 100644 index 0000000..edcc409 --- /dev/null +++ b/testdata/fixtures/llm-responses/malformed_markdown_wrapped.json @@ -0,0 +1,20 @@ +```json +{ + "repo_name": "example/markdown-wrapped", + "description": "Response wrapped in markdown code fences", + "public_api_routes": [], + "security_issues": [ + { + "issue": "Server-Side Request Forgery", + "file_path": "internal/proxy/handler.go", + "start_line": 55, + "end_line": 62, + "technical_details": "User-controlled URL is fetched without validation, enabling SSRF attacks.", + "severity": 8.0, + "cwe_id": "CWE-918" + } + ], + "security_risk": 8.0, + "risk_justification": "SSRF vulnerability allows access to internal services." +} +``` diff --git a/testdata/fixtures/llm-responses/malformed_missing_fields.json b/testdata/fixtures/llm-responses/malformed_missing_fields.json new file mode 100644 index 0000000..c527891 --- /dev/null +++ b/testdata/fixtures/llm-responses/malformed_missing_fields.json @@ -0,0 +1,16 @@ +{ + "repo_name": "example/incomplete", + "description": "Response missing security_risk and risk_justification", + "public_api_routes": [], + "security_issues": [ + { + "issue": "Hardcoded Credentials", + "file_path": "config/database.go", + "start_line": 10, + "end_line": 12, + "technical_details": "Database password is hardcoded in source file.", + "severity": 8.0, + "cwe_id": "CWE-798" + } + ] +} diff --git a/testdata/fixtures/llm-responses/malformed_null_values.json b/testdata/fixtures/llm-responses/malformed_null_values.json new file mode 100644 index 0000000..d78419b --- /dev/null +++ b/testdata/fixtures/llm-responses/malformed_null_values.json @@ -0,0 +1,18 @@ +{ + "repo_name": "example/null-values", + "description": "Response with null text fields", + "public_api_routes": [], + "security_issues": [ + { + "issue": null, + "file_path": null, + "start_line": 1, + "end_line": 5, + "technical_details": "Some details about the finding.", + "severity": 6.0, + "cwe_id": "CWE-200" + } + ], + "security_risk": 6.0, + "risk_justification": "Moderate risk due to information disclosure." +} diff --git a/testdata/fixtures/llm-responses/malformed_wrong_types.json b/testdata/fixtures/llm-responses/malformed_wrong_types.json new file mode 100644 index 0000000..bf0db60 --- /dev/null +++ b/testdata/fixtures/llm-responses/malformed_wrong_types.json @@ -0,0 +1,18 @@ +{ + "repo_name": "example/wrong-types", + "description": "Response with wrong field types", + "public_api_routes": [], + "security_issues": [ + { + "issue": "Buffer Overflow", + "file_path": "src/parser.c", + "start_line": "42", + "end_line": 50, + "technical_details": "Unbounded memcpy from user input.", + "severity": "critical", + "cwe_id": "CWE-120" + } + ], + "security_risk": 9.0, + "risk_justification": "Critical memory corruption vulnerability." +} diff --git a/testdata/fixtures/llm-responses/valid_claude_single_finding.json b/testdata/fixtures/llm-responses/valid_claude_single_finding.json new file mode 100644 index 0000000..7dd8e56 --- /dev/null +++ b/testdata/fixtures/llm-responses/valid_claude_single_finding.json @@ -0,0 +1,18 @@ +{ + "repo_name": "example/webapp", + "description": "A Go web application with SQL database backend", + "public_api_routes": [], + "security_issues": [ + { + "issue": "SQL Injection", + "file_path": "internal/db/queries.go", + "start_line": 42, + "end_line": 48, + "technical_details": "User input is concatenated directly into SQL query string without parameterization, allowing arbitrary SQL execution.", + "severity": 8.5, + "cwe_id": "CWE-89: Improper Neutralization of Special Elements used in an SQL Command" + } + ], + "security_risk": 8.5, + "risk_justification": "Critical SQL injection vulnerability allows data exfiltration and potential remote code execution via database." +} diff --git a/testdata/fixtures/llm-responses/valid_gemini_zero_findings.json b/testdata/fixtures/llm-responses/valid_gemini_zero_findings.json new file mode 100644 index 0000000..e193261 --- /dev/null +++ b/testdata/fixtures/llm-responses/valid_gemini_zero_findings.json @@ -0,0 +1,8 @@ +{ + "repo_name": "utils/stringlib", + "description": "A small utility library for string manipulation with no external dependencies", + "public_api_routes": [], + "security_issues": [], + "security_risk": 0.5, + "risk_justification": "No security issues identified. The library performs only pure string transformations with no I/O or network access." +} diff --git a/testdata/fixtures/llm-responses/valid_gpt_multiple_findings.json b/testdata/fixtures/llm-responses/valid_gpt_multiple_findings.json new file mode 100644 index 0000000..ca3c2ac --- /dev/null +++ b/testdata/fixtures/llm-responses/valid_gpt_multiple_findings.json @@ -0,0 +1,36 @@ +{ + "repo_name": "frontend/dashboard", + "description": "React dashboard application with authentication", + "public_api_routes": [], + "security_issues": [ + { + "issue": "Cross-Site Scripting (XSS)", + "file_path": "src/components/UserProfile.jsx", + "start_line": 115, + "end_line": 120, + "technical_details": "User-supplied HTML is rendered via dangerouslySetInnerHTML without sanitization.", + "severity": 7.0, + "cwe_id": "CWE-79: Improper Neutralization of Input During Web Page Generation" + }, + { + "issue": "Use of Weak Cryptographic Algorithm", + "file_path": "src/utils/crypto.js", + "start_line": 22, + "end_line": 30, + "technical_details": "MD5 is used for password hashing instead of a modern algorithm like bcrypt or argon2.", + "severity": 5.5, + "cwe_id": "CWE-327: Use of a Broken or Risky Cryptographic Algorithm" + }, + { + "issue": "Missing Authentication for Critical Function", + "file_path": "src/api/admin.js", + "start_line": 8, + "end_line": 15, + "technical_details": "Admin endpoint /api/admin/users lacks authentication middleware, allowing unauthenticated access to user management.", + "severity": 9.0, + "cwe_id": "CWE-306: Missing Authentication for Critical Function" + } + ], + "security_risk": 9.0, + "risk_justification": "Multiple vulnerabilities including a critical missing authentication issue on admin endpoints." +} diff --git a/testdata/fixtures/llm-responses/valid_with_api_routes.json b/testdata/fixtures/llm-responses/valid_with_api_routes.json new file mode 100644 index 0000000..e3eb4fc --- /dev/null +++ b/testdata/fixtures/llm-responses/valid_with_api_routes.json @@ -0,0 +1,40 @@ +{ + "repo_name": "backend/payment-service", + "description": "Payment processing microservice with REST API", + "public_api_routes": [ + { + "route": "POST /api/v1/payments", + "citation": "cmd/server/routes.go:35" + }, + { + "route": "GET /api/v1/payments/{id}", + "citation": "cmd/server/routes.go:42" + }, + { + "route": "POST /api/v1/refunds", + "citation": "cmd/server/routes.go:50" + } + ], + "security_issues": [ + { + "issue": "Insecure Direct Object Reference", + "file_path": "internal/handlers/payment.go", + "start_line": 67, + "end_line": 75, + "technical_details": "Payment lookup uses user-supplied ID without verifying ownership, allowing access to other users' payment records.", + "severity": 7.5, + "cwe_id": "CWE-639: Authorization Bypass Through User-Controlled Key" + }, + { + "issue": "Insufficient Logging of Security Events", + "file_path": "internal/handlers/refund.go", + "start_line": 30, + "end_line": 35, + "technical_details": "Refund operations are not logged, making it impossible to audit or detect fraudulent refund activity.", + "severity": 4.0, + "cwe_id": "CWE-778: Insufficient Logging" + } + ], + "security_risk": 7.5, + "risk_justification": "IDOR vulnerability on payment endpoint poses significant risk to data confidentiality." +}