diff --git a/lib/auth/apiserver.go b/lib/auth/apiserver.go index 3a6e5e9b613..6df4888bf20 100644 --- a/lib/auth/apiserver.go +++ b/lib/auth/apiserver.go @@ -1935,7 +1935,7 @@ func (s *APIServer) emitAuditEvent(auth ClientI, w http.ResponseWriter, r *http. // Validate serverID field in event matches server ID from x509 identity. This // check makes sure nodes can only submit events for themselves. - serverID, err := getServerID(r) + serverID, err := s.getServerID(r) if err != nil { return nil, trace.Wrap(err) } @@ -1975,7 +1975,7 @@ func (s *APIServer) postSessionSlice(auth ClientI, w http.ResponseWriter, r *htt // Validate serverID field in event matches server ID from x509 identity. This // check makes sure nodes can only submit events for themselves. - serverID, err := getServerID(r) + serverID, err := s.getServerID(r) if err != nil { return nil, trace.Wrap(err) } @@ -2036,7 +2036,7 @@ func (s *APIServer) uploadSessionRecording(auth ClientI, w http.ResponseWriter, // Validate namespace and serverID fields in the archive match namespace and // serverID of the authenticated client. This check makes sure nodes can // only submit recordings for themselves. - serverID, err := getServerID(r) + serverID, err := s.getServerID(r) if err != nil { return nil, trace.Wrap(err) } @@ -2527,18 +2527,26 @@ func (s *APIServer) processKubeCSR(auth ClientI, w http.ResponseWriter, r *http. } // getServerID returns the ID of the connected client. -func getServerID(r *http.Request) (string, error) { +func (s *APIServer) getServerID(r *http.Request) (string, error) { role, ok := r.Context().Value(ContextUser).(BuiltinRole) if !ok { return "", trace.BadParameter("invalid role %T", r.Context().Value(ContextUser)) } - parts := strings.Split(role.Username, ".") - if len(parts) == 0 { - return "", trace.BadParameter("invalid username: %v", role.Username) + clusterName, err := s.AuthServer.GetDomainName() + if err != nil { + return "", trace.Wrap(err) } - return parts[0], nil + // The username extracted from the node's identity (x.509 certificate) + // is expected to consist of "." so strip the + // cluster name suffix to get the server id. + // + // Note that as of right now Teleport expects server id to be a uuid4 + // but older Gravity clusters used to override it with strings like + // "192_168_1_1." so this code can't rely on it being + // uuid4 to account for clusters upgraded from older versions. + return strings.TrimSuffix(role.Username, "."+clusterName), nil } func message(msg string) map[string]interface{} { diff --git a/lib/auth/tls_test.go b/lib/auth/tls_test.go index f15ef75a88d..9692b6713e7 100644 --- a/lib/auth/tls_test.go +++ b/lib/auth/tls_test.go @@ -901,7 +901,7 @@ func (s *TLSSuite) TestValidateUploadSessionRecording(c *check.C) { }, } for _, tt := range tests { - clt, err := s.server.NewClient(TestServerID(s.server.Identity.ID.HostUUID)) + clt, err := s.server.NewClient(TestServerID(serverID)) c.Assert(err, check.IsNil) sessionID := session.NewID() @@ -997,7 +997,7 @@ func (s *TLSSuite) TestValidatePostSessionSlice(c *check.C) { }, } for _, tt := range tests { - clt, err := s.server.NewClient(TestServerID(s.server.Identity.ID.HostUUID)) + clt, err := s.server.NewClient(TestServerID(serverID)) c.Assert(err, check.IsNil) date := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)