crypto/tls: add WrapSession and UnwrapSession

There was a bug in TestResumption: the first ExpiredSessionTicket was
inserting a ticket far in the future, so the second ExpiredSessionTicket
wasn't actually supposed to fail. However, there was a bug in
checkForResumption->sendSessionTicket, too: if a session was not resumed
because it was too old, its createdAt was still persisted in the next
ticket. The two bugs used to cancel each other out.

For #60105
Fixes #19199

Change-Id: Ic9b2aab943dcbf0de62b8758a6195319dc286e2f
Reviewed-on: https://go-review.googlesource.com/c/go/+/496821
Run-TryBot: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
This commit is contained in:
Filippo Valsorda 2023-05-22 10:49:07 +02:00
parent 371ebe731b
commit 6824765b4b
7 changed files with 182 additions and 55 deletions

View file

@ -3,3 +3,7 @@ pkg crypto/tls, method (*SessionState) Bytes() ([]uint8, error) #60105
pkg crypto/tls, type SessionState struct #60105
pkg crypto/tls, func NewResumptionState([]uint8, *SessionState) (*ClientSessionState, error) #60105
pkg crypto/tls, method (*ClientSessionState) ResumptionState() ([]uint8, *SessionState, error) #60105
pkg crypto/tls, method (*Config) DecryptTicket([]uint8, ConnectionState) (*SessionState, error) #60105
pkg crypto/tls, method (*Config) EncryptTicket(ConnectionState, *SessionState) ([]uint8, error) #60105
pkg crypto/tls, type Config struct, UnwrapSession func([]uint8, ConnectionState) (*SessionState, error) #60105
pkg crypto/tls, type Config struct, WrapSession func(ConnectionState, *SessionState) ([]uint8, error) #60105

View file

@ -673,6 +673,35 @@ type Config struct {
// session resumption. It is only used by clients.
ClientSessionCache ClientSessionCache
// UnwrapSession is called on the server to turn a ticket/identity
// previously produced by [WrapSession] into a usable session.
//
// UnwrapSession will usually either decrypt a session state in the ticket
// (for example with [Config.EncryptTicket]), or use the ticket as a handle
// to recover a previously stored state. It must use [ParseSessionState] to
// deserialize the session state.
//
// If UnwrapSession returns an error, the connection is terminated. If it
// returns (nil, nil), the session is ignored. crypto/tls may still choose
// not to resume the returned session.
UnwrapSession func(identity []byte, cs ConnectionState) (*SessionState, error)
// WrapSession is called on the server to produce a session ticket/identity.
//
// WrapSession must serialize the session state with [SessionState.Bytes].
// It may then encrypt the serialized state (for example with
// [Config.DecryptTicket]) and use it as the ticket, or store the state and
// return a handle for it.
//
// If WrapSession returns an error, the connection is terminated.
//
// Warning: the return value will be exposed on the wire and to clients in
// plaintext. The application is in charge of encrypting and authenticating
// it (and rotating keys) or returning high-entropy identifiers. Failing to
// do so correctly can compromise current, previous, and future connections
// depending on the protocol version.
WrapSession func(ConnectionState, *SessionState) ([]byte, error)
// MinVersion contains the minimum TLS version that is acceptable.
//
// By default, TLS 1.2 is currently used as the minimum when acting as a
@ -794,6 +823,8 @@ func (c *Config) Clone() *Config {
SessionTicketsDisabled: c.SessionTicketsDisabled,
SessionTicketKey: c.SessionTicketKey,
ClientSessionCache: c.ClientSessionCache,
UnwrapSession: c.UnwrapSession,
WrapSession: c.WrapSession,
MinVersion: c.MinVersion,
MaxVersion: c.MaxVersion,
CurvePreferences: c.CurvePreferences,

View file

@ -900,6 +900,7 @@ func testResumption(t *testing.T, version uint16) {
}
testResumeState := func(test string, didResume bool) {
t.Helper()
_, hs, err := testHandshake(t, clientConfig, serverConfig)
if err != nil {
t.Fatalf("%s: handshake failed: %s", test, err)
@ -985,9 +986,11 @@ func testResumption(t *testing.T, version uint16) {
// Age the session ticket a bit at a time, but don't expire it.
d := 0 * time.Hour
serverConfig.Time = func() time.Time { return time.Now().Add(d) }
deleteTicket()
testResumeState("GetFreshSessionTicket", false)
for i := 0; i < 13; i++ {
d += 12 * time.Hour
serverConfig.Time = func() time.Time { return time.Now().Add(d) }
testResumeState("OldSessionTicket", true)
}
// Expire it (now a little more than 7 days) and make sure a full
@ -995,7 +998,6 @@ func testResumption(t *testing.T, version uint16) {
// TLS 1.3 since the client should be using a fresh ticket sent over
// by the server.
d += 12 * time.Hour
serverConfig.Time = func() time.Time { return time.Now().Add(d) }
if version == VersionTLS13 {
testResumeState("ExpiredSessionTicket", true)
} else {

View file

@ -70,9 +70,11 @@ func (hs *serverHandshakeState) handshake() error {
// For an overview of TLS handshaking, see RFC 5246, Section 7.3.
c.buffering = true
if hs.checkForResumption() {
if err := hs.checkForResumption(); err != nil {
return err
}
if hs.sessionState != nil {
// The client has included a session ticket and so we do an abbreviated handshake.
c.didResume = true
if err := hs.doResumeHandshake(); err != nil {
return err
}
@ -399,65 +401,80 @@ func (hs *serverHandshakeState) cipherSuiteOk(c *cipherSuite) bool {
}
// checkForResumption reports whether we should perform resumption on this connection.
func (hs *serverHandshakeState) checkForResumption() bool {
func (hs *serverHandshakeState) checkForResumption() error {
c := hs.c
if c.config.SessionTicketsDisabled {
return false
return nil
}
plaintext := c.decryptTicket(hs.clientHello.sessionTicket)
if plaintext == nil {
return false
var sessionState *SessionState
if c.config.UnwrapSession != nil {
ss, err := c.config.UnwrapSession(hs.clientHello.sessionTicket, c.connectionStateLocked())
if err != nil {
return err
}
if ss == nil {
return nil
}
sessionState = ss
} else {
plaintext := c.config.decryptTicket(hs.clientHello.sessionTicket, c.ticketKeys)
if plaintext == nil {
return nil
}
ss, err := ParseSessionState(plaintext)
if err != nil {
return nil
}
sessionState = ss
}
ss, err := ParseSessionState(plaintext)
if err != nil {
return false
}
hs.sessionState = ss
// TLS 1.2 tickets don't natively have a lifetime, but we want to avoid
// re-wrapping the same master secret in different tickets over and over for
// too long, weakening forward secrecy.
createdAt := time.Unix(int64(hs.sessionState.createdAt), 0)
createdAt := time.Unix(int64(sessionState.createdAt), 0)
if c.config.time().Sub(createdAt) > maxSessionTicketLifetime {
return false
return nil
}
// Never resume a session for a different TLS version.
if c.vers != hs.sessionState.version {
return false
if c.vers != sessionState.version {
return nil
}
cipherSuiteOk := false
// Check that the client is still offering the ciphersuite in the session.
for _, id := range hs.clientHello.cipherSuites {
if id == hs.sessionState.cipherSuite {
if id == sessionState.cipherSuite {
cipherSuiteOk = true
break
}
}
if !cipherSuiteOk {
return false
return nil
}
// Check that we also support the ciphersuite from the session.
hs.suite = selectCipherSuite([]uint16{hs.sessionState.cipherSuite},
suite := selectCipherSuite([]uint16{sessionState.cipherSuite},
c.config.cipherSuites(), hs.cipherSuiteOk)
if hs.suite == nil {
return false
if suite == nil {
return nil
}
sessionHasClientCerts := len(hs.sessionState.peerCertificates) != 0
sessionHasClientCerts := len(sessionState.peerCertificates) != 0
needClientCerts := requiresClientCert(c.config.ClientAuth)
if needClientCerts && !sessionHasClientCerts {
return false
return nil
}
if sessionHasClientCerts && c.config.ClientAuth == NoClientCert {
return false
return nil
}
return true
hs.sessionState = sessionState
hs.suite = suite
c.didResume = true
return nil
}
func (hs *serverHandshakeState) doResumeHandshake() error {
@ -769,13 +786,20 @@ func (hs *serverHandshakeState) sendSessionTicket() error {
// the original time it was created.
state.createdAt = hs.sessionState.createdAt
}
stateBytes, err := state.Bytes()
if err != nil {
return err
}
m.ticket, err = c.encryptTicket(stateBytes)
if err != nil {
return err
if c.config.WrapSession != nil {
m.ticket, err = c.config.WrapSession(c.connectionStateLocked(), state)
if err != nil {
return err
}
} else {
stateBytes, err := state.Bytes()
if err != nil {
return err
}
m.ticket, err = c.config.encryptTicket(stateBytes, c.ticketKeys)
if err != nil {
return err
}
}
if _, err := hs.c.writeHandshakeRecord(m, &hs.finishedHash); err != nil {

View file

@ -275,12 +275,29 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error {
break
}
plaintext := c.decryptTicket(identity.label)
if plaintext == nil {
continue
var sessionState *SessionState
if c.config.UnwrapSession != nil {
var err error
sessionState, err = c.config.UnwrapSession(identity.label, c.connectionStateLocked())
if err != nil {
return err
}
if sessionState == nil {
continue
}
} else {
plaintext := c.config.decryptTicket(identity.label, c.ticketKeys)
if plaintext == nil {
continue
}
var err error
sessionState, err = ParseSessionState(plaintext)
if err != nil {
continue
}
}
sessionState, err := ParseSessionState(plaintext)
if err != nil || sessionState.version != VersionTLS13 {
if sessionState.version != VersionTLS13 {
continue
}
@ -781,14 +798,21 @@ func (hs *serverHandshakeStateTLS13) sendSessionTickets() error {
return err
}
state.secret = psk
stateBytes, err := state.Bytes()
if err != nil {
c.sendAlert(alertInternalError)
return err
}
m.label, err = c.encryptTicket(stateBytes)
if err != nil {
return err
if c.config.WrapSession != nil {
m.label, err = c.config.WrapSession(c.connectionStateLocked(), state)
if err != nil {
return err
}
} else {
stateBytes, err := state.Bytes()
if err != nil {
c.sendAlert(alertInternalError)
return err
}
m.label, err = c.config.encryptTicket(stateBytes, c.ticketKeys)
if err != nil {
return err
}
}
m.lifetime = uint32(maxSessionTicketLifetime / time.Second)

View file

@ -228,8 +228,21 @@ func (c *Conn) sessionState() (*SessionState, error) {
}, nil
}
func (c *Conn) encryptTicket(state []byte) ([]byte, error) {
if len(c.ticketKeys) == 0 {
// EncryptTicket encrypts a ticket with the Config's configured (or default)
// session ticket keys. It can be used as a [Config.WrapSession] implementation.
func (c *Config) EncryptTicket(cs ConnectionState, ss *SessionState) ([]byte, error) {
ticketKeys := c.ticketKeys(nil)
stateBytes, err := ss.Bytes()
if err != nil {
return nil, err
}
return c.encryptTicket(stateBytes, ticketKeys)
}
var _ = &Config{WrapSession: (&Config{}).EncryptTicket}
func (c *Config) encryptTicket(state []byte, ticketKeys []ticketKey) ([]byte, error) {
if len(ticketKeys) == 0 {
return nil, errors.New("tls: internal error: session ticket keys unavailable")
}
@ -239,10 +252,10 @@ func (c *Conn) encryptTicket(state []byte) ([]byte, error) {
authenticated := encrypted[:len(encrypted)-sha256.Size]
macBytes := encrypted[len(encrypted)-sha256.Size:]
if _, err := io.ReadFull(c.config.rand(), iv); err != nil {
if _, err := io.ReadFull(c.rand(), iv); err != nil {
return nil, err
}
key := c.ticketKeys[0]
key := ticketKeys[0]
block, err := aes.NewCipher(key.aesKey[:])
if err != nil {
return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error())
@ -256,7 +269,26 @@ func (c *Conn) encryptTicket(state []byte) ([]byte, error) {
return encrypted, nil
}
func (c *Conn) decryptTicket(encrypted []byte) []byte {
// DecryptTicket decrypts a ticket encrypted by [Config.EncryptTicket]. It can
// be used as a [Config.UnwrapSession] implementation.
//
// If the ticket can't be decrypted or parsed, DecryptTicket returns (nil, nil).
func (c *Config) DecryptTicket(identity []byte, cs ConnectionState) (*SessionState, error) {
ticketKeys := c.ticketKeys(nil)
stateBytes := c.decryptTicket(identity, ticketKeys)
if stateBytes == nil {
return nil, nil
}
s, err := ParseSessionState(stateBytes)
if err != nil {
return nil, nil // drop unparsable tickets on the floor
}
return s, nil
}
var _ = &Config{UnwrapSession: (&Config{}).DecryptTicket}
func (c *Config) decryptTicket(encrypted []byte, ticketKeys []ticketKey) []byte {
if len(encrypted) < aes.BlockSize+sha256.Size {
return nil
}
@ -266,7 +298,7 @@ func (c *Conn) decryptTicket(encrypted []byte) []byte {
authenticated := encrypted[:len(encrypted)-sha256.Size]
macBytes := encrypted[len(encrypted)-sha256.Size:]
for _, key := range c.ticketKeys {
for _, key := range ticketKeys {
mac := hmac.New(sha256.New, key.hmacKey[:])
mac.Write(authenticated)
expected := mac.Sum(nil)

View file

@ -758,7 +758,7 @@ func TestWarningAlertFlood(t *testing.T) {
}
func TestCloneFuncFields(t *testing.T) {
const expectedCount = 6
const expectedCount = 8
called := 0
c1 := Config{
@ -786,6 +786,14 @@ func TestCloneFuncFields(t *testing.T) {
called |= 1 << 5
return nil
},
UnwrapSession: func(identity []byte, cs ConnectionState) (*SessionState, error) {
called |= 1 << 6
return nil, nil
},
WrapSession: func(cs ConnectionState, ss *SessionState) ([]byte, error) {
called |= 1 << 7
return nil, nil
},
}
c2 := c1.Clone()
@ -796,6 +804,8 @@ func TestCloneFuncFields(t *testing.T) {
c2.GetConfigForClient(nil)
c2.VerifyPeerCertificate(nil, nil)
c2.VerifyConnection(ConnectionState{})
c2.UnwrapSession(nil, ConnectionState{})
c2.WrapSession(ConnectionState{}, nil)
if called != (1<<expectedCount)-1 {
t.Fatalf("expected %d calls but saw calls %b", expectedCount, called)
@ -814,7 +824,7 @@ func TestCloneNonFuncFields(t *testing.T) {
switch fn := typ.Field(i).Name; fn {
case "Rand":
f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "VerifyConnection", "GetClientCertificate":
case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "VerifyConnection", "GetClientCertificate", "WrapSession", "UnwrapSession":
// DeepEqual can't compare functions. If you add a
// function field to this list, you must also change
// TestCloneFuncFields to ensure that the func field is