Changes for the upcoming teleport pro:

* Allow external audit log plugins
* Add support for auth API server plugins
* Add license file path configuration parameter (not used in open-source)
* Extend audit log with user login events
This commit is contained in:
Roman Tkachenko 2017-11-21 17:35:58 -08:00
parent c66e7f9275
commit 143b834e57
77 changed files with 8421 additions and 60 deletions

18
Gopkg.lock generated
View file

@ -117,7 +117,7 @@
[[projects]]
name = "github.com/golang/protobuf"
packages = ["proto","protoc-gen-go/descriptor","ptypes/any","ptypes/empty"]
packages = ["jsonpb","proto","protoc-gen-go/descriptor","ptypes/any","ptypes/empty"]
revision = "8ee79997227bf9b34611aee7946ae64735e6fd93"
[[projects]]
@ -143,6 +143,18 @@
packages = ["."]
revision = "52bc17adf63c0807b5e5b5d91350703630f621c7"
[[projects]]
name = "github.com/gravitational/license"
packages = [".","constants"]
revision = "102213511ace56c97ccf1eef645835e16f84d130"
version = "0.0.4"
[[projects]]
name = "github.com/gravitational/reporting"
packages = [".","client","types"]
revision = "3c4a4e96fb5896e14fe29da7fcce14b8d93f3965"
version = "0.0.4"
[[projects]]
branch = "master"
name = "github.com/gravitational/roundtrip"
@ -163,7 +175,7 @@
[[projects]]
name = "github.com/grpc-ecosystem/grpc-gateway"
packages = ["third_party/googleapis/google/api"]
packages = ["runtime","runtime/internal","third_party/googleapis/google/api","utilities"]
revision = "a8f25bd1ab549f8b87afd48aa9181221e9d439bb"
version = "v1.1.0"
@ -393,6 +405,6 @@
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
inputs-digest = "c8734d8ce5c599785cf35e78aca518adc1fdd6ff31c139d65777429211331ba5"
inputs-digest = "97a726bb183a88f10d511f1853d157e006b3a6a3a8d2887dddc0aa8f92506175"
solver-name = "gps-cdcl"
solver-version = 1

View file

@ -143,3 +143,11 @@ ignored = ["github.com/Sirupsen/logrus"]
[[constraint]]
name = "github.com/russellhaering/gosaml2"
revision = "8908227c114abe0b63b1f0606abae72d11bf632a"
[[constraint]]
name = "github.com/gravitational/reporting"
version = "0.0.4"
[[constraint]]
name = "github.com/gravitational/license"
version = "0.0.4"

2
e

@ -1 +1 @@
Subproject commit 4f3e4ebd66716cd256abc2847a8e80addc85e4a3
Subproject commit 96a9523e7e7d8937bf738a1b299642f27447b3ef

View file

@ -34,10 +34,9 @@ import (
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/julienschmidt/httprouter"
log "github.com/sirupsen/logrus"
"github.com/jonboulle/clockwork"
"github.com/tstranex/u2f"
)
@ -193,6 +192,10 @@ func NewAPIServer(config *APIConfig) http.Handler {
srv.GET("/:version/events", srv.withAuth(srv.searchEvents))
srv.GET("/:version/events/session", srv.withAuth(srv.searchSessionEvents))
if plugin := GetPlugin(); plugin != nil {
plugin.AddHandlers(&srv)
}
return httplib.RewritePaths(&srv.Router,
httplib.Rewrite("/v1/nodes", "/v1/namespaces/default/nodes"),
httplib.Rewrite("/v1/sessions", "/v1/namespaces/default/sessions"),
@ -223,7 +226,7 @@ func (s *APIServer) withAuth(handler HandlerWithAuthFunc) httprouter.Handle {
user: authContext.User,
checker: authContext.Checker,
sessions: s.SessionService,
alog: s.AuditLog,
alog: s.AuthServer.IAuditLog,
}
version := p.ByName("version")
if version == "" {

View file

@ -79,6 +79,7 @@ func (s *APISuite) SetUpTest(c *C) {
s.a = NewAuthServer(&InitConfig{
Backend: s.bk,
Authority: authority.New(),
AuditLog: s.alog,
})
s.sessions, err = session.New(s.bk)
c.Assert(err, IsNil)

View file

@ -33,12 +33,13 @@ import (
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/backend"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/services/local"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/trace"
"github.com/coreos/go-oidc/oidc"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
saml2 "github.com/russellhaering/gosaml2"
log "github.com/sirupsen/logrus"
@ -87,6 +88,9 @@ func NewAuthServer(cfg *InitConfig, opts ...AuthServerOption) *AuthServer {
if cfg.ClusterConfiguration == nil {
cfg.ClusterConfiguration = local.NewClusterConfigurationService(cfg.Backend)
}
if cfg.AuditLog == nil {
cfg.AuditLog = events.NewDiscardAuditLog()
}
closeCtx, cancelFunc := context.WithCancel(context.TODO())
as := AuthServer{
Entry: log.WithFields(log.Fields{
@ -101,6 +105,7 @@ func NewAuthServer(cfg *InitConfig, opts ...AuthServerOption) *AuthServer {
Access: cfg.Access,
AuthServiceName: cfg.AuthServiceName,
ClusterConfiguration: cfg.ClusterConfiguration,
IAuditLog: cfg.AuditLog,
oidcClients: make(map[string]*oidcClient),
samlProviders: make(map[string]*samlProvider),
cancelFunc: cancelFunc,
@ -157,6 +162,7 @@ type AuthServer struct {
services.Identity
services.Access
services.ClusterConfiguration
events.IAuditLog
}
func (a *AuthServer) Close() error {
@ -172,6 +178,11 @@ func (a *AuthServer) SetClock(clock clockwork.Clock) {
a.clock = clock
}
// SetAuditLog sets the server's audit log
func (a *AuthServer) SetAuditLog(auditLog events.IAuditLog) {
a.IAuditLog = auditLog
}
// GetDomainName returns the domain name that identifies this authority server.
// Also known as "cluster name"
func (a *AuthServer) GetDomainName() (string, error) {
@ -225,7 +236,6 @@ func (s *AuthServer) GenerateUserCert(key []byte, user services.User, allowedLog
if err != nil {
return nil, trace.Wrap(err)
}
ca, err := s.Trust.GetCertAuthority(services.CertAuthID{
Type: services.UserCA,
DomainName: domainName,
@ -237,7 +247,7 @@ func (s *AuthServer) GenerateUserCert(key []byte, user services.User, allowedLog
if err != nil {
return nil, trace.Wrap(err)
}
return s.Authority.GenerateUserCert(services.UserCertParams{
cert, err := s.Authority.GenerateUserCert(services.UserCertParams{
PrivateCASigningKey: privateKey,
PublicUserKey: key,
Username: user.GetName(),
@ -247,6 +257,14 @@ func (s *AuthServer) GenerateUserCert(key []byte, user services.User, allowedLog
Compatibility: compatibility,
PermitAgentForwarding: canForwardAgents,
})
if err != nil {
return nil, trace.Wrap(err)
}
s.EmitAuditEvent(events.UserLoginEvent, events.EventFields{
events.EventUser: user.GetName(),
events.LoginMethod: events.LoginMethodLocal,
})
return cert, nil
}
// WithUserLock executes function authenticateFn that performs user authentication
@ -315,6 +333,10 @@ func (s *AuthServer) SignIn(user string, password []byte) (services.WebSession,
if err != nil {
return nil, trace.Wrap(err)
}
s.EmitAuditEvent(events.UserLoginEvent, events.EventFields{
events.EventUser: user,
events.LoginMethod: events.LoginMethodLocal,
})
return s.PreAuthenticatedSignIn(user)
}

View file

@ -28,10 +28,12 @@ import (
"github.com/gravitational/teleport/lib"
"github.com/gravitational/teleport/lib/backend"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/trace"
"github.com/pborman/uuid"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
)
@ -105,10 +107,13 @@ type InitConfig struct {
// ClusterConfig holds cluster level configuration.
ClusterConfig services.ClusterConfig
// AuditLog is used for emitting events to audit log
AuditLog events.IAuditLog
}
// Init instantiates and configures an instance of AuthServer
func Init(cfg InitConfig) (*AuthServer, *Identity, error) {
func Init(cfg InitConfig, opts ...AuthServerOption) (*AuthServer, *Identity, error) {
if cfg.DataDir == "" {
return nil, nil, trace.BadParameter("DataDir: data dir can not be empty")
}
@ -124,7 +129,7 @@ func Init(cfg InitConfig) (*AuthServer, *Identity, error) {
defer cfg.Backend.ReleaseLock(domainName)
// check that user CA and host CA are present and set the certs if needed
asrv := NewAuthServer(&cfg)
asrv := NewAuthServer(&cfg, opts...)
// INTERNAL: Authorities (plus Roles) and ReverseTunnels don't follow the
// same pattern as the rest of the configuration (they are not configuration
@ -155,6 +160,17 @@ func Init(cfg InitConfig) (*AuthServer, *Identity, error) {
}
// set cluster level config on the backend and then force a sync of the cache.
clusterConfig, err := asrv.GetClusterConfig()
if err != nil && !trace.IsNotFound(err) {
return nil, nil, trace.Wrap(err)
}
// init a unique cluster ID, it must be set once only during the first
// start so if it's already there, reuse it
if clusterConfig != nil && clusterConfig.GetClusterID() != "" {
cfg.ClusterConfig.SetClusterID(clusterConfig.GetClusterID())
} else {
cfg.ClusterConfig.SetClusterID(uuid.New())
}
err = asrv.SetClusterConfig(cfg.ClusterConfig)
if err != nil {
return nil, nil, trace.Wrap(err)

View file

@ -214,3 +214,45 @@ func (s *AuthInitSuite) TestAuthPreference(c *C) {
c.Assert(u.AppID, Equals, "foo")
c.Assert(u.Facets, DeepEquals, []string{"bar", "baz"})
}
func (s *AuthInitSuite) TestClusterID(c *C) {
bk, err := boltbk.New(backend.Params{"path": c.MkDir()})
c.Assert(err, IsNil)
clusterName, err := services.NewClusterName(services.ClusterNameSpecV2{
ClusterName: "me.localhost",
})
c.Assert(err, IsNil)
authServer, _, err := Init(InitConfig{
DataDir: c.MkDir(),
HostUUID: "00000000-0000-0000-0000-000000000000",
NodeName: "foo",
Backend: bk,
Authority: testauthority.New(),
ClusterName: clusterName,
ClusterConfig: services.DefaultClusterConfig(),
})
c.Assert(err, IsNil)
cc, err := authServer.GetClusterConfig()
c.Assert(err, IsNil)
clusterID := cc.GetClusterID()
c.Assert(clusterID, Not(Equals), "")
// do it again and make sure cluster ID hasn't changed
authServer, _, err = Init(InitConfig{
DataDir: c.MkDir(),
HostUUID: "00000000-0000-0000-0000-000000000000",
NodeName: "foo",
Backend: bk,
Authority: testauthority.New(),
ClusterName: clusterName,
ClusterConfig: services.DefaultClusterConfig(),
})
c.Assert(err, IsNil)
cc, err = authServer.GetClusterConfig()
c.Assert(err, IsNil)
c.Assert(cc.GetClusterID(), Equals, clusterID)
}

View file

@ -9,13 +9,14 @@ import (
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/trace"
"github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/oauth2"
"github.com/coreos/go-oidc/oidc"
"github.com/gravitational/trace"
log "github.com/sirupsen/logrus"
)
@ -231,7 +232,10 @@ func (a *AuthServer) ValidateOIDCAuthCallback(q url.Values) (*OIDCAuthResponse,
response.HostSigners = append(response.HostSigners, authority)
}
}
a.EmitAuditEvent(events.UserLoginEvent, events.EventFields{
events.EventUser: user.GetName(),
events.LoginMethod: events.LoginMethodOIDC,
})
return response, nil
}

44
lib/auth/plugin.go Normal file
View file

@ -0,0 +1,44 @@
/*
Copyright 2017 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package auth
import (
"sync"
)
var pluginMutex = &sync.Mutex{}
var plugin Plugin
// GetPlugin returns auth API server plugin that allows injecting handlers
func GetPlugin() Plugin {
pluginMutex.Lock()
defer pluginMutex.Unlock()
return plugin
}
// SetPlugin sets plugin for the auth API server
func SetPlugin(p Plugin) {
pluginMutex.Lock()
defer pluginMutex.Unlock()
plugin = p
}
// Plugin is auth API server extension setter
type Plugin interface {
// AddHandlers adds handlers to the auth API server
AddHandlers(srv *APIServer)
}

View file

@ -9,11 +9,12 @@ import (
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/trace"
"github.com/beevik/etree"
"github.com/gravitational/trace"
saml2 "github.com/russellhaering/gosaml2"
log "github.com/sirupsen/logrus"
)
@ -353,6 +354,9 @@ func (a *AuthServer) ValidateSAMLResponse(samlResponse string) (*SAMLAuthRespons
response.HostSigners = append(response.HostSigners, authority)
}
}
a.EmitAuditEvent(events.UserLoginEvent, events.EventFields{
events.EventUser: user.GetName(),
events.LoginMethod: events.LoginMethodSAML,
})
return response, nil
}

View file

@ -26,6 +26,7 @@ import (
"net"
"net/url"
"os"
"path/filepath"
"strings"
"time"
"unicode"
@ -404,6 +405,7 @@ func ApplyFileConfig(fc *FileConfig, cfg *service.Config) error {
log.Warnf(warningMessage)
}
}
// read in and set session recording
clusterConfig, err := fc.Auth.SessionRecording.Parse()
if err != nil {
@ -411,6 +413,16 @@ func ApplyFileConfig(fc *FileConfig, cfg *service.Config) error {
}
cfg.Auth.ClusterConfig = clusterConfig
// read in and set the license file path (not used in open-source version)
licenseFile := fc.Auth.LicenseFile
if licenseFile != "" {
if filepath.IsAbs(licenseFile) {
cfg.Auth.LicenseFile = licenseFile
} else {
cfg.Auth.LicenseFile = filepath.Join(cfg.DataDir, licenseFile)
}
}
// apply "ssh_service" section
if fc.SSH.ListenAddress != "" {
addr, err := utils.ParseHostPortAddr(fc.SSH.ListenAddress, int(defaults.SSHServerListenPort))

View file

@ -146,6 +146,7 @@ func (s *ConfigTestSuite) TestConfigReading(c *check.C) {
c.Assert(conf.DataDir, check.Equals, "/path/to/data")
c.Assert(conf.Auth.Enabled(), check.Equals, true)
c.Assert(conf.Auth.ListenAddress, check.Equals, "tcp://auth")
c.Assert(conf.Auth.LicenseFile, check.Equals, "lic.pem")
c.Assert(conf.SSH.Configured(), check.Equals, true)
c.Assert(conf.SSH.Enabled(), check.Equals, true)
c.Assert(conf.SSH.ListenAddress, check.Equals, "tcp://ssh")
@ -508,6 +509,7 @@ func makeConfigFixture() string {
// auth service:
conf.Auth.EnabledFlag = "Yeah"
conf.Auth.ListenAddress = "tcp://auth"
conf.Auth.LicenseFile = "lic.pem"
// ssh service:
conf.SSH.EnabledFlag = "true"
@ -572,3 +574,40 @@ ssh_service:
c.Assert(cfg.SSH.PermitUserEnvironment, check.Equals, tt.outPermitUserEnvironment, comment)
}
}
func (s *ConfigTestSuite) TestLicenseFile(c *check.C) {
testCases := []struct {
path string
result string
}{
// 0 - no license
{
path: "",
result: filepath.Join(defaults.DataDir, defaults.LicenseFile),
},
// 1 - relative path
{
path: "lic.pem",
result: filepath.Join(defaults.DataDir, "lic.pem"),
},
// 2 - absolute path
{
path: "/etc/teleport/license",
result: "/etc/teleport/license",
},
}
cfg := service.MakeDefaultConfig()
c.Assert(cfg.Auth.LicenseFile, check.Equals,
filepath.Join(defaults.DataDir, defaults.LicenseFile))
for _, tc := range testCases {
err := ApplyFileConfig(&FileConfig{
Auth: Auth{
LicenseFile: tc.path,
},
}, cfg)
c.Assert(err, check.IsNil)
c.Assert(cfg.Auth.LicenseFile, check.Equals, tc.result)
}
}

View file

@ -44,8 +44,8 @@ import (
var (
// all possible valid YAML config keys
// true = non-scalar
// false = scalar
// true = has sub-keys
// false = does not have sub-keys (a leaf)
validKeys = map[string]bool{
"namespace": true,
"cluster_name": true,
@ -131,6 +131,7 @@ var (
"session_recording": false,
"read_capacity_units": false,
"write_capacity_units": false,
"license_file": false,
}
)
@ -496,6 +497,10 @@ type Auth struct {
// it here overrides defaults.
// Deprecated: Remove in Teleport 2.4.1.
DynamicConfig *bool `yaml:"dynamic_config,omitempty"`
// LicenseFile is a path to the license file. The path can be either absolute or
// relative to the global data dir
LicenseFile string `yaml:"license_file,omitempty"`
}
// TrustedCluster struct holds configuration values under "trusted_clusters" key

View file

@ -46,7 +46,7 @@ teleport:
- period: 10m10s
average: 170
burst: 171
keys:
keys:
- cert: node.cert
private_key: !!binary cHJpdmF0ZSBrZXk=
- cert_file: /proxy.cert.file
@ -61,21 +61,21 @@ auth_service:
tokens:
- "proxy,node:xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
- "auth:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
authorities:
authorities:
- type: host
domain_name: example.com
checking_keys:
checking_keys:
- checking key 1
checking_key_files:
- /ca.checking.key
signing_keys:
signing_keys:
- !!binary c2lnbmluZyBrZXkgMQ==
signing_key_files:
- /ca.signing.key
reverse_tunnels:
- domain_name: tunnel.example.com
- domain_name: tunnel.example.com
addresses: ["com-1", "com-2"]
- domain_name: tunnel.example.org
- domain_name: tunnel.example.org
addresses: ["org-1"]
ssh_service:
@ -135,7 +135,7 @@ proxy_service:
// need to support it until it's fully removed.
const LegacyAuthenticationSection = `
auth_service:
oidc_connectors:
oidc_connectors:
- id: google
redirect_url: https://localhost:3080/v1/webapi/oidc/callback
client_id: id-from-google.apps.googleusercontent.com

View file

@ -235,7 +235,7 @@ var (
// ConfigFilePath is default path to teleport config file
ConfigFilePath = "/etc/teleport.yaml"
// DataDir is where all mutable data is stored (user keys, recorded sessions,
// DataDir is where all mutable data is stored (user keys, recorded sessions,
// registered SSH servers, etc):
DataDir = "/var/lib/teleport"
@ -247,6 +247,9 @@ var (
// ConfigEnvar is a name of teleport's configuration environment variable
ConfigEnvar = "TELEPORT_CONFIG"
// LicenseFile is the default name of the license file
LicenseFile = "license.pem"
)
const (

View file

@ -63,10 +63,22 @@ const (
// the beginning
SessionByteOffset = "offset"
// Join & Leave events indicate when someone joins/leaves a session
SessionJoinEvent = "session.join"
// SessionJoinEvent indicates that someone joined a session
SessionJoinEvent = "session.join"
// SessionLeaveEvent indicates that someone left a session
SessionLeaveEvent = "session.leave"
// UserLoginEvent indicates that a user logged into web UI or via tsh
UserLoginEvent = "user.login"
// LoginMethod is the event field indicating how the login was performed
LoginMethod = "method"
// LoginMethodLocal represents login with username/password
LoginMethodLocal = "local"
// LoginMethodOIDC represents login with OIDC
LoginMethodOIDC = "oidc"
// LoginMethodSAML represents login with SAML
LoginMethodSAML = "saml"
// ExecEvent is an exec command executed by script or user on
// the server side
ExecEvent = "exec"
@ -101,7 +113,7 @@ const (
MaxChunkBytes = 1024 * 1024 * 5
)
// IAuditLog is the primary (and the only external-facing) interface for AUditLogger.
// IAuditLog is the primary (and the only external-facing) interface for AuditLogger.
// If you wish to implement a different kind of logger (not filesystem-based), you
// have to implement this interface
type IAuditLog interface {

View file

@ -10,7 +10,11 @@ import (
// DiscardAuditLog is do-nothing, discard-everything implementation
// of IAuditLog interface used for cases when audit is turned off
type DiscardAuditLog struct {
type DiscardAuditLog struct{}
// NewDiscardAuditLog returns a no-op audit log instance
func NewDiscardAuditLog() *DiscardAuditLog {
return &DiscardAuditLog{}
}
func (d *DiscardAuditLog) WaitForDelivery(context.Context) error {
@ -46,8 +50,7 @@ func (d *DiscardAuditLog) SearchSessionEvents(fromUTC time.Time, toUTC time.Time
// discardSessionLogger implements a session logger that does nothing. It
// discards all events and chunks written to it. It is used when session
// recording has been disabled.
type discardSessionLogger struct {
}
type discardSessionLogger struct{}
func (d *discardSessionLogger) LogEvent(fields EventFields) {
return

View file

@ -21,6 +21,7 @@ import (
"io"
"net"
"os"
"path/filepath"
"time"
"golang.org/x/crypto/ssh"
@ -70,7 +71,7 @@ type Config struct {
// SSH role an SSH endpoint server
SSH SSHConfig
// Auth server authentication and authorizatin server config
// Auth server authentication and authorization server config
Auth AuthConfig
// Keygen points to a key generator implementation
@ -266,6 +267,9 @@ type AuthConfig struct {
// Preference defines the authentication preference (type and second factor) for
// the auth server.
Preference services.AuthPreference
// LicenseFile is a full path to the license file
LicenseFile string
}
// SSHConfig configures SSH server node role
@ -320,6 +324,7 @@ func ApplyDefaults(cfg *Config) {
ap := &services.AuthPreferenceV2{}
ap.CheckAndSetDefaults()
cfg.Auth.Preference = ap
cfg.Auth.LicenseFile = filepath.Join(cfg.DataDir, defaults.LicenseFile)
// defaults for the SSH proxy service:
cfg.Proxy.Enabled = true

View file

@ -103,15 +103,30 @@ type TeleportProcess struct {
// localAuth has local auth server listed in case if this process
// has started with auth server role enabled
localAuth *auth.AuthServer
// backend is the process' backend
backend backend.Backend
// auditLog is the initialized audit log
auditLog events.IAuditLog
// identities of this process (credentials to auth sever, basically)
Identities map[teleport.Role]*auth.Identity
}
// GetAuthServer returns the process' auth server
func (process *TeleportProcess) GetAuthServer() *auth.AuthServer {
return process.localAuth
}
// GetAuditLog returns the process' audit log
func (process *TeleportProcess) GetAuditLog() events.IAuditLog {
return process.auditLog
}
// GetBackend returns the process' backend
func (process *TeleportProcess) GetBackend() backend.Backend {
return process.backend
}
func (process *TeleportProcess) findStaticIdentity(id auth.IdentityID) (*auth.Identity, error) {
for i := range process.Config.Identities {
identity := process.Config.Identities[i]
@ -303,13 +318,13 @@ func (process *TeleportProcess) initAuthService(authority auth.Authority) error
if err != nil {
return trace.Wrap(err)
}
process.backend = b
// create the audit log, which will be consuming (and recording) all events
// and recording all sessions.
var auditLog events.IAuditLog
if cfg.Auth.NoAudit {
// this is for teleconsole
auditLog = &events.DiscardAuditLog{}
process.auditLog = events.NewDiscardAuditLog()
warningMessage := "Warning: Teleport audit and session recording have been " +
"turned off. This is dangerous, you will not be able to view audit events " +
@ -345,7 +360,7 @@ func (process *TeleportProcess) initAuthService(authority auth.Authority) error
auditConfig.GID = &gid
}
}
auditLog, err = events.NewAuditLog(auditConfig)
process.auditLog, err = events.NewAuditLog(auditConfig)
if err != nil {
return trace.Wrap(err)
}
@ -372,6 +387,7 @@ func (process *TeleportProcess) initAuthService(authority auth.Authority) error
Roles: cfg.Auth.Roles,
AuthPreference: cfg.Auth.Preference,
OIDCConnectors: cfg.OIDCConnectors,
AuditLog: process.auditLog,
})
if err != nil {
return trace.Wrap(err)
@ -393,7 +409,7 @@ func (process *TeleportProcess) initAuthService(authority auth.Authority) error
AuthServer: authServer,
SessionService: sessionService,
Authorizer: authorizer,
AuditLog: auditLog,
AuditLog: process.auditLog,
}
limiter, err := limiter.NewLimiter(cfg.Auth.Limiter)

View file

@ -41,6 +41,12 @@ type ClusterConfig interface {
// SetSessionRecording sets where the session is recorded.
SetSessionRecording(RecordingType)
// GetClusterID returns the unique cluster ID
GetClusterID() string
// SetClusterID sets the cluster ID
SetClusterID(string)
// CheckAndSetDefaults checks and set default values for missing fields.
CheckAndSetDefaults() error
@ -115,6 +121,9 @@ const (
type ClusterConfigSpecV3 struct {
// SessionRecording controls where (or if) the session is recorded.
SessionRecording RecordingType `json:"session_recording"`
// ClusterID is the unique cluster ID that is set once during the first auth
// server startup.
ClusterID string `json:"cluster_id"`
}
// GetName returns the name of the cluster.
@ -157,6 +166,16 @@ func (c *ClusterConfigV3) SetSessionRecording(s RecordingType) {
c.Spec.SessionRecording = s
}
// GetClusterID returns the unique cluster ID
func (c *ClusterConfigV3) GetClusterID() string {
return c.Spec.ClusterID
}
// SetClusterID sets the cluster ID
func (c *ClusterConfigV3) SetClusterID(id string) {
c.Spec.ClusterID = id
}
// CheckAndSetDefaults checks validity of all parameters and sets defaults.
func (c *ClusterConfigV3) CheckAndSetDefaults() error {
// make sure we have defaults for all metadata fields
@ -187,7 +206,8 @@ func (c *ClusterConfigV3) Copy() ClusterConfig {
// String represents a human readable version of the cluster name.
func (c *ClusterConfigV3) String() string {
return fmt.Sprintf("ClusterConfig(SessionRecording=%v)", c.Spec.SessionRecording)
return fmt.Sprintf("ClusterConfig(SessionRecording=%v, ClusterID=%v)",
c.Spec.SessionRecording, c.Spec.ClusterID)
}
// ClusterConfigSpecSchemaTemplate is a template for ClusterConfig schema.
@ -197,6 +217,9 @@ const ClusterConfigSpecSchemaTemplate = `{
"properties": {
"session_recording": {
"type": "string"
},
"cluster_id": {
"type": "string"
}%v
}
}`

View file

@ -228,8 +228,8 @@ const ExternalIdentitySchema = `{
"type": "object",
"additionalProperties": false,
"properties": {
"connector_id": {"type": "string"},
"username": {"type": "string"}
"connector_id": {"type": "string"},
"username": {"type": "string"}
}
}`

View file

@ -19,6 +19,7 @@ package local
import (
"github.com/gravitational/teleport/lib/backend"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/trace"
)

View file

@ -311,6 +311,11 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*RewritingHandler, error) {
}, nil
}
// GetProxyClient returns authenticated auth server client
func (h *Handler) GetProxyClient() auth.ClientI {
return h.cfg.ProxyClient
}
// Close closes associated session cache operations
func (h *Handler) Close() error {
return h.auth.Close()

View file

@ -138,7 +138,7 @@ func (c *SessionContext) getRemoteClient(siteName string) (auth.ClientI, bool) {
}
// GetClient returns the client connected to the auth server
func (c *SessionContext) GetClient() (auth.ClientI, error) {
func (c *SessionContext) GetClient() (*auth.TunClient, error) {
return c.clt, nil
}

View file

@ -39,10 +39,17 @@ import (
log "github.com/sirupsen/logrus"
)
// same as main() but has a testing switch
// - cmdlineArgs are passed from main()
// - testRun is 'true' when running under an integration test
func Run(cmdlineArgs []string, testRun bool) (executedCommand string, conf *service.Config) {
// Options combines init/start teleport options
type Options struct {
// Args is a list of command-line args passed from main()
Args []string
// InitOnly when set to true, initializes config and aux
// endpoints but does not start the process
InitOnly bool
}
// Run inits/starts the process according to the provided options
func Run(options Options) (executedCommand string, conf *service.Config) {
var err error
// configure trace's errors to produce full stack traces
@ -130,7 +137,7 @@ func Run(cmdlineArgs []string, testRun bool) (executedCommand string, conf *serv
scpc.Arg("target", "").StringsVar(&scpCommand.Target)
// parse CLI commands+flags:
command, err := app.Parse(cmdlineArgs)
command, err := app.Parse(options.Args)
if err != nil {
utils.FatalError(err)
}
@ -145,7 +152,7 @@ func Run(cmdlineArgs []string, testRun bool) (executedCommand string, conf *serv
if err = config.Configure(&ccf, conf); err != nil {
utils.FatalError(err)
}
if !testRun {
if !options.InitOnly {
log.Debug(conf.DebugDumpToYAML())
}
if ccf.HTTPProfileEndpoint {
@ -172,8 +179,8 @@ func Run(cmdlineArgs []string, testRun bool) (executedCommand string, conf *serv
}
}()
}
if !testRun {
err = onStart(conf)
if !options.InitOnly {
err = OnStart(conf)
}
case scpc.FullCommand():
err = onSCP(&scpCommand)
@ -191,12 +198,13 @@ func Run(cmdlineArgs []string, testRun bool) (executedCommand string, conf *serv
return command, conf
}
// onStart is the handler for "start" CLI command
func onStart(config *service.Config) error {
// OnStart is the handler for "start" CLI command
func OnStart(config *service.Config) error {
srv, err := service.NewTeleport(config)
if err != nil {
return trace.Wrap(err, "initializing teleport")
}
if err := srv.Start(); err != nil {
return trace.Wrap(err, "starting teleport")
}
@ -210,8 +218,8 @@ func onStart(config *service.Config) error {
fmt.Fprintf(f, "%v", os.Getpid())
defer f.Close()
}
srv.Wait()
return nil
return trace.Wrap(srv.Wait())
}
// onStatus is the handler for "status" CLI command

View file

@ -66,7 +66,10 @@ func (s *MainTestSuite) SetUpSuite(c *check.C) {
}
func (s *MainTestSuite) TestDefault(c *check.C) {
cmd, conf := Run([]string{"start"}, true)
cmd, conf := Run(Options{
Args: []string{"start"},
InitOnly: true,
})
c.Assert(cmd, check.Equals, "start")
c.Assert(conf.Hostname, check.Equals, s.hostname)
c.Assert(conf.DataDir, check.Equals, "/tmp/teleport/var/lib/teleport")
@ -78,17 +81,26 @@ func (s *MainTestSuite) TestDefault(c *check.C) {
}
func (s *MainTestSuite) TestRolesFlag(c *check.C) {
cmd, conf := Run([]string{"start", "--roles=node"}, true)
cmd, conf := Run(Options{
Args: []string{"start", "--roles=node"},
InitOnly: true,
})
c.Assert(conf.SSH.Enabled, check.Equals, true)
c.Assert(conf.Auth.Enabled, check.Equals, false)
c.Assert(conf.Proxy.Enabled, check.Equals, false)
cmd, conf = Run([]string{"start", "--roles=proxy"}, true)
cmd, conf = Run(Options{
Args: []string{"start", "--roles=proxy"},
InitOnly: true,
})
c.Assert(conf.SSH.Enabled, check.Equals, false)
c.Assert(conf.Auth.Enabled, check.Equals, false)
c.Assert(conf.Proxy.Enabled, check.Equals, true)
cmd, conf = Run([]string{"start", "--roles=auth"}, true)
cmd, conf = Run(Options{
Args: []string{"start", "--roles=auth"},
InitOnly: true,
})
c.Assert(conf.SSH.Enabled, check.Equals, false)
c.Assert(conf.Auth.Enabled, check.Equals, true)
c.Assert(conf.Proxy.Enabled, check.Equals, false)
@ -96,7 +108,10 @@ func (s *MainTestSuite) TestRolesFlag(c *check.C) {
}
func (s *MainTestSuite) TestConfigFile(c *check.C) {
cmd, conf := Run([]string{"start", "--roles=node", "--labels=a=a1,b=b1", "--config=" + s.configFile}, true)
cmd, conf := Run(Options{
Args: []string{"start", "--roles=node", "--labels=a=a1,b=b1", "--config=" + s.configFile},
InitOnly: true,
})
c.Assert(cmd, check.Equals, "start")
c.Assert(conf.SSH.Enabled, check.Equals, true)
c.Assert(conf.Auth.Enabled, check.Equals, false)

View file

@ -23,6 +23,7 @@ import (
)
func main() {
const testRun = false
common.Run(os.Args[1:], testRun)
common.Run(common.Options{
Args: os.Args[1:],
})
}

843
vendor/github.com/golang/protobuf/jsonpb/jsonpb.go generated vendored Normal file
View file

@ -0,0 +1,843 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2015 The Go Authors. All rights reserved.
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
/*
Package jsonpb provides marshaling and unmarshaling between protocol buffers and JSON.
It follows the specification at https://developers.google.com/protocol-buffers/docs/proto3#json.
This package produces a different output than the standard "encoding/json" package,
which does not operate correctly on protocol buffers.
*/
package jsonpb
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"reflect"
"sort"
"strconv"
"strings"
"time"
"github.com/golang/protobuf/proto"
)
// Marshaler is a configurable object for converting between
// protocol buffer objects and a JSON representation for them.
type Marshaler struct {
// Whether to render enum values as integers, as opposed to string values.
EnumsAsInts bool
// Whether to render fields with zero values.
EmitDefaults bool
// A string to indent each level by. The presence of this field will
// also cause a space to appear between the field separator and
// value, and for newlines to be appear between fields and array
// elements.
Indent string
// Whether to use the original (.proto) name for fields.
OrigName bool
}
// Marshal marshals a protocol buffer into JSON.
func (m *Marshaler) Marshal(out io.Writer, pb proto.Message) error {
writer := &errWriter{writer: out}
return m.marshalObject(writer, pb, "", "")
}
// MarshalToString converts a protocol buffer object to JSON string.
func (m *Marshaler) MarshalToString(pb proto.Message) (string, error) {
var buf bytes.Buffer
if err := m.Marshal(&buf, pb); err != nil {
return "", err
}
return buf.String(), nil
}
type int32Slice []int32
// For sorting extensions ids to ensure stable output.
func (s int32Slice) Len() int { return len(s) }
func (s int32Slice) Less(i, j int) bool { return s[i] < s[j] }
func (s int32Slice) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
type wkt interface {
XXX_WellKnownType() string
}
// marshalObject writes a struct to the Writer.
func (m *Marshaler) marshalObject(out *errWriter, v proto.Message, indent, typeURL string) error {
s := reflect.ValueOf(v).Elem()
// Handle well-known types.
if wkt, ok := v.(wkt); ok {
switch wkt.XXX_WellKnownType() {
case "DoubleValue", "FloatValue", "Int64Value", "UInt64Value",
"Int32Value", "UInt32Value", "BoolValue", "StringValue", "BytesValue":
// "Wrappers use the same representation in JSON
// as the wrapped primitive type, ..."
sprop := proto.GetProperties(s.Type())
return m.marshalValue(out, sprop.Prop[0], s.Field(0), indent)
case "Any":
// Any is a bit more involved.
return m.marshalAny(out, v, indent)
case "Duration":
// "Generated output always contains 3, 6, or 9 fractional digits,
// depending on required precision."
s, ns := s.Field(0).Int(), s.Field(1).Int()
d := time.Duration(s)*time.Second + time.Duration(ns)*time.Nanosecond
x := fmt.Sprintf("%.9f", d.Seconds())
x = strings.TrimSuffix(x, "000")
x = strings.TrimSuffix(x, "000")
out.write(`"`)
out.write(x)
out.write(`s"`)
return out.err
case "Struct":
// Let marshalValue handle the `fields` map.
// TODO: pass the correct Properties if needed.
return m.marshalValue(out, &proto.Properties{}, s.Field(0), indent)
case "Timestamp":
// "RFC 3339, where generated output will always be Z-normalized
// and uses 3, 6 or 9 fractional digits."
s, ns := s.Field(0).Int(), s.Field(1).Int()
t := time.Unix(s, ns).UTC()
// time.RFC3339Nano isn't exactly right (we need to get 3/6/9 fractional digits).
x := t.Format("2006-01-02T15:04:05.000000000")
x = strings.TrimSuffix(x, "000")
x = strings.TrimSuffix(x, "000")
out.write(`"`)
out.write(x)
out.write(`Z"`)
return out.err
case "Value":
// Value has a single oneof.
kind := s.Field(0)
if kind.IsNil() {
// "absence of any variant indicates an error"
return errors.New("nil Value")
}
// oneof -> *T -> T -> T.F
x := kind.Elem().Elem().Field(0)
// TODO: pass the correct Properties if needed.
return m.marshalValue(out, &proto.Properties{}, x, indent)
}
}
out.write("{")
if m.Indent != "" {
out.write("\n")
}
firstField := true
if typeURL != "" {
if err := m.marshalTypeURL(out, indent, typeURL); err != nil {
return err
}
firstField = false
}
for i := 0; i < s.NumField(); i++ {
value := s.Field(i)
valueField := s.Type().Field(i)
if strings.HasPrefix(valueField.Name, "XXX_") {
continue
}
// IsNil will panic on most value kinds.
switch value.Kind() {
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
if value.IsNil() {
continue
}
}
if !m.EmitDefaults {
switch value.Kind() {
case reflect.Bool:
if !value.Bool() {
continue
}
case reflect.Int32, reflect.Int64:
if value.Int() == 0 {
continue
}
case reflect.Uint32, reflect.Uint64:
if value.Uint() == 0 {
continue
}
case reflect.Float32, reflect.Float64:
if value.Float() == 0 {
continue
}
case reflect.String:
if value.Len() == 0 {
continue
}
}
}
// Oneof fields need special handling.
if valueField.Tag.Get("protobuf_oneof") != "" {
// value is an interface containing &T{real_value}.
sv := value.Elem().Elem() // interface -> *T -> T
value = sv.Field(0)
valueField = sv.Type().Field(0)
}
prop := jsonProperties(valueField, m.OrigName)
if !firstField {
m.writeSep(out)
}
if err := m.marshalField(out, prop, value, indent); err != nil {
return err
}
firstField = false
}
// Handle proto2 extensions.
if ep, ok := v.(proto.Message); ok {
extensions := proto.RegisteredExtensions(v)
// Sort extensions for stable output.
ids := make([]int32, 0, len(extensions))
for id, desc := range extensions {
if !proto.HasExtension(ep, desc) {
continue
}
ids = append(ids, id)
}
sort.Sort(int32Slice(ids))
for _, id := range ids {
desc := extensions[id]
if desc == nil {
// unknown extension
continue
}
ext, extErr := proto.GetExtension(ep, desc)
if extErr != nil {
return extErr
}
value := reflect.ValueOf(ext)
var prop proto.Properties
prop.Parse(desc.Tag)
prop.JSONName = fmt.Sprintf("[%s]", desc.Name)
if !firstField {
m.writeSep(out)
}
if err := m.marshalField(out, &prop, value, indent); err != nil {
return err
}
firstField = false
}
}
if m.Indent != "" {
out.write("\n")
out.write(indent)
}
out.write("}")
return out.err
}
func (m *Marshaler) writeSep(out *errWriter) {
if m.Indent != "" {
out.write(",\n")
} else {
out.write(",")
}
}
func (m *Marshaler) marshalAny(out *errWriter, any proto.Message, indent string) error {
// "If the Any contains a value that has a special JSON mapping,
// it will be converted as follows: {"@type": xxx, "value": yyy}.
// Otherwise, the value will be converted into a JSON object,
// and the "@type" field will be inserted to indicate the actual data type."
v := reflect.ValueOf(any).Elem()
turl := v.Field(0).String()
val := v.Field(1).Bytes()
// Only the part of type_url after the last slash is relevant.
mname := turl
if slash := strings.LastIndex(mname, "/"); slash >= 0 {
mname = mname[slash+1:]
}
mt := proto.MessageType(mname)
if mt == nil {
return fmt.Errorf("unknown message type %q", mname)
}
msg := reflect.New(mt.Elem()).Interface().(proto.Message)
if err := proto.Unmarshal(val, msg); err != nil {
return err
}
if _, ok := msg.(wkt); ok {
out.write("{")
if m.Indent != "" {
out.write("\n")
}
if err := m.marshalTypeURL(out, indent, turl); err != nil {
return err
}
m.writeSep(out)
if m.Indent != "" {
out.write(indent)
out.write(m.Indent)
out.write(`"value": `)
} else {
out.write(`"value":`)
}
if err := m.marshalObject(out, msg, indent+m.Indent, ""); err != nil {
return err
}
if m.Indent != "" {
out.write("\n")
out.write(indent)
}
out.write("}")
return out.err
}
return m.marshalObject(out, msg, indent, turl)
}
func (m *Marshaler) marshalTypeURL(out *errWriter, indent, typeURL string) error {
if m.Indent != "" {
out.write(indent)
out.write(m.Indent)
}
out.write(`"@type":`)
if m.Indent != "" {
out.write(" ")
}
b, err := json.Marshal(typeURL)
if err != nil {
return err
}
out.write(string(b))
return out.err
}
// marshalField writes field description and value to the Writer.
func (m *Marshaler) marshalField(out *errWriter, prop *proto.Properties, v reflect.Value, indent string) error {
if m.Indent != "" {
out.write(indent)
out.write(m.Indent)
}
out.write(`"`)
out.write(prop.JSONName)
out.write(`":`)
if m.Indent != "" {
out.write(" ")
}
if err := m.marshalValue(out, prop, v, indent); err != nil {
return err
}
return nil
}
// marshalValue writes the value to the Writer.
func (m *Marshaler) marshalValue(out *errWriter, prop *proto.Properties, v reflect.Value, indent string) error {
var err error
v = reflect.Indirect(v)
// Handle repeated elements.
if v.Kind() == reflect.Slice && v.Type().Elem().Kind() != reflect.Uint8 {
out.write("[")
comma := ""
for i := 0; i < v.Len(); i++ {
sliceVal := v.Index(i)
out.write(comma)
if m.Indent != "" {
out.write("\n")
out.write(indent)
out.write(m.Indent)
out.write(m.Indent)
}
if err := m.marshalValue(out, prop, sliceVal, indent+m.Indent); err != nil {
return err
}
comma = ","
}
if m.Indent != "" {
out.write("\n")
out.write(indent)
out.write(m.Indent)
}
out.write("]")
return out.err
}
// Handle well-known types.
// Most are handled up in marshalObject (because 99% are messages).
type wkt interface {
XXX_WellKnownType() string
}
if wkt, ok := v.Interface().(wkt); ok {
switch wkt.XXX_WellKnownType() {
case "NullValue":
out.write("null")
return out.err
}
}
// Handle enumerations.
if !m.EnumsAsInts && prop.Enum != "" {
// Unknown enum values will are stringified by the proto library as their
// value. Such values should _not_ be quoted or they will be interpreted
// as an enum string instead of their value.
enumStr := v.Interface().(fmt.Stringer).String()
var valStr string
if v.Kind() == reflect.Ptr {
valStr = strconv.Itoa(int(v.Elem().Int()))
} else {
valStr = strconv.Itoa(int(v.Int()))
}
isKnownEnum := enumStr != valStr
if isKnownEnum {
out.write(`"`)
}
out.write(enumStr)
if isKnownEnum {
out.write(`"`)
}
return out.err
}
// Handle nested messages.
if v.Kind() == reflect.Struct {
return m.marshalObject(out, v.Addr().Interface().(proto.Message), indent+m.Indent, "")
}
// Handle maps.
// Since Go randomizes map iteration, we sort keys for stable output.
if v.Kind() == reflect.Map {
out.write(`{`)
keys := v.MapKeys()
sort.Sort(mapKeys(keys))
for i, k := range keys {
if i > 0 {
out.write(`,`)
}
if m.Indent != "" {
out.write("\n")
out.write(indent)
out.write(m.Indent)
out.write(m.Indent)
}
b, err := json.Marshal(k.Interface())
if err != nil {
return err
}
s := string(b)
// If the JSON is not a string value, encode it again to make it one.
if !strings.HasPrefix(s, `"`) {
b, err := json.Marshal(s)
if err != nil {
return err
}
s = string(b)
}
out.write(s)
out.write(`:`)
if m.Indent != "" {
out.write(` `)
}
if err := m.marshalValue(out, prop, v.MapIndex(k), indent+m.Indent); err != nil {
return err
}
}
if m.Indent != "" {
out.write("\n")
out.write(indent)
out.write(m.Indent)
}
out.write(`}`)
return out.err
}
// Default handling defers to the encoding/json library.
b, err := json.Marshal(v.Interface())
if err != nil {
return err
}
needToQuote := string(b[0]) != `"` && (v.Kind() == reflect.Int64 || v.Kind() == reflect.Uint64)
if needToQuote {
out.write(`"`)
}
out.write(string(b))
if needToQuote {
out.write(`"`)
}
return out.err
}
// Unmarshaler is a configurable object for converting from a JSON
// representation to a protocol buffer object.
type Unmarshaler struct {
// Whether to allow messages to contain unknown fields, as opposed to
// failing to unmarshal.
AllowUnknownFields bool
}
// UnmarshalNext unmarshals the next protocol buffer from a JSON object stream.
// This function is lenient and will decode any options permutations of the
// related Marshaler.
func (u *Unmarshaler) UnmarshalNext(dec *json.Decoder, pb proto.Message) error {
inputValue := json.RawMessage{}
if err := dec.Decode(&inputValue); err != nil {
return err
}
return u.unmarshalValue(reflect.ValueOf(pb).Elem(), inputValue, nil)
}
// Unmarshal unmarshals a JSON object stream into a protocol
// buffer. This function is lenient and will decode any options
// permutations of the related Marshaler.
func (u *Unmarshaler) Unmarshal(r io.Reader, pb proto.Message) error {
dec := json.NewDecoder(r)
return u.UnmarshalNext(dec, pb)
}
// UnmarshalNext unmarshals the next protocol buffer from a JSON object stream.
// This function is lenient and will decode any options permutations of the
// related Marshaler.
func UnmarshalNext(dec *json.Decoder, pb proto.Message) error {
return new(Unmarshaler).UnmarshalNext(dec, pb)
}
// Unmarshal unmarshals a JSON object stream into a protocol
// buffer. This function is lenient and will decode any options
// permutations of the related Marshaler.
func Unmarshal(r io.Reader, pb proto.Message) error {
return new(Unmarshaler).Unmarshal(r, pb)
}
// UnmarshalString will populate the fields of a protocol buffer based
// on a JSON string. This function is lenient and will decode any options
// permutations of the related Marshaler.
func UnmarshalString(str string, pb proto.Message) error {
return new(Unmarshaler).Unmarshal(strings.NewReader(str), pb)
}
// unmarshalValue converts/copies a value into the target.
// prop may be nil.
func (u *Unmarshaler) unmarshalValue(target reflect.Value, inputValue json.RawMessage, prop *proto.Properties) error {
targetType := target.Type()
// Allocate memory for pointer fields.
if targetType.Kind() == reflect.Ptr {
target.Set(reflect.New(targetType.Elem()))
return u.unmarshalValue(target.Elem(), inputValue, prop)
}
// Handle well-known types.
type wkt interface {
XXX_WellKnownType() string
}
if wkt, ok := target.Addr().Interface().(wkt); ok {
switch wkt.XXX_WellKnownType() {
case "DoubleValue", "FloatValue", "Int64Value", "UInt64Value",
"Int32Value", "UInt32Value", "BoolValue", "StringValue", "BytesValue":
// "Wrappers use the same representation in JSON
// as the wrapped primitive type, except that null is allowed."
// encoding/json will turn JSON `null` into Go `nil`,
// so we don't have to do any extra work.
return u.unmarshalValue(target.Field(0), inputValue, prop)
case "Any":
return fmt.Errorf("unmarshaling Any not supported yet")
case "Duration":
ivStr := string(inputValue)
if ivStr == "null" {
target.Field(0).SetInt(0)
target.Field(1).SetInt(0)
return nil
}
unq, err := strconv.Unquote(ivStr)
if err != nil {
return err
}
d, err := time.ParseDuration(unq)
if err != nil {
return fmt.Errorf("bad Duration: %v", err)
}
ns := d.Nanoseconds()
s := ns / 1e9
ns %= 1e9
target.Field(0).SetInt(s)
target.Field(1).SetInt(ns)
return nil
case "Timestamp":
ivStr := string(inputValue)
if ivStr == "null" {
target.Field(0).SetInt(0)
target.Field(1).SetInt(0)
return nil
}
unq, err := strconv.Unquote(ivStr)
if err != nil {
return err
}
t, err := time.Parse(time.RFC3339Nano, unq)
if err != nil {
return fmt.Errorf("bad Timestamp: %v", err)
}
target.Field(0).SetInt(int64(t.Unix()))
target.Field(1).SetInt(int64(t.Nanosecond()))
return nil
}
}
// Handle enums, which have an underlying type of int32,
// and may appear as strings.
// The case of an enum appearing as a number is handled
// at the bottom of this function.
if inputValue[0] == '"' && prop != nil && prop.Enum != "" {
vmap := proto.EnumValueMap(prop.Enum)
// Don't need to do unquoting; valid enum names
// are from a limited character set.
s := inputValue[1 : len(inputValue)-1]
n, ok := vmap[string(s)]
if !ok {
return fmt.Errorf("unknown value %q for enum %s", s, prop.Enum)
}
if target.Kind() == reflect.Ptr { // proto2
target.Set(reflect.New(targetType.Elem()))
target = target.Elem()
}
target.SetInt(int64(n))
return nil
}
// Handle nested messages.
if targetType.Kind() == reflect.Struct {
var jsonFields map[string]json.RawMessage
if err := json.Unmarshal(inputValue, &jsonFields); err != nil {
return err
}
consumeField := func(prop *proto.Properties) (json.RawMessage, bool) {
// Be liberal in what names we accept; both orig_name and camelName are okay.
fieldNames := acceptedJSONFieldNames(prop)
vOrig, okOrig := jsonFields[fieldNames.orig]
vCamel, okCamel := jsonFields[fieldNames.camel]
if !okOrig && !okCamel {
return nil, false
}
// If, for some reason, both are present in the data, favour the camelName.
var raw json.RawMessage
if okOrig {
raw = vOrig
delete(jsonFields, fieldNames.orig)
}
if okCamel {
raw = vCamel
delete(jsonFields, fieldNames.camel)
}
return raw, true
}
sprops := proto.GetProperties(targetType)
for i := 0; i < target.NumField(); i++ {
ft := target.Type().Field(i)
if strings.HasPrefix(ft.Name, "XXX_") {
continue
}
valueForField, ok := consumeField(sprops.Prop[i])
if !ok {
continue
}
if err := u.unmarshalValue(target.Field(i), valueForField, sprops.Prop[i]); err != nil {
return err
}
}
// Check for any oneof fields.
if len(jsonFields) > 0 {
for _, oop := range sprops.OneofTypes {
raw, ok := consumeField(oop.Prop)
if !ok {
continue
}
nv := reflect.New(oop.Type.Elem())
target.Field(oop.Field).Set(nv)
if err := u.unmarshalValue(nv.Elem().Field(0), raw, oop.Prop); err != nil {
return err
}
}
}
if !u.AllowUnknownFields && len(jsonFields) > 0 {
// Pick any field to be the scapegoat.
var f string
for fname := range jsonFields {
f = fname
break
}
return fmt.Errorf("unknown field %q in %v", f, targetType)
}
return nil
}
// Handle arrays (which aren't encoded bytes)
if targetType.Kind() == reflect.Slice && targetType.Elem().Kind() != reflect.Uint8 {
var slc []json.RawMessage
if err := json.Unmarshal(inputValue, &slc); err != nil {
return err
}
len := len(slc)
target.Set(reflect.MakeSlice(targetType, len, len))
for i := 0; i < len; i++ {
if err := u.unmarshalValue(target.Index(i), slc[i], prop); err != nil {
return err
}
}
return nil
}
// Handle maps (whose keys are always strings)
if targetType.Kind() == reflect.Map {
var mp map[string]json.RawMessage
if err := json.Unmarshal(inputValue, &mp); err != nil {
return err
}
target.Set(reflect.MakeMap(targetType))
var keyprop, valprop *proto.Properties
if prop != nil {
// These could still be nil if the protobuf metadata is broken somehow.
// TODO: This won't work because the fields are unexported.
// We should probably just reparse them.
//keyprop, valprop = prop.mkeyprop, prop.mvalprop
}
for ks, raw := range mp {
// Unmarshal map key. The core json library already decoded the key into a
// string, so we handle that specially. Other types were quoted post-serialization.
var k reflect.Value
if targetType.Key().Kind() == reflect.String {
k = reflect.ValueOf(ks)
} else {
k = reflect.New(targetType.Key()).Elem()
if err := u.unmarshalValue(k, json.RawMessage(ks), keyprop); err != nil {
return err
}
}
// Unmarshal map value.
v := reflect.New(targetType.Elem()).Elem()
if err := u.unmarshalValue(v, raw, valprop); err != nil {
return err
}
target.SetMapIndex(k, v)
}
return nil
}
// 64-bit integers can be encoded as strings. In this case we drop
// the quotes and proceed as normal.
isNum := targetType.Kind() == reflect.Int64 || targetType.Kind() == reflect.Uint64
if isNum && strings.HasPrefix(string(inputValue), `"`) {
inputValue = inputValue[1 : len(inputValue)-1]
}
// Use the encoding/json for parsing other value types.
return json.Unmarshal(inputValue, target.Addr().Interface())
}
// jsonProperties returns parsed proto.Properties for the field and corrects JSONName attribute.
func jsonProperties(f reflect.StructField, origName bool) *proto.Properties {
var prop proto.Properties
prop.Init(f.Type, f.Name, f.Tag.Get("protobuf"), &f)
if origName || prop.JSONName == "" {
prop.JSONName = prop.OrigName
}
return &prop
}
type fieldNames struct {
orig, camel string
}
func acceptedJSONFieldNames(prop *proto.Properties) fieldNames {
opts := fieldNames{orig: prop.OrigName, camel: prop.OrigName}
if prop.JSONName != "" {
opts.camel = prop.JSONName
}
return opts
}
// Writer wrapper inspired by https://blog.golang.org/errors-are-values
type errWriter struct {
writer io.Writer
err error
}
func (w *errWriter) write(str string) {
if w.err != nil {
return
}
_, w.err = w.writer.Write([]byte(str))
}
// Map fields may have key types of non-float scalars, strings and enums.
// The easiest way to sort them in some deterministic order is to use fmt.
// If this turns out to be inefficient we can always consider other options,
// such as doing a Schwartzian transform.
//
// Numeric keys are sorted in numeric order per
// https://developers.google.com/protocol-buffers/docs/proto#maps.
type mapKeys []reflect.Value
func (s mapKeys) Len() int { return len(s) }
func (s mapKeys) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func (s mapKeys) Less(i, j int) bool {
if k := s[i].Kind(); k == s[j].Kind() {
switch k {
case reflect.Int32, reflect.Int64:
return s[i].Int() < s[j].Int()
case reflect.Uint32, reflect.Uint64:
return s[i].Uint() < s[j].Uint()
}
}
return fmt.Sprint(s[i].Interface()) < fmt.Sprint(s[j].Interface())
}

563
vendor/github.com/golang/protobuf/jsonpb/jsonpb_test.go generated vendored Normal file
View file

@ -0,0 +1,563 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2015 The Go Authors. All rights reserved.
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package jsonpb
import (
"bytes"
"encoding/json"
"io"
"reflect"
"strings"
"testing"
"github.com/golang/protobuf/proto"
pb "github.com/golang/protobuf/jsonpb/jsonpb_test_proto"
proto3pb "github.com/golang/protobuf/proto/proto3_proto"
anypb "github.com/golang/protobuf/ptypes/any"
durpb "github.com/golang/protobuf/ptypes/duration"
stpb "github.com/golang/protobuf/ptypes/struct"
tspb "github.com/golang/protobuf/ptypes/timestamp"
wpb "github.com/golang/protobuf/ptypes/wrappers"
)
var (
marshaler = Marshaler{}
marshalerAllOptions = Marshaler{
Indent: " ",
}
simpleObject = &pb.Simple{
OInt32: proto.Int32(-32),
OInt64: proto.Int64(-6400000000),
OUint32: proto.Uint32(32),
OUint64: proto.Uint64(6400000000),
OSint32: proto.Int32(-13),
OSint64: proto.Int64(-2600000000),
OFloat: proto.Float32(3.14),
ODouble: proto.Float64(6.02214179e23),
OBool: proto.Bool(true),
OString: proto.String("hello \"there\""),
OBytes: []byte("beep boop"),
}
simpleObjectJSON = `{` +
`"oBool":true,` +
`"oInt32":-32,` +
`"oInt64":"-6400000000",` +
`"oUint32":32,` +
`"oUint64":"6400000000",` +
`"oSint32":-13,` +
`"oSint64":"-2600000000",` +
`"oFloat":3.14,` +
`"oDouble":6.02214179e+23,` +
`"oString":"hello \"there\"",` +
`"oBytes":"YmVlcCBib29w"` +
`}`
simpleObjectPrettyJSON = `{
"oBool": true,
"oInt32": -32,
"oInt64": "-6400000000",
"oUint32": 32,
"oUint64": "6400000000",
"oSint32": -13,
"oSint64": "-2600000000",
"oFloat": 3.14,
"oDouble": 6.02214179e+23,
"oString": "hello \"there\"",
"oBytes": "YmVlcCBib29w"
}`
repeatsObject = &pb.Repeats{
RBool: []bool{true, false, true},
RInt32: []int32{-3, -4, -5},
RInt64: []int64{-123456789, -987654321},
RUint32: []uint32{1, 2, 3},
RUint64: []uint64{6789012345, 3456789012},
RSint32: []int32{-1, -2, -3},
RSint64: []int64{-6789012345, -3456789012},
RFloat: []float32{3.14, 6.28},
RDouble: []float64{299792458 * 1e20, 6.62606957e-34},
RString: []string{"happy", "days"},
RBytes: [][]byte{[]byte("skittles"), []byte("m&m's")},
}
repeatsObjectJSON = `{` +
`"rBool":[true,false,true],` +
`"rInt32":[-3,-4,-5],` +
`"rInt64":["-123456789","-987654321"],` +
`"rUint32":[1,2,3],` +
`"rUint64":["6789012345","3456789012"],` +
`"rSint32":[-1,-2,-3],` +
`"rSint64":["-6789012345","-3456789012"],` +
`"rFloat":[3.14,6.28],` +
`"rDouble":[2.99792458e+28,6.62606957e-34],` +
`"rString":["happy","days"],` +
`"rBytes":["c2tpdHRsZXM=","bSZtJ3M="]` +
`}`
repeatsObjectPrettyJSON = `{
"rBool": [
true,
false,
true
],
"rInt32": [
-3,
-4,
-5
],
"rInt64": [
"-123456789",
"-987654321"
],
"rUint32": [
1,
2,
3
],
"rUint64": [
"6789012345",
"3456789012"
],
"rSint32": [
-1,
-2,
-3
],
"rSint64": [
"-6789012345",
"-3456789012"
],
"rFloat": [
3.14,
6.28
],
"rDouble": [
2.99792458e+28,
6.62606957e-34
],
"rString": [
"happy",
"days"
],
"rBytes": [
"c2tpdHRsZXM=",
"bSZtJ3M="
]
}`
innerSimple = &pb.Simple{OInt32: proto.Int32(-32)}
innerSimple2 = &pb.Simple{OInt64: proto.Int64(25)}
innerRepeats = &pb.Repeats{RString: []string{"roses", "red"}}
innerRepeats2 = &pb.Repeats{RString: []string{"violets", "blue"}}
complexObject = &pb.Widget{
Color: pb.Widget_GREEN.Enum(),
RColor: []pb.Widget_Color{pb.Widget_RED, pb.Widget_GREEN, pb.Widget_BLUE},
Simple: innerSimple,
RSimple: []*pb.Simple{innerSimple, innerSimple2},
Repeats: innerRepeats,
RRepeats: []*pb.Repeats{innerRepeats, innerRepeats2},
}
complexObjectJSON = `{"color":"GREEN",` +
`"rColor":["RED","GREEN","BLUE"],` +
`"simple":{"oInt32":-32},` +
`"rSimple":[{"oInt32":-32},{"oInt64":"25"}],` +
`"repeats":{"rString":["roses","red"]},` +
`"rRepeats":[{"rString":["roses","red"]},{"rString":["violets","blue"]}]` +
`}`
complexObjectPrettyJSON = `{
"color": "GREEN",
"rColor": [
"RED",
"GREEN",
"BLUE"
],
"simple": {
"oInt32": -32
},
"rSimple": [
{
"oInt32": -32
},
{
"oInt64": "25"
}
],
"repeats": {
"rString": [
"roses",
"red"
]
},
"rRepeats": [
{
"rString": [
"roses",
"red"
]
},
{
"rString": [
"violets",
"blue"
]
}
]
}`
colorPrettyJSON = `{
"color": 2
}`
colorListPrettyJSON = `{
"color": 1000,
"rColor": [
"RED"
]
}`
nummyPrettyJSON = `{
"nummy": {
"1": 2,
"3": 4
}
}`
objjyPrettyJSON = `{
"objjy": {
"1": {
"dub": 1
}
}
}`
realNumber = &pb.Real{Value: proto.Float64(3.14159265359)}
realNumberName = "Pi"
complexNumber = &pb.Complex{Imaginary: proto.Float64(0.5772156649)}
realNumberJSON = `{` +
`"value":3.14159265359,` +
`"[jsonpb.Complex.real_extension]":{"imaginary":0.5772156649},` +
`"[jsonpb.name]":"Pi"` +
`}`
anySimple = &pb.KnownTypes{
An: &anypb.Any{
TypeUrl: "something.example.com/jsonpb.Simple",
Value: []byte{
// &pb.Simple{OBool:true}
1 << 3, 1,
},
},
}
anySimpleJSON = `{"an":{"@type":"something.example.com/jsonpb.Simple","oBool":true}}`
anySimplePrettyJSON = `{
"an": {
"@type": "something.example.com/jsonpb.Simple",
"oBool": true
}
}`
anyWellKnown = &pb.KnownTypes{
An: &anypb.Any{
TypeUrl: "type.googleapis.com/google.protobuf.Duration",
Value: []byte{
// &durpb.Duration{Seconds: 1, Nanos: 212000000 }
1 << 3, 1, // seconds
2 << 3, 0x80, 0xba, 0x8b, 0x65, // nanos
},
},
}
anyWellKnownJSON = `{"an":{"@type":"type.googleapis.com/google.protobuf.Duration","value":"1.212s"}}`
anyWellKnownPrettyJSON = `{
"an": {
"@type": "type.googleapis.com/google.protobuf.Duration",
"value": "1.212s"
}
}`
)
func init() {
if err := proto.SetExtension(realNumber, pb.E_Name, &realNumberName); err != nil {
panic(err)
}
if err := proto.SetExtension(realNumber, pb.E_Complex_RealExtension, complexNumber); err != nil {
panic(err)
}
}
var marshalingTests = []struct {
desc string
marshaler Marshaler
pb proto.Message
json string
}{
{"simple flat object", marshaler, simpleObject, simpleObjectJSON},
{"simple pretty object", marshalerAllOptions, simpleObject, simpleObjectPrettyJSON},
{"repeated fields flat object", marshaler, repeatsObject, repeatsObjectJSON},
{"repeated fields pretty object", marshalerAllOptions, repeatsObject, repeatsObjectPrettyJSON},
{"nested message/enum flat object", marshaler, complexObject, complexObjectJSON},
{"nested message/enum pretty object", marshalerAllOptions, complexObject, complexObjectPrettyJSON},
{"enum-string flat object", Marshaler{},
&pb.Widget{Color: pb.Widget_BLUE.Enum()}, `{"color":"BLUE"}`},
{"enum-value pretty object", Marshaler{EnumsAsInts: true, Indent: " "},
&pb.Widget{Color: pb.Widget_BLUE.Enum()}, colorPrettyJSON},
{"unknown enum value object", marshalerAllOptions,
&pb.Widget{Color: pb.Widget_Color(1000).Enum(), RColor: []pb.Widget_Color{pb.Widget_RED}}, colorListPrettyJSON},
{"repeated proto3 enum", Marshaler{},
&proto3pb.Message{RFunny: []proto3pb.Message_Humour{
proto3pb.Message_PUNS,
proto3pb.Message_SLAPSTICK,
}},
`{"rFunny":["PUNS","SLAPSTICK"]}`},
{"repeated proto3 enum as int", Marshaler{EnumsAsInts: true},
&proto3pb.Message{RFunny: []proto3pb.Message_Humour{
proto3pb.Message_PUNS,
proto3pb.Message_SLAPSTICK,
}},
`{"rFunny":[1,2]}`},
{"empty value", marshaler, &pb.Simple3{}, `{}`},
{"empty value emitted", Marshaler{EmitDefaults: true}, &pb.Simple3{}, `{"dub":0}`},
{"map<int64, int32>", marshaler, &pb.Mappy{Nummy: map[int64]int32{1: 2, 3: 4}}, `{"nummy":{"1":2,"3":4}}`},
{"map<int64, int32>", marshalerAllOptions, &pb.Mappy{Nummy: map[int64]int32{1: 2, 3: 4}}, nummyPrettyJSON},
{"map<string, string>", marshaler,
&pb.Mappy{Strry: map[string]string{`"one"`: "two", "three": "four"}},
`{"strry":{"\"one\"":"two","three":"four"}}`},
{"map<int32, Object>", marshaler,
&pb.Mappy{Objjy: map[int32]*pb.Simple3{1: &pb.Simple3{Dub: 1}}}, `{"objjy":{"1":{"dub":1}}}`},
{"map<int32, Object>", marshalerAllOptions,
&pb.Mappy{Objjy: map[int32]*pb.Simple3{1: &pb.Simple3{Dub: 1}}}, objjyPrettyJSON},
{"map<int64, string>", marshaler, &pb.Mappy{Buggy: map[int64]string{1234: "yup"}},
`{"buggy":{"1234":"yup"}}`},
{"map<bool, bool>", marshaler, &pb.Mappy{Booly: map[bool]bool{false: true}}, `{"booly":{"false":true}}`},
// TODO: This is broken.
//{"map<string, enum>", marshaler, &pb.Mappy{Enumy: map[string]pb.Numeral{"XIV": pb.Numeral_ROMAN}}, `{"enumy":{"XIV":"ROMAN"}`},
{"map<string, enum as int>", Marshaler{EnumsAsInts: true}, &pb.Mappy{Enumy: map[string]pb.Numeral{"XIV": pb.Numeral_ROMAN}}, `{"enumy":{"XIV":2}}`},
{"map<int32, bool>", marshaler, &pb.Mappy{S32Booly: map[int32]bool{1: true, 3: false, 10: true, 12: false}}, `{"s32booly":{"1":true,"3":false,"10":true,"12":false}}`},
{"map<int64, bool>", marshaler, &pb.Mappy{S64Booly: map[int64]bool{1: true, 3: false, 10: true, 12: false}}, `{"s64booly":{"1":true,"3":false,"10":true,"12":false}}`},
{"map<uint32, bool>", marshaler, &pb.Mappy{U32Booly: map[uint32]bool{1: true, 3: false, 10: true, 12: false}}, `{"u32booly":{"1":true,"3":false,"10":true,"12":false}}`},
{"map<uint64, bool>", marshaler, &pb.Mappy{U64Booly: map[uint64]bool{1: true, 3: false, 10: true, 12: false}}, `{"u64booly":{"1":true,"3":false,"10":true,"12":false}}`},
{"proto2 map<int64, string>", marshaler, &pb.Maps{MInt64Str: map[int64]string{213: "cat"}},
`{"mInt64Str":{"213":"cat"}}`},
{"proto2 map<bool, Object>", marshaler,
&pb.Maps{MBoolSimple: map[bool]*pb.Simple{true: &pb.Simple{OInt32: proto.Int32(1)}}},
`{"mBoolSimple":{"true":{"oInt32":1}}}`},
{"oneof, not set", marshaler, &pb.MsgWithOneof{}, `{}`},
{"oneof, set", marshaler, &pb.MsgWithOneof{Union: &pb.MsgWithOneof_Title{"Grand Poobah"}}, `{"title":"Grand Poobah"}`},
{"force orig_name", Marshaler{OrigName: true}, &pb.Simple{OInt32: proto.Int32(4)},
`{"o_int32":4}`},
{"proto2 extension", marshaler, realNumber, realNumberJSON},
{"Any with message", marshaler, anySimple, anySimpleJSON},
{"Any with message and indent", marshalerAllOptions, anySimple, anySimplePrettyJSON},
{"Any with WKT", marshaler, anyWellKnown, anyWellKnownJSON},
{"Any with WKT and indent", marshalerAllOptions, anyWellKnown, anyWellKnownPrettyJSON},
{"Duration", marshaler, &pb.KnownTypes{Dur: &durpb.Duration{Seconds: 3}}, `{"dur":"3.000s"}`},
{"Struct", marshaler, &pb.KnownTypes{St: &stpb.Struct{
Fields: map[string]*stpb.Value{
"one": &stpb.Value{Kind: &stpb.Value_StringValue{"loneliest number"}},
"two": &stpb.Value{Kind: &stpb.Value_NullValue{stpb.NullValue_NULL_VALUE}},
},
}}, `{"st":{"one":"loneliest number","two":null}}`},
{"Timestamp", marshaler, &pb.KnownTypes{Ts: &tspb.Timestamp{Seconds: 14e8, Nanos: 21e6}}, `{"ts":"2014-05-13T16:53:20.021Z"}`},
{"DoubleValue", marshaler, &pb.KnownTypes{Dbl: &wpb.DoubleValue{Value: 1.2}}, `{"dbl":1.2}`},
{"FloatValue", marshaler, &pb.KnownTypes{Flt: &wpb.FloatValue{Value: 1.2}}, `{"flt":1.2}`},
{"Int64Value", marshaler, &pb.KnownTypes{I64: &wpb.Int64Value{Value: -3}}, `{"i64":"-3"}`},
{"UInt64Value", marshaler, &pb.KnownTypes{U64: &wpb.UInt64Value{Value: 3}}, `{"u64":"3"}`},
{"Int32Value", marshaler, &pb.KnownTypes{I32: &wpb.Int32Value{Value: -4}}, `{"i32":-4}`},
{"UInt32Value", marshaler, &pb.KnownTypes{U32: &wpb.UInt32Value{Value: 4}}, `{"u32":4}`},
{"BoolValue", marshaler, &pb.KnownTypes{Bool: &wpb.BoolValue{Value: true}}, `{"bool":true}`},
{"StringValue", marshaler, &pb.KnownTypes{Str: &wpb.StringValue{Value: "plush"}}, `{"str":"plush"}`},
{"BytesValue", marshaler, &pb.KnownTypes{Bytes: &wpb.BytesValue{Value: []byte("wow")}}, `{"bytes":"d293"}`},
}
func TestMarshaling(t *testing.T) {
for _, tt := range marshalingTests {
json, err := tt.marshaler.MarshalToString(tt.pb)
if err != nil {
t.Errorf("%s: marshaling error: %v", tt.desc, err)
} else if tt.json != json {
t.Errorf("%s: got [%v] want [%v]", tt.desc, json, tt.json)
}
}
}
var unmarshalingTests = []struct {
desc string
unmarshaler Unmarshaler
json string
pb proto.Message
}{
{"simple flat object", Unmarshaler{}, simpleObjectJSON, simpleObject},
{"simple pretty object", Unmarshaler{}, simpleObjectPrettyJSON, simpleObject},
{"repeated fields flat object", Unmarshaler{}, repeatsObjectJSON, repeatsObject},
{"repeated fields pretty object", Unmarshaler{}, repeatsObjectPrettyJSON, repeatsObject},
{"nested message/enum flat object", Unmarshaler{}, complexObjectJSON, complexObject},
{"nested message/enum pretty object", Unmarshaler{}, complexObjectPrettyJSON, complexObject},
{"enum-string object", Unmarshaler{}, `{"color":"BLUE"}`, &pb.Widget{Color: pb.Widget_BLUE.Enum()}},
{"enum-value object", Unmarshaler{}, "{\n \"color\": 2\n}", &pb.Widget{Color: pb.Widget_BLUE.Enum()}},
{"unknown field with allowed option", Unmarshaler{AllowUnknownFields: true}, `{"unknown": "foo"}`, new(pb.Simple)},
{"proto3 enum string", Unmarshaler{}, `{"hilarity":"PUNS"}`, &proto3pb.Message{Hilarity: proto3pb.Message_PUNS}},
{"proto3 enum value", Unmarshaler{}, `{"hilarity":1}`, &proto3pb.Message{Hilarity: proto3pb.Message_PUNS}},
{"unknown enum value object",
Unmarshaler{},
"{\n \"color\": 1000,\n \"r_color\": [\n \"RED\"\n ]\n}",
&pb.Widget{Color: pb.Widget_Color(1000).Enum(), RColor: []pb.Widget_Color{pb.Widget_RED}}},
{"repeated proto3 enum", Unmarshaler{}, `{"rFunny":["PUNS","SLAPSTICK"]}`,
&proto3pb.Message{RFunny: []proto3pb.Message_Humour{
proto3pb.Message_PUNS,
proto3pb.Message_SLAPSTICK,
}}},
{"repeated proto3 enum as int", Unmarshaler{}, `{"rFunny":[1,2]}`,
&proto3pb.Message{RFunny: []proto3pb.Message_Humour{
proto3pb.Message_PUNS,
proto3pb.Message_SLAPSTICK,
}}},
{"repeated proto3 enum as mix of strings and ints", Unmarshaler{}, `{"rFunny":["PUNS",2]}`,
&proto3pb.Message{RFunny: []proto3pb.Message_Humour{
proto3pb.Message_PUNS,
proto3pb.Message_SLAPSTICK,
}}},
{"unquoted int64 object", Unmarshaler{}, `{"oInt64":-314}`, &pb.Simple{OInt64: proto.Int64(-314)}},
{"unquoted uint64 object", Unmarshaler{}, `{"oUint64":123}`, &pb.Simple{OUint64: proto.Uint64(123)}},
{"map<int64, int32>", Unmarshaler{}, `{"nummy":{"1":2,"3":4}}`, &pb.Mappy{Nummy: map[int64]int32{1: 2, 3: 4}}},
{"map<string, string>", Unmarshaler{}, `{"strry":{"\"one\"":"two","three":"four"}}`, &pb.Mappy{Strry: map[string]string{`"one"`: "two", "three": "four"}}},
{"map<int32, Object>", Unmarshaler{}, `{"objjy":{"1":{"dub":1}}}`, &pb.Mappy{Objjy: map[int32]*pb.Simple3{1: &pb.Simple3{Dub: 1}}}},
// TODO: This is broken.
//{"map<string, enum>", Unmarshaler{}, `{"enumy":{"XIV":"ROMAN"}`, &pb.Mappy{Enumy: map[string]pb.Numeral{"XIV": pb.Numeral_ROMAN}}},
{"map<string, enum as int>", Unmarshaler{}, `{"enumy":{"XIV":2}}`, &pb.Mappy{Enumy: map[string]pb.Numeral{"XIV": pb.Numeral_ROMAN}}},
{"oneof", Unmarshaler{}, `{"salary":31000}`, &pb.MsgWithOneof{Union: &pb.MsgWithOneof_Salary{31000}}},
{"oneof spec name", Unmarshaler{}, `{"Country":"Australia"}`, &pb.MsgWithOneof{Union: &pb.MsgWithOneof_Country{"Australia"}}},
{"oneof orig_name", Unmarshaler{}, `{"Country":"Australia"}`, &pb.MsgWithOneof{Union: &pb.MsgWithOneof_Country{"Australia"}}},
{"oneof spec name2", Unmarshaler{}, `{"homeAddress":"Australia"}`, &pb.MsgWithOneof{Union: &pb.MsgWithOneof_HomeAddress{"Australia"}}},
{"oneof orig_name2", Unmarshaler{}, `{"home_address":"Australia"}`, &pb.MsgWithOneof{Union: &pb.MsgWithOneof_HomeAddress{"Australia"}}},
{"orig_name input", Unmarshaler{}, `{"o_bool":true}`, &pb.Simple{OBool: proto.Bool(true)}},
{"camelName input", Unmarshaler{}, `{"oBool":true}`, &pb.Simple{OBool: proto.Bool(true)}},
{"Duration", Unmarshaler{}, `{"dur":"3.000s"}`, &pb.KnownTypes{Dur: &durpb.Duration{Seconds: 3}}},
{"null Duration", Unmarshaler{}, `{"dur":null}`, &pb.KnownTypes{Dur: &durpb.Duration{Seconds: 0}}},
{"Timestamp", Unmarshaler{}, `{"ts":"2014-05-13T16:53:20.021Z"}`, &pb.KnownTypes{Ts: &tspb.Timestamp{Seconds: 14e8, Nanos: 21e6}}},
{"PreEpochTimestamp", Unmarshaler{}, `{"ts":"1969-12-31T23:59:58.999999995Z"}`, &pb.KnownTypes{Ts: &tspb.Timestamp{Seconds: -2, Nanos: 999999995}}},
{"ZeroTimeTimestamp", Unmarshaler{}, `{"ts":"0001-01-01T00:00:00Z"}`, &pb.KnownTypes{Ts: &tspb.Timestamp{Seconds: -62135596800, Nanos: 0}}},
{"null Timestamp", Unmarshaler{}, `{"ts":null}`, &pb.KnownTypes{Ts: &tspb.Timestamp{Seconds: 0, Nanos: 0}}},
{"DoubleValue", Unmarshaler{}, `{"dbl":1.2}`, &pb.KnownTypes{Dbl: &wpb.DoubleValue{Value: 1.2}}},
{"FloatValue", Unmarshaler{}, `{"flt":1.2}`, &pb.KnownTypes{Flt: &wpb.FloatValue{Value: 1.2}}},
{"Int64Value", Unmarshaler{}, `{"i64":"-3"}`, &pb.KnownTypes{I64: &wpb.Int64Value{Value: -3}}},
{"UInt64Value", Unmarshaler{}, `{"u64":"3"}`, &pb.KnownTypes{U64: &wpb.UInt64Value{Value: 3}}},
{"Int32Value", Unmarshaler{}, `{"i32":-4}`, &pb.KnownTypes{I32: &wpb.Int32Value{Value: -4}}},
{"UInt32Value", Unmarshaler{}, `{"u32":4}`, &pb.KnownTypes{U32: &wpb.UInt32Value{Value: 4}}},
{"BoolValue", Unmarshaler{}, `{"bool":true}`, &pb.KnownTypes{Bool: &wpb.BoolValue{Value: true}}},
{"StringValue", Unmarshaler{}, `{"str":"plush"}`, &pb.KnownTypes{Str: &wpb.StringValue{Value: "plush"}}},
{"BytesValue", Unmarshaler{}, `{"bytes":"d293"}`, &pb.KnownTypes{Bytes: &wpb.BytesValue{Value: []byte("wow")}}},
// `null` is also a permissible value. Let's just test one.
{"null DoubleValue", Unmarshaler{}, `{"dbl":null}`, &pb.KnownTypes{Dbl: &wpb.DoubleValue{}}},
}
func TestUnmarshaling(t *testing.T) {
for _, tt := range unmarshalingTests {
// Make a new instance of the type of our expected object.
p := reflect.New(reflect.TypeOf(tt.pb).Elem()).Interface().(proto.Message)
err := tt.unmarshaler.Unmarshal(strings.NewReader(tt.json), p)
if err != nil {
t.Errorf("%s: %v", tt.desc, err)
continue
}
// For easier diffs, compare text strings of the protos.
exp := proto.MarshalTextString(tt.pb)
act := proto.MarshalTextString(p)
if string(exp) != string(act) {
t.Errorf("%s: got [%s] want [%s]", tt.desc, act, exp)
}
}
}
func TestUnmarshalNext(t *testing.T) {
// We only need to check against a few, not all of them.
tests := unmarshalingTests[:5]
// Create a buffer with many concatenated JSON objects.
var b bytes.Buffer
for _, tt := range tests {
b.WriteString(tt.json)
}
dec := json.NewDecoder(&b)
for _, tt := range tests {
// Make a new instance of the type of our expected object.
p := reflect.New(reflect.TypeOf(tt.pb).Elem()).Interface().(proto.Message)
err := tt.unmarshaler.UnmarshalNext(dec, p)
if err != nil {
t.Errorf("%s: %v", tt.desc, err)
continue
}
// For easier diffs, compare text strings of the protos.
exp := proto.MarshalTextString(tt.pb)
act := proto.MarshalTextString(p)
if string(exp) != string(act) {
t.Errorf("%s: got [%s] want [%s]", tt.desc, act, exp)
}
}
p := &pb.Simple{}
err := new(Unmarshaler).UnmarshalNext(dec, p)
if err != io.EOF {
t.Errorf("eof: got %v, expected io.EOF", err)
}
}
var unmarshalingShouldError = []struct {
desc string
in string
pb proto.Message
}{
{"a value", "666", new(pb.Simple)},
{"gibberish", "{adskja123;l23=-=", new(pb.Simple)},
{"unknown field", `{"unknown": "foo"}`, new(pb.Simple)},
{"unknown enum name", `{"hilarity":"DAVE"}`, new(proto3pb.Message)},
}
func TestUnmarshalingBadInput(t *testing.T) {
for _, tt := range unmarshalingShouldError {
err := UnmarshalString(tt.in, tt.pb)
if err == nil {
t.Errorf("an error was expected when parsing %q instead of an object", tt.desc)
}
}
}

14
vendor/github.com/gravitational/license/.gitignore generated vendored Normal file
View file

@ -0,0 +1,14 @@
# Binaries for programs and plugins
*.exe
*.dll
*.so
*.dylib
# Test binary, build with `go test -c`
*.test
# Output of the go coverage tool, specifically when used with LiteIDE
*.out
# Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736
.glide/

201
vendor/github.com/gravitational/license/LICENSE generated vendored Normal file
View file

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "{}"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright {yyyy} {name of copyright owner}
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

13
vendor/github.com/gravitational/license/Makefile generated vendored Normal file
View file

@ -0,0 +1,13 @@
#
# test runs tests for all packages
#
.PHONY: test
test:
go test -v -test.parallel=0 -race ./...
#
# build builds all packages
#
.PHONY: build
build:
go build ./...

2
vendor/github.com/gravitational/license/README.md generated vendored Normal file
View file

@ -0,0 +1,2 @@
# license
CA and licensing tools

View file

@ -0,0 +1,98 @@
/*
Copyright 2017 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package constants
import "encoding/asn1"
const (
// TLSKeyAlgo is default TLS algo used for K8s X509 certs
TLSKeyAlgo = "rsa"
// TLSKeySize is default TLS key size used for K8s X509 certs
TLSKeySize = 2048
// RSAPrivateKeyPEMBlock is the name of the PEM block where private key is stored
RSAPrivateKeyPEMBlock = "RSA PRIVATE KEY"
// CertificatePEMBlock is the name of the PEM block where certificate is stored
CertificatePEMBlock = "CERTIFICATE"
// LicenseKeyPair is a name of the license key pair
LicenseKeyPair = "license"
// LoopbackIP is IP of the loopback interface
LoopbackIP = "127.0.0.1"
// LicenseKeyBits used when generating private key for license certificate
LicenseKeyBits = 2048
// LicenseOrg is the default name of license subject organization
LicenseOrg = "gravitational.io"
// LicenseTimeFormat represents format of expiration time in license payload
LicenseTimeFormat = "2006-01-02 15:04:05"
)
// LicenseASNExtensionID is an extension ID used when encoding/decoding
// license payload into certificates
var LicenseASN1ExtensionID = asn1.ObjectIdentifier{2, 5, 42}
// EC2InstanceTypes maps AWS instance types to their number of CPUs,
// used for determining whether license allows a certain instance
// type in some cases
var EC2InstanceTypes = map[string]int{
"t2.nano": 1,
"t2.micro": 1,
"t2.small": 1,
"t2.medium": 2,
"t2.large": 2,
"m3.medium": 1,
"m3.large": 2,
"m3.xlarge": 4,
"m3.2xlarge": 8,
"m4.large": 2,
"m4.xlarge": 4,
"m4.2xlarge": 8,
"m4.4xlarge": 16,
"m4.10xlarge": 40,
"c3.large": 2,
"c3.xlarge": 4,
"c3.2xlarge": 8,
"c3.4xlarge": 16,
"c3.8xlarge": 32,
"c4.large": 2,
"c4.xlarge": 4,
"c4.2xlarge": 8,
"c4.4xlarge": 16,
"c4.8xlarge": 36,
"x1.32xlarge": 128,
"g2.2xlarge": 8,
"g2.8xlarge": 32,
"r3.large": 2,
"r3.xlarge": 4,
"r3.2xlarge": 8,
"r3.4xlarge": 16,
"r3.8xlarge": 32,
"i2.xlarge": 4,
"i2.2xlarge": 8,
"i2.4xlarge": 16,
"i2.8xlarge": 32,
"d2.xlarge": 4,
"d2.2xlarge": 8,
"d2.4xlarge": 16,
"d2.8xlarge": 36,
}

91
vendor/github.com/gravitational/license/license.go generated vendored Normal file
View file

@ -0,0 +1,91 @@
/*
Copyright 2017 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package license
import (
"crypto/x509"
"time"
"github.com/gravitational/trace"
)
// License represents Gravitational license
type License struct {
// Cert is the x509 license certificate
Cert *x509.Certificate
// Payload is the license payload
Payload Payload
// CertPEM is the certificate part of the license in PEM format
CertPEM []byte
// KeyPEM is the private key part of the license in PEM Format,
// may be empty if the license was parsed from certificate only
KeyPEM []byte
}
// Verify makes sure the license is valid
func (l *License) Verify(caPEM []byte) error {
roots := x509.NewCertPool()
// add the provided CA certificate to the roots
ok := roots.AppendCertsFromPEM(caPEM)
if !ok {
return trace.BadParameter("could not find any CA certificates")
}
_, err := l.Cert.Verify(x509.VerifyOptions{Roots: roots})
if err != nil {
certErr, ok := err.(x509.CertificateInvalidError)
if ok && certErr.Reason == x509.Expired {
return trace.BadParameter("the license has expired")
}
return trace.Wrap(err, "failed to verify the license")
}
return nil
}
// Payload is custom information that gets encoded into licenses
type Payload struct {
// ClusterID is vendor-specific cluster ID
ClusterID string `json:"cluster_id,omitempty"`
// Expiration is expiration time for the license
Expiration time.Time `json:"expiration,omitempty"`
// MaxNodes is maximum number of nodes the license allows
MaxNodes int `json:"max_nodes,omitempty"`
// MaxCores is maximum number of CPUs per node the license allows
MaxCores int `json:"max_cores,omitempty"`
// Company is the company name the license is generated for
Company string `json:"company,omitempty"`
// Person is the name of the person the license is generated for
Person string `json:"person,omitempty"`
// Email is the email of the person the license is generated for
Email string `json:"email,omitempty"`
// Metadata is an arbitrary customer metadata
Metadata string `json:"metadata,omitempty"`
// ProductName is the name of the product the license is for
ProductName string `json:"product_name,omitempty"`
// ProductVersion is the product version
ProductVersion string `json:"product_version,omitempty"`
// EncryptionKey is the passphrase for decoding encrypted packages
EncryptionKey []byte `json:"encryption_key,omitempty"`
// Signature is vendor-specific signature
Signature string `json:"signature,omitempty"`
// Shutdown indicates whether the app should be stopped when the license expires
Shutdown bool `json:"shutdown,omitempty"`
// AccountID is the ID of the account the license was issued for
AccountID string `json:"account_id,omitempty"`
}

159
vendor/github.com/gravitational/license/parse.go generated vendored Normal file
View file

@ -0,0 +1,159 @@
/*
Copyright 2017 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package license
import (
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/json"
"encoding/pem"
"github.com/gravitational/license/constants"
"github.com/gravitational/trace"
)
// ParseString parses the license from the provided string
func ParseString(pem string) (*License, error) {
certPEM, keyPEM, err := SplitPEM([]byte(pem))
if err != nil {
return nil, trace.Wrap(err)
}
certificateBytes, privateBytes, err := parseCertificatePEM(pem)
if err != nil {
return nil, trace.Wrap(err)
}
certificate, err := x509.ParseCertificate(certificateBytes)
if err != nil {
return nil, trace.Wrap(err)
}
payload, err := parsePayloadFromX509(certificate)
if err != nil {
return nil, trace.Wrap(err)
}
// decrypt encryption key
if len(payload.EncryptionKey) != 0 {
private, err := x509.ParsePKCS1PrivateKey(privateBytes)
if err != nil {
return nil, trace.Wrap(err)
}
payload.EncryptionKey, err = rsa.DecryptOAEP(sha256.New(), rand.Reader,
private, payload.EncryptionKey, nil)
if err != nil {
return nil, trace.Wrap(err)
}
}
return &License{
Cert: certificate,
Payload: *payload,
CertPEM: certPEM,
KeyPEM: keyPEM,
}, nil
}
// ParseX509 parses the license from the provided x509 certificate
func ParseX509(cert *x509.Certificate) (*License, error) {
payload, err := parsePayloadFromX509(cert)
if err != nil {
return nil, trace.Wrap(err)
}
return &License{
Cert: cert,
Payload: *payload,
}, nil
}
// MakeTLSCert takes the provided license and makes a TLS certificate
// which is the format used by Go servers
func MakeTLSCert(license License) (*tls.Certificate, error) {
tlsCert, err := tls.X509KeyPair(license.CertPEM, license.KeyPEM)
if err != nil {
return nil, trace.Wrap(err)
}
return &tlsCert, nil
}
// MakeTLSConfig builds a client TLS config from the supplied license
func MakeTLSConfig(license License) (*tls.Config, error) {
tlsCert, err := MakeTLSCert(license)
if err != nil {
return nil, trace.Wrap(err)
}
return &tls.Config{
Certificates: []tls.Certificate{*tlsCert},
}, nil
}
// parsePayloadFromX509 parses the extension with license payload from the
// provided x509 certificate
func parsePayloadFromX509(cert *x509.Certificate) (*Payload, error) {
for _, ext := range cert.Extensions {
if ext.Id.Equal(constants.LicenseASN1ExtensionID) {
var p Payload
if err := json.Unmarshal(ext.Value, &p); err != nil {
return nil, trace.Wrap(err)
}
return &p, nil
}
}
return nil, trace.NotFound(
"certificate does not contain extension with license payload")
}
// parseCertificatePEM parses the concatenated certificate/private key in PEM format
// and returns certificate and private key in decoded DER ASN.1 structure
func parseCertificatePEM(certPEM string) ([]byte, []byte, error) {
var certificateBytes, privateBytes []byte
block, rest := pem.Decode([]byte(certPEM))
for block != nil {
switch block.Type {
case constants.CertificatePEMBlock:
certificateBytes = block.Bytes
case constants.RSAPrivateKeyPEMBlock:
privateBytes = block.Bytes
}
// parse the next block
block, rest = pem.Decode(rest)
}
if len(certificateBytes) == 0 || len(privateBytes) == 0 {
return nil, nil, trace.BadParameter("could not parse the license")
}
return certificateBytes, privateBytes, nil
}
// SplitPEM splits the provided PEM data that contains concatenated cert and key
// (in any order) into cert PEM and key PEM respectively. Returns an error if
// any of them is missing
func SplitPEM(pemData []byte) (certPEM []byte, keyPEM []byte, err error) {
block, rest := pem.Decode(pemData)
for block != nil {
switch block.Type {
case constants.CertificatePEMBlock:
certPEM = pem.EncodeToMemory(block)
case constants.RSAPrivateKeyPEMBlock:
keyPEM = pem.EncodeToMemory(block)
}
block, rest = pem.Decode(rest)
}
if len(certPEM) == 0 || len(keyPEM) == 0 {
return nil, nil, trace.BadParameter("cert or key PEM data is missing")
}
return certPEM, keyPEM, nil
}

14
vendor/github.com/gravitational/reporting/.gitignore generated vendored Normal file
View file

@ -0,0 +1,14 @@
# Binaries for programs and plugins
*.exe
*.dll
*.so
*.dylib
# Test binary, build with `go test -c`
*.test
# Output of the go coverage tool, specifically when used with LiteIDE
*.out
# Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736
.glide/

26
vendor/github.com/gravitational/reporting/Dockerfile generated vendored Normal file
View file

@ -0,0 +1,26 @@
FROM quay.io/gravitational/debian-venti:go1.9.1-jessie
ARG PROTOC_VER
ARG GOGO_PROTO_TAG
ARG GRPC_GATEWAY_TAG
ARG PLATFORM
ENV TARBALL protoc-${PROTOC_VER}-${PLATFORM}.zip
ENV GOGOPROTO_ROOT ${GOPATH}/src/github.com/gogo/protobuf
ENV LANGUAGE="en_US.UTF-8" \
LANG="en_US.UTF-8" \
LC_ALL="en_US.UTF-8" \
LC_CTYPE="en_US.UTF-8" \
GOPATH="/gopath" \
PATH="$PATH:/opt/protoc/bin:/opt/go/bin:/gopath/bin"
RUN apt-get update && apt-get install unzip
RUN curl -L -o /tmp/${TARBALL} https://github.com/google/protobuf/releases/download/v${PROTOC_VER}/${TARBALL}
RUN cd /tmp && unzip /tmp/protoc-${PROTOC_VER}-linux-x86_64.zip -d /usr/local && rm /tmp/${TARBALL}
RUN go get -u github.com/gogo/protobuf/proto github.com/gogo/protobuf/protoc-gen-gogo github.com/gogo/protobuf/gogoproto golang.org/x/tools/cmd/goimports github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway github.com/grpc-ecosystem/grpc-gateway/protoc-gen-swagger
RUN cd ${GOPATH}/src/github.com/gogo/protobuf && git reset --hard ${GOGO_PROTO_TAG} && make install
RUN cd ${GOPATH}/src/github.com/grpc-ecosystem/grpc-gateway && git reset --hard ${GRPC_GATEWAY_TAG} && go install ./protoc-gen-grpc-gateway
ENV PROTO_INCLUDE "/usr/local/include":"${GOPATH}/src":"${GOPATH}/src/github.com/gogo/protobuf/protobuf":"${GOPATH}/src/github.com/grpc-ecosystem/grpc-gateway/third_party/googleapis":"${GOPATH}/src/github.com/grpc-ecosystem/grpc-gateway/third_party/googleapis":"${GOGOPROTO_ROOT}:${GOGOPROTO_ROOT}/protobuf"

201
vendor/github.com/gravitational/reporting/LICENSE generated vendored Normal file
View file

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "{}"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright {yyyy} {name of copyright owner}
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

39
vendor/github.com/gravitational/reporting/Makefile generated vendored Normal file
View file

@ -0,0 +1,39 @@
PROTOC_VER ?= 3.0.0
GOGO_PROTO_TAG ?= v0.3
GRPC_GATEWAY_TAG ?= v1.1.0
PLATFORM := linux-x86_64
GRPC_API := .
BUILDBOX_TAG := reporting-buildbox:0.0.1
.PHONY: all
all: grpc build
.PHONY: build
build:
go build ./...
.PHONY: test
test:
go test -v ./...
.PHONY: buildbox
buildbox:
docker build \
--build-arg PROTOC_VER=$(PROTOC_VER) \
--build-arg GOGO_PROTO_TAG=$(GOGO_PROTO_TAG) \
--build-arg GRPC_GATEWAY_TAG=$(GRPC_GATEWAY_TAG) \
--build-arg PLATFORM=$(PLATFORM) \
-t $(BUILDBOX_TAG) .
.PHONY: grpc
grpc: buildbox
docker run -v $(shell pwd):/go/src/github.com/gravitational/reporting $(BUILDBOX_TAG) \
make -C /go/src/github.com/gravitational/reporting buildbox-grpc
.PHONY: buildbox-grpc
buildbox-grpc:
echo $$PROTO_INCLUDE
cd $(GRPC_API) && protoc -I=.:$$PROTO_INCLUDE \
--gofast_out=plugins=grpc:.\
--grpc-gateway_out=logtostderr=true:. \
*.proto

2
vendor/github.com/gravitational/reporting/README.md generated vendored Normal file
View file

@ -0,0 +1,2 @@
# reporting
gRPC based client/server usage reporting module

554
vendor/github.com/gravitational/reporting/api.pb.go generated vendored Normal file
View file

@ -0,0 +1,554 @@
// Code generated by protoc-gen-gogo.
// source: api.proto
// DO NOT EDIT!
/*
Package reporting is a generated protocol buffer package.
It is generated from these files:
api.proto
It has these top-level messages:
GRPCEvent
GRPCEvents
*/
package reporting
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
import _ "github.com/grpc-ecosystem/grpc-gateway/third_party/googleapis/google/api"
import google_protobuf1 "github.com/golang/protobuf/ptypes/empty"
import (
context "golang.org/x/net/context"
grpc "google.golang.org/grpc"
)
import io "io"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
// GRPCEvent represents a single event sent over gRPC
type GRPCEvent struct {
// Data is the JSON-encoded event payload
Data []byte `protobuf:"bytes,1,opt,name=Data,json=data,proto3" json:"Data,omitempty"`
}
func (m *GRPCEvent) Reset() { *m = GRPCEvent{} }
func (m *GRPCEvent) String() string { return proto.CompactTextString(m) }
func (*GRPCEvent) ProtoMessage() {}
func (*GRPCEvent) Descriptor() ([]byte, []int) { return fileDescriptorApi, []int{0} }
// Events defines a series of events sent over gRPC
type GRPCEvents struct {
// Events is a list of events
Events []*GRPCEvent `protobuf:"bytes,1,rep,name=Events,json=events" json:"Events,omitempty"`
}
func (m *GRPCEvents) Reset() { *m = GRPCEvents{} }
func (m *GRPCEvents) String() string { return proto.CompactTextString(m) }
func (*GRPCEvents) ProtoMessage() {}
func (*GRPCEvents) Descriptor() ([]byte, []int) { return fileDescriptorApi, []int{1} }
func (m *GRPCEvents) GetEvents() []*GRPCEvent {
if m != nil {
return m.Events
}
return nil
}
func init() {
proto.RegisterType((*GRPCEvent)(nil), "reporting.GRPCEvent")
proto.RegisterType((*GRPCEvents)(nil), "reporting.GRPCEvents")
}
// Reference imports to suppress errors if they are not otherwise used.
var _ context.Context
var _ grpc.ClientConn
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
const _ = grpc.SupportPackageIsVersion3
// Client API for EventsService service
type EventsServiceClient interface {
// Record records the provided list of gRPC events
Record(ctx context.Context, in *GRPCEvents, opts ...grpc.CallOption) (*google_protobuf1.Empty, error)
}
type eventsServiceClient struct {
cc *grpc.ClientConn
}
func NewEventsServiceClient(cc *grpc.ClientConn) EventsServiceClient {
return &eventsServiceClient{cc}
}
func (c *eventsServiceClient) Record(ctx context.Context, in *GRPCEvents, opts ...grpc.CallOption) (*google_protobuf1.Empty, error) {
out := new(google_protobuf1.Empty)
err := grpc.Invoke(ctx, "/reporting.EventsService/Record", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// Server API for EventsService service
type EventsServiceServer interface {
// Record records the provided list of gRPC events
Record(context.Context, *GRPCEvents) (*google_protobuf1.Empty, error)
}
func RegisterEventsServiceServer(s *grpc.Server, srv EventsServiceServer) {
s.RegisterService(&_EventsService_serviceDesc, srv)
}
func _EventsService_Record_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(GRPCEvents)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(EventsServiceServer).Record(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/reporting.EventsService/Record",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(EventsServiceServer).Record(ctx, req.(*GRPCEvents))
}
return interceptor(ctx, in, info, handler)
}
var _EventsService_serviceDesc = grpc.ServiceDesc{
ServiceName: "reporting.EventsService",
HandlerType: (*EventsServiceServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "Record",
Handler: _EventsService_Record_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: fileDescriptorApi,
}
func (m *GRPCEvent) Marshal() (data []byte, err error) {
size := m.Size()
data = make([]byte, size)
n, err := m.MarshalTo(data)
if err != nil {
return nil, err
}
return data[:n], nil
}
func (m *GRPCEvent) MarshalTo(data []byte) (int, error) {
var i int
_ = i
var l int
_ = l
if len(m.Data) > 0 {
data[i] = 0xa
i++
i = encodeVarintApi(data, i, uint64(len(m.Data)))
i += copy(data[i:], m.Data)
}
return i, nil
}
func (m *GRPCEvents) Marshal() (data []byte, err error) {
size := m.Size()
data = make([]byte, size)
n, err := m.MarshalTo(data)
if err != nil {
return nil, err
}
return data[:n], nil
}
func (m *GRPCEvents) MarshalTo(data []byte) (int, error) {
var i int
_ = i
var l int
_ = l
if len(m.Events) > 0 {
for _, msg := range m.Events {
data[i] = 0xa
i++
i = encodeVarintApi(data, i, uint64(msg.Size()))
n, err := msg.MarshalTo(data[i:])
if err != nil {
return 0, err
}
i += n
}
}
return i, nil
}
func encodeFixed64Api(data []byte, offset int, v uint64) int {
data[offset] = uint8(v)
data[offset+1] = uint8(v >> 8)
data[offset+2] = uint8(v >> 16)
data[offset+3] = uint8(v >> 24)
data[offset+4] = uint8(v >> 32)
data[offset+5] = uint8(v >> 40)
data[offset+6] = uint8(v >> 48)
data[offset+7] = uint8(v >> 56)
return offset + 8
}
func encodeFixed32Api(data []byte, offset int, v uint32) int {
data[offset] = uint8(v)
data[offset+1] = uint8(v >> 8)
data[offset+2] = uint8(v >> 16)
data[offset+3] = uint8(v >> 24)
return offset + 4
}
func encodeVarintApi(data []byte, offset int, v uint64) int {
for v >= 1<<7 {
data[offset] = uint8(v&0x7f | 0x80)
v >>= 7
offset++
}
data[offset] = uint8(v)
return offset + 1
}
func (m *GRPCEvent) Size() (n int) {
var l int
_ = l
l = len(m.Data)
if l > 0 {
n += 1 + l + sovApi(uint64(l))
}
return n
}
func (m *GRPCEvents) Size() (n int) {
var l int
_ = l
if len(m.Events) > 0 {
for _, e := range m.Events {
l = e.Size()
n += 1 + l + sovApi(uint64(l))
}
}
return n
}
func sovApi(x uint64) (n int) {
for {
n++
x >>= 7
if x == 0 {
break
}
}
return n
}
func sozApi(x uint64) (n int) {
return sovApi(uint64((x << 1) ^ uint64((int64(x) >> 63))))
}
func (m *GRPCEvent) Unmarshal(data []byte) error {
l := len(data)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowApi
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := data[iNdEx]
iNdEx++
wire |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: GRPCEvent: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: GRPCEvent: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Data", wireType)
}
var byteLen int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowApi
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := data[iNdEx]
iNdEx++
byteLen |= (int(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
if byteLen < 0 {
return ErrInvalidLengthApi
}
postIndex := iNdEx + byteLen
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.Data = append(m.Data[:0], data[iNdEx:postIndex]...)
if m.Data == nil {
m.Data = []byte{}
}
iNdEx = postIndex
default:
iNdEx = preIndex
skippy, err := skipApi(data[iNdEx:])
if err != nil {
return err
}
if skippy < 0 {
return ErrInvalidLengthApi
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}
func (m *GRPCEvents) Unmarshal(data []byte) error {
l := len(data)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowApi
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := data[iNdEx]
iNdEx++
wire |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: GRPCEvents: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: GRPCEvents: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Events", wireType)
}
var msglen int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowApi
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := data[iNdEx]
iNdEx++
msglen |= (int(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
if msglen < 0 {
return ErrInvalidLengthApi
}
postIndex := iNdEx + msglen
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.Events = append(m.Events, &GRPCEvent{})
if err := m.Events[len(m.Events)-1].Unmarshal(data[iNdEx:postIndex]); err != nil {
return err
}
iNdEx = postIndex
default:
iNdEx = preIndex
skippy, err := skipApi(data[iNdEx:])
if err != nil {
return err
}
if skippy < 0 {
return ErrInvalidLengthApi
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}
func skipApi(data []byte) (n int, err error) {
l := len(data)
iNdEx := 0
for iNdEx < l {
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowApi
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := data[iNdEx]
iNdEx++
wire |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
wireType := int(wire & 0x7)
switch wireType {
case 0:
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowApi
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
iNdEx++
if data[iNdEx-1] < 0x80 {
break
}
}
return iNdEx, nil
case 1:
iNdEx += 8
return iNdEx, nil
case 2:
var length int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowApi
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := data[iNdEx]
iNdEx++
length |= (int(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
iNdEx += length
if length < 0 {
return 0, ErrInvalidLengthApi
}
return iNdEx, nil
case 3:
for {
var innerWire uint64
var start int = iNdEx
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowApi
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := data[iNdEx]
iNdEx++
innerWire |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
innerWireType := int(innerWire & 0x7)
if innerWireType == 4 {
break
}
next, err := skipApi(data[start:])
if err != nil {
return 0, err
}
iNdEx = start + next
}
return iNdEx, nil
case 4:
return iNdEx, nil
case 5:
iNdEx += 4
return iNdEx, nil
default:
return 0, fmt.Errorf("proto: illegal wireType %d", wireType)
}
}
panic("unreachable")
}
var (
ErrInvalidLengthApi = fmt.Errorf("proto: negative length found during unmarshaling")
ErrIntOverflowApi = fmt.Errorf("proto: integer overflow")
)
func init() { proto.RegisterFile("api.proto", fileDescriptorApi) }
var fileDescriptorApi = []byte{
// 286 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x6c, 0x8f, 0xc1, 0x4a, 0x33, 0x31,
0x14, 0x85, 0xff, 0xfc, 0x96, 0x81, 0x46, 0x45, 0x09, 0x56, 0x4a, 0x85, 0xb1, 0xcc, 0xaa, 0x88,
0x26, 0x58, 0x77, 0x5d, 0xaa, 0xc5, 0x9d, 0xc8, 0xb8, 0x73, 0x53, 0x6e, 0xa7, 0x31, 0x0d, 0xd8,
0x49, 0x48, 0x6e, 0x47, 0x66, 0xeb, 0x2b, 0xb8, 0xf1, 0x91, 0x5c, 0x0a, 0xbe, 0x80, 0x8c, 0x3e,
0x88, 0x38, 0x19, 0x67, 0xe5, 0xee, 0x1c, 0xce, 0x77, 0x92, 0x73, 0x69, 0x17, 0xac, 0xe6, 0xd6,
0x19, 0x34, 0xac, 0xeb, 0xa4, 0x35, 0x0e, 0x75, 0xae, 0x06, 0x77, 0x4a, 0xe3, 0x72, 0x3d, 0xe7,
0x99, 0x59, 0x09, 0xe5, 0x6c, 0x76, 0x22, 0x33, 0xe3, 0x4b, 0x8f, 0xb2, 0xb1, 0x0a, 0x50, 0x3e,
0x42, 0x29, 0x70, 0xa9, 0xdd, 0x62, 0x66, 0xc1, 0x61, 0x29, 0x94, 0x31, 0xea, 0x41, 0x82, 0xd5,
0xbe, 0x91, 0x02, 0xac, 0x16, 0x90, 0xe7, 0x06, 0x01, 0xb5, 0xc9, 0x7d, 0xf8, 0x66, 0x70, 0xd0,
0xa4, 0xb5, 0x9b, 0xaf, 0xef, 0x85, 0x5c, 0x59, 0x2c, 0x43, 0x98, 0x1c, 0xd2, 0xee, 0x55, 0x7a,
0x73, 0x31, 0x2d, 0x64, 0x8e, 0x8c, 0xd1, 0xce, 0x25, 0x20, 0xf4, 0xc9, 0x90, 0x8c, 0xb6, 0xd2,
0xce, 0x02, 0x10, 0x92, 0x09, 0xa5, 0x2d, 0xe0, 0xd9, 0x31, 0x8d, 0x82, 0xea, 0x93, 0xe1, 0xc6,
0x68, 0x73, 0xbc, 0xc7, 0xdb, 0x1b, 0x78, 0x8b, 0xa5, 0x91, 0xac, 0x99, 0xf1, 0x8c, 0x6e, 0x07,
0xfa, 0x56, 0xba, 0x42, 0x67, 0x92, 0x5d, 0xd3, 0x28, 0x95, 0x99, 0x71, 0x0b, 0xd6, 0xfb, 0xab,
0xe8, 0x07, 0xfb, 0x3c, 0x8c, 0xe5, 0xbf, 0x63, 0xf9, 0xf4, 0x67, 0x6c, 0xd2, 0x7b, 0x7a, 0xff,
0x7a, 0xfe, 0xbf, 0x93, 0x50, 0x51, 0x9c, 0x8a, 0xf0, 0xfa, 0x84, 0x1c, 0x9d, 0xef, 0xbe, 0x56,
0x31, 0x79, 0xab, 0x62, 0xf2, 0x51, 0xc5, 0xe4, 0xe5, 0x33, 0xfe, 0x37, 0x8f, 0xea, 0xe2, 0xd9,
0x77, 0x00, 0x00, 0x00, 0xff, 0xff, 0x5c, 0x2d, 0xa0, 0xdf, 0x67, 0x01, 0x00, 0x00,
}

110
vendor/github.com/gravitational/reporting/api.pb.gw.go generated vendored Normal file
View file

@ -0,0 +1,110 @@
// Code generated by protoc-gen-grpc-gateway
// source: api.proto
// DO NOT EDIT!
/*
Package reporting is a reverse proxy.
It translates gRPC into RESTful JSON APIs.
*/
package reporting
import (
"io"
"net/http"
"github.com/golang/protobuf/proto"
"github.com/grpc-ecosystem/grpc-gateway/runtime"
"github.com/grpc-ecosystem/grpc-gateway/utilities"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog"
)
var _ codes.Code
var _ io.Reader
var _ = runtime.String
var _ = utilities.NewDoubleArray
func request_EventsService_Record_0(ctx context.Context, marshaler runtime.Marshaler, client EventsServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {
var protoReq GRPCEvents
var metadata runtime.ServerMetadata
if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil {
return nil, metadata, grpc.Errorf(codes.InvalidArgument, "%v", err)
}
msg, err := client.Record(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD))
return msg, metadata, err
}
// RegisterEventsServiceHandlerFromEndpoint is same as RegisterEventsServiceHandler but
// automatically dials to "endpoint" and closes the connection when "ctx" gets done.
func RegisterEventsServiceHandlerFromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) (err error) {
conn, err := grpc.Dial(endpoint, opts...)
if err != nil {
return err
}
defer func() {
if err != nil {
if cerr := conn.Close(); cerr != nil {
grpclog.Printf("Failed to close conn to %s: %v", endpoint, cerr)
}
return
}
go func() {
<-ctx.Done()
if cerr := conn.Close(); cerr != nil {
grpclog.Printf("Failed to close conn to %s: %v", endpoint, cerr)
}
}()
}()
return RegisterEventsServiceHandler(ctx, mux, conn)
}
// RegisterEventsServiceHandler registers the http handlers for service EventsService to "mux".
// The handlers forward requests to the grpc endpoint over "conn".
func RegisterEventsServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error {
client := NewEventsServiceClient(conn)
mux.Handle("POST", pattern_EventsService_Record_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
if cn, ok := w.(http.CloseNotifier); ok {
go func(done <-chan struct{}, closed <-chan bool) {
select {
case <-done:
case <-closed:
cancel()
}
}(ctx.Done(), cn.CloseNotify())
}
inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req)
rctx, err := runtime.AnnotateContext(ctx, req)
if err != nil {
runtime.HTTPError(ctx, outboundMarshaler, w, req, err)
}
resp, md, err := request_EventsService_Record_0(rctx, inboundMarshaler, client, req, pathParams)
ctx = runtime.NewServerMetadataContext(ctx, md)
if err != nil {
runtime.HTTPError(ctx, outboundMarshaler, w, req, err)
return
}
forward_EventsService_Record_0(ctx, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
})
return nil
}
var (
pattern_EventsService_Record_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1}, []string{"v1", "events"}, ""))
)
var (
forward_EventsService_Record_0 = runtime.ForwardResponseMessage
)

45
vendor/github.com/gravitational/reporting/api.proto generated vendored Normal file
View file

@ -0,0 +1,45 @@
/*
Copyright 2017 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
syntax = "proto3";
package reporting;
import "github.com/grpc-ecosystem/grpc-gateway/third_party/googleapis/google/api/annotations.proto";
import "google/protobuf/empty.proto";
// GRPCEvent represents a single event sent over gRPC
message GRPCEvent {
// Data is the JSON-encoded event payload
bytes Data = 1;
}
// Events defines a series of events sent over gRPC
message GRPCEvents {
// Events is a list of events
repeated GRPCEvent Events = 1;
}
// EventsService defines an event-recording service
service EventsService {
// Record records the provided list of gRPC events
rpc Record(GRPCEvents) returns (google.protobuf.Empty) {
option (google.api.http) = {
post: "/v1/events"
body: "*"
};
}
}

View file

@ -0,0 +1,160 @@
/*
Copyright 2017 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package client
import (
"context"
"crypto/tls"
"time"
"github.com/gravitational/reporting"
"github.com/gravitational/reporting/types"
"github.com/gravitational/trace"
log "github.com/sirupsen/logrus"
grpcapi "google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
// ClientConfig defines the reporting client config
type ClientConfig struct {
// ServerAddr is the address of the reporting gRPC server
ServerAddr string
// ServerName is the SNI server name
ServerName string
// Certificate is the client certificate to authenticate with
Certificate tls.Certificate
// Insecure is whether the client should skip server cert verification
Insecure bool
}
// Client defines the reporting client interface
type Client interface {
// Record records an event
Record(types.Event)
}
// NewClient returns a new reporting gRPC client
func NewClient(ctx context.Context, config ClientConfig) (*client, error) {
conn, err := grpcapi.Dial(config.ServerAddr,
grpcapi.WithTransportCredentials(
credentials.NewTLS(&tls.Config{
ServerName: config.ServerName,
InsecureSkipVerify: config.Insecure,
Certificates: []tls.Certificate{config.Certificate},
})))
if err != nil {
return nil, trace.Wrap(err)
}
client := &client{
client: reporting.NewEventsServiceClient(conn),
// give an extra room to the events channel in case events
// are generated faster we can flush them (unlikely due to
// our events nature)
eventsCh: make(chan types.Event, 5*flushCount),
ctx: ctx,
}
go client.receiveAndFlushEvents()
return client, nil
}
type client struct {
client reporting.EventsServiceClient
// eventsCh is the channel where events are submitted before they are
// put into internal buffer
eventsCh chan types.Event
// events is the internal events buffer that gets flushed periodically
events []types.Event
// ctx may be used to stop client goroutine
ctx context.Context
}
// Record records an event. Note that the client accumulates events in memory
// and flushes them every once in a while
func (c *client) Record(event types.Event) {
select {
case c.eventsCh <- event:
log.Debugf("queued %v", event)
default:
log.Warnf("events channel is full, discarding %v", event)
}
}
// receiveAndFlushEvents receives events on a channel, accumulates them in
// memory and flushes them once a certain number has been accumulated, or
// certain amount of time has passed
func (c *client) receiveAndFlushEvents() {
ticker := time.NewTicker(flushInterval)
defer ticker.Stop()
for {
select {
case event := <-c.eventsCh:
if len(c.events) >= flushCount {
if err := c.flush(); err != nil {
log.Errorf("events queue full and failed to flush events, discarding %v: %v",
event, trace.DebugReport(err))
continue
}
}
c.events = append(c.events, event)
case <-ticker.C:
if err := c.flush(); err != nil {
log.Errorf("failed to flush events: %v",
trace.DebugReport(err))
}
case <-c.ctx.Done():
log.Debug("reporting client is shutting down")
if err := c.flush(); err != nil {
log.Errorf("failed to flush events: %v",
trace.DebugReport(err))
}
return
}
}
}
// flush flushes all accumulated events
func (c *client) flush() error {
if len(c.events) == 0 {
return nil // nothing to flush
}
var grpcEvents reporting.GRPCEvents
for _, event := range c.events {
grpcEvent, err := types.ToGRPCEvent(event)
if err != nil {
return trace.Wrap(err)
}
grpcEvents.Events = append(
grpcEvents.Events, grpcEvent)
}
// if we fail to flush some events here, they will be retried on
// the next cycle, we may get duplicates but each event includes
// a unique ID which server sinks can use to de-duplicate
if _, err := c.client.Record(c.ctx, &grpcEvents); err != nil {
return trace.Wrap(err)
}
log.Debugf("flushed %v events", len(c.events))
c.events = []types.Event{}
return nil
}
const (
// flushInterval is how often the client flushes accumulated events
flushInterval = 3 * time.Second
// flushCount is the number of events to accumulate before flush triggers
flushCount = 5
)

View file

@ -0,0 +1,42 @@
/*
Copyright 2017 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package types
const (
// ResourceVersion is the current event resource version
ResourceVersion = "v2"
// KindEvent is the event resource kind
KindEvent = "event"
// EventTypeServer is the server-related event type
EventTypeServer = "server"
// EventTypeUser is the user-related event type
EventTypeUser = "user"
// EventActionLogin is the event login action
EventActionLogin = "login"
// KindHeartbeat is the heartbeat resource kind
KindHeartbeat = "heartbeat"
// NotificationUsage is the usage limit notification type
NotificationUsage = "usage"
// NotificationTerms is the terms of service violation notification type
NotificationTerms = "terms"
// SeverityInfo is info notification severity
SeverityInfo = "info"
// SeverityWarning is warning notification severity
SeverityWarning = "warning"
// SeverityError is error notification severity
SeverityError = "error"
)

View file

@ -0,0 +1,320 @@
/*
Copyright 2017 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package types
import (
"encoding/json"
"fmt"
"time"
"github.com/gravitational/reporting"
"github.com/gravitational/configure/jsonschema"
"github.com/gravitational/trace"
"github.com/pborman/uuid"
)
// Event defines an interface all event types should implement
type Event interface {
// GetName returns the event name
GetName() string
// GetMetadata returns the event metadata
GetMetadata() Metadata
// SetAccountID sets the event account ID
SetAccountID(string)
}
// Metadata represents event resource metadata
type Metadata struct {
// Name is the event name
Name string `json:"name"`
// Created is the event creation timestamp
Created time.Time `json:"created"`
}
// ServerEvent represents server-related event, such as "logged into server"
type ServerEvent struct {
// Kind is resource kind, for events it is "event"
Kind string `json:"kind"`
// Version is the event resource version
Version string `json:"version"`
// Metadata is the event metadata
Metadata Metadata `json:"metadata"`
// Spec is the event spec
Spec ServerEventSpec `json:"spec"`
}
// ServerEventSpec is server event specification
type ServerEventSpec struct {
// ID is event ID, may be used for de-duplication
ID string `json:"id"`
// Action is event action, such as "login"
Action string `json:"action"`
// AccountID is ID of account that triggered the event
AccountID string `json:"accountID"`
// ServerID is anonymized ID of server that triggered the event
ServerID string `json:"serverID"`
}
// NewServerLoginEvent creates an instance of "server login" event
func NewServerLoginEvent(serverID string) *ServerEvent {
return &ServerEvent{
Kind: KindEvent,
Version: ResourceVersion,
Metadata: Metadata{
Name: EventTypeServer,
Created: time.Now().UTC(),
},
Spec: ServerEventSpec{
ID: uuid.New(),
Action: EventActionLogin,
ServerID: serverID,
},
}
}
// GetName returns the event name
func (e *ServerEvent) GetName() string { return e.Metadata.Name }
// GetMetadata returns the event metadata
func (e *ServerEvent) GetMetadata() Metadata { return e.Metadata }
// SetAccountID sets the event account ID
func (e *ServerEvent) SetAccountID(id string) {
e.Spec.AccountID = id
}
// UserEvent represents user-related event, such as "user logged in"
type UserEvent struct {
// Kind is resource kind, for events it is "event"
Kind string `json:"kind"`
// Version is the event resource version
Version string `json:"version"`
// Metadata is the event metadata
Metadata Metadata `json:"metadata"`
// Spec is the event spec
Spec UserEventSpec `json:"spec"`
}
// UserEventSpec is user event specification
type UserEventSpec struct {
// ID is event ID, may be used for de-duplication
ID string `json:"id"`
// Action is event action, such as "login"
Action string `json:"action"`
// AccountID is ID of account that triggered the event
AccountID string `json:"accountID"`
// UserID is anonymized ID of user that triggered the event
UserID string `json:"userID"`
}
// NewUserLoginEvent creates an instance of "user login" event
func NewUserLoginEvent(userID string) *UserEvent {
return &UserEvent{
Kind: KindEvent,
Version: ResourceVersion,
Metadata: Metadata{
Name: EventTypeUser,
Created: time.Now().UTC(),
},
Spec: UserEventSpec{
ID: uuid.New(),
Action: EventActionLogin,
UserID: userID,
},
}
}
// GetName returns the event name
func (e *UserEvent) GetName() string { return e.Metadata.Name }
// GetMetadata returns the event metadata
func (e *UserEvent) GetMetadata() Metadata { return e.Metadata }
// SetAccountID sets the event account id
func (e *UserEvent) SetAccountID(id string) {
e.Spec.AccountID = id
}
// ToGRPCEvent converts provided event to the format used by gRPC server/client
func ToGRPCEvent(event Event) (*reporting.GRPCEvent, error) {
payload, err := json.Marshal(event)
if err != nil {
return nil, trace.Wrap(err)
}
return &reporting.GRPCEvent{
Data: payload,
}, nil
}
// FromGRPCEvent converts event from the format used by gRPC server/client
func FromGRPCEvent(grpcEvent reporting.GRPCEvent) (Event, error) {
var header resourceHeader
if err := json.Unmarshal(grpcEvent.Data, &header); err != nil {
return nil, trace.Wrap(err)
}
if header.Kind != KindEvent {
return nil, trace.BadParameter("expected kind %q, got %q",
KindEvent, header.Kind)
}
if header.Version != ResourceVersion {
return nil, trace.BadParameter("expected resource version %q, got %q",
ResourceVersion, header.Version)
}
switch header.Metadata.Name {
case EventTypeServer:
var event ServerEvent
err := unmarshalWithSchema(
getServerEventSchema(), grpcEvent.Data, &event)
if err != nil {
return nil, trace.Wrap(err)
}
return &event, nil
case EventTypeUser:
var event UserEvent
err := unmarshalWithSchema(
getUserEventSchema(), grpcEvent.Data, &event)
if err != nil {
return nil, trace.Wrap(err)
}
return &event, nil
default:
return nil, trace.BadParameter("unknown event type %q", header.Metadata.Name)
}
}
// FromGRPCEvents converts a series of events from the format used by gRPC server/client
func FromGRPCEvents(grpcEvents reporting.GRPCEvents) ([]Event, error) {
var events []Event
for _, grpcEvent := range grpcEvents.Events {
event, err := FromGRPCEvent(*grpcEvent)
if err != nil {
return nil, trace.Wrap(err)
}
events = append(events, event)
}
return events, nil
}
// resourceHeader is used when unmarhsaling resources
type resourceHeader struct {
// Kind the the resource kind
Kind string `json:"kind"`
// Version is the resource version
Version string `json:"version"`
// Metadata is the resource metadata
Metadata Metadata `json:"metadata"`
}
// schemaTemplate is the event resource schema template
const schemaTemplate = `{
"type": "object",
"additionalProperties": false,
"required": ["kind", "version", "metadata", "spec"],
"properties": {
"kind": {"type": "string"},
"version": {"type": "string", "default": "v2"},
"metadata": {
"type": "object",
"additionalProperties": false,
"required": ["name", "created"],
"properties": {
"name": {"type": "string"},
"created": {"type": "string"}
}
},
"spec": %v
}
}`
// getServerEventSchema returns full server event JSON schema
func getServerEventSchema() string {
return fmt.Sprintf(schemaTemplate, serverEventSchema)
}
// serverEventSchema is the server event spec schema
const serverEventSchema = `{
"type": "object",
"additionalProperties": false,
"required": ["id", "action", "accountID", "serverID"],
"properties": {
"id": {"type": "string"},
"action": {"type": "string"},
"accountID": {"type": "string"},
"serverID": {"type": "string"}
}
}`
// getUserEventSchema returns full user event JSON schema
func getUserEventSchema() string {
return fmt.Sprintf(schemaTemplate, userEventSchema)
}
// userEventSchema is the user event spec schema
const userEventSchema = `{
"type": "object",
"additionalProperties": false,
"required": ["id", "action", "accountID", "userID"],
"properties": {
"id": {"type": "string"},
"action": {"type": "string"},
"accountID": {"type": "string"},
"userID": {"type": "string"}
}
}`
// unmarshalWithSchema unmarshals the provided data into the provided object
// using specified JSON schema
func unmarshalWithSchema(objectSchema string, data []byte, object interface{}) error {
schema, err := jsonschema.New([]byte(objectSchema))
if err != nil {
return trace.Wrap(err)
}
raw := map[string]interface{}{}
if err := json.Unmarshal(data, &raw); err != nil {
return trace.Wrap(err)
}
processed, err := schema.ProcessObject(raw)
if err != nil {
return trace.Wrap(err)
}
bytes, err := json.Marshal(processed)
if err != nil {
return trace.Wrap(err)
}
if err := json.Unmarshal(bytes, object); err != nil {
return trace.Wrap(err)
}
return nil
}
// marshalWithSchema marshals the provided objects while checking the specified schema
func marshalWithSchema(objectSchema string, object interface{}) ([]byte, error) {
schema, err := jsonschema.New([]byte(objectSchema))
if err != nil {
return nil, trace.Wrap(err)
}
processed, err := schema.ProcessObject(object)
if err != nil {
return nil, trace.Wrap(err)
}
bytes, err := json.Marshal(processed)
if err != nil {
return nil, trace.Wrap(err)
}
return bytes, nil
}

View file

@ -0,0 +1,134 @@
/*
Copyright 2017 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package types
import (
"encoding/json"
"fmt"
"time"
"github.com/gravitational/trace"
)
// Heartbeat represents a heartbeat that is sent from control plane to teleport
type Heartbeat struct {
// Kind is resource kind, for heartbeat it is "heartbeat"
Kind string `json:"kind"`
// Version is the heartbeat resource version
Version string `json:"version"`
// Metadata is the heartbeat metadata
Metadata Metadata `json:"metadata"`
// Spec is the heartbeat spec
Spec HeartbeatSpec `json:"spec"`
}
// HeartbeatSpec is the heartbeat resource spec
type HeartbeatSpec struct {
// Notifications is a list of notifications sent with the heartbeat
Notifications []Notification `json:"notifications,omitempty"`
}
// Notification represents a user notification message
type Notification struct {
// Type is the notification type
Type string `json:"type"`
// Severity is the notification severity: info, warning or error
Severity string `json:"severity"`
// Text is the notification plain text
Text string `json:"text"`
// HTML is the notification HTML
HTML string `json:"html"`
}
// NewHeartbeat returns a new heartbeat
func NewHeartbeat(notifications ...Notification) *Heartbeat {
return &Heartbeat{
Kind: KindHeartbeat,
Version: ResourceVersion,
Metadata: Metadata{
Name: "heartbeat",
Created: time.Now().UTC(),
},
Spec: HeartbeatSpec{
Notifications: notifications,
},
}
}
// GetName returns the resource name
func (h *Heartbeat) GetName() string { return h.Metadata.Name }
// GetMetadata returns the heartbeat metadata
func (h *Heartbeat) GetMetadata() Metadata { return h.Metadata }
// UnmarshalHeartbeat unmarshals heartbeat with schema validation
func UnmarshalHeartbeat(bytes []byte) (*Heartbeat, error) {
var header resourceHeader
if err := json.Unmarshal(bytes, &header); err != nil {
return nil, trace.Wrap(err)
}
if header.Kind != KindHeartbeat {
return nil, trace.BadParameter("expected kind %q, got %q",
KindHeartbeat, header.Kind)
}
if header.Version != ResourceVersion {
return nil, trace.BadParameter("expected resource version %q, got %q",
ResourceVersion, header.Version)
}
var heartbeat Heartbeat
err := unmarshalWithSchema(
getHeartbeatSchema(), bytes, &heartbeat)
if err != nil {
return nil, trace.Wrap(err)
}
return &heartbeat, nil
}
// MarshalHeartbeat marshals heartbeat with schema validation
func MarshalHeartbeat(h Heartbeat) ([]byte, error) {
bytes, err := marshalWithSchema(getHeartbeatSchema(), h)
if err != nil {
return nil, trace.Wrap(err)
}
return bytes, nil
}
// heartbeatSchema is the heartbeat spec schema
const heartbeatSchema = `{
"type": "object",
"properties": {
"notifications": {
"type": "array",
"items": {
"type": "object",
"required": ["type", "severity", "text", "html"],
"additionalProperties": false,
"properties": {
"type": {"type": "string"},
"severity": {"type": "string"},
"text": {"type": "string"},
"html": {"type": "string"}
}
}
}
}
}`
// getHeartbeatSchema returns the full heartbeat resource schema
func getHeartbeatSchema() string {
return fmt.Sprintf(schemaTemplate, heartbeatSchema)
}

View file

@ -0,0 +1,61 @@
/*
Copyright 2017 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package types
import (
"testing"
check "gopkg.in/check.v1"
)
func TestTypes(t *testing.T) { check.TestingT(t) }
type TypesSuite struct{}
var _ = check.Suite(&TypesSuite{})
func (s *TypesSuite) TestHeartbeat(c *check.C) {
h := NewHeartbeat(
Notification{
Type: NotificationUsage,
Severity: SeverityWarning,
Text: "Usage limit exceeded",
HTML: "<div>Usage limit exceeded</div>",
},
Notification{
Type: NotificationTerms,
Severity: SeverityError,
Text: "Terms of service violation",
HTML: "<div>Terms of service violation</div>",
})
bytes, err := MarshalHeartbeat(*h)
c.Assert(err, check.IsNil)
unmarshaled, err := UnmarshalHeartbeat(bytes)
c.Assert(err, check.IsNil)
c.Assert(len(unmarshaled.Spec.Notifications), check.Equals, 2)
c.Assert(unmarshaled, check.DeepEquals, h)
}
func (s *TypesSuite) TestEmptyHeartbeat(c *check.C) {
h := NewHeartbeat()
bytes, err := MarshalHeartbeat(*h)
c.Assert(err, check.IsNil)
unmarshaled, err := UnmarshalHeartbeat(bytes)
c.Assert(err, check.IsNil)
c.Assert(len(unmarshaled.Spec.Notifications), check.Equals, 0)
c.Assert(unmarshaled, check.DeepEquals, h)
}

View file

@ -0,0 +1,139 @@
package runtime
import (
"fmt"
"net"
"net/http"
"strconv"
"strings"
"time"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/metadata"
)
const metadataHeaderPrefix = "Grpc-Metadata-"
const metadataTrailerPrefix = "Grpc-Trailer-"
const metadataGrpcTimeout = "Grpc-Timeout"
const xForwardedFor = "X-Forwarded-For"
const xForwardedHost = "X-Forwarded-Host"
var (
// DefaultContextTimeout is used for gRPC call context.WithTimeout whenever a Grpc-Timeout inbound
// header isn't present. If the value is 0 the sent `context` will not have a timeout.
DefaultContextTimeout = 0 * time.Second
)
/*
AnnotateContext adds context information such as metadata from the request.
At a minimum, the RemoteAddr is included in the fashion of "X-Forwarded-For",
except that the forwarded destination is not another HTTP service but rather
a gRPC service.
*/
func AnnotateContext(ctx context.Context, req *http.Request) (context.Context, error) {
var pairs []string
timeout := DefaultContextTimeout
if tm := req.Header.Get(metadataGrpcTimeout); tm != "" {
var err error
timeout, err = timeoutDecode(tm)
if err != nil {
return nil, grpc.Errorf(codes.InvalidArgument, "invalid grpc-timeout: %s", tm)
}
}
for key, vals := range req.Header {
for _, val := range vals {
if key == "Authorization" {
pairs = append(pairs, "authorization", val)
continue
}
if strings.HasPrefix(key, metadataHeaderPrefix) {
pairs = append(pairs, key[len(metadataHeaderPrefix):], val)
}
}
}
if host := req.Header.Get(xForwardedHost); host != "" {
pairs = append(pairs, strings.ToLower(xForwardedHost), host)
} else if req.Host != "" {
pairs = append(pairs, strings.ToLower(xForwardedHost), req.Host)
}
if addr := req.RemoteAddr; addr != "" {
if remoteIP, _, err := net.SplitHostPort(addr); err == nil {
if fwd := req.Header.Get(xForwardedFor); fwd == "" {
pairs = append(pairs, strings.ToLower(xForwardedFor), remoteIP)
} else {
pairs = append(pairs, strings.ToLower(xForwardedFor), fmt.Sprintf("%s, %s", fwd, remoteIP))
}
} else {
grpclog.Printf("invalid remote addr: %s", addr)
}
}
if timeout != 0 {
ctx, _ = context.WithTimeout(ctx, timeout)
}
if len(pairs) == 0 {
return ctx, nil
}
return metadata.NewContext(ctx, metadata.Pairs(pairs...)), nil
}
// ServerMetadata consists of metadata sent from gRPC server.
type ServerMetadata struct {
HeaderMD metadata.MD
TrailerMD metadata.MD
}
type serverMetadataKey struct{}
// NewServerMetadataContext creates a new context with ServerMetadata
func NewServerMetadataContext(ctx context.Context, md ServerMetadata) context.Context {
return context.WithValue(ctx, serverMetadataKey{}, md)
}
// ServerMetadataFromContext returns the ServerMetadata in ctx
func ServerMetadataFromContext(ctx context.Context) (md ServerMetadata, ok bool) {
md, ok = ctx.Value(serverMetadataKey{}).(ServerMetadata)
return
}
func timeoutDecode(s string) (time.Duration, error) {
size := len(s)
if size < 2 {
return 0, fmt.Errorf("timeout string is too short: %q", s)
}
d, ok := timeoutUnitToDuration(s[size-1])
if !ok {
return 0, fmt.Errorf("timeout unit is not recognized: %q", s)
}
t, err := strconv.ParseInt(s[:size-1], 10, 64)
if err != nil {
return 0, err
}
return d * time.Duration(t), nil
}
func timeoutUnitToDuration(u uint8) (d time.Duration, ok bool) {
switch u {
case 'H':
return time.Hour, true
case 'M':
return time.Minute, true
case 'S':
return time.Second, true
case 'm':
return time.Millisecond, true
case 'u':
return time.Microsecond, true
case 'n':
return time.Nanosecond, true
default:
}
return
}

View file

@ -0,0 +1,169 @@
package runtime_test
import (
"net/http"
"reflect"
"testing"
"time"
"github.com/grpc-ecosystem/grpc-gateway/runtime"
"golang.org/x/net/context"
"google.golang.org/grpc/metadata"
)
const (
emptyForwardMetaCount = 1
)
func TestAnnotateContext_WorksWithEmpty(t *testing.T) {
ctx := context.Background()
request, err := http.NewRequest("GET", "http://www.example.com", nil)
if err != nil {
t.Fatalf("http.NewRequest(%q, %q, nil) failed with %v; want success", "GET", "http://www.example.com", err)
}
request.Header.Add("Some-Irrelevant-Header", "some value")
annotated, err := runtime.AnnotateContext(ctx, request)
if err != nil {
t.Errorf("runtime.AnnotateContext(ctx, %#v) failed with %v; want success", request, err)
return
}
md, ok := metadata.FromContext(annotated)
if !ok || len(md) != emptyForwardMetaCount {
t.Errorf("Expected %d metadata items in context; got %v", emptyForwardMetaCount, md)
}
}
func TestAnnotateContext_ForwardsGrpcMetadata(t *testing.T) {
ctx := context.Background()
request, err := http.NewRequest("GET", "http://www.example.com", nil)
if err != nil {
t.Fatalf("http.NewRequest(%q, %q, nil) failed with %v; want success", "GET", "http://www.example.com", err)
}
request.Header.Add("Some-Irrelevant-Header", "some value")
request.Header.Add("Grpc-Metadata-FooBar", "Value1")
request.Header.Add("Grpc-Metadata-Foo-BAZ", "Value2")
request.Header.Add("Grpc-Metadata-foo-bAz", "Value3")
request.Header.Add("Authorization", "Token 1234567890")
annotated, err := runtime.AnnotateContext(ctx, request)
if err != nil {
t.Errorf("runtime.AnnotateContext(ctx, %#v) failed with %v; want success", request, err)
return
}
md, ok := metadata.FromContext(annotated)
if got, want := len(md), emptyForwardMetaCount+3; !ok || got != want {
t.Errorf("Expected %d metadata items in context; got %d", got, want)
}
if got, want := md["foobar"], []string{"Value1"}; !reflect.DeepEqual(got, want) {
t.Errorf(`md["foobar"] = %q; want %q`, got, want)
}
if got, want := md["foo-baz"], []string{"Value2", "Value3"}; !reflect.DeepEqual(got, want) {
t.Errorf(`md["foo-baz"] = %q want %q`, got, want)
}
if got, want := md["authorization"], []string{"Token 1234567890"}; !reflect.DeepEqual(got, want) {
t.Errorf(`md["authorization"] = %q want %q`, got, want)
}
}
func TestAnnotateContext_XForwardedFor(t *testing.T) {
ctx := context.Background()
request, err := http.NewRequest("GET", "http://bar.foo.example.com", nil)
if err != nil {
t.Fatalf("http.NewRequest(%q, %q, nil) failed with %v; want success", "GET", "http://bar.foo.example.com", err)
}
request.Header.Add("X-Forwarded-For", "192.0.2.100") // client
request.RemoteAddr = "192.0.2.200:12345" // proxy
annotated, err := runtime.AnnotateContext(ctx, request)
if err != nil {
t.Errorf("runtime.AnnotateContext(ctx, %#v) failed with %v; want success", request, err)
return
}
md, ok := metadata.FromContext(annotated)
if !ok || len(md) != emptyForwardMetaCount+1 {
t.Errorf("Expected %d metadata items in context; got %v", emptyForwardMetaCount+1, md)
}
if got, want := md["x-forwarded-host"], []string{"bar.foo.example.com"}; !reflect.DeepEqual(got, want) {
t.Errorf(`md["host"] = %v; want %v`, got, want)
}
// Note: it must be in order client, proxy1, proxy2
if got, want := md["x-forwarded-for"], []string{"192.0.2.100, 192.0.2.200"}; !reflect.DeepEqual(got, want) {
t.Errorf(`md["x-forwarded-for"] = %v want %v`, got, want)
}
}
func TestAnnotateContext_SupportsTimeouts(t *testing.T) {
ctx := context.Background()
request, err := http.NewRequest("GET", "http://example.com", nil)
if err != nil {
t.Fatalf(`http.NewRequest("GET", "http://example.com", nil failed with %v; want success`, err)
}
annotated, err := runtime.AnnotateContext(ctx, request)
if err != nil {
t.Errorf("runtime.AnnotateContext(ctx, %#v) failed with %v; want success", request, err)
return
}
if _, ok := annotated.Deadline(); ok {
// no deadline by default
t.Errorf("annotated.Deadline() = _, true; want _, false")
}
const acceptableError = 50 * time.Millisecond
runtime.DefaultContextTimeout = 10 * time.Second
annotated, err = runtime.AnnotateContext(ctx, request)
if err != nil {
t.Errorf("runtime.AnnotateContext(ctx, %#v) failed with %v; want success", request, err)
return
}
deadline, ok := annotated.Deadline()
if !ok {
t.Errorf("annotated.Deadline() = _, false; want _, true")
}
if got, want := deadline.Sub(time.Now()), runtime.DefaultContextTimeout; got-want > acceptableError || got-want < -acceptableError {
t.Errorf("deadline.Sub(time.Now()) = %v; want %v; with error %v", got, want, acceptableError)
}
for _, spec := range []struct {
timeout string
want time.Duration
}{
{
timeout: "17H",
want: 17 * time.Hour,
},
{
timeout: "19M",
want: 19 * time.Minute,
},
{
timeout: "23S",
want: 23 * time.Second,
},
{
timeout: "1009m",
want: 1009 * time.Millisecond,
},
{
timeout: "1000003u",
want: 1000003 * time.Microsecond,
},
{
timeout: "100000007n",
want: 100000007 * time.Nanosecond,
},
} {
request.Header.Set("Grpc-Timeout", spec.timeout)
annotated, err = runtime.AnnotateContext(ctx, request)
if err != nil {
t.Errorf("runtime.AnnotateContext(ctx, %#v) failed with %v; want success", request, err)
return
}
deadline, ok := annotated.Deadline()
if !ok {
t.Errorf("annotated.Deadline() = _, false; want _, true; timeout = %q", spec.timeout)
}
if got, want := deadline.Sub(time.Now()), spec.want; got-want > acceptableError || got-want < -acceptableError {
t.Errorf("deadline.Sub(time.Now()) = %v; want %v; with error %v; timeout= %q", got, want, acceptableError, spec.timeout)
}
}
}

View file

@ -0,0 +1,58 @@
package runtime
import (
"strconv"
)
// String just returns the given string.
// It is just for compatibility to other types.
func String(val string) (string, error) {
return val, nil
}
// Bool converts the given string representation of a boolean value into bool.
func Bool(val string) (bool, error) {
return strconv.ParseBool(val)
}
// Float64 converts the given string representation into representation of a floating point number into float64.
func Float64(val string) (float64, error) {
return strconv.ParseFloat(val, 64)
}
// Float32 converts the given string representation of a floating point number into float32.
func Float32(val string) (float32, error) {
f, err := strconv.ParseFloat(val, 32)
if err != nil {
return 0, err
}
return float32(f), nil
}
// Int64 converts the given string representation of an integer into int64.
func Int64(val string) (int64, error) {
return strconv.ParseInt(val, 0, 64)
}
// Int32 converts the given string representation of an integer into int32.
func Int32(val string) (int32, error) {
i, err := strconv.ParseInt(val, 0, 32)
if err != nil {
return 0, err
}
return int32(i), nil
}
// Uint64 converts the given string representation of an integer into uint64.
func Uint64(val string) (uint64, error) {
return strconv.ParseUint(val, 0, 64)
}
// Uint32 converts the given string representation of an integer into uint32.
func Uint32(val string) (uint32, error) {
i, err := strconv.ParseUint(val, 0, 32)
if err != nil {
return 0, err
}
return uint32(i), nil
}

View file

@ -0,0 +1,5 @@
/*
Package runtime contains runtime helper functions used by
servers which protoc-gen-grpc-gateway generates.
*/
package runtime

View file

@ -0,0 +1,121 @@
package runtime
import (
"io"
"net/http"
"github.com/golang/protobuf/proto"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog"
)
// HTTPStatusFromCode converts a gRPC error code into the corresponding HTTP response status.
func HTTPStatusFromCode(code codes.Code) int {
switch code {
case codes.OK:
return http.StatusOK
case codes.Canceled:
return http.StatusRequestTimeout
case codes.Unknown:
return http.StatusInternalServerError
case codes.InvalidArgument:
return http.StatusBadRequest
case codes.DeadlineExceeded:
return http.StatusRequestTimeout
case codes.NotFound:
return http.StatusNotFound
case codes.AlreadyExists:
return http.StatusConflict
case codes.PermissionDenied:
return http.StatusForbidden
case codes.Unauthenticated:
return http.StatusUnauthorized
case codes.ResourceExhausted:
return http.StatusForbidden
case codes.FailedPrecondition:
return http.StatusPreconditionFailed
case codes.Aborted:
return http.StatusConflict
case codes.OutOfRange:
return http.StatusBadRequest
case codes.Unimplemented:
return http.StatusNotImplemented
case codes.Internal:
return http.StatusInternalServerError
case codes.Unavailable:
return http.StatusServiceUnavailable
case codes.DataLoss:
return http.StatusInternalServerError
}
grpclog.Printf("Unknown gRPC error code: %v", code)
return http.StatusInternalServerError
}
var (
// HTTPError replies to the request with the error.
// You can set a custom function to this variable to customize error format.
HTTPError = DefaultHTTPError
// OtherErrorHandler handles the following error used by the gateway: StatusMethodNotAllowed StatusNotFound and StatusBadRequest
OtherErrorHandler = DefaultOtherErrorHandler
)
type errorBody struct {
Error string `json:"error"`
Code int `json:"code"`
}
//Make this also conform to proto.Message for builtin JSONPb Marshaler
func (e *errorBody) Reset() { *e = errorBody{} }
func (e *errorBody) String() string { return proto.CompactTextString(e) }
func (*errorBody) ProtoMessage() {}
// DefaultHTTPError is the default implementation of HTTPError.
// If "err" is an error from gRPC system, the function replies with the status code mapped by HTTPStatusFromCode.
// If otherwise, it replies with http.StatusInternalServerError.
//
// The response body returned by this function is a JSON object,
// which contains a member whose key is "error" and whose value is err.Error().
func DefaultHTTPError(ctx context.Context, marshaler Marshaler, w http.ResponseWriter, _ *http.Request, err error) {
const fallback = `{"error": "failed to marshal error message"}`
w.Header().Del("Trailer")
w.Header().Set("Content-Type", marshaler.ContentType())
body := &errorBody{
Error: grpc.ErrorDesc(err),
Code: int(grpc.Code(err)),
}
buf, merr := marshaler.Marshal(body)
if merr != nil {
grpclog.Printf("Failed to marshal error message %q: %v", body, merr)
w.WriteHeader(http.StatusInternalServerError)
if _, err := io.WriteString(w, fallback); err != nil {
grpclog.Printf("Failed to write response: %v", err)
}
return
}
md, ok := ServerMetadataFromContext(ctx)
if !ok {
grpclog.Printf("Failed to extract ServerMetadata from context")
}
handleForwardResponseServerMetadata(w, md)
handleForwardResponseTrailerHeader(w, md)
st := HTTPStatusFromCode(grpc.Code(err))
w.WriteHeader(st)
if _, err := w.Write(buf); err != nil {
grpclog.Printf("Failed to write response: %v", err)
}
handleForwardResponseTrailer(w, md)
}
// DefaultOtherErrorHandler is the default implementation of OtherErrorHandler.
// It simply writes a string representation of the given error into "w".
func DefaultOtherErrorHandler(w http.ResponseWriter, _ *http.Request, msg string, code int) {
http.Error(w, msg, code)
}

View file

@ -0,0 +1,56 @@
package runtime_test
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/grpc-ecosystem/grpc-gateway/runtime"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
)
func TestDefaultHTTPError(t *testing.T) {
ctx := context.Background()
for _, spec := range []struct {
err error
status int
msg string
}{
{
err: fmt.Errorf("example error"),
status: http.StatusInternalServerError,
msg: "example error",
},
{
err: grpc.Errorf(codes.NotFound, "no such resource"),
status: http.StatusNotFound,
msg: "no such resource",
},
} {
w := httptest.NewRecorder()
req, _ := http.NewRequest("", "", nil) // Pass in an empty request to match the signature
runtime.DefaultHTTPError(ctx, &runtime.JSONBuiltin{}, w, req, spec.err)
if got, want := w.Header().Get("Content-Type"), "application/json"; got != want {
t.Errorf(`w.Header().Get("Content-Type") = %q; want %q; on spec.err=%v`, got, want, spec.err)
}
if got, want := w.Code, spec.status; got != want {
t.Errorf("w.Code = %d; want %d", got, want)
}
body := make(map[string]interface{})
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
t.Errorf("json.Unmarshal(%q, &body) failed with %v; want success", w.Body.Bytes(), err)
continue
}
if got, want := body["error"].(string), spec.msg; !strings.Contains(got, want) {
t.Errorf(`body["error"] = %q; want %q; on spec.err=%v`, got, want, spec.err)
}
}
}

View file

@ -0,0 +1,164 @@
package runtime
import (
"fmt"
"io"
"net/http"
"net/textproto"
"github.com/golang/protobuf/proto"
"github.com/grpc-ecosystem/grpc-gateway/runtime/internal"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/grpclog"
)
// ForwardResponseStream forwards the stream from gRPC server to REST client.
func ForwardResponseStream(ctx context.Context, marshaler Marshaler, w http.ResponseWriter, req *http.Request, recv func() (proto.Message, error), opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
f, ok := w.(http.Flusher)
if !ok {
grpclog.Printf("Flush not supported in %T", w)
http.Error(w, "unexpected type of web server", http.StatusInternalServerError)
return
}
md, ok := ServerMetadataFromContext(ctx)
if !ok {
grpclog.Printf("Failed to extract ServerMetadata from context")
http.Error(w, "unexpected error", http.StatusInternalServerError)
return
}
handleForwardResponseServerMetadata(w, md)
w.Header().Set("Transfer-Encoding", "chunked")
w.Header().Set("Content-Type", marshaler.ContentType())
if err := handleForwardResponseOptions(ctx, w, nil, opts); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
f.Flush()
for {
resp, err := recv()
if err == io.EOF {
return
}
if err != nil {
handleForwardResponseStreamError(marshaler, w, err)
return
}
if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
handleForwardResponseStreamError(marshaler, w, err)
return
}
buf, err := marshaler.Marshal(streamChunk(resp, nil))
if err != nil {
grpclog.Printf("Failed to marshal response chunk: %v", err)
return
}
if _, err = fmt.Fprintf(w, "%s\n", buf); err != nil {
grpclog.Printf("Failed to send response chunk: %v", err)
return
}
f.Flush()
}
}
func handleForwardResponseServerMetadata(w http.ResponseWriter, md ServerMetadata) {
for k, vs := range md.HeaderMD {
hKey := fmt.Sprintf("%s%s", metadataHeaderPrefix, k)
for i := range vs {
w.Header().Add(hKey, vs[i])
}
}
}
func handleForwardResponseTrailerHeader(w http.ResponseWriter, md ServerMetadata) {
for k := range md.TrailerMD {
tKey := textproto.CanonicalMIMEHeaderKey(fmt.Sprintf("%s%s", metadataTrailerPrefix, k))
w.Header().Add("Trailer", tKey)
}
}
func handleForwardResponseTrailer(w http.ResponseWriter, md ServerMetadata) {
for k, vs := range md.TrailerMD {
tKey := fmt.Sprintf("%s%s", metadataTrailerPrefix, k)
for i := range vs {
w.Header().Add(tKey, vs[i])
}
}
}
// ForwardResponseMessage forwards the message "resp" from gRPC server to REST client.
func ForwardResponseMessage(ctx context.Context, marshaler Marshaler, w http.ResponseWriter, req *http.Request, resp proto.Message, opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
md, ok := ServerMetadataFromContext(ctx)
if !ok {
grpclog.Printf("Failed to extract ServerMetadata from context")
}
handleForwardResponseServerMetadata(w, md)
handleForwardResponseTrailerHeader(w, md)
w.Header().Set("Content-Type", marshaler.ContentType())
if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
HTTPError(ctx, marshaler, w, req, err)
return
}
buf, err := marshaler.Marshal(resp)
if err != nil {
grpclog.Printf("Marshal error: %v", err)
HTTPError(ctx, marshaler, w, req, err)
return
}
if _, err = w.Write(buf); err != nil {
grpclog.Printf("Failed to write response: %v", err)
}
handleForwardResponseTrailer(w, md)
}
func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, resp proto.Message, opts []func(context.Context, http.ResponseWriter, proto.Message) error) error {
if len(opts) == 0 {
return nil
}
for _, opt := range opts {
if err := opt(ctx, w, resp); err != nil {
grpclog.Printf("Error handling ForwardResponseOptions: %v", err)
return err
}
}
return nil
}
func handleForwardResponseStreamError(marshaler Marshaler, w http.ResponseWriter, err error) {
buf, merr := marshaler.Marshal(streamChunk(nil, err))
if merr != nil {
grpclog.Printf("Failed to marshal an error: %v", merr)
return
}
if _, werr := fmt.Fprintf(w, "%s\n", buf); werr != nil {
grpclog.Printf("Failed to notify error to client: %v", werr)
return
}
}
func streamChunk(result proto.Message, err error) map[string]proto.Message {
if err != nil {
grpcCode := grpc.Code(err)
httpCode := HTTPStatusFromCode(grpcCode)
return map[string]proto.Message{
"error": &internal.StreamError{
GrpcCode: int32(grpcCode),
HttpCode: int32(httpCode),
Message: err.Error(),
HttpStatus: http.StatusText(httpCode),
},
}
}
if result == nil {
return streamChunk(nil, fmt.Errorf("empty response"))
}
return map[string]proto.Message{"result": result}
}

View file

@ -0,0 +1,65 @@
// Code generated by protoc-gen-go.
// source: runtime/internal/stream_chunk.proto
// DO NOT EDIT!
/*
Package internal is a generated protocol buffer package.
It is generated from these files:
runtime/internal/stream_chunk.proto
It has these top-level messages:
StreamError
*/
package internal
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
// StreamError is a response type which is returned when
// streaming rpc returns an error.
type StreamError struct {
GrpcCode int32 `protobuf:"varint,1,opt,name=grpc_code,json=grpcCode" json:"grpc_code,omitempty"`
HttpCode int32 `protobuf:"varint,2,opt,name=http_code,json=httpCode" json:"http_code,omitempty"`
Message string `protobuf:"bytes,3,opt,name=message" json:"message,omitempty"`
HttpStatus string `protobuf:"bytes,4,opt,name=http_status,json=httpStatus" json:"http_status,omitempty"`
}
func (m *StreamError) Reset() { *m = StreamError{} }
func (m *StreamError) String() string { return proto.CompactTextString(m) }
func (*StreamError) ProtoMessage() {}
func (*StreamError) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
func init() {
proto.RegisterType((*StreamError)(nil), "grpc.gateway.runtime.StreamError")
}
func init() { proto.RegisterFile("runtime/internal/stream_chunk.proto", fileDescriptor0) }
var fileDescriptor0 = []byte{
// 180 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0x52, 0x2e, 0x2a, 0xcd, 0x2b,
0xc9, 0xcc, 0x4d, 0xd5, 0xcf, 0xcc, 0x2b, 0x49, 0x2d, 0xca, 0x4b, 0xcc, 0xd1, 0x2f, 0x2e, 0x29,
0x4a, 0x4d, 0xcc, 0x8d, 0x4f, 0xce, 0x28, 0xcd, 0xcb, 0xd6, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17,
0x12, 0x49, 0x2f, 0x2a, 0x48, 0xd6, 0x4b, 0x4f, 0x2c, 0x49, 0x2d, 0x4f, 0xac, 0xd4, 0x83, 0xea,
0x50, 0x6a, 0x62, 0xe4, 0xe2, 0x0e, 0x06, 0x2b, 0x76, 0x2d, 0x2a, 0xca, 0x2f, 0x12, 0x92, 0xe6,
0xe2, 0x04, 0xa9, 0x8b, 0x4f, 0xce, 0x4f, 0x49, 0x95, 0x60, 0x54, 0x60, 0xd4, 0x60, 0x0d, 0xe2,
0x00, 0x09, 0x38, 0x03, 0xf9, 0x20, 0xc9, 0x8c, 0x92, 0x92, 0x02, 0x88, 0x24, 0x13, 0x44, 0x12,
0x24, 0x00, 0x96, 0x94, 0xe0, 0x62, 0xcf, 0x4d, 0x2d, 0x2e, 0x4e, 0x4c, 0x4f, 0x95, 0x60, 0x06,
0x4a, 0x71, 0x06, 0xc1, 0xb8, 0x42, 0xf2, 0x5c, 0xdc, 0x60, 0x6d, 0xc5, 0x25, 0x89, 0x25, 0xa5,
0xc5, 0x12, 0x2c, 0x60, 0x59, 0x2e, 0x90, 0x50, 0x30, 0x58, 0xc4, 0x89, 0x2b, 0x8a, 0x03, 0xe6,
0xf2, 0x24, 0x36, 0xb0, 0x6b, 0x8d, 0x01, 0x01, 0x00, 0x00, 0xff, 0xff, 0xa9, 0x07, 0x92, 0xb6,
0xd4, 0x00, 0x00, 0x00,
}

View file

@ -0,0 +1,12 @@
syntax = "proto3";
package grpc.gateway.runtime;
option go_package = "internal";
// StreamError is a response type which is returned when
// streaming rpc returns an error.
message StreamError {
int32 grpc_code = 1;
int32 http_code = 2;
string message = 3;
string http_status = 4;
}

View file

@ -0,0 +1,37 @@
package runtime
import (
"encoding/json"
"io"
)
// JSONBuiltin is a Marshaler which marshals/unmarshals into/from JSON
// with the standard "encoding/json" package of Golang.
// Although it is generally faster for simple proto messages than JSONPb,
// it does not support advanced features of protobuf, e.g. map, oneof, ....
type JSONBuiltin struct{}
// ContentType always Returns "application/json".
func (*JSONBuiltin) ContentType() string {
return "application/json"
}
// Marshal marshals "v" into JSON
func (j *JSONBuiltin) Marshal(v interface{}) ([]byte, error) {
return json.Marshal(v)
}
// Unmarshal unmarshals JSON data into "v".
func (j *JSONBuiltin) Unmarshal(data []byte, v interface{}) error {
return json.Unmarshal(data, v)
}
// NewDecoder returns a Decoder which reads JSON stream from "r".
func (j *JSONBuiltin) NewDecoder(r io.Reader) Decoder {
return json.NewDecoder(r)
}
// NewEncoder returns an Encoder which writes JSON stream into "w".
func (j *JSONBuiltin) NewEncoder(w io.Writer) Encoder {
return json.NewEncoder(w)
}

View file

@ -0,0 +1,245 @@
package runtime_test
import (
"bytes"
"encoding/json"
"reflect"
"strings"
"testing"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes/empty"
structpb "github.com/golang/protobuf/ptypes/struct"
"github.com/golang/protobuf/ptypes/timestamp"
"github.com/golang/protobuf/ptypes/wrappers"
"github.com/grpc-ecosystem/grpc-gateway/examples/examplepb"
"github.com/grpc-ecosystem/grpc-gateway/runtime"
)
func TestJSONBuiltinMarshal(t *testing.T) {
var m runtime.JSONBuiltin
msg := examplepb.SimpleMessage{
Id: "foo",
}
buf, err := m.Marshal(&msg)
if err != nil {
t.Errorf("m.Marshal(%v) failed with %v; want success", &msg, err)
}
var got examplepb.SimpleMessage
if err := json.Unmarshal(buf, &got); err != nil {
t.Errorf("json.Unmarshal(%q, &got) failed with %v; want success", buf, err)
}
if want := msg; !reflect.DeepEqual(got, want) {
t.Errorf("got = %v; want %v", &got, &want)
}
}
func TestJSONBuiltinMarshalField(t *testing.T) {
var m runtime.JSONBuiltin
for _, fixt := range builtinFieldFixtures {
buf, err := m.Marshal(fixt.data)
if err != nil {
t.Errorf("m.Marshal(%v) failed with %v; want success", fixt.data, err)
}
if got, want := string(buf), fixt.json; got != want {
t.Errorf("got = %q; want %q; data = %#v", got, want, fixt.data)
}
}
}
func TestJSONBuiltinMarshalFieldKnownErrors(t *testing.T) {
var m runtime.JSONBuiltin
for _, fixt := range builtinKnownErrors {
buf, err := m.Marshal(fixt.data)
if err != nil {
t.Errorf("m.Marshal(%v) failed with %v; want success", fixt.data, err)
}
if got, want := string(buf), fixt.json; got == want {
t.Errorf("surprisingly got = %q; as want %q; data = %#v", got, want, fixt.data)
}
}
}
func TestJSONBuiltinsnmarshal(t *testing.T) {
var (
m runtime.JSONBuiltin
got examplepb.SimpleMessage
data = []byte(`{"id": "foo"}`)
)
if err := m.Unmarshal(data, &got); err != nil {
t.Errorf("m.Unmarshal(%q, &got) failed with %v; want success", data, err)
}
want := examplepb.SimpleMessage{
Id: "foo",
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got = %v; want = %v", &got, &want)
}
}
func TestJSONBuiltinUnmarshalField(t *testing.T) {
var m runtime.JSONBuiltin
for _, fixt := range builtinFieldFixtures {
dest := reflect.New(reflect.TypeOf(fixt.data))
if err := m.Unmarshal([]byte(fixt.json), dest.Interface()); err != nil {
t.Errorf("m.Unmarshal(%q, dest) failed with %v; want success", fixt.json, err)
}
if got, want := dest.Elem().Interface(), fixt.data; !reflect.DeepEqual(got, want) {
t.Errorf("got = %#v; want = %#v; input = %q", got, want, fixt.json)
}
}
}
func TestJSONBuiltinUnmarshalFieldKnownErrors(t *testing.T) {
var m runtime.JSONBuiltin
for _, fixt := range builtinKnownErrors {
dest := reflect.New(reflect.TypeOf(fixt.data))
if err := m.Unmarshal([]byte(fixt.json), dest.Interface()); err == nil {
t.Errorf("m.Unmarshal(%q, dest) succeeded; want ane error", fixt.json)
}
}
}
func TestJSONBuiltinEncoder(t *testing.T) {
var m runtime.JSONBuiltin
msg := examplepb.SimpleMessage{
Id: "foo",
}
var buf bytes.Buffer
enc := m.NewEncoder(&buf)
if err := enc.Encode(&msg); err != nil {
t.Errorf("enc.Encode(%v) failed with %v; want success", &msg, err)
}
var got examplepb.SimpleMessage
if err := json.Unmarshal(buf.Bytes(), &got); err != nil {
t.Errorf("json.Unmarshal(%q, &got) failed with %v; want success", buf.String(), err)
}
if want := msg; !reflect.DeepEqual(got, want) {
t.Errorf("got = %v; want %v", &got, &want)
}
}
func TestJSONBuiltinEncoderFields(t *testing.T) {
var m runtime.JSONBuiltin
for _, fixt := range builtinFieldFixtures {
var buf bytes.Buffer
enc := m.NewEncoder(&buf)
if err := enc.Encode(fixt.data); err != nil {
t.Errorf("enc.Encode(%#v) failed with %v; want success", fixt.data, err)
}
if got, want := buf.String(), fixt.json+"\n"; got != want {
t.Errorf("got = %q; want %q; data = %#v", got, want, fixt.data)
}
}
}
func TestJSONBuiltinDecoder(t *testing.T) {
var (
m runtime.JSONBuiltin
got examplepb.SimpleMessage
data = `{"id": "foo"}`
)
r := strings.NewReader(data)
dec := m.NewDecoder(r)
if err := dec.Decode(&got); err != nil {
t.Errorf("m.Unmarshal(&got) failed with %v; want success", err)
}
want := examplepb.SimpleMessage{
Id: "foo",
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got = %v; want = %v", &got, &want)
}
}
func TestJSONBuiltinDecoderFields(t *testing.T) {
var m runtime.JSONBuiltin
for _, fixt := range builtinFieldFixtures {
r := strings.NewReader(fixt.json)
dec := m.NewDecoder(r)
dest := reflect.New(reflect.TypeOf(fixt.data))
if err := dec.Decode(dest.Interface()); err != nil {
t.Errorf("dec.Decode(dest) failed with %v; want success; data = %q", err, fixt.json)
}
if got, want := dest.Elem().Interface(), fixt.data; !reflect.DeepEqual(got, want) {
t.Errorf("got = %v; want = %v; input = %q", got, want, fixt.json)
}
}
}
var (
builtinFieldFixtures = []struct {
data interface{}
json string
}{
{data: "", json: `""`},
{data: proto.String(""), json: `""`},
{data: "foo", json: `"foo"`},
{data: proto.String("foo"), json: `"foo"`},
{data: int32(-1), json: "-1"},
{data: proto.Int32(-1), json: "-1"},
{data: int64(-1), json: "-1"},
{data: proto.Int64(-1), json: "-1"},
{data: uint32(123), json: "123"},
{data: proto.Uint32(123), json: "123"},
{data: uint64(123), json: "123"},
{data: proto.Uint64(123), json: "123"},
{data: float32(-1.5), json: "-1.5"},
{data: proto.Float32(-1.5), json: "-1.5"},
{data: float64(-1.5), json: "-1.5"},
{data: proto.Float64(-1.5), json: "-1.5"},
{data: true, json: "true"},
{data: proto.Bool(true), json: "true"},
{data: (*string)(nil), json: "null"},
{data: new(empty.Empty), json: "{}"},
{data: examplepb.NumericEnum_ONE, json: "1"},
{
data: (*examplepb.NumericEnum)(proto.Int32(int32(examplepb.NumericEnum_ONE))),
json: "1",
},
}
builtinKnownErrors = []struct {
data interface{}
json string
}{
{data: examplepb.NumericEnum_ONE, json: "ONE"},
{
data: (*examplepb.NumericEnum)(proto.Int32(int32(examplepb.NumericEnum_ONE))),
json: "ONE",
},
{
data: &examplepb.ABitOfEverything_OneofString{OneofString: "abc"},
json: `"abc"`,
},
{
data: &timestamp.Timestamp{
Seconds: 1462875553,
Nanos: 123000000,
},
json: `"2016-05-10T10:19:13.123Z"`,
},
{
data: &wrappers.Int32Value{Value: 123},
json: "123",
},
{
data: &structpb.Value{
Kind: &structpb.Value_StringValue{
StringValue: "abc",
},
},
json: `"abc"`,
},
}
)

View file

@ -0,0 +1,182 @@
package runtime
import (
"bytes"
"encoding/json"
"fmt"
"io"
"reflect"
"github.com/golang/protobuf/jsonpb"
"github.com/golang/protobuf/proto"
)
// JSONPb is a Marshaler which marshals/unmarshals into/from JSON
// with the "github.com/golang/protobuf/jsonpb".
// It supports fully functionality of protobuf unlike JSONBuiltin.
type JSONPb jsonpb.Marshaler
// ContentType always returns "application/json".
func (*JSONPb) ContentType() string {
return "application/json"
}
// Marshal marshals "v" into JSON
// Currently it can marshal only proto.Message.
// TODO(yugui) Support fields of primitive types in a message.
func (j *JSONPb) Marshal(v interface{}) ([]byte, error) {
if _, ok := v.(proto.Message); !ok {
return j.marshalNonProtoField(v)
}
var buf bytes.Buffer
if err := j.marshalTo(&buf, v); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func (j *JSONPb) marshalTo(w io.Writer, v interface{}) error {
p, ok := v.(proto.Message)
if !ok {
buf, err := j.marshalNonProtoField(v)
if err != nil {
return err
}
_, err = w.Write(buf)
return err
}
return (*jsonpb.Marshaler)(j).Marshal(w, p)
}
// marshalNonProto marshals a non-message field of a protobuf message.
// This function does not correctly marshals arbitary data structure into JSON,
// but it is only capable of marshaling non-message field values of protobuf,
// i.e. primitive types, enums; pointers to primitives or enums; maps from
// integer/string types to primitives/enums/pointers to messages.
func (j *JSONPb) marshalNonProtoField(v interface{}) ([]byte, error) {
rv := reflect.ValueOf(v)
for rv.Kind() == reflect.Ptr {
if rv.IsNil() {
return []byte("null"), nil
}
rv = rv.Elem()
}
if rv.Kind() == reflect.Map {
m := make(map[string]*json.RawMessage)
for _, k := range rv.MapKeys() {
buf, err := j.Marshal(rv.MapIndex(k).Interface())
if err != nil {
return nil, err
}
m[fmt.Sprintf("%v", k.Interface())] = (*json.RawMessage)(&buf)
}
if j.Indent != "" {
return json.MarshalIndent(m, "", j.Indent)
}
return json.Marshal(m)
}
if enum, ok := rv.Interface().(protoEnum); ok && !j.EnumsAsInts {
return json.Marshal(enum.String())
}
return json.Marshal(rv.Interface())
}
// Unmarshal unmarshals JSON "data" into "v"
// Currently it can marshal only proto.Message.
// TODO(yugui) Support fields of primitive types in a message.
func (j *JSONPb) Unmarshal(data []byte, v interface{}) error {
return unmarshalJSONPb(data, v)
}
// NewDecoder returns a Decoder which reads JSON stream from "r".
func (j *JSONPb) NewDecoder(r io.Reader) Decoder {
d := json.NewDecoder(r)
return DecoderFunc(func(v interface{}) error { return decodeJSONPb(d, v) })
}
// NewEncoder returns an Encoder which writes JSON stream into "w".
func (j *JSONPb) NewEncoder(w io.Writer) Encoder {
return EncoderFunc(func(v interface{}) error { return j.marshalTo(w, v) })
}
func unmarshalJSONPb(data []byte, v interface{}) error {
d := json.NewDecoder(bytes.NewReader(data))
return decodeJSONPb(d, v)
}
func decodeJSONPb(d *json.Decoder, v interface{}) error {
p, ok := v.(proto.Message)
if !ok {
return decodeNonProtoField(d, v)
}
return jsonpb.UnmarshalNext(d, p)
}
func decodeNonProtoField(d *json.Decoder, v interface{}) error {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr {
return fmt.Errorf("%T is not a pointer", v)
}
for rv.Kind() == reflect.Ptr {
if rv.IsNil() {
rv.Set(reflect.New(rv.Type().Elem()))
}
if rv.Type().ConvertibleTo(typeProtoMessage) {
return jsonpb.UnmarshalNext(d, rv.Interface().(proto.Message))
}
rv = rv.Elem()
}
if rv.Kind() == reflect.Map {
if rv.IsNil() {
rv.Set(reflect.MakeMap(rv.Type()))
}
conv, ok := convFromType[rv.Type().Key().Kind()]
if !ok {
return fmt.Errorf("unsupported type of map field key: %v", rv.Type().Key())
}
m := make(map[string]*json.RawMessage)
if err := d.Decode(&m); err != nil {
return err
}
for k, v := range m {
result := conv.Call([]reflect.Value{reflect.ValueOf(k)})
if err := result[1].Interface(); err != nil {
return err.(error)
}
bk := result[0]
bv := reflect.New(rv.Type().Elem())
if err := unmarshalJSONPb([]byte(*v), bv.Interface()); err != nil {
return err
}
rv.SetMapIndex(bk, bv.Elem())
}
return nil
}
if _, ok := rv.Interface().(protoEnum); ok {
var repr interface{}
if err := d.Decode(&repr); err != nil {
return err
}
switch repr.(type) {
case string:
// TODO(yugui) Should use proto.StructProperties?
return fmt.Errorf("unmarshaling of symbolic enum %q not supported: %T", repr, rv.Interface())
case float64:
rv.Set(reflect.ValueOf(int32(repr.(float64))).Convert(rv.Type()))
return nil
default:
return fmt.Errorf("cannot assign %#v into Go type %T", repr, rv.Interface())
}
}
return d.Decode(v)
}
type protoEnum interface {
fmt.Stringer
EnumDescriptor() ([]byte, []int)
}
var typeProtoMessage = reflect.TypeOf((*proto.Message)(nil)).Elem()

View file

@ -0,0 +1,606 @@
package runtime_test
import (
"bytes"
"reflect"
"strings"
"testing"
"github.com/golang/protobuf/jsonpb"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes/duration"
"github.com/golang/protobuf/ptypes/empty"
structpb "github.com/golang/protobuf/ptypes/struct"
"github.com/golang/protobuf/ptypes/timestamp"
"github.com/golang/protobuf/ptypes/wrappers"
"github.com/grpc-ecosystem/grpc-gateway/examples/examplepb"
"github.com/grpc-ecosystem/grpc-gateway/runtime"
)
func TestJSONPbMarshal(t *testing.T) {
msg := examplepb.ABitOfEverything{
Uuid: "6EC2446F-7E89-4127-B3E6-5C05E6BECBA7",
Nested: []*examplepb.ABitOfEverything_Nested{
{
Name: "foo",
Amount: 12345,
},
},
Uint64Value: 0xFFFFFFFFFFFFFFFF,
EnumValue: examplepb.NumericEnum_ONE,
OneofValue: &examplepb.ABitOfEverything_OneofString{
OneofString: "bar",
},
MapValue: map[string]examplepb.NumericEnum{
"a": examplepb.NumericEnum_ONE,
"b": examplepb.NumericEnum_ZERO,
},
}
for _, spec := range []struct {
enumsAsInts, emitDefaults bool
indent string
origName bool
verifier func(json string)
}{
{
verifier: func(json string) {
if strings.ContainsAny(json, " \t\r\n") {
t.Errorf("strings.ContainsAny(%q, %q) = true; want false", json, " \t\r\n")
}
if !strings.Contains(json, "ONE") {
t.Errorf(`strings.Contains(%q, "ONE") = false; want true`, json)
}
if want := "uint64Value"; !strings.Contains(json, want) {
t.Errorf(`strings.Contains(%q, %q) = false; want true`, json, want)
}
},
},
{
enumsAsInts: true,
verifier: func(json string) {
if strings.Contains(json, "ONE") {
t.Errorf(`strings.Contains(%q, "ONE") = true; want false`, json)
}
},
},
{
emitDefaults: true,
verifier: func(json string) {
if want := `"sfixed32Value"`; !strings.Contains(json, want) {
t.Errorf(`strings.Contains(%q, %q) = false; want true`, json, want)
}
},
},
{
indent: "\t\t",
verifier: func(json string) {
if want := "\t\t\"amount\":"; !strings.Contains(json, want) {
t.Errorf(`strings.Contains(%q, %q) = false; want true`, json, want)
}
},
},
{
origName: true,
verifier: func(json string) {
if want := "uint64_value"; !strings.Contains(json, want) {
t.Errorf(`strings.Contains(%q, %q) = false; want true`, json, want)
}
},
},
} {
m := runtime.JSONPb{
EnumsAsInts: spec.enumsAsInts,
EmitDefaults: spec.emitDefaults,
Indent: spec.indent,
OrigName: spec.origName,
}
buf, err := m.Marshal(&msg)
if err != nil {
t.Errorf("m.Marshal(%v) failed with %v; want success; spec=%v", &msg, err, spec)
}
var got examplepb.ABitOfEverything
if err := jsonpb.UnmarshalString(string(buf), &got); err != nil {
t.Errorf("jsonpb.UnmarshalString(%q, &got) failed with %v; want success; spec=%v", string(buf), err, spec)
}
if want := msg; !reflect.DeepEqual(got, want) {
t.Errorf("got = %v; want %v; spec=%v", &got, &want, spec)
}
if spec.verifier != nil {
spec.verifier(string(buf))
}
}
}
func TestJSONPbMarshalFields(t *testing.T) {
var m runtime.JSONPb
for _, spec := range []struct {
val interface{}
want string
}{} {
buf, err := m.Marshal(spec.val)
if err != nil {
t.Errorf("m.Marshal(%#v) failed with %v; want success", spec.val, err)
}
if got, want := string(buf), spec.want; got != want {
t.Errorf("m.Marshal(%#v) = %q; want %q", spec.val, got, want)
}
}
m.EnumsAsInts = true
buf, err := m.Marshal(examplepb.NumericEnum_ONE)
if err != nil {
t.Errorf("m.Marshal(%#v) failed with %v; want success", examplepb.NumericEnum_ONE, err)
}
if got, want := string(buf), "1"; got != want {
t.Errorf("m.Marshal(%#v) = %q; want %q", examplepb.NumericEnum_ONE, got, want)
}
}
func TestJSONPbUnmarshal(t *testing.T) {
var (
m runtime.JSONPb
got examplepb.ABitOfEverything
)
for _, data := range []string{
`{
"uuid": "6EC2446F-7E89-4127-B3E6-5C05E6BECBA7",
"nested": [
{"name": "foo", "amount": 12345}
],
"uint64Value": 18446744073709551615,
"enumValue": "ONE",
"oneofString": "bar",
"mapValue": {
"a": 1,
"b": 0
}
}`,
`{
"uuid": "6EC2446F-7E89-4127-B3E6-5C05E6BECBA7",
"nested": [
{"name": "foo", "amount": 12345}
],
"uint64Value": "18446744073709551615",
"enumValue": "ONE",
"oneofString": "bar",
"mapValue": {
"a": 1,
"b": 0
}
}`,
`{
"uuid": "6EC2446F-7E89-4127-B3E6-5C05E6BECBA7",
"nested": [
{"name": "foo", "amount": 12345}
],
"uint64Value": 18446744073709551615,
"enumValue": 1,
"oneofString": "bar",
"mapValue": {
"a": 1,
"b": 0
}
}`,
} {
if err := m.Unmarshal([]byte(data), &got); err != nil {
t.Errorf("m.Unmarshal(%q, &got) failed with %v; want success", data, err)
}
want := examplepb.ABitOfEverything{
Uuid: "6EC2446F-7E89-4127-B3E6-5C05E6BECBA7",
Nested: []*examplepb.ABitOfEverything_Nested{
{
Name: "foo",
Amount: 12345,
},
},
Uint64Value: 0xFFFFFFFFFFFFFFFF,
EnumValue: examplepb.NumericEnum_ONE,
OneofValue: &examplepb.ABitOfEverything_OneofString{
OneofString: "bar",
},
MapValue: map[string]examplepb.NumericEnum{
"a": examplepb.NumericEnum_ONE,
"b": examplepb.NumericEnum_ZERO,
},
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got = %v; want = %v", &got, &want)
}
}
}
func TestJSONPbUnmarshalFields(t *testing.T) {
var m runtime.JSONPb
for _, fixt := range fieldFixtures {
if fixt.skipUnmarshal {
continue
}
dest := reflect.New(reflect.TypeOf(fixt.data))
if err := m.Unmarshal([]byte(fixt.json), dest.Interface()); err != nil {
t.Errorf("m.Unmarshal(%q, %T) failed with %v; want success", fixt.json, dest.Interface(), err)
}
if got, want := dest.Elem().Interface(), fixt.data; !reflect.DeepEqual(got, want) {
t.Errorf("dest = %#v; want %#v; input = %v", got, want, fixt.json)
}
}
}
func TestJSONPbEncoder(t *testing.T) {
msg := examplepb.ABitOfEverything{
Uuid: "6EC2446F-7E89-4127-B3E6-5C05E6BECBA7",
Nested: []*examplepb.ABitOfEverything_Nested{
{
Name: "foo",
Amount: 12345,
},
},
Uint64Value: 0xFFFFFFFFFFFFFFFF,
OneofValue: &examplepb.ABitOfEverything_OneofString{
OneofString: "bar",
},
MapValue: map[string]examplepb.NumericEnum{
"a": examplepb.NumericEnum_ONE,
"b": examplepb.NumericEnum_ZERO,
},
}
for _, spec := range []struct {
enumsAsInts, emitDefaults bool
indent string
origName bool
verifier func(json string)
}{
{
verifier: func(json string) {
if strings.ContainsAny(json, " \t\r\n") {
t.Errorf("strings.ContainsAny(%q, %q) = true; want false", json, " \t\r\n")
}
if strings.Contains(json, "ONE") {
t.Errorf(`strings.Contains(%q, "ONE") = true; want false`, json)
}
if want := "uint64Value"; !strings.Contains(json, want) {
t.Errorf(`strings.Contains(%q, %q) = false; want true`, json, want)
}
},
},
{
enumsAsInts: true,
verifier: func(json string) {
if strings.Contains(json, "ONE") {
t.Errorf(`strings.Contains(%q, "ONE") = true; want false`, json)
}
},
},
{
emitDefaults: true,
verifier: func(json string) {
if want := `"sfixed32Value"`; !strings.Contains(json, want) {
t.Errorf(`strings.Contains(%q, %q) = false; want true`, json, want)
}
},
},
{
indent: "\t\t",
verifier: func(json string) {
if want := "\t\t\"amount\":"; !strings.Contains(json, want) {
t.Errorf(`strings.Contains(%q, %q) = false; want true`, json, want)
}
},
},
{
origName: true,
verifier: func(json string) {
if want := "uint64_value"; !strings.Contains(json, want) {
t.Errorf(`strings.Contains(%q, %q) = false; want true`, json, want)
}
},
},
} {
m := runtime.JSONPb{
EnumsAsInts: spec.enumsAsInts,
EmitDefaults: spec.emitDefaults,
Indent: spec.indent,
OrigName: spec.origName,
}
var buf bytes.Buffer
enc := m.NewEncoder(&buf)
if err := enc.Encode(&msg); err != nil {
t.Errorf("enc.Encode(%v) failed with %v; want success; spec=%v", &msg, err, spec)
}
var got examplepb.ABitOfEverything
if err := jsonpb.UnmarshalString(buf.String(), &got); err != nil {
t.Errorf("jsonpb.UnmarshalString(%q, &got) failed with %v; want success; spec=%v", buf.String(), err, spec)
}
if want := msg; !reflect.DeepEqual(got, want) {
t.Errorf("got = %v; want %v; spec=%v", &got, &want, spec)
}
if spec.verifier != nil {
spec.verifier(buf.String())
}
}
}
func TestJSONPbEncoderFields(t *testing.T) {
var m runtime.JSONPb
for _, fixt := range fieldFixtures {
var buf bytes.Buffer
enc := m.NewEncoder(&buf)
if err := enc.Encode(fixt.data); err != nil {
t.Errorf("enc.Encode(%#v) failed with %v; want success", fixt.data, err)
}
if got, want := buf.String(), fixt.json; got != want {
t.Errorf("enc.Encode(%#v) = %q; want %q", fixt.data, got, want)
}
}
m.EnumsAsInts = true
buf, err := m.Marshal(examplepb.NumericEnum_ONE)
if err != nil {
t.Errorf("m.Marshal(%#v) failed with %v; want success", examplepb.NumericEnum_ONE, err)
}
if got, want := string(buf), "1"; got != want {
t.Errorf("m.Marshal(%#v) = %q; want %q", examplepb.NumericEnum_ONE, got, want)
}
}
func TestJSONPbDecoder(t *testing.T) {
var (
m runtime.JSONPb
got examplepb.ABitOfEverything
)
for _, data := range []string{
`{
"uuid": "6EC2446F-7E89-4127-B3E6-5C05E6BECBA7",
"nested": [
{"name": "foo", "amount": 12345}
],
"uint64Value": 18446744073709551615,
"enumValue": "ONE",
"oneofString": "bar",
"mapValue": {
"a": 1,
"b": 0
}
}`,
`{
"uuid": "6EC2446F-7E89-4127-B3E6-5C05E6BECBA7",
"nested": [
{"name": "foo", "amount": 12345}
],
"uint64Value": "18446744073709551615",
"enumValue": "ONE",
"oneofString": "bar",
"mapValue": {
"a": 1,
"b": 0
}
}`,
`{
"uuid": "6EC2446F-7E89-4127-B3E6-5C05E6BECBA7",
"nested": [
{"name": "foo", "amount": 12345}
],
"uint64Value": 18446744073709551615,
"enumValue": 1,
"oneofString": "bar",
"mapValue": {
"a": 1,
"b": 0
}
}`,
} {
r := strings.NewReader(data)
dec := m.NewDecoder(r)
if err := dec.Decode(&got); err != nil {
t.Errorf("m.Unmarshal(&got) failed with %v; want success; data=%q", err, data)
}
want := examplepb.ABitOfEverything{
Uuid: "6EC2446F-7E89-4127-B3E6-5C05E6BECBA7",
Nested: []*examplepb.ABitOfEverything_Nested{
{
Name: "foo",
Amount: 12345,
},
},
Uint64Value: 0xFFFFFFFFFFFFFFFF,
EnumValue: examplepb.NumericEnum_ONE,
OneofValue: &examplepb.ABitOfEverything_OneofString{
OneofString: "bar",
},
MapValue: map[string]examplepb.NumericEnum{
"a": examplepb.NumericEnum_ONE,
"b": examplepb.NumericEnum_ZERO,
},
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got = %v; want = %v; data = %v", &got, &want, data)
}
}
}
func TestJSONPbDecoderFields(t *testing.T) {
var m runtime.JSONPb
for _, fixt := range fieldFixtures {
if fixt.skipUnmarshal {
continue
}
dest := reflect.New(reflect.TypeOf(fixt.data))
dec := m.NewDecoder(strings.NewReader(fixt.json))
if err := dec.Decode(dest.Interface()); err != nil {
t.Errorf("dec.Decode(%T) failed with %v; want success; input = %q", dest.Interface(), err, fixt.json)
}
if got, want := dest.Elem().Interface(), fixt.data; !reflect.DeepEqual(got, want) {
t.Errorf("dest = %#v; want %#v; input = %v", got, want, fixt.json)
}
}
}
var (
fieldFixtures = []struct {
data interface{}
json string
skipUnmarshal bool
}{
{data: int32(1), json: "1"},
{data: proto.Int32(1), json: "1"},
{data: int64(1), json: "1"},
{data: proto.Int64(1), json: "1"},
{data: uint32(1), json: "1"},
{data: proto.Uint32(1), json: "1"},
{data: uint64(1), json: "1"},
{data: proto.Uint64(1), json: "1"},
{data: "abc", json: `"abc"`},
{data: proto.String("abc"), json: `"abc"`},
{data: float32(1.5), json: "1.5"},
{data: proto.Float32(1.5), json: "1.5"},
{data: float64(1.5), json: "1.5"},
{data: proto.Float64(1.5), json: "1.5"},
{data: true, json: "true"},
{data: false, json: "false"},
{data: (*string)(nil), json: "null"},
{
data: examplepb.NumericEnum_ONE,
json: `"ONE"`,
// TODO(yugui) support unmarshaling of symbolic enum
skipUnmarshal: true,
},
{
data: (*examplepb.NumericEnum)(proto.Int32(int32(examplepb.NumericEnum_ONE))),
json: `"ONE"`,
// TODO(yugui) support unmarshaling of symbolic enum
skipUnmarshal: true,
},
{
data: map[string]int32{
"foo": 1,
},
json: `{"foo":1}`,
},
{
data: map[string]*examplepb.SimpleMessage{
"foo": {Id: "bar"},
},
json: `{"foo":{"id":"bar"}}`,
},
{
data: map[int32]*examplepb.SimpleMessage{
1: {Id: "foo"},
},
json: `{"1":{"id":"foo"}}`,
},
{
data: map[bool]*examplepb.SimpleMessage{
true: {Id: "foo"},
},
json: `{"true":{"id":"foo"}}`,
},
{
data: &duration.Duration{
Seconds: 123,
Nanos: 456000000,
},
json: `"123.456s"`,
},
{
data: &timestamp.Timestamp{
Seconds: 1462875553,
Nanos: 123000000,
},
json: `"2016-05-10T10:19:13.123Z"`,
},
{
data: new(empty.Empty),
json: "{}",
},
// TODO(yugui) Enable unmarshaling of the following examples
// once jsonpb supports them.
{
data: &structpb.Value{
Kind: new(structpb.Value_NullValue),
},
json: "null",
skipUnmarshal: true,
},
{
data: &structpb.Value{
Kind: &structpb.Value_NumberValue{
NumberValue: 123.4,
},
},
json: "123.4",
skipUnmarshal: true,
},
{
data: &structpb.Value{
Kind: &structpb.Value_StringValue{
StringValue: "abc",
},
},
json: `"abc"`,
skipUnmarshal: true,
},
{
data: &structpb.Value{
Kind: &structpb.Value_BoolValue{
BoolValue: true,
},
},
json: "true",
skipUnmarshal: true,
},
{
data: &structpb.Struct{
Fields: map[string]*structpb.Value{
"foo_bar": {
Kind: &structpb.Value_BoolValue{
BoolValue: true,
},
},
},
},
json: `{"foo_bar":true}`,
skipUnmarshal: true,
},
{
data: &wrappers.BoolValue{Value: true},
json: "true",
},
{
data: &wrappers.DoubleValue{Value: 123.456},
json: "123.456",
},
{
data: &wrappers.FloatValue{Value: 123.456},
json: "123.456",
},
{
data: &wrappers.Int32Value{Value: -123},
json: "-123",
},
{
data: &wrappers.Int64Value{Value: -123},
json: `"-123"`,
},
{
data: &wrappers.UInt32Value{Value: 123},
json: "123",
},
{
data: &wrappers.UInt64Value{Value: 123},
json: `"123"`,
},
// TODO(yugui) Add other well-known types once jsonpb supports them
}
)

View file

@ -0,0 +1,42 @@
package runtime
import (
"io"
)
// Marshaler defines a conversion between byte sequence and gRPC payloads / fields.
type Marshaler interface {
// Marshal marshals "v" into byte sequence.
Marshal(v interface{}) ([]byte, error)
// Unmarshal unmarshals "data" into "v".
// "v" must be a pointer value.
Unmarshal(data []byte, v interface{}) error
// NewDecoder returns a Decoder which reads byte sequence from "r".
NewDecoder(r io.Reader) Decoder
// NewEncoder returns an Encoder which writes bytes sequence into "w".
NewEncoder(w io.Writer) Encoder
// ContentType returns the Content-Type which this marshaler is responsible for.
ContentType() string
}
// Decoder decodes a byte sequence
type Decoder interface {
Decode(v interface{}) error
}
// Encoder encodes gRPC payloads / fields into byte sequence.
type Encoder interface {
Encode(v interface{}) error
}
// DecoderFunc adapts an decoder function into Decoder.
type DecoderFunc func(v interface{}) error
// Decode delegates invocations to the underlying function itself.
func (f DecoderFunc) Decode(v interface{}) error { return f(v) }
// EncoderFunc adapts an encoder function into Encoder
type EncoderFunc func(v interface{}) error
// Encode delegates invocations to the underlying function itself.
func (f EncoderFunc) Encode(v interface{}) error { return f(v) }

View file

@ -0,0 +1,91 @@
package runtime
import (
"errors"
"net/http"
)
// MIMEWildcard is the fallback MIME type used for requests which do not match
// a registered MIME type.
const MIMEWildcard = "*"
var (
acceptHeader = http.CanonicalHeaderKey("Accept")
contentTypeHeader = http.CanonicalHeaderKey("Content-Type")
defaultMarshaler = &JSONPb{OrigName: true}
)
// MarshalerForRequest returns the inbound/outbound marshalers for this request.
// It checks the registry on the ServeMux for the MIME type set by the Content-Type header.
// If it isn't set (or the request Content-Type is empty), checks for "*".
// If there are multiple Content-Type headers set, choose the first one that it can
// exactly match in the registry.
// Otherwise, it follows the above logic for "*"/InboundMarshaler/OutboundMarshaler.
func MarshalerForRequest(mux *ServeMux, r *http.Request) (inbound Marshaler, outbound Marshaler) {
for _, acceptVal := range r.Header[acceptHeader] {
if m, ok := mux.marshalers.mimeMap[acceptVal]; ok {
outbound = m
break
}
}
for _, contentTypeVal := range r.Header[contentTypeHeader] {
if m, ok := mux.marshalers.mimeMap[contentTypeVal]; ok {
inbound = m
break
}
}
if inbound == nil {
inbound = mux.marshalers.mimeMap[MIMEWildcard]
}
if outbound == nil {
outbound = inbound
}
return inbound, outbound
}
// marshalerRegistry is a mapping from MIME types to Marshalers.
type marshalerRegistry struct {
mimeMap map[string]Marshaler
}
// add adds a marshaler for a case-sensitive MIME type string ("*" to match any
// MIME type).
func (m marshalerRegistry) add(mime string, marshaler Marshaler) error {
if len(mime) == 0 {
return errors.New("empty MIME type")
}
m.mimeMap[mime] = marshaler
return nil
}
// makeMarshalerMIMERegistry returns a new registry of marshalers.
// It allows for a mapping of case-sensitive Content-Type MIME type string to runtime.Marshaler interfaces.
//
// For example, you could allow the client to specify the use of the runtime.JSONPb marshaler
// with a "applicaton/jsonpb" Content-Type and the use of the runtime.JSONBuiltin marshaler
// with a "application/json" Content-Type.
// "*" can be used to match any Content-Type.
// This can be attached to a ServerMux with the marshaler option.
func makeMarshalerMIMERegistry() marshalerRegistry {
return marshalerRegistry{
mimeMap: map[string]Marshaler{
MIMEWildcard: defaultMarshaler,
},
}
}
// WithMarshalerOption returns a ServeMuxOption which associates inbound and outbound
// Marshalers to a MIME type in mux.
func WithMarshalerOption(mime string, marshaler Marshaler) ServeMuxOption {
return func(mux *ServeMux) {
if err := mux.marshalers.add(mime, marshaler); err != nil {
panic(err)
}
}
}

View file

@ -0,0 +1,107 @@
package runtime_test
import (
"errors"
"io"
"net/http"
"testing"
"github.com/grpc-ecosystem/grpc-gateway/runtime"
)
func TestMarshalerForRequest(t *testing.T) {
r, err := http.NewRequest("GET", "http://example.com", nil)
if err != nil {
t.Fatalf(`http.NewRequest("GET", "http://example.com", nil) failed with %v; want success`, err)
}
r.Header.Set("Accept", "application/x-out")
r.Header.Set("Content-Type", "application/x-in")
mux := runtime.NewServeMux()
in, out := runtime.MarshalerForRequest(mux, r)
if _, ok := in.(*runtime.JSONPb); !ok {
t.Errorf("in = %#v; want a runtime.JSONPb", in)
}
if _, ok := out.(*runtime.JSONPb); !ok {
t.Errorf("out = %#v; want a runtime.JSONPb", in)
}
var marshalers [3]dummyMarshaler
specs := []struct {
opt runtime.ServeMuxOption
wantIn runtime.Marshaler
wantOut runtime.Marshaler
}{
{
opt: runtime.WithMarshalerOption(runtime.MIMEWildcard, &marshalers[0]),
wantIn: &marshalers[0],
wantOut: &marshalers[0],
},
{
opt: runtime.WithMarshalerOption("application/x-in", &marshalers[1]),
wantIn: &marshalers[1],
wantOut: &marshalers[0],
},
{
opt: runtime.WithMarshalerOption("application/x-out", &marshalers[2]),
wantIn: &marshalers[1],
wantOut: &marshalers[2],
},
}
for i, spec := range specs {
var opts []runtime.ServeMuxOption
for _, s := range specs[:i+1] {
opts = append(opts, s.opt)
}
mux = runtime.NewServeMux(opts...)
in, out = runtime.MarshalerForRequest(mux, r)
if got, want := in, spec.wantIn; got != want {
t.Errorf("in = %#v; want %#v", got, want)
}
if got, want := out, spec.wantOut; got != want {
t.Errorf("out = %#v; want %#v", got, want)
}
}
r.Header.Set("Content-Type", "application/x-another")
in, out = runtime.MarshalerForRequest(mux, r)
if got, want := in, &marshalers[1]; got != want {
t.Errorf("in = %#v; want %#v", got, want)
}
if got, want := out, &marshalers[0]; got != want {
t.Errorf("out = %#v; want %#v", got, want)
}
}
type dummyMarshaler struct{}
func (dummyMarshaler) ContentType() string { return "" }
func (dummyMarshaler) Marshal(interface{}) ([]byte, error) {
return nil, errors.New("not implemented")
}
func (dummyMarshaler) Unmarshal([]byte, interface{}) error {
return errors.New("not implemented")
}
func (dummyMarshaler) NewDecoder(r io.Reader) runtime.Decoder {
return dummyDecoder{}
}
func (dummyMarshaler) NewEncoder(w io.Writer) runtime.Encoder {
return dummyEncoder{}
}
type dummyDecoder struct{}
func (dummyDecoder) Decode(interface{}) error {
return errors.New("not implemented")
}
type dummyEncoder struct{}
func (dummyEncoder) Encode(interface{}) error {
return errors.New("not implemented")
}

View file

@ -0,0 +1,132 @@
package runtime
import (
"net/http"
"strings"
"golang.org/x/net/context"
"github.com/golang/protobuf/proto"
)
// A HandlerFunc handles a specific pair of path pattern and HTTP method.
type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[string]string)
// ServeMux is a request multiplexer for grpc-gateway.
// It matches http requests to patterns and invokes the corresponding handler.
type ServeMux struct {
// handlers maps HTTP method to a list of handlers.
handlers map[string][]handler
forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error
marshalers marshalerRegistry
}
// ServeMuxOption is an option that can be given to a ServeMux on construction.
type ServeMuxOption func(*ServeMux)
// WithForwardResponseOption returns a ServeMuxOption representing the forwardResponseOption.
//
// forwardResponseOption is an option that will be called on the relevant context.Context,
// http.ResponseWriter, and proto.Message before every forwarded response.
//
// The message may be nil in the case where just a header is being sent.
func WithForwardResponseOption(forwardResponseOption func(context.Context, http.ResponseWriter, proto.Message) error) ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.forwardResponseOptions = append(serveMux.forwardResponseOptions, forwardResponseOption)
}
}
// NewServeMux returns a new ServeMux whose internal mapping is empty.
func NewServeMux(opts ...ServeMuxOption) *ServeMux {
serveMux := &ServeMux{
handlers: make(map[string][]handler),
forwardResponseOptions: make([]func(context.Context, http.ResponseWriter, proto.Message) error, 0),
marshalers: makeMarshalerMIMERegistry(),
}
for _, opt := range opts {
opt(serveMux)
}
return serveMux
}
// Handle associates "h" to the pair of HTTP method and path pattern.
func (s *ServeMux) Handle(meth string, pat Pattern, h HandlerFunc) {
s.handlers[meth] = append(s.handlers[meth], handler{pat: pat, h: h})
}
// ServeHTTP dispatches the request to the first handler whose pattern matches to r.Method and r.Path.
func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
if !strings.HasPrefix(path, "/") {
OtherErrorHandler(w, r, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
components := strings.Split(path[1:], "/")
l := len(components)
var verb string
if idx := strings.LastIndex(components[l-1], ":"); idx == 0 {
OtherErrorHandler(w, r, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return
} else if idx > 0 {
c := components[l-1]
components[l-1], verb = c[:idx], c[idx+1:]
}
if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && isPathLengthFallback(r) {
r.Method = strings.ToUpper(override)
if err := r.ParseForm(); err != nil {
OtherErrorHandler(w, r, err.Error(), http.StatusBadRequest)
return
}
}
for _, h := range s.handlers[r.Method] {
pathParams, err := h.pat.Match(components, verb)
if err != nil {
continue
}
h.h(w, r, pathParams)
return
}
// lookup other methods to handle fallback from GET to POST and
// to determine if it is MethodNotAllowed or NotFound.
for m, handlers := range s.handlers {
if m == r.Method {
continue
}
for _, h := range handlers {
pathParams, err := h.pat.Match(components, verb)
if err != nil {
continue
}
// X-HTTP-Method-Override is optional. Always allow fallback to POST.
if isPathLengthFallback(r) {
if err := r.ParseForm(); err != nil {
OtherErrorHandler(w, r, err.Error(), http.StatusBadRequest)
return
}
h.h(w, r, pathParams)
return
}
OtherErrorHandler(w, r, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return
}
}
OtherErrorHandler(w, r, http.StatusText(http.StatusNotFound), http.StatusNotFound)
}
// GetForwardResponseOptions returns the ForwardResponseOptions associated with this ServeMux.
func (s *ServeMux) GetForwardResponseOptions() []func(context.Context, http.ResponseWriter, proto.Message) error {
return s.forwardResponseOptions
}
func isPathLengthFallback(r *http.Request) bool {
return r.Method == "POST" && r.Header.Get("Content-Type") == "application/x-www-form-urlencoded"
}
type handler struct {
pat Pattern
h HandlerFunc
}

View file

@ -0,0 +1,213 @@
package runtime_test
import (
"bytes"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/grpc-ecosystem/grpc-gateway/runtime"
"github.com/grpc-ecosystem/grpc-gateway/utilities"
)
func TestMuxServeHTTP(t *testing.T) {
type stubPattern struct {
method string
ops []int
pool []string
verb string
}
for _, spec := range []struct {
patterns []stubPattern
reqMethod string
reqPath string
headers map[string]string
respStatus int
respContent string
}{
{
patterns: nil,
reqMethod: "GET",
reqPath: "/",
respStatus: http.StatusNotFound,
},
{
patterns: []stubPattern{
{
method: "GET",
ops: []int{int(utilities.OpLitPush), 0},
pool: []string{"foo"},
},
},
reqMethod: "GET",
reqPath: "/foo",
respStatus: http.StatusOK,
respContent: "GET /foo",
},
{
patterns: []stubPattern{
{
method: "GET",
ops: []int{int(utilities.OpLitPush), 0},
pool: []string{"foo"},
},
},
reqMethod: "GET",
reqPath: "/bar",
respStatus: http.StatusNotFound,
},
{
patterns: []stubPattern{
{
method: "GET",
ops: []int{int(utilities.OpLitPush), 0},
pool: []string{"foo"},
},
{
method: "GET",
ops: []int{int(utilities.OpPush), 0},
},
},
reqMethod: "GET",
reqPath: "/foo",
respStatus: http.StatusOK,
respContent: "GET /foo",
},
{
patterns: []stubPattern{
{
method: "GET",
ops: []int{int(utilities.OpLitPush), 0},
pool: []string{"foo"},
},
{
method: "POST",
ops: []int{int(utilities.OpLitPush), 0},
pool: []string{"foo"},
},
},
reqMethod: "POST",
reqPath: "/foo",
respStatus: http.StatusOK,
respContent: "POST /foo",
},
{
patterns: []stubPattern{
{
method: "GET",
ops: []int{int(utilities.OpLitPush), 0},
pool: []string{"foo"},
},
},
reqMethod: "DELETE",
reqPath: "/foo",
respStatus: http.StatusMethodNotAllowed,
},
{
patterns: []stubPattern{
{
method: "GET",
ops: []int{int(utilities.OpLitPush), 0},
pool: []string{"foo"},
},
},
reqMethod: "POST",
reqPath: "/foo",
headers: map[string]string{
"Content-Type": "application/x-www-form-urlencoded",
},
respStatus: http.StatusOK,
respContent: "GET /foo",
},
{
patterns: []stubPattern{
{
method: "GET",
ops: []int{int(utilities.OpLitPush), 0},
pool: []string{"foo"},
},
{
method: "POST",
ops: []int{int(utilities.OpLitPush), 0},
pool: []string{"foo"},
},
},
reqMethod: "POST",
reqPath: "/foo",
headers: map[string]string{
"Content-Type": "application/x-www-form-urlencoded",
"X-HTTP-Method-Override": "GET",
},
respStatus: http.StatusOK,
respContent: "GET /foo",
},
{
patterns: []stubPattern{
{
method: "GET",
ops: []int{int(utilities.OpLitPush), 0},
pool: []string{"foo"},
},
},
reqMethod: "POST",
reqPath: "/foo",
headers: map[string]string{
"Content-Type": "application/json",
},
respStatus: http.StatusMethodNotAllowed,
},
{
patterns: []stubPattern{
{
method: "POST",
ops: []int{int(utilities.OpLitPush), 0},
pool: []string{"foo"},
verb: "bar",
},
},
reqMethod: "POST",
reqPath: "/foo:bar",
headers: map[string]string{
"Content-Type": "application/json",
},
respStatus: http.StatusOK,
respContent: "POST /foo:bar",
},
} {
mux := runtime.NewServeMux()
for _, p := range spec.patterns {
func(p stubPattern) {
pat, err := runtime.NewPattern(1, p.ops, p.pool, p.verb)
if err != nil {
t.Fatalf("runtime.NewPattern(1, %#v, %#v, %q) failed with %v; want success", p.ops, p.pool, p.verb, err)
}
mux.Handle(p.method, pat, func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) {
fmt.Fprintf(w, "%s %s", p.method, pat.String())
})
}(p)
}
url := fmt.Sprintf("http://host.example%s", spec.reqPath)
r, err := http.NewRequest(spec.reqMethod, url, bytes.NewReader(nil))
if err != nil {
t.Fatalf("http.NewRequest(%q, %q, nil) failed with %v; want success", spec.reqMethod, url, err)
}
for name, value := range spec.headers {
r.Header.Set(name, value)
}
w := httptest.NewRecorder()
mux.ServeHTTP(w, r)
if got, want := w.Code, spec.respStatus; got != want {
t.Errorf("w.Code = %d; want %d; patterns=%v; req=%v", got, want, spec.patterns, r)
}
if spec.respContent != "" {
if got, want := w.Body.String(), spec.respContent; got != want {
t.Errorf("w.Body = %q; want %q; patterns=%v; req=%v", got, want, spec.patterns, r)
}
}
}
}

View file

@ -0,0 +1,227 @@
package runtime
import (
"errors"
"fmt"
"strings"
"github.com/grpc-ecosystem/grpc-gateway/utilities"
"google.golang.org/grpc/grpclog"
)
var (
// ErrNotMatch indicates that the given HTTP request path does not match to the pattern.
ErrNotMatch = errors.New("not match to the path pattern")
// ErrInvalidPattern indicates that the given definition of Pattern is not valid.
ErrInvalidPattern = errors.New("invalid pattern")
)
type op struct {
code utilities.OpCode
operand int
}
// Pattern is a template pattern of http request paths defined in third_party/googleapis/google/api/http.proto.
type Pattern struct {
// ops is a list of operations
ops []op
// pool is a constant pool indexed by the operands or vars.
pool []string
// vars is a list of variables names to be bound by this pattern
vars []string
// stacksize is the max depth of the stack
stacksize int
// tailLen is the length of the fixed-size segments after a deep wildcard
tailLen int
// verb is the VERB part of the path pattern. It is empty if the pattern does not have VERB part.
verb string
}
// NewPattern returns a new Pattern from the given definition values.
// "ops" is a sequence of op codes. "pool" is a constant pool.
// "verb" is the verb part of the pattern. It is empty if the pattern does not have the part.
// "version" must be 1 for now.
// It returns an error if the given definition is invalid.
func NewPattern(version int, ops []int, pool []string, verb string) (Pattern, error) {
if version != 1 {
grpclog.Printf("unsupported version: %d", version)
return Pattern{}, ErrInvalidPattern
}
l := len(ops)
if l%2 != 0 {
grpclog.Printf("odd number of ops codes: %d", l)
return Pattern{}, ErrInvalidPattern
}
var (
typedOps []op
stack, maxstack int
tailLen int
pushMSeen bool
vars []string
)
for i := 0; i < l; i += 2 {
op := op{code: utilities.OpCode(ops[i]), operand: ops[i+1]}
switch op.code {
case utilities.OpNop:
continue
case utilities.OpPush:
if pushMSeen {
tailLen++
}
stack++
case utilities.OpPushM:
if pushMSeen {
grpclog.Printf("pushM appears twice")
return Pattern{}, ErrInvalidPattern
}
pushMSeen = true
stack++
case utilities.OpLitPush:
if op.operand < 0 || len(pool) <= op.operand {
grpclog.Printf("negative literal index: %d", op.operand)
return Pattern{}, ErrInvalidPattern
}
if pushMSeen {
tailLen++
}
stack++
case utilities.OpConcatN:
if op.operand <= 0 {
grpclog.Printf("negative concat size: %d", op.operand)
return Pattern{}, ErrInvalidPattern
}
stack -= op.operand
if stack < 0 {
grpclog.Print("stack underflow")
return Pattern{}, ErrInvalidPattern
}
stack++
case utilities.OpCapture:
if op.operand < 0 || len(pool) <= op.operand {
grpclog.Printf("variable name index out of bound: %d", op.operand)
return Pattern{}, ErrInvalidPattern
}
v := pool[op.operand]
op.operand = len(vars)
vars = append(vars, v)
stack--
if stack < 0 {
grpclog.Printf("stack underflow")
return Pattern{}, ErrInvalidPattern
}
default:
grpclog.Printf("invalid opcode: %d", op.code)
return Pattern{}, ErrInvalidPattern
}
if maxstack < stack {
maxstack = stack
}
typedOps = append(typedOps, op)
}
return Pattern{
ops: typedOps,
pool: pool,
vars: vars,
stacksize: maxstack,
tailLen: tailLen,
verb: verb,
}, nil
}
// MustPattern is a helper function which makes it easier to call NewPattern in variable initialization.
func MustPattern(p Pattern, err error) Pattern {
if err != nil {
grpclog.Fatalf("Pattern initialization failed: %v", err)
}
return p
}
// Match examines components if it matches to the Pattern.
// If it matches, the function returns a mapping from field paths to their captured values.
// If otherwise, the function returns an error.
func (p Pattern) Match(components []string, verb string) (map[string]string, error) {
if p.verb != verb {
return nil, ErrNotMatch
}
var pos int
stack := make([]string, 0, p.stacksize)
captured := make([]string, len(p.vars))
l := len(components)
for _, op := range p.ops {
switch op.code {
case utilities.OpNop:
continue
case utilities.OpPush, utilities.OpLitPush:
if pos >= l {
return nil, ErrNotMatch
}
c := components[pos]
if op.code == utilities.OpLitPush {
if lit := p.pool[op.operand]; c != lit {
return nil, ErrNotMatch
}
}
stack = append(stack, c)
pos++
case utilities.OpPushM:
end := len(components)
if end < pos+p.tailLen {
return nil, ErrNotMatch
}
end -= p.tailLen
stack = append(stack, strings.Join(components[pos:end], "/"))
pos = end
case utilities.OpConcatN:
n := op.operand
l := len(stack) - n
stack = append(stack[:l], strings.Join(stack[l:], "/"))
case utilities.OpCapture:
n := len(stack) - 1
captured[op.operand] = stack[n]
stack = stack[:n]
}
}
if pos < l {
return nil, ErrNotMatch
}
bindings := make(map[string]string)
for i, val := range captured {
bindings[p.vars[i]] = val
}
return bindings, nil
}
// Verb returns the verb part of the Pattern.
func (p Pattern) Verb() string { return p.verb }
func (p Pattern) String() string {
var stack []string
for _, op := range p.ops {
switch op.code {
case utilities.OpNop:
continue
case utilities.OpPush:
stack = append(stack, "*")
case utilities.OpLitPush:
stack = append(stack, p.pool[op.operand])
case utilities.OpPushM:
stack = append(stack, "**")
case utilities.OpConcatN:
n := op.operand
l := len(stack) - n
stack = append(stack[:l], strings.Join(stack[l:], "/"))
case utilities.OpCapture:
n := len(stack) - 1
stack[n] = fmt.Sprintf("{%s=%s}", p.vars[op.operand], stack[n])
}
}
segs := strings.Join(stack, "/")
if p.verb != "" {
return fmt.Sprintf("/%s:%s", segs, p.verb)
}
return "/" + segs
}

View file

@ -0,0 +1,590 @@
package runtime
import (
"fmt"
"reflect"
"strings"
"testing"
"github.com/grpc-ecosystem/grpc-gateway/utilities"
)
const (
validVersion = 1
anything = 0
)
func TestNewPattern(t *testing.T) {
for _, spec := range []struct {
ops []int
pool []string
verb string
stackSizeWant, tailLenWant int
}{
{},
{
ops: []int{int(utilities.OpNop), anything},
stackSizeWant: 0,
tailLenWant: 0,
},
{
ops: []int{int(utilities.OpPush), anything},
stackSizeWant: 1,
tailLenWant: 0,
},
{
ops: []int{int(utilities.OpLitPush), 0},
pool: []string{"abc"},
stackSizeWant: 1,
tailLenWant: 0,
},
{
ops: []int{int(utilities.OpPushM), anything},
stackSizeWant: 1,
tailLenWant: 0,
},
{
ops: []int{
int(utilities.OpPush), anything,
int(utilities.OpConcatN), 1,
},
stackSizeWant: 1,
tailLenWant: 0,
},
{
ops: []int{
int(utilities.OpPush), anything,
int(utilities.OpConcatN), 1,
int(utilities.OpCapture), 0,
},
pool: []string{"abc"},
stackSizeWant: 1,
tailLenWant: 0,
},
{
ops: []int{
int(utilities.OpPush), anything,
int(utilities.OpLitPush), 0,
int(utilities.OpLitPush), 1,
int(utilities.OpPushM), anything,
int(utilities.OpConcatN), 2,
int(utilities.OpCapture), 2,
},
pool: []string{"lit1", "lit2", "var1"},
stackSizeWant: 4,
tailLenWant: 0,
},
{
ops: []int{
int(utilities.OpPushM), anything,
int(utilities.OpConcatN), 1,
int(utilities.OpCapture), 2,
int(utilities.OpLitPush), 0,
int(utilities.OpLitPush), 1,
},
pool: []string{"lit1", "lit2", "var1"},
stackSizeWant: 2,
tailLenWant: 2,
},
{
ops: []int{
int(utilities.OpLitPush), 0,
int(utilities.OpLitPush), 1,
int(utilities.OpPushM), anything,
int(utilities.OpLitPush), 2,
int(utilities.OpConcatN), 3,
int(utilities.OpLitPush), 3,
int(utilities.OpCapture), 4,
},
pool: []string{"lit1", "lit2", "lit3", "lit4", "var1"},
stackSizeWant: 4,
tailLenWant: 2,
},
{
ops: []int{int(utilities.OpLitPush), 0},
pool: []string{"abc"},
verb: "LOCK",
stackSizeWant: 1,
tailLenWant: 0,
},
} {
pat, err := NewPattern(validVersion, spec.ops, spec.pool, spec.verb)
if err != nil {
t.Errorf("NewPattern(%d, %v, %q, %q) failed with %v; want success", validVersion, spec.ops, spec.pool, spec.verb, err)
continue
}
if got, want := pat.stacksize, spec.stackSizeWant; got != want {
t.Errorf("pat.stacksize = %d; want %d", got, want)
}
if got, want := pat.tailLen, spec.tailLenWant; got != want {
t.Errorf("pat.stacksize = %d; want %d", got, want)
}
}
}
func TestNewPatternWithWrongOp(t *testing.T) {
for _, spec := range []struct {
ops []int
pool []string
verb string
}{
{
// op code out of bound
ops: []int{-1, anything},
},
{
// op code out of bound
ops: []int{int(utilities.OpEnd), 0},
},
{
// odd number of items
ops: []int{int(utilities.OpPush)},
},
{
// negative index
ops: []int{int(utilities.OpLitPush), -1},
pool: []string{"abc"},
},
{
// index out of bound
ops: []int{int(utilities.OpLitPush), 1},
pool: []string{"abc"},
},
{
// negative # of segments
ops: []int{int(utilities.OpConcatN), -1},
pool: []string{"abc"},
},
{
// negative index
ops: []int{int(utilities.OpCapture), -1},
pool: []string{"abc"},
},
{
// index out of bound
ops: []int{int(utilities.OpCapture), 1},
pool: []string{"abc"},
},
{
// pushM appears twice
ops: []int{
int(utilities.OpPushM), anything,
int(utilities.OpLitPush), 0,
int(utilities.OpPushM), anything,
},
pool: []string{"abc"},
},
} {
_, err := NewPattern(validVersion, spec.ops, spec.pool, spec.verb)
if err == nil {
t.Errorf("NewPattern(%d, %v, %q, %q) succeeded; want failure with %v", validVersion, spec.ops, spec.pool, spec.verb, ErrInvalidPattern)
continue
}
if err != ErrInvalidPattern {
t.Errorf("NewPattern(%d, %v, %q, %q) failed with %v; want failure with %v", validVersion, spec.ops, spec.pool, spec.verb, err, ErrInvalidPattern)
continue
}
}
}
func TestNewPatternWithStackUnderflow(t *testing.T) {
for _, spec := range []struct {
ops []int
pool []string
verb string
}{
{
ops: []int{int(utilities.OpConcatN), 1},
},
{
ops: []int{int(utilities.OpCapture), 0},
pool: []string{"abc"},
},
} {
_, err := NewPattern(validVersion, spec.ops, spec.pool, spec.verb)
if err == nil {
t.Errorf("NewPattern(%d, %v, %q, %q) succeeded; want failure with %v", validVersion, spec.ops, spec.pool, spec.verb, ErrInvalidPattern)
continue
}
if err != ErrInvalidPattern {
t.Errorf("NewPattern(%d, %v, %q, %q) failed with %v; want failure with %v", validVersion, spec.ops, spec.pool, spec.verb, err, ErrInvalidPattern)
continue
}
}
}
func TestMatch(t *testing.T) {
for _, spec := range []struct {
ops []int
pool []string
verb string
match []string
notMatch []string
}{
{
match: []string{""},
notMatch: []string{"example"},
},
{
ops: []int{int(utilities.OpNop), anything},
match: []string{""},
notMatch: []string{"example", "path/to/example"},
},
{
ops: []int{int(utilities.OpPush), anything},
match: []string{"abc", "def"},
notMatch: []string{"", "abc/def"},
},
{
ops: []int{int(utilities.OpLitPush), 0},
pool: []string{"v1"},
match: []string{"v1"},
notMatch: []string{"", "v2"},
},
{
ops: []int{int(utilities.OpPushM), anything},
match: []string{"", "abc", "abc/def", "abc/def/ghi"},
},
{
ops: []int{
int(utilities.OpPushM), anything,
int(utilities.OpLitPush), 0,
},
pool: []string{"tail"},
match: []string{"tail", "abc/tail", "abc/def/tail"},
notMatch: []string{
"", "abc", "abc/def",
"tail/extra", "abc/tail/extra", "abc/def/tail/extra",
},
},
{
ops: []int{
int(utilities.OpLitPush), 0,
int(utilities.OpLitPush), 1,
int(utilities.OpPush), anything,
int(utilities.OpConcatN), 1,
int(utilities.OpCapture), 2,
},
pool: []string{"v1", "bucket", "name"},
match: []string{"v1/bucket/my-bucket", "v1/bucket/our-bucket"},
notMatch: []string{
"",
"v1",
"v1/bucket",
"v2/bucket/my-bucket",
"v1/pubsub/my-topic",
},
},
{
ops: []int{
int(utilities.OpLitPush), 0,
int(utilities.OpLitPush), 1,
int(utilities.OpPushM), anything,
int(utilities.OpConcatN), 2,
int(utilities.OpCapture), 2,
},
pool: []string{"v1", "o", "name"},
match: []string{
"v1/o",
"v1/o/my-bucket",
"v1/o/our-bucket",
"v1/o/my-bucket/dir",
"v1/o/my-bucket/dir/dir2",
"v1/o/my-bucket/dir/dir2/obj",
},
notMatch: []string{
"",
"v1",
"v2/o/my-bucket",
"v1/b/my-bucket",
},
},
{
ops: []int{
int(utilities.OpLitPush), 0,
int(utilities.OpLitPush), 1,
int(utilities.OpPush), anything,
int(utilities.OpConcatN), 2,
int(utilities.OpCapture), 2,
int(utilities.OpLitPush), 3,
int(utilities.OpPush), anything,
int(utilities.OpConcatN), 1,
int(utilities.OpCapture), 4,
},
pool: []string{"v2", "b", "name", "o", "oname"},
match: []string{
"v2/b/my-bucket/o/obj",
"v2/b/our-bucket/o/obj",
"v2/b/my-bucket/o/dir",
},
notMatch: []string{
"",
"v2",
"v2/b",
"v2/b/my-bucket",
"v2/b/my-bucket/o",
},
},
{
ops: []int{int(utilities.OpLitPush), 0},
pool: []string{"v1"},
verb: "LOCK",
match: []string{"v1:LOCK"},
notMatch: []string{"v1", "LOCK"},
},
} {
pat, err := NewPattern(validVersion, spec.ops, spec.pool, spec.verb)
if err != nil {
t.Errorf("NewPattern(%d, %v, %q, %q) failed with %v; want success", validVersion, spec.ops, spec.pool, spec.verb, err)
continue
}
for _, path := range spec.match {
_, err = pat.Match(segments(path))
if err != nil {
t.Errorf("pat.Match(%q) failed with %v; want success; pattern = (%v, %q)", path, err, spec.ops, spec.pool)
}
}
for _, path := range spec.notMatch {
_, err = pat.Match(segments(path))
if err == nil {
t.Errorf("pat.Match(%q) succeeded; want failure with %v; pattern = (%v, %q)", path, ErrNotMatch, spec.ops, spec.pool)
continue
}
if err != ErrNotMatch {
t.Errorf("pat.Match(%q) failed with %v; want failure with %v; pattern = (%v, %q)", spec.notMatch, err, ErrNotMatch, spec.ops, spec.pool)
}
}
}
}
func TestMatchWithBinding(t *testing.T) {
for _, spec := range []struct {
ops []int
pool []string
path string
verb string
want map[string]string
}{
{
want: make(map[string]string),
},
{
ops: []int{int(utilities.OpNop), anything},
want: make(map[string]string),
},
{
ops: []int{int(utilities.OpPush), anything},
path: "abc",
want: make(map[string]string),
},
{
ops: []int{int(utilities.OpPush), anything},
verb: "LOCK",
path: "abc:LOCK",
want: make(map[string]string),
},
{
ops: []int{int(utilities.OpLitPush), 0},
pool: []string{"endpoint"},
path: "endpoint",
want: make(map[string]string),
},
{
ops: []int{int(utilities.OpPushM), anything},
path: "abc/def/ghi",
want: make(map[string]string),
},
{
ops: []int{
int(utilities.OpLitPush), 0,
int(utilities.OpLitPush), 1,
int(utilities.OpPush), anything,
int(utilities.OpConcatN), 1,
int(utilities.OpCapture), 2,
},
pool: []string{"v1", "bucket", "name"},
path: "v1/bucket/my-bucket",
want: map[string]string{
"name": "my-bucket",
},
},
{
ops: []int{
int(utilities.OpLitPush), 0,
int(utilities.OpLitPush), 1,
int(utilities.OpPush), anything,
int(utilities.OpConcatN), 1,
int(utilities.OpCapture), 2,
},
pool: []string{"v1", "bucket", "name"},
verb: "LOCK",
path: "v1/bucket/my-bucket:LOCK",
want: map[string]string{
"name": "my-bucket",
},
},
{
ops: []int{
int(utilities.OpLitPush), 0,
int(utilities.OpLitPush), 1,
int(utilities.OpPushM), anything,
int(utilities.OpConcatN), 2,
int(utilities.OpCapture), 2,
},
pool: []string{"v1", "o", "name"},
path: "v1/o/my-bucket/dir/dir2/obj",
want: map[string]string{
"name": "o/my-bucket/dir/dir2/obj",
},
},
{
ops: []int{
int(utilities.OpLitPush), 0,
int(utilities.OpLitPush), 1,
int(utilities.OpPushM), anything,
int(utilities.OpLitPush), 2,
int(utilities.OpConcatN), 3,
int(utilities.OpCapture), 4,
int(utilities.OpLitPush), 3,
},
pool: []string{"v1", "o", ".ext", "tail", "name"},
path: "v1/o/my-bucket/dir/dir2/obj/.ext/tail",
want: map[string]string{
"name": "o/my-bucket/dir/dir2/obj/.ext",
},
},
{
ops: []int{
int(utilities.OpLitPush), 0,
int(utilities.OpLitPush), 1,
int(utilities.OpPush), anything,
int(utilities.OpConcatN), 2,
int(utilities.OpCapture), 2,
int(utilities.OpLitPush), 3,
int(utilities.OpPush), anything,
int(utilities.OpConcatN), 1,
int(utilities.OpCapture), 4,
},
pool: []string{"v2", "b", "name", "o", "oname"},
path: "v2/b/my-bucket/o/obj",
want: map[string]string{
"name": "b/my-bucket",
"oname": "obj",
},
},
} {
pat, err := NewPattern(validVersion, spec.ops, spec.pool, spec.verb)
if err != nil {
t.Errorf("NewPattern(%d, %v, %q, %q) failed with %v; want success", validVersion, spec.ops, spec.pool, spec.verb, err)
continue
}
got, err := pat.Match(segments(spec.path))
if err != nil {
t.Errorf("pat.Match(%q) failed with %v; want success; pattern = (%v, %q)", spec.path, err, spec.ops, spec.pool)
}
if !reflect.DeepEqual(got, spec.want) {
t.Errorf("pat.Match(%q) = %q; want %q; pattern = (%v, %q)", spec.path, got, spec.want, spec.ops, spec.pool)
}
}
}
func segments(path string) (components []string, verb string) {
if path == "" {
return nil, ""
}
components = strings.Split(path, "/")
l := len(components)
c := components[l-1]
if idx := strings.LastIndex(c, ":"); idx >= 0 {
components[l-1], verb = c[:idx], c[idx+1:]
}
return components, verb
}
func TestPatternString(t *testing.T) {
for _, spec := range []struct {
ops []int
pool []string
want string
}{
{
want: "/",
},
{
ops: []int{int(utilities.OpNop), anything},
want: "/",
},
{
ops: []int{int(utilities.OpPush), anything},
want: "/*",
},
{
ops: []int{int(utilities.OpLitPush), 0},
pool: []string{"endpoint"},
want: "/endpoint",
},
{
ops: []int{int(utilities.OpPushM), anything},
want: "/**",
},
{
ops: []int{
int(utilities.OpPush), anything,
int(utilities.OpConcatN), 1,
},
want: "/*",
},
{
ops: []int{
int(utilities.OpPush), anything,
int(utilities.OpConcatN), 1,
int(utilities.OpCapture), 0,
},
pool: []string{"name"},
want: "/{name=*}",
},
{
ops: []int{
int(utilities.OpLitPush), 0,
int(utilities.OpLitPush), 1,
int(utilities.OpPush), anything,
int(utilities.OpConcatN), 2,
int(utilities.OpCapture), 2,
int(utilities.OpLitPush), 3,
int(utilities.OpPushM), anything,
int(utilities.OpLitPush), 4,
int(utilities.OpConcatN), 3,
int(utilities.OpCapture), 6,
int(utilities.OpLitPush), 5,
},
pool: []string{"v1", "buckets", "bucket_name", "objects", ".ext", "tail", "name"},
want: "/v1/{bucket_name=buckets/*}/{name=objects/**/.ext}/tail",
},
} {
p, err := NewPattern(validVersion, spec.ops, spec.pool, "")
if err != nil {
t.Errorf("NewPattern(%d, %v, %q, %q) failed with %v; want success", validVersion, spec.ops, spec.pool, "", err)
continue
}
if got, want := p.String(), spec.want; got != want {
t.Errorf("%#v.String() = %q; want %q", p, got, want)
}
verb := "LOCK"
p, err = NewPattern(validVersion, spec.ops, spec.pool, verb)
if err != nil {
t.Errorf("NewPattern(%d, %v, %q, %q) failed with %v; want success", validVersion, spec.ops, spec.pool, verb, err)
continue
}
if got, want := p.String(), fmt.Sprintf("%s:%s", spec.want, verb); got != want {
t.Errorf("%#v.String() = %q; want %q", p, got, want)
}
}
}

View file

@ -0,0 +1,80 @@
package runtime
import (
"github.com/golang/protobuf/proto"
)
// StringP returns a pointer to a string whose pointee is same as the given string value.
func StringP(val string) (*string, error) {
return proto.String(val), nil
}
// BoolP parses the given string representation of a boolean value,
// and returns a pointer to a bool whose value is same as the parsed value.
func BoolP(val string) (*bool, error) {
b, err := Bool(val)
if err != nil {
return nil, err
}
return proto.Bool(b), nil
}
// Float64P parses the given string representation of a floating point number,
// and returns a pointer to a float64 whose value is same as the parsed number.
func Float64P(val string) (*float64, error) {
f, err := Float64(val)
if err != nil {
return nil, err
}
return proto.Float64(f), nil
}
// Float32P parses the given string representation of a floating point number,
// and returns a pointer to a float32 whose value is same as the parsed number.
func Float32P(val string) (*float32, error) {
f, err := Float32(val)
if err != nil {
return nil, err
}
return proto.Float32(f), nil
}
// Int64P parses the given string representation of an integer
// and returns a pointer to a int64 whose value is same as the parsed integer.
func Int64P(val string) (*int64, error) {
i, err := Int64(val)
if err != nil {
return nil, err
}
return proto.Int64(i), nil
}
// Int32P parses the given string representation of an integer
// and returns a pointer to a int32 whose value is same as the parsed integer.
func Int32P(val string) (*int32, error) {
i, err := Int32(val)
if err != nil {
return nil, err
}
return proto.Int32(i), err
}
// Uint64P parses the given string representation of an integer
// and returns a pointer to a uint64 whose value is same as the parsed integer.
func Uint64P(val string) (*uint64, error) {
i, err := Uint64(val)
if err != nil {
return nil, err
}
return proto.Uint64(i), err
}
// Uint32P parses the given string representation of an integer
// and returns a pointer to a uint32 whose value is same as the parsed integer.
func Uint32P(val string) (*uint32, error) {
i, err := Uint32(val)
if err != nil {
return nil, err
}
return proto.Uint32(i), err
}

View file

@ -0,0 +1,140 @@
package runtime
import (
"fmt"
"net/url"
"reflect"
"strings"
"github.com/golang/protobuf/proto"
"github.com/grpc-ecosystem/grpc-gateway/utilities"
"google.golang.org/grpc/grpclog"
)
// PopulateQueryParameters populates "values" into "msg".
// A value is ignored if its key starts with one of the elements in "filter".
func PopulateQueryParameters(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error {
for key, values := range values {
fieldPath := strings.Split(key, ".")
if filter.HasCommonPrefix(fieldPath) {
continue
}
if err := populateFieldValueFromPath(msg, fieldPath, values); err != nil {
return err
}
}
return nil
}
// PopulateFieldFromPath sets a value in a nested Protobuf structure.
// It instantiates missing protobuf fields as it goes.
func PopulateFieldFromPath(msg proto.Message, fieldPathString string, value string) error {
fieldPath := strings.Split(fieldPathString, ".")
return populateFieldValueFromPath(msg, fieldPath, []string{value})
}
func populateFieldValueFromPath(msg proto.Message, fieldPath []string, values []string) error {
m := reflect.ValueOf(msg)
if m.Kind() != reflect.Ptr {
return fmt.Errorf("unexpected type %T: %v", msg, msg)
}
m = m.Elem()
for i, fieldName := range fieldPath {
isLast := i == len(fieldPath)-1
if !isLast && m.Kind() != reflect.Struct {
return fmt.Errorf("non-aggregate type in the mid of path: %s", strings.Join(fieldPath, "."))
}
f := fieldByProtoName(m, fieldName)
if !f.IsValid() {
grpclog.Printf("field not found in %T: %s", msg, strings.Join(fieldPath, "."))
return nil
}
switch f.Kind() {
case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64, reflect.String, reflect.Uint32, reflect.Uint64:
m = f
case reflect.Slice:
// TODO(yugui) Support []byte
if !isLast {
return fmt.Errorf("unexpected repeated field in %s", strings.Join(fieldPath, "."))
}
return populateRepeatedField(f, values)
case reflect.Ptr:
if f.IsNil() {
m = reflect.New(f.Type().Elem())
f.Set(m)
}
m = f.Elem()
continue
case reflect.Struct:
m = f
continue
default:
return fmt.Errorf("unexpected type %s in %T", f.Type(), msg)
}
}
switch len(values) {
case 0:
return fmt.Errorf("no value of field: %s", strings.Join(fieldPath, "."))
case 1:
default:
grpclog.Printf("too many field values: %s", strings.Join(fieldPath, "."))
}
return populateField(m, values[0])
}
// fieldByProtoName looks up a field whose corresponding protobuf field name is "name".
// "m" must be a struct value. It returns zero reflect.Value if no such field found.
func fieldByProtoName(m reflect.Value, name string) reflect.Value {
props := proto.GetProperties(m.Type())
for _, p := range props.Prop {
if p.OrigName == name {
return m.FieldByName(p.Name)
}
}
return reflect.Value{}
}
func populateRepeatedField(f reflect.Value, values []string) error {
elemType := f.Type().Elem()
conv, ok := convFromType[elemType.Kind()]
if !ok {
return fmt.Errorf("unsupported field type %s", elemType)
}
f.Set(reflect.MakeSlice(f.Type(), len(values), len(values)))
for i, v := range values {
result := conv.Call([]reflect.Value{reflect.ValueOf(v)})
if err := result[1].Interface(); err != nil {
return err.(error)
}
f.Index(i).Set(result[0])
}
return nil
}
func populateField(f reflect.Value, value string) error {
conv, ok := convFromType[f.Kind()]
if !ok {
return fmt.Errorf("unsupported field type %T", f)
}
result := conv.Call([]reflect.Value{reflect.ValueOf(value)})
if err := result[1].Interface(); err != nil {
return err.(error)
}
f.Set(result[0])
return nil
}
var (
convFromType = map[reflect.Kind]reflect.Value{
reflect.String: reflect.ValueOf(String),
reflect.Bool: reflect.ValueOf(Bool),
reflect.Float64: reflect.ValueOf(Float64),
reflect.Float32: reflect.ValueOf(Float32),
reflect.Int64: reflect.ValueOf(Int64),
reflect.Int32: reflect.ValueOf(Int32),
reflect.Uint64: reflect.ValueOf(Uint64),
reflect.Uint32: reflect.ValueOf(Uint32),
// TODO(yugui) Support []byte
}
)

View file

@ -0,0 +1,311 @@
package runtime_test
import (
"net/url"
"testing"
"github.com/golang/protobuf/proto"
"github.com/grpc-ecosystem/grpc-gateway/runtime"
"github.com/grpc-ecosystem/grpc-gateway/utilities"
)
func TestPopulateParameters(t *testing.T) {
for _, spec := range []struct {
values url.Values
filter *utilities.DoubleArray
want proto.Message
}{
{
values: url.Values{
"float_value": {"1.5"},
"double_value": {"2.5"},
"int64_value": {"-1"},
"int32_value": {"-2"},
"uint64_value": {"3"},
"uint32_value": {"4"},
"bool_value": {"true"},
"string_value": {"str"},
"repeated_value": {"a", "b", "c"},
},
filter: utilities.NewDoubleArray(nil),
want: &proto3Message{
FloatValue: 1.5,
DoubleValue: 2.5,
Int64Value: -1,
Int32Value: -2,
Uint64Value: 3,
Uint32Value: 4,
BoolValue: true,
StringValue: "str",
RepeatedValue: []string{"a", "b", "c"},
},
},
{
values: url.Values{
"float_value": {"1.5"},
"double_value": {"2.5"},
"int64_value": {"-1"},
"int32_value": {"-2"},
"uint64_value": {"3"},
"uint32_value": {"4"},
"bool_value": {"true"},
"string_value": {"str"},
"repeated_value": {"a", "b", "c"},
},
filter: utilities.NewDoubleArray(nil),
want: &proto2Message{
FloatValue: proto.Float32(1.5),
DoubleValue: proto.Float64(2.5),
Int64Value: proto.Int64(-1),
Int32Value: proto.Int32(-2),
Uint64Value: proto.Uint64(3),
Uint32Value: proto.Uint32(4),
BoolValue: proto.Bool(true),
StringValue: proto.String("str"),
RepeatedValue: []string{"a", "b", "c"},
},
},
{
values: url.Values{
"nested.nested.nested.repeated_value": {"a", "b", "c"},
"nested.nested.nested.string_value": {"s"},
"nested.nested.string_value": {"t"},
"nested.string_value": {"u"},
"nested_non_null.string_value": {"v"},
},
filter: utilities.NewDoubleArray(nil),
want: &proto3Message{
Nested: &proto2Message{
Nested: &proto3Message{
Nested: &proto2Message{
RepeatedValue: []string{"a", "b", "c"},
StringValue: proto.String("s"),
},
StringValue: "t",
},
StringValue: proto.String("u"),
},
NestedNonNull: proto2Message{
StringValue: proto.String("v"),
},
},
},
{
values: url.Values{
"uint64_value": {"1", "2", "3", "4", "5"},
},
filter: utilities.NewDoubleArray(nil),
want: &proto3Message{
Uint64Value: 1,
},
},
} {
msg := proto.Clone(spec.want)
msg.Reset()
err := runtime.PopulateQueryParameters(msg, spec.values, spec.filter)
if err != nil {
t.Errorf("runtime.PoplateQueryParameters(msg, %v, %v) failed with %v; want success", spec.values, spec.filter, err)
continue
}
if got, want := msg, spec.want; !proto.Equal(got, want) {
t.Errorf("runtime.PopulateQueryParameters(msg, %v, %v = %v; want %v", spec.values, spec.filter, got, want)
}
}
}
func TestPopulateParametersWithFilters(t *testing.T) {
for _, spec := range []struct {
values url.Values
filter *utilities.DoubleArray
want proto.Message
}{
{
values: url.Values{
"bool_value": {"true"},
"string_value": {"str"},
"repeated_value": {"a", "b", "c"},
},
filter: utilities.NewDoubleArray([][]string{
{"bool_value"}, {"repeated_value"},
}),
want: &proto3Message{
StringValue: "str",
},
},
{
values: url.Values{
"nested.nested.bool_value": {"true"},
"nested.nested.string_value": {"str"},
"nested.string_value": {"str"},
"string_value": {"str"},
},
filter: utilities.NewDoubleArray([][]string{
{"nested"},
}),
want: &proto3Message{
StringValue: "str",
},
},
{
values: url.Values{
"nested.nested.bool_value": {"true"},
"nested.nested.string_value": {"str"},
"nested.string_value": {"str"},
"string_value": {"str"},
},
filter: utilities.NewDoubleArray([][]string{
{"nested", "nested"},
}),
want: &proto3Message{
Nested: &proto2Message{
StringValue: proto.String("str"),
},
StringValue: "str",
},
},
{
values: url.Values{
"nested.nested.bool_value": {"true"},
"nested.nested.string_value": {"str"},
"nested.string_value": {"str"},
"string_value": {"str"},
},
filter: utilities.NewDoubleArray([][]string{
{"nested", "nested", "string_value"},
}),
want: &proto3Message{
Nested: &proto2Message{
StringValue: proto.String("str"),
Nested: &proto3Message{
BoolValue: true,
},
},
StringValue: "str",
},
},
} {
msg := proto.Clone(spec.want)
msg.Reset()
err := runtime.PopulateQueryParameters(msg, spec.values, spec.filter)
if err != nil {
t.Errorf("runtime.PoplateQueryParameters(msg, %v, %v) failed with %v; want success", spec.values, spec.filter, err)
continue
}
if got, want := msg, spec.want; !proto.Equal(got, want) {
t.Errorf("runtime.PopulateQueryParameters(msg, %v, %v = %v; want %v", spec.values, spec.filter, got, want)
}
}
}
type proto3Message struct {
Nested *proto2Message `protobuf:"bytes,1,opt,name=nested" json:"nested,omitempty"`
NestedNonNull proto2Message `protobuf:"bytes,11,opt,name=nested_non_null" json:"nested_non_null,omitempty"`
FloatValue float32 `protobuf:"fixed32,2,opt,name=float_value" json:"float_value,omitempty"`
DoubleValue float64 `protobuf:"fixed64,3,opt,name=double_value" json:"double_value,omitempty"`
Int64Value int64 `protobuf:"varint,4,opt,name=int64_value" json:"int64_value,omitempty"`
Int32Value int32 `protobuf:"varint,5,opt,name=int32_value" json:"int32_value,omitempty"`
Uint64Value uint64 `protobuf:"varint,6,opt,name=uint64_value" json:"uint64_value,omitempty"`
Uint32Value uint32 `protobuf:"varint,7,opt,name=uint32_value" json:"uint32_value,omitempty"`
BoolValue bool `protobuf:"varint,8,opt,name=bool_value" json:"bool_value,omitempty"`
StringValue string `protobuf:"bytes,9,opt,name=string_value" json:"string_value,omitempty"`
RepeatedValue []string `protobuf:"bytes,10,rep,name=repeated_value" json:"repeated_value,omitempty"`
}
func (m *proto3Message) Reset() { *m = proto3Message{} }
func (m *proto3Message) String() string { return proto.CompactTextString(m) }
func (*proto3Message) ProtoMessage() {}
func (m *proto3Message) GetNested() *proto2Message {
if m != nil {
return m.Nested
}
return nil
}
type proto2Message struct {
Nested *proto3Message `protobuf:"bytes,1,opt,name=nested" json:"nested,omitempty"`
FloatValue *float32 `protobuf:"fixed32,2,opt,name=float_value" json:"float_value,omitempty"`
DoubleValue *float64 `protobuf:"fixed64,3,opt,name=double_value" json:"double_value,omitempty"`
Int64Value *int64 `protobuf:"varint,4,opt,name=int64_value" json:"int64_value,omitempty"`
Int32Value *int32 `protobuf:"varint,5,opt,name=int32_value" json:"int32_value,omitempty"`
Uint64Value *uint64 `protobuf:"varint,6,opt,name=uint64_value" json:"uint64_value,omitempty"`
Uint32Value *uint32 `protobuf:"varint,7,opt,name=uint32_value" json:"uint32_value,omitempty"`
BoolValue *bool `protobuf:"varint,8,opt,name=bool_value" json:"bool_value,omitempty"`
StringValue *string `protobuf:"bytes,9,opt,name=string_value" json:"string_value,omitempty"`
RepeatedValue []string `protobuf:"bytes,10,rep,name=repeated_value" json:"repeated_value,omitempty"`
XXX_unrecognized []byte `json:"-"`
}
func (m *proto2Message) Reset() { *m = proto2Message{} }
func (m *proto2Message) String() string { return proto.CompactTextString(m) }
func (*proto2Message) ProtoMessage() {}
func (m *proto2Message) GetNested() *proto3Message {
if m != nil {
return m.Nested
}
return nil
}
func (m *proto2Message) GetFloatValue() float32 {
if m != nil && m.FloatValue != nil {
return *m.FloatValue
}
return 0
}
func (m *proto2Message) GetDoubleValue() float64 {
if m != nil && m.DoubleValue != nil {
return *m.DoubleValue
}
return 0
}
func (m *proto2Message) GetInt64Value() int64 {
if m != nil && m.Int64Value != nil {
return *m.Int64Value
}
return 0
}
func (m *proto2Message) GetInt32Value() int32 {
if m != nil && m.Int32Value != nil {
return *m.Int32Value
}
return 0
}
func (m *proto2Message) GetUint64Value() uint64 {
if m != nil && m.Uint64Value != nil {
return *m.Uint64Value
}
return 0
}
func (m *proto2Message) GetUint32Value() uint32 {
if m != nil && m.Uint32Value != nil {
return *m.Uint32Value
}
return 0
}
func (m *proto2Message) GetBoolValue() bool {
if m != nil && m.BoolValue != nil {
return *m.BoolValue
}
return false
}
func (m *proto2Message) GetStringValue() string {
if m != nil && m.StringValue != nil {
return *m.StringValue
}
return ""
}
func (m *proto2Message) GetRepeatedValue() []string {
if m != nil {
return m.RepeatedValue
}
return nil
}

View file

@ -0,0 +1,2 @@
// Package utilities provides members for internal use in grpc-gateway.
package utilities

View file

@ -0,0 +1,22 @@
package utilities
// An OpCode is a opcode of compiled path patterns.
type OpCode int
// These constants are the valid values of OpCode.
const (
// OpNop does nothing
OpNop = OpCode(iota)
// OpPush pushes a component to stack
OpPush
// OpLitPush pushes a component to stack if it matches to the literal
OpLitPush
// OpPushM concatenates the remaining components and pushes it to stack
OpPushM
// OpConcatN pops N items from stack, concatenates them and pushes it back to stack
OpConcatN
// OpCapture pops an item and binds it to the variable
OpCapture
// OpEnd is the least postive invalid opcode.
OpEnd
)

View file

@ -0,0 +1,177 @@
package utilities
import (
"sort"
)
// DoubleArray is a Double Array implementation of trie on sequences of strings.
type DoubleArray struct {
// Encoding keeps an encoding from string to int
Encoding map[string]int
// Base is the base array of Double Array
Base []int
// Check is the check array of Double Array
Check []int
}
// NewDoubleArray builds a DoubleArray from a set of sequences of strings.
func NewDoubleArray(seqs [][]string) *DoubleArray {
da := &DoubleArray{Encoding: make(map[string]int)}
if len(seqs) == 0 {
return da
}
encoded := registerTokens(da, seqs)
sort.Sort(byLex(encoded))
root := node{row: -1, col: -1, left: 0, right: len(encoded)}
addSeqs(da, encoded, 0, root)
for i := len(da.Base); i > 0; i-- {
if da.Check[i-1] != 0 {
da.Base = da.Base[:i]
da.Check = da.Check[:i]
break
}
}
return da
}
func registerTokens(da *DoubleArray, seqs [][]string) [][]int {
var result [][]int
for _, seq := range seqs {
var encoded []int
for _, token := range seq {
if _, ok := da.Encoding[token]; !ok {
da.Encoding[token] = len(da.Encoding)
}
encoded = append(encoded, da.Encoding[token])
}
result = append(result, encoded)
}
for i := range result {
result[i] = append(result[i], len(da.Encoding))
}
return result
}
type node struct {
row, col int
left, right int
}
func (n node) value(seqs [][]int) int {
return seqs[n.row][n.col]
}
func (n node) children(seqs [][]int) []*node {
var result []*node
lastVal := int(-1)
last := new(node)
for i := n.left; i < n.right; i++ {
if lastVal == seqs[i][n.col+1] {
continue
}
last.right = i
last = &node{
row: i,
col: n.col + 1,
left: i,
}
result = append(result, last)
}
last.right = n.right
return result
}
func addSeqs(da *DoubleArray, seqs [][]int, pos int, n node) {
ensureSize(da, pos)
children := n.children(seqs)
var i int
for i = 1; ; i++ {
ok := func() bool {
for _, child := range children {
code := child.value(seqs)
j := i + code
ensureSize(da, j)
if da.Check[j] != 0 {
return false
}
}
return true
}()
if ok {
break
}
}
da.Base[pos] = i
for _, child := range children {
code := child.value(seqs)
j := i + code
da.Check[j] = pos + 1
}
terminator := len(da.Encoding)
for _, child := range children {
code := child.value(seqs)
if code == terminator {
continue
}
j := i + code
addSeqs(da, seqs, j, *child)
}
}
func ensureSize(da *DoubleArray, i int) {
for i >= len(da.Base) {
da.Base = append(da.Base, make([]int, len(da.Base)+1)...)
da.Check = append(da.Check, make([]int, len(da.Check)+1)...)
}
}
type byLex [][]int
func (l byLex) Len() int { return len(l) }
func (l byLex) Swap(i, j int) { l[i], l[j] = l[j], l[i] }
func (l byLex) Less(i, j int) bool {
si := l[i]
sj := l[j]
var k int
for k = 0; k < len(si) && k < len(sj); k++ {
if si[k] < sj[k] {
return true
}
if si[k] > sj[k] {
return false
}
}
if k < len(sj) {
return true
}
return false
}
// HasCommonPrefix determines if any sequence in the DoubleArray is a prefix of the given sequence.
func (da *DoubleArray) HasCommonPrefix(seq []string) bool {
if len(da.Base) == 0 {
return false
}
var i int
for _, t := range seq {
code, ok := da.Encoding[t]
if !ok {
break
}
j := da.Base[i] + code
if len(da.Check) <= j || da.Check[j] != i+1 {
break
}
i = j
}
j := da.Base[i] + len(da.Encoding)
if len(da.Check) <= j || da.Check[j] != i+1 {
return false
}
return true
}

View file

@ -0,0 +1,372 @@
package utilities_test
import (
"reflect"
"testing"
"github.com/grpc-ecosystem/grpc-gateway/utilities"
)
func TestMaxCommonPrefix(t *testing.T) {
for _, spec := range []struct {
da utilities.DoubleArray
tokens []string
want bool
}{
{
da: utilities.DoubleArray{},
tokens: nil,
want: false,
},
{
da: utilities.DoubleArray{},
tokens: []string{"foo"},
want: false,
},
{
da: utilities.DoubleArray{
Encoding: map[string]int{
"foo": 0,
},
Base: []int{1, 1, 0},
Check: []int{0, 1, 2},
},
tokens: nil,
want: false,
},
{
da: utilities.DoubleArray{
Encoding: map[string]int{
"foo": 0,
},
Base: []int{1, 1, 0},
Check: []int{0, 1, 2},
},
tokens: []string{"foo"},
want: true,
},
{
da: utilities.DoubleArray{
Encoding: map[string]int{
"foo": 0,
},
Base: []int{1, 1, 0},
Check: []int{0, 1, 2},
},
tokens: []string{"bar"},
want: false,
},
{
// foo|bar
da: utilities.DoubleArray{
Encoding: map[string]int{
"foo": 0,
"bar": 1,
},
Base: []int{1, 1, 2, 0, 0},
Check: []int{0, 1, 1, 2, 3},
// 0: ^
// 1: ^foo
// 2: ^bar
// 3: ^foo$
// 4: ^bar$
},
tokens: []string{"foo"},
want: true,
},
{
// foo|bar
da: utilities.DoubleArray{
Encoding: map[string]int{
"foo": 0,
"bar": 1,
},
Base: []int{1, 1, 2, 0, 0},
Check: []int{0, 1, 1, 2, 3},
// 0: ^
// 1: ^foo
// 2: ^bar
// 3: ^foo$
// 4: ^bar$
},
tokens: []string{"bar"},
want: true,
},
{
// foo|bar
da: utilities.DoubleArray{
Encoding: map[string]int{
"foo": 0,
"bar": 1,
},
Base: []int{1, 1, 2, 0, 0},
Check: []int{0, 1, 1, 2, 3},
// 0: ^
// 1: ^foo
// 2: ^bar
// 3: ^foo$
// 4: ^bar$
},
tokens: []string{"something-else"},
want: false,
},
{
// foo|bar
da: utilities.DoubleArray{
Encoding: map[string]int{
"foo": 0,
"bar": 1,
},
Base: []int{1, 1, 2, 0, 0},
Check: []int{0, 1, 1, 2, 3},
// 0: ^
// 1: ^foo
// 2: ^bar
// 3: ^foo$
// 4: ^bar$
},
tokens: []string{"foo", "bar"},
want: true,
},
{
// foo|foo\.bar|bar
da: utilities.DoubleArray{
Encoding: map[string]int{
"foo": 0,
"bar": 1,
},
Base: []int{1, 3, 1, 0, 4, 0, 0},
Check: []int{0, 1, 1, 3, 2, 2, 5},
// 0: ^
// 1: ^foo
// 2: ^bar
// 3: ^bar$
// 4: ^foo.bar
// 5: ^foo$
// 6: ^foo.bar$
},
tokens: []string{"foo"},
want: true,
},
{
// foo|foo\.bar|bar
da: utilities.DoubleArray{
Encoding: map[string]int{
"foo": 0,
"bar": 1,
},
Base: []int{1, 3, 1, 0, 4, 0, 0},
Check: []int{0, 1, 1, 3, 2, 2, 5},
// 0: ^
// 1: ^foo
// 2: ^bar
// 3: ^bar$
// 4: ^foo.bar
// 5: ^foo$
// 6: ^foo.bar$
},
tokens: []string{"foo", "bar"},
want: true,
},
{
// foo|foo\.bar|bar
da: utilities.DoubleArray{
Encoding: map[string]int{
"foo": 0,
"bar": 1,
},
Base: []int{1, 3, 1, 0, 4, 0, 0},
Check: []int{0, 1, 1, 3, 2, 2, 5},
// 0: ^
// 1: ^foo
// 2: ^bar
// 3: ^bar$
// 4: ^foo.bar
// 5: ^foo$
// 6: ^foo.bar$
},
tokens: []string{"bar"},
want: true,
},
{
// foo|foo\.bar|bar
da: utilities.DoubleArray{
Encoding: map[string]int{
"foo": 0,
"bar": 1,
},
Base: []int{1, 3, 1, 0, 4, 0, 0},
Check: []int{0, 1, 1, 3, 2, 2, 5},
// 0: ^
// 1: ^foo
// 2: ^bar
// 3: ^bar$
// 4: ^foo.bar
// 5: ^foo$
// 6: ^foo.bar$
},
tokens: []string{"something-else"},
want: false,
},
{
// foo|foo\.bar|bar
da: utilities.DoubleArray{
Encoding: map[string]int{
"foo": 0,
"bar": 1,
},
Base: []int{1, 3, 1, 0, 4, 0, 0},
Check: []int{0, 1, 1, 3, 2, 2, 5},
// 0: ^
// 1: ^foo
// 2: ^bar
// 3: ^bar$
// 4: ^foo.bar
// 5: ^foo$
// 6: ^foo.bar$
},
tokens: []string{"foo", "bar", "baz"},
want: true,
},
} {
got := spec.da.HasCommonPrefix(spec.tokens)
if got != spec.want {
t.Errorf("%#v.HasCommonPrefix(%v) = %v; want %v", spec.da, spec.tokens, got, spec.want)
}
}
}
func TestAdd(t *testing.T) {
for _, spec := range []struct {
tokens [][]string
want utilities.DoubleArray
}{
{
want: utilities.DoubleArray{
Encoding: make(map[string]int),
},
},
{
tokens: [][]string{{"foo"}},
want: utilities.DoubleArray{
Encoding: map[string]int{"foo": 0},
Base: []int{1, 1, 0},
Check: []int{0, 1, 2},
// 0: ^
// 1: ^foo
// 2: ^foo$
},
},
{
tokens: [][]string{{"foo"}, {"bar"}},
want: utilities.DoubleArray{
Encoding: map[string]int{
"foo": 0,
"bar": 1,
},
Base: []int{1, 1, 2, 0, 0},
Check: []int{0, 1, 1, 2, 3},
// 0: ^
// 1: ^foo
// 2: ^bar
// 3: ^foo$
// 4: ^bar$
},
},
{
tokens: [][]string{{"foo", "bar"}, {"foo", "baz"}},
want: utilities.DoubleArray{
Encoding: map[string]int{
"foo": 0,
"bar": 1,
"baz": 2,
},
Base: []int{1, 1, 1, 2, 0, 0},
Check: []int{0, 1, 2, 2, 3, 4},
// 0: ^
// 1: ^foo
// 2: ^foo.bar
// 3: ^foo.baz
// 4: ^foo.bar$
// 5: ^foo.baz$
},
},
{
tokens: [][]string{{"foo", "bar"}, {"foo", "baz"}, {"qux"}},
want: utilities.DoubleArray{
Encoding: map[string]int{
"foo": 0,
"bar": 1,
"baz": 2,
"qux": 3,
},
Base: []int{1, 1, 1, 2, 3, 0, 0, 0},
Check: []int{0, 1, 2, 2, 1, 3, 4, 5},
// 0: ^
// 1: ^foo
// 2: ^foo.bar
// 3: ^foo.baz
// 4: ^qux
// 5: ^foo.bar$
// 6: ^foo.baz$
// 7: ^qux$
},
},
{
tokens: [][]string{
{"foo", "bar"},
{"foo", "baz", "bar"},
{"qux", "foo"},
},
want: utilities.DoubleArray{
Encoding: map[string]int{
"foo": 0,
"bar": 1,
"baz": 2,
"qux": 3,
},
Base: []int{1, 1, 1, 5, 8, 0, 3, 0, 5, 0},
Check: []int{0, 1, 2, 2, 1, 3, 4, 7, 5, 9},
// 0: ^
// 1: ^foo
// 2: ^foo.bar
// 3: ^foo.baz
// 4: ^qux
// 5: ^foo.bar$
// 6: ^foo.baz.bar
// 7: ^foo.baz.bar$
// 8: ^qux.foo
// 9: ^qux.foo$
},
},
} {
da := utilities.NewDoubleArray(spec.tokens)
if got, want := da.Encoding, spec.want.Encoding; !reflect.DeepEqual(got, want) {
t.Errorf("da.Encoding = %v; want %v; tokens = %#v", got, want, spec.tokens)
}
if got, want := da.Base, spec.want.Base; !compareArray(got, want) {
t.Errorf("da.Base = %v; want %v; tokens = %#v", got, want, spec.tokens)
}
if got, want := da.Check, spec.want.Check; !compareArray(got, want) {
t.Errorf("da.Check = %v; want %v; tokens = %#v", got, want, spec.tokens)
}
}
}
func compareArray(got, want []int) bool {
var i int
for i = 0; i < len(got) && i < len(want); i++ {
if got[i] != want[i] {
return false
}
}
if i < len(want) {
return false
}
for ; i < len(got); i++ {
if got[i] != 0 {
return false
}
}
return true
}