diff --git a/lib/go/thrift/socket_conn_test.go b/lib/go/thrift/socket_conn_test.go index 073fd818b34..d6f554effad 100644 --- a/lib/go/thrift/socket_conn_test.go +++ b/lib/go/thrift/socket_conn_test.go @@ -20,6 +20,7 @@ package thrift import ( + "crypto/tls" "errors" "io" "net" @@ -32,13 +33,16 @@ import ( type serverSocketConnCallback func(testing.TB, *socketConn) -func serverSocketConn(tb testing.TB, f serverSocketConnCallback) (net.Listener, error) { +func serverSocketConn(tb testing.TB, f serverSocketConnCallback, tlsCert *tls.Certificate) (net.Listener, error) { tb.Helper() ln, err := net.Listen("tcp", "localhost:0") if err != nil { return nil, err } + if tlsCert != nil { + ln = tls.NewListener(ln, &tls.Config{Certificates: []tls.Certificate{*tlsCert}}) + } go func() { for { sc, err := createSocketConnFromReturn(ln.Accept()) @@ -86,6 +90,7 @@ func TestSocketConn(t *testing.T) { time.Sleep(interval) writeFully(tb, sc, second) }, + nil, ) if err != nil { t.Fatal(err) diff --git a/lib/go/thrift/socket_unix_conn.go b/lib/go/thrift/socket_unix_conn.go index c7621257995..2722c1b82aa 100644 --- a/lib/go/thrift/socket_unix_conn.go +++ b/lib/go/thrift/socket_unix_conn.go @@ -22,6 +22,7 @@ package thrift import ( + "crypto/tls" "errors" "io" "syscall" @@ -38,7 +39,12 @@ func (sc *socketConn) read0() error { } func (sc *socketConn) checkConn() error { - syscallConn, ok := sc.Conn.(syscall.Conn) + rawConn := sc.Conn + if tlsConn, ok := rawConn.(*tls.Conn); ok { + rawConn = tlsConn.NetConn() + } + + syscallConn, ok := rawConn.(syscall.Conn) if !ok { // No way to check, return nil return nil @@ -47,7 +53,7 @@ func (sc *socketConn) checkConn() error { // The reading about to be done here is non-blocking so we don't really // need a read deadline. We just need to clear the previously set read // deadline, if any. - sc.Conn.SetReadDeadline(zeroTime) + rawConn.SetReadDeadline(zeroTime) rc, err := syscallConn.SyscallConn() if err != nil { diff --git a/lib/go/thrift/socket_unix_conn_test.go b/lib/go/thrift/socket_unix_conn_test.go index 612d21325e9..0cf30b9574b 100644 --- a/lib/go/thrift/socket_unix_conn_test.go +++ b/lib/go/thrift/socket_unix_conn_test.go @@ -22,6 +22,11 @@ package thrift import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" "io" "net" "testing" @@ -29,6 +34,17 @@ import ( ) func TestSocketConnUnix(t *testing.T) { + + t.Run("plain", func(t *testing.T) { + testSocketConn(t, nil) + }) + t.Run("tls", func(t *testing.T) { + tlsCert := randomTLSCertificate(t) + testSocketConn(t, tlsCert) + }) +} + +func testSocketConn(t *testing.T, tlsCert *tls.Certificate) { const ( interval = time.Millisecond * 10 first = "hello" @@ -47,16 +63,24 @@ func TestSocketConnUnix(t *testing.T) { time.Sleep(interval) writeFully(tb, sc, second) }, + tlsCert, ) if err != nil { t.Fatal(err) } defer ln.Close() - sc, err := createSocketConnFromReturn(net.Dial("tcp", ln.Addr().String())) + conn, err := net.Dial("tcp", ln.Addr().String()) if err != nil { t.Fatal(err) } + if tlsCert != nil { + conn = tls.Client(conn, &tls.Config{ + InsecureSkipVerify: true, + }) + } + sc := wrapSocketConn(conn) + buf := make([]byte, 1024) if !sc.IsOpen() { @@ -100,3 +124,46 @@ func TestSocketConnUnix(t *testing.T) { t.Error("Expected sc to report not open, got true") } } + +func randomTLSCertificate(t *testing.T) *tls.Certificate { + t.Helper() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate private key: %v", err) + } + + template := x509.Certificate{ + Subject: pkix.Name{ + CommonName: "localhost", + }, + NotBefore: time.Now().Add(-time.Minute), + NotAfter: time.Now().Add(time.Hour), + + BasicConstraintsValid: true, + + DNSNames: []string{"localhost"}, + IPAddresses: []net.IP{ + net.ParseIP("127.0.0.1"), + net.ParseIP("::1"), + }, + } + + derBytes, err := x509.CreateCertificate( + rand.Reader, + &template, + &template, + &privateKey.PublicKey, + privateKey, + ) + if err != nil { + t.Fatalf("create certificate: %v", err) + } + + cert := tls.Certificate{ + Certificate: [][]byte{derBytes}, + PrivateKey: privateKey, + } + + return &cert +}