From 14e8885c3237eae03b125316cccea3973264f8d0 Mon Sep 17 00:00:00 2001 From: Percy Wegmann Date: Wed, 6 Aug 2014 11:22:00 -0700 Subject: [PATCH] crypto/tls: Added dynamic alternative to NameToCertificate map for SNI Revised version of https://golang.org/cl/81260045/ LGTM=agl R=golang-codereviews, gobot, agl, ox CC=golang-codereviews https://golang.org/cl/107400043 --- src/pkg/crypto/tls/common.go | 61 +++++++++++++--- src/pkg/crypto/tls/conn_test.go | 22 ++++-- src/pkg/crypto/tls/handshake_server.go | 11 ++- src/pkg/crypto/tls/handshake_server_test.go | 78 +++++++++++++++++++-- 4 files changed, 149 insertions(+), 23 deletions(-) diff --git a/src/pkg/crypto/tls/common.go b/src/pkg/crypto/tls/common.go index ce30385917..2b59136e65 100644 --- a/src/pkg/crypto/tls/common.go +++ b/src/pkg/crypto/tls/common.go @@ -202,6 +202,32 @@ type ClientSessionCache interface { Put(sessionKey string, cs *ClientSessionState) } +// ClientHelloInfo contains information from a ClientHello message in order to +// guide certificate selection in the GetCertificate callback. +type ClientHelloInfo struct { + // CipherSuites lists the CipherSuites supported by the client (e.g. + // TLS_RSA_WITH_RC4_128_SHA). + CipherSuites []uint16 + + // ServerName indicates the name of the server requested by the client + // in order to support virtual hosting. ServerName is only set if the + // client is using SNI (see + // http://tools.ietf.org/html/rfc4366#section-3.1). + ServerName string + + // SupportedCurves lists the elliptic curves supported by the client. + // SupportedCurves is set only if the Supported Elliptic Curves + // Extension is being used (see + // http://tools.ietf.org/html/rfc4492#section-5.1.1). + SupportedCurves []CurveID + + // SupportedPoints lists the point formats supported by the client. + // SupportedPoints is set only if the Supported Point Formats Extension + // is being used (see + // http://tools.ietf.org/html/rfc4492#section-5.1.2). + SupportedPoints []uint8 +} + // A Config structure is used to configure a TLS client or server. // After one has been passed to a TLS function it must not be // modified. A Config may be reused; the tls package will also not @@ -230,6 +256,13 @@ type Config struct { // for all connections. NameToCertificate map[string]*Certificate + // GetCertificate returns a Certificate based on the given + // ClientHelloInfo. If GetCertificate is nil or returns nil, then the + // certificate is retrieved from NameToCertificate. If + // NameToCertificate is nil, the first element of Certificates will be + // used. + GetCertificate func(clientHello *ClientHelloInfo) (*Certificate, error) + // RootCAs defines the set of root certificate authorities // that clients use when verifying server certificates. // If RootCAs is nil, TLS uses the host's root CA set. @@ -384,22 +417,28 @@ func (c *Config) mutualVersion(vers uint16) (uint16, bool) { return vers, true } -// getCertificateForName returns the best certificate for the given name, -// defaulting to the first element of c.Certificates if there are no good -// options. -func (c *Config) getCertificateForName(name string) *Certificate { - if len(c.Certificates) == 1 || c.NameToCertificate == nil { - // There's only one choice, so no point doing any work. - return &c.Certificates[0] +// getCertificate returns the best certificate for the given ClientHelloInfo, +// defaulting to the first element of c.Certificates. +func (c *Config) getCertificate(clientHello *ClientHelloInfo) (*Certificate, error) { + if c.GetCertificate != nil { + cert, err := c.GetCertificate(clientHello) + if cert != nil || err != nil { + return cert, err + } } - name = strings.ToLower(name) + if len(c.Certificates) == 1 || c.NameToCertificate == nil { + // There's only one choice, so no point doing any work. + return &c.Certificates[0], nil + } + + name := strings.ToLower(clientHello.ServerName) for len(name) > 0 && name[len(name)-1] == '.' { name = name[:len(name)-1] } if cert, ok := c.NameToCertificate[name]; ok { - return cert + return cert, nil } // try replacing labels in the name with wildcards until we get a @@ -409,12 +448,12 @@ func (c *Config) getCertificateForName(name string) *Certificate { labels[i] = "*" candidate := strings.Join(labels, ".") if cert, ok := c.NameToCertificate[candidate]; ok { - return cert + return cert, nil } } // If nothing matches, return the first certificate. - return &c.Certificates[0] + return &c.Certificates[0], nil } // BuildNameToCertificate parses c.Certificates and builds c.NameToCertificate diff --git a/src/pkg/crypto/tls/conn_test.go b/src/pkg/crypto/tls/conn_test.go index 5c555147ca..ec802cad70 100644 --- a/src/pkg/crypto/tls/conn_test.go +++ b/src/pkg/crypto/tls/conn_test.go @@ -88,19 +88,31 @@ func TestCertificateSelection(t *testing.T) { return -1 } - if n := pointerToIndex(config.getCertificateForName("example.com")); n != 0 { + certificateForName := func(name string) *Certificate { + clientHello := &ClientHelloInfo{ + ServerName: name, + } + if cert, err := config.getCertificate(clientHello); err != nil { + t.Errorf("unable to get certificate for name '%s': %s", name, err) + return nil + } else { + return cert + } + } + + if n := pointerToIndex(certificateForName("example.com")); n != 0 { t.Errorf("example.com returned certificate %d, not 0", n) } - if n := pointerToIndex(config.getCertificateForName("bar.example.com")); n != 1 { + if n := pointerToIndex(certificateForName("bar.example.com")); n != 1 { t.Errorf("bar.example.com returned certificate %d, not 1", n) } - if n := pointerToIndex(config.getCertificateForName("foo.example.com")); n != 2 { + if n := pointerToIndex(certificateForName("foo.example.com")); n != 2 { t.Errorf("foo.example.com returned certificate %d, not 2", n) } - if n := pointerToIndex(config.getCertificateForName("foo.bar.example.com")); n != 3 { + if n := pointerToIndex(certificateForName("foo.bar.example.com")); n != 3 { t.Errorf("foo.bar.example.com returned certificate %d, not 3", n) } - if n := pointerToIndex(config.getCertificateForName("foo.bar.baz.example.com")); n != 0 { + if n := pointerToIndex(certificateForName("foo.bar.baz.example.com")); n != 0 { t.Errorf("foo.bar.baz.example.com returned certificate %d, not 0", n) } } diff --git a/src/pkg/crypto/tls/handshake_server.go b/src/pkg/crypto/tls/handshake_server.go index 910f3d6ef0..39eeb363cd 100644 --- a/src/pkg/crypto/tls/handshake_server.go +++ b/src/pkg/crypto/tls/handshake_server.go @@ -186,7 +186,16 @@ Curves: } hs.cert = &config.Certificates[0] if len(hs.clientHello.serverName) > 0 { - hs.cert = config.getCertificateForName(hs.clientHello.serverName) + chi := &ClientHelloInfo{ + CipherSuites: hs.clientHello.cipherSuites, + ServerName: hs.clientHello.serverName, + SupportedCurves: hs.clientHello.supportedCurves, + SupportedPoints: hs.clientHello.supportedPoints, + } + if hs.cert, err = config.getCertificate(chi); err != nil { + c.sendAlert(alertInternalError) + return false, err + } } _, hs.ecdsaOk = hs.cert.PrivateKey.(*ecdsa.PrivateKey) diff --git a/src/pkg/crypto/tls/handshake_server_test.go b/src/pkg/crypto/tls/handshake_server_test.go index e94dc9f995..36c79a9b0c 100644 --- a/src/pkg/crypto/tls/handshake_server_test.go +++ b/src/pkg/crypto/tls/handshake_server_test.go @@ -257,6 +257,9 @@ type serverTest struct { expectedPeerCerts []string // config, if not nil, contains a custom Config to use for this test. config *Config + // expectAlert, if true, indicates that a fatal alert should be returned + // when handshaking with the server. + expectAlert bool // validate, if not nil, is a function that will be called with the // ConnectionState of the resulting connection. It returns false if the // ConnectionState is unacceptable. @@ -370,7 +373,9 @@ func (test *serverTest) run(t *testing.T, write bool) { if !write { flows, err := test.loadData() if err != nil { - t.Fatalf("%s: failed to load data from %s", test.name, test.dataPath()) + if !test.expectAlert { + t.Fatalf("%s: failed to load data from %s", test.name, test.dataPath()) + } } for i, b := range flows { if i%2 == 0 { @@ -379,11 +384,17 @@ func (test *serverTest) run(t *testing.T, write bool) { } bb := make([]byte, len(b)) n, err := io.ReadFull(clientConn, bb) - if err != nil { - t.Fatalf("%s #%d: %s\nRead %d, wanted %d, got %x, wanted %x\n", test.name, i+1, err, n, len(bb), bb[:n], b) - } - if !bytes.Equal(b, bb) { - t.Fatalf("%s #%d: mismatch on read: got:%x want:%x", test.name, i+1, bb, b) + if test.expectAlert { + if err == nil { + t.Fatal("Expected read failure but read succeeded") + } + } else { + if err != nil { + t.Fatalf("%s #%d: %s\nRead %d, wanted %d, got %x, wanted %x\n", test.name, i+1, err, n, len(bb), bb[:n], b) + } + if !bytes.Equal(b, bb) { + t.Fatalf("%s #%d: mismatch on read: got:%x want:%x", test.name, i+1, bb, b) + } } } clientConn.Close() @@ -562,6 +573,61 @@ func TestHandshakeServerSNI(t *testing.T) { runServerTestTLS12(t, test) } +// TestHandshakeServerSNICertForName is similar to TestHandshakeServerSNI, but +// tests the dynamic GetCertificate method +func TestHandshakeServerSNIGetCertificate(t *testing.T) { + config := *testConfig + + // Replace the NameToCertificate map with a GetCertificate function + nameToCert := config.NameToCertificate + config.NameToCertificate = nil + config.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) { + cert, _ := nameToCert[clientHello.ServerName] + return cert, nil + } + test := &serverTest{ + name: "SNI", + command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "AES128-SHA", "-servername", "snitest.com"}, + config: &config, + } + runServerTestTLS12(t, test) +} + +// TestHandshakeServerSNICertForNameNotFound is similar to +// TestHandshakeServerSNICertForName, but tests to make sure that when the +// GetCertificate method doesn't return a cert, we fall back to what's in +// the NameToCertificate map. +func TestHandshakeServerSNIGetCertificateNotFound(t *testing.T) { + config := *testConfig + + config.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) { + return nil, nil + } + test := &serverTest{ + name: "SNI", + command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "AES128-SHA", "-servername", "snitest.com"}, + config: &config, + } + runServerTestTLS12(t, test) +} + +// TestHandshakeServerSNICertForNameError tests to make sure that errors in +// GetCertificate result in a tls alert. +func TestHandshakeServerSNIGetCertificateError(t *testing.T) { + config := *testConfig + + config.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) { + return nil, fmt.Errorf("Test error in GetCertificate") + } + test := &serverTest{ + name: "SNI", + command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "AES128-SHA", "-servername", "snitest.com"}, + config: &config, + expectAlert: true, + } + runServerTestTLS12(t, test) +} + // TestCipherSuiteCertPreferance ensures that we select an RSA ciphersuite with // an RSA certificate and an ECDSA ciphersuite with an ECDSA certificate. func TestCipherSuiteCertPreferenceECDSA(t *testing.T) {