diff --git a/addrmgr/netaddress.go b/addrmgr/netaddress.go index 4380bb147..1b8a4a699 100644 --- a/addrmgr/netaddress.go +++ b/addrmgr/netaddress.go @@ -75,9 +75,18 @@ func (netAddr *NetAddress) Key() string { return net.JoinHostPort(netAddr.ipString(), portString) } +// Network returns the name of the network. It is always tcp. +// +// This is part of the [net.Addr] implementation. +func (netAddr *NetAddress) Network() string { + return "tcp" +} + // String returns a human-readable string for the network address. This is // equivalent to calling Key, but is provided so the type can be used as a // fmt.Stringer. +// +// This is part of the [net.Addr] implementation. func (netAddr *NetAddress) String() string { return netAddr.Key() } diff --git a/internal/connmgr/README.md b/internal/connmgr/README.md index 46114366a..3a83f035c 100644 --- a/internal/connmgr/README.md +++ b/internal/connmgr/README.md @@ -5,26 +5,48 @@ connmgr [![ISC License](https://img.shields.io/badge/license-ISC-blue.svg)](http://copyfree.org) [![Doc](https://img.shields.io/badge/doc-reference-blue.svg)](https://pkg.go.dev/github.com/decred/dcrd/internal/connmgr) -Package connmgr implements a generic Decred network connection manager. - ## Overview -This package handles all the general connection concerns such as maintaining a -set number of outbound connections, sourcing peers, banning, limiting max -connections, tor lookup, etc. - -The package provides a generic connection manager which is able to accept -connection requests from a source or a set of given addresses, dial them and -notify the caller on connections. The main intended use is to initialize a pool -of active connections and maintain them to remain connected to the P2P network. - -In addition the connection manager provides the following utilities: - -- Notifications on connections or disconnections -- Handle failures and retry new addresses from the source -- Connect only to specified addresses -- Permanent connections with increasing backoff retry timers -- Disconnect or Remove an established connection +Package `connmgr` provides a flexible and robust context-aware connection +manager for inbound, outbound, and persistent network connections with retry +logic. + +In short, it handles all general connection concerns such as accepting inbound +connections, automatically maintaining a set number of outbound connections, +maintaining persistent connections, and limiting max connections. + +The design has a strong emphasis on reliability, readability, and efficiency +while also aiming to provide an ergonomic API. + +The following is a brief overview of the key features: + +- Full context support +- Inbound listening + - Accepts inbound connections on provided `Listeners` + - Uses connection shedding for rejected inbound connections +- Automatic outbound maintenance + - Maintains up to `TargetOutbound` normal outbound connections via a provided + address source (`GetNewAddress`) +- Persistent connections + - Maintains up to `MaxPersistent` addresses that are automatically retried + with exponential backoff on disconnect +- Manual connections + - Supports manual connection establishment via `Connect` +- Custom connection wrapping via `Conn` + - Connection types for differentiated handling + - Automatic cleanup on connection close + - Concrete parsed address access +- Manual disconnection and removal + - Ability to disconnect / remove established, pending, and persistent + connections via `Disconnect` and `Remove` +- Duplicate address prevention + - Rejects duplicate connections to and from the same address (host:port) +- Notification callbacks + - Provides callbacks for connection establishment and disconnects +- Graceful network outage handling + - Automatic connection attempts are throttled during network outages + +A full suite of tests is provided to help ensure proper functionality. ## License diff --git a/internal/connmgr/connmanager.go b/internal/connmgr/connmanager.go index 60e1afcc1..ed966c2a8 100644 --- a/internal/connmgr/connmanager.go +++ b/internal/connmgr/connmanager.go @@ -3,24 +3,29 @@ // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. +// Package connmgr provides a robust connection manager for inbound, outbound, +// and persistent network connections with retry logic. package connmgr import ( "context" + "errors" "fmt" "net" + "net/netip" + "strconv" "sync" "sync/atomic" "time" -) -var ( + "github.com/decred/dcrd/addrmgr/v4" +) - // maxRetryDuration is the max duration of time retrying of a persistent - // connection is allowed to grow to. This is necessary since the retry - // logic uses a backoff mechanism which increases the interval base times - // the number of retries that have been done. - maxRetryDuration = time.Minute * 5 +const ( + // MaxPersistent is the maximum number of persistent connections that can be + // added. Persistent connections do not count towards the automatic + // outbound connection limits. + MaxPersistent = 8 ) const ( @@ -33,77 +38,179 @@ const ( // persistent connections. defaultRetryDuration = time.Second * 5 + // defaultMaxRetryDuration is the default maximum duration a persistent + // connection retry backoff is allowed to grow to. This is necessary since + // the retry logic uses a backoff mechanism which increases the interval + // base times the number of retries that have been done. + defaultMaxRetryDuration = time.Minute * 5 + + // defaultMaxNormalConns is the default number of maximum normal inbound, + // outbound, and pending connections to permit. + defaultMaxNormalConns = 125 + // defaultTargetOutbound is the default number of outbound connections to // maintain. - defaultTargetOutbound = uint32(8) + defaultTargetOutbound = 8 ) -// ConnState represents the state of the requested connection. -type ConnState uint32 +// ConnectionType specifies the different types of supported connections. +type ConnectionType uint8 -// ConnState can be either pending, established, disconnected or failed. When -// a new connection is requested, it is attempted and categorized as -// established or failed depending on the connection result. An established -// connection which was disconnected is categorized as disconnected. const ( - ConnPending ConnState = iota - ConnEstablished - ConnDisconnected - ConnFailed - ConnCanceled + // ConnTypeInbound indicates the connection was established by a remote + // peer. No further details are known about this connection until a + // handshake takes place. + ConnTypeInbound ConnectionType = iota + + // ConnTypeOutbound indicates a normal outbound connection that was + // established with no additional restrictions imposed on the type of + // information that the local peer/server is willing to relay. + // + // Note that this in no way implies further restrictions may not be + // negotiated depending on the protocol messages exchanged between the two + // peers. + ConnTypeOutbound + + // ConnTypeManual indicates an outbound connection that was manually + // requested via [ConnManager.Connect] or [ConnManager.AddPersistent]. In + // practice, this connection type is the result of requesting manual + // connections via an RPC method (e.g. "node connect") or via command line + // configuration options (e.g. --addpeer and --connect). + ConnTypeManual + + // numConnTypes is the number of connection types. This entry MUST be the + // last entry in the enum. + numConnTypes ) -// ConnReq is the connection request to a network address. If permanent, the -// connection will be retried on disconnection. -type ConnReq struct { - // id is the unique identifier for this connection request. - id atomic.Uint64 +// connTypeStrings is a map of connection types to human-readable names for +// pretty printing. +var connTypeStrings = map[ConnectionType]string{ + ConnTypeInbound: "inbound", + ConnTypeOutbound: "outbound", + ConnTypeManual: "manual", +} - // state is the current connection state for this connection request. - state atomic.Uint32 +// String returns the [ConnectionType] in human-readable form. +func (connType ConnectionType) String() string { + if s, ok := connTypeStrings[connType]; ok { + return s + } - // The following fields are owned by the connection manager and must not - // be accessed without its connection mutex held. + return fmt.Sprintf("Unknown ConnectionType (%d)", uint8(connType)) +} + +// Conn houses information about a managed connection. It is the callers +// responsibility to always ensure [Conn.Close] is called when the connection +// is no longer required. +type Conn struct { + // The following variables are set at the time the instance is created and + // are safe for concurrent access. + // + // net.Conn is the underlying connection. It is embedded which makes all of + // its methods immediately available. // - // retryCount is the number of times a permanent connection request that - // fails to connect has been retried since the last successful connection. + // id is the unique identifier for this connection. // - // conn is the underlying network connection. It will be nil before a - // connection has been established. - retryCount uint32 - conn net.Conn + // connType specifies the connection type. + // + // remoteAddr is the remote address associated with the connection. It is + // a concrete address manager address. + // + // onClose is a callback that will be invoked when the connection is closed. + net.Conn + id uint64 + connType ConnectionType + remoteAddr addrmgr.NetAddress + onClose func() + + // closed houses whether or not the connection has already been closed. + closed atomic.Bool +} - // Addr is the address to connect to. - Addr net.Addr +// newConn returns a new connection given an underlying [net.Conn], connection +// ID, and connection type. +// +// The returned connection is automatically configured to invoke the provided on +// close handler (when non-nil) followed by the [Config.OnDisconnection] that +// was configured when initially creating the connection manager when the +// connection is closed. The on close handler is invoked in the same goroutine +// as the caller of [Conn.Close] and [Config.OnDisconnection] is invoked in a +// separate goroutine. +func newConn(cm *ConnManager, conn net.Conn, id uint64, connType ConnectionType, remoteAddr *addrmgr.NetAddress, onClose func()) *Conn { + c := &Conn{Conn: conn, id: id, connType: connType, remoteAddr: *remoteAddr} + c.onClose = func() { + onClose() + if cm.cfg.OnDisconnection != nil { + go cm.cfg.OnDisconnection(c) + } + } + return c +} - // Permanent specifies whether or not the connection request represents what - // should be treated as a permanent connection, meaning the connection - // manager will try to always maintain the connection including retries with - // increasing backoff timeouts. - Permanent bool +// ID returns a unique identifier for the connection. +// +// This function is safe for concurrent access. +func (c *Conn) ID() uint64 { + return c.id } -// updateState updates the state of the connection request. -func (c *ConnReq) updateState(state ConnState) { - c.state.Store(uint32(state)) +// Close closes the connection. The [Config.OnDisconnection] that was +// configured when initially creating the connection manager will be invoked in +// a separate goroutine prior to closing the underlying connection. +// +// Repeated close attempts are ignored. Closing a connection that has already +// been closed will not return an error. +// +// This function is safe for concurrent access. +func (c *Conn) Close() error { + // Already closed. + if !c.closed.CompareAndSwap(false, true) { + return nil + } + + // Invoke close callback associated with the connection when it's closed. + if c.onClose != nil { + c.onClose() + } + + // Close the underlying connection. + return c.Conn.Close() } -// ID returns a unique identifier for the connection request. -func (c *ConnReq) ID() uint64 { - return c.id.Load() +// RemoteAddr returns the remote address manager network address associated with +// the connection. It returns a [net.Addr] to implement the [net.Conn] +// interface, but the underlying type will be a [*addrmgr.NetAddress]. +func (c *Conn) RemoteAddr() net.Addr { + return &c.remoteAddr } -// State is the connection state of the requested connection. -func (c *ConnReq) State() ConnState { - return ConnState(c.state.Load()) +// Type returns the [ConnectionType] of the connection. +// +// This function is safe for concurrent access. +func (c *Conn) Type() ConnectionType { + return c.connType } -// String returns a human-readable string for the connection request. -func (c *ConnReq) String() string { - if c.Addr == nil || c.Addr.String() == "" { - return fmt.Sprintf("reqid %d", c.ID()) - } - return fmt.Sprintf("%s (reqid %d)", c.Addr, c.ID()) +// pendingConnInfo houses information about a pending connection attempt. +type pendingConnInfo struct { + id uint64 + addr net.Addr + cancel context.CancelFunc +} + +// persistentEntry houses information about a persistent connection that has +// been added to the connection manager. Once an ID has been assigned, all +// future connections established for the persistent connection will have the +// same ID. This allows it to be uniquely identified and removed later. +type persistentEntry struct { + id uint64 + addr net.Addr + + // cancel shuts down the goroutine that maintains the persistent connection. + // It is owned by the connection manager and must not be accessed without + // its connection mutex held. + cancel context.CancelFunc } // Config holds the configuration options related to the connection manager. @@ -129,10 +236,27 @@ type Config struct { // This field will not have any effect if the Listeners field is not // also specified since there couldn't possibly be any accepted // connections in that case. - OnAccept func(net.Conn) + OnAccept func(*Conn) - // TargetOutbound is the number of outbound network connections to - // maintain. Defaults to 8. + // MaxNormalConns is the maximum number of normal inbound, outbound, and + // pending connections to permit. Defaults to 125. + // + // Persistent connections do not count against this limit. They have their + // own maximum defined by [MaxPersistent]. + // + // Whitelisted connections and some connections with special permissions are + // also exempt. As a result, the total number of connections may exceed + // this value. + MaxNormalConns uint32 + + // TargetOutbound is the number of outbound network connections to maintain + // automatically. Defaults to 8. + // + // Persistent connections do not count against this value. They have their + // own maximum limit defined by [MaxPersistent]. + // + // This will be forced to the smaller of the specified value (or its default + // value when unspecified) and [Config.MaxNormalConns]. TargetOutbound uint32 // RetryDuration is the duration to wait before retrying connection @@ -141,11 +265,10 @@ type Config struct { // OnConnection is a callback that is fired when a new outbound // connection is established. - OnConnection func(*ConnReq, net.Conn) + OnConnection func(*Conn) - // OnDisconnection is a callback that is fired when an outbound - // connection is disconnected. - OnDisconnection func(*ConnReq) + // OnDisconnection is a callback that is fired when a connection is closed. + OnDisconnection func(*Conn) // GetNewAddress is a way to get an address to make a network connection // to. If nil, no new connections will be made automatically. @@ -157,17 +280,16 @@ type Config struct { // DialTimeout specifies the amount of time to wait for a connection to // complete before giving up. DialTimeout time.Duration + + // Whitelists specifies CIDR address prefixes to whitelist. Whitelisted + // addresses are exempt from banning and certain connection limits. + Whitelists []netip.Prefix } // ConnManager provides a manager to handle network connections. type ConnManager struct { - // connReqCount is the number of connection requests that have been made and - // is primarily used to assign unique connection request IDs. - connReqCount atomic.Uint64 - - // assignIDMtx synchronizes the assignment of an ID to a connection request - // with overall connection request count above. - assignIDMtx sync.Mutex + // nextConnID is used to assign unique connection request IDs. + nextConnID atomic.Uint64 // quit is used for lifecycle management of the connection manager. quit chan struct{} @@ -176,424 +298,970 @@ type ConnManager struct { // creating time and treated as immutable after that. cfg Config - // failedAttempts tracks the total number of failed oubound connection - // attempts since the last successful connection made by the connection - // manager. It is primarily used to detect network outages in order to - // impose a retry timeout on achieving the target number of outbound - // connections which prevents runaway failed connection attempt churn. + // maxRetryDuration is the maximum duration a persistent connection retry + // backoff is allowed to grow to. + maxRetryDuration time.Duration + + // runPersistentChan is used to signal the persistent connections handler to + // launch a goroutine that attempts to always maintain an established + // connection with a given address. // - // This field is owned by the connection handler and must not be accessed - // outside of it. - failedAttempts uint64 + // It is a buffered channel with size [MaxPersistent]. + runPersistentChan chan *persistentEntry - // The following fields are used to track the various connections managed - // by the connection manager. They are protected by the associated - // connection mutex. + // These semaphores are used to enforce max limits on the number of + // connections of different kinds. They do not apply to persistent + // connections which are separately limited to [MaxPersistent]. // - // pending holds all registered connection requests that have yet to - // succeed. + // totalNormalConnsSem limits the total overall number of normal inbound, + // outbound, and pending connections. // - // conns represents the set of all active connections. + // outboundSem limits the number of active outbound connections. + totalNormalConnsSem semaphore + activeOutboundsSem semaphore + + // The fields below this point are all protected by the connection mutex. connMtx sync.Mutex - pending map[uint64]*ConnReq - conns map[uint64]*ConnReq + + // persistent tracks all registered persistent connection entries. + // + // A persistent connection can be in one of three states: + // + // - Established with the connection instance in the active map + // - Pending with an entry in the pending map + // - Awaiting a retry + // + // Regardless of the state, there will always be an entry in this map. + persistent map[uint64]*persistentEntry + + // pending holds all registered connection attempts that have yet to + // succeed. + pending map[uint64]pendingConnInfo + + // active represents the set of all active connections. + active map[uint64]*Conn + + // connIDByAddr provides fast O(1) lookup of connection IDs by address + // (host:port). It is kept in sync with the persistent, pending, and active + // maps and is primarily used to efficiently reject duplicate connections. + connIDByAddr map[string]uint64 } -// registerPending registers the provided connection request as a pending -// connection attempt. -// -// This function MUST be called with the connection mutex lock held (writes). -func (cm *ConnManager) registerPending(connReq *ConnReq) { - connReq.updateState(ConnPending) - cm.pending[connReq.ID()] = connReq +// IsWhitelisted returns whether the IP address is included in the whitelisted +// networks and IPs. +func (cm *ConnManager) IsWhitelisted(addr *addrmgr.NetAddress) bool { + if len(cm.cfg.Whitelists) == 0 { + return false + } + + ip, _ := netip.AddrFromSlice(addr.IP) + for _, prefix := range cm.cfg.Whitelists { + if prefix.Contains(ip) { + return true + } + } + return false } -// newConnReq creates a new connection request and connects to the corresponding -// address. -func (cm *ConnManager) newConnReq(ctx context.Context) { - // Ignore during shutdown. - if ctx.Err() != nil { - return +// checkShutdown returns [ErrShutdown] when the connection manager quit channel +// has been closed. +func (cm *ConnManager) checkShutdown() error { + select { + case <-cm.quit: + const str = "connection manager shutdown" + return MakeError(ErrShutdown, str) + default: } + return nil +} - c := &ConnReq{} - c.id.Store(cm.connReqCount.Add(1)) +// stdlibNetAddrToAddrMgrNetAddr converts the provided standard lib [net.Addr] +// to a concrete address manager network address. +func stdlibNetAddrToAddrMgrNetAddr(addr net.Addr) (*addrmgr.NetAddress, error) { + addrStr := addr.String() + host, portStr, err := net.SplitHostPort(addrStr) + if err != nil { + str := fmt.Sprintf("unable to split address %q", addrStr) + return nil, MakeError(ErrUnsupportedAddr, str) + } + port, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + str := fmt.Sprintf("invalid port for address %q", addrStr) + return nil, MakeError(ErrUnsupportedAddr, str) + } - // Register the pending connection attempt so it can be canceled via the - // [ConnManager.Remove] method. - cm.connMtx.Lock() - cm.registerPending(c) - cm.connMtx.Unlock() + addrType, addrBytes := addrmgr.EncodeHost(host) + if addrType == addrmgr.UnknownAddressType { + str := fmt.Sprintf("unable to determine address type for %q", addrStr) + return nil, MakeError(ErrUnsupportedAddr, str) + } - addr, err := cm.cfg.GetNewAddress() + // Since the host type has been successfully recognized and encoded, there + // is no need to perform a DNS lookup. + now := time.Unix(time.Now().Unix(), 0) + netAddr, err := addrmgr.NewNetAddressFromParams(addrType, addrBytes, + uint16(port), now, 0) if err != nil { - cm.connMtx.Lock() - cm.handleFailedPending(ctx, c, err) - cm.connMtx.Unlock() - return + return nil, MakeError(ErrUnsupportedAddr, err.Error()) } + return netAddr, nil +} - c.Addr = addr +// addPendingInfo adds information about a pending connection attempt to the +// local state. +// +// This function MUST be called with the connection mutex held (writes). +func (cm *ConnManager) addPendingInfo(info *pendingConnInfo) { + cm.pending[info.id] = *info + if _, ok := cm.persistent[info.id]; !ok { + cm.connIDByAddr[info.addr.String()] = info.id + } +} - cm.Connect(ctx, c) +// removePendingInfo removes a pending connection attempt from the local state. +// +// This function MUST be called with the connection mutex held (writes). +func (cm *ConnManager) removePendingInfo(id uint64, addr net.Addr) { + delete(cm.pending, id) + if _, ok := cm.persistent[id]; !ok { + delete(cm.connIDByAddr, addr.String()) + } } -// handleFailedConn handles a connection failed due to a disconnect or any other -// failure. Permanent connection requests are retried after the configured -// retry duration. A new connection request is created if required. +// addActiveConn adds an established connection to the local state. // -// In the event there have been [maxFailedAttempts] failed successive attempts, -// new connections will be retried after the configured retry duration. +// This function MUST be called with the connection mutex held (writes). +func (cm *ConnManager) addActiveConn(conn *Conn) { + cm.active[conn.id] = conn + if _, ok := cm.persistent[conn.id]; !ok { + cm.connIDByAddr[conn.remoteAddr.String()] = conn.id + } +} + +// removeActiveConn removes an established connection from the local state. It +// has no effect if the connection has already been removed from the active map. // -// This function MUST be called with the connection lock held (writes). -func (cm *ConnManager) handleFailedConn(ctx context.Context, c *ConnReq) { - // Ignore during shutdown. - select { - case <-cm.quit: - return - case <-ctx.Done(): +// This function MUST be called with the connection mutex held (writes). +func (cm *ConnManager) removeActiveConn(id uint64, addr net.Addr) { + // The active connection might have already been removed before releasing + // the mutex to call [Conn.Close]. + if _, ok := cm.active[id]; !ok { return - default: } - // Reconnect to permanent connection requests after a retry timeout with - // an increasing backoff up to a max for repeated failed attempts. - if c.Permanent { - c.retryCount++ - retryWait := time.Duration(c.retryCount) * cm.cfg.RetryDuration - retryWait = min(retryWait, maxRetryDuration) - log.Debugf("Retrying connection to %v in %v", c, retryWait) - go func() { - select { - case <-time.After(retryWait): - cm.Connect(ctx, c) - case <-cm.quit: - case <-ctx.Done(): - } - }() - return + delete(cm.active, id) + if _, ok := cm.persistent[id]; !ok { + delete(cm.connIDByAddr, addr.String()) } +} - // Nothing more to do when the method to automatically get new addresses - // to connect to isn't configured. - if cm.cfg.GetNewAddress == nil { - return +// addPersistentEntry adds a persistent connection entry to the local state. +// +// This function MUST be called with the connection mutex held (writes). +func (cm *ConnManager) addPersistentEntry(entry *persistentEntry) { + cm.persistent[entry.id] = entry + cm.connIDByAddr[entry.addr.String()] = entry.id +} + +// removePersistentEntry removes a persistent connection entry from the local +// state. +// +// This function MUST be called with the connection mutex held (writes). +func (cm *ConnManager) removePersistentEntry(id uint64, addr net.Addr) { + delete(cm.persistent, id) + _, pending := cm.pending[id] + _, active := cm.active[id] + if !pending && !active { + delete(cm.connIDByAddr, addr.String()) } +} - // Wait to attempt new connections when there are too many successive - // failures. This prevents massive connection spam when no connections can - // be made, such as a network outtage. - cm.failedAttempts++ - if cm.failedAttempts >= maxFailedAttempts { - log.Debugf("Max failed connection attempts reached: [%d] -- retrying "+ - "connection in: %v", maxFailedAttempts, cm.cfg.RetryDuration) - go func() { - select { - case <-time.After(cm.cfg.RetryDuration): - cm.newConnReq(ctx) - case <-cm.quit: - case <-ctx.Done(): - } - }() - return +// rejectConnectedAddr returns an error if there is already either an +// established connection to the provided address or a pending attempt to +// connect to it. Persistent connections in the retry state are intentionally +// not detected. +// +// This function MUST be called with the connection mutex held (reads). +func (cm *ConnManager) rejectConnectedAddr(addr net.Addr) error { + addrStr := addr.String() + connID, ok := cm.connIDByAddr[addrStr] + if !ok { + return nil } - // Otherwise, attempt a new connection with a new address now. - go cm.newConnReq(ctx) + if _, ok := cm.pending[connID]; ok { + str := fmt.Sprintf("a pending connection to %s already exists", addr) + return MakeError(ErrAlreadyPending, str) + } + if _, ok := cm.active[connID]; ok { + str := fmt.Sprintf("a connection to %s is already established", addr) + return MakeError(ErrAlreadyConnected, str) + } + return nil } -// handleFailedPending handles failed pending connection requests. Connection -// requests that were canceled are ignored. Otherwise, their state is updated -// to mark it failed and it is passed along to [ConnManager.handleFailedConn] to -// possibly retry or be reused for a new connection depending on settings. +// findPersistentAddrID attempts to find and return the persistent connection ID +// associated with the passed address. The bool return indicates whether or not +// it was found. // -// This function MUST be called with the connection lock held (writes). -func (cm *ConnManager) handleFailedPending(ctx context.Context, c *ConnReq, failedErr error) { - if _, ok := cm.pending[c.ID()]; !ok { - log.Debugf("Ignoring connection for canceled conn req: %v", c) - return +// This function MUST be called with the connection mutex held (reads). +func (cm *ConnManager) findPersistentAddrID(addr net.Addr) (uint64, bool) { + addrStr := addr.String() + connID, ok := cm.connIDByAddr[addrStr] + if !ok { + return 0, false } - c.updateState(ConnFailed) - log.Debugf("Failed to connect to %v: %v", c, failedErr) - cm.handleFailedConn(ctx, c) + entry, ok := cm.persistent[connID] + if !ok { + return 0, false + } + + return entry.id, true } -// Connect assigns an id and dials a connection to the address of the connection -// request using the provided context and the dial function configured when -// initially creating the connection manager. +// rejectPersistentAddr returns an error if there is already a persistent +// connection entry for the given address. // -// The connection attempt will be ignored if the connection manager has been -// shutdown by canceling the lifecycle context the Run method was invoked with -// or the provided connection request is already in the failed state. +// This function MUST be called with the connection mutex held (reads). +func (cm *ConnManager) rejectPersistentAddr(addr net.Addr) error { + if _, ok := cm.findPersistentAddrID(addr); ok { + str := fmt.Sprintf("a persistent connection for %s already exists", addr) + return MakeError(ErrDuplicatePersistent, str) + } + return nil +} + +// rejectDuplicateAddr returns an error if there is already a persistent +// connection entry, a pending connection attempt, or an established connection +// for the given address. // -// Note that the context parameter to this function and the lifecycle context -// may be independent. -func (cm *ConnManager) Connect(ctx context.Context, c *ConnReq) { +// This function MUST be called with the connection mutex held (reads). +func (cm *ConnManager) rejectDuplicateAddr(addr net.Addr) error { + if err := cm.rejectPersistentAddr(addr); err != nil { + return err + } + if err := cm.rejectConnectedAddr(addr); err != nil { + return err + } + return nil +} + +// dial attempts to connect to the provided address and returns a connection +// configured with the provided params on success. +// +// A new globally unique connection ID is assigned unless one is provided by +// passing a non-nil value in the persistent connection ID parameter. This +// allows persistent connections to retain the same ID across reconnects. +// +// Attempts to dial addresses that are already connected, pending, or (in most +// cases) persistent will return an error as described below. Only established +// and pending connections are rejected when a non-nil persistent connection ID +// is passed. +// +// The following connection limits are enforced: +// +// - Total normal connections ([Config.MaxNormalConns]) +// +// On success, the returned connection is configured to remove itself from the +// set of all active connections and invoke the provided on close callback (if +// set) when it is closed. +// +// On failure, the provided on close callback will be invoked prior to +// returning. +// +// In addition to errors returned by the underlying dialer, the following errors +// are possible: +// +// - [ErrDuplicatePersistent] when a persistent connection already exists for +// the address and no persistent connection ID is provided +// - [ErrAlreadyPending] when there is already a pending connection attempt +// to the address +// - [ErrAlreadyConnected] when there is already an established connection to +// the address +// - [ErrMaxNormalConns] when there are already the maximum allowed number of +// normal connections (inbound, outbound, and pending) +// - [ErrShutdown] when the connection manager is shutting down +// - [context.Canceled] or [context.DeadlineExceeded] depending on the +// provided context or when the dialer fails to establish a connection +// before the timeout configured for the connection manager +// +// This function is safe for concurrent access. +func (cm *ConnManager) dial(ctx context.Context, addr net.Addr, connType ConnectionType, onClose func(), persistentConnID *uint64) (*Conn, error) { + var skipOnClose bool + defer func() { + if !skipOnClose && onClose != nil { + onClose() + } + }() + // Ignore during shutdown and when caller provided context is already // canceled. - select { - case <-cm.quit: - return - default: + if err := cm.checkShutdown(); err != nil { + return nil, err } if ctx.Err() != nil { - return + return nil, ctx.Err() } - // During the time we wait for retry there is a chance that this - // connection was already cancelled. - if c.State() == ConnCanceled { - log.Debugf("Ignoring connect for canceled connreq=%v", c) - return + rAddr, err := stdlibNetAddrToAddrMgrNetAddr(addr) + if err != nil { + return nil, err } - // Assign an ID and register the pending connection attempt when an ID has - // not already been assigned so it can be canceled via the - // [ConnManager.Remove] method. + // Reject attempts to dial addresses that are already connected (or in the + // process of it). Additionally, reject attempts to dial existing + // persistent addresses unless a persistent connection ID was provided + // indicating the dial is specifically for a persistent connection. // - // Note that the assignment of the ID and the overall request count need to - // be synchronized. So long as this is the only place an existing conn - // request ID is updated and this method is not called concurrently on the - // same conn request, no race could occur. However, those preconditions - // would be easy to inadvertently violate via updates to the code, so the - // mutex is added here for additional safety. - var doRegisterPending bool - cm.assignIDMtx.Lock() - if c.ID() == 0 { - c.id.Store(cm.connReqCount.Add(1)) - doRegisterPending = true - } - connReqID := c.ID() - cm.assignIDMtx.Unlock() - if doRegisterPending { - cm.connMtx.Lock() - cm.registerPending(c) + // This needs to be done under the same lock as adding a pending entry to + // prevent the possibility of two simultaneous attempts logic racing. + rejectFn := cm.rejectDuplicateAddr + if persistentConnID != nil { + rejectFn = cm.rejectConnectedAddr + } + cm.connMtx.Lock() + if err := rejectFn(addr); err != nil { cm.connMtx.Unlock() + log.Debugf("Rejected connection: %v", err) + return nil, err } - log.Debugf("Attempting to connect to %v", c) - - // Attempt to establish the connection to the address associated with the - // connection request. Apply a timeout if requested. + // Apply a dial timeout if requested. Otherwise, use a regular cancel + // context to support canceling the pending connection later. + var cancel context.CancelFunc if cm.cfg.DialTimeout != 0 { - var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, cm.cfg.DialTimeout) - defer cancel() + } else { + ctx, cancel = context.WithCancel(ctx) } - var conn net.Conn - conn, err := cm.cfg.Dial(ctx, c.Addr.Network(), c.Addr.String()) - if err != nil { + defer cancel() + + // Register the pending connection attempt and defer its removal to ensure + // it is always removed on failure. + var connID uint64 + if persistentConnID != nil { + connID = *persistentConnID + } else { + connID = cm.nextConnID.Add(1) + } + info := &pendingConnInfo{connID, rAddr, cancel} + cm.addPendingInfo(info) + cm.connMtx.Unlock() + defer func() { cm.connMtx.Lock() - cm.handleFailedPending(ctx, c, err) + if _, ok := cm.pending[connID]; ok { + cm.removePendingInfo(connID, rAddr) + } cm.connMtx.Unlock() - return + }() + + log.Debugf("Attempting to connect to %v (id: %d, type: %v)", addr, connID, + connType) + + // Attempt to establish the connection to the address. + netConn, err := cm.cfg.Dial(ctx, addr.Network(), addr.String()) + if err != nil { + var logErrStr string + switch { + case errors.Is(err, context.DeadlineExceeded): + logErrStr = fmt.Sprintf("no response for %v", cm.cfg.DialTimeout) + case errors.Is(err, context.Canceled): + // Override the error with the shutdown error instead when that is + // the upstream cause of the context cancel. + if sErr := cm.checkShutdown(); sErr != nil { + err = sErr + break + } + logErrStr = "attempt manually canceled" + } + if logErrStr == "" { + logErrStr = err.Error() + } + log.Debugf("Failed to connect to %v: %v", addr, logErrStr) + return nil, err } + // Ignore any connections that succeed after they were manually canceled. cm.connMtx.Lock() - defer cm.connMtx.Unlock() - - if _, ok := cm.pending[connReqID]; !ok { - conn.Close() - log.Debugf("Ignoring connection for canceled connreq=%v", c) - return + if _, ok := cm.pending[connID]; !ok { + cm.connMtx.Unlock() + netConn.Close() + log.Debugf("Ignoring canceled connection %v (id: %d, type: %v)", addr, + connID, connType) + return nil, context.Canceled } - c.updateState(ConnEstablished) - c.conn = conn - cm.conns[connReqID] = c - log.Debugf("Connected to %v", c) - c.retryCount = 0 - cm.failedAttempts = 0 - delete(cm.pending, connReqID) + // Remove the pending entry under the lock. This ensures the maps are + // mutually exclusive for a given id. + cm.removePendingInfo(connID, rAddr) - if cm.cfg.OnConnection != nil { - go cm.cfg.OnConnection(c, conn) + // Successful return means the on close callback is not invoked until the + // connection is closed. + skipOnClose = true + + // Setup a close callback to remove the connection from the map that tracks + // all active connections when the connection is closed and also to invoke + // the close callback provided by the caller when specified. + dialOnClose := func() { + cm.connMtx.Lock() + cm.removeActiveConn(connID, rAddr) + cm.connMtx.Unlock() + if onClose != nil { + onClose() + } + log.Debugf("Disconnected from %v (id: %d, type: %v)", addr, connID, + connType) } + + // Create a new connection instance with the connection ID and type and add + // an entry to the map that tracks all active connections. + conn := newConn(cm, netConn, connID, connType, rAddr, dialOnClose) + cm.addActiveConn(conn) + cm.connMtx.Unlock() + + log.Debugf("Connected to %v (id: %d, type: %v)", addr, connID, connType) + return conn, nil } -// handleDisconnected handles a connection that has been disconnected. +// Connect assigns an ID and dials a connection to the provided address using +// the provided context and the dial function configured when initially creating +// the connection manager. // -// This function MUST be called with the connection mutex held (writes). -func (cm *ConnManager) handleDisconnected(id uint64, retry bool) { - // Mark the connection request as canceled and remove it from the pending - // connections when it is still pending. Since the connection attempt is - // taking place asynchronously, this ensures any later successful connection - // is ignored. - connReq, ok := cm.pending[id] - if ok { - connReq.updateState(ConnCanceled) - log.Debugf("Canceling: %v", connReq) - delete(cm.pending, id) - } - - connReq, ok = cm.conns[id] - if !ok { - log.Errorf("Unknown connid=%d", id) - return - } - - // Close the underlying connection and invoke the associated callback (if - // assigned). - log.Debugf("Disconnected from %v", connReq) - delete(cm.conns, id) - if connReq.conn != nil { - connReq.conn.Close() +// Attempts to dial addresses that already have an established, pending, or +// persistent connection or would exceed max allowed limits will return an error +// as described below. +// +// The connection will have type [ConnTypeManual] and the following connection +// limits are enforced: +// +// - Total normal connections ([Config.MaxNormalConns]) +// +// Note that the context parameter to this function and the lifecycle context +// may be independent. +// +// In addition to errors returned by the underlying dialer, the following errors +// are possible: +// +// - [ErrDuplicatePersistent] when a persistent connection already exists for +// the address (regardless of its current state) +// - [ErrAlreadyPending] when there is already a pending connection attempt +// to the address +// - [ErrAlreadyConnected] when there is already an established connection to +// the address +// - [ErrMaxNormalConns] when there are already the maximum allowed number of +// normal connections (inbound, outbound, and pending) +// - [ErrShutdown] when the connection manager is shutting down +// - [context.Canceled] or [context.DeadlineExceeded] depending on the +// provided context or when the dialer fails to establish a connection +// before the timeout configured for the connection manager +// +// This function is safe for concurrent access. +func (cm *ConnManager) Connect(ctx context.Context, addr net.Addr) (*Conn, error) { + acquired, err := cm.totalNormalConnsSem.TryAcquire(ctx) + if err != nil { + if sErr := cm.checkShutdown(); sErr != nil { + return nil, sErr + } + return nil, err } - if cm.cfg.OnDisconnection != nil { - go cm.cfg.OnDisconnection(connReq) + if !acquired { + maxAllowed := cm.cfg.MaxNormalConns + str := fmt.Sprintf("a maximum of %d %s is allowed", maxAllowed, + pickNoun(maxAllowed, "connection", "connections")) + return nil, MakeError(ErrMaxNormalConns, str) } - - // Mark the associated connection request as disconnected and return when no - // further attempts will be made now that all internal state has been - // cleaned up. - if !retry { - connReq.updateState(ConnDisconnected) - return + onClose := cm.totalNormalConnsSem.Release + conn, err := cm.dial(ctx, addr, ConnTypeManual, onClose, nil) + if err != nil { + return nil, err } - - // Otherwise, attempt a reconnection when the associated connection request - // is marked as permanent or there are not already enough outbound peers to - // satisfy the target number of outbound peers. - numConns := uint32(len(cm.conns)) - if connReq.Permanent || numConns < cm.cfg.TargetOutbound { - // The connection request is reused for permanent ones, so add it back - // to the pending map in that case so that subsequent processing of - // connections and failures do not ignore the request. - if connReq.Permanent { - cm.registerPending(connReq) - log.Debugf("Reconnecting to %v", connReq) - } - - // A background context is the only viable choice here. It is not - // ideal, but it is acceptable, because, ultimately, this context is - // really only used for persistent peers when they retry and persistent - // peers are not tied to a specific context anyway. They are instead - // removed by other means. Due to that, there also is no machinery to - // cancel a given persistent peer from a given context anyway. - // - // Future work ideally should refactor the persistent peer handling to - // have proper full context support. - cm.handleFailedConn(context.Background(), connReq) + if cm.cfg.OnConnection != nil { + go cm.cfg.OnConnection(conn) } + return conn, nil } -// Disconnect disconnects the connection corresponding to the given connection -// id. Permanent connections will be retried with an increasing backoff -// duration. +// Disconnect either disconnects the connection corresponding to the given +// connection id or cancels any pending attempts associated with it. Persistent +// connections will be retried with an increasing backoff duration. // // This function is safe for concurrent access. -func (cm *ConnManager) Disconnect(id uint64) { +func (cm *ConnManager) Disconnect(id uint64) error { + // Cancel and remove pending entries. Even though the pending entry will be + // removed from the map regardless by the dialer, doing it now ensures that + // any connections that are already in progress and later succeed are + // ignored. cm.connMtx.Lock() - cm.handleDisconnected(id, true) + if info, ok := cm.pending[id]; ok { + info.cancel() + cm.removePendingInfo(id, info.addr) + cm.connMtx.Unlock() + return nil + } + + conn := cm.active[id] + if conn != nil { + cm.connMtx.Unlock() + conn.Close() // Close requires the conn mutex. + return nil + } + _, isPersistent := cm.persistent[id] cm.connMtx.Unlock() + + // Not found in active or pending, but it might still be a persistent conn + // waiting to retry. No error in that case. + if isPersistent { + return nil + } + + str := fmt.Sprintf("no entries with id %d exist", id) + return MakeError(ErrNotFound, str) } -// Remove removes the connection corresponding to the given connection id from -// known connections. +// Remove closes, cancels, or removes the connection corresponding to the given +// connection id. +// +// This function may be used for all connections states and types, including +// established, pending, and persistent connections. // -// NOTE: This method can also be used to cancel a pending connection attempt -// that hasn't yet succeeded. +// Connections that are already established are closed and connection attempts +// that are still pending are canceled. Persistent connections are additionally +// removed so that no future retries will occur. // // This function is safe for concurrent access. -func (cm *ConnManager) Remove(id uint64) { +func (cm *ConnManager) Remove(id uint64) error { + // When the ID is for a persistent connection, cancel the associated context + // and remove it from the persistent map to prevent future retries. cm.connMtx.Lock() - cm.handleDisconnected(id, false) + entry, isPersistent := cm.persistent[id] + if isPersistent { + cm.removePersistentEntry(id, entry.addr) + if entry.cancel != nil { + entry.cancel() + } + log.Debugf("Removed persistent connection to %v (id %d)", entry.addr, + entry.id) + } + + // Cancel and remove pending entries. Even though the pending entry will be + // removed from the map regardless by the dialer, doing it now ensures that + // any connections that are already in progress and later succeed are + // ignored. + if info, ok := cm.pending[id]; ok { + info.cancel() + cm.removePendingInfo(id, info.addr) + cm.connMtx.Unlock() + return nil + } + + // Close active connections and remove the entry from the active map. + // + // Even though the connection close handler would remove it from the map, it + // needs to be removed under same lock as removals from the persistent map + // to prevent the possibility of two simultaneous attempts logic racing. + if conn, ok := cm.active[id]; ok { + cm.removeActiveConn(id, &conn.remoteAddr) + cm.connMtx.Unlock() + conn.Close() // Close requires the conn mutex. + return nil + } cm.connMtx.Unlock() + + // Not found in active or pending, but no error if it was a removed + // persistent conn. + if isPersistent { + return nil + } + + str := fmt.Sprintf("no entries with id %d exist", id) + return MakeError(ErrNotFound, str) } -// findPendingByAddr attempts to find and return the pending connection request -// associated with the provided address. It returns nil if no matching request -// is found. -// -// This function MUST be called with the connection mutex held (writes). -func (cm *ConnManager) findPendingByAddr(addr net.Addr) *ConnReq { - pendingAddr := addr.String() - for _, req := range cm.pending { - if req == nil || req.Addr == nil { +// listenHandler accepts incoming connections on a given listener. It must be +// run as a goroutine. +func (cm *ConnManager) listenHandler(ctx context.Context, listener net.Listener) { + log.Infof("Server listening on %s", listener.Addr()) + defer log.Tracef("Listener handler done for %s", listener.Addr()) + + for ctx.Err() == nil { + // The following is intentionally implementing active connection + // shedding by accepting connections and then immediately disconnecting + // them after the [net.Listener.Accept] call if any policies are + // violated. + // + // Reversing it and blocking until a permit is available and only then + // calling Accept would cause the connections to build up in the kernel. + // Then, since the kernel will still create the 3-way handshake, clients + // would connect and hang until their own timeouts are hit, and, + // eventually, the entire service could appear entirely down if the SYN + // queue were to fill. It also would not allow implementing better + // additional policies. + netConn, err := listener.Accept() + if err != nil { + // Only log the error if not forcibly shutting down. + if ctx.Err() == nil { + log.Errorf("Can't accept connection: %v", err) + } continue } - if pendingAddr == req.Addr.String() { - return req + + rAddr, err := stdlibNetAddrToAddrMgrNetAddr(netConn.RemoteAddr()) + if err != nil { + log.Warnf("Dropped connection from %v: failed to parse address", + netConn.RemoteAddr()) + netConn.Close() + continue } + + // Reject connections with the same host:port as any existing pending, + // established, or persistent connections. Note that this does NOT + // prevent multiple connections from the same host given they typically + // will be coming from different ports. + // + // The aforementioned behavior is intentional as it allows connections + // from the same host to be independently limited to more than one + // elsewhere. + cm.connMtx.Lock() + if err := cm.rejectDuplicateAddr(rAddr); err != nil { + cm.connMtx.Unlock() + log.Debugf("Dropped connection from %v: %v", rAddr, err) + netConn.Close() + continue + } + cm.connMtx.Unlock() + + // Require a permit to allow the inbound connection unless the address + // has special permissions (e.g. whitelisted). + // + // Attempt to acquire a permit via a non-blocking call and immediately + // disconnect if unsuccessful so that all blocking happens on + // [net.Listener.Accept] for the reasons described above. + requirePermit := !cm.IsWhitelisted(rAddr) + if requirePermit { + acquired, err := cm.totalNormalConnsSem.TryAcquire(ctx) + if err != nil { + netConn.Close() + continue + } + if !acquired { + maxAllowed := cm.cfg.MaxNormalConns + log.Debugf("Dropped connection from %v: a maximum of %d %s is "+ + "allowed", rAddr, maxAllowed, pickNoun(maxAllowed, + "connection", "connections")) + netConn.Close() + continue + } + } + go func(netConn net.Conn, requirePermit bool) { + // Create a new connection instance with the next globally unique + // connection ID, add an entry to the map that tracks all active + // connections, and invoke the configured accept callback with it. + // + // Also set a close callback to remove the connection from the map + // when it is closed. + id := cm.nextConnID.Add(1) + const connType = ConnTypeInbound + onClose := func() { + cm.connMtx.Lock() + cm.removeActiveConn(id, rAddr) + cm.connMtx.Unlock() + log.Debugf("Disconnected from %v (id: %d, type: %v)", rAddr, id, + connType) + if requirePermit { + cm.totalNormalConnsSem.Release() + } + } + conn := newConn(cm, netConn, id, connType, rAddr, onClose) + cm.connMtx.Lock() + cm.addActiveConn(conn) + cm.connMtx.Unlock() + log.Debugf("Accepted connection from %v (id: %d, type: %v)", rAddr, + id, connType) + cm.cfg.OnAccept(conn) + }(netConn, requirePermit) } - return nil } -// CancelPending removes the connection corresponding to the given address -// from the list of pending failed connections. +// AddPersistent adds an address the connection manager will attempt to always +// maintain an established connection with until the persistent connection entry +// is removed via [ConnManager.Remove] or the context associated with +// [ConnManager.Run] is canceled. +// +// When the associated connection is dropped, it will be retried with an +// increasing backoff, up to a maximum for repeated failed attempts. +// +// A maximum of [MaxPersistent] connections may be added. Attempting to add any +// more will return [ErrMaxPersistent]. +// +// Adding a duplicate persistent address will return [ErrDuplicatePersistent] +// and adding addresses that already have an established or pending connection +// will return [ErrAlreadyConnected] or [ErrAlreadyPending], respectively. // -// Returns an error if the connection manager is stopped or there is no pending -// connection for the given address. -func (cm *ConnManager) CancelPending(addr net.Addr) error { +// An ID is returned that uniquely identifies the persistent connection. All +// future connections established will have the same ID. +// +// Persistent connections do not count against [Config.TargetOutbound]. +// +// Note that the actual connections to the address happen asynchronously and +// will have type [ConnTypeManual]. Established connections will invoke the +// [Config.OnConnection] callback that was configured when initially creating +// the connection manager. +// +// Since connections happen asynchronously, the error only indicates issues with +// adding the persistent connection entry. +// +// The persistent connection may be removed by passing the returned connection +// ID to [ConnManager.Remove]. +// +// This function is safe for concurrent access. +func (cm *ConnManager) AddPersistent(addr net.Addr) (uint64, error) { cm.connMtx.Lock() defer cm.connMtx.Unlock() - connReq := cm.findPendingByAddr(addr) - if connReq == nil { - str := fmt.Sprintf("no pending connection to %v", addr) - return MakeError(ErrNotFound, str) + if len(cm.persistent)+1 > MaxPersistent { + str := fmt.Sprintf("a maximum of %d persistent connections is allowed", + MaxPersistent) + return 0, MakeError(ErrMaxPersistent, str) } - delete(cm.pending, connReq.ID()) - connReq.updateState(ConnCanceled) - log.Debugf("Canceled pending connection to %v", addr) - return nil + if err := cm.rejectDuplicateAddr(addr); err != nil { + return 0, err + } + + rAddr, err := stdlibNetAddrToAddrMgrNetAddr(addr) + if err != nil { + return 0, err + } + + entry := &persistentEntry{id: cm.nextConnID.Add(1), addr: rAddr} + cm.addPersistentEntry(entry) + log.Debugf("Added persistent connection to %v (id: %d)", addr, entry.id) + + // The channel is buffered with the max allowed persistent conns, so there + // is no possibility of blocking here. This approach allows persistent + // peers to be added both before and after the connection manager is running + // without starting the goroutines before it is running. + cm.runPersistentChan <- entry + return entry.id, nil } -// ForEachConnReq calls the provided function with each connection request known -// to the connection manager, including pending requests. Returning an error -// from the provided function will stop the iteration early and return said -// error from this function. +// IsPersistent returns whether or not the provided connection id belongs to a +// persistent connection. // // This function is safe for concurrent access. +func (cm *ConnManager) IsPersistent(id uint64) bool { + cm.connMtx.Lock() + _, ok := cm.persistent[id] + cm.connMtx.Unlock() + return ok +} + +// FindPersistentAddrID attempts to find and return the persistent connection ID +// associated with the passed address. The bool return indicates whether or not +// it was found. // -// NOTE: This must not call any other connection manager methods during -// iteration or it will result in a deadlock. -func (cm *ConnManager) ForEachConnReq(f func(c *ConnReq) error) error { +// This function is safe for concurrent access. +func (cm *ConnManager) FindPersistentAddrID(addr net.Addr) (uint64, bool) { cm.connMtx.Lock() - defer cm.connMtx.Unlock() + id, ok := cm.findPersistentAddrID(addr) + cm.connMtx.Unlock() + return id, ok +} - var err error - for _, connReq := range cm.pending { - err = f(connReq) - if err != nil { - return err +// runPersistent attempts to maintain a persistent connection to the provided +// address until the passed context is canceled. +// +// When the associated connection is dropped, it will be retried with an +// increasing backoff, up to a maximum for repeated failed attempts. +// +// This MUST be run as a goroutine. +func (cm *ConnManager) runPersistent(ctx context.Context, connID uint64, addr net.Addr) { + // Ensure the connection is closed when the goroutine exits. + var conn *Conn + defer func() { + if conn != nil { + conn.Close() } + }() + + // Setup a callback that notifies a disconnect channel for use below and + // start with the channel signaled. + disconnected := make(chan struct{}, 1) + disconnected <- struct{}{} + onClose := func() { + disconnected <- struct{}{} } - for _, connReq := range cm.conns { - err = f(connReq) + + var retryCount uint32 + var retryAfter <-chan time.Time + var lastAttempt time.Time + for { + // Wait for disconnect or retry timer when it's set. + select { + case <-ctx.Done(): + return + case <-cm.quit: + return + case <-retryAfter: + case <-disconnected: + // Wait to retry any time the connection was not maintained for at + // least a single retry interval. + // + // This approach is used over only incrementing the retry count when + // the dial fails to effectively rate limit the attempts with an + // increasing backoff regardless of the reason a stable connection + // was not maintained. + // + // For example, the remote might repeatedly reject the peer for a + // variety of reasons (max limits, not enough peers of a desired + // type, etc) after a successful connection is made. + if !lastAttempt.IsZero() && time.Since(lastAttempt) < cm.cfg.RetryDuration { + // Reconnect after a retry timeout with an increasing backoff up + // to a max for repeated failed attempts. + const maxUint32 = 1<<32 - 1 + if retryCount < maxUint32 { + retryCount++ + } + retryWait := time.Duration(retryCount) * cm.cfg.RetryDuration + retryWait = min(retryWait, cm.maxRetryDuration) + log.Debugf("Retrying connection to %v in %v (retries %d)", addr, + retryWait, retryCount) + retryAfter = time.After(retryWait) + continue + } + + // A connection succeeded and was maintained for at least a single + // retry interval. + // + // Clear the retry state. + retryCount = 0 + retryAfter = nil + } + + lastAttempt = time.Now() + var err error + conn, err = cm.dial(ctx, addr, ConnTypeManual, onClose, &connID) if err != nil { - return err + if ctx.Err() != nil { + return + } + + // Retry, potentially after a timeout with backoff. + continue + } + + // Successful connection. + if cm.cfg.OnConnection != nil { + go cm.cfg.OnConnection(conn) } } - return nil } -// listenHandler accepts incoming connections on a given listener. It must be -// run as a goroutine. -func (cm *ConnManager) listenHandler(ctx context.Context, listener net.Listener) { - log.Infof("Server listening on %s", listener.Addr()) +// persistentConnsHandler handles launching individual goroutines for persistent +// connections. +func (cm *ConnManager) persistentConnsHandler(ctx context.Context) { + for { + select { + case entry := <-cm.runPersistentChan: + pCtx, cancel := context.WithCancel(ctx) + cm.connMtx.Lock() + entry.cancel = cancel + cm.connMtx.Unlock() + go cm.runPersistent(pCtx, entry.id, entry.addr) + + case <-ctx.Done(): + return + } + } +} + +// targetOutboundHandler attempts to automatically maintain the target number of +// outbound connections configured via [Config.TargetOutbound] when initially +// creating the connection manager. +// +// This MUST be run as a goroutine. +func (cm *ConnManager) targetOutboundHandler(ctx context.Context) { + log.Trace("Starting target outbound handler") + defer log.Trace("Target outbound handler done") + + // failedAttempts tracks the total number of failed outbound connection + // attempts since the last successful connection. It is primarily used to + // detect network outages in order to impose a retry timeout on achieving + // the target number of outbound connections which prevents runaway failed + // connection attempt churn. + // + // Overflow is not checked since it would be virtually impossible to hit + // anywhere max uint64 in practice and even if it ever happened, the only + // consequence would potentially be a few extra retries before it hit the + // max failures again. + var failedAttempts atomic.Uint64 + for ctx.Err() == nil { - conn, err := listener.Accept() - if err != nil { - // Only log the error if not forcibly shutting down. - if ctx.Err() == nil { - log.Errorf("Can't accept connection: %v", err) + // Pause automatic outbound connections for a retry timeout after too + // many failed connection attempts. The network very likely has become + // temporarily unreachable. + if failedAttempts.Load() >= maxFailedAttempts { + log.Debugf("Max failed connection attempts reached [%d] -- "+ + "pausing connections for %v", maxFailedAttempts, + cm.cfg.RetryDuration) + + select { + case <-time.After(cm.cfg.RetryDuration): + case <-cm.quit: + return + case <-ctx.Done(): + return } + } + + // Wait for a permit to make another outbound connection. + if !cm.activeOutboundsSem.Acquire(ctx) { + return + } + + // Wait for a permit to make another overall connection. This limits + // the total number of normal connections while the previous limits the + // total number of automatic outbound connections. + if !cm.totalNormalConnsSem.Acquire(ctx) { + cm.activeOutboundsSem.Release() + return + } + + addr, err := cm.cfg.GetNewAddress() + if err != nil { + failedAttempts.Add(1) + log.Debugf("Failed to get address for outbound connection: %v", err) + cm.totalNormalConnsSem.Release() + cm.activeOutboundsSem.Release() continue } - go cm.cfg.OnAccept(conn) - } - log.Tracef("Listener handler done for %s", listener.Addr()) + go func(addr net.Addr) { + onClose := func() { + cm.totalNormalConnsSem.Release() + cm.activeOutboundsSem.Release() + } + conn, err := cm.dial(ctx, addr, ConnTypeOutbound, onClose, nil) + if err != nil { + failedAttempts.Add(1) + return + } + + failedAttempts.Store(0) + if cm.cfg.OnConnection != nil { + go cm.cfg.OnConnection(conn) + } + }(addr) + } } // Run starts the connection manager along with its configured listeners and // begins connecting to the network. It blocks until the provided context is -// cancelled. +// canceled. func (cm *ConnManager) Run(ctx context.Context) { log.Trace("Starting connection manager") + defer log.Trace("Connection manager stopped") // Start all the listeners so long as the caller requested them and provided // a callback to be invoked when connections are accepted. @@ -610,27 +1278,56 @@ func (cm *ConnManager) Run(ctx context.Context) { }(listener) } - // Start enough outbound connections to reach the target number when not - // in manual connect mode. + // Start persistent connections handler which starts individual goroutines + // for each persistent connection already added and any newly added ones + // later. + wg.Add(1) + go func() { + cm.persistentConnsHandler(ctx) + wg.Done() + }() + + // Start outbound connection handler to maintain the target number of + // normal outbound connections when not in manual connect mode. if cm.cfg.GetNewAddress != nil { - curConnReqCount := cm.connReqCount.Load() - for i := curConnReqCount; i < uint64(cm.cfg.TargetOutbound); i++ { - go cm.newConnReq(ctx) - } + wg.Add(1) + go func() { + cm.targetOutboundHandler(ctx) + wg.Done() + }() } - // Stop all the listeners and shutdown the connection manager when the - // context is cancelled. There will not be any listeners if listening is - // disabled. + // Shutdown the connection manager when the context is canceled. <-ctx.Done() close(cm.quit) + + // Stop all the listeners. There will not be any listeners if listening is + // disabled. for _, listener := range listeners { // Ignore the error since this is shutdown and there is no way // to recover anyways. _ = listener.Close() } + + // Shutdown persistent conns, cancel pending conns, and close active conns. + cm.connMtx.Lock() + totalIDs := len(cm.persistent) + len(cm.pending) + len(cm.active) + ids := make(map[uint64]struct{}, totalIDs) + for id := range cm.persistent { + ids[id] = struct{}{} + } + for id := range cm.pending { + ids[id] = struct{}{} + } + for id := range cm.active { + ids[id] = struct{}{} + } + cm.connMtx.Unlock() + for id := range ids { + cm.Remove(id) + } + wg.Wait() - log.Trace("Connection manager stopped") } // New returns a new connection manager with the provided configuration. @@ -640,18 +1337,28 @@ func New(cfg *Config) (*ConnManager, error) { if cfg.Dial == nil { return nil, MakeError(ErrDialNil, "dial cannot be nil") } - // Default to sane values + // Default to sane values. if cfg.RetryDuration <= 0 { cfg.RetryDuration = defaultRetryDuration } + if cfg.MaxNormalConns == 0 { + cfg.MaxNormalConns = defaultMaxNormalConns + } if cfg.TargetOutbound == 0 { cfg.TargetOutbound = defaultTargetOutbound } + cfg.TargetOutbound = min(cfg.TargetOutbound, cfg.MaxNormalConns) cm := ConnManager{ - cfg: *cfg, // Copy so caller can't mutate - quit: make(chan struct{}), - pending: make(map[uint64]*ConnReq), - conns: make(map[uint64]*ConnReq, cfg.TargetOutbound), + cfg: *cfg, // Copy so caller can't mutate + quit: make(chan struct{}), + maxRetryDuration: defaultMaxRetryDuration, + runPersistentChan: make(chan *persistentEntry, MaxPersistent), + totalNormalConnsSem: makeSemaphore(cfg.MaxNormalConns), + activeOutboundsSem: makeSemaphore(cfg.TargetOutbound), + persistent: make(map[uint64]*persistentEntry, MaxPersistent), + pending: make(map[uint64]pendingConnInfo), + active: make(map[uint64]*Conn, cfg.TargetOutbound), + connIDByAddr: make(map[string]uint64), } return &cm, nil } diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index 65900d583..5b42ab06b 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -11,15 +11,37 @@ import ( "fmt" "io" "net" + "net/netip" "sync" "sync/atomic" "testing" "time" ) -func init() { - // Override the max retry duration when running tests. - maxRetryDuration = 2 * time.Millisecond +const ( + // defaultTestMaxRetryDuration is the default max duration a connection + // retry backoff is allowed to grow to when running tests. + defaultTestMaxRetryDuration = 2 * time.Millisecond + + // connTestReceiveTimeout is the default receive timeout used throughout the + // tests when expecting to receive connections to prevent test hangs. + connTestReceiveTimeout = 10 * time.Millisecond + + // connTestNonReceiveTimeout is the default timeout used throughout the + // tests when expecting that a connection will NOT be received. + connTestNonReceiveTimeout = 20 * time.Millisecond +) + +// mustParseAddrPort parses the provided address into a [*net.TCPAddr] and will +// panic if there is an error. It will only (and must only) be called with +// hard-coded, and therefore known good, addresses. +func mustParseAddrPort(addr string) *net.TCPAddr { + addrPort := netip.MustParseAddrPort(addr) + return &net.TCPAddr{ + IP: addrPort.Addr().AsSlice(), + Port: int(addrPort.Port()), + Zone: addrPort.Addr().Zone(), + } } // runConnMgrAsync invokes the Run method on the passed connection manager in a @@ -86,85 +108,571 @@ func mockDialer(ctx context.Context, network, addr string) (net.Conn, error) { return c, ctx.Err() } +// newTestConnManager returns a new connection manager with the provided +// configuration and some timeout tweaks so that it is suitable for use in the +// tests. +func newTestConnManager(t *testing.T, cfg *Config) *ConnManager { + t.Helper() + + cmgr, err := New(cfg) + if err != nil { + t.Fatalf("New: unexpected error: %v", err) + } + cmgr.maxRetryDuration = defaultTestMaxRetryDuration + return cmgr +} + // TestNewConfig tests that new ConnManager config is validated as expected. func TestNewConfig(t *testing.T) { + t.Parallel() + _, err := New(&Config{}) if err == nil { t.Fatal("New expected error: 'Dial can't be nil', got nil") } - _, err = New(&Config{ + + newTestConnManager(t, &Config{ Dial: mockDialer, }) - if err != nil { - t.Fatalf("New unexpected error: %v", err) +} + +// TestIsWhitelisted ensures [ConnManager.IsWhitelisted] works as expected. +func TestIsWhitelisted(t *testing.T) { + type perManagerTest struct { + addr string // address to test against whitelist + whitelisted bool // expected whitelisted result + } + + tests := []struct { + name string // test description + prefixes []string // CIDR prefixes to whitelist + perManagerTests []perManagerTest // tests to run against the prefixes + }{{ + name: "no whitelisted entries", + prefixes: nil, + perManagerTests: []perManagerTest{ + {"1.2.3.4:18555", false}, + {"127.0.0.1:18555", false}, + }, + }, { + name: "single /32 IPv4 entry", + prefixes: []string{"1.2.3.4/32"}, + perManagerTests: []perManagerTest{ + {"1.2.3.4:18555", true}, + {"1.2.3.4:9108", true}, + {"[::1.2.3.4]:18555", false}, // IPv4 in IPv6 + {"1.2.3.5:18555", false}, + }, + }, { + name: "single /128 IPv6 entry", + prefixes: []string{"::1.2.3.4/128"}, + perManagerTests: []perManagerTest{ + {"[::1.2.3.4]:18555", true}, + {"[::1.2.3.4]:9108", true}, + {"1.2.3.4:18555", false}, // IPv4 doesn't match IPv4 in IPv6 + {"[::1.2.3.5]:9108", false}, + }, + }, { + name: "mixed IPv4 and IPv6 with different prefix lengths", + prefixes: []string{"12.13.14.0/24", "20.21.22.23/8", "fe80::/64"}, + perManagerTests: []perManagerTest{ + {"12.13.14.1:18555", true}, + {"12.13.14.255:18555", true}, + {"12.13.15.0:18555", false}, + {"20.0.0.0:18555", true}, + {"20.0.0.0:9108", true}, + {"20.255.255.255:18555", true}, + {"20.255.255.255:9108", true}, + {"21.0.0.0:18555", false}, + {"[fe80::1]:18555", true}, + {"[fe80::1]:9108", true}, + {"[fe80::ffff:ffff:ffff:ffff]:18555", true}, + {"[fe80::ffff:ffff:ffff:ffff]:1234", true}, + {"[fe80::1:ffff:ffff:ffff:ffff]:18555", false}, + }, + }} + + for _, test := range tests { + // Parse the whitelist entries for the test. + prefixes := make([]netip.Prefix, 0, len(test.prefixes)) + for _, prefixStr := range test.prefixes { + prefix, err := netip.ParsePrefix(prefixStr) + if err != nil { + t.Fatalf("%q: failed to parse prefix %q: %v", test.name, prefixStr, + err) + } + prefixes = append(prefixes, prefix) + } + cmgr := newTestConnManager(t, &Config{ + Dial: mockDialer, + Whitelists: prefixes, + }) + + for _, pmTest := range test.perManagerTests { + mAddr := mockAddr{"tcp", pmTest.addr} + addr, err := stdlibNetAddrToAddrMgrNetAddr(mAddr) + if err != nil { + t.Fatalf("%q-%q: failed to parse address: %v", test.name, + pmTest.addr, err) + } + if got := cmgr.IsWhitelisted(addr); got != pmTest.whitelisted { + t.Errorf("%q-%q: mismatched result -- got %v, want %v", + test.name, pmTest.addr, got, pmTest.whitelisted) + continue + } + } } } -// assertConnReqID ensures the provided connection request has the given ID. -func assertConnReqID(t *testing.T, connReq *ConnReq, wantID uint64) { +// assertConnID ensures the provided connection has the given ID. +func assertConnID(t *testing.T, conn *Conn, wantID uint64) { t.Helper() - gotID := connReq.ID() + gotID := conn.ID() if gotID != wantID { t.Fatalf("unexpected ID -- got %v, want %v", gotID, wantID) } } -// assertConnReqState ensures the provided connection request has the given -// state. -func assertConnReqState(t *testing.T, connReq *ConnReq, wantState ConnState) { +// assertConnType ensures the provided connection has the given type. +func assertConnType(t *testing.T, conn *Conn, wantType ConnectionType) { + t.Helper() + + gotType := conn.Type() + if gotType != wantType { + t.Fatalf("unexpected type -- got %v, want %v", gotType, wantType) + } +} + +// pendingAddrConnID returns the connection ID associated with the pending +// connection attempt for the provided address. The second return value will be +// false if no pending attempt is found. +func pendingAddrConnID(cm *ConnManager, addr net.Addr) (uint64, bool) { + cm.connMtx.Lock() + defer cm.connMtx.Unlock() + addrStr := addr.String() + for _, info := range cm.pending { + if info.addr.String() == addrStr { + return info.id, true + } + } + return 0, false +} + +// assertPendingAddr ensures there is a pending connection with the given +// address. +func assertPendingAddr(t *testing.T, cm *ConnManager, addr net.Addr) { + t.Helper() + + if _, ok := pendingAddrConnID(cm, addr); !ok { + t.Fatalf("connection %s is not pending", addr) + } +} + +// assertRemovedPersistent ensures there are no persistent conns with the +// provided address. +func assertRemovedPersistent(t *testing.T, cm *ConnManager, addr net.Addr) { + t.Helper() + + if _, ok := cm.FindPersistentAddrID(addr); ok { + t.Fatalf("found persistent entry for %s", addr) + } +} + +// assertConnReceivedTimeout ensures a connection with the given type is +// received on the provided channel before the given timeout. When given a +// non-zero connection ID, it asserts the received connection has that ID. +func assertConnReceivedTimeout(t *testing.T, ch <-chan *Conn, timeout time.Duration, connID uint64, connType ConnectionType) *Conn { + t.Helper() + + select { + case conn := <-ch: + if connID != 0 { + assertConnID(t, conn, connID) + } + assertConnType(t, conn, connType) + return conn + case <-time.After(timeout): + t.Fatal("connection not received before timeout") + } + return nil +} + +// assertConnReceived ensures a connection with the given type is received on +// the provided channel before the default timeout. When given a non-zero +// connection ID, it asserts the received connection has that ID. +func assertConnReceived(t *testing.T, ch <-chan *Conn, connID uint64, connType ConnectionType) *Conn { + t.Helper() + + return assertConnReceivedTimeout(t, ch, connTestReceiveTimeout, connID, + connType) +} + +// assertNoConnReceivedTimeout ensures no connections are received on the +// provided channel before the given timeout. +func assertNoConnReceivedTimeout(t *testing.T, ch <-chan *Conn, timeout time.Duration) { t.Helper() - gotState := connReq.State() - if gotState != wantState { - t.Fatalf("unexpected state -- got %v, want %v", gotState, wantState) + select { + case conn := <-ch: + conn.Close() + t.Fatalf("got unexpected connection from %v", conn.RemoteAddr()) + case <-time.After(timeout): + // Connection not received as expected. } } +// assertNoConnReceived ensures no connections are received on the provided +// channel before the default timeout. +func assertNoConnReceived(t *testing.T, ch <-chan *Conn) { + t.Helper() + + assertNoConnReceivedTimeout(t, ch, connTestNonReceiveTimeout) +} + // TestConnectMode tests that the connection manager works in the connect mode. // -// In connect mode, automatic connections are disabled, so we test that -// requests using Connect are handled and that no other connections are made. +// In connect mode, automatic connections are disabled, so test that connections +// using [ConnManager.Connect] are handled and that no other connections are +// made. func TestConnectMode(t *testing.T) { - connected := make(chan *ConnReq) - cmgr, err := New(&Config{ + t.Parallel() + + connected := make(chan *Conn) + cmgr := newTestConnManager(t, &Config{ TargetOutbound: 2, Dial: mockDialer, - OnConnection: func(c *ConnReq, conn net.Conn) { - connected <- c + OnConnection: func(conn *Conn) { + connected <- conn }, }) - if err != nil { - t.Fatalf("New error: %v", err) + ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + + addr := mustParseAddrPort("127.0.0.1:18555") + go cmgr.Connect(ctx, addr) + + // Ensure that only a single connection is received. + assertConnReceived(t, connected, 0, ConnTypeManual) + assertNoConnReceived(t, connected) + + // Ensure clean shutdown of connection manager. + shutdown() + wg.Wait() +} + +// TestDisconnect ensures that [ConnManager.Disconnect] properly disconnects +// pending and established connections for both non-persistent and persistent +// connections. +func TestDisconnect(t *testing.T) { + t.Parallel() + + // Create a connection manager instance with a dialer that has a few + // synchronization channels to notify when a dial attempt is made, to keep + // connection attempts in a pending state, and to notify when the context + // for the attempt is canceled. Whether or not to wait/send the signals are + // controlled by the associated atomic flags. + connected := make(chan *Conn) + disconnected := make(chan *Conn) + dialed := make(chan struct{}) + pending := make(chan struct{}) + canceled := make(chan struct{}) + var notifyDialed, waitForPending, notifyCanceled atomic.Bool + pendingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + if notifyDialed.Load() { + dialed <- struct{}{} + } + if waitForPending.Load() { + <-pending + } + conn, err := mockDialer(ctx, network, addr) + if errors.Is(err, context.Canceled) && notifyCanceled.Load() { + canceled <- struct{}{} + } + return conn, err } + cmgr := newTestConnManager(t, &Config{ + Dial: pendingDialer, + OnConnection: func(conn *Conn) { + connected <- conn + }, + OnDisconnection: func(conn *Conn) { + disconnected <- conn + }, + }) ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) - cr := &ConnReq{ - Addr: &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, + // Attempt a connection to a localhost IP. + notifyDialed.Store(true) + waitForPending.Store(true) + notifyCanceled.Store(true) + addr := mustParseAddrPort("127.0.0.1:18555") + go cmgr.Connect(ctx, addr) + + // Wait for the connection manager to attempt to dial and ensure the + // connection is marked as pending while the dialer is blocked. + select { + case <-dialed: + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting for dial") + } + assertPendingAddr(t, cmgr, addr) + + // Disconnect the connection attempt while it's still pending. + connID, _ := pendingAddrConnID(cmgr, addr) + if err := cmgr.Disconnect(connID); err != nil { + t.Fatalf("unexpected disconnect err: %v", err) + } + + // Allow the dialer to proceed with the disconnected connection attempt and + // then wait for the dialer to signal the context associated with the dial + // was canceled. Finally, ensure the internal pending state is removed. + select { + case pending <- struct{}{}: + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting to signal pending") + } + select { + case <-canceled: + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting for cancel") + } + if _, ok := pendingAddrConnID(cmgr, addr); ok { + t.Fatalf("connection %s is still pending", addr) + } + + // Start a connection attempt and wait for it to be established. + notifyDialed.Store(false) + waitForPending.Store(false) + notifyCanceled.Store(false) + go cmgr.Connect(ctx, addr) + conn := assertConnReceived(t, connected, 0, ConnTypeManual) + + // Disconnect the established connection and wait for the disconnect + // notification to ensure it is disconnected as intended. + connID = conn.ID() + if err := cmgr.Disconnect(connID); err != nil { + t.Fatalf("unexpected disconnect err: %v", err) + } + assertConnReceived(t, disconnected, connID, ConnTypeManual) + + // Add a persistent connection back to the same address. + notifyDialed.Store(true) + waitForPending.Store(true) + notifyCanceled.Store(true) + connID, err := cmgr.AddPersistent(addr) + if err != nil { + t.Fatalf("failed to add persistent connection: %v", err) + } + + // Wait for the connection manager to attempt to dial and ensure the + // connection is marked as pending while the dialer is blocked. + select { + case <-dialed: + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting for dial") + } + assertPendingAddr(t, cmgr, addr) + + // Disconnect the persistent connection attempt while it's still pending. + if err := cmgr.Disconnect(connID); err != nil { + t.Fatalf("unexpected disconnect err: %v", err) + } + + // Allow the dialer to proceed with the disconnected persistent connection + // attempt and then wait for the dialer to signal the context associated + // with the dial was canceled. + select { + case pending <- struct{}{}: + // Ensure the reconnect attempt doesn't notify the dialed chan or + // wait for the pending chan. + notifyDialed.Store(false) + waitForPending.Store(false) + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting to signal pending") + } + select { + case <-canceled: + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting for cancel") + } + + // Wait for the retry to be established. + assertConnReceived(t, connected, connID, ConnTypeManual) + + // Disconnect the established persistent connection and wait for the + // disconnect notification to ensure it is disconnected as intended. + if err := cmgr.Disconnect(connID); err != nil { + t.Fatalf("unexpected disconnect err: %v", err) + } + assertConnReceived(t, disconnected, connID, ConnTypeManual) + + // Ensure clean shutdown of connection manager. + shutdown() + wg.Wait() +} + +// TestRemove ensures that [ConnManager.Remove] properly removes pending and +// established connections for both non-persistent and persistent connections. +// +// It also ensures removal of an invalid ID returns the expected error. +func TestRemove(t *testing.T) { + t.Parallel() + + // Create a connection manager instance with a dialer that has a few + // synchronization channels to notify when a dial attempt is made, to keep + // connection attempts in a pending state, and to notify when the context + // for the attempt is canceled. Whether or not to wait/send the signals are + // controlled by the associated atomic flags. + connected := make(chan *Conn) + disconnected := make(chan *Conn) + dialed := make(chan struct{}) + pending := make(chan struct{}) + canceled := make(chan struct{}) + var notifyDialed, waitForPending, notifyCanceled atomic.Bool + pendingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + if notifyDialed.Load() { + dialed <- struct{}{} + } + if waitForPending.Load() { + <-pending + } + conn, err := mockDialer(ctx, network, addr) + if errors.Is(err, context.Canceled) && notifyCanceled.Load() { + canceled <- struct{}{} + } + return conn, err + } + cmgr := newTestConnManager(t, &Config{ + Dial: pendingDialer, + OnConnection: func(conn *Conn) { + connected <- conn }, - Permanent: true, + OnDisconnection: func(conn *Conn) { + disconnected <- conn + }, + }) + ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + + // Ensure removing an ID that doesn't exist returns the expected error. + if err := cmgr.Remove(^uint64(0)); !errors.Is(err, ErrNotFound) { + t.Fatalf("mismatched remove error: got %v, want %v", err, ErrNotFound) } - go cmgr.Connect(ctx, cr) - // Ensure that the connection was received. + // Attempt a connection to a localhost IP. + notifyDialed.Store(true) + waitForPending.Store(true) + notifyCanceled.Store(true) + addr := mustParseAddrPort("127.0.0.1:18555") + go cmgr.Connect(ctx, addr) + + // Wait for the connection manager to attempt to dial and ensure the + // connection is marked as pending while the dialer is blocked. select { - case gotConnReq := <-connected: - assertConnReqID(t, gotConnReq, cr.ID()) - assertConnReqState(t, cr, ConnEstablished) + case <-dialed: + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting for dial") + } + assertPendingAddr(t, cmgr, addr) + + // Remove the connection attempt while it's still pending. + connID, _ := pendingAddrConnID(cmgr, addr) + if err := cmgr.Remove(connID); err != nil { + t.Fatalf("unexpected remove err: %v", err) + } + // Allow the dialer to proceed with the removed connection attempt and then + // wait for the dialer to signal the context associated with the dial was + // canceled. Finally, ensure the internal pending state is removed. + select { + case pending <- struct{}{}: case <-time.After(time.Millisecond * 5): - t.Fatalf("connect mode: connection timeout - %v", cr.Addr) + t.Fatal("timeout waiting to signal pending") + } + select { + case <-canceled: + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting for cancel") + } + if _, ok := pendingAddrConnID(cmgr, addr); ok { + t.Fatalf("connection %s is still pending", addr) } - // Ensure only a single connection was made. + // Start a connection attempt and wait for it to be established. + notifyDialed.Store(false) + waitForPending.Store(false) + notifyCanceled.Store(false) + go cmgr.Connect(ctx, addr) + conn := assertConnReceived(t, connected, 0, ConnTypeManual) + + // Remove the established connection and wait for the disconnect + // notification to ensure it is disconnected as intended. + connID = conn.ID() + if err := cmgr.Remove(connID); err != nil { + t.Fatalf("unexpected disconnect err: %v", err) + } + assertConnReceived(t, disconnected, connID, ConnTypeManual) + + // Add a persistent connection back to the same address. + notifyDialed.Store(true) + waitForPending.Store(true) + notifyCanceled.Store(true) + connID, err := cmgr.AddPersistent(addr) + if err != nil { + t.Fatalf("failed to add persistent connection: %v", err) + } + + // Wait for the connection manager to attempt to dial and ensure the + // connection is marked as pending while the dialer is blocked. + select { + case <-dialed: + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting for dial") + } + assertPendingAddr(t, cmgr, addr) + + // Remove the persistent connection attempt while it's still pending. + if err := cmgr.Remove(connID); err != nil { + t.Fatalf("unexpected disconnect err: %v", err) + } + + // Allow the dialer to proceed with the removed persistent connection + // attempt and then wait for the dialer to signal the context associated + // with the dial was canceled. + select { + case pending <- struct{}{}: + // Ensure the reconnect attempt doesn't notify the dialed chan or + // wait for the pending chan. + notifyDialed.Store(false) + waitForPending.Store(false) + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting to signal pending") + } select { - case c := <-connected: - t.Fatalf("connect mode: got unexpected connection - %v", c.Addr) + case <-canceled: case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting for cancel") } + // Add a persistent connection back to the same address and wait for it to + // be established. + notifyDialed.Store(false) + waitForPending.Store(false) + notifyCanceled.Store(false) + connID, err = cmgr.AddPersistent(addr) + if err != nil { + t.Fatalf("failed to add persistent connection: %v", err) + } + conn2 := assertConnReceived(t, connected, connID, ConnTypeManual) + + // Remove the established persistent connection and wait for the disconnect + // notification to ensure it is disconnected as intended. Also, ensure the + // persistent connection entry is removed. + connID = conn2.ID() + if err := cmgr.Remove(connID); err != nil { + t.Fatalf("unexpected disconnect err: %v", err) + } + assertConnReceived(t, disconnected, connID, ConnTypeManual) + // Ensure clean shutdown of connection manager. shutdown() wg.Wait() @@ -174,48 +682,71 @@ func TestConnectMode(t *testing.T) { // configuration option by waiting until all connections are established and // ensuring they are the only connections made. func TestTargetOutbound(t *testing.T) { + t.Parallel() + const targetOutbound = 10 - var numConnections atomic.Uint32 - hitTargetConns := make(chan struct{}) - extraConns := make(chan *ConnReq) - cmgr, err := New(&Config{ + var nextAddr atomic.Uint32 + connected := make(chan *Conn) + cmgr := newTestConnManager(t, &Config{ TargetOutbound: targetOutbound, Dial: mockDialer, GetNewAddress: func() (net.Addr, error) { - return &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }, nil + addrStr := fmt.Sprintf("127.0.0.%d:18555", nextAddr.Add(1)) + return mustParseAddrPort(addrStr), nil }, - OnConnection: func(c *ConnReq, conn net.Conn) { - totalConnections := numConnections.Add(1) - if totalConnections == targetOutbound { - close(hitTargetConns) - return - } - if totalConnections > targetOutbound { - extraConns <- c - } + OnConnection: func(conn *Conn) { + connected <- conn }, }) - if err != nil { - t.Fatalf("New error: %v", err) + _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + + // Ensure only the expected number of target outbound conns are established + // and no more. + for range targetOutbound { + assertConnReceived(t, connected, 0, ConnTypeOutbound) } + assertNoConnReceived(t, connected) + + // Ensure clean shutdown of connection manager. + shutdown() + wg.Wait() +} + +// TestDoubleClose ensures closing a connection multiple times is a noop after +// the first call. +func TestDoubleClose(t *testing.T) { + t.Parallel() + + connected := make(chan *Conn) + cmgr := newTestConnManager(t, &Config{ + TargetOutbound: 1, + Dial: mockDialer, + GetNewAddress: func() (net.Addr, error) { + return mustParseAddrPort("127.0.0.1:18555"), nil + }, + OnConnection: func(conn *Conn) { + connected <- conn + }, + }) _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) - // Wait for the expected number of target outbound conns to be established. - select { - case <-hitTargetConns: - case <-time.After(20 * time.Millisecond): - t.Fatal("did not reach target number of conns before timeout") + // Wait for the connection to be established. + conn := assertConnReceived(t, connected, 0, ConnTypeOutbound) + + // Override the close func to cleanly detect closes. + var numClosed uint32 + origOnClose := conn.onClose + conn.onClose = func() { + numClosed++ + origOnClose() } - // Ensure no additional connections are made. - select { - case c := <-extraConns: - t.Fatalf("target outbound: got unexpected connection - %v", c.Addr) - case <-time.After(time.Millisecond * 5): - break + // Close the connection multiple times and make sure it only happens once. + for range 3 { + conn.Close() + } + if numClosed != 1 { + t.Fatal("connection closed more than once") } // Ensure clean shutdown of connection manager. @@ -223,54 +754,128 @@ func TestTargetOutbound(t *testing.T) { wg.Wait() } -// TestRetryPermanent tests that permanent connection requests are retried. -// -// We make a permanent connection request using Connect, disconnect it using -// Disconnect and we wait for it to be connected back. -func TestRetryPermanent(t *testing.T) { - connected := make(chan *ConnReq) - disconnected := make(chan *ConnReq) - cmgr, err := New(&Config{ +// TestRetryPersistent tests that persistent connections are retried. +func TestRetryPersistent(t *testing.T) { + t.Parallel() + + connected := make(chan *Conn) + disconnected := make(chan *Conn) + cmgr := newTestConnManager(t, &Config{ RetryDuration: time.Millisecond, TargetOutbound: 1, Dial: mockDialer, - OnConnection: func(c *ConnReq, conn net.Conn) { - connected <- c + OnConnection: func(conn *Conn) { + connected <- conn }, - OnDisconnection: func(c *ConnReq) { - disconnected <- c + OnDisconnection: func(conn *Conn) { + disconnected <- conn }, }) + _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + + addr := mustParseAddrPort("127.0.0.1:18555") + connID, err := cmgr.AddPersistent(addr) if err != nil { - t.Fatalf("New error: %v", err) + t.Fatalf("failed to add persistent connection: %v", err) + } + if !cmgr.IsPersistent(connID) { + t.Fatal("IsPersistent did not reported true for persistent conn") } - ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) - cr := &ConnReq{ - Addr: &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, + // Wait for the first connection, close it, wait for the disconnect, and + // ensure the retry succeeds. + conn := assertConnReceived(t, connected, connID, ConnTypeManual) + conn.Close() + assertConnReceived(t, disconnected, connID, ConnTypeManual) + assertConnReceived(t, connected, connID, ConnTypeManual) + + // Remove the persistent connection, wait for it to disconnect, and ensure + // it is actually removed. + if err := cmgr.Remove(connID); err != nil { + t.Fatalf("failed to remove persistent connection: %v", err) + } + assertConnReceived(t, disconnected, connID, ConnTypeManual) + assertRemovedPersistent(t, cmgr, addr) + + // Ensure clean shutdown of connection manager. + shutdown() + wg.Wait() +} + +// TestMaxPersistent ensures [ConnManager.AddPersistent] limits the maximum +// number of persistent connections including a removal and addition of a new +// one after achieving the max. +func TestMaxPersistent(t *testing.T) { + t.Parallel() + + connected := make(chan *Conn) + disconnected := make(chan *Conn) + cmgr := newTestConnManager(t, &Config{ + Dial: mockDialer, + OnConnection: func(conn *Conn) { + connected <- conn + }, + OnDisconnection: func(conn *Conn) { + disconnected <- conn }, - Permanent: true, + }) + _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + + var numAddrs uint32 + nextAddr := func() net.Addr { + numAddrs++ + addrStr := fmt.Sprintf("127.0.0.%d:18555", numAddrs) + return mustParseAddrPort(addrStr) } - go cmgr.Connect(ctx, cr) - gotConnReq := <-connected - assertConnReqID(t, gotConnReq, cr.ID()) - assertConnReqState(t, cr, ConnEstablished) - cmgr.Disconnect(cr.ID()) - gotConnReq = <-disconnected - assertConnReqID(t, gotConnReq, cr.ID()) - assertConnReqState(t, cr, ConnPending) + // Add the maximum allowed number of persistent conns. + connIDs := make([]uint64, 0, MaxPersistent) + addrs := make([]net.Addr, 0, MaxPersistent) + for range MaxPersistent { + addr := nextAddr() + connID, err := cmgr.AddPersistent(addr) + if err != nil { + t.Fatalf("failed to add persistent connection %v: %v", addr, err) + } + connIDs = append(connIDs, connID) + addrs = append(addrs, addr) - gotConnReq = <-connected - assertConnReqID(t, gotConnReq, cr.ID()) - assertConnReqState(t, cr, ConnEstablished) + // Wait for the connection. + assertConnReceived(t, connected, connID, ConnTypeManual) + } - cmgr.Remove(cr.ID()) - gotConnReq = <-disconnected - assertConnReqID(t, gotConnReq, cr.ID()) - assertConnReqState(t, cr, ConnDisconnected) + // Attempting to add more than the max allowed number of persistent conns + // should be rejected. + _, err := cmgr.AddPersistent(nextAddr()) + if !errors.Is(err, ErrMaxPersistent) { + t.Fatalf("did not reject > max persistent, err: %v", err) + } + + // Ensure disconnecting the persistent conn does not incorrectly decrement + // the count. + connID, addr := connIDs[0], addrs[0] + if err := cmgr.Disconnect(connID); err != nil { + t.Fatalf("failed to disconnect persistent conn %v: %v", addr, err) + } + _, err = cmgr.AddPersistent(nextAddr()) + if !errors.Is(err, ErrMaxPersistent) { + t.Fatalf("did not reject max persistent after dc, err: %v", err) + } + + // Remove the first persistent connection, wait for it to disconnect, and + // ensure it is actually removed. + if err := cmgr.Remove(connID); err != nil { + t.Fatalf("failed to remove persistent conn %v: %v", addr, err) + } + assertConnReceived(t, disconnected, connID, ConnTypeManual) + assertRemovedPersistent(t, cmgr, addr) + + // A new persistent conn should now be allowed. + addr = nextAddr() + _, err = cmgr.AddPersistent(addr) + if err != nil { + t.Fatalf("failed to add persistent connection %v: %v", addr, err) + } // Ensure clean shutdown of connection manager. shutdown() @@ -282,17 +887,16 @@ func TestRetryPermanent(t *testing.T) { // We have a timed dialer which initially returns err but after RetryDuration // hits maxRetryDuration returns a mock conn. func TestMaxRetryDuration(t *testing.T) { + t.Parallel() + // This test relies on the current value of the max retry duration defined // in the tests, so assert it. - if maxRetryDuration != 2*time.Millisecond { + if defaultTestMaxRetryDuration != 2*time.Millisecond { t.Fatalf("max retry duration of %v is not the required value for test", - maxRetryDuration) + defaultTestMaxRetryDuration) } networkUp := make(chan struct{}) - time.AfterFunc(5*time.Millisecond, func() { - close(networkUp) - }) timedDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { select { case <-networkUp: @@ -302,36 +906,31 @@ func TestMaxRetryDuration(t *testing.T) { } } - connected := make(chan *ConnReq) - cmgr, err := New(&Config{ + connected := make(chan *Conn) + cmgr := newTestConnManager(t, &Config{ RetryDuration: time.Millisecond, TargetOutbound: 1, Dial: timedDialer, - OnConnection: func(c *ConnReq, conn net.Conn) { - connected <- c + OnConnection: func(conn *Conn) { + connected <- conn }, }) + _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + + connID, err := cmgr.AddPersistent(mustParseAddrPort("127.0.0.1:18555")) if err != nil { - t.Fatalf("New error: %v", err) + t.Fatalf("failed to add persistent connection: %v", err) } - ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) - cr := &ConnReq{ - Addr: &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }, - Permanent: true, - } - go cmgr.Connect(ctx, cr) // retry in 1ms // retry in 2ms - max retry duration reached // retry in 2ms - timedDialer returns mockDial - select { - case <-connected: - case <-time.After(200 * time.Millisecond): - t.Fatal("max retry duration: connection timeout") - } + const networkUpTimeout = 5 * time.Millisecond + time.AfterFunc(networkUpTimeout, func() { + close(networkUp) + }) + const timeout = connTestReceiveTimeout + networkUpTimeout + assertConnReceivedTimeout(t, connected, timeout, connID, ConnTypeManual) // Ensure clean shutdown of connection manager. shutdown() @@ -341,6 +940,8 @@ func TestMaxRetryDuration(t *testing.T) { // TestNetworkFailure tests that the connection manager handles a network // failure gracefully. func TestNetworkFailure(t *testing.T) { + t.Parallel() + var closeOnce sync.Once const targetOutbound = 5 const retryTimeout = time.Millisecond * 5 @@ -349,35 +950,36 @@ func TestNetworkFailure(t *testing.T) { connMgrDone := make(chan struct{}) errDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { totalDials := dials.Add(1) - if totalDials >= maxFailedAttempts { + if totalDials > maxFailedAttempts { closeOnce.Do(func() { close(reachedMaxFailedAttempts) }) <-connMgrDone } return nil, errors.New("network down") } - cmgr, err := New(&Config{ + var nextAddr atomic.Uint32 + cmgr := newTestConnManager(t, &Config{ TargetOutbound: targetOutbound, RetryDuration: retryTimeout, Dial: errDialer, GetNewAddress: func() (net.Addr, error) { - return &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }, nil + addrStr := fmt.Sprintf("127.0.0.%d:18555", nextAddr.Add(1)) + return mustParseAddrPort(addrStr), nil }, - OnConnection: func(c *ConnReq, conn net.Conn) { - t.Fatalf("network failure: got unexpected connection - %v", c.Addr) + OnConnection: func(conn *Conn) { + t.Fatalf("network failure: got unexpected connection - %v", + conn.RemoteAddr()) }, }) - if err != nil { - t.Fatalf("New error: %v", err) - } _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) // Shutdown the connection manager after the max failed attempts is reached // and an additional retry duration has passed and then wait for the // shutdown to complete. - <-reachedMaxFailedAttempts + select { + case <-reachedMaxFailedAttempts: + case <-time.After(retryTimeout * maxFailedAttempts * 3): + t.Fatal("did not reach target number of failed attempts before timeout") + } time.Sleep(retryTimeout) shutdown() close(connMgrDone) @@ -396,15 +998,13 @@ func TestNetworkFailure(t *testing.T) { // TestMultipleFailedConns ensures that the connection manager remains // responsive when there are multiple simultaneous failed connections for -// persistent peers in the retry state. +// persistent conns in the retry state. func TestMultipleFailedConns(t *testing.T) { + t.Parallel() + // Override the max retry duration for this test since it relies on having // multiple connections in the retry state. - curMaxRetryDuration := maxRetryDuration - maxRetryDuration = 500 * time.Millisecond - defer func() { - maxRetryDuration = curMaxRetryDuration - }() + const maxRetryDuration = 500 * time.Millisecond const targetFailed = 5 var dials atomic.Uint32 @@ -417,25 +1017,20 @@ func TestMultipleFailedConns(t *testing.T) { } return nil, errors.New("network down") } - cmgr, err := New(&Config{ + cmgr := newTestConnManager(t, &Config{ RetryDuration: maxRetryDuration, Dial: errDialer, }) - if err != nil { - t.Fatalf("New error: %v", err) - } - ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + cmgr.maxRetryDuration = maxRetryDuration + _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) // Establish several connection requests to localhost IPs. - for i := 0; i < targetFailed; i++ { - cr := &ConnReq{ - Addr: &net.TCPAddr{ - IP: net.ParseIP(fmt.Sprintf("127.0.0.%d", i+1)), - Port: 18555, - }, - Permanent: true, + for i := range targetFailed { + addr := mustParseAddrPort(fmt.Sprintf("127.0.0.%d:18555", i+1)) + _, err := cmgr.AddPersistent(addr) + if err != nil { + t.Fatalf("unexpected add err: %v", err) } - go cmgr.Connect(ctx, cr) } // Wait for the target number of dials and ensure they happen simultaneously @@ -467,184 +1062,171 @@ func TestMultipleFailedConns(t *testing.T) { // TestShutdownFailedConns tests that failed connections are ignored after // connmgr is shutdown. -// -// We have a dialer which sets the stop flag on the conn manager and returns an -// err so that the handler assumes that the conn manager is stopped and ignores -// the failure. func TestShutdownFailedConns(t *testing.T) { + t.Parallel() + var closeOnce sync.Once dialed := make(chan struct{}) waitDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { closeOnce.Do(func() { close(dialed) }) return nil, errors.New("network down") } - cmgr, err := New(&Config{ - RetryDuration: maxRetryDuration, + cmgr := newTestConnManager(t, &Config{ + RetryDuration: defaultTestMaxRetryDuration, Dial: waitDialer, }) - if err != nil { - t.Fatalf("New error: %v", err) - } ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) // Shutdown the connection manager during the retry timeout after a failed // dial attempt. go func() { <-dialed - time.Sleep(maxRetryDuration / 2) + time.Sleep(cmgr.maxRetryDuration / 2) shutdown() }() - // Establish a connection request to a localhost IP. - cr := &ConnReq{ - Addr: &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }, - Permanent: true, - } - go cmgr.Connect(ctx, cr) + // Establish a connection. + addr := mustParseAddrPort("127.0.0.1:18555") + go cmgr.Connect(ctx, addr) // Ensure clean shutdown of connection manager. wg.Wait() } -// TestRemovePendingConnection tests that it's possible to cancel a pending -// connection, removing its internal state from the connection manager. +// TestRemovePendingConnection ensures that removing a pending outbound +// connection correctly cancels the context used to dial and removes the +// internal state. func TestRemovePendingConnection(t *testing.T) { - // Create a ConnMgr instance with an instance of a dialer that'll never - // succeed. + t.Parallel() + + // Create a conn manager with an instance of a dialer that'll never succeed. dialed := make(chan struct{}) - wait := make(chan struct{}) + canceled := make(chan struct{}) indefiniteDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { close(dialed) - <-wait + <-ctx.Done() + close(canceled) return nil, errors.New("error") } - cmgr, err := New(&Config{ + cmgr := newTestConnManager(t, &Config{ Dial: indefiniteDialer, }) - if err != nil { - t.Fatalf("New error: %v", err) - } ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) // Establish a connection request to a localhost IP. - cr := &ConnReq{ - Addr: &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }, - Permanent: true, - } - go cmgr.Connect(ctx, cr) + addr := mustParseAddrPort("127.0.0.1:18555") + go cmgr.Connect(ctx, addr) - // Wait for the connection manager to attempt to dial the connection request - // and ensure the connection is marked as pending while the dialer is - // blocked. + // Wait for the connection manager to attempt to dial and ensure the + // connection is marked as pending while the dialer is blocked. select { case <-dialed: case <-time.After(time.Millisecond * 20): t.Fatal("timeout waiting for dial") } - assertConnReqState(t, cr, ConnPending) + assertPendingAddr(t, cmgr, addr) - // The request launched above will never be able to establish a connection, - // so cancel it _before_ it's able to be completed. - cmgr.Remove(cr.ID()) + // Cancel the connection attempt while it's still pending. + connID, _ := pendingAddrConnID(cmgr, addr) + if err := cmgr.Remove(connID); err != nil { + t.Fatalf("unexpected remove err: %v", err) + } - // Ensure the connection request is now marked as canceled after a short - // timeout to allow the transition to occur. - time.Sleep(10 * time.Millisecond) - assertConnReqState(t, cr, ConnCanceled) + // Wait for the dialer to signal the context associated with the dial was + // canceled and ensure the internal pending state is removed. + select { + case <-canceled: + case <-time.After(time.Millisecond * 20): + t.Fatal("timeout waiting for cancel") + } + if _, ok := pendingAddrConnID(cmgr, addr); ok { + t.Fatalf("connection %s is still pending", addr) + } // Ensure clean shutdown of connection manager. - close(wait) shutdown() wg.Wait() } -// TestCancelIgnoreDelayedConnection tests that a canceled connection request -// will not execute the on connection callback, even if an outstanding retry -// succeeds. +// TestCancelIgnoreDelayedConnection tests that a canceled pending persistent +// connection will not execute the on connection callback, even if a pending +// retry succeeds. func TestCancelIgnoreDelayedConnection(t *testing.T) { + t.Parallel() + const retryTimeout = 10 * time.Millisecond - // Setup a dialer that will continue to return an error until the - // connect chan is signaled. The dial attempt immediately after that - // will succeed in returning a connection. + // Setup a dialer that returns an error on the first attempt and then blocks + // until the connect chan is signaled. The dial attempt immediately after + // that will succeed in returning a connection. + var numAttempts atomic.Uint32 connect := make(chan struct{}) + retried := make(chan struct{}) failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { - select { - case <-connect: - return mockDialer(ctx, network, addr) - default: + if numAttempts.Add(1) == 1 { + return nil, errors.New("network down") } - return nil, errors.New("error") + close(retried) + <-connect + + // Override the context to ensure the pending dial succeeds even though + // the passed context will be canceled. + ctx = context.Background() + return mockDialer(ctx, network, addr) } - connected := make(chan *ConnReq) - cmgr, err := New(&Config{ + connected := make(chan *Conn) + cmgr := newTestConnManager(t, &Config{ Dial: failingDialer, RetryDuration: retryTimeout, - OnConnection: func(c *ConnReq, conn net.Conn) { - connected <- c + OnConnection: func(conn *Conn) { + connected <- conn }, }) + _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + + // Establish a persistent connection to a localhost IP. + addr := mustParseAddrPort("127.0.0.1:18555") + connID, err := cmgr.AddPersistent(addr) if err != nil { - t.Fatalf("New error: %v", err) + t.Fatalf("unexpected error: %v", err) } - ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) - // Establish a connection request to a localhost IP. - cr := &ConnReq{ - Addr: &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }, + // Wait for the retry and ensure the connection is pending. + select { + case <-retried: + case <-time.After(20 * time.Millisecond): + t.Fatalf("did not get retry before timeout") } - go cmgr.Connect(ctx, cr) - - // Allow for the first retry timeout to elapse. - time.Sleep(2 * retryTimeout) + assertPendingAddr(t, cmgr, addr) - // Ensure the status of the connection request is marked as failed, even - // after reattempting to connect. - assertConnReqState(t, cr, ConnFailed) - - // Remove the connection, and then immediately allow the next connection - // to succeed. - cmgr.Remove(cr.ID()) + // Remove the connection and then immediately allow the next connection to + // succeed. + if err := cmgr.Remove(connID); err != nil { + t.Fatalf("unexpected remove err: %v", err) + } close(connect) - // Allow the connection manager to process the removal. - time.Sleep(5 * time.Millisecond) - - // Ensure the status of the connection request is canceled. - assertConnReqState(t, cr, ConnCanceled) - // Finally, the connection manager should not signal the OnConnection // callback, since the request was explicitly canceled. Give a generous - // timeout window to ensure the connection manager's linear backoff is - // allowed to properly elapse. - select { - case <-connected: - t.Fatal("on-connect should not be called for canceled req") - case <-time.After(5 * retryTimeout): - } + // timeout window to ensure the connection manager's backoff is allowed to + // properly elapse. + assertNoConnReceivedTimeout(t, connected, 5*retryTimeout) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() } -// TestDialTimeout ensure the Timeout configuration parameter works as intended -// by creating a dialer that blocks for three times the configured dial timeout -// before connecting and ensuring the connection fails as expected. +// TestDialTimeout ensure [Config.Timeout] works as intended by creating a +// dialer that blocks for three times the configured dial timeout before +// connecting and ensuring the connection fails as expected. func TestDialTimeout(t *testing.T) { - // Create a connection manager instance with a dialer that blocks for twice - // the configured dial timeout before connecting. + t.Parallel() + + // Create a connection manager instance with a dialer that blocks for three + // times the configured dial timeout before connecting. const dialTimeout = time.Millisecond * 20 cancelled := make(chan struct{}) timeoutDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { @@ -657,43 +1239,34 @@ func TestDialTimeout(t *testing.T) { return mockDialer(ctx, network, addr) } - cmgr, err := New(&Config{ + cmgr := newTestConnManager(t, &Config{ Dial: timeoutDialer, DialTimeout: dialTimeout, }) - if err != nil { - t.Fatalf("New error: %v", err) - } - _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) - // Establish a connection request to a localhost IP. - cr := &ConnReq{ - Addr: &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }, - } - go cmgr.Connect(context.Background(), cr) + // Establish a connection to a localhost IP. + addr := mustParseAddrPort("127.0.0.1:18555") + go cmgr.Connect(ctx, addr) // Wait to receive the signal that the dialer context was cancelled, which - // means the dial timeout was hit, and ensure the connection request is - // marked as failed after a short timeout to allow the transition to occur. + // means the dial timeout was hit. select { case <-cancelled: case <-time.After(dialTimeout * 10): t.Fatal("timeout waiting for dial cancellation") } - time.Sleep(10 * time.Millisecond) - assertConnReqState(t, cr, ConnFailed) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() } -// TestConnectContext ensures the Connect method works as intended when provided -// with a context that times out before a dial attempt succeeds. +// TestConnectContext ensures the [ConnManager.Connect] method works as intended +// when provided with a context that is canceled before a dial attempt succeeds. func TestConnectContext(t *testing.T) { + t.Parallel() + // Create a connection manager instance with a dialer that blocks until its // provided context is canceled. dialed := make(chan struct{}) @@ -702,24 +1275,20 @@ func TestConnectContext(t *testing.T) { <-ctx.Done() return nil, ctx.Err() } - cmgr, err := New(&Config{ + cmgr := newTestConnManager(t, &Config{ Dial: indefiniteDialer, }) - if err != nil { - t.Fatalf("New error: %v", err) - } - _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) // Establish a connection request to a localhost IP with a separate context // that can be canceled. - cr := &ConnReq{ - Addr: &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }, - } - connectCtx, cancelConnect := context.WithCancel(context.Background()) - go cmgr.Connect(connectCtx, cr) + addr := mustParseAddrPort("127.0.0.1:18555") + connectCtx, cancelConnect := context.WithCancel(ctx) + connectErr := make(chan error, 1) + go func() { + _, err := cmgr.Connect(connectCtx, addr) + connectErr <- err + }() // Wait for the connection manager to attempt to dial the connection request // and ensure the connection is marked as pending while the dialer is @@ -729,119 +1298,19 @@ func TestConnectContext(t *testing.T) { case <-time.After(time.Millisecond * 20): t.Fatal("timeout waiting for dial") } - assertConnReqState(t, cr, ConnPending) + assertPendingAddr(t, cmgr, addr) - // Cancel the connection context and ensure the connection request is marked - // as failed after a short timeout to allow the transition to occur. + // Cancel the connection context, wait for the error from connect, and + // ensure it is the expected error. cancelConnect() - time.Sleep(10 * time.Millisecond) - assertConnReqState(t, cr, ConnFailed) - - // Ensure clean shutdown of connection manager. - shutdown() - wg.Wait() -} - -// TestForEachConnReq tests the connection request iteration logic work as -// expected including for normal, permanent, and pending connections. -func TestForEachConnReq(t *testing.T) { - // Create a connection manager instance with a dialer that recognizes a - // special address to delay on in order to keep it pending. - targetOutbound := uint32(5) - connected := make(chan *ConnReq) - pending := make(chan struct{}) - delayDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { - if addr == "127.0.0.1:18557" { - close(pending) - time.Sleep(time.Second) - return nil, errors.New("error") - } - return mockDialer(ctx, network, addr) - } - cmgr, err := New(&Config{ - TargetOutbound: targetOutbound, - Dial: delayDialer, - GetNewAddress: func() (net.Addr, error) { - return &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }, nil - }, - OnConnection: func(c *ConnReq, conn net.Conn) { - connected <- c - }, - }) - if err != nil { - t.Fatalf("New error: %v", err) - } - _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) - - // Wait for the expected number of target outbound conns to be established. - allConnected := make(chan struct{}) - go func() { - for i := uint32(0); i < targetOutbound; i++ { - <-connected - } - close(allConnected) - }() - select { - case <-allConnected: - case <-time.After(time.Millisecond * 5 * time.Duration(targetOutbound)): - t.Fatal("timeout waiting for connections") - } - - // Create a permanent connection. - cr := &ConnReq{ - Permanent: true, - Addr: &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18556, - }, - } - go cmgr.Connect(context.Background(), cr) - select { - case <-connected: - case <-time.After(time.Millisecond * 5): - t.Fatal("timeout waiting for permanent connection") - } - - // Create a connection that triggers the mock dialer to keep it pending. - cr = &ConnReq{ - Addr: &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18557, - }, - } - go cmgr.Connect(context.Background(), cr) select { - case <-pending: - case <-time.After(time.Millisecond * 5): - t.Fatal("timeout waiting for pending connection") - } - - // Ensure the expected number of each type of connection exists. - var numConnected, numPermanent, numPending uint32 - _ = cmgr.ForEachConnReq(func(cr *ConnReq) error { - numConnected++ - if cr.State() == ConnPending { - numPending++ - } - if cr.Permanent { - numPermanent++ + case err := <-connectErr: + if !errors.Is(err, context.Canceled) { + t.Fatalf("unexpected connect err: got %v, want %v", err, + context.Canceled) } - return nil - }) - if numConnected != targetOutbound+2 { - t.Fatalf("unexpected number of iterated conn reqs -- got %d, want %d", - numConnected, targetOutbound+2) - } - if numPermanent != 1 { - t.Fatalf("unexpected number of permanent conn reqs -- got %d, want %d", - numPermanent, 1) - } - if numPending != 1 { - t.Fatalf("unexpected number of pending conn reqs -- got %d, want %d", - numPending, 1) + case <-time.After(10 * time.Millisecond): + t.Fatal("timeout waiting for dial cancellation") } // Ensure clean shutdown of connection manager. @@ -888,14 +1357,11 @@ func (m *mockListener) Addr() net.Addr { // address. It will cause the Accept function to return a mock connection // configured with the provided remote address and the local address for the // mock listener. -func (m *mockListener) Connect(ip string, port int) { +func (m *mockListener) Connect(addr net.Addr) { m.provideConn <- &mockConn{ laddr: m.localAddr, lnet: "tcp", - rAddr: &net.TCPAddr{ - IP: net.ParseIP(ip), - Port: port, - }, + rAddr: addr, } } @@ -911,52 +1377,310 @@ func newMockListener(localAddr string) *mockListener { // TestListeners ensures providing listeners to the connection manager along // with an accept callback works properly. func TestListeners(t *testing.T) { + t.Parallel() + // Setup a connection manager with a couple of mock listeners that // notify a channel when they receive mock connections. - receivedConns := make(chan net.Conn) - listener1 := newMockListener("127.0.0.1:8333") - listener2 := newMockListener("127.0.0.1:9333") + receivedConns := make(chan *Conn) + listener1 := newMockListener("127.0.0.1:9108") + listener2 := newMockListener("127.0.0.1:9208") listeners := []net.Listener{listener1, listener2} - cmgr, err := New(&Config{ + cmgr := newTestConnManager(t, &Config{ Listeners: listeners, - OnAccept: func(conn net.Conn) { + OnAccept: func(conn *Conn) { receivedConns <- conn }, Dial: mockDialer, }) - if err != nil { - t.Fatalf("New error: %v", err) - } _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) // Fake a couple of mock connections to each of the listeners. go func() { for i, listener := range listeners { l := listener.(*mockListener) - l.Connect("127.0.0.1", 10000+i*2) - l.Connect("127.0.0.1", 10000+i*2+1) + l.Connect(mustParseAddrPort(fmt.Sprintf("127.0.0.1:%d", 10000+i*2))) + l.Connect(mustParseAddrPort(fmt.Sprintf("127.0.0.1:%d", 10000+i*2+1))) } }() - // Tally the receive connections to ensure the expected number are - // received. Also, fail the test after a timeout so it will not hang - // forever should the test not work. + // Ensure the expected number of inbound connections are received. expectedNumConns := len(listeners) * 2 - var numConns int -out: - for { - select { - case <-receivedConns: - numConns++ - if numConns == expectedNumConns { - break out + for range expectedNumConns { + assertConnReceived(t, receivedConns, 0, ConnTypeInbound) + } + + // Ensure clean shutdown of connection manager. + shutdown() + wg.Wait() +} + +// TestRejectDuplicateConns ensures duplicate addresses are rejected. This +// includes: +// - Attempts to dial addresses that already have pending, established, and +// persistent connections (via [ConnManager.Connect] +// - Attempts to add duplicate persistent conns (via [ConnManager.AddPersistent]) +// - Attempts to receive inbound remote addresses that already have pending, +// established, and persistent connections +func TestRejectDuplicateConns(t *testing.T) { + t.Parallel() + + var closeDialedOnce sync.Once + inboundConns := make(chan *Conn) + listener := newMockListener("127.0.0.1:18109") + connected := make(chan *Conn) + disconnected := make(chan *Conn) + dialed := make(chan struct{}) + pending := make(chan struct{}) + pendingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + closeDialedOnce.Do(func() { close(dialed) }) + <-pending + return mockDialer(ctx, network, addr) + } + cmgr := newTestConnManager(t, &Config{ + Listeners: []net.Listener{listener}, + OnAccept: func(conn *Conn) { + inboundConns <- conn + }, + Dial: pendingDialer, + OnConnection: func(conn *Conn) { + connected <- conn + }, + OnDisconnection: func(conn *Conn) { + disconnected <- conn + }, + }) + ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + + // Dial a manual connection and wait for it to become pending. + addr := mustParseAddrPort("127.0.0.1:18555") + go cmgr.Connect(ctx, addr) + select { + case <-dialed: + case <-time.After(time.Millisecond * 5): + t.Fatal("did not receive pending dial before timeout") + } + assertPendingAddr(t, cmgr, addr) + + // Duplicate connect to the pending address should be rejected. + if _, err := cmgr.Connect(ctx, addr); !errors.Is(err, ErrAlreadyPending) { + t.Fatalf("did not reject duplicate pending connection, err: %v", err) + } + + // Inbound attempts from the pending outbound address should be rejected. + go listener.Connect(addr) + assertNoConnReceived(t, inboundConns) + + // Allow the pending connection to complete. + close(pending) + conn := assertConnReceived(t, connected, 0, ConnTypeManual) + + // Duplicate connect to the established address should be rejected. + if _, err := cmgr.Connect(ctx, addr); !errors.Is(err, ErrAlreadyConnected) { + t.Fatalf("did not reject duplicate active connection, err: %v", err) + } + + // Inbound attempts from the established outbound address should be + // rejected. + go listener.Connect(addr) + assertNoConnReceived(t, inboundConns) + + // Close the connection and wait for the disconnect. + conn.Close() + assertConnReceived(t, disconnected, conn.ID(), ConnTypeManual) + + // Add a persistent connection back to the same address and wait for it to + // connect since there are no longer any connections to the address. + connID, err := cmgr.AddPersistent(addr) + if err != nil { + t.Fatalf("failed to add persistent connection: %v", err) + } + assertConnReceived(t, connected, connID, ConnTypeManual) + + // Duplicate persistent connection attempts should be rejected. + _, err = cmgr.AddPersistent(addr) + if !errors.Is(err, ErrDuplicatePersistent) { + t.Fatalf("did not reject duplicate persistent connection, err: %v", err) + } + + // Manual connection attempts to persistent connection should be rejected. + _, err = cmgr.Connect(ctx, addr) + if !errors.Is(err, ErrDuplicatePersistent) { + t.Fatalf("did not reject manual connection to persistent, err: %v", err) + } + + // Inbound atempts from the persistent address should be rejected. + go listener.Connect(addr) + assertNoConnReceived(t, inboundConns) + + // Remove the persistent connection, wait for it to disconnect, and ensure + // it is actually removed. + if err := cmgr.Remove(connID); err != nil { + t.Fatalf("failed to remove persistent connection: %v", err) + } + assertConnReceived(t, disconnected, connID, ConnTypeManual) + assertRemovedPersistent(t, cmgr, addr) + + // Inbound connections from the same address should now succeed. + go listener.Connect(addr) + assertConnReceived(t, inboundConns, 0, ConnTypeInbound) + + // Manual connection attempts to the inbound address should be rejected. + if _, err := cmgr.Connect(ctx, addr); !errors.Is(err, ErrAlreadyConnected) { + t.Fatalf("did not reject outbound for existing inbound conn, err: %v", + err) + } + + // Attempts to add a persistent connection to an existing inbound should be + // rejected. + _, err = cmgr.AddPersistent(addr) + if !errors.Is(err, ErrAlreadyConnected) { + t.Fatalf("did not reject persistent conn for existing inbound conn: %v", + err) + } + + // Ensure clean shutdown of connection manager. + shutdown() + wg.Wait() +} + +// TestMaxNormalConns ensures the connection manager limits the total number of +// normal connections to [Config.MaxNormalConns] including automatic outbound, +// manual outbound, and inbound connections. It also ensures that it is not +// applied to persistent connections. +func TestMaxNormalConns(t *testing.T) { + t.Parallel() + + // nextAddr is a convenience func to return a new unique address with every + // invocation. + var numAddrs atomic.Uint32 + nextAddr := func() net.Addr { + addrStr := fmt.Sprintf("10.0.0.%d:18555", numAddrs.Add(1)) + return mustParseAddrPort(addrStr) + } + + // Constants for the number of various normal connection types to test + // overall max normal connection limits. + const ( + targetOutbound = 3 + targetManual = 4 + targetInbound = 5 + maxNormalConns = targetOutbound + targetManual + targetInbound + ) + connected := make(chan *Conn) + disconnected := make(chan *Conn) + inboundConns := make(chan *Conn) + listener := newMockListener("127.0.0.1:9108") + var pauseTargetOutbound atomic.Bool + var totalPausedAddrs atomic.Uint32 + hitMaxFailedAttempts := make(chan struct{}) + cmgr := newTestConnManager(t, &Config{ + Listeners: []net.Listener{listener}, + MaxNormalConns: maxNormalConns, + TargetOutbound: targetOutbound, + RetryDuration: 50 * time.Millisecond, + Dial: mockDialer, + OnAccept: func(conn *Conn) { + inboundConns <- conn + }, + GetNewAddress: func() (net.Addr, error) { + if pauseTargetOutbound.Load() { + total := totalPausedAddrs.Add(1) + if total == maxFailedAttempts { + hitMaxFailedAttempts <- struct{}{} + } + return nil, errors.New("network down") } + return nextAddr(), nil + }, + OnConnection: func(conn *Conn) { + connected <- conn + }, + OnDisconnection: func(conn *Conn) { + disconnected <- conn + }, + }) + cmgr.maxRetryDuration = cmgr.cfg.RetryDuration + ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) - case <-time.After(time.Millisecond * 50): - t.Fatalf("Timeout waiting for %d expected connections", - expectedNumConns) + // Wait for the expected number of target outbound conns to be established. + outbounds := make([]*Conn, 0, targetOutbound) + for len(outbounds) < targetOutbound { + conn := assertConnReceived(t, connected, 0, ConnTypeOutbound) + outbounds = append(outbounds, conn) + } + + // Establish target number of inbounds to the listener and wait for them to + // be established. + go func() { + for range targetInbound { + listener.Connect(nextAddr()) } + }() + inbounds := make([]*Conn, 0, targetInbound) + for len(inbounds) < targetInbound { + conn := assertConnReceived(t, inboundConns, 0, ConnTypeInbound) + inbounds = append(inbounds, conn) + } + + // Establish target number of manual connections and wait for them to be + // established. + go func() { + for range targetManual { + go cmgr.Connect(ctx, nextAddr()) + } + }() + manualConns := make([]*Conn, 0, targetManual+1) + for len(manualConns) < targetManual { + conn := assertConnReceived(t, connected, 0, ConnTypeManual) + manualConns = append(manualConns, conn) + } + + // Ensure manual connections that would exceed the max allowed normal + // connections are rejected. + _, err := cmgr.Connect(ctx, nextAddr()) + if !errors.Is(err, ErrMaxNormalConns) { + t.Fatalf("did not reject manual connection at max allowed, err: %v", err) + } + + // Ensure inbound connections that would exceed the max allowed normal + // connections are rejected. + go listener.Connect(nextAddr()) + assertNoConnReceived(t, inboundConns) + + // Pause the target outbound dials and remove one of the target outbound + // connections to make room for another manual connection. Then wait for + // the max failures to be hit so attempts are paused for a retry timeout. + pauseTargetOutbound.Store(true) + outboundConn := outbounds[0] + outboundConn.Close() + assertConnReceived(t, disconnected, outboundConn.ID(), ConnTypeOutbound) + select { + case <-hitMaxFailedAttempts: + time.Sleep(connTestReceiveTimeout) + case <-time.After(maxFailedAttempts * connTestReceiveTimeout): + t.Fatal("did not reach max failed attempts before timeout") + } + + // Establish another manual connection to take the place of the target + // outbound connection that was just closed and wait for it to be + // established. + go cmgr.Connect(ctx, nextAddr()) + assertConnReceived(t, connected, 0, ConnTypeManual) + + // Unpause the target outbound dials and ensure no additional automatic + // outbound connections are made despite being under the target outbound due + // to max total conns. + pauseTargetOutbound.Store(false) + assertNoConnReceivedTimeout(t, connected, connTestNonReceiveTimeout+ + cmgr.cfg.RetryDuration) + + // Ensure persistent connections are not subject to the max total normal + // connections by adding one and waiting for it to be established. + connID, err := cmgr.AddPersistent(nextAddr()) + if err != nil { + t.Fatalf("failed to add persistent connection: %v", err) } + assertConnReceived(t, connected, connID, ConnTypeManual) // Ensure clean shutdown of connection manager. shutdown() diff --git a/internal/connmgr/conntype_test.go b/internal/connmgr/conntype_test.go new file mode 100644 index 000000000..f4d3f69e1 --- /dev/null +++ b/internal/connmgr/conntype_test.go @@ -0,0 +1,35 @@ +// Copyright (c) 2026 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package connmgr + +import ( + "testing" +) + +// TestConnectionTypeStringer tests the stringized output for connection types. +func TestConnectionTypeStringer(t *testing.T) { + tests := []struct { + in ConnectionType + want string + }{ + {ConnTypeInbound, "inbound"}, + {ConnTypeOutbound, "outbound"}, + {ConnTypeManual, "manual"}, + {0xff, "Unknown ConnectionType (255)"}, + } + + // Detect additional defines that don't have the stringer added. + if len(tests)-1 != int(numConnTypes) { + t.Fatal("It appears a connection type was added without adding an " + + "associated stringer test") + } + + for i, test := range tests { + if got := test.in.String(); got != test.want { + t.Errorf("String #%d: got: %s, want: %s", i, got, test.want) + continue + } + } +} diff --git a/internal/connmgr/doc.go b/internal/connmgr/doc.go deleted file mode 100644 index 3eb872c3c..000000000 --- a/internal/connmgr/doc.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) 2016 The btcsuite developers -// Copyright (c) 2017-2022 The Decred developers -// Use of this source code is governed by an ISC -// license that can be found in the LICENSE file. - -/* -Package connmgr implements a generic Decred network connection manager. - -# Deprecated - -This module is deprecated and is no longer maintained. Callers are encouraged -to use github.com/decred/dcrd/addrmgr/vX for methods that were moved to it -instead. - -# Connection Manager Overview - -Connection manager handles all the general connection concerns such as -maintaining a set number of outbound connections, sourcing peers, banning, -limiting max connections, tor lookup, etc. -*/ -package connmgr diff --git a/internal/connmgr/error.go b/internal/connmgr/error.go index 932a13f28..81b360a00 100644 --- a/internal/connmgr/error.go +++ b/internal/connmgr/error.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020 The Decred developers +// Copyright (c) 2020-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -14,10 +14,38 @@ const ( // the configuration. ErrDialNil = ErrorKind("ErrDialNil") + // ErrAlreadyPending indicates an attempt to connect to an address that + // already has a pending connection attempt. + ErrAlreadyPending = ErrorKind("ErrAlreadyPending") + + // ErrAlreadyConnected indicates an attempt to connect to an address that + // already has an established connection. + ErrAlreadyConnected = ErrorKind("ErrAlreadyConnected") + + // ErrMaxNormalConns indicates a connection attempt (inbound or outbound) + // would exceed the maximum allowed number of normal connections. + ErrMaxNormalConns = ErrorKind("ErrMaxNormalConns") + + // ErrMaxPersistent indicates an attempt to add more than the maximum + // allowed number of persistent connections. + ErrMaxPersistent = ErrorKind("ErrMaxPersistent") + + // ErrDuplicatePersistent indicates an attempt to add a persistent + // connection to an address that already exists. + ErrDuplicatePersistent = ErrorKind("ErrDuplicatePersistent") + // ErrNotFound indicates a specified connection ID or address is unknown to // the connection manager. ErrNotFound = ErrorKind("ErrNotFound") + // ErrUnsupportedAddr indicates an address is either an unsupporetd type or + // an unrecognized type due to being malformed. + ErrUnsupportedAddr = ErrorKind("ErrUnsupportedAddr") + + // ErrShutdown indicates the connection manager is either in the process of + // shutting down or has already been shutdown. + ErrShutdown = ErrorKind("ErrShutdown") + // ErrTorInvalidAddressResponse indicates an invalid address was // returned by the Tor DNS resolver. ErrTorInvalidAddressResponse = ErrorKind("ErrTorInvalidAddressResponse") diff --git a/internal/connmgr/error_test.go b/internal/connmgr/error_test.go index 1e177b97e..b3881bf7c 100644 --- a/internal/connmgr/error_test.go +++ b/internal/connmgr/error_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021 The Decred developers +// Copyright (c) 2020-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -17,7 +17,14 @@ func TestErrorKindStringer(t *testing.T) { want string }{ {ErrDialNil, "ErrDialNil"}, + {ErrAlreadyPending, "ErrAlreadyPending"}, + {ErrAlreadyConnected, "ErrAlreadyConnected"}, + {ErrMaxNormalConns, "ErrMaxNormalConns"}, + {ErrMaxPersistent, "ErrMaxPersistent"}, + {ErrDuplicatePersistent, "ErrDuplicatePersistent"}, {ErrNotFound, "ErrNotFound"}, + {ErrUnsupportedAddr, "ErrUnsupportedAddr"}, + {ErrShutdown, "ErrShutdown"}, {ErrTorInvalidAddressResponse, "ErrTorInvalidAddressResponse"}, {ErrTorInvalidProxyResponse, "ErrTorInvalidProxyResponse"}, {ErrTorUnrecognizedAuthMethod, "ErrTorUnrecognizedAuthMethod"}, diff --git a/internal/connmgr/log.go b/internal/connmgr/log.go index 4bf44f579..f6ba6f5f4 100644 --- a/internal/connmgr/log.go +++ b/internal/connmgr/log.go @@ -19,3 +19,12 @@ var log = slog.Disabled func UseLogger(logger slog.Logger) { log = logger } + +// pickNoun returns the singular or plural form of a noun depending on the count +// n. +func pickNoun[T ~uint32 | ~uint64](n T, singular, plural string) string { + if n == 1 { + return singular + } + return plural +} diff --git a/internal/connmgr/semaphore.go b/internal/connmgr/semaphore.go new file mode 100644 index 000000000..532efab5c --- /dev/null +++ b/internal/connmgr/semaphore.go @@ -0,0 +1,57 @@ +// Copyright (c) 2026 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package connmgr + +import "context" + +// semaphore is a simple context-aware channel based semaphore for bounding +// concurrent access. +type semaphore chan struct{} + +// makeSemaphore returns a new semaphore with the given capacity. +func makeSemaphore(n uint32) semaphore { + return make(chan struct{}, n) +} + +// Acquire acquires the semaphore. It blocks until resources are available or +// the provided context is done. It returns true on success and false when the +// context is done before semaphore can be acquired. +func (s semaphore) Acquire(ctx context.Context) bool { + select { + case s <- struct{}{}: + case <-ctx.Done(): + return false + } + return true +} + +// TryAcquire attempts to acquire the semaphore without blocking when there are +// no resources immediately available. +// +// It returns true with a nil error on success. It return false with a nil +// error when the semaphore is at capacity and no permit is available. +// +// Finally, it returns false with the error associated with the context +// immediately when the context is already canceled or timed out at the time of +// the call. It does not attempt to acquire the semaphore in that case. +func (s semaphore) TryAcquire(ctx context.Context) (bool, error) { + if ctx.Err() != nil { + return false, ctx.Err() + } + select { + case s <- struct{}{}: + return true, nil + default: + } + return false, nil +} + +// Release release the semaphore. +func (s semaphore) Release() { + select { + case <-s: + default: + } +} diff --git a/internal/connmgr/semaphore_test.go b/internal/connmgr/semaphore_test.go new file mode 100644 index 000000000..1d7d5b750 --- /dev/null +++ b/internal/connmgr/semaphore_test.go @@ -0,0 +1,180 @@ +// Copyright (c) 2026 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package connmgr + +import ( + "context" + "errors" + "testing" + "time" +) + +// TestSemaphore ensures the semaphore acquire, release, and context cancel +// semantics are as expected. +func TestSemaphore(t *testing.T) { + // Create a closure that acquires a semaphore with a timeout. + ctx := context.Background() + timedAcquire := func(sem semaphore, timeout time.Duration) bool { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + return sem.Acquire(ctx) + } + + // Create a closure that tries to acquire a semaphore via the nonblocking + // method with a timeout. + timedTryAcquire := func(sem semaphore, timeout time.Duration) (bool, error) { + ctx, cancel := context.WithTimeout(ctx, timeout) + if timeout == 0 { + cancel() + } else { + defer cancel() + } + return sem.TryAcquire(ctx) + } + + // perSemTest describes a test to run against the same semaphore. + type perSemTest struct { + name string // test description + numAcquires uint32 // num to acquire + numTries uint32 // num to try to acquire via nonblocking method + cancelTry bool // whether or not to cancel nonblocking try + numReleases uint32 // num to release + } + + tests := []struct { + name string // test description + cap uint32 // capacity of the semaphore + perSemTests []perSemTest // tests to run against same semaphore + want []bool // expected results + }{{ + name: "normal block/release behavior", + cap: 2, + perSemTests: []perSemTest{{ + name: "cap 2 (0 acquired): acquire 3, release 1", + numAcquires: 3, + numReleases: 1, + }, { + name: "cap 2 (1 acquired): acquire 2, release 0", + numAcquires: 2, + numReleases: 0, + }, { + name: "cap 2 (2 acquired): acquire 1, release 2", + numAcquires: 1, + numReleases: 2, + }}, + want: []bool{true, true, false, true, false, false}, + }, { + // Releasing more than acquired ignores the extra release and does not + // influence future ops. + name: "relase more than acquired", + cap: 5, + perSemTests: []perSemTest{{ + name: "cap 5 (0 acquired): acquire 1, release 2", + numAcquires: 1, + numReleases: 2, + }, { + name: "cap 5 (0 acquired): acquire 5, release 1", + numAcquires: 5, + numReleases: 1, + }, { + name: "cap 5 (4 acquired): acquire 2, release 5", + numAcquires: 2, + numReleases: 5, + }}, + want: []bool{true, true, true, true, true, true, true, false}, + }, { + name: "nonblocking tryacquire and blocking acquire mixed", + cap: 3, + perSemTests: []perSemTest{{ + name: "cap 3 (0 acquired): try 1, release 2", + numTries: 1, + numReleases: 2, + }, { + name: "cap 3 (0 acquired): acquire 2, try 1, release 1", + numAcquires: 2, + numTries: 1, + numReleases: 1, + }, { + name: "cap 3 (2 acquired): acquire 1, try 2, release 3", + numAcquires: 1, + numTries: 2, + numReleases: 3, + }}, + want: []bool{true, true, true, true, true, false, false}, + }, { + name: "nonblocking tryacquire with canceled context", + cap: 1, + perSemTests: []perSemTest{{ + name: "cap 1 (0 acquired): try 1 (canceled), release 0", + numTries: 1, + cancelTry: true, + numReleases: 0, + }, { + name: "cap 1 (0 acquired): acquire 1, try 1, release 1", + numAcquires: 1, + numTries: 1, + numReleases: 1, + }, { + name: "cap 1 (0 acquired): try 2, release 1", + numAcquires: 0, + numTries: 2, + numReleases: 1, + }}, + want: []bool{false, true, false, true, false}, + }} + + for _, test := range tests { + // Create semaphore with the capacity specified in the test and the + // a slice to hold the results. + sem := makeSemaphore(test.cap) + results := make([]bool, 0, len(test.want)) + + // Perform each sequence of acquires, try acquires, and releases as + // specified by the per semaphore tests. + for _, psTest := range test.perSemTests { + const timeout = 10 * time.Millisecond + for range psTest.numAcquires { + results = append(results, timedAcquire(sem, timeout)) + } + for range psTest.numTries { + // Override timeout with a duration 0 and expected error when + // the flag to force the context for the try acquire to be + // canceled is specified. + var wantErr error + tryTimeout := timeout + if psTest.cancelTry { + tryTimeout = 0 + wantErr = context.DeadlineExceeded + } + acquired, err := timedTryAcquire(sem, tryTimeout) + if !errors.Is(err, wantErr) { + t.Fatalf("%q: unexpected try acquire error: got %v, want %v", + psTest.name, err, wantErr) + } + results = append(results, acquired) + } + for range psTest.numReleases { + sem.Release() + } + } + + if len(results) != len(test.want) { + t.Errorf("%q: unexpected number of results: got %d, want %d", + test.name, len(results), len(test.want)) + } + for i := range results { + if results[i] != test.want[i] { + t.Errorf("%q: unexpected result for [%d]: got %v, want %v", + test.name, i, results[i], test.want[i]) + } + } + + // Ensure all acquires were released as expected. + if numAcquired := uint32(len(sem)); numAcquired != 0 { + t.Errorf("%q: unexpected final semaphore count: got %v, want %v", + test.name, numAcquired, 0) + } + } +} diff --git a/internal/rpcserver/interface.go b/internal/rpcserver/interface.go index 4a078f181..164ad974a 100644 --- a/internal/rpcserver/interface.go +++ b/internal/rpcserver/interface.go @@ -114,7 +114,7 @@ type ConnManager interface { // permanent flag indicates whether or not to make the peer persistent // and reconnect if the connection is lost. Attempting to connect to an // already existing peer will return an error. - Connect(addr string, permanent bool) error + Connect(ctx context.Context, addr string, permanent bool) error // RemoveByID removes the peer associated with the provided id from the // list of persistent peers. Attempting to remove an id that does not diff --git a/internal/rpcserver/rpcserver.go b/internal/rpcserver/rpcserver.go index da4e0c1dd..6ae6c6d59 100644 --- a/internal/rpcserver/rpcserver.go +++ b/internal/rpcserver/rpcserver.go @@ -598,7 +598,7 @@ func newWorkState() *workState { } // handleAddNode handles addnode commands. -func handleAddNode(_ context.Context, s *Server, cmd any) (any, error) { +func handleAddNode(ctx context.Context, s *Server, cmd any) (any, error) { c := cmd.(*types.AddNodeCmd) addr := normalizeAddress(c.Addr, s.cfg.ChainParams.DefaultPort) @@ -606,25 +606,38 @@ func handleAddNode(_ context.Context, s *Server, cmd any) (any, error) { var err error switch c.SubCmd { case "add": - err = connMgr.Connect(addr, true) + err = connMgr.Connect(ctx, addr, true) case "remove": err = connMgr.RemoveByAddr(addr) case "onetry": - err = connMgr.Connect(addr, false) + err = connMgr.Connect(ctx, addr, false) default: return nil, rpcInvalidError("Invalid subcommand for addnode") } if err != nil { - return nil, rpcInvalidError("%v: %v", c.SubCmd, err) + switch { + case ctx.Err() != nil: + return nil, rpcConnectionClosedError() + + case errors.Is(err, context.Canceled): + return nil, rpcCancelError("%v: connection attempt to %v canceled", + c.SubCmd, addr) + + case errors.Is(err, context.DeadlineExceeded): + return nil, rpcCancelError("%v: timeout connecting to %v", c.SubCmd, + addr) + } + + prefix := fmt.Sprintf("%v: failed operation on %v", c.SubCmd, addr) + return nil, rpcInternalErr(err, prefix) } - // no data returned unless an error. return nil, nil } // handleNode handles node commands. -func handleNode(_ context.Context, s *Server, cmd any) (any, error) { +func handleNode(ctx context.Context, s *Server, cmd any) (any, error) { c := cmd.(*types.NodeCmd) connMgr := s.cfg.ConnMgr @@ -646,13 +659,16 @@ func handleNode(_ context.Context, s *Server, cmd any) (any, error) { addr = normalizeAddress(c.Target, params.DefaultPort) err = connMgr.DisconnectByAddr(addr) } else { - return nil, rpcInvalidError("%v: Invalid "+ - "address or node ID", c.SubCmd) + return nil, rpcInvalidError("%v: invalid address or node ID", + c.SubCmd) } } if err != nil && peerExists(connMgr, addr, int32(nodeID)) { - return nil, rpcMiscError("can't disconnect a permanent peer, " + - "use remove") + return nil, rpcMiscError("can't disconnect a permanent peer, use " + + "remove") + } + if err != nil { + return nil, rpcInvalidError("%v: %v", c.SubCmd, err) } case "remove": @@ -667,13 +683,16 @@ func handleNode(_ context.Context, s *Server, cmd any) (any, error) { addr = normalizeAddress(c.Target, params.DefaultPort) err = connMgr.RemoveByAddr(addr) } else { - return nil, rpcInvalidError("%v: invalid "+ - "address or node ID", c.SubCmd) + return nil, rpcInvalidError("%v: invalid address or node ID", + c.SubCmd) } } if err != nil && peerExists(connMgr, addr, int32(nodeID)) { - return nil, rpcMiscError("can't remove a temporary peer, " + - "use disconnect") + return nil, rpcMiscError("can't remove a temporary peer, use " + + "disconnect") + } + if err != nil { + return nil, rpcInvalidError("%v: %v", c.SubCmd, err) } case "connect": @@ -687,20 +706,33 @@ func handleNode(_ context.Context, s *Server, cmd any) (any, error) { switch subCmd { case "perm", "temp": - err = connMgr.Connect(addr, subCmd == "perm") + err = connMgr.Connect(ctx, addr, subCmd == "perm") default: - return nil, rpcInvalidError("%v: invalid subcommand "+ - "for node connect", subCmd) + return nil, rpcInvalidError("%v: invalid subcommand for node "+ + "connect", subCmd) } + if err != nil { + switch { + case ctx.Err() != nil: + return nil, rpcConnectionClosedError() + + case errors.Is(err, context.Canceled): + return nil, rpcCancelError("%v: connection attempt to %v "+ + "canceled", c.SubCmd, addr) + + case errors.Is(err, context.DeadlineExceeded): + return nil, rpcCancelError("%v: timeout connecting to %v", + c.SubCmd, addr) + } + + prefix := fmt.Sprintf("%v: failed operation on %v", c.SubCmd, addr) + return nil, rpcInternalErr(err, prefix) + } + default: return nil, rpcInvalidError("%v: invalid subcommand for node", c.SubCmd) } - if err != nil { - return nil, rpcInvalidError("%v: %v", c.SubCmd, err) - } - - // no data returned unless an error. return nil, nil } diff --git a/internal/rpcserver/rpcserverhandlers_test.go b/internal/rpcserver/rpcserverhandlers_test.go index 562282735..957aeba95 100644 --- a/internal/rpcserver/rpcserverhandlers_test.go +++ b/internal/rpcserver/rpcserverhandlers_test.go @@ -856,7 +856,7 @@ type testConnManager struct { // Connect provides a mock implementation for adding the provided address as a // new outbound peer. -func (c *testConnManager) Connect(addr string, permanent bool) error { +func (c *testConnManager) Connect(ctx context.Context, addr string, permanent bool) error { return c.connectErr } @@ -1919,7 +1919,7 @@ func TestHandleAddNode(t *testing.T) { return connManager }(), wantErr: true, - errCode: dcrjson.ErrRPCInvalidParameter, + errCode: dcrjson.ErrRPCInternal.Code, }, { name: "handleAddNode: 'remove' subcommand error", handler: handleAddNode, @@ -1933,7 +1933,7 @@ func TestHandleAddNode(t *testing.T) { return connManager }(), wantErr: true, - errCode: dcrjson.ErrRPCInvalidParameter, + errCode: dcrjson.ErrRPCInternal.Code, }, { name: "handleAddNode: 'onetry' subcommand error", handler: handleAddNode, @@ -1947,7 +1947,7 @@ func TestHandleAddNode(t *testing.T) { return connManager }(), wantErr: true, - errCode: dcrjson.ErrRPCInvalidParameter, + errCode: dcrjson.ErrRPCInternal.Code, }, { name: "handleAddNode: invalid subcommand", handler: handleAddNode, diff --git a/rpcadaptors.go b/rpcadaptors.go index 8dd072a94..30d6c6c27 100644 --- a/rpcadaptors.go +++ b/rpcadaptors.go @@ -15,7 +15,6 @@ import ( "github.com/decred/dcrd/chaincfg/v3" "github.com/decred/dcrd/dcrutil/v4" "github.com/decred/dcrd/internal/blockchain" - "github.com/decred/dcrd/internal/connmgr" "github.com/decred/dcrd/internal/mempool" "github.com/decred/dcrd/internal/mining" "github.com/decred/dcrd/internal/mining/cpuminer" @@ -125,49 +124,28 @@ var _ rpcserver.ConnManager = (*rpcConnManager)(nil) // // This function is safe for concurrent access and is part of the // rpcserver.ConnManager interface implementation. -func (cm *rpcConnManager) Connect(addr string, permanent bool) error { - // Prevent duplicate connections to the same peer. - connManager := cm.server.connManager - err := connManager.ForEachConnReq(func(c *connmgr.ConnReq) error { - if c.Addr != nil && c.Addr.String() == addr { - if c.Permanent { - return errors.New("peer exists as a permanent peer") - } - - switch c.State() { - case connmgr.ConnPending: - return errors.New("peer pending connection") - case connmgr.ConnEstablished: - return errors.New("peer already connected") - - } - } - return nil - }) - if err != nil { - return err - } - +func (cm *rpcConnManager) Connect(ctx context.Context, addr string, permanent bool) error { netAddr, err := addrStringToNetAddr(addr) if err != nil { return err } - // Limit max number of total peers. - cm.server.peerState.Lock() - count := cm.server.peerState.count() - cm.server.peerState.Unlock() - if count >= cfg.MaxPeers { - return errors.New("max peers reached") + // Attempt to add a persistent peer when requested. + connManager := cm.server.connManager + if permanent { + _, err := connManager.AddPersistent(netAddr) + return err } - go connManager.Connect(context.Background(), &connmgr.ConnReq{ - Addr: netAddr, - Permanent: permanent, - }) - return nil + // Attempt to connect to the address. + _, err = connManager.Connect(ctx, netAddr) + return err } +// errPeerNotFound is returned by the RPC conn manager when no matching peer for +// a given address or ID is found. +var errPeerNotFound = errors.New("peer not found") + // removeNode removes any peers that the provided compare function return true // for from the list of persistent peers. // @@ -175,26 +153,21 @@ func (cm *rpcConnManager) Connect(addr string, permanent bool) error { // function returns false for all peers). func (cm *rpcConnManager) removeNode(cmp func(*serverPeer) bool) error { state := &cm.server.peerState + var found *serverPeer state.Lock() - found := disconnectPeer(state.persistentPeers, cmp, func(sp *serverPeer) { - // Update the group counts since the peer will be removed from the - // persistent peers just after this func returns. - state.outboundGroups[sp.remoteAddr.GroupKey()]-- - - connReq := sp.connReq.Load() - peerLog.Debugf("Removing persistent peer %s (reqid %d)", sp.remoteAddr, - connReq.ID()) - - // Mark the peer's connReq as nil to prevent it from scheduling a - // re-connect attempt. - sp.connReq.Store(nil) - cm.server.connManager.Remove(connReq.ID()) - }) + for _, peer := range state.persistentPeers { + if cmp(peer) { + found = peer + break + } + } state.Unlock() - - if !found { - return errors.New("peer not found") + if found == nil { + return errPeerNotFound } + + peerLog.Debugf("Removing persistent peer %s", found.remoteAddr) + cm.server.connManager.Remove(found.conn.ID()) return nil } @@ -203,10 +176,20 @@ func (cm *rpcConnManager) removeNode(cmp func(*serverPeer) bool) error { // an error. // // This function is safe for concurrent access and is part of the -// rpcserver.ConnManager interface implementation. +// [rpcserver.ConnManager] interface implementation. func (cm *rpcConnManager) RemoveByID(id int32) error { + // Attempt to remove the peer by ID first. When the ID does not correspond + // to an established persistent peer, fall back to treating the ID as a + // connection ID and remove it when it is for a persistent connection. + connManager := cm.server.connManager cmp := func(sp *serverPeer) bool { return sp.ID() == id } - return cm.removeNode(cmp) + err := cm.removeNode(cmp) + if errors.Is(err, errPeerNotFound) && connManager.IsPersistent(uint64(id)) { + if rErr := connManager.Remove(uint64(id)); rErr == nil { + return nil + } + } + return err } // RemoveByAddr removes the peer associated with the provided address from the @@ -214,18 +197,22 @@ func (cm *rpcConnManager) RemoveByID(id int32) error { // exist will return an error. // // This function is safe for concurrent access and is part of the -// rpcserver.ConnManager interface implementation. +// [rpcserver.ConnManager] interface implementation. func (cm *rpcConnManager) RemoveByAddr(addr string) error { + // Attempt to remove the peer by address first. When the address does not + // correspond to an established persistent peer, fall back to searching the + // connection manager directly for a matching persistent connection entry + // and remove it when found. cmp := func(sp *serverPeer) bool { return sp.Addr() == addr } err := cm.removeNode(cmp) - if err != nil { - netAddr, err := addrStringToNetAddr(addr) - if err != nil { - return err + if errors.Is(err, errPeerNotFound) { + netAddr := simpleAddr{"tcp", addr} + if id, ok := cm.server.connManager.FindPersistentAddrID(netAddr); ok { + cm.server.connManager.Remove(id) + return nil } - return cm.server.connManager.CancelPending(netAddr) } - return nil + return err } // disconnectNode disconnects any peers that the provided compare function @@ -237,34 +224,40 @@ func (cm *rpcConnManager) RemoveByAddr(addr string) error { // This function is safe for concurrent access. func (cm *rpcConnManager) disconnectNode(cmp func(sp *serverPeer) bool) error { state := &cm.server.peerState - defer state.Unlock() - state.Lock() - // Check inbound peers. No callback is passed since there are no additional - // actions on disconnect for inbound peers. - found := disconnectPeer(state.inboundPeers, cmp, nil) - if found { + // Check inbound peers. + var inbound *serverPeer + state.Lock() + for _, peer := range state.inboundPeers { + if cmp(peer) { + inbound = peer + break + } + } + state.Unlock() + if inbound != nil { + inbound.Disconnect() return nil } // Check outbound peers in a loop to ensure all outbound connections to the // same ip:port are disconnected when there are multiple. - var numFound uint32 - for ; ; numFound++ { - found = disconnectPeer(state.outboundPeers, cmp, func(sp *serverPeer) { - // Update the group counts since the peer will be removed from the - // persistent peers just after this func returns. - state.outboundGroups[sp.remoteAddr.GroupKey()]-- - }) - if !found { - break + outbounds := make([]*serverPeer, 0, 1) + state.Lock() + for _, peer := range state.outboundPeers { + if cmp(peer) { + outbounds = append(outbounds, peer) } } - - if numFound == 0 { - return errors.New("peer not found") + state.Unlock() + if len(outbounds) > 0 { + for _, sp := range outbounds { + sp.Disconnect() + } + return nil } - return nil + + return errors.New("peer not found") } // DisconnectByID disconnects the peer associated with the provided id. This diff --git a/server.go b/server.go index 3f823bc93..81cfb2d4f 100644 --- a/server.go +++ b/server.go @@ -442,6 +442,7 @@ type serverPeer struct { // The service flags are updated in the address manager directly once the // peer reports them. The service flags on this instance are never used. server *server + conn *connmgr.Conn remoteAddr *addrmgr.NetAddress persistent bool isWhitelisted bool @@ -455,7 +456,6 @@ type serverPeer struct { // otherwise modified during operation and thus need to consider whether or // not they need to be protected for concurrent access. - connReq atomic.Pointer[connmgr.ConnReq] continueHash atomic.Pointer[chainhash.Hash] disableRelayTx atomic.Bool knownAddresses *apbf.Filter @@ -507,11 +507,13 @@ type serverPeer struct { // newServerPeer returns a new serverPeer instance. The peer needs to be set by // the caller. -func newServerPeer(s *server, remoteAddr *addrmgr.NetAddress, isPersistent bool) *serverPeer { +func newServerPeer(s *server, conn *connmgr.Conn, remoteAddr *addrmgr.NetAddress) *serverPeer { return &serverPeer{ server: s, + conn: conn, remoteAddr: remoteAddr, - persistent: isPersistent, + persistent: s.connManager.IsPersistent(conn.ID()), + isWhitelisted: s.connManager.IsWhitelisted(remoteAddr), knownAddresses: apbf.NewFilter(maxKnownAddrsPerPeer, knownAddrsFPRate), quit: make(chan struct{}), getDataQueue: make(chan []*wire.InvVect, maxConcurrentGetDataReqs), @@ -1036,42 +1038,6 @@ func (sp *serverPeer) OnVersion(msg *wire.MsgVersion) error { "providing desired services %v", msg.Services, missingServices) } - // Update the address manager and request known addresses from the - // remote peer for outbound connections. This is skipped when running - // on the simulation and regression test networks since they are only - // intended to connect to specified peers and actively avoid advertising - // and connecting to discovered peers. - if !cfg.SimNet && !cfg.RegNet && !isInbound { - // Advertise the local address when the server accepts incoming - // connections and it believes itself to be close to the best - // known tip. - if !cfg.DisableListen && sp.server.syncManager.IsCurrent() { - // Get address that best matches. - pver := uint32(msg.ProtocolVersion) - addrTypeFilter := natfSupported(pver) - lna := addrManager.GetBestLocalAddress(sp.remoteAddr, addrTypeFilter) - if lna.IsRoutable() { - addresses := []*addrmgr.NetAddress{lna} - sp.pushAddrMsg(pver, addresses) - } else { - srvrLog.Debugf("Local address %s is not routable and will not "+ - "be broadcast to outbound peer %v", lna.Key(), sp.Addr()) - } - } - - // Request known addresses if the server address manager needs - // more. - if addrManager.NeedMoreAddresses() { - sp.QueueMessage(wire.NewMsgGetAddr(), nil) - } - - // Mark the address as a known good address. - err := addrManager.Good(sp.remoteAddr) - if err != nil { - srvrLog.Errorf("Marking address as good failed: %v", err) - } - } - sp.reportedLocalAddr.Store(&msg.AddrYou) // Choose whether or not to relay transactions. @@ -2213,57 +2179,6 @@ func (s *server) handleBroadcastMsg(state *peerState, bmsg *broadcastMsg) { }) } -// disconnectPeer attempts to drop the connection of a targeted peer in the -// passed peer list. Targets are identified via usage of the passed -// `compareFunc`, which should return `true` if the passed peer is the target -// peer. This function returns true on success and false if the peer is unable -// to be located. If the peer is found, and the passed callback: `whenFound' -// isn't nil, we call it with the peer as the argument before it is removed -// from the peerList, and is disconnected from the server. -func disconnectPeer(peerList map[int32]*serverPeer, compareFunc func(*serverPeer) bool, whenFound func(*serverPeer)) bool { - for addr, peer := range peerList { - if compareFunc(peer) { - if whenFound != nil { - whenFound(peer) - } - - // This is ok because we are not continuing - // to iterate so won't corrupt the loop. - delete(peerList, addr) - peer.Disconnect() - return true - } - } - return false -} - -// connToNetAddr parses and returns an address manager network address from the -// remote address associated with the given connection. -// -// This function is safe for concurrent access. -func connToNetAddr(conn net.Conn) (*addrmgr.NetAddress, error) { - addrStr := conn.RemoteAddr().String() - host, portStr, err := net.SplitHostPort(addrStr) - if err != nil { - return nil, err - } - port, err := strconv.ParseUint(portStr, 10, 16) - if err != nil { - return nil, err - } - - addrType, addrBytes := addrmgr.EncodeHost(host) - if addrType == addrmgr.UnknownAddressType { - return nil, fmt.Errorf("unable to determine address type: %v", addrStr) - } - - // Since the host type has been successfully recognized and encoded, - // there is no need to perform a DNS lookup. - now := time.Unix(time.Now().Unix(), 0) - return addrmgr.NewNetAddressFromParams(addrType, addrBytes, uint16(port), - now, 0) -} - // handleBannedConn closes the provided connection if the remote address // associated with it is banned or the address can't be properly parsed. It // returns true when the connection is closed. @@ -2359,11 +2274,11 @@ func newPeerConfig(sp *serverPeer) *peer.Config { // instance, associates it with the connection, runs the peer (which starts all // additional server peer processing goroutines) and blocks until the peer // disconnects. -func (s *server) inboundPeerConnected(ctx context.Context, conn net.Conn) { - remoteNetAddr, err := connToNetAddr(conn) - if err != nil { - srvrLog.Debugf("Unable to create inbound peer for address %s: %v", - conn.RemoteAddr(), err) +func (s *server) inboundPeerConnected(ctx context.Context, conn *connmgr.Conn) { + remoteNetAddr, ok := conn.RemoteAddr().(*addrmgr.NetAddress) + if !ok { + srvrLog.Warnf("remote address for connection is incorrect type %T", + conn.RemoteAddr()) conn.Close() return } @@ -2373,8 +2288,7 @@ func (s *server) inboundPeerConnected(ctx context.Context, conn net.Conn) { return } - sp := newServerPeer(s, remoteNetAddr, false) - sp.isWhitelisted = isWhitelisted(remoteNetAddr) + sp := newServerPeer(s, conn, remoteNetAddr) sp.Peer = peer.NewInboundPeer(newPeerConfig(sp), conn) if err := sp.Handshake(ctx, sp.OnVersion); err != nil { srvrLog.Debugf("Failed handshake for inbound peer %s: %v", @@ -2391,31 +2305,27 @@ func (s *server) inboundPeerConnected(ctx context.Context, conn net.Conn) { // peer instance, associates it with the relevant state such as the connection // request instance and the connection itself, and start all additional server // peer processing goroutines. -func (s *server) outboundPeerConnected(ctx context.Context, c *connmgr.ConnReq, conn net.Conn) { - remoteNetAddr, err := connToNetAddr(conn) - if err != nil { - srvrLog.Debugf("Unable to create outbound peer for address %s: %v", - conn.RemoteAddr(), err) +func (s *server) outboundPeerConnected(ctx context.Context, conn *connmgr.Conn) { + remoteNetAddr, ok := conn.RemoteAddr().(*addrmgr.NetAddress) + if !ok { + srvrLog.Warnf("remote address for connection is incorrect type %T", + conn.RemoteAddr()) conn.Close() - s.connManager.Disconnect(c.ID()) + return } // Disconnect banned connections. Ideally we would never connect to a // banned peer, but the connection manager is currently unaware of banned // addresses, so this is needed. if disconnected := s.handleBannedConn(remoteNetAddr, conn); disconnected { - s.connManager.Disconnect(c.ID()) return } - sp := newServerPeer(s, remoteNetAddr, c.Permanent) - p := peer.NewOutboundPeer(newPeerConfig(sp), c.Addr, conn) - sp.Peer = p - sp.connReq.Store(c) - sp.isWhitelisted = isWhitelisted(remoteNetAddr) + sp := newServerPeer(s, conn, remoteNetAddr) + sp.Peer = peer.NewOutboundPeer(newPeerConfig(sp), conn.RemoteAddr(), conn) if err := sp.Handshake(ctx, sp.OnVersion); err != nil { - srvrLog.Debugf("Failed handshake for outbound peer %s: %v", c.Addr, err) - s.connManager.Disconnect(c.ID()) + srvrLog.Debugf("Failed handshake for outbound peer %s: %v", + conn.RemoteAddr(), err) return } sp.syncMgrPeer = netsync.NewPeer(sp.Peer) @@ -2737,6 +2647,41 @@ func (s *server) handleAddPeer(sp *serverPeer) bool { return false } + // Update the address manager and request known addresses from the remote + // peer for outbound connections. This is skipped when running on the + // simulation and regression test networks since they are only intended to + // connect to specified peers and actively avoid advertising and connecting + // to discovered peers. + addrManager := sp.server.addrManager + if !cfg.SimNet && !cfg.RegNet && !sp.Inbound() { + // Advertise the local address when the server accepts incoming + // connections and it believes itself to be close to the best known tip. + if !cfg.DisableListen && sp.server.syncManager.IsCurrent() { + // Get address that best matches. + pver := sp.ProtocolVersion() + addrTypeFilter := natfSupported(pver) + lna := addrManager.GetBestLocalAddress(sp.remoteAddr, addrTypeFilter) + if lna.IsRoutable() { + addrs := []*addrmgr.NetAddress{lna} + sp.pushAddrMsg(pver, addrs) + } else { + srvrLog.Debugf("Local address %s is not routable and will not "+ + "be broadcast to outbound peer %v", lna.Key(), sp.Addr()) + } + } + + // Request known addresses if the server address manager needs more. + if addrManager.NeedMoreAddresses() { + sp.QueueMessage(wire.NewMsgGetAddr(), nil) + } + + // Mark the address as a known good address. + err := addrManager.Good(sp.remoteAddr) + if err != nil { + srvrLog.Errorf("Marking address as good failed: %v", err) + } + } + // Consider the address the remote peer reported for the local connection as // a potential external address candidate for the server. s.considerReportedAddr(sp, sp.reportedLocalAddr.Load()) @@ -2758,17 +2703,6 @@ func (s *server) handleAddPeer(sp *serverPeer) bool { return false } - // Limit max number of total peers. However, allow whitelisted inbound - // peers regardless. - if state.count()+1 > cfg.MaxPeers && !isInboundWhitelisted { - srvrLog.Infof("Max peers reached [%d] - disconnecting peer %s", - cfg.MaxPeers, sp) - sp.Disconnect() - // TODO: how to handle permanent peers here? - // they should be rescheduled. - return false - } - // Add the new peer. if sp.Inbound() { state.inboundPeers[sp.ID()] = sp @@ -2819,21 +2753,12 @@ func (s *server) DonePeer(sp *serverPeer) { if _, ok := list[sp.ID()]; ok { if !sp.Inbound() { state.outboundGroups[sp.remoteAddr.GroupKey()]-- - connReq := sp.connReq.Load() - if connReq != nil { - s.connManager.Disconnect(connReq.ID()) - } } delete(list, sp.ID()) srvrLog.Debugf("Removed peer %s", sp) return } - connReq := sp.connReq.Load() - if connReq != nil { - s.connManager.Disconnect(connReq.ID()) - } - // Update the address manager with the last seen time. This is skipped when // running on the simulation and regression test networks since they are // only intended to connect to specified peers and actively avoid @@ -4419,39 +4344,40 @@ func newServer(ctx context.Context, profiler *profileServer, } cmgr, err := connmgr.New(&connmgr.Config{ Listeners: listeners, - OnAccept: func(conn net.Conn) { + OnAccept: func(conn *connmgr.Conn) { s.inboundPeerConnected(ctx, conn) }, RetryDuration: connectionRetryInterval, + MaxNormalConns: uint32(cfg.MaxPeers), TargetOutbound: s.targetOutbound, Dial: s.attemptDcrdDial, DialTimeout: cfg.DialTimeout, - OnConnection: func(c *connmgr.ConnReq, conn net.Conn) { - s.outboundPeerConnected(ctx, c, conn) + OnConnection: func(conn *connmgr.Conn) { + s.outboundPeerConnected(ctx, conn) }, GetNewAddress: newAddressFunc, + Whitelists: cfg.whitelists, }) if err != nil { return nil, err } s.connManager = cmgr - // Start up persistent peers. - permanentPeers := cfg.ConnectPeers - if len(permanentPeers) == 0 { - permanentPeers = cfg.AddPeers + // Add persistent peers. + persistentPeers := cfg.ConnectPeers + if len(persistentPeers) == 0 { + persistentPeers = cfg.AddPeers } - for _, addr := range permanentPeers { + for _, addr := range persistentPeers { tcpAddr, err := addrStringToNetAddr(addr) if err != nil { return nil, err } - go s.connManager.Connect(ctx, - &connmgr.ConnReq{ - Addr: tcpAddr, - Permanent: true, - }) + _, err = s.connManager.AddPersistent(tcpAddr) + if err != nil { + return nil, err + } } if !cfg.DisableRPC { @@ -4637,14 +4563,14 @@ func addrStringToNetAddr(addr string) (net.Addr, error) { return nil, fmt.Errorf("no addresses found for %s", host) } - port, err := strconv.Atoi(strPort) + port, err := strconv.ParseUint(strPort, 10, 16) if err != nil { return nil, err } return &net.TCPAddr{ IP: ips[0], - Port: port, + Port: int(port), }, nil } @@ -4695,19 +4621,3 @@ func addLocalAddress(addrMgr *addrmgr.AddrManager, addr string, services wire.Se return nil } - -// isWhitelisted returns whether the IP address is included in the whitelisted -// networks and IPs. -func isWhitelisted(addr *addrmgr.NetAddress) bool { - if len(cfg.whitelists) == 0 { - return false - } - - ip, _ := netip.AddrFromSlice(addr.IP) - for _, prefix := range cfg.whitelists { - if prefix.Contains(ip) { - return true - } - } - return false -}