From 143b834e578740f4e4eef89c696d8e1e7c1b30be Mon Sep 17 00:00:00 2001 From: Roman Tkachenko Date: Tue, 21 Nov 2017 17:35:58 -0800 Subject: [PATCH] Changes for the upcoming teleport pro: * Allow external audit log plugins * Add support for auth API server plugins * Add license file path configuration parameter (not used in open-source) * Extend audit log with user login events --- Gopkg.lock | 18 +- Gopkg.toml | 8 + e | 2 +- lib/auth/apiserver.go | 9 +- lib/auth/apiserver_test.go | 1 + lib/auth/auth.go | 28 +- lib/auth/init.go | 20 +- lib/auth/init_test.go | 42 + lib/auth/oidc.go | 8 +- lib/auth/plugin.go | 44 + lib/auth/saml.go | 8 +- lib/config/configuration.go | 12 + lib/config/configuration_test.go | 39 + lib/config/fileconf.go | 9 +- lib/config/testdata_test.go | 14 +- lib/defaults/defaults.go | 5 +- lib/events/api.go | 18 +- lib/events/discard.go | 9 +- lib/service/cfg.go | 7 +- lib/service/service.go | 24 +- lib/services/clusterconfig.go | 25 +- lib/services/identity.go | 4 +- lib/services/local/configuration.go | 1 + lib/web/apiserver.go | 5 + lib/web/sessions.go | 2 +- tool/teleport/common/teleport.go | 32 +- tool/teleport/common/teleport_test.go | 25 +- tool/teleport/main.go | 5 +- .../golang/protobuf/jsonpb/jsonpb.go | 843 ++++++++++++++++++ .../golang/protobuf/jsonpb/jsonpb_test.go | 563 ++++++++++++ .../gravitational/license/.gitignore | 14 + .../github.com/gravitational/license/LICENSE | 201 +++++ .../github.com/gravitational/license/Makefile | 13 + .../gravitational/license/README.md | 2 + .../license/constants/constants.go | 98 ++ .../gravitational/license/license.go | 91 ++ .../github.com/gravitational/license/parse.go | 159 ++++ .../gravitational/reporting/.gitignore | 14 + .../gravitational/reporting/Dockerfile | 26 + .../gravitational/reporting/LICENSE | 201 +++++ .../gravitational/reporting/Makefile | 39 + .../gravitational/reporting/README.md | 2 + .../gravitational/reporting/api.pb.go | 554 ++++++++++++ .../gravitational/reporting/api.pb.gw.go | 110 +++ .../gravitational/reporting/api.proto | 45 + .../gravitational/reporting/client/client.go | 160 ++++ .../reporting/types/constants.go | 42 + .../gravitational/reporting/types/events.go | 320 +++++++ .../reporting/types/heartbeat.go | 134 +++ .../reporting/types/types_test.go | 61 ++ .../grpc-gateway/runtime/context.go | 139 +++ .../grpc-gateway/runtime/context_test.go | 169 ++++ .../grpc-gateway/runtime/convert.go | 58 ++ .../grpc-gateway/runtime/doc.go | 5 + .../grpc-gateway/runtime/errors.go | 121 +++ .../grpc-gateway/runtime/errors_test.go | 56 ++ .../grpc-gateway/runtime/handler.go | 164 ++++ .../runtime/internal/stream_chunk.pb.go | 65 ++ .../runtime/internal/stream_chunk.proto | 12 + .../grpc-gateway/runtime/marshal_json.go | 37 + .../grpc-gateway/runtime/marshal_json_test.go | 245 +++++ .../grpc-gateway/runtime/marshal_jsonpb.go | 182 ++++ .../runtime/marshal_jsonpb_test.go | 606 +++++++++++++ .../grpc-gateway/runtime/marshaler.go | 42 + .../runtime/marshaler_registry.go | 91 ++ .../runtime/marshaler_registry_test.go | 107 +++ .../grpc-gateway/runtime/mux.go | 132 +++ .../grpc-gateway/runtime/mux_test.go | 213 +++++ .../grpc-gateway/runtime/pattern.go | 227 +++++ .../grpc-gateway/runtime/pattern_test.go | 590 ++++++++++++ .../grpc-gateway/runtime/proto2_convert.go | 80 ++ .../grpc-gateway/runtime/query.go | 140 +++ .../grpc-gateway/runtime/query_test.go | 311 +++++++ .../grpc-gateway/utilities/doc.go | 2 + .../grpc-gateway/utilities/pattern.go | 22 + .../grpc-gateway/utilities/trie.go | 177 ++++ .../grpc-gateway/utilities/trie_test.go | 372 ++++++++ 77 files changed, 8421 insertions(+), 60 deletions(-) create mode 100644 lib/auth/plugin.go create mode 100644 vendor/github.com/golang/protobuf/jsonpb/jsonpb.go create mode 100644 vendor/github.com/golang/protobuf/jsonpb/jsonpb_test.go create mode 100644 vendor/github.com/gravitational/license/.gitignore create mode 100644 vendor/github.com/gravitational/license/LICENSE create mode 100644 vendor/github.com/gravitational/license/Makefile create mode 100644 vendor/github.com/gravitational/license/README.md create mode 100644 vendor/github.com/gravitational/license/constants/constants.go create mode 100644 vendor/github.com/gravitational/license/license.go create mode 100644 vendor/github.com/gravitational/license/parse.go create mode 100644 vendor/github.com/gravitational/reporting/.gitignore create mode 100644 vendor/github.com/gravitational/reporting/Dockerfile create mode 100644 vendor/github.com/gravitational/reporting/LICENSE create mode 100644 vendor/github.com/gravitational/reporting/Makefile create mode 100644 vendor/github.com/gravitational/reporting/README.md create mode 100644 vendor/github.com/gravitational/reporting/api.pb.go create mode 100644 vendor/github.com/gravitational/reporting/api.pb.gw.go create mode 100644 vendor/github.com/gravitational/reporting/api.proto create mode 100644 vendor/github.com/gravitational/reporting/client/client.go create mode 100644 vendor/github.com/gravitational/reporting/types/constants.go create mode 100644 vendor/github.com/gravitational/reporting/types/events.go create mode 100644 vendor/github.com/gravitational/reporting/types/heartbeat.go create mode 100644 vendor/github.com/gravitational/reporting/types/types_test.go create mode 100644 vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/context.go create mode 100644 vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/context_test.go create mode 100644 vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/convert.go create mode 100644 vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/doc.go create mode 100644 vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/errors.go create mode 100644 vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/errors_test.go create mode 100644 vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/handler.go create mode 100644 vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/internal/stream_chunk.pb.go create mode 100644 vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/internal/stream_chunk.proto create mode 100644 vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/marshal_json.go create mode 100644 vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/marshal_json_test.go create mode 100644 vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/marshal_jsonpb.go create mode 100644 vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/marshal_jsonpb_test.go create mode 100644 vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/marshaler.go create mode 100644 vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/marshaler_registry.go create mode 100644 vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/marshaler_registry_test.go create mode 100644 vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/mux.go create mode 100644 vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/mux_test.go create mode 100644 vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/pattern.go create mode 100644 vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/pattern_test.go create mode 100644 vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/proto2_convert.go create mode 100644 vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/query.go create mode 100644 vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/query_test.go create mode 100644 vendor/github.com/grpc-ecosystem/grpc-gateway/utilities/doc.go create mode 100644 vendor/github.com/grpc-ecosystem/grpc-gateway/utilities/pattern.go create mode 100644 vendor/github.com/grpc-ecosystem/grpc-gateway/utilities/trie.go create mode 100644 vendor/github.com/grpc-ecosystem/grpc-gateway/utilities/trie_test.go diff --git a/Gopkg.lock b/Gopkg.lock index 91469a0130d..4b447daf878 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -117,7 +117,7 @@ [[projects]] name = "github.com/golang/protobuf" - packages = ["proto","protoc-gen-go/descriptor","ptypes/any","ptypes/empty"] + packages = ["jsonpb","proto","protoc-gen-go/descriptor","ptypes/any","ptypes/empty"] revision = "8ee79997227bf9b34611aee7946ae64735e6fd93" [[projects]] @@ -143,6 +143,18 @@ packages = ["."] revision = "52bc17adf63c0807b5e5b5d91350703630f621c7" +[[projects]] + name = "github.com/gravitational/license" + packages = [".","constants"] + revision = "102213511ace56c97ccf1eef645835e16f84d130" + version = "0.0.4" + +[[projects]] + name = "github.com/gravitational/reporting" + packages = [".","client","types"] + revision = "3c4a4e96fb5896e14fe29da7fcce14b8d93f3965" + version = "0.0.4" + [[projects]] branch = "master" name = "github.com/gravitational/roundtrip" @@ -163,7 +175,7 @@ [[projects]] name = "github.com/grpc-ecosystem/grpc-gateway" - packages = ["third_party/googleapis/google/api"] + packages = ["runtime","runtime/internal","third_party/googleapis/google/api","utilities"] revision = "a8f25bd1ab549f8b87afd48aa9181221e9d439bb" version = "v1.1.0" @@ -393,6 +405,6 @@ [solve-meta] analyzer-name = "dep" analyzer-version = 1 - inputs-digest = "c8734d8ce5c599785cf35e78aca518adc1fdd6ff31c139d65777429211331ba5" + inputs-digest = "97a726bb183a88f10d511f1853d157e006b3a6a3a8d2887dddc0aa8f92506175" solver-name = "gps-cdcl" solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml index ceda35c5c09..3f2cb24b5f2 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -143,3 +143,11 @@ ignored = ["github.com/Sirupsen/logrus"] [[constraint]] name = "github.com/russellhaering/gosaml2" revision = "8908227c114abe0b63b1f0606abae72d11bf632a" + +[[constraint]] + name = "github.com/gravitational/reporting" + version = "0.0.4" + +[[constraint]] + name = "github.com/gravitational/license" + version = "0.0.4" diff --git a/e b/e index 4f3e4ebd667..96a9523e7e7 160000 --- a/e +++ b/e @@ -1 +1 @@ -Subproject commit 4f3e4ebd66716cd256abc2847a8e80addc85e4a3 +Subproject commit 96a9523e7e7d8937bf738a1b299642f27447b3ef diff --git a/lib/auth/apiserver.go b/lib/auth/apiserver.go index 2ca22f429a0..e58f0549cb6 100644 --- a/lib/auth/apiserver.go +++ b/lib/auth/apiserver.go @@ -34,10 +34,9 @@ import ( "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" "github.com/julienschmidt/httprouter" log "github.com/sirupsen/logrus" - - "github.com/jonboulle/clockwork" "github.com/tstranex/u2f" ) @@ -193,6 +192,10 @@ func NewAPIServer(config *APIConfig) http.Handler { srv.GET("/:version/events", srv.withAuth(srv.searchEvents)) srv.GET("/:version/events/session", srv.withAuth(srv.searchSessionEvents)) + if plugin := GetPlugin(); plugin != nil { + plugin.AddHandlers(&srv) + } + return httplib.RewritePaths(&srv.Router, httplib.Rewrite("/v1/nodes", "/v1/namespaces/default/nodes"), httplib.Rewrite("/v1/sessions", "/v1/namespaces/default/sessions"), @@ -223,7 +226,7 @@ func (s *APIServer) withAuth(handler HandlerWithAuthFunc) httprouter.Handle { user: authContext.User, checker: authContext.Checker, sessions: s.SessionService, - alog: s.AuditLog, + alog: s.AuthServer.IAuditLog, } version := p.ByName("version") if version == "" { diff --git a/lib/auth/apiserver_test.go b/lib/auth/apiserver_test.go index 5b79706d0bf..8925cc91346 100644 --- a/lib/auth/apiserver_test.go +++ b/lib/auth/apiserver_test.go @@ -79,6 +79,7 @@ func (s *APISuite) SetUpTest(c *C) { s.a = NewAuthServer(&InitConfig{ Backend: s.bk, Authority: authority.New(), + AuditLog: s.alog, }) s.sessions, err = session.New(s.bk) c.Assert(err, IsNil) diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 7f5e5db2e69..7c1cafc9b8b 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -33,12 +33,13 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/services/local" "github.com/gravitational/teleport/lib/utils" - "github.com/gravitational/trace" "github.com/coreos/go-oidc/oidc" + "github.com/gravitational/trace" "github.com/jonboulle/clockwork" saml2 "github.com/russellhaering/gosaml2" log "github.com/sirupsen/logrus" @@ -87,6 +88,9 @@ func NewAuthServer(cfg *InitConfig, opts ...AuthServerOption) *AuthServer { if cfg.ClusterConfiguration == nil { cfg.ClusterConfiguration = local.NewClusterConfigurationService(cfg.Backend) } + if cfg.AuditLog == nil { + cfg.AuditLog = events.NewDiscardAuditLog() + } closeCtx, cancelFunc := context.WithCancel(context.TODO()) as := AuthServer{ Entry: log.WithFields(log.Fields{ @@ -101,6 +105,7 @@ func NewAuthServer(cfg *InitConfig, opts ...AuthServerOption) *AuthServer { Access: cfg.Access, AuthServiceName: cfg.AuthServiceName, ClusterConfiguration: cfg.ClusterConfiguration, + IAuditLog: cfg.AuditLog, oidcClients: make(map[string]*oidcClient), samlProviders: make(map[string]*samlProvider), cancelFunc: cancelFunc, @@ -157,6 +162,7 @@ type AuthServer struct { services.Identity services.Access services.ClusterConfiguration + events.IAuditLog } func (a *AuthServer) Close() error { @@ -172,6 +178,11 @@ func (a *AuthServer) SetClock(clock clockwork.Clock) { a.clock = clock } +// SetAuditLog sets the server's audit log +func (a *AuthServer) SetAuditLog(auditLog events.IAuditLog) { + a.IAuditLog = auditLog +} + // GetDomainName returns the domain name that identifies this authority server. // Also known as "cluster name" func (a *AuthServer) GetDomainName() (string, error) { @@ -225,7 +236,6 @@ func (s *AuthServer) GenerateUserCert(key []byte, user services.User, allowedLog if err != nil { return nil, trace.Wrap(err) } - ca, err := s.Trust.GetCertAuthority(services.CertAuthID{ Type: services.UserCA, DomainName: domainName, @@ -237,7 +247,7 @@ func (s *AuthServer) GenerateUserCert(key []byte, user services.User, allowedLog if err != nil { return nil, trace.Wrap(err) } - return s.Authority.GenerateUserCert(services.UserCertParams{ + cert, err := s.Authority.GenerateUserCert(services.UserCertParams{ PrivateCASigningKey: privateKey, PublicUserKey: key, Username: user.GetName(), @@ -247,6 +257,14 @@ func (s *AuthServer) GenerateUserCert(key []byte, user services.User, allowedLog Compatibility: compatibility, PermitAgentForwarding: canForwardAgents, }) + if err != nil { + return nil, trace.Wrap(err) + } + s.EmitAuditEvent(events.UserLoginEvent, events.EventFields{ + events.EventUser: user.GetName(), + events.LoginMethod: events.LoginMethodLocal, + }) + return cert, nil } // WithUserLock executes function authenticateFn that performs user authentication @@ -315,6 +333,10 @@ func (s *AuthServer) SignIn(user string, password []byte) (services.WebSession, if err != nil { return nil, trace.Wrap(err) } + s.EmitAuditEvent(events.UserLoginEvent, events.EventFields{ + events.EventUser: user, + events.LoginMethod: events.LoginMethodLocal, + }) return s.PreAuthenticatedSignIn(user) } diff --git a/lib/auth/init.go b/lib/auth/init.go index 5ed56a4956f..0d4ef5489a7 100644 --- a/lib/auth/init.go +++ b/lib/auth/init.go @@ -28,10 +28,12 @@ import ( "github.com/gravitational/teleport/lib" "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/trace" + "github.com/pborman/uuid" log "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" ) @@ -105,10 +107,13 @@ type InitConfig struct { // ClusterConfig holds cluster level configuration. ClusterConfig services.ClusterConfig + + // AuditLog is used for emitting events to audit log + AuditLog events.IAuditLog } // Init instantiates and configures an instance of AuthServer -func Init(cfg InitConfig) (*AuthServer, *Identity, error) { +func Init(cfg InitConfig, opts ...AuthServerOption) (*AuthServer, *Identity, error) { if cfg.DataDir == "" { return nil, nil, trace.BadParameter("DataDir: data dir can not be empty") } @@ -124,7 +129,7 @@ func Init(cfg InitConfig) (*AuthServer, *Identity, error) { defer cfg.Backend.ReleaseLock(domainName) // check that user CA and host CA are present and set the certs if needed - asrv := NewAuthServer(&cfg) + asrv := NewAuthServer(&cfg, opts...) // INTERNAL: Authorities (plus Roles) and ReverseTunnels don't follow the // same pattern as the rest of the configuration (they are not configuration @@ -155,6 +160,17 @@ func Init(cfg InitConfig) (*AuthServer, *Identity, error) { } // set cluster level config on the backend and then force a sync of the cache. + clusterConfig, err := asrv.GetClusterConfig() + if err != nil && !trace.IsNotFound(err) { + return nil, nil, trace.Wrap(err) + } + // init a unique cluster ID, it must be set once only during the first + // start so if it's already there, reuse it + if clusterConfig != nil && clusterConfig.GetClusterID() != "" { + cfg.ClusterConfig.SetClusterID(clusterConfig.GetClusterID()) + } else { + cfg.ClusterConfig.SetClusterID(uuid.New()) + } err = asrv.SetClusterConfig(cfg.ClusterConfig) if err != nil { return nil, nil, trace.Wrap(err) diff --git a/lib/auth/init_test.go b/lib/auth/init_test.go index 03bff27d357..98f3bfbaa76 100644 --- a/lib/auth/init_test.go +++ b/lib/auth/init_test.go @@ -214,3 +214,45 @@ func (s *AuthInitSuite) TestAuthPreference(c *C) { c.Assert(u.AppID, Equals, "foo") c.Assert(u.Facets, DeepEquals, []string{"bar", "baz"}) } + +func (s *AuthInitSuite) TestClusterID(c *C) { + bk, err := boltbk.New(backend.Params{"path": c.MkDir()}) + c.Assert(err, IsNil) + + clusterName, err := services.NewClusterName(services.ClusterNameSpecV2{ + ClusterName: "me.localhost", + }) + c.Assert(err, IsNil) + + authServer, _, err := Init(InitConfig{ + DataDir: c.MkDir(), + HostUUID: "00000000-0000-0000-0000-000000000000", + NodeName: "foo", + Backend: bk, + Authority: testauthority.New(), + ClusterName: clusterName, + ClusterConfig: services.DefaultClusterConfig(), + }) + c.Assert(err, IsNil) + + cc, err := authServer.GetClusterConfig() + c.Assert(err, IsNil) + clusterID := cc.GetClusterID() + c.Assert(clusterID, Not(Equals), "") + + // do it again and make sure cluster ID hasn't changed + authServer, _, err = Init(InitConfig{ + DataDir: c.MkDir(), + HostUUID: "00000000-0000-0000-0000-000000000000", + NodeName: "foo", + Backend: bk, + Authority: testauthority.New(), + ClusterName: clusterName, + ClusterConfig: services.DefaultClusterConfig(), + }) + c.Assert(err, IsNil) + + cc, err = authServer.GetClusterConfig() + c.Assert(err, IsNil) + c.Assert(cc.GetClusterID(), Equals, clusterID) +} diff --git a/lib/auth/oidc.go b/lib/auth/oidc.go index 4d5b122915b..7309edccb35 100644 --- a/lib/auth/oidc.go +++ b/lib/auth/oidc.go @@ -9,13 +9,14 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" - "github.com/gravitational/trace" "github.com/coreos/go-oidc/jose" "github.com/coreos/go-oidc/oauth2" "github.com/coreos/go-oidc/oidc" + "github.com/gravitational/trace" log "github.com/sirupsen/logrus" ) @@ -231,7 +232,10 @@ func (a *AuthServer) ValidateOIDCAuthCallback(q url.Values) (*OIDCAuthResponse, response.HostSigners = append(response.HostSigners, authority) } } - + a.EmitAuditEvent(events.UserLoginEvent, events.EventFields{ + events.EventUser: user.GetName(), + events.LoginMethod: events.LoginMethodOIDC, + }) return response, nil } diff --git a/lib/auth/plugin.go b/lib/auth/plugin.go new file mode 100644 index 00000000000..b2d460dda5c --- /dev/null +++ b/lib/auth/plugin.go @@ -0,0 +1,44 @@ +/* +Copyright 2017 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package auth + +import ( + "sync" +) + +var pluginMutex = &sync.Mutex{} +var plugin Plugin + +// GetPlugin returns auth API server plugin that allows injecting handlers +func GetPlugin() Plugin { + pluginMutex.Lock() + defer pluginMutex.Unlock() + return plugin +} + +// SetPlugin sets plugin for the auth API server +func SetPlugin(p Plugin) { + pluginMutex.Lock() + defer pluginMutex.Unlock() + plugin = p +} + +// Plugin is auth API server extension setter +type Plugin interface { + // AddHandlers adds handlers to the auth API server + AddHandlers(srv *APIServer) +} diff --git a/lib/auth/saml.go b/lib/auth/saml.go index 0b08cfcb5bf..61296415c4d 100644 --- a/lib/auth/saml.go +++ b/lib/auth/saml.go @@ -9,11 +9,12 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" - "github.com/gravitational/trace" "github.com/beevik/etree" + "github.com/gravitational/trace" saml2 "github.com/russellhaering/gosaml2" log "github.com/sirupsen/logrus" ) @@ -353,6 +354,9 @@ func (a *AuthServer) ValidateSAMLResponse(samlResponse string) (*SAMLAuthRespons response.HostSigners = append(response.HostSigners, authority) } } - + a.EmitAuditEvent(events.UserLoginEvent, events.EventFields{ + events.EventUser: user.GetName(), + events.LoginMethod: events.LoginMethodSAML, + }) return response, nil } diff --git a/lib/config/configuration.go b/lib/config/configuration.go index 0e2d5b5909b..d2abc8fcd26 100644 --- a/lib/config/configuration.go +++ b/lib/config/configuration.go @@ -26,6 +26,7 @@ import ( "net" "net/url" "os" + "path/filepath" "strings" "time" "unicode" @@ -404,6 +405,7 @@ func ApplyFileConfig(fc *FileConfig, cfg *service.Config) error { log.Warnf(warningMessage) } } + // read in and set session recording clusterConfig, err := fc.Auth.SessionRecording.Parse() if err != nil { @@ -411,6 +413,16 @@ func ApplyFileConfig(fc *FileConfig, cfg *service.Config) error { } cfg.Auth.ClusterConfig = clusterConfig + // read in and set the license file path (not used in open-source version) + licenseFile := fc.Auth.LicenseFile + if licenseFile != "" { + if filepath.IsAbs(licenseFile) { + cfg.Auth.LicenseFile = licenseFile + } else { + cfg.Auth.LicenseFile = filepath.Join(cfg.DataDir, licenseFile) + } + } + // apply "ssh_service" section if fc.SSH.ListenAddress != "" { addr, err := utils.ParseHostPortAddr(fc.SSH.ListenAddress, int(defaults.SSHServerListenPort)) diff --git a/lib/config/configuration_test.go b/lib/config/configuration_test.go index 571a6daa90d..9d3e2ec916b 100644 --- a/lib/config/configuration_test.go +++ b/lib/config/configuration_test.go @@ -146,6 +146,7 @@ func (s *ConfigTestSuite) TestConfigReading(c *check.C) { c.Assert(conf.DataDir, check.Equals, "/path/to/data") c.Assert(conf.Auth.Enabled(), check.Equals, true) c.Assert(conf.Auth.ListenAddress, check.Equals, "tcp://auth") + c.Assert(conf.Auth.LicenseFile, check.Equals, "lic.pem") c.Assert(conf.SSH.Configured(), check.Equals, true) c.Assert(conf.SSH.Enabled(), check.Equals, true) c.Assert(conf.SSH.ListenAddress, check.Equals, "tcp://ssh") @@ -508,6 +509,7 @@ func makeConfigFixture() string { // auth service: conf.Auth.EnabledFlag = "Yeah" conf.Auth.ListenAddress = "tcp://auth" + conf.Auth.LicenseFile = "lic.pem" // ssh service: conf.SSH.EnabledFlag = "true" @@ -572,3 +574,40 @@ ssh_service: c.Assert(cfg.SSH.PermitUserEnvironment, check.Equals, tt.outPermitUserEnvironment, comment) } } + +func (s *ConfigTestSuite) TestLicenseFile(c *check.C) { + testCases := []struct { + path string + result string + }{ + // 0 - no license + { + path: "", + result: filepath.Join(defaults.DataDir, defaults.LicenseFile), + }, + // 1 - relative path + { + path: "lic.pem", + result: filepath.Join(defaults.DataDir, "lic.pem"), + }, + // 2 - absolute path + { + path: "/etc/teleport/license", + result: "/etc/teleport/license", + }, + } + + cfg := service.MakeDefaultConfig() + c.Assert(cfg.Auth.LicenseFile, check.Equals, + filepath.Join(defaults.DataDir, defaults.LicenseFile)) + + for _, tc := range testCases { + err := ApplyFileConfig(&FileConfig{ + Auth: Auth{ + LicenseFile: tc.path, + }, + }, cfg) + c.Assert(err, check.IsNil) + c.Assert(cfg.Auth.LicenseFile, check.Equals, tc.result) + } +} diff --git a/lib/config/fileconf.go b/lib/config/fileconf.go index 7e94a69977c..c6c6bf3b779 100644 --- a/lib/config/fileconf.go +++ b/lib/config/fileconf.go @@ -44,8 +44,8 @@ import ( var ( // all possible valid YAML config keys - // true = non-scalar - // false = scalar + // true = has sub-keys + // false = does not have sub-keys (a leaf) validKeys = map[string]bool{ "namespace": true, "cluster_name": true, @@ -131,6 +131,7 @@ var ( "session_recording": false, "read_capacity_units": false, "write_capacity_units": false, + "license_file": false, } ) @@ -496,6 +497,10 @@ type Auth struct { // it here overrides defaults. // Deprecated: Remove in Teleport 2.4.1. DynamicConfig *bool `yaml:"dynamic_config,omitempty"` + + // LicenseFile is a path to the license file. The path can be either absolute or + // relative to the global data dir + LicenseFile string `yaml:"license_file,omitempty"` } // TrustedCluster struct holds configuration values under "trusted_clusters" key diff --git a/lib/config/testdata_test.go b/lib/config/testdata_test.go index eaf7f6c70e7..b4fc2c6057a 100644 --- a/lib/config/testdata_test.go +++ b/lib/config/testdata_test.go @@ -46,7 +46,7 @@ teleport: - period: 10m10s average: 170 burst: 171 - keys: + keys: - cert: node.cert private_key: !!binary cHJpdmF0ZSBrZXk= - cert_file: /proxy.cert.file @@ -61,21 +61,21 @@ auth_service: tokens: - "proxy,node:xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" - "auth:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" - authorities: + authorities: - type: host domain_name: example.com - checking_keys: + checking_keys: - checking key 1 checking_key_files: - /ca.checking.key - signing_keys: + signing_keys: - !!binary c2lnbmluZyBrZXkgMQ== signing_key_files: - /ca.signing.key reverse_tunnels: - - domain_name: tunnel.example.com + - domain_name: tunnel.example.com addresses: ["com-1", "com-2"] - - domain_name: tunnel.example.org + - domain_name: tunnel.example.org addresses: ["org-1"] ssh_service: @@ -135,7 +135,7 @@ proxy_service: // need to support it until it's fully removed. const LegacyAuthenticationSection = ` auth_service: - oidc_connectors: + oidc_connectors: - id: google redirect_url: https://localhost:3080/v1/webapi/oidc/callback client_id: id-from-google.apps.googleusercontent.com diff --git a/lib/defaults/defaults.go b/lib/defaults/defaults.go index 180d51a3d5f..cbc54044254 100644 --- a/lib/defaults/defaults.go +++ b/lib/defaults/defaults.go @@ -235,7 +235,7 @@ var ( // ConfigFilePath is default path to teleport config file ConfigFilePath = "/etc/teleport.yaml" - // DataDir is where all mutable data is stored (user keys, recorded sessions, + // DataDir is where all mutable data is stored (user keys, recorded sessions, // registered SSH servers, etc): DataDir = "/var/lib/teleport" @@ -247,6 +247,9 @@ var ( // ConfigEnvar is a name of teleport's configuration environment variable ConfigEnvar = "TELEPORT_CONFIG" + + // LicenseFile is the default name of the license file + LicenseFile = "license.pem" ) const ( diff --git a/lib/events/api.go b/lib/events/api.go index c6bea981cf9..3c307eb7c08 100644 --- a/lib/events/api.go +++ b/lib/events/api.go @@ -63,10 +63,22 @@ const ( // the beginning SessionByteOffset = "offset" - // Join & Leave events indicate when someone joins/leaves a session - SessionJoinEvent = "session.join" + // SessionJoinEvent indicates that someone joined a session + SessionJoinEvent = "session.join" + // SessionLeaveEvent indicates that someone left a session SessionLeaveEvent = "session.leave" + // UserLoginEvent indicates that a user logged into web UI or via tsh + UserLoginEvent = "user.login" + // LoginMethod is the event field indicating how the login was performed + LoginMethod = "method" + // LoginMethodLocal represents login with username/password + LoginMethodLocal = "local" + // LoginMethodOIDC represents login with OIDC + LoginMethodOIDC = "oidc" + // LoginMethodSAML represents login with SAML + LoginMethodSAML = "saml" + // ExecEvent is an exec command executed by script or user on // the server side ExecEvent = "exec" @@ -101,7 +113,7 @@ const ( MaxChunkBytes = 1024 * 1024 * 5 ) -// IAuditLog is the primary (and the only external-facing) interface for AUditLogger. +// IAuditLog is the primary (and the only external-facing) interface for AuditLogger. // If you wish to implement a different kind of logger (not filesystem-based), you // have to implement this interface type IAuditLog interface { diff --git a/lib/events/discard.go b/lib/events/discard.go index fe08a3dce4b..66cefcae380 100644 --- a/lib/events/discard.go +++ b/lib/events/discard.go @@ -10,7 +10,11 @@ import ( // DiscardAuditLog is do-nothing, discard-everything implementation // of IAuditLog interface used for cases when audit is turned off -type DiscardAuditLog struct { +type DiscardAuditLog struct{} + +// NewDiscardAuditLog returns a no-op audit log instance +func NewDiscardAuditLog() *DiscardAuditLog { + return &DiscardAuditLog{} } func (d *DiscardAuditLog) WaitForDelivery(context.Context) error { @@ -46,8 +50,7 @@ func (d *DiscardAuditLog) SearchSessionEvents(fromUTC time.Time, toUTC time.Time // discardSessionLogger implements a session logger that does nothing. It // discards all events and chunks written to it. It is used when session // recording has been disabled. -type discardSessionLogger struct { -} +type discardSessionLogger struct{} func (d *discardSessionLogger) LogEvent(fields EventFields) { return diff --git a/lib/service/cfg.go b/lib/service/cfg.go index b034c6005a5..2b0892a13b8 100644 --- a/lib/service/cfg.go +++ b/lib/service/cfg.go @@ -21,6 +21,7 @@ import ( "io" "net" "os" + "path/filepath" "time" "golang.org/x/crypto/ssh" @@ -70,7 +71,7 @@ type Config struct { // SSH role an SSH endpoint server SSH SSHConfig - // Auth server authentication and authorizatin server config + // Auth server authentication and authorization server config Auth AuthConfig // Keygen points to a key generator implementation @@ -266,6 +267,9 @@ type AuthConfig struct { // Preference defines the authentication preference (type and second factor) for // the auth server. Preference services.AuthPreference + + // LicenseFile is a full path to the license file + LicenseFile string } // SSHConfig configures SSH server node role @@ -320,6 +324,7 @@ func ApplyDefaults(cfg *Config) { ap := &services.AuthPreferenceV2{} ap.CheckAndSetDefaults() cfg.Auth.Preference = ap + cfg.Auth.LicenseFile = filepath.Join(cfg.DataDir, defaults.LicenseFile) // defaults for the SSH proxy service: cfg.Proxy.Enabled = true diff --git a/lib/service/service.go b/lib/service/service.go index a2a8ffc3b4b..7c6527fb86a 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -103,15 +103,30 @@ type TeleportProcess struct { // localAuth has local auth server listed in case if this process // has started with auth server role enabled localAuth *auth.AuthServer + // backend is the process' backend + backend backend.Backend + // auditLog is the initialized audit log + auditLog events.IAuditLog // identities of this process (credentials to auth sever, basically) Identities map[teleport.Role]*auth.Identity } +// GetAuthServer returns the process' auth server func (process *TeleportProcess) GetAuthServer() *auth.AuthServer { return process.localAuth } +// GetAuditLog returns the process' audit log +func (process *TeleportProcess) GetAuditLog() events.IAuditLog { + return process.auditLog +} + +// GetBackend returns the process' backend +func (process *TeleportProcess) GetBackend() backend.Backend { + return process.backend +} + func (process *TeleportProcess) findStaticIdentity(id auth.IdentityID) (*auth.Identity, error) { for i := range process.Config.Identities { identity := process.Config.Identities[i] @@ -303,13 +318,13 @@ func (process *TeleportProcess) initAuthService(authority auth.Authority) error if err != nil { return trace.Wrap(err) } + process.backend = b // create the audit log, which will be consuming (and recording) all events // and recording all sessions. - var auditLog events.IAuditLog if cfg.Auth.NoAudit { // this is for teleconsole - auditLog = &events.DiscardAuditLog{} + process.auditLog = events.NewDiscardAuditLog() warningMessage := "Warning: Teleport audit and session recording have been " + "turned off. This is dangerous, you will not be able to view audit events " + @@ -345,7 +360,7 @@ func (process *TeleportProcess) initAuthService(authority auth.Authority) error auditConfig.GID = &gid } } - auditLog, err = events.NewAuditLog(auditConfig) + process.auditLog, err = events.NewAuditLog(auditConfig) if err != nil { return trace.Wrap(err) } @@ -372,6 +387,7 @@ func (process *TeleportProcess) initAuthService(authority auth.Authority) error Roles: cfg.Auth.Roles, AuthPreference: cfg.Auth.Preference, OIDCConnectors: cfg.OIDCConnectors, + AuditLog: process.auditLog, }) if err != nil { return trace.Wrap(err) @@ -393,7 +409,7 @@ func (process *TeleportProcess) initAuthService(authority auth.Authority) error AuthServer: authServer, SessionService: sessionService, Authorizer: authorizer, - AuditLog: auditLog, + AuditLog: process.auditLog, } limiter, err := limiter.NewLimiter(cfg.Auth.Limiter) diff --git a/lib/services/clusterconfig.go b/lib/services/clusterconfig.go index 36f9b6bfe6f..41740c122b0 100644 --- a/lib/services/clusterconfig.go +++ b/lib/services/clusterconfig.go @@ -41,6 +41,12 @@ type ClusterConfig interface { // SetSessionRecording sets where the session is recorded. SetSessionRecording(RecordingType) + // GetClusterID returns the unique cluster ID + GetClusterID() string + + // SetClusterID sets the cluster ID + SetClusterID(string) + // CheckAndSetDefaults checks and set default values for missing fields. CheckAndSetDefaults() error @@ -115,6 +121,9 @@ const ( type ClusterConfigSpecV3 struct { // SessionRecording controls where (or if) the session is recorded. SessionRecording RecordingType `json:"session_recording"` + // ClusterID is the unique cluster ID that is set once during the first auth + // server startup. + ClusterID string `json:"cluster_id"` } // GetName returns the name of the cluster. @@ -157,6 +166,16 @@ func (c *ClusterConfigV3) SetSessionRecording(s RecordingType) { c.Spec.SessionRecording = s } +// GetClusterID returns the unique cluster ID +func (c *ClusterConfigV3) GetClusterID() string { + return c.Spec.ClusterID +} + +// SetClusterID sets the cluster ID +func (c *ClusterConfigV3) SetClusterID(id string) { + c.Spec.ClusterID = id +} + // CheckAndSetDefaults checks validity of all parameters and sets defaults. func (c *ClusterConfigV3) CheckAndSetDefaults() error { // make sure we have defaults for all metadata fields @@ -187,7 +206,8 @@ func (c *ClusterConfigV3) Copy() ClusterConfig { // String represents a human readable version of the cluster name. func (c *ClusterConfigV3) String() string { - return fmt.Sprintf("ClusterConfig(SessionRecording=%v)", c.Spec.SessionRecording) + return fmt.Sprintf("ClusterConfig(SessionRecording=%v, ClusterID=%v)", + c.Spec.SessionRecording, c.Spec.ClusterID) } // ClusterConfigSpecSchemaTemplate is a template for ClusterConfig schema. @@ -197,6 +217,9 @@ const ClusterConfigSpecSchemaTemplate = `{ "properties": { "session_recording": { "type": "string" + }, + "cluster_id": { + "type": "string" }%v } }` diff --git a/lib/services/identity.go b/lib/services/identity.go index e921840fc75..34e6ecc3785 100644 --- a/lib/services/identity.go +++ b/lib/services/identity.go @@ -228,8 +228,8 @@ const ExternalIdentitySchema = `{ "type": "object", "additionalProperties": false, "properties": { - "connector_id": {"type": "string"}, - "username": {"type": "string"} + "connector_id": {"type": "string"}, + "username": {"type": "string"} } }` diff --git a/lib/services/local/configuration.go b/lib/services/local/configuration.go index 5e1d28fa556..334e8ad4394 100644 --- a/lib/services/local/configuration.go +++ b/lib/services/local/configuration.go @@ -19,6 +19,7 @@ package local import ( "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/trace" ) diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 9de6e5d2934..f58ad7ea9de 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -311,6 +311,11 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*RewritingHandler, error) { }, nil } +// GetProxyClient returns authenticated auth server client +func (h *Handler) GetProxyClient() auth.ClientI { + return h.cfg.ProxyClient +} + // Close closes associated session cache operations func (h *Handler) Close() error { return h.auth.Close() diff --git a/lib/web/sessions.go b/lib/web/sessions.go index b96e02b74f7..ecd22850b6f 100644 --- a/lib/web/sessions.go +++ b/lib/web/sessions.go @@ -138,7 +138,7 @@ func (c *SessionContext) getRemoteClient(siteName string) (auth.ClientI, bool) { } // GetClient returns the client connected to the auth server -func (c *SessionContext) GetClient() (auth.ClientI, error) { +func (c *SessionContext) GetClient() (*auth.TunClient, error) { return c.clt, nil } diff --git a/tool/teleport/common/teleport.go b/tool/teleport/common/teleport.go index 83b3a140c49..0b6a1988efa 100644 --- a/tool/teleport/common/teleport.go +++ b/tool/teleport/common/teleport.go @@ -39,10 +39,17 @@ import ( log "github.com/sirupsen/logrus" ) -// same as main() but has a testing switch -// - cmdlineArgs are passed from main() -// - testRun is 'true' when running under an integration test -func Run(cmdlineArgs []string, testRun bool) (executedCommand string, conf *service.Config) { +// Options combines init/start teleport options +type Options struct { + // Args is a list of command-line args passed from main() + Args []string + // InitOnly when set to true, initializes config and aux + // endpoints but does not start the process + InitOnly bool +} + +// Run inits/starts the process according to the provided options +func Run(options Options) (executedCommand string, conf *service.Config) { var err error // configure trace's errors to produce full stack traces @@ -130,7 +137,7 @@ func Run(cmdlineArgs []string, testRun bool) (executedCommand string, conf *serv scpc.Arg("target", "").StringsVar(&scpCommand.Target) // parse CLI commands+flags: - command, err := app.Parse(cmdlineArgs) + command, err := app.Parse(options.Args) if err != nil { utils.FatalError(err) } @@ -145,7 +152,7 @@ func Run(cmdlineArgs []string, testRun bool) (executedCommand string, conf *serv if err = config.Configure(&ccf, conf); err != nil { utils.FatalError(err) } - if !testRun { + if !options.InitOnly { log.Debug(conf.DebugDumpToYAML()) } if ccf.HTTPProfileEndpoint { @@ -172,8 +179,8 @@ func Run(cmdlineArgs []string, testRun bool) (executedCommand string, conf *serv } }() } - if !testRun { - err = onStart(conf) + if !options.InitOnly { + err = OnStart(conf) } case scpc.FullCommand(): err = onSCP(&scpCommand) @@ -191,12 +198,13 @@ func Run(cmdlineArgs []string, testRun bool) (executedCommand string, conf *serv return command, conf } -// onStart is the handler for "start" CLI command -func onStart(config *service.Config) error { +// OnStart is the handler for "start" CLI command +func OnStart(config *service.Config) error { srv, err := service.NewTeleport(config) if err != nil { return trace.Wrap(err, "initializing teleport") } + if err := srv.Start(); err != nil { return trace.Wrap(err, "starting teleport") } @@ -210,8 +218,8 @@ func onStart(config *service.Config) error { fmt.Fprintf(f, "%v", os.Getpid()) defer f.Close() } - srv.Wait() - return nil + + return trace.Wrap(srv.Wait()) } // onStatus is the handler for "status" CLI command diff --git a/tool/teleport/common/teleport_test.go b/tool/teleport/common/teleport_test.go index db9a30acf4e..621e7eee208 100644 --- a/tool/teleport/common/teleport_test.go +++ b/tool/teleport/common/teleport_test.go @@ -66,7 +66,10 @@ func (s *MainTestSuite) SetUpSuite(c *check.C) { } func (s *MainTestSuite) TestDefault(c *check.C) { - cmd, conf := Run([]string{"start"}, true) + cmd, conf := Run(Options{ + Args: []string{"start"}, + InitOnly: true, + }) c.Assert(cmd, check.Equals, "start") c.Assert(conf.Hostname, check.Equals, s.hostname) c.Assert(conf.DataDir, check.Equals, "/tmp/teleport/var/lib/teleport") @@ -78,17 +81,26 @@ func (s *MainTestSuite) TestDefault(c *check.C) { } func (s *MainTestSuite) TestRolesFlag(c *check.C) { - cmd, conf := Run([]string{"start", "--roles=node"}, true) + cmd, conf := Run(Options{ + Args: []string{"start", "--roles=node"}, + InitOnly: true, + }) c.Assert(conf.SSH.Enabled, check.Equals, true) c.Assert(conf.Auth.Enabled, check.Equals, false) c.Assert(conf.Proxy.Enabled, check.Equals, false) - cmd, conf = Run([]string{"start", "--roles=proxy"}, true) + cmd, conf = Run(Options{ + Args: []string{"start", "--roles=proxy"}, + InitOnly: true, + }) c.Assert(conf.SSH.Enabled, check.Equals, false) c.Assert(conf.Auth.Enabled, check.Equals, false) c.Assert(conf.Proxy.Enabled, check.Equals, true) - cmd, conf = Run([]string{"start", "--roles=auth"}, true) + cmd, conf = Run(Options{ + Args: []string{"start", "--roles=auth"}, + InitOnly: true, + }) c.Assert(conf.SSH.Enabled, check.Equals, false) c.Assert(conf.Auth.Enabled, check.Equals, true) c.Assert(conf.Proxy.Enabled, check.Equals, false) @@ -96,7 +108,10 @@ func (s *MainTestSuite) TestRolesFlag(c *check.C) { } func (s *MainTestSuite) TestConfigFile(c *check.C) { - cmd, conf := Run([]string{"start", "--roles=node", "--labels=a=a1,b=b1", "--config=" + s.configFile}, true) + cmd, conf := Run(Options{ + Args: []string{"start", "--roles=node", "--labels=a=a1,b=b1", "--config=" + s.configFile}, + InitOnly: true, + }) c.Assert(cmd, check.Equals, "start") c.Assert(conf.SSH.Enabled, check.Equals, true) c.Assert(conf.Auth.Enabled, check.Equals, false) diff --git a/tool/teleport/main.go b/tool/teleport/main.go index dff378a6426..35fa11a826e 100644 --- a/tool/teleport/main.go +++ b/tool/teleport/main.go @@ -23,6 +23,7 @@ import ( ) func main() { - const testRun = false - common.Run(os.Args[1:], testRun) + common.Run(common.Options{ + Args: os.Args[1:], + }) } diff --git a/vendor/github.com/golang/protobuf/jsonpb/jsonpb.go b/vendor/github.com/golang/protobuf/jsonpb/jsonpb.go new file mode 100644 index 00000000000..82c61624ef0 --- /dev/null +++ b/vendor/github.com/golang/protobuf/jsonpb/jsonpb.go @@ -0,0 +1,843 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2015 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +/* +Package jsonpb provides marshaling and unmarshaling between protocol buffers and JSON. +It follows the specification at https://developers.google.com/protocol-buffers/docs/proto3#json. + +This package produces a different output than the standard "encoding/json" package, +which does not operate correctly on protocol buffers. +*/ +package jsonpb + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "reflect" + "sort" + "strconv" + "strings" + "time" + + "github.com/golang/protobuf/proto" +) + +// Marshaler is a configurable object for converting between +// protocol buffer objects and a JSON representation for them. +type Marshaler struct { + // Whether to render enum values as integers, as opposed to string values. + EnumsAsInts bool + + // Whether to render fields with zero values. + EmitDefaults bool + + // A string to indent each level by. The presence of this field will + // also cause a space to appear between the field separator and + // value, and for newlines to be appear between fields and array + // elements. + Indent string + + // Whether to use the original (.proto) name for fields. + OrigName bool +} + +// Marshal marshals a protocol buffer into JSON. +func (m *Marshaler) Marshal(out io.Writer, pb proto.Message) error { + writer := &errWriter{writer: out} + return m.marshalObject(writer, pb, "", "") +} + +// MarshalToString converts a protocol buffer object to JSON string. +func (m *Marshaler) MarshalToString(pb proto.Message) (string, error) { + var buf bytes.Buffer + if err := m.Marshal(&buf, pb); err != nil { + return "", err + } + return buf.String(), nil +} + +type int32Slice []int32 + +// For sorting extensions ids to ensure stable output. +func (s int32Slice) Len() int { return len(s) } +func (s int32Slice) Less(i, j int) bool { return s[i] < s[j] } +func (s int32Slice) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +type wkt interface { + XXX_WellKnownType() string +} + +// marshalObject writes a struct to the Writer. +func (m *Marshaler) marshalObject(out *errWriter, v proto.Message, indent, typeURL string) error { + s := reflect.ValueOf(v).Elem() + + // Handle well-known types. + if wkt, ok := v.(wkt); ok { + switch wkt.XXX_WellKnownType() { + case "DoubleValue", "FloatValue", "Int64Value", "UInt64Value", + "Int32Value", "UInt32Value", "BoolValue", "StringValue", "BytesValue": + // "Wrappers use the same representation in JSON + // as the wrapped primitive type, ..." + sprop := proto.GetProperties(s.Type()) + return m.marshalValue(out, sprop.Prop[0], s.Field(0), indent) + case "Any": + // Any is a bit more involved. + return m.marshalAny(out, v, indent) + case "Duration": + // "Generated output always contains 3, 6, or 9 fractional digits, + // depending on required precision." + s, ns := s.Field(0).Int(), s.Field(1).Int() + d := time.Duration(s)*time.Second + time.Duration(ns)*time.Nanosecond + x := fmt.Sprintf("%.9f", d.Seconds()) + x = strings.TrimSuffix(x, "000") + x = strings.TrimSuffix(x, "000") + out.write(`"`) + out.write(x) + out.write(`s"`) + return out.err + case "Struct": + // Let marshalValue handle the `fields` map. + // TODO: pass the correct Properties if needed. + return m.marshalValue(out, &proto.Properties{}, s.Field(0), indent) + case "Timestamp": + // "RFC 3339, where generated output will always be Z-normalized + // and uses 3, 6 or 9 fractional digits." + s, ns := s.Field(0).Int(), s.Field(1).Int() + t := time.Unix(s, ns).UTC() + // time.RFC3339Nano isn't exactly right (we need to get 3/6/9 fractional digits). + x := t.Format("2006-01-02T15:04:05.000000000") + x = strings.TrimSuffix(x, "000") + x = strings.TrimSuffix(x, "000") + out.write(`"`) + out.write(x) + out.write(`Z"`) + return out.err + case "Value": + // Value has a single oneof. + kind := s.Field(0) + if kind.IsNil() { + // "absence of any variant indicates an error" + return errors.New("nil Value") + } + // oneof -> *T -> T -> T.F + x := kind.Elem().Elem().Field(0) + // TODO: pass the correct Properties if needed. + return m.marshalValue(out, &proto.Properties{}, x, indent) + } + } + + out.write("{") + if m.Indent != "" { + out.write("\n") + } + + firstField := true + + if typeURL != "" { + if err := m.marshalTypeURL(out, indent, typeURL); err != nil { + return err + } + firstField = false + } + + for i := 0; i < s.NumField(); i++ { + value := s.Field(i) + valueField := s.Type().Field(i) + if strings.HasPrefix(valueField.Name, "XXX_") { + continue + } + + // IsNil will panic on most value kinds. + switch value.Kind() { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: + if value.IsNil() { + continue + } + } + + if !m.EmitDefaults { + switch value.Kind() { + case reflect.Bool: + if !value.Bool() { + continue + } + case reflect.Int32, reflect.Int64: + if value.Int() == 0 { + continue + } + case reflect.Uint32, reflect.Uint64: + if value.Uint() == 0 { + continue + } + case reflect.Float32, reflect.Float64: + if value.Float() == 0 { + continue + } + case reflect.String: + if value.Len() == 0 { + continue + } + } + } + + // Oneof fields need special handling. + if valueField.Tag.Get("protobuf_oneof") != "" { + // value is an interface containing &T{real_value}. + sv := value.Elem().Elem() // interface -> *T -> T + value = sv.Field(0) + valueField = sv.Type().Field(0) + } + prop := jsonProperties(valueField, m.OrigName) + if !firstField { + m.writeSep(out) + } + if err := m.marshalField(out, prop, value, indent); err != nil { + return err + } + firstField = false + } + + // Handle proto2 extensions. + if ep, ok := v.(proto.Message); ok { + extensions := proto.RegisteredExtensions(v) + // Sort extensions for stable output. + ids := make([]int32, 0, len(extensions)) + for id, desc := range extensions { + if !proto.HasExtension(ep, desc) { + continue + } + ids = append(ids, id) + } + sort.Sort(int32Slice(ids)) + for _, id := range ids { + desc := extensions[id] + if desc == nil { + // unknown extension + continue + } + ext, extErr := proto.GetExtension(ep, desc) + if extErr != nil { + return extErr + } + value := reflect.ValueOf(ext) + var prop proto.Properties + prop.Parse(desc.Tag) + prop.JSONName = fmt.Sprintf("[%s]", desc.Name) + if !firstField { + m.writeSep(out) + } + if err := m.marshalField(out, &prop, value, indent); err != nil { + return err + } + firstField = false + } + + } + + if m.Indent != "" { + out.write("\n") + out.write(indent) + } + out.write("}") + return out.err +} + +func (m *Marshaler) writeSep(out *errWriter) { + if m.Indent != "" { + out.write(",\n") + } else { + out.write(",") + } +} + +func (m *Marshaler) marshalAny(out *errWriter, any proto.Message, indent string) error { + // "If the Any contains a value that has a special JSON mapping, + // it will be converted as follows: {"@type": xxx, "value": yyy}. + // Otherwise, the value will be converted into a JSON object, + // and the "@type" field will be inserted to indicate the actual data type." + v := reflect.ValueOf(any).Elem() + turl := v.Field(0).String() + val := v.Field(1).Bytes() + + // Only the part of type_url after the last slash is relevant. + mname := turl + if slash := strings.LastIndex(mname, "/"); slash >= 0 { + mname = mname[slash+1:] + } + mt := proto.MessageType(mname) + if mt == nil { + return fmt.Errorf("unknown message type %q", mname) + } + msg := reflect.New(mt.Elem()).Interface().(proto.Message) + if err := proto.Unmarshal(val, msg); err != nil { + return err + } + + if _, ok := msg.(wkt); ok { + out.write("{") + if m.Indent != "" { + out.write("\n") + } + if err := m.marshalTypeURL(out, indent, turl); err != nil { + return err + } + m.writeSep(out) + if m.Indent != "" { + out.write(indent) + out.write(m.Indent) + out.write(`"value": `) + } else { + out.write(`"value":`) + } + if err := m.marshalObject(out, msg, indent+m.Indent, ""); err != nil { + return err + } + if m.Indent != "" { + out.write("\n") + out.write(indent) + } + out.write("}") + return out.err + } + + return m.marshalObject(out, msg, indent, turl) +} + +func (m *Marshaler) marshalTypeURL(out *errWriter, indent, typeURL string) error { + if m.Indent != "" { + out.write(indent) + out.write(m.Indent) + } + out.write(`"@type":`) + if m.Indent != "" { + out.write(" ") + } + b, err := json.Marshal(typeURL) + if err != nil { + return err + } + out.write(string(b)) + return out.err +} + +// marshalField writes field description and value to the Writer. +func (m *Marshaler) marshalField(out *errWriter, prop *proto.Properties, v reflect.Value, indent string) error { + if m.Indent != "" { + out.write(indent) + out.write(m.Indent) + } + out.write(`"`) + out.write(prop.JSONName) + out.write(`":`) + if m.Indent != "" { + out.write(" ") + } + if err := m.marshalValue(out, prop, v, indent); err != nil { + return err + } + return nil +} + +// marshalValue writes the value to the Writer. +func (m *Marshaler) marshalValue(out *errWriter, prop *proto.Properties, v reflect.Value, indent string) error { + + var err error + v = reflect.Indirect(v) + + // Handle repeated elements. + if v.Kind() == reflect.Slice && v.Type().Elem().Kind() != reflect.Uint8 { + out.write("[") + comma := "" + for i := 0; i < v.Len(); i++ { + sliceVal := v.Index(i) + out.write(comma) + if m.Indent != "" { + out.write("\n") + out.write(indent) + out.write(m.Indent) + out.write(m.Indent) + } + if err := m.marshalValue(out, prop, sliceVal, indent+m.Indent); err != nil { + return err + } + comma = "," + } + if m.Indent != "" { + out.write("\n") + out.write(indent) + out.write(m.Indent) + } + out.write("]") + return out.err + } + + // Handle well-known types. + // Most are handled up in marshalObject (because 99% are messages). + type wkt interface { + XXX_WellKnownType() string + } + if wkt, ok := v.Interface().(wkt); ok { + switch wkt.XXX_WellKnownType() { + case "NullValue": + out.write("null") + return out.err + } + } + + // Handle enumerations. + if !m.EnumsAsInts && prop.Enum != "" { + // Unknown enum values will are stringified by the proto library as their + // value. Such values should _not_ be quoted or they will be interpreted + // as an enum string instead of their value. + enumStr := v.Interface().(fmt.Stringer).String() + var valStr string + if v.Kind() == reflect.Ptr { + valStr = strconv.Itoa(int(v.Elem().Int())) + } else { + valStr = strconv.Itoa(int(v.Int())) + } + isKnownEnum := enumStr != valStr + if isKnownEnum { + out.write(`"`) + } + out.write(enumStr) + if isKnownEnum { + out.write(`"`) + } + return out.err + } + + // Handle nested messages. + if v.Kind() == reflect.Struct { + return m.marshalObject(out, v.Addr().Interface().(proto.Message), indent+m.Indent, "") + } + + // Handle maps. + // Since Go randomizes map iteration, we sort keys for stable output. + if v.Kind() == reflect.Map { + out.write(`{`) + keys := v.MapKeys() + sort.Sort(mapKeys(keys)) + for i, k := range keys { + if i > 0 { + out.write(`,`) + } + if m.Indent != "" { + out.write("\n") + out.write(indent) + out.write(m.Indent) + out.write(m.Indent) + } + + b, err := json.Marshal(k.Interface()) + if err != nil { + return err + } + s := string(b) + + // If the JSON is not a string value, encode it again to make it one. + if !strings.HasPrefix(s, `"`) { + b, err := json.Marshal(s) + if err != nil { + return err + } + s = string(b) + } + + out.write(s) + out.write(`:`) + if m.Indent != "" { + out.write(` `) + } + + if err := m.marshalValue(out, prop, v.MapIndex(k), indent+m.Indent); err != nil { + return err + } + } + if m.Indent != "" { + out.write("\n") + out.write(indent) + out.write(m.Indent) + } + out.write(`}`) + return out.err + } + + // Default handling defers to the encoding/json library. + b, err := json.Marshal(v.Interface()) + if err != nil { + return err + } + needToQuote := string(b[0]) != `"` && (v.Kind() == reflect.Int64 || v.Kind() == reflect.Uint64) + if needToQuote { + out.write(`"`) + } + out.write(string(b)) + if needToQuote { + out.write(`"`) + } + return out.err +} + +// Unmarshaler is a configurable object for converting from a JSON +// representation to a protocol buffer object. +type Unmarshaler struct { + // Whether to allow messages to contain unknown fields, as opposed to + // failing to unmarshal. + AllowUnknownFields bool +} + +// UnmarshalNext unmarshals the next protocol buffer from a JSON object stream. +// This function is lenient and will decode any options permutations of the +// related Marshaler. +func (u *Unmarshaler) UnmarshalNext(dec *json.Decoder, pb proto.Message) error { + inputValue := json.RawMessage{} + if err := dec.Decode(&inputValue); err != nil { + return err + } + return u.unmarshalValue(reflect.ValueOf(pb).Elem(), inputValue, nil) +} + +// Unmarshal unmarshals a JSON object stream into a protocol +// buffer. This function is lenient and will decode any options +// permutations of the related Marshaler. +func (u *Unmarshaler) Unmarshal(r io.Reader, pb proto.Message) error { + dec := json.NewDecoder(r) + return u.UnmarshalNext(dec, pb) +} + +// UnmarshalNext unmarshals the next protocol buffer from a JSON object stream. +// This function is lenient and will decode any options permutations of the +// related Marshaler. +func UnmarshalNext(dec *json.Decoder, pb proto.Message) error { + return new(Unmarshaler).UnmarshalNext(dec, pb) +} + +// Unmarshal unmarshals a JSON object stream into a protocol +// buffer. This function is lenient and will decode any options +// permutations of the related Marshaler. +func Unmarshal(r io.Reader, pb proto.Message) error { + return new(Unmarshaler).Unmarshal(r, pb) +} + +// UnmarshalString will populate the fields of a protocol buffer based +// on a JSON string. This function is lenient and will decode any options +// permutations of the related Marshaler. +func UnmarshalString(str string, pb proto.Message) error { + return new(Unmarshaler).Unmarshal(strings.NewReader(str), pb) +} + +// unmarshalValue converts/copies a value into the target. +// prop may be nil. +func (u *Unmarshaler) unmarshalValue(target reflect.Value, inputValue json.RawMessage, prop *proto.Properties) error { + targetType := target.Type() + + // Allocate memory for pointer fields. + if targetType.Kind() == reflect.Ptr { + target.Set(reflect.New(targetType.Elem())) + return u.unmarshalValue(target.Elem(), inputValue, prop) + } + + // Handle well-known types. + type wkt interface { + XXX_WellKnownType() string + } + if wkt, ok := target.Addr().Interface().(wkt); ok { + switch wkt.XXX_WellKnownType() { + case "DoubleValue", "FloatValue", "Int64Value", "UInt64Value", + "Int32Value", "UInt32Value", "BoolValue", "StringValue", "BytesValue": + // "Wrappers use the same representation in JSON + // as the wrapped primitive type, except that null is allowed." + // encoding/json will turn JSON `null` into Go `nil`, + // so we don't have to do any extra work. + return u.unmarshalValue(target.Field(0), inputValue, prop) + case "Any": + return fmt.Errorf("unmarshaling Any not supported yet") + case "Duration": + ivStr := string(inputValue) + if ivStr == "null" { + target.Field(0).SetInt(0) + target.Field(1).SetInt(0) + return nil + } + + unq, err := strconv.Unquote(ivStr) + if err != nil { + return err + } + d, err := time.ParseDuration(unq) + if err != nil { + return fmt.Errorf("bad Duration: %v", err) + } + ns := d.Nanoseconds() + s := ns / 1e9 + ns %= 1e9 + target.Field(0).SetInt(s) + target.Field(1).SetInt(ns) + return nil + case "Timestamp": + ivStr := string(inputValue) + if ivStr == "null" { + target.Field(0).SetInt(0) + target.Field(1).SetInt(0) + return nil + } + + unq, err := strconv.Unquote(ivStr) + if err != nil { + return err + } + t, err := time.Parse(time.RFC3339Nano, unq) + if err != nil { + return fmt.Errorf("bad Timestamp: %v", err) + } + target.Field(0).SetInt(int64(t.Unix())) + target.Field(1).SetInt(int64(t.Nanosecond())) + return nil + } + } + + // Handle enums, which have an underlying type of int32, + // and may appear as strings. + // The case of an enum appearing as a number is handled + // at the bottom of this function. + if inputValue[0] == '"' && prop != nil && prop.Enum != "" { + vmap := proto.EnumValueMap(prop.Enum) + // Don't need to do unquoting; valid enum names + // are from a limited character set. + s := inputValue[1 : len(inputValue)-1] + n, ok := vmap[string(s)] + if !ok { + return fmt.Errorf("unknown value %q for enum %s", s, prop.Enum) + } + if target.Kind() == reflect.Ptr { // proto2 + target.Set(reflect.New(targetType.Elem())) + target = target.Elem() + } + target.SetInt(int64(n)) + return nil + } + + // Handle nested messages. + if targetType.Kind() == reflect.Struct { + var jsonFields map[string]json.RawMessage + if err := json.Unmarshal(inputValue, &jsonFields); err != nil { + return err + } + + consumeField := func(prop *proto.Properties) (json.RawMessage, bool) { + // Be liberal in what names we accept; both orig_name and camelName are okay. + fieldNames := acceptedJSONFieldNames(prop) + + vOrig, okOrig := jsonFields[fieldNames.orig] + vCamel, okCamel := jsonFields[fieldNames.camel] + if !okOrig && !okCamel { + return nil, false + } + // If, for some reason, both are present in the data, favour the camelName. + var raw json.RawMessage + if okOrig { + raw = vOrig + delete(jsonFields, fieldNames.orig) + } + if okCamel { + raw = vCamel + delete(jsonFields, fieldNames.camel) + } + return raw, true + } + + sprops := proto.GetProperties(targetType) + for i := 0; i < target.NumField(); i++ { + ft := target.Type().Field(i) + if strings.HasPrefix(ft.Name, "XXX_") { + continue + } + + valueForField, ok := consumeField(sprops.Prop[i]) + if !ok { + continue + } + + if err := u.unmarshalValue(target.Field(i), valueForField, sprops.Prop[i]); err != nil { + return err + } + } + // Check for any oneof fields. + if len(jsonFields) > 0 { + for _, oop := range sprops.OneofTypes { + raw, ok := consumeField(oop.Prop) + if !ok { + continue + } + nv := reflect.New(oop.Type.Elem()) + target.Field(oop.Field).Set(nv) + if err := u.unmarshalValue(nv.Elem().Field(0), raw, oop.Prop); err != nil { + return err + } + } + } + if !u.AllowUnknownFields && len(jsonFields) > 0 { + // Pick any field to be the scapegoat. + var f string + for fname := range jsonFields { + f = fname + break + } + return fmt.Errorf("unknown field %q in %v", f, targetType) + } + return nil + } + + // Handle arrays (which aren't encoded bytes) + if targetType.Kind() == reflect.Slice && targetType.Elem().Kind() != reflect.Uint8 { + var slc []json.RawMessage + if err := json.Unmarshal(inputValue, &slc); err != nil { + return err + } + len := len(slc) + target.Set(reflect.MakeSlice(targetType, len, len)) + for i := 0; i < len; i++ { + if err := u.unmarshalValue(target.Index(i), slc[i], prop); err != nil { + return err + } + } + return nil + } + + // Handle maps (whose keys are always strings) + if targetType.Kind() == reflect.Map { + var mp map[string]json.RawMessage + if err := json.Unmarshal(inputValue, &mp); err != nil { + return err + } + target.Set(reflect.MakeMap(targetType)) + var keyprop, valprop *proto.Properties + if prop != nil { + // These could still be nil if the protobuf metadata is broken somehow. + // TODO: This won't work because the fields are unexported. + // We should probably just reparse them. + //keyprop, valprop = prop.mkeyprop, prop.mvalprop + } + for ks, raw := range mp { + // Unmarshal map key. The core json library already decoded the key into a + // string, so we handle that specially. Other types were quoted post-serialization. + var k reflect.Value + if targetType.Key().Kind() == reflect.String { + k = reflect.ValueOf(ks) + } else { + k = reflect.New(targetType.Key()).Elem() + if err := u.unmarshalValue(k, json.RawMessage(ks), keyprop); err != nil { + return err + } + } + + // Unmarshal map value. + v := reflect.New(targetType.Elem()).Elem() + if err := u.unmarshalValue(v, raw, valprop); err != nil { + return err + } + target.SetMapIndex(k, v) + } + return nil + } + + // 64-bit integers can be encoded as strings. In this case we drop + // the quotes and proceed as normal. + isNum := targetType.Kind() == reflect.Int64 || targetType.Kind() == reflect.Uint64 + if isNum && strings.HasPrefix(string(inputValue), `"`) { + inputValue = inputValue[1 : len(inputValue)-1] + } + + // Use the encoding/json for parsing other value types. + return json.Unmarshal(inputValue, target.Addr().Interface()) +} + +// jsonProperties returns parsed proto.Properties for the field and corrects JSONName attribute. +func jsonProperties(f reflect.StructField, origName bool) *proto.Properties { + var prop proto.Properties + prop.Init(f.Type, f.Name, f.Tag.Get("protobuf"), &f) + if origName || prop.JSONName == "" { + prop.JSONName = prop.OrigName + } + return &prop +} + +type fieldNames struct { + orig, camel string +} + +func acceptedJSONFieldNames(prop *proto.Properties) fieldNames { + opts := fieldNames{orig: prop.OrigName, camel: prop.OrigName} + if prop.JSONName != "" { + opts.camel = prop.JSONName + } + return opts +} + +// Writer wrapper inspired by https://blog.golang.org/errors-are-values +type errWriter struct { + writer io.Writer + err error +} + +func (w *errWriter) write(str string) { + if w.err != nil { + return + } + _, w.err = w.writer.Write([]byte(str)) +} + +// Map fields may have key types of non-float scalars, strings and enums. +// The easiest way to sort them in some deterministic order is to use fmt. +// If this turns out to be inefficient we can always consider other options, +// such as doing a Schwartzian transform. +// +// Numeric keys are sorted in numeric order per +// https://developers.google.com/protocol-buffers/docs/proto#maps. +type mapKeys []reflect.Value + +func (s mapKeys) Len() int { return len(s) } +func (s mapKeys) Swap(i, j int) { s[i], s[j] = s[j], s[i] } +func (s mapKeys) Less(i, j int) bool { + if k := s[i].Kind(); k == s[j].Kind() { + switch k { + case reflect.Int32, reflect.Int64: + return s[i].Int() < s[j].Int() + case reflect.Uint32, reflect.Uint64: + return s[i].Uint() < s[j].Uint() + } + } + return fmt.Sprint(s[i].Interface()) < fmt.Sprint(s[j].Interface()) +} diff --git a/vendor/github.com/golang/protobuf/jsonpb/jsonpb_test.go b/vendor/github.com/golang/protobuf/jsonpb/jsonpb_test.go new file mode 100644 index 00000000000..e237df55c35 --- /dev/null +++ b/vendor/github.com/golang/protobuf/jsonpb/jsonpb_test.go @@ -0,0 +1,563 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2015 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package jsonpb + +import ( + "bytes" + "encoding/json" + "io" + "reflect" + "strings" + "testing" + + "github.com/golang/protobuf/proto" + + pb "github.com/golang/protobuf/jsonpb/jsonpb_test_proto" + proto3pb "github.com/golang/protobuf/proto/proto3_proto" + anypb "github.com/golang/protobuf/ptypes/any" + durpb "github.com/golang/protobuf/ptypes/duration" + stpb "github.com/golang/protobuf/ptypes/struct" + tspb "github.com/golang/protobuf/ptypes/timestamp" + wpb "github.com/golang/protobuf/ptypes/wrappers" +) + +var ( + marshaler = Marshaler{} + + marshalerAllOptions = Marshaler{ + Indent: " ", + } + + simpleObject = &pb.Simple{ + OInt32: proto.Int32(-32), + OInt64: proto.Int64(-6400000000), + OUint32: proto.Uint32(32), + OUint64: proto.Uint64(6400000000), + OSint32: proto.Int32(-13), + OSint64: proto.Int64(-2600000000), + OFloat: proto.Float32(3.14), + ODouble: proto.Float64(6.02214179e23), + OBool: proto.Bool(true), + OString: proto.String("hello \"there\""), + OBytes: []byte("beep boop"), + } + + simpleObjectJSON = `{` + + `"oBool":true,` + + `"oInt32":-32,` + + `"oInt64":"-6400000000",` + + `"oUint32":32,` + + `"oUint64":"6400000000",` + + `"oSint32":-13,` + + `"oSint64":"-2600000000",` + + `"oFloat":3.14,` + + `"oDouble":6.02214179e+23,` + + `"oString":"hello \"there\"",` + + `"oBytes":"YmVlcCBib29w"` + + `}` + + simpleObjectPrettyJSON = `{ + "oBool": true, + "oInt32": -32, + "oInt64": "-6400000000", + "oUint32": 32, + "oUint64": "6400000000", + "oSint32": -13, + "oSint64": "-2600000000", + "oFloat": 3.14, + "oDouble": 6.02214179e+23, + "oString": "hello \"there\"", + "oBytes": "YmVlcCBib29w" +}` + + repeatsObject = &pb.Repeats{ + RBool: []bool{true, false, true}, + RInt32: []int32{-3, -4, -5}, + RInt64: []int64{-123456789, -987654321}, + RUint32: []uint32{1, 2, 3}, + RUint64: []uint64{6789012345, 3456789012}, + RSint32: []int32{-1, -2, -3}, + RSint64: []int64{-6789012345, -3456789012}, + RFloat: []float32{3.14, 6.28}, + RDouble: []float64{299792458 * 1e20, 6.62606957e-34}, + RString: []string{"happy", "days"}, + RBytes: [][]byte{[]byte("skittles"), []byte("m&m's")}, + } + + repeatsObjectJSON = `{` + + `"rBool":[true,false,true],` + + `"rInt32":[-3,-4,-5],` + + `"rInt64":["-123456789","-987654321"],` + + `"rUint32":[1,2,3],` + + `"rUint64":["6789012345","3456789012"],` + + `"rSint32":[-1,-2,-3],` + + `"rSint64":["-6789012345","-3456789012"],` + + `"rFloat":[3.14,6.28],` + + `"rDouble":[2.99792458e+28,6.62606957e-34],` + + `"rString":["happy","days"],` + + `"rBytes":["c2tpdHRsZXM=","bSZtJ3M="]` + + `}` + + repeatsObjectPrettyJSON = `{ + "rBool": [ + true, + false, + true + ], + "rInt32": [ + -3, + -4, + -5 + ], + "rInt64": [ + "-123456789", + "-987654321" + ], + "rUint32": [ + 1, + 2, + 3 + ], + "rUint64": [ + "6789012345", + "3456789012" + ], + "rSint32": [ + -1, + -2, + -3 + ], + "rSint64": [ + "-6789012345", + "-3456789012" + ], + "rFloat": [ + 3.14, + 6.28 + ], + "rDouble": [ + 2.99792458e+28, + 6.62606957e-34 + ], + "rString": [ + "happy", + "days" + ], + "rBytes": [ + "c2tpdHRsZXM=", + "bSZtJ3M=" + ] +}` + + innerSimple = &pb.Simple{OInt32: proto.Int32(-32)} + innerSimple2 = &pb.Simple{OInt64: proto.Int64(25)} + innerRepeats = &pb.Repeats{RString: []string{"roses", "red"}} + innerRepeats2 = &pb.Repeats{RString: []string{"violets", "blue"}} + complexObject = &pb.Widget{ + Color: pb.Widget_GREEN.Enum(), + RColor: []pb.Widget_Color{pb.Widget_RED, pb.Widget_GREEN, pb.Widget_BLUE}, + Simple: innerSimple, + RSimple: []*pb.Simple{innerSimple, innerSimple2}, + Repeats: innerRepeats, + RRepeats: []*pb.Repeats{innerRepeats, innerRepeats2}, + } + + complexObjectJSON = `{"color":"GREEN",` + + `"rColor":["RED","GREEN","BLUE"],` + + `"simple":{"oInt32":-32},` + + `"rSimple":[{"oInt32":-32},{"oInt64":"25"}],` + + `"repeats":{"rString":["roses","red"]},` + + `"rRepeats":[{"rString":["roses","red"]},{"rString":["violets","blue"]}]` + + `}` + + complexObjectPrettyJSON = `{ + "color": "GREEN", + "rColor": [ + "RED", + "GREEN", + "BLUE" + ], + "simple": { + "oInt32": -32 + }, + "rSimple": [ + { + "oInt32": -32 + }, + { + "oInt64": "25" + } + ], + "repeats": { + "rString": [ + "roses", + "red" + ] + }, + "rRepeats": [ + { + "rString": [ + "roses", + "red" + ] + }, + { + "rString": [ + "violets", + "blue" + ] + } + ] +}` + + colorPrettyJSON = `{ + "color": 2 +}` + + colorListPrettyJSON = `{ + "color": 1000, + "rColor": [ + "RED" + ] +}` + + nummyPrettyJSON = `{ + "nummy": { + "1": 2, + "3": 4 + } +}` + + objjyPrettyJSON = `{ + "objjy": { + "1": { + "dub": 1 + } + } +}` + realNumber = &pb.Real{Value: proto.Float64(3.14159265359)} + realNumberName = "Pi" + complexNumber = &pb.Complex{Imaginary: proto.Float64(0.5772156649)} + realNumberJSON = `{` + + `"value":3.14159265359,` + + `"[jsonpb.Complex.real_extension]":{"imaginary":0.5772156649},` + + `"[jsonpb.name]":"Pi"` + + `}` + + anySimple = &pb.KnownTypes{ + An: &anypb.Any{ + TypeUrl: "something.example.com/jsonpb.Simple", + Value: []byte{ + // &pb.Simple{OBool:true} + 1 << 3, 1, + }, + }, + } + anySimpleJSON = `{"an":{"@type":"something.example.com/jsonpb.Simple","oBool":true}}` + anySimplePrettyJSON = `{ + "an": { + "@type": "something.example.com/jsonpb.Simple", + "oBool": true + } +}` + + anyWellKnown = &pb.KnownTypes{ + An: &anypb.Any{ + TypeUrl: "type.googleapis.com/google.protobuf.Duration", + Value: []byte{ + // &durpb.Duration{Seconds: 1, Nanos: 212000000 } + 1 << 3, 1, // seconds + 2 << 3, 0x80, 0xba, 0x8b, 0x65, // nanos + }, + }, + } + anyWellKnownJSON = `{"an":{"@type":"type.googleapis.com/google.protobuf.Duration","value":"1.212s"}}` + anyWellKnownPrettyJSON = `{ + "an": { + "@type": "type.googleapis.com/google.protobuf.Duration", + "value": "1.212s" + } +}` +) + +func init() { + if err := proto.SetExtension(realNumber, pb.E_Name, &realNumberName); err != nil { + panic(err) + } + if err := proto.SetExtension(realNumber, pb.E_Complex_RealExtension, complexNumber); err != nil { + panic(err) + } +} + +var marshalingTests = []struct { + desc string + marshaler Marshaler + pb proto.Message + json string +}{ + {"simple flat object", marshaler, simpleObject, simpleObjectJSON}, + {"simple pretty object", marshalerAllOptions, simpleObject, simpleObjectPrettyJSON}, + {"repeated fields flat object", marshaler, repeatsObject, repeatsObjectJSON}, + {"repeated fields pretty object", marshalerAllOptions, repeatsObject, repeatsObjectPrettyJSON}, + {"nested message/enum flat object", marshaler, complexObject, complexObjectJSON}, + {"nested message/enum pretty object", marshalerAllOptions, complexObject, complexObjectPrettyJSON}, + {"enum-string flat object", Marshaler{}, + &pb.Widget{Color: pb.Widget_BLUE.Enum()}, `{"color":"BLUE"}`}, + {"enum-value pretty object", Marshaler{EnumsAsInts: true, Indent: " "}, + &pb.Widget{Color: pb.Widget_BLUE.Enum()}, colorPrettyJSON}, + {"unknown enum value object", marshalerAllOptions, + &pb.Widget{Color: pb.Widget_Color(1000).Enum(), RColor: []pb.Widget_Color{pb.Widget_RED}}, colorListPrettyJSON}, + {"repeated proto3 enum", Marshaler{}, + &proto3pb.Message{RFunny: []proto3pb.Message_Humour{ + proto3pb.Message_PUNS, + proto3pb.Message_SLAPSTICK, + }}, + `{"rFunny":["PUNS","SLAPSTICK"]}`}, + {"repeated proto3 enum as int", Marshaler{EnumsAsInts: true}, + &proto3pb.Message{RFunny: []proto3pb.Message_Humour{ + proto3pb.Message_PUNS, + proto3pb.Message_SLAPSTICK, + }}, + `{"rFunny":[1,2]}`}, + {"empty value", marshaler, &pb.Simple3{}, `{}`}, + {"empty value emitted", Marshaler{EmitDefaults: true}, &pb.Simple3{}, `{"dub":0}`}, + {"map", marshaler, &pb.Mappy{Nummy: map[int64]int32{1: 2, 3: 4}}, `{"nummy":{"1":2,"3":4}}`}, + {"map", marshalerAllOptions, &pb.Mappy{Nummy: map[int64]int32{1: 2, 3: 4}}, nummyPrettyJSON}, + {"map", marshaler, + &pb.Mappy{Strry: map[string]string{`"one"`: "two", "three": "four"}}, + `{"strry":{"\"one\"":"two","three":"four"}}`}, + {"map", marshaler, + &pb.Mappy{Objjy: map[int32]*pb.Simple3{1: &pb.Simple3{Dub: 1}}}, `{"objjy":{"1":{"dub":1}}}`}, + {"map", marshalerAllOptions, + &pb.Mappy{Objjy: map[int32]*pb.Simple3{1: &pb.Simple3{Dub: 1}}}, objjyPrettyJSON}, + {"map", marshaler, &pb.Mappy{Buggy: map[int64]string{1234: "yup"}}, + `{"buggy":{"1234":"yup"}}`}, + {"map", marshaler, &pb.Mappy{Booly: map[bool]bool{false: true}}, `{"booly":{"false":true}}`}, + // TODO: This is broken. + //{"map", marshaler, &pb.Mappy{Enumy: map[string]pb.Numeral{"XIV": pb.Numeral_ROMAN}}, `{"enumy":{"XIV":"ROMAN"}`}, + {"map", Marshaler{EnumsAsInts: true}, &pb.Mappy{Enumy: map[string]pb.Numeral{"XIV": pb.Numeral_ROMAN}}, `{"enumy":{"XIV":2}}`}, + {"map", marshaler, &pb.Mappy{S32Booly: map[int32]bool{1: true, 3: false, 10: true, 12: false}}, `{"s32booly":{"1":true,"3":false,"10":true,"12":false}}`}, + {"map", marshaler, &pb.Mappy{S64Booly: map[int64]bool{1: true, 3: false, 10: true, 12: false}}, `{"s64booly":{"1":true,"3":false,"10":true,"12":false}}`}, + {"map", marshaler, &pb.Mappy{U32Booly: map[uint32]bool{1: true, 3: false, 10: true, 12: false}}, `{"u32booly":{"1":true,"3":false,"10":true,"12":false}}`}, + {"map", marshaler, &pb.Mappy{U64Booly: map[uint64]bool{1: true, 3: false, 10: true, 12: false}}, `{"u64booly":{"1":true,"3":false,"10":true,"12":false}}`}, + {"proto2 map", marshaler, &pb.Maps{MInt64Str: map[int64]string{213: "cat"}}, + `{"mInt64Str":{"213":"cat"}}`}, + {"proto2 map", marshaler, + &pb.Maps{MBoolSimple: map[bool]*pb.Simple{true: &pb.Simple{OInt32: proto.Int32(1)}}}, + `{"mBoolSimple":{"true":{"oInt32":1}}}`}, + {"oneof, not set", marshaler, &pb.MsgWithOneof{}, `{}`}, + {"oneof, set", marshaler, &pb.MsgWithOneof{Union: &pb.MsgWithOneof_Title{"Grand Poobah"}}, `{"title":"Grand Poobah"}`}, + {"force orig_name", Marshaler{OrigName: true}, &pb.Simple{OInt32: proto.Int32(4)}, + `{"o_int32":4}`}, + {"proto2 extension", marshaler, realNumber, realNumberJSON}, + {"Any with message", marshaler, anySimple, anySimpleJSON}, + {"Any with message and indent", marshalerAllOptions, anySimple, anySimplePrettyJSON}, + {"Any with WKT", marshaler, anyWellKnown, anyWellKnownJSON}, + {"Any with WKT and indent", marshalerAllOptions, anyWellKnown, anyWellKnownPrettyJSON}, + {"Duration", marshaler, &pb.KnownTypes{Dur: &durpb.Duration{Seconds: 3}}, `{"dur":"3.000s"}`}, + {"Struct", marshaler, &pb.KnownTypes{St: &stpb.Struct{ + Fields: map[string]*stpb.Value{ + "one": &stpb.Value{Kind: &stpb.Value_StringValue{"loneliest number"}}, + "two": &stpb.Value{Kind: &stpb.Value_NullValue{stpb.NullValue_NULL_VALUE}}, + }, + }}, `{"st":{"one":"loneliest number","two":null}}`}, + {"Timestamp", marshaler, &pb.KnownTypes{Ts: &tspb.Timestamp{Seconds: 14e8, Nanos: 21e6}}, `{"ts":"2014-05-13T16:53:20.021Z"}`}, + + {"DoubleValue", marshaler, &pb.KnownTypes{Dbl: &wpb.DoubleValue{Value: 1.2}}, `{"dbl":1.2}`}, + {"FloatValue", marshaler, &pb.KnownTypes{Flt: &wpb.FloatValue{Value: 1.2}}, `{"flt":1.2}`}, + {"Int64Value", marshaler, &pb.KnownTypes{I64: &wpb.Int64Value{Value: -3}}, `{"i64":"-3"}`}, + {"UInt64Value", marshaler, &pb.KnownTypes{U64: &wpb.UInt64Value{Value: 3}}, `{"u64":"3"}`}, + {"Int32Value", marshaler, &pb.KnownTypes{I32: &wpb.Int32Value{Value: -4}}, `{"i32":-4}`}, + {"UInt32Value", marshaler, &pb.KnownTypes{U32: &wpb.UInt32Value{Value: 4}}, `{"u32":4}`}, + {"BoolValue", marshaler, &pb.KnownTypes{Bool: &wpb.BoolValue{Value: true}}, `{"bool":true}`}, + {"StringValue", marshaler, &pb.KnownTypes{Str: &wpb.StringValue{Value: "plush"}}, `{"str":"plush"}`}, + {"BytesValue", marshaler, &pb.KnownTypes{Bytes: &wpb.BytesValue{Value: []byte("wow")}}, `{"bytes":"d293"}`}, +} + +func TestMarshaling(t *testing.T) { + for _, tt := range marshalingTests { + json, err := tt.marshaler.MarshalToString(tt.pb) + if err != nil { + t.Errorf("%s: marshaling error: %v", tt.desc, err) + } else if tt.json != json { + t.Errorf("%s: got [%v] want [%v]", tt.desc, json, tt.json) + } + } +} + +var unmarshalingTests = []struct { + desc string + unmarshaler Unmarshaler + json string + pb proto.Message +}{ + {"simple flat object", Unmarshaler{}, simpleObjectJSON, simpleObject}, + {"simple pretty object", Unmarshaler{}, simpleObjectPrettyJSON, simpleObject}, + {"repeated fields flat object", Unmarshaler{}, repeatsObjectJSON, repeatsObject}, + {"repeated fields pretty object", Unmarshaler{}, repeatsObjectPrettyJSON, repeatsObject}, + {"nested message/enum flat object", Unmarshaler{}, complexObjectJSON, complexObject}, + {"nested message/enum pretty object", Unmarshaler{}, complexObjectPrettyJSON, complexObject}, + {"enum-string object", Unmarshaler{}, `{"color":"BLUE"}`, &pb.Widget{Color: pb.Widget_BLUE.Enum()}}, + {"enum-value object", Unmarshaler{}, "{\n \"color\": 2\n}", &pb.Widget{Color: pb.Widget_BLUE.Enum()}}, + {"unknown field with allowed option", Unmarshaler{AllowUnknownFields: true}, `{"unknown": "foo"}`, new(pb.Simple)}, + {"proto3 enum string", Unmarshaler{}, `{"hilarity":"PUNS"}`, &proto3pb.Message{Hilarity: proto3pb.Message_PUNS}}, + {"proto3 enum value", Unmarshaler{}, `{"hilarity":1}`, &proto3pb.Message{Hilarity: proto3pb.Message_PUNS}}, + {"unknown enum value object", + Unmarshaler{}, + "{\n \"color\": 1000,\n \"r_color\": [\n \"RED\"\n ]\n}", + &pb.Widget{Color: pb.Widget_Color(1000).Enum(), RColor: []pb.Widget_Color{pb.Widget_RED}}}, + {"repeated proto3 enum", Unmarshaler{}, `{"rFunny":["PUNS","SLAPSTICK"]}`, + &proto3pb.Message{RFunny: []proto3pb.Message_Humour{ + proto3pb.Message_PUNS, + proto3pb.Message_SLAPSTICK, + }}}, + {"repeated proto3 enum as int", Unmarshaler{}, `{"rFunny":[1,2]}`, + &proto3pb.Message{RFunny: []proto3pb.Message_Humour{ + proto3pb.Message_PUNS, + proto3pb.Message_SLAPSTICK, + }}}, + {"repeated proto3 enum as mix of strings and ints", Unmarshaler{}, `{"rFunny":["PUNS",2]}`, + &proto3pb.Message{RFunny: []proto3pb.Message_Humour{ + proto3pb.Message_PUNS, + proto3pb.Message_SLAPSTICK, + }}}, + {"unquoted int64 object", Unmarshaler{}, `{"oInt64":-314}`, &pb.Simple{OInt64: proto.Int64(-314)}}, + {"unquoted uint64 object", Unmarshaler{}, `{"oUint64":123}`, &pb.Simple{OUint64: proto.Uint64(123)}}, + {"map", Unmarshaler{}, `{"nummy":{"1":2,"3":4}}`, &pb.Mappy{Nummy: map[int64]int32{1: 2, 3: 4}}}, + {"map", Unmarshaler{}, `{"strry":{"\"one\"":"two","three":"four"}}`, &pb.Mappy{Strry: map[string]string{`"one"`: "two", "three": "four"}}}, + {"map", Unmarshaler{}, `{"objjy":{"1":{"dub":1}}}`, &pb.Mappy{Objjy: map[int32]*pb.Simple3{1: &pb.Simple3{Dub: 1}}}}, + // TODO: This is broken. + //{"map", Unmarshaler{}, `{"enumy":{"XIV":"ROMAN"}`, &pb.Mappy{Enumy: map[string]pb.Numeral{"XIV": pb.Numeral_ROMAN}}}, + {"map", Unmarshaler{}, `{"enumy":{"XIV":2}}`, &pb.Mappy{Enumy: map[string]pb.Numeral{"XIV": pb.Numeral_ROMAN}}}, + {"oneof", Unmarshaler{}, `{"salary":31000}`, &pb.MsgWithOneof{Union: &pb.MsgWithOneof_Salary{31000}}}, + {"oneof spec name", Unmarshaler{}, `{"Country":"Australia"}`, &pb.MsgWithOneof{Union: &pb.MsgWithOneof_Country{"Australia"}}}, + {"oneof orig_name", Unmarshaler{}, `{"Country":"Australia"}`, &pb.MsgWithOneof{Union: &pb.MsgWithOneof_Country{"Australia"}}}, + {"oneof spec name2", Unmarshaler{}, `{"homeAddress":"Australia"}`, &pb.MsgWithOneof{Union: &pb.MsgWithOneof_HomeAddress{"Australia"}}}, + {"oneof orig_name2", Unmarshaler{}, `{"home_address":"Australia"}`, &pb.MsgWithOneof{Union: &pb.MsgWithOneof_HomeAddress{"Australia"}}}, + {"orig_name input", Unmarshaler{}, `{"o_bool":true}`, &pb.Simple{OBool: proto.Bool(true)}}, + {"camelName input", Unmarshaler{}, `{"oBool":true}`, &pb.Simple{OBool: proto.Bool(true)}}, + + {"Duration", Unmarshaler{}, `{"dur":"3.000s"}`, &pb.KnownTypes{Dur: &durpb.Duration{Seconds: 3}}}, + {"null Duration", Unmarshaler{}, `{"dur":null}`, &pb.KnownTypes{Dur: &durpb.Duration{Seconds: 0}}}, + {"Timestamp", Unmarshaler{}, `{"ts":"2014-05-13T16:53:20.021Z"}`, &pb.KnownTypes{Ts: &tspb.Timestamp{Seconds: 14e8, Nanos: 21e6}}}, + {"PreEpochTimestamp", Unmarshaler{}, `{"ts":"1969-12-31T23:59:58.999999995Z"}`, &pb.KnownTypes{Ts: &tspb.Timestamp{Seconds: -2, Nanos: 999999995}}}, + {"ZeroTimeTimestamp", Unmarshaler{}, `{"ts":"0001-01-01T00:00:00Z"}`, &pb.KnownTypes{Ts: &tspb.Timestamp{Seconds: -62135596800, Nanos: 0}}}, + {"null Timestamp", Unmarshaler{}, `{"ts":null}`, &pb.KnownTypes{Ts: &tspb.Timestamp{Seconds: 0, Nanos: 0}}}, + + {"DoubleValue", Unmarshaler{}, `{"dbl":1.2}`, &pb.KnownTypes{Dbl: &wpb.DoubleValue{Value: 1.2}}}, + {"FloatValue", Unmarshaler{}, `{"flt":1.2}`, &pb.KnownTypes{Flt: &wpb.FloatValue{Value: 1.2}}}, + {"Int64Value", Unmarshaler{}, `{"i64":"-3"}`, &pb.KnownTypes{I64: &wpb.Int64Value{Value: -3}}}, + {"UInt64Value", Unmarshaler{}, `{"u64":"3"}`, &pb.KnownTypes{U64: &wpb.UInt64Value{Value: 3}}}, + {"Int32Value", Unmarshaler{}, `{"i32":-4}`, &pb.KnownTypes{I32: &wpb.Int32Value{Value: -4}}}, + {"UInt32Value", Unmarshaler{}, `{"u32":4}`, &pb.KnownTypes{U32: &wpb.UInt32Value{Value: 4}}}, + {"BoolValue", Unmarshaler{}, `{"bool":true}`, &pb.KnownTypes{Bool: &wpb.BoolValue{Value: true}}}, + {"StringValue", Unmarshaler{}, `{"str":"plush"}`, &pb.KnownTypes{Str: &wpb.StringValue{Value: "plush"}}}, + {"BytesValue", Unmarshaler{}, `{"bytes":"d293"}`, &pb.KnownTypes{Bytes: &wpb.BytesValue{Value: []byte("wow")}}}, + // `null` is also a permissible value. Let's just test one. + {"null DoubleValue", Unmarshaler{}, `{"dbl":null}`, &pb.KnownTypes{Dbl: &wpb.DoubleValue{}}}, +} + +func TestUnmarshaling(t *testing.T) { + for _, tt := range unmarshalingTests { + // Make a new instance of the type of our expected object. + p := reflect.New(reflect.TypeOf(tt.pb).Elem()).Interface().(proto.Message) + + err := tt.unmarshaler.Unmarshal(strings.NewReader(tt.json), p) + if err != nil { + t.Errorf("%s: %v", tt.desc, err) + continue + } + + // For easier diffs, compare text strings of the protos. + exp := proto.MarshalTextString(tt.pb) + act := proto.MarshalTextString(p) + if string(exp) != string(act) { + t.Errorf("%s: got [%s] want [%s]", tt.desc, act, exp) + } + } +} + +func TestUnmarshalNext(t *testing.T) { + // We only need to check against a few, not all of them. + tests := unmarshalingTests[:5] + + // Create a buffer with many concatenated JSON objects. + var b bytes.Buffer + for _, tt := range tests { + b.WriteString(tt.json) + } + + dec := json.NewDecoder(&b) + for _, tt := range tests { + // Make a new instance of the type of our expected object. + p := reflect.New(reflect.TypeOf(tt.pb).Elem()).Interface().(proto.Message) + + err := tt.unmarshaler.UnmarshalNext(dec, p) + if err != nil { + t.Errorf("%s: %v", tt.desc, err) + continue + } + + // For easier diffs, compare text strings of the protos. + exp := proto.MarshalTextString(tt.pb) + act := proto.MarshalTextString(p) + if string(exp) != string(act) { + t.Errorf("%s: got [%s] want [%s]", tt.desc, act, exp) + } + } + + p := &pb.Simple{} + err := new(Unmarshaler).UnmarshalNext(dec, p) + if err != io.EOF { + t.Errorf("eof: got %v, expected io.EOF", err) + } +} + +var unmarshalingShouldError = []struct { + desc string + in string + pb proto.Message +}{ + {"a value", "666", new(pb.Simple)}, + {"gibberish", "{adskja123;l23=-=", new(pb.Simple)}, + {"unknown field", `{"unknown": "foo"}`, new(pb.Simple)}, + {"unknown enum name", `{"hilarity":"DAVE"}`, new(proto3pb.Message)}, +} + +func TestUnmarshalingBadInput(t *testing.T) { + for _, tt := range unmarshalingShouldError { + err := UnmarshalString(tt.in, tt.pb) + if err == nil { + t.Errorf("an error was expected when parsing %q instead of an object", tt.desc) + } + } +} diff --git a/vendor/github.com/gravitational/license/.gitignore b/vendor/github.com/gravitational/license/.gitignore new file mode 100644 index 00000000000..a1338d68517 --- /dev/null +++ b/vendor/github.com/gravitational/license/.gitignore @@ -0,0 +1,14 @@ +# Binaries for programs and plugins +*.exe +*.dll +*.so +*.dylib + +# Test binary, build with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736 +.glide/ diff --git a/vendor/github.com/gravitational/license/LICENSE b/vendor/github.com/gravitational/license/LICENSE new file mode 100644 index 00000000000..8dada3edaf5 --- /dev/null +++ b/vendor/github.com/gravitational/license/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/github.com/gravitational/license/Makefile b/vendor/github.com/gravitational/license/Makefile new file mode 100644 index 00000000000..7ffc5703b27 --- /dev/null +++ b/vendor/github.com/gravitational/license/Makefile @@ -0,0 +1,13 @@ +# +# test runs tests for all packages +# +.PHONY: test +test: + go test -v -test.parallel=0 -race ./... + +# +# build builds all packages +# +.PHONY: build +build: + go build ./... diff --git a/vendor/github.com/gravitational/license/README.md b/vendor/github.com/gravitational/license/README.md new file mode 100644 index 00000000000..6292dcecebf --- /dev/null +++ b/vendor/github.com/gravitational/license/README.md @@ -0,0 +1,2 @@ +# license +CA and licensing tools diff --git a/vendor/github.com/gravitational/license/constants/constants.go b/vendor/github.com/gravitational/license/constants/constants.go new file mode 100644 index 00000000000..9f4742387c4 --- /dev/null +++ b/vendor/github.com/gravitational/license/constants/constants.go @@ -0,0 +1,98 @@ +/* +Copyright 2017 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package constants + +import "encoding/asn1" + +const ( + // TLSKeyAlgo is default TLS algo used for K8s X509 certs + TLSKeyAlgo = "rsa" + + // TLSKeySize is default TLS key size used for K8s X509 certs + TLSKeySize = 2048 + + // RSAPrivateKeyPEMBlock is the name of the PEM block where private key is stored + RSAPrivateKeyPEMBlock = "RSA PRIVATE KEY" + + // CertificatePEMBlock is the name of the PEM block where certificate is stored + CertificatePEMBlock = "CERTIFICATE" + + // LicenseKeyPair is a name of the license key pair + LicenseKeyPair = "license" + + // LoopbackIP is IP of the loopback interface + LoopbackIP = "127.0.0.1" + + // LicenseKeyBits used when generating private key for license certificate + LicenseKeyBits = 2048 + + // LicenseOrg is the default name of license subject organization + LicenseOrg = "gravitational.io" + + // LicenseTimeFormat represents format of expiration time in license payload + LicenseTimeFormat = "2006-01-02 15:04:05" +) + +// LicenseASNExtensionID is an extension ID used when encoding/decoding +// license payload into certificates +var LicenseASN1ExtensionID = asn1.ObjectIdentifier{2, 5, 42} + +// EC2InstanceTypes maps AWS instance types to their number of CPUs, +// used for determining whether license allows a certain instance +// type in some cases +var EC2InstanceTypes = map[string]int{ + "t2.nano": 1, + "t2.micro": 1, + "t2.small": 1, + "t2.medium": 2, + "t2.large": 2, + "m3.medium": 1, + "m3.large": 2, + "m3.xlarge": 4, + "m3.2xlarge": 8, + "m4.large": 2, + "m4.xlarge": 4, + "m4.2xlarge": 8, + "m4.4xlarge": 16, + "m4.10xlarge": 40, + "c3.large": 2, + "c3.xlarge": 4, + "c3.2xlarge": 8, + "c3.4xlarge": 16, + "c3.8xlarge": 32, + "c4.large": 2, + "c4.xlarge": 4, + "c4.2xlarge": 8, + "c4.4xlarge": 16, + "c4.8xlarge": 36, + "x1.32xlarge": 128, + "g2.2xlarge": 8, + "g2.8xlarge": 32, + "r3.large": 2, + "r3.xlarge": 4, + "r3.2xlarge": 8, + "r3.4xlarge": 16, + "r3.8xlarge": 32, + "i2.xlarge": 4, + "i2.2xlarge": 8, + "i2.4xlarge": 16, + "i2.8xlarge": 32, + "d2.xlarge": 4, + "d2.2xlarge": 8, + "d2.4xlarge": 16, + "d2.8xlarge": 36, +} diff --git a/vendor/github.com/gravitational/license/license.go b/vendor/github.com/gravitational/license/license.go new file mode 100644 index 00000000000..af7ea9df4c9 --- /dev/null +++ b/vendor/github.com/gravitational/license/license.go @@ -0,0 +1,91 @@ +/* +Copyright 2017 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package license + +import ( + "crypto/x509" + "time" + + "github.com/gravitational/trace" +) + +// License represents Gravitational license +type License struct { + // Cert is the x509 license certificate + Cert *x509.Certificate + // Payload is the license payload + Payload Payload + // CertPEM is the certificate part of the license in PEM format + CertPEM []byte + // KeyPEM is the private key part of the license in PEM Format, + // may be empty if the license was parsed from certificate only + KeyPEM []byte +} + +// Verify makes sure the license is valid +func (l *License) Verify(caPEM []byte) error { + roots := x509.NewCertPool() + + // add the provided CA certificate to the roots + ok := roots.AppendCertsFromPEM(caPEM) + if !ok { + return trace.BadParameter("could not find any CA certificates") + } + + _, err := l.Cert.Verify(x509.VerifyOptions{Roots: roots}) + if err != nil { + certErr, ok := err.(x509.CertificateInvalidError) + if ok && certErr.Reason == x509.Expired { + return trace.BadParameter("the license has expired") + } + return trace.Wrap(err, "failed to verify the license") + } + + return nil +} + +// Payload is custom information that gets encoded into licenses +type Payload struct { + // ClusterID is vendor-specific cluster ID + ClusterID string `json:"cluster_id,omitempty"` + // Expiration is expiration time for the license + Expiration time.Time `json:"expiration,omitempty"` + // MaxNodes is maximum number of nodes the license allows + MaxNodes int `json:"max_nodes,omitempty"` + // MaxCores is maximum number of CPUs per node the license allows + MaxCores int `json:"max_cores,omitempty"` + // Company is the company name the license is generated for + Company string `json:"company,omitempty"` + // Person is the name of the person the license is generated for + Person string `json:"person,omitempty"` + // Email is the email of the person the license is generated for + Email string `json:"email,omitempty"` + // Metadata is an arbitrary customer metadata + Metadata string `json:"metadata,omitempty"` + // ProductName is the name of the product the license is for + ProductName string `json:"product_name,omitempty"` + // ProductVersion is the product version + ProductVersion string `json:"product_version,omitempty"` + // EncryptionKey is the passphrase for decoding encrypted packages + EncryptionKey []byte `json:"encryption_key,omitempty"` + // Signature is vendor-specific signature + Signature string `json:"signature,omitempty"` + // Shutdown indicates whether the app should be stopped when the license expires + Shutdown bool `json:"shutdown,omitempty"` + // AccountID is the ID of the account the license was issued for + AccountID string `json:"account_id,omitempty"` +} diff --git a/vendor/github.com/gravitational/license/parse.go b/vendor/github.com/gravitational/license/parse.go new file mode 100644 index 00000000000..4278b3f2bb9 --- /dev/null +++ b/vendor/github.com/gravitational/license/parse.go @@ -0,0 +1,159 @@ +/* +Copyright 2017 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package license + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "encoding/json" + "encoding/pem" + + "github.com/gravitational/license/constants" + + "github.com/gravitational/trace" +) + +// ParseString parses the license from the provided string +func ParseString(pem string) (*License, error) { + certPEM, keyPEM, err := SplitPEM([]byte(pem)) + if err != nil { + return nil, trace.Wrap(err) + } + certificateBytes, privateBytes, err := parseCertificatePEM(pem) + if err != nil { + return nil, trace.Wrap(err) + } + certificate, err := x509.ParseCertificate(certificateBytes) + if err != nil { + return nil, trace.Wrap(err) + } + payload, err := parsePayloadFromX509(certificate) + if err != nil { + return nil, trace.Wrap(err) + } + // decrypt encryption key + if len(payload.EncryptionKey) != 0 { + private, err := x509.ParsePKCS1PrivateKey(privateBytes) + if err != nil { + return nil, trace.Wrap(err) + } + payload.EncryptionKey, err = rsa.DecryptOAEP(sha256.New(), rand.Reader, + private, payload.EncryptionKey, nil) + if err != nil { + return nil, trace.Wrap(err) + } + } + return &License{ + Cert: certificate, + Payload: *payload, + CertPEM: certPEM, + KeyPEM: keyPEM, + }, nil +} + +// ParseX509 parses the license from the provided x509 certificate +func ParseX509(cert *x509.Certificate) (*License, error) { + payload, err := parsePayloadFromX509(cert) + if err != nil { + return nil, trace.Wrap(err) + } + return &License{ + Cert: cert, + Payload: *payload, + }, nil +} + +// MakeTLSCert takes the provided license and makes a TLS certificate +// which is the format used by Go servers +func MakeTLSCert(license License) (*tls.Certificate, error) { + tlsCert, err := tls.X509KeyPair(license.CertPEM, license.KeyPEM) + if err != nil { + return nil, trace.Wrap(err) + } + return &tlsCert, nil +} + +// MakeTLSConfig builds a client TLS config from the supplied license +func MakeTLSConfig(license License) (*tls.Config, error) { + tlsCert, err := MakeTLSCert(license) + if err != nil { + return nil, trace.Wrap(err) + } + return &tls.Config{ + Certificates: []tls.Certificate{*tlsCert}, + }, nil +} + +// parsePayloadFromX509 parses the extension with license payload from the +// provided x509 certificate +func parsePayloadFromX509(cert *x509.Certificate) (*Payload, error) { + for _, ext := range cert.Extensions { + if ext.Id.Equal(constants.LicenseASN1ExtensionID) { + var p Payload + if err := json.Unmarshal(ext.Value, &p); err != nil { + return nil, trace.Wrap(err) + } + return &p, nil + } + } + return nil, trace.NotFound( + "certificate does not contain extension with license payload") +} + +// parseCertificatePEM parses the concatenated certificate/private key in PEM format +// and returns certificate and private key in decoded DER ASN.1 structure +func parseCertificatePEM(certPEM string) ([]byte, []byte, error) { + var certificateBytes, privateBytes []byte + block, rest := pem.Decode([]byte(certPEM)) + for block != nil { + switch block.Type { + case constants.CertificatePEMBlock: + certificateBytes = block.Bytes + case constants.RSAPrivateKeyPEMBlock: + privateBytes = block.Bytes + } + // parse the next block + block, rest = pem.Decode(rest) + } + if len(certificateBytes) == 0 || len(privateBytes) == 0 { + return nil, nil, trace.BadParameter("could not parse the license") + } + return certificateBytes, privateBytes, nil +} + +// SplitPEM splits the provided PEM data that contains concatenated cert and key +// (in any order) into cert PEM and key PEM respectively. Returns an error if +// any of them is missing +func SplitPEM(pemData []byte) (certPEM []byte, keyPEM []byte, err error) { + block, rest := pem.Decode(pemData) + for block != nil { + switch block.Type { + case constants.CertificatePEMBlock: + certPEM = pem.EncodeToMemory(block) + case constants.RSAPrivateKeyPEMBlock: + keyPEM = pem.EncodeToMemory(block) + } + block, rest = pem.Decode(rest) + } + if len(certPEM) == 0 || len(keyPEM) == 0 { + return nil, nil, trace.BadParameter("cert or key PEM data is missing") + } + return certPEM, keyPEM, nil +} diff --git a/vendor/github.com/gravitational/reporting/.gitignore b/vendor/github.com/gravitational/reporting/.gitignore new file mode 100644 index 00000000000..a1338d68517 --- /dev/null +++ b/vendor/github.com/gravitational/reporting/.gitignore @@ -0,0 +1,14 @@ +# Binaries for programs and plugins +*.exe +*.dll +*.so +*.dylib + +# Test binary, build with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736 +.glide/ diff --git a/vendor/github.com/gravitational/reporting/Dockerfile b/vendor/github.com/gravitational/reporting/Dockerfile new file mode 100644 index 00000000000..5845aedc6f2 --- /dev/null +++ b/vendor/github.com/gravitational/reporting/Dockerfile @@ -0,0 +1,26 @@ +FROM quay.io/gravitational/debian-venti:go1.9.1-jessie + +ARG PROTOC_VER +ARG GOGO_PROTO_TAG +ARG GRPC_GATEWAY_TAG +ARG PLATFORM + +ENV TARBALL protoc-${PROTOC_VER}-${PLATFORM}.zip +ENV GOGOPROTO_ROOT ${GOPATH}/src/github.com/gogo/protobuf +ENV LANGUAGE="en_US.UTF-8" \ + LANG="en_US.UTF-8" \ + LC_ALL="en_US.UTF-8" \ + LC_CTYPE="en_US.UTF-8" \ + GOPATH="/gopath" \ + PATH="$PATH:/opt/protoc/bin:/opt/go/bin:/gopath/bin" + +RUN apt-get update && apt-get install unzip + +RUN curl -L -o /tmp/${TARBALL} https://github.com/google/protobuf/releases/download/v${PROTOC_VER}/${TARBALL} +RUN cd /tmp && unzip /tmp/protoc-${PROTOC_VER}-linux-x86_64.zip -d /usr/local && rm /tmp/${TARBALL} + +RUN go get -u github.com/gogo/protobuf/proto github.com/gogo/protobuf/protoc-gen-gogo github.com/gogo/protobuf/gogoproto golang.org/x/tools/cmd/goimports github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway github.com/grpc-ecosystem/grpc-gateway/protoc-gen-swagger +RUN cd ${GOPATH}/src/github.com/gogo/protobuf && git reset --hard ${GOGO_PROTO_TAG} && make install +RUN cd ${GOPATH}/src/github.com/grpc-ecosystem/grpc-gateway && git reset --hard ${GRPC_GATEWAY_TAG} && go install ./protoc-gen-grpc-gateway + +ENV PROTO_INCLUDE "/usr/local/include":"${GOPATH}/src":"${GOPATH}/src/github.com/gogo/protobuf/protobuf":"${GOPATH}/src/github.com/grpc-ecosystem/grpc-gateway/third_party/googleapis":"${GOPATH}/src/github.com/grpc-ecosystem/grpc-gateway/third_party/googleapis":"${GOGOPROTO_ROOT}:${GOGOPROTO_ROOT}/protobuf" diff --git a/vendor/github.com/gravitational/reporting/LICENSE b/vendor/github.com/gravitational/reporting/LICENSE new file mode 100644 index 00000000000..8dada3edaf5 --- /dev/null +++ b/vendor/github.com/gravitational/reporting/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/github.com/gravitational/reporting/Makefile b/vendor/github.com/gravitational/reporting/Makefile new file mode 100644 index 00000000000..79490c19f35 --- /dev/null +++ b/vendor/github.com/gravitational/reporting/Makefile @@ -0,0 +1,39 @@ +PROTOC_VER ?= 3.0.0 +GOGO_PROTO_TAG ?= v0.3 +GRPC_GATEWAY_TAG ?= v1.1.0 +PLATFORM := linux-x86_64 +GRPC_API := . +BUILDBOX_TAG := reporting-buildbox:0.0.1 + +.PHONY: all +all: grpc build + +.PHONY: build +build: + go build ./... + +.PHONY: test +test: + go test -v ./... + +.PHONY: buildbox +buildbox: + docker build \ + --build-arg PROTOC_VER=$(PROTOC_VER) \ + --build-arg GOGO_PROTO_TAG=$(GOGO_PROTO_TAG) \ + --build-arg GRPC_GATEWAY_TAG=$(GRPC_GATEWAY_TAG) \ + --build-arg PLATFORM=$(PLATFORM) \ + -t $(BUILDBOX_TAG) . + +.PHONY: grpc +grpc: buildbox + docker run -v $(shell pwd):/go/src/github.com/gravitational/reporting $(BUILDBOX_TAG) \ + make -C /go/src/github.com/gravitational/reporting buildbox-grpc + +.PHONY: buildbox-grpc +buildbox-grpc: + echo $$PROTO_INCLUDE + cd $(GRPC_API) && protoc -I=.:$$PROTO_INCLUDE \ + --gofast_out=plugins=grpc:.\ + --grpc-gateway_out=logtostderr=true:. \ + *.proto diff --git a/vendor/github.com/gravitational/reporting/README.md b/vendor/github.com/gravitational/reporting/README.md new file mode 100644 index 00000000000..c442d967873 --- /dev/null +++ b/vendor/github.com/gravitational/reporting/README.md @@ -0,0 +1,2 @@ +# reporting +gRPC based client/server usage reporting module diff --git a/vendor/github.com/gravitational/reporting/api.pb.go b/vendor/github.com/gravitational/reporting/api.pb.go new file mode 100644 index 00000000000..534959f5b7d --- /dev/null +++ b/vendor/github.com/gravitational/reporting/api.pb.go @@ -0,0 +1,554 @@ +// Code generated by protoc-gen-gogo. +// source: api.proto +// DO NOT EDIT! + +/* + Package reporting is a generated protocol buffer package. + + It is generated from these files: + api.proto + + It has these top-level messages: + GRPCEvent + GRPCEvents +*/ +package reporting + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" +import _ "github.com/grpc-ecosystem/grpc-gateway/third_party/googleapis/google/api" +import google_protobuf1 "github.com/golang/protobuf/ptypes/empty" + +import ( + context "golang.org/x/net/context" + grpc "google.golang.org/grpc" +) + +import io "io" + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +// GRPCEvent represents a single event sent over gRPC +type GRPCEvent struct { + // Data is the JSON-encoded event payload + Data []byte `protobuf:"bytes,1,opt,name=Data,json=data,proto3" json:"Data,omitempty"` +} + +func (m *GRPCEvent) Reset() { *m = GRPCEvent{} } +func (m *GRPCEvent) String() string { return proto.CompactTextString(m) } +func (*GRPCEvent) ProtoMessage() {} +func (*GRPCEvent) Descriptor() ([]byte, []int) { return fileDescriptorApi, []int{0} } + +// Events defines a series of events sent over gRPC +type GRPCEvents struct { + // Events is a list of events + Events []*GRPCEvent `protobuf:"bytes,1,rep,name=Events,json=events" json:"Events,omitempty"` +} + +func (m *GRPCEvents) Reset() { *m = GRPCEvents{} } +func (m *GRPCEvents) String() string { return proto.CompactTextString(m) } +func (*GRPCEvents) ProtoMessage() {} +func (*GRPCEvents) Descriptor() ([]byte, []int) { return fileDescriptorApi, []int{1} } + +func (m *GRPCEvents) GetEvents() []*GRPCEvent { + if m != nil { + return m.Events + } + return nil +} + +func init() { + proto.RegisterType((*GRPCEvent)(nil), "reporting.GRPCEvent") + proto.RegisterType((*GRPCEvents)(nil), "reporting.GRPCEvents") +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion3 + +// Client API for EventsService service + +type EventsServiceClient interface { + // Record records the provided list of gRPC events + Record(ctx context.Context, in *GRPCEvents, opts ...grpc.CallOption) (*google_protobuf1.Empty, error) +} + +type eventsServiceClient struct { + cc *grpc.ClientConn +} + +func NewEventsServiceClient(cc *grpc.ClientConn) EventsServiceClient { + return &eventsServiceClient{cc} +} + +func (c *eventsServiceClient) Record(ctx context.Context, in *GRPCEvents, opts ...grpc.CallOption) (*google_protobuf1.Empty, error) { + out := new(google_protobuf1.Empty) + err := grpc.Invoke(ctx, "/reporting.EventsService/Record", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// Server API for EventsService service + +type EventsServiceServer interface { + // Record records the provided list of gRPC events + Record(context.Context, *GRPCEvents) (*google_protobuf1.Empty, error) +} + +func RegisterEventsServiceServer(s *grpc.Server, srv EventsServiceServer) { + s.RegisterService(&_EventsService_serviceDesc, srv) +} + +func _EventsService_Record_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GRPCEvents) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(EventsServiceServer).Record(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/reporting.EventsService/Record", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(EventsServiceServer).Record(ctx, req.(*GRPCEvents)) + } + return interceptor(ctx, in, info, handler) +} + +var _EventsService_serviceDesc = grpc.ServiceDesc{ + ServiceName: "reporting.EventsService", + HandlerType: (*EventsServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Record", + Handler: _EventsService_Record_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: fileDescriptorApi, +} + +func (m *GRPCEvent) Marshal() (data []byte, err error) { + size := m.Size() + data = make([]byte, size) + n, err := m.MarshalTo(data) + if err != nil { + return nil, err + } + return data[:n], nil +} + +func (m *GRPCEvent) MarshalTo(data []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.Data) > 0 { + data[i] = 0xa + i++ + i = encodeVarintApi(data, i, uint64(len(m.Data))) + i += copy(data[i:], m.Data) + } + return i, nil +} + +func (m *GRPCEvents) Marshal() (data []byte, err error) { + size := m.Size() + data = make([]byte, size) + n, err := m.MarshalTo(data) + if err != nil { + return nil, err + } + return data[:n], nil +} + +func (m *GRPCEvents) MarshalTo(data []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.Events) > 0 { + for _, msg := range m.Events { + data[i] = 0xa + i++ + i = encodeVarintApi(data, i, uint64(msg.Size())) + n, err := msg.MarshalTo(data[i:]) + if err != nil { + return 0, err + } + i += n + } + } + return i, nil +} + +func encodeFixed64Api(data []byte, offset int, v uint64) int { + data[offset] = uint8(v) + data[offset+1] = uint8(v >> 8) + data[offset+2] = uint8(v >> 16) + data[offset+3] = uint8(v >> 24) + data[offset+4] = uint8(v >> 32) + data[offset+5] = uint8(v >> 40) + data[offset+6] = uint8(v >> 48) + data[offset+7] = uint8(v >> 56) + return offset + 8 +} +func encodeFixed32Api(data []byte, offset int, v uint32) int { + data[offset] = uint8(v) + data[offset+1] = uint8(v >> 8) + data[offset+2] = uint8(v >> 16) + data[offset+3] = uint8(v >> 24) + return offset + 4 +} +func encodeVarintApi(data []byte, offset int, v uint64) int { + for v >= 1<<7 { + data[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + data[offset] = uint8(v) + return offset + 1 +} +func (m *GRPCEvent) Size() (n int) { + var l int + _ = l + l = len(m.Data) + if l > 0 { + n += 1 + l + sovApi(uint64(l)) + } + return n +} + +func (m *GRPCEvents) Size() (n int) { + var l int + _ = l + if len(m.Events) > 0 { + for _, e := range m.Events { + l = e.Size() + n += 1 + l + sovApi(uint64(l)) + } + } + return n +} + +func sovApi(x uint64) (n int) { + for { + n++ + x >>= 7 + if x == 0 { + break + } + } + return n +} +func sozApi(x uint64) (n int) { + return sovApi(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} +func (m *GRPCEvent) Unmarshal(data []byte) error { + l := len(data) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowApi + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: GRPCEvent: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: GRPCEvent: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Data", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowApi + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + byteLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthApi + } + postIndex := iNdEx + byteLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Data = append(m.Data[:0], data[iNdEx:postIndex]...) + if m.Data == nil { + m.Data = []byte{} + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipApi(data[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthApi + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *GRPCEvents) Unmarshal(data []byte) error { + l := len(data) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowApi + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: GRPCEvents: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: GRPCEvents: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Events", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowApi + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + msglen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthApi + } + postIndex := iNdEx + msglen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Events = append(m.Events, &GRPCEvent{}) + if err := m.Events[len(m.Events)-1].Unmarshal(data[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipApi(data[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthApi + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func skipApi(data []byte) (n int, err error) { + l := len(data) + iNdEx := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowApi + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowApi + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if data[iNdEx-1] < 0x80 { + break + } + } + return iNdEx, nil + case 1: + iNdEx += 8 + return iNdEx, nil + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowApi + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + iNdEx += length + if length < 0 { + return 0, ErrInvalidLengthApi + } + return iNdEx, nil + case 3: + for { + var innerWire uint64 + var start int = iNdEx + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowApi + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + innerWire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + innerWireType := int(innerWire & 0x7) + if innerWireType == 4 { + break + } + next, err := skipApi(data[start:]) + if err != nil { + return 0, err + } + iNdEx = start + next + } + return iNdEx, nil + case 4: + return iNdEx, nil + case 5: + iNdEx += 4 + return iNdEx, nil + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + } + panic("unreachable") +} + +var ( + ErrInvalidLengthApi = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflowApi = fmt.Errorf("proto: integer overflow") +) + +func init() { proto.RegisterFile("api.proto", fileDescriptorApi) } + +var fileDescriptorApi = []byte{ + // 286 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x6c, 0x8f, 0xc1, 0x4a, 0x33, 0x31, + 0x14, 0x85, 0xff, 0xfc, 0x96, 0x81, 0x46, 0x45, 0x09, 0x56, 0x4a, 0x85, 0xb1, 0xcc, 0xaa, 0x88, + 0x26, 0x58, 0x77, 0x5d, 0xaa, 0xc5, 0x9d, 0xc8, 0xb8, 0x73, 0x53, 0x6e, 0xa7, 0x31, 0x0d, 0xd8, + 0x49, 0x48, 0x6e, 0x47, 0x66, 0xeb, 0x2b, 0xb8, 0xf1, 0x91, 0x5c, 0x0a, 0xbe, 0x80, 0x8c, 0x3e, + 0x88, 0x38, 0x19, 0x67, 0xe5, 0xee, 0x1c, 0xce, 0x77, 0x92, 0x73, 0x69, 0x17, 0xac, 0xe6, 0xd6, + 0x19, 0x34, 0xac, 0xeb, 0xa4, 0x35, 0x0e, 0x75, 0xae, 0x06, 0x77, 0x4a, 0xe3, 0x72, 0x3d, 0xe7, + 0x99, 0x59, 0x09, 0xe5, 0x6c, 0x76, 0x22, 0x33, 0xe3, 0x4b, 0x8f, 0xb2, 0xb1, 0x0a, 0x50, 0x3e, + 0x42, 0x29, 0x70, 0xa9, 0xdd, 0x62, 0x66, 0xc1, 0x61, 0x29, 0x94, 0x31, 0xea, 0x41, 0x82, 0xd5, + 0xbe, 0x91, 0x02, 0xac, 0x16, 0x90, 0xe7, 0x06, 0x01, 0xb5, 0xc9, 0x7d, 0xf8, 0x66, 0x70, 0xd0, + 0xa4, 0xb5, 0x9b, 0xaf, 0xef, 0x85, 0x5c, 0x59, 0x2c, 0x43, 0x98, 0x1c, 0xd2, 0xee, 0x55, 0x7a, + 0x73, 0x31, 0x2d, 0x64, 0x8e, 0x8c, 0xd1, 0xce, 0x25, 0x20, 0xf4, 0xc9, 0x90, 0x8c, 0xb6, 0xd2, + 0xce, 0x02, 0x10, 0x92, 0x09, 0xa5, 0x2d, 0xe0, 0xd9, 0x31, 0x8d, 0x82, 0xea, 0x93, 0xe1, 0xc6, + 0x68, 0x73, 0xbc, 0xc7, 0xdb, 0x1b, 0x78, 0x8b, 0xa5, 0x91, 0xac, 0x99, 0xf1, 0x8c, 0x6e, 0x07, + 0xfa, 0x56, 0xba, 0x42, 0x67, 0x92, 0x5d, 0xd3, 0x28, 0x95, 0x99, 0x71, 0x0b, 0xd6, 0xfb, 0xab, + 0xe8, 0x07, 0xfb, 0x3c, 0x8c, 0xe5, 0xbf, 0x63, 0xf9, 0xf4, 0x67, 0x6c, 0xd2, 0x7b, 0x7a, 0xff, + 0x7a, 0xfe, 0xbf, 0x93, 0x50, 0x51, 0x9c, 0x8a, 0xf0, 0xfa, 0x84, 0x1c, 0x9d, 0xef, 0xbe, 0x56, + 0x31, 0x79, 0xab, 0x62, 0xf2, 0x51, 0xc5, 0xe4, 0xe5, 0x33, 0xfe, 0x37, 0x8f, 0xea, 0xe2, 0xd9, + 0x77, 0x00, 0x00, 0x00, 0xff, 0xff, 0x5c, 0x2d, 0xa0, 0xdf, 0x67, 0x01, 0x00, 0x00, +} diff --git a/vendor/github.com/gravitational/reporting/api.pb.gw.go b/vendor/github.com/gravitational/reporting/api.pb.gw.go new file mode 100644 index 00000000000..6327192e1b6 --- /dev/null +++ b/vendor/github.com/gravitational/reporting/api.pb.gw.go @@ -0,0 +1,110 @@ +// Code generated by protoc-gen-grpc-gateway +// source: api.proto +// DO NOT EDIT! + +/* +Package reporting is a reverse proxy. + +It translates gRPC into RESTful JSON APIs. +*/ +package reporting + +import ( + "io" + "net/http" + + "github.com/golang/protobuf/proto" + "github.com/grpc-ecosystem/grpc-gateway/runtime" + "github.com/grpc-ecosystem/grpc-gateway/utilities" + "golang.org/x/net/context" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/grpclog" +) + +var _ codes.Code +var _ io.Reader +var _ = runtime.String +var _ = utilities.NewDoubleArray + +func request_EventsService_Record_0(ctx context.Context, marshaler runtime.Marshaler, client EventsServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq GRPCEvents + var metadata runtime.ServerMetadata + + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil { + return nil, metadata, grpc.Errorf(codes.InvalidArgument, "%v", err) + } + + msg, err := client.Record(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err + +} + +// RegisterEventsServiceHandlerFromEndpoint is same as RegisterEventsServiceHandler but +// automatically dials to "endpoint" and closes the connection when "ctx" gets done. +func RegisterEventsServiceHandlerFromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) (err error) { + conn, err := grpc.Dial(endpoint, opts...) + if err != nil { + return err + } + defer func() { + if err != nil { + if cerr := conn.Close(); cerr != nil { + grpclog.Printf("Failed to close conn to %s: %v", endpoint, cerr) + } + return + } + go func() { + <-ctx.Done() + if cerr := conn.Close(); cerr != nil { + grpclog.Printf("Failed to close conn to %s: %v", endpoint, cerr) + } + }() + }() + + return RegisterEventsServiceHandler(ctx, mux, conn) +} + +// RegisterEventsServiceHandler registers the http handlers for service EventsService to "mux". +// The handlers forward requests to the grpc endpoint over "conn". +func RegisterEventsServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error { + client := NewEventsServiceClient(conn) + + mux.Handle("POST", pattern_EventsService_Record_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + if cn, ok := w.(http.CloseNotifier); ok { + go func(done <-chan struct{}, closed <-chan bool) { + select { + case <-done: + case <-closed: + cancel() + } + }(ctx.Done(), cn.CloseNotify()) + } + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateContext(ctx, req) + if err != nil { + runtime.HTTPError(ctx, outboundMarshaler, w, req, err) + } + resp, md, err := request_EventsService_Record_0(rctx, inboundMarshaler, client, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, outboundMarshaler, w, req, err) + return + } + + forward_EventsService_Record_0(ctx, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + return nil +} + +var ( + pattern_EventsService_Record_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1}, []string{"v1", "events"}, "")) +) + +var ( + forward_EventsService_Record_0 = runtime.ForwardResponseMessage +) diff --git a/vendor/github.com/gravitational/reporting/api.proto b/vendor/github.com/gravitational/reporting/api.proto new file mode 100644 index 00000000000..5bb1fbcc8e0 --- /dev/null +++ b/vendor/github.com/gravitational/reporting/api.proto @@ -0,0 +1,45 @@ +/* +Copyright 2017 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +syntax = "proto3"; + +package reporting; + +import "github.com/grpc-ecosystem/grpc-gateway/third_party/googleapis/google/api/annotations.proto"; +import "google/protobuf/empty.proto"; + +// GRPCEvent represents a single event sent over gRPC +message GRPCEvent { + // Data is the JSON-encoded event payload + bytes Data = 1; +} + +// Events defines a series of events sent over gRPC +message GRPCEvents { + // Events is a list of events + repeated GRPCEvent Events = 1; +} + +// EventsService defines an event-recording service +service EventsService { + // Record records the provided list of gRPC events + rpc Record(GRPCEvents) returns (google.protobuf.Empty) { + option (google.api.http) = { + post: "/v1/events" + body: "*" + }; + } +} diff --git a/vendor/github.com/gravitational/reporting/client/client.go b/vendor/github.com/gravitational/reporting/client/client.go new file mode 100644 index 00000000000..8815e7b4853 --- /dev/null +++ b/vendor/github.com/gravitational/reporting/client/client.go @@ -0,0 +1,160 @@ +/* +Copyright 2017 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package client + +import ( + "context" + "crypto/tls" + "time" + + "github.com/gravitational/reporting" + "github.com/gravitational/reporting/types" + + "github.com/gravitational/trace" + log "github.com/sirupsen/logrus" + grpcapi "google.golang.org/grpc" + "google.golang.org/grpc/credentials" +) + +// ClientConfig defines the reporting client config +type ClientConfig struct { + // ServerAddr is the address of the reporting gRPC server + ServerAddr string + // ServerName is the SNI server name + ServerName string + // Certificate is the client certificate to authenticate with + Certificate tls.Certificate + // Insecure is whether the client should skip server cert verification + Insecure bool +} + +// Client defines the reporting client interface +type Client interface { + // Record records an event + Record(types.Event) +} + +// NewClient returns a new reporting gRPC client +func NewClient(ctx context.Context, config ClientConfig) (*client, error) { + conn, err := grpcapi.Dial(config.ServerAddr, + grpcapi.WithTransportCredentials( + credentials.NewTLS(&tls.Config{ + ServerName: config.ServerName, + InsecureSkipVerify: config.Insecure, + Certificates: []tls.Certificate{config.Certificate}, + }))) + if err != nil { + return nil, trace.Wrap(err) + } + client := &client{ + client: reporting.NewEventsServiceClient(conn), + // give an extra room to the events channel in case events + // are generated faster we can flush them (unlikely due to + // our events nature) + eventsCh: make(chan types.Event, 5*flushCount), + ctx: ctx, + } + go client.receiveAndFlushEvents() + return client, nil +} + +type client struct { + client reporting.EventsServiceClient + // eventsCh is the channel where events are submitted before they are + // put into internal buffer + eventsCh chan types.Event + // events is the internal events buffer that gets flushed periodically + events []types.Event + // ctx may be used to stop client goroutine + ctx context.Context +} + +// Record records an event. Note that the client accumulates events in memory +// and flushes them every once in a while +func (c *client) Record(event types.Event) { + select { + case c.eventsCh <- event: + log.Debugf("queued %v", event) + default: + log.Warnf("events channel is full, discarding %v", event) + } +} + +// receiveAndFlushEvents receives events on a channel, accumulates them in +// memory and flushes them once a certain number has been accumulated, or +// certain amount of time has passed +func (c *client) receiveAndFlushEvents() { + ticker := time.NewTicker(flushInterval) + defer ticker.Stop() + for { + select { + case event := <-c.eventsCh: + if len(c.events) >= flushCount { + if err := c.flush(); err != nil { + log.Errorf("events queue full and failed to flush events, discarding %v: %v", + event, trace.DebugReport(err)) + continue + } + } + c.events = append(c.events, event) + case <-ticker.C: + if err := c.flush(); err != nil { + log.Errorf("failed to flush events: %v", + trace.DebugReport(err)) + } + case <-c.ctx.Done(): + log.Debug("reporting client is shutting down") + if err := c.flush(); err != nil { + log.Errorf("failed to flush events: %v", + trace.DebugReport(err)) + } + return + } + } +} + +// flush flushes all accumulated events +func (c *client) flush() error { + if len(c.events) == 0 { + return nil // nothing to flush + } + var grpcEvents reporting.GRPCEvents + for _, event := range c.events { + grpcEvent, err := types.ToGRPCEvent(event) + if err != nil { + return trace.Wrap(err) + } + grpcEvents.Events = append( + grpcEvents.Events, grpcEvent) + } + // if we fail to flush some events here, they will be retried on + // the next cycle, we may get duplicates but each event includes + // a unique ID which server sinks can use to de-duplicate + if _, err := c.client.Record(c.ctx, &grpcEvents); err != nil { + return trace.Wrap(err) + } + log.Debugf("flushed %v events", len(c.events)) + c.events = []types.Event{} + return nil +} + +const ( + // flushInterval is how often the client flushes accumulated events + flushInterval = 3 * time.Second + // flushCount is the number of events to accumulate before flush triggers + flushCount = 5 +) diff --git a/vendor/github.com/gravitational/reporting/types/constants.go b/vendor/github.com/gravitational/reporting/types/constants.go new file mode 100644 index 00000000000..ab1579644a4 --- /dev/null +++ b/vendor/github.com/gravitational/reporting/types/constants.go @@ -0,0 +1,42 @@ +/* +Copyright 2017 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package types + +const ( + // ResourceVersion is the current event resource version + ResourceVersion = "v2" + // KindEvent is the event resource kind + KindEvent = "event" + // EventTypeServer is the server-related event type + EventTypeServer = "server" + // EventTypeUser is the user-related event type + EventTypeUser = "user" + // EventActionLogin is the event login action + EventActionLogin = "login" + // KindHeartbeat is the heartbeat resource kind + KindHeartbeat = "heartbeat" + // NotificationUsage is the usage limit notification type + NotificationUsage = "usage" + // NotificationTerms is the terms of service violation notification type + NotificationTerms = "terms" + // SeverityInfo is info notification severity + SeverityInfo = "info" + // SeverityWarning is warning notification severity + SeverityWarning = "warning" + // SeverityError is error notification severity + SeverityError = "error" +) diff --git a/vendor/github.com/gravitational/reporting/types/events.go b/vendor/github.com/gravitational/reporting/types/events.go new file mode 100644 index 00000000000..23dc28215dc --- /dev/null +++ b/vendor/github.com/gravitational/reporting/types/events.go @@ -0,0 +1,320 @@ +/* +Copyright 2017 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package types + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/gravitational/reporting" + + "github.com/gravitational/configure/jsonschema" + "github.com/gravitational/trace" + "github.com/pborman/uuid" +) + +// Event defines an interface all event types should implement +type Event interface { + // GetName returns the event name + GetName() string + // GetMetadata returns the event metadata + GetMetadata() Metadata + // SetAccountID sets the event account ID + SetAccountID(string) +} + +// Metadata represents event resource metadata +type Metadata struct { + // Name is the event name + Name string `json:"name"` + // Created is the event creation timestamp + Created time.Time `json:"created"` +} + +// ServerEvent represents server-related event, such as "logged into server" +type ServerEvent struct { + // Kind is resource kind, for events it is "event" + Kind string `json:"kind"` + // Version is the event resource version + Version string `json:"version"` + // Metadata is the event metadata + Metadata Metadata `json:"metadata"` + // Spec is the event spec + Spec ServerEventSpec `json:"spec"` +} + +// ServerEventSpec is server event specification +type ServerEventSpec struct { + // ID is event ID, may be used for de-duplication + ID string `json:"id"` + // Action is event action, such as "login" + Action string `json:"action"` + // AccountID is ID of account that triggered the event + AccountID string `json:"accountID"` + // ServerID is anonymized ID of server that triggered the event + ServerID string `json:"serverID"` +} + +// NewServerLoginEvent creates an instance of "server login" event +func NewServerLoginEvent(serverID string) *ServerEvent { + return &ServerEvent{ + Kind: KindEvent, + Version: ResourceVersion, + Metadata: Metadata{ + Name: EventTypeServer, + Created: time.Now().UTC(), + }, + Spec: ServerEventSpec{ + ID: uuid.New(), + Action: EventActionLogin, + ServerID: serverID, + }, + } +} + +// GetName returns the event name +func (e *ServerEvent) GetName() string { return e.Metadata.Name } + +// GetMetadata returns the event metadata +func (e *ServerEvent) GetMetadata() Metadata { return e.Metadata } + +// SetAccountID sets the event account ID +func (e *ServerEvent) SetAccountID(id string) { + e.Spec.AccountID = id +} + +// UserEvent represents user-related event, such as "user logged in" +type UserEvent struct { + // Kind is resource kind, for events it is "event" + Kind string `json:"kind"` + // Version is the event resource version + Version string `json:"version"` + // Metadata is the event metadata + Metadata Metadata `json:"metadata"` + // Spec is the event spec + Spec UserEventSpec `json:"spec"` +} + +// UserEventSpec is user event specification +type UserEventSpec struct { + // ID is event ID, may be used for de-duplication + ID string `json:"id"` + // Action is event action, such as "login" + Action string `json:"action"` + // AccountID is ID of account that triggered the event + AccountID string `json:"accountID"` + // UserID is anonymized ID of user that triggered the event + UserID string `json:"userID"` +} + +// NewUserLoginEvent creates an instance of "user login" event +func NewUserLoginEvent(userID string) *UserEvent { + return &UserEvent{ + Kind: KindEvent, + Version: ResourceVersion, + Metadata: Metadata{ + Name: EventTypeUser, + Created: time.Now().UTC(), + }, + Spec: UserEventSpec{ + ID: uuid.New(), + Action: EventActionLogin, + UserID: userID, + }, + } +} + +// GetName returns the event name +func (e *UserEvent) GetName() string { return e.Metadata.Name } + +// GetMetadata returns the event metadata +func (e *UserEvent) GetMetadata() Metadata { return e.Metadata } + +// SetAccountID sets the event account id +func (e *UserEvent) SetAccountID(id string) { + e.Spec.AccountID = id +} + +// ToGRPCEvent converts provided event to the format used by gRPC server/client +func ToGRPCEvent(event Event) (*reporting.GRPCEvent, error) { + payload, err := json.Marshal(event) + if err != nil { + return nil, trace.Wrap(err) + } + return &reporting.GRPCEvent{ + Data: payload, + }, nil +} + +// FromGRPCEvent converts event from the format used by gRPC server/client +func FromGRPCEvent(grpcEvent reporting.GRPCEvent) (Event, error) { + var header resourceHeader + if err := json.Unmarshal(grpcEvent.Data, &header); err != nil { + return nil, trace.Wrap(err) + } + if header.Kind != KindEvent { + return nil, trace.BadParameter("expected kind %q, got %q", + KindEvent, header.Kind) + } + if header.Version != ResourceVersion { + return nil, trace.BadParameter("expected resource version %q, got %q", + ResourceVersion, header.Version) + } + switch header.Metadata.Name { + case EventTypeServer: + var event ServerEvent + err := unmarshalWithSchema( + getServerEventSchema(), grpcEvent.Data, &event) + if err != nil { + return nil, trace.Wrap(err) + } + return &event, nil + case EventTypeUser: + var event UserEvent + err := unmarshalWithSchema( + getUserEventSchema(), grpcEvent.Data, &event) + if err != nil { + return nil, trace.Wrap(err) + } + return &event, nil + default: + return nil, trace.BadParameter("unknown event type %q", header.Metadata.Name) + } +} + +// FromGRPCEvents converts a series of events from the format used by gRPC server/client +func FromGRPCEvents(grpcEvents reporting.GRPCEvents) ([]Event, error) { + var events []Event + for _, grpcEvent := range grpcEvents.Events { + event, err := FromGRPCEvent(*grpcEvent) + if err != nil { + return nil, trace.Wrap(err) + } + events = append(events, event) + } + return events, nil +} + +// resourceHeader is used when unmarhsaling resources +type resourceHeader struct { + // Kind the the resource kind + Kind string `json:"kind"` + // Version is the resource version + Version string `json:"version"` + // Metadata is the resource metadata + Metadata Metadata `json:"metadata"` +} + +// schemaTemplate is the event resource schema template +const schemaTemplate = `{ + "type": "object", + "additionalProperties": false, + "required": ["kind", "version", "metadata", "spec"], + "properties": { + "kind": {"type": "string"}, + "version": {"type": "string", "default": "v2"}, + "metadata": { + "type": "object", + "additionalProperties": false, + "required": ["name", "created"], + "properties": { + "name": {"type": "string"}, + "created": {"type": "string"} + } + }, + "spec": %v + } +}` + +// getServerEventSchema returns full server event JSON schema +func getServerEventSchema() string { + return fmt.Sprintf(schemaTemplate, serverEventSchema) +} + +// serverEventSchema is the server event spec schema +const serverEventSchema = `{ + "type": "object", + "additionalProperties": false, + "required": ["id", "action", "accountID", "serverID"], + "properties": { + "id": {"type": "string"}, + "action": {"type": "string"}, + "accountID": {"type": "string"}, + "serverID": {"type": "string"} + } +}` + +// getUserEventSchema returns full user event JSON schema +func getUserEventSchema() string { + return fmt.Sprintf(schemaTemplate, userEventSchema) +} + +// userEventSchema is the user event spec schema +const userEventSchema = `{ + "type": "object", + "additionalProperties": false, + "required": ["id", "action", "accountID", "userID"], + "properties": { + "id": {"type": "string"}, + "action": {"type": "string"}, + "accountID": {"type": "string"}, + "userID": {"type": "string"} + } +}` + +// unmarshalWithSchema unmarshals the provided data into the provided object +// using specified JSON schema +func unmarshalWithSchema(objectSchema string, data []byte, object interface{}) error { + schema, err := jsonschema.New([]byte(objectSchema)) + if err != nil { + return trace.Wrap(err) + } + raw := map[string]interface{}{} + if err := json.Unmarshal(data, &raw); err != nil { + return trace.Wrap(err) + } + processed, err := schema.ProcessObject(raw) + if err != nil { + return trace.Wrap(err) + } + bytes, err := json.Marshal(processed) + if err != nil { + return trace.Wrap(err) + } + if err := json.Unmarshal(bytes, object); err != nil { + return trace.Wrap(err) + } + return nil +} + +// marshalWithSchema marshals the provided objects while checking the specified schema +func marshalWithSchema(objectSchema string, object interface{}) ([]byte, error) { + schema, err := jsonschema.New([]byte(objectSchema)) + if err != nil { + return nil, trace.Wrap(err) + } + processed, err := schema.ProcessObject(object) + if err != nil { + return nil, trace.Wrap(err) + } + bytes, err := json.Marshal(processed) + if err != nil { + return nil, trace.Wrap(err) + } + return bytes, nil +} diff --git a/vendor/github.com/gravitational/reporting/types/heartbeat.go b/vendor/github.com/gravitational/reporting/types/heartbeat.go new file mode 100644 index 00000000000..a903e2db1b2 --- /dev/null +++ b/vendor/github.com/gravitational/reporting/types/heartbeat.go @@ -0,0 +1,134 @@ +/* +Copyright 2017 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package types + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/gravitational/trace" +) + +// Heartbeat represents a heartbeat that is sent from control plane to teleport +type Heartbeat struct { + // Kind is resource kind, for heartbeat it is "heartbeat" + Kind string `json:"kind"` + // Version is the heartbeat resource version + Version string `json:"version"` + // Metadata is the heartbeat metadata + Metadata Metadata `json:"metadata"` + // Spec is the heartbeat spec + Spec HeartbeatSpec `json:"spec"` +} + +// HeartbeatSpec is the heartbeat resource spec +type HeartbeatSpec struct { + // Notifications is a list of notifications sent with the heartbeat + Notifications []Notification `json:"notifications,omitempty"` +} + +// Notification represents a user notification message +type Notification struct { + // Type is the notification type + Type string `json:"type"` + // Severity is the notification severity: info, warning or error + Severity string `json:"severity"` + // Text is the notification plain text + Text string `json:"text"` + // HTML is the notification HTML + HTML string `json:"html"` +} + +// NewHeartbeat returns a new heartbeat +func NewHeartbeat(notifications ...Notification) *Heartbeat { + return &Heartbeat{ + Kind: KindHeartbeat, + Version: ResourceVersion, + Metadata: Metadata{ + Name: "heartbeat", + Created: time.Now().UTC(), + }, + Spec: HeartbeatSpec{ + Notifications: notifications, + }, + } +} + +// GetName returns the resource name +func (h *Heartbeat) GetName() string { return h.Metadata.Name } + +// GetMetadata returns the heartbeat metadata +func (h *Heartbeat) GetMetadata() Metadata { return h.Metadata } + +// UnmarshalHeartbeat unmarshals heartbeat with schema validation +func UnmarshalHeartbeat(bytes []byte) (*Heartbeat, error) { + var header resourceHeader + if err := json.Unmarshal(bytes, &header); err != nil { + return nil, trace.Wrap(err) + } + if header.Kind != KindHeartbeat { + return nil, trace.BadParameter("expected kind %q, got %q", + KindHeartbeat, header.Kind) + } + if header.Version != ResourceVersion { + return nil, trace.BadParameter("expected resource version %q, got %q", + ResourceVersion, header.Version) + } + var heartbeat Heartbeat + err := unmarshalWithSchema( + getHeartbeatSchema(), bytes, &heartbeat) + if err != nil { + return nil, trace.Wrap(err) + } + return &heartbeat, nil +} + +// MarshalHeartbeat marshals heartbeat with schema validation +func MarshalHeartbeat(h Heartbeat) ([]byte, error) { + bytes, err := marshalWithSchema(getHeartbeatSchema(), h) + if err != nil { + return nil, trace.Wrap(err) + } + return bytes, nil +} + +// heartbeatSchema is the heartbeat spec schema +const heartbeatSchema = `{ + "type": "object", + "properties": { + "notifications": { + "type": "array", + "items": { + "type": "object", + "required": ["type", "severity", "text", "html"], + "additionalProperties": false, + "properties": { + "type": {"type": "string"}, + "severity": {"type": "string"}, + "text": {"type": "string"}, + "html": {"type": "string"} + } + } + } + } +}` + +// getHeartbeatSchema returns the full heartbeat resource schema +func getHeartbeatSchema() string { + return fmt.Sprintf(schemaTemplate, heartbeatSchema) +} diff --git a/vendor/github.com/gravitational/reporting/types/types_test.go b/vendor/github.com/gravitational/reporting/types/types_test.go new file mode 100644 index 00000000000..fe4b4d2c3a9 --- /dev/null +++ b/vendor/github.com/gravitational/reporting/types/types_test.go @@ -0,0 +1,61 @@ +/* +Copyright 2017 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package types + +import ( + "testing" + + check "gopkg.in/check.v1" +) + +func TestTypes(t *testing.T) { check.TestingT(t) } + +type TypesSuite struct{} + +var _ = check.Suite(&TypesSuite{}) + +func (s *TypesSuite) TestHeartbeat(c *check.C) { + h := NewHeartbeat( + Notification{ + Type: NotificationUsage, + Severity: SeverityWarning, + Text: "Usage limit exceeded", + HTML: "
Usage limit exceeded
", + }, + Notification{ + Type: NotificationTerms, + Severity: SeverityError, + Text: "Terms of service violation", + HTML: "
Terms of service violation
", + }) + bytes, err := MarshalHeartbeat(*h) + c.Assert(err, check.IsNil) + unmarshaled, err := UnmarshalHeartbeat(bytes) + c.Assert(err, check.IsNil) + c.Assert(len(unmarshaled.Spec.Notifications), check.Equals, 2) + c.Assert(unmarshaled, check.DeepEquals, h) +} + +func (s *TypesSuite) TestEmptyHeartbeat(c *check.C) { + h := NewHeartbeat() + bytes, err := MarshalHeartbeat(*h) + c.Assert(err, check.IsNil) + unmarshaled, err := UnmarshalHeartbeat(bytes) + c.Assert(err, check.IsNil) + c.Assert(len(unmarshaled.Spec.Notifications), check.Equals, 0) + c.Assert(unmarshaled, check.DeepEquals, h) +} diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/context.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/context.go new file mode 100644 index 00000000000..ad42535662a --- /dev/null +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/context.go @@ -0,0 +1,139 @@ +package runtime + +import ( + "fmt" + "net" + "net/http" + "strconv" + "strings" + "time" + + "golang.org/x/net/context" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/metadata" +) + +const metadataHeaderPrefix = "Grpc-Metadata-" +const metadataTrailerPrefix = "Grpc-Trailer-" +const metadataGrpcTimeout = "Grpc-Timeout" + +const xForwardedFor = "X-Forwarded-For" +const xForwardedHost = "X-Forwarded-Host" + +var ( + // DefaultContextTimeout is used for gRPC call context.WithTimeout whenever a Grpc-Timeout inbound + // header isn't present. If the value is 0 the sent `context` will not have a timeout. + DefaultContextTimeout = 0 * time.Second +) + +/* +AnnotateContext adds context information such as metadata from the request. + +At a minimum, the RemoteAddr is included in the fashion of "X-Forwarded-For", +except that the forwarded destination is not another HTTP service but rather +a gRPC service. +*/ +func AnnotateContext(ctx context.Context, req *http.Request) (context.Context, error) { + var pairs []string + timeout := DefaultContextTimeout + if tm := req.Header.Get(metadataGrpcTimeout); tm != "" { + var err error + timeout, err = timeoutDecode(tm) + if err != nil { + return nil, grpc.Errorf(codes.InvalidArgument, "invalid grpc-timeout: %s", tm) + } + } + + for key, vals := range req.Header { + for _, val := range vals { + if key == "Authorization" { + pairs = append(pairs, "authorization", val) + continue + } + if strings.HasPrefix(key, metadataHeaderPrefix) { + pairs = append(pairs, key[len(metadataHeaderPrefix):], val) + } + } + } + if host := req.Header.Get(xForwardedHost); host != "" { + pairs = append(pairs, strings.ToLower(xForwardedHost), host) + } else if req.Host != "" { + pairs = append(pairs, strings.ToLower(xForwardedHost), req.Host) + } + + if addr := req.RemoteAddr; addr != "" { + if remoteIP, _, err := net.SplitHostPort(addr); err == nil { + if fwd := req.Header.Get(xForwardedFor); fwd == "" { + pairs = append(pairs, strings.ToLower(xForwardedFor), remoteIP) + } else { + pairs = append(pairs, strings.ToLower(xForwardedFor), fmt.Sprintf("%s, %s", fwd, remoteIP)) + } + } else { + grpclog.Printf("invalid remote addr: %s", addr) + } + } + + if timeout != 0 { + ctx, _ = context.WithTimeout(ctx, timeout) + } + if len(pairs) == 0 { + return ctx, nil + } + return metadata.NewContext(ctx, metadata.Pairs(pairs...)), nil +} + +// ServerMetadata consists of metadata sent from gRPC server. +type ServerMetadata struct { + HeaderMD metadata.MD + TrailerMD metadata.MD +} + +type serverMetadataKey struct{} + +// NewServerMetadataContext creates a new context with ServerMetadata +func NewServerMetadataContext(ctx context.Context, md ServerMetadata) context.Context { + return context.WithValue(ctx, serverMetadataKey{}, md) +} + +// ServerMetadataFromContext returns the ServerMetadata in ctx +func ServerMetadataFromContext(ctx context.Context) (md ServerMetadata, ok bool) { + md, ok = ctx.Value(serverMetadataKey{}).(ServerMetadata) + return +} + +func timeoutDecode(s string) (time.Duration, error) { + size := len(s) + if size < 2 { + return 0, fmt.Errorf("timeout string is too short: %q", s) + } + d, ok := timeoutUnitToDuration(s[size-1]) + if !ok { + return 0, fmt.Errorf("timeout unit is not recognized: %q", s) + } + t, err := strconv.ParseInt(s[:size-1], 10, 64) + if err != nil { + return 0, err + } + return d * time.Duration(t), nil +} + +func timeoutUnitToDuration(u uint8) (d time.Duration, ok bool) { + switch u { + case 'H': + return time.Hour, true + case 'M': + return time.Minute, true + case 'S': + return time.Second, true + case 'm': + return time.Millisecond, true + case 'u': + return time.Microsecond, true + case 'n': + return time.Nanosecond, true + default: + } + return +} diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/context_test.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/context_test.go new file mode 100644 index 00000000000..abc1873fad9 --- /dev/null +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/context_test.go @@ -0,0 +1,169 @@ +package runtime_test + +import ( + "net/http" + "reflect" + "testing" + "time" + + "github.com/grpc-ecosystem/grpc-gateway/runtime" + "golang.org/x/net/context" + "google.golang.org/grpc/metadata" +) + +const ( + emptyForwardMetaCount = 1 +) + +func TestAnnotateContext_WorksWithEmpty(t *testing.T) { + ctx := context.Background() + + request, err := http.NewRequest("GET", "http://www.example.com", nil) + if err != nil { + t.Fatalf("http.NewRequest(%q, %q, nil) failed with %v; want success", "GET", "http://www.example.com", err) + } + request.Header.Add("Some-Irrelevant-Header", "some value") + annotated, err := runtime.AnnotateContext(ctx, request) + if err != nil { + t.Errorf("runtime.AnnotateContext(ctx, %#v) failed with %v; want success", request, err) + return + } + md, ok := metadata.FromContext(annotated) + if !ok || len(md) != emptyForwardMetaCount { + t.Errorf("Expected %d metadata items in context; got %v", emptyForwardMetaCount, md) + } +} + +func TestAnnotateContext_ForwardsGrpcMetadata(t *testing.T) { + ctx := context.Background() + request, err := http.NewRequest("GET", "http://www.example.com", nil) + if err != nil { + t.Fatalf("http.NewRequest(%q, %q, nil) failed with %v; want success", "GET", "http://www.example.com", err) + } + request.Header.Add("Some-Irrelevant-Header", "some value") + request.Header.Add("Grpc-Metadata-FooBar", "Value1") + request.Header.Add("Grpc-Metadata-Foo-BAZ", "Value2") + request.Header.Add("Grpc-Metadata-foo-bAz", "Value3") + request.Header.Add("Authorization", "Token 1234567890") + annotated, err := runtime.AnnotateContext(ctx, request) + if err != nil { + t.Errorf("runtime.AnnotateContext(ctx, %#v) failed with %v; want success", request, err) + return + } + md, ok := metadata.FromContext(annotated) + if got, want := len(md), emptyForwardMetaCount+3; !ok || got != want { + t.Errorf("Expected %d metadata items in context; got %d", got, want) + } + if got, want := md["foobar"], []string{"Value1"}; !reflect.DeepEqual(got, want) { + t.Errorf(`md["foobar"] = %q; want %q`, got, want) + } + if got, want := md["foo-baz"], []string{"Value2", "Value3"}; !reflect.DeepEqual(got, want) { + t.Errorf(`md["foo-baz"] = %q want %q`, got, want) + } + if got, want := md["authorization"], []string{"Token 1234567890"}; !reflect.DeepEqual(got, want) { + t.Errorf(`md["authorization"] = %q want %q`, got, want) + } +} + +func TestAnnotateContext_XForwardedFor(t *testing.T) { + ctx := context.Background() + request, err := http.NewRequest("GET", "http://bar.foo.example.com", nil) + if err != nil { + t.Fatalf("http.NewRequest(%q, %q, nil) failed with %v; want success", "GET", "http://bar.foo.example.com", err) + } + request.Header.Add("X-Forwarded-For", "192.0.2.100") // client + request.RemoteAddr = "192.0.2.200:12345" // proxy + + annotated, err := runtime.AnnotateContext(ctx, request) + if err != nil { + t.Errorf("runtime.AnnotateContext(ctx, %#v) failed with %v; want success", request, err) + return + } + md, ok := metadata.FromContext(annotated) + if !ok || len(md) != emptyForwardMetaCount+1 { + t.Errorf("Expected %d metadata items in context; got %v", emptyForwardMetaCount+1, md) + } + if got, want := md["x-forwarded-host"], []string{"bar.foo.example.com"}; !reflect.DeepEqual(got, want) { + t.Errorf(`md["host"] = %v; want %v`, got, want) + } + // Note: it must be in order client, proxy1, proxy2 + if got, want := md["x-forwarded-for"], []string{"192.0.2.100, 192.0.2.200"}; !reflect.DeepEqual(got, want) { + t.Errorf(`md["x-forwarded-for"] = %v want %v`, got, want) + } +} + +func TestAnnotateContext_SupportsTimeouts(t *testing.T) { + ctx := context.Background() + request, err := http.NewRequest("GET", "http://example.com", nil) + if err != nil { + t.Fatalf(`http.NewRequest("GET", "http://example.com", nil failed with %v; want success`, err) + } + annotated, err := runtime.AnnotateContext(ctx, request) + if err != nil { + t.Errorf("runtime.AnnotateContext(ctx, %#v) failed with %v; want success", request, err) + return + } + if _, ok := annotated.Deadline(); ok { + // no deadline by default + t.Errorf("annotated.Deadline() = _, true; want _, false") + } + + const acceptableError = 50 * time.Millisecond + runtime.DefaultContextTimeout = 10 * time.Second + annotated, err = runtime.AnnotateContext(ctx, request) + if err != nil { + t.Errorf("runtime.AnnotateContext(ctx, %#v) failed with %v; want success", request, err) + return + } + deadline, ok := annotated.Deadline() + if !ok { + t.Errorf("annotated.Deadline() = _, false; want _, true") + } + if got, want := deadline.Sub(time.Now()), runtime.DefaultContextTimeout; got-want > acceptableError || got-want < -acceptableError { + t.Errorf("deadline.Sub(time.Now()) = %v; want %v; with error %v", got, want, acceptableError) + } + + for _, spec := range []struct { + timeout string + want time.Duration + }{ + { + timeout: "17H", + want: 17 * time.Hour, + }, + { + timeout: "19M", + want: 19 * time.Minute, + }, + { + timeout: "23S", + want: 23 * time.Second, + }, + { + timeout: "1009m", + want: 1009 * time.Millisecond, + }, + { + timeout: "1000003u", + want: 1000003 * time.Microsecond, + }, + { + timeout: "100000007n", + want: 100000007 * time.Nanosecond, + }, + } { + request.Header.Set("Grpc-Timeout", spec.timeout) + annotated, err = runtime.AnnotateContext(ctx, request) + if err != nil { + t.Errorf("runtime.AnnotateContext(ctx, %#v) failed with %v; want success", request, err) + return + } + deadline, ok := annotated.Deadline() + if !ok { + t.Errorf("annotated.Deadline() = _, false; want _, true; timeout = %q", spec.timeout) + } + if got, want := deadline.Sub(time.Now()), spec.want; got-want > acceptableError || got-want < -acceptableError { + t.Errorf("deadline.Sub(time.Now()) = %v; want %v; with error %v; timeout= %q", got, want, acceptableError, spec.timeout) + } + } +} diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/convert.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/convert.go new file mode 100644 index 00000000000..1af5cc4ebdd --- /dev/null +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/convert.go @@ -0,0 +1,58 @@ +package runtime + +import ( + "strconv" +) + +// String just returns the given string. +// It is just for compatibility to other types. +func String(val string) (string, error) { + return val, nil +} + +// Bool converts the given string representation of a boolean value into bool. +func Bool(val string) (bool, error) { + return strconv.ParseBool(val) +} + +// Float64 converts the given string representation into representation of a floating point number into float64. +func Float64(val string) (float64, error) { + return strconv.ParseFloat(val, 64) +} + +// Float32 converts the given string representation of a floating point number into float32. +func Float32(val string) (float32, error) { + f, err := strconv.ParseFloat(val, 32) + if err != nil { + return 0, err + } + return float32(f), nil +} + +// Int64 converts the given string representation of an integer into int64. +func Int64(val string) (int64, error) { + return strconv.ParseInt(val, 0, 64) +} + +// Int32 converts the given string representation of an integer into int32. +func Int32(val string) (int32, error) { + i, err := strconv.ParseInt(val, 0, 32) + if err != nil { + return 0, err + } + return int32(i), nil +} + +// Uint64 converts the given string representation of an integer into uint64. +func Uint64(val string) (uint64, error) { + return strconv.ParseUint(val, 0, 64) +} + +// Uint32 converts the given string representation of an integer into uint32. +func Uint32(val string) (uint32, error) { + i, err := strconv.ParseUint(val, 0, 32) + if err != nil { + return 0, err + } + return uint32(i), nil +} diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/doc.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/doc.go new file mode 100644 index 00000000000..b6e5ddf7a9f --- /dev/null +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/doc.go @@ -0,0 +1,5 @@ +/* +Package runtime contains runtime helper functions used by +servers which protoc-gen-grpc-gateway generates. +*/ +package runtime diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/errors.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/errors.go new file mode 100644 index 00000000000..7d7a9b22415 --- /dev/null +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/errors.go @@ -0,0 +1,121 @@ +package runtime + +import ( + "io" + "net/http" + + "github.com/golang/protobuf/proto" + "golang.org/x/net/context" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/grpclog" +) + +// HTTPStatusFromCode converts a gRPC error code into the corresponding HTTP response status. +func HTTPStatusFromCode(code codes.Code) int { + switch code { + case codes.OK: + return http.StatusOK + case codes.Canceled: + return http.StatusRequestTimeout + case codes.Unknown: + return http.StatusInternalServerError + case codes.InvalidArgument: + return http.StatusBadRequest + case codes.DeadlineExceeded: + return http.StatusRequestTimeout + case codes.NotFound: + return http.StatusNotFound + case codes.AlreadyExists: + return http.StatusConflict + case codes.PermissionDenied: + return http.StatusForbidden + case codes.Unauthenticated: + return http.StatusUnauthorized + case codes.ResourceExhausted: + return http.StatusForbidden + case codes.FailedPrecondition: + return http.StatusPreconditionFailed + case codes.Aborted: + return http.StatusConflict + case codes.OutOfRange: + return http.StatusBadRequest + case codes.Unimplemented: + return http.StatusNotImplemented + case codes.Internal: + return http.StatusInternalServerError + case codes.Unavailable: + return http.StatusServiceUnavailable + case codes.DataLoss: + return http.StatusInternalServerError + } + + grpclog.Printf("Unknown gRPC error code: %v", code) + return http.StatusInternalServerError +} + +var ( + // HTTPError replies to the request with the error. + // You can set a custom function to this variable to customize error format. + HTTPError = DefaultHTTPError + // OtherErrorHandler handles the following error used by the gateway: StatusMethodNotAllowed StatusNotFound and StatusBadRequest + OtherErrorHandler = DefaultOtherErrorHandler +) + +type errorBody struct { + Error string `json:"error"` + Code int `json:"code"` +} + +//Make this also conform to proto.Message for builtin JSONPb Marshaler +func (e *errorBody) Reset() { *e = errorBody{} } +func (e *errorBody) String() string { return proto.CompactTextString(e) } +func (*errorBody) ProtoMessage() {} + +// DefaultHTTPError is the default implementation of HTTPError. +// If "err" is an error from gRPC system, the function replies with the status code mapped by HTTPStatusFromCode. +// If otherwise, it replies with http.StatusInternalServerError. +// +// The response body returned by this function is a JSON object, +// which contains a member whose key is "error" and whose value is err.Error(). +func DefaultHTTPError(ctx context.Context, marshaler Marshaler, w http.ResponseWriter, _ *http.Request, err error) { + const fallback = `{"error": "failed to marshal error message"}` + + w.Header().Del("Trailer") + w.Header().Set("Content-Type", marshaler.ContentType()) + body := &errorBody{ + Error: grpc.ErrorDesc(err), + Code: int(grpc.Code(err)), + } + + buf, merr := marshaler.Marshal(body) + if merr != nil { + grpclog.Printf("Failed to marshal error message %q: %v", body, merr) + w.WriteHeader(http.StatusInternalServerError) + if _, err := io.WriteString(w, fallback); err != nil { + grpclog.Printf("Failed to write response: %v", err) + } + return + } + + md, ok := ServerMetadataFromContext(ctx) + if !ok { + grpclog.Printf("Failed to extract ServerMetadata from context") + } + + handleForwardResponseServerMetadata(w, md) + handleForwardResponseTrailerHeader(w, md) + st := HTTPStatusFromCode(grpc.Code(err)) + w.WriteHeader(st) + if _, err := w.Write(buf); err != nil { + grpclog.Printf("Failed to write response: %v", err) + } + + handleForwardResponseTrailer(w, md) +} + +// DefaultOtherErrorHandler is the default implementation of OtherErrorHandler. +// It simply writes a string representation of the given error into "w". +func DefaultOtherErrorHandler(w http.ResponseWriter, _ *http.Request, msg string, code int) { + http.Error(w, msg, code) +} diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/errors_test.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/errors_test.go new file mode 100644 index 00000000000..2bdfca637c1 --- /dev/null +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/errors_test.go @@ -0,0 +1,56 @@ +package runtime_test + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/grpc-ecosystem/grpc-gateway/runtime" + "golang.org/x/net/context" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" +) + +func TestDefaultHTTPError(t *testing.T) { + ctx := context.Background() + + for _, spec := range []struct { + err error + status int + msg string + }{ + { + err: fmt.Errorf("example error"), + status: http.StatusInternalServerError, + msg: "example error", + }, + { + err: grpc.Errorf(codes.NotFound, "no such resource"), + status: http.StatusNotFound, + msg: "no such resource", + }, + } { + w := httptest.NewRecorder() + req, _ := http.NewRequest("", "", nil) // Pass in an empty request to match the signature + runtime.DefaultHTTPError(ctx, &runtime.JSONBuiltin{}, w, req, spec.err) + + if got, want := w.Header().Get("Content-Type"), "application/json"; got != want { + t.Errorf(`w.Header().Get("Content-Type") = %q; want %q; on spec.err=%v`, got, want, spec.err) + } + if got, want := w.Code, spec.status; got != want { + t.Errorf("w.Code = %d; want %d", got, want) + } + + body := make(map[string]interface{}) + if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil { + t.Errorf("json.Unmarshal(%q, &body) failed with %v; want success", w.Body.Bytes(), err) + continue + } + if got, want := body["error"].(string), spec.msg; !strings.Contains(got, want) { + t.Errorf(`body["error"] = %q; want %q; on spec.err=%v`, got, want, spec.err) + } + } +} diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/handler.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/handler.go new file mode 100644 index 00000000000..bafa4285f91 --- /dev/null +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/handler.go @@ -0,0 +1,164 @@ +package runtime + +import ( + "fmt" + "io" + "net/http" + "net/textproto" + + "github.com/golang/protobuf/proto" + "github.com/grpc-ecosystem/grpc-gateway/runtime/internal" + "golang.org/x/net/context" + "google.golang.org/grpc" + "google.golang.org/grpc/grpclog" +) + +// ForwardResponseStream forwards the stream from gRPC server to REST client. +func ForwardResponseStream(ctx context.Context, marshaler Marshaler, w http.ResponseWriter, req *http.Request, recv func() (proto.Message, error), opts ...func(context.Context, http.ResponseWriter, proto.Message) error) { + f, ok := w.(http.Flusher) + if !ok { + grpclog.Printf("Flush not supported in %T", w) + http.Error(w, "unexpected type of web server", http.StatusInternalServerError) + return + } + + md, ok := ServerMetadataFromContext(ctx) + if !ok { + grpclog.Printf("Failed to extract ServerMetadata from context") + http.Error(w, "unexpected error", http.StatusInternalServerError) + return + } + handleForwardResponseServerMetadata(w, md) + + w.Header().Set("Transfer-Encoding", "chunked") + w.Header().Set("Content-Type", marshaler.ContentType()) + if err := handleForwardResponseOptions(ctx, w, nil, opts); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + f.Flush() + for { + resp, err := recv() + if err == io.EOF { + return + } + if err != nil { + handleForwardResponseStreamError(marshaler, w, err) + return + } + if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil { + handleForwardResponseStreamError(marshaler, w, err) + return + } + + buf, err := marshaler.Marshal(streamChunk(resp, nil)) + if err != nil { + grpclog.Printf("Failed to marshal response chunk: %v", err) + return + } + if _, err = fmt.Fprintf(w, "%s\n", buf); err != nil { + grpclog.Printf("Failed to send response chunk: %v", err) + return + } + f.Flush() + } +} + +func handleForwardResponseServerMetadata(w http.ResponseWriter, md ServerMetadata) { + for k, vs := range md.HeaderMD { + hKey := fmt.Sprintf("%s%s", metadataHeaderPrefix, k) + for i := range vs { + w.Header().Add(hKey, vs[i]) + } + } +} + +func handleForwardResponseTrailerHeader(w http.ResponseWriter, md ServerMetadata) { + for k := range md.TrailerMD { + tKey := textproto.CanonicalMIMEHeaderKey(fmt.Sprintf("%s%s", metadataTrailerPrefix, k)) + w.Header().Add("Trailer", tKey) + } +} + +func handleForwardResponseTrailer(w http.ResponseWriter, md ServerMetadata) { + for k, vs := range md.TrailerMD { + tKey := fmt.Sprintf("%s%s", metadataTrailerPrefix, k) + for i := range vs { + w.Header().Add(tKey, vs[i]) + } + } +} + +// ForwardResponseMessage forwards the message "resp" from gRPC server to REST client. +func ForwardResponseMessage(ctx context.Context, marshaler Marshaler, w http.ResponseWriter, req *http.Request, resp proto.Message, opts ...func(context.Context, http.ResponseWriter, proto.Message) error) { + md, ok := ServerMetadataFromContext(ctx) + if !ok { + grpclog.Printf("Failed to extract ServerMetadata from context") + } + + handleForwardResponseServerMetadata(w, md) + handleForwardResponseTrailerHeader(w, md) + w.Header().Set("Content-Type", marshaler.ContentType()) + if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil { + HTTPError(ctx, marshaler, w, req, err) + return + } + + buf, err := marshaler.Marshal(resp) + if err != nil { + grpclog.Printf("Marshal error: %v", err) + HTTPError(ctx, marshaler, w, req, err) + return + } + + if _, err = w.Write(buf); err != nil { + grpclog.Printf("Failed to write response: %v", err) + } + + handleForwardResponseTrailer(w, md) +} + +func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, resp proto.Message, opts []func(context.Context, http.ResponseWriter, proto.Message) error) error { + if len(opts) == 0 { + return nil + } + for _, opt := range opts { + if err := opt(ctx, w, resp); err != nil { + grpclog.Printf("Error handling ForwardResponseOptions: %v", err) + return err + } + } + return nil +} + +func handleForwardResponseStreamError(marshaler Marshaler, w http.ResponseWriter, err error) { + buf, merr := marshaler.Marshal(streamChunk(nil, err)) + if merr != nil { + grpclog.Printf("Failed to marshal an error: %v", merr) + return + } + if _, werr := fmt.Fprintf(w, "%s\n", buf); werr != nil { + grpclog.Printf("Failed to notify error to client: %v", werr) + return + } +} + +func streamChunk(result proto.Message, err error) map[string]proto.Message { + if err != nil { + grpcCode := grpc.Code(err) + httpCode := HTTPStatusFromCode(grpcCode) + return map[string]proto.Message{ + "error": &internal.StreamError{ + GrpcCode: int32(grpcCode), + HttpCode: int32(httpCode), + Message: err.Error(), + HttpStatus: http.StatusText(httpCode), + }, + } + } + if result == nil { + return streamChunk(nil, fmt.Errorf("empty response")) + } + return map[string]proto.Message{"result": result} +} diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/internal/stream_chunk.pb.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/internal/stream_chunk.pb.go new file mode 100644 index 00000000000..524e0d3c34c --- /dev/null +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/internal/stream_chunk.pb.go @@ -0,0 +1,65 @@ +// Code generated by protoc-gen-go. +// source: runtime/internal/stream_chunk.proto +// DO NOT EDIT! + +/* +Package internal is a generated protocol buffer package. + +It is generated from these files: + runtime/internal/stream_chunk.proto + +It has these top-level messages: + StreamError +*/ +package internal + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +// StreamError is a response type which is returned when +// streaming rpc returns an error. +type StreamError struct { + GrpcCode int32 `protobuf:"varint,1,opt,name=grpc_code,json=grpcCode" json:"grpc_code,omitempty"` + HttpCode int32 `protobuf:"varint,2,opt,name=http_code,json=httpCode" json:"http_code,omitempty"` + Message string `protobuf:"bytes,3,opt,name=message" json:"message,omitempty"` + HttpStatus string `protobuf:"bytes,4,opt,name=http_status,json=httpStatus" json:"http_status,omitempty"` +} + +func (m *StreamError) Reset() { *m = StreamError{} } +func (m *StreamError) String() string { return proto.CompactTextString(m) } +func (*StreamError) ProtoMessage() {} +func (*StreamError) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } + +func init() { + proto.RegisterType((*StreamError)(nil), "grpc.gateway.runtime.StreamError") +} + +func init() { proto.RegisterFile("runtime/internal/stream_chunk.proto", fileDescriptor0) } + +var fileDescriptor0 = []byte{ + // 180 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0x52, 0x2e, 0x2a, 0xcd, 0x2b, + 0xc9, 0xcc, 0x4d, 0xd5, 0xcf, 0xcc, 0x2b, 0x49, 0x2d, 0xca, 0x4b, 0xcc, 0xd1, 0x2f, 0x2e, 0x29, + 0x4a, 0x4d, 0xcc, 0x8d, 0x4f, 0xce, 0x28, 0xcd, 0xcb, 0xd6, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, + 0x12, 0x49, 0x2f, 0x2a, 0x48, 0xd6, 0x4b, 0x4f, 0x2c, 0x49, 0x2d, 0x4f, 0xac, 0xd4, 0x83, 0xea, + 0x50, 0x6a, 0x62, 0xe4, 0xe2, 0x0e, 0x06, 0x2b, 0x76, 0x2d, 0x2a, 0xca, 0x2f, 0x12, 0x92, 0xe6, + 0xe2, 0x04, 0xa9, 0x8b, 0x4f, 0xce, 0x4f, 0x49, 0x95, 0x60, 0x54, 0x60, 0xd4, 0x60, 0x0d, 0xe2, + 0x00, 0x09, 0x38, 0x03, 0xf9, 0x20, 0xc9, 0x8c, 0x92, 0x92, 0x02, 0x88, 0x24, 0x13, 0x44, 0x12, + 0x24, 0x00, 0x96, 0x94, 0xe0, 0x62, 0xcf, 0x4d, 0x2d, 0x2e, 0x4e, 0x4c, 0x4f, 0x95, 0x60, 0x06, + 0x4a, 0x71, 0x06, 0xc1, 0xb8, 0x42, 0xf2, 0x5c, 0xdc, 0x60, 0x6d, 0xc5, 0x25, 0x89, 0x25, 0xa5, + 0xc5, 0x12, 0x2c, 0x60, 0x59, 0x2e, 0x90, 0x50, 0x30, 0x58, 0xc4, 0x89, 0x2b, 0x8a, 0x03, 0xe6, + 0xf2, 0x24, 0x36, 0xb0, 0x6b, 0x8d, 0x01, 0x01, 0x00, 0x00, 0xff, 0xff, 0xa9, 0x07, 0x92, 0xb6, + 0xd4, 0x00, 0x00, 0x00, +} diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/internal/stream_chunk.proto b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/internal/stream_chunk.proto new file mode 100644 index 00000000000..f7fba56c35b --- /dev/null +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/internal/stream_chunk.proto @@ -0,0 +1,12 @@ +syntax = "proto3"; +package grpc.gateway.runtime; +option go_package = "internal"; + +// StreamError is a response type which is returned when +// streaming rpc returns an error. +message StreamError { + int32 grpc_code = 1; + int32 http_code = 2; + string message = 3; + string http_status = 4; +} diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/marshal_json.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/marshal_json.go new file mode 100644 index 00000000000..0acd2ca29ef --- /dev/null +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/marshal_json.go @@ -0,0 +1,37 @@ +package runtime + +import ( + "encoding/json" + "io" +) + +// JSONBuiltin is a Marshaler which marshals/unmarshals into/from JSON +// with the standard "encoding/json" package of Golang. +// Although it is generally faster for simple proto messages than JSONPb, +// it does not support advanced features of protobuf, e.g. map, oneof, .... +type JSONBuiltin struct{} + +// ContentType always Returns "application/json". +func (*JSONBuiltin) ContentType() string { + return "application/json" +} + +// Marshal marshals "v" into JSON +func (j *JSONBuiltin) Marshal(v interface{}) ([]byte, error) { + return json.Marshal(v) +} + +// Unmarshal unmarshals JSON data into "v". +func (j *JSONBuiltin) Unmarshal(data []byte, v interface{}) error { + return json.Unmarshal(data, v) +} + +// NewDecoder returns a Decoder which reads JSON stream from "r". +func (j *JSONBuiltin) NewDecoder(r io.Reader) Decoder { + return json.NewDecoder(r) +} + +// NewEncoder returns an Encoder which writes JSON stream into "w". +func (j *JSONBuiltin) NewEncoder(w io.Writer) Encoder { + return json.NewEncoder(w) +} diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/marshal_json_test.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/marshal_json_test.go new file mode 100644 index 00000000000..e6efa291072 --- /dev/null +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/marshal_json_test.go @@ -0,0 +1,245 @@ +package runtime_test + +import ( + "bytes" + "encoding/json" + "reflect" + "strings" + "testing" + + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes/empty" + structpb "github.com/golang/protobuf/ptypes/struct" + "github.com/golang/protobuf/ptypes/timestamp" + "github.com/golang/protobuf/ptypes/wrappers" + "github.com/grpc-ecosystem/grpc-gateway/examples/examplepb" + "github.com/grpc-ecosystem/grpc-gateway/runtime" +) + +func TestJSONBuiltinMarshal(t *testing.T) { + var m runtime.JSONBuiltin + msg := examplepb.SimpleMessage{ + Id: "foo", + } + + buf, err := m.Marshal(&msg) + if err != nil { + t.Errorf("m.Marshal(%v) failed with %v; want success", &msg, err) + } + + var got examplepb.SimpleMessage + if err := json.Unmarshal(buf, &got); err != nil { + t.Errorf("json.Unmarshal(%q, &got) failed with %v; want success", buf, err) + } + if want := msg; !reflect.DeepEqual(got, want) { + t.Errorf("got = %v; want %v", &got, &want) + } +} + +func TestJSONBuiltinMarshalField(t *testing.T) { + var m runtime.JSONBuiltin + for _, fixt := range builtinFieldFixtures { + buf, err := m.Marshal(fixt.data) + if err != nil { + t.Errorf("m.Marshal(%v) failed with %v; want success", fixt.data, err) + } + if got, want := string(buf), fixt.json; got != want { + t.Errorf("got = %q; want %q; data = %#v", got, want, fixt.data) + } + } +} + +func TestJSONBuiltinMarshalFieldKnownErrors(t *testing.T) { + var m runtime.JSONBuiltin + for _, fixt := range builtinKnownErrors { + buf, err := m.Marshal(fixt.data) + if err != nil { + t.Errorf("m.Marshal(%v) failed with %v; want success", fixt.data, err) + } + if got, want := string(buf), fixt.json; got == want { + t.Errorf("surprisingly got = %q; as want %q; data = %#v", got, want, fixt.data) + } + } +} + +func TestJSONBuiltinsnmarshal(t *testing.T) { + var ( + m runtime.JSONBuiltin + got examplepb.SimpleMessage + + data = []byte(`{"id": "foo"}`) + ) + if err := m.Unmarshal(data, &got); err != nil { + t.Errorf("m.Unmarshal(%q, &got) failed with %v; want success", data, err) + } + + want := examplepb.SimpleMessage{ + Id: "foo", + } + if !reflect.DeepEqual(got, want) { + t.Errorf("got = %v; want = %v", &got, &want) + } +} + +func TestJSONBuiltinUnmarshalField(t *testing.T) { + var m runtime.JSONBuiltin + for _, fixt := range builtinFieldFixtures { + dest := reflect.New(reflect.TypeOf(fixt.data)) + if err := m.Unmarshal([]byte(fixt.json), dest.Interface()); err != nil { + t.Errorf("m.Unmarshal(%q, dest) failed with %v; want success", fixt.json, err) + } + + if got, want := dest.Elem().Interface(), fixt.data; !reflect.DeepEqual(got, want) { + t.Errorf("got = %#v; want = %#v; input = %q", got, want, fixt.json) + } + } +} + +func TestJSONBuiltinUnmarshalFieldKnownErrors(t *testing.T) { + var m runtime.JSONBuiltin + for _, fixt := range builtinKnownErrors { + dest := reflect.New(reflect.TypeOf(fixt.data)) + if err := m.Unmarshal([]byte(fixt.json), dest.Interface()); err == nil { + t.Errorf("m.Unmarshal(%q, dest) succeeded; want ane error", fixt.json) + } + } +} + +func TestJSONBuiltinEncoder(t *testing.T) { + var m runtime.JSONBuiltin + msg := examplepb.SimpleMessage{ + Id: "foo", + } + + var buf bytes.Buffer + enc := m.NewEncoder(&buf) + if err := enc.Encode(&msg); err != nil { + t.Errorf("enc.Encode(%v) failed with %v; want success", &msg, err) + } + + var got examplepb.SimpleMessage + if err := json.Unmarshal(buf.Bytes(), &got); err != nil { + t.Errorf("json.Unmarshal(%q, &got) failed with %v; want success", buf.String(), err) + } + if want := msg; !reflect.DeepEqual(got, want) { + t.Errorf("got = %v; want %v", &got, &want) + } +} + +func TestJSONBuiltinEncoderFields(t *testing.T) { + var m runtime.JSONBuiltin + for _, fixt := range builtinFieldFixtures { + var buf bytes.Buffer + enc := m.NewEncoder(&buf) + if err := enc.Encode(fixt.data); err != nil { + t.Errorf("enc.Encode(%#v) failed with %v; want success", fixt.data, err) + } + + if got, want := buf.String(), fixt.json+"\n"; got != want { + t.Errorf("got = %q; want %q; data = %#v", got, want, fixt.data) + } + } +} + +func TestJSONBuiltinDecoder(t *testing.T) { + var ( + m runtime.JSONBuiltin + got examplepb.SimpleMessage + + data = `{"id": "foo"}` + ) + r := strings.NewReader(data) + dec := m.NewDecoder(r) + if err := dec.Decode(&got); err != nil { + t.Errorf("m.Unmarshal(&got) failed with %v; want success", err) + } + + want := examplepb.SimpleMessage{ + Id: "foo", + } + if !reflect.DeepEqual(got, want) { + t.Errorf("got = %v; want = %v", &got, &want) + } +} + +func TestJSONBuiltinDecoderFields(t *testing.T) { + var m runtime.JSONBuiltin + for _, fixt := range builtinFieldFixtures { + r := strings.NewReader(fixt.json) + dec := m.NewDecoder(r) + dest := reflect.New(reflect.TypeOf(fixt.data)) + if err := dec.Decode(dest.Interface()); err != nil { + t.Errorf("dec.Decode(dest) failed with %v; want success; data = %q", err, fixt.json) + } + + if got, want := dest.Elem().Interface(), fixt.data; !reflect.DeepEqual(got, want) { + t.Errorf("got = %v; want = %v; input = %q", got, want, fixt.json) + } + } +} + +var ( + builtinFieldFixtures = []struct { + data interface{} + json string + }{ + {data: "", json: `""`}, + {data: proto.String(""), json: `""`}, + {data: "foo", json: `"foo"`}, + {data: proto.String("foo"), json: `"foo"`}, + {data: int32(-1), json: "-1"}, + {data: proto.Int32(-1), json: "-1"}, + {data: int64(-1), json: "-1"}, + {data: proto.Int64(-1), json: "-1"}, + {data: uint32(123), json: "123"}, + {data: proto.Uint32(123), json: "123"}, + {data: uint64(123), json: "123"}, + {data: proto.Uint64(123), json: "123"}, + {data: float32(-1.5), json: "-1.5"}, + {data: proto.Float32(-1.5), json: "-1.5"}, + {data: float64(-1.5), json: "-1.5"}, + {data: proto.Float64(-1.5), json: "-1.5"}, + {data: true, json: "true"}, + {data: proto.Bool(true), json: "true"}, + {data: (*string)(nil), json: "null"}, + {data: new(empty.Empty), json: "{}"}, + {data: examplepb.NumericEnum_ONE, json: "1"}, + { + data: (*examplepb.NumericEnum)(proto.Int32(int32(examplepb.NumericEnum_ONE))), + json: "1", + }, + } + builtinKnownErrors = []struct { + data interface{} + json string + }{ + {data: examplepb.NumericEnum_ONE, json: "ONE"}, + { + data: (*examplepb.NumericEnum)(proto.Int32(int32(examplepb.NumericEnum_ONE))), + json: "ONE", + }, + { + data: &examplepb.ABitOfEverything_OneofString{OneofString: "abc"}, + json: `"abc"`, + }, + { + data: ×tamp.Timestamp{ + Seconds: 1462875553, + Nanos: 123000000, + }, + json: `"2016-05-10T10:19:13.123Z"`, + }, + { + data: &wrappers.Int32Value{Value: 123}, + json: "123", + }, + { + data: &structpb.Value{ + Kind: &structpb.Value_StringValue{ + StringValue: "abc", + }, + }, + json: `"abc"`, + }, + } +) diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/marshal_jsonpb.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/marshal_jsonpb.go new file mode 100644 index 00000000000..9a42191119a --- /dev/null +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/marshal_jsonpb.go @@ -0,0 +1,182 @@ +package runtime + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "reflect" + + "github.com/golang/protobuf/jsonpb" + "github.com/golang/protobuf/proto" +) + +// JSONPb is a Marshaler which marshals/unmarshals into/from JSON +// with the "github.com/golang/protobuf/jsonpb". +// It supports fully functionality of protobuf unlike JSONBuiltin. +type JSONPb jsonpb.Marshaler + +// ContentType always returns "application/json". +func (*JSONPb) ContentType() string { + return "application/json" +} + +// Marshal marshals "v" into JSON +// Currently it can marshal only proto.Message. +// TODO(yugui) Support fields of primitive types in a message. +func (j *JSONPb) Marshal(v interface{}) ([]byte, error) { + if _, ok := v.(proto.Message); !ok { + return j.marshalNonProtoField(v) + } + + var buf bytes.Buffer + if err := j.marshalTo(&buf, v); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func (j *JSONPb) marshalTo(w io.Writer, v interface{}) error { + p, ok := v.(proto.Message) + if !ok { + buf, err := j.marshalNonProtoField(v) + if err != nil { + return err + } + _, err = w.Write(buf) + return err + } + return (*jsonpb.Marshaler)(j).Marshal(w, p) +} + +// marshalNonProto marshals a non-message field of a protobuf message. +// This function does not correctly marshals arbitary data structure into JSON, +// but it is only capable of marshaling non-message field values of protobuf, +// i.e. primitive types, enums; pointers to primitives or enums; maps from +// integer/string types to primitives/enums/pointers to messages. +func (j *JSONPb) marshalNonProtoField(v interface{}) ([]byte, error) { + rv := reflect.ValueOf(v) + for rv.Kind() == reflect.Ptr { + if rv.IsNil() { + return []byte("null"), nil + } + rv = rv.Elem() + } + + if rv.Kind() == reflect.Map { + m := make(map[string]*json.RawMessage) + for _, k := range rv.MapKeys() { + buf, err := j.Marshal(rv.MapIndex(k).Interface()) + if err != nil { + return nil, err + } + m[fmt.Sprintf("%v", k.Interface())] = (*json.RawMessage)(&buf) + } + if j.Indent != "" { + return json.MarshalIndent(m, "", j.Indent) + } + return json.Marshal(m) + } + if enum, ok := rv.Interface().(protoEnum); ok && !j.EnumsAsInts { + return json.Marshal(enum.String()) + } + return json.Marshal(rv.Interface()) +} + +// Unmarshal unmarshals JSON "data" into "v" +// Currently it can marshal only proto.Message. +// TODO(yugui) Support fields of primitive types in a message. +func (j *JSONPb) Unmarshal(data []byte, v interface{}) error { + return unmarshalJSONPb(data, v) +} + +// NewDecoder returns a Decoder which reads JSON stream from "r". +func (j *JSONPb) NewDecoder(r io.Reader) Decoder { + d := json.NewDecoder(r) + return DecoderFunc(func(v interface{}) error { return decodeJSONPb(d, v) }) +} + +// NewEncoder returns an Encoder which writes JSON stream into "w". +func (j *JSONPb) NewEncoder(w io.Writer) Encoder { + return EncoderFunc(func(v interface{}) error { return j.marshalTo(w, v) }) +} + +func unmarshalJSONPb(data []byte, v interface{}) error { + d := json.NewDecoder(bytes.NewReader(data)) + return decodeJSONPb(d, v) +} + +func decodeJSONPb(d *json.Decoder, v interface{}) error { + p, ok := v.(proto.Message) + if !ok { + return decodeNonProtoField(d, v) + } + return jsonpb.UnmarshalNext(d, p) +} + +func decodeNonProtoField(d *json.Decoder, v interface{}) error { + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr { + return fmt.Errorf("%T is not a pointer", v) + } + for rv.Kind() == reflect.Ptr { + if rv.IsNil() { + rv.Set(reflect.New(rv.Type().Elem())) + } + if rv.Type().ConvertibleTo(typeProtoMessage) { + return jsonpb.UnmarshalNext(d, rv.Interface().(proto.Message)) + } + rv = rv.Elem() + } + if rv.Kind() == reflect.Map { + if rv.IsNil() { + rv.Set(reflect.MakeMap(rv.Type())) + } + conv, ok := convFromType[rv.Type().Key().Kind()] + if !ok { + return fmt.Errorf("unsupported type of map field key: %v", rv.Type().Key()) + } + + m := make(map[string]*json.RawMessage) + if err := d.Decode(&m); err != nil { + return err + } + for k, v := range m { + result := conv.Call([]reflect.Value{reflect.ValueOf(k)}) + if err := result[1].Interface(); err != nil { + return err.(error) + } + bk := result[0] + bv := reflect.New(rv.Type().Elem()) + if err := unmarshalJSONPb([]byte(*v), bv.Interface()); err != nil { + return err + } + rv.SetMapIndex(bk, bv.Elem()) + } + return nil + } + if _, ok := rv.Interface().(protoEnum); ok { + var repr interface{} + if err := d.Decode(&repr); err != nil { + return err + } + switch repr.(type) { + case string: + // TODO(yugui) Should use proto.StructProperties? + return fmt.Errorf("unmarshaling of symbolic enum %q not supported: %T", repr, rv.Interface()) + case float64: + rv.Set(reflect.ValueOf(int32(repr.(float64))).Convert(rv.Type())) + return nil + default: + return fmt.Errorf("cannot assign %#v into Go type %T", repr, rv.Interface()) + } + } + return d.Decode(v) +} + +type protoEnum interface { + fmt.Stringer + EnumDescriptor() ([]byte, []int) +} + +var typeProtoMessage = reflect.TypeOf((*proto.Message)(nil)).Elem() diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/marshal_jsonpb_test.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/marshal_jsonpb_test.go new file mode 100644 index 00000000000..01e7ce87251 --- /dev/null +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/marshal_jsonpb_test.go @@ -0,0 +1,606 @@ +package runtime_test + +import ( + "bytes" + "reflect" + "strings" + "testing" + + "github.com/golang/protobuf/jsonpb" + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes/duration" + "github.com/golang/protobuf/ptypes/empty" + structpb "github.com/golang/protobuf/ptypes/struct" + "github.com/golang/protobuf/ptypes/timestamp" + "github.com/golang/protobuf/ptypes/wrappers" + "github.com/grpc-ecosystem/grpc-gateway/examples/examplepb" + "github.com/grpc-ecosystem/grpc-gateway/runtime" +) + +func TestJSONPbMarshal(t *testing.T) { + msg := examplepb.ABitOfEverything{ + Uuid: "6EC2446F-7E89-4127-B3E6-5C05E6BECBA7", + Nested: []*examplepb.ABitOfEverything_Nested{ + { + Name: "foo", + Amount: 12345, + }, + }, + Uint64Value: 0xFFFFFFFFFFFFFFFF, + EnumValue: examplepb.NumericEnum_ONE, + OneofValue: &examplepb.ABitOfEverything_OneofString{ + OneofString: "bar", + }, + MapValue: map[string]examplepb.NumericEnum{ + "a": examplepb.NumericEnum_ONE, + "b": examplepb.NumericEnum_ZERO, + }, + } + + for _, spec := range []struct { + enumsAsInts, emitDefaults bool + indent string + origName bool + verifier func(json string) + }{ + { + verifier: func(json string) { + if strings.ContainsAny(json, " \t\r\n") { + t.Errorf("strings.ContainsAny(%q, %q) = true; want false", json, " \t\r\n") + } + if !strings.Contains(json, "ONE") { + t.Errorf(`strings.Contains(%q, "ONE") = false; want true`, json) + } + if want := "uint64Value"; !strings.Contains(json, want) { + t.Errorf(`strings.Contains(%q, %q) = false; want true`, json, want) + } + }, + }, + { + enumsAsInts: true, + verifier: func(json string) { + if strings.Contains(json, "ONE") { + t.Errorf(`strings.Contains(%q, "ONE") = true; want false`, json) + } + }, + }, + { + emitDefaults: true, + verifier: func(json string) { + if want := `"sfixed32Value"`; !strings.Contains(json, want) { + t.Errorf(`strings.Contains(%q, %q) = false; want true`, json, want) + } + }, + }, + { + indent: "\t\t", + verifier: func(json string) { + if want := "\t\t\"amount\":"; !strings.Contains(json, want) { + t.Errorf(`strings.Contains(%q, %q) = false; want true`, json, want) + } + }, + }, + { + origName: true, + verifier: func(json string) { + if want := "uint64_value"; !strings.Contains(json, want) { + t.Errorf(`strings.Contains(%q, %q) = false; want true`, json, want) + } + }, + }, + } { + m := runtime.JSONPb{ + EnumsAsInts: spec.enumsAsInts, + EmitDefaults: spec.emitDefaults, + Indent: spec.indent, + OrigName: spec.origName, + } + buf, err := m.Marshal(&msg) + if err != nil { + t.Errorf("m.Marshal(%v) failed with %v; want success; spec=%v", &msg, err, spec) + } + + var got examplepb.ABitOfEverything + if err := jsonpb.UnmarshalString(string(buf), &got); err != nil { + t.Errorf("jsonpb.UnmarshalString(%q, &got) failed with %v; want success; spec=%v", string(buf), err, spec) + } + if want := msg; !reflect.DeepEqual(got, want) { + t.Errorf("got = %v; want %v; spec=%v", &got, &want, spec) + } + if spec.verifier != nil { + spec.verifier(string(buf)) + } + } +} + +func TestJSONPbMarshalFields(t *testing.T) { + var m runtime.JSONPb + for _, spec := range []struct { + val interface{} + want string + }{} { + buf, err := m.Marshal(spec.val) + if err != nil { + t.Errorf("m.Marshal(%#v) failed with %v; want success", spec.val, err) + } + if got, want := string(buf), spec.want; got != want { + t.Errorf("m.Marshal(%#v) = %q; want %q", spec.val, got, want) + } + } + + m.EnumsAsInts = true + buf, err := m.Marshal(examplepb.NumericEnum_ONE) + if err != nil { + t.Errorf("m.Marshal(%#v) failed with %v; want success", examplepb.NumericEnum_ONE, err) + } + if got, want := string(buf), "1"; got != want { + t.Errorf("m.Marshal(%#v) = %q; want %q", examplepb.NumericEnum_ONE, got, want) + } +} + +func TestJSONPbUnmarshal(t *testing.T) { + var ( + m runtime.JSONPb + got examplepb.ABitOfEverything + ) + for _, data := range []string{ + `{ + "uuid": "6EC2446F-7E89-4127-B3E6-5C05E6BECBA7", + "nested": [ + {"name": "foo", "amount": 12345} + ], + "uint64Value": 18446744073709551615, + "enumValue": "ONE", + "oneofString": "bar", + "mapValue": { + "a": 1, + "b": 0 + } + }`, + `{ + "uuid": "6EC2446F-7E89-4127-B3E6-5C05E6BECBA7", + "nested": [ + {"name": "foo", "amount": 12345} + ], + "uint64Value": "18446744073709551615", + "enumValue": "ONE", + "oneofString": "bar", + "mapValue": { + "a": 1, + "b": 0 + } + }`, + `{ + "uuid": "6EC2446F-7E89-4127-B3E6-5C05E6BECBA7", + "nested": [ + {"name": "foo", "amount": 12345} + ], + "uint64Value": 18446744073709551615, + "enumValue": 1, + "oneofString": "bar", + "mapValue": { + "a": 1, + "b": 0 + } + }`, + } { + if err := m.Unmarshal([]byte(data), &got); err != nil { + t.Errorf("m.Unmarshal(%q, &got) failed with %v; want success", data, err) + } + + want := examplepb.ABitOfEverything{ + Uuid: "6EC2446F-7E89-4127-B3E6-5C05E6BECBA7", + Nested: []*examplepb.ABitOfEverything_Nested{ + { + Name: "foo", + Amount: 12345, + }, + }, + Uint64Value: 0xFFFFFFFFFFFFFFFF, + EnumValue: examplepb.NumericEnum_ONE, + OneofValue: &examplepb.ABitOfEverything_OneofString{ + OneofString: "bar", + }, + MapValue: map[string]examplepb.NumericEnum{ + "a": examplepb.NumericEnum_ONE, + "b": examplepb.NumericEnum_ZERO, + }, + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got = %v; want = %v", &got, &want) + } + } +} + +func TestJSONPbUnmarshalFields(t *testing.T) { + var m runtime.JSONPb + for _, fixt := range fieldFixtures { + if fixt.skipUnmarshal { + continue + } + + dest := reflect.New(reflect.TypeOf(fixt.data)) + if err := m.Unmarshal([]byte(fixt.json), dest.Interface()); err != nil { + t.Errorf("m.Unmarshal(%q, %T) failed with %v; want success", fixt.json, dest.Interface(), err) + } + if got, want := dest.Elem().Interface(), fixt.data; !reflect.DeepEqual(got, want) { + t.Errorf("dest = %#v; want %#v; input = %v", got, want, fixt.json) + } + } +} + +func TestJSONPbEncoder(t *testing.T) { + msg := examplepb.ABitOfEverything{ + Uuid: "6EC2446F-7E89-4127-B3E6-5C05E6BECBA7", + Nested: []*examplepb.ABitOfEverything_Nested{ + { + Name: "foo", + Amount: 12345, + }, + }, + Uint64Value: 0xFFFFFFFFFFFFFFFF, + OneofValue: &examplepb.ABitOfEverything_OneofString{ + OneofString: "bar", + }, + MapValue: map[string]examplepb.NumericEnum{ + "a": examplepb.NumericEnum_ONE, + "b": examplepb.NumericEnum_ZERO, + }, + } + + for _, spec := range []struct { + enumsAsInts, emitDefaults bool + indent string + origName bool + verifier func(json string) + }{ + { + verifier: func(json string) { + if strings.ContainsAny(json, " \t\r\n") { + t.Errorf("strings.ContainsAny(%q, %q) = true; want false", json, " \t\r\n") + } + if strings.Contains(json, "ONE") { + t.Errorf(`strings.Contains(%q, "ONE") = true; want false`, json) + } + if want := "uint64Value"; !strings.Contains(json, want) { + t.Errorf(`strings.Contains(%q, %q) = false; want true`, json, want) + } + }, + }, + { + enumsAsInts: true, + verifier: func(json string) { + if strings.Contains(json, "ONE") { + t.Errorf(`strings.Contains(%q, "ONE") = true; want false`, json) + } + }, + }, + { + emitDefaults: true, + verifier: func(json string) { + if want := `"sfixed32Value"`; !strings.Contains(json, want) { + t.Errorf(`strings.Contains(%q, %q) = false; want true`, json, want) + } + }, + }, + { + indent: "\t\t", + verifier: func(json string) { + if want := "\t\t\"amount\":"; !strings.Contains(json, want) { + t.Errorf(`strings.Contains(%q, %q) = false; want true`, json, want) + } + }, + }, + { + origName: true, + verifier: func(json string) { + if want := "uint64_value"; !strings.Contains(json, want) { + t.Errorf(`strings.Contains(%q, %q) = false; want true`, json, want) + } + }, + }, + } { + m := runtime.JSONPb{ + EnumsAsInts: spec.enumsAsInts, + EmitDefaults: spec.emitDefaults, + Indent: spec.indent, + OrigName: spec.origName, + } + + var buf bytes.Buffer + enc := m.NewEncoder(&buf) + if err := enc.Encode(&msg); err != nil { + t.Errorf("enc.Encode(%v) failed with %v; want success; spec=%v", &msg, err, spec) + } + + var got examplepb.ABitOfEverything + if err := jsonpb.UnmarshalString(buf.String(), &got); err != nil { + t.Errorf("jsonpb.UnmarshalString(%q, &got) failed with %v; want success; spec=%v", buf.String(), err, spec) + } + if want := msg; !reflect.DeepEqual(got, want) { + t.Errorf("got = %v; want %v; spec=%v", &got, &want, spec) + } + if spec.verifier != nil { + spec.verifier(buf.String()) + } + } +} + +func TestJSONPbEncoderFields(t *testing.T) { + var m runtime.JSONPb + for _, fixt := range fieldFixtures { + var buf bytes.Buffer + enc := m.NewEncoder(&buf) + if err := enc.Encode(fixt.data); err != nil { + t.Errorf("enc.Encode(%#v) failed with %v; want success", fixt.data, err) + } + if got, want := buf.String(), fixt.json; got != want { + t.Errorf("enc.Encode(%#v) = %q; want %q", fixt.data, got, want) + } + } + + m.EnumsAsInts = true + buf, err := m.Marshal(examplepb.NumericEnum_ONE) + if err != nil { + t.Errorf("m.Marshal(%#v) failed with %v; want success", examplepb.NumericEnum_ONE, err) + } + if got, want := string(buf), "1"; got != want { + t.Errorf("m.Marshal(%#v) = %q; want %q", examplepb.NumericEnum_ONE, got, want) + } +} + +func TestJSONPbDecoder(t *testing.T) { + var ( + m runtime.JSONPb + got examplepb.ABitOfEverything + ) + for _, data := range []string{ + `{ + "uuid": "6EC2446F-7E89-4127-B3E6-5C05E6BECBA7", + "nested": [ + {"name": "foo", "amount": 12345} + ], + "uint64Value": 18446744073709551615, + "enumValue": "ONE", + "oneofString": "bar", + "mapValue": { + "a": 1, + "b": 0 + } + }`, + `{ + "uuid": "6EC2446F-7E89-4127-B3E6-5C05E6BECBA7", + "nested": [ + {"name": "foo", "amount": 12345} + ], + "uint64Value": "18446744073709551615", + "enumValue": "ONE", + "oneofString": "bar", + "mapValue": { + "a": 1, + "b": 0 + } + }`, + `{ + "uuid": "6EC2446F-7E89-4127-B3E6-5C05E6BECBA7", + "nested": [ + {"name": "foo", "amount": 12345} + ], + "uint64Value": 18446744073709551615, + "enumValue": 1, + "oneofString": "bar", + "mapValue": { + "a": 1, + "b": 0 + } + }`, + } { + r := strings.NewReader(data) + dec := m.NewDecoder(r) + if err := dec.Decode(&got); err != nil { + t.Errorf("m.Unmarshal(&got) failed with %v; want success; data=%q", err, data) + } + + want := examplepb.ABitOfEverything{ + Uuid: "6EC2446F-7E89-4127-B3E6-5C05E6BECBA7", + Nested: []*examplepb.ABitOfEverything_Nested{ + { + Name: "foo", + Amount: 12345, + }, + }, + Uint64Value: 0xFFFFFFFFFFFFFFFF, + EnumValue: examplepb.NumericEnum_ONE, + OneofValue: &examplepb.ABitOfEverything_OneofString{ + OneofString: "bar", + }, + MapValue: map[string]examplepb.NumericEnum{ + "a": examplepb.NumericEnum_ONE, + "b": examplepb.NumericEnum_ZERO, + }, + } + if !reflect.DeepEqual(got, want) { + t.Errorf("got = %v; want = %v; data = %v", &got, &want, data) + } + } +} + +func TestJSONPbDecoderFields(t *testing.T) { + var m runtime.JSONPb + for _, fixt := range fieldFixtures { + if fixt.skipUnmarshal { + continue + } + + dest := reflect.New(reflect.TypeOf(fixt.data)) + dec := m.NewDecoder(strings.NewReader(fixt.json)) + if err := dec.Decode(dest.Interface()); err != nil { + t.Errorf("dec.Decode(%T) failed with %v; want success; input = %q", dest.Interface(), err, fixt.json) + } + if got, want := dest.Elem().Interface(), fixt.data; !reflect.DeepEqual(got, want) { + t.Errorf("dest = %#v; want %#v; input = %v", got, want, fixt.json) + } + } +} + +var ( + fieldFixtures = []struct { + data interface{} + json string + skipUnmarshal bool + }{ + {data: int32(1), json: "1"}, + {data: proto.Int32(1), json: "1"}, + {data: int64(1), json: "1"}, + {data: proto.Int64(1), json: "1"}, + {data: uint32(1), json: "1"}, + {data: proto.Uint32(1), json: "1"}, + {data: uint64(1), json: "1"}, + {data: proto.Uint64(1), json: "1"}, + {data: "abc", json: `"abc"`}, + {data: proto.String("abc"), json: `"abc"`}, + {data: float32(1.5), json: "1.5"}, + {data: proto.Float32(1.5), json: "1.5"}, + {data: float64(1.5), json: "1.5"}, + {data: proto.Float64(1.5), json: "1.5"}, + {data: true, json: "true"}, + {data: false, json: "false"}, + {data: (*string)(nil), json: "null"}, + { + data: examplepb.NumericEnum_ONE, + json: `"ONE"`, + // TODO(yugui) support unmarshaling of symbolic enum + skipUnmarshal: true, + }, + { + data: (*examplepb.NumericEnum)(proto.Int32(int32(examplepb.NumericEnum_ONE))), + json: `"ONE"`, + // TODO(yugui) support unmarshaling of symbolic enum + skipUnmarshal: true, + }, + + { + data: map[string]int32{ + "foo": 1, + }, + json: `{"foo":1}`, + }, + { + data: map[string]*examplepb.SimpleMessage{ + "foo": {Id: "bar"}, + }, + json: `{"foo":{"id":"bar"}}`, + }, + { + data: map[int32]*examplepb.SimpleMessage{ + 1: {Id: "foo"}, + }, + json: `{"1":{"id":"foo"}}`, + }, + { + data: map[bool]*examplepb.SimpleMessage{ + true: {Id: "foo"}, + }, + json: `{"true":{"id":"foo"}}`, + }, + { + data: &duration.Duration{ + Seconds: 123, + Nanos: 456000000, + }, + json: `"123.456s"`, + }, + { + data: ×tamp.Timestamp{ + Seconds: 1462875553, + Nanos: 123000000, + }, + json: `"2016-05-10T10:19:13.123Z"`, + }, + { + data: new(empty.Empty), + json: "{}", + }, + + // TODO(yugui) Enable unmarshaling of the following examples + // once jsonpb supports them. + { + data: &structpb.Value{ + Kind: new(structpb.Value_NullValue), + }, + json: "null", + skipUnmarshal: true, + }, + { + data: &structpb.Value{ + Kind: &structpb.Value_NumberValue{ + NumberValue: 123.4, + }, + }, + json: "123.4", + skipUnmarshal: true, + }, + { + data: &structpb.Value{ + Kind: &structpb.Value_StringValue{ + StringValue: "abc", + }, + }, + json: `"abc"`, + skipUnmarshal: true, + }, + { + data: &structpb.Value{ + Kind: &structpb.Value_BoolValue{ + BoolValue: true, + }, + }, + json: "true", + skipUnmarshal: true, + }, + { + data: &structpb.Struct{ + Fields: map[string]*structpb.Value{ + "foo_bar": { + Kind: &structpb.Value_BoolValue{ + BoolValue: true, + }, + }, + }, + }, + json: `{"foo_bar":true}`, + skipUnmarshal: true, + }, + + { + data: &wrappers.BoolValue{Value: true}, + json: "true", + }, + { + data: &wrappers.DoubleValue{Value: 123.456}, + json: "123.456", + }, + { + data: &wrappers.FloatValue{Value: 123.456}, + json: "123.456", + }, + { + data: &wrappers.Int32Value{Value: -123}, + json: "-123", + }, + { + data: &wrappers.Int64Value{Value: -123}, + json: `"-123"`, + }, + { + data: &wrappers.UInt32Value{Value: 123}, + json: "123", + }, + { + data: &wrappers.UInt64Value{Value: 123}, + json: `"123"`, + }, + // TODO(yugui) Add other well-known types once jsonpb supports them + } +) diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/marshaler.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/marshaler.go new file mode 100644 index 00000000000..6d434f13cb4 --- /dev/null +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/marshaler.go @@ -0,0 +1,42 @@ +package runtime + +import ( + "io" +) + +// Marshaler defines a conversion between byte sequence and gRPC payloads / fields. +type Marshaler interface { + // Marshal marshals "v" into byte sequence. + Marshal(v interface{}) ([]byte, error) + // Unmarshal unmarshals "data" into "v". + // "v" must be a pointer value. + Unmarshal(data []byte, v interface{}) error + // NewDecoder returns a Decoder which reads byte sequence from "r". + NewDecoder(r io.Reader) Decoder + // NewEncoder returns an Encoder which writes bytes sequence into "w". + NewEncoder(w io.Writer) Encoder + // ContentType returns the Content-Type which this marshaler is responsible for. + ContentType() string +} + +// Decoder decodes a byte sequence +type Decoder interface { + Decode(v interface{}) error +} + +// Encoder encodes gRPC payloads / fields into byte sequence. +type Encoder interface { + Encode(v interface{}) error +} + +// DecoderFunc adapts an decoder function into Decoder. +type DecoderFunc func(v interface{}) error + +// Decode delegates invocations to the underlying function itself. +func (f DecoderFunc) Decode(v interface{}) error { return f(v) } + +// EncoderFunc adapts an encoder function into Encoder +type EncoderFunc func(v interface{}) error + +// Encode delegates invocations to the underlying function itself. +func (f EncoderFunc) Encode(v interface{}) error { return f(v) } diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/marshaler_registry.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/marshaler_registry.go new file mode 100644 index 00000000000..928f0733214 --- /dev/null +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/marshaler_registry.go @@ -0,0 +1,91 @@ +package runtime + +import ( + "errors" + "net/http" +) + +// MIMEWildcard is the fallback MIME type used for requests which do not match +// a registered MIME type. +const MIMEWildcard = "*" + +var ( + acceptHeader = http.CanonicalHeaderKey("Accept") + contentTypeHeader = http.CanonicalHeaderKey("Content-Type") + + defaultMarshaler = &JSONPb{OrigName: true} +) + +// MarshalerForRequest returns the inbound/outbound marshalers for this request. +// It checks the registry on the ServeMux for the MIME type set by the Content-Type header. +// If it isn't set (or the request Content-Type is empty), checks for "*". +// If there are multiple Content-Type headers set, choose the first one that it can +// exactly match in the registry. +// Otherwise, it follows the above logic for "*"/InboundMarshaler/OutboundMarshaler. +func MarshalerForRequest(mux *ServeMux, r *http.Request) (inbound Marshaler, outbound Marshaler) { + for _, acceptVal := range r.Header[acceptHeader] { + if m, ok := mux.marshalers.mimeMap[acceptVal]; ok { + outbound = m + break + } + } + + for _, contentTypeVal := range r.Header[contentTypeHeader] { + if m, ok := mux.marshalers.mimeMap[contentTypeVal]; ok { + inbound = m + break + } + } + + if inbound == nil { + inbound = mux.marshalers.mimeMap[MIMEWildcard] + } + if outbound == nil { + outbound = inbound + } + + return inbound, outbound +} + +// marshalerRegistry is a mapping from MIME types to Marshalers. +type marshalerRegistry struct { + mimeMap map[string]Marshaler +} + +// add adds a marshaler for a case-sensitive MIME type string ("*" to match any +// MIME type). +func (m marshalerRegistry) add(mime string, marshaler Marshaler) error { + if len(mime) == 0 { + return errors.New("empty MIME type") + } + + m.mimeMap[mime] = marshaler + + return nil +} + +// makeMarshalerMIMERegistry returns a new registry of marshalers. +// It allows for a mapping of case-sensitive Content-Type MIME type string to runtime.Marshaler interfaces. +// +// For example, you could allow the client to specify the use of the runtime.JSONPb marshaler +// with a "applicaton/jsonpb" Content-Type and the use of the runtime.JSONBuiltin marshaler +// with a "application/json" Content-Type. +// "*" can be used to match any Content-Type. +// This can be attached to a ServerMux with the marshaler option. +func makeMarshalerMIMERegistry() marshalerRegistry { + return marshalerRegistry{ + mimeMap: map[string]Marshaler{ + MIMEWildcard: defaultMarshaler, + }, + } +} + +// WithMarshalerOption returns a ServeMuxOption which associates inbound and outbound +// Marshalers to a MIME type in mux. +func WithMarshalerOption(mime string, marshaler Marshaler) ServeMuxOption { + return func(mux *ServeMux) { + if err := mux.marshalers.add(mime, marshaler); err != nil { + panic(err) + } + } +} diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/marshaler_registry_test.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/marshaler_registry_test.go new file mode 100644 index 00000000000..194de6fee11 --- /dev/null +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/marshaler_registry_test.go @@ -0,0 +1,107 @@ +package runtime_test + +import ( + "errors" + "io" + "net/http" + "testing" + + "github.com/grpc-ecosystem/grpc-gateway/runtime" +) + +func TestMarshalerForRequest(t *testing.T) { + r, err := http.NewRequest("GET", "http://example.com", nil) + if err != nil { + t.Fatalf(`http.NewRequest("GET", "http://example.com", nil) failed with %v; want success`, err) + } + r.Header.Set("Accept", "application/x-out") + r.Header.Set("Content-Type", "application/x-in") + + mux := runtime.NewServeMux() + + in, out := runtime.MarshalerForRequest(mux, r) + if _, ok := in.(*runtime.JSONPb); !ok { + t.Errorf("in = %#v; want a runtime.JSONPb", in) + } + if _, ok := out.(*runtime.JSONPb); !ok { + t.Errorf("out = %#v; want a runtime.JSONPb", in) + } + + var marshalers [3]dummyMarshaler + specs := []struct { + opt runtime.ServeMuxOption + + wantIn runtime.Marshaler + wantOut runtime.Marshaler + }{ + { + opt: runtime.WithMarshalerOption(runtime.MIMEWildcard, &marshalers[0]), + wantIn: &marshalers[0], + wantOut: &marshalers[0], + }, + { + opt: runtime.WithMarshalerOption("application/x-in", &marshalers[1]), + wantIn: &marshalers[1], + wantOut: &marshalers[0], + }, + { + opt: runtime.WithMarshalerOption("application/x-out", &marshalers[2]), + wantIn: &marshalers[1], + wantOut: &marshalers[2], + }, + } + for i, spec := range specs { + var opts []runtime.ServeMuxOption + for _, s := range specs[:i+1] { + opts = append(opts, s.opt) + } + mux = runtime.NewServeMux(opts...) + + in, out = runtime.MarshalerForRequest(mux, r) + if got, want := in, spec.wantIn; got != want { + t.Errorf("in = %#v; want %#v", got, want) + } + if got, want := out, spec.wantOut; got != want { + t.Errorf("out = %#v; want %#v", got, want) + } + } + + r.Header.Set("Content-Type", "application/x-another") + in, out = runtime.MarshalerForRequest(mux, r) + if got, want := in, &marshalers[1]; got != want { + t.Errorf("in = %#v; want %#v", got, want) + } + if got, want := out, &marshalers[0]; got != want { + t.Errorf("out = %#v; want %#v", got, want) + } +} + +type dummyMarshaler struct{} + +func (dummyMarshaler) ContentType() string { return "" } +func (dummyMarshaler) Marshal(interface{}) ([]byte, error) { + return nil, errors.New("not implemented") +} + +func (dummyMarshaler) Unmarshal([]byte, interface{}) error { + return errors.New("not implemented") +} + +func (dummyMarshaler) NewDecoder(r io.Reader) runtime.Decoder { + return dummyDecoder{} +} +func (dummyMarshaler) NewEncoder(w io.Writer) runtime.Encoder { + return dummyEncoder{} +} + +type dummyDecoder struct{} + +func (dummyDecoder) Decode(interface{}) error { + return errors.New("not implemented") +} + +type dummyEncoder struct{} + +func (dummyEncoder) Encode(interface{}) error { + return errors.New("not implemented") +} diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/mux.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/mux.go new file mode 100644 index 00000000000..2e6c5621302 --- /dev/null +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/mux.go @@ -0,0 +1,132 @@ +package runtime + +import ( + "net/http" + "strings" + + "golang.org/x/net/context" + + "github.com/golang/protobuf/proto" +) + +// A HandlerFunc handles a specific pair of path pattern and HTTP method. +type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) + +// ServeMux is a request multiplexer for grpc-gateway. +// It matches http requests to patterns and invokes the corresponding handler. +type ServeMux struct { + // handlers maps HTTP method to a list of handlers. + handlers map[string][]handler + forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error + marshalers marshalerRegistry +} + +// ServeMuxOption is an option that can be given to a ServeMux on construction. +type ServeMuxOption func(*ServeMux) + +// WithForwardResponseOption returns a ServeMuxOption representing the forwardResponseOption. +// +// forwardResponseOption is an option that will be called on the relevant context.Context, +// http.ResponseWriter, and proto.Message before every forwarded response. +// +// The message may be nil in the case where just a header is being sent. +func WithForwardResponseOption(forwardResponseOption func(context.Context, http.ResponseWriter, proto.Message) error) ServeMuxOption { + return func(serveMux *ServeMux) { + serveMux.forwardResponseOptions = append(serveMux.forwardResponseOptions, forwardResponseOption) + } +} + +// NewServeMux returns a new ServeMux whose internal mapping is empty. +func NewServeMux(opts ...ServeMuxOption) *ServeMux { + serveMux := &ServeMux{ + handlers: make(map[string][]handler), + forwardResponseOptions: make([]func(context.Context, http.ResponseWriter, proto.Message) error, 0), + marshalers: makeMarshalerMIMERegistry(), + } + + for _, opt := range opts { + opt(serveMux) + } + return serveMux +} + +// Handle associates "h" to the pair of HTTP method and path pattern. +func (s *ServeMux) Handle(meth string, pat Pattern, h HandlerFunc) { + s.handlers[meth] = append(s.handlers[meth], handler{pat: pat, h: h}) +} + +// ServeHTTP dispatches the request to the first handler whose pattern matches to r.Method and r.Path. +func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { + path := r.URL.Path + if !strings.HasPrefix(path, "/") { + OtherErrorHandler(w, r, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) + return + } + + components := strings.Split(path[1:], "/") + l := len(components) + var verb string + if idx := strings.LastIndex(components[l-1], ":"); idx == 0 { + OtherErrorHandler(w, r, http.StatusText(http.StatusNotFound), http.StatusNotFound) + return + } else if idx > 0 { + c := components[l-1] + components[l-1], verb = c[:idx], c[idx+1:] + } + + if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && isPathLengthFallback(r) { + r.Method = strings.ToUpper(override) + if err := r.ParseForm(); err != nil { + OtherErrorHandler(w, r, err.Error(), http.StatusBadRequest) + return + } + } + for _, h := range s.handlers[r.Method] { + pathParams, err := h.pat.Match(components, verb) + if err != nil { + continue + } + h.h(w, r, pathParams) + return + } + + // lookup other methods to handle fallback from GET to POST and + // to determine if it is MethodNotAllowed or NotFound. + for m, handlers := range s.handlers { + if m == r.Method { + continue + } + for _, h := range handlers { + pathParams, err := h.pat.Match(components, verb) + if err != nil { + continue + } + // X-HTTP-Method-Override is optional. Always allow fallback to POST. + if isPathLengthFallback(r) { + if err := r.ParseForm(); err != nil { + OtherErrorHandler(w, r, err.Error(), http.StatusBadRequest) + return + } + h.h(w, r, pathParams) + return + } + OtherErrorHandler(w, r, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + return + } + } + OtherErrorHandler(w, r, http.StatusText(http.StatusNotFound), http.StatusNotFound) +} + +// GetForwardResponseOptions returns the ForwardResponseOptions associated with this ServeMux. +func (s *ServeMux) GetForwardResponseOptions() []func(context.Context, http.ResponseWriter, proto.Message) error { + return s.forwardResponseOptions +} + +func isPathLengthFallback(r *http.Request) bool { + return r.Method == "POST" && r.Header.Get("Content-Type") == "application/x-www-form-urlencoded" +} + +type handler struct { + pat Pattern + h HandlerFunc +} diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/mux_test.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/mux_test.go new file mode 100644 index 00000000000..bb90a7306a1 --- /dev/null +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/mux_test.go @@ -0,0 +1,213 @@ +package runtime_test + +import ( + "bytes" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/grpc-ecosystem/grpc-gateway/runtime" + "github.com/grpc-ecosystem/grpc-gateway/utilities" +) + +func TestMuxServeHTTP(t *testing.T) { + type stubPattern struct { + method string + ops []int + pool []string + verb string + } + for _, spec := range []struct { + patterns []stubPattern + + reqMethod string + reqPath string + headers map[string]string + + respStatus int + respContent string + }{ + { + patterns: nil, + reqMethod: "GET", + reqPath: "/", + respStatus: http.StatusNotFound, + }, + { + patterns: []stubPattern{ + { + method: "GET", + ops: []int{int(utilities.OpLitPush), 0}, + pool: []string{"foo"}, + }, + }, + reqMethod: "GET", + reqPath: "/foo", + respStatus: http.StatusOK, + respContent: "GET /foo", + }, + { + patterns: []stubPattern{ + { + method: "GET", + ops: []int{int(utilities.OpLitPush), 0}, + pool: []string{"foo"}, + }, + }, + reqMethod: "GET", + reqPath: "/bar", + respStatus: http.StatusNotFound, + }, + { + patterns: []stubPattern{ + { + method: "GET", + ops: []int{int(utilities.OpLitPush), 0}, + pool: []string{"foo"}, + }, + { + method: "GET", + ops: []int{int(utilities.OpPush), 0}, + }, + }, + reqMethod: "GET", + reqPath: "/foo", + respStatus: http.StatusOK, + respContent: "GET /foo", + }, + { + patterns: []stubPattern{ + { + method: "GET", + ops: []int{int(utilities.OpLitPush), 0}, + pool: []string{"foo"}, + }, + { + method: "POST", + ops: []int{int(utilities.OpLitPush), 0}, + pool: []string{"foo"}, + }, + }, + reqMethod: "POST", + reqPath: "/foo", + respStatus: http.StatusOK, + respContent: "POST /foo", + }, + { + patterns: []stubPattern{ + { + method: "GET", + ops: []int{int(utilities.OpLitPush), 0}, + pool: []string{"foo"}, + }, + }, + reqMethod: "DELETE", + reqPath: "/foo", + respStatus: http.StatusMethodNotAllowed, + }, + { + patterns: []stubPattern{ + { + method: "GET", + ops: []int{int(utilities.OpLitPush), 0}, + pool: []string{"foo"}, + }, + }, + reqMethod: "POST", + reqPath: "/foo", + headers: map[string]string{ + "Content-Type": "application/x-www-form-urlencoded", + }, + respStatus: http.StatusOK, + respContent: "GET /foo", + }, + { + patterns: []stubPattern{ + { + method: "GET", + ops: []int{int(utilities.OpLitPush), 0}, + pool: []string{"foo"}, + }, + { + method: "POST", + ops: []int{int(utilities.OpLitPush), 0}, + pool: []string{"foo"}, + }, + }, + reqMethod: "POST", + reqPath: "/foo", + headers: map[string]string{ + "Content-Type": "application/x-www-form-urlencoded", + "X-HTTP-Method-Override": "GET", + }, + respStatus: http.StatusOK, + respContent: "GET /foo", + }, + { + patterns: []stubPattern{ + { + method: "GET", + ops: []int{int(utilities.OpLitPush), 0}, + pool: []string{"foo"}, + }, + }, + reqMethod: "POST", + reqPath: "/foo", + headers: map[string]string{ + "Content-Type": "application/json", + }, + respStatus: http.StatusMethodNotAllowed, + }, + { + patterns: []stubPattern{ + { + method: "POST", + ops: []int{int(utilities.OpLitPush), 0}, + pool: []string{"foo"}, + verb: "bar", + }, + }, + reqMethod: "POST", + reqPath: "/foo:bar", + headers: map[string]string{ + "Content-Type": "application/json", + }, + respStatus: http.StatusOK, + respContent: "POST /foo:bar", + }, + } { + mux := runtime.NewServeMux() + for _, p := range spec.patterns { + func(p stubPattern) { + pat, err := runtime.NewPattern(1, p.ops, p.pool, p.verb) + if err != nil { + t.Fatalf("runtime.NewPattern(1, %#v, %#v, %q) failed with %v; want success", p.ops, p.pool, p.verb, err) + } + mux.Handle(p.method, pat, func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) { + fmt.Fprintf(w, "%s %s", p.method, pat.String()) + }) + }(p) + } + + url := fmt.Sprintf("http://host.example%s", spec.reqPath) + r, err := http.NewRequest(spec.reqMethod, url, bytes.NewReader(nil)) + if err != nil { + t.Fatalf("http.NewRequest(%q, %q, nil) failed with %v; want success", spec.reqMethod, url, err) + } + for name, value := range spec.headers { + r.Header.Set(name, value) + } + w := httptest.NewRecorder() + mux.ServeHTTP(w, r) + + if got, want := w.Code, spec.respStatus; got != want { + t.Errorf("w.Code = %d; want %d; patterns=%v; req=%v", got, want, spec.patterns, r) + } + if spec.respContent != "" { + if got, want := w.Body.String(), spec.respContent; got != want { + t.Errorf("w.Body = %q; want %q; patterns=%v; req=%v", got, want, spec.patterns, r) + } + } + } +} diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/pattern.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/pattern.go new file mode 100644 index 00000000000..3947dbea023 --- /dev/null +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/pattern.go @@ -0,0 +1,227 @@ +package runtime + +import ( + "errors" + "fmt" + "strings" + + "github.com/grpc-ecosystem/grpc-gateway/utilities" + "google.golang.org/grpc/grpclog" +) + +var ( + // ErrNotMatch indicates that the given HTTP request path does not match to the pattern. + ErrNotMatch = errors.New("not match to the path pattern") + // ErrInvalidPattern indicates that the given definition of Pattern is not valid. + ErrInvalidPattern = errors.New("invalid pattern") +) + +type op struct { + code utilities.OpCode + operand int +} + +// Pattern is a template pattern of http request paths defined in third_party/googleapis/google/api/http.proto. +type Pattern struct { + // ops is a list of operations + ops []op + // pool is a constant pool indexed by the operands or vars. + pool []string + // vars is a list of variables names to be bound by this pattern + vars []string + // stacksize is the max depth of the stack + stacksize int + // tailLen is the length of the fixed-size segments after a deep wildcard + tailLen int + // verb is the VERB part of the path pattern. It is empty if the pattern does not have VERB part. + verb string +} + +// NewPattern returns a new Pattern from the given definition values. +// "ops" is a sequence of op codes. "pool" is a constant pool. +// "verb" is the verb part of the pattern. It is empty if the pattern does not have the part. +// "version" must be 1 for now. +// It returns an error if the given definition is invalid. +func NewPattern(version int, ops []int, pool []string, verb string) (Pattern, error) { + if version != 1 { + grpclog.Printf("unsupported version: %d", version) + return Pattern{}, ErrInvalidPattern + } + + l := len(ops) + if l%2 != 0 { + grpclog.Printf("odd number of ops codes: %d", l) + return Pattern{}, ErrInvalidPattern + } + + var ( + typedOps []op + stack, maxstack int + tailLen int + pushMSeen bool + vars []string + ) + for i := 0; i < l; i += 2 { + op := op{code: utilities.OpCode(ops[i]), operand: ops[i+1]} + switch op.code { + case utilities.OpNop: + continue + case utilities.OpPush: + if pushMSeen { + tailLen++ + } + stack++ + case utilities.OpPushM: + if pushMSeen { + grpclog.Printf("pushM appears twice") + return Pattern{}, ErrInvalidPattern + } + pushMSeen = true + stack++ + case utilities.OpLitPush: + if op.operand < 0 || len(pool) <= op.operand { + grpclog.Printf("negative literal index: %d", op.operand) + return Pattern{}, ErrInvalidPattern + } + if pushMSeen { + tailLen++ + } + stack++ + case utilities.OpConcatN: + if op.operand <= 0 { + grpclog.Printf("negative concat size: %d", op.operand) + return Pattern{}, ErrInvalidPattern + } + stack -= op.operand + if stack < 0 { + grpclog.Print("stack underflow") + return Pattern{}, ErrInvalidPattern + } + stack++ + case utilities.OpCapture: + if op.operand < 0 || len(pool) <= op.operand { + grpclog.Printf("variable name index out of bound: %d", op.operand) + return Pattern{}, ErrInvalidPattern + } + v := pool[op.operand] + op.operand = len(vars) + vars = append(vars, v) + stack-- + if stack < 0 { + grpclog.Printf("stack underflow") + return Pattern{}, ErrInvalidPattern + } + default: + grpclog.Printf("invalid opcode: %d", op.code) + return Pattern{}, ErrInvalidPattern + } + + if maxstack < stack { + maxstack = stack + } + typedOps = append(typedOps, op) + } + return Pattern{ + ops: typedOps, + pool: pool, + vars: vars, + stacksize: maxstack, + tailLen: tailLen, + verb: verb, + }, nil +} + +// MustPattern is a helper function which makes it easier to call NewPattern in variable initialization. +func MustPattern(p Pattern, err error) Pattern { + if err != nil { + grpclog.Fatalf("Pattern initialization failed: %v", err) + } + return p +} + +// Match examines components if it matches to the Pattern. +// If it matches, the function returns a mapping from field paths to their captured values. +// If otherwise, the function returns an error. +func (p Pattern) Match(components []string, verb string) (map[string]string, error) { + if p.verb != verb { + return nil, ErrNotMatch + } + + var pos int + stack := make([]string, 0, p.stacksize) + captured := make([]string, len(p.vars)) + l := len(components) + for _, op := range p.ops { + switch op.code { + case utilities.OpNop: + continue + case utilities.OpPush, utilities.OpLitPush: + if pos >= l { + return nil, ErrNotMatch + } + c := components[pos] + if op.code == utilities.OpLitPush { + if lit := p.pool[op.operand]; c != lit { + return nil, ErrNotMatch + } + } + stack = append(stack, c) + pos++ + case utilities.OpPushM: + end := len(components) + if end < pos+p.tailLen { + return nil, ErrNotMatch + } + end -= p.tailLen + stack = append(stack, strings.Join(components[pos:end], "/")) + pos = end + case utilities.OpConcatN: + n := op.operand + l := len(stack) - n + stack = append(stack[:l], strings.Join(stack[l:], "/")) + case utilities.OpCapture: + n := len(stack) - 1 + captured[op.operand] = stack[n] + stack = stack[:n] + } + } + if pos < l { + return nil, ErrNotMatch + } + bindings := make(map[string]string) + for i, val := range captured { + bindings[p.vars[i]] = val + } + return bindings, nil +} + +// Verb returns the verb part of the Pattern. +func (p Pattern) Verb() string { return p.verb } + +func (p Pattern) String() string { + var stack []string + for _, op := range p.ops { + switch op.code { + case utilities.OpNop: + continue + case utilities.OpPush: + stack = append(stack, "*") + case utilities.OpLitPush: + stack = append(stack, p.pool[op.operand]) + case utilities.OpPushM: + stack = append(stack, "**") + case utilities.OpConcatN: + n := op.operand + l := len(stack) - n + stack = append(stack[:l], strings.Join(stack[l:], "/")) + case utilities.OpCapture: + n := len(stack) - 1 + stack[n] = fmt.Sprintf("{%s=%s}", p.vars[op.operand], stack[n]) + } + } + segs := strings.Join(stack, "/") + if p.verb != "" { + return fmt.Sprintf("/%s:%s", segs, p.verb) + } + return "/" + segs +} diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/pattern_test.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/pattern_test.go new file mode 100644 index 00000000000..8f5a664aba5 --- /dev/null +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/pattern_test.go @@ -0,0 +1,590 @@ +package runtime + +import ( + "fmt" + "reflect" + "strings" + "testing" + + "github.com/grpc-ecosystem/grpc-gateway/utilities" +) + +const ( + validVersion = 1 + anything = 0 +) + +func TestNewPattern(t *testing.T) { + for _, spec := range []struct { + ops []int + pool []string + verb string + + stackSizeWant, tailLenWant int + }{ + {}, + { + ops: []int{int(utilities.OpNop), anything}, + stackSizeWant: 0, + tailLenWant: 0, + }, + { + ops: []int{int(utilities.OpPush), anything}, + stackSizeWant: 1, + tailLenWant: 0, + }, + { + ops: []int{int(utilities.OpLitPush), 0}, + pool: []string{"abc"}, + stackSizeWant: 1, + tailLenWant: 0, + }, + { + ops: []int{int(utilities.OpPushM), anything}, + stackSizeWant: 1, + tailLenWant: 0, + }, + { + ops: []int{ + int(utilities.OpPush), anything, + int(utilities.OpConcatN), 1, + }, + stackSizeWant: 1, + tailLenWant: 0, + }, + { + ops: []int{ + int(utilities.OpPush), anything, + int(utilities.OpConcatN), 1, + int(utilities.OpCapture), 0, + }, + pool: []string{"abc"}, + stackSizeWant: 1, + tailLenWant: 0, + }, + { + ops: []int{ + int(utilities.OpPush), anything, + int(utilities.OpLitPush), 0, + int(utilities.OpLitPush), 1, + int(utilities.OpPushM), anything, + int(utilities.OpConcatN), 2, + int(utilities.OpCapture), 2, + }, + pool: []string{"lit1", "lit2", "var1"}, + stackSizeWant: 4, + tailLenWant: 0, + }, + { + ops: []int{ + int(utilities.OpPushM), anything, + int(utilities.OpConcatN), 1, + int(utilities.OpCapture), 2, + int(utilities.OpLitPush), 0, + int(utilities.OpLitPush), 1, + }, + pool: []string{"lit1", "lit2", "var1"}, + stackSizeWant: 2, + tailLenWant: 2, + }, + { + ops: []int{ + int(utilities.OpLitPush), 0, + int(utilities.OpLitPush), 1, + int(utilities.OpPushM), anything, + int(utilities.OpLitPush), 2, + int(utilities.OpConcatN), 3, + int(utilities.OpLitPush), 3, + int(utilities.OpCapture), 4, + }, + pool: []string{"lit1", "lit2", "lit3", "lit4", "var1"}, + stackSizeWant: 4, + tailLenWant: 2, + }, + { + ops: []int{int(utilities.OpLitPush), 0}, + pool: []string{"abc"}, + verb: "LOCK", + stackSizeWant: 1, + tailLenWant: 0, + }, + } { + pat, err := NewPattern(validVersion, spec.ops, spec.pool, spec.verb) + if err != nil { + t.Errorf("NewPattern(%d, %v, %q, %q) failed with %v; want success", validVersion, spec.ops, spec.pool, spec.verb, err) + continue + } + if got, want := pat.stacksize, spec.stackSizeWant; got != want { + t.Errorf("pat.stacksize = %d; want %d", got, want) + } + if got, want := pat.tailLen, spec.tailLenWant; got != want { + t.Errorf("pat.stacksize = %d; want %d", got, want) + } + } +} + +func TestNewPatternWithWrongOp(t *testing.T) { + for _, spec := range []struct { + ops []int + pool []string + verb string + }{ + { + // op code out of bound + ops: []int{-1, anything}, + }, + { + // op code out of bound + ops: []int{int(utilities.OpEnd), 0}, + }, + { + // odd number of items + ops: []int{int(utilities.OpPush)}, + }, + { + // negative index + ops: []int{int(utilities.OpLitPush), -1}, + pool: []string{"abc"}, + }, + { + // index out of bound + ops: []int{int(utilities.OpLitPush), 1}, + pool: []string{"abc"}, + }, + { + // negative # of segments + ops: []int{int(utilities.OpConcatN), -1}, + pool: []string{"abc"}, + }, + { + // negative index + ops: []int{int(utilities.OpCapture), -1}, + pool: []string{"abc"}, + }, + { + // index out of bound + ops: []int{int(utilities.OpCapture), 1}, + pool: []string{"abc"}, + }, + { + // pushM appears twice + ops: []int{ + int(utilities.OpPushM), anything, + int(utilities.OpLitPush), 0, + int(utilities.OpPushM), anything, + }, + pool: []string{"abc"}, + }, + } { + _, err := NewPattern(validVersion, spec.ops, spec.pool, spec.verb) + if err == nil { + t.Errorf("NewPattern(%d, %v, %q, %q) succeeded; want failure with %v", validVersion, spec.ops, spec.pool, spec.verb, ErrInvalidPattern) + continue + } + if err != ErrInvalidPattern { + t.Errorf("NewPattern(%d, %v, %q, %q) failed with %v; want failure with %v", validVersion, spec.ops, spec.pool, spec.verb, err, ErrInvalidPattern) + continue + } + } +} + +func TestNewPatternWithStackUnderflow(t *testing.T) { + for _, spec := range []struct { + ops []int + pool []string + verb string + }{ + { + ops: []int{int(utilities.OpConcatN), 1}, + }, + { + ops: []int{int(utilities.OpCapture), 0}, + pool: []string{"abc"}, + }, + } { + _, err := NewPattern(validVersion, spec.ops, spec.pool, spec.verb) + if err == nil { + t.Errorf("NewPattern(%d, %v, %q, %q) succeeded; want failure with %v", validVersion, spec.ops, spec.pool, spec.verb, ErrInvalidPattern) + continue + } + if err != ErrInvalidPattern { + t.Errorf("NewPattern(%d, %v, %q, %q) failed with %v; want failure with %v", validVersion, spec.ops, spec.pool, spec.verb, err, ErrInvalidPattern) + continue + } + } +} + +func TestMatch(t *testing.T) { + for _, spec := range []struct { + ops []int + pool []string + verb string + + match []string + notMatch []string + }{ + { + match: []string{""}, + notMatch: []string{"example"}, + }, + { + ops: []int{int(utilities.OpNop), anything}, + match: []string{""}, + notMatch: []string{"example", "path/to/example"}, + }, + { + ops: []int{int(utilities.OpPush), anything}, + match: []string{"abc", "def"}, + notMatch: []string{"", "abc/def"}, + }, + { + ops: []int{int(utilities.OpLitPush), 0}, + pool: []string{"v1"}, + match: []string{"v1"}, + notMatch: []string{"", "v2"}, + }, + { + ops: []int{int(utilities.OpPushM), anything}, + match: []string{"", "abc", "abc/def", "abc/def/ghi"}, + }, + { + ops: []int{ + int(utilities.OpPushM), anything, + int(utilities.OpLitPush), 0, + }, + pool: []string{"tail"}, + match: []string{"tail", "abc/tail", "abc/def/tail"}, + notMatch: []string{ + "", "abc", "abc/def", + "tail/extra", "abc/tail/extra", "abc/def/tail/extra", + }, + }, + { + ops: []int{ + int(utilities.OpLitPush), 0, + int(utilities.OpLitPush), 1, + int(utilities.OpPush), anything, + int(utilities.OpConcatN), 1, + int(utilities.OpCapture), 2, + }, + pool: []string{"v1", "bucket", "name"}, + match: []string{"v1/bucket/my-bucket", "v1/bucket/our-bucket"}, + notMatch: []string{ + "", + "v1", + "v1/bucket", + "v2/bucket/my-bucket", + "v1/pubsub/my-topic", + }, + }, + { + ops: []int{ + int(utilities.OpLitPush), 0, + int(utilities.OpLitPush), 1, + int(utilities.OpPushM), anything, + int(utilities.OpConcatN), 2, + int(utilities.OpCapture), 2, + }, + pool: []string{"v1", "o", "name"}, + match: []string{ + "v1/o", + "v1/o/my-bucket", + "v1/o/our-bucket", + "v1/o/my-bucket/dir", + "v1/o/my-bucket/dir/dir2", + "v1/o/my-bucket/dir/dir2/obj", + }, + notMatch: []string{ + "", + "v1", + "v2/o/my-bucket", + "v1/b/my-bucket", + }, + }, + { + ops: []int{ + int(utilities.OpLitPush), 0, + int(utilities.OpLitPush), 1, + int(utilities.OpPush), anything, + int(utilities.OpConcatN), 2, + int(utilities.OpCapture), 2, + int(utilities.OpLitPush), 3, + int(utilities.OpPush), anything, + int(utilities.OpConcatN), 1, + int(utilities.OpCapture), 4, + }, + pool: []string{"v2", "b", "name", "o", "oname"}, + match: []string{ + "v2/b/my-bucket/o/obj", + "v2/b/our-bucket/o/obj", + "v2/b/my-bucket/o/dir", + }, + notMatch: []string{ + "", + "v2", + "v2/b", + "v2/b/my-bucket", + "v2/b/my-bucket/o", + }, + }, + { + ops: []int{int(utilities.OpLitPush), 0}, + pool: []string{"v1"}, + verb: "LOCK", + match: []string{"v1:LOCK"}, + notMatch: []string{"v1", "LOCK"}, + }, + } { + pat, err := NewPattern(validVersion, spec.ops, spec.pool, spec.verb) + if err != nil { + t.Errorf("NewPattern(%d, %v, %q, %q) failed with %v; want success", validVersion, spec.ops, spec.pool, spec.verb, err) + continue + } + + for _, path := range spec.match { + _, err = pat.Match(segments(path)) + if err != nil { + t.Errorf("pat.Match(%q) failed with %v; want success; pattern = (%v, %q)", path, err, spec.ops, spec.pool) + } + } + + for _, path := range spec.notMatch { + _, err = pat.Match(segments(path)) + if err == nil { + t.Errorf("pat.Match(%q) succeeded; want failure with %v; pattern = (%v, %q)", path, ErrNotMatch, spec.ops, spec.pool) + continue + } + if err != ErrNotMatch { + t.Errorf("pat.Match(%q) failed with %v; want failure with %v; pattern = (%v, %q)", spec.notMatch, err, ErrNotMatch, spec.ops, spec.pool) + } + } + } +} + +func TestMatchWithBinding(t *testing.T) { + for _, spec := range []struct { + ops []int + pool []string + path string + verb string + + want map[string]string + }{ + { + want: make(map[string]string), + }, + { + ops: []int{int(utilities.OpNop), anything}, + want: make(map[string]string), + }, + { + ops: []int{int(utilities.OpPush), anything}, + path: "abc", + want: make(map[string]string), + }, + { + ops: []int{int(utilities.OpPush), anything}, + verb: "LOCK", + path: "abc:LOCK", + want: make(map[string]string), + }, + { + ops: []int{int(utilities.OpLitPush), 0}, + pool: []string{"endpoint"}, + path: "endpoint", + want: make(map[string]string), + }, + { + ops: []int{int(utilities.OpPushM), anything}, + path: "abc/def/ghi", + want: make(map[string]string), + }, + { + ops: []int{ + int(utilities.OpLitPush), 0, + int(utilities.OpLitPush), 1, + int(utilities.OpPush), anything, + int(utilities.OpConcatN), 1, + int(utilities.OpCapture), 2, + }, + pool: []string{"v1", "bucket", "name"}, + path: "v1/bucket/my-bucket", + want: map[string]string{ + "name": "my-bucket", + }, + }, + { + ops: []int{ + int(utilities.OpLitPush), 0, + int(utilities.OpLitPush), 1, + int(utilities.OpPush), anything, + int(utilities.OpConcatN), 1, + int(utilities.OpCapture), 2, + }, + pool: []string{"v1", "bucket", "name"}, + verb: "LOCK", + path: "v1/bucket/my-bucket:LOCK", + want: map[string]string{ + "name": "my-bucket", + }, + }, + { + ops: []int{ + int(utilities.OpLitPush), 0, + int(utilities.OpLitPush), 1, + int(utilities.OpPushM), anything, + int(utilities.OpConcatN), 2, + int(utilities.OpCapture), 2, + }, + pool: []string{"v1", "o", "name"}, + path: "v1/o/my-bucket/dir/dir2/obj", + want: map[string]string{ + "name": "o/my-bucket/dir/dir2/obj", + }, + }, + { + ops: []int{ + int(utilities.OpLitPush), 0, + int(utilities.OpLitPush), 1, + int(utilities.OpPushM), anything, + int(utilities.OpLitPush), 2, + int(utilities.OpConcatN), 3, + int(utilities.OpCapture), 4, + int(utilities.OpLitPush), 3, + }, + pool: []string{"v1", "o", ".ext", "tail", "name"}, + path: "v1/o/my-bucket/dir/dir2/obj/.ext/tail", + want: map[string]string{ + "name": "o/my-bucket/dir/dir2/obj/.ext", + }, + }, + { + ops: []int{ + int(utilities.OpLitPush), 0, + int(utilities.OpLitPush), 1, + int(utilities.OpPush), anything, + int(utilities.OpConcatN), 2, + int(utilities.OpCapture), 2, + int(utilities.OpLitPush), 3, + int(utilities.OpPush), anything, + int(utilities.OpConcatN), 1, + int(utilities.OpCapture), 4, + }, + pool: []string{"v2", "b", "name", "o", "oname"}, + path: "v2/b/my-bucket/o/obj", + want: map[string]string{ + "name": "b/my-bucket", + "oname": "obj", + }, + }, + } { + pat, err := NewPattern(validVersion, spec.ops, spec.pool, spec.verb) + if err != nil { + t.Errorf("NewPattern(%d, %v, %q, %q) failed with %v; want success", validVersion, spec.ops, spec.pool, spec.verb, err) + continue + } + + got, err := pat.Match(segments(spec.path)) + if err != nil { + t.Errorf("pat.Match(%q) failed with %v; want success; pattern = (%v, %q)", spec.path, err, spec.ops, spec.pool) + } + if !reflect.DeepEqual(got, spec.want) { + t.Errorf("pat.Match(%q) = %q; want %q; pattern = (%v, %q)", spec.path, got, spec.want, spec.ops, spec.pool) + } + } +} + +func segments(path string) (components []string, verb string) { + if path == "" { + return nil, "" + } + components = strings.Split(path, "/") + l := len(components) + c := components[l-1] + if idx := strings.LastIndex(c, ":"); idx >= 0 { + components[l-1], verb = c[:idx], c[idx+1:] + } + return components, verb +} + +func TestPatternString(t *testing.T) { + for _, spec := range []struct { + ops []int + pool []string + + want string + }{ + { + want: "/", + }, + { + ops: []int{int(utilities.OpNop), anything}, + want: "/", + }, + { + ops: []int{int(utilities.OpPush), anything}, + want: "/*", + }, + { + ops: []int{int(utilities.OpLitPush), 0}, + pool: []string{"endpoint"}, + want: "/endpoint", + }, + { + ops: []int{int(utilities.OpPushM), anything}, + want: "/**", + }, + { + ops: []int{ + int(utilities.OpPush), anything, + int(utilities.OpConcatN), 1, + }, + want: "/*", + }, + { + ops: []int{ + int(utilities.OpPush), anything, + int(utilities.OpConcatN), 1, + int(utilities.OpCapture), 0, + }, + pool: []string{"name"}, + want: "/{name=*}", + }, + { + ops: []int{ + int(utilities.OpLitPush), 0, + int(utilities.OpLitPush), 1, + int(utilities.OpPush), anything, + int(utilities.OpConcatN), 2, + int(utilities.OpCapture), 2, + int(utilities.OpLitPush), 3, + int(utilities.OpPushM), anything, + int(utilities.OpLitPush), 4, + int(utilities.OpConcatN), 3, + int(utilities.OpCapture), 6, + int(utilities.OpLitPush), 5, + }, + pool: []string{"v1", "buckets", "bucket_name", "objects", ".ext", "tail", "name"}, + want: "/v1/{bucket_name=buckets/*}/{name=objects/**/.ext}/tail", + }, + } { + p, err := NewPattern(validVersion, spec.ops, spec.pool, "") + if err != nil { + t.Errorf("NewPattern(%d, %v, %q, %q) failed with %v; want success", validVersion, spec.ops, spec.pool, "", err) + continue + } + if got, want := p.String(), spec.want; got != want { + t.Errorf("%#v.String() = %q; want %q", p, got, want) + } + + verb := "LOCK" + p, err = NewPattern(validVersion, spec.ops, spec.pool, verb) + if err != nil { + t.Errorf("NewPattern(%d, %v, %q, %q) failed with %v; want success", validVersion, spec.ops, spec.pool, verb, err) + continue + } + if got, want := p.String(), fmt.Sprintf("%s:%s", spec.want, verb); got != want { + t.Errorf("%#v.String() = %q; want %q", p, got, want) + } + } +} diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/proto2_convert.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/proto2_convert.go new file mode 100644 index 00000000000..a3151e2a552 --- /dev/null +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/proto2_convert.go @@ -0,0 +1,80 @@ +package runtime + +import ( + "github.com/golang/protobuf/proto" +) + +// StringP returns a pointer to a string whose pointee is same as the given string value. +func StringP(val string) (*string, error) { + return proto.String(val), nil +} + +// BoolP parses the given string representation of a boolean value, +// and returns a pointer to a bool whose value is same as the parsed value. +func BoolP(val string) (*bool, error) { + b, err := Bool(val) + if err != nil { + return nil, err + } + return proto.Bool(b), nil +} + +// Float64P parses the given string representation of a floating point number, +// and returns a pointer to a float64 whose value is same as the parsed number. +func Float64P(val string) (*float64, error) { + f, err := Float64(val) + if err != nil { + return nil, err + } + return proto.Float64(f), nil +} + +// Float32P parses the given string representation of a floating point number, +// and returns a pointer to a float32 whose value is same as the parsed number. +func Float32P(val string) (*float32, error) { + f, err := Float32(val) + if err != nil { + return nil, err + } + return proto.Float32(f), nil +} + +// Int64P parses the given string representation of an integer +// and returns a pointer to a int64 whose value is same as the parsed integer. +func Int64P(val string) (*int64, error) { + i, err := Int64(val) + if err != nil { + return nil, err + } + return proto.Int64(i), nil +} + +// Int32P parses the given string representation of an integer +// and returns a pointer to a int32 whose value is same as the parsed integer. +func Int32P(val string) (*int32, error) { + i, err := Int32(val) + if err != nil { + return nil, err + } + return proto.Int32(i), err +} + +// Uint64P parses the given string representation of an integer +// and returns a pointer to a uint64 whose value is same as the parsed integer. +func Uint64P(val string) (*uint64, error) { + i, err := Uint64(val) + if err != nil { + return nil, err + } + return proto.Uint64(i), err +} + +// Uint32P parses the given string representation of an integer +// and returns a pointer to a uint32 whose value is same as the parsed integer. +func Uint32P(val string) (*uint32, error) { + i, err := Uint32(val) + if err != nil { + return nil, err + } + return proto.Uint32(i), err +} diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/query.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/query.go new file mode 100644 index 00000000000..56a919a52f1 --- /dev/null +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/query.go @@ -0,0 +1,140 @@ +package runtime + +import ( + "fmt" + "net/url" + "reflect" + "strings" + + "github.com/golang/protobuf/proto" + "github.com/grpc-ecosystem/grpc-gateway/utilities" + "google.golang.org/grpc/grpclog" +) + +// PopulateQueryParameters populates "values" into "msg". +// A value is ignored if its key starts with one of the elements in "filter". +func PopulateQueryParameters(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error { + for key, values := range values { + fieldPath := strings.Split(key, ".") + if filter.HasCommonPrefix(fieldPath) { + continue + } + if err := populateFieldValueFromPath(msg, fieldPath, values); err != nil { + return err + } + } + return nil +} + +// PopulateFieldFromPath sets a value in a nested Protobuf structure. +// It instantiates missing protobuf fields as it goes. +func PopulateFieldFromPath(msg proto.Message, fieldPathString string, value string) error { + fieldPath := strings.Split(fieldPathString, ".") + return populateFieldValueFromPath(msg, fieldPath, []string{value}) +} + +func populateFieldValueFromPath(msg proto.Message, fieldPath []string, values []string) error { + m := reflect.ValueOf(msg) + if m.Kind() != reflect.Ptr { + return fmt.Errorf("unexpected type %T: %v", msg, msg) + } + m = m.Elem() + for i, fieldName := range fieldPath { + isLast := i == len(fieldPath)-1 + if !isLast && m.Kind() != reflect.Struct { + return fmt.Errorf("non-aggregate type in the mid of path: %s", strings.Join(fieldPath, ".")) + } + f := fieldByProtoName(m, fieldName) + if !f.IsValid() { + grpclog.Printf("field not found in %T: %s", msg, strings.Join(fieldPath, ".")) + return nil + } + + switch f.Kind() { + case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64, reflect.String, reflect.Uint32, reflect.Uint64: + m = f + case reflect.Slice: + // TODO(yugui) Support []byte + if !isLast { + return fmt.Errorf("unexpected repeated field in %s", strings.Join(fieldPath, ".")) + } + return populateRepeatedField(f, values) + case reflect.Ptr: + if f.IsNil() { + m = reflect.New(f.Type().Elem()) + f.Set(m) + } + m = f.Elem() + continue + case reflect.Struct: + m = f + continue + default: + return fmt.Errorf("unexpected type %s in %T", f.Type(), msg) + } + } + switch len(values) { + case 0: + return fmt.Errorf("no value of field: %s", strings.Join(fieldPath, ".")) + case 1: + default: + grpclog.Printf("too many field values: %s", strings.Join(fieldPath, ".")) + } + return populateField(m, values[0]) +} + +// fieldByProtoName looks up a field whose corresponding protobuf field name is "name". +// "m" must be a struct value. It returns zero reflect.Value if no such field found. +func fieldByProtoName(m reflect.Value, name string) reflect.Value { + props := proto.GetProperties(m.Type()) + for _, p := range props.Prop { + if p.OrigName == name { + return m.FieldByName(p.Name) + } + } + return reflect.Value{} +} + +func populateRepeatedField(f reflect.Value, values []string) error { + elemType := f.Type().Elem() + conv, ok := convFromType[elemType.Kind()] + if !ok { + return fmt.Errorf("unsupported field type %s", elemType) + } + f.Set(reflect.MakeSlice(f.Type(), len(values), len(values))) + for i, v := range values { + result := conv.Call([]reflect.Value{reflect.ValueOf(v)}) + if err := result[1].Interface(); err != nil { + return err.(error) + } + f.Index(i).Set(result[0]) + } + return nil +} + +func populateField(f reflect.Value, value string) error { + conv, ok := convFromType[f.Kind()] + if !ok { + return fmt.Errorf("unsupported field type %T", f) + } + result := conv.Call([]reflect.Value{reflect.ValueOf(value)}) + if err := result[1].Interface(); err != nil { + return err.(error) + } + f.Set(result[0]) + return nil +} + +var ( + convFromType = map[reflect.Kind]reflect.Value{ + reflect.String: reflect.ValueOf(String), + reflect.Bool: reflect.ValueOf(Bool), + reflect.Float64: reflect.ValueOf(Float64), + reflect.Float32: reflect.ValueOf(Float32), + reflect.Int64: reflect.ValueOf(Int64), + reflect.Int32: reflect.ValueOf(Int32), + reflect.Uint64: reflect.ValueOf(Uint64), + reflect.Uint32: reflect.ValueOf(Uint32), + // TODO(yugui) Support []byte + } +) diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/query_test.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/query_test.go new file mode 100644 index 00000000000..cf2d4285616 --- /dev/null +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/query_test.go @@ -0,0 +1,311 @@ +package runtime_test + +import ( + "net/url" + "testing" + + "github.com/golang/protobuf/proto" + "github.com/grpc-ecosystem/grpc-gateway/runtime" + "github.com/grpc-ecosystem/grpc-gateway/utilities" +) + +func TestPopulateParameters(t *testing.T) { + for _, spec := range []struct { + values url.Values + filter *utilities.DoubleArray + want proto.Message + }{ + { + values: url.Values{ + "float_value": {"1.5"}, + "double_value": {"2.5"}, + "int64_value": {"-1"}, + "int32_value": {"-2"}, + "uint64_value": {"3"}, + "uint32_value": {"4"}, + "bool_value": {"true"}, + "string_value": {"str"}, + "repeated_value": {"a", "b", "c"}, + }, + filter: utilities.NewDoubleArray(nil), + want: &proto3Message{ + FloatValue: 1.5, + DoubleValue: 2.5, + Int64Value: -1, + Int32Value: -2, + Uint64Value: 3, + Uint32Value: 4, + BoolValue: true, + StringValue: "str", + RepeatedValue: []string{"a", "b", "c"}, + }, + }, + { + values: url.Values{ + "float_value": {"1.5"}, + "double_value": {"2.5"}, + "int64_value": {"-1"}, + "int32_value": {"-2"}, + "uint64_value": {"3"}, + "uint32_value": {"4"}, + "bool_value": {"true"}, + "string_value": {"str"}, + "repeated_value": {"a", "b", "c"}, + }, + filter: utilities.NewDoubleArray(nil), + want: &proto2Message{ + FloatValue: proto.Float32(1.5), + DoubleValue: proto.Float64(2.5), + Int64Value: proto.Int64(-1), + Int32Value: proto.Int32(-2), + Uint64Value: proto.Uint64(3), + Uint32Value: proto.Uint32(4), + BoolValue: proto.Bool(true), + StringValue: proto.String("str"), + RepeatedValue: []string{"a", "b", "c"}, + }, + }, + { + values: url.Values{ + "nested.nested.nested.repeated_value": {"a", "b", "c"}, + "nested.nested.nested.string_value": {"s"}, + "nested.nested.string_value": {"t"}, + "nested.string_value": {"u"}, + "nested_non_null.string_value": {"v"}, + }, + filter: utilities.NewDoubleArray(nil), + want: &proto3Message{ + Nested: &proto2Message{ + Nested: &proto3Message{ + Nested: &proto2Message{ + RepeatedValue: []string{"a", "b", "c"}, + StringValue: proto.String("s"), + }, + StringValue: "t", + }, + StringValue: proto.String("u"), + }, + NestedNonNull: proto2Message{ + StringValue: proto.String("v"), + }, + }, + }, + { + values: url.Values{ + "uint64_value": {"1", "2", "3", "4", "5"}, + }, + filter: utilities.NewDoubleArray(nil), + want: &proto3Message{ + Uint64Value: 1, + }, + }, + } { + msg := proto.Clone(spec.want) + msg.Reset() + err := runtime.PopulateQueryParameters(msg, spec.values, spec.filter) + if err != nil { + t.Errorf("runtime.PoplateQueryParameters(msg, %v, %v) failed with %v; want success", spec.values, spec.filter, err) + continue + } + if got, want := msg, spec.want; !proto.Equal(got, want) { + t.Errorf("runtime.PopulateQueryParameters(msg, %v, %v = %v; want %v", spec.values, spec.filter, got, want) + } + } +} + +func TestPopulateParametersWithFilters(t *testing.T) { + for _, spec := range []struct { + values url.Values + filter *utilities.DoubleArray + want proto.Message + }{ + { + values: url.Values{ + "bool_value": {"true"}, + "string_value": {"str"}, + "repeated_value": {"a", "b", "c"}, + }, + filter: utilities.NewDoubleArray([][]string{ + {"bool_value"}, {"repeated_value"}, + }), + want: &proto3Message{ + StringValue: "str", + }, + }, + { + values: url.Values{ + "nested.nested.bool_value": {"true"}, + "nested.nested.string_value": {"str"}, + "nested.string_value": {"str"}, + "string_value": {"str"}, + }, + filter: utilities.NewDoubleArray([][]string{ + {"nested"}, + }), + want: &proto3Message{ + StringValue: "str", + }, + }, + { + values: url.Values{ + "nested.nested.bool_value": {"true"}, + "nested.nested.string_value": {"str"}, + "nested.string_value": {"str"}, + "string_value": {"str"}, + }, + filter: utilities.NewDoubleArray([][]string{ + {"nested", "nested"}, + }), + want: &proto3Message{ + Nested: &proto2Message{ + StringValue: proto.String("str"), + }, + StringValue: "str", + }, + }, + { + values: url.Values{ + "nested.nested.bool_value": {"true"}, + "nested.nested.string_value": {"str"}, + "nested.string_value": {"str"}, + "string_value": {"str"}, + }, + filter: utilities.NewDoubleArray([][]string{ + {"nested", "nested", "string_value"}, + }), + want: &proto3Message{ + Nested: &proto2Message{ + StringValue: proto.String("str"), + Nested: &proto3Message{ + BoolValue: true, + }, + }, + StringValue: "str", + }, + }, + } { + msg := proto.Clone(spec.want) + msg.Reset() + err := runtime.PopulateQueryParameters(msg, spec.values, spec.filter) + if err != nil { + t.Errorf("runtime.PoplateQueryParameters(msg, %v, %v) failed with %v; want success", spec.values, spec.filter, err) + continue + } + if got, want := msg, spec.want; !proto.Equal(got, want) { + t.Errorf("runtime.PopulateQueryParameters(msg, %v, %v = %v; want %v", spec.values, spec.filter, got, want) + } + } +} + +type proto3Message struct { + Nested *proto2Message `protobuf:"bytes,1,opt,name=nested" json:"nested,omitempty"` + NestedNonNull proto2Message `protobuf:"bytes,11,opt,name=nested_non_null" json:"nested_non_null,omitempty"` + FloatValue float32 `protobuf:"fixed32,2,opt,name=float_value" json:"float_value,omitempty"` + DoubleValue float64 `protobuf:"fixed64,3,opt,name=double_value" json:"double_value,omitempty"` + Int64Value int64 `protobuf:"varint,4,opt,name=int64_value" json:"int64_value,omitempty"` + Int32Value int32 `protobuf:"varint,5,opt,name=int32_value" json:"int32_value,omitempty"` + Uint64Value uint64 `protobuf:"varint,6,opt,name=uint64_value" json:"uint64_value,omitempty"` + Uint32Value uint32 `protobuf:"varint,7,opt,name=uint32_value" json:"uint32_value,omitempty"` + BoolValue bool `protobuf:"varint,8,opt,name=bool_value" json:"bool_value,omitempty"` + StringValue string `protobuf:"bytes,9,opt,name=string_value" json:"string_value,omitempty"` + RepeatedValue []string `protobuf:"bytes,10,rep,name=repeated_value" json:"repeated_value,omitempty"` +} + +func (m *proto3Message) Reset() { *m = proto3Message{} } +func (m *proto3Message) String() string { return proto.CompactTextString(m) } +func (*proto3Message) ProtoMessage() {} + +func (m *proto3Message) GetNested() *proto2Message { + if m != nil { + return m.Nested + } + return nil +} + +type proto2Message struct { + Nested *proto3Message `protobuf:"bytes,1,opt,name=nested" json:"nested,omitempty"` + FloatValue *float32 `protobuf:"fixed32,2,opt,name=float_value" json:"float_value,omitempty"` + DoubleValue *float64 `protobuf:"fixed64,3,opt,name=double_value" json:"double_value,omitempty"` + Int64Value *int64 `protobuf:"varint,4,opt,name=int64_value" json:"int64_value,omitempty"` + Int32Value *int32 `protobuf:"varint,5,opt,name=int32_value" json:"int32_value,omitempty"` + Uint64Value *uint64 `protobuf:"varint,6,opt,name=uint64_value" json:"uint64_value,omitempty"` + Uint32Value *uint32 `protobuf:"varint,7,opt,name=uint32_value" json:"uint32_value,omitempty"` + BoolValue *bool `protobuf:"varint,8,opt,name=bool_value" json:"bool_value,omitempty"` + StringValue *string `protobuf:"bytes,9,opt,name=string_value" json:"string_value,omitempty"` + RepeatedValue []string `protobuf:"bytes,10,rep,name=repeated_value" json:"repeated_value,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *proto2Message) Reset() { *m = proto2Message{} } +func (m *proto2Message) String() string { return proto.CompactTextString(m) } +func (*proto2Message) ProtoMessage() {} + +func (m *proto2Message) GetNested() *proto3Message { + if m != nil { + return m.Nested + } + return nil +} + +func (m *proto2Message) GetFloatValue() float32 { + if m != nil && m.FloatValue != nil { + return *m.FloatValue + } + return 0 +} + +func (m *proto2Message) GetDoubleValue() float64 { + if m != nil && m.DoubleValue != nil { + return *m.DoubleValue + } + return 0 +} + +func (m *proto2Message) GetInt64Value() int64 { + if m != nil && m.Int64Value != nil { + return *m.Int64Value + } + return 0 +} + +func (m *proto2Message) GetInt32Value() int32 { + if m != nil && m.Int32Value != nil { + return *m.Int32Value + } + return 0 +} + +func (m *proto2Message) GetUint64Value() uint64 { + if m != nil && m.Uint64Value != nil { + return *m.Uint64Value + } + return 0 +} + +func (m *proto2Message) GetUint32Value() uint32 { + if m != nil && m.Uint32Value != nil { + return *m.Uint32Value + } + return 0 +} + +func (m *proto2Message) GetBoolValue() bool { + if m != nil && m.BoolValue != nil { + return *m.BoolValue + } + return false +} + +func (m *proto2Message) GetStringValue() string { + if m != nil && m.StringValue != nil { + return *m.StringValue + } + return "" +} + +func (m *proto2Message) GetRepeatedValue() []string { + if m != nil { + return m.RepeatedValue + } + return nil +} diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/utilities/doc.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/utilities/doc.go new file mode 100644 index 00000000000..cf79a4d5886 --- /dev/null +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/utilities/doc.go @@ -0,0 +1,2 @@ +// Package utilities provides members for internal use in grpc-gateway. +package utilities diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/utilities/pattern.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/utilities/pattern.go new file mode 100644 index 00000000000..28ad9461f86 --- /dev/null +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/utilities/pattern.go @@ -0,0 +1,22 @@ +package utilities + +// An OpCode is a opcode of compiled path patterns. +type OpCode int + +// These constants are the valid values of OpCode. +const ( + // OpNop does nothing + OpNop = OpCode(iota) + // OpPush pushes a component to stack + OpPush + // OpLitPush pushes a component to stack if it matches to the literal + OpLitPush + // OpPushM concatenates the remaining components and pushes it to stack + OpPushM + // OpConcatN pops N items from stack, concatenates them and pushes it back to stack + OpConcatN + // OpCapture pops an item and binds it to the variable + OpCapture + // OpEnd is the least postive invalid opcode. + OpEnd +) diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/utilities/trie.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/utilities/trie.go new file mode 100644 index 00000000000..c2b7b30dd91 --- /dev/null +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/utilities/trie.go @@ -0,0 +1,177 @@ +package utilities + +import ( + "sort" +) + +// DoubleArray is a Double Array implementation of trie on sequences of strings. +type DoubleArray struct { + // Encoding keeps an encoding from string to int + Encoding map[string]int + // Base is the base array of Double Array + Base []int + // Check is the check array of Double Array + Check []int +} + +// NewDoubleArray builds a DoubleArray from a set of sequences of strings. +func NewDoubleArray(seqs [][]string) *DoubleArray { + da := &DoubleArray{Encoding: make(map[string]int)} + if len(seqs) == 0 { + return da + } + + encoded := registerTokens(da, seqs) + sort.Sort(byLex(encoded)) + + root := node{row: -1, col: -1, left: 0, right: len(encoded)} + addSeqs(da, encoded, 0, root) + + for i := len(da.Base); i > 0; i-- { + if da.Check[i-1] != 0 { + da.Base = da.Base[:i] + da.Check = da.Check[:i] + break + } + } + return da +} + +func registerTokens(da *DoubleArray, seqs [][]string) [][]int { + var result [][]int + for _, seq := range seqs { + var encoded []int + for _, token := range seq { + if _, ok := da.Encoding[token]; !ok { + da.Encoding[token] = len(da.Encoding) + } + encoded = append(encoded, da.Encoding[token]) + } + result = append(result, encoded) + } + for i := range result { + result[i] = append(result[i], len(da.Encoding)) + } + return result +} + +type node struct { + row, col int + left, right int +} + +func (n node) value(seqs [][]int) int { + return seqs[n.row][n.col] +} + +func (n node) children(seqs [][]int) []*node { + var result []*node + lastVal := int(-1) + last := new(node) + for i := n.left; i < n.right; i++ { + if lastVal == seqs[i][n.col+1] { + continue + } + last.right = i + last = &node{ + row: i, + col: n.col + 1, + left: i, + } + result = append(result, last) + } + last.right = n.right + return result +} + +func addSeqs(da *DoubleArray, seqs [][]int, pos int, n node) { + ensureSize(da, pos) + + children := n.children(seqs) + var i int + for i = 1; ; i++ { + ok := func() bool { + for _, child := range children { + code := child.value(seqs) + j := i + code + ensureSize(da, j) + if da.Check[j] != 0 { + return false + } + } + return true + }() + if ok { + break + } + } + da.Base[pos] = i + for _, child := range children { + code := child.value(seqs) + j := i + code + da.Check[j] = pos + 1 + } + terminator := len(da.Encoding) + for _, child := range children { + code := child.value(seqs) + if code == terminator { + continue + } + j := i + code + addSeqs(da, seqs, j, *child) + } +} + +func ensureSize(da *DoubleArray, i int) { + for i >= len(da.Base) { + da.Base = append(da.Base, make([]int, len(da.Base)+1)...) + da.Check = append(da.Check, make([]int, len(da.Check)+1)...) + } +} + +type byLex [][]int + +func (l byLex) Len() int { return len(l) } +func (l byLex) Swap(i, j int) { l[i], l[j] = l[j], l[i] } +func (l byLex) Less(i, j int) bool { + si := l[i] + sj := l[j] + var k int + for k = 0; k < len(si) && k < len(sj); k++ { + if si[k] < sj[k] { + return true + } + if si[k] > sj[k] { + return false + } + } + if k < len(sj) { + return true + } + return false +} + +// HasCommonPrefix determines if any sequence in the DoubleArray is a prefix of the given sequence. +func (da *DoubleArray) HasCommonPrefix(seq []string) bool { + if len(da.Base) == 0 { + return false + } + + var i int + for _, t := range seq { + code, ok := da.Encoding[t] + if !ok { + break + } + j := da.Base[i] + code + if len(da.Check) <= j || da.Check[j] != i+1 { + break + } + i = j + } + j := da.Base[i] + len(da.Encoding) + if len(da.Check) <= j || da.Check[j] != i+1 { + return false + } + return true +} diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/utilities/trie_test.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/utilities/trie_test.go new file mode 100644 index 00000000000..0178aa827a0 --- /dev/null +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/utilities/trie_test.go @@ -0,0 +1,372 @@ +package utilities_test + +import ( + "reflect" + "testing" + + "github.com/grpc-ecosystem/grpc-gateway/utilities" +) + +func TestMaxCommonPrefix(t *testing.T) { + for _, spec := range []struct { + da utilities.DoubleArray + tokens []string + want bool + }{ + { + da: utilities.DoubleArray{}, + tokens: nil, + want: false, + }, + { + da: utilities.DoubleArray{}, + tokens: []string{"foo"}, + want: false, + }, + { + da: utilities.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + }, + Base: []int{1, 1, 0}, + Check: []int{0, 1, 2}, + }, + tokens: nil, + want: false, + }, + { + da: utilities.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + }, + Base: []int{1, 1, 0}, + Check: []int{0, 1, 2}, + }, + tokens: []string{"foo"}, + want: true, + }, + { + da: utilities.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + }, + Base: []int{1, 1, 0}, + Check: []int{0, 1, 2}, + }, + tokens: []string{"bar"}, + want: false, + }, + { + // foo|bar + da: utilities.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + }, + Base: []int{1, 1, 2, 0, 0}, + Check: []int{0, 1, 1, 2, 3}, + // 0: ^ + // 1: ^foo + // 2: ^bar + // 3: ^foo$ + // 4: ^bar$ + }, + tokens: []string{"foo"}, + want: true, + }, + { + // foo|bar + da: utilities.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + }, + Base: []int{1, 1, 2, 0, 0}, + Check: []int{0, 1, 1, 2, 3}, + // 0: ^ + // 1: ^foo + // 2: ^bar + // 3: ^foo$ + // 4: ^bar$ + }, + tokens: []string{"bar"}, + want: true, + }, + { + // foo|bar + da: utilities.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + }, + Base: []int{1, 1, 2, 0, 0}, + Check: []int{0, 1, 1, 2, 3}, + // 0: ^ + // 1: ^foo + // 2: ^bar + // 3: ^foo$ + // 4: ^bar$ + }, + tokens: []string{"something-else"}, + want: false, + }, + { + // foo|bar + da: utilities.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + }, + Base: []int{1, 1, 2, 0, 0}, + Check: []int{0, 1, 1, 2, 3}, + // 0: ^ + // 1: ^foo + // 2: ^bar + // 3: ^foo$ + // 4: ^bar$ + }, + tokens: []string{"foo", "bar"}, + want: true, + }, + { + // foo|foo\.bar|bar + da: utilities.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + }, + Base: []int{1, 3, 1, 0, 4, 0, 0}, + Check: []int{0, 1, 1, 3, 2, 2, 5}, + // 0: ^ + // 1: ^foo + // 2: ^bar + // 3: ^bar$ + // 4: ^foo.bar + // 5: ^foo$ + // 6: ^foo.bar$ + }, + tokens: []string{"foo"}, + want: true, + }, + { + // foo|foo\.bar|bar + da: utilities.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + }, + Base: []int{1, 3, 1, 0, 4, 0, 0}, + Check: []int{0, 1, 1, 3, 2, 2, 5}, + // 0: ^ + // 1: ^foo + // 2: ^bar + // 3: ^bar$ + // 4: ^foo.bar + // 5: ^foo$ + // 6: ^foo.bar$ + }, + tokens: []string{"foo", "bar"}, + want: true, + }, + { + // foo|foo\.bar|bar + da: utilities.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + }, + Base: []int{1, 3, 1, 0, 4, 0, 0}, + Check: []int{0, 1, 1, 3, 2, 2, 5}, + // 0: ^ + // 1: ^foo + // 2: ^bar + // 3: ^bar$ + // 4: ^foo.bar + // 5: ^foo$ + // 6: ^foo.bar$ + }, + tokens: []string{"bar"}, + want: true, + }, + { + // foo|foo\.bar|bar + da: utilities.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + }, + Base: []int{1, 3, 1, 0, 4, 0, 0}, + Check: []int{0, 1, 1, 3, 2, 2, 5}, + // 0: ^ + // 1: ^foo + // 2: ^bar + // 3: ^bar$ + // 4: ^foo.bar + // 5: ^foo$ + // 6: ^foo.bar$ + }, + tokens: []string{"something-else"}, + want: false, + }, + { + // foo|foo\.bar|bar + da: utilities.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + }, + Base: []int{1, 3, 1, 0, 4, 0, 0}, + Check: []int{0, 1, 1, 3, 2, 2, 5}, + // 0: ^ + // 1: ^foo + // 2: ^bar + // 3: ^bar$ + // 4: ^foo.bar + // 5: ^foo$ + // 6: ^foo.bar$ + }, + tokens: []string{"foo", "bar", "baz"}, + want: true, + }, + } { + got := spec.da.HasCommonPrefix(spec.tokens) + if got != spec.want { + t.Errorf("%#v.HasCommonPrefix(%v) = %v; want %v", spec.da, spec.tokens, got, spec.want) + } + } +} + +func TestAdd(t *testing.T) { + for _, spec := range []struct { + tokens [][]string + want utilities.DoubleArray + }{ + { + want: utilities.DoubleArray{ + Encoding: make(map[string]int), + }, + }, + { + tokens: [][]string{{"foo"}}, + want: utilities.DoubleArray{ + Encoding: map[string]int{"foo": 0}, + Base: []int{1, 1, 0}, + Check: []int{0, 1, 2}, + // 0: ^ + // 1: ^foo + // 2: ^foo$ + }, + }, + { + tokens: [][]string{{"foo"}, {"bar"}}, + want: utilities.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + }, + Base: []int{1, 1, 2, 0, 0}, + Check: []int{0, 1, 1, 2, 3}, + // 0: ^ + // 1: ^foo + // 2: ^bar + // 3: ^foo$ + // 4: ^bar$ + }, + }, + { + tokens: [][]string{{"foo", "bar"}, {"foo", "baz"}}, + want: utilities.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + "baz": 2, + }, + Base: []int{1, 1, 1, 2, 0, 0}, + Check: []int{0, 1, 2, 2, 3, 4}, + // 0: ^ + // 1: ^foo + // 2: ^foo.bar + // 3: ^foo.baz + // 4: ^foo.bar$ + // 5: ^foo.baz$ + }, + }, + { + tokens: [][]string{{"foo", "bar"}, {"foo", "baz"}, {"qux"}}, + want: utilities.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + "baz": 2, + "qux": 3, + }, + Base: []int{1, 1, 1, 2, 3, 0, 0, 0}, + Check: []int{0, 1, 2, 2, 1, 3, 4, 5}, + // 0: ^ + // 1: ^foo + // 2: ^foo.bar + // 3: ^foo.baz + // 4: ^qux + // 5: ^foo.bar$ + // 6: ^foo.baz$ + // 7: ^qux$ + }, + }, + { + tokens: [][]string{ + {"foo", "bar"}, + {"foo", "baz", "bar"}, + {"qux", "foo"}, + }, + want: utilities.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + "baz": 2, + "qux": 3, + }, + Base: []int{1, 1, 1, 5, 8, 0, 3, 0, 5, 0}, + Check: []int{0, 1, 2, 2, 1, 3, 4, 7, 5, 9}, + // 0: ^ + // 1: ^foo + // 2: ^foo.bar + // 3: ^foo.baz + // 4: ^qux + // 5: ^foo.bar$ + // 6: ^foo.baz.bar + // 7: ^foo.baz.bar$ + // 8: ^qux.foo + // 9: ^qux.foo$ + }, + }, + } { + da := utilities.NewDoubleArray(spec.tokens) + if got, want := da.Encoding, spec.want.Encoding; !reflect.DeepEqual(got, want) { + t.Errorf("da.Encoding = %v; want %v; tokens = %#v", got, want, spec.tokens) + } + if got, want := da.Base, spec.want.Base; !compareArray(got, want) { + t.Errorf("da.Base = %v; want %v; tokens = %#v", got, want, spec.tokens) + } + if got, want := da.Check, spec.want.Check; !compareArray(got, want) { + t.Errorf("da.Check = %v; want %v; tokens = %#v", got, want, spec.tokens) + } + } +} + +func compareArray(got, want []int) bool { + var i int + for i = 0; i < len(got) && i < len(want); i++ { + if got[i] != want[i] { + return false + } + } + if i < len(want) { + return false + } + for ; i < len(got); i++ { + if got[i] != 0 { + return false + } + } + return true +}