diff --git a/dtlstransport.go b/dtlstransport.go index c9276a19a52..3156c26149c 100644 --- a/dtlstransport.go +++ b/dtlstransport.go @@ -13,6 +13,7 @@ import ( "crypto/x509" "errors" "fmt" + "net" "strings" "sync" "sync/atomic" @@ -319,7 +320,16 @@ func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error { return t.failStart(err) } + // Configure ICE for SPED after we created the DTLS transport. + if t.api.settingEngine.enableSped { + t.iceTransport.SetDtlsCallback(func(packet []byte, rAddr net.Addr) { + dtlsConn.InjectInboundPacket(packet, rAddr) + }) + } + + // This awaits the DTLS handshake. if err = t.handshakeDTLS(dtlsConn); err != nil { + fmt.Println("DTLS handshake complete") dtlsEndpoint.SetOnClose(nil) _ = dtlsConn.Close() @@ -368,7 +378,8 @@ func (t *DTLSTransport) dtlsSharedOptions(certificate tls.Certificate) []dtls.Op dtls.WithCertificates(certificate), dtls.WithSRTPProtectionProfiles(t.srtpProtectionProfiles()...), dtls.WithExtendedMasterSecret(t.api.settingEngine.dtls.extendedMasterSecret), - dtls.WithInsecureSkipVerify(!t.api.settingEngine.dtls.disableInsecureSkipVerify), + // TODO: this should be the default, DTLS runs over ICE which *hopefully* checks the source. + dtls.WithInsecureSkipVerify(true), dtls.WithLoggerFactory(t.api.settingEngine.LoggerFactory), dtls.WithVerifyPeerCertificate(t.verifyPeerCertificateFunc()), } @@ -380,6 +391,7 @@ func (t *DTLSTransport) dtlsSharedOptions(certificate tls.Certificate) []dtls.Op ) } + // TODO: should this initially be set to one day for SPED? if t.api.settingEngine.dtls.retransmissionInterval > 0 { sharedOpts = append( sharedOpts, @@ -387,6 +399,20 @@ func (t *DTLSTransport) dtlsSharedOptions(certificate tls.Certificate) []dtls.Op ) } + // Configure DTLS for SPED. + if t.api.settingEngine.enableSped { + sharedOpts = append( + sharedOpts, + dtls.WithOutboundHandshakePacketInterceptor(func(packet []byte, end bool) bool { + // Forward the packet to the ICE transport for piggybacking. + return t.iceTransport.Piggyback(packet, end) + }), + dtls.WithInboundHandshakePacketNotifier(func(packet []byte) { + t.iceTransport.ReportDtlsPacket(packet) + }), + ) + } + if t.api.settingEngine.replayProtection.DTLS != nil { sharedOpts = append( sharedOpts, @@ -559,6 +585,10 @@ func (t *DTLSTransport) completeStart(dtlsConn *dtls.Conn) error { t.srtpProtectionProfile = srtpProtectionProfile t.conn = dtlsConn t.onStateChange(DTLSTransportStateConnected) + if t.api.settingEngine.enableSped { + t.iceTransport.Piggyback(nil, true) + t.iceTransport.SetDtlsCallback(nil) + } return t.startSRTP() } diff --git a/examples/warp/main.go b/examples/warp/main.go index 2814017126b..561b972e746 100644 --- a/examples/warp/main.go +++ b/examples/warp/main.go @@ -44,9 +44,10 @@ func setupOfferHandler(pc **webrtc.PeerConnection) { return } - // Enable SNAP. + // Enable SNAP and SPED. s := webrtc.SettingEngine{} s.EnableSctpSnap(true) + s.EnableSped(true) api := webrtc.NewAPI(webrtc.WithSettingEngine(s)) var err error diff --git a/go.mod b/go.mod index 261908e630a..0c8b5d795b1 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,11 @@ module github.com/pion/webrtc/v4 go 1.24.0 +replace github.com/pion/sctp => /home/fippo/pion/sctp +replace github.com/pion/dtls/v3 => /home/fippo/pion/dtls +replace github.com/pion/stun/v3 => /home/fippo/pion/stun +replace github.com/pion/ice/v4 => /home/fippo/pion/ice + require ( github.com/pion/datachannel v1.6.0 github.com/pion/dtls/v3 v3.1.2 @@ -28,6 +33,7 @@ require ( github.com/onsi/ginkgo v1.16.5 // indirect github.com/onsi/gomega v1.17.0 // indirect github.com/pion/mdns/v2 v2.1.0 // indirect + github.com/pion/transport/v3 v3.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/wlynxg/anet v0.0.5 // indirect golang.org/x/crypto v0.48.0 // indirect diff --git a/icetransport.go b/icetransport.go index 0e6d864f341..8c0d9aee32a 100644 --- a/icetransport.go +++ b/icetransport.go @@ -8,6 +8,7 @@ package webrtc import ( "context" "fmt" + "net" "sync" "sync/atomic" "time" @@ -39,6 +40,8 @@ type ICETransport struct { loggerFactory logging.LoggerFactory + dtlsCallback func(packet []byte, rAddr net.Addr) + log logging.LeveledLogger } @@ -69,7 +72,7 @@ func (t *ICETransport) GetSelectedCandidatePair() (*ICECandidatePair, error) { } // GetSelectedCandidatePairStats returns the selected candidate pair stats on which packets are sent -// if there is no selected pair empty stats, false is returned to indicate stats not available. +// if there is no selected pair, false is returned to indicate stats are not available. func (t *ICETransport) GetSelectedCandidatePairStats() (ICECandidatePairStats, bool) { return t.gatherer.getSelectedCandidatePairStats() } @@ -107,6 +110,7 @@ func (t *ICETransport) Start(gatherer *ICEGatherer, params ICEParameters, role * if agent == nil { return fmt.Errorf("%w: unable to start ICETransport", errICEAgentNotExist) } + agent.SetDtlsCallback(t.dtlsCallback) if err := agent.OnConnectionStateChange(func(iceState ice.ConnectionState) { state := newICETransportStateFromICE(iceState) @@ -145,12 +149,12 @@ func (t *ICETransport) Start(gatherer *ICEGatherer, params ICEParameters, role * var err error switch *role { case ICERoleControlling: - iceConn, err = agent.Dial(ctx, + iceConn, err = agent.StartDial( params.UsernameFragment, params.Password) case ICERoleControlled: - iceConn, err = agent.Accept(ctx, + iceConn, err = agent.StartAccept( params.UsernameFragment, params.Password) @@ -158,6 +162,17 @@ func (t *ICETransport) Start(gatherer *ICEGatherer, params ICEParameters, role * err = errICERoleUnknown } + if err != nil { + t.lock.Lock() + + return err + } + + if !t.gatherer.api.settingEngine.enableSped { + // Note: this blocks until a pair is found. + err = agent.AwaitConnect(ctx) + } + // Reacquire the lock to set the connection/mux t.lock.Lock() if err != nil { @@ -180,6 +195,16 @@ func (t *ICETransport) Start(gatherer *ICEGatherer, params ICEParameters, role * return nil } +func (t *ICETransport) SetDtlsCallback(cb func(packet []byte, rAddr net.Addr)) { + t.lock.Lock() + defer t.lock.Unlock() + if agent := t.gatherer.getAgent(); agent != nil { + agent.SetDtlsCallback(cb) + } else { + t.dtlsCallback = cb + } +} + // restart is not exposed currently because ORTC has users create a whole new ICETransport // so for now lets keep it private so we don't cause ORTC users to depend on non-standard APIs. func (t *ICETransport) restart() error { @@ -455,3 +480,31 @@ func (t *ICETransport) setRemoteCredentials(newUfrag, newPwd string) error { return agent.SetRemoteCredentials(newUfrag, newPwd) } + +// Piggyback forwards a raw packet to the ICE Agent. +func (t *ICETransport) Piggyback(packet []byte, end bool) bool { + t.lock.Lock() + defer t.lock.Unlock() + + agent := t.gatherer.getAgent() + if agent == nil { + t.log.Warnf("%w: unable to Piggyback", errICEAgentNotExist) + + return false + } + + return agent.Piggyback(packet, end) +} + +func (t *ICETransport) ReportDtlsPacket(packet []byte) { + t.lock.Lock() + defer t.lock.Unlock() + + agent := t.gatherer.getAgent() + if agent == nil { + t.log.Warnf("%w: unable report DTLS packet", errICEAgentNotExist) + + return + } + agent.ReportDtlsPacket(packet) +} diff --git a/peerconnection.go b/peerconnection.go index acb5b5bbb37..d2f5e12216a 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -2770,6 +2770,7 @@ func (pc *PeerConnection) startTransports( dtlsRole DTLSRole, remoteUfrag, remotePwd, fingerprint, fingerprintHash string, ) { + fmt.Println("START ICE", time.Now()) // Start the ice transport err := pc.iceTransport.Start( pc.iceGatherer, @@ -2799,6 +2800,7 @@ func (pc *PeerConnection) startTransports( }() } + fmt.Println("START DTLS", time.Now(), dtlsRole) // Start the dtls transport err = pc.dtlsTransport.Start(DTLSParameters{ Role: dtlsRole, diff --git a/peerconnection_test.go b/peerconnection_test.go index 09f09eba167..4877d7cf4a1 100644 --- a/peerconnection_test.go +++ b/peerconnection_test.go @@ -5,6 +5,7 @@ package webrtc import ( "runtime" + "strings" "sync" "sync/atomic" "testing" @@ -968,3 +969,44 @@ func TestICETrickleCapabilityString(t *testing.T) { assert.Equal(t, tt.expected, tt.value.String()) } } + +func TestWarp(t *testing.T) { + s := SettingEngine{} + s.EnableSped(true) + api := NewAPI(WithSettingEngine(s)) + + offer, err := api.NewPeerConnection(Configuration{}) + assert.NoError(t, err) + answer, err := api.NewPeerConnection(Configuration{}) + assert.NoError(t, err) + + peerConnectionsConnected := untilConnectionState(PeerConnectionStateConnected, offer, answer) + assert.NoError(t, signalPair(offer, answer)) + peerConnectionsConnected.Wait() + + closePairNow(t, offer, answer) +} + +func TestWarpClient(t *testing.T) { + s := SettingEngine{} + s.EnableSped(true) + api := NewAPI(WithSettingEngine(s)) + + offer, err := api.NewPeerConnection(Configuration{}) + assert.NoError(t, err) + answer, err := api.NewPeerConnection(Configuration{}) + assert.NoError(t, err) + + peerConnectionsConnected := untilConnectionState(PeerConnectionStateConnected, offer, answer) + assert.NoError(t, signalPairWithModification( + offer, answer, + func(sessionDescription string) string { + return strings.ReplaceAll( + sessionDescription, + "setup:actpass", + "setup:active") + })) + peerConnectionsConnected.Wait() + + closePairNow(t, offer, answer) +} diff --git a/settingengine.go b/settingengine.go index dda020f7745..f4bed675779 100644 --- a/settingengine.go +++ b/settingengine.go @@ -117,6 +117,7 @@ type SettingEngine struct { dataChannelBlockWrite bool handleUndeclaredSSRCWithoutAnswer bool ignoreRidPauseForRecv bool + enableSped bool } type renominationSettings struct { @@ -435,7 +436,7 @@ func (e *SettingEngine) SetMulticastDNSHostName(hostName string) { e.candidates.MulticastDNSHostName = hostName } -// SetICECredentials sets a staic uFrag/uPwd to be used by pion/ice +// SetICECredentials sets a static ice-ufrag/ice-pwd to be used by pion/ice // // This is useful if you want to do signalless WebRTC session, // or having a reproducible environment with static credentials. @@ -477,7 +478,7 @@ func (e *SettingEngine) DisableSRTCPReplayProtection(isDisabled bool) { } // SetSDPMediaLevelFingerprints configures the logic for DTLS Fingerprint insertion -// If true, fingerprints will be inserted in the sdp at the fingerprint +// If true, fingerprints will be inserted in the sdp at the media // level, instead of the session level. This helps with compatibility with // some webrtc implementations. func (e *SettingEngine) SetSDPMediaLevelFingerprints(sdpMediaLevelFingerprints bool) { @@ -722,3 +723,8 @@ func (e *SettingEngine) SetHandleUndeclaredSSRCWithoutAnswer(handleUndeclaredSSR func (e *SettingEngine) SetIgnoreRidPauseForRecv(ignoreRidPauseForRecv bool) { e.ignoreRidPauseForRecv = ignoreRidPauseForRecv } + +// EnableSped enabled SPED/dtls-in-stun. +func (e *SettingEngine) EnableSped(enable bool) { + e.enableSped = enable +}