use "google.golang.org/protobuf" to clone protobuf messages (#25466)

"github.com/gogo/protobuf/proto.Clone" has multiple bugs that cause
panics, so create a wrapper function that uses a thoroughly tested
cloning function from "google.golang.org/protobuf".
This commit is contained in:
Andrew LeFevre 2023-05-04 17:22:01 -04:00 committed by GitHub
parent 76074b4f1c
commit 35b837de87
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
28 changed files with 220 additions and 69 deletions

View file

@ -23,7 +23,6 @@ import (
"sort"
"time"
"github.com/gogo/protobuf/proto"
"github.com/gravitational/trace"
"github.com/gravitational/teleport/api/utils"
@ -422,7 +421,7 @@ func (r *AccessRequestV3) SetDryRun(dryRun bool) {
// Copy returns a copy of the access request resource.
func (r *AccessRequestV3) Copy() AccessRequest {
return proto.Clone(r).(*AccessRequestV3)
return utils.CloneProtoMsg(r)
}
// GetLabel retrieves the label with the provided key. If not found

View file

@ -22,7 +22,6 @@ import (
"strings"
"time"
"github.com/gogo/protobuf/proto"
"github.com/gravitational/trace"
"github.com/gravitational/teleport/api/constants"
@ -298,7 +297,7 @@ func (a *AppV3) String() string {
// Copy returns a copy of this database resource.
func (a *AppV3) Copy() *AppV3 {
return proto.Clone(a).(*AppV3)
return utils.CloneProtoMsg(a)
}
// MatchSearch goes through select field values and tries to

View file

@ -21,10 +21,10 @@ import (
"sort"
"time"
"github.com/gogo/protobuf/proto"
"github.com/gravitational/trace"
"github.com/gravitational/teleport/api"
"github.com/gravitational/teleport/api/utils"
)
// AppServer represents a single proxied web app.
@ -289,7 +289,7 @@ func (s *AppServerV3) SetStaticLabels(sl map[string]string) {
// Copy returns a copy of this app server object.
func (s *AppServerV3) Copy() AppServer {
return proto.Clone(s).(*AppServerV3)
return utils.CloneProtoMsg(s)
}
// MatchSearch goes through select field values and tries to

View file

@ -19,8 +19,9 @@ package types
import (
"time"
"github.com/gogo/protobuf/proto"
"github.com/gravitational/trace"
"github.com/gravitational/teleport/api/utils"
)
// ClusterAuditConfig defines cluster-wide audit log configuration. This is
@ -253,7 +254,7 @@ func (c *ClusterAuditConfigV2) RetentionPeriod() *Duration {
// Clone performs a deep copy.
func (c *ClusterAuditConfigV2) Clone() ClusterAuditConfig {
return proto.Clone(c).(*ClusterAuditConfigV2)
return utils.CloneProtoMsg(c)
}
// setStaticFields sets static resource header and metadata fields.

View file

@ -20,11 +20,11 @@ import (
"fmt"
"time"
"github.com/gogo/protobuf/proto"
"github.com/gravitational/trace"
"golang.org/x/exp/slices"
"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/utils"
)
// CertAuthority is a host or user certificate authority that can check and if
@ -108,7 +108,7 @@ func (ca *CertAuthorityV2) SetSubKind(s string) {
// Clone returns a copy of the cert authority object.
func (ca *CertAuthorityV2) Clone() CertAuthority {
return proto.Clone(ca).(*CertAuthorityV2)
return utils.CloneProtoMsg(ca)
}
// GetRotation returns rotation state.
@ -717,5 +717,4 @@ func (f *CertAuthorityFilter) FromMap(m map[string]string) {
for key, val := range m {
(*f)[CertAuthType(key)] = val
}
}

69
api/types/clone_test.go Normal file
View file

@ -0,0 +1,69 @@
/*
Copyright 2023 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package types
import (
"fmt"
"testing"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/protoadapt"
"github.com/gravitational/teleport/api/utils"
)
type protoResource interface {
Resource
protoadapt.MessageV1
}
func TestCloning(t *testing.T) {
// Test that cloning some of our messages produces the same type
// with the same contents. When CheckAndSetDefaults sets an empty
// slice or map instead of a nil one, set it to nil so the
// equality check below won't fail.
var resources []protoResource
a, err := NewAccessRequest("foo", "bar", "role")
require.NoError(t, err)
accessRequest := a.(*AccessRequestV3)
accessRequest.Spec.SuggestedReviewers = nil
accessRequest.Spec.RequestedResourceIDs = nil
resources = append(resources, accessRequest)
user, err := NewUser("foo")
require.NoError(t, err)
resources = append(resources, user.(*UserV2))
s, err := NewServer("foo", KindNode, ServerSpecV2{})
require.NoError(t, err)
server := s.(*ServerV2)
server.Metadata.Labels = nil
resources = append(resources, server)
remCluster, err := NewRemoteCluster("foo")
require.NoError(t, err)
resources = append(resources, remCluster.(*RemoteClusterV3))
for _, r := range resources {
t.Run(fmt.Sprintf("%T", r), func(t *testing.T) {
rCopy := utils.CloneProtoMsg(r)
require.Equal(t, r, rCopy)
require.IsType(t, r, rCopy)
})
}
}

View file

@ -20,8 +20,9 @@ import (
"fmt"
"time"
"github.com/gogo/protobuf/proto"
"github.com/gravitational/trace"
"github.com/gravitational/teleport/api/utils"
)
// ClusterName defines the name of the cluster. This is a configuration
@ -130,7 +131,7 @@ func (c *ClusterNameV2) GetClusterID() string {
// Clone performs a deep copy.
func (c *ClusterNameV2) Clone() ClusterName {
return proto.Clone(c).(*ClusterNameV2)
return utils.CloneProtoMsg(c)
}
// setStaticFields sets static resource header and metadata fields.

View file

@ -22,7 +22,6 @@ import (
"strings"
"time"
"github.com/gogo/protobuf/proto"
"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
@ -502,7 +501,7 @@ func (d *DatabaseV3) String() string {
// Copy returns a copy of this database resource.
func (d *DatabaseV3) Copy() *DatabaseV3 {
return proto.Clone(d).(*DatabaseV3)
return utils.CloneProtoMsg(d)
}
// MatchSearch goes through select field values and tries to

View file

@ -21,10 +21,10 @@ import (
"sort"
"time"
"github.com/gogo/protobuf/proto"
"github.com/gravitational/trace"
"github.com/gravitational/teleport/api"
"github.com/gravitational/teleport/api/utils"
)
// DatabaseServer represents a database access server.
@ -309,7 +309,7 @@ func (s *DatabaseServerV3) SetStaticLabels(sl map[string]string) {
// Copy returns a copy of this database server object.
func (s *DatabaseServerV3) Copy() DatabaseServer {
return proto.Clone(s).(*DatabaseServerV3)
return utils.CloneProtoMsg(s)
}
// MatchSearch goes through select field values and tries to

View file

@ -16,7 +16,9 @@ limitations under the License.
package events
import "github.com/gogo/protobuf/proto"
import (
"github.com/gravitational/teleport/api/utils"
)
func trimN(s string, n int) string {
// Starting at 2 to leave room for quotes at the begging and end.
@ -50,7 +52,7 @@ func (m *DatabaseSessionQuery) TrimToMaxSize(maxSize int) AuditEvent {
return m
}
out := proto.Clone(m).(*DatabaseSessionQuery)
out := utils.CloneProtoMsg(m)
out.DatabaseQuery = ""
out.DatabaseQueryParameters = nil
@ -88,7 +90,7 @@ func (e *SessionStart) TrimToMaxSize(maxSize int) AuditEvent {
return e
}
out := proto.Clone(e).(*SessionStart)
out := utils.CloneProtoMsg(e)
out.InitialCommand = nil
// Use 10% max size ballast + message size without InitialCommand
@ -114,7 +116,7 @@ func (e *Exec) TrimToMaxSize(maxSize int) AuditEvent {
return e
}
out := proto.Clone(e).(*Exec)
out := utils.CloneProtoMsg(e)
out.Command = ""
// Use 10% max size ballast + message size without Command
@ -137,7 +139,7 @@ func (e *UserLogin) TrimToMaxSize(maxSize int) AuditEvent {
return e
}
out := proto.Clone(e).(*UserLogin)
out := utils.CloneProtoMsg(e)
out.Status.Error = ""
out.Status.UserMessage = ""

View file

@ -19,11 +19,11 @@ package types
import (
"time"
"github.com/gogo/protobuf/proto"
"github.com/gravitational/trace"
"golang.org/x/exp/slices"
"github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/utils"
)
// Match checks if the given instance appears to match this filter.
@ -247,10 +247,10 @@ func (i *InstanceV1) expireControlLog(now time.Time, ttl time.Duration) time.Tim
}
func (i *InstanceV1) Clone() Instance {
return proto.Clone(i).(*InstanceV1)
return utils.CloneProtoMsg(i)
}
func (e *InstanceControlLogEntry) Clone() InstanceControlLogEntry {
e.Time = e.Time.UTC()
return *proto.Clone(e).(*InstanceControlLogEntry)
return *utils.CloneProtoMsg(e)
}

View file

@ -22,7 +22,6 @@ import (
"sort"
"time"
"github.com/gogo/protobuf/proto"
"github.com/gravitational/trace"
"golang.org/x/exp/slices"
@ -313,7 +312,7 @@ func (k *KubernetesClusterV3) String() string {
// Copy returns a copy of this resource.
func (k *KubernetesClusterV3) Copy() *KubernetesClusterV3 {
return proto.Clone(k).(*KubernetesClusterV3)
return utils.CloneProtoMsg(k)
}
// MatchSearch goes through select field values and tries to

View file

@ -21,10 +21,10 @@ import (
"sort"
"time"
"github.com/gogo/protobuf/proto"
"github.com/gravitational/trace"
"github.com/gravitational/teleport/api"
"github.com/gravitational/teleport/api/utils"
)
// KubeServer represents a single Kubernetes server.
@ -279,7 +279,7 @@ func (s *KubernetesServerV3) SetStaticLabels(sl map[string]string) {
// Copy returns a copy of this kube server object.
func (s *KubernetesServerV3) Copy() KubeServer {
return proto.Clone(s).(*KubernetesServerV3)
return utils.CloneProtoMsg(s)
}
// MatchSearch goes through select field values and tries to

View file

@ -20,8 +20,9 @@ import (
"strings"
"time"
"github.com/gogo/protobuf/proto"
"github.com/gravitational/trace"
"github.com/gravitational/teleport/api/utils"
)
const (
@ -118,7 +119,7 @@ func (w *AgentUpgradeWindow) Export(from time.Time, n int) AgentUpgradeSchedule
}
func (s *AgentUpgradeSchedule) Clone() *AgentUpgradeSchedule {
return proto.Clone(s).(*AgentUpgradeSchedule)
return utils.CloneProtoMsg(s)
}
// NewClusterMaintenanceConfig creates a new maintenance config with no parameters set.

View file

@ -15,15 +15,16 @@
package types
import (
proto "github.com/gogo/protobuf/proto"
"github.com/gravitational/trace"
"github.com/gravitational/teleport/api/utils"
)
func (d *MFADevice) WithoutSensitiveData() (*MFADevice, error) {
if d == nil {
return nil, trace.BadParameter("cannot hide sensitive data on empty object")
}
out := proto.Clone(d).(*MFADevice)
out := utils.CloneProtoMsg(d)
switch mfad := out.Device.(type) {
case *MFADevice_Totp:

View file

@ -20,10 +20,10 @@ import (
"strings"
"time"
"github.com/gogo/protobuf/proto"
"github.com/gravitational/trace"
"github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/utils"
)
// ClusterNetworkingConfig defines cluster networking configuration. This is
@ -276,7 +276,7 @@ func (c *ClusterNetworkingConfigV2) SetProxyListenerMode(mode ProxyListenerMode)
// Clone performs a deep copy.
func (c *ClusterNetworkingConfigV2) Clone() ClusterNetworkingConfig {
return proto.Clone(c).(*ClusterNetworkingConfigV2)
return utils.CloneProtoMsg(c)
}
// setStaticFields sets static resource header and metadata fields.

View file

@ -20,7 +20,6 @@ import (
"fmt"
"time"
"github.com/gogo/protobuf/proto"
"github.com/gravitational/trace"
"github.com/gravitational/teleport/api/constants"
@ -337,7 +336,7 @@ func (o *OktaAssignmentV1) SetFinalized(finalized bool) {
// Copy returns a copy of this Okta assignment resource.
func (o *OktaAssignmentV1) Copy() OktaAssignment {
return proto.Clone(o).(*OktaAssignmentV1)
return utils.CloneProtoMsg(o)
}
// String returns the Okta assignment rule string representation.

View file

@ -19,8 +19,9 @@ package types
import (
"time"
"github.com/gogo/protobuf/proto"
"github.com/gravitational/trace"
"github.com/gravitational/teleport/api/utils"
)
// PluginType represents the type of the plugin
@ -122,7 +123,7 @@ func (p *PluginV1) setStaticFields() {
// Clone returns a copy of the Plugin instance
func (p *PluginV1) Clone() Plugin {
return proto.Clone(p).(*PluginV1)
return utils.CloneProtoMsg(p)
}
// GetVersion returns resource version

View file

@ -20,8 +20,9 @@ import (
"fmt"
"time"
"github.com/gogo/protobuf/proto"
"github.com/gravitational/trace"
"github.com/gravitational/teleport/api/utils"
)
// RemoteCluster represents a remote cluster that has connected via reverse tunnel
@ -116,7 +117,7 @@ func (c *RemoteClusterV3) SetLastHeartbeat(t time.Time) {
// Clone performs a deep copy.
func (c *RemoteClusterV3) Clone() RemoteCluster {
return proto.Clone(c).(*RemoteClusterV3)
return utils.CloneProtoMsg(c)
}
// GetConnectionStatus returns connection status

View file

@ -23,7 +23,6 @@ import (
"strings"
"time"
"github.com/gogo/protobuf/proto"
"github.com/google/uuid"
"github.com/gravitational/trace"
@ -469,7 +468,7 @@ func (s *ServerV2) MatchSearch(values []string) bool {
// DeepCopy creates a clone of this server value
func (s *ServerV2) DeepCopy() Server {
return proto.Clone(s).(*ServerV2)
return utils.CloneProtoMsg(s)
}
// IsAWSConsole returns true if this app is AWS management console.

View file

@ -21,8 +21,6 @@ import (
"time"
"github.com/gravitational/trace"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/protoadapt"
"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/utils"
@ -454,12 +452,7 @@ func (u *UserV2) ResetLocks() {
// DeepCopy creates a clone of this user value.
func (u *UserV2) DeepCopy() User {
// github.com/golang/protobuf/proto.Clone panics when trying to
// copy a map[K]V where the type of V is a slice of anything
// other than byte. See https://github.com/gogo/protobuf/issues/14
uV2 := protoadapt.MessageV2Of(u)
uV2Copy := proto.Clone(uV2)
return protoadapt.MessageV1Of(uV2Copy).(*UserV2)
return utils.CloneProtoMsg(u)
}
// IsEmpty returns true if there's no info about who created this user

37
api/utils/protobuf.go Normal file
View file

@ -0,0 +1,37 @@
/*
Copyright 2023 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package utils
import (
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/protoadapt"
)
// CloneProtoMsg returns a deep copy of msg. Modifying the returned
// protobuf message will not affect msg. If msg contains any empty
// slices, the returned copy will have nil slices instead.
func CloneProtoMsg[T protoadapt.MessageV1](msg T) T {
// github.com/golang/protobuf/proto.Clone panics when trying to
// copy a map[K]V where the type of V is a slice of anything
// other than byte. See https://github.com/gogo/protobuf/issues/14
msgV2 := protoadapt.MessageV2Of(msg)
msgV2 = proto.Clone(msgV2)
// this is safe as protoadapt.MessageV2Of will simply wrap the message
// with a type that implements the protobuf v2 API, and
// protoadapt.MessageV1Of will return the unwrapped message
return protoadapt.MessageV1Of(msgV2).(T)
}

View file

@ -0,0 +1,47 @@
/*
Copyright 2023 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package utils
import (
"testing"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/structpb"
)
func TestCloneProtoMsg(t *testing.T) {
m := map[string]any{
"4": 2.0,
"2": 4.0,
}
origMsg, err := structpb.NewStruct(m)
require.NoError(t, err)
msgCopy := CloneProtoMsg(origMsg)
require.Equal(t, origMsg, msgCopy)
require.IsType(t, origMsg, msgCopy)
// test that modifying the original doesn't affect the copy
delete(origMsg.Fields, "2")
require.Equal(t, m, msgCopy.AsMap())
// test cloning a nil message
var sm *structpb.Struct
smCopy := CloneProtoMsg(sm)
require.Equal(t, sm, smCopy)
require.IsType(t, sm, smCopy)
}

View file

@ -30,6 +30,7 @@ import (
"github.com/gravitational/teleport/api/types"
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/api/types/wrappers"
apiutils "github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/lib/authz"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
@ -373,17 +374,11 @@ func (s *Server) validateGenerationLabel(ctx context.Context, user types.User, c
return trace.BadParameter("explicitly requested generation %d is not equal to 1, this is a logic error", certReq.generation)
}
// Fetch a fresh copy of the user we can mutate safely. We can't
// implement a protobuf clone on User due to protobuf's proto.Clone()
// panicing when the user object has traits set, and a JSON
// marshal/unmarshal creates an import cycle so... here we are.
// There's a tiny chance the underlying user is mutated between calls
// to GetUser() but we're comparing with an older value so it'll fail
// safely.
newUser, err := s.Services.GetUser(user.GetName(), false)
if err != nil {
return trace.Wrap(err)
userV2, ok := user.(*types.UserV2)
if !ok {
return trace.BadParameter("unsupported version of user: %T", user)
}
newUser := apiutils.CloneProtoMsg(userV2)
metadata := newUser.GetMetadata()
metadata.Labels[types.BotGenerationLabel] = fmt.Sprint(certReq.generation)
newUser.SetMetadata(metadata)

View file

@ -37,9 +37,9 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/test/bufconn"
"google.golang.org/protobuf/proto"
"github.com/gravitational/teleport/api/types"
apiutils "github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/lib/auth/keystore/internal/faketime"
"github.com/gravitational/teleport/lib/auth/testauthority"
"github.com/gravitational/teleport/lib/jwt"
@ -120,7 +120,7 @@ func (f *fakeGCPKMSServer) CreateCryptoKey(ctx context.Context, req *kmspb.Creat
keyName := req.Parent + "/cryptoKeys/" + req.CryptoKeyId
keyVersionName := keyName + "/cryptoKeyVersions/1"
cryptoKey := proto.Clone(req.CryptoKey).(*kmspb.CryptoKey)
cryptoKey := apiutils.CloneProtoMsg(req.CryptoKey)
cryptoKey.Name = keyName
cryptoKeyVersion := &kmspb.CryptoKeyVersion{

View file

@ -1086,12 +1086,21 @@ func MetadataFromElastiCacheCluster(cluster *elasticache.ReplicationGroup, endpo
return nil, trace.Wrap(err)
}
// aws.StringValueSlice will return an empty slice is the input slice
// is empty, but when cloning protobuf messages a cloned empty slice
// will return nil. Keep this behavior so tests comparing cloned
// messages don't fail.
var userGroupIDs []string
if len(cluster.UserGroupIds) != 0 {
userGroupIDs = aws.StringValueSlice(cluster.UserGroupIds)
}
return &types.AWS{
Region: parsedARN.Region,
AccountID: parsedARN.AccountID,
ElastiCache: types.ElastiCache{
ReplicationGroupID: aws.StringValue(cluster.ReplicationGroupId),
UserGroupIDs: aws.StringValueSlice(cluster.UserGroupIds),
UserGroupIDs: userGroupIDs,
TransitEncryptionEnabled: aws.BoolValue(cluster.TransitEncryptionEnabled),
EndpointType: endpointType,
},

View file

@ -22,12 +22,12 @@ import (
"sync"
"time"
"github.com/gogo/protobuf/proto"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"
"github.com/gravitational/teleport/api/types"
apiutils "github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/api/utils/retryutils"
"github.com/gravitational/teleport/lib/backend"
"github.com/gravitational/teleport/lib/defaults"
@ -225,7 +225,7 @@ func (h *HeadlessAuthenticationWatcher) notify(headlessAuthns ...*types.Headless
for _, s := range h.subscribers {
if s != nil && s.name == ha.Metadata.Name {
select {
case s.updates <- proto.Clone(ha).(*types.HeadlessAuthentication):
case s.updates <- apiutils.CloneProtoMsg(ha):
default:
select {
case s.stale <- struct{}{}:

View file

@ -15,12 +15,12 @@
package loginrule
import (
"github.com/gogo/protobuf/proto"
"github.com/gravitational/trace"
loginrulepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/loginrule/v1"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/wrappers"
apiutils "github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/lib/utils"
)
@ -105,7 +105,7 @@ func ProtoToResource(rule *loginrulepb.LoginRule) *Resource {
ResourceHeader: types.ResourceHeader{
Kind: types.KindLoginRule,
Version: rule.Version,
Metadata: *proto.Clone(rule.Metadata).(*types.Metadata),
Metadata: *apiutils.CloneProtoMsg(rule.Metadata),
},
Spec: spec{
Priority: rule.Priority,
@ -118,7 +118,7 @@ func ProtoToResource(rule *loginrulepb.LoginRule) *Resource {
func resourceToProto(r *Resource) *loginrulepb.LoginRule {
return &loginrulepb.LoginRule{
Metadata: proto.Clone(&r.Metadata).(*types.Metadata),
Metadata: apiutils.CloneProtoMsg(&r.Metadata),
Version: r.Version,
Priority: r.Spec.Priority,
TraitsMap: traitsMapResourceToProto(r.Spec.TraitsMap),