Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions aiscan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,17 @@ llm:
model: ""
# Proxy for LLM API requests
proxy: ""
# Custom HTTP headers for LLM API requests
# headers:
# User-Agent: "Version: 5.10.0 openwarp"
# Additional LLM providers for fallback or multi-model routing
# providers:
# - provider: ""
# base_url: ""
# api_key: ""
# model: ""
# proxy: ""
# headers: {}
# timeout: 0
# images: ""

Expand Down
44 changes: 36 additions & 8 deletions cmd/aiscan/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,12 @@ func parseCLI(args []string) (parsedCLI, error) {

option := cli.Option
if cli.Version {
return parsedCLI{Option: option, Mode: cfg.RunModeNoCommand}, nil
return finalizeParsedCLI(parsedCLI{Option: option, Mode: cfg.RunModeNoCommand})
}

mode := selectedMode(parser)
if mode == cfg.RunModeNoCommand {
return parsedCLI{Option: option, Mode: cfg.RunModeNoCommand}, nil
return finalizeParsedCLI(parsedCLI{Option: option, Mode: cfg.RunModeNoCommand})
}

if mode == cfg.RunModeScanner {
Expand All @@ -217,15 +217,15 @@ func parseCLI(args []string) (parsedCLI, error) {
return parsedCLI{}, err
}
scannerArgs := append([]string{scannerName}, scannerRest...)
return parsedCLI{Option: option, Mode: mode, ScannerArgs: scannerArgs}, nil
return finalizeParsedCLI(parsedCLI{Option: option, Mode: mode, ScannerArgs: scannerArgs})
}

if mode == runModeWeb {
return parsedCLI{Option: option, Mode: runModeWeb, WebOpts: cli.Web}, nil
return finalizeParsedCLI(parsedCLI{Option: option, Mode: runModeWeb, WebOpts: cli.Web})
}

ioaArgs := extractIOAArgs(&cli, mode)
return parsedCLI{Option: option, Mode: mode, IOAArgs: ioaArgs}, nil
return finalizeParsedCLI(parsedCLI{Option: option, Mode: mode, IOAArgs: ioaArgs})
}

func parseScannerCLI(scannerName string, rootArgs, scannerRest []string) (parsedCLI, error) {
Expand All @@ -250,7 +250,7 @@ func parseScannerCLI(scannerName string, rootArgs, scannerRest []string) (parsed
option := cli.Option
mergeManualScannerOptions(&option, manual)
if cli.Version {
return parsedCLI{Option: option, Mode: cfg.RunModeNoCommand}, nil
return finalizeParsedCLI(parsedCLI{Option: option, Mode: cfg.RunModeNoCommand})
}
option.Timeout = 3600

Expand All @@ -264,11 +264,35 @@ func parseScannerCLI(scannerName string, rootArgs, scannerRest []string) (parsed
if boolFlagEnabled(scannerArgs, "--debug") {
option.Debug = true
}
return parsedCLI{
return finalizeParsedCLI(parsedCLI{
Option: option,
Mode: cfg.RunModeScanner,
ScannerArgs: append([]string{scannerName}, scannerArgs...),
}, nil
})
}

func finalizeParsedCLI(parsed parsedCLI) (parsedCLI, error) {
headers, err := cfg.ParseHeaderFlags(parsed.Option.LLMHeaderFlags)
if err != nil {
return parsedCLI{}, err
}
parsed.Option.Headers = cfgMergeHeaders(parsed.Option.Headers, headers)
parsed.Option.LLMHeaderFlags = nil
return parsed, nil
}

func cfgMergeHeaders(base, override map[string]string) map[string]string {
if len(base) == 0 && len(override) == 0 {
return nil
}
out := make(map[string]string, len(base)+len(override))
for key, value := range base {
out[key] = value
}
for key, value := range override {
out[key] = value
}
return out
}

func mergeManualScannerOptions(option *cfg.Option, manual cfg.Option) {
Expand All @@ -277,6 +301,9 @@ func mergeManualScannerOptions(option *cfg.Option, manual cfg.Option) {
option.APIKey = cfg.ResolveString(manual.APIKey, option.APIKey)
option.Model = cfg.ResolveString(manual.Model, option.Model)
option.LLMProxy = cfg.ResolveString(manual.LLMProxy, option.LLMProxy)
if len(manual.LLMHeaderFlags) > 0 {
option.LLMHeaderFlags = append(option.LLMHeaderFlags, manual.LLMHeaderFlags...)
}
if manual.AI {
option.AI = true
}
Expand Down Expand Up @@ -413,6 +440,7 @@ var scannerKnownFlags = []knownFlag{
{names: []string{"--model"}, arity: 1, apply: func(o *cfg.Option, v string) { o.Model = v }},
{names: []string{"--proxy"}, arity: 1, apply: func(o *cfg.Option, v string) { o.Proxy = v }},
{names: []string{"--llm-proxy"}, arity: 1, apply: func(o *cfg.Option, v string) { o.LLMProxy = v }},
{names: []string{"--llm-header"}, arity: 1, apply: func(o *cfg.Option, v string) { o.LLMHeaderFlags = append(o.LLMHeaderFlags, v) }},
{names: []string{"--fofa-email"}, arity: 1, apply: func(o *cfg.Option, v string) { o.FofaEmail = v }},
{names: []string{"--fofa-key"}, arity: 1, apply: func(o *cfg.Option, v string) { o.FofaKey = v }},
{names: []string{"--hunter-token"}, arity: 1, apply: func(o *cfg.Option, v string) { o.HunterToken = v }},
Expand Down
42 changes: 42 additions & 0 deletions cmd/aiscan/cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,48 @@ func TestParseCLIAgentAcceptsLLMFlags(t *testing.T) {
}
}

func TestParseCLIAcceptsLLMHeaderFlags(t *testing.T) {
parsed, err := parseCLI([]string{
"agent",
"--llm-header", "User-Agent=Version: 5.10.0 openwarp",
"--llm-header", "X-Test=yes",
})
if err != nil {
t.Fatalf("parseCLI() error = %v", err)
}
if got := parsed.Option.Headers["User-Agent"]; got != "Version: 5.10.0 openwarp" {
t.Fatalf("User-Agent header = %q", got)
}
if got := parsed.Option.Headers["X-Test"]; got != "yes" {
t.Fatalf("X-Test header = %q", got)
}
}

func TestParseCLIScanExtractsLLMHeaderFlag(t *testing.T) {
parsed, err := parseCLI([]string{
"scan",
"-i", "127.0.0.1",
"--llm-header", "User-Agent=Version: 5.10.0 openwarp",
"--verify=high",
})
if err != nil {
t.Fatalf("parseCLI() error = %v", err)
}
wantArgs := []string{"scan", "-i", "127.0.0.1", "--verify=high"}
if !reflect.DeepEqual(parsed.ScannerArgs, wantArgs) {
t.Fatalf("scanner args = %#v, want %#v", parsed.ScannerArgs, wantArgs)
}
if got := parsed.Option.Headers["User-Agent"]; got != "Version: 5.10.0 openwarp" {
t.Fatalf("User-Agent header = %q", got)
}
}

func TestParseCLIRejectsInvalidLLMHeaderFlag(t *testing.T) {
if _, err := parseCLI([]string{"agent", "--llm-header", "User Agent=value"}); err == nil {
t.Fatal("parseCLI() error = nil, want invalid header error")
}
}

func TestParseCLIScanExtractsLLMFlags(t *testing.T) {
parsed, err := parseCLI([]string{
"scan",
Expand Down
22 changes: 17 additions & 5 deletions cmd/aiscan/web_full.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,12 @@ func initWebApp(ctx context.Context, configFile string, logger telemetry.Logger)

type webYAMLConfig struct {
LLM struct {
Provider string `yaml:"provider"`
BaseURL string `yaml:"base_url"`
APIKey string `yaml:"api_key"`
Model string `yaml:"model"`
Proxy string `yaml:"proxy"`
Provider string `yaml:"provider"`
BaseURL string `yaml:"base_url"`
APIKey string `yaml:"api_key"`
Model string `yaml:"model"`
Proxy string `yaml:"proxy"`
Headers map[string]string `yaml:"headers,omitempty"`
} `yaml:"llm"`
Cyberhub struct {
URL string `yaml:"url"`
Expand Down Expand Up @@ -198,6 +199,7 @@ func (s *webConfigStore) GetLLMConfig(ctx context.Context) (web.LLMConfig, error
APIKeyConfigured: strings.TrimSpace(c.LLM.APIKey) != "",
Model: c.LLM.Model,
Proxy: c.LLM.Proxy,
Headers: c.LLM.Headers,
}, nil
}

Expand Down Expand Up @@ -226,12 +228,21 @@ func (s *webConfigStore) SaveLLMConfig(ctx context.Context, llmCfg web.LLMConfig
if apiKey == "" {
apiKey = current.LLM.APIKey
}
headers := current.LLM.Headers
if llmCfg.Headers != nil {
var err error
headers, err = cfg.NormalizeHeaderMap(llmCfg.Headers)
if err != nil {
return web.LLMConfig{}, err
}
}

current.LLM.Provider = strings.TrimSpace(llmCfg.Provider)
current.LLM.BaseURL = strings.TrimSpace(llmCfg.BaseURL)
current.LLM.APIKey = apiKey
current.LLM.Model = strings.TrimSpace(llmCfg.Model)
current.LLM.Proxy = strings.TrimSpace(llmCfg.Proxy)
current.LLM.Headers = headers
next, _ := yaml.Marshal(&current)
if dir := filepath.Dir(p); dir != "." && dir != "" {
if err := os.MkdirAll(dir, 0755); err != nil {
Expand All @@ -250,6 +261,7 @@ func (s *webConfigStore) SaveLLMConfig(ctx context.Context, llmCfg web.LLMConfig
APIKeyConfigured: strings.TrimSpace(saved.LLM.APIKey) != "",
Model: saved.LLM.Model,
Proxy: saved.LLM.Proxy,
Headers: saved.LLM.Headers,
}, nil
}

Expand Down
14 changes: 14 additions & 0 deletions core/config/config_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ func generateFromStruct(t reflect.Type, v reflect.Value, indent int) string {
case fieldType.Kind() == reflect.Slice:
b.WriteString(generateSliceComment(prefix, configTag, descTag, fieldType))

case fieldType.Kind() == reflect.Map:
b.WriteString(generateMapComment(prefix, configTag, descTag))

default:
if descTag != "" {
b.WriteString(fmt.Sprintf("%s# %s\n", prefix, descTag))
Expand Down Expand Up @@ -146,6 +149,15 @@ func generateSliceComment(prefix, configTag, descTag string, t reflect.Type) str
return b.String()
}

func generateMapComment(prefix, configTag, descTag string) string {
var b strings.Builder
if descTag != "" {
b.WriteString(fmt.Sprintf("%s# %s\n", prefix, descTag))
}
b.WriteString(fmt.Sprintf("%s# %s: {}\n", prefix, configTag))
return b.String()
}

func formatValue(kind reflect.Kind, defaultVal string) string {
if defaultVal != "" {
switch kind {
Expand All @@ -164,6 +176,8 @@ func formatValue(kind reflect.Kind, defaultVal string) string {
return "0.0"
case reflect.String:
return `""`
case reflect.Map:
return `{}`
default:
return `""`
}
Expand Down
83 changes: 83 additions & 0 deletions core/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,64 @@ ioa:
}
}

func TestLoadConfigLLMHeaders(t *testing.T) {
dir := t.TempDir()
writeTestConfig(t, dir, `
llm:
headers:
User-Agent: "Version: 5.10.0 openwarp"
providers:
- provider: deepseek
api_key: dk-111
model: deepseek-chat
headers:
X-Primary: primary
- provider: openai
api_key: sk-222
model: gpt-4o
headers:
X-Fallback: fallback
`)

var opt Option
if err := LoadConfig(filepath.Join(dir, "aiscan.yaml"), &opt); err != nil {
t.Fatal(err)
}
if got := opt.Headers["User-Agent"]; got != "Version: 5.10.0 openwarp" {
t.Fatalf("top-level User-Agent header = %q", got)
}
if len(opt.Providers) != 2 {
t.Fatalf("providers = %d, want 2", len(opt.Providers))
}
if got := opt.Providers[0].Headers["X-Primary"]; got != "primary" {
t.Fatalf("primary header = %q", got)
}
if got := opt.Providers[1].Headers["X-Fallback"]; got != "fallback" {
t.Fatalf("fallback header = %q", got)
}
}

func TestMergeOptionHeadersCLIWins(t *testing.T) {
dst := Option{LLMOptions: LLMOptions{
Headers: map[string]string{"User-Agent": "cli-agent"},
}}
src := Option{LLMOptions: LLMOptions{
Headers: map[string]string{"user-agent": "config-agent", "X-Config": "yes"},
}}

mergeOption(&dst, &src)

if got := dst.Headers["User-Agent"]; got != "cli-agent" {
t.Fatalf("User-Agent = %q, want CLI value", got)
}
if got := dst.Headers["X-Config"]; got != "yes" {
t.Fatalf("X-Config = %q, want config value", got)
}
if _, ok := dst.Headers["user-agent"]; ok {
t.Fatalf("config user-agent key should be replaced by CLI override: %#v", dst.Headers)
}
}

func TestLoadConfigReconNumericZeroIsExplicit(t *testing.T) {
dir := t.TempDir()
writeTestConfig(t, dir, `
Expand Down Expand Up @@ -620,6 +678,31 @@ func TestProvidersListOnly(t *testing.T) {
}
}

func TestProvidersListPrimaryUsesTopLevelHeaders(t *testing.T) {
option := Option{}
option.Headers = map[string]string{"User-Agent": "top-agent"}
option.Providers = []LLMProviderEntry{
{Provider: "deepseek", APIKey: "key1", Model: "deepseek-chat", Headers: map[string]string{"X-Primary": "yes"}},
{Provider: "openai", APIKey: "key2", Model: "gpt-4o", Headers: map[string]string{"X-Fallback": "yes"}},
}

primary := ProviderConfig(&option)
if primary.Headers["User-Agent"] != "top-agent" || primary.Headers["X-Primary"] != "yes" {
t.Fatalf("primary headers = %#v, want top-level and primary entry headers", primary.Headers)
}

fallbacks := FallbackProviderConfigs(&option)
if len(fallbacks) != 1 {
t.Fatalf("fallbacks = %d, want 1", len(fallbacks))
}
if _, ok := fallbacks[0].Headers["User-Agent"]; ok {
t.Fatalf("fallback should not inherit top-level headers: %#v", fallbacks[0].Headers)
}
if got := fallbacks[0].Headers["X-Fallback"]; got != "yes" {
t.Fatalf("fallback X-Fallback = %q", got)
}
}

func TestProvidersListWithSingleFields(t *testing.T) {
option := Option{}
option.Provider = "anthropic"
Expand Down
Loading