Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion lib/go/thrift/socket_conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package thrift

import (
"crypto/tls"
"errors"
"io"
"net"
Expand All @@ -32,13 +33,16 @@ import (

type serverSocketConnCallback func(testing.TB, *socketConn)

func serverSocketConn(tb testing.TB, f serverSocketConnCallback) (net.Listener, error) {
func serverSocketConn(tb testing.TB, f serverSocketConnCallback, tlsCert *tls.Certificate) (net.Listener, error) {
tb.Helper()

ln, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, err
}
if tlsCert != nil {
ln = tls.NewListener(ln, &tls.Config{Certificates: []tls.Certificate{*tlsCert}})
}
go func() {
for {
sc, err := createSocketConnFromReturn(ln.Accept())
Expand Down Expand Up @@ -86,6 +90,7 @@ func TestSocketConn(t *testing.T) {
time.Sleep(interval)
writeFully(tb, sc, second)
},
nil,
)
if err != nil {
t.Fatal(err)
Expand Down
10 changes: 8 additions & 2 deletions lib/go/thrift/socket_unix_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
package thrift

import (
"crypto/tls"
"errors"
"io"
"syscall"
Expand All @@ -38,7 +39,12 @@ func (sc *socketConn) read0() error {
}

func (sc *socketConn) checkConn() error {
syscallConn, ok := sc.Conn.(syscall.Conn)
rawConn := sc.Conn
if tlsConn, ok := rawConn.(*tls.Conn); ok {
rawConn = tlsConn.NetConn()
}

syscallConn, ok := rawConn.(syscall.Conn)
if !ok {
// No way to check, return nil
return nil
Expand All @@ -47,7 +53,7 @@ func (sc *socketConn) checkConn() error {
// The reading about to be done here is non-blocking so we don't really
// need a read deadline. We just need to clear the previously set read
// deadline, if any.
sc.Conn.SetReadDeadline(zeroTime)
rawConn.SetReadDeadline(zeroTime)

rc, err := syscallConn.SyscallConn()
if err != nil {
Expand Down
69 changes: 68 additions & 1 deletion lib/go/thrift/socket_unix_conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,29 @@
package thrift

import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"io"
"net"
"testing"
"time"
)

func TestSocketConnUnix(t *testing.T) {

t.Run("plain", func(t *testing.T) {
testSocketConn(t, nil)
})
t.Run("tls", func(t *testing.T) {
tlsCert := randomTLSCertificate(t)
testSocketConn(t, tlsCert)
})
}

func testSocketConn(t *testing.T, tlsCert *tls.Certificate) {
const (
interval = time.Millisecond * 10
first = "hello"
Expand All @@ -47,16 +63,24 @@ func TestSocketConnUnix(t *testing.T) {
time.Sleep(interval)
writeFully(tb, sc, second)
},
tlsCert,
)
if err != nil {
t.Fatal(err)
}
defer ln.Close()

sc, err := createSocketConnFromReturn(net.Dial("tcp", ln.Addr().String()))
conn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatal(err)
}
if tlsCert != nil {
conn = tls.Client(conn, &tls.Config{
InsecureSkipVerify: true,
})
}
sc := wrapSocketConn(conn)

buf := make([]byte, 1024)

if !sc.IsOpen() {
Expand Down Expand Up @@ -100,3 +124,46 @@ func TestSocketConnUnix(t *testing.T) {
t.Error("Expected sc to report not open, got true")
}
}

func randomTLSCertificate(t *testing.T) *tls.Certificate {
t.Helper()

privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("generate private key: %v", err)
}

template := x509.Certificate{
Subject: pkix.Name{
CommonName: "localhost",
},
NotBefore: time.Now().Add(-time.Minute),
NotAfter: time.Now().Add(time.Hour),

BasicConstraintsValid: true,

DNSNames: []string{"localhost"},
IPAddresses: []net.IP{
net.ParseIP("127.0.0.1"),
net.ParseIP("::1"),
},
}

derBytes, err := x509.CreateCertificate(
rand.Reader,
&template,
&template,
&privateKey.PublicKey,
privateKey,
)
if err != nil {
t.Fatalf("create certificate: %v", err)
}

cert := tls.Certificate{
Certificate: [][]byte{derBytes},
PrivateKey: privateKey,
}

return &cert
}
Loading