From bc682724198464a63efe95761c3ca67a6bfceb16 Mon Sep 17 00:00:00 2001 From: mmoulikk Date: Thu, 28 May 2026 14:48:40 +0530 Subject: [PATCH] fix: forward port and sublocation for launchable --- pkg/cmd/gpucreate/gpucreate.go | 103 +++++++++++- pkg/cmd/gpucreate/gpucreate_test.go | 249 +++++++++++++++++++++++++++- pkg/store/workspace.go | 22 ++- 3 files changed, 367 insertions(+), 7 deletions(-) diff --git a/pkg/cmd/gpucreate/gpucreate.go b/pkg/cmd/gpucreate/gpucreate.go index 9950fe86..ebc2bd41 100644 --- a/pkg/cmd/gpucreate/gpucreate.go +++ b/pkg/cmd/gpucreate/gpucreate.go @@ -6,6 +6,8 @@ import ( "fmt" "io" "math/rand/v2" + "net" + "net/http" "net/url" "os" "strconv" @@ -1116,10 +1118,13 @@ func applyLaunchableConfig(cwOptions *store.CreateWorkspacesOptions, launchableI cwOptions.WorkspaceGroupID = wsReq.WorkspaceGroupID } - // Location + // Location / sub-location if wsReq.Location != "" { cwOptions.Location = wsReq.Location } + if wsReq.SubLocation != "" { + cwOptions.SubLocation = wsReq.SubLocation + } // Disk storage — the API may return a bare number (e.g., "256") or with // a unit suffix (e.g., "256Gi"). The server's ParseDiskStorage expects a @@ -1150,6 +1155,10 @@ func applyLaunchableConfig(cwOptions *store.CreateWorkspacesOptions, launchableI cwOptions.PortMappings = portMappings } + if len(wsReq.FirewallRules) > 0 { + cwOptions.FirewallRules = resolveFirewallRulesClientIP(wsReq.FirewallRules, publicIPLookup) + } + // Files from launchable if info.File != nil { cwOptions.Files = []map[string]string{ @@ -1173,6 +1182,98 @@ func applyLaunchableConfig(cwOptions *store.CreateWorkspacesOptions, launchableI cwOptions.Labels = labels } +// resolveFirewallRulesClientIP fills ClientIPs on any "user-ip" rule that +// doesn't already have one, calling lookupIP at most once. Rules are left +// unchanged on lookup failure or unparseable IPs. +func resolveFirewallRulesClientIP(rules []store.CreateFirewallRule, lookupIP func() (string, error)) []store.CreateFirewallRule { + out := make([]store.CreateFirewallRule, len(rules)) + copy(out, rules) + + var ( + ip string + ipErr error + looked bool + ) + for i := range out { + if out[i].AllowedIPs != "user-ip" || len(out[i].ClientIPs) > 0 { + continue + } + if !looked { + ip, ipErr = lookupIP() + looked = true + } + if ipErr != nil || ip == "" { + continue + } + cidr := toHostCIDR(ip) + if cidr == "" { + continue + } + out[i].ClientIPs = []string{cidr} + } + return out +} + +// toHostCIDR returns the single-host CIDR for an IP literal: /32 for IPv4, +// /128 for IPv6. Returns "" if raw isn't a valid IP. +func toHostCIDR(raw string) string { + parsed := net.ParseIP(strings.TrimSpace(raw)) + if parsed == nil { + return "" + } + if v4 := parsed.To4(); v4 != nil { + return v4.String() + "/32" + } + return parsed.String() + "/128" +} + +// publicIPLookup is a var so tests can stub it. +var publicIPLookup = resolvePublicIP + +// publicIPEndpoints are tried in order until one returns a valid IP. +// All return the IP as a plain-text body. +var publicIPEndpoints = []string{ + "https://api.ipify.org", + "https://checkip.amazonaws.com", + "https://ifconfig.me/ip", +} + +func resolvePublicIP() (string, error) { + client := &http.Client{Timeout: 3 * time.Second} + var lastErr error + for _, url := range publicIPEndpoints { + ip, err := fetchPublicIP(client, url) + if err == nil { + return ip, nil + } + lastErr = err + } + if lastErr == nil { + lastErr = fmt.Errorf("no public IP endpoints configured") + } + return "", lastErr +} + +func fetchPublicIP(client *http.Client, url string) (string, error) { + resp, err := client.Get(url) + if err != nil { + return "", breverrors.WrapAndTrace(err) + } + defer resp.Body.Close() //nolint:errcheck // best-effort + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("%s returned status %d", url, resp.StatusCode) + } + body, err := io.ReadAll(io.LimitReader(resp.Body, 64)) + if err != nil { + return "", breverrors.WrapAndTrace(err) + } + ipStr := strings.TrimSpace(string(body)) + if net.ParseIP(ipStr) == nil { + return "", fmt.Errorf("%s returned non-IP response: %q", url, ipStr) + } + return ipStr, nil +} + // normalizeDiskStorage ensures a disk storage value has a Kubernetes quantity suffix. // If the value is purely numeric (e.g., "256"), appends "Gi". Otherwise passes through // as-is, trusting the server's ParseDiskStorage to handle formats like "256Gi", "100G", etc. diff --git a/pkg/cmd/gpucreate/gpucreate_test.go b/pkg/cmd/gpucreate/gpucreate_test.go index be791982..ccffcbcc 100644 --- a/pkg/cmd/gpucreate/gpucreate_test.go +++ b/pkg/cmd/gpucreate/gpucreate_test.go @@ -1,6 +1,9 @@ package gpucreate import ( + "encoding/json" + "net/http" + "net/http/httptest" "strings" "testing" "time" @@ -208,6 +211,11 @@ func TestApplyLaunchableConfig(t *testing.T) { //nolint:funlen // test InstanceType: "n2-standard-4", Storage: "256", Location: "us-west1", + SubLocation: "us-west1-b", + FirewallRules: []store.CreateFirewallRule{ + {Port: "8080", AllowedIPs: "all"}, + {Port: "9000-9100", AllowedIPs: "all"}, + }, }, BuildRequest: store.LaunchableBuildRequest{ VMBuild: &store.VMBuild{ @@ -232,8 +240,9 @@ func TestApplyLaunchableConfig(t *testing.T) { //nolint:funlen // test // Workspace group from launchable assert.Equal(t, "GCP", cwOptions.WorkspaceGroupID) - // Location + // Location / sub-location assert.Equal(t, "us-west1", cwOptions.Location) + assert.Equal(t, "us-west1-b", cwOptions.SubLocation) // Storage with Gi suffix assert.Equal(t, "256Gi", cwOptions.DiskStorage) // Build config @@ -243,6 +252,10 @@ func TestApplyLaunchableConfig(t *testing.T) { //nolint:funlen // test assert.Equal(t, "ls-abc", cwOptions.VMBuild.LifeCycleScriptAttr.ID) // Port mappings assert.Equal(t, map[string]string{"Code-Server": "13337", "OpenClaw": "18789"}, cwOptions.PortMappings) + assert.Equal(t, []store.CreateFirewallRule{ + {Port: "8080", AllowedIPs: "all"}, + {Port: "9000-9100", AllowedIPs: "all"}, + }, cwOptions.FirewallRules) // Files assert.NotNil(t, cwOptions.Files) // LaunchableConfig @@ -330,6 +343,96 @@ func TestApplyLaunchableConfig(t *testing.T) { //nolint:funlen // test assert.Equal(t, "nvcr.io/nvidia/test:latest", cwOptions.CustomContainer.ContainerURL) }) + t.Run("substitutes public IP for user-ip firewall rules", func(t *testing.T) { + orig := publicIPLookup + publicIPLookup = func() (string, error) { return "203.0.113.7", nil } + defer func() { publicIPLookup = orig }() + + cwOptions := &store.CreateWorkspacesOptions{} + info := &store.LaunchableResponse{ + CreateWorkspaceRequest: store.LaunchableWorkspaceRequest{ + InstanceType: "n2-standard-4", + FirewallRules: []store.CreateFirewallRule{ + {Port: "22", AllowedIPs: "user-ip"}, + {Port: "443", AllowedIPs: "all"}, + }, + }, + } + + applyLaunchableConfig(cwOptions, "env-abc", info) + + assert.Equal(t, []store.CreateFirewallRule{ + {Port: "22", AllowedIPs: "user-ip", ClientIPs: []string{"203.0.113.7/32"}}, + {Port: "443", AllowedIPs: "all"}, + }, cwOptions.FirewallRules) + }) + + t.Run("preserves rule when public IP lookup fails", func(t *testing.T) { + rules := []store.CreateFirewallRule{ + {Port: "22", AllowedIPs: "user-ip"}, + } + resolved := resolveFirewallRulesClientIP(rules, func() (string, error) { + return "", assert.AnError + }) + assert.Equal(t, []store.CreateFirewallRule{ + {Port: "22", AllowedIPs: "user-ip"}, + }, resolved) + }) + + t.Run("does not call lookup when no user-ip rules", func(t *testing.T) { + called := false + rules := []store.CreateFirewallRule{ + {Port: "8080", AllowedIPs: "all"}, + } + resolveFirewallRulesClientIP(rules, func() (string, error) { + called = true + return "1.2.3.4", nil + }) + assert.False(t, called, "lookup should be skipped when no user-ip rules exist") + }) + + t.Run("respects pre-existing ClientIPs on user-ip rule", func(t *testing.T) { + rules := []store.CreateFirewallRule{ + {Port: "22", AllowedIPs: "user-ip", ClientIPs: []string{"198.51.100.5/32"}}, + } + resolved := resolveFirewallRulesClientIP(rules, func() (string, error) { + return "203.0.113.7", nil + }) + assert.Equal(t, []string{"198.51.100.5/32"}, resolved[0].ClientIPs) + }) + + t.Run("uses /128 for IPv6 public IPs", func(t *testing.T) { + rules := []store.CreateFirewallRule{ + {Port: "22", AllowedIPs: "user-ip"}, + } + resolved := resolveFirewallRulesClientIP(rules, func() (string, error) { + return "2403:2500:4000:0000:0000:0000:0000:090a", nil + }) + assert.Equal(t, []string{"2403:2500:4000::90a/128"}, resolved[0].ClientIPs) + }) + + t.Run("canonicalizes IPv4-mapped IPv6 to bare IPv4 with /32", func(t *testing.T) { + rules := []store.CreateFirewallRule{ + {Port: "22", AllowedIPs: "user-ip"}, + } + resolved := resolveFirewallRulesClientIP(rules, func() (string, error) { + return "::ffff:203.0.113.7", nil + }) + assert.Equal(t, []string{"203.0.113.7/32"}, resolved[0].ClientIPs) + }) + + t.Run("skips rule when lookup returns garbage", func(t *testing.T) { + rules := []store.CreateFirewallRule{ + {Port: "22", AllowedIPs: "user-ip"}, + } + resolved := resolveFirewallRulesClientIP(rules, func() (string, error) { + return "captive portal", nil + }) + assert.Equal(t, []store.CreateFirewallRule{ + {Port: "22", AllowedIPs: "user-ip"}, + }, resolved) + }) + t.Run("merges with existing labels", func(t *testing.T) { cwOptions := &store.CreateWorkspacesOptions{ Labels: map[string]string{"existingKey": "existingValue"}, @@ -351,6 +454,150 @@ func TestApplyLaunchableConfig(t *testing.T) { //nolint:funlen // test }) } +func TestToHostCIDR(t *testing.T) { + tests := []struct { + name string + in string + want string + }{ + {"IPv4", "203.0.113.7", "203.0.113.7/32"}, + {"IPv4 with surrounding whitespace", " 203.0.113.7\n", "203.0.113.7/32"}, + {"IPv4-mapped IPv6 canonicalizes to IPv4", "::ffff:203.0.113.7", "203.0.113.7/32"}, + {"IPv6 short form", "2403:2500:4000::90a", "2403:2500:4000::90a/128"}, + {"IPv6 long form canonicalizes", "2403:2500:4000:0000:0000:0000:0000:090a", "2403:2500:4000::90a/128"}, + {"IPv6 loopback", "::1", "::1/128"}, + {"empty", "", ""}, + {"not an IP", "not-an-ip", ""}, + {"captive portal HTML", "", ""}, + {"IPv4 with port appended", "203.0.113.7:443", ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, toHostCIDR(tt.in)) + }) + } +} + +// TestLaunchableJSONWireFormat guards against silent regressions in the JSON +// tags on CreateWorkspacesOptions — the original bug was that fields the +// server expected (firewallRules, subLocation) were missing from the request +// body. Asserting on the marshaled bytes catches typos like "firewall_rules" +// or a dropped tag that struct-field assertions can't. +func TestLaunchableJSONWireFormat(t *testing.T) { + cwOptions := &store.CreateWorkspacesOptions{} + info := &store.LaunchableResponse{ + ID: "env-abc", + CreateWorkspaceRequest: store.LaunchableWorkspaceRequest{ + InstanceType: "n2-standard-4", + Location: "us-west1", + SubLocation: "us-west1-b", + FirewallRules: []store.CreateFirewallRule{ + {Port: "8080", AllowedIPs: "all"}, + {Port: "22", AllowedIPs: "user-ip", ClientIPs: []string{"203.0.113.7/32"}}, + }, + }, + } + + applyLaunchableConfig(cwOptions, "env-abc", info) + + body, err := json.Marshal(cwOptions) + assert.NoError(t, err) + s := string(body) + + assert.Contains(t, s, `"subLocation":"us-west1-b"`) + assert.Contains(t, s, `"firewallRules":`) + assert.Contains(t, s, `"port":"8080"`) + assert.Contains(t, s, `"port":"22"`) + assert.Contains(t, s, `"allowedIPs":"all"`) + assert.Contains(t, s, `"allowedIPs":"user-ip"`) + assert.Contains(t, s, `"clientIPs":["203.0.113.7/32"]`) + assert.Contains(t, s, `"launchableConfig":{"id":"env-abc"}`) +} + +func TestFetchPublicIP(t *testing.T) { + client := &http.Client{Timeout: 2 * time.Second} + + t.Run("returns IP for plain text body", func(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("203.0.113.7\n")) + })) + defer s.Close() + + ip, err := fetchPublicIP(client, s.URL) + assert.NoError(t, err) + assert.Equal(t, "203.0.113.7", ip) + }) + + t.Run("rejects non-200 status", func(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "boom", http.StatusInternalServerError) + })) + defer s.Close() + + _, err := fetchPublicIP(client, s.URL) + assert.Error(t, err) + }) + + t.Run("rejects non-IP body (captive portal HTML)", func(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("login required")) + })) + defer s.Close() + + _, err := fetchPublicIP(client, s.URL) + assert.Error(t, err) + }) +} + +func TestResolvePublicIPFallback(t *testing.T) { + failing := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "down", http.StatusBadGateway) + })) + defer failing.Close() + + ok := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("198.51.100.42")) + })) + defer ok.Close() + + t.Run("falls through failing endpoints until one succeeds", func(t *testing.T) { + orig := publicIPEndpoints + publicIPEndpoints = []string{failing.URL, failing.URL, ok.URL} + defer func() { publicIPEndpoints = orig }() + + ip, err := resolvePublicIP() + assert.NoError(t, err) + assert.Equal(t, "198.51.100.42", ip) + }) + + t.Run("returns error when all endpoints fail", func(t *testing.T) { + orig := publicIPEndpoints + publicIPEndpoints = []string{failing.URL, failing.URL} + defer func() { publicIPEndpoints = orig }() + + _, err := resolvePublicIP() + assert.Error(t, err) + }) + + t.Run("short-circuits on first success", func(t *testing.T) { + hits := 0 + counted := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + hits++ + _, _ = w.Write([]byte("203.0.113.7")) + })) + defer counted.Close() + + orig := publicIPEndpoints + publicIPEndpoints = []string{counted.URL, failing.URL, ok.URL} + defer func() { publicIPEndpoints = orig }() + + ip, err := resolvePublicIP() + assert.NoError(t, err) + assert.Equal(t, "203.0.113.7", ip) + assert.Equal(t, 1, hits, "later endpoints must not be hit after success") + }) +} + func TestParseInstanceTypesFromFlag(t *testing.T) { tests := []struct { name string diff --git a/pkg/store/workspace.go b/pkg/store/workspace.go index 35f5d4f1..a170e739 100644 --- a/pkg/store/workspace.go +++ b/pkg/store/workspace.go @@ -99,6 +99,7 @@ type CreateWorkspacesOptions struct { ExecsV1 *entity.ExecsV1 `json:"execsV1"` InstanceType string `json:"instanceType"` Location string `json:"location,omitempty"` + SubLocation string `json:"subLocation,omitempty"` DiskStorage string `json:"diskStorage"` BaseImage string `json:"baseImage"` VMOnlyMode bool `json:"vmOnlyMode"` @@ -107,6 +108,7 @@ type CreateWorkspacesOptions struct { DockerCompose *DockerCompose `json:"dockerCompose,omitempty"` OnContainer bool `json:"onContainer,omitempty"` PortMappings map[string]string `json:"portMappings"` + FirewallRules []CreateFirewallRule `json:"firewallRules,omitempty"` Files interface{} `json:"files"` Labels interface{} `json:"labels"` WorkspaceVersion string `json:"workspaceVersion"` @@ -114,6 +116,14 @@ type CreateWorkspacesOptions struct { LaunchableConfig *LaunchableConfig `json:"launchableConfig,omitempty"` } +// CreateFirewallRule mirrors brev-deploy's CreateFirewallRule. AllowedIPs is +// either "all" (open to 0.0.0.0/0) or "user-ip" (open to ClientIPs). +type CreateFirewallRule struct { + Port string `json:"port"` + AllowedIPs string `json:"allowedIPs"` + ClientIPs []string `json:"clientIPs,omitempty"` +} + type LaunchableConfig struct { ID string `json:"id"` } @@ -131,11 +141,13 @@ type LaunchableResponse struct { } type LaunchableWorkspaceRequest struct { - WorkspaceGroupID string `json:"workspaceGroupId,omitempty"` - InstanceType string `json:"instanceType"` - Storage string `json:"storage,omitempty"` - Location string `json:"location,omitempty"` - ImageID string `json:"imageId,omitempty"` + WorkspaceGroupID string `json:"workspaceGroupId,omitempty"` + InstanceType string `json:"instanceType"` + Storage string `json:"storage,omitempty"` + Location string `json:"location,omitempty"` + SubLocation string `json:"subLocation,omitempty"` + ImageID string `json:"imageId,omitempty"` + FirewallRules []CreateFirewallRule `json:"firewallRules,omitempty"` } type LaunchableBuildRequest struct {