From a1dcaea6ee8cd74c010da68c0e5984addc066e97 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Sun, 17 May 2026 07:28:07 -0500 Subject: [PATCH 01/24] connmgr: Add context-aware semaphore. This adds a new context-aware semaphore type with Acquire and Release methods for use in upcoming changes that aim to simplify connection limiting by making use of semaphores for blocking until permits become available. --- internal/connmgr/semaphore.go | 36 +++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 internal/connmgr/semaphore.go diff --git a/internal/connmgr/semaphore.go b/internal/connmgr/semaphore.go new file mode 100644 index 000000000..fb7d7eed4 --- /dev/null +++ b/internal/connmgr/semaphore.go @@ -0,0 +1,36 @@ +// Copyright (c) 2026 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package connmgr + +import "context" + +// semaphore is a simple context-aware channel based semaphore for bounding +// concurrent access. +type semaphore chan struct{} + +// makeSemaphore returns a new semaphore with the given capacity. +func makeSemaphore(n uint32) semaphore { + return make(chan struct{}, n) +} + +// Acquire acquires the semaphore. It blocks until resources are available or +// the provided context is done. It returns true on success and false when the +// context is done before semaphore can be acquired. +func (s semaphore) Acquire(ctx context.Context) bool { + select { + case s <- struct{}{}: + case <-ctx.Done(): + return false + } + return true +} + +// Release release the semaphore. +func (s semaphore) Release() { + select { + case <-s: + default: + } +} From 704eb8509483489f301ad2f8f1045353e3c1b182 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Sun, 17 May 2026 07:28:08 -0500 Subject: [PATCH 02/24] connmgr: Add semaphore tests. This adds tests for the new context-aware semaphore to ensure the acquire, release, and context cancel semantics work as expected. --- internal/connmgr/semaphore_test.go | 109 +++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 internal/connmgr/semaphore_test.go diff --git a/internal/connmgr/semaphore_test.go b/internal/connmgr/semaphore_test.go new file mode 100644 index 000000000..9542176df --- /dev/null +++ b/internal/connmgr/semaphore_test.go @@ -0,0 +1,109 @@ +// Copyright (c) 2026 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package connmgr + +import ( + "context" + "testing" + "time" +) + +// TestSemaphore ensures the semaphore acquire, release, and context cancel +// semantics are as expected. +func TestSemaphore(t *testing.T) { + // Create a closure that acquires a semaphore with a timeout. + ctx := context.Background() + timedAcquire := func(sem semaphore, timeout time.Duration) bool { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + return sem.Acquire(ctx) + } + + // perSemTest describes a test to run against the same semaphore. + type perSemTest struct { + name string // test description + numAcquires uint32 // num to acquire + numReleases uint32 // num to release + } + + tests := []struct { + name string // test description + cap uint32 // capacity of the semaphore + perSemTests []perSemTest // tests to run against same semaphore + want []bool // expected results + }{{ + name: "normal block/release behavior", + cap: 2, + perSemTests: []perSemTest{{ + name: "cap 2 (0 acquired): acquire 3, release 1", + numAcquires: 3, + numReleases: 1, + }, { + name: "cap 2 (1 acquired): acquire 2, release 0", + numAcquires: 2, + numReleases: 0, + }, { + name: "cap 2 (2 acquired): acquire 1, release 2", + numAcquires: 1, + numReleases: 2, + }}, + want: []bool{true, true, false, true, false, false}, + }, { + // Releasing more than acquired ignores the extra release and does not + // influence future ops. + name: "relase more than acquired", + cap: 5, + perSemTests: []perSemTest{{ + name: "cap 5 (0 acquired): acquire 1, release 2", + numAcquires: 1, + numReleases: 2, + }, { + name: "cap 5 (0 acquired): acquire 5, release 1", + numAcquires: 5, + numReleases: 1, + }, { + name: "cap 5 (4 acquired): acquire 2, release 5", + numAcquires: 2, + numReleases: 5, + }}, + want: []bool{true, true, true, true, true, true, true, false}, + }} + + for _, test := range tests { + // Create semaphore with the capacity specified in the test and the + // a slice to hold the results. + sem := makeSemaphore(test.cap) + results := make([]bool, 0, len(test.want)) + + // Perform each sequence of acquires and releases as specified by the + // per semaphore tests. + for _, psTest := range test.perSemTests { + const timeout = 10 * time.Millisecond + for range psTest.numAcquires { + results = append(results, timedAcquire(sem, timeout)) + } + for range psTest.numReleases { + sem.Release() + } + } + + if len(results) != len(test.want) { + t.Errorf("%q: unexpected number of results: got %d, want %d", + test.name, len(results), len(test.want)) + } + for i := range results { + if results[i] != test.want[i] { + t.Errorf("%q: unexpected result for [%d]: got %v, want %v", + test.name, i, results[i], test.want[i]) + } + } + + // Ensure all acquires were released as expected. + if numAcquired := uint32(len(sem)); numAcquired != 0 { + t.Errorf("%q: unexpected final semaphore count: got %v, want %v", + test.name, numAcquired, 0) + } + } +} From 7db9f9e35ae9a93e124a281384d70410d912b3b0 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Sun, 17 May 2026 07:28:09 -0500 Subject: [PATCH 03/24] connmgr: Overhaul to use wrapped conns plus ctx. The existing connection manager code was written well before contexts were introduced. Further, due to the old async model that has now been converted to a synchronous model, it is based around connection requests that have their state atomically updated asynchronously as various things happen. While it has undoubtedly worked well enough for over a decade, it has always been a challenge to add new functionality to it and requires the use of a lot of less than ideal and highly outdated techniques such as polling for state changes. It is also rather brittle in terms of requiring output connections to be manually disconnected in the connection manager after they've been closed to avoid things like leaking goroutines and failing to update target outbound counts. Moreover, it only tracks outgoing connections which ultimately forces a lot of connection-related tasks to be split across different layers instead of residing in the connection manager itself where they more naturally belong. Notably, that split, for all intents and purposes, prevents implementing some desirable more advanced features such as immediate connection shedding, different connection types, and listeners tied to specific network types. With the primary goal of addressing all of the aforementioned points and providing a solid base to work on for adding new features moving forward, this significantly reworks the connection manager to completely get rid of the notion of exposed connection requests in favor of a new custom connection type that wraps the underlying net.Conn. The new wrapped connections automatically handle cleanup when closed and have an associated connection type enum that allows easily distinguishing inbound, outbound, and manual connections as well as supporting new connection types in the future. Another nice feature of the new wrapped connections is they provide efficient access to concrete parsed address types which paves the way for avoiding a lot of constant reparsing, repeated host/port splitting and joining, and generally much more ergonomic immutable address types. Since changing to wrapped connections basically required a rather significant rewrite of large portions of the connection manager anyway, this also takes the opportunity to improve several other aspects of the connection manager in the process such as implementing full context support, full tracking of all connection types by the manager itself, much more robust semaphore-based automatic connection limiting, cleaner persistent connection handling with independent limits, prevention of multiple connections of any type to the same address:port, more useful debug logging, and cleanly closing all connections during shutdown. It is also important to note that the following overall semantics have intentionally been changed versus the existing connection manager: - A maximum of 8 persistent connections is now imposed and they no longer count toward the configured target number of automatic outbound peers to maintain - Duplicate addresses (host:port) are now rejected by the connection manager for all types (inbound, outbound, manual, persistent) - Note that inbound conns from the same IP will necessarily have different ports, so the same max IP limits apply in that case - RPC 'node connect' for all connection attempts now: - Supports the RPC connection and server contexts - Properly handles duplicate address rejection including pending attempts - RPC 'node connect' for non-persistent conn attempts now: - Waits for the connection attempt result before returning - Returns an error if the connection attempt fails - Cancels the connection attempt if the RPC connection is closed before it succeeds - RPC 'node remove' now supports removing a pending connection by its persistent connection ID (since no peer ID exists before a valid connection is established) - It is no longer possible for state transitions to allow things like duplicate addresses or failed cancellation --- internal/connmgr/connmanager.go | 1353 +++++++++++++----- internal/connmgr/connmanager_test.go | 701 ++++----- internal/connmgr/conntype_test.go | 35 + internal/connmgr/error.go | 26 +- internal/connmgr/error_test.go | 8 +- internal/rpcserver/interface.go | 2 +- internal/rpcserver/rpcserver.go | 94 +- internal/rpcserver/rpcserverhandlers_test.go | 8 +- rpcadaptors.go | 159 +- server.go | 133 +- 10 files changed, 1544 insertions(+), 975 deletions(-) create mode 100644 internal/connmgr/conntype_test.go diff --git a/internal/connmgr/connmanager.go b/internal/connmgr/connmanager.go index 60e1afcc1..83ddf10c5 100644 --- a/internal/connmgr/connmanager.go +++ b/internal/connmgr/connmanager.go @@ -7,19 +7,29 @@ package connmgr import ( "context" + "errors" "fmt" "net" + "strconv" "sync" "sync/atomic" "time" + + "github.com/decred/dcrd/addrmgr/v4" ) -var ( +const ( + // MaxPersistent is the maximum number of persistent connections that can be + // added. Persistent connections do not count towards the automatic + // outbound connection limits. + MaxPersistent = 8 +) - // maxRetryDuration is the max duration of time retrying of a persistent - // connection is allowed to grow to. This is necessary since the retry - // logic uses a backoff mechanism which increases the interval base times - // the number of retries that have been done. +var ( + // maxRetryDuration is the maximum duration a persistent connection retry + // backoff is allowed to grow to. This is necessary since the retry logic + // uses a backoff mechanism which increases the interval base times the + // number of retries that have been done. maxRetryDuration = time.Minute * 5 ) @@ -35,75 +45,166 @@ const ( // defaultTargetOutbound is the default number of outbound connections to // maintain. - defaultTargetOutbound = uint32(8) + defaultTargetOutbound = 8 ) -// ConnState represents the state of the requested connection. -type ConnState uint32 +// ConnectionType specifies the different types of supported connections. +type ConnectionType uint8 -// ConnState can be either pending, established, disconnected or failed. When -// a new connection is requested, it is attempted and categorized as -// established or failed depending on the connection result. An established -// connection which was disconnected is categorized as disconnected. const ( - ConnPending ConnState = iota - ConnEstablished - ConnDisconnected - ConnFailed - ConnCanceled + // ConnTypeInbound indicates the connection was established by a remote + // peer. No further details are known about this connection until a + // handshake takes place. + ConnTypeInbound ConnectionType = iota + + // ConnTypeOutbound indicates a normal outbound connection that was + // established with no additional restrictions imposed on the type of + // information that the local peer/server is willing to relay. + // + // Note that this in no way implies further restrictions may not be + // negotiated depending on the protocol messages exchanged between the two + // peers. + ConnTypeOutbound + + // ConnTypeManual indicates an outbound connection that was manually + // requested via [ConnManager.Connect] or [ConnManager.AddPersistent]. In + // practice, this connection type is the result of requesting manual + // connections via an RPC method (e.g. "node connect") or via command line + // configuration options (e.g. --addpeer and --connect). + ConnTypeManual + + // numConnTypes is the number of connection types. This entry MUST be the + // last entry in the enum. + numConnTypes ) -// ConnReq is the connection request to a network address. If permanent, the -// connection will be retried on disconnection. -type ConnReq struct { - // id is the unique identifier for this connection request. - id atomic.Uint64 +// connTypeStrings is a map of connection types to human-readable names for +// pretty printing. +var connTypeStrings = map[ConnectionType]string{ + ConnTypeInbound: "inbound", + ConnTypeOutbound: "outbound", + ConnTypeManual: "manual", +} - // state is the current connection state for this connection request. - state atomic.Uint32 +// String returns the [ConnectionType] in human-readable form. +func (connType ConnectionType) String() string { + if s, ok := connTypeStrings[connType]; ok { + return s + } - // The following fields are owned by the connection manager and must not - // be accessed without its connection mutex held. + return fmt.Sprintf("Unknown ConnectionType (%d)", uint8(connType)) +} + +// Conn houses information about a managed connection. It is the callers +// responsibility to always ensure [Conn.Close] is called when the connection +// is no longer required. +type Conn struct { + // The following variables are set at the time the instance is created and + // are safe for concurrent access. + // + // net.Conn is the underlying connection. It is embedded which makes all of + // its methods immediately available. // - // retryCount is the number of times a permanent connection request that - // fails to connect has been retried since the last successful connection. + // id is the unique identifier for this connection. // - // conn is the underlying network connection. It will be nil before a - // connection has been established. - retryCount uint32 - conn net.Conn + // connType specifies the connection type. + // + // remoteAddr is the remote address associated with the connection. It is + // a concrete address manager address. + // + // onClose is a callback that will be invoked when the connection is closed. + net.Conn + id uint64 + connType ConnectionType + remoteAddr addrmgr.NetAddress + onClose func() + + // closed houses whether or not the connection has already been closed. + closed atomic.Bool +} - // Addr is the address to connect to. - Addr net.Addr +// newConn returns a new connection given an underlying [net.Conn], connection +// ID, and connection type. +// +// The returned connection is automatically configured to invoke the provided on +// close handler followed by the [Config.OnDisconnection] that was configured +// when initially creating the connection manager when the connection is closed. +// The on close handler is invoked in the same goroutine as the caller of +// [Conn.Close] and [Config.OnDisconnection] is invoked in a separate goroutine. +func newConn(cm *ConnManager, conn net.Conn, id uint64, connType ConnectionType, remoteAddr *addrmgr.NetAddress, onClose func()) *Conn { + c := &Conn{Conn: conn, id: id, connType: connType, remoteAddr: *remoteAddr} + c.onClose = func() { + onClose() + if cm.cfg.OnDisconnection != nil { + go cm.cfg.OnDisconnection(c) + } + } + return c +} - // Permanent specifies whether or not the connection request represents what - // should be treated as a permanent connection, meaning the connection - // manager will try to always maintain the connection including retries with - // increasing backoff timeouts. - Permanent bool +// ID returns a unique identifier for the connection. +// +// This function is safe for concurrent access. +func (c *Conn) ID() uint64 { + return c.id } -// updateState updates the state of the connection request. -func (c *ConnReq) updateState(state ConnState) { - c.state.Store(uint32(state)) +// Close closes the connection. The [Config.OnDisconnection] that was +// configured when initially creating the connection manager will be invoked in +// a separate goroutine prior to closing the underlying connection. +// +// Repeated close attempts are ignored. Closing a connection that has already +// been closed will not return an error. +// +// This function is safe for concurrent access. +func (c *Conn) Close() error { + // Already closed. + if !c.closed.CompareAndSwap(false, true) { + return nil + } + + // Invoke close callback associated with the connection when it's closed. + if c.onClose != nil { + c.onClose() + } + + // Close the underlying connection. + return c.Conn.Close() } -// ID returns a unique identifier for the connection request. -func (c *ConnReq) ID() uint64 { - return c.id.Load() +// RemoteAddr returns the remote address manager network address associated with +// the connection. It returns a [net.Addr] to implement the [net.Conn] +// interface, but the underlying type will be a [*addrmgr.NetAddress]. +func (c *Conn) RemoteAddr() net.Addr { + return &c.remoteAddr } -// State is the connection state of the requested connection. -func (c *ConnReq) State() ConnState { - return ConnState(c.state.Load()) +// Type returns the [ConnectionType] of the connection. +// +// This function is safe for concurrent access. +func (c *Conn) Type() ConnectionType { + return c.connType } -// String returns a human-readable string for the connection request. -func (c *ConnReq) String() string { - if c.Addr == nil || c.Addr.String() == "" { - return fmt.Sprintf("reqid %d", c.ID()) - } - return fmt.Sprintf("%s (reqid %d)", c.Addr, c.ID()) +// pendingConnInfo houses information about a pending connection attempt. +type pendingConnInfo struct { + id uint64 + addr *addrmgr.NetAddress + cancel context.CancelFunc +} + +// persistentEntry houses information about a persistent connection that has +// been added to the connection manager. Once an ID has been assigned, all +// future connections established for the persistent connection will have the +// same ID. This allows it to be uniquely identified and removed later. +type persistentEntry struct { + id uint64 + addr *addrmgr.NetAddress + + // cancel shuts down the goroutine that maintains the persistent connection. + // It is owned by the connection manager and must not be accessed without + // its connection mutex held. + cancel context.CancelFunc } // Config holds the configuration options related to the connection manager. @@ -129,10 +230,13 @@ type Config struct { // This field will not have any effect if the Listeners field is not // also specified since there couldn't possibly be any accepted // connections in that case. - OnAccept func(net.Conn) + OnAccept func(*Conn) - // TargetOutbound is the number of outbound network connections to - // maintain. Defaults to 8. + // TargetOutbound is the number of outbound network connections to maintain + // automatically. Defaults to 8. + // + // Persistent connections do not count against this value. They have their + // own maximum limit defined by [MaxPersistent]. TargetOutbound uint32 // RetryDuration is the duration to wait before retrying connection @@ -141,11 +245,10 @@ type Config struct { // OnConnection is a callback that is fired when a new outbound // connection is established. - OnConnection func(*ConnReq, net.Conn) + OnConnection func(*Conn) - // OnDisconnection is a callback that is fired when an outbound - // connection is disconnected. - OnDisconnection func(*ConnReq) + // OnDisconnection is a callback that is fired when a connection is closed. + OnDisconnection func(*Conn) // GetNewAddress is a way to get an address to make a network connection // to. If nil, no new connections will be made automatically. @@ -161,13 +264,8 @@ type Config struct { // ConnManager provides a manager to handle network connections. type ConnManager struct { - // connReqCount is the number of connection requests that have been made and - // is primarily used to assign unique connection request IDs. - connReqCount atomic.Uint64 - - // assignIDMtx synchronizes the assignment of an ID to a connection request - // with overall connection request count above. - assignIDMtx sync.Mutex + // nextConnID is used to assign unique connection request IDs. + nextConnID atomic.Uint64 // quit is used for lifecycle management of the connection manager. quit chan struct{} @@ -176,424 +274,884 @@ type ConnManager struct { // creating time and treated as immutable after that. cfg Config - // failedAttempts tracks the total number of failed oubound connection - // attempts since the last successful connection made by the connection - // manager. It is primarily used to detect network outages in order to - // impose a retry timeout on achieving the target number of outbound - // connections which prevents runaway failed connection attempt churn. + // runPersistentChan is used to signal the persistent connections handler to + // launch a goroutine that attempts to always maintain an established + // connection with a given address. // - // This field is owned by the connection handler and must not be accessed - // outside of it. - failedAttempts uint64 + // It is a buffered channel with size [MaxPersistent]. + runPersistentChan chan *persistentEntry + + // outboundSem limits the number of active outbound connections. It does + // not apply to persistent connections which are separately limited to + // [MaxPersistent]. + activeOutboundsSem semaphore + + // The fields below this point are all protected by the connection mutex. + connMtx sync.Mutex - // The following fields are used to track the various connections managed - // by the connection manager. They are protected by the associated - // connection mutex. + // persistent tracks all registered persistent connection entries. // - // pending holds all registered connection requests that have yet to - // succeed. + // A persistent connection can be in one of three states: // - // conns represents the set of all active connections. - connMtx sync.Mutex - pending map[uint64]*ConnReq - conns map[uint64]*ConnReq -} + // - Established with the connection instance in the active map + // - Pending with an entry in the pending map + // - Awaiting a retry + // + // Regardless of the state, there will always be an entry in this map. + persistent map[uint64]*persistentEntry -// registerPending registers the provided connection request as a pending -// connection attempt. -// -// This function MUST be called with the connection mutex lock held (writes). -func (cm *ConnManager) registerPending(connReq *ConnReq) { - connReq.updateState(ConnPending) - cm.pending[connReq.ID()] = connReq + // pending holds all registered connection attempts that have yet to + // succeed. + pending map[uint64]*pendingConnInfo + + // active represents the set of all active connections. + active map[uint64]*Conn + + // connIDByAddr provides fast O(1) lookup of connection IDs by address + // (host:port). It is kept in sync with the persistent, pending, and active + // maps and is primarily used to efficiently reject duplicate connections. + connIDByAddr map[string]uint64 } -// newConnReq creates a new connection request and connects to the corresponding -// address. -func (cm *ConnManager) newConnReq(ctx context.Context) { - // Ignore during shutdown. - if ctx.Err() != nil { - return +// checkShutdown returns [ErrShutdown] when the connection manager quit channel +// has been closed. +func (cm *ConnManager) checkShutdown() error { + select { + case <-cm.quit: + const str = "connection manager shutdown" + return MakeError(ErrShutdown, str) + default: } + return nil +} - c := &ConnReq{} - c.id.Store(cm.connReqCount.Add(1)) +// stdlibNetAddrToAddrMgrNetAddr converts the provided standard lib [net.Addr] +// to a concrete address manager address. +func stdlibNetAddrToAddrMgrNetAddr(addr net.Addr) (*addrmgr.NetAddress, error) { + host, portStr, err := net.SplitHostPort(addr.String()) + if err != nil { + str := fmt.Sprintf("unable to split address %q", addr) + return nil, MakeError(ErrUnsupportedAddr, str) + } + port, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + str := fmt.Sprintf("invalid port for address %q", addr) + return nil, MakeError(ErrUnsupportedAddr, str) + } - // Register the pending connection attempt so it can be canceled via the - // [ConnManager.Remove] method. - cm.connMtx.Lock() - cm.registerPending(c) - cm.connMtx.Unlock() + addrType, addrBytes := addrmgr.EncodeHost(host) + if addrType == addrmgr.UnknownAddressType { + str := fmt.Sprintf("unable to determine address type for %q", addr) + return nil, MakeError(ErrUnsupportedAddr, str) + } - addr, err := cm.cfg.GetNewAddress() + now := time.Unix(time.Now().Unix(), 0) + netAddr, err := addrmgr.NewNetAddressFromParams(addrType, addrBytes, + uint16(port), now, 0) if err != nil { - cm.connMtx.Lock() - cm.handleFailedPending(ctx, c, err) - cm.connMtx.Unlock() - return + return nil, MakeError(ErrUnsupportedAddr, err.Error()) } + return netAddr, nil +} - c.Addr = addr +// addPendingInfo adds information about a pending connection attempt to the +// local state. +// +// This function MUST be called with the connection mutex held (writes). +func (cm *ConnManager) addPendingInfo(info *pendingConnInfo) { + cm.pending[info.id] = info + if _, ok := cm.persistent[info.id]; !ok { + cm.connIDByAddr[info.addr.String()] = info.id + } +} - cm.Connect(ctx, c) +// removePendingInfo removes a pending connection attempt from the local state. +// +// This function MUST be called with the connection mutex held (writes). +func (cm *ConnManager) removePendingInfo(info *pendingConnInfo) { + delete(cm.pending, info.id) + if _, ok := cm.persistent[info.id]; !ok { + delete(cm.connIDByAddr, info.addr.String()) + } } -// handleFailedConn handles a connection failed due to a disconnect or any other -// failure. Permanent connection requests are retried after the configured -// retry duration. A new connection request is created if required. +// addActiveConn adds an established connection to the local state. // -// In the event there have been [maxFailedAttempts] failed successive attempts, -// new connections will be retried after the configured retry duration. +// This function MUST be called with the connection mutex held (writes). +func (cm *ConnManager) addActiveConn(conn *Conn) { + cm.active[conn.id] = conn + if _, ok := cm.persistent[conn.id]; !ok { + cm.connIDByAddr[conn.remoteAddr.String()] = conn.id + } +} + +// removeActiveConn removes an established connection from the local state. It +// has no effect if the connection has already been removed from the active map. // -// This function MUST be called with the connection lock held (writes). -func (cm *ConnManager) handleFailedConn(ctx context.Context, c *ConnReq) { - // Ignore during shutdown. - select { - case <-cm.quit: - return - case <-ctx.Done(): +// This function MUST be called with the connection mutex held (writes). +func (cm *ConnManager) removeActiveConn(conn *Conn) { + // The active connection might have already been removed before releasing + // the mutex to call [Conn.Close]. + if _, ok := cm.active[conn.id]; !ok { return - default: } - // Reconnect to permanent connection requests after a retry timeout with - // an increasing backoff up to a max for repeated failed attempts. - if c.Permanent { - c.retryCount++ - retryWait := time.Duration(c.retryCount) * cm.cfg.RetryDuration - retryWait = min(retryWait, maxRetryDuration) - log.Debugf("Retrying connection to %v in %v", c, retryWait) - go func() { - select { - case <-time.After(retryWait): - cm.Connect(ctx, c) - case <-cm.quit: - case <-ctx.Done(): - } - }() - return + delete(cm.active, conn.id) + if _, ok := cm.persistent[conn.id]; !ok { + delete(cm.connIDByAddr, conn.remoteAddr.String()) } +} - // Nothing more to do when the method to automatically get new addresses - // to connect to isn't configured. - if cm.cfg.GetNewAddress == nil { - return +// addPersistentEntry adds a persistent connection entry to the local state. +// +// This function MUST be called with the connection mutex held (writes). +func (cm *ConnManager) addPersistentEntry(entry *persistentEntry) { + cm.persistent[entry.id] = entry + cm.connIDByAddr[entry.addr.String()] = entry.id +} + +// removePersistentEntry removes a persistent connection entry from the local +// state. +// +// This function MUST be called with the connection mutex held (writes). +func (cm *ConnManager) removePersistentEntry(entry *persistentEntry) { + delete(cm.persistent, entry.id) + _, pending := cm.pending[entry.id] + _, active := cm.active[entry.id] + if !pending && !active { + delete(cm.connIDByAddr, entry.addr.String()) } +} - // Wait to attempt new connections when there are too many successive - // failures. This prevents massive connection spam when no connections can - // be made, such as a network outtage. - cm.failedAttempts++ - if cm.failedAttempts >= maxFailedAttempts { - log.Debugf("Max failed connection attempts reached: [%d] -- retrying "+ - "connection in: %v", maxFailedAttempts, cm.cfg.RetryDuration) - go func() { - select { - case <-time.After(cm.cfg.RetryDuration): - cm.newConnReq(ctx) - case <-cm.quit: - case <-ctx.Done(): - } - }() - return +// rejectConnectedAddr returns an error if there is already either an +// established connection to the provided address or a pending attempt to +// connect to it. Persistent connections in the retry state are intentionally +// not detected. +// +// This function MUST be called with the connection mutex held (reads). +func (cm *ConnManager) rejectConnectedAddr(addr *addrmgr.NetAddress) error { + connID, ok := cm.connIDByAddr[addr.String()] + if !ok { + return nil } - // Otherwise, attempt a new connection with a new address now. - go cm.newConnReq(ctx) + if _, ok := cm.pending[connID]; ok { + str := fmt.Sprintf("a pending connection to %s already exists", addr) + return MakeError(ErrAlreadyPending, str) + } + if _, ok := cm.active[connID]; ok { + str := fmt.Sprintf("a connection to %s is already established", addr) + return MakeError(ErrAlreadyConnected, str) + } + return nil } -// handleFailedPending handles failed pending connection requests. Connection -// requests that were canceled are ignored. Otherwise, their state is updated -// to mark it failed and it is passed along to [ConnManager.handleFailedConn] to -// possibly retry or be reused for a new connection depending on settings. +// findPersistentAddrID attempts to find and return the persistent connection ID +// associated with the passed address. The bool return indicates whether or not +// it was found. // -// This function MUST be called with the connection lock held (writes). -func (cm *ConnManager) handleFailedPending(ctx context.Context, c *ConnReq, failedErr error) { - if _, ok := cm.pending[c.ID()]; !ok { - log.Debugf("Ignoring connection for canceled conn req: %v", c) - return +// This function MUST be called with the connection mutex held (reads). +func (cm *ConnManager) findPersistentAddrID(addr net.Addr) (uint64, bool) { + connID, ok := cm.connIDByAddr[addr.String()] + if !ok { + return 0, false + } + + entry, ok := cm.persistent[connID] + if !ok { + return 0, false } - c.updateState(ConnFailed) - log.Debugf("Failed to connect to %v: %v", c, failedErr) - cm.handleFailedConn(ctx, c) + return entry.id, true } -// Connect assigns an id and dials a connection to the address of the connection -// request using the provided context and the dial function configured when -// initially creating the connection manager. +// rejectPersistentAddr returns an error if there is already a persistent +// connection entry for the given address. // -// The connection attempt will be ignored if the connection manager has been -// shutdown by canceling the lifecycle context the Run method was invoked with -// or the provided connection request is already in the failed state. +// This function MUST be called with the connection mutex held (reads). +func (cm *ConnManager) rejectPersistentAddr(addr *addrmgr.NetAddress) error { + if _, ok := cm.findPersistentAddrID(addr); ok { + str := fmt.Sprintf("a persistent connection for %s already exists", addr) + return MakeError(ErrDuplicatePersistent, str) + } + return nil +} + +// rejectDuplicateAddr returns an error if there is already a persistent +// connection entry, a pending connection attempt, or an established connection +// for the given address. // -// Note that the context parameter to this function and the lifecycle context -// may be independent. -func (cm *ConnManager) Connect(ctx context.Context, c *ConnReq) { +// This function MUST be called with the connection mutex held (reads). +func (cm *ConnManager) rejectDuplicateAddr(addr *addrmgr.NetAddress) error { + if err := cm.rejectPersistentAddr(addr); err != nil { + return err + } + if err := cm.rejectConnectedAddr(addr); err != nil { + return err + } + return nil +} + +// dial attempts to connect to the provided address and returns a connection +// configured with the provided params on success. +// +// A new globally unique connection ID is assigned unless one is provided by +// passing a non-nil value in the persistent connection ID parameter. This +// allows persistent connections to retain the same ID across reconnects. +// +// Attempts to dial addresses that are already connected, pending, or (in most +// cases) persistent will return an error as described below. Only established +// and pending connections are rejected when a non-nil persistent connection ID +// is passed. +// +// On success, the returned connection is configured to remove itself from the +// set of all active connections and invoke the provided on close callback (if +// set) when it is closed. +// +// On failure, the provided on close callback (when non-nil) will be invoked +// prior to returning. +// +// In addition to errors returned by [Config.Dial], the following errors are +// possible: +// +// - [ErrDuplicatePersistent] when a persistent connection already exists for +// the address and no persistent connection ID is provided +// - [ErrAlreadyPending] when there is already a pending connection attempt +// to the address +// - [ErrAlreadyConnected] when there is already an established connection to +// the address +// - [ErrShutdown] when the connection manager is shutting down +// - [context.Canceled] or [context.DeadlineExceeded] depending on the +// provided context or when the dialer fails to establish a connection +// before the timeout configured for the connection manager +// +// This function is safe for concurrent access. +func (cm *ConnManager) dial(ctx context.Context, addr net.Addr, connType ConnectionType, onClose func(), persistentConnID *uint64) (*Conn, error) { + var skipOnClose bool + defer func() { + if !skipOnClose && onClose != nil { + onClose() + } + }() + // Ignore during shutdown and when caller provided context is already // canceled. - select { - case <-cm.quit: - return - default: + if err := cm.checkShutdown(); err != nil { + return nil, err } if ctx.Err() != nil { - return + return nil, ctx.Err() } - // During the time we wait for retry there is a chance that this - // connection was already cancelled. - if c.State() == ConnCanceled { - log.Debugf("Ignoring connect for canceled connreq=%v", c) - return + rAddr, err := stdlibNetAddrToAddrMgrNetAddr(addr) + if err != nil { + return nil, err } - // Assign an ID and register the pending connection attempt when an ID has - // not already been assigned so it can be canceled via the - // [ConnManager.Remove] method. + // Reject attempts to dial addresses that are already connected (or in the + // process of it). Additionally, reject attempts to dial existing + // persistent addresses unless a persistent connection ID was provided + // indicating the dial is specifically for a persistent connection. // - // Note that the assignment of the ID and the overall request count need to - // be synchronized. So long as this is the only place an existing conn - // request ID is updated and this method is not called concurrently on the - // same conn request, no race could occur. However, those preconditions - // would be easy to inadvertently violate via updates to the code, so the - // mutex is added here for additional safety. - var doRegisterPending bool - cm.assignIDMtx.Lock() - if c.ID() == 0 { - c.id.Store(cm.connReqCount.Add(1)) - doRegisterPending = true - } - connReqID := c.ID() - cm.assignIDMtx.Unlock() - if doRegisterPending { - cm.connMtx.Lock() - cm.registerPending(c) + // This needs to be done under the same lock as adding a pending entry to + // prevent the possibility of two simultaneous attempts logic racing. + rejectFn := cm.rejectDuplicateAddr + if persistentConnID != nil { + rejectFn = cm.rejectConnectedAddr + } + cm.connMtx.Lock() + if err := rejectFn(rAddr); err != nil { cm.connMtx.Unlock() + log.Debugf("Rejected connection: %v", err) + return nil, err } - log.Debugf("Attempting to connect to %v", c) - - // Attempt to establish the connection to the address associated with the - // connection request. Apply a timeout if requested. + // Apply a dial timeout if requested. Otherwise, use a regular cancel + // context to support canceling the pending connection later. + var cancel context.CancelFunc if cm.cfg.DialTimeout != 0 { - var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, cm.cfg.DialTimeout) - defer cancel() + } else { + ctx, cancel = context.WithCancel(ctx) } - var conn net.Conn - conn, err := cm.cfg.Dial(ctx, c.Addr.Network(), c.Addr.String()) - if err != nil { + defer cancel() + + // Register the pending connection attempt and defer its removal to ensure + // it is always removed on failure. + var connID uint64 + if persistentConnID != nil { + connID = *persistentConnID + } else { + connID = cm.nextConnID.Add(1) + } + info := &pendingConnInfo{connID, rAddr, cancel} + cm.addPendingInfo(info) + cm.connMtx.Unlock() + defer func() { cm.connMtx.Lock() - cm.handleFailedPending(ctx, c, err) + if _, ok := cm.pending[connID]; ok { + cm.removePendingInfo(info) + } cm.connMtx.Unlock() - return + }() + + log.Debugf("Attempting to connect to %v (id: %d, type: %v)", addr, connID, + connType) + + // Attempt to establish the connection to the address. + netConn, err := cm.cfg.Dial(ctx, addr.Network(), addr.String()) + if err != nil { + var logErrStr string + switch { + case errors.Is(err, context.DeadlineExceeded): + logErrStr = fmt.Sprintf("no response for %v", cm.cfg.DialTimeout) + case errors.Is(err, context.Canceled): + // Override the error with the shutdown error instead when that is + // the upstream cause of the context cancel. + if sErr := cm.checkShutdown(); sErr != nil { + err = sErr + break + } + logErrStr = "attempt manually canceled" + } + if logErrStr == "" { + logErrStr = err.Error() + } + log.Debugf("Failed to connect to %v: %v", addr, logErrStr) + return nil, err } + // Ignore any connections that succeed after they were manually canceled. cm.connMtx.Lock() - defer cm.connMtx.Unlock() + if _, ok := cm.pending[connID]; !ok { + cm.connMtx.Unlock() + netConn.Close() + log.Debugf("Ignoring canceled connection %v (id: %d, type: %v)", addr, + connID, connType) + return nil, context.Canceled + } - if _, ok := cm.pending[connReqID]; !ok { - conn.Close() - log.Debugf("Ignoring connection for canceled connreq=%v", c) - return + // Remove the pending entry under the lock. This ensures the maps are + // mutually exclusive for a given id. + cm.removePendingInfo(info) + + // Successful return means the on close callback is not invoked until the + // connection is closed. + skipOnClose = true + + // Setup a close callback to remove the connection from the map that tracks + // all active connections when the connection is closed and also to invoke + // the close callback provided by the caller when specified. + var conn *Conn + dialOnClose := func() { + cm.connMtx.Lock() + cm.removeActiveConn(conn) + cm.connMtx.Unlock() + if onClose != nil { + onClose() + } + log.Debugf("Disconnected from %v (id: %d, type: %v)", addr, connID, + connType) } - c.updateState(ConnEstablished) - c.conn = conn - cm.conns[connReqID] = c - log.Debugf("Connected to %v", c) - c.retryCount = 0 - cm.failedAttempts = 0 - delete(cm.pending, connReqID) + // Create a new connection instance with the connection ID and type and add + // an entry to the map that tracks all active connections. + conn = newConn(cm, netConn, connID, connType, rAddr, dialOnClose) + cm.addActiveConn(conn) + cm.connMtx.Unlock() + + log.Debugf("Connected to %v (id: %d, type: %v)", addr, connID, connType) + return conn, nil +} +// Connect assigns an ID and dials a connection to the provided address using +// the provided context and the dial function configured when initially creating +// the connection manager. +// +// Attempts to dial addresses that already have an established, pending, or +// persistent connection will return an error as described below. +// +// The connection will have type [ConnTypeManual]. +// +// Note that the context parameter to this function and the lifecycle context +// may be independent. +// +// In addition to errors returned by the underlying dialer, the following errors +// are possible: +// +// - [ErrDuplicatePersistent] when a persistent connection already exists for +// the address (regardless of its current state) +// - [ErrAlreadyPending] when there is already a pending connection attempt +// to the address +// - [ErrAlreadyConnected] when there is already an established connection to +// the address +// - [ErrShutdown] when the connection manager is shutting down +// - [context.Canceled] or [context.DeadlineExceeded] depending on the +// provided context or when the dialer fails to establish a connection +// before the timeout configured for the connection manager +// +// This function is safe for concurrent access. +func (cm *ConnManager) Connect(ctx context.Context, addr net.Addr) (*Conn, error) { + conn, err := cm.dial(ctx, addr, ConnTypeManual, nil, nil) + if err != nil { + return nil, err + } if cm.cfg.OnConnection != nil { - go cm.cfg.OnConnection(c, conn) + go cm.cfg.OnConnection(conn) } + return conn, nil } -// handleDisconnected handles a connection that has been disconnected. +// Disconnect either disconnects the connection corresponding to the given +// connection id or cancels any pending attempts associated with it. Persistent +// connections will be retried with an increasing backoff duration. // -// This function MUST be called with the connection mutex held (writes). -func (cm *ConnManager) handleDisconnected(id uint64, retry bool) { - // Mark the connection request as canceled and remove it from the pending - // connections when it is still pending. Since the connection attempt is - // taking place asynchronously, this ensures any later successful connection - // is ignored. - connReq, ok := cm.pending[id] - if ok { - connReq.updateState(ConnCanceled) - log.Debugf("Canceling: %v", connReq) - delete(cm.pending, id) - } - - connReq, ok = cm.conns[id] - if !ok { - log.Errorf("Unknown connid=%d", id) - return +// This function is safe for concurrent access. +func (cm *ConnManager) Disconnect(id uint64) error { + // Cancel and remove pending entries. Even though the pending entry will be + // removed from the map regardless by the dialer, doing it now ensures that + // any connections that are already in progress and later succeed are + // ignored. + cm.connMtx.Lock() + if info, ok := cm.pending[id]; ok { + info.cancel() + cm.removePendingInfo(info) + cm.connMtx.Unlock() + return nil } - // Close the underlying connection and invoke the associated callback (if - // assigned). - log.Debugf("Disconnected from %v", connReq) - delete(cm.conns, id) - if connReq.conn != nil { - connReq.conn.Close() - } - if cm.cfg.OnDisconnection != nil { - go cm.cfg.OnDisconnection(connReq) + conn := cm.active[id] + if conn != nil { + cm.connMtx.Unlock() + conn.Close() // Close requires the conn mutex. + return nil } + _, isPersistent := cm.persistent[id] + cm.connMtx.Unlock() - // Mark the associated connection request as disconnected and return when no - // further attempts will be made now that all internal state has been - // cleaned up. - if !retry { - connReq.updateState(ConnDisconnected) - return + // Not found in active or pending, but it might still be a persistent conn + // waiting to retry. No error in that case. + if isPersistent { + return nil } - // Otherwise, attempt a reconnection when the associated connection request - // is marked as permanent or there are not already enough outbound peers to - // satisfy the target number of outbound peers. - numConns := uint32(len(cm.conns)) - if connReq.Permanent || numConns < cm.cfg.TargetOutbound { - // The connection request is reused for permanent ones, so add it back - // to the pending map in that case so that subsequent processing of - // connections and failures do not ignore the request. - if connReq.Permanent { - cm.registerPending(connReq) - log.Debugf("Reconnecting to %v", connReq) - } - - // A background context is the only viable choice here. It is not - // ideal, but it is acceptable, because, ultimately, this context is - // really only used for persistent peers when they retry and persistent - // peers are not tied to a specific context anyway. They are instead - // removed by other means. Due to that, there also is no machinery to - // cancel a given persistent peer from a given context anyway. - // - // Future work ideally should refactor the persistent peer handling to - // have proper full context support. - cm.handleFailedConn(context.Background(), connReq) - } + str := fmt.Sprintf("no entries with id %d exist", id) + return MakeError(ErrNotFound, str) } -// Disconnect disconnects the connection corresponding to the given connection -// id. Permanent connections will be retried with an increasing backoff -// duration. +// Remove closes, cancels, or removes the connection corresponding to the given +// connection id. // -// This function is safe for concurrent access. -func (cm *ConnManager) Disconnect(id uint64) { - cm.connMtx.Lock() - cm.handleDisconnected(id, true) - cm.connMtx.Unlock() -} - -// Remove removes the connection corresponding to the given connection id from -// known connections. +// This function may be used for all connections states and types, including +// established, pending, and persistent connections. // -// NOTE: This method can also be used to cancel a pending connection attempt -// that hasn't yet succeeded. +// Connections that are already established are closed and connection attempts +// that are still pending are canceled. Persistent connections are additionally +// removed so that no future retries will occur. // // This function is safe for concurrent access. -func (cm *ConnManager) Remove(id uint64) { +func (cm *ConnManager) Remove(id uint64) error { + // When the ID is for a persistent connection, cancel the associated context + // and remove it from the persistent map to prevent future retries. cm.connMtx.Lock() - cm.handleDisconnected(id, false) + entry, isPersistent := cm.persistent[id] + if isPersistent { + cm.removePersistentEntry(entry) + if entry.cancel != nil { + entry.cancel() + } + log.Debugf("Removed persistent connection to %v (id %d)", entry.addr, + entry.id) + } + + // Cancel and remove pending entries. Even though the pending entry will be + // removed from the map regardless by the dialer, doing it now ensures that + // any connections that are already in progress and later succeed are + // ignored. + if info, ok := cm.pending[id]; ok { + info.cancel() + cm.removePendingInfo(info) + cm.connMtx.Unlock() + return nil + } + + // Close active connections and remove the entry from the active map. + // + // Even though the connection close handler would remove it from the map, it + // needs to be removed under same lock as removals from the persistent map + // to prevent the possibility of two simultaneous attempts logic racing. + if conn, ok := cm.active[id]; ok { + cm.removeActiveConn(conn) + cm.connMtx.Unlock() + conn.Close() // Close requires the conn mutex. + return nil + } cm.connMtx.Unlock() + + // Not found in active or pending, but no error if it was a removed + // persistent conn. + if isPersistent { + return nil + } + + str := fmt.Sprintf("no entries with id %d exist", id) + return MakeError(ErrNotFound, str) } -// findPendingByAddr attempts to find and return the pending connection request -// associated with the provided address. It returns nil if no matching request -// is found. -// -// This function MUST be called with the connection mutex held (writes). -func (cm *ConnManager) findPendingByAddr(addr net.Addr) *ConnReq { - pendingAddr := addr.String() - for _, req := range cm.pending { - if req == nil || req.Addr == nil { +// inboundStdlibNetAddrToAddrMgrAddr converts the provided standard library +// [net.Addr] that is expected to be from an inbound connection to a concrete +// address manager address. +func inboundStdlibNetAddrToAddrMgrAddr(addr net.Addr) (*addrmgr.NetAddress, error) { + // Fast path for inbounds since they will almost always be one of these + // given they are created by [net.Listener.Accept]. + switch a := addr.(type) { + case *net.TCPAddr: + return addrmgr.NewNetAddressFromIPPort(a.IP, uint16(a.Port), 0), nil + case *net.UDPAddr: + return addrmgr.NewNetAddressFromIPPort(a.IP, uint16(a.Port), 0), nil + } + + // Fall back to slower string parsing. + return stdlibNetAddrToAddrMgrNetAddr(addr) +} + +// listenHandler accepts incoming connections on a given listener. It must be +// run as a goroutine. +func (cm *ConnManager) listenHandler(ctx context.Context, listener net.Listener) { + log.Infof("Server listening on %s", listener.Addr()) + defer log.Tracef("Listener handler done for %s", listener.Addr()) + + for ctx.Err() == nil { + netConn, err := listener.Accept() + if err != nil { + // Only log the error if not forcibly shutting down. + if ctx.Err() == nil { + log.Errorf("Can't accept connection: %v", err) + } + continue + } + + rAddr, err := inboundStdlibNetAddrToAddrMgrAddr(netConn.RemoteAddr()) + if err != nil { + log.Warnf("Dropped connection from %v: failed to parse address", + netConn.RemoteAddr()) + netConn.Close() continue } - if pendingAddr == req.Addr.String() { - return req + + // Reject connections with the same host:port as any existing pending, + // established, or persistent connections. Note that this does NOT + // prevent multiple connections from the same host given they typically + // will be coming from different ports. + // + // The aforementioned behavior is intentional as it allows connections + // from the same host to be independently limited to more than one + // elsewhere. + cm.connMtx.Lock() + if err := cm.rejectDuplicateAddr(rAddr); err != nil { + cm.connMtx.Unlock() + log.Debugf("Dropped connection from %v: %v", rAddr, err) + netConn.Close() + continue } + cm.connMtx.Unlock() + + go func(netConn net.Conn) { + // Create a new connection instance with the next globally unique + // connection ID, add an entry to the map that tracks all active + // connections, and invoke the configured accept callback with it. + // + // Also set a close callback to remove the connection from the map + // when it is closed. + id := cm.nextConnID.Add(1) + const connType = ConnTypeInbound + var conn *Conn + onClose := func() { + cm.connMtx.Lock() + cm.removeActiveConn(conn) + cm.connMtx.Unlock() + log.Debugf("Disconnected from %v (id: %d, type: %v)", rAddr, id, + connType) + } + conn = newConn(cm, netConn, id, connType, rAddr, onClose) + cm.connMtx.Lock() + cm.addActiveConn(conn) + cm.connMtx.Unlock() + log.Debugf("Accepted connection from %v (id: %d, type: %v)", rAddr, + id, connType) + cm.cfg.OnAccept(conn) + }(netConn) } - return nil } -// CancelPending removes the connection corresponding to the given address -// from the list of pending failed connections. +// AddPersistent adds an address the connection manager will attempt to always +// maintain an established connection with until the persistent connection entry +// is removed via [ConnManager.Remove] or the context associated with +// [ConnManager.Run] is canceled. +// +// When the associated connection is dropped, it will be retried with an +// increasing backoff, up to a maximum for repeated failed attempts. +// +// A maximum of [MaxPersistent] connections may be added. Attempting to add any +// more will return [ErrMaxPersistent]. +// +// Adding a duplicate persistent address will return [ErrDuplicatePersistent] +// and adding addresses that already have an established or pending connection +// will return [ErrAlreadyConnected] or [ErrAlreadyPending], respectively. +// +// An ID is returned that uniquely identifies the persistent connection. All +// future connections established will have the same ID. // -// Returns an error if the connection manager is stopped or there is no pending -// connection for the given address. -func (cm *ConnManager) CancelPending(addr net.Addr) error { +// Persistent connections do not count against [Config.TargetOutbound]. +// +// Note that the actual connections to the address happen asynchronously and +// will have type [ConnTypeManual]. Established connections will invoke the +// [Config.OnConnection] callback that was configured when initially creating +// the connection manager. +// +// Since connections happen asynchronously, the error only indicates issues with +// adding the persistent connection entry. +// +// The persistent connection may be removed by passing the returned connection +// ID to [ConnManager.Remove]. +// +// This function is safe for concurrent access. +func (cm *ConnManager) AddPersistent(addr net.Addr) (uint64, error) { cm.connMtx.Lock() defer cm.connMtx.Unlock() - connReq := cm.findPendingByAddr(addr) - if connReq == nil { - str := fmt.Sprintf("no pending connection to %v", addr) - return MakeError(ErrNotFound, str) + if len(cm.persistent)+1 > MaxPersistent { + str := fmt.Sprintf("a maximum of %d persistent connections is allowed", + MaxPersistent) + return 0, MakeError(ErrMaxPersistent, str) } - delete(cm.pending, connReq.ID()) - connReq.updateState(ConnCanceled) - log.Debugf("Canceled pending connection to %v", addr) - return nil + rAddr, err := stdlibNetAddrToAddrMgrNetAddr(addr) + if err != nil { + return 0, err + } + + if err := cm.rejectDuplicateAddr(rAddr); err != nil { + return 0, err + } + + entry := &persistentEntry{id: cm.nextConnID.Add(1), addr: rAddr} + cm.addPersistentEntry(entry) + log.Debugf("Added persistent connection to %v (id: %d)", addr, entry.id) + + // The channel is buffered with the max allowed persistent conns, so there + // is no possibility of blocking here. This approach allows persistent + // peers to be added both before and after the connection manager is running + // without starting the goroutines before it is running. + cm.runPersistentChan <- entry + return entry.id, nil } -// ForEachConnReq calls the provided function with each connection request known -// to the connection manager, including pending requests. Returning an error -// from the provided function will stop the iteration early and return said -// error from this function. +// IsPersistent returns whether or not the provided connection id belongs to a +// persistent connection. // // This function is safe for concurrent access. +func (cm *ConnManager) IsPersistent(id uint64) bool { + cm.connMtx.Lock() + _, ok := cm.persistent[id] + cm.connMtx.Unlock() + return ok +} + +// FindPersistentAddrID attempts to find and return the persistent connection ID +// associated with the passed address. The bool return indicates whether or not +// it was found. // -// NOTE: This must not call any other connection manager methods during -// iteration or it will result in a deadlock. -func (cm *ConnManager) ForEachConnReq(f func(c *ConnReq) error) error { +// This function is safe for concurrent access. +func (cm *ConnManager) FindPersistentAddrID(addr net.Addr) (uint64, bool) { cm.connMtx.Lock() - defer cm.connMtx.Unlock() + id, ok := cm.findPersistentAddrID(addr) + cm.connMtx.Unlock() + return id, ok +} - var err error - for _, connReq := range cm.pending { - err = f(connReq) - if err != nil { - return err +// runPersistent attempts to maintain a persistent connection to the provided +// address until the passed context is canceled. +// +// When the associated connection is dropped, it will be retried with an +// increasing backoff, up to a maximum for repeated failed attempts. +// +// This MUST be run as a goroutine. +func (cm *ConnManager) runPersistent(ctx context.Context, connID uint64, addr net.Addr) { + // Ensure the connection is closed when the goroutine exits. + var conn *Conn + defer func() { + if conn != nil { + conn.Close() } + }() + + // Setup a callback that notifies a disconnect channel for use below and + // start with the channel signaled. + disconnected := make(chan struct{}, 1) + disconnected <- struct{}{} + onClose := func() { + disconnected <- struct{}{} } - for _, connReq := range cm.conns { - err = f(connReq) + + var retryCount uint32 + var retryAfter <-chan time.Time + var lastAttempt time.Time + for { + // Wait for disconnect or retry timer when it's set. + select { + case <-ctx.Done(): + return + case <-cm.quit: + return + case <-retryAfter: + retryAfter = nil + case <-disconnected: + // Wait to retry any time the connection was not maintained for at + // least a single retry interval. + // + // This approach is used over only incrementing the retry count when + // the dial fails to effectively rate limit the attempts with an + // increasing backoff regardless of the reason a stable connection + // was not maintained. + // + // For example, the remote might repeatedly reject the peer for a + // variety of reasons (max limits, not enough peers of a desired + // type, etc) after a successful connection is made. + if !lastAttempt.IsZero() && time.Since(lastAttempt) < cm.cfg.RetryDuration { + // Reconnect after a retry timeout with an increasing backoff up + // to a max for repeated failed attempts. + const maxUint32 = 1<<32 - 1 + if retryCount < maxUint32 { + retryCount++ + } + retryWait := time.Duration(retryCount) * cm.cfg.RetryDuration + retryWait = min(retryWait, maxRetryDuration) + log.Debugf("Retrying connection to %v in %v (retries %d)", addr, + retryWait, retryCount) + retryAfter = time.After(retryWait) + continue + } + + // A connection succeeded and was maintained for at least a single + // retry interval. + // + // Clear the retry state. + retryCount = 0 + retryAfter = nil + } + + lastAttempt = time.Now() + var err error + conn, err = cm.dial(ctx, addr, ConnTypeManual, onClose, &connID) if err != nil { - return err + if ctx.Err() != nil { + return + } + + // Retry, potentially after a timeout with backoff. + continue + } + + // Successful connection. + if cm.cfg.OnConnection != nil { + go cm.cfg.OnConnection(conn) } } - return nil } -// listenHandler accepts incoming connections on a given listener. It must be -// run as a goroutine. -func (cm *ConnManager) listenHandler(ctx context.Context, listener net.Listener) { - log.Infof("Server listening on %s", listener.Addr()) +// persistentConnsHandler handles launching individual goroutines for persistent +// connections. +func (cm *ConnManager) persistentConnsHandler(ctx context.Context) { + for { + select { + case entry := <-cm.runPersistentChan: + pCtx, cancel := context.WithCancel(ctx) + cm.connMtx.Lock() + entry.cancel = cancel + cm.connMtx.Unlock() + go cm.runPersistent(pCtx, entry.id, entry.addr) + + case <-ctx.Done(): + return + } + } +} + +// targetOutboundHandler attempts to automatically maintain the target number of +// outbound connections configured via [Config.TargetOutbound] when initially +// creating the connection manager. +// +// This MUST be run as a goroutine. +func (cm *ConnManager) targetOutboundHandler(ctx context.Context) { + log.Trace("Starting target outbound handler") + defer log.Trace("Target outbound handler done") + + // failedAttempts tracks the total number of failed outbound connection + // attempts since the last successful connection. It is primarily used to + // detect network outages in order to impose a retry timeout on achieving + // the target number of outbound connections which prevents runaway failed + // connection attempt churn. + // + // Overflow is not checked since it would be virtually impossible to hit + // anywhere max uint64 in practice and even if it ever happened, the only + // consequence would potentially be a few extra retries before it hit the + // max failures again. + var failedAttempts atomic.Uint64 + for ctx.Err() == nil { - conn, err := listener.Accept() - if err != nil { - // Only log the error if not forcibly shutting down. - if ctx.Err() == nil { - log.Errorf("Can't accept connection: %v", err) + // Pause automatic outbound connections for a retry timeout after too + // many failed connection attempts. The network very likely has become + // temporarily unreachable. + if failedAttempts.Load() >= maxFailedAttempts { + log.Debugf("Max failed connection attempts reached [%d] -- "+ + "pausing connections for %v", maxFailedAttempts, + cm.cfg.RetryDuration) + + select { + case <-time.After(cm.cfg.RetryDuration): + case <-cm.quit: + return + case <-ctx.Done(): + return } + } + + // Wait for a permit to make another outbound connection. + if !cm.activeOutboundsSem.Acquire(ctx) { + return + } + + addr, err := cm.cfg.GetNewAddress() + if err != nil { + failedAttempts.Add(1) + log.Debugf("Failed to get address for outbound connection: %v", err) + cm.activeOutboundsSem.Release() continue } - go cm.cfg.OnAccept(conn) - } - log.Tracef("Listener handler done for %s", listener.Addr()) + go func(addr net.Addr) { + onClose := cm.activeOutboundsSem.Release + conn, err := cm.dial(ctx, addr, ConnTypeOutbound, onClose, nil) + if err != nil { + failedAttempts.Add(1) + return + } + + failedAttempts.Store(0) + if cm.cfg.OnConnection != nil { + go cm.cfg.OnConnection(conn) + } + }(addr) + } } // Run starts the connection manager along with its configured listeners and // begins connecting to the network. It blocks until the provided context is -// cancelled. +// canceled. func (cm *ConnManager) Run(ctx context.Context) { log.Trace("Starting connection manager") + defer log.Trace("Connection manager stopped") // Start all the listeners so long as the caller requested them and provided // a callback to be invoked when connections are accepted. @@ -610,27 +1168,56 @@ func (cm *ConnManager) Run(ctx context.Context) { }(listener) } - // Start enough outbound connections to reach the target number when not - // in manual connect mode. + // Start persistent connections handler which starts individual goroutines + // for each persistent connection already added and any newly added ones + // later. + wg.Add(1) + go func() { + cm.persistentConnsHandler(ctx) + wg.Done() + }() + + // Start outbound connection handler to maintain the target number of + // normal outbound connections when not in manual connect mode. if cm.cfg.GetNewAddress != nil { - curConnReqCount := cm.connReqCount.Load() - for i := curConnReqCount; i < uint64(cm.cfg.TargetOutbound); i++ { - go cm.newConnReq(ctx) - } + wg.Add(1) + go func() { + cm.targetOutboundHandler(ctx) + wg.Done() + }() } - // Stop all the listeners and shutdown the connection manager when the - // context is cancelled. There will not be any listeners if listening is - // disabled. + // Shutdown the connection manager when the context is canceled. <-ctx.Done() close(cm.quit) + + // Stop all the listeners. There will not be any listeners if listening is + // disabled. for _, listener := range listeners { // Ignore the error since this is shutdown and there is no way // to recover anyways. _ = listener.Close() } + + // Shutdown persistent conns, cancel pending conns, and close active conns. + cm.connMtx.Lock() + totalIDs := len(cm.persistent) + len(cm.pending) + len(cm.active) + ids := make(map[uint64]struct{}, totalIDs) + for id := range cm.persistent { + ids[id] = struct{}{} + } + for id := range cm.pending { + ids[id] = struct{}{} + } + for id := range cm.active { + ids[id] = struct{}{} + } + cm.connMtx.Unlock() + for id := range ids { + cm.Remove(id) + } + wg.Wait() - log.Trace("Connection manager stopped") } // New returns a new connection manager with the provided configuration. @@ -648,10 +1235,14 @@ func New(cfg *Config) (*ConnManager, error) { cfg.TargetOutbound = defaultTargetOutbound } cm := ConnManager{ - cfg: *cfg, // Copy so caller can't mutate - quit: make(chan struct{}), - pending: make(map[uint64]*ConnReq), - conns: make(map[uint64]*ConnReq, cfg.TargetOutbound), + cfg: *cfg, // Copy so caller can't mutate + quit: make(chan struct{}), + runPersistentChan: make(chan *persistentEntry, MaxPersistent), + activeOutboundsSem: makeSemaphore(cfg.TargetOutbound), + persistent: make(map[uint64]*persistentEntry, MaxPersistent), + pending: make(map[uint64]*pendingConnInfo), + active: make(map[uint64]*Conn, cfg.TargetOutbound), + connIDByAddr: make(map[string]uint64), } return &cm, nil } diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index 65900d583..6001ec75a 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -11,6 +11,7 @@ import ( "fmt" "io" "net" + "net/netip" "sync" "sync/atomic" "testing" @@ -22,6 +23,28 @@ func init() { maxRetryDuration = 2 * time.Millisecond } +const ( + // connTestReceiveTimeout is the default receive timeout used throughout the + // tests when expecting to receive connections to prevent test hangs. + connTestReceiveTimeout = 10 * time.Millisecond + + // connTestNonReceiveTimeout is the default timeout used throughout the + // tests when expecting that a connection will NOT be received. + connTestNonReceiveTimeout = 20 * time.Millisecond +) + +// mustParseAddrPort parses the provided address into a [*net.TCPAddr] and will +// panic if there is an error. It will only (and must only) be called with +// hard-coded, and therefore known good, addresses. +func mustParseAddrPort(addr string) *net.TCPAddr { + addrPort := netip.MustParseAddrPort(addr) + return &net.TCPAddr{ + IP: addrPort.Addr().AsSlice(), + Port: int(addrPort.Port()), + Zone: addrPort.Addr().Zone(), + } +} + // runConnMgrAsync invokes the Run method on the passed connection manager in a // separate goroutine and returns a cancelable context and wait group the caller // can use to shutdown the connection manager and wait for clean shutdown. @@ -100,38 +123,124 @@ func TestNewConfig(t *testing.T) { } } -// assertConnReqID ensures the provided connection request has the given ID. -func assertConnReqID(t *testing.T, connReq *ConnReq, wantID uint64) { +// assertConnID ensures the provided connection has the given ID. +func assertConnID(t *testing.T, conn *Conn, wantID uint64) { t.Helper() - gotID := connReq.ID() + gotID := conn.ID() if gotID != wantID { t.Fatalf("unexpected ID -- got %v, want %v", gotID, wantID) } } -// assertConnReqState ensures the provided connection request has the given -// state. -func assertConnReqState(t *testing.T, connReq *ConnReq, wantState ConnState) { +// assertConnType ensures the provided connection has the given type. +func assertConnType(t *testing.T, conn *Conn, wantType ConnectionType) { + t.Helper() + + gotType := conn.Type() + if gotType != wantType { + t.Fatalf("unexpected type -- got %v, want %v", gotType, wantType) + } +} + +// pendingAddrConnID returns the connection ID associated with the pending +// connection attempt for the provided address. The second return value will be +// false if no pending attempt is found. +func pendingAddrConnID(cm *ConnManager, addr net.Addr) (uint64, bool) { + cm.connMtx.Lock() + defer cm.connMtx.Unlock() + addrStr := addr.String() + for _, info := range cm.pending { + if info.addr.String() == addrStr { + return info.id, true + } + } + return 0, false +} + +// assertPendingAddr ensures there is a pending connection with the given +// address. +func assertPendingAddr(t *testing.T, cm *ConnManager, addr net.Addr) { t.Helper() - gotState := connReq.State() - if gotState != wantState { - t.Fatalf("unexpected state -- got %v, want %v", gotState, wantState) + if _, ok := pendingAddrConnID(cm, addr); !ok { + t.Fatalf("connection %s is not pending", addr) } } +// assertRemovedPersistent ensures there are no persistent conns with the +// provided address. +func assertRemovedPersistent(t *testing.T, cm *ConnManager, addr net.Addr) { + t.Helper() + + if _, ok := cm.FindPersistentAddrID(addr); ok { + t.Fatalf("found persistent entry for %s", addr) + } +} + +// assertConnReceivedTimeout ensures a connection with the given type is +// received on the provided channel before the given timeout. When given a +// non-zero connection ID, it asserts the received connection has that ID. +func assertConnReceivedTimeout(t *testing.T, ch <-chan *Conn, timeout time.Duration, connID uint64, connType ConnectionType) *Conn { + t.Helper() + + select { + case conn := <-ch: + if connID != 0 { + assertConnID(t, conn, connID) + } + assertConnType(t, conn, connType) + return conn + case <-time.After(timeout): + t.Fatal("connection not received before timeout") + } + return nil +} + +// assertConnReceived ensures a connection with the given type is received on +// the provided channel before the default timeout. When given a non-zero +// connection ID, it asserts the received connection has that ID. +func assertConnReceived(t *testing.T, ch <-chan *Conn, connID uint64, connType ConnectionType) *Conn { + t.Helper() + + return assertConnReceivedTimeout(t, ch, connTestReceiveTimeout, connID, + connType) +} + +// assertNoConnReceivedTimeout ensures no connections are received on the +// provided channel before the given timeout. +func assertNoConnReceivedTimeout(t *testing.T, ch <-chan *Conn, timeout time.Duration) { + t.Helper() + + select { + case conn := <-ch: + conn.Close() + t.Fatalf("got unexpected connection from %v", conn.RemoteAddr()) + case <-time.After(timeout): + // Connection not received as expected. + } +} + +// assertNoConnReceived ensures no connections are received on the provided +// channel before the default timeout. +func assertNoConnReceived(t *testing.T, ch <-chan *Conn) { + t.Helper() + + assertNoConnReceivedTimeout(t, ch, connTestNonReceiveTimeout) +} + // TestConnectMode tests that the connection manager works in the connect mode. // -// In connect mode, automatic connections are disabled, so we test that -// requests using Connect are handled and that no other connections are made. +// In connect mode, automatic connections are disabled, so test that connections +// using [ConnManager.Connect] are handled and that no other connections are +// made. func TestConnectMode(t *testing.T) { - connected := make(chan *ConnReq) + connected := make(chan *Conn) cmgr, err := New(&Config{ TargetOutbound: 2, Dial: mockDialer, - OnConnection: func(c *ConnReq, conn net.Conn) { - connected <- c + OnConnection: func(conn *Conn) { + connected <- conn }, }) if err != nil { @@ -139,31 +248,12 @@ func TestConnectMode(t *testing.T) { } ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) - cr := &ConnReq{ - Addr: &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }, - Permanent: true, - } - go cmgr.Connect(ctx, cr) - - // Ensure that the connection was received. - select { - case gotConnReq := <-connected: - assertConnReqID(t, gotConnReq, cr.ID()) - assertConnReqState(t, cr, ConnEstablished) + addr := mustParseAddrPort("127.0.0.1:18555") + go cmgr.Connect(ctx, addr) - case <-time.After(time.Millisecond * 5): - t.Fatalf("connect mode: connection timeout - %v", cr.Addr) - } - - // Ensure only a single connection was made. - select { - case c := <-connected: - t.Fatalf("connect mode: got unexpected connection - %v", c.Addr) - case <-time.After(time.Millisecond * 5): - } + // Ensure that only a single connection is received. + assertConnReceived(t, connected, 0, ConnTypeManual) + assertNoConnReceived(t, connected) // Ensure clean shutdown of connection manager. shutdown() @@ -175,27 +265,17 @@ func TestConnectMode(t *testing.T) { // ensuring they are the only connections made. func TestTargetOutbound(t *testing.T) { const targetOutbound = 10 - var numConnections atomic.Uint32 - hitTargetConns := make(chan struct{}) - extraConns := make(chan *ConnReq) + var nextAddr atomic.Uint32 + connected := make(chan *Conn) cmgr, err := New(&Config{ TargetOutbound: targetOutbound, Dial: mockDialer, GetNewAddress: func() (net.Addr, error) { - return &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }, nil + addrStr := fmt.Sprintf("127.0.0.%d:18555", nextAddr.Add(1)) + return mustParseAddrPort(addrStr), nil }, - OnConnection: func(c *ConnReq, conn net.Conn) { - totalConnections := numConnections.Add(1) - if totalConnections == targetOutbound { - close(hitTargetConns) - return - } - if totalConnections > targetOutbound { - extraConns <- c - } + OnConnection: func(conn *Conn) { + connected <- conn }, }) if err != nil { @@ -203,74 +283,61 @@ func TestTargetOutbound(t *testing.T) { } _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) - // Wait for the expected number of target outbound conns to be established. - select { - case <-hitTargetConns: - case <-time.After(20 * time.Millisecond): - t.Fatal("did not reach target number of conns before timeout") - } - - // Ensure no additional connections are made. - select { - case c := <-extraConns: - t.Fatalf("target outbound: got unexpected connection - %v", c.Addr) - case <-time.After(time.Millisecond * 5): - break + // Ensure only the expected number of target outbound conns are established + // and no more. + for range targetOutbound { + assertConnReceived(t, connected, 0, ConnTypeOutbound) } + assertNoConnReceived(t, connected) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() } -// TestRetryPermanent tests that permanent connection requests are retried. -// -// We make a permanent connection request using Connect, disconnect it using -// Disconnect and we wait for it to be connected back. -func TestRetryPermanent(t *testing.T) { - connected := make(chan *ConnReq) - disconnected := make(chan *ConnReq) +// TestRetryPersistent tests that persistent connections are retried. +func TestRetryPersistent(t *testing.T) { + connected := make(chan *Conn) + disconnected := make(chan *Conn) cmgr, err := New(&Config{ RetryDuration: time.Millisecond, TargetOutbound: 1, Dial: mockDialer, - OnConnection: func(c *ConnReq, conn net.Conn) { - connected <- c + OnConnection: func(conn *Conn) { + connected <- conn }, - OnDisconnection: func(c *ConnReq) { - disconnected <- c + OnDisconnection: func(conn *Conn) { + disconnected <- conn }, }) if err != nil { t.Fatalf("New error: %v", err) } - ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) - cr := &ConnReq{ - Addr: &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }, - Permanent: true, + addr := mustParseAddrPort("127.0.0.1:18555") + connID, err := cmgr.AddPersistent(addr) + if err != nil { + t.Fatalf("failed to add persistent connection: %v", err) + } + if !cmgr.IsPersistent(connID) { + t.Fatal("IsPersistent did not reported true for persistent conn") } - go cmgr.Connect(ctx, cr) - gotConnReq := <-connected - assertConnReqID(t, gotConnReq, cr.ID()) - assertConnReqState(t, cr, ConnEstablished) - - cmgr.Disconnect(cr.ID()) - gotConnReq = <-disconnected - assertConnReqID(t, gotConnReq, cr.ID()) - assertConnReqState(t, cr, ConnPending) - gotConnReq = <-connected - assertConnReqID(t, gotConnReq, cr.ID()) - assertConnReqState(t, cr, ConnEstablished) + // Wait for the first connection, close it, wait for the disconnect, and + // ensure the retry succeeds. + conn := assertConnReceived(t, connected, connID, ConnTypeManual) + conn.Close() + assertConnReceived(t, disconnected, connID, ConnTypeManual) + assertConnReceived(t, connected, connID, ConnTypeManual) - cmgr.Remove(cr.ID()) - gotConnReq = <-disconnected - assertConnReqID(t, gotConnReq, cr.ID()) - assertConnReqState(t, cr, ConnDisconnected) + // Remove the persistent connection, wait for it to disconnect, and ensure + // it is actually removed. + if err := cmgr.Remove(connID); err != nil { + t.Fatalf("failed to remove persistent connection: %v", err) + } + assertConnReceived(t, disconnected, connID, ConnTypeManual) + assertRemovedPersistent(t, cmgr, addr) // Ensure clean shutdown of connection manager. shutdown() @@ -290,9 +357,6 @@ func TestMaxRetryDuration(t *testing.T) { } networkUp := make(chan struct{}) - time.AfterFunc(5*time.Millisecond, func() { - close(networkUp) - }) timedDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { select { case <-networkUp: @@ -302,36 +366,34 @@ func TestMaxRetryDuration(t *testing.T) { } } - connected := make(chan *ConnReq) + connected := make(chan *Conn) cmgr, err := New(&Config{ RetryDuration: time.Millisecond, TargetOutbound: 1, Dial: timedDialer, - OnConnection: func(c *ConnReq, conn net.Conn) { - connected <- c + OnConnection: func(conn *Conn) { + connected <- conn }, }) if err != nil { t.Fatalf("New error: %v", err) } - ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) - cr := &ConnReq{ - Addr: &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }, - Permanent: true, + connID, err := cmgr.AddPersistent(mustParseAddrPort("127.0.0.1:18555")) + if err != nil { + t.Fatalf("failed to add persistent connection: %v", err) } - go cmgr.Connect(ctx, cr) + // retry in 1ms // retry in 2ms - max retry duration reached - // retry in 2ms - timedDialer returns mockDial - select { - case <-connected: - case <-time.After(200 * time.Millisecond): - t.Fatal("max retry duration: connection timeout") - } + // retry in 2ms - timedDialer returns [mockDialer] + const networkUpTimeout = 5 * time.Millisecond + time.AfterFunc(networkUpTimeout, func() { + close(networkUp) + }) + const timeout = connTestReceiveTimeout + networkUpTimeout + assertConnReceivedTimeout(t, connected, timeout, connID, ConnTypeManual) // Ensure clean shutdown of connection manager. shutdown() @@ -349,24 +411,24 @@ func TestNetworkFailure(t *testing.T) { connMgrDone := make(chan struct{}) errDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { totalDials := dials.Add(1) - if totalDials >= maxFailedAttempts { + if totalDials > maxFailedAttempts { closeOnce.Do(func() { close(reachedMaxFailedAttempts) }) <-connMgrDone } return nil, errors.New("network down") } + var nextAddr atomic.Uint32 cmgr, err := New(&Config{ TargetOutbound: targetOutbound, RetryDuration: retryTimeout, Dial: errDialer, GetNewAddress: func() (net.Addr, error) { - return &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }, nil + addrStr := fmt.Sprintf("127.0.0.%d:18555", nextAddr.Add(1)) + return mustParseAddrPort(addrStr), nil }, - OnConnection: func(c *ConnReq, conn net.Conn) { - t.Fatalf("network failure: got unexpected connection - %v", c.Addr) + OnConnection: func(conn *Conn) { + t.Fatalf("network failure: got unexpected connection - %v", + conn.RemoteAddr()) }, }) if err != nil { @@ -377,7 +439,11 @@ func TestNetworkFailure(t *testing.T) { // Shutdown the connection manager after the max failed attempts is reached // and an additional retry duration has passed and then wait for the // shutdown to complete. - <-reachedMaxFailedAttempts + select { + case <-reachedMaxFailedAttempts: + case <-time.After(retryTimeout * maxFailedAttempts * 3): + t.Fatal("did not reach target number of failed attempts before timeout") + } time.Sleep(retryTimeout) shutdown() close(connMgrDone) @@ -396,7 +462,7 @@ func TestNetworkFailure(t *testing.T) { // TestMultipleFailedConns ensures that the connection manager remains // responsive when there are multiple simultaneous failed connections for -// persistent peers in the retry state. +// persistent conns in the retry state. func TestMultipleFailedConns(t *testing.T) { // Override the max retry duration for this test since it relies on having // multiple connections in the retry state. @@ -424,18 +490,15 @@ func TestMultipleFailedConns(t *testing.T) { if err != nil { t.Fatalf("New error: %v", err) } - ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) // Establish several connection requests to localhost IPs. - for i := 0; i < targetFailed; i++ { - cr := &ConnReq{ - Addr: &net.TCPAddr{ - IP: net.ParseIP(fmt.Sprintf("127.0.0.%d", i+1)), - Port: 18555, - }, - Permanent: true, + for i := range targetFailed { + addr := mustParseAddrPort(fmt.Sprintf("127.0.0.%d:18555", i+1)) + _, err := cmgr.AddPersistent(addr) + if err != nil { + t.Fatalf("unexpected add err: %v", err) } - go cmgr.Connect(ctx, cr) } // Wait for the target number of dials and ensure they happen simultaneously @@ -467,10 +530,6 @@ func TestMultipleFailedConns(t *testing.T) { // TestShutdownFailedConns tests that failed connections are ignored after // connmgr is shutdown. -// -// We have a dialer which sets the stop flag on the conn manager and returns an -// err so that the handler assumes that the conn manager is stopped and ignores -// the failure. func TestShutdownFailedConns(t *testing.T) { var closeOnce sync.Once dialed := make(chan struct{}) @@ -495,30 +554,25 @@ func TestShutdownFailedConns(t *testing.T) { shutdown() }() - // Establish a connection request to a localhost IP. - cr := &ConnReq{ - Addr: &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }, - Permanent: true, - } - go cmgr.Connect(ctx, cr) + // Establish a connection. + addr := mustParseAddrPort("127.0.0.1:18555") + go cmgr.Connect(ctx, addr) // Ensure clean shutdown of connection manager. wg.Wait() } -// TestRemovePendingConnection tests that it's possible to cancel a pending -// connection, removing its internal state from the connection manager. +// TestRemovePendingConnection ensures that removing a pending outbound +// connection correctly cancels the context used to dial and removes the +// internal state. func TestRemovePendingConnection(t *testing.T) { - // Create a ConnMgr instance with an instance of a dialer that'll never - // succeed. + // Create a conn manager with an instance of a dialer that'll never succeed. dialed := make(chan struct{}) - wait := make(chan struct{}) + canceled := make(chan struct{}) indefiniteDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { close(dialed) - <-wait + <-ctx.Done() + close(canceled) return nil, errors.New("error") } cmgr, err := New(&Config{ @@ -530,121 +584,118 @@ func TestRemovePendingConnection(t *testing.T) { ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) // Establish a connection request to a localhost IP. - cr := &ConnReq{ - Addr: &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }, - Permanent: true, - } - go cmgr.Connect(ctx, cr) + addr := mustParseAddrPort("127.0.0.1:18555") + go cmgr.Connect(ctx, addr) - // Wait for the connection manager to attempt to dial the connection request - // and ensure the connection is marked as pending while the dialer is - // blocked. + // Wait for the connection manager to attempt to dial and ensure the + // connection is marked as pending while the dialer is blocked. select { case <-dialed: case <-time.After(time.Millisecond * 20): t.Fatal("timeout waiting for dial") } - assertConnReqState(t, cr, ConnPending) + assertPendingAddr(t, cmgr, addr) - // The request launched above will never be able to establish a connection, - // so cancel it _before_ it's able to be completed. - cmgr.Remove(cr.ID()) + // Cancel the connection attempt while it's still pending. + connID, _ := pendingAddrConnID(cmgr, addr) + if err := cmgr.Remove(connID); err != nil { + t.Fatalf("unexpected remove err: %v", err) + } - // Ensure the connection request is now marked as canceled after a short - // timeout to allow the transition to occur. - time.Sleep(10 * time.Millisecond) - assertConnReqState(t, cr, ConnCanceled) + // Wait for the dialer to signal the context associated with the dial was + // canceled and ensure the internal pending state is removed. + select { + case <-canceled: + case <-time.After(time.Millisecond * 20): + t.Fatal("timeout waiting for cancel") + } + if _, ok := pendingAddrConnID(cmgr, addr); ok { + t.Fatalf("connection %s is still pending", addr) + } // Ensure clean shutdown of connection manager. - close(wait) shutdown() wg.Wait() } -// TestCancelIgnoreDelayedConnection tests that a canceled connection request -// will not execute the on connection callback, even if an outstanding retry -// succeeds. +// TestCancelIgnoreDelayedConnection tests that a canceled pending persistent +// connection will not execute the on connection callback, even if a pending +// retry succeeds. func TestCancelIgnoreDelayedConnection(t *testing.T) { const retryTimeout = 10 * time.Millisecond - // Setup a dialer that will continue to return an error until the - // connect chan is signaled. The dial attempt immediately after that - // will succeed in returning a connection. + // Setup a dialer that returns an error on the first attempt and then blocks + // until the connect chan is signaled. The dial attempt immediately after + // that will succeed in returning a connection. + var numAttempts atomic.Uint32 connect := make(chan struct{}) + retried := make(chan struct{}) failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { - select { - case <-connect: - return mockDialer(ctx, network, addr) - default: + if numAttempts.Add(1) == 1 { + return nil, errors.New("network down") } - return nil, errors.New("error") + close(retried) + <-connect + + // Override the context to ensure the pending dial succeeds even though + // the passed context will be canceled. + ctx = context.Background() + return mockDialer(ctx, network, addr) } - connected := make(chan *ConnReq) + connected := make(chan *Conn) cmgr, err := New(&Config{ Dial: failingDialer, RetryDuration: retryTimeout, - OnConnection: func(c *ConnReq, conn net.Conn) { - connected <- c + OnConnection: func(conn *Conn) { + connected <- conn }, }) if err != nil { t.Fatalf("New error: %v", err) } - ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) - // Establish a connection request to a localhost IP. - cr := &ConnReq{ - Addr: &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }, + // Establish a persistent connection to a localhost IP. + addr := mustParseAddrPort("127.0.0.1:18555") + connID, err := cmgr.AddPersistent(addr) + if err != nil { + t.Fatalf("unexpected error: %v", err) } - go cmgr.Connect(ctx, cr) - - // Allow for the first retry timeout to elapse. - time.Sleep(2 * retryTimeout) - // Ensure the status of the connection request is marked as failed, even - // after reattempting to connect. - assertConnReqState(t, cr, ConnFailed) + // Wait for the retry and ensure the connection is pending. + select { + case <-retried: + case <-time.After(20 * time.Millisecond): + t.Fatalf("did not get retry before timeout") + } + assertPendingAddr(t, cmgr, addr) - // Remove the connection, and then immediately allow the next connection - // to succeed. - cmgr.Remove(cr.ID()) + // Remove the connection and then immediately allow the next connection to + // succeed. + if err := cmgr.Remove(connID); err != nil { + t.Fatalf("unexpected remove err: %v", err) + } close(connect) - // Allow the connection manager to process the removal. - time.Sleep(5 * time.Millisecond) - - // Ensure the status of the connection request is canceled. - assertConnReqState(t, cr, ConnCanceled) - // Finally, the connection manager should not signal the OnConnection // callback, since the request was explicitly canceled. Give a generous - // timeout window to ensure the connection manager's linear backoff is - // allowed to properly elapse. - select { - case <-connected: - t.Fatal("on-connect should not be called for canceled req") - case <-time.After(5 * retryTimeout): - } + // timeout window to ensure the connection manager's backoff is allowed to + // properly elapse. + assertNoConnReceivedTimeout(t, connected, 5*retryTimeout) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() } -// TestDialTimeout ensure the Timeout configuration parameter works as intended -// by creating a dialer that blocks for three times the configured dial timeout -// before connecting and ensuring the connection fails as expected. +// TestDialTimeout ensure [Config.Timeout] works as intended by creating a +// dialer that blocks for three times the configured dial timeout before +// connecting and ensuring the connection fails as expected. func TestDialTimeout(t *testing.T) { - // Create a connection manager instance with a dialer that blocks for twice - // the configured dial timeout before connecting. + // Create a connection manager instance with a dialer that blocks for three + // times the configured dial timeout before connecting. const dialTimeout = time.Millisecond * 20 cancelled := make(chan struct{}) timeoutDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { @@ -664,35 +715,27 @@ func TestDialTimeout(t *testing.T) { if err != nil { t.Fatalf("New error: %v", err) } - _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) - // Establish a connection request to a localhost IP. - cr := &ConnReq{ - Addr: &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }, - } - go cmgr.Connect(context.Background(), cr) + // Establish a connection to a localhost IP. + addr := mustParseAddrPort("127.0.0.1:18555") + go cmgr.Connect(ctx, addr) // Wait to receive the signal that the dialer context was cancelled, which - // means the dial timeout was hit, and ensure the connection request is - // marked as failed after a short timeout to allow the transition to occur. + // means the dial timeout was hit. select { case <-cancelled: case <-time.After(dialTimeout * 10): t.Fatal("timeout waiting for dial cancellation") } - time.Sleep(10 * time.Millisecond) - assertConnReqState(t, cr, ConnFailed) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() } -// TestConnectContext ensures the Connect method works as intended when provided -// with a context that times out before a dial attempt succeeds. +// TestConnectContext ensures the [ConnManager.Connect] method works as intended +// when provided with a context that is canceled before a dial attempt succeeds. func TestConnectContext(t *testing.T) { // Create a connection manager instance with a dialer that blocks until its // provided context is canceled. @@ -708,18 +751,17 @@ func TestConnectContext(t *testing.T) { if err != nil { t.Fatalf("New error: %v", err) } - _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) // Establish a connection request to a localhost IP with a separate context // that can be canceled. - cr := &ConnReq{ - Addr: &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }, - } - connectCtx, cancelConnect := context.WithCancel(context.Background()) - go cmgr.Connect(connectCtx, cr) + addr := mustParseAddrPort("127.0.0.1:18555") + connectCtx, cancelConnect := context.WithCancel(ctx) + connectErr := make(chan error, 1) + go func() { + _, err := cmgr.Connect(connectCtx, addr) + connectErr <- err + }() // Wait for the connection manager to attempt to dial the connection request // and ensure the connection is marked as pending while the dialer is @@ -729,119 +771,19 @@ func TestConnectContext(t *testing.T) { case <-time.After(time.Millisecond * 20): t.Fatal("timeout waiting for dial") } - assertConnReqState(t, cr, ConnPending) + assertPendingAddr(t, cmgr, addr) - // Cancel the connection context and ensure the connection request is marked - // as failed after a short timeout to allow the transition to occur. + // Cancel the connection context, wait for the error from connect, and + // ensure it is the expected error. cancelConnect() - time.Sleep(10 * time.Millisecond) - assertConnReqState(t, cr, ConnFailed) - - // Ensure clean shutdown of connection manager. - shutdown() - wg.Wait() -} - -// TestForEachConnReq tests the connection request iteration logic work as -// expected including for normal, permanent, and pending connections. -func TestForEachConnReq(t *testing.T) { - // Create a connection manager instance with a dialer that recognizes a - // special address to delay on in order to keep it pending. - targetOutbound := uint32(5) - connected := make(chan *ConnReq) - pending := make(chan struct{}) - delayDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { - if addr == "127.0.0.1:18557" { - close(pending) - time.Sleep(time.Second) - return nil, errors.New("error") - } - return mockDialer(ctx, network, addr) - } - cmgr, err := New(&Config{ - TargetOutbound: targetOutbound, - Dial: delayDialer, - GetNewAddress: func() (net.Addr, error) { - return &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }, nil - }, - OnConnection: func(c *ConnReq, conn net.Conn) { - connected <- c - }, - }) - if err != nil { - t.Fatalf("New error: %v", err) - } - _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) - - // Wait for the expected number of target outbound conns to be established. - allConnected := make(chan struct{}) - go func() { - for i := uint32(0); i < targetOutbound; i++ { - <-connected - } - close(allConnected) - }() - select { - case <-allConnected: - case <-time.After(time.Millisecond * 5 * time.Duration(targetOutbound)): - t.Fatal("timeout waiting for connections") - } - - // Create a permanent connection. - cr := &ConnReq{ - Permanent: true, - Addr: &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18556, - }, - } - go cmgr.Connect(context.Background(), cr) - select { - case <-connected: - case <-time.After(time.Millisecond * 5): - t.Fatal("timeout waiting for permanent connection") - } - - // Create a connection that triggers the mock dialer to keep it pending. - cr = &ConnReq{ - Addr: &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18557, - }, - } - go cmgr.Connect(context.Background(), cr) select { - case <-pending: - case <-time.After(time.Millisecond * 5): - t.Fatal("timeout waiting for pending connection") - } - - // Ensure the expected number of each type of connection exists. - var numConnected, numPermanent, numPending uint32 - _ = cmgr.ForEachConnReq(func(cr *ConnReq) error { - numConnected++ - if cr.State() == ConnPending { - numPending++ - } - if cr.Permanent { - numPermanent++ + case err := <-connectErr: + if !errors.Is(err, context.Canceled) { + t.Fatalf("unexpected connect err: got %v, want %v", err, + context.Canceled) } - return nil - }) - if numConnected != targetOutbound+2 { - t.Fatalf("unexpected number of iterated conn reqs -- got %d, want %d", - numConnected, targetOutbound+2) - } - if numPermanent != 1 { - t.Fatalf("unexpected number of permanent conn reqs -- got %d, want %d", - numPermanent, 1) - } - if numPending != 1 { - t.Fatalf("unexpected number of pending conn reqs -- got %d, want %d", - numPending, 1) + case <-time.After(10 * time.Millisecond): + t.Fatal("timeout waiting for dial cancellation") } // Ensure clean shutdown of connection manager. @@ -888,14 +830,11 @@ func (m *mockListener) Addr() net.Addr { // address. It will cause the Accept function to return a mock connection // configured with the provided remote address and the local address for the // mock listener. -func (m *mockListener) Connect(ip string, port int) { +func (m *mockListener) Connect(addr net.Addr) { m.provideConn <- &mockConn{ laddr: m.localAddr, lnet: "tcp", - rAddr: &net.TCPAddr{ - IP: net.ParseIP(ip), - Port: port, - }, + rAddr: addr, } } @@ -913,13 +852,13 @@ func newMockListener(localAddr string) *mockListener { func TestListeners(t *testing.T) { // Setup a connection manager with a couple of mock listeners that // notify a channel when they receive mock connections. - receivedConns := make(chan net.Conn) - listener1 := newMockListener("127.0.0.1:8333") - listener2 := newMockListener("127.0.0.1:9333") + receivedConns := make(chan *Conn) + listener1 := newMockListener("127.0.0.1:9108") + listener2 := newMockListener("127.0.0.1:9208") listeners := []net.Listener{listener1, listener2} cmgr, err := New(&Config{ Listeners: listeners, - OnAccept: func(conn net.Conn) { + OnAccept: func(conn *Conn) { receivedConns <- conn }, Dial: mockDialer, @@ -933,29 +872,15 @@ func TestListeners(t *testing.T) { go func() { for i, listener := range listeners { l := listener.(*mockListener) - l.Connect("127.0.0.1", 10000+i*2) - l.Connect("127.0.0.1", 10000+i*2+1) + l.Connect(mustParseAddrPort(fmt.Sprintf("127.0.0.1:%d", 10000+i*2))) + l.Connect(mustParseAddrPort(fmt.Sprintf("127.0.0.1:%d", 10000+i*2+1))) } }() - // Tally the receive connections to ensure the expected number are - // received. Also, fail the test after a timeout so it will not hang - // forever should the test not work. + // Ensure the expected number of inbound connections are received. expectedNumConns := len(listeners) * 2 - var numConns int -out: - for { - select { - case <-receivedConns: - numConns++ - if numConns == expectedNumConns { - break out - } - - case <-time.After(time.Millisecond * 50): - t.Fatalf("Timeout waiting for %d expected connections", - expectedNumConns) - } + for range expectedNumConns { + assertConnReceived(t, receivedConns, 0, ConnTypeInbound) } // Ensure clean shutdown of connection manager. diff --git a/internal/connmgr/conntype_test.go b/internal/connmgr/conntype_test.go new file mode 100644 index 000000000..f4d3f69e1 --- /dev/null +++ b/internal/connmgr/conntype_test.go @@ -0,0 +1,35 @@ +// Copyright (c) 2026 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package connmgr + +import ( + "testing" +) + +// TestConnectionTypeStringer tests the stringized output for connection types. +func TestConnectionTypeStringer(t *testing.T) { + tests := []struct { + in ConnectionType + want string + }{ + {ConnTypeInbound, "inbound"}, + {ConnTypeOutbound, "outbound"}, + {ConnTypeManual, "manual"}, + {0xff, "Unknown ConnectionType (255)"}, + } + + // Detect additional defines that don't have the stringer added. + if len(tests)-1 != int(numConnTypes) { + t.Fatal("It appears a connection type was added without adding an " + + "associated stringer test") + } + + for i, test := range tests { + if got := test.in.String(); got != test.want { + t.Errorf("String #%d: got: %s, want: %s", i, got, test.want) + continue + } + } +} diff --git a/internal/connmgr/error.go b/internal/connmgr/error.go index 932a13f28..9c87dad6e 100644 --- a/internal/connmgr/error.go +++ b/internal/connmgr/error.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020 The Decred developers +// Copyright (c) 2020-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -14,10 +14,34 @@ const ( // the configuration. ErrDialNil = ErrorKind("ErrDialNil") + // ErrAlreadyPending indicates an attempt to connect to an address that + // already has a pending connection attempt. + ErrAlreadyPending = ErrorKind("ErrAlreadyPending") + + // ErrAlreadyConnected indicates an attempt to connect to an address that + // already has an established connection. + ErrAlreadyConnected = ErrorKind("ErrAlreadyConnected") + + // ErrMaxPersistent indicates an attempt to add more than the maximum + // allowed number of persistent connections. + ErrMaxPersistent = ErrorKind("ErrMaxPersistent") + + // ErrDuplicatePersistent indicates an attempt to add a persistent + // connection to an address that already exists. + ErrDuplicatePersistent = ErrorKind("ErrDuplicatePersistent") + // ErrNotFound indicates a specified connection ID or address is unknown to // the connection manager. ErrNotFound = ErrorKind("ErrNotFound") + // ErrUnsupportedAddr indicates an address is either an unsupported type or + // an unrecognized type due to being malformed. + ErrUnsupportedAddr = ErrorKind("ErrUnsupportedAddr") + + // ErrShutdown indicates the connection manager is either in the process of + // shutting down or has already been shutdown. + ErrShutdown = ErrorKind("ErrShutdown") + // ErrTorInvalidAddressResponse indicates an invalid address was // returned by the Tor DNS resolver. ErrTorInvalidAddressResponse = ErrorKind("ErrTorInvalidAddressResponse") diff --git a/internal/connmgr/error_test.go b/internal/connmgr/error_test.go index 1e177b97e..d4e5d2262 100644 --- a/internal/connmgr/error_test.go +++ b/internal/connmgr/error_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021 The Decred developers +// Copyright (c) 2020-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -17,7 +17,13 @@ func TestErrorKindStringer(t *testing.T) { want string }{ {ErrDialNil, "ErrDialNil"}, + {ErrAlreadyPending, "ErrAlreadyPending"}, + {ErrAlreadyConnected, "ErrAlreadyConnected"}, + {ErrMaxPersistent, "ErrMaxPersistent"}, + {ErrDuplicatePersistent, "ErrDuplicatePersistent"}, {ErrNotFound, "ErrNotFound"}, + {ErrUnsupportedAddr, "ErrUnsupportedAddr"}, + {ErrShutdown, "ErrShutdown"}, {ErrTorInvalidAddressResponse, "ErrTorInvalidAddressResponse"}, {ErrTorInvalidProxyResponse, "ErrTorInvalidProxyResponse"}, {ErrTorUnrecognizedAuthMethod, "ErrTorUnrecognizedAuthMethod"}, diff --git a/internal/rpcserver/interface.go b/internal/rpcserver/interface.go index 97b88fde9..c87134f49 100644 --- a/internal/rpcserver/interface.go +++ b/internal/rpcserver/interface.go @@ -114,7 +114,7 @@ type ConnManager interface { // permanent flag indicates whether or not to make the peer persistent // and reconnect if the connection is lost. Attempting to connect to an // already existing peer will return an error. - Connect(addr string, permanent bool) error + Connect(ctx context.Context, addr string, permanent bool) error // RemoveByID removes the peer associated with the provided id from the // list of persistent peers. Attempting to remove an id that does not diff --git a/internal/rpcserver/rpcserver.go b/internal/rpcserver/rpcserver.go index f4907bdd8..175ec9831 100644 --- a/internal/rpcserver/rpcserver.go +++ b/internal/rpcserver/rpcserver.go @@ -598,7 +598,7 @@ func newWorkState() *workState { } // handleAddNode handles addnode commands. -func handleAddNode(_ context.Context, s *Server, cmd any) (any, error) { +func handleAddNode(ctx context.Context, s *Server, cmd any) (any, error) { c := cmd.(*types.AddNodeCmd) addr := normalizeAddress(c.Addr, s.cfg.ChainParams.DefaultPort) @@ -606,25 +606,47 @@ func handleAddNode(_ context.Context, s *Server, cmd any) (any, error) { var err error switch c.SubCmd { case "add": - err = connMgr.Connect(addr, true) + err = connMgr.Connect(ctx, addr, true) case "remove": err = connMgr.RemoveByAddr(addr) case "onetry": - err = connMgr.Connect(addr, false) + err = connMgr.Connect(ctx, addr, false) default: return nil, rpcInvalidError("Invalid subcommand for addnode") } if err != nil { - return nil, rpcInvalidError("%v: %v", c.SubCmd, err) + switch { + // Connecting involves child contexts, so there is no guarantee that + // context errors returned from Connect are the result of the parent + // context. + // + // Check the parent context first to determine if the failure is the + // result of the RPC server (e.g. RPC connection closed, server + // shutdown, etc). + // + // Otherwise, context errors refer to the actual connection attempt. + case ctx.Err() != nil: + return nil, rpcConnectionClosedError() + + case errors.Is(err, context.Canceled): + return nil, rpcCancelError("%v: connection attempt to %v canceled", + c.SubCmd, addr) + + case errors.Is(err, context.DeadlineExceeded): + return nil, rpcCancelError("%v: timeout connecting to %v", c.SubCmd, + addr) + } + + prefix := fmt.Sprintf("%v: failed operation on %v", c.SubCmd, addr) + return nil, rpcInternalErr(err, prefix) } - // no data returned unless an error. return nil, nil } // handleNode handles node commands. -func handleNode(_ context.Context, s *Server, cmd any) (any, error) { +func handleNode(ctx context.Context, s *Server, cmd any) (any, error) { c := cmd.(*types.NodeCmd) connMgr := s.cfg.ConnMgr @@ -646,13 +668,16 @@ func handleNode(_ context.Context, s *Server, cmd any) (any, error) { addr = normalizeAddress(c.Target, params.DefaultPort) err = connMgr.DisconnectByAddr(addr) } else { - return nil, rpcInvalidError("%v: Invalid "+ - "address or node ID", c.SubCmd) + return nil, rpcInvalidError("%v: invalid address or node ID", + c.SubCmd) } } if err != nil && peerExists(connMgr, addr, int32(nodeID)) { - return nil, rpcMiscError("can't disconnect a permanent peer, " + - "use remove") + return nil, rpcMiscError("can't disconnect a permanent peer, use " + + "remove") + } + if err != nil { + return nil, rpcInvalidError("%v: %v", c.SubCmd, err) } case "remove": @@ -667,13 +692,16 @@ func handleNode(_ context.Context, s *Server, cmd any) (any, error) { addr = normalizeAddress(c.Target, params.DefaultPort) err = connMgr.RemoveByAddr(addr) } else { - return nil, rpcInvalidError("%v: invalid "+ - "address or node ID", c.SubCmd) + return nil, rpcInvalidError("%v: invalid address or node ID", + c.SubCmd) } } if err != nil && peerExists(connMgr, addr, int32(nodeID)) { - return nil, rpcMiscError("can't remove a temporary peer, " + - "use disconnect") + return nil, rpcMiscError("can't remove a temporary peer, use " + + "disconnect") + } + if err != nil { + return nil, rpcInvalidError("%v: %v", c.SubCmd, err) } case "connect": @@ -687,20 +715,42 @@ func handleNode(_ context.Context, s *Server, cmd any) (any, error) { switch subCmd { case "perm", "temp": - err = connMgr.Connect(addr, subCmd == "perm") + err = connMgr.Connect(ctx, addr, subCmd == "perm") default: - return nil, rpcInvalidError("%v: invalid subcommand "+ - "for node connect", subCmd) + return nil, rpcInvalidError("%v: invalid subcommand for node "+ + "connect", subCmd) } + if err != nil { + // Connecting involves child contexts, so there is no guarantee that + // context errors returned from Connect are the result of the parent + // context. + // + // Check the parent context first to determine if the failure is the + // result of the RPC server (e.g. RPC connection closed, server + // shutdown, etc). + // + // Otherwise, context errors refer to the actual connection attempt. + switch { + case ctx.Err() != nil: + return nil, rpcConnectionClosedError() + + case errors.Is(err, context.Canceled): + return nil, rpcCancelError("%v: connection attempt to %v "+ + "canceled", c.SubCmd, addr) + + case errors.Is(err, context.DeadlineExceeded): + return nil, rpcCancelError("%v: timeout connecting to %v", + c.SubCmd, addr) + } + + prefix := fmt.Sprintf("%v: failed operation on %v", c.SubCmd, addr) + return nil, rpcInternalErr(err, prefix) + } + default: return nil, rpcInvalidError("%v: invalid subcommand for node", c.SubCmd) } - if err != nil { - return nil, rpcInvalidError("%v: %v", c.SubCmd, err) - } - - // no data returned unless an error. return nil, nil } diff --git a/internal/rpcserver/rpcserverhandlers_test.go b/internal/rpcserver/rpcserverhandlers_test.go index ac6d8fd3b..25d1a070c 100644 --- a/internal/rpcserver/rpcserverhandlers_test.go +++ b/internal/rpcserver/rpcserverhandlers_test.go @@ -856,7 +856,7 @@ type testConnManager struct { // Connect provides a mock implementation for adding the provided address as a // new outbound peer. -func (c *testConnManager) Connect(addr string, permanent bool) error { +func (c *testConnManager) Connect(ctx context.Context, addr string, permanent bool) error { return c.connectErr } @@ -1919,7 +1919,7 @@ func TestHandleAddNode(t *testing.T) { return connManager }(), wantErr: true, - errCode: dcrjson.ErrRPCInvalidParameter, + errCode: dcrjson.ErrRPCInternal.Code, }, { name: "handleAddNode: 'remove' subcommand error", handler: handleAddNode, @@ -1933,7 +1933,7 @@ func TestHandleAddNode(t *testing.T) { return connManager }(), wantErr: true, - errCode: dcrjson.ErrRPCInvalidParameter, + errCode: dcrjson.ErrRPCInternal.Code, }, { name: "handleAddNode: 'onetry' subcommand error", handler: handleAddNode, @@ -1947,7 +1947,7 @@ func TestHandleAddNode(t *testing.T) { return connManager }(), wantErr: true, - errCode: dcrjson.ErrRPCInvalidParameter, + errCode: dcrjson.ErrRPCInternal.Code, }, { name: "handleAddNode: invalid subcommand", handler: handleAddNode, diff --git a/rpcadaptors.go b/rpcadaptors.go index 8dd072a94..a6826e047 100644 --- a/rpcadaptors.go +++ b/rpcadaptors.go @@ -15,7 +15,6 @@ import ( "github.com/decred/dcrd/chaincfg/v3" "github.com/decred/dcrd/dcrutil/v4" "github.com/decred/dcrd/internal/blockchain" - "github.com/decred/dcrd/internal/connmgr" "github.com/decred/dcrd/internal/mempool" "github.com/decred/dcrd/internal/mining" "github.com/decred/dcrd/internal/mining/cpuminer" @@ -125,29 +124,7 @@ var _ rpcserver.ConnManager = (*rpcConnManager)(nil) // // This function is safe for concurrent access and is part of the // rpcserver.ConnManager interface implementation. -func (cm *rpcConnManager) Connect(addr string, permanent bool) error { - // Prevent duplicate connections to the same peer. - connManager := cm.server.connManager - err := connManager.ForEachConnReq(func(c *connmgr.ConnReq) error { - if c.Addr != nil && c.Addr.String() == addr { - if c.Permanent { - return errors.New("peer exists as a permanent peer") - } - - switch c.State() { - case connmgr.ConnPending: - return errors.New("peer pending connection") - case connmgr.ConnEstablished: - return errors.New("peer already connected") - - } - } - return nil - }) - if err != nil { - return err - } - +func (cm *rpcConnManager) Connect(ctx context.Context, addr string, permanent bool) error { netAddr, err := addrStringToNetAddr(addr) if err != nil { return err @@ -161,40 +138,44 @@ func (cm *rpcConnManager) Connect(addr string, permanent bool) error { return errors.New("max peers reached") } - go connManager.Connect(context.Background(), &connmgr.ConnReq{ - Addr: netAddr, - Permanent: permanent, - }) - return nil + // Attempt to add a persistent peer when requested. + connManager := cm.server.connManager + if permanent { + _, err := connManager.AddPersistent(netAddr) + return err + } + + // Attempt to connect to the address. + _, err = connManager.Connect(ctx, netAddr) + return err } -// removeNode removes any peers that the provided compare function return true +// errPeerNotFound is returned by the RPC conn manager when no matching peer for +// a given address or ID is found. +var errPeerNotFound = errors.New("peer not found") + +// removeNode removes any peer that the provided compare function return true // for from the list of persistent peers. // -// An error will be returned if no matching peers are found (aka the compare +// An error will be returned if no matching peer is found (aka the compare // function returns false for all peers). func (cm *rpcConnManager) removeNode(cmp func(*serverPeer) bool) error { state := &cm.server.peerState + var found *serverPeer state.Lock() - found := disconnectPeer(state.persistentPeers, cmp, func(sp *serverPeer) { - // Update the group counts since the peer will be removed from the - // persistent peers just after this func returns. - state.outboundGroups[sp.remoteAddr.GroupKey()]-- - - connReq := sp.connReq.Load() - peerLog.Debugf("Removing persistent peer %s (reqid %d)", sp.remoteAddr, - connReq.ID()) - - // Mark the peer's connReq as nil to prevent it from scheduling a - // re-connect attempt. - sp.connReq.Store(nil) - cm.server.connManager.Remove(connReq.ID()) - }) + for _, peer := range state.persistentPeers { + if cmp(peer) { + found = peer + break + } + } state.Unlock() - - if !found { - return errors.New("peer not found") + if found == nil { + return errPeerNotFound } + + peerLog.Debugf("Removing persistent peer %s", found.remoteAddr) + cm.server.connManager.Remove(found.conn.ID()) return nil } @@ -203,10 +184,20 @@ func (cm *rpcConnManager) removeNode(cmp func(*serverPeer) bool) error { // an error. // // This function is safe for concurrent access and is part of the -// rpcserver.ConnManager interface implementation. +// [rpcserver.ConnManager] interface implementation. func (cm *rpcConnManager) RemoveByID(id int32) error { + // Attempt to remove the peer by ID first. When the ID does not correspond + // to an established persistent peer, fall back to treating the ID as a + // connection ID and remove it when it is for a persistent connection. + connManager := cm.server.connManager cmp := func(sp *serverPeer) bool { return sp.ID() == id } - return cm.removeNode(cmp) + err := cm.removeNode(cmp) + if errors.Is(err, errPeerNotFound) && connManager.IsPersistent(uint64(id)) { + if rErr := connManager.Remove(uint64(id)); rErr == nil { + return nil + } + } + return err } // RemoveByAddr removes the peer associated with the provided address from the @@ -214,57 +205,67 @@ func (cm *rpcConnManager) RemoveByID(id int32) error { // exist will return an error. // // This function is safe for concurrent access and is part of the -// rpcserver.ConnManager interface implementation. +// [rpcserver.ConnManager] interface implementation. func (cm *rpcConnManager) RemoveByAddr(addr string) error { + // Attempt to remove the peer by address first. When the address does not + // correspond to an established persistent peer, fall back to searching the + // connection manager directly for a matching persistent connection entry + // and remove it when found. cmp := func(sp *serverPeer) bool { return sp.Addr() == addr } err := cm.removeNode(cmp) - if err != nil { - netAddr, err := addrStringToNetAddr(addr) - if err != nil { - return err + if errors.Is(err, errPeerNotFound) { + netAddr := simpleAddr{"tcp", addr} + if id, ok := cm.server.connManager.FindPersistentAddrID(netAddr); ok { + cm.server.connManager.Remove(id) + return nil } - return cm.server.connManager.CancelPending(netAddr) } - return nil + return err } -// disconnectNode disconnects any peers that the provided compare function +// disconnectNode disconnects any peer that the provided compare function // returns true for. It applies to both inbound and outbound peers. // -// An error will be returned if no matching peers are found (aka the compare +// An error will be returned if no matching peer is found (aka the compare // function returns false for all peers). // // This function is safe for concurrent access. func (cm *rpcConnManager) disconnectNode(cmp func(sp *serverPeer) bool) error { state := &cm.server.peerState - defer state.Unlock() + state.Lock() + defer state.Unlock() - // Check inbound peers. No callback is passed since there are no additional - // actions on disconnect for inbound peers. - found := disconnectPeer(state.inboundPeers, cmp, nil) - if found { - return nil - } + // The code below uses the fact that the connection manager prevents + // connections with duplicate addresses to limit the search to a single + // match. - // Check outbound peers in a loop to ensure all outbound connections to the - // same ip:port are disconnected when there are multiple. - var numFound uint32 - for ; ; numFound++ { - found = disconnectPeer(state.outboundPeers, cmp, func(sp *serverPeer) { - // Update the group counts since the peer will be removed from the - // persistent peers just after this func returns. - state.outboundGroups[sp.remoteAddr.GroupKey()]-- - }) - if !found { + // Check inbound peers. + var inbound *serverPeer + for _, peer := range state.inboundPeers { + if cmp(peer) { + inbound = peer break } } + if inbound != nil { + inbound.Disconnect() + return nil + } - if numFound == 0 { - return errors.New("peer not found") + // Check outbound peers. + var outbound *serverPeer + for _, peer := range state.outboundPeers { + if cmp(peer) { + outbound = peer + } } - return nil + if outbound != nil { + outbound.Disconnect() + return nil + } + + return errPeerNotFound } // DisconnectByID disconnects the peer associated with the provided id. This diff --git a/server.go b/server.go index 8914090da..f61afaf33 100644 --- a/server.go +++ b/server.go @@ -442,6 +442,7 @@ type serverPeer struct { // The service flags are updated in the address manager directly once the // peer reports them. The service flags on this instance are never used. server *server + conn *connmgr.Conn remoteAddr *addrmgr.NetAddress persistent bool isWhitelisted bool @@ -455,7 +456,6 @@ type serverPeer struct { // otherwise modified during operation and thus need to consider whether or // not they need to be protected for concurrent access. - connReq atomic.Pointer[connmgr.ConnReq] continueHash atomic.Pointer[chainhash.Hash] disableRelayTx atomic.Bool knownAddresses *apbf.Filter @@ -507,11 +507,12 @@ type serverPeer struct { // newServerPeer returns a new serverPeer instance. The peer needs to be set by // the caller. -func newServerPeer(s *server, remoteAddr *addrmgr.NetAddress, isPersistent bool) *serverPeer { +func newServerPeer(s *server, conn *connmgr.Conn, remoteAddr *addrmgr.NetAddress) *serverPeer { return &serverPeer{ server: s, + conn: conn, remoteAddr: remoteAddr, - persistent: isPersistent, + persistent: s.connManager.IsPersistent(conn.ID()), knownAddresses: apbf.NewFilter(maxKnownAddrsPerPeer, knownAddrsFPRate), quit: make(chan struct{}), getDataQueue: make(chan []*wire.InvVect, maxConcurrentGetDataReqs), @@ -2177,57 +2178,6 @@ func (s *server) handleBroadcastMsg(state *peerState, bmsg *broadcastMsg) { }) } -// disconnectPeer attempts to drop the connection of a targeted peer in the -// passed peer list. Targets are identified via usage of the passed -// `compareFunc`, which should return `true` if the passed peer is the target -// peer. This function returns true on success and false if the peer is unable -// to be located. If the peer is found, and the passed callback: `whenFound' -// isn't nil, we call it with the peer as the argument before it is removed -// from the peerList, and is disconnected from the server. -func disconnectPeer(peerList map[int32]*serverPeer, compareFunc func(*serverPeer) bool, whenFound func(*serverPeer)) bool { - for addr, peer := range peerList { - if compareFunc(peer) { - if whenFound != nil { - whenFound(peer) - } - - // This is ok because we are not continuing - // to iterate so won't corrupt the loop. - delete(peerList, addr) - peer.Disconnect() - return true - } - } - return false -} - -// connToNetAddr parses and returns an address manager network address from the -// remote address associated with the given connection. -// -// This function is safe for concurrent access. -func connToNetAddr(conn net.Conn) (*addrmgr.NetAddress, error) { - addrStr := conn.RemoteAddr().String() - host, portStr, err := net.SplitHostPort(addrStr) - if err != nil { - return nil, err - } - port, err := strconv.ParseUint(portStr, 10, 16) - if err != nil { - return nil, err - } - - addrType, addrBytes := addrmgr.EncodeHost(host) - if addrType == addrmgr.UnknownAddressType { - return nil, fmt.Errorf("unable to determine address type: %v", addrStr) - } - - // Since the host type has been successfully recognized and encoded, - // there is no need to perform a DNS lookup. - now := time.Unix(time.Now().Unix(), 0) - return addrmgr.NewNetAddressFromParams(addrType, addrBytes, uint16(port), - now, 0) -} - // handleBannedConn closes the provided connection if the remote address // associated with it is banned or the address can't be properly parsed. It // returns true when the connection is closed. @@ -2323,11 +2273,11 @@ func newPeerConfig(sp *serverPeer) *peer.Config { // instance, associates it with the connection, runs the peer (which starts all // additional server peer processing goroutines) and blocks until the peer // disconnects. -func (s *server) inboundPeerConnected(ctx context.Context, conn net.Conn) { - remoteNetAddr, err := connToNetAddr(conn) - if err != nil { - srvrLog.Debugf("Unable to create inbound peer for address %s: %v", - conn.RemoteAddr(), err) +func (s *server) inboundPeerConnected(ctx context.Context, conn *connmgr.Conn) { + remoteNetAddr, ok := conn.RemoteAddr().(*addrmgr.NetAddress) + if !ok { + srvrLog.Warnf("remote address for connection is incorrect type %T", + conn.RemoteAddr()) conn.Close() return } @@ -2337,9 +2287,9 @@ func (s *server) inboundPeerConnected(ctx context.Context, conn net.Conn) { return } - sp := newServerPeer(s, remoteNetAddr, false) - sp.isWhitelisted = isWhitelisted(remoteNetAddr) + sp := newServerPeer(s, conn, remoteNetAddr) sp.Peer = peer.NewInboundPeer(newPeerConfig(sp), conn) + sp.isWhitelisted = isWhitelisted(remoteNetAddr) if err := sp.Handshake(ctx, sp.OnVersion); err != nil { srvrLog.Debugf("Failed handshake for inbound peer %s: %v", remoteNetAddr, err) @@ -2355,31 +2305,28 @@ func (s *server) inboundPeerConnected(ctx context.Context, conn net.Conn) { // peer instance, associates it with the relevant state such as the connection // request instance and the connection itself, and start all additional server // peer processing goroutines. -func (s *server) outboundPeerConnected(ctx context.Context, c *connmgr.ConnReq, conn net.Conn) { - remoteNetAddr, err := connToNetAddr(conn) - if err != nil { - srvrLog.Debugf("Unable to create outbound peer for address %s: %v", - conn.RemoteAddr(), err) +func (s *server) outboundPeerConnected(ctx context.Context, conn *connmgr.Conn) { + remoteNetAddr, ok := conn.RemoteAddr().(*addrmgr.NetAddress) + if !ok { + srvrLog.Warnf("remote address for connection is incorrect type %T", + conn.RemoteAddr()) conn.Close() - s.connManager.Disconnect(c.ID()) + return } // Disconnect banned connections. Ideally we would never connect to a // banned peer, but the connection manager is currently unaware of banned // addresses, so this is needed. if disconnected := s.handleBannedConn(remoteNetAddr, conn); disconnected { - s.connManager.Disconnect(c.ID()) return } - sp := newServerPeer(s, remoteNetAddr, c.Permanent) - p := peer.NewOutboundPeer(newPeerConfig(sp), c.Addr, conn) - sp.Peer = p - sp.connReq.Store(c) + sp := newServerPeer(s, conn, remoteNetAddr) + sp.Peer = peer.NewOutboundPeer(newPeerConfig(sp), conn.RemoteAddr(), conn) sp.isWhitelisted = isWhitelisted(remoteNetAddr) if err := sp.Handshake(ctx, sp.OnVersion); err != nil { - srvrLog.Debugf("Failed handshake for outbound peer %s: %v", c.Addr, err) - s.connManager.Disconnect(c.ID()) + srvrLog.Debugf("Failed handshake for outbound peer %s: %v", + conn.RemoteAddr(), err) return } sp.syncMgrPeer = netsync.NewPeer(sp.Peer) @@ -2818,21 +2765,12 @@ func (s *server) DonePeer(sp *serverPeer) { if _, ok := list[sp.ID()]; ok { if !sp.Inbound() { state.outboundGroups[sp.remoteAddr.GroupKey()]-- - connReq := sp.connReq.Load() - if connReq != nil { - s.connManager.Disconnect(connReq.ID()) - } } delete(list, sp.ID()) srvrLog.Debugf("Removed peer %s", sp) return } - connReq := sp.connReq.Load() - if connReq != nil { - s.connManager.Disconnect(connReq.ID()) - } - // Update the address manager with the last seen time. This is skipped when // running on the simulation and regression test networks since they are // only intended to connect to specified peers and actively avoid @@ -4418,15 +4356,15 @@ func newServer(ctx context.Context, profiler *profileServer, } cmgr, err := connmgr.New(&connmgr.Config{ Listeners: listeners, - OnAccept: func(conn net.Conn) { + OnAccept: func(conn *connmgr.Conn) { s.inboundPeerConnected(ctx, conn) }, RetryDuration: connectionRetryInterval, TargetOutbound: s.targetOutbound, Dial: s.attemptDcrdDial, DialTimeout: cfg.DialTimeout, - OnConnection: func(c *connmgr.ConnReq, conn net.Conn) { - s.outboundPeerConnected(ctx, c, conn) + OnConnection: func(conn *connmgr.Conn) { + s.outboundPeerConnected(ctx, conn) }, GetNewAddress: newAddressFunc, }) @@ -4435,22 +4373,21 @@ func newServer(ctx context.Context, profiler *profileServer, } s.connManager = cmgr - // Start up persistent peers. - permanentPeers := cfg.ConnectPeers - if len(permanentPeers) == 0 { - permanentPeers = cfg.AddPeers + // Add persistent peers. + persistentPeers := cfg.ConnectPeers + if len(persistentPeers) == 0 { + persistentPeers = cfg.AddPeers } - for _, addr := range permanentPeers { + for _, addr := range persistentPeers { tcpAddr, err := addrStringToNetAddr(addr) if err != nil { return nil, err } - go s.connManager.Connect(ctx, - &connmgr.ConnReq{ - Addr: tcpAddr, - Permanent: true, - }) + _, err = s.connManager.AddPersistent(tcpAddr) + if err != nil { + return nil, err + } } if !cfg.DisableRPC { @@ -4636,14 +4573,14 @@ func addrStringToNetAddr(addr string) (net.Addr, error) { return nil, fmt.Errorf("no addresses found for %s", host) } - port, err := strconv.Atoi(strPort) + port, err := strconv.ParseUint(strPort, 10, 16) if err != nil { return nil, err } return &net.TCPAddr{ IP: ips[0], - Port: port, + Port: int(port), }, nil } From 9ad909d60d27e833f5582c641e528aa484e2dd8e Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Wed, 20 May 2026 18:45:53 -0500 Subject: [PATCH 04/24] connmgr: Make max retry duration a field. The max retry duration is currently an unexported global variable that the tests override at init time. At least one of the tests also additionally overrides it for that specified test too. While this works, it is somewhat brittle and prevents the tests from being run in parallel. This improves the situation by making the max retry duration a field on the connection manager instead of a global variable and adding a test helper for creating a new connection manager that overrides it by default. Then any tests that need a different value can simply override it on their local instance. It also makes the tests parallel since they can no longer clobber one another. --- internal/connmgr/connmanager.go | 21 +++-- internal/connmgr/connmanager_test.go | 130 +++++++++++++-------------- 2 files changed, 76 insertions(+), 75 deletions(-) diff --git a/internal/connmgr/connmanager.go b/internal/connmgr/connmanager.go index 83ddf10c5..e4e75b06d 100644 --- a/internal/connmgr/connmanager.go +++ b/internal/connmgr/connmanager.go @@ -25,14 +25,6 @@ const ( MaxPersistent = 8 ) -var ( - // maxRetryDuration is the maximum duration a persistent connection retry - // backoff is allowed to grow to. This is necessary since the retry logic - // uses a backoff mechanism which increases the interval base times the - // number of retries that have been done. - maxRetryDuration = time.Minute * 5 -) - const ( // maxFailedAttempts is the maximum number of successive failed connection // attempts after which network failure is assumed and new connections will @@ -43,6 +35,12 @@ const ( // persistent connections. defaultRetryDuration = time.Second * 5 + // defaultMaxRetryDuration is the default maximum duration a persistent + // connection retry backoff is allowed to grow to. This is necessary since + // the retry logic uses a backoff mechanism which increases the interval + // base times the number of retries that have been done. + defaultMaxRetryDuration = time.Minute * 5 + // defaultTargetOutbound is the default number of outbound connections to // maintain. defaultTargetOutbound = 8 @@ -274,6 +272,10 @@ type ConnManager struct { // creating time and treated as immutable after that. cfg Config + // maxRetryDuration is the maximum duration a persistent connection retry + // backoff is allowed to grow to. + maxRetryDuration time.Duration + // runPersistentChan is used to signal the persistent connections handler to // launch a goroutine that attempts to always maintain an established // connection with a given address. @@ -1026,7 +1028,7 @@ func (cm *ConnManager) runPersistent(ctx context.Context, connID uint64, addr ne retryCount++ } retryWait := time.Duration(retryCount) * cm.cfg.RetryDuration - retryWait = min(retryWait, maxRetryDuration) + retryWait = min(retryWait, cm.maxRetryDuration) log.Debugf("Retrying connection to %v in %v (retries %d)", addr, retryWait, retryCount) retryAfter = time.After(retryWait) @@ -1237,6 +1239,7 @@ func New(cfg *Config) (*ConnManager, error) { cm := ConnManager{ cfg: *cfg, // Copy so caller can't mutate quit: make(chan struct{}), + maxRetryDuration: defaultMaxRetryDuration, runPersistentChan: make(chan *persistentEntry, MaxPersistent), activeOutboundsSem: makeSemaphore(cfg.TargetOutbound), persistent: make(map[uint64]*persistentEntry, MaxPersistent), diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index 6001ec75a..fda415cd2 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -18,12 +18,11 @@ import ( "time" ) -func init() { - // Override the max retry duration when running tests. - maxRetryDuration = 2 * time.Millisecond -} - const ( + // defaultTestMaxRetryDuration is the default max duration a connection + // retry backoff is allowed to grow to when running tests. + defaultTestMaxRetryDuration = 2 * time.Millisecond + // connTestReceiveTimeout is the default receive timeout used throughout the // tests when expecting to receive connections to prevent test hangs. connTestReceiveTimeout = 10 * time.Millisecond @@ -109,18 +108,32 @@ func mockDialer(ctx context.Context, network, addr string) (net.Conn, error) { return c, ctx.Err() } +// newTestConnManager returns a new connection manager with the provided +// configuration and some timeout tweaks so that it is suitable for use in the +// tests. +func newTestConnManager(t *testing.T, cfg *Config) *ConnManager { + t.Helper() + + cmgr, err := New(cfg) + if err != nil { + t.Fatalf("New: unexpected error: %v", err) + } + cmgr.maxRetryDuration = defaultTestMaxRetryDuration + return cmgr +} + // TestNewConfig tests that new ConnManager config is validated as expected. func TestNewConfig(t *testing.T) { + t.Parallel() + _, err := New(&Config{}) if err == nil { t.Fatal("New expected error: 'Dial can't be nil', got nil") } - _, err = New(&Config{ + + newTestConnManager(t, &Config{ Dial: mockDialer, }) - if err != nil { - t.Fatalf("New unexpected error: %v", err) - } } // assertConnID ensures the provided connection has the given ID. @@ -235,17 +248,16 @@ func assertNoConnReceived(t *testing.T, ch <-chan *Conn) { // using [ConnManager.Connect] are handled and that no other connections are // made. func TestConnectMode(t *testing.T) { + t.Parallel() + connected := make(chan *Conn) - cmgr, err := New(&Config{ + cmgr := newTestConnManager(t, &Config{ TargetOutbound: 2, Dial: mockDialer, OnConnection: func(conn *Conn) { connected <- conn }, }) - if err != nil { - t.Fatalf("New error: %v", err) - } ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) addr := mustParseAddrPort("127.0.0.1:18555") @@ -264,10 +276,12 @@ func TestConnectMode(t *testing.T) { // configuration option by waiting until all connections are established and // ensuring they are the only connections made. func TestTargetOutbound(t *testing.T) { + t.Parallel() + const targetOutbound = 10 var nextAddr atomic.Uint32 connected := make(chan *Conn) - cmgr, err := New(&Config{ + cmgr := newTestConnManager(t, &Config{ TargetOutbound: targetOutbound, Dial: mockDialer, GetNewAddress: func() (net.Addr, error) { @@ -278,9 +292,6 @@ func TestTargetOutbound(t *testing.T) { connected <- conn }, }) - if err != nil { - t.Fatalf("New error: %v", err) - } _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) // Ensure only the expected number of target outbound conns are established @@ -297,9 +308,11 @@ func TestTargetOutbound(t *testing.T) { // TestRetryPersistent tests that persistent connections are retried. func TestRetryPersistent(t *testing.T) { + t.Parallel() + connected := make(chan *Conn) disconnected := make(chan *Conn) - cmgr, err := New(&Config{ + cmgr := newTestConnManager(t, &Config{ RetryDuration: time.Millisecond, TargetOutbound: 1, Dial: mockDialer, @@ -310,9 +323,6 @@ func TestRetryPersistent(t *testing.T) { disconnected <- conn }, }) - if err != nil { - t.Fatalf("New error: %v", err) - } _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) addr := mustParseAddrPort("127.0.0.1:18555") @@ -349,11 +359,13 @@ func TestRetryPersistent(t *testing.T) { // We have a timed dialer which initially returns err but after RetryDuration // hits maxRetryDuration returns a mock conn. func TestMaxRetryDuration(t *testing.T) { + t.Parallel() + // This test relies on the current value of the max retry duration defined // in the tests, so assert it. - if maxRetryDuration != 2*time.Millisecond { + if defaultTestMaxRetryDuration != 2*time.Millisecond { t.Fatalf("max retry duration of %v is not the required value for test", - maxRetryDuration) + defaultTestMaxRetryDuration) } networkUp := make(chan struct{}) @@ -367,7 +379,7 @@ func TestMaxRetryDuration(t *testing.T) { } connected := make(chan *Conn) - cmgr, err := New(&Config{ + cmgr := newTestConnManager(t, &Config{ RetryDuration: time.Millisecond, TargetOutbound: 1, Dial: timedDialer, @@ -375,9 +387,6 @@ func TestMaxRetryDuration(t *testing.T) { connected <- conn }, }) - if err != nil { - t.Fatalf("New error: %v", err) - } _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) connID, err := cmgr.AddPersistent(mustParseAddrPort("127.0.0.1:18555")) @@ -403,6 +412,8 @@ func TestMaxRetryDuration(t *testing.T) { // TestNetworkFailure tests that the connection manager handles a network // failure gracefully. func TestNetworkFailure(t *testing.T) { + t.Parallel() + var closeOnce sync.Once const targetOutbound = 5 const retryTimeout = time.Millisecond * 5 @@ -418,7 +429,7 @@ func TestNetworkFailure(t *testing.T) { return nil, errors.New("network down") } var nextAddr atomic.Uint32 - cmgr, err := New(&Config{ + cmgr := newTestConnManager(t, &Config{ TargetOutbound: targetOutbound, RetryDuration: retryTimeout, Dial: errDialer, @@ -431,9 +442,6 @@ func TestNetworkFailure(t *testing.T) { conn.RemoteAddr()) }, }) - if err != nil { - t.Fatalf("New error: %v", err) - } _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) // Shutdown the connection manager after the max failed attempts is reached @@ -464,13 +472,11 @@ func TestNetworkFailure(t *testing.T) { // responsive when there are multiple simultaneous failed connections for // persistent conns in the retry state. func TestMultipleFailedConns(t *testing.T) { + t.Parallel() + // Override the max retry duration for this test since it relies on having // multiple connections in the retry state. - curMaxRetryDuration := maxRetryDuration - maxRetryDuration = 500 * time.Millisecond - defer func() { - maxRetryDuration = curMaxRetryDuration - }() + const maxRetryDuration = 500 * time.Millisecond const targetFailed = 5 var dials atomic.Uint32 @@ -483,13 +489,11 @@ func TestMultipleFailedConns(t *testing.T) { } return nil, errors.New("network down") } - cmgr, err := New(&Config{ + cmgr := newTestConnManager(t, &Config{ RetryDuration: maxRetryDuration, Dial: errDialer, }) - if err != nil { - t.Fatalf("New error: %v", err) - } + cmgr.maxRetryDuration = maxRetryDuration _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) // Establish several connection requests to localhost IPs. @@ -531,26 +535,25 @@ func TestMultipleFailedConns(t *testing.T) { // TestShutdownFailedConns tests that failed connections are ignored after // connmgr is shutdown. func TestShutdownFailedConns(t *testing.T) { + t.Parallel() + var closeOnce sync.Once dialed := make(chan struct{}) waitDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { closeOnce.Do(func() { close(dialed) }) return nil, errors.New("network down") } - cmgr, err := New(&Config{ - RetryDuration: maxRetryDuration, + cmgr := newTestConnManager(t, &Config{ + RetryDuration: defaultTestMaxRetryDuration, Dial: waitDialer, }) - if err != nil { - t.Fatalf("New error: %v", err) - } ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) // Shutdown the connection manager during the retry timeout after a failed // dial attempt. go func() { <-dialed - time.Sleep(maxRetryDuration / 2) + time.Sleep(cmgr.maxRetryDuration / 2) shutdown() }() @@ -566,6 +569,8 @@ func TestShutdownFailedConns(t *testing.T) { // connection correctly cancels the context used to dial and removes the // internal state. func TestRemovePendingConnection(t *testing.T) { + t.Parallel() + // Create a conn manager with an instance of a dialer that'll never succeed. dialed := make(chan struct{}) canceled := make(chan struct{}) @@ -575,12 +580,9 @@ func TestRemovePendingConnection(t *testing.T) { close(canceled) return nil, errors.New("error") } - cmgr, err := New(&Config{ + cmgr := newTestConnManager(t, &Config{ Dial: indefiniteDialer, }) - if err != nil { - t.Fatalf("New error: %v", err) - } ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) // Establish a connection request to a localhost IP. @@ -622,6 +624,8 @@ func TestRemovePendingConnection(t *testing.T) { // connection will not execute the on connection callback, even if a pending // retry succeeds. func TestCancelIgnoreDelayedConnection(t *testing.T) { + t.Parallel() + const retryTimeout = 10 * time.Millisecond // Setup a dialer that returns an error on the first attempt and then blocks @@ -645,16 +649,13 @@ func TestCancelIgnoreDelayedConnection(t *testing.T) { } connected := make(chan *Conn) - cmgr, err := New(&Config{ + cmgr := newTestConnManager(t, &Config{ Dial: failingDialer, RetryDuration: retryTimeout, OnConnection: func(conn *Conn) { connected <- conn }, }) - if err != nil { - t.Fatalf("New error: %v", err) - } _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) // Establish a persistent connection to a localhost IP. @@ -694,6 +695,8 @@ func TestCancelIgnoreDelayedConnection(t *testing.T) { // dialer that blocks for three times the configured dial timeout before // connecting and ensuring the connection fails as expected. func TestDialTimeout(t *testing.T) { + t.Parallel() + // Create a connection manager instance with a dialer that blocks for three // times the configured dial timeout before connecting. const dialTimeout = time.Millisecond * 20 @@ -708,13 +711,10 @@ func TestDialTimeout(t *testing.T) { return mockDialer(ctx, network, addr) } - cmgr, err := New(&Config{ + cmgr := newTestConnManager(t, &Config{ Dial: timeoutDialer, DialTimeout: dialTimeout, }) - if err != nil { - t.Fatalf("New error: %v", err) - } ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) // Establish a connection to a localhost IP. @@ -737,6 +737,8 @@ func TestDialTimeout(t *testing.T) { // TestConnectContext ensures the [ConnManager.Connect] method works as intended // when provided with a context that is canceled before a dial attempt succeeds. func TestConnectContext(t *testing.T) { + t.Parallel() + // Create a connection manager instance with a dialer that blocks until its // provided context is canceled. dialed := make(chan struct{}) @@ -745,12 +747,9 @@ func TestConnectContext(t *testing.T) { <-ctx.Done() return nil, ctx.Err() } - cmgr, err := New(&Config{ + cmgr := newTestConnManager(t, &Config{ Dial: indefiniteDialer, }) - if err != nil { - t.Fatalf("New error: %v", err) - } ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) // Establish a connection request to a localhost IP with a separate context @@ -850,22 +849,21 @@ func newMockListener(localAddr string) *mockListener { // TestListeners ensures providing listeners to the connection manager along // with an accept callback works properly. func TestListeners(t *testing.T) { + t.Parallel() + // Setup a connection manager with a couple of mock listeners that // notify a channel when they receive mock connections. receivedConns := make(chan *Conn) listener1 := newMockListener("127.0.0.1:9108") listener2 := newMockListener("127.0.0.1:9208") listeners := []net.Listener{listener1, listener2} - cmgr, err := New(&Config{ + cmgr := newTestConnManager(t, &Config{ Listeners: listeners, OnAccept: func(conn *Conn) { receivedConns <- conn }, Dial: mockDialer, }) - if err != nil { - t.Fatalf("New error: %v", err) - } _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) // Fake a couple of mock connections to each of the listeners. From 61122e4761d0a559a3b833580b669c2e0ef0dee7 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Sun, 24 May 2026 17:25:31 -0500 Subject: [PATCH 05/24] connmgr: Correct shutdown failed conns test. This updates the test for checking the connection manager cleanly shuts down with failed conns to actualy test what it is intended to. Manual connections do not automatically retry, only persistent connections. --- internal/connmgr/connmanager_test.go | 29 +++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index fda415cd2..166387a53 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -537,6 +537,7 @@ func TestMultipleFailedConns(t *testing.T) { func TestShutdownFailedConns(t *testing.T) { t.Parallel() + const retryTimeout = time.Second var closeOnce sync.Once dialed := make(chan struct{}) waitDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { @@ -544,22 +545,28 @@ func TestShutdownFailedConns(t *testing.T) { return nil, errors.New("network down") } cmgr := newTestConnManager(t, &Config{ - RetryDuration: defaultTestMaxRetryDuration, + RetryDuration: retryTimeout, Dial: waitDialer, }) - ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + cmgr.maxRetryDuration = retryTimeout + _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + + // Add a persistent connection. + addr := mustParseAddrPort("127.0.0.1:18555") + _, err := cmgr.AddPersistent(addr) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } // Shutdown the connection manager during the retry timeout after a failed // dial attempt. - go func() { - <-dialed - time.Sleep(cmgr.maxRetryDuration / 2) - shutdown() - }() - - // Establish a connection. - addr := mustParseAddrPort("127.0.0.1:18555") - go cmgr.Connect(ctx, addr) + select { + case <-dialed: + case <-time.After(connTestNonReceiveTimeout): + t.Fatal("timeout waiting for dial") + } + time.Sleep(connTestNonReceiveTimeout) + shutdown() // Ensure clean shutdown of connection manager. wg.Wait() From 1ef516b1380b7e3defc6557eae33351c46cab4df Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Sun, 17 May 2026 07:28:10 -0500 Subject: [PATCH 06/24] connmgr: Add double close tests. This adds tests to ensure closing a connection multiple times works as intended. --- internal/connmgr/connmanager_test.go | 42 ++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index 166387a53..b298f9198 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -306,6 +306,48 @@ func TestTargetOutbound(t *testing.T) { wg.Wait() } +// TestDoubleClose ensures closing a connection multiple times is a noop after +// the first call. +func TestDoubleClose(t *testing.T) { + t.Parallel() + + connected := make(chan *Conn) + cmgr := newTestConnManager(t, &Config{ + TargetOutbound: 1, + Dial: mockDialer, + GetNewAddress: func() (net.Addr, error) { + return mustParseAddrPort("127.0.0.1:18555"), nil + }, + OnConnection: func(conn *Conn) { + connected <- conn + }, + }) + _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + + // Wait for the connection to be established. + conn := assertConnReceived(t, connected, 0, ConnTypeOutbound) + + // Override the close func to cleanly detect closes. + var numClosed uint32 + origOnClose := conn.onClose + conn.onClose = func() { + numClosed++ + origOnClose() + } + + // Close the connection multiple times and make sure it only happens once. + for range 3 { + conn.Close() + } + if numClosed != 1 { + t.Fatal("connection closed more than once") + } + + // Ensure clean shutdown of connection manager. + shutdown() + wg.Wait() +} + // TestRetryPersistent tests that persistent connections are retried. func TestRetryPersistent(t *testing.T) { t.Parallel() From a3f4690705f37fec1c7dccf0ecd7056bad371a86 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Sun, 17 May 2026 07:28:10 -0500 Subject: [PATCH 07/24] connmgr: Add duplicate conn rejection tests. This adds tests to ensure duplication connections are rejected for all possible states. --- internal/connmgr/connmanager_test.go | 129 +++++++++++++++++++++++++++ 1 file changed, 129 insertions(+) diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index b298f9198..aaa59b273 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -934,3 +934,132 @@ func TestListeners(t *testing.T) { shutdown() wg.Wait() } + +// TestRejectDuplicateConns ensures duplicate addresses are rejected. This +// includes: +// - Attempts to dial addresses that already have pending, established, and +// persistent connections (via [ConnManager.Connect] +// - Attempts to add duplicate persistent conns (via [ConnManager.AddPersistent]) +// - Attempts to receive inbound remote addresses that already have pending, +// established, and persistent connections +func TestRejectDuplicateConns(t *testing.T) { + t.Parallel() + + var closeDialedOnce sync.Once + inboundConns := make(chan *Conn) + listener := newMockListener("127.0.0.1:18109") + connected := make(chan *Conn) + disconnected := make(chan *Conn) + dialed := make(chan struct{}) + pending := make(chan struct{}) + pendingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + closeDialedOnce.Do(func() { close(dialed) }) + <-pending + return mockDialer(ctx, network, addr) + } + cmgr := newTestConnManager(t, &Config{ + Listeners: []net.Listener{listener}, + OnAccept: func(conn *Conn) { + inboundConns <- conn + }, + Dial: pendingDialer, + OnConnection: func(conn *Conn) { + connected <- conn + }, + OnDisconnection: func(conn *Conn) { + disconnected <- conn + }, + }) + ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + + // Dial a manual connection and wait for it to become pending. + addr := mustParseAddrPort("127.0.0.1:18555") + go cmgr.Connect(ctx, addr) + select { + case <-dialed: + case <-time.After(time.Millisecond * 5): + t.Fatal("did not receive pending dial before timeout") + } + assertPendingAddr(t, cmgr, addr) + + // Duplicate connect to the pending address should be rejected. + if _, err := cmgr.Connect(ctx, addr); !errors.Is(err, ErrAlreadyPending) { + t.Fatalf("did not reject duplicate pending connection, err: %v", err) + } + + // Inbound attempts from the pending outbound address should be rejected. + go listener.Connect(addr) + assertNoConnReceived(t, inboundConns) + + // Allow the pending connection to complete. + close(pending) + conn := assertConnReceived(t, connected, 0, ConnTypeManual) + + // Duplicate connect to the established address should be rejected. + if _, err := cmgr.Connect(ctx, addr); !errors.Is(err, ErrAlreadyConnected) { + t.Fatalf("did not reject duplicate active connection, err: %v", err) + } + + // Inbound attempts from the established outbound address should be + // rejected. + go listener.Connect(addr) + assertNoConnReceived(t, inboundConns) + + // Close the connection and wait for the disconnect. + conn.Close() + assertConnReceived(t, disconnected, conn.ID(), ConnTypeManual) + + // Add a persistent connection back to the same address and wait for it to + // connect since there are no longer any connections to the address. + connID, err := cmgr.AddPersistent(addr) + if err != nil { + t.Fatalf("failed to add persistent connection: %v", err) + } + assertConnReceived(t, connected, connID, ConnTypeManual) + + // Duplicate persistent connection attempts should be rejected. + _, err = cmgr.AddPersistent(addr) + if !errors.Is(err, ErrDuplicatePersistent) { + t.Fatalf("did not reject duplicate persistent connection, err: %v", err) + } + + // Manual connection attempts to persistent connection should be rejected. + _, err = cmgr.Connect(ctx, addr) + if !errors.Is(err, ErrDuplicatePersistent) { + t.Fatalf("did not reject manual connection to persistent, err: %v", err) + } + + // Inbound atempts from the persistent address should be rejected. + go listener.Connect(addr) + assertNoConnReceived(t, inboundConns) + + // Remove the persistent connection, wait for it to disconnect, and ensure + // it is actually removed. + if err := cmgr.Remove(connID); err != nil { + t.Fatalf("failed to remove persistent connection: %v", err) + } + assertConnReceived(t, disconnected, connID, ConnTypeManual) + assertRemovedPersistent(t, cmgr, addr) + + // Inbound connections from the same address should now succeed. + go listener.Connect(addr) + assertConnReceived(t, inboundConns, 0, ConnTypeInbound) + + // Manual connection attempts to the inbound address should be rejected. + if _, err := cmgr.Connect(ctx, addr); !errors.Is(err, ErrAlreadyConnected) { + t.Fatalf("did not reject outbound for existing inbound conn, err: %v", + err) + } + + // Attempts to add a persistent connection to an existing inbound should be + // rejected. + _, err = cmgr.AddPersistent(addr) + if !errors.Is(err, ErrAlreadyConnected) { + t.Fatalf("did not reject persistent conn for existing inbound conn: %v", + err) + } + + // Ensure clean shutdown of connection manager. + shutdown() + wg.Wait() +} From 9e6fe3762d27283861c6fc49634764e7ef13b6ae Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Sun, 17 May 2026 07:28:11 -0500 Subject: [PATCH 08/24] connmgr: Add max persistent conns test. This adds tests to ensure attempts to add more than the maximum allowed number of persistent are rejected. --- internal/connmgr/connmanager_test.go | 80 ++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index aaa59b273..a4dce9dd9 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -396,6 +396,86 @@ func TestRetryPersistent(t *testing.T) { wg.Wait() } +// TestMaxPersistent ensures [ConnManager.AddPersistent] limits the maximum +// number of persistent connections including a removal and addition of a new +// one after achieving the max. +func TestMaxPersistent(t *testing.T) { + t.Parallel() + + connected := make(chan *Conn) + disconnected := make(chan *Conn) + cmgr := newTestConnManager(t, &Config{ + Dial: mockDialer, + OnConnection: func(conn *Conn) { + connected <- conn + }, + OnDisconnection: func(conn *Conn) { + disconnected <- conn + }, + }) + _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + + var numAddrs uint32 + nextAddr := func() net.Addr { + numAddrs++ + addrStr := fmt.Sprintf("127.0.0.%d:18555", numAddrs) + return mustParseAddrPort(addrStr) + } + + // Add the maximum allowed number of persistent conns. + connIDs := make([]uint64, 0, MaxPersistent) + addrs := make([]net.Addr, 0, MaxPersistent) + for range MaxPersistent { + addr := nextAddr() + connID, err := cmgr.AddPersistent(addr) + if err != nil { + t.Fatalf("failed to add persistent connection %v: %v", addr, err) + } + connIDs = append(connIDs, connID) + addrs = append(addrs, addr) + + // Wait for the connection. + assertConnReceived(t, connected, connID, ConnTypeManual) + } + + // Attempting to add more than the max allowed number of persistent conns + // should be rejected. + _, err := cmgr.AddPersistent(nextAddr()) + if !errors.Is(err, ErrMaxPersistent) { + t.Fatalf("did not reject > max persistent, err: %v", err) + } + + // Ensure disconnecting the persistent conn does not incorrectly decrement + // the count. + connID, addr := connIDs[0], addrs[0] + if err := cmgr.Disconnect(connID); err != nil { + t.Fatalf("failed to disconnect persistent conn %v: %v", addr, err) + } + _, err = cmgr.AddPersistent(nextAddr()) + if !errors.Is(err, ErrMaxPersistent) { + t.Fatalf("did not reject max persistent after dc, err: %v", err) + } + + // Remove the first persistent connection, wait for it to disconnect, and + // ensure it is actually removed. + if err := cmgr.Remove(connID); err != nil { + t.Fatalf("failed to remove persistent conn %v: %v", addr, err) + } + assertConnReceived(t, disconnected, connID, ConnTypeManual) + assertRemovedPersistent(t, cmgr, addr) + + // A new persistent conn should now be allowed. + addr = nextAddr() + _, err = cmgr.AddPersistent(addr) + if err != nil { + t.Fatalf("failed to add persistent connection %v: %v", addr, err) + } + + // Ensure clean shutdown of connection manager. + shutdown() + wg.Wait() +} + // TestMaxRetryDuration tests the maximum retry duration. // // We have a timed dialer which initially returns err but after RetryDuration From 354adde8c96923390e8eaa49634d2feebfebc752 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Sun, 17 May 2026 07:28:12 -0500 Subject: [PATCH 09/24] connmgr: Add disconnect by id tests. This adds tests to ensure the Disconnect method properly disconnects pending and established connections for both non-persistent and persistent connections. --- internal/connmgr/connmanager_test.go | 151 +++++++++++++++++++++++++++ 1 file changed, 151 insertions(+) diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index a4dce9dd9..70389177c 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -272,6 +272,157 @@ func TestConnectMode(t *testing.T) { wg.Wait() } +// TestDisconnect ensures that [ConnManager.Disconnect] properly disconnects +// pending and established connections for both non-persistent and persistent +// connections. +func TestDisconnect(t *testing.T) { + t.Parallel() + + // Create a connection manager instance with a dialer that has a few + // synchronization channels to notify when a dial attempt is made, to keep + // connection attempts in a pending state, and to notify when the context + // for the attempt is canceled. Whether or not to wait/send the signals are + // controlled by the associated atomic flags. + connected := make(chan *Conn) + disconnected := make(chan *Conn) + dialed := make(chan struct{}) + pending := make(chan struct{}) + canceled := make(chan struct{}) + var notifyDialed, waitForPending, notifyCanceled atomic.Bool + pendingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + if notifyDialed.Load() { + dialed <- struct{}{} + } + if waitForPending.Load() { + <-pending + } + conn, err := mockDialer(ctx, network, addr) + if errors.Is(err, context.Canceled) && notifyCanceled.Load() { + canceled <- struct{}{} + } + return conn, err + } + cmgr := newTestConnManager(t, &Config{ + Dial: pendingDialer, + OnConnection: func(conn *Conn) { + connected <- conn + }, + OnDisconnection: func(conn *Conn) { + disconnected <- conn + }, + }) + ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + + // Attempt a connection to a localhost IP. + notifyDialed.Store(true) + waitForPending.Store(true) + notifyCanceled.Store(true) + addr := mustParseAddrPort("127.0.0.1:18555") + go cmgr.Connect(ctx, addr) + + // Wait for the connection manager to attempt to dial and ensure the + // connection is marked as pending while the dialer is blocked. + select { + case <-dialed: + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting for dial") + } + assertPendingAddr(t, cmgr, addr) + + // Disconnect the connection attempt while it's still pending. + connID, _ := pendingAddrConnID(cmgr, addr) + if err := cmgr.Disconnect(connID); err != nil { + t.Fatalf("unexpected disconnect err: %v", err) + } + + // Allow the dialer to proceed with the disconnected connection attempt and + // then wait for the dialer to signal the context associated with the dial + // was canceled. Finally, ensure the internal pending state is removed. + select { + case pending <- struct{}{}: + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting to signal pending") + } + select { + case <-canceled: + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting for cancel") + } + if _, ok := pendingAddrConnID(cmgr, addr); ok { + t.Fatalf("connection %s is still pending", addr) + } + + // Start a connection attempt and wait for it to be established. + notifyDialed.Store(false) + waitForPending.Store(false) + notifyCanceled.Store(false) + go cmgr.Connect(ctx, addr) + conn := assertConnReceived(t, connected, 0, ConnTypeManual) + + // Disconnect the established connection and wait for the disconnect + // notification to ensure it is disconnected as intended. + connID = conn.ID() + if err := cmgr.Disconnect(connID); err != nil { + t.Fatalf("unexpected disconnect err: %v", err) + } + assertConnReceived(t, disconnected, connID, ConnTypeManual) + + // Add a persistent connection back to the same address. + notifyDialed.Store(true) + waitForPending.Store(true) + notifyCanceled.Store(true) + connID, err := cmgr.AddPersistent(addr) + if err != nil { + t.Fatalf("failed to add persistent connection: %v", err) + } + + // Wait for the connection manager to attempt to dial and ensure the + // connection is marked as pending while the dialer is blocked. + select { + case <-dialed: + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting for dial") + } + assertPendingAddr(t, cmgr, addr) + + // Disconnect the persistent connection attempt while it's still pending. + if err := cmgr.Disconnect(connID); err != nil { + t.Fatalf("unexpected disconnect err: %v", err) + } + + // Allow the dialer to proceed with the disconnected persistent connection + // attempt and then wait for the dialer to signal the context associated + // with the dial was canceled. + select { + case pending <- struct{}{}: + // Ensure the reconnect attempt doesn't notify the dialed chan or + // wait for the pending chan. + notifyDialed.Store(false) + waitForPending.Store(false) + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting to signal pending") + } + select { + case <-canceled: + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting for cancel") + } + + // Wait for the retry to be established. + assertConnReceived(t, connected, connID, ConnTypeManual) + + // Disconnect the established persistent connection and wait for the + // disconnect notification to ensure it is disconnected as intended. + if err := cmgr.Disconnect(connID); err != nil { + t.Fatalf("unexpected disconnect err: %v", err) + } + assertConnReceived(t, disconnected, connID, ConnTypeManual) + + // Ensure clean shutdown of connection manager. + shutdown() + wg.Wait() +} + // TestTargetOutbound tests the target number of outbound connections // configuration option by waiting until all connections are established and // ensuring they are the only connections made. From bd5aaf2d92b949ea208c7da087ff13553336d1bf Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Sun, 17 May 2026 07:28:12 -0500 Subject: [PATCH 10/24] connmgr: Add remove by id tests. This adds tests to ensure the Remove method properly disconnects and removes pending and established connections for both non-persistent and persistent connections. --- internal/connmgr/connmanager_test.go | 167 +++++++++++++++++++++++++++ 1 file changed, 167 insertions(+) diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index 70389177c..0e22b60c8 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -423,6 +423,173 @@ func TestDisconnect(t *testing.T) { wg.Wait() } +// TestRemove ensures that [ConnManager.Remove] properly removes pending and +// established connections for both non-persistent and persistent connections. +// +// It also ensures removal of an invalid ID returns the expected error. +func TestRemove(t *testing.T) { + t.Parallel() + + // Create a connection manager instance with a dialer that has a few + // synchronization channels to notify when a dial attempt is made, to keep + // connection attempts in a pending state, and to notify when the context + // for the attempt is canceled. Whether or not to wait/send the signals are + // controlled by the associated atomic flags. + connected := make(chan *Conn) + disconnected := make(chan *Conn) + dialed := make(chan struct{}) + pending := make(chan struct{}) + canceled := make(chan struct{}) + var notifyDialed, waitForPending, notifyCanceled atomic.Bool + pendingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + if notifyDialed.Load() { + dialed <- struct{}{} + } + if waitForPending.Load() { + <-pending + } + conn, err := mockDialer(ctx, network, addr) + if errors.Is(err, context.Canceled) && notifyCanceled.Load() { + canceled <- struct{}{} + } + return conn, err + } + cmgr := newTestConnManager(t, &Config{ + Dial: pendingDialer, + OnConnection: func(conn *Conn) { + connected <- conn + }, + OnDisconnection: func(conn *Conn) { + disconnected <- conn + }, + }) + ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + + // Ensure removing an ID that doesn't exist returns the expected error. + if err := cmgr.Remove(^uint64(0)); !errors.Is(err, ErrNotFound) { + t.Fatalf("mismatched remove error: got %v, want %v", err, ErrNotFound) + } + + // Attempt a connection to a localhost IP. + notifyDialed.Store(true) + waitForPending.Store(true) + notifyCanceled.Store(true) + addr := mustParseAddrPort("127.0.0.1:18555") + go cmgr.Connect(ctx, addr) + + // Wait for the connection manager to attempt to dial and ensure the + // connection is marked as pending while the dialer is blocked. + select { + case <-dialed: + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting for dial") + } + assertPendingAddr(t, cmgr, addr) + + // Remove the connection attempt while it's still pending. + connID, _ := pendingAddrConnID(cmgr, addr) + if err := cmgr.Remove(connID); err != nil { + t.Fatalf("unexpected remove err: %v", err) + } + + // Allow the dialer to proceed with the removed connection attempt and then + // wait for the dialer to signal the context associated with the dial was + // canceled. Finally, ensure the internal pending state is removed. + select { + case pending <- struct{}{}: + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting to signal pending") + } + select { + case <-canceled: + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting for cancel") + } + if _, ok := pendingAddrConnID(cmgr, addr); ok { + t.Fatalf("connection %s is still pending", addr) + } + + // Start a connection attempt and wait for it to be established. + notifyDialed.Store(false) + waitForPending.Store(false) + notifyCanceled.Store(false) + go cmgr.Connect(ctx, addr) + conn := assertConnReceived(t, connected, 0, ConnTypeManual) + + // Remove the established connection and wait for the disconnect + // notification to ensure it is disconnected as intended. + connID = conn.ID() + if err := cmgr.Remove(connID); err != nil { + t.Fatalf("unexpected disconnect err: %v", err) + } + assertConnReceived(t, disconnected, connID, ConnTypeManual) + + // Add a persistent connection back to the same address. + notifyDialed.Store(true) + waitForPending.Store(true) + notifyCanceled.Store(true) + connID, err := cmgr.AddPersistent(addr) + if err != nil { + t.Fatalf("failed to add persistent connection: %v", err) + } + + // Wait for the connection manager to attempt to dial and ensure the + // connection is marked as pending while the dialer is blocked. + select { + case <-dialed: + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting for dial") + } + assertPendingAddr(t, cmgr, addr) + + // Remove the persistent connection attempt while it's still pending. + if err := cmgr.Remove(connID); err != nil { + t.Fatalf("unexpected disconnect err: %v", err) + } + + // Allow the dialer to proceed with the removed persistent connection + // attempt and then wait for the dialer to signal the context associated + // with the dial was canceled. + select { + case pending <- struct{}{}: + // Ensure the reconnect attempt doesn't notify the dialed chan or + // wait for the pending chan. + notifyDialed.Store(false) + waitForPending.Store(false) + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting to signal pending") + } + select { + case <-canceled: + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting for cancel") + } + + // Add a persistent connection back to the same address and wait for it to + // be established. + notifyDialed.Store(false) + waitForPending.Store(false) + notifyCanceled.Store(false) + connID, err = cmgr.AddPersistent(addr) + if err != nil { + t.Fatalf("failed to add persistent connection: %v", err) + } + conn2 := assertConnReceived(t, connected, connID, ConnTypeManual) + + // Remove the established persistent connection and wait for the disconnect + // notification to ensure it is disconnected as intended. Also, ensure the + // persistent connection entry is removed. + connID = conn2.ID() + if err := cmgr.Remove(connID); err != nil { + t.Fatalf("unexpected disconnect err: %v", err) + } + assertConnReceived(t, disconnected, connID, ConnTypeManual) + + // Ensure clean shutdown of connection manager. + shutdown() + wg.Wait() +} + // TestTargetOutbound tests the target number of outbound connections // configuration option by waiting until all connections are established and // ensuring they are the only connections made. From 3b78fbb51e4449068af9fc27f52d70ca459049a8 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Sun, 17 May 2026 07:28:13 -0500 Subject: [PATCH 11/24] connmgr: Update README.md. This updates the connmgr package README.md to match the new design and capabilities. --- internal/connmgr/README.md | 58 +++++++++++++++++++++++---------- internal/connmgr/connmanager.go | 2 ++ internal/connmgr/doc.go | 21 ------------ 3 files changed, 42 insertions(+), 39 deletions(-) delete mode 100644 internal/connmgr/doc.go diff --git a/internal/connmgr/README.md b/internal/connmgr/README.md index 46114366a..b9299d864 100644 --- a/internal/connmgr/README.md +++ b/internal/connmgr/README.md @@ -5,26 +5,48 @@ connmgr [![ISC License](https://img.shields.io/badge/license-ISC-blue.svg)](http://copyfree.org) [![Doc](https://img.shields.io/badge/doc-reference-blue.svg)](https://pkg.go.dev/github.com/decred/dcrd/internal/connmgr) -Package connmgr implements a generic Decred network connection manager. - ## Overview -This package handles all the general connection concerns such as maintaining a -set number of outbound connections, sourcing peers, banning, limiting max -connections, tor lookup, etc. - -The package provides a generic connection manager which is able to accept -connection requests from a source or a set of given addresses, dial them and -notify the caller on connections. The main intended use is to initialize a pool -of active connections and maintain them to remain connected to the P2P network. - -In addition the connection manager provides the following utilities: - -- Notifications on connections or disconnections -- Handle failures and retry new addresses from the source -- Connect only to specified addresses -- Permanent connections with increasing backoff retry timers -- Disconnect or Remove an established connection +Package `connmgr` provides a flexible and robust context-aware connection +manager for inbound, outbound, and persistent network connections with retry +logic. + +It handles all general connection lifecycle concerns such as accepting inbound +connections, automatically maintaining a set number of outbound connections, +maintaining persistent connections, and limiting max connections. + +The design has a strong emphasis on reliability, readability, and efficiency under high connection load while also aiming to provide an ergonomic API. + +The following is a brief overview of the key features: + +- Full context support +- Inbound listening + - Accepts inbound connections on provided `Listeners` + - Uses connection shedding for rejected inbound connections +- Automatic outbound maintenance + - Maintains up to `TargetOutbound` normal outbound connections via a provided + address source (`GetNewAddress`) +- Persistent connections + - Maintains up to `MaxPersistent` addresses that are automatically retried + with exponential backoff on disconnect +- Manual connections + - Supports manual connection establishment via `Connect` +- Duplicate address prevention + - Rejects duplicate connections to and from the same address (host:port) +- Rich managed connections via `Conn` + - Connection types for differentiated handling + - Automatic cleanup on connection close + - Concrete parsed address access +- Manual disconnection and removal + - Ability to disconnect / remove established, pending, and persistent + connections via `Disconnect` and `Remove` +- Notification callbacks + - Provides callbacks for connection establishment and disconnects +- Graceful network outage handling + - Automatic connection attempts are throttled during network outages +- Clear and actionable programatically-detectable errors + +A full suite of tests is provided to help ensure proper functionality. ## License diff --git a/internal/connmgr/connmanager.go b/internal/connmgr/connmanager.go index e4e75b06d..453e4a33d 100644 --- a/internal/connmgr/connmanager.go +++ b/internal/connmgr/connmanager.go @@ -3,6 +3,8 @@ // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. +// Package connmgr provides a robust connection manager for inbound, outbound, +// and persistent network connections with retry logic. package connmgr import ( diff --git a/internal/connmgr/doc.go b/internal/connmgr/doc.go deleted file mode 100644 index 3eb872c3c..000000000 --- a/internal/connmgr/doc.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) 2016 The btcsuite developers -// Copyright (c) 2017-2022 The Decred developers -// Use of this source code is governed by an ISC -// license that can be found in the LICENSE file. - -/* -Package connmgr implements a generic Decred network connection manager. - -# Deprecated - -This module is deprecated and is no longer maintained. Callers are encouraged -to use github.com/decred/dcrd/addrmgr/vX for methods that were moved to it -instead. - -# Connection Manager Overview - -Connection manager handles all the general connection concerns such as -maintaining a set number of outbound connections, sourcing peers, banning, -limiting max connections, tor lookup, etc. -*/ -package connmgr From 1cc2fa725c10039873fe7d0344f5fd5460275d37 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Fri, 22 May 2026 20:58:21 -0500 Subject: [PATCH 12/24] connmgr: Add internal state test assertions. This adds a couple of test helpers for asserting the internal state of the connection manager updates all tests to call the new helpers throughout. The first one asserts the internal maps are all coherent and do not violate any preconditions. The second one asserts clean shutdown. --- internal/connmgr/connmanager_test.go | 162 +++++++++++++++++++++++++++ 1 file changed, 162 insertions(+) diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index 0e22b60c8..ecbb0ccf0 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -12,6 +12,7 @@ import ( "io" "net" "net/netip" + "reflect" "sync" "sync/atomic" "testing" @@ -122,6 +123,88 @@ func newTestConnManager(t *testing.T, cfg *Config) *ConnManager { return cmgr } +// assertConnManagerInternalState ensures the internal state of the passed +// connection manager instance is coherent. +func assertConnManagerInternalState(t *testing.T, cm *ConnManager) { + t.Helper() + + cm.connMtx.Lock() + defer cm.connMtx.Unlock() + + // Assert established persistent conns have the correct connection type. + for id, conn := range cm.active { + if _, ok := cm.persistent[id]; ok { + want := ConnTypeManual + if got := conn.Type(); got != want { + t.Fatalf("bad conn type in active map: %v != %v", got, want) + } + } + } + + // Assert the pending and active maps are mutually exclusive for both conn + // IDs and addrs. + // + // Also build a map of addrs to conn IDs in the pending, active, and + // persistent maps for the checks below. + connIDByAddr := make(map[string]uint64) + for id, info := range cm.pending { + if _, ok := cm.active[id]; ok { + t.Fatalf("conn ID %d is both pending and active", id) + } + connIDByAddr[info.addr.String()] = id + } + for id, conn := range cm.active { + if _, ok := cm.pending[id]; ok { + t.Fatalf("conn ID %d is both pending and active", id) + } + addrStr := conn.remoteAddr.String() + if _, ok := connIDByAddr[addrStr]; ok { + t.Fatalf("addr %s is both pending and active", addrStr) + } + connIDByAddr[addrStr] = id + } + for id, entry := range cm.persistent { + // Assert the conn ID of established/pending persistent conns matches. + addrStr := entry.addr.String() + if existingID, ok := connIDByAddr[addrStr]; ok && existingID != id { + t.Fatalf("conn ID for addr %s mismatch: %d != %d", addrStr, + existingID, id) + } + connIDByAddr[addrStr] = id + } + + // Assert the addr to conn ID mappings match the values obtained from + // manually constructing them. + if !reflect.DeepEqual(cm.connIDByAddr, connIDByAddr) { + t.Fatalf("mismatched conn ID by addr maps\ngot: %v\nwant %v", + cm.connIDByAddr, connIDByAddr) + } +} + +// assertConnManagerCleanShutdown ensures the internal state of the passed +// connection manager is fully cleaned up as expected. It must only be called +// after [ConnManager.Run] returns. +func assertConnManagerCleanShutdown(t *testing.T, cm *ConnManager) { + t.Helper() + + cm.connMtx.Lock() + defer cm.connMtx.Unlock() + + if len(cm.active) != 0 { + t.Fatalf("active map is not empty: %d entries", len(cm.active)) + } + if len(cm.pending) != 0 { + t.Fatalf("pending map is not empty: %d entries", len(cm.pending)) + } + if len(cm.persistent) != 0 { + t.Fatalf("persistent map is not empty: %d entries", len(cm.persistent)) + } + if len(cm.connIDByAddr) != 0 { + t.Fatalf("conn ID by addr map not empty: %d entries", + len(cm.connIDByAddr)) + } +} + // TestNewConfig tests that new ConnManager config is validated as expected. func TestNewConfig(t *testing.T) { t.Parallel() @@ -266,10 +349,12 @@ func TestConnectMode(t *testing.T) { // Ensure that only a single connection is received. assertConnReceived(t, connected, 0, ConnTypeManual) assertNoConnReceived(t, connected) + assertConnManagerInternalState(t, cmgr) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) } // TestDisconnect ensures that [ConnManager.Disconnect] properly disconnects @@ -328,12 +413,14 @@ func TestDisconnect(t *testing.T) { t.Fatal("timeout waiting for dial") } assertPendingAddr(t, cmgr, addr) + assertConnManagerInternalState(t, cmgr) // Disconnect the connection attempt while it's still pending. connID, _ := pendingAddrConnID(cmgr, addr) if err := cmgr.Disconnect(connID); err != nil { t.Fatalf("unexpected disconnect err: %v", err) } + assertConnManagerInternalState(t, cmgr) // Allow the dialer to proceed with the disconnected connection attempt and // then wait for the dialer to signal the context associated with the dial @@ -351,6 +438,7 @@ func TestDisconnect(t *testing.T) { if _, ok := pendingAddrConnID(cmgr, addr); ok { t.Fatalf("connection %s is still pending", addr) } + assertConnManagerInternalState(t, cmgr) // Start a connection attempt and wait for it to be established. notifyDialed.Store(false) @@ -358,6 +446,7 @@ func TestDisconnect(t *testing.T) { notifyCanceled.Store(false) go cmgr.Connect(ctx, addr) conn := assertConnReceived(t, connected, 0, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) // Disconnect the established connection and wait for the disconnect // notification to ensure it is disconnected as intended. @@ -366,6 +455,7 @@ func TestDisconnect(t *testing.T) { t.Fatalf("unexpected disconnect err: %v", err) } assertConnReceived(t, disconnected, connID, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) // Add a persistent connection back to the same address. notifyDialed.Store(true) @@ -375,6 +465,7 @@ func TestDisconnect(t *testing.T) { if err != nil { t.Fatalf("failed to add persistent connection: %v", err) } + assertConnManagerInternalState(t, cmgr) // Wait for the connection manager to attempt to dial and ensure the // connection is marked as pending while the dialer is blocked. @@ -384,11 +475,13 @@ func TestDisconnect(t *testing.T) { t.Fatal("timeout waiting for dial") } assertPendingAddr(t, cmgr, addr) + assertConnManagerInternalState(t, cmgr) // Disconnect the persistent connection attempt while it's still pending. if err := cmgr.Disconnect(connID); err != nil { t.Fatalf("unexpected disconnect err: %v", err) } + assertConnManagerInternalState(t, cmgr) // Allow the dialer to proceed with the disconnected persistent connection // attempt and then wait for the dialer to signal the context associated @@ -410,6 +503,7 @@ func TestDisconnect(t *testing.T) { // Wait for the retry to be established. assertConnReceived(t, connected, connID, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) // Disconnect the established persistent connection and wait for the // disconnect notification to ensure it is disconnected as intended. @@ -417,10 +511,12 @@ func TestDisconnect(t *testing.T) { t.Fatalf("unexpected disconnect err: %v", err) } assertConnReceived(t, disconnected, connID, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) } // TestRemove ensures that [ConnManager.Remove] properly removes pending and @@ -485,6 +581,7 @@ func TestRemove(t *testing.T) { t.Fatal("timeout waiting for dial") } assertPendingAddr(t, cmgr, addr) + assertConnManagerInternalState(t, cmgr) // Remove the connection attempt while it's still pending. connID, _ := pendingAddrConnID(cmgr, addr) @@ -508,6 +605,7 @@ func TestRemove(t *testing.T) { if _, ok := pendingAddrConnID(cmgr, addr); ok { t.Fatalf("connection %s is still pending", addr) } + assertConnManagerInternalState(t, cmgr) // Start a connection attempt and wait for it to be established. notifyDialed.Store(false) @@ -515,6 +613,7 @@ func TestRemove(t *testing.T) { notifyCanceled.Store(false) go cmgr.Connect(ctx, addr) conn := assertConnReceived(t, connected, 0, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) // Remove the established connection and wait for the disconnect // notification to ensure it is disconnected as intended. @@ -523,6 +622,7 @@ func TestRemove(t *testing.T) { t.Fatalf("unexpected disconnect err: %v", err) } assertConnReceived(t, disconnected, connID, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) // Add a persistent connection back to the same address. notifyDialed.Store(true) @@ -532,6 +632,7 @@ func TestRemove(t *testing.T) { if err != nil { t.Fatalf("failed to add persistent connection: %v", err) } + assertConnManagerInternalState(t, cmgr) // Wait for the connection manager to attempt to dial and ensure the // connection is marked as pending while the dialer is blocked. @@ -546,6 +647,7 @@ func TestRemove(t *testing.T) { if err := cmgr.Remove(connID); err != nil { t.Fatalf("unexpected disconnect err: %v", err) } + assertConnManagerInternalState(t, cmgr) // Allow the dialer to proceed with the removed persistent connection // attempt and then wait for the dialer to signal the context associated @@ -564,6 +666,7 @@ func TestRemove(t *testing.T) { case <-time.After(time.Millisecond * 5): t.Fatal("timeout waiting for cancel") } + assertConnManagerInternalState(t, cmgr) // Add a persistent connection back to the same address and wait for it to // be established. @@ -575,6 +678,7 @@ func TestRemove(t *testing.T) { t.Fatalf("failed to add persistent connection: %v", err) } conn2 := assertConnReceived(t, connected, connID, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) // Remove the established persistent connection and wait for the disconnect // notification to ensure it is disconnected as intended. Also, ensure the @@ -584,10 +688,12 @@ func TestRemove(t *testing.T) { t.Fatalf("unexpected disconnect err: %v", err) } assertConnReceived(t, disconnected, connID, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) } // TestTargetOutbound tests the target number of outbound connections @@ -618,10 +724,12 @@ func TestTargetOutbound(t *testing.T) { assertConnReceived(t, connected, 0, ConnTypeOutbound) } assertNoConnReceived(t, connected) + assertConnManagerInternalState(t, cmgr) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) } // TestDoubleClose ensures closing a connection multiple times is a noop after @@ -644,6 +752,7 @@ func TestDoubleClose(t *testing.T) { // Wait for the connection to be established. conn := assertConnReceived(t, connected, 0, ConnTypeOutbound) + assertConnManagerInternalState(t, cmgr) // Override the close func to cleanly detect closes. var numClosed uint32 @@ -660,10 +769,12 @@ func TestDoubleClose(t *testing.T) { if numClosed != 1 { t.Fatal("connection closed more than once") } + assertConnManagerInternalState(t, cmgr) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) } // TestRetryPersistent tests that persistent connections are retried. @@ -700,6 +811,7 @@ func TestRetryPersistent(t *testing.T) { conn.Close() assertConnReceived(t, disconnected, connID, ConnTypeManual) assertConnReceived(t, connected, connID, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) // Remove the persistent connection, wait for it to disconnect, and ensure // it is actually removed. @@ -708,10 +820,12 @@ func TestRetryPersistent(t *testing.T) { } assertConnReceived(t, disconnected, connID, ConnTypeManual) assertRemovedPersistent(t, cmgr, addr) + assertConnManagerInternalState(t, cmgr) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) } // TestMaxPersistent ensures [ConnManager.AddPersistent] limits the maximum @@ -754,6 +868,7 @@ func TestMaxPersistent(t *testing.T) { // Wait for the connection. assertConnReceived(t, connected, connID, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) } // Attempting to add more than the max allowed number of persistent conns @@ -762,6 +877,7 @@ func TestMaxPersistent(t *testing.T) { if !errors.Is(err, ErrMaxPersistent) { t.Fatalf("did not reject > max persistent, err: %v", err) } + assertConnManagerInternalState(t, cmgr) // Ensure disconnecting the persistent conn does not incorrectly decrement // the count. @@ -773,6 +889,7 @@ func TestMaxPersistent(t *testing.T) { if !errors.Is(err, ErrMaxPersistent) { t.Fatalf("did not reject max persistent after dc, err: %v", err) } + assertConnManagerInternalState(t, cmgr) // Remove the first persistent connection, wait for it to disconnect, and // ensure it is actually removed. @@ -781,6 +898,7 @@ func TestMaxPersistent(t *testing.T) { } assertConnReceived(t, disconnected, connID, ConnTypeManual) assertRemovedPersistent(t, cmgr, addr) + assertConnManagerInternalState(t, cmgr) // A new persistent conn should now be allowed. addr = nextAddr() @@ -788,10 +906,12 @@ func TestMaxPersistent(t *testing.T) { if err != nil { t.Fatalf("failed to add persistent connection %v: %v", addr, err) } + assertConnManagerInternalState(t, cmgr) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) } // TestMaxRetryDuration tests the maximum retry duration. @@ -843,10 +963,12 @@ func TestMaxRetryDuration(t *testing.T) { }) const timeout = connTestReceiveTimeout + networkUpTimeout assertConnReceivedTimeout(t, connected, timeout, connID, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) } // TestNetworkFailure tests that the connection manager handles a network @@ -906,6 +1028,8 @@ func TestNetworkFailure(t *testing.T) { t.Fatalf("unexpected number of dials - got %v, want <= %v", gotDials, wantMaxDials) } + + assertConnManagerCleanShutdown(t, cmgr) } // TestMultipleFailedConns ensures that the connection manager remains @@ -944,6 +1068,7 @@ func TestMultipleFailedConns(t *testing.T) { t.Fatalf("unexpected add err: %v", err) } } + assertConnManagerInternalState(t, cmgr) // Wait for the target number of dials and ensure they happen simultaneously // by checking it happens before the retry timeout. @@ -952,6 +1077,7 @@ func TestMultipleFailedConns(t *testing.T) { case <-time.After(20 * time.Millisecond): t.Fatal("did not reach target number of dials before timeout") } + assertConnManagerInternalState(t, cmgr) // Ensure that the connection manager still responds to requests while the // failed connections are still retrying. @@ -966,10 +1092,12 @@ func TestMultipleFailedConns(t *testing.T) { case <-time.After(20 * time.Millisecond): t.Fatal("timeout servicing connmgr requests") } + assertConnManagerInternalState(t, cmgr) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) } // TestShutdownFailedConns tests that failed connections are ignored after @@ -997,6 +1125,7 @@ func TestShutdownFailedConns(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } + assertConnManagerInternalState(t, cmgr) // Shutdown the connection manager during the retry timeout after a failed // dial attempt. @@ -1010,6 +1139,7 @@ func TestShutdownFailedConns(t *testing.T) { // Ensure clean shutdown of connection manager. wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) } // TestRemovePendingConnection ensures that removing a pending outbound @@ -1035,6 +1165,7 @@ func TestRemovePendingConnection(t *testing.T) { // Establish a connection request to a localhost IP. addr := mustParseAddrPort("127.0.0.1:18555") go cmgr.Connect(ctx, addr) + assertConnManagerInternalState(t, cmgr) // Wait for the connection manager to attempt to dial and ensure the // connection is marked as pending while the dialer is blocked. @@ -1044,12 +1175,14 @@ func TestRemovePendingConnection(t *testing.T) { t.Fatal("timeout waiting for dial") } assertPendingAddr(t, cmgr, addr) + assertConnManagerInternalState(t, cmgr) // Cancel the connection attempt while it's still pending. connID, _ := pendingAddrConnID(cmgr, addr) if err := cmgr.Remove(connID); err != nil { t.Fatalf("unexpected remove err: %v", err) } + assertConnManagerInternalState(t, cmgr) // Wait for the dialer to signal the context associated with the dial was // canceled and ensure the internal pending state is removed. @@ -1061,10 +1194,12 @@ func TestRemovePendingConnection(t *testing.T) { if _, ok := pendingAddrConnID(cmgr, addr); ok { t.Fatalf("connection %s is still pending", addr) } + assertConnManagerInternalState(t, cmgr) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) } // TestCancelIgnoreDelayedConnection tests that a canceled pending persistent @@ -1111,6 +1246,7 @@ func TestCancelIgnoreDelayedConnection(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } + assertConnManagerInternalState(t, cmgr) // Wait for the retry and ensure the connection is pending. select { @@ -1119,6 +1255,7 @@ func TestCancelIgnoreDelayedConnection(t *testing.T) { t.Fatalf("did not get retry before timeout") } assertPendingAddr(t, cmgr, addr) + assertConnManagerInternalState(t, cmgr) // Remove the connection and then immediately allow the next connection to // succeed. @@ -1132,10 +1269,12 @@ func TestCancelIgnoreDelayedConnection(t *testing.T) { // timeout window to ensure the connection manager's backoff is allowed to // properly elapse. assertNoConnReceivedTimeout(t, connected, 5*retryTimeout) + assertConnManagerInternalState(t, cmgr) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) } // TestDialTimeout ensure [Config.Timeout] works as intended by creating a @@ -1167,6 +1306,7 @@ func TestDialTimeout(t *testing.T) { // Establish a connection to a localhost IP. addr := mustParseAddrPort("127.0.0.1:18555") go cmgr.Connect(ctx, addr) + assertConnManagerInternalState(t, cmgr) // Wait to receive the signal that the dialer context was cancelled, which // means the dial timeout was hit. @@ -1175,10 +1315,12 @@ func TestDialTimeout(t *testing.T) { case <-time.After(dialTimeout * 10): t.Fatal("timeout waiting for dial cancellation") } + assertConnManagerInternalState(t, cmgr) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) } // TestConnectContext ensures the [ConnManager.Connect] method works as intended @@ -1218,6 +1360,7 @@ func TestConnectContext(t *testing.T) { t.Fatal("timeout waiting for dial") } assertPendingAddr(t, cmgr, addr) + assertConnManagerInternalState(t, cmgr) // Cancel the connection context, wait for the error from connect, and // ensure it is the expected error. @@ -1231,10 +1374,12 @@ func TestConnectContext(t *testing.T) { case <-time.After(10 * time.Millisecond): t.Fatal("timeout waiting for dial cancellation") } + assertConnManagerInternalState(t, cmgr) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) } // mockListener implements the net.Listener interface and is used to test @@ -1327,10 +1472,12 @@ func TestListeners(t *testing.T) { for range expectedNumConns { assertConnReceived(t, receivedConns, 0, ConnTypeInbound) } + assertConnManagerInternalState(t, cmgr) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) } // TestRejectDuplicateConns ensures duplicate addresses are rejected. This @@ -1379,6 +1526,7 @@ func TestRejectDuplicateConns(t *testing.T) { t.Fatal("did not receive pending dial before timeout") } assertPendingAddr(t, cmgr, addr) + assertConnManagerInternalState(t, cmgr) // Duplicate connect to the pending address should be rejected. if _, err := cmgr.Connect(ctx, addr); !errors.Is(err, ErrAlreadyPending) { @@ -1388,24 +1536,29 @@ func TestRejectDuplicateConns(t *testing.T) { // Inbound attempts from the pending outbound address should be rejected. go listener.Connect(addr) assertNoConnReceived(t, inboundConns) + assertConnManagerInternalState(t, cmgr) // Allow the pending connection to complete. close(pending) conn := assertConnReceived(t, connected, 0, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) // Duplicate connect to the established address should be rejected. if _, err := cmgr.Connect(ctx, addr); !errors.Is(err, ErrAlreadyConnected) { t.Fatalf("did not reject duplicate active connection, err: %v", err) } + assertConnManagerInternalState(t, cmgr) // Inbound attempts from the established outbound address should be // rejected. go listener.Connect(addr) assertNoConnReceived(t, inboundConns) + assertConnManagerInternalState(t, cmgr) // Close the connection and wait for the disconnect. conn.Close() assertConnReceived(t, disconnected, conn.ID(), ConnTypeManual) + assertConnManagerInternalState(t, cmgr) // Add a persistent connection back to the same address and wait for it to // connect since there are no longer any connections to the address. @@ -1414,22 +1567,26 @@ func TestRejectDuplicateConns(t *testing.T) { t.Fatalf("failed to add persistent connection: %v", err) } assertConnReceived(t, connected, connID, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) // Duplicate persistent connection attempts should be rejected. _, err = cmgr.AddPersistent(addr) if !errors.Is(err, ErrDuplicatePersistent) { t.Fatalf("did not reject duplicate persistent connection, err: %v", err) } + assertConnManagerInternalState(t, cmgr) // Manual connection attempts to persistent connection should be rejected. _, err = cmgr.Connect(ctx, addr) if !errors.Is(err, ErrDuplicatePersistent) { t.Fatalf("did not reject manual connection to persistent, err: %v", err) } + assertConnManagerInternalState(t, cmgr) // Inbound atempts from the persistent address should be rejected. go listener.Connect(addr) assertNoConnReceived(t, inboundConns) + assertConnManagerInternalState(t, cmgr) // Remove the persistent connection, wait for it to disconnect, and ensure // it is actually removed. @@ -1438,16 +1595,19 @@ func TestRejectDuplicateConns(t *testing.T) { } assertConnReceived(t, disconnected, connID, ConnTypeManual) assertRemovedPersistent(t, cmgr, addr) + assertConnManagerInternalState(t, cmgr) // Inbound connections from the same address should now succeed. go listener.Connect(addr) assertConnReceived(t, inboundConns, 0, ConnTypeInbound) + assertConnManagerInternalState(t, cmgr) // Manual connection attempts to the inbound address should be rejected. if _, err := cmgr.Connect(ctx, addr); !errors.Is(err, ErrAlreadyConnected) { t.Fatalf("did not reject outbound for existing inbound conn, err: %v", err) } + assertConnManagerInternalState(t, cmgr) // Attempts to add a persistent connection to an existing inbound should be // rejected. @@ -1456,8 +1616,10 @@ func TestRejectDuplicateConns(t *testing.T) { t.Fatalf("did not reject persistent conn for existing inbound conn: %v", err) } + assertConnManagerInternalState(t, cmgr) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) } From 76bf20218ce048223f98cc83d80501d251e83b4d Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Mon, 18 May 2026 15:59:15 -0500 Subject: [PATCH 13/24] connmgr: Support whitelisting. Currently the whitelisting logic happens in the server which makes it inaccessible to the connection manager. In order to pave the way for supporting various connection-related logic that currently happens in the server, but ideally should be happening in the connection manager, this adds basic support for whitelisting CIDR prefixes to the connection manager. The connection manager config struct now accepts a slice of prefixes and a new method named IsWhitelisted is added. Note that this only adds support . It does not update anything to use the new functionality yet. --- internal/connmgr/connmanager.go | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/internal/connmgr/connmanager.go b/internal/connmgr/connmanager.go index 453e4a33d..a2df96398 100644 --- a/internal/connmgr/connmanager.go +++ b/internal/connmgr/connmanager.go @@ -12,6 +12,7 @@ import ( "errors" "fmt" "net" + "net/netip" "strconv" "sync" "sync/atomic" @@ -260,6 +261,10 @@ type Config struct { // DialTimeout specifies the amount of time to wait for a connection to // complete before giving up. DialTimeout time.Duration + + // Whitelists specifies CIDR address prefixes to whitelist. Whitelisted + // addresses are exempt from banning and certain connection limits. + Whitelists []netip.Prefix } // ConnManager provides a manager to handle network connections. @@ -317,6 +322,22 @@ type ConnManager struct { connIDByAddr map[string]uint64 } +// IsWhitelisted returns whether the IP address is included in the whitelisted +// networks and IPs. +func (cm *ConnManager) IsWhitelisted(addr *addrmgr.NetAddress) bool { + if len(cm.cfg.Whitelists) == 0 { + return false + } + + ip, _ := netip.AddrFromSlice(addr.IP) + for _, prefix := range cm.cfg.Whitelists { + if prefix.Contains(ip) { + return true + } + } + return false +} + // checkShutdown returns [ErrShutdown] when the connection manager quit channel // has been closed. func (cm *ConnManager) checkShutdown() error { From f151dcd8a75b68307d2e00985138590f551e04b4 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Wed, 20 May 2026 19:21:44 -0500 Subject: [PATCH 14/24] connmgr: Add whitelist detection tests. This adds tests to ensure the new whitelist detection method works as expected. --- internal/connmgr/connmanager_test.go | 88 ++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index ecbb0ccf0..6ce82f066 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -219,6 +219,94 @@ func TestNewConfig(t *testing.T) { }) } +// TestIsWhitelisted ensures [ConnManager.IsWhitelisted] works as expected. +func TestIsWhitelisted(t *testing.T) { + type perManagerTest struct { + addr string // address to test against whitelist + whitelisted bool // expected whitelisted result + } + + tests := []struct { + name string // test description + prefixes []string // CIDR prefixes to whitelist + perManagerTests []perManagerTest // tests to run against the prefixes + }{{ + name: "no whitelisted entries", + prefixes: nil, + perManagerTests: []perManagerTest{ + {"1.2.3.4:18555", false}, + {"127.0.0.1:18555", false}, + }, + }, { + name: "single /32 IPv4 entry", + prefixes: []string{"1.2.3.4/32"}, + perManagerTests: []perManagerTest{ + {"1.2.3.4:18555", true}, + {"1.2.3.4:9108", true}, + {"[::1.2.3.4]:18555", false}, // IPv4 in IPv6 + {"1.2.3.5:18555", false}, + }, + }, { + name: "single /128 IPv6 entry", + prefixes: []string{"::1.2.3.4/128"}, + perManagerTests: []perManagerTest{ + {"[::1.2.3.4]:18555", true}, + {"[::1.2.3.4]:9108", true}, + {"1.2.3.4:18555", false}, // IPv4 doesn't match IPv4 in IPv6 + {"[::1.2.3.5]:9108", false}, + }, + }, { + name: "mixed IPv4 and IPv6 with different prefix lengths", + prefixes: []string{"12.13.14.0/24", "20.21.22.23/8", "fe80::/64"}, + perManagerTests: []perManagerTest{ + {"12.13.14.1:18555", true}, + {"12.13.14.255:18555", true}, + {"12.13.15.0:18555", false}, + {"20.0.0.0:18555", true}, + {"20.0.0.0:9108", true}, + {"20.255.255.255:18555", true}, + {"20.255.255.255:9108", true}, + {"21.0.0.0:18555", false}, + {"[fe80::1]:18555", true}, + {"[fe80::1]:9108", true}, + {"[fe80::ffff:ffff:ffff:ffff]:18555", true}, + {"[fe80::ffff:ffff:ffff:ffff]:1234", true}, + {"[fe80::1:ffff:ffff:ffff:ffff]:18555", false}, + }, + }} + + for _, test := range tests { + // Parse the whitelist entries for the test. + prefixes := make([]netip.Prefix, 0, len(test.prefixes)) + for _, prefixStr := range test.prefixes { + prefix, err := netip.ParsePrefix(prefixStr) + if err != nil { + t.Fatalf("%q: failed to parse prefix %q: %v", test.name, prefixStr, + err) + } + prefixes = append(prefixes, prefix) + } + cmgr := newTestConnManager(t, &Config{ + Dial: mockDialer, + Whitelists: prefixes, + }) + + for _, pmTest := range test.perManagerTests { + mAddr := mockAddr{"tcp", pmTest.addr} + addr, err := stdlibNetAddrToAddrMgrNetAddr(mAddr) + if err != nil { + t.Fatalf("%q-%q: failed to parse address: %v", test.name, + pmTest.addr, err) + } + if got := cmgr.IsWhitelisted(addr); got != pmTest.whitelisted { + t.Errorf("%q-%q: mismatched result -- got %v, want %v", + test.name, pmTest.addr, got, pmTest.whitelisted) + continue + } + } + } +} + // assertConnID ensures the provided connection has the given ID. func assertConnID(t *testing.T, conn *Conn, wantID uint64) { t.Helper() From bb01b2146fd661167ada03292c874d87a3e937ff Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Mon, 18 May 2026 15:59:20 -0500 Subject: [PATCH 15/24] server: Integrate connmgr whitelisting. This modifies the server to pass in the parsed whitelist entries to the connection manager config and the relevant code to make use of the new method it exposes. Finally, it removes the no longer used local isWhitelisted method. --- server.go | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/server.go b/server.go index f61afaf33..3a7b1b671 100644 --- a/server.go +++ b/server.go @@ -513,6 +513,7 @@ func newServerPeer(s *server, conn *connmgr.Conn, remoteAddr *addrmgr.NetAddress conn: conn, remoteAddr: remoteAddr, persistent: s.connManager.IsPersistent(conn.ID()), + isWhitelisted: s.connManager.IsWhitelisted(remoteAddr), knownAddresses: apbf.NewFilter(maxKnownAddrsPerPeer, knownAddrsFPRate), quit: make(chan struct{}), getDataQueue: make(chan []*wire.InvVect, maxConcurrentGetDataReqs), @@ -2289,7 +2290,6 @@ func (s *server) inboundPeerConnected(ctx context.Context, conn *connmgr.Conn) { sp := newServerPeer(s, conn, remoteNetAddr) sp.Peer = peer.NewInboundPeer(newPeerConfig(sp), conn) - sp.isWhitelisted = isWhitelisted(remoteNetAddr) if err := sp.Handshake(ctx, sp.OnVersion); err != nil { srvrLog.Debugf("Failed handshake for inbound peer %s: %v", remoteNetAddr, err) @@ -2323,7 +2323,6 @@ func (s *server) outboundPeerConnected(ctx context.Context, conn *connmgr.Conn) sp := newServerPeer(s, conn, remoteNetAddr) sp.Peer = peer.NewOutboundPeer(newPeerConfig(sp), conn.RemoteAddr(), conn) - sp.isWhitelisted = isWhitelisted(remoteNetAddr) if err := sp.Handshake(ctx, sp.OnVersion); err != nil { srvrLog.Debugf("Failed handshake for outbound peer %s: %v", conn.RemoteAddr(), err) @@ -4367,6 +4366,7 @@ func newServer(ctx context.Context, profiler *profileServer, s.outboundPeerConnected(ctx, conn) }, GetNewAddress: newAddressFunc, + Whitelists: cfg.whitelists, }) if err != nil { return nil, err @@ -4631,19 +4631,3 @@ func addLocalAddress(addrMgr *addrmgr.AddrManager, addr string, services wire.Se return nil } - -// isWhitelisted returns whether the IP address is included in the whitelisted -// networks and IPs. -func isWhitelisted(addr *addrmgr.NetAddress) bool { - if len(cfg.whitelists) == 0 { - return false - } - - ip, _ := netip.AddrFromSlice(addr.IP) - for _, prefix := range cfg.whitelists { - if prefix.Contains(ip) { - return true - } - } - return false -} From 259a45f815e0f3280c3f951563a1f27640efe0d2 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Thu, 21 May 2026 17:10:47 -0500 Subject: [PATCH 16/24] connmgr: Update README.md for whitelist support. --- internal/connmgr/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/internal/connmgr/README.md b/internal/connmgr/README.md index b9299d864..cf43ac92a 100644 --- a/internal/connmgr/README.md +++ b/internal/connmgr/README.md @@ -33,6 +33,8 @@ The following is a brief overview of the key features: - Supports manual connection establishment via `Connect` - Duplicate address prevention - Rejects duplicate connections to and from the same address (host:port) +- Whitelist support + - CIDR-based whitelists that allow bypassing certain limits and restrictions - Rich managed connections via `Conn` - Connection types for differentiated handling - Automatic cleanup on connection close From 54a2d5999e952d04d18c1e87b47b53472472e308 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Tue, 19 May 2026 14:11:47 -0500 Subject: [PATCH 17/24] connmgr: Add try acquire support to semaphore. This adds a new TryAcquire method to the context-aware semaphore. As the name implies, the method supports conditionally acquiring the semaphore only when resources are immediately available. In other words, it will not block when there are no resources immediately available. --- internal/connmgr/semaphore.go | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/internal/connmgr/semaphore.go b/internal/connmgr/semaphore.go index fb7d7eed4..532efab5c 100644 --- a/internal/connmgr/semaphore.go +++ b/internal/connmgr/semaphore.go @@ -27,6 +27,27 @@ func (s semaphore) Acquire(ctx context.Context) bool { return true } +// TryAcquire attempts to acquire the semaphore without blocking when there are +// no resources immediately available. +// +// It returns true with a nil error on success. It return false with a nil +// error when the semaphore is at capacity and no permit is available. +// +// Finally, it returns false with the error associated with the context +// immediately when the context is already canceled or timed out at the time of +// the call. It does not attempt to acquire the semaphore in that case. +func (s semaphore) TryAcquire(ctx context.Context) (bool, error) { + if ctx.Err() != nil { + return false, ctx.Err() + } + select { + case s <- struct{}{}: + return true, nil + default: + } + return false, nil +} + // Release release the semaphore. func (s semaphore) Release() { select { From 87b5fe809f0b1eb81ceb765a22c073d5b75cdd09 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Tue, 19 May 2026 14:11:48 -0500 Subject: [PATCH 18/24] connmgr: Add semaphore try acquire tests. This adds tests for the new TryAcquire method on the context-aware semaphore to ensure the semantics work as expected. --- internal/connmgr/semaphore_test.go | 75 +++++++++++++++++++++++++++++- 1 file changed, 73 insertions(+), 2 deletions(-) diff --git a/internal/connmgr/semaphore_test.go b/internal/connmgr/semaphore_test.go index 9542176df..1d7d5b750 100644 --- a/internal/connmgr/semaphore_test.go +++ b/internal/connmgr/semaphore_test.go @@ -6,6 +6,7 @@ package connmgr import ( "context" + "errors" "testing" "time" ) @@ -21,10 +22,24 @@ func TestSemaphore(t *testing.T) { return sem.Acquire(ctx) } + // Create a closure that tries to acquire a semaphore via the nonblocking + // method with a timeout. + timedTryAcquire := func(sem semaphore, timeout time.Duration) (bool, error) { + ctx, cancel := context.WithTimeout(ctx, timeout) + if timeout == 0 { + cancel() + } else { + defer cancel() + } + return sem.TryAcquire(ctx) + } + // perSemTest describes a test to run against the same semaphore. type perSemTest struct { name string // test description numAcquires uint32 // num to acquire + numTries uint32 // num to try to acquire via nonblocking method + cancelTry bool // whether or not to cancel nonblocking try numReleases uint32 // num to release } @@ -69,6 +84,45 @@ func TestSemaphore(t *testing.T) { numReleases: 5, }}, want: []bool{true, true, true, true, true, true, true, false}, + }, { + name: "nonblocking tryacquire and blocking acquire mixed", + cap: 3, + perSemTests: []perSemTest{{ + name: "cap 3 (0 acquired): try 1, release 2", + numTries: 1, + numReleases: 2, + }, { + name: "cap 3 (0 acquired): acquire 2, try 1, release 1", + numAcquires: 2, + numTries: 1, + numReleases: 1, + }, { + name: "cap 3 (2 acquired): acquire 1, try 2, release 3", + numAcquires: 1, + numTries: 2, + numReleases: 3, + }}, + want: []bool{true, true, true, true, true, false, false}, + }, { + name: "nonblocking tryacquire with canceled context", + cap: 1, + perSemTests: []perSemTest{{ + name: "cap 1 (0 acquired): try 1 (canceled), release 0", + numTries: 1, + cancelTry: true, + numReleases: 0, + }, { + name: "cap 1 (0 acquired): acquire 1, try 1, release 1", + numAcquires: 1, + numTries: 1, + numReleases: 1, + }, { + name: "cap 1 (0 acquired): try 2, release 1", + numAcquires: 0, + numTries: 2, + numReleases: 1, + }}, + want: []bool{false, true, false, true, false}, }} for _, test := range tests { @@ -77,13 +131,30 @@ func TestSemaphore(t *testing.T) { sem := makeSemaphore(test.cap) results := make([]bool, 0, len(test.want)) - // Perform each sequence of acquires and releases as specified by the - // per semaphore tests. + // Perform each sequence of acquires, try acquires, and releases as + // specified by the per semaphore tests. for _, psTest := range test.perSemTests { const timeout = 10 * time.Millisecond for range psTest.numAcquires { results = append(results, timedAcquire(sem, timeout)) } + for range psTest.numTries { + // Override timeout with a duration 0 and expected error when + // the flag to force the context for the try acquire to be + // canceled is specified. + var wantErr error + tryTimeout := timeout + if psTest.cancelTry { + tryTimeout = 0 + wantErr = context.DeadlineExceeded + } + acquired, err := timedTryAcquire(sem, tryTimeout) + if !errors.Is(err, wantErr) { + t.Fatalf("%q: unexpected try acquire error: got %v, want %v", + psTest.name, err, wantErr) + } + results = append(results, acquired) + } for range psTest.numReleases { sem.Release() } From 5d594abc7aa39b602ef67a58f7a04f6548677c2f Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Tue, 19 May 2026 14:11:49 -0500 Subject: [PATCH 19/24] connmgr: Limit total overall normal connections. The current overall total connection limits are enforced by the server rather than the connection manager. This is not ideal for many reasons, but one of the most important consequences is that it makes DoS attacks easier. Another example of some less than ideal behavior that it allows is that some rare combinations of events can lead to temporary extra connection churn. It is much more robust and natural to perform the limiting in the connection manager itself via semaphores. That approach not only significantly hardens the server against DoS attacks and solves various edge cases present in the current code, it also paves the way for even more advanced features such as traffic shaping in the future. To that end, this adds semaphore-based limiting for the total overall number of normal connections to the connection manager and removes the relevant current limiting for it from the server. Normal connections are the automatic outbound, manual outbound, and inbound connections. Persistent connections, on the other hand, are not subject to the limit since they have their own limiting. This is consistent with them not being subject to the automatic target outbound limit either. --- internal/connmgr/connmanager.go | 144 +++++++++++++++++++++++++++----- internal/connmgr/error.go | 4 + internal/connmgr/error_test.go | 1 + internal/connmgr/log.go | 9 ++ rpcadaptors.go | 8 -- server.go | 12 +-- 6 files changed, 139 insertions(+), 39 deletions(-) diff --git a/internal/connmgr/connmanager.go b/internal/connmgr/connmanager.go index a2df96398..08722cdd4 100644 --- a/internal/connmgr/connmanager.go +++ b/internal/connmgr/connmanager.go @@ -44,6 +44,10 @@ const ( // base times the number of retries that have been done. defaultMaxRetryDuration = time.Minute * 5 + // defaultMaxNormalConns is the default number of maximum normal inbound, + // outbound, and pending connections to permit. + defaultMaxNormalConns = 125 + // defaultTargetOutbound is the default number of outbound connections to // maintain. defaultTargetOutbound = 8 @@ -233,11 +237,25 @@ type Config struct { // connections in that case. OnAccept func(*Conn) + // MaxNormalConns is the maximum number of normal inbound, outbound, and + // pending connections to permit. Defaults to 125. + // + // Persistent connections do not count against this limit. They have their + // own maximum defined by [MaxPersistent]. + // + // Whitelisted connections and some connections with special permissions are + // also exempt. As a result, the total number of connections may exceed + // this value. + MaxNormalConns uint32 + // TargetOutbound is the number of outbound network connections to maintain // automatically. Defaults to 8. // // Persistent connections do not count against this value. They have their // own maximum limit defined by [MaxPersistent]. + // + // This will be forced to the smaller of the specified value (or its default + // value when unspecified) and [Config.MaxNormalConns]. TargetOutbound uint32 // RetryDuration is the duration to wait before retrying connection @@ -290,10 +308,16 @@ type ConnManager struct { // It is a buffered channel with size [MaxPersistent]. runPersistentChan chan *persistentEntry - // outboundSem limits the number of active outbound connections. It does - // not apply to persistent connections which are separately limited to - // [MaxPersistent]. - activeOutboundsSem semaphore + // These semaphores are used to enforce max limits on the number of + // connections of different kinds. They do not apply to persistent + // connections which are separately limited to [MaxPersistent]. + // + // totalNormalConnsSem limits the total overall number of normal inbound, + // outbound, and pending connections. + // + // outboundSem limits the number of active outbound connections. + totalNormalConnsSem semaphore + activeOutboundsSem semaphore // The fields below this point are all protected by the connection mutex. connMtx sync.Mutex @@ -529,6 +553,10 @@ func (cm *ConnManager) rejectDuplicateAddr(addr *addrmgr.NetAddress) error { // and pending connections are rejected when a non-nil persistent connection ID // is passed. // +// The following connection limits are enforced: +// +// - Total normal connections ([Config.MaxNormalConns]) +// // On success, the returned connection is configured to remove itself from the // set of all active connections and invoke the provided on close callback (if // set) when it is closed. @@ -545,6 +573,8 @@ func (cm *ConnManager) rejectDuplicateAddr(addr *addrmgr.NetAddress) error { // to the address // - [ErrAlreadyConnected] when there is already an established connection to // the address +// - [ErrMaxNormalConns] when there are already the maximum allowed number of +// normal connections (inbound, outbound, and pending) // - [ErrShutdown] when the connection manager is shutting down // - [context.Canceled] or [context.DeadlineExceeded] depending on the // provided context or when the dialer fails to establish a connection @@ -694,9 +724,13 @@ func (cm *ConnManager) dial(ctx context.Context, addr net.Addr, connType Connect // the connection manager. // // Attempts to dial addresses that already have an established, pending, or -// persistent connection will return an error as described below. +// persistent connection or would exceed max allowed limits will return an error +// as described below. +// +// The connection will have type [ConnTypeManual] and the following connection +// limits are enforced: // -// The connection will have type [ConnTypeManual]. +// - Total normal connections ([Config.MaxNormalConns]) // // Note that the context parameter to this function and the lifecycle context // may be independent. @@ -710,6 +744,8 @@ func (cm *ConnManager) dial(ctx context.Context, addr net.Addr, connType Connect // to the address // - [ErrAlreadyConnected] when there is already an established connection to // the address +// - [ErrMaxNormalConns] when there are already the maximum allowed number of +// normal connections (inbound, outbound, and pending) // - [ErrShutdown] when the connection manager is shutting down // - [context.Canceled] or [context.DeadlineExceeded] depending on the // provided context or when the dialer fails to establish a connection @@ -717,7 +753,21 @@ func (cm *ConnManager) dial(ctx context.Context, addr net.Addr, connType Connect // // This function is safe for concurrent access. func (cm *ConnManager) Connect(ctx context.Context, addr net.Addr) (*Conn, error) { - conn, err := cm.dial(ctx, addr, ConnTypeManual, nil, nil) + acquired, err := cm.totalNormalConnsSem.TryAcquire(ctx) + if err != nil { + if sErr := cm.checkShutdown(); sErr != nil { + return nil, sErr + } + return nil, err + } + if !acquired { + maxAllowed := cm.cfg.MaxNormalConns + str := fmt.Sprintf("a maximum of %d %s is allowed", maxAllowed, + pickNoun(maxAllowed, "connection", "connections")) + return nil, MakeError(ErrMaxNormalConns, str) + } + onClose := cm.totalNormalConnsSem.Release + conn, err := cm.dial(ctx, addr, ConnTypeManual, onClose, nil) if err != nil { return nil, err } @@ -847,6 +897,18 @@ func (cm *ConnManager) listenHandler(ctx context.Context, listener net.Listener) defer log.Tracef("Listener handler done for %s", listener.Addr()) for ctx.Err() == nil { + // The following is intentionally implementing active connection + // shedding by accepting connections and then immediately disconnecting + // them after the [net.Listener.Accept] call if any policies are + // violated. + // + // Reversing it and blocking until a permit is available and only then + // calling Accept would cause the connections to build up in the kernel. + // Then, since the kernel will still create the 3-way handshake, clients + // would connect and hang until their own timeouts are hit, and, + // eventually, the entire service could appear entirely down if the SYN + // queue were to fill. It also would not allow implementing better + // additional policies. netConn, err := listener.Accept() if err != nil { // Only log the error if not forcibly shutting down. @@ -881,7 +943,29 @@ func (cm *ConnManager) listenHandler(ctx context.Context, listener net.Listener) } cm.connMtx.Unlock() - go func(netConn net.Conn) { + // Require a permit to allow the inbound connection unless the address + // has special permissions (e.g. whitelisted). + // + // Attempt to acquire a permit via a non-blocking call and immediately + // disconnect if unsuccessful so that all blocking happens on + // [net.Listener.Accept] for the reasons described above. + requirePermit := !cm.IsWhitelisted(rAddr) + if requirePermit { + acquired, err := cm.totalNormalConnsSem.TryAcquire(ctx) + if err != nil { + netConn.Close() + continue + } + if !acquired { + maxAllowed := cm.cfg.MaxNormalConns + log.Debugf("Dropped connection from %v: a maximum of %d %s is "+ + "allowed", rAddr, maxAllowed, pickNoun(maxAllowed, + "connection", "connections")) + netConn.Close() + continue + } + } + go func(netConn net.Conn, requirePermit bool) { // Create a new connection instance with the next globally unique // connection ID, add an entry to the map that tracks all active // connections, and invoke the configured accept callback with it. @@ -897,6 +981,9 @@ func (cm *ConnManager) listenHandler(ctx context.Context, listener net.Listener) cm.connMtx.Unlock() log.Debugf("Disconnected from %v (id: %d, type: %v)", rAddr, id, connType) + if requirePermit { + cm.totalNormalConnsSem.Release() + } } conn = newConn(cm, netConn, id, connType, rAddr, onClose) cm.connMtx.Lock() @@ -905,7 +992,7 @@ func (cm *ConnManager) listenHandler(ctx context.Context, listener net.Listener) log.Debugf("Accepted connection from %v (id: %d, type: %v)", rAddr, id, connType) cm.cfg.OnAccept(conn) - }(netConn) + }(netConn, requirePermit) } } @@ -1147,16 +1234,28 @@ func (cm *ConnManager) targetOutboundHandler(ctx context.Context) { return } + // Wait for a permit to make another overall connection. This limits + // the total number of normal connections while the previous limits the + // total number of automatic outbound connections. + if !cm.totalNormalConnsSem.Acquire(ctx) { + cm.activeOutboundsSem.Release() + return + } + addr, err := cm.cfg.GetNewAddress() if err != nil { failedAttempts.Add(1) log.Debugf("Failed to get address for outbound connection: %v", err) + cm.totalNormalConnsSem.Release() cm.activeOutboundsSem.Release() continue } go func(addr net.Addr) { - onClose := cm.activeOutboundsSem.Release + onClose := func() { + cm.totalNormalConnsSem.Release() + cm.activeOutboundsSem.Release() + } conn, err := cm.dial(ctx, addr, ConnTypeOutbound, onClose, nil) if err != nil { failedAttempts.Add(1) @@ -1252,23 +1351,28 @@ func New(cfg *Config) (*ConnManager, error) { if cfg.Dial == nil { return nil, MakeError(ErrDialNil, "dial cannot be nil") } - // Default to sane values + // Default to sane values. if cfg.RetryDuration <= 0 { cfg.RetryDuration = defaultRetryDuration } + if cfg.MaxNormalConns == 0 { + cfg.MaxNormalConns = defaultMaxNormalConns + } if cfg.TargetOutbound == 0 { cfg.TargetOutbound = defaultTargetOutbound } + cfg.TargetOutbound = min(cfg.TargetOutbound, cfg.MaxNormalConns) cm := ConnManager{ - cfg: *cfg, // Copy so caller can't mutate - quit: make(chan struct{}), - maxRetryDuration: defaultMaxRetryDuration, - runPersistentChan: make(chan *persistentEntry, MaxPersistent), - activeOutboundsSem: makeSemaphore(cfg.TargetOutbound), - persistent: make(map[uint64]*persistentEntry, MaxPersistent), - pending: make(map[uint64]*pendingConnInfo), - active: make(map[uint64]*Conn, cfg.TargetOutbound), - connIDByAddr: make(map[string]uint64), + cfg: *cfg, // Copy so caller can't mutate + quit: make(chan struct{}), + maxRetryDuration: defaultMaxRetryDuration, + runPersistentChan: make(chan *persistentEntry, MaxPersistent), + totalNormalConnsSem: makeSemaphore(cfg.MaxNormalConns), + activeOutboundsSem: makeSemaphore(cfg.TargetOutbound), + persistent: make(map[uint64]*persistentEntry, MaxPersistent), + pending: make(map[uint64]*pendingConnInfo), + active: make(map[uint64]*Conn, cfg.TargetOutbound), + connIDByAddr: make(map[string]uint64), } return &cm, nil } diff --git a/internal/connmgr/error.go b/internal/connmgr/error.go index 9c87dad6e..6b313d89f 100644 --- a/internal/connmgr/error.go +++ b/internal/connmgr/error.go @@ -22,6 +22,10 @@ const ( // already has an established connection. ErrAlreadyConnected = ErrorKind("ErrAlreadyConnected") + // ErrMaxNormalConns indicates a connection attempt (inbound or outbound) + // would exceed the maximum allowed number of normal connections. + ErrMaxNormalConns = ErrorKind("ErrMaxNormalConns") + // ErrMaxPersistent indicates an attempt to add more than the maximum // allowed number of persistent connections. ErrMaxPersistent = ErrorKind("ErrMaxPersistent") diff --git a/internal/connmgr/error_test.go b/internal/connmgr/error_test.go index d4e5d2262..b3881bf7c 100644 --- a/internal/connmgr/error_test.go +++ b/internal/connmgr/error_test.go @@ -19,6 +19,7 @@ func TestErrorKindStringer(t *testing.T) { {ErrDialNil, "ErrDialNil"}, {ErrAlreadyPending, "ErrAlreadyPending"}, {ErrAlreadyConnected, "ErrAlreadyConnected"}, + {ErrMaxNormalConns, "ErrMaxNormalConns"}, {ErrMaxPersistent, "ErrMaxPersistent"}, {ErrDuplicatePersistent, "ErrDuplicatePersistent"}, {ErrNotFound, "ErrNotFound"}, diff --git a/internal/connmgr/log.go b/internal/connmgr/log.go index 4bf44f579..f6ba6f5f4 100644 --- a/internal/connmgr/log.go +++ b/internal/connmgr/log.go @@ -19,3 +19,12 @@ var log = slog.Disabled func UseLogger(logger slog.Logger) { log = logger } + +// pickNoun returns the singular or plural form of a noun depending on the count +// n. +func pickNoun[T ~uint32 | ~uint64](n T, singular, plural string) string { + if n == 1 { + return singular + } + return plural +} diff --git a/rpcadaptors.go b/rpcadaptors.go index a6826e047..b88fd5bee 100644 --- a/rpcadaptors.go +++ b/rpcadaptors.go @@ -130,14 +130,6 @@ func (cm *rpcConnManager) Connect(ctx context.Context, addr string, permanent bo return err } - // Limit max number of total peers. - cm.server.peerState.Lock() - count := cm.server.peerState.count() - cm.server.peerState.Unlock() - if count >= cfg.MaxPeers { - return errors.New("max peers reached") - } - // Attempt to add a persistent peer when requested. connManager := cm.server.connManager if permanent { diff --git a/server.go b/server.go index 3a7b1b671..5d019fdb4 100644 --- a/server.go +++ b/server.go @@ -2703,17 +2703,6 @@ func (s *server) handleAddPeer(sp *serverPeer) bool { return false } - // Limit max number of total peers. However, allow whitelisted inbound - // peers regardless. - if state.count()+1 > cfg.MaxPeers && !isInboundWhitelisted { - srvrLog.Infof("Max peers reached [%d] - disconnecting peer %s", - cfg.MaxPeers, sp) - sp.Disconnect() - // TODO: how to handle permanent peers here? - // they should be rescheduled. - return false - } - // Add the new peer. if sp.Inbound() { state.inboundPeers[sp.ID()] = sp @@ -4359,6 +4348,7 @@ func newServer(ctx context.Context, profiler *profileServer, s.inboundPeerConnected(ctx, conn) }, RetryDuration: connectionRetryInterval, + MaxNormalConns: uint32(cfg.MaxPeers), TargetOutbound: s.targetOutbound, Dial: s.attemptDcrdDial, DialTimeout: cfg.DialTimeout, From f14e676cb425b159b1454095cfd9af612147ab71 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Tue, 19 May 2026 14:11:50 -0500 Subject: [PATCH 20/24] connmgr: Add total max normal conns tests. This adds tests to ensure that the new max normal connection limiting properly enforces the limit including automatic outbound, manual outbound, and inbound connections. It also ensures that it not applied to persistent connections. --- internal/connmgr/connmanager_test.go | 154 +++++++++++++++++++++++++++ 1 file changed, 154 insertions(+) diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index 6ce82f066..3003b9869 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -1711,3 +1711,157 @@ func TestRejectDuplicateConns(t *testing.T) { wg.Wait() assertConnManagerCleanShutdown(t, cmgr) } + +// TestMaxNormalConns ensures the connection manager limits the total number of +// normal connections to [Config.MaxNormalConns] including automatic outbound, +// manual outbound, and inbound connections. It also ensures that it is not +// applied to persistent connections. +func TestMaxNormalConns(t *testing.T) { + t.Parallel() + + // nextAddr is a convenience func to return a new unique address with every + // invocation. + var numAddrs atomic.Uint32 + nextAddr := func() net.Addr { + addrStr := fmt.Sprintf("10.0.0.%d:18555", numAddrs.Add(1)) + return mustParseAddrPort(addrStr) + } + + // Constants for the number of various normal connection types to test + // overall max normal connection limits. + const ( + targetOutbound = 3 + targetManual = 4 + targetInbound = 5 + maxNormalConns = targetOutbound + targetManual + targetInbound + ) + connected := make(chan *Conn) + disconnected := make(chan *Conn) + inboundConns := make(chan *Conn) + listener := newMockListener("127.0.0.1:9108") + var pauseTargetOutbound atomic.Bool + var totalPausedAddrs atomic.Uint32 + hitMaxFailedAttempts := make(chan struct{}) + cmgr := newTestConnManager(t, &Config{ + Listeners: []net.Listener{listener}, + MaxNormalConns: maxNormalConns, + TargetOutbound: targetOutbound, + RetryDuration: 50 * time.Millisecond, + Dial: mockDialer, + OnAccept: func(conn *Conn) { + inboundConns <- conn + }, + GetNewAddress: func() (net.Addr, error) { + if pauseTargetOutbound.Load() { + total := totalPausedAddrs.Add(1) + if total == maxFailedAttempts { + hitMaxFailedAttempts <- struct{}{} + } + return nil, errors.New("network down") + } + return nextAddr(), nil + }, + OnConnection: func(conn *Conn) { + connected <- conn + }, + OnDisconnection: func(conn *Conn) { + disconnected <- conn + }, + }) + cmgr.maxRetryDuration = cmgr.cfg.RetryDuration + ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + + // Wait for the expected number of target outbound conns to be established. + outbounds := make([]*Conn, 0, targetOutbound) + for len(outbounds) < targetOutbound { + conn := assertConnReceived(t, connected, 0, ConnTypeOutbound) + outbounds = append(outbounds, conn) + } + assertConnManagerInternalState(t, cmgr) + + // Establish target number of inbounds to the listener and wait for them to + // be established. + go func() { + for range targetInbound { + listener.Connect(nextAddr()) + } + }() + inbounds := make([]*Conn, 0, targetInbound) + for len(inbounds) < targetInbound { + conn := assertConnReceived(t, inboundConns, 0, ConnTypeInbound) + inbounds = append(inbounds, conn) + } + assertConnManagerInternalState(t, cmgr) + + // Establish target number of manual connections and wait for them to be + // established. + go func() { + for range targetManual { + go cmgr.Connect(ctx, nextAddr()) + } + }() + manualConns := make([]*Conn, 0, targetManual+1) + for len(manualConns) < targetManual { + conn := assertConnReceived(t, connected, 0, ConnTypeManual) + manualConns = append(manualConns, conn) + } + assertConnManagerInternalState(t, cmgr) + + // Ensure manual connections that would exceed the max allowed normal + // connections are rejected. + _, err := cmgr.Connect(ctx, nextAddr()) + if !errors.Is(err, ErrMaxNormalConns) { + t.Fatalf("did not reject manual connection at max allowed, err: %v", err) + } + assertConnManagerInternalState(t, cmgr) + + // Ensure inbound connections that would exceed the max allowed normal + // connections are rejected. + go listener.Connect(nextAddr()) + assertNoConnReceived(t, inboundConns) + assertConnManagerInternalState(t, cmgr) + + // Pause the target outbound dials and remove one of the target outbound + // connections to make room for another manual connection. Then wait for + // the max failures to be hit so attempts are paused for a retry timeout. + pauseTargetOutbound.Store(true) + outboundConn := outbounds[0] + outboundConn.Close() + assertConnReceived(t, disconnected, outboundConn.ID(), ConnTypeOutbound) + select { + case <-hitMaxFailedAttempts: + time.Sleep(connTestReceiveTimeout) + case <-time.After(maxFailedAttempts * connTestReceiveTimeout): + t.Fatal("did not reach max failed attempts before timeout") + } + assertConnManagerInternalState(t, cmgr) + + // Establish another manual connection to take the place of the target + // outbound connection that was just closed and wait for it to be + // established. + go cmgr.Connect(ctx, nextAddr()) + assertConnReceived(t, connected, 0, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) + + // Unpause the target outbound dials and ensure no additional automatic + // outbound connections are made despite being under the target outbound due + // to max total conns. + pauseTargetOutbound.Store(false) + assertNoConnReceivedTimeout(t, connected, connTestNonReceiveTimeout+ + cmgr.cfg.RetryDuration) + assertConnManagerInternalState(t, cmgr) + + // Ensure persistent connections are not subject to the max total normal + // connections by adding one and waiting for it to be established. + connID, err := cmgr.AddPersistent(nextAddr()) + if err != nil { + t.Fatalf("failed to add persistent connection: %v", err) + } + assertConnReceived(t, connected, connID, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) + + // Ensure clean shutdown of connection manager. + shutdown() + wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) +} From 0fdcbecddc3d14a582022e6ded7d377dc5ddf6ba Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Thu, 21 May 2026 17:14:30 -0500 Subject: [PATCH 21/24] connmgr: Update README.md for total conn limits. --- internal/connmgr/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/internal/connmgr/README.md b/internal/connmgr/README.md index cf43ac92a..850f2620e 100644 --- a/internal/connmgr/README.md +++ b/internal/connmgr/README.md @@ -31,6 +31,8 @@ The following is a brief overview of the key features: with exponential backoff on disconnect - Manual connections - Supports manual connection establishment via `Connect` +- Connection limits + - Limits total normal (non-persistent) connections to `MaxNormalConns` - Duplicate address prevention - Rejects duplicate connections to and from the same address (host:port) - Whitelist support From e41935768f64b71ebda8b51501375b08e9fccb34 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Thu, 21 May 2026 16:44:17 -0500 Subject: [PATCH 22/24] connmgr: Limit max connections per host. Similar to the recent total normal connection limiting, the current per-host connection limits are enforced by the server. For similar reasons, it is much more robust and natural to perform the limiting early in the connection manager. With that in mind, this implements the per-host connection limiting in the connection manager and removes the relevant current limiting for it from the server. The limiting is applied to inbound, outbound, and persistent connections. The new limiting is handled early in both the inbound and outbound paths now which allows it to take advantage of fast connection shedding for inbound connections and to preemptively prevent all outbound attempts that would exceed the limit regardless of source. It also provides the flexibility to apply independent special permissions in the future. This also slightly changes the semantics to exempt whitelisted addresses for both inbound and outbound connections as opposed to only inbound connections. --- internal/connmgr/connmanager.go | 123 ++++++++++++++++++++++++++++++-- internal/connmgr/error.go | 4 ++ server.go | 38 ++-------- 3 files changed, 127 insertions(+), 38 deletions(-) diff --git a/internal/connmgr/connmanager.go b/internal/connmgr/connmanager.go index 08722cdd4..be2eaeda7 100644 --- a/internal/connmgr/connmanager.go +++ b/internal/connmgr/connmanager.go @@ -48,6 +48,11 @@ const ( // outbound, and pending connections to permit. defaultMaxNormalConns = 125 + // defaultMaxConnsPerHost is the default number of maximum connections with + // the same host to permit. It does not apply to whitelisted or loopback + // addresses. + defaultMaxConnsPerHost = 5 + // defaultTargetOutbound is the default number of outbound connections to // maintain. defaultTargetOutbound = 8 @@ -193,9 +198,10 @@ func (c *Conn) Type() ConnectionType { // pendingConnInfo houses information about a pending connection attempt. type pendingConnInfo struct { - id uint64 - addr *addrmgr.NetAddress - cancel context.CancelFunc + id uint64 + addr *addrmgr.NetAddress + hostKey string + cancel context.CancelFunc } // persistentEntry houses information about a persistent connection that has @@ -248,6 +254,21 @@ type Config struct { // this value. MaxNormalConns uint32 + // MaxConnsPerHost is the maximum number of connections with the same host + // to permit. Defaults to 5. + // + // This applies to inbound, outbound, and persistent connections. However, + // in practice, it is highly unlikely that outbound connections will hit the + // default limit (unless intentionally connecting manually) because: + // + // - connections to the same host:port are rejected and it is extremely rare + // for the same host to serve multiple instances on different ports + // - all automatic outbound connections are heavily biased toward different + // network groups + // + // This limit is not applied to whitelisted or loopback connections. + MaxConnsPerHost uint32 + // TargetOutbound is the number of outbound network connections to maintain // automatically. Defaults to 8. // @@ -344,6 +365,11 @@ type ConnManager struct { // (host:port). It is kept in sync with the persistent, pending, and active // maps and is primarily used to efficiently reject duplicate connections. connIDByAddr map[string]uint64 + + // perHostCounts provides fast O(1) lookup of the number of entries per + // host. It is kept in sync with the persistent, pending, and active maps + // and is primarily used to efficiently enforce per-host connection limits. + perHostCounts map[string]uint32 } // IsWhitelisted returns whether the IP address is included in the whitelisted @@ -403,6 +429,32 @@ func stdlibNetAddrToAddrMgrNetAddr(addr net.Addr) (*addrmgr.NetAddress, error) { return netAddr, nil } +// addrHostKey returns the host portion of the passed address as a string +// suitable for use as a map key. +func addrHostKey(addr net.Addr) string { + if na, ok := addr.(*addrmgr.NetAddress); ok { + return net.IP(na.IP).String() + } + + addrStr := addr.String() + host, _, err := net.SplitHostPort(addrStr) + if err == nil { + return host + } + return addrStr +} + +// decrementPerHostCount decrements the reference count for the provided host +// and cleans up the associated entry when there are no more references. +// +// This function MUST be called with the connection mutex held (writes). +func (cm *ConnManager) decrementPerHostCount(hostKey string) { + cm.perHostCounts[hostKey]-- + if cm.perHostCounts[hostKey] == 0 { + delete(cm.perHostCounts, hostKey) + } +} + // addPendingInfo adds information about a pending connection attempt to the // local state. // @@ -411,6 +463,7 @@ func (cm *ConnManager) addPendingInfo(info *pendingConnInfo) { cm.pending[info.id] = info if _, ok := cm.persistent[info.id]; !ok { cm.connIDByAddr[info.addr.String()] = info.id + cm.perHostCounts[info.hostKey]++ } } @@ -421,6 +474,7 @@ func (cm *ConnManager) removePendingInfo(info *pendingConnInfo) { delete(cm.pending, info.id) if _, ok := cm.persistent[info.id]; !ok { delete(cm.connIDByAddr, info.addr.String()) + cm.decrementPerHostCount(info.hostKey) } } @@ -431,6 +485,7 @@ func (cm *ConnManager) addActiveConn(conn *Conn) { cm.active[conn.id] = conn if _, ok := cm.persistent[conn.id]; !ok { cm.connIDByAddr[conn.remoteAddr.String()] = conn.id + cm.perHostCounts[addrHostKey(&conn.remoteAddr)]++ } } @@ -448,6 +503,7 @@ func (cm *ConnManager) removeActiveConn(conn *Conn) { delete(cm.active, conn.id) if _, ok := cm.persistent[conn.id]; !ok { delete(cm.connIDByAddr, conn.remoteAddr.String()) + cm.decrementPerHostCount(addrHostKey(&conn.remoteAddr)) } } @@ -457,6 +513,7 @@ func (cm *ConnManager) removeActiveConn(conn *Conn) { func (cm *ConnManager) addPersistentEntry(entry *persistentEntry) { cm.persistent[entry.id] = entry cm.connIDByAddr[entry.addr.String()] = entry.id + cm.perHostCounts[addrHostKey(entry.addr)]++ } // removePersistentEntry removes a persistent connection entry from the local @@ -469,6 +526,7 @@ func (cm *ConnManager) removePersistentEntry(entry *persistentEntry) { _, active := cm.active[entry.id] if !pending && !active { delete(cm.connIDByAddr, entry.addr.String()) + cm.decrementPerHostCount(addrHostKey(entry.addr)) } } @@ -541,6 +599,28 @@ func (cm *ConnManager) rejectDuplicateAddr(addr *addrmgr.NetAddress) error { return nil } +// rejectMaxConnsPerHost returns an error if adding an additional connection +// with the provided host address would exceed [Config.MaxConnsPerHost] and is +// not exempt. +// +// This function MUST be called with the connection mutex held (reads). +func (cm *ConnManager) rejectMaxConnsPerHost(addr *addrmgr.NetAddress, hostKey string, isWhitelisted bool) error { + // Whitelisted and loopback addresses are exempt. + isLoopback := net.IP(addr.IP).IsLoopback() + if isWhitelisted || isLoopback { + return nil + } + + maxAllowed := cm.cfg.MaxConnsPerHost + if numConns := cm.perHostCounts[hostKey]; numConns+1 > maxAllowed { + str := fmt.Sprintf("a maximum of %d %s per host is allowed", maxAllowed, + pickNoun(maxAllowed, "connection", "connections")) + return MakeError(ErrMaxConnsPerHost, str) + } + + return nil +} + // dial attempts to connect to the provided address and returns a connection // configured with the provided params on success. // @@ -556,6 +636,7 @@ func (cm *ConnManager) rejectDuplicateAddr(addr *addrmgr.NetAddress) error { // The following connection limits are enforced: // // - Total normal connections ([Config.MaxNormalConns]) +// - Total connections with the same host ([Config.MaxConnsPerHost]) // // On success, the returned connection is configured to remove itself from the // set of all active connections and invoke the provided on close callback (if @@ -575,6 +656,8 @@ func (cm *ConnManager) rejectDuplicateAddr(addr *addrmgr.NetAddress) error { // the address // - [ErrMaxNormalConns] when there are already the maximum allowed number of // normal connections (inbound, outbound, and pending) +// - [ErrMaxConnsPerHost] when there are already the maximum allowed number of +// connections (pending, active, and persistent) with the same host // - [ErrShutdown] when the connection manager is shutting down // - [context.Canceled] or [context.DeadlineExceeded] depending on the // provided context or when the dialer fails to establish a connection @@ -602,6 +685,8 @@ func (cm *ConnManager) dial(ctx context.Context, addr net.Addr, connType Connect if err != nil { return nil, err } + rAddrHostKey := addrHostKey(rAddr) + isWhitelisted := cm.IsWhitelisted(rAddr) // Reject attempts to dial addresses that are already connected (or in the // process of it). Additionally, reject attempts to dial existing @@ -621,6 +706,14 @@ func (cm *ConnManager) dial(ctx context.Context, addr net.Addr, connType Connect return nil, err } + // Limit the max number of connections per host. + err = cm.rejectMaxConnsPerHost(rAddr, rAddrHostKey, isWhitelisted) + if err != nil { + cm.connMtx.Unlock() + log.Debugf("Rejected connection to %v: %v", rAddr, err) + return nil, err + } + // Apply a dial timeout if requested. Otherwise, use a regular cancel // context to support canceling the pending connection later. var cancel context.CancelFunc @@ -639,7 +732,7 @@ func (cm *ConnManager) dial(ctx context.Context, addr net.Addr, connType Connect } else { connID = cm.nextConnID.Add(1) } - info := &pendingConnInfo{connID, rAddr, cancel} + info := &pendingConnInfo{connID, rAddr, rAddrHostKey, cancel} cm.addPendingInfo(info) cm.connMtx.Unlock() defer func() { @@ -731,6 +824,7 @@ func (cm *ConnManager) dial(ctx context.Context, addr net.Addr, connType Connect // limits are enforced: // // - Total normal connections ([Config.MaxNormalConns]) +// - Total connections with the same host ([Config.MaxConnsPerHost]) // // Note that the context parameter to this function and the lifecycle context // may be independent. @@ -746,6 +840,8 @@ func (cm *ConnManager) dial(ctx context.Context, addr net.Addr, connType Connect // the address // - [ErrMaxNormalConns] when there are already the maximum allowed number of // normal connections (inbound, outbound, and pending) +// - [ErrMaxConnsPerHost] when there are already the maximum allowed number of +// connections (pending, active, and persistent) with the same host // - [ErrShutdown] when the connection manager is shutting down // - [context.Canceled] or [context.DeadlineExceeded] depending on the // provided context or when the dialer fails to establish a connection @@ -925,6 +1021,8 @@ func (cm *ConnManager) listenHandler(ctx context.Context, listener net.Listener) netConn.Close() continue } + rAddrHostKey := addrHostKey(rAddr) + isWhitelisted := cm.IsWhitelisted(rAddr) // Reject connections with the same host:port as any existing pending, // established, or persistent connections. Note that this does NOT @@ -933,7 +1031,7 @@ func (cm *ConnManager) listenHandler(ctx context.Context, listener net.Listener) // // The aforementioned behavior is intentional as it allows connections // from the same host to be independently limited to more than one - // elsewhere. + // below. cm.connMtx.Lock() if err := cm.rejectDuplicateAddr(rAddr); err != nil { cm.connMtx.Unlock() @@ -941,6 +1039,15 @@ func (cm *ConnManager) listenHandler(ctx context.Context, listener net.Listener) netConn.Close() continue } + + // Limit the max number of connections per host. + err = cm.rejectMaxConnsPerHost(rAddr, rAddrHostKey, isWhitelisted) + if err != nil { + cm.connMtx.Unlock() + log.Debugf("Dropped connection from %v: %v", rAddr, err) + netConn.Close() + continue + } cm.connMtx.Unlock() // Require a permit to allow the inbound connection unless the address @@ -949,7 +1056,7 @@ func (cm *ConnManager) listenHandler(ctx context.Context, listener net.Listener) // Attempt to acquire a permit via a non-blocking call and immediately // disconnect if unsuccessful so that all blocking happens on // [net.Listener.Accept] for the reasons described above. - requirePermit := !cm.IsWhitelisted(rAddr) + requirePermit := !isWhitelisted if requirePermit { acquired, err := cm.totalNormalConnsSem.TryAcquire(ctx) if err != nil { @@ -1358,6 +1465,9 @@ func New(cfg *Config) (*ConnManager, error) { if cfg.MaxNormalConns == 0 { cfg.MaxNormalConns = defaultMaxNormalConns } + if cfg.MaxConnsPerHost == 0 { + cfg.MaxConnsPerHost = defaultMaxConnsPerHost + } if cfg.TargetOutbound == 0 { cfg.TargetOutbound = defaultTargetOutbound } @@ -1373,6 +1483,7 @@ func New(cfg *Config) (*ConnManager, error) { pending: make(map[uint64]*pendingConnInfo), active: make(map[uint64]*Conn, cfg.TargetOutbound), connIDByAddr: make(map[string]uint64), + perHostCounts: make(map[string]uint32), } return &cm, nil } diff --git a/internal/connmgr/error.go b/internal/connmgr/error.go index 6b313d89f..966475894 100644 --- a/internal/connmgr/error.go +++ b/internal/connmgr/error.go @@ -26,6 +26,10 @@ const ( // would exceed the maximum allowed number of normal connections. ErrMaxNormalConns = ErrorKind("ErrMaxNormalConns") + // ErrMaxConnsPerHost indicates a connection attempt (inbound or outbound) + // would exceed the maximum allowed number of connections per host. + ErrMaxConnsPerHost = ErrorKind("ErrMaxConnsPerHost") + // ErrMaxPersistent indicates an attempt to add more than the maximum // allowed number of persistent connections. ErrMaxPersistent = ErrorKind("ErrMaxPersistent") diff --git a/server.go b/server.go index 5d019fdb4..327ed87ba 100644 --- a/server.go +++ b/server.go @@ -2618,20 +2618,6 @@ func (s *server) considerReportedAddr(from *serverPeer, addr *wire.NetAddress) { s.considerReportedAddrOutbound(from, addr) } -// connectionsWithIP returns the number of connections with the given IP. -// -// This function MUST be called with the embedded mutex locked (for reads). -func (ps *peerState) connectionsWithIP(ip net.IP) int { - var total int - ps.forAllPeers(func(sp *serverPeer) { - if ip.Equal(sp.remoteAddr.IP) { - total++ - } - - }) - return total -} - // handleAddPeer deals with adding new peers and includes logic such as // categorizing the type of peer, limiting the maximum allowed number of peers, // and local external address resolution. @@ -2690,19 +2676,6 @@ func (s *server) handleAddPeer(sp *serverPeer) bool { defer state.Unlock() state.Lock() - // Limit max number of connections from a single IP. However, allow - // whitelisted inbound peers and localhost connections regardless. - isInboundWhitelisted := sp.isWhitelisted && sp.Inbound() - peerIP := net.IP(sp.remoteAddr.IP) - if cfg.MaxSameIP > 0 && !isInboundWhitelisted && !peerIP.IsLoopback() && - state.connectionsWithIP(peerIP)+1 > cfg.MaxSameIP { - - srvrLog.Infof("Max connections with %s reached [%d] - disconnecting "+ - "peer", sp, cfg.MaxSameIP) - sp.Disconnect() - return false - } - // Add the new peer. if sp.Inbound() { state.inboundPeers[sp.ID()] = sp @@ -4347,11 +4320,12 @@ func newServer(ctx context.Context, profiler *profileServer, OnAccept: func(conn *connmgr.Conn) { s.inboundPeerConnected(ctx, conn) }, - RetryDuration: connectionRetryInterval, - MaxNormalConns: uint32(cfg.MaxPeers), - TargetOutbound: s.targetOutbound, - Dial: s.attemptDcrdDial, - DialTimeout: cfg.DialTimeout, + RetryDuration: connectionRetryInterval, + MaxNormalConns: uint32(cfg.MaxPeers), + MaxConnsPerHost: uint32(cfg.MaxSameIP), + TargetOutbound: s.targetOutbound, + Dial: s.attemptDcrdDial, + DialTimeout: cfg.DialTimeout, OnConnection: func(conn *connmgr.Conn) { s.outboundPeerConnected(ctx, conn) }, From 4f627e96a4c5dc1c42e5f94fab4f9d37241e047e Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Thu, 21 May 2026 16:44:17 -0500 Subject: [PATCH 23/24] connmgr: Add max per-host conn tests. This adds tests to ensure that the new max connections per host limiting properly enforces the limit including automatic outbound, manual outbound, inbound, and persistent connections. It also tests whitelisted addresses are exempt. --- internal/connmgr/connmanager_test.go | 175 ++++++++++++++++++++++++++- 1 file changed, 173 insertions(+), 2 deletions(-) diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index 3003b9869..1313c9233 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -144,14 +144,18 @@ func assertConnManagerInternalState(t *testing.T, cm *ConnManager) { // Assert the pending and active maps are mutually exclusive for both conn // IDs and addrs. // - // Also build a map of addrs to conn IDs in the pending, active, and - // persistent maps for the checks below. + // Also build a map of addrs to conn IDs and tally the per host counts in + // the pending, active, and persistent maps for the checks below. connIDByAddr := make(map[string]uint64) + perHostCounts := make(map[string]uint32) for id, info := range cm.pending { if _, ok := cm.active[id]; ok { t.Fatalf("conn ID %d is both pending and active", id) } connIDByAddr[info.addr.String()] = id + if _, ok := cm.persistent[id]; !ok { + perHostCounts[addrHostKey(info.addr)]++ + } } for id, conn := range cm.active { if _, ok := cm.pending[id]; ok { @@ -162,6 +166,9 @@ func assertConnManagerInternalState(t *testing.T, cm *ConnManager) { t.Fatalf("addr %s is both pending and active", addrStr) } connIDByAddr[addrStr] = id + if _, ok := cm.persistent[id]; !ok { + perHostCounts[addrHostKey(&conn.remoteAddr)]++ + } } for id, entry := range cm.persistent { // Assert the conn ID of established/pending persistent conns matches. @@ -170,6 +177,7 @@ func assertConnManagerInternalState(t *testing.T, cm *ConnManager) { t.Fatalf("conn ID for addr %s mismatch: %d != %d", addrStr, existingID, id) } + perHostCounts[addrHostKey(entry.addr)]++ connIDByAddr[addrStr] = id } @@ -179,6 +187,13 @@ func assertConnManagerInternalState(t *testing.T, cm *ConnManager) { t.Fatalf("mismatched conn ID by addr maps\ngot: %v\nwant %v", cm.connIDByAddr, connIDByAddr) } + + // Assert the per host counts match the values obtained from manually + // tallying them. + if !reflect.DeepEqual(cm.perHostCounts, perHostCounts) { + t.Fatalf("mismatched per host count maps\ngot: %v\nwant %v", + cm.perHostCounts, perHostCounts) + } } // assertConnManagerCleanShutdown ensures the internal state of the passed @@ -203,6 +218,10 @@ func assertConnManagerCleanShutdown(t *testing.T, cm *ConnManager) { t.Fatalf("conn ID by addr map not empty: %d entries", len(cm.connIDByAddr)) } + if len(cm.perHostCounts) != 0 { + t.Fatalf("per host counts map not empty: %d entries", + len(cm.perHostCounts)) + } } // TestNewConfig tests that new ConnManager config is validated as expected. @@ -1865,3 +1884,155 @@ func TestMaxNormalConns(t *testing.T) { wg.Wait() assertConnManagerCleanShutdown(t, cmgr) } + +// TestMaxConnsPerHost ensures the connection manager limits the total number of +// connections with the same host to [Config.MaxConnsPerHost] including +// automatic outbound, manual outbound, inbound, and persistent connections. It +// also tests whitelisted addresses are exempt. +func TestMaxConnsPerHost(t *testing.T) { + t.Parallel() + + // nextSameHost is a convenience func to return a new address to the same IP + // with a different port on every invocation. + var nextPort atomic.Uint32 + nextSameHost := func() net.Addr { + addrStr := fmt.Sprintf("10.10.0.1:%d", nextPort.Add(1)+1024) + return mustParseAddrPort(addrStr) + } + + // nextSameHostWhitelisted is a convenience func to return a new address to + // the same whitelisted IP with a different port on every invocation. + allowedIP := netip.MustParseAddr("10.20.0.1") + nextSameWhitelistedHost := func() net.Addr { + addrStr := fmt.Sprintf("%s:%d", allowedIP, nextPort.Add(1)+1024) + return mustParseAddrPort(addrStr) + } + + const maxConnsPerHost = 3 + connected := make(chan *Conn, 1) + disconnected := make(chan *Conn, 1) + inboundConns := make(chan *Conn) + listener := newMockListener("127.0.0.1:9108") + var pauseTargetOutbound atomic.Bool + var totalPausedAddrs atomic.Uint32 + hitMaxFailedAttempts := make(chan struct{}) + cmgr := newTestConnManager(t, &Config{ + Listeners: []net.Listener{listener}, + MaxNormalConns: 30, // High enough to not interfere with per-host tests. + MaxConnsPerHost: maxConnsPerHost, + TargetOutbound: maxConnsPerHost, + RetryDuration: 50 * time.Millisecond, + Dial: mockDialer, + Whitelists: []netip.Prefix{netip.PrefixFrom(allowedIP, 32)}, + OnAccept: func(conn *Conn) { + inboundConns <- conn + }, + GetNewAddress: func() (net.Addr, error) { + if pauseTargetOutbound.Load() { + total := totalPausedAddrs.Add(1) + if total == maxFailedAttempts { + close(hitMaxFailedAttempts) + } + return nil, errors.New("network down") + } + return nextSameHost(), nil + }, + OnConnection: func(conn *Conn) { + connected <- conn + }, + OnDisconnection: func(conn *Conn) { + disconnected <- conn + }, + }) + cmgr.maxRetryDuration = cmgr.cfg.RetryDuration + ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + + // Wait for the maximum allowed non-whitelisted per-host automatic outbound + // conns. + outboundConns := make([]*Conn, 0, maxConnsPerHost) + for len(outboundConns) < maxConnsPerHost { + conn := assertConnReceived(t, connected, 0, ConnTypeOutbound) + outboundConns = append(outboundConns, conn) + } + assertConnManagerInternalState(t, cmgr) + + // Ensure non-whitelisted manual connections that would exceed the max + // allowed per-host connections are rejected. + _, err := cmgr.Connect(ctx, nextSameHost()) + if !errors.Is(err, ErrMaxConnsPerHost) { + t.Fatalf("did not reject manual connection at per-host limit, err: %v", + err) + } + assertConnManagerInternalState(t, cmgr) + + // Ensure non-whitelisted inbound connections that would exceed the max + // allowed per-host connections are rejected. + go listener.Connect(nextSameHost()) + assertNoConnReceived(t, inboundConns) + assertConnManagerInternalState(t, cmgr) + + // Ensure whitelisted manual connections are allowed to exceed the per-host + // limit. + for range maxConnsPerHost + 1 { + go cmgr.Connect(ctx, nextSameWhitelistedHost()) + assertConnReceived(t, connected, 0, ConnTypeManual) + } + + // Ensure whitelisted inbound connections are allowed to exceed the per-host + // limit. + go listener.Connect(nextSameWhitelistedHost()) + assertConnReceived(t, inboundConns, 0, ConnTypeInbound) + assertConnManagerInternalState(t, cmgr) + + // Ensure whitelisted persistent connections are allowed to exceed the + // per-host limit. + connID, err := cmgr.AddPersistent(nextSameWhitelistedHost()) + if err != nil { + t.Fatalf("failed to add persistent connection: %v", err) + } + assertConnReceived(t, connected, connID, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) + + // Pause the target outbound dials and remove one of the target outbound + // connections to make room for another manual connection with the same + // host. Then wait for the max failures to be hit so attempts are paused + // for a retry timeout. + pauseTargetOutbound.Store(true) + outboundConn := outboundConns[0] + outboundConn.Close() + assertConnReceived(t, disconnected, outboundConn.ID(), ConnTypeOutbound) + select { + case <-hitMaxFailedAttempts: + time.Sleep(connTestReceiveTimeout) + case <-time.After(maxFailedAttempts * connTestReceiveTimeout): + t.Fatal("did not reach max failed attempts before timeout") + } + + // Ensure a new non-whitelisted manual connection to the same host now + // succeeds. + go cmgr.Connect(ctx, nextSameHost()) + assertConnReceived(t, connected, 0, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) + + // Unpause the target outbound dials and ensure no additional automatic + // outbound connections to the same host are made despite being under the + // target outbound. + noConnWaitTimeout := connTestReceiveTimeout + cmgr.cfg.RetryDuration + pauseTargetOutbound.Store(false) + assertNoConnReceivedTimeout(t, connected, noConnWaitTimeout) + assertConnManagerInternalState(t, cmgr) + + // Ensure persistent connections are also subject to the max per-host + // connections by adding one and confirming it is NOT established. + _, err = cmgr.AddPersistent(nextSameHost()) + if err != nil { + t.Fatalf("failed to add persistent connection: %v", err) + } + assertNoConnReceivedTimeout(t, connected, noConnWaitTimeout) + assertConnManagerInternalState(t, cmgr) + + // Ensure clean shutdown of connection manager. + shutdown() + wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) +} From 92268d2bcbde18985a447d9dafde95fa4bf3f6ba Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Thu, 21 May 2026 17:20:32 -0500 Subject: [PATCH 24/24] connmgr: Update README.md for per-host conn limits. --- internal/connmgr/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/internal/connmgr/README.md b/internal/connmgr/README.md index 850f2620e..30ebaef01 100644 --- a/internal/connmgr/README.md +++ b/internal/connmgr/README.md @@ -33,6 +33,8 @@ The following is a brief overview of the key features: - Supports manual connection establishment via `Connect` - Connection limits - Limits total normal (non-persistent) connections to `MaxNormalConns` + - Limits per-host connections to `MaxConnsPerHost` with exemptions for + whitelisted and loopback addresses - Duplicate address prevention - Rejects duplicate connections to and from the same address (host:port) - Whitelist support