From 0a5dff812699328a1fc285b868ea7837500e3f26 Mon Sep 17 00:00:00 2001 From: "Luis Gustavo S. Barreto" Date: Tue, 23 Jun 2026 17:23:38 -0300 Subject: [PATCH] feat: add connections, a saved SSH connection manager A personal address book for reaching SSH targets from the browser: saved external hosts (dialed straight through an SSRF guardian, no agent) and ShellHub devices, each with its own auth. Keys come from the vault or a one-off paste, and host keys are pinned on a trust-on-first-use basis. The target kind is fixed at creation. Team connections are a Cloud/Enterprise capability, surfaced as an upsell on the Community edition. --- api/go.mod | 2 + api/go.sum | 4 + api/pkg/responses/known_host.go | 15 + api/routes/connection.go | 149 ++ api/routes/known_host.go | 101 ++ api/routes/routes.go | 12 + api/services/connection.go | 218 +++ api/services/connection_test.go | 154 ++ api/services/errors.go | 20 + api/services/known_host.go | 188 +++ api/services/mocks/mock_service.go | 774 ++++++++- api/services/service.go | 2 + api/store/connection.go | 32 + api/store/known_host.go | 22 + api/store/mocks/mock_query_options.go | 53 + api/store/mocks/mock_store.go | 570 +++++++ api/store/pg/connection.go | 116 ++ api/store/pg/entity/connection.go | 62 + api/store/pg/entity/known_host.go | 56 + api/store/pg/known_host.go | 100 ++ .../pg/migrations/007_connections.tx.down.sql | 1 + .../pg/migrations/007_connections.tx.up.sql | 20 + .../008_connection_device_target.tx.down.sql | 1 + .../008_connection_device_target.tx.up.sql | 1 + .../009_ssh_known_hosts.tx.down.sql | 1 + .../migrations/009_ssh_known_hosts.tx.up.sql | 22 + api/store/pg/query-options.go | 25 + api/store/query-options.go | 3 + api/store/store.go | 2 + go.mod | 2 + go.sum | 4 + openapi/spec/cloud-openapi.yaml | 11 + openapi/spec/community-openapi.yaml | 15 + .../spec/components/schemas/connection.yaml | 68 + .../schemas/connectionCreateRequest.yaml | 41 + .../components/schemas/connectionStatus.yaml | 9 + .../schemas/connectionUpdateRequest.yaml | 33 + .../spec/components/schemas/knownHost.yaml | 52 + .../schemas/knownHostAcceptRequest.yaml | 31 + .../schemas/knownHostScanRequest.yaml | 20 + .../schemas/knownHostScanResult.yaml | 29 + .../components/schemas/teamConnection.yaml | 47 + .../schemas/teamConnectionPrefs.yaml | 30 + .../schemas/teamConnectionPrefsRequest.yaml | 12 + .../schemas/teamConnectionRequest.yaml | 24 + openapi/spec/enterprise-openapi.yaml | 11 + openapi/spec/openapi.yaml | 23 + openapi/spec/paths/api@connections.yaml | 65 + .../spec/paths/api@connections@host-key.yaml | 78 + .../api@connections@host-key@accept.yaml | 39 + .../paths/api@connections@host-key@scan.yaml | 37 + openapi/spec/paths/api@connections@team.yaml | 63 + .../spec/paths/api@connections@team@{id}.yaml | 64 + .../api@connections@team@{id}@prefs.yaml | 62 + .../api@connections@team@{id}@status.yaml | 31 + openapi/spec/paths/api@connections@{id}.yaml | 76 + .../paths/api@connections@{id}@status.yaml | 29 + pkg/api/authorizer/permissions.go | 24 + pkg/api/authorizer/role_test.go | 12 + pkg/api/requests/connection.go | 68 + pkg/api/requests/known_host.go | 48 + pkg/egress/egress.go | 84 + pkg/egress/egress_test.go | 25 + pkg/models/connection.go | 42 + pkg/models/known_host.go | 40 + ssh/go.mod | 2 + ssh/go.sum | 4 + ssh/main.go | 1 + ssh/web/connect.go | 340 ++++ ssh/web/errors.go | 4 + ssh/web/manager.go | 6 +- ssh/web/utils.go | 17 +- ui/apps/console/src/App.tsx | 15 +- ui/apps/console/src/api/connections.ts | 60 + ui/apps/console/src/api/hostKeys.ts | 70 + ui/apps/console/src/api/teamConnections.ts | 86 + .../console/src/components/ConnectDrawer.tsx | 1503 ++++++++++++++--- .../console/src/components/HostKeyModal.tsx | 131 ++ .../src/components/common/DevicePicker.tsx | 108 ++ .../src/components/common/PremiumUpsell.tsx | 57 + .../src/components/common/ProBadge.tsx | 8 + .../console/src/components/layout/Sidebar.tsx | 15 +- .../components/terminal/TerminalInstance.tsx | 34 +- .../src/hooks/useConnectionMutations.ts | 40 + ui/apps/console/src/hooks/useConnections.ts | 46 + ui/apps/console/src/hooks/useHostKeys.ts | 56 + .../src/hooks/useTeamConnectionMutations.ts | 54 + .../console/src/hooks/useTeamConnections.ts | 81 + .../console/src/pages/connections/index.tsx | 621 +++++++ .../pages/devices/__tests__/Devices.test.tsx | 4 + ui/apps/console/src/pages/devices/index.tsx | 11 + .../__tests__/SecureVault.test.tsx | 35 +- ui/apps/console/src/stores/terminalStore.ts | 26 +- .../utils/__tests__/connectionDirty.test.ts | 88 + ui/apps/console/src/utils/connectionDirty.ts | 32 + ui/apps/console/src/utils/ssh-keys.ts | 17 +- 96 files changed, 7330 insertions(+), 317 deletions(-) create mode 100644 api/pkg/responses/known_host.go create mode 100644 api/routes/connection.go create mode 100644 api/routes/known_host.go create mode 100644 api/services/connection.go create mode 100644 api/services/connection_test.go create mode 100644 api/services/known_host.go create mode 100644 api/store/connection.go create mode 100644 api/store/known_host.go create mode 100644 api/store/pg/connection.go create mode 100644 api/store/pg/entity/connection.go create mode 100644 api/store/pg/entity/known_host.go create mode 100644 api/store/pg/known_host.go create mode 100644 api/store/pg/migrations/007_connections.tx.down.sql create mode 100644 api/store/pg/migrations/007_connections.tx.up.sql create mode 100644 api/store/pg/migrations/008_connection_device_target.tx.down.sql create mode 100644 api/store/pg/migrations/008_connection_device_target.tx.up.sql create mode 100644 api/store/pg/migrations/009_ssh_known_hosts.tx.down.sql create mode 100644 api/store/pg/migrations/009_ssh_known_hosts.tx.up.sql create mode 100644 openapi/spec/components/schemas/connection.yaml create mode 100644 openapi/spec/components/schemas/connectionCreateRequest.yaml create mode 100644 openapi/spec/components/schemas/connectionStatus.yaml create mode 100644 openapi/spec/components/schemas/connectionUpdateRequest.yaml create mode 100644 openapi/spec/components/schemas/knownHost.yaml create mode 100644 openapi/spec/components/schemas/knownHostAcceptRequest.yaml create mode 100644 openapi/spec/components/schemas/knownHostScanRequest.yaml create mode 100644 openapi/spec/components/schemas/knownHostScanResult.yaml create mode 100644 openapi/spec/components/schemas/teamConnection.yaml create mode 100644 openapi/spec/components/schemas/teamConnectionPrefs.yaml create mode 100644 openapi/spec/components/schemas/teamConnectionPrefsRequest.yaml create mode 100644 openapi/spec/components/schemas/teamConnectionRequest.yaml create mode 100644 openapi/spec/paths/api@connections.yaml create mode 100644 openapi/spec/paths/api@connections@host-key.yaml create mode 100644 openapi/spec/paths/api@connections@host-key@accept.yaml create mode 100644 openapi/spec/paths/api@connections@host-key@scan.yaml create mode 100644 openapi/spec/paths/api@connections@team.yaml create mode 100644 openapi/spec/paths/api@connections@team@{id}.yaml create mode 100644 openapi/spec/paths/api@connections@team@{id}@prefs.yaml create mode 100644 openapi/spec/paths/api@connections@team@{id}@status.yaml create mode 100644 openapi/spec/paths/api@connections@{id}.yaml create mode 100644 openapi/spec/paths/api@connections@{id}@status.yaml create mode 100644 pkg/api/requests/connection.go create mode 100644 pkg/api/requests/known_host.go create mode 100644 pkg/egress/egress.go create mode 100644 pkg/egress/egress_test.go create mode 100644 pkg/models/connection.go create mode 100644 pkg/models/known_host.go create mode 100644 ssh/web/connect.go create mode 100644 ui/apps/console/src/api/connections.ts create mode 100644 ui/apps/console/src/api/hostKeys.ts create mode 100644 ui/apps/console/src/api/teamConnections.ts create mode 100644 ui/apps/console/src/components/HostKeyModal.tsx create mode 100644 ui/apps/console/src/components/common/DevicePicker.tsx create mode 100644 ui/apps/console/src/components/common/PremiumUpsell.tsx create mode 100644 ui/apps/console/src/components/common/ProBadge.tsx create mode 100644 ui/apps/console/src/hooks/useConnectionMutations.ts create mode 100644 ui/apps/console/src/hooks/useConnections.ts create mode 100644 ui/apps/console/src/hooks/useHostKeys.ts create mode 100644 ui/apps/console/src/hooks/useTeamConnectionMutations.ts create mode 100644 ui/apps/console/src/hooks/useTeamConnections.ts create mode 100644 ui/apps/console/src/pages/connections/index.tsx create mode 100644 ui/apps/console/src/utils/__tests__/connectionDirty.test.ts create mode 100644 ui/apps/console/src/utils/connectionDirty.ts diff --git a/api/go.mod b/api/go.mod index 781ce86566e..9eb8ea96106 100644 --- a/api/go.mod +++ b/api/go.mod @@ -3,6 +3,7 @@ module github.com/shellhub-io/shellhub/api go 1.25.8 require ( + code.dny.dev/ssrf v0.2.0 github.com/cnf/structhash v0.0.0-20201127153200-e1b16c1ebc08 github.com/getkin/kin-openapi v0.140.0 github.com/getsentry/sentry-go v0.47.0 @@ -117,6 +118,7 @@ require ( go.opentelemetry.io/otel v1.41.0 // indirect go.opentelemetry.io/otel/metric v1.41.0 // indirect go.opentelemetry.io/otel/trace v1.41.0 // indirect + golang.org/x/exp v0.0.0-20221126150942-6ab00d035af9 // indirect golang.org/x/net v0.56.0 // indirect golang.org/x/sync v0.21.0 // indirect golang.org/x/sys v0.46.0 // indirect diff --git a/api/go.sum b/api/go.sum index d9da7d5c3b5..1e1490cb90b 100644 --- a/api/go.sum +++ b/api/go.sum @@ -1,3 +1,5 @@ +code.dny.dev/ssrf v0.2.0 h1:wCBP990rQQ1CYfRpW+YK1+8xhwUjv189AQ3WMo1jQaI= +code.dny.dev/ssrf v0.2.0/go.mod h1:B+91l25OnyaLIeCx0WRJN5qfJ/4/ZTZxRXgm0lj/2w8= dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA= github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk= @@ -335,6 +337,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.53.0 h1:QZ4Muo8THX6CizN2vPPd5fBGHyogrdK9fG4wLPFUsto= golang.org/x/crypto v0.53.0/go.mod h1:DNLU434OwVakk9PzuwV8w62mAJpRJL3vsgcfp4Qnsio= +golang.org/x/exp v0.0.0-20221126150942-6ab00d035af9 h1:yZNXmy+j/JpX19vZkVktWqAo7Gny4PBWYYK3zskGpx4= +golang.org/x/exp v0.0.0-20221126150942-6ab00d035af9/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= diff --git a/api/pkg/responses/known_host.go b/api/pkg/responses/known_host.go new file mode 100644 index 00000000000..80296d202f1 --- /dev/null +++ b/api/pkg/responses/known_host.go @@ -0,0 +1,15 @@ +package responses + +import "github.com/shellhub-io/shellhub/pkg/models" + +// KnownHostScanResult is the presented host key plus its verification status, +// returned by a host-key scan. It lives here (not in the service package) so the +// generated Service mock can reference it without importing the service package, +// which would form an import cycle with the service's own tests. +type KnownHostScanResult struct { + KeyType string `json:"key_type"` + Fingerprint string `json:"fingerprint"` + PublicKey string `json:"public_key"` + Status models.KnownHostStatus `json:"status"` + Stored *models.KnownHost `json:"stored"` +} diff --git a/api/routes/connection.go b/api/routes/connection.go new file mode 100644 index 00000000000..3acfeb56d2e --- /dev/null +++ b/api/routes/connection.go @@ -0,0 +1,149 @@ +package routes + +import ( + "errors" + "net/http" + "strconv" + + "github.com/shellhub-io/shellhub/api/pkg/gateway" + svc "github.com/shellhub-io/shellhub/api/services" + "github.com/shellhub-io/shellhub/pkg/api/requests" +) + +const ( + CreateConnectionURL = "/connections" + ListConnectionsURL = "/connections" + GetConnectionURL = "/connections/:id" + ConnectionStatusURL = "/connections/:id/status" + UpdateConnectionURL = "/connections/:id" + DeleteConnectionURL = "/connections/:id" +) + +const ParamConnectionID = "id" + +func (h *Handler) CreateConnection(c gateway.Context) error { + var req requests.ConnectionCreate + if err := c.Bind(&req); err != nil { + return err + } + + if err := c.Validate(&req); err != nil { + return err + } + + // For external connections, probe at save time. A 422 lets the UI distinguish + // a blocked target (not a permitted address) from an unreachable host (the + // NAT/firewall hint + install-the-agent funnel). Force skips the probe. + if req.Kind == "external" && !req.Force { + reachable, err := h.service.ProbeReachability(c.Ctx(), &requests.ConnectionProbe{Host: req.Host, Port: req.Port}) + switch { + case errors.Is(err, svc.ErrEgressBlocked): + return c.JSON(http.StatusUnprocessableEntity, map[string]string{"error": "blocked"}) + case err != nil: + return err + case !reachable: + return c.JSON(http.StatusUnprocessableEntity, map[string]string{"error": "unreachable"}) + } + } + + connection, err := h.service.CreateConnection(c.Ctx(), &req) + if err != nil { + return err + } + + return c.JSON(http.StatusCreated, connection) +} + +func (h *Handler) ListConnections(c gateway.Context) error { + req := new(requests.ConnectionList) + if err := c.Bind(req); err != nil { + return err + } + + req.Paginator.Normalize() + req.Sorter.Normalize() + + if err := c.Validate(req); err != nil { + return err + } + + connections, count, err := h.service.ListConnections(c.Ctx(), req) + if err != nil { + return err + } + + c.Response().Header().Set("X-Total-Count", strconv.Itoa(count)) + + return c.JSON(http.StatusOK, connections) +} + +func (h *Handler) GetConnection(c gateway.Context) error { + var req requests.ConnectionGet + if err := c.Bind(&req); err != nil { + return err + } + + if err := c.Validate(&req); err != nil { + return err + } + + connection, err := h.service.GetConnection(c.Ctx(), &req) + if err != nil { + return err + } + + return c.JSON(http.StatusOK, connection) +} + +func (h *Handler) ConnectionStatus(c gateway.Context) error { + var req requests.ConnectionGet + if err := c.Bind(&req); err != nil { + return err + } + + if err := c.Validate(&req); err != nil { + return err + } + + online, err := h.service.ConnectionStatus(c.Ctx(), &req) + if err != nil { + return err + } + + return c.JSON(http.StatusOK, map[string]bool{"online": online}) +} + +func (h *Handler) UpdateConnection(c gateway.Context) error { + var req requests.ConnectionUpdate + if err := c.Bind(&req); err != nil { + return err + } + + if err := c.Validate(&req); err != nil { + return err + } + + connection, err := h.service.UpdateConnection(c.Ctx(), &req) + if err != nil { + return err + } + + return c.JSON(http.StatusOK, connection) +} + +func (h *Handler) DeleteConnection(c gateway.Context) error { + var req requests.ConnectionDelete + if err := c.Bind(&req); err != nil { + return err + } + + if err := c.Validate(&req); err != nil { + return err + } + + if err := h.service.DeleteConnection(c.Ctx(), &req); err != nil { + return err + } + + return c.NoContent(http.StatusOK) +} diff --git a/api/routes/known_host.go b/api/routes/known_host.go new file mode 100644 index 00000000000..bf1b674484e --- /dev/null +++ b/api/routes/known_host.go @@ -0,0 +1,101 @@ +package routes + +import ( + "errors" + "net/http" + + "github.com/shellhub-io/shellhub/api/pkg/gateway" + svc "github.com/shellhub-io/shellhub/api/services" + "github.com/shellhub-io/shellhub/pkg/api/requests" +) + +const ( + ScanKnownHostURL = "/connections/host-key/scan" + AcceptKnownHostURL = "/connections/host-key/accept" + GetKnownHostURL = "/connections/host-key" + DeleteKnownHostURL = "/connections/host-key" +) + +func (h *Handler) ScanKnownHost(c gateway.Context) error { + var req requests.KnownHostScan + if err := c.Bind(&req); err != nil { + return err + } + + if err := c.Validate(&req); err != nil { + return err + } + + result, err := h.service.ScanKnownHost(c.Ctx(), &req) + if err != nil { + // Surface a target we can't reach/read as a 422 (not a 500), and let the + // UI tell a blocked address apart from an unreachable host. + switch { + case errors.Is(err, svc.ErrEgressBlocked): + return c.JSON(http.StatusUnprocessableEntity, map[string]string{"error": "blocked"}) + case errors.Is(err, svc.ErrKnownHostUnreachable): + return c.JSON(http.StatusUnprocessableEntity, map[string]string{"error": "unreachable"}) + } + + return err + } + + return c.JSON(http.StatusOK, result) +} + +func (h *Handler) AcceptKnownHost(c gateway.Context) error { + var req requests.KnownHostAccept + if err := c.Bind(&req); err != nil { + return err + } + + if err := c.Validate(&req); err != nil { + return err + } + + knownHost, err := h.service.AcceptKnownHost(c.Ctx(), &req) + if err != nil { + if errors.Is(err, svc.ErrKnownHostInvalidKey) { + return c.JSON(http.StatusUnprocessableEntity, map[string]string{"error": "invalid_key"}) + } + + return err + } + + return c.JSON(http.StatusOK, knownHost) +} + +func (h *Handler) GetKnownHost(c gateway.Context) error { + var req requests.KnownHostGet + if err := c.Bind(&req); err != nil { + return err + } + + if err := c.Validate(&req); err != nil { + return err + } + + knownHost, err := h.service.GetKnownHost(c.Ctx(), &req) + if err != nil { + return err + } + + return c.JSON(http.StatusOK, knownHost) +} + +func (h *Handler) DeleteKnownHost(c gateway.Context) error { + var req requests.KnownHostDelete + if err := c.Bind(&req); err != nil { + return err + } + + if err := c.Validate(&req); err != nil { + return err + } + + if err := h.service.DeleteKnownHost(c.Ctx(), &req); err != nil { + return err + } + + return c.NoContent(http.StatusOK) +} diff --git a/api/routes/routes.go b/api/routes/routes.go index f8c6256e8d0..b3e1a56c553 100644 --- a/api/routes/routes.go +++ b/api/routes/routes.go @@ -151,6 +151,18 @@ func NewRouter(service services.Service, opts ...Option) *echo.Echo { publicAPI.PUT(SetDeviceCustomFieldURL, gateway.Handler(handler.SetDeviceCustomField), routesmiddleware.RequiresPermission(authorizer.DeviceCustomFieldUpdate)) publicAPI.DELETE(DeleteDeviceCustomFieldURL, gateway.Handler(handler.DeleteDeviceCustomField), routesmiddleware.RequiresPermission(authorizer.DeviceCustomFieldUpdate)) + publicAPI.GET(ListConnectionsURL, routesmiddleware.Authorize(gateway.Handler(handler.ListConnections))) + publicAPI.GET(GetConnectionURL, routesmiddleware.Authorize(gateway.Handler(handler.GetConnection))) + publicAPI.GET(ConnectionStatusURL, routesmiddleware.Authorize(gateway.Handler(handler.ConnectionStatus))) + publicAPI.POST(CreateConnectionURL, gateway.Handler(handler.CreateConnection), routesmiddleware.RequiresPermission(authorizer.ConnectionCreate)) + publicAPI.PUT(UpdateConnectionURL, gateway.Handler(handler.UpdateConnection), routesmiddleware.RequiresPermission(authorizer.ConnectionUpdate)) + publicAPI.DELETE(DeleteConnectionURL, gateway.Handler(handler.DeleteConnection), routesmiddleware.RequiresPermission(authorizer.ConnectionDelete)) + + publicAPI.POST(ScanKnownHostURL, routesmiddleware.Authorize(gateway.Handler(handler.ScanKnownHost))) + publicAPI.POST(AcceptKnownHostURL, routesmiddleware.Authorize(gateway.Handler(handler.AcceptKnownHost))) + publicAPI.GET(GetKnownHostURL, routesmiddleware.Authorize(gateway.Handler(handler.GetKnownHost))) + publicAPI.DELETE(DeleteKnownHostURL, routesmiddleware.Authorize(gateway.Handler(handler.DeleteKnownHost))) + publicAPI.GET(URLGetTags, gateway.Handler(handler.GetTags)) publicAPI.POST(URLCreateTag, gateway.Handler(handler.CreateTag), routesmiddleware.RequiresPermission(authorizer.TagCreate)) publicAPI.PATCH(URLUpdateTag, gateway.Handler(handler.UpdateTag), routesmiddleware.RequiresPermission(authorizer.TagUpdate)) diff --git a/api/services/connection.go b/api/services/connection.go new file mode 100644 index 00000000000..40dac47e7b3 --- /dev/null +++ b/api/services/connection.go @@ -0,0 +1,218 @@ +package services + +import ( + "context" + "errors" + + "github.com/shellhub-io/shellhub/api/store" + "github.com/shellhub-io/shellhub/pkg/api/requests" + "github.com/shellhub-io/shellhub/pkg/egress" + "github.com/shellhub-io/shellhub/pkg/models" + "github.com/shellhub-io/shellhub/pkg/uuid" +) + +type ConnectionService interface { + // CreateConnection creates a personal connection owned by the caller. + CreateConnection(ctx context.Context, req *requests.ConnectionCreate) (*models.Connection, error) + // ListConnections lists the connections owned by the caller. + ListConnections(ctx context.Context, req *requests.ConnectionList) ([]models.Connection, int, error) + // UpdateConnection updates a connection owned by the caller. + UpdateConnection(ctx context.Context, req *requests.ConnectionUpdate) (*models.Connection, error) + // GetConnection returns a single connection owned by the caller. + GetConnection(ctx context.Context, req *requests.ConnectionGet) (*models.Connection, error) + // DeleteConnection deletes a connection owned by the caller. + DeleteConnection(ctx context.Context, req *requests.ConnectionDelete) error + // ConnectionStatus reports whether the connection's target is reachable. For + // external connections it probes the host:port over TCP; for device connections + // it reflects the device's connection state. + ConnectionStatus(ctx context.Context, req *requests.ConnectionGet) (bool, error) + // ProbeReachability reports whether an arbitrary host:port is reachable over + // TCP. Used before saving an external connection to surface NAT/firewall issues. + ProbeReachability(ctx context.Context, req *requests.ConnectionProbe) (bool, error) +} + +// ErrEgressBlocked means the target isn't a permitted connection endpoint: the +// SSRF guardian rejected it (loopback, link-local/metadata, reserved, or a +// private address that isn't allowlisted). Distinct from a host that is simply +// unreachable. It aliases egress.ErrBlocked so route checks (errors.Is) keep working. +var ErrEgressBlocked = egress.ErrBlocked + +// resolveOwnedConnection fetches a connection owned by the caller. A connection +// owned by someone else (or absent) resolves to ErrConnectionNotFound, so we +// never leak the existence of another user's connection. +func (s *service) resolveOwnedConnection(ctx context.Context, tenantID, userID, id string) (*models.Connection, error) { + // id maps to a uuid-typed column; a non-UUID can never match, so short-circuit + // to not-found instead of letting the cast raise a SQL error (500). + if _, err := uuid.Parse(id); err != nil { + return nil, NewErrConnectionNotFound(id, err) + } + + connection, err := s.store.ConnectionResolve( + ctx, + store.ConnectionIDResolver, + id, + s.store.Options().InNamespace(tenantID), + s.store.Options().OwnedBy(userID), + ) + if err != nil { + if errors.Is(err, store.ErrNoDocuments) { + return nil, NewErrConnectionNotFound(id, err) + } + + return nil, err + } + + return connection, nil +} + +// validateDeviceTarget ensures a device-kind connection points at a device that +// exists in the caller's namespace. Without this, a connection could reference a +// device in another namespace and leak its status through ConnectionStatus. +func (s *service) validateDeviceTarget(ctx context.Context, tenantID, kind, deviceUID string) error { + if models.ConnectionKind(kind) != models.ConnectionKindDevice { + return nil + } + + _, err := s.store.DeviceResolve(ctx, store.DeviceUIDResolver, deviceUID, s.store.Options().InNamespace(tenantID)) + if err != nil { + if errors.Is(err, store.ErrNoDocuments) { + return NewErrDeviceNotFound(models.UID(deviceUID), err) + } + + return err + } + + return nil +} + +func (s *service) CreateConnection(ctx context.Context, req *requests.ConnectionCreate) (*models.Connection, error) { + if _, err := s.store.NamespaceResolve(ctx, store.NamespaceTenantIDResolver, req.TenantID); err != nil { + return nil, NewErrNamespaceNotFound(req.TenantID, err) + } + + if err := s.validateDeviceTarget(ctx, req.TenantID, req.Kind, req.DeviceUID); err != nil { + return nil, err + } + + connection := &models.Connection{ + ID: uuid.Generate(), + TenantID: req.TenantID, + OwnerID: req.UserID, + Label: req.Label, + Kind: models.ConnectionKind(req.Kind), + Username: req.Username, + AuthMethod: req.AuthMethod, + KeyFingerprint: req.KeyFingerprint, + } + applyTarget(connection, req.Kind, req.Host, req.Port, req.DeviceUID) + + if _, err := s.store.ConnectionCreate(ctx, connection); err != nil { + return nil, err + } + + return connection, nil +} + +// applyTarget sets the target fields on the connection based on its kind, +// clearing the ones that don't apply so rows stay clean. +func applyTarget(c *models.Connection, kind, host string, port int, deviceUID string) { + switch models.ConnectionKind(kind) { + case models.ConnectionKindDevice: + c.DeviceUID = deviceUID + c.Host = "" + c.Port = 0 + default: + c.Host = host + c.Port = port + c.DeviceUID = "" + } +} + +func (s *service) UpdateConnection(ctx context.Context, req *requests.ConnectionUpdate) (*models.Connection, error) { + connection, err := s.resolveOwnedConnection(ctx, req.TenantID, req.UserID, req.ID) + if err != nil { + return nil, err + } + + // The target kind is fixed at creation; an external host and a device are + // distinct target shapes, so changing it on update is rejected. + if models.ConnectionKind(req.Kind) != connection.Kind { + return nil, NewErrConnectionKindImmutable() + } + + if err := s.validateDeviceTarget(ctx, req.TenantID, req.Kind, req.DeviceUID); err != nil { + return nil, err + } + + connection.Label = req.Label + connection.Kind = models.ConnectionKind(req.Kind) + connection.Username = req.Username + connection.AuthMethod = req.AuthMethod + connection.KeyFingerprint = req.KeyFingerprint + applyTarget(connection, req.Kind, req.Host, req.Port, req.DeviceUID) + + if err := s.store.ConnectionUpdate(ctx, connection); err != nil { + return nil, err + } + + return connection, nil +} + +func (s *service) ListConnections(ctx context.Context, req *requests.ConnectionList) ([]models.Connection, int, error) { + if req.Sorter.By == "" { + req.Sorter.By = "created_at" + } + + req.Sorter.Tiebreak = "id" + + return s.store.ConnectionList( + ctx, + s.store.Options().InNamespace(req.TenantID), + s.store.Options().OwnedBy(req.UserID), + s.store.Options().Sort(&req.Sorter), + s.store.Options().Paginate(&req.Paginator), + ) +} + +func (s *service) GetConnection(ctx context.Context, req *requests.ConnectionGet) (*models.Connection, error) { + return s.resolveOwnedConnection(ctx, req.TenantID, req.UserID, req.ID) +} + +func (s *service) DeleteConnection(ctx context.Context, req *requests.ConnectionDelete) error { + connection, err := s.resolveOwnedConnection(ctx, req.TenantID, req.UserID, req.ID) + if err != nil { + return err + } + + return s.store.ConnectionDelete(ctx, connection) +} + +func (s *service) ConnectionStatus(ctx context.Context, req *requests.ConnectionGet) (bool, error) { + connection, err := s.resolveOwnedConnection(ctx, req.TenantID, req.UserID, req.ID) + if err != nil { + return false, err + } + + // Device connections reflect the agent's connection state. A resolve failure + // just means we can't confirm it's up, so report offline. + if connection.Kind == models.ConnectionKindDevice { + device, err := s.store.DeviceResolve(ctx, store.DeviceUIDResolver, connection.DeviceUID, s.store.Options().InNamespace(connection.TenantID)) + if err == nil { + return device.DisconnectedAt == nil, nil + } + + return false, nil + } + + // External connections have no agent, so probe the endpoint over TCP. + return egress.Reachable(ctx, connection.Host, connection.Port), nil +} + +func (s *service) ProbeReachability(ctx context.Context, req *requests.ConnectionProbe) (bool, error) { + reachable, blocked := egress.Probe(ctx, req.Host, req.Port) + if blocked { + return false, ErrEgressBlocked + } + + return reachable, nil +} diff --git a/api/services/connection_test.go b/api/services/connection_test.go new file mode 100644 index 00000000000..9be1c322283 --- /dev/null +++ b/api/services/connection_test.go @@ -0,0 +1,154 @@ +package services + +import ( + "context" + "testing" + + "github.com/shellhub-io/shellhub/api/store" + storemock "github.com/shellhub-io/shellhub/api/store/mocks" + "github.com/shellhub-io/shellhub/pkg/api/requests" + storecache "github.com/shellhub-io/shellhub/pkg/cache" + "github.com/shellhub-io/shellhub/pkg/models" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// A connection is personal: GetConnection resolves it scoped to both the +// namespace and the caller (InNamespace + OwnedBy). When the store finds no such +// row (because it belongs to another user, or doesn't exist) the service returns +// NotFound, never leaking another user's connection. +func TestGetConnection(t *testing.T) { + storeMock := storemock.NewMockStore(t) + queryOptionsMock := storemock.NewMockQueryOptions(t) + storeMock.On("Options").Return(queryOptionsMock) + + ctx := context.TODO() + + const ( + tenantID = "00000000-0000-4000-0000-000000000000" + userID = "60fb0632538a82e62c2c40a1" + connID = "11111111-1111-4111-8111-111111111111" + ) + + owned := &models.Connection{ + ID: connID, + TenantID: tenantID, + OwnerID: userID, + Label: "db-primary", + Kind: models.ConnectionKindExternal, + Host: "10.0.0.5", + Port: 22, + } + + type Expected struct { + connection *models.Connection + err error + } + + cases := []struct { + description string + req *requests.ConnectionGet + requiredMocks func() + expected Expected + }{ + { + description: "fail with not found when the connection is not the caller's", + req: &requests.ConnectionGet{TenantID: tenantID, UserID: userID, ID: connID}, + requiredMocks: func() { + queryOptionsMock.On("InNamespace", tenantID).Return(nil).Once() + queryOptionsMock.On("OwnedBy", userID).Return(nil).Once() + storeMock. + On("ConnectionResolve", ctx, store.ConnectionIDResolver, connID, + mock.MatchedBy(func(opts []store.QueryOption) bool { return len(opts) == 2 })). + Return(nil, store.ErrNoDocuments). + Once() + }, + expected: Expected{ + connection: nil, + err: NewErrConnectionNotFound(connID, store.ErrNoDocuments), + }, + }, + { + description: "success when the connection is owned by the caller", + req: &requests.ConnectionGet{TenantID: tenantID, UserID: userID, ID: connID}, + requiredMocks: func() { + queryOptionsMock.On("InNamespace", tenantID).Return(nil).Once() + queryOptionsMock.On("OwnedBy", userID).Return(nil).Once() + storeMock. + On("ConnectionResolve", ctx, store.ConnectionIDResolver, connID, + mock.MatchedBy(func(opts []store.QueryOption) bool { return len(opts) == 2 })). + Return(owned, nil). + Once() + }, + expected: Expected{connection: owned, err: nil}, + }, + } + + s := NewService(store.Store(storeMock), privateKey, publicKey, storecache.NewNullCache(), clientMock) + + for _, tc := range cases { + t.Run(tc.description, func(t *testing.T) { + tc.requiredMocks() + + connection, err := s.GetConnection(ctx, tc.req) + assert.Equal(t, tc.expected, Expected{connection, err}) + }) + } + + storeMock.AssertExpectations(t) +} + +// The target kind is fixed at creation. An update that flips an external host to a +// device (or vice-versa) is rejected before any write, so a saved target can't +// change shape under a stable id. +func TestUpdateConnectionRejectsKindChange(t *testing.T) { + storeMock := storemock.NewMockStore(t) + queryOptionsMock := storemock.NewMockQueryOptions(t) + storeMock.On("Options").Return(queryOptionsMock) + + ctx := context.TODO() + + const ( + tenantID = "00000000-0000-4000-0000-000000000000" + userID = "60fb0632538a82e62c2c40a1" + connID = "11111111-1111-4111-8111-111111111111" + ) + + owned := &models.Connection{ + ID: connID, + TenantID: tenantID, + OwnerID: userID, + Label: "db-primary", + Kind: models.ConnectionKindExternal, + Host: "10.0.0.5", + Port: 22, + } + + queryOptionsMock.On("InNamespace", tenantID).Return(nil).Once() + queryOptionsMock.On("OwnedBy", userID).Return(nil).Once() + storeMock. + On("ConnectionResolve", ctx, store.ConnectionIDResolver, connID, + mock.MatchedBy(func(opts []store.QueryOption) bool { return len(opts) == 2 })). + Return(owned, nil). + Once() + + s := NewService(store.Store(storeMock), privateKey, publicKey, storecache.NewNullCache(), clientMock) + + // Same id, but now pointed at a device. + connection, err := s.UpdateConnection(ctx, &requests.ConnectionUpdate{ + TenantID: tenantID, + UserID: userID, + ID: connID, + Label: "db-primary", + Kind: "device", + DeviceUID: "aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa", + }) + + assert.Nil(t, connection) + assert.Equal(t, NewErrConnectionKindImmutable(), err) + + // The kind guard runs before any device validation or write. + storeMock.AssertNotCalled(t, "DeviceResolve", mock.Anything, mock.Anything, mock.Anything, mock.Anything) + storeMock.AssertNotCalled(t, "ConnectionUpdate", mock.Anything, mock.Anything) + storeMock.AssertExpectations(t) +} diff --git a/api/services/errors.go b/api/services/errors.go index 9c9afcc7969..c6f45ad6cd4 100644 --- a/api/services/errors.go +++ b/api/services/errors.go @@ -130,6 +130,9 @@ var ( ErrSameTags = errors.New("trying to update tags with the same content", ErrLayer, ErrCodeNoContentChange) ErrAPIKeyNotFound = errors.New("APIKey not found", ErrLayer, ErrCodeNotFound) ErrAPIKeyDuplicated = errors.New("APIKey duplicated", ErrLayer, ErrCodeDuplicated) + ErrConnectionNotFound = errors.New("Connection not found", ErrLayer, ErrCodeNotFound) + ErrConnectionKindImmutable = errors.New("connection kind cannot be changed", ErrLayer, ErrCodeInvalid) + ErrKnownHostNotFound = errors.New("known host not found", ErrLayer, ErrCodeNotFound) ErrAuthForbidden = errors.New("user is authenticated but cannot access this resource", ErrLayer, ErrCodeForbidden) ErrRoleForbidden = errors.New("role is forbidden", ErrLayer, ErrCodeForbidden) ErrUserDelete = errors.New("user couldn't be deleted", ErrLayer, ErrCodeInvalid) @@ -205,6 +208,23 @@ func NewErrAPIKeyInvalid(name string) error { return NewErrAuthInvalid(map[string]interface{}{"api-key": name}, nil) } +// NewErrConnectionNotFound returns an error when the Connection is not found. +func NewErrConnectionNotFound(id string, next error) error { + return NewErrNotFound(ErrConnectionNotFound, id, next) +} + +// NewErrConnectionKindImmutable returns an error when an update tries to change +// a connection's kind. The target type is fixed at creation. +func NewErrConnectionKindImmutable() error { + return NewErrInvalid(ErrConnectionKindImmutable, nil, nil) +} + +// NewErrKnownHostNotFound returns an error when no known host is stored for the +// target. +func NewErrKnownHostNotFound(host string, next error) error { + return NewErrNotFound(ErrKnownHostNotFound, host, next) +} + // NewErrAPIKeyDuplicated returns an error when the APIKey name is duplicated. func NewErrAPIKeyDuplicated(conflicts []string) error { return NewErrDuplicated(ErrAPIKeyDuplicated, conflicts, nil) diff --git a/api/services/known_host.go b/api/services/known_host.go new file mode 100644 index 00000000000..4c52ca851bf --- /dev/null +++ b/api/services/known_host.go @@ -0,0 +1,188 @@ +package services + +import ( + "context" + "errors" + "net" + "strconv" + "strings" + "time" + + "github.com/shellhub-io/shellhub/api/pkg/responses" + "github.com/shellhub-io/shellhub/api/store" + "github.com/shellhub-io/shellhub/pkg/api/authorizer" + "github.com/shellhub-io/shellhub/pkg/api/requests" + "github.com/shellhub-io/shellhub/pkg/egress" + "github.com/shellhub-io/shellhub/pkg/models" + gossh "golang.org/x/crypto/ssh" +) + +// ErrKnownHostUnreachable means the host key could not be scanned (host down, +// blocked by the egress guardian, or not speaking SSH). +var ErrKnownHostUnreachable = errors.New("could not read the host key") + +// ErrKnownHostInvalidKey means the public key supplied to AcceptKnownHost could +// not be parsed. +var ErrKnownHostInvalidKey = errors.New("invalid host public key") + +type KnownHostService interface { + // ScanKnownHost reads the target's host key and reports it against the stored + // one (unverified / trusted / changed). + ScanKnownHost(ctx context.Context, req *requests.KnownHostScan) (*responses.KnownHostScanResult, error) + // AcceptKnownHost trusts (stores) a host key for the target's scope. + AcceptKnownHost(ctx context.Context, req *requests.KnownHostAccept) (*models.KnownHost, error) + // GetKnownHost returns the stored known host for a target, or nil if none. + GetKnownHost(ctx context.Context, req *requests.KnownHostGet) (*models.KnownHost, error) + // DeleteKnownHost forgets the stored host key for the target's scope. + DeleteKnownHost(ctx context.Context, req *requests.KnownHostDelete) error +} + +// scopeOwner maps a request scope to the owner id used for storage: personal +// records belong to the caller; namespace (team) records are shared (empty). +func scopeOwner(scope, userID string) string { + if scope == "personal" { + return userID + } + + return "" +} + +// scanHostKey opens an SSH handshake to host:port just far enough to capture the +// presented host key, through the SSRF egress guardian. It never authenticates. +func scanHostKey(ctx context.Context, host string, port int) (*responses.KnownHostScanResult, error) { + addr := net.JoinHostPort(host, strconv.Itoa(port)) + + // Reuse the shared SSRF guardian; the host-key handshake wants a longer + // connect timeout than the plain reachability probe, so override it here. + dialer := egress.GuardedDialer(port) + dialer.Timeout = 8 * time.Second + + conn, err := dialer.DialContext(ctx, "tcp", addr) + if err != nil { + if egress.IsBlocked(err) { + return nil, ErrEgressBlocked + } + + return nil, ErrKnownHostUnreachable + } + defer conn.Close() //nolint:errcheck + + var captured gossh.PublicKey + config := &gossh.ClientConfig{ //nolint:exhaustruct + User: "shellhub-probe", + HostKeyCallback: func(_ string, _ net.Addr, key gossh.PublicKey) error { + captured = key + + return nil + }, + Timeout: 8 * time.Second, + } + + // The handshake fails at authentication (no methods), but HostKeyCallback + // runs first, so the key is captured regardless. + if sshConn, chans, reqs, err := gossh.NewClientConn(conn, addr, config); err == nil { + go gossh.DiscardRequests(reqs) + go func() { + for ch := range chans { + ch.Reject(gossh.Prohibited, "") //nolint:errcheck + } + }() + sshConn.Close() //nolint:errcheck + } + + if captured == nil { + return nil, ErrKnownHostUnreachable + } + + return &responses.KnownHostScanResult{ + KeyType: captured.Type(), + Fingerprint: gossh.FingerprintSHA256(captured), + PublicKey: strings.TrimSpace(string(gossh.MarshalAuthorizedKey(captured))), + }, nil +} + +func (s *service) ScanKnownHost(ctx context.Context, req *requests.KnownHostScan) (*responses.KnownHostScanResult, error) { + result, err := scanHostKey(ctx, req.Host, req.Port) + if err != nil { + return nil, err + } + + stored, err := s.store.KnownHostResolve(ctx, req.TenantID, scopeOwner(req.Scope, req.UserID), req.Host, req.Port) + switch { + case err == nil: + result.Stored = stored + if stored.Fingerprint == result.Fingerprint { + result.Status = models.KnownHostTrusted + } else { + result.Status = models.KnownHostChanged + } + case errors.Is(err, store.ErrNoDocuments): + result.Status = models.KnownHostUnverified + default: + return nil, err + } + + return result, nil +} + +func (s *service) AcceptKnownHost(ctx context.Context, req *requests.KnownHostAccept) (*models.KnownHost, error) { + owner := scopeOwner(req.Scope, req.UserID) + + // Any shared (team) trust write needs operator+: a member must not be able to + // plant or change a host key that every other member's connects will trust. + // Personal records are the caller's own, so any role may write them. + if req.Scope == "namespace" && !req.Role.HasAuthority(authorizer.RoleOperator) { + return nil, ErrAuthForbidden + } + + // Trust only key material the server can parse, and derive the type and + // fingerprint from the key itself — never from the client-supplied fields, so + // a caller can't store a fingerprint that doesn't match the stored key. + parsed, _, _, _, err := gossh.ParseAuthorizedKey([]byte(req.PublicKey)) + if err != nil { + return nil, ErrKnownHostInvalidKey + } + + knownHost := &models.KnownHost{ + TenantID: req.TenantID, + OwnerID: owner, + Host: req.Host, + Port: req.Port, + KeyType: parsed.Type(), + PublicKey: strings.TrimSpace(string(gossh.MarshalAuthorizedKey(parsed))), + Fingerprint: gossh.FingerprintSHA256(parsed), + AcceptedBy: req.UserID, + } + + if err := s.store.KnownHostUpsert(ctx, knownHost); err != nil { + return nil, err + } + + return knownHost, nil +} + +func (s *service) GetKnownHost(ctx context.Context, req *requests.KnownHostGet) (*models.KnownHost, error) { + knownHost, err := s.store.KnownHostResolve(ctx, req.TenantID, scopeOwner(req.Scope, req.UserID), req.Host, req.Port) + if err != nil { + if errors.Is(err, store.ErrNoDocuments) { + return nil, nil //nolint:nilnil + } + + return nil, err + } + + return knownHost, nil +} + +func (s *service) DeleteKnownHost(ctx context.Context, req *requests.KnownHostDelete) error { + if req.Scope == "namespace" && !req.Role.HasAuthority(authorizer.RoleOperator) { + return ErrAuthForbidden + } + + err := s.store.KnownHostDelete(ctx, req.TenantID, scopeOwner(req.Scope, req.UserID), req.Host, req.Port) + if err != nil && errors.Is(err, store.ErrNoDocuments) { + return NewErrKnownHostNotFound(req.Host, err) + } + + return err +} diff --git a/api/services/mocks/mock_service.go b/api/services/mocks/mock_service.go index 3d2f1fa766f..da65653e73e 100644 --- a/api/services/mocks/mock_service.go +++ b/api/services/mocks/mock_service.go @@ -43,6 +43,74 @@ func (_m *MockService) EXPECT() *MockService_Expecter { return &MockService_Expecter{mock: &_m.Mock} } +// AcceptKnownHost provides a mock function for the type MockService +func (_mock *MockService) AcceptKnownHost(ctx context.Context, req *requests.KnownHostAccept) (*models.KnownHost, error) { + ret := _mock.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for AcceptKnownHost") + } + + var r0 *models.KnownHost + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, *requests.KnownHostAccept) (*models.KnownHost, error)); ok { + return returnFunc(ctx, req) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, *requests.KnownHostAccept) *models.KnownHost); ok { + r0 = returnFunc(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*models.KnownHost) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, *requests.KnownHostAccept) error); ok { + r1 = returnFunc(ctx, req) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// MockService_AcceptKnownHost_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AcceptKnownHost' +type MockService_AcceptKnownHost_Call struct { + *mock.Call +} + +// AcceptKnownHost is a helper method to define mock.On call +// - ctx context.Context +// - req *requests.KnownHostAccept +func (_e *MockService_Expecter) AcceptKnownHost(ctx any, req any) *MockService_AcceptKnownHost_Call { + return &MockService_AcceptKnownHost_Call{Call: _e.mock.On("AcceptKnownHost", ctx, req)} +} + +func (_c *MockService_AcceptKnownHost_Call) Run(run func(ctx context.Context, req *requests.KnownHostAccept)) *MockService_AcceptKnownHost_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 *requests.KnownHostAccept + if args[1] != nil { + arg1 = args[1].(*requests.KnownHostAccept) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockService_AcceptKnownHost_Call) Return(knownHost *models.KnownHost, err error) *MockService_AcceptKnownHost_Call { + _c.Call.Return(knownHost, err) + return _c +} + +func (_c *MockService_AcceptKnownHost_Call) RunAndReturn(run func(ctx context.Context, req *requests.KnownHostAccept) (*models.KnownHost, error)) *MockService_AcceptKnownHost_Call { + _c.Call.Return(run) + return _c +} + // AddNamespaceMember provides a mock function for the type MockService func (_mock *MockService) AddNamespaceMember(ctx context.Context, req *requests.NamespaceAddMember) (*models.Namespace, error) { ret := _mock.Called(ctx, req) @@ -605,6 +673,72 @@ func (_c *MockService_AuthUncacheToken_Call) RunAndReturn(run func(ctx context.C return _c } +// ConnectionStatus provides a mock function for the type MockService +func (_mock *MockService) ConnectionStatus(ctx context.Context, req *requests.ConnectionGet) (bool, error) { + ret := _mock.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for ConnectionStatus") + } + + var r0 bool + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, *requests.ConnectionGet) (bool, error)); ok { + return returnFunc(ctx, req) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, *requests.ConnectionGet) bool); ok { + r0 = returnFunc(ctx, req) + } else { + r0 = ret.Get(0).(bool) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, *requests.ConnectionGet) error); ok { + r1 = returnFunc(ctx, req) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// MockService_ConnectionStatus_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ConnectionStatus' +type MockService_ConnectionStatus_Call struct { + *mock.Call +} + +// ConnectionStatus is a helper method to define mock.On call +// - ctx context.Context +// - req *requests.ConnectionGet +func (_e *MockService_Expecter) ConnectionStatus(ctx any, req any) *MockService_ConnectionStatus_Call { + return &MockService_ConnectionStatus_Call{Call: _e.mock.On("ConnectionStatus", ctx, req)} +} + +func (_c *MockService_ConnectionStatus_Call) Run(run func(ctx context.Context, req *requests.ConnectionGet)) *MockService_ConnectionStatus_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 *requests.ConnectionGet + if args[1] != nil { + arg1 = args[1].(*requests.ConnectionGet) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockService_ConnectionStatus_Call) Return(b bool, err error) *MockService_ConnectionStatus_Call { + _c.Call.Return(b, err) + return _c +} + +func (_c *MockService_ConnectionStatus_Call) RunAndReturn(run func(ctx context.Context, req *requests.ConnectionGet) (bool, error)) *MockService_ConnectionStatus_Call { + _c.Call.Return(run) + return _c +} + // CreateAPIKey provides a mock function for the type MockService func (_mock *MockService) CreateAPIKey(ctx context.Context, req *requests.CreateAPIKey) (*responses.CreateAPIKey, error) { ret := _mock.Called(ctx, req) @@ -673,6 +807,74 @@ func (_c *MockService_CreateAPIKey_Call) RunAndReturn(run func(ctx context.Conte return _c } +// CreateConnection provides a mock function for the type MockService +func (_mock *MockService) CreateConnection(ctx context.Context, req *requests.ConnectionCreate) (*models.Connection, error) { + ret := _mock.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for CreateConnection") + } + + var r0 *models.Connection + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, *requests.ConnectionCreate) (*models.Connection, error)); ok { + return returnFunc(ctx, req) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, *requests.ConnectionCreate) *models.Connection); ok { + r0 = returnFunc(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*models.Connection) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, *requests.ConnectionCreate) error); ok { + r1 = returnFunc(ctx, req) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// MockService_CreateConnection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateConnection' +type MockService_CreateConnection_Call struct { + *mock.Call +} + +// CreateConnection is a helper method to define mock.On call +// - ctx context.Context +// - req *requests.ConnectionCreate +func (_e *MockService_Expecter) CreateConnection(ctx any, req any) *MockService_CreateConnection_Call { + return &MockService_CreateConnection_Call{Call: _e.mock.On("CreateConnection", ctx, req)} +} + +func (_c *MockService_CreateConnection_Call) Run(run func(ctx context.Context, req *requests.ConnectionCreate)) *MockService_CreateConnection_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 *requests.ConnectionCreate + if args[1] != nil { + arg1 = args[1].(*requests.ConnectionCreate) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockService_CreateConnection_Call) Return(connection *models.Connection, err error) *MockService_CreateConnection_Call { + _c.Call.Return(connection, err) + return _c +} + +func (_c *MockService_CreateConnection_Call) RunAndReturn(run func(ctx context.Context, req *requests.ConnectionCreate) (*models.Connection, error)) *MockService_CreateConnection_Call { + _c.Call.Return(run) + return _c +} + // CreateNamespace provides a mock function for the type MockService func (_mock *MockService) CreateNamespace(ctx context.Context, namespace *requests.NamespaceCreate) (*models.Namespace, error) { ret := _mock.Called(ctx, namespace) @@ -1201,6 +1403,63 @@ func (_c *MockService_DeleteAPIKey_Call) RunAndReturn(run func(ctx context.Conte return _c } +// DeleteConnection provides a mock function for the type MockService +func (_mock *MockService) DeleteConnection(ctx context.Context, req *requests.ConnectionDelete) error { + ret := _mock.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for DeleteConnection") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, *requests.ConnectionDelete) error); ok { + r0 = returnFunc(ctx, req) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// MockService_DeleteConnection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteConnection' +type MockService_DeleteConnection_Call struct { + *mock.Call +} + +// DeleteConnection is a helper method to define mock.On call +// - ctx context.Context +// - req *requests.ConnectionDelete +func (_e *MockService_Expecter) DeleteConnection(ctx any, req any) *MockService_DeleteConnection_Call { + return &MockService_DeleteConnection_Call{Call: _e.mock.On("DeleteConnection", ctx, req)} +} + +func (_c *MockService_DeleteConnection_Call) Run(run func(ctx context.Context, req *requests.ConnectionDelete)) *MockService_DeleteConnection_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 *requests.ConnectionDelete + if args[1] != nil { + arg1 = args[1].(*requests.ConnectionDelete) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockService_DeleteConnection_Call) Return(err error) *MockService_DeleteConnection_Call { + _c.Call.Return(err) + return _c +} + +func (_c *MockService_DeleteConnection_Call) RunAndReturn(run func(ctx context.Context, req *requests.ConnectionDelete) error) *MockService_DeleteConnection_Call { + _c.Call.Return(run) + return _c +} + // DeleteDevice provides a mock function for the type MockService func (_mock *MockService) DeleteDevice(ctx context.Context, uid models.UID, tenant string) error { ret := _mock.Called(ctx, uid, tenant) @@ -1321,6 +1580,63 @@ func (_c *MockService_DeleteDeviceCustomField_Call) RunAndReturn(run func(ctx co return _c } +// DeleteKnownHost provides a mock function for the type MockService +func (_mock *MockService) DeleteKnownHost(ctx context.Context, req *requests.KnownHostDelete) error { + ret := _mock.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for DeleteKnownHost") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, *requests.KnownHostDelete) error); ok { + r0 = returnFunc(ctx, req) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// MockService_DeleteKnownHost_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteKnownHost' +type MockService_DeleteKnownHost_Call struct { + *mock.Call +} + +// DeleteKnownHost is a helper method to define mock.On call +// - ctx context.Context +// - req *requests.KnownHostDelete +func (_e *MockService_Expecter) DeleteKnownHost(ctx any, req any) *MockService_DeleteKnownHost_Call { + return &MockService_DeleteKnownHost_Call{Call: _e.mock.On("DeleteKnownHost", ctx, req)} +} + +func (_c *MockService_DeleteKnownHost_Call) Run(run func(ctx context.Context, req *requests.KnownHostDelete)) *MockService_DeleteKnownHost_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 *requests.KnownHostDelete + if args[1] != nil { + arg1 = args[1].(*requests.KnownHostDelete) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockService_DeleteKnownHost_Call) Return(err error) *MockService_DeleteKnownHost_Call { + _c.Call.Return(err) + return _c +} + +func (_c *MockService_DeleteKnownHost_Call) RunAndReturn(run func(ctx context.Context, req *requests.KnownHostDelete) error) *MockService_DeleteKnownHost_Call { + _c.Call.Return(run) + return _c +} + // DeleteNamespace provides a mock function for the type MockService func (_mock *MockService) DeleteNamespace(ctx context.Context, tenantID string) error { ret := _mock.Called(ctx, tenantID) @@ -1899,55 +2215,55 @@ func (_c *MockService_EventSession_Call) RunAndReturn(run func(ctx context.Conte return _c } -// GetDevice provides a mock function for the type MockService -func (_mock *MockService) GetDevice(ctx context.Context, uid models.UID) (*models.Device, error) { - ret := _mock.Called(ctx, uid) +// GetConnection provides a mock function for the type MockService +func (_mock *MockService) GetConnection(ctx context.Context, req *requests.ConnectionGet) (*models.Connection, error) { + ret := _mock.Called(ctx, req) if len(ret) == 0 { - panic("no return value specified for GetDevice") + panic("no return value specified for GetConnection") } - var r0 *models.Device + var r0 *models.Connection var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, models.UID) (*models.Device, error)); ok { - return returnFunc(ctx, uid) + if returnFunc, ok := ret.Get(0).(func(context.Context, *requests.ConnectionGet) (*models.Connection, error)); ok { + return returnFunc(ctx, req) } - if returnFunc, ok := ret.Get(0).(func(context.Context, models.UID) *models.Device); ok { - r0 = returnFunc(ctx, uid) + if returnFunc, ok := ret.Get(0).(func(context.Context, *requests.ConnectionGet) *models.Connection); ok { + r0 = returnFunc(ctx, req) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*models.Device) + r0 = ret.Get(0).(*models.Connection) } } - if returnFunc, ok := ret.Get(1).(func(context.Context, models.UID) error); ok { - r1 = returnFunc(ctx, uid) + if returnFunc, ok := ret.Get(1).(func(context.Context, *requests.ConnectionGet) error); ok { + r1 = returnFunc(ctx, req) } else { r1 = ret.Error(1) } return r0, r1 } -// MockService_GetDevice_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetDevice' -type MockService_GetDevice_Call struct { +// MockService_GetConnection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetConnection' +type MockService_GetConnection_Call struct { *mock.Call } -// GetDevice is a helper method to define mock.On call +// GetConnection is a helper method to define mock.On call // - ctx context.Context -// - uid models.UID -func (_e *MockService_Expecter) GetDevice(ctx any, uid any) *MockService_GetDevice_Call { - return &MockService_GetDevice_Call{Call: _e.mock.On("GetDevice", ctx, uid)} +// - req *requests.ConnectionGet +func (_e *MockService_Expecter) GetConnection(ctx any, req any) *MockService_GetConnection_Call { + return &MockService_GetConnection_Call{Call: _e.mock.On("GetConnection", ctx, req)} } -func (_c *MockService_GetDevice_Call) Run(run func(ctx context.Context, uid models.UID)) *MockService_GetDevice_Call { +func (_c *MockService_GetConnection_Call) Run(run func(ctx context.Context, req *requests.ConnectionGet)) *MockService_GetConnection_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { arg0 = args[0].(context.Context) } - var arg1 models.UID + var arg1 *requests.ConnectionGet if args[1] != nil { - arg1 = args[1].(models.UID) + arg1 = args[1].(*requests.ConnectionGet) } run( arg0, @@ -1957,8 +2273,76 @@ func (_c *MockService_GetDevice_Call) Run(run func(ctx context.Context, uid mode return _c } -func (_c *MockService_GetDevice_Call) Return(device *models.Device, err error) *MockService_GetDevice_Call { - _c.Call.Return(device, err) +func (_c *MockService_GetConnection_Call) Return(connection *models.Connection, err error) *MockService_GetConnection_Call { + _c.Call.Return(connection, err) + return _c +} + +func (_c *MockService_GetConnection_Call) RunAndReturn(run func(ctx context.Context, req *requests.ConnectionGet) (*models.Connection, error)) *MockService_GetConnection_Call { + _c.Call.Return(run) + return _c +} + +// GetDevice provides a mock function for the type MockService +func (_mock *MockService) GetDevice(ctx context.Context, uid models.UID) (*models.Device, error) { + ret := _mock.Called(ctx, uid) + + if len(ret) == 0 { + panic("no return value specified for GetDevice") + } + + var r0 *models.Device + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, models.UID) (*models.Device, error)); ok { + return returnFunc(ctx, uid) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, models.UID) *models.Device); ok { + r0 = returnFunc(ctx, uid) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*models.Device) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, models.UID) error); ok { + r1 = returnFunc(ctx, uid) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// MockService_GetDevice_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetDevice' +type MockService_GetDevice_Call struct { + *mock.Call +} + +// GetDevice is a helper method to define mock.On call +// - ctx context.Context +// - uid models.UID +func (_e *MockService_Expecter) GetDevice(ctx any, uid any) *MockService_GetDevice_Call { + return &MockService_GetDevice_Call{Call: _e.mock.On("GetDevice", ctx, uid)} +} + +func (_c *MockService_GetDevice_Call) Run(run func(ctx context.Context, uid models.UID)) *MockService_GetDevice_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 models.UID + if args[1] != nil { + arg1 = args[1].(models.UID) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockService_GetDevice_Call) Return(device *models.Device, err error) *MockService_GetDevice_Call { + _c.Call.Return(device, err) return _c } @@ -1967,6 +2351,74 @@ func (_c *MockService_GetDevice_Call) RunAndReturn(run func(ctx context.Context, return _c } +// GetKnownHost provides a mock function for the type MockService +func (_mock *MockService) GetKnownHost(ctx context.Context, req *requests.KnownHostGet) (*models.KnownHost, error) { + ret := _mock.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for GetKnownHost") + } + + var r0 *models.KnownHost + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, *requests.KnownHostGet) (*models.KnownHost, error)); ok { + return returnFunc(ctx, req) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, *requests.KnownHostGet) *models.KnownHost); ok { + r0 = returnFunc(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*models.KnownHost) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, *requests.KnownHostGet) error); ok { + r1 = returnFunc(ctx, req) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// MockService_GetKnownHost_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetKnownHost' +type MockService_GetKnownHost_Call struct { + *mock.Call +} + +// GetKnownHost is a helper method to define mock.On call +// - ctx context.Context +// - req *requests.KnownHostGet +func (_e *MockService_Expecter) GetKnownHost(ctx any, req any) *MockService_GetKnownHost_Call { + return &MockService_GetKnownHost_Call{Call: _e.mock.On("GetKnownHost", ctx, req)} +} + +func (_c *MockService_GetKnownHost_Call) Run(run func(ctx context.Context, req *requests.KnownHostGet)) *MockService_GetKnownHost_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 *requests.KnownHostGet + if args[1] != nil { + arg1 = args[1].(*requests.KnownHostGet) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockService_GetKnownHost_Call) Return(knownHost *models.KnownHost, err error) *MockService_GetKnownHost_Call { + _c.Call.Return(knownHost, err) + return _c +} + +func (_c *MockService_GetKnownHost_Call) RunAndReturn(run func(ctx context.Context, req *requests.KnownHostGet) (*models.KnownHost, error)) *MockService_GetKnownHost_Call { + _c.Call.Return(run) + return _c +} + // GetNamespace provides a mock function for the type MockService func (_mock *MockService) GetNamespace(ctx context.Context, tenantID string) (*models.Namespace, error) { ret := _mock.Called(ctx, tenantID) @@ -2584,6 +3036,80 @@ func (_c *MockService_ListAPIKeys_Call) RunAndReturn(run func(ctx context.Contex return _c } +// ListConnections provides a mock function for the type MockService +func (_mock *MockService) ListConnections(ctx context.Context, req *requests.ConnectionList) ([]models.Connection, int, error) { + ret := _mock.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for ListConnections") + } + + var r0 []models.Connection + var r1 int + var r2 error + if returnFunc, ok := ret.Get(0).(func(context.Context, *requests.ConnectionList) ([]models.Connection, int, error)); ok { + return returnFunc(ctx, req) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, *requests.ConnectionList) []models.Connection); ok { + r0 = returnFunc(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]models.Connection) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, *requests.ConnectionList) int); ok { + r1 = returnFunc(ctx, req) + } else { + r1 = ret.Get(1).(int) + } + if returnFunc, ok := ret.Get(2).(func(context.Context, *requests.ConnectionList) error); ok { + r2 = returnFunc(ctx, req) + } else { + r2 = ret.Error(2) + } + return r0, r1, r2 +} + +// MockService_ListConnections_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListConnections' +type MockService_ListConnections_Call struct { + *mock.Call +} + +// ListConnections is a helper method to define mock.On call +// - ctx context.Context +// - req *requests.ConnectionList +func (_e *MockService_Expecter) ListConnections(ctx any, req any) *MockService_ListConnections_Call { + return &MockService_ListConnections_Call{Call: _e.mock.On("ListConnections", ctx, req)} +} + +func (_c *MockService_ListConnections_Call) Run(run func(ctx context.Context, req *requests.ConnectionList)) *MockService_ListConnections_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 *requests.ConnectionList + if args[1] != nil { + arg1 = args[1].(*requests.ConnectionList) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockService_ListConnections_Call) Return(connections []models.Connection, n int, err error) *MockService_ListConnections_Call { + _c.Call.Return(connections, n, err) + return _c +} + +func (_c *MockService_ListConnections_Call) RunAndReturn(run func(ctx context.Context, req *requests.ConnectionList) ([]models.Connection, int, error)) *MockService_ListConnections_Call { + _c.Call.Return(run) + return _c +} + // ListDevices provides a mock function for the type MockService func (_mock *MockService) ListDevices(ctx context.Context, req *requests.DeviceList) ([]models.Device, int, error) { ret := _mock.Called(ctx, req) @@ -3085,6 +3611,72 @@ func (_c *MockService_OfflineDevice_Call) RunAndReturn(run func(ctx context.Cont return _c } +// ProbeReachability provides a mock function for the type MockService +func (_mock *MockService) ProbeReachability(ctx context.Context, req *requests.ConnectionProbe) (bool, error) { + ret := _mock.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for ProbeReachability") + } + + var r0 bool + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, *requests.ConnectionProbe) (bool, error)); ok { + return returnFunc(ctx, req) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, *requests.ConnectionProbe) bool); ok { + r0 = returnFunc(ctx, req) + } else { + r0 = ret.Get(0).(bool) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, *requests.ConnectionProbe) error); ok { + r1 = returnFunc(ctx, req) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// MockService_ProbeReachability_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ProbeReachability' +type MockService_ProbeReachability_Call struct { + *mock.Call +} + +// ProbeReachability is a helper method to define mock.On call +// - ctx context.Context +// - req *requests.ConnectionProbe +func (_e *MockService_Expecter) ProbeReachability(ctx any, req any) *MockService_ProbeReachability_Call { + return &MockService_ProbeReachability_Call{Call: _e.mock.On("ProbeReachability", ctx, req)} +} + +func (_c *MockService_ProbeReachability_Call) Run(run func(ctx context.Context, req *requests.ConnectionProbe)) *MockService_ProbeReachability_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 *requests.ConnectionProbe + if args[1] != nil { + arg1 = args[1].(*requests.ConnectionProbe) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockService_ProbeReachability_Call) Return(b bool, err error) *MockService_ProbeReachability_Call { + _c.Call.Return(b, err) + return _c +} + +func (_c *MockService_ProbeReachability_Call) RunAndReturn(run func(ctx context.Context, req *requests.ConnectionProbe) (bool, error)) *MockService_ProbeReachability_Call { + _c.Call.Return(run) + return _c +} + // PublicKey provides a mock function for the type MockService func (_mock *MockService) PublicKey() *rsa.PublicKey { ret := _mock.Called() @@ -3462,6 +4054,74 @@ func (_c *MockService_ResolveDevice_Call) RunAndReturn(run func(ctx context.Cont return _c } +// ScanKnownHost provides a mock function for the type MockService +func (_mock *MockService) ScanKnownHost(ctx context.Context, req *requests.KnownHostScan) (*responses0.KnownHostScanResult, error) { + ret := _mock.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for ScanKnownHost") + } + + var r0 *responses0.KnownHostScanResult + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, *requests.KnownHostScan) (*responses0.KnownHostScanResult, error)); ok { + return returnFunc(ctx, req) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, *requests.KnownHostScan) *responses0.KnownHostScanResult); ok { + r0 = returnFunc(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*responses0.KnownHostScanResult) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, *requests.KnownHostScan) error); ok { + r1 = returnFunc(ctx, req) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// MockService_ScanKnownHost_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ScanKnownHost' +type MockService_ScanKnownHost_Call struct { + *mock.Call +} + +// ScanKnownHost is a helper method to define mock.On call +// - ctx context.Context +// - req *requests.KnownHostScan +func (_e *MockService_Expecter) ScanKnownHost(ctx any, req any) *MockService_ScanKnownHost_Call { + return &MockService_ScanKnownHost_Call{Call: _e.mock.On("ScanKnownHost", ctx, req)} +} + +func (_c *MockService_ScanKnownHost_Call) Run(run func(ctx context.Context, req *requests.KnownHostScan)) *MockService_ScanKnownHost_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 *requests.KnownHostScan + if args[1] != nil { + arg1 = args[1].(*requests.KnownHostScan) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockService_ScanKnownHost_Call) Return(knownHostScanResult *responses0.KnownHostScanResult, err error) *MockService_ScanKnownHost_Call { + _c.Call.Return(knownHostScanResult, err) + return _c +} + +func (_c *MockService_ScanKnownHost_Call) RunAndReturn(run func(ctx context.Context, req *requests.KnownHostScan) (*responses0.KnownHostScanResult, error)) *MockService_ScanKnownHost_Call { + _c.Call.Return(run) + return _c +} + // SetDeviceCustomField provides a mock function for the type MockService func (_mock *MockService) SetDeviceCustomField(ctx context.Context, req *requests.DeviceSetCustomField) error { ret := _mock.Called(ctx, req) @@ -3739,6 +4399,74 @@ func (_c *MockService_UpdateAPIKey_Call) RunAndReturn(run func(ctx context.Conte return _c } +// UpdateConnection provides a mock function for the type MockService +func (_mock *MockService) UpdateConnection(ctx context.Context, req *requests.ConnectionUpdate) (*models.Connection, error) { + ret := _mock.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for UpdateConnection") + } + + var r0 *models.Connection + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, *requests.ConnectionUpdate) (*models.Connection, error)); ok { + return returnFunc(ctx, req) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, *requests.ConnectionUpdate) *models.Connection); ok { + r0 = returnFunc(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*models.Connection) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, *requests.ConnectionUpdate) error); ok { + r1 = returnFunc(ctx, req) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// MockService_UpdateConnection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateConnection' +type MockService_UpdateConnection_Call struct { + *mock.Call +} + +// UpdateConnection is a helper method to define mock.On call +// - ctx context.Context +// - req *requests.ConnectionUpdate +func (_e *MockService_Expecter) UpdateConnection(ctx any, req any) *MockService_UpdateConnection_Call { + return &MockService_UpdateConnection_Call{Call: _e.mock.On("UpdateConnection", ctx, req)} +} + +func (_c *MockService_UpdateConnection_Call) Run(run func(ctx context.Context, req *requests.ConnectionUpdate)) *MockService_UpdateConnection_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 *requests.ConnectionUpdate + if args[1] != nil { + arg1 = args[1].(*requests.ConnectionUpdate) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockService_UpdateConnection_Call) Return(connection *models.Connection, err error) *MockService_UpdateConnection_Call { + _c.Call.Return(connection, err) + return _c +} + +func (_c *MockService_UpdateConnection_Call) RunAndReturn(run func(ctx context.Context, req *requests.ConnectionUpdate) (*models.Connection, error)) *MockService_UpdateConnection_Call { + _c.Call.Return(run) + return _c +} + // UpdateDevice provides a mock function for the type MockService func (_mock *MockService) UpdateDevice(ctx context.Context, req *requests.DeviceUpdate) error { ret := _mock.Called(ctx, req) diff --git a/api/services/service.go b/api/services/service.go index 969ba9e09bf..a5296486e0d 100644 --- a/api/services/service.go +++ b/api/services/service.go @@ -41,6 +41,8 @@ type Service interface { SetupService SystemService APIKeyService + ConnectionService + KnownHostService // Store returns the underlying store instance. // diff --git a/api/store/connection.go b/api/store/connection.go new file mode 100644 index 00000000000..ce3bb28c604 --- /dev/null +++ b/api/store/connection.go @@ -0,0 +1,32 @@ +package store + +import ( + "context" + + "github.com/shellhub-io/shellhub/pkg/models" +) + +type ConnectionResolver uint + +const ( + ConnectionIDResolver ConnectionResolver = iota + 1 +) + +type ConnectionStore interface { + // ConnectionCreate creates a connection. Returns the inserted ID and an error if any. + ConnectionCreate(ctx context.Context, connection *models.Connection) (insertedID string, err error) + + // ConnectionResolve fetches a connection using a specific resolver. Scope it to a + // tenant by passing the InNamespace query option. + ConnectionResolve(ctx context.Context, resolver ConnectionResolver, value string, opts ...QueryOption) (*models.Connection, error) + + // ConnectionList retrieves a list of connections. Returns the list, the total count + // of matched documents, and an error if any. + ConnectionList(ctx context.Context, opts ...QueryOption) (connections []models.Connection, count int, err error) + + // ConnectionUpdate updates a connection. Returns an error if any. + ConnectionUpdate(ctx context.Context, connection *models.Connection) (err error) + + // ConnectionDelete deletes a connection. Returns an error if any. + ConnectionDelete(ctx context.Context, connection *models.Connection) (err error) +} diff --git a/api/store/known_host.go b/api/store/known_host.go new file mode 100644 index 00000000000..db1a4b577c6 --- /dev/null +++ b/api/store/known_host.go @@ -0,0 +1,22 @@ +package store + +import ( + "context" + + "github.com/shellhub-io/shellhub/pkg/models" +) + +type KnownHostStore interface { + // KnownHostResolve fetches the stored known host for a target, scoped by + // owner. A non-empty ownerID resolves the caller's personal record; an empty + // ownerID resolves the namespace-shared (team) record. Returns ErrNoDocuments + // when none is stored. + KnownHostResolve(ctx context.Context, tenantID, ownerID, host string, port int) (*models.KnownHost, error) + + // KnownHostUpsert creates or replaces the known host for its (tenant, owner, + // host, port) scope. + KnownHostUpsert(ctx context.Context, knownHost *models.KnownHost) error + + // KnownHostDelete removes the stored known host for a target scope. + KnownHostDelete(ctx context.Context, tenantID, ownerID, host string, port int) error +} diff --git a/api/store/mocks/mock_query_options.go b/api/store/mocks/mock_query_options.go index fcb12211361..6acb29f7980 100644 --- a/api/store/mocks/mock_query_options.go +++ b/api/store/mocks/mock_query_options.go @@ -144,6 +144,59 @@ func (_c *MockQueryOptions_Match_Call) RunAndReturn(run func(fs *query.Filters) return _c } +// OwnedBy provides a mock function for the type MockQueryOptions +func (_mock *MockQueryOptions) OwnedBy(userID string) store.QueryOption { + ret := _mock.Called(userID) + + if len(ret) == 0 { + panic("no return value specified for OwnedBy") + } + + var r0 store.QueryOption + if returnFunc, ok := ret.Get(0).(func(string) store.QueryOption); ok { + r0 = returnFunc(userID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(store.QueryOption) + } + } + return r0 +} + +// MockQueryOptions_OwnedBy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OwnedBy' +type MockQueryOptions_OwnedBy_Call struct { + *mock.Call +} + +// OwnedBy is a helper method to define mock.On call +// - userID string +func (_e *MockQueryOptions_Expecter) OwnedBy(userID any) *MockQueryOptions_OwnedBy_Call { + return &MockQueryOptions_OwnedBy_Call{Call: _e.mock.On("OwnedBy", userID)} +} + +func (_c *MockQueryOptions_OwnedBy_Call) Run(run func(userID string)) *MockQueryOptions_OwnedBy_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *MockQueryOptions_OwnedBy_Call) Return(queryOption store.QueryOption) *MockQueryOptions_OwnedBy_Call { + _c.Call.Return(queryOption) + return _c +} + +func (_c *MockQueryOptions_OwnedBy_Call) RunAndReturn(run func(userID string) store.QueryOption) *MockQueryOptions_OwnedBy_Call { + _c.Call.Return(run) + return _c +} + // Paginate provides a mock function for the type MockQueryOptions func (_mock *MockQueryOptions) Paginate(paginator *query.Paginator) store.QueryOption { ret := _mock.Called(paginator) diff --git a/api/store/mocks/mock_store.go b/api/store/mocks/mock_store.go index f55f94b93e2..35061e57e10 100644 --- a/api/store/mocks/mock_store.go +++ b/api/store/mocks/mock_store.go @@ -717,6 +717,358 @@ func (_c *MockStore_ActiveSessionUpdate_Call) RunAndReturn(run func(ctx context. return _c } +// ConnectionCreate provides a mock function for the type MockStore +func (_mock *MockStore) ConnectionCreate(ctx context.Context, connection *models.Connection) (string, error) { + ret := _mock.Called(ctx, connection) + + if len(ret) == 0 { + panic("no return value specified for ConnectionCreate") + } + + var r0 string + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, *models.Connection) (string, error)); ok { + return returnFunc(ctx, connection) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, *models.Connection) string); ok { + r0 = returnFunc(ctx, connection) + } else { + r0 = ret.Get(0).(string) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, *models.Connection) error); ok { + r1 = returnFunc(ctx, connection) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// MockStore_ConnectionCreate_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ConnectionCreate' +type MockStore_ConnectionCreate_Call struct { + *mock.Call +} + +// ConnectionCreate is a helper method to define mock.On call +// - ctx context.Context +// - connection *models.Connection +func (_e *MockStore_Expecter) ConnectionCreate(ctx any, connection any) *MockStore_ConnectionCreate_Call { + return &MockStore_ConnectionCreate_Call{Call: _e.mock.On("ConnectionCreate", ctx, connection)} +} + +func (_c *MockStore_ConnectionCreate_Call) Run(run func(ctx context.Context, connection *models.Connection)) *MockStore_ConnectionCreate_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 *models.Connection + if args[1] != nil { + arg1 = args[1].(*models.Connection) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockStore_ConnectionCreate_Call) Return(insertedID string, err error) *MockStore_ConnectionCreate_Call { + _c.Call.Return(insertedID, err) + return _c +} + +func (_c *MockStore_ConnectionCreate_Call) RunAndReturn(run func(ctx context.Context, connection *models.Connection) (string, error)) *MockStore_ConnectionCreate_Call { + _c.Call.Return(run) + return _c +} + +// ConnectionDelete provides a mock function for the type MockStore +func (_mock *MockStore) ConnectionDelete(ctx context.Context, connection *models.Connection) error { + ret := _mock.Called(ctx, connection) + + if len(ret) == 0 { + panic("no return value specified for ConnectionDelete") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, *models.Connection) error); ok { + r0 = returnFunc(ctx, connection) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// MockStore_ConnectionDelete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ConnectionDelete' +type MockStore_ConnectionDelete_Call struct { + *mock.Call +} + +// ConnectionDelete is a helper method to define mock.On call +// - ctx context.Context +// - connection *models.Connection +func (_e *MockStore_Expecter) ConnectionDelete(ctx any, connection any) *MockStore_ConnectionDelete_Call { + return &MockStore_ConnectionDelete_Call{Call: _e.mock.On("ConnectionDelete", ctx, connection)} +} + +func (_c *MockStore_ConnectionDelete_Call) Run(run func(ctx context.Context, connection *models.Connection)) *MockStore_ConnectionDelete_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 *models.Connection + if args[1] != nil { + arg1 = args[1].(*models.Connection) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockStore_ConnectionDelete_Call) Return(err error) *MockStore_ConnectionDelete_Call { + _c.Call.Return(err) + return _c +} + +func (_c *MockStore_ConnectionDelete_Call) RunAndReturn(run func(ctx context.Context, connection *models.Connection) error) *MockStore_ConnectionDelete_Call { + _c.Call.Return(run) + return _c +} + +// ConnectionList provides a mock function for the type MockStore +func (_mock *MockStore) ConnectionList(ctx context.Context, opts ...store.QueryOption) ([]models.Connection, int, error) { + var tmpRet mock.Arguments + if len(opts) > 0 { + tmpRet = _mock.Called(ctx, opts) + } else { + tmpRet = _mock.Called(ctx) + } + ret := tmpRet + + if len(ret) == 0 { + panic("no return value specified for ConnectionList") + } + + var r0 []models.Connection + var r1 int + var r2 error + if returnFunc, ok := ret.Get(0).(func(context.Context, ...store.QueryOption) ([]models.Connection, int, error)); ok { + return returnFunc(ctx, opts...) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, ...store.QueryOption) []models.Connection); ok { + r0 = returnFunc(ctx, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]models.Connection) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, ...store.QueryOption) int); ok { + r1 = returnFunc(ctx, opts...) + } else { + r1 = ret.Get(1).(int) + } + if returnFunc, ok := ret.Get(2).(func(context.Context, ...store.QueryOption) error); ok { + r2 = returnFunc(ctx, opts...) + } else { + r2 = ret.Error(2) + } + return r0, r1, r2 +} + +// MockStore_ConnectionList_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ConnectionList' +type MockStore_ConnectionList_Call struct { + *mock.Call +} + +// ConnectionList is a helper method to define mock.On call +// - ctx context.Context +// - opts ...store.QueryOption +func (_e *MockStore_Expecter) ConnectionList(ctx any, opts ...any) *MockStore_ConnectionList_Call { + return &MockStore_ConnectionList_Call{Call: _e.mock.On("ConnectionList", + append([]any{ctx}, opts...)...)} +} + +func (_c *MockStore_ConnectionList_Call) Run(run func(ctx context.Context, opts ...store.QueryOption)) *MockStore_ConnectionList_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 []store.QueryOption + var variadicArgs []store.QueryOption + if len(args) > 1 { + variadicArgs = args[1].([]store.QueryOption) + } + arg1 = variadicArgs + run( + arg0, + arg1..., + ) + }) + return _c +} + +func (_c *MockStore_ConnectionList_Call) Return(connections []models.Connection, count int, err error) *MockStore_ConnectionList_Call { + _c.Call.Return(connections, count, err) + return _c +} + +func (_c *MockStore_ConnectionList_Call) RunAndReturn(run func(ctx context.Context, opts ...store.QueryOption) ([]models.Connection, int, error)) *MockStore_ConnectionList_Call { + _c.Call.Return(run) + return _c +} + +// ConnectionResolve provides a mock function for the type MockStore +func (_mock *MockStore) ConnectionResolve(ctx context.Context, resolver store.ConnectionResolver, value string, opts ...store.QueryOption) (*models.Connection, error) { + var tmpRet mock.Arguments + if len(opts) > 0 { + tmpRet = _mock.Called(ctx, resolver, value, opts) + } else { + tmpRet = _mock.Called(ctx, resolver, value) + } + ret := tmpRet + + if len(ret) == 0 { + panic("no return value specified for ConnectionResolve") + } + + var r0 *models.Connection + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, store.ConnectionResolver, string, ...store.QueryOption) (*models.Connection, error)); ok { + return returnFunc(ctx, resolver, value, opts...) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, store.ConnectionResolver, string, ...store.QueryOption) *models.Connection); ok { + r0 = returnFunc(ctx, resolver, value, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*models.Connection) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, store.ConnectionResolver, string, ...store.QueryOption) error); ok { + r1 = returnFunc(ctx, resolver, value, opts...) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// MockStore_ConnectionResolve_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ConnectionResolve' +type MockStore_ConnectionResolve_Call struct { + *mock.Call +} + +// ConnectionResolve is a helper method to define mock.On call +// - ctx context.Context +// - resolver store.ConnectionResolver +// - value string +// - opts ...store.QueryOption +func (_e *MockStore_Expecter) ConnectionResolve(ctx any, resolver any, value any, opts ...any) *MockStore_ConnectionResolve_Call { + return &MockStore_ConnectionResolve_Call{Call: _e.mock.On("ConnectionResolve", + append([]any{ctx, resolver, value}, opts...)...)} +} + +func (_c *MockStore_ConnectionResolve_Call) Run(run func(ctx context.Context, resolver store.ConnectionResolver, value string, opts ...store.QueryOption)) *MockStore_ConnectionResolve_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 store.ConnectionResolver + if args[1] != nil { + arg1 = args[1].(store.ConnectionResolver) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 []store.QueryOption + var variadicArgs []store.QueryOption + if len(args) > 3 { + variadicArgs = args[3].([]store.QueryOption) + } + arg3 = variadicArgs + run( + arg0, + arg1, + arg2, + arg3..., + ) + }) + return _c +} + +func (_c *MockStore_ConnectionResolve_Call) Return(connection *models.Connection, err error) *MockStore_ConnectionResolve_Call { + _c.Call.Return(connection, err) + return _c +} + +func (_c *MockStore_ConnectionResolve_Call) RunAndReturn(run func(ctx context.Context, resolver store.ConnectionResolver, value string, opts ...store.QueryOption) (*models.Connection, error)) *MockStore_ConnectionResolve_Call { + _c.Call.Return(run) + return _c +} + +// ConnectionUpdate provides a mock function for the type MockStore +func (_mock *MockStore) ConnectionUpdate(ctx context.Context, connection *models.Connection) error { + ret := _mock.Called(ctx, connection) + + if len(ret) == 0 { + panic("no return value specified for ConnectionUpdate") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, *models.Connection) error); ok { + r0 = returnFunc(ctx, connection) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// MockStore_ConnectionUpdate_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ConnectionUpdate' +type MockStore_ConnectionUpdate_Call struct { + *mock.Call +} + +// ConnectionUpdate is a helper method to define mock.On call +// - ctx context.Context +// - connection *models.Connection +func (_e *MockStore_Expecter) ConnectionUpdate(ctx any, connection any) *MockStore_ConnectionUpdate_Call { + return &MockStore_ConnectionUpdate_Call{Call: _e.mock.On("ConnectionUpdate", ctx, connection)} +} + +func (_c *MockStore_ConnectionUpdate_Call) Run(run func(ctx context.Context, connection *models.Connection)) *MockStore_ConnectionUpdate_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 *models.Connection + if args[1] != nil { + arg1 = args[1].(*models.Connection) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockStore_ConnectionUpdate_Call) Return(err error) *MockStore_ConnectionUpdate_Call { + _c.Call.Return(err) + return _c +} + +func (_c *MockStore_ConnectionUpdate_Call) RunAndReturn(run func(ctx context.Context, connection *models.Connection) error) *MockStore_ConnectionUpdate_Call { + _c.Call.Return(run) + return _c +} + // DeviceConflicts provides a mock function for the type MockStore func (_mock *MockStore) DeviceConflicts(ctx context.Context, target *models.DeviceConflicts, opts ...store.QueryOption) ([]string, bool, error) { var tmpRet mock.Arguments @@ -1502,6 +1854,224 @@ func (_c *MockStore_GetStats_Call) RunAndReturn(run func(ctx context.Context, te return _c } +// KnownHostDelete provides a mock function for the type MockStore +func (_mock *MockStore) KnownHostDelete(ctx context.Context, tenantID string, ownerID string, host string, port int) error { + ret := _mock.Called(ctx, tenantID, ownerID, host, port) + + if len(ret) == 0 { + panic("no return value specified for KnownHostDelete") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, int) error); ok { + r0 = returnFunc(ctx, tenantID, ownerID, host, port) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// MockStore_KnownHostDelete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'KnownHostDelete' +type MockStore_KnownHostDelete_Call struct { + *mock.Call +} + +// KnownHostDelete is a helper method to define mock.On call +// - ctx context.Context +// - tenantID string +// - ownerID string +// - host string +// - port int +func (_e *MockStore_Expecter) KnownHostDelete(ctx any, tenantID any, ownerID any, host any, port any) *MockStore_KnownHostDelete_Call { + return &MockStore_KnownHostDelete_Call{Call: _e.mock.On("KnownHostDelete", ctx, tenantID, ownerID, host, port)} +} + +func (_c *MockStore_KnownHostDelete_Call) Run(run func(ctx context.Context, tenantID string, ownerID string, host string, port int)) *MockStore_KnownHostDelete_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 int + if args[4] != nil { + arg4 = args[4].(int) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *MockStore_KnownHostDelete_Call) Return(err error) *MockStore_KnownHostDelete_Call { + _c.Call.Return(err) + return _c +} + +func (_c *MockStore_KnownHostDelete_Call) RunAndReturn(run func(ctx context.Context, tenantID string, ownerID string, host string, port int) error) *MockStore_KnownHostDelete_Call { + _c.Call.Return(run) + return _c +} + +// KnownHostResolve provides a mock function for the type MockStore +func (_mock *MockStore) KnownHostResolve(ctx context.Context, tenantID string, ownerID string, host string, port int) (*models.KnownHost, error) { + ret := _mock.Called(ctx, tenantID, ownerID, host, port) + + if len(ret) == 0 { + panic("no return value specified for KnownHostResolve") + } + + var r0 *models.KnownHost + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, int) (*models.KnownHost, error)); ok { + return returnFunc(ctx, tenantID, ownerID, host, port) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, int) *models.KnownHost); ok { + r0 = returnFunc(ctx, tenantID, ownerID, host, port) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*models.KnownHost) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string, int) error); ok { + r1 = returnFunc(ctx, tenantID, ownerID, host, port) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// MockStore_KnownHostResolve_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'KnownHostResolve' +type MockStore_KnownHostResolve_Call struct { + *mock.Call +} + +// KnownHostResolve is a helper method to define mock.On call +// - ctx context.Context +// - tenantID string +// - ownerID string +// - host string +// - port int +func (_e *MockStore_Expecter) KnownHostResolve(ctx any, tenantID any, ownerID any, host any, port any) *MockStore_KnownHostResolve_Call { + return &MockStore_KnownHostResolve_Call{Call: _e.mock.On("KnownHostResolve", ctx, tenantID, ownerID, host, port)} +} + +func (_c *MockStore_KnownHostResolve_Call) Run(run func(ctx context.Context, tenantID string, ownerID string, host string, port int)) *MockStore_KnownHostResolve_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 int + if args[4] != nil { + arg4 = args[4].(int) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *MockStore_KnownHostResolve_Call) Return(knownHost *models.KnownHost, err error) *MockStore_KnownHostResolve_Call { + _c.Call.Return(knownHost, err) + return _c +} + +func (_c *MockStore_KnownHostResolve_Call) RunAndReturn(run func(ctx context.Context, tenantID string, ownerID string, host string, port int) (*models.KnownHost, error)) *MockStore_KnownHostResolve_Call { + _c.Call.Return(run) + return _c +} + +// KnownHostUpsert provides a mock function for the type MockStore +func (_mock *MockStore) KnownHostUpsert(ctx context.Context, knownHost *models.KnownHost) error { + ret := _mock.Called(ctx, knownHost) + + if len(ret) == 0 { + panic("no return value specified for KnownHostUpsert") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, *models.KnownHost) error); ok { + r0 = returnFunc(ctx, knownHost) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// MockStore_KnownHostUpsert_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'KnownHostUpsert' +type MockStore_KnownHostUpsert_Call struct { + *mock.Call +} + +// KnownHostUpsert is a helper method to define mock.On call +// - ctx context.Context +// - knownHost *models.KnownHost +func (_e *MockStore_Expecter) KnownHostUpsert(ctx any, knownHost any) *MockStore_KnownHostUpsert_Call { + return &MockStore_KnownHostUpsert_Call{Call: _e.mock.On("KnownHostUpsert", ctx, knownHost)} +} + +func (_c *MockStore_KnownHostUpsert_Call) Run(run func(ctx context.Context, knownHost *models.KnownHost)) *MockStore_KnownHostUpsert_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 *models.KnownHost + if args[1] != nil { + arg1 = args[1].(*models.KnownHost) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockStore_KnownHostUpsert_Call) Return(err error) *MockStore_KnownHostUpsert_Call { + _c.Call.Return(err) + return _c +} + +func (_c *MockStore_KnownHostUpsert_Call) RunAndReturn(run func(ctx context.Context, knownHost *models.KnownHost) error) *MockStore_KnownHostUpsert_Call { + _c.Call.Return(run) + return _c +} + // NamespaceConflicts provides a mock function for the type MockStore func (_mock *MockStore) NamespaceConflicts(ctx context.Context, target *models.NamespaceConflicts) ([]string, bool, error) { ret := _mock.Called(ctx, target) diff --git a/api/store/pg/connection.go b/api/store/pg/connection.go new file mode 100644 index 00000000000..b19a37fbd23 --- /dev/null +++ b/api/store/pg/connection.go @@ -0,0 +1,116 @@ +package pg + +import ( + "context" + + "github.com/shellhub-io/shellhub/api/store" + "github.com/shellhub-io/shellhub/api/store/pg/entity" + "github.com/shellhub-io/shellhub/pkg/clock" + "github.com/shellhub-io/shellhub/pkg/models" + "github.com/uptrace/bun" +) + +func (pg *Pg) ConnectionCreate(ctx context.Context, connection *models.Connection) (string, error) { + db := pg.GetConnection(ctx) + + connection.CreatedAt = clock.Now() + connection.UpdatedAt = clock.Now() + if _, err := db.NewInsert().Model(entity.ConnectionFromModel(connection)).Exec(ctx); err != nil { + return "", fromSQLError(err) + } + + return connection.ID, nil +} + +func (pg *Pg) ConnectionList(ctx context.Context, opts ...store.QueryOption) ([]models.Connection, int, error) { + db := pg.GetConnection(ctx) + + entities := make([]entity.Connection, 0) + + query := db.NewSelect().Model(&entities) + var err error + query, err = applyOptions(ctx, query, opts...) + if err != nil { + return nil, 0, err + } + + count, err := query.ScanAndCount(ctx) + if err != nil { + return nil, 0, fromSQLError(err) + } + + connections := make([]models.Connection, len(entities)) + for i, e := range entities { + connections[i] = *entity.ConnectionToModel(&e) + } + + return connections, count, nil +} + +func (pg *Pg) ConnectionResolve(ctx context.Context, resolver store.ConnectionResolver, val string, opts ...store.QueryOption) (*models.Connection, error) { + db := pg.GetConnection(ctx) + + column, err := connectionResolverToString(resolver) + if err != nil { + return nil, err + } + + c := new(entity.Connection) + query := db.NewSelect().Model(c).Where("? = ?", bun.Ident(column), val) + query, err = applyOptions(ctx, query, opts...) + if err != nil { + return nil, err + } + + if err = query.Scan(ctx); err != nil { + return nil, fromSQLError(err) + } + + return entity.ConnectionToModel(c), nil +} + +func (pg *Pg) ConnectionUpdate(ctx context.Context, connection *models.Connection) error { + db := pg.GetConnection(ctx) + + connection.UpdatedAt = clock.Now() + e := entity.ConnectionFromModel(connection) + r, err := db.NewUpdate(). + Model(e). + Column("label", "kind", "host", "port", "device_uid", "username", "auth_method", "key_fingerprint", "updated_at"). + WherePK(). + Exec(ctx) + if err != nil { + return fromSQLError(err) + } + + if rowsAffected, err := r.RowsAffected(); err != nil || rowsAffected == 0 { + return store.ErrNoDocuments + } + + return nil +} + +func (pg *Pg) ConnectionDelete(ctx context.Context, connection *models.Connection) error { + db := pg.GetConnection(ctx) + + c := entity.ConnectionFromModel(connection) + r, err := db.NewDelete().Model(c).WherePK().Exec(ctx) + if err != nil { + return fromSQLError(err) + } + + if rowsAffected, err := r.RowsAffected(); err != nil || rowsAffected == 0 { + return store.ErrNoDocuments + } + + return nil +} + +func connectionResolverToString(resolver store.ConnectionResolver) (string, error) { + switch resolver { + case store.ConnectionIDResolver: + return "id", nil + default: + return "", store.ErrResolverNotFound + } +} diff --git a/api/store/pg/entity/connection.go b/api/store/pg/entity/connection.go new file mode 100644 index 00000000000..b6a6e2a80ed --- /dev/null +++ b/api/store/pg/entity/connection.go @@ -0,0 +1,62 @@ +package entity + +import ( + "time" + + "github.com/shellhub-io/shellhub/pkg/models" + "github.com/uptrace/bun" +) + +type Connection struct { + bun.BaseModel `bun:"table:connections"` + + ID string `bun:"id,pk,type:uuid"` + NamespaceID string `bun:"namespace_id,type:uuid"` + OwnerID string `bun:"owner_id,type:uuid"` + Label string `bun:"label"` + Kind string `bun:"kind"` + Host string `bun:"host"` + Port int `bun:"port"` + DeviceUID string `bun:"device_uid"` + Username string `bun:"username"` + AuthMethod string `bun:"auth_method"` + KeyFingerprint string `bun:"key_fingerprint"` + CreatedAt time.Time `bun:"created_at"` + UpdatedAt time.Time `bun:"updated_at"` +} + +func ConnectionFromModel(model *models.Connection) *Connection { + return &Connection{ + ID: model.ID, + NamespaceID: model.TenantID, + OwnerID: model.OwnerID, + Label: model.Label, + Kind: string(model.Kind), + Host: model.Host, + Port: model.Port, + DeviceUID: model.DeviceUID, + Username: model.Username, + AuthMethod: model.AuthMethod, + KeyFingerprint: model.KeyFingerprint, + CreatedAt: model.CreatedAt, + UpdatedAt: model.UpdatedAt, + } +} + +func ConnectionToModel(entity *Connection) *models.Connection { + return &models.Connection{ + ID: entity.ID, + TenantID: entity.NamespaceID, + OwnerID: entity.OwnerID, + Label: entity.Label, + Kind: models.ConnectionKind(entity.Kind), + Host: entity.Host, + Port: entity.Port, + DeviceUID: entity.DeviceUID, + Username: entity.Username, + AuthMethod: entity.AuthMethod, + KeyFingerprint: entity.KeyFingerprint, + CreatedAt: entity.CreatedAt, + UpdatedAt: entity.UpdatedAt, + } +} diff --git a/api/store/pg/entity/known_host.go b/api/store/pg/entity/known_host.go new file mode 100644 index 00000000000..9b2cb61c65b --- /dev/null +++ b/api/store/pg/entity/known_host.go @@ -0,0 +1,56 @@ +package entity + +import ( + "time" + + "github.com/shellhub-io/shellhub/pkg/models" + "github.com/uptrace/bun" +) + +type KnownHost struct { + bun.BaseModel `bun:"table:ssh_known_hosts"` + + ID string `bun:"id,pk,type:uuid"` + NamespaceID string `bun:"namespace_id,type:uuid"` + OwnerID string `bun:"owner_id,type:uuid,nullzero"` + Host string `bun:"host"` + Port int `bun:"port"` + KeyType string `bun:"key_type"` + PublicKey string `bun:"public_key"` + Fingerprint string `bun:"fingerprint"` + AcceptedBy string `bun:"accepted_by,type:uuid,nullzero"` + CreatedAt time.Time `bun:"created_at"` + UpdatedAt time.Time `bun:"updated_at"` +} + +func KnownHostFromModel(model *models.KnownHost) *KnownHost { + return &KnownHost{ + ID: model.ID, + NamespaceID: model.TenantID, + OwnerID: model.OwnerID, + Host: model.Host, + Port: model.Port, + KeyType: model.KeyType, + PublicKey: model.PublicKey, + Fingerprint: model.Fingerprint, + AcceptedBy: model.AcceptedBy, + CreatedAt: model.CreatedAt, + UpdatedAt: model.UpdatedAt, + } +} + +func KnownHostToModel(entity *KnownHost) *models.KnownHost { + return &models.KnownHost{ + ID: entity.ID, + TenantID: entity.NamespaceID, + OwnerID: entity.OwnerID, + Host: entity.Host, + Port: entity.Port, + KeyType: entity.KeyType, + PublicKey: entity.PublicKey, + Fingerprint: entity.Fingerprint, + AcceptedBy: entity.AcceptedBy, + CreatedAt: entity.CreatedAt, + UpdatedAt: entity.UpdatedAt, + } +} diff --git a/api/store/pg/known_host.go b/api/store/pg/known_host.go new file mode 100644 index 00000000000..deab9522d90 --- /dev/null +++ b/api/store/pg/known_host.go @@ -0,0 +1,100 @@ +package pg + +import ( + "context" + + "github.com/shellhub-io/shellhub/api/store" + "github.com/shellhub-io/shellhub/api/store/pg/entity" + "github.com/shellhub-io/shellhub/pkg/clock" + "github.com/shellhub-io/shellhub/pkg/models" + "github.com/shellhub-io/shellhub/pkg/uuid" +) + +func (pg *Pg) KnownHostResolve(ctx context.Context, tenantID, ownerID, host string, port int) (*models.KnownHost, error) { + db := pg.GetConnection(ctx) + + e := new(entity.KnownHost) + q := db.NewSelect(). + Model(e). + Where("namespace_id = ?", tenantID). + Where("host = ?", host). + Where("port = ?", port) + + if ownerID != "" { + q = q.Where("owner_id = ?", ownerID) + } else { + q = q.Where("owner_id IS NULL") + } + + if err := q.Scan(ctx); err != nil { + return nil, fromSQLError(err) + } + + return entity.KnownHostToModel(e), nil +} + +func (pg *Pg) KnownHostUpsert(ctx context.Context, knownHost *models.KnownHost) error { + now := clock.Now() + knownHost.ID = uuid.Generate() + knownHost.CreatedAt = now + knownHost.UpdatedAt = now + + // A single INSERT ... ON CONFLICT atomically inserts or replaces the stored + // key, avoiding the resolve-then-write race on a concurrent first accept. The + // scope picks which partial unique index arbitrates (personal vs team), since + // only one applies to a given row. + conflict := "CONFLICT (namespace_id, host, port) WHERE owner_id IS NULL" + if knownHost.OwnerID != "" { + conflict = "CONFLICT (namespace_id, owner_id, host, port) WHERE owner_id IS NOT NULL" + } + + e := entity.KnownHostFromModel(knownHost) + + var result entity.KnownHost + if err := pg.GetConnection(ctx).NewInsert(). + Model(e). + On(conflict+" DO UPDATE"). + Set("key_type = EXCLUDED.key_type"). + Set("public_key = EXCLUDED.public_key"). + Set("fingerprint = EXCLUDED.fingerprint"). + Set("accepted_by = EXCLUDED.accepted_by"). + Set("updated_at = EXCLUDED.updated_at"). + Returning("id, created_at"). + Scan(ctx, &result); err != nil { + return fromSQLError(err) + } + + // On a conflict update the row keeps its original id/created_at; reflect the + // stored values back onto the caller's model. + knownHost.ID = result.ID + knownHost.CreatedAt = result.CreatedAt + + return nil +} + +func (pg *Pg) KnownHostDelete(ctx context.Context, tenantID, ownerID, host string, port int) error { + db := pg.GetConnection(ctx) + + q := db.NewDelete(). + Model((*entity.KnownHost)(nil)). + Where("namespace_id = ?", tenantID). + Where("host = ?", host). + Where("port = ?", port) + + if ownerID != "" { + q = q.Where("owner_id = ?", ownerID) + } else { + q = q.Where("owner_id IS NULL") + } + + r, err := q.Exec(ctx) + if err != nil { + return fromSQLError(err) + } + + if rows, err := r.RowsAffected(); err != nil || rows == 0 { + return store.ErrNoDocuments + } + + return nil +} diff --git a/api/store/pg/migrations/007_connections.tx.down.sql b/api/store/pg/migrations/007_connections.tx.down.sql new file mode 100644 index 00000000000..31d4a934897 --- /dev/null +++ b/api/store/pg/migrations/007_connections.tx.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS connections; diff --git a/api/store/pg/migrations/007_connections.tx.up.sql b/api/store/pg/migrations/007_connections.tx.up.sql new file mode 100644 index 00000000000..9c5510d37c0 --- /dev/null +++ b/api/store/pg/migrations/007_connections.tx.up.sql @@ -0,0 +1,20 @@ +CREATE TABLE connections ( + id uuid NOT NULL, + namespace_id uuid NOT NULL, + owner_id uuid NOT NULL, + label character varying NOT NULL, + kind character varying NOT NULL DEFAULT 'external', + host character varying NOT NULL DEFAULT '', + port integer NOT NULL DEFAULT 22, + username character varying NOT NULL DEFAULT '', + auth_method character varying NOT NULL DEFAULT '', + key_fingerprint character varying NOT NULL DEFAULT '', + created_at timestamp with time zone NOT NULL, + updated_at timestamp with time zone NOT NULL, + PRIMARY KEY (id), + FOREIGN KEY (namespace_id) REFERENCES namespaces(id) ON DELETE CASCADE, + FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE +); + +-- A connection is personal; labels are unique per owner within a namespace. +CREATE UNIQUE INDEX connections_ns_owner_label_unique ON connections USING btree (namespace_id, owner_id, label); diff --git a/api/store/pg/migrations/008_connection_device_target.tx.down.sql b/api/store/pg/migrations/008_connection_device_target.tx.down.sql new file mode 100644 index 00000000000..5a878368809 --- /dev/null +++ b/api/store/pg/migrations/008_connection_device_target.tx.down.sql @@ -0,0 +1 @@ +ALTER TABLE connections DROP COLUMN IF EXISTS device_uid; diff --git a/api/store/pg/migrations/008_connection_device_target.tx.up.sql b/api/store/pg/migrations/008_connection_device_target.tx.up.sql new file mode 100644 index 00000000000..bebdb1e18bf --- /dev/null +++ b/api/store/pg/migrations/008_connection_device_target.tx.up.sql @@ -0,0 +1 @@ +ALTER TABLE connections ADD COLUMN IF NOT EXISTS device_uid character varying NOT NULL DEFAULT ''; diff --git a/api/store/pg/migrations/009_ssh_known_hosts.tx.down.sql b/api/store/pg/migrations/009_ssh_known_hosts.tx.down.sql new file mode 100644 index 00000000000..a4843059884 --- /dev/null +++ b/api/store/pg/migrations/009_ssh_known_hosts.tx.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS ssh_known_hosts; diff --git a/api/store/pg/migrations/009_ssh_known_hosts.tx.up.sql b/api/store/pg/migrations/009_ssh_known_hosts.tx.up.sql new file mode 100644 index 00000000000..93dad63a735 --- /dev/null +++ b/api/store/pg/migrations/009_ssh_known_hosts.tx.up.sql @@ -0,0 +1,22 @@ +CREATE TABLE ssh_known_hosts ( + id uuid NOT NULL, + namespace_id uuid NOT NULL, + owner_id uuid, + host character varying NOT NULL, + port integer NOT NULL, + key_type character varying NOT NULL, + public_key text NOT NULL, + fingerprint character varying NOT NULL, + accepted_by uuid, + created_at timestamp with time zone NOT NULL, + updated_at timestamp with time zone NOT NULL, + PRIMARY KEY (id), + FOREIGN KEY (namespace_id) REFERENCES namespaces(id) ON DELETE CASCADE, + FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE, + FOREIGN KEY (accepted_by) REFERENCES users(id) ON DELETE SET NULL +); + +-- A known host follows the scope of the connection it was reached through: +-- personal (owner_id set) is per-user; team (owner_id NULL) is shared per namespace. +CREATE UNIQUE INDEX ssh_known_hosts_personal_unique ON ssh_known_hosts USING btree (namespace_id, owner_id, host, port) WHERE owner_id IS NOT NULL; +CREATE UNIQUE INDEX ssh_known_hosts_team_unique ON ssh_known_hosts USING btree (namespace_id, host, port) WHERE owner_id IS NULL; diff --git a/api/store/pg/query-options.go b/api/store/pg/query-options.go index 01c79eb1ddd..08653dd6a78 100644 --- a/api/store/pg/query-options.go +++ b/api/store/pg/query-options.go @@ -152,6 +152,31 @@ func (*queryOptions) WithMember(userID string) store.QueryOption { } } +func (*queryOptions) OwnedBy(userID string) store.QueryOption { + return func(ctx context.Context) error { + wrapper, ok := ctx.Value("query").(*queryWrapper) + if !ok { + return ErrQueryNotFound + } + + // owner_id is a uuid-typed column. A blank or non-UUID user id can never + // own a row, so short-circuit to a clean not-found instead of letting the + // uuid cast raise SQLSTATE 22P02. + if _, err := uuid.Parse(userID); err != nil { + return store.ErrNoDocuments + } + + col := "owner_id" + if alias, ok := ctx.Value(CtxTableAlias).(string); ok && alias != "" { + col = alias + ".owner_id" + } + + wrapper.query = wrapper.query.Where("? = ?", bun.Ident(col), userID) + + return nil + } +} + func (*queryOptions) InNamespace(namespaceID string) store.QueryOption { return func(ctx context.Context) error { wrapper, ok := ctx.Value("query").(*queryWrapper) diff --git a/api/store/query-options.go b/api/store/query-options.go index 70cf2642931..73c76887d54 100644 --- a/api/store/query-options.go +++ b/api/store/query-options.go @@ -21,6 +21,9 @@ type QueryOptions interface { // WithMember filters namespaces where the given user is a member. WithMember(userID string) QueryOption + // OwnedBy matches a record owned by the given user (owner_id). + OwnedBy(userID string) QueryOption + // Match applies the provided query filters to match records Match(fs *query.Filters) QueryOption diff --git a/api/store/store.go b/api/store/store.go index 941e9e53ae9..cab5851a94f 100644 --- a/api/store/store.go +++ b/api/store/store.go @@ -11,6 +11,8 @@ type Store interface { PrivateKeyStore StatsStore APIKeyStore + ConnectionStore + KnownHostStore TransactionStore SystemStore diff --git a/go.mod b/go.mod index c925bdf4ed6..798d48715c3 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/shellhub-io/shellhub go 1.25.8 require ( + code.dny.dev/ssrf v0.2.0 github.com/adhocore/gronx v1.8.1 github.com/go-playground/validator/v10 v10.11.2 github.com/go-redis/cache/v8 v8.4.4 @@ -90,6 +91,7 @@ require ( go.opentelemetry.io/otel/sdk/metric v1.40.0 // indirect go.opentelemetry.io/otel/trace v1.41.0 // indirect go.uber.org/goleak v1.3.0 // indirect + golang.org/x/exp v0.0.0-20221126150942-6ab00d035af9 // indirect golang.org/x/net v0.49.0 // indirect golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.42.0 // indirect diff --git a/go.sum b/go.sum index d9d166a0823..7843a435f6d 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +code.dny.dev/ssrf v0.2.0 h1:wCBP990rQQ1CYfRpW+YK1+8xhwUjv189AQ3WMo1jQaI= +code.dny.dev/ssrf v0.2.0/go.mod h1:B+91l25OnyaLIeCx0WRJN5qfJ/4/ZTZxRXgm0lj/2w8= dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA= github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk= @@ -256,6 +258,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/exp v0.0.0-20221126150942-6ab00d035af9 h1:yZNXmy+j/JpX19vZkVktWqAo7Gny4PBWYYK3zskGpx4= +golang.org/x/exp v0.0.0-20221126150942-6ab00d035af9/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= diff --git a/openapi/spec/cloud-openapi.yaml b/openapi/spec/cloud-openapi.yaml index 8f977405c40..6e62b2c9406 100644 --- a/openapi/spec/cloud-openapi.yaml +++ b/openapi/spec/cloud-openapi.yaml @@ -21,6 +21,9 @@ components: securitySchemes: $ref: ./components/schemas/security.yaml tags: + - name: connections + x-displayName: Connections + description: Saved SSH connections (personal and team) and trusted host keys. - name: devices x-displayName: Devices description: List, inspect, accept, reject, rename, and delete devices, and manage their custom fields. @@ -203,3 +206,11 @@ paths: $ref: paths/admin@api@announcements.yaml /admin/api/announcements/{uuid}: $ref: paths/admin@api@announcements@{uuid}.yaml + /api/connections/team: + $ref: paths/api@connections@team.yaml + /api/connections/team/{id}: + $ref: paths/api@connections@team@{id}.yaml + /api/connections/team/{id}/status: + $ref: paths/api@connections@team@{id}@status.yaml + /api/connections/team/{id}/prefs: + $ref: paths/api@connections@team@{id}@prefs.yaml diff --git a/openapi/spec/community-openapi.yaml b/openapi/spec/community-openapi.yaml index 42b1193774b..6a8dbae776a 100644 --- a/openapi/spec/community-openapi.yaml +++ b/openapi/spec/community-openapi.yaml @@ -21,6 +21,9 @@ components: securitySchemes: $ref: ./components/schemas/security.yaml tags: + - name: connections + x-displayName: Connections + description: Saved SSH connections (personal and team) and trusted host keys. - name: devices x-displayName: Devices description: List, inspect, accept, reject, rename, and delete devices, and manage their custom fields. @@ -152,3 +155,15 @@ paths: $ref: paths/api@containers@{uid}@{status}.yaml /api/setup: $ref: paths/api@setup.yaml + /api/connections: + $ref: paths/api@connections.yaml + /api/connections/{id}: + $ref: paths/api@connections@{id}.yaml + /api/connections/{id}/status: + $ref: paths/api@connections@{id}@status.yaml + /api/connections/host-key: + $ref: paths/api@connections@host-key.yaml + /api/connections/host-key/scan: + $ref: paths/api@connections@host-key@scan.yaml + /api/connections/host-key/accept: + $ref: paths/api@connections@host-key@accept.yaml diff --git a/openapi/spec/components/schemas/connection.yaml b/openapi/spec/components/schemas/connection.yaml new file mode 100644 index 00000000000..443e80f21ff --- /dev/null +++ b/openapi/spec/components/schemas/connection.yaml @@ -0,0 +1,68 @@ +type: object +description: A saved SSH connection (personal address-book entry). +properties: + id: + description: Connection's ID. + type: string + example: 3b8c2f1e-1d2a-4c5b-8e9f-0a1b2c3d4e5f + tenant_id: + $ref: namespaceTenantID.yaml + owner_id: + description: The user the connection belongs to. + type: string + example: 507f1f77bcf86cd799439011 + label: + description: Human-friendly name for the connection. + type: string + example: db-primary + kind: + description: How the connection reaches its target. + type: string + enum: + - external + - device + example: external + host: + description: Dial host for an external connection. + type: string + example: 10.0.0.5 + port: + description: Dial port for an external connection. + type: integer + example: 22 + device_uid: + description: Target device UID for a device connection. + type: string + example: a1b2c3d4 + username: + description: OS username to authenticate as. + type: string + example: root + auth_method: + description: Saved authentication method ("password" or "key"; empty when none). + type: string + example: key + key_fingerprint: + description: Pointer to the SSH key in the owner's vault (never the secret). + type: string + example: SHA256:abcd + created_at: + type: string + format: date-time + updated_at: + type: string + format: date-time +required: + - id + - tenant_id + - owner_id + - label + - kind + - host + - port + - device_uid + - username + - auth_method + - key_fingerprint + - created_at + - updated_at diff --git a/openapi/spec/components/schemas/connectionCreateRequest.yaml b/openapi/spec/components/schemas/connectionCreateRequest.yaml new file mode 100644 index 00000000000..cba7885b56f --- /dev/null +++ b/openapi/spec/components/schemas/connectionCreateRequest.yaml @@ -0,0 +1,41 @@ +type: object +description: Payload to create a personal connection. +properties: + label: + description: Human-friendly name for the connection. + type: string + example: db-primary + kind: + type: string + enum: + - external + - device + example: external + host: + description: Dial host (required for an external connection). + type: string + example: 10.0.0.5 + port: + description: Dial port (required for an external connection). + type: integer + example: 22 + device_uid: + description: Target device UID (required for a device connection). + type: string + example: a1b2c3d4 + username: + type: string + example: root + auth_method: + type: string + example: key + key_fingerprint: + type: string + example: SHA256:abcd + force: + description: Save even if the external target is currently unreachable. + type: boolean + example: false +required: + - label + - kind diff --git a/openapi/spec/components/schemas/connectionStatus.yaml b/openapi/spec/components/schemas/connectionStatus.yaml new file mode 100644 index 00000000000..b661e593657 --- /dev/null +++ b/openapi/spec/components/schemas/connectionStatus.yaml @@ -0,0 +1,9 @@ +type: object +description: Reachability of a connection's target. +properties: + online: + description: Whether the target is currently reachable. + type: boolean + example: true +required: + - online diff --git a/openapi/spec/components/schemas/connectionUpdateRequest.yaml b/openapi/spec/components/schemas/connectionUpdateRequest.yaml new file mode 100644 index 00000000000..c2500b3e463 --- /dev/null +++ b/openapi/spec/components/schemas/connectionUpdateRequest.yaml @@ -0,0 +1,33 @@ +type: object +description: Payload to update a personal connection. +properties: + label: + type: string + example: db-primary + kind: + type: string + enum: + - external + - device + example: external + host: + type: string + example: 10.0.0.5 + port: + type: integer + example: 22 + device_uid: + type: string + example: a1b2c3d4 + username: + type: string + example: root + auth_method: + type: string + example: key + key_fingerprint: + type: string + example: SHA256:abcd +required: + - label + - kind diff --git a/openapi/spec/components/schemas/knownHost.yaml b/openapi/spec/components/schemas/knownHost.yaml new file mode 100644 index 00000000000..41ea3bb3127 --- /dev/null +++ b/openapi/spec/components/schemas/knownHost.yaml @@ -0,0 +1,52 @@ +type: object +description: A trusted SSH host key (trust-on-first-use) for an external target. +properties: + id: + type: string + example: 3b8c2f1e-1d2a-4c5b-8e9f-0a1b2c3d4e5f + tenant_id: + $ref: namespaceTenantID.yaml + owner_id: + description: Set for a personal (per-user) known host; empty for a namespace-shared one. + type: string + example: 507f1f77bcf86cd799439011 + host: + type: string + example: 10.0.0.5 + port: + type: integer + example: 22 + key_type: + description: Host key algorithm. + type: string + example: ssh-ed25519 + public_key: + description: Host key in authorized_keys format. + type: string + example: ssh-ed25519 AAAAC3Nz... + fingerprint: + description: SHA256 fingerprint of the host key. + type: string + example: SHA256:abcd + accepted_by: + description: User who accepted the key (audit). + type: string + example: 507f1f77bcf86cd799439011 + created_at: + type: string + format: date-time + updated_at: + type: string + format: date-time +required: + - id + - tenant_id + - owner_id + - host + - port + - key_type + - public_key + - fingerprint + - accepted_by + - created_at + - updated_at diff --git a/openapi/spec/components/schemas/knownHostAcceptRequest.yaml b/openapi/spec/components/schemas/knownHostAcceptRequest.yaml new file mode 100644 index 00000000000..55dd14d7fc7 --- /dev/null +++ b/openapi/spec/components/schemas/knownHostAcceptRequest.yaml @@ -0,0 +1,31 @@ +type: object +description: Payload to trust (store) a host key for an external target. +properties: + host: + type: string + example: 10.0.0.5 + port: + type: integer + example: 22 + scope: + type: string + enum: + - personal + - namespace + example: personal + key_type: + type: string + example: ssh-ed25519 + public_key: + type: string + example: ssh-ed25519 AAAAC3Nz... + fingerprint: + type: string + example: SHA256:abcd +required: + - host + - port + - scope + - key_type + - public_key + - fingerprint diff --git a/openapi/spec/components/schemas/knownHostScanRequest.yaml b/openapi/spec/components/schemas/knownHostScanRequest.yaml new file mode 100644 index 00000000000..376645699ad --- /dev/null +++ b/openapi/spec/components/schemas/knownHostScanRequest.yaml @@ -0,0 +1,20 @@ +type: object +description: Payload to scan an external target's host key. +properties: + host: + type: string + example: 10.0.0.5 + port: + type: integer + example: 22 + scope: + description: Whether the known host is personal or namespace-shared. + type: string + enum: + - personal + - namespace + example: personal +required: + - host + - port + - scope diff --git a/openapi/spec/components/schemas/knownHostScanResult.yaml b/openapi/spec/components/schemas/knownHostScanResult.yaml new file mode 100644 index 00000000000..3d8aa0550f6 --- /dev/null +++ b/openapi/spec/components/schemas/knownHostScanResult.yaml @@ -0,0 +1,29 @@ +type: object +description: A scanned host key plus its verification status against what is stored. +properties: + key_type: + type: string + example: ssh-ed25519 + fingerprint: + type: string + example: SHA256:abcd + public_key: + type: string + example: ssh-ed25519 AAAAC3Nz... + status: + description: Verification status of the scanned key. + type: string + enum: + - unverified + - trusted + - changed + example: unverified + stored: + description: The currently stored known host, when one exists. + allOf: + - $ref: knownHost.yaml +required: + - key_type + - fingerprint + - public_key + - status diff --git a/openapi/spec/components/schemas/teamConnection.yaml b/openapi/spec/components/schemas/teamConnection.yaml new file mode 100644 index 00000000000..e2931cdcdd9 --- /dev/null +++ b/openapi/spec/components/schemas/teamConnection.yaml @@ -0,0 +1,47 @@ +type: object +description: A saved SSH connection shared with every member of a namespace. +properties: + id: + type: string + example: 3b8c2f1e-1d2a-4c5b-8e9f-0a1b2c3d4e5f + tenant_id: + $ref: namespaceTenantID.yaml + created_by: + description: User that created the connection (audit; does not gate visibility). + type: string + example: 507f1f77bcf86cd799439011 + label: + type: string + example: db-primary + kind: + type: string + enum: + - external + - device + example: external + host: + type: string + example: 10.0.0.5 + port: + type: integer + example: 22 + device_uid: + type: string + example: a1b2c3d4 + created_at: + type: string + format: date-time + updated_at: + type: string + format: date-time +required: + - id + - tenant_id + - created_by + - label + - kind + - host + - port + - device_uid + - created_at + - updated_at diff --git a/openapi/spec/components/schemas/teamConnectionPrefs.yaml b/openapi/spec/components/schemas/teamConnectionPrefs.yaml new file mode 100644 index 00000000000..7525ff6924e --- /dev/null +++ b/openapi/spec/components/schemas/teamConnectionPrefs.yaml @@ -0,0 +1,30 @@ +type: object +description: A member's personal auth preference for a team connection. +properties: + team_connection_id: + type: string + example: 3b8c2f1e-1d2a-4c5b-8e9f-0a1b2c3d4e5f + user_id: + type: string + example: 507f1f77bcf86cd799439011 + username: + type: string + example: root + auth_method: + type: string + example: key + key_fingerprint: + type: string + example: SHA256:abcd + created_at: + type: string + format: date-time + updated_at: + type: string + format: date-time +required: + - team_connection_id + - user_id + - username + - auth_method + - key_fingerprint diff --git a/openapi/spec/components/schemas/teamConnectionPrefsRequest.yaml b/openapi/spec/components/schemas/teamConnectionPrefsRequest.yaml new file mode 100644 index 00000000000..7131ec305bf --- /dev/null +++ b/openapi/spec/components/schemas/teamConnectionPrefsRequest.yaml @@ -0,0 +1,12 @@ +type: object +description: Payload to save the caller's own auth preference for a team connection. +properties: + username: + type: string + example: root + auth_method: + type: string + example: key + key_fingerprint: + type: string + example: SHA256:abcd diff --git a/openapi/spec/components/schemas/teamConnectionRequest.yaml b/openapi/spec/components/schemas/teamConnectionRequest.yaml new file mode 100644 index 00000000000..f09a190e0df --- /dev/null +++ b/openapi/spec/components/schemas/teamConnectionRequest.yaml @@ -0,0 +1,24 @@ +type: object +description: Payload to create or update a team connection's shared target. +properties: + label: + type: string + example: db-primary + kind: + type: string + enum: + - external + - device + example: external + host: + type: string + example: 10.0.0.5 + port: + type: integer + example: 22 + device_uid: + type: string + example: a1b2c3d4 +required: + - label + - kind diff --git a/openapi/spec/enterprise-openapi.yaml b/openapi/spec/enterprise-openapi.yaml index c9dc60c7f1b..589aa3fe805 100644 --- a/openapi/spec/enterprise-openapi.yaml +++ b/openapi/spec/enterprise-openapi.yaml @@ -18,6 +18,9 @@ servers: - url: / description: ShellHub server as a Enterprise instance. tags: + - name: connections + x-displayName: Connections + description: Saved SSH connections (personal and team) and trusted host keys. - name: devices x-displayName: Devices description: List, inspect, accept, reject, rename, and delete devices, and manage their custom fields. @@ -166,3 +169,11 @@ paths: $ref: paths/admin@api@announcements.yaml /admin/api/announcements/{uuid}: $ref: paths/admin@api@announcements@{uuid}.yaml + /api/connections/team: + $ref: paths/api@connections@team.yaml + /api/connections/team/{id}: + $ref: paths/api@connections@team@{id}.yaml + /api/connections/team/{id}/status: + $ref: paths/api@connections@team@{id}@status.yaml + /api/connections/team/{id}/prefs: + $ref: paths/api@connections@team@{id}@prefs.yaml diff --git a/openapi/spec/openapi.yaml b/openapi/spec/openapi.yaml index c1d8897bcb5..2b4c33a0343 100644 --- a/openapi/spec/openapi.yaml +++ b/openapi/spec/openapi.yaml @@ -30,6 +30,9 @@ components: An API key is an alternative to the standard JWT authentication. Authentication with this method is namespace-related and is not tied to any user. tags: + - name: connections + x-displayName: Connections + description: Saved SSH connections (personal and team) and trusted host keys. - name: internal description: Requests executed internally by ShellHub server. - name: external @@ -318,3 +321,23 @@ paths: $ref: paths/admin@api@announcements.yaml /admin/api/announcements/{uuid}: $ref: paths/admin@api@announcements@{uuid}.yaml + /api/connections: + $ref: paths/api@connections.yaml + /api/connections/{id}: + $ref: paths/api@connections@{id}.yaml + /api/connections/{id}/status: + $ref: paths/api@connections@{id}@status.yaml + /api/connections/host-key: + $ref: paths/api@connections@host-key.yaml + /api/connections/host-key/scan: + $ref: paths/api@connections@host-key@scan.yaml + /api/connections/host-key/accept: + $ref: paths/api@connections@host-key@accept.yaml + /api/connections/team: + $ref: paths/api@connections@team.yaml + /api/connections/team/{id}: + $ref: paths/api@connections@team@{id}.yaml + /api/connections/team/{id}/status: + $ref: paths/api@connections@team@{id}@status.yaml + /api/connections/team/{id}/prefs: + $ref: paths/api@connections@team@{id}@prefs.yaml diff --git a/openapi/spec/paths/api@connections.yaml b/openapi/spec/paths/api@connections.yaml new file mode 100644 index 00000000000..7f83a54d773 --- /dev/null +++ b/openapi/spec/paths/api@connections.yaml @@ -0,0 +1,65 @@ +get: + operationId: listConnections + summary: List connections + description: List the personal connections owned by the caller. + tags: + - connections + security: + - jwt: [] + - api-key: [] + parameters: + - $ref: ../components/parameters/query/pageQuery.yaml + - $ref: ../components/parameters/query/perPageQuery.yaml + responses: + '200': + description: Success to list connections. + headers: + X-Total-Count: + $ref: ../components/headers/XTotalCount.yaml + content: + application/json: + schema: + type: array + items: + $ref: ../components/schemas/connection.yaml + '401': + $ref: ../components/responses/401.yaml + '500': + $ref: ../components/responses/500.yaml +post: + operationId: createConnection + summary: Create a connection + description: Create a personal connection owned by the caller. + tags: + - connections + security: + - jwt: [] + - api-key: [] + requestBody: + content: + application/json: + schema: + $ref: ../components/schemas/connectionCreateRequest.yaml + responses: + '201': + description: Success to create connection. + content: + application/json: + schema: + $ref: ../components/schemas/connection.yaml + '400': + $ref: ../components/responses/400.yaml + '401': + $ref: ../components/responses/401.yaml + '422': + description: The external target is unreachable or not a permitted target. + content: + application/json: + schema: + type: object + properties: + error: + type: string + example: unreachable + '500': + $ref: ../components/responses/500.yaml diff --git a/openapi/spec/paths/api@connections@host-key.yaml b/openapi/spec/paths/api@connections@host-key.yaml new file mode 100644 index 00000000000..cd24792387a --- /dev/null +++ b/openapi/spec/paths/api@connections@host-key.yaml @@ -0,0 +1,78 @@ +get: + operationId: getKnownHost + summary: Get a stored host key + description: Return the stored known host for an external target, or null if none. + tags: + - connections + security: + - jwt: [] + - api-key: [] + parameters: + - name: host + in: query + required: true + schema: + type: string + - name: port + in: query + required: true + schema: + type: integer + - name: scope + in: query + required: true + schema: + type: string + enum: + - personal + - namespace + responses: + '200': + description: The stored known host, or null. + content: + application/json: + schema: + $ref: ../components/schemas/knownHost.yaml + '401': + $ref: ../components/responses/401.yaml + '500': + $ref: ../components/responses/500.yaml +delete: + operationId: deleteKnownHost + summary: Forget a stored host key + description: Forget the stored host key for an external target's scope. + tags: + - connections + security: + - jwt: [] + - api-key: [] + parameters: + - name: host + in: query + required: true + schema: + type: string + - name: port + in: query + required: true + schema: + type: integer + - name: scope + in: query + required: true + schema: + type: string + enum: + - personal + - namespace + responses: + '200': + description: Success to forget the host key. + '401': + $ref: ../components/responses/401.yaml + '403': + $ref: ../components/responses/403.yaml + '404': + $ref: ../components/responses/404.yaml + '500': + $ref: ../components/responses/500.yaml diff --git a/openapi/spec/paths/api@connections@host-key@accept.yaml b/openapi/spec/paths/api@connections@host-key@accept.yaml new file mode 100644 index 00000000000..0586ba31529 --- /dev/null +++ b/openapi/spec/paths/api@connections@host-key@accept.yaml @@ -0,0 +1,39 @@ +post: + operationId: acceptKnownHost + summary: Trust a host key + description: Trust (store) a host key for an external target's scope. + tags: + - connections + security: + - jwt: [] + - api-key: [] + requestBody: + content: + application/json: + schema: + $ref: ../components/schemas/knownHostAcceptRequest.yaml + responses: + '200': + description: Success to trust the host key. + content: + application/json: + schema: + $ref: ../components/schemas/knownHost.yaml + '400': + $ref: ../components/responses/400.yaml + '401': + $ref: ../components/responses/401.yaml + '403': + $ref: ../components/responses/403.yaml + '422': + description: The supplied public key could not be parsed. + content: + application/json: + schema: + type: object + properties: + error: + type: string + example: invalid_key + '500': + $ref: ../components/responses/500.yaml diff --git a/openapi/spec/paths/api@connections@host-key@scan.yaml b/openapi/spec/paths/api@connections@host-key@scan.yaml new file mode 100644 index 00000000000..ae5724cfc14 --- /dev/null +++ b/openapi/spec/paths/api@connections@host-key@scan.yaml @@ -0,0 +1,37 @@ +post: + operationId: scanKnownHost + summary: Scan a host key + description: Read an external target's host key and report it against the stored one. + tags: + - connections + security: + - jwt: [] + - api-key: [] + requestBody: + content: + application/json: + schema: + $ref: ../components/schemas/knownHostScanRequest.yaml + responses: + '200': + description: The scanned host key and its verification status. + content: + application/json: + schema: + $ref: ../components/schemas/knownHostScanResult.yaml + '400': + $ref: ../components/responses/400.yaml + '401': + $ref: ../components/responses/401.yaml + '422': + description: The host could not be reached or is not a permitted target. + content: + application/json: + schema: + type: object + properties: + error: + type: string + example: unreachable + '500': + $ref: ../components/responses/500.yaml diff --git a/openapi/spec/paths/api@connections@team.yaml b/openapi/spec/paths/api@connections@team.yaml new file mode 100644 index 00000000000..e985945c43c --- /dev/null +++ b/openapi/spec/paths/api@connections@team.yaml @@ -0,0 +1,63 @@ +get: + operationId: listTeamConnections + summary: List team connections + description: List every team connection in the namespace. Any member can see them. + tags: + - connections + security: + - jwt: [] + - api-key: [] + parameters: + - $ref: ../components/parameters/query/pageQuery.yaml + - $ref: ../components/parameters/query/perPageQuery.yaml + responses: + '200': + description: Success to list team connections. + headers: + X-Total-Count: + $ref: ../components/headers/XTotalCount.yaml + content: + application/json: + schema: + type: array + items: + $ref: ../components/schemas/teamConnection.yaml + '401': + $ref: ../components/responses/401.yaml + '402': + $ref: ../components/responses/402.yaml + '500': + $ref: ../components/responses/500.yaml +post: + operationId: createTeamConnection + summary: Create a team connection + description: Create a shared connection. Requires operator+. + tags: + - connections + security: + - jwt: [] + - api-key: [] + requestBody: + content: + application/json: + schema: + $ref: ../components/schemas/teamConnectionRequest.yaml + responses: + '200': + description: Success to create team connection. + content: + application/json: + schema: + $ref: ../components/schemas/teamConnection.yaml + '400': + $ref: ../components/responses/400.yaml + '401': + $ref: ../components/responses/401.yaml + '402': + $ref: ../components/responses/402.yaml + '403': + $ref: ../components/responses/403.yaml + '409': + $ref: ../components/responses/409.yaml + '500': + $ref: ../components/responses/500.yaml diff --git a/openapi/spec/paths/api@connections@team@{id}.yaml b/openapi/spec/paths/api@connections@team@{id}.yaml new file mode 100644 index 00000000000..f6b405d9d19 --- /dev/null +++ b/openapi/spec/paths/api@connections@team@{id}.yaml @@ -0,0 +1,64 @@ +parameters: + - name: id + schema: + description: Team connection's ID. + type: string + in: path + required: true +put: + operationId: updateTeamConnection + summary: Update a team connection + description: Update the shared target. Requires operator+. + tags: + - connections + security: + - jwt: [] + - api-key: [] + requestBody: + content: + application/json: + schema: + $ref: ../components/schemas/teamConnectionRequest.yaml + responses: + '200': + description: Success to update team connection. + content: + application/json: + schema: + $ref: ../components/schemas/teamConnection.yaml + '400': + $ref: ../components/responses/400.yaml + '401': + $ref: ../components/responses/401.yaml + '402': + $ref: ../components/responses/402.yaml + '403': + $ref: ../components/responses/403.yaml + '404': + $ref: ../components/responses/404.yaml + '409': + $ref: ../components/responses/409.yaml + '500': + $ref: ../components/responses/500.yaml +delete: + operationId: deleteTeamConnection + summary: Delete a team connection + description: Delete a shared connection. Requires operator+. + tags: + - connections + security: + - jwt: [] + - api-key: [] + responses: + '200': + description: Success to delete team connection. + '401': + $ref: ../components/responses/401.yaml + '402': + $ref: ../components/responses/402.yaml + '403': + $ref: ../components/responses/403.yaml + '404': + $ref: ../components/responses/404.yaml + '500': + $ref: ../components/responses/500.yaml diff --git a/openapi/spec/paths/api@connections@team@{id}@prefs.yaml b/openapi/spec/paths/api@connections@team@{id}@prefs.yaml new file mode 100644 index 00000000000..95beab428d5 --- /dev/null +++ b/openapi/spec/paths/api@connections@team@{id}@prefs.yaml @@ -0,0 +1,62 @@ +parameters: + - name: id + schema: + description: Team connection's ID. + type: string + in: path + required: true +get: + operationId: getTeamConnectionPrefs + summary: Get your team connection auth preference + description: Return the caller's own auth preference (empty when none saved yet). + tags: + - connections + security: + - jwt: [] + - api-key: [] + responses: + '200': + description: Success to get the auth preference. + content: + application/json: + schema: + $ref: ../components/schemas/teamConnectionPrefs.yaml + '401': + $ref: ../components/responses/401.yaml + '402': + $ref: ../components/responses/402.yaml + '404': + $ref: ../components/responses/404.yaml + '500': + $ref: ../components/responses/500.yaml +put: + operationId: updateTeamConnectionPrefs + summary: Save your team connection auth preference + description: Save the caller's own auth preference for a team connection. + tags: + - connections + security: + - jwt: [] + - api-key: [] + requestBody: + content: + application/json: + schema: + $ref: ../components/schemas/teamConnectionPrefsRequest.yaml + responses: + '200': + description: Success to save the auth preference. + content: + application/json: + schema: + $ref: ../components/schemas/teamConnectionPrefs.yaml + '400': + $ref: ../components/responses/400.yaml + '401': + $ref: ../components/responses/401.yaml + '402': + $ref: ../components/responses/402.yaml + '404': + $ref: ../components/responses/404.yaml + '500': + $ref: ../components/responses/500.yaml diff --git a/openapi/spec/paths/api@connections@team@{id}@status.yaml b/openapi/spec/paths/api@connections@team@{id}@status.yaml new file mode 100644 index 00000000000..76ce99716fa --- /dev/null +++ b/openapi/spec/paths/api@connections@team@{id}@status.yaml @@ -0,0 +1,31 @@ +parameters: + - name: id + schema: + description: Team connection's ID. + type: string + in: path + required: true +get: + operationId: getTeamConnectionStatus + summary: Get team connection status + description: Report whether the team connection's target is currently reachable. + tags: + - connections + security: + - jwt: [] + - api-key: [] + responses: + '200': + description: Success to get team connection status. + content: + application/json: + schema: + $ref: ../components/schemas/connectionStatus.yaml + '401': + $ref: ../components/responses/401.yaml + '402': + $ref: ../components/responses/402.yaml + '404': + $ref: ../components/responses/404.yaml + '500': + $ref: ../components/responses/500.yaml diff --git a/openapi/spec/paths/api@connections@{id}.yaml b/openapi/spec/paths/api@connections@{id}.yaml new file mode 100644 index 00000000000..282d6630298 --- /dev/null +++ b/openapi/spec/paths/api@connections@{id}.yaml @@ -0,0 +1,76 @@ +parameters: + - name: id + schema: + description: Connection's ID. + type: string + in: path + required: true +get: + operationId: getConnection + summary: Get a connection + description: Get a single personal connection owned by the caller. + tags: + - connections + security: + - jwt: [] + - api-key: [] + responses: + '200': + description: Success to get connection. + content: + application/json: + schema: + $ref: ../components/schemas/connection.yaml + '401': + $ref: ../components/responses/401.yaml + '404': + $ref: ../components/responses/404.yaml + '500': + $ref: ../components/responses/500.yaml +put: + operationId: updateConnection + summary: Update a connection + description: Update a personal connection owned by the caller. + tags: + - connections + security: + - jwt: [] + - api-key: [] + requestBody: + content: + application/json: + schema: + $ref: ../components/schemas/connectionUpdateRequest.yaml + responses: + '200': + description: Success to update connection. + content: + application/json: + schema: + $ref: ../components/schemas/connection.yaml + '400': + $ref: ../components/responses/400.yaml + '401': + $ref: ../components/responses/401.yaml + '404': + $ref: ../components/responses/404.yaml + '500': + $ref: ../components/responses/500.yaml +delete: + operationId: deleteConnection + summary: Delete a connection + description: Delete a personal connection owned by the caller. + tags: + - connections + security: + - jwt: [] + - api-key: [] + responses: + '200': + description: Success to delete connection. + '401': + $ref: ../components/responses/401.yaml + '404': + $ref: ../components/responses/404.yaml + '500': + $ref: ../components/responses/500.yaml diff --git a/openapi/spec/paths/api@connections@{id}@status.yaml b/openapi/spec/paths/api@connections@{id}@status.yaml new file mode 100644 index 00000000000..fe2c62c4568 --- /dev/null +++ b/openapi/spec/paths/api@connections@{id}@status.yaml @@ -0,0 +1,29 @@ +parameters: + - name: id + schema: + description: Connection's ID. + type: string + in: path + required: true +get: + operationId: getConnectionStatus + summary: Get connection status + description: Report whether the connection's target is currently reachable. + tags: + - connections + security: + - jwt: [] + - api-key: [] + responses: + '200': + description: Success to get connection status. + content: + application/json: + schema: + $ref: ../components/schemas/connectionStatus.yaml + '401': + $ref: ../components/responses/401.yaml + '404': + $ref: ../components/responses/404.yaml + '500': + $ref: ../components/responses/500.yaml diff --git a/pkg/api/authorizer/permissions.go b/pkg/api/authorizer/permissions.go index ef2d8197359..19dc930c7d6 100644 --- a/pkg/api/authorizer/permissions.go +++ b/pkg/api/authorizer/permissions.go @@ -57,12 +57,24 @@ const ( TunnelsCreate TunnelsDelete + + // Connection permissions. Declared last so existing permission ordinals are + // preserved. + ConnectionCreate + ConnectionUpdate + ConnectionDelete ) var observerPermissions = []Permission{ DeviceConnect, DeviceDetails, + // Observers can create/manage their OWN personal connections; the service + // blocks them from writing namespace (team) connections. + ConnectionCreate, + ConnectionUpdate, + ConnectionDelete, + SessionDetails, } @@ -75,6 +87,10 @@ var operatorPermissions = []Permission{ DeviceUpdate, DeviceCustomFieldUpdate, + ConnectionCreate, + ConnectionUpdate, + ConnectionDelete, + TagCreate, TagUpdate, TagDelete, @@ -92,6 +108,10 @@ var adminPermissions = []Permission{ DeviceUpdate, DeviceCustomFieldUpdate, + ConnectionCreate, + ConnectionUpdate, + ConnectionDelete, + TagCreate, TagUpdate, TagDelete, @@ -138,6 +158,10 @@ var ownerPermissions = []Permission{ DeviceUpdate, DeviceCustomFieldUpdate, + ConnectionCreate, + ConnectionUpdate, + ConnectionDelete, + TagCreate, TagUpdate, TagDelete, diff --git a/pkg/api/authorizer/role_test.go b/pkg/api/authorizer/role_test.go index a3d0f252f14..a7c51af58d8 100644 --- a/pkg/api/authorizer/role_test.go +++ b/pkg/api/authorizer/role_test.go @@ -70,6 +70,9 @@ func TestRolePermissions(t *testing.T) { authorizer.DeviceDetails, authorizer.DeviceUpdate, authorizer.DeviceCustomFieldUpdate, + authorizer.ConnectionCreate, + authorizer.ConnectionUpdate, + authorizer.ConnectionDelete, authorizer.TagCreate, authorizer.TagUpdate, authorizer.TagDelete, @@ -120,6 +123,9 @@ func TestRolePermissions(t *testing.T) { authorizer.DeviceDetails, authorizer.DeviceUpdate, authorizer.DeviceCustomFieldUpdate, + authorizer.ConnectionCreate, + authorizer.ConnectionUpdate, + authorizer.ConnectionDelete, authorizer.TagCreate, authorizer.TagUpdate, authorizer.TagDelete, @@ -160,6 +166,9 @@ func TestRolePermissions(t *testing.T) { authorizer.DeviceDetails, authorizer.DeviceUpdate, authorizer.DeviceCustomFieldUpdate, + authorizer.ConnectionCreate, + authorizer.ConnectionUpdate, + authorizer.ConnectionDelete, authorizer.TagCreate, authorizer.TagUpdate, authorizer.TagDelete, @@ -172,6 +181,9 @@ func TestRolePermissions(t *testing.T) { expected: []authorizer.Permission{ authorizer.DeviceConnect, authorizer.DeviceDetails, + authorizer.ConnectionCreate, + authorizer.ConnectionUpdate, + authorizer.ConnectionDelete, authorizer.SessionDetails, }, }, diff --git a/pkg/api/requests/connection.go b/pkg/api/requests/connection.go new file mode 100644 index 00000000000..dba60c51da4 --- /dev/null +++ b/pkg/api/requests/connection.go @@ -0,0 +1,68 @@ +package requests + +import ( + "github.com/shellhub-io/shellhub/pkg/api/query" +) + +// ConnectionCreate is the request data for creating a connection. Kind selects +// the target: "external" requires Host/Port; "device" requires DeviceUID. The +// connection is personal: it belongs to the caller (X-ID). +type ConnectionCreate struct { + TenantID string `header:"X-Tenant-ID"` + UserID string `header:"X-ID"` + Label string `json:"label" validate:"required,min=1,max=200"` + Kind string `json:"kind" validate:"required,oneof=external device"` + Host string `json:"host" validate:"required_if=Kind external,omitempty,hostname_rfc1123|ip"` + Port int `json:"port" validate:"required_if=Kind external,omitempty,min=1,max=65535"` + DeviceUID string `json:"device_uid" validate:"required_if=Kind device"` + Username string `json:"username" validate:"omitempty,max=256"` + AuthMethod string `json:"auth_method" validate:"omitempty,oneof=password key"` + KeyFingerprint string `json:"key_fingerprint" validate:"omitempty,max=256"` + // Force saves an external connection even if its target is currently unreachable. + Force bool `json:"force"` +} + +// ConnectionUpdate is the request data for updating a connection. +type ConnectionUpdate struct { + TenantID string `header:"X-Tenant-ID"` + UserID string `header:"X-ID"` + ID string `param:"id" validate:"required"` + Label string `json:"label" validate:"required,min=1,max=200"` + Kind string `json:"kind" validate:"required,oneof=external device"` + Host string `json:"host" validate:"required_if=Kind external,omitempty,hostname_rfc1123|ip"` + Port int `json:"port" validate:"required_if=Kind external,omitempty,min=1,max=65535"` + DeviceUID string `json:"device_uid" validate:"required_if=Kind device"` + Username string `json:"username" validate:"omitempty,max=256"` + AuthMethod string `json:"auth_method" validate:"omitempty,oneof=password key"` + KeyFingerprint string `json:"key_fingerprint" validate:"omitempty,max=256"` +} + +// ConnectionList is the request data for listing connections. +type ConnectionList struct { + TenantID string `header:"X-Tenant-ID"` + UserID string `header:"X-ID"` + query.Paginator + query.Sorter +} + +// ConnectionProbe is the request data for testing whether a host:port is +// reachable before saving an external connection. +type ConnectionProbe struct { + TenantID string `header:"X-Tenant-ID"` + Host string `json:"host" validate:"required,hostname_rfc1123|ip"` + Port int `json:"port" validate:"required,min=1,max=65535"` +} + +// ConnectionGet is the request data for getting a single connection. +type ConnectionGet struct { + TenantID string `header:"X-Tenant-ID"` + UserID string `header:"X-ID"` + ID string `param:"id" validate:"required"` +} + +// ConnectionDelete is the request data for deleting a connection. +type ConnectionDelete struct { + TenantID string `header:"X-Tenant-ID"` + UserID string `header:"X-ID"` + ID string `param:"id" validate:"required"` +} diff --git a/pkg/api/requests/known_host.go b/pkg/api/requests/known_host.go new file mode 100644 index 00000000000..c32d3546f92 --- /dev/null +++ b/pkg/api/requests/known_host.go @@ -0,0 +1,48 @@ +package requests + +import "github.com/shellhub-io/shellhub/pkg/api/authorizer" + +// Scope selects which known host record a request targets: "personal" (the +// caller's own, per-user) or "namespace" (shared with the team). + +// KnownHostScan probes an external target's host key and reports it against +// what is stored. +type KnownHostScan struct { + TenantID string `header:"X-Tenant-ID"` + UserID string `header:"X-ID"` + Host string `json:"host" validate:"required,hostname_rfc1123|ip"` + Port int `json:"port" validate:"required,min=1,max=65535"` + Scope string `json:"scope" validate:"required,oneof=personal namespace"` +} + +// KnownHostAccept stores (trusts) a host key for a target. +type KnownHostAccept struct { + TenantID string `header:"X-Tenant-ID"` + UserID string `header:"X-ID"` + Role authorizer.Role `header:"X-Role"` + Host string `json:"host" validate:"required,hostname_rfc1123|ip"` + Port int `json:"port" validate:"required,min=1,max=65535"` + Scope string `json:"scope" validate:"required,oneof=personal namespace"` + KeyType string `json:"key_type" validate:"required,max=64"` + PublicKey string `json:"public_key" validate:"required"` + Fingerprint string `json:"fingerprint" validate:"required,max=128"` +} + +// KnownHostGet reads the stored known host for a target. +type KnownHostGet struct { + TenantID string `header:"X-Tenant-ID"` + UserID string `header:"X-ID"` + Host string `query:"host" validate:"required"` + Port int `query:"port" validate:"required,min=1,max=65535"` + Scope string `query:"scope" validate:"required,oneof=personal namespace"` +} + +// KnownHostDelete forgets the stored known host for a target. +type KnownHostDelete struct { + TenantID string `header:"X-Tenant-ID"` + UserID string `header:"X-ID"` + Role authorizer.Role `header:"X-Role"` + Host string `query:"host" validate:"required"` + Port int `query:"port" validate:"required,min=1,max=65535"` + Scope string `query:"scope" validate:"required,oneof=personal namespace"` +} diff --git a/pkg/egress/egress.go b/pkg/egress/egress.go new file mode 100644 index 00000000000..a11137db171 --- /dev/null +++ b/pkg/egress/egress.go @@ -0,0 +1,84 @@ +// Package egress dials connection targets through an SSRF guardian so the server +// can't be used as a pivot to reach internal, reserved, or metadata addresses. It +// is the one place this logic lives, shared by every path that opens a connection +// to a user-supplied host: the API connection service, the SSH direct-connect path, +// and the cloud team-connection service. Keeping it here stops the guardian policy +// from drifting between them. +package egress + +import ( + "context" + "errors" + "net" + "net/netip" + "strconv" + "time" + + "code.dny.dev/ssrf" + "github.com/shellhub-io/shellhub/pkg/envs" +) + +// devAllowedV4Prefixes are private ranges let through the guardian on top of the +// public-only default. They are applied ONLY in development, so external connects +// can be exercised against the docker host or a host-machine sshd. In production +// the guardian is public-only with no extra configuration and no code to remove. +var devAllowedV4Prefixes = []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("172.16.0.0/12"), + netip.MustParsePrefix("192.168.0.0/16"), +} + +// ErrBlocked means the target isn't a permitted connection endpoint: the guardian +// rejected it (loopback, link-local/metadata, reserved, or a private address that +// isn't allowlisted). Distinct from a host that is simply unreachable. +var ErrBlocked = errors.New("the address is not a permitted connection target") + +const dialTimeout = 4 * time.Second + +// GuardedDialer returns a dialer whose Control validates the real resolved IP at +// the socket layer (closing the DNS-rebind window a resolve-then-dial would leave) +// and permits only the given port. Private ranges are allowed only in development; +// otherwise the guardian is public-only. Callers needing a different timeout (a +// host-key scan, a full connect) set Timeout on the returned dialer. +func GuardedDialer(port int) *net.Dialer { + opts := []ssrf.Option{ + ssrf.WithPorts(uint16(port)), //nolint:gosec // port is validated to 1-65535 by the request layer. + } + if envs.IsDevelopment() { + opts = append(opts, ssrf.WithAllowedV4Prefixes(devAllowedV4Prefixes...)) + } + + return &net.Dialer{ + Timeout: dialTimeout, + Control: ssrf.New(opts...).Safe, + } +} + +// IsBlocked reports whether err is the guardian rejecting the target by policy, as +// opposed to the host being unreachable for another reason. +func IsBlocked(err error) bool { + return errors.Is(err, ssrf.ErrProhibitedIP) || + errors.Is(err, ssrf.ErrProhibitedPort) || + errors.Is(err, ssrf.ErrProhibitedNetwork) +} + +// Probe dials host:port through the guardian and reports whether it is reachable +// and, separately, whether the guardian blocked it (policy) versus the host being +// down. The dial honors ctx, so it is cancelled when the caller's request ends. +func Probe(ctx context.Context, host string, port int) (reachable, blocked bool) { + conn, err := GuardedDialer(port).DialContext(ctx, "tcp", net.JoinHostPort(host, strconv.Itoa(port))) + if err != nil { + return false, IsBlocked(err) + } + conn.Close() //nolint:errcheck + + return true, false +} + +// Reachable reports plain reachability, collapsing the blocked distinction (used +// for connection status, where a blocked target reads the same as unreachable). +func Reachable(ctx context.Context, host string, port int) bool { + reachable, _ := Probe(ctx, host, port) + + return reachable +} diff --git a/pkg/egress/egress_test.go b/pkg/egress/egress_test.go new file mode 100644 index 00000000000..76ba3d29c7f --- /dev/null +++ b/pkg/egress/egress_test.go @@ -0,0 +1,25 @@ +package egress + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +// The guardian rejects internal/reserved targets at the socket layer, so they are +// never reached and report unreachable. Loopback, link-local, and unspecified are +// blocked regardless of the development allowlist (which only adds RFC1918 ranges). +func TestProbeGuardsEgress(t *testing.T) { + for _, host := range []string{ + "127.0.0.1", + "::1", + "169.254.169.254", + "0.0.0.0", + "this.host.does.not.exist.invalid", + } { + t.Run(host, func(t *testing.T) { + assert.False(t, Reachable(context.Background(), host, 22)) + }) + } +} diff --git a/pkg/models/connection.go b/pkg/models/connection.go new file mode 100644 index 00000000000..8f6faf64af8 --- /dev/null +++ b/pkg/models/connection.go @@ -0,0 +1,42 @@ +package models + +import "time" + +// ConnectionKind discriminates how a connection reaches its target. +type ConnectionKind string + +const ( + // ConnectionKindExternal dials the target SSH endpoint directly, without the + // agent (an external host reached by Host:Port). + ConnectionKindExternal ConnectionKind = "external" + // ConnectionKindDevice reaches an agent-registered device over the reverse + // tunnel, reusing the standard device session flow. + ConnectionKindDevice ConnectionKind = "device" +) + +// Connection is a saved, reusable way to reach an SSH target, distinct from +// [Device]: it is user-provisioned inventory on top of the agent-registered +// fleet. The target is discriminated by Kind. A connection is personal (belongs +// to OwnerID); sharing one with a team is a separate Enterprise/Cloud capability. +type Connection struct { + ID string `json:"id"` + TenantID string `json:"tenant_id"` + // OwnerID is the user the connection belongs to. It scopes visibility: only + // the owner can see or use a connection. + OwnerID string `json:"owner_id"` + Label string `json:"label"` + Kind ConnectionKind `json:"kind"` + // Host and Port hold the dial target for Kind == ConnectionKindExternal. + Host string `json:"host"` + Port int `json:"port"` + // DeviceUID references the target device for Kind == ConnectionKindDevice. + DeviceUID string `json:"device_uid"` + Username string `json:"username"` + // AuthMethod is "password" or "key"; empty means none saved. + AuthMethod string `json:"auth_method"` + // KeyFingerprint points at the SSH key to use (resolved against the owner's + // vault). The secret never reaches the server. + KeyFingerprint string `json:"key_fingerprint"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} diff --git a/pkg/models/known_host.go b/pkg/models/known_host.go new file mode 100644 index 00000000000..7b674932c4e --- /dev/null +++ b/pkg/models/known_host.go @@ -0,0 +1,40 @@ +package models + +import "time" + +// KnownHost is an accepted SSH host key for an external connection target, +// recorded on trust-on-first-use so later connects can verify the host hasn't +// changed (guarding against a man-in-the-middle on the server egress path). +// +// Scope follows the connection it was reached through: a personal connection +// records a per-user known host (OwnerID set); a team connection records one +// shared with the whole namespace (OwnerID empty). +type KnownHost struct { + ID string `json:"id"` + TenantID string `json:"tenant_id"` + // OwnerID is set for a personal (per-user) known host; empty means a + // namespace-shared (team) one. + OwnerID string `json:"owner_id"` + Host string `json:"host"` + Port int `json:"port"` + KeyType string `json:"key_type"` + // PublicKey is the host key in authorized_keys format. + PublicKey string `json:"public_key"` + Fingerprint string `json:"fingerprint"` + AcceptedBy string `json:"accepted_by"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// KnownHostStatus is the verification state of a scanned host key against what +// is stored. +type KnownHostStatus string + +const ( + // KnownHostUnverified means no key is stored yet for this target (first use). + KnownHostUnverified KnownHostStatus = "unverified" + // KnownHostTrusted means the scanned key matches the stored one. + KnownHostTrusted KnownHostStatus = "trusted" + // KnownHostChanged means a key is stored but the scanned one differs (danger). + KnownHostChanged KnownHostStatus = "changed" +) diff --git a/ssh/go.mod b/ssh/go.mod index 1953989be55..c2d4af861eb 100644 --- a/ssh/go.mod +++ b/ssh/go.mod @@ -3,6 +3,7 @@ module github.com/shellhub-io/shellhub/ssh go 1.25.8 require ( + code.dny.dev/ssrf v0.2.0 github.com/Masterminds/semver v1.5.0 github.com/gliderlabs/ssh v0.3.8 github.com/golang-jwt/jwt/v5 v5.3.1 @@ -55,6 +56,7 @@ require ( github.com/vmihailenco/go-tinylfu v0.2.2 // indirect github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect + golang.org/x/exp v0.0.0-20221126150942-6ab00d035af9 // indirect golang.org/x/sync v0.21.0 // indirect golang.org/x/sys v0.46.0 // indirect golang.org/x/text v0.38.0 // indirect diff --git a/ssh/go.sum b/ssh/go.sum index ce03b778a5f..3186701ad4d 100644 --- a/ssh/go.sum +++ b/ssh/go.sum @@ -1,3 +1,5 @@ +code.dny.dev/ssrf v0.2.0 h1:wCBP990rQQ1CYfRpW+YK1+8xhwUjv189AQ3WMo1jQaI= +code.dny.dev/ssrf v0.2.0/go.mod h1:B+91l25OnyaLIeCx0WRJN5qfJ/4/ZTZxRXgm0lj/2w8= dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA= github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEKWjV8V+WSxDXJ4NFATAsZjh8iIbsQIg= @@ -255,6 +257,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/crypto v0.53.0 h1:QZ4Muo8THX6CizN2vPPd5fBGHyogrdK9fG4wLPFUsto= golang.org/x/crypto v0.53.0/go.mod h1:DNLU434OwVakk9PzuwV8w62mAJpRJL3vsgcfp4Qnsio= +golang.org/x/exp v0.0.0-20221126150942-6ab00d035af9 h1:yZNXmy+j/JpX19vZkVktWqAo7Gny4PBWYYK3zskGpx4= +golang.org/x/exp v0.0.0-20221126150942-6ab00d035af9/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= diff --git a/ssh/main.go b/ssh/main.go index 91f4436eed4..2429e0a6b8d 100644 --- a/ssh/main.go +++ b/ssh/main.go @@ -72,6 +72,7 @@ func main() { router := h.Router web.NewSSHServerBridge(router, cache) + web.NewConnectBridge(router) if envs.IsDevelopment() { runtime.SetBlockProfileRate(1) diff --git a/ssh/web/connect.go b/ssh/web/connect.go new file mode 100644 index 00000000000..59e6ca57d91 --- /dev/null +++ b/ssh/web/connect.go @@ -0,0 +1,340 @@ +package web + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "math" + "net" + "net/http" + "strconv" + "time" + + "github.com/labstack/echo/v4" + "github.com/shellhub-io/shellhub/pkg/egress" + "github.com/shellhub-io/shellhub/ssh/pkg/magickey" + "github.com/shellhub-io/shellhub/ssh/web/pkg/token" + log "github.com/sirupsen/logrus" + gossh "golang.org/x/crypto/ssh" + "golang.org/x/net/websocket" +) + +// NewConnectBridge registers the routes for the direct connection bridge: a +// lightweight web terminal that dials an external SSH endpoint directly. Unlike +// [NewSSHServerBridge], it does NOT route through the agent/reverse-tunnel nor +// the device session machinery — it dials host:port itself and pipes the shell +// to the websocket. Used by saved "direct" connections. +// +// MVP: password authentication only. The credential is encrypted in transit and +// kept only for the short TTL of the token cache; it is never persisted. +func NewConnectBridge(router *echo.Echo) { + const route = "/ws/connect" + + manager := newManager(30 * time.Second) + + // POST receives the connection credentials and returns a short-lived token. + router.Add(http.MethodPost, route, echo.WrapHandler(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { + type Success struct { + Token string `json:"token"` + } + + type Fail struct { + Error string `json:"error"` + } + + decoder := json.NewDecoder(req.Body) + encoder := json.NewEncoder(res) + + response := func(res http.ResponseWriter, status int, data any) { + res.Header().Set("Content-Type", "application/json") + res.WriteHeader(status) + + encoder.Encode(data) //nolint: errcheck,errchkjson + } + + var request Credentials + if err := decoder.Decode(&request); err != nil { + // Keep the raw decode error out of the response; it can carry internal + // detail and the client can't act on it anyway. + log.WithError(err).Debug("failed to decode direct connect credentials") + response(res, http.StatusBadRequest, Fail{Error: "invalid request body"}) + + return + } + + key := magickey.GetReference() + + tkn, err := token.NewToken(key) + if err != nil { + log.WithError(err).Error("failed to create direct connect token") + response(res, http.StatusInternalServerError, Fail{Error: "failed to create token"}) + + return + } + + request.encryptPassword(key) //nolint:errcheck + + manager.save(tkn.ID, &request) + + response(res, http.StatusOK, Success{Token: tkn.ID}) + }))) + + // GET upgrades to a websocket and pipes the shell of the dialed host. + router.Add(http.MethodGet, route, echo.WrapHandler(websocket.Handler(func(wsconn *websocket.Conn) { + defer wsconn.Close() + + exit := func(wsconn *websocket.Conn, err error) { + log.WithError(err).Error("web connect terminal error") + + buffer, marshalErr := json.Marshal(Message{Kind: messageKindError, Data: err.Error()}) + if marshalErr != nil { + return + } + + wsconn.Write(buffer) //nolint:errcheck + } + + tkn, err := getToken(wsconn.Request()) + if err != nil { + exit(wsconn, ErrWebSocketGetToken) + + return + } + + cols, rows, err := getDimensions(wsconn.Request()) + if err != nil { + exit(wsconn, ErrWebSocketGetDimensions) + + return + } + + creds, ok := manager.get(tkn) + if !ok { + exit(wsconn, ErrBridgeCredentialsNotFound) + + return + } + + conn := NewConn(wsconn) + defer conn.Close() + + go conn.KeepAlive() + + creds.decryptPassword(magickey.GetReference()) //nolint:errcheck + + if err := connectSession(wsconn.Request().Context(), conn, creds, Dimensions{cols, rows}); err != nil { + exit(wsconn, err) + + return + } + }))) +} + +// connectSession dials the external SSH endpoint described by creds and pipes a +// shell to the websocket connection. The context is tied to the websocket +// request, so a dropped client cancels an in-flight dial. +func connectSession(ctx context.Context, conn *Conn, creds *Credentials, dim Dimensions) error { + logger := log.WithFields(log.Fields{ + "user": creds.Username, + "host": creds.Host, + "port": creds.Port, + }) + + logger.Info("handling direct connect request started") + defer logger.Info("handling direct connect request end") + + addr := net.JoinHostPort(creds.Host, strconv.Itoa(creds.Port)) + + // Public-key auth keeps the private key in the browser: the server advertises + // the supplied public key and proxies each signing challenge over the + // websocket via [Signer]. Falls back to password when no key is selected. + var auth []gossh.AuthMethod + if creds.isPublicKey() { + pubKey, _, _, _, parseErr := gossh.ParseAuthorizedKey([]byte(creds.PublicKey)) + if parseErr != nil { + logger.WithError(parseErr).Debug("failed to parse the direct connection public key") + + return ErrDataPublicKey + } + + auth = []gossh.AuthMethod{gossh.PublicKeys(&Signer{conn: conn, publicKey: &pubKey})} + } else { + auth = []gossh.AuthMethod{gossh.Password(creds.Password)} + } + + // creds come straight from the browser; reject an out-of-range port so the + // uint16 conversion below is well-defined. + if creds.Port < 1 || creds.Port > math.MaxUint16 { + logger.WithField("port", creds.Port).Debug("rejected out-of-range port") + + return ErrAuthentication + } + + // The host comes straight from the browser, so dial through an SSRF guardian: + // guardian.Safe validates the real resolved IP (and the port) at the socket + // layer right before connecting, so the server can't be used as a pivot to + // reach internal/reserved addresses. Only the configured target port passes. + dialer := egress.GuardedDialer(creds.Port) + // Live session, so a longer timeout than the reachability probe's default. + dialer.Timeout = 30 * time.Second + + netConn, err := dialer.DialContext(ctx, "tcp", addr) + if err != nil { + // Distinguish a guardian rejection (host/port not a permitted target) + // from an ordinary unreachable host. + if egress.IsBlocked(err) { + logger.WithError(err).Warn("blocked direct connect to a disallowed host") + + return ErrEgressBlocked + } + + logger.WithError(err).Debug("failed to dial the direct host") + + return ErrUnreachableHost + } + + // Verify the live host key against the one the browser confirmed (TOFU). A + // mismatch means the host differs from what the user trusted, so abort. An + // empty key would disable verification entirely (the value is browser-supplied, + // so an attacker who can shape the request could force it), so refuse it. + if creds.KnownHostKey == "" { + netConn.Close() //nolint:errcheck + logger.Warn("direct connect rejected: no verified host key") + + return ErrHostKeyUnverified + } + + expected, _, _, _, parseErr := gossh.ParseAuthorizedKey([]byte(creds.KnownHostKey)) + if parseErr != nil { + netConn.Close() //nolint:errcheck + logger.WithError(parseErr).Warn("invalid known host key supplied") + + return ErrAuthentication + } + + // gossh wraps the callback error, so errors.Is won't reliably recover the + // sentinel after the handshake. Record the mismatch out of band instead. + hostKeyMismatch := false + expectedKey := expected.Marshal() + hostKeyCallback := func(_ string, _ net.Addr, key gossh.PublicKey) error { + if !bytes.Equal(key.Marshal(), expectedKey) { + hostKeyMismatch = true + + return ErrHostKeyMismatch + } + + return nil + } + + sshConn, chans, reqs, err := gossh.NewClientConn(netConn, addr, &gossh.ClientConfig{ //nolint: exhaustruct + User: creds.Username, + Auth: auth, + HostKeyCallback: hostKeyCallback, + Timeout: 30 * time.Second, + }) + if err != nil { + netConn.Close() //nolint:errcheck + + if hostKeyMismatch { + logger.Warn("host key mismatch on direct connect") + + return ErrHostKeyMismatch + } + + return ErrAuthentication + } + + client := gossh.NewClient(sshConn, chans, reqs) + defer client.Close() + + agent, err := client.NewSession() + if err != nil { + return ErrSession + } + + defer agent.Close() + + // Return a sentinel so the raw error (internal detail) isn't echoed to the browser. + stdin, err := agent.StdinPipe() + if err != nil { + logger.WithError(err).Debug("failed to create the stdin pipe") + + return ErrSession + } + + stdout, err := agent.StdoutPipe() + if err != nil { + logger.WithError(err).Debug("failed to create the stdout pipe") + + return ErrSession + } + + stderr, err := agent.StderrPipe() + if err != nil { + logger.WithError(err).Debug("failed to create the stderr pipe") + + return ErrSession + } + + if err := agent.RequestPty("xterm", int(dim.Rows), int(dim.Cols), gossh.TerminalModes{ + gossh.ECHO: 1, + gossh.TTY_OP_ISPEED: 14400, + gossh.TTY_OP_OSPEED: 14400, + }); err != nil { + return ErrPty + } + + if err := agent.Shell(); err != nil { + return ErrShell + } + + go func() { + defer agent.Close() + + for { + var message Message + + if _, err := conn.ReadMessage(&message); err != nil { + if errors.Is(err, io.EOF) { + return + } + + logger.WithError(err).Error("failed to read the message from the client") + + return + } + + switch message.Kind { + case messageKindInput: + buffer, ok := message.Data.(string) + if !ok { + continue + } + + if _, err := stdin.Write([]byte(buffer)); err != nil { + return + } + case messageKindResize: + d, ok := message.Data.(Dimensions) + if !ok { + continue + } + + if err := agent.WindowChange(int(d.Rows), int(d.Cols)); err != nil { + return + } + } + } + }() + + go redirToWs(stdout, conn) //nolint:errcheck + go io.Copy(conn, stderr) //nolint:errcheck + + if err := agent.Wait(); err != nil { + logger.WithError(err).Warning("client remote command returned an error") + } + + return nil +} diff --git a/ssh/web/errors.go b/ssh/web/errors.go index 4b55cc2ab21..deaaf154534 100644 --- a/ssh/web/errors.go +++ b/ssh/web/errors.go @@ -33,6 +33,10 @@ var ( ErrConfiguration = fmt.Errorf("failed to create communication configuration") ErrInvalidVersion = fmt.Errorf("failed to parse device version") ErrUnsuportedPublicKeyAuth = fmt.Errorf("connections using public keys are not permitted when the agent version is 0.5.x or earlier") + ErrHostKeyMismatch = fmt.Errorf("the host key does not match the trusted one for this connection") + ErrHostKeyUnverified = fmt.Errorf("no verified host key for this connection") + ErrEgressBlocked = fmt.Errorf("the address is not a permitted connection target") + ErrUnreachableHost = fmt.Errorf("could not reach the host") ) var ( diff --git a/ssh/web/manager.go b/ssh/web/manager.go index 7671da27534..d2f59d9f4ec 100644 --- a/ssh/web/manager.go +++ b/ssh/web/manager.go @@ -29,9 +29,11 @@ func (m *manager) save(id string, data *Credentials) { }) } -// get gets the credentials if it time period have not ended. +// get consumes the credentials for id, if the TTL has not elapsed. The token is +// single-use: it is deleted on first read so a leaked token (it travels as a +// query param) can't be replayed within the TTL window. func (m *manager) get(id string) (*Credentials, bool) { - l, ok := m.credentials.Load(id) + l, ok := m.credentials.LoadAndDelete(id) if !ok { return nil, false } diff --git a/ssh/web/utils.go b/ssh/web/utils.go index 6fab6eef4f8..b77174aedd1 100644 --- a/ssh/web/utils.go +++ b/ssh/web/utils.go @@ -17,6 +17,21 @@ type Credentials struct { Password string `json:"password"` // Fingerprint is the identifier of the public key used in the device's OS. Fingerprint string `json:"fingerprint"` + // PublicKey is the OpenSSH authorized-keys blob of the key used for the + // direct connection bridge (/ws/connect). The target host is external, so + // its public key is not registered in ShellHub and must be supplied by the + // browser. The matching private key never leaves the browser: the server + // proxies each signing challenge over the websocket (see [Signer]). + PublicKey string `json:"public_key"` + // Host and Port are set only for the direct connection bridge (/ws/connect), + // where the server dials the target SSH endpoint directly instead of routing + // to a device over the reverse tunnel. + Host string `json:"host"` + Port int `json:"port"` + // KnownHostKey is the host key the browser confirmed (authorized_keys format) + // for this external target. The server verifies the live host key against it + // and aborts on mismatch. Empty means no verified key (legacy/uninitialized). + KnownHostKey string `json:"known_host_key"` } func (c *Credentials) encryptPassword(key *rsa.PrivateKey) error { @@ -54,7 +69,7 @@ func (c *Credentials) decryptPassword(key *rsa.PrivateKey) error { return nil } -func (c *Credentials) isPublicKey() bool { // nolint: unused +func (c *Credentials) isPublicKey() bool { return c.Fingerprint != "" } diff --git a/ui/apps/console/src/App.tsx b/ui/apps/console/src/App.tsx index 3010345ebb5..e0837f54866 100644 --- a/ui/apps/console/src/App.tsx +++ b/ui/apps/console/src/App.tsx @@ -26,6 +26,7 @@ const ConfirmAccount = lazy(() => import("./pages/ConfirmAccount")); const ValidationAccount = lazy(() => import("./pages/ValidationAccount")); const Dashboard = lazy(() => import("./pages/Dashboard")); const Devices = lazy(() => import("./pages/devices")); +const Connections = lazy(() => import("./pages/connections")); const Containers = lazy(() => import("./pages/containers")); const ContainerDetails = lazy(() => import("./pages/ContainerDetails")); const Sessions = lazy(() => import("./pages/sessions")); @@ -194,36 +195,40 @@ export default function App() { /> } /> } /> + } /> } /> } /> } /> - } /> + } + /> } /> } /> } /> } /> - )} + } /> {getConfig().webEndpoints && ( - )} + } /> )} } /> diff --git a/ui/apps/console/src/api/connections.ts b/ui/apps/console/src/api/connections.ts new file mode 100644 index 00000000000..a6d5f0e082d --- /dev/null +++ b/ui/apps/console/src/api/connections.ts @@ -0,0 +1,60 @@ +import { + createConnection as createConnectionSdk, + updateConnection as updateConnectionSdk, + deleteConnection as deleteConnectionSdk, + getConnectionStatus as getConnectionStatusSdk, +} from "@/client"; +import type { Connection } from "@/client"; + +// Maps the generated (all-optional) SDK response types to the app's Connection +// type at the boundary, so the rest of the app keeps working with required +// fields. + +export interface ConnectionBody { + label: string; + kind: "external" | "device"; + host?: string; + port?: number; + device_uid?: string; + username?: string; + /** "" | "password" | "key" */ + auth_method?: string; + /** Points at the SSH key to use (never the secret). */ + key_fingerprint?: string; + /** Save an external connection even if the target is currently unreachable. */ + force?: boolean; +} + +export async function createConnection( + body: ConnectionBody, +): Promise { + const { data } = await createConnectionSdk({ body, throwOnError: true }); + + return data; +} + +export async function updateConnection( + id: string, + body: ConnectionBody, +): Promise { + const { data } = await updateConnectionSdk({ + path: { id }, + body, + throwOnError: true, + }); + + return data; +} + +export async function getConnectionStatus(id: string): Promise { + const { data } = await getConnectionStatusSdk({ + path: { id }, + throwOnError: true, + }); + + return data?.online ?? false; +} + +export async function deleteConnection(id: string): Promise { + await deleteConnectionSdk({ path: { id }, throwOnError: true }); +} diff --git a/ui/apps/console/src/api/hostKeys.ts b/ui/apps/console/src/api/hostKeys.ts new file mode 100644 index 00000000000..e45621e9fb6 --- /dev/null +++ b/ui/apps/console/src/api/hostKeys.ts @@ -0,0 +1,70 @@ +import { + scanKnownHost as scanKnownHostSdk, + acceptKnownHost as acceptKnownHostSdk, + getKnownHost as getKnownHostSdk, + deleteKnownHost as deleteKnownHostSdk, +} from "@/client"; +import type { KnownHost } from "@/client"; + +// Known-host (TOFU) endpoints for external connections. The API scope is +// "personal" (per-user) or "namespace" (team-shared) and follows the connection. +export type HostKeyScope = "personal" | "namespace"; +export type HostKeyStatus = "unverified" | "trusted" | "changed"; + +export interface HostKeyScanResult { + key_type: string; + fingerprint: string; + public_key: string; + status: HostKeyStatus; + stored: KnownHost | null; +} + +export async function scanHostKey( + host: string, + port: number, + scope: HostKeyScope, +): Promise { + const { data } = await scanKnownHostSdk({ + body: { host, port, scope }, + throwOnError: true, + }); + + return data as HostKeyScanResult; +} + +export async function acceptHostKey(body: { + host: string; + port: number; + scope: HostKeyScope; + key_type: string; + public_key: string; + fingerprint: string; +}): Promise { + const { data } = await acceptKnownHostSdk({ body, throwOnError: true }); + + return data; +} + +export async function getHostKey( + host: string, + port: number, + scope: HostKeyScope, +): Promise { + const { data } = await getKnownHostSdk({ + query: { host, port, scope }, + throwOnError: true, + }); + + return data ?? null; +} + +export async function forgetHostKey( + host: string, + port: number, + scope: HostKeyScope, +): Promise { + await deleteKnownHostSdk({ + query: { host, port, scope }, + throwOnError: true, + }); +} diff --git a/ui/apps/console/src/api/teamConnections.ts b/ui/apps/console/src/api/teamConnections.ts new file mode 100644 index 00000000000..7272c648886 --- /dev/null +++ b/ui/apps/console/src/api/teamConnections.ts @@ -0,0 +1,86 @@ +import { + createTeamConnection as createTeamConnectionSdk, + updateTeamConnection as updateTeamConnectionSdk, + deleteTeamConnection as deleteTeamConnectionSdk, + getTeamConnectionStatus as getTeamConnectionStatusSdk, + getTeamConnectionPrefs as getTeamConnectionPrefsSdk, + updateTeamConnectionPrefs as updateTeamConnectionPrefsSdk, +} from "@/client"; +import type { TeamConnection, TeamConnectionPrefs } from "@/client"; + +// Team connections are an Enterprise/Cloud feature, license-gated (402 on an +// edition without it). Maps the generated SDK types to the app types at the +// boundary. + +export interface TeamConnectionBody { + label: string; + kind: "external" | "device"; + host?: string; + port?: number; + device_uid?: string; +} + +/** Per-user auth preference payload for a team connection (never the secret). */ +export interface TeamConnectionPrefsBody { + username?: string; + auth_method?: string; + key_fingerprint?: string; +} + +export async function createTeamConnection( + body: TeamConnectionBody, +): Promise { + const { data } = await createTeamConnectionSdk({ body, throwOnError: true }); + + return data; +} + +export async function updateTeamConnection( + id: string, + body: TeamConnectionBody, +): Promise { + const { data } = await updateTeamConnectionSdk({ + path: { id }, + body, + throwOnError: true, + }); + + return data; +} + +export async function deleteTeamConnection(id: string): Promise { + await deleteTeamConnectionSdk({ path: { id }, throwOnError: true }); +} + +export async function getTeamConnectionStatus(id: string): Promise { + const { data } = await getTeamConnectionStatusSdk({ + path: { id }, + throwOnError: true, + }); + + return data?.online ?? false; +} + +export async function getTeamConnectionPrefs( + id: string, +): Promise { + const { data } = await getTeamConnectionPrefsSdk({ + path: { id }, + throwOnError: true, + }); + + return data; +} + +export async function putTeamConnectionPrefs( + id: string, + body: TeamConnectionPrefsBody, +): Promise { + const { data } = await updateTeamConnectionPrefsSdk({ + path: { id }, + body, + throwOnError: true, + }); + + return data; +} diff --git a/ui/apps/console/src/components/ConnectDrawer.tsx b/ui/apps/console/src/components/ConnectDrawer.tsx index 8c9b6ab9c75..c93846b8299 100644 --- a/ui/apps/console/src/components/ConnectDrawer.tsx +++ b/ui/apps/console/src/components/ConnectDrawer.tsx @@ -1,36 +1,96 @@ -import { useEffect, useReducer, useState, FormEvent } from "react"; +import { useEffect, useReducer, useRef, useState } from "react"; +import { useQueryClient } from "@tanstack/react-query"; +import { Link } from "react-router-dom"; import { LockClosedIcon, KeyIcon, ChevronDoubleRightIcon, + ChevronUpDownIcon, ShieldCheckIcon, ExclamationCircleIcon, + ExclamationTriangleIcon, + ArrowTopRightOnSquareIcon, + ServerStackIcon, + UserIcon, + UsersIcon, } from "@heroicons/react/24/outline"; import { useTerminalStore } from "../stores/terminalStore"; import { useVaultStore } from "../stores/vaultStore"; -import { getFingerprint, validatePrivateKey } from "../utils/ssh-keys"; +import { useClickOutside } from "../hooks/useClickOutside"; +import { + getFingerprint, + getPublicKey, + validatePrivateKey, +} from "../utils/ssh-keys"; +import { + useCreateConnection, + useUpdateConnection, +} from "../hooks/useConnectionMutations"; +import { + useCreateTeamConnection, + useUpdateTeamConnection, + useUpdateTeamConnectionPrefs, +} from "../hooks/useTeamConnectionMutations"; +import { useTeamConnectionPrefs } from "../hooks/useTeamConnections"; +import { + scanHostKey, + getHostKey, + acceptHostKey, + type HostKeyScanResult, +} from "@/api/hostKeys"; import CopyButton from "./common/CopyButton"; import Drawer from "./common/Drawer"; +import Alert from "./common/Alert"; +import DevicePicker from "./common/DevicePicker"; import VaultLockedBanner from "./vault/VaultLockedBanner"; import VaultUnlockDialog from "./vault/VaultUnlockDialog"; import InputField from "@/components/common/fields/InputField"; import PasswordField from "@/components/common/fields/PasswordField"; +import CheckboxField from "@/components/common/fields/CheckboxField"; import FieldLabel from "@/components/common/fields/FieldLabel"; import RadioCard from "@/components/common/fields/RadioCard"; import RadioGroupField from "@/components/common/fields/RadioGroupField"; import RadioSegment from "@/components/common/fields/RadioSegment"; +import KeyFileInput from "@/components/common/fields/KeyFileInput"; +import PremiumUpsell from "@/components/common/PremiumUpsell"; +import { getConfig } from "@/env"; import { INPUT, LABEL } from "../utils/styles"; import { Card, Button } from "@shellhub/design-system/primitives"; +import { isSdkError } from "@/api/errors"; +import { connectionDirty } from "@/utils/connectionDirty"; import type { VaultKeyEntry } from "../types/vault"; +import type { Connection, TeamConnection } from "@/client"; + +// The single drawer for reaching an SSH target. Depending on props it can: +// - Connect to a fixed device target (device/container pages). +// - Connect to a saved connection (connections page rows). +// - Create or edit a saved connection. +// +// Scope decides where the auth lives. A "personal" connection belongs to the +// caller, so its auth (username + key) is stored on the connection record. A +// "team" connection is shared with the namespace, so the target is shared but +// the auth is the caller's own per-user pref. Team is an Enterprise/Cloud +// capability; only the secret never leaves the browser in either case. +type Scope = "personal" | "team"; interface Props { open: boolean; onClose: () => void; - deviceUid: string; - deviceName: string; - sshid: string; + deviceUid?: string; + deviceName?: string; + sshid?: string; + connection?: Connection | null; + // Scope of the connection being edited/connected. Ignored on create, where + // the Personal/Team toggle drives it. + scope?: Scope; + // Whether the caller may create a team connection (Enterprise/Cloud + operator+). + canCreateTeam?: boolean; + editable?: boolean; + onSaved?: () => void; } +type Kind = "external" | "device"; + interface FormState { username: string; authMethod: "password" | "key"; @@ -97,44 +157,370 @@ function formReducer(state: FormState, action: FormAction): FormState { } } +// Map a team mutation result into the shared Connection shape the drawer's +// callbacks consume (auth stays empty; it lives in per-user prefs). +function teamResultToConnection(t: TeamConnection): Connection { + return { + id: t.id, + tenant_id: t.tenant_id, + owner_id: t.created_by, + label: t.label, + kind: t.kind, + host: t.host ?? "", + port: t.port || 22, + device_uid: t.device_uid, + username: "", + auth_method: "", + key_fingerprint: "", + created_at: t.created_at, + updated_at: t.updated_at, + }; +} + +function VaultKeySelect({ + keys, + value, + onChange, +}: { + keys: VaultKeyEntry[]; + value: string; + onChange: (id: string) => void; +}) { + const [open, setOpen] = useState(false); + const ref = useRef(null); + useClickOutside(ref, () => setOpen(false)); + + const selected = keys.find((k) => k.id === value); + + return ( +
+ + + {open && ( +
+
+ {keys.length === 0 ? ( +
+ No keys in vault +
+ ) : ( + keys.map((k) => ( + + )) + )} +
+
+ )} +
+ ); +} + export default function ConnectDrawer({ open, onClose, - deviceUid, - deviceName, - sshid, + deviceUid = "", + deviceName = "", + sshid = "", + connection = null, + scope = "personal", + canCreateTeam = false, + editable = false, + onSaved, }: Props) { + const queryClient = useQueryClient(); const openTerminal = useTerminalStore((s) => s.open); const vaultStatus = useVaultStore((s) => s.status); const vaultKeys = useVaultStore((s) => s.keys); const refreshVault = useVaultStore((s) => s.refreshStatus); + const createMutation = useCreateConnection(); + const updateMutation = useUpdateConnection(); + const createTeamMutation = useCreateTeamConnection(); + const updateTeamMutation = useUpdateTeamConnection(); + const teamPrefsMutation = useUpdateTeamConnectionPrefs(); + + const isCreate = editable && !connection; + const isEdit = editable && !!connection; + const isConnect = !editable; + + // Whether the edition has team connections at all (Cloud/Enterprise). On + // Community it doesn't, so the Team scope is shown as a locked Pro upsell. + const cfg = getConfig(); + const teamEdition = !!cfg.cloud || !!cfg.enterprise; + const [state, dispatch] = useReducer(formReducer, initialState); const [unlockOpen, setUnlockOpen] = useState(false); + // Target editor state (only used when `editable`). + const [tScope, setTScope] = useState("personal"); + const [tLabel, setTLabel] = useState(""); + const [tKind, setTKind] = useState("external"); + const [tHost, setTHost] = useState(""); + const [tPort, setTPort] = useState("22"); + const [tDeviceUid, setTDeviceUid] = useState(""); + const [tDeviceName, setTDeviceName] = useState(""); + const [unreachable, setUnreachable] = useState(false); + const [saveError, setSaveError] = useState(null); + const [saveOnConnect, setSaveOnConnect] = useState(false); + + // Host-key (TOFU) confirmation state for external connects. + const [hostKeyResult, setHostKeyResult] = useState( + null, + ); + const [hostKeyBusy, setHostKeyBusy] = useState(false); + const [hostKeyError, setHostKeyError] = useState(null); + // Stash the resolved key across the host-key confirmation step. + const pendingKeyRef = useRef<{ + key: string; + phrase: string | undefined; + fingerprint: string; + publicKey: string | undefined; + } | null>(null); + + // Effective scope: the toggle drives create; the prop fixes edit/connect. + const effectiveScope: Scope = isCreate ? tScope : scope; + const isTeam = effectiveScope === "team"; + // API scope param: "team" maps to the shared "namespace" record. + const apiScope = isTeam ? "namespace" : "personal"; + + // A team connection's auth is the caller's own per-user pref, fetched + // separately for connect and edit (create has no connection yet). + const { prefs: teamPrefs } = useTeamConnectionPrefs( + !isCreate && scope === "team" && connection ? connection.id : undefined, + ); + useEffect(() => { if (!open) return; dispatch({ type: "reset" }); + // eslint-disable-next-line react-hooks/set-state-in-effect + setUnreachable(false); + setSaveError(null); + // Create defaults to saving the new connection (a "New connection" the user + // can still opt out of); connect starts unchecked. + setSaveOnConnect(isCreate); + setHostKeyResult(null); + setHostKeyBusy(false); + setHostKeyError(null); + setTScope(isCreate ? "personal" : scope); + + if (connection) { + setTLabel(connection.label); + setTKind(connection.kind === "device" ? "device" : "external"); + setTHost(connection.host); + setTPort(String(connection.port || 22)); + setTDeviceUid(connection.device_uid); + setTDeviceName(connection.label); + + // Personal auth lives on the record; team auth is prefilled from the + // per-user prefs effect below. + if (scope === "personal") { + dispatch({ type: "setUsername", value: connection.username ?? "" }); + if ( + connection.auth_method === "password" || + connection.auth_method === "key" + ) { + dispatch({ type: "setAuthMethod", value: connection.auth_method }); + } + } + } else { + setTLabel(""); + setTKind("external"); + setTHost(""); + setTPort("22"); + setTDeviceUid(""); + setTDeviceName(""); + } + void refreshVault(); - }, [open, refreshVault]); + }, [open, connection, scope, isCreate, refreshVault]); + + // Prefill from the caller's team prefs (team connect/edit), once they load. + useEffect(() => { + if (!open || isCreate || scope !== "team" || !teamPrefs) return; + if (teamPrefs.username) + dispatch({ type: "setUsername", value: teamPrefs.username }); + if ( + teamPrefs.auth_method === "password" || + teamPrefs.auth_method === "key" + ) { + dispatch({ type: "setAuthMethod", value: teamPrefs.auth_method }); + } + }, [open, isCreate, scope, teamPrefs]); const hasVaultKeys = vaultStatus === "unlocked" && vaultKeys.length > 0; - const effectiveKeySource = hasVaultKeys ? state.keySource : "manual"; + // The vault is offerable as a key source when it holds keys or is just locked + // (unlocking reveals them). Keeping "vault" selectable while locked lets the + // locked notice live inside the Vault tab instead of floating above the toggle. + const vaultPresent = vaultStatus === "locked" || hasVaultKeys; + const effectiveKeySource = vaultPresent ? state.keySource : "manual"; const selectedVaultKey: VaultKeyEntry | undefined = hasVaultKeys ? vaultKeys.find((k) => k.id === state.selectedKeyId) : undefined; - const canConnect = - state.username.trim().length > 0 && - (state.authMethod === "password" - ? state.password.trim().length > 0 - : effectiveKeySource === "vault" - ? !!selectedVaultKey && - (!selectedVaultKey.hasPassphrase || - state.passphrase.trim().length > 0) - : state.manualKeyValid && - (!state.manualKeyEncrypted || state.passphrase.trim().length > 0)); + // Which key the saved auth pref points at (team: per-user prefs; personal: row). + const preferredKeyFingerprint = isTeam + ? (teamPrefs?.key_fingerprint ?? "") + : (connection?.key_fingerprint ?? ""); + + useEffect(() => { + if (!open || !preferredKeyFingerprint || !hasVaultKeys) return; + const match = vaultKeys.find( + (k) => k.fingerprint === preferredKeyFingerprint, + ); + if (match) dispatch({ type: "setSelectedKeyId", value: match.id }); + }, [open, preferredKeyFingerprint, hasVaultKeys, vaultKeys]); + + const preferredKeyAvailable = + vaultStatus === "unlocked" && + vaultKeys.some((k) => k.fingerprint === preferredKeyFingerprint); + const preferredKeyMissing = + state.authMethod === "key" && + preferredKeyFingerprint !== "" && + vaultStatus !== "locked" && + !preferredKeyAvailable; + + // The connection authenticates with a key that lives in the vault, but the + // vault is locked and the user hasn't pasted a one-off key instead. Connecting + // should prompt an unlock rather than fail. + const needsVaultUnlock = + state.authMethod === "key" && + vaultStatus === "locked" && + preferredKeyFingerprint !== "" && + !state.privateKey.trim(); + + const targetKind: Kind = editable + ? tKind + : connection + ? connection.kind === "device" + ? "device" + : "external" + : "device"; + const isExternal = targetKind === "external"; + const targetHost = editable ? tHost.trim() : (connection?.host ?? ""); + const portNum = editable ? Number(tPort) : (connection?.port ?? 22); + const portValid = + Number.isInteger(portNum) && portNum >= 1 && portNum <= 65535; + const targetDeviceUid = editable + ? tDeviceUid + : (connection?.device_uid ?? deviceUid); + const targetDeviceName = editable + ? tDeviceName + : (connection?.label ?? deviceName); + const targetLabel = editable + ? tLabel.trim() + : (connection?.label ?? deviceName); + + // Whether the editable target differs from the saved record. A team edit may + // be opened by a non-manager just to set their own auth; when the shared + // target is untouched we skip the operator+-only update entirely. + const targetChanged = + !connection || + targetLabel !== connection.label || + targetKind !== (connection.kind === "device" ? "device" : "external") || + (isExternal + ? targetHost !== connection.host || portNum !== connection.port + : targetDeviceUid !== connection.device_uid); + + const showConnect = !isEdit; + // Create folds save into the Connect action via the "Save connection" + // checkbox; only edit keeps a dedicated Save button. + const showSave = isEdit; + // Auth is always editable. Personal auth lives on the record; team auth is the + // caller's own per-user pref, seeded on save (create/edit) or on connect. + const showAuth = true; + const pending = + createMutation.isPending || + updateMutation.isPending || + createTeamMutation.isPending || + updateTeamMutation.isPending || + teamPrefsMutation.isPending; + + const targetValid = editable + ? targetLabel.length > 0 && + (isExternal + ? targetHost.length > 0 && portValid + : targetDeviceUid.length > 0) + : true; + + const authValid = + !showAuth || + (state.username.trim().length > 0 && + (state.authMethod === "password" + ? state.password.trim().length > 0 + : effectiveKeySource === "vault" + ? !!selectedVaultKey && + (!selectedVaultKey.hasPassphrase || + state.passphrase.trim().length > 0) + : state.manualKeyValid && + (!state.manualKeyEncrypted || state.passphrase.trim().length > 0))); + + // A locked vault still lets the user click Connect: the click prompts an + // unlock (handled in doConnect) instead of being blocked by the missing key. + const canConnect = targetValid && (authValid || needsVaultUnlock) && !pending; + const canSave = showSave && targetValid && !pending; + + const currentKeyFp = + state.authMethod === "key" + ? effectiveKeySource === "vault" + ? (selectedVaultKey?.fingerprint ?? "") + : state.manualKeyValid + ? "manual" + : "" + : ""; + const authRecord = isTeam + ? { + username: teamPrefs?.username ?? "", + auth_method: teamPrefs?.auth_method ?? "", + key_fingerprint: teamPrefs?.key_fingerprint ?? "", + } + : { + username: connection?.username ?? "", + auth_method: connection?.auth_method ?? "", + key_fingerprint: connection?.key_fingerprint ?? "", + }; + const dirty = + isConnect && + !!connection && + connectionDirty( + { + username: state.username, + authMethod: state.authMethod, + keyFingerprint: currentKeyFp, + }, + authRecord, + ); const handleManualKeyChange = (pem: string) => { if (!pem.trim()) { @@ -155,68 +541,410 @@ export default function ConnectDrawer({ }); }; - const handleConnect = (e: FormEvent) => { - e.preventDefault(); - if (!canConnect) return; - + // Best-effort current auth preference to persist (never the secret). + const currentAuthPref = (): { + auth_method: string; + key_fingerprint: string; + } => { if (state.authMethod === "password") { - openTerminal({ - deviceUid, - deviceName, - username: state.username.trim(), - password: state.password, + return { auth_method: "password", key_fingerprint: "" }; + } + if (effectiveKeySource === "vault" && selectedVaultKey) { + return { + auth_method: "key", + key_fingerprint: selectedVaultKey.fingerprint, + }; + } + try { + const phrase = state.manualKeyEncrypted ? state.passphrase : undefined; + return { + auth_method: "key", + key_fingerprint: getFingerprint(state.privateKey.trim(), phrase), + }; + } catch { + return { auth_method: "key", key_fingerprint: "" }; + } + }; + + const targetBody = () => ({ + label: targetLabel, + kind: targetKind, + host: isExternal ? targetHost : undefined, + port: isExternal ? portNum : undefined, + device_uid: isExternal ? undefined : targetDeviceUid, + }); + + // Persist the caller's team auth pref (username + key pointer, never secret). + const seedTeamPrefs = (id: string, onDone?: () => void) => { + const pref = currentAuthPref(); + teamPrefsMutation.mutate( + { + id, + body: { + username: state.username.trim(), + auth_method: pref.auth_method, + key_fingerprint: pref.key_fingerprint, + }, + }, + { onSuccess: () => onDone?.() }, + ); + }; + + // Create or update the connection target. Personal connections carry the + // caller's auth on the record; team connections store only the shared target + // (the per-user auth is seeded separately). Runs onDone with the record. + const persistTarget = (force: boolean, onDone?: (c: Connection) => void) => { + setSaveError(null); + const onError = (err?: unknown) => { + // Editing a team target needs operator+; tell the user that apart from a + // generic failure (they reach this drawer to set their own auth too). + if (isSdkError(err) && err.status === 403) { + setSaveError("You don't have permission to change this connection."); + + return; + } + + setSaveError("Failed to save connection. Check the fields."); + }; + + if (isTeam) { + const body = targetBody(); + if (connection) { + // The shared target is operator+-only; if it's unchanged (e.g. a member + // just setting their own auth) keep the record and skip the update. + if (!targetChanged) { + onDone?.(connection); + return; + } + updateTeamMutation.mutate( + { id: connection.id, body }, + { + onSuccess: (c) => { + onSaved?.(); + onDone?.(teamResultToConnection(c)); + }, + onError, + }, + ); + return; + } + + createTeamMutation.mutate(body, { + onSuccess: (c) => { + onSaved?.(); + onDone?.(teamResultToConnection(c)); + }, + onError, }); - } else { - const key = - effectiveKeySource === "vault" && selectedVaultKey - ? selectedVaultKey.data - : state.privateKey.trim(); - const phrase = - effectiveKeySource === "vault" && selectedVaultKey - ? selectedVaultKey.hasPassphrase - ? state.passphrase - : undefined - : state.manualKeyEncrypted - ? state.passphrase - : undefined; - - let fingerprint: string; + return; + } + + const pref = currentAuthPref(); + const body = { + ...targetBody(), + username: state.username.trim(), + auth_method: pref.auth_method, + key_fingerprint: pref.key_fingerprint, + force, + }; + + if (connection) { + updateMutation.mutate( + { id: connection.id, body }, + { + onSuccess: (c) => { + onSaved?.(); + onDone?.(c); + }, + onError, + }, + ); + return; + } + + createMutation.mutate(body, { + onSuccess: (c) => { + onSaved?.(); + onDone?.(c); + }, + onError: (err) => { + if (isSdkError(err) && err.status === 422) { + // 422 distinguishes a blocked address (policy) from an unreachable host + // (the NAT/firewall + install-agent funnel). + if ((err as { error?: string }).error === "blocked") { + setSaveError( + `${targetHost}:${portNum} isn't a permitted connection target.`, + ); + } else { + setUnreachable(true); + } + } else { + onError(); + } + }, + }); + }; + + // Save (bookmark). Creating a team connection also seeds the creator's prefs. + const save = (force: boolean) => { + setUnreachable(false); + persistTarget(force, (saved) => { + // Team save (create or edit) also persists the caller's own auth pref, + // but only when they actually provided a username. + if (isTeam && state.username.trim()) { + seedTeamPrefs(saved.id, () => onClose()); + } else { + onClose(); + } + }); + }; + + const handleSave = () => { + if (!canSave) return; + save(false); + }; + + const resolveKey = () => { + const key = + effectiveKeySource === "vault" && selectedVaultKey + ? selectedVaultKey.data + : state.privateKey.trim(); + const phrase = + effectiveKeySource === "vault" && selectedVaultKey + ? selectedVaultKey.hasPassphrase + ? state.passphrase + : undefined + : state.manualKeyEncrypted + ? state.passphrase + : undefined; + + let fingerprint: string; + try { + fingerprint = getFingerprint(key, phrase); + } catch { + dispatch({ + type: "setKeyError", + value: "Failed to read private key. Check the key or passphrase.", + }); + return null; + } + if ( + effectiveKeySource === "vault" && + selectedVaultKey && + fingerprint !== selectedVaultKey.fingerprint + ) { + dispatch({ + type: "setKeyError", + value: + "Key data appears corrupted. Try re-importing the key into the vault.", + }); + return null; + } + + let publicKey: string | undefined; + if (isExternal) { try { - fingerprint = getFingerprint(key, phrase); + publicKey = getPublicKey(key, phrase); } catch { dispatch({ type: "setKeyError", value: "Failed to read private key. Check the key or passphrase.", }); + return null; + } + } + dispatch({ type: "setKeyError", value: null }); + return { key, phrase, fingerprint, publicKey }; + }; + + const openTerminalFor = ( + connId: string, + keyArg: ReturnType, + hostKey: string, + ) => { + const base = isExternal + ? { + kind: "connect" as const, + deviceUid: connId, + deviceName: targetLabel || targetHost, + host: targetHost, + port: portNum, + knownHostKey: hostKey, + } + : { deviceUid: targetDeviceUid, deviceName: targetDeviceName }; + + if (!keyArg) { + openTerminal({ + ...base, + username: state.username.trim(), + password: state.password, + }); + return; + } + + openTerminal({ + ...base, + username: state.username.trim(), + password: "", + fingerprint: keyArg.fingerprint, + privateKey: keyArg.key, + passphrase: keyArg.phrase, + publicKey: keyArg.publicKey, + }); + }; + + // Open the terminal (creating/seeding as needed) once the host key is settled. + const finishConnect = ( + keyArg: ReturnType, + hostKey: string, + ) => { + // Create mode: minting the record IS the save (forced, since the user chose + // to connect). Team also seeds the creator's prefs before opening. + if (isCreate) { + // The "Save connection" checkbox decides whether to bookmark the target. + // Unchecked is a one-off, ephemeral connect with no record persisted. + if (!saveOnConnect) { + const ephemeralId = isExternal + ? `external:${targetHost}:${portNum}` + : targetDeviceUid; + openTerminalFor(ephemeralId, keyArg, hostKey); + onClose(); return; } + + persistTarget(true, (created) => { + if (isTeam && state.username.trim()) { + seedTeamPrefs(created.id, () => { + openTerminalFor(created.id, keyArg, hostKey); + onClose(); + }); + } else { + openTerminalFor(created.id, keyArg, hostKey); + onClose(); + } + }); + return; + } + + openTerminalFor(connection?.id ?? targetDeviceUid, keyArg, hostKey); + // Connect mode: persist the tweak only if "Save changes" is ticked. Team + // writes per-user prefs; personal updates the connection's auth. + if (connection && dirty && saveOnConnect) { + if (isTeam) { + seedTeamPrefs(connection.id); + } else { + persistTarget(false); + } + } + onClose(); + }; + + const doConnect = async () => { + if (!canConnect || hostKeyBusy) return; + + // Key auth whose key lives in the locked vault: prompt an unlock instead of + // attempting the connect. Once unlocked, the saved key resolves and the user + // can connect. + if (needsVaultUnlock) { + setUnlockOpen(true); + return; + } + + let keyArg: ReturnType = null; + if (state.authMethod === "key") { + keyArg = resolveKey(); + if (!keyArg) return; + } + + // Device targets go through the agent tunnel; no host key to verify. + if (!isExternal) { + finishConnect(keyArg, ""); + + return; + } + + // External: verify the host key (TOFU). If already trusted, pass the stored + // key straight through (the server re-checks). On first use, scan and ask. + setHostKeyError(null); + setHostKeyBusy(true); + try { + const stored = await getHostKey(targetHost, portNum, apiScope); + if (stored) { + setHostKeyBusy(false); + finishConnect(keyArg, stored.public_key); + + return; + } + + const scan = await scanHostKey(targetHost, portNum, apiScope); + setHostKeyBusy(false); + if (scan.status === "trusted") { + finishConnect(keyArg, scan.public_key); + + return; + } + + pendingKeyRef.current = keyArg; + setHostKeyResult(scan); + } catch (err) { + setHostKeyBusy(false); if ( - effectiveKeySource === "vault" && - selectedVaultKey && - fingerprint !== selectedVaultKey.fingerprint + isSdkError(err) && + err.status === 422 && + (err as { error?: string }).error === "blocked" ) { - dispatch({ - type: "setKeyError", - value: - "Key data appears corrupted. Try re-importing the key into the vault.", - }); - return; + setHostKeyError( + `${targetHost}:${portNum} isn't a permitted connection target.`, + ); + } else { + setHostKeyError( + "Couldn't read the host key. Check that the host is reachable.", + ); } - dispatch({ type: "setKeyError", value: null }); + } + }; - openTerminal({ - deviceUid, - deviceName, - username: state.username.trim(), - password: "", - fingerprint, - privateKey: key, - passphrase: phrase, + // Accept the scanned host key (TOFU) and continue connecting. + const acceptAndConnect = async () => { + if (!hostKeyResult) return; + + setHostKeyBusy(true); + try { + await acceptHostKey({ + host: targetHost, + port: portNum, + scope: apiScope, + key_type: hostKeyResult.key_type, + public_key: hostKeyResult.public_key, + fingerprint: hostKeyResult.fingerprint, }); + // The accept goes through the raw API (not useAcceptHostKey), so invalidate + // the cache the host-key modal reads or it shows stale "no key stored". + void queryClient.invalidateQueries({ + queryKey: ["host-key", apiScope, targetHost, portNum], + }); + + const publicKey = hostKeyResult.public_key; + const keyArg = pendingKeyRef.current; + setHostKeyBusy(false); + setHostKeyResult(null); + finishConnect(keyArg, publicKey); + } catch { + setHostKeyBusy(false); + setHostKeyError("Failed to save the host key. Try again."); } - onClose(); }; + const formId = `connect-form-${connection?.id ?? deviceUid ?? "new"}`; + + const title = isCreate + ? "New connection" + : isEdit + ? isTeam + ? "Edit team connection" + : "Edit connection" + : `Connect to ${targetLabel || targetDeviceName}`; + return ( <> {deviceName}} + title={title} + subtitle={ + editable ? "A device or external host you reach over SSH" : undefined + } footer={ <> - - + {showSave && ( + + )} + {showConnect && ( + + )} } >
{ + e.preventDefault(); + if (isEdit) handleSave(); + else void doConnect(); + }} className="space-y-5" > - {/* SSHID helper */} - - Connect via terminal -
- - ssh - {state.username.trim() ? ( - - {state.username.trim()}@{sshid} - - ) : ( - <> - - <username> - - @{sshid} - - )} - - @${sshid}` - } - /> -
- {state.username.trim() ? ( -

- Command ready — copy and run in your terminal. -

- ) : ( -

- Enter your device OS username below to complete this command. -

- )} -
- -
-
- - or connect via web - -
-
+