teleport/lib/multiplexer/multiplexer_test.go
Jakub Nyckowski 9379243390
Use pre-generated RSA keys in tests (#18750)
Many of our tests (db package, I'm looking at you) generate many RSA keys. This has two main side effects; makes our tests slow and flaky as CPU usage spikes in random moments when the tests are run in parallel.
This change pre-generates RSA keys at the beginning of each test module and reuse them in randomized order to reduce the situation that one key has been used multiple times in one test.
I had to move a few files to avoid circular dependencies.
2022-12-05 19:33:49 +00:00

893 lines
24 KiB
Go

/*
Copyright 2017 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package multiplexer
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"testing"
"time"
"github.com/jackc/pgproto3/v2"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"github.com/gravitational/teleport/api/constants"
apisshutils "github.com/gravitational/teleport/api/utils/sshutils"
"github.com/gravitational/teleport/lib/auth/native"
"github.com/gravitational/teleport/lib/fixtures"
"github.com/gravitational/teleport/lib/httplib"
"github.com/gravitational/teleport/lib/multiplexer/test"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/cert"
)
func TestMain(m *testing.M) {
utils.InitLoggerForTests()
os.Exit(m.Run())
}
// TestMux tests multiplexing protocols
// using the same listener.
func TestMux(t *testing.T) {
_, signer, err := cert.CreateCertificate("foo", ssh.HostCert)
require.NoError(t, err)
// TestMux tests basic use case of multiplexing TLS
// and SSH on the same listener socket
t.Run("TLSSSH", func(t *testing.T) {
t.Parallel()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
mux, err := New(Config{
Listener: listener,
EnableProxyProtocol: true,
})
require.NoError(t, err)
go mux.Serve()
defer mux.Close()
backend1 := &httptest.Server{
Listener: mux.TLS(),
Config: &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "backend 1")
}),
},
}
backend1.StartTLS()
defer backend1.Close()
called := false
sshHandler := sshutils.NewChanHandlerFunc(func(_ context.Context, _ *sshutils.ConnectionContext, nch ssh.NewChannel) {
called = true
err := nch.Reject(ssh.Prohibited, "nothing to see here")
require.NoError(t, err)
})
srv, err := sshutils.NewServer(
"test",
utils.NetAddr{AddrNetwork: "tcp", Addr: "localhost:0"},
sshHandler,
[]ssh.Signer{signer},
sshutils.AuthMethods{Password: pass("abc123")},
)
require.NoError(t, err)
go srv.Serve(mux.SSH())
defer srv.Close()
clt, err := ssh.Dial("tcp", listener.Addr().String(), &ssh.ClientConfig{
Auth: []ssh.AuthMethod{ssh.Password("abc123")},
Timeout: time.Second,
HostKeyCallback: ssh.FixedHostKey(signer.PublicKey()),
})
require.NoError(t, err)
defer clt.Close()
// call new session to initiate opening new channel
_, err = clt.NewSession()
require.NotNil(t, err)
// make sure the channel handler was called OK
require.True(t, called)
client := testClient(backend1)
re, err := client.Get(backend1.URL)
require.NoError(t, err)
defer re.Body.Close()
bytes, err := io.ReadAll(re.Body)
require.NoError(t, err)
require.Equal(t, string(bytes), "backend 1")
// Close mux, new requests should fail
mux.Close()
mux.Wait()
// use new client to use new connection pool
client = testClient(backend1)
re, err = client.Get(backend1.URL)
if err == nil {
re.Body.Close()
}
require.NotNil(t, err)
})
// ProxyLine tests proxy line protocol
t.Run("ProxyLine", func(t *testing.T) {
t.Parallel()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
mux, err := New(Config{
Listener: listener,
EnableProxyProtocol: true,
})
require.NoError(t, err)
go mux.Serve()
defer mux.Close()
backend1 := &httptest.Server{
Listener: mux.TLS(),
Config: &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, r.RemoteAddr)
}),
},
}
backend1.StartTLS()
defer backend1.Close()
remoteAddr := net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8000}
proxyLine := ProxyLine{
Protocol: TCP4,
Source: remoteAddr,
Destination: net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 9000},
}
parsedURL, err := url.Parse(backend1.URL)
require.NoError(t, err)
conn, err := net.Dial("tcp", parsedURL.Host)
require.NoError(t, err)
defer conn.Close()
// send proxy line first before establishing TLS connection
_, err = fmt.Fprint(conn, proxyLine.String())
require.NoError(t, err)
// upgrade connection to TLS
tlsConn := tls.Client(conn, clientConfig(backend1))
defer tlsConn.Close()
// make sure the TLS call succeeded and we got remote address
// correctly
out, err := utils.RoundtripWithConn(tlsConn)
require.NoError(t, err)
require.Equal(t, out, remoteAddr.String())
})
// ProxyLineV2 tests proxy protocol v2
t.Run("ProxyLineV2", func(t *testing.T) {
t.Parallel()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
mux, err := New(Config{
Listener: listener,
EnableProxyProtocol: true,
})
require.NoError(t, err)
go mux.Serve()
defer mux.Close()
backend1 := &httptest.Server{
Listener: mux.TLS(),
Config: &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, r.RemoteAddr)
}),
},
}
backend1.StartTLS()
defer backend1.Close()
parsedURL, err := url.Parse(backend1.URL)
require.NoError(t, err)
conn, err := net.Dial("tcp", parsedURL.Host)
require.NoError(t, err)
defer conn.Close()
// send proxy header + addresses before establishing TLS connection
_, err = conn.Write([]byte{
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, // signature
0x21, 0x11, // version/command, family
0x00, 12, // address length
0x7F, 0x00, 0x00, 0x01, // source address: 127.0.0.1
0x7F, 0x00, 0x00, 0x01, // destination address: 127.0.0.1
0x1F, 0x40, 0x23, 0x28, // source port: 8000, destination port: 9000
})
require.NoError(t, err)
// upgrade connection to TLS
tlsConn := tls.Client(conn, clientConfig(backend1))
defer tlsConn.Close()
// make sure the TLS call succeeded and we got remote address
// correctly
out, err := utils.RoundtripWithConn(tlsConn)
require.NoError(t, err)
require.Equal(t, out, "127.0.0.1:8000")
})
// TestDisabledProxy makes sure the connection gets dropped
// when Proxy line support protocol is turned off
t.Run("DisabledProxy", func(t *testing.T) {
t.Parallel()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
mux, err := New(Config{
Listener: listener,
EnableProxyProtocol: false,
})
require.NoError(t, err)
go mux.Serve()
defer mux.Close()
backend1 := &httptest.Server{
Listener: mux.TLS(),
Config: &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, r.RemoteAddr)
}),
},
}
backend1.StartTLS()
defer backend1.Close()
remoteAddr := net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8000}
proxyLine := ProxyLine{
Protocol: TCP4,
Source: remoteAddr,
Destination: net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 9000},
}
parsedURL, err := url.Parse(backend1.URL)
require.NoError(t, err)
conn, err := net.Dial("tcp", parsedURL.Host)
require.NoError(t, err)
defer conn.Close()
// send proxy line first before establishing TLS connection
_, err = fmt.Fprint(conn, proxyLine.String())
require.NoError(t, err)
// upgrade connection to TLS
tlsConn := tls.Client(conn, clientConfig(backend1))
defer tlsConn.Close()
// make sure the TLS call failed
_, err = utils.RoundtripWithConn(tlsConn)
require.NotNil(t, err)
})
// Timeout test makes sure that multiplexer respects read deadlines.
t.Run("Timeout", func(t *testing.T) {
t.Parallel()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
config := Config{
Listener: listener,
// Set read deadline in the past to remove reliance on real time
// and simulate scenario when read deadline has elapsed.
ReadDeadline: -time.Millisecond,
EnableProxyProtocol: true,
}
mux, err := New(config)
require.NoError(t, err)
go mux.Serve()
defer mux.Close()
backend1 := &httptest.Server{
Listener: mux.TLS(),
Config: &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, r.RemoteAddr)
}),
},
}
backend1.StartTLS()
defer backend1.Close()
parsedURL, err := url.Parse(backend1.URL)
require.NoError(t, err)
conn, err := net.Dial("tcp", parsedURL.Host)
require.NoError(t, err)
defer conn.Close()
// upgrade connection to TLS
tlsConn := tls.Client(conn, clientConfig(backend1))
defer tlsConn.Close()
// roundtrip should fail on the timeout
_, err = utils.RoundtripWithConn(tlsConn)
require.NotNil(t, err)
})
// UnknownProtocol make sure that multiplexer closes connection
// with unknown protocol
t.Run("UnknownProtocol", func(t *testing.T) {
t.Parallel()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
mux, err := New(Config{
Listener: listener,
EnableProxyProtocol: true,
})
require.NoError(t, err)
go mux.Serve()
defer mux.Close()
conn, err := net.Dial("tcp", listener.Addr().String())
require.NoError(t, err)
defer conn.Close()
// try plain HTTP
_, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nHost: 127.0.0.1\r\n\r\n")
require.NoError(t, err)
// connection should be closed
_, err = conn.Read(make([]byte, 1))
require.Equal(t, err, io.EOF)
})
// DisableSSH disables SSH
t.Run("DisableSSH", func(t *testing.T) {
t.Parallel()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
mux, err := New(Config{
Listener: listener,
EnableProxyProtocol: true,
})
require.NoError(t, err)
go mux.Serve()
defer mux.Close()
backend1 := &httptest.Server{
Listener: mux.TLS(),
Config: &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "backend 1")
}),
},
}
backend1.StartTLS()
defer backend1.Close()
_, err = ssh.Dial("tcp", listener.Addr().String(), &ssh.ClientConfig{
Auth: []ssh.AuthMethod{ssh.Password("abc123")},
Timeout: time.Second,
HostKeyCallback: ssh.FixedHostKey(signer.PublicKey()),
})
require.NotNil(t, err)
// TLS requests will succeed
client := testClient(backend1)
re, err := client.Get(backend1.URL)
require.NoError(t, err)
defer re.Body.Close()
bytes, err := io.ReadAll(re.Body)
require.NoError(t, err)
require.Equal(t, string(bytes), "backend 1")
// Close mux, new requests should fail
mux.Close()
mux.Wait()
// use new client to use new connection pool
client = testClient(backend1)
re, err = client.Get(backend1.URL)
if err == nil {
re.Body.Close()
}
require.NotNil(t, err)
})
// TestDisableTLS tests scenario with disabled TLS
t.Run("DisableTLS", func(t *testing.T) {
t.Parallel()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
mux, err := New(Config{
Listener: listener,
EnableProxyProtocol: true,
})
require.NoError(t, err)
go mux.Serve()
defer mux.Close()
backend1 := &httptest.Server{
Listener: &noopListener{addr: listener.Addr()},
Config: &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "backend 1")
}),
},
}
backend1.StartTLS()
defer backend1.Close()
called := false
sshHandler := sshutils.NewChanHandlerFunc(func(_ context.Context, _ *sshutils.ConnectionContext, nch ssh.NewChannel) {
called = true
err := nch.Reject(ssh.Prohibited, "nothing to see here")
require.NoError(t, err)
})
srv, err := sshutils.NewServer(
"test",
utils.NetAddr{AddrNetwork: "tcp", Addr: "localhost:0"},
sshHandler,
[]ssh.Signer{signer},
sshutils.AuthMethods{Password: pass("abc123")},
)
require.NoError(t, err)
go srv.Serve(mux.SSH())
defer srv.Close()
clt, err := ssh.Dial("tcp", listener.Addr().String(), &ssh.ClientConfig{
Auth: []ssh.AuthMethod{ssh.Password("abc123")},
Timeout: time.Second,
HostKeyCallback: ssh.FixedHostKey(signer.PublicKey()),
})
require.NoError(t, err)
defer clt.Close()
// call new session to initiate opening new channel
_, err = clt.NewSession()
require.NotNil(t, err)
// make sure the channel handler was called OK
require.Equal(t, called, true)
client := testClient(backend1)
re, err := client.Get(backend1.URL)
if err == nil {
re.Body.Close()
}
require.NotNil(t, err)
// Close mux, new requests should fail
mux.Close()
mux.Wait()
})
// NextProto tests multiplexing using NextProto selector
t.Run("NextProto", func(t *testing.T) {
t.Parallel()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
mux, err := New(Config{
Listener: listener,
EnableProxyProtocol: true,
})
require.NoError(t, err)
go mux.Serve()
defer mux.Close()
cfg, err := fixtures.LocalTLSConfig()
require.NoError(t, err)
tlsLis, err := NewTLSListener(TLSListenerConfig{
Listener: tls.NewListener(mux.TLS(), cfg.TLS),
})
require.NoError(t, err)
go tlsLis.Serve()
opts := []grpc.ServerOption{
grpc.Creds(&httplib.TLSCreds{
Config: cfg.TLS,
}),
}
s := grpc.NewServer(opts...)
test.RegisterPingerServer(s, &server{})
errCh := make(chan error, 2)
go func() {
errCh <- s.Serve(tlsLis.HTTP2())
}()
httpServer := http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "http backend")
}),
}
go func() {
err := httpServer.Serve(tlsLis.HTTP())
if err == nil || err == http.ErrServerClosed {
errCh <- nil
return
}
errCh <- err
}()
url := fmt.Sprintf("https://%s", listener.Addr())
client := cfg.NewClient()
re, err := client.Get(url)
require.NoError(t, err)
defer re.Body.Close()
bytes, err := io.ReadAll(re.Body)
require.NoError(t, err)
require.Equal(t, string(bytes), "http backend")
creds := credentials.NewClientTLSFromCert(cfg.CertPool, "")
// Set up a connection to the server.
conn, err := grpc.Dial(listener.Addr().String(), grpc.WithTransportCredentials(creds), grpc.WithBlock())
require.NoError(t, err)
defer conn.Close()
gclient := test.NewPingerClient(conn)
out, err := gclient.Ping(context.TODO(), &test.Request{})
require.NoError(t, err)
require.Equal(t, out.GetPayload(), "grpc backend")
// Close mux, new requests should fail
mux.Close()
mux.Wait()
// use new client to use new connection pool
client = cfg.NewClient()
re, err = client.Get(url)
if err == nil {
re.Body.Close()
}
require.NotNil(t, err)
httpServer.Close()
s.Stop()
// wait for both servers to finish
for i := 0; i < 2; i++ {
err := <-errCh
require.NoError(t, err)
}
})
t.Run("PostgresProxy", func(t *testing.T) {
t.Parallel()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
mux, err := New(Config{
Context: ctx,
Listener: listener,
})
require.NoError(t, err)
go mux.Serve()
defer mux.Close()
// register listener before establishing frontend connection
dblistener := mux.DB()
// Connect to the listener and send Postgres SSLRequest which is what
// psql or other Postgres client will do.
conn, err := net.Dial("tcp", listener.Addr().String())
require.NoError(t, err)
defer conn.Close()
frontend := pgproto3.NewFrontend(pgproto3.NewChunkReader(conn), conn)
err = frontend.Send(&pgproto3.SSLRequest{})
require.NoError(t, err)
// This should not hang indefinitely since we set timeout on the mux context above.
conn, err = dblistener.Accept()
require.NoError(t, err, "detected Postgres connection")
require.Equal(t, ProtoPostgres, conn.(*Conn).Protocol())
})
// WebListener verifies web listener correctly multiplexes connections
// between web and database listeners based on the client certificate.
t.Run("WebListener", func(t *testing.T) {
t.Parallel()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
mux, err := New(Config{
Listener: listener,
EnableProxyProtocol: true,
})
require.NoError(t, err)
go mux.Serve()
defer mux.Close()
// register listener before establishing frontend connection
tlslistener := mux.TLS()
// Generate self-signed CA.
caKey, caCert, err := tlsca.GenerateSelfSignedCA(pkix.Name{CommonName: "test-ca"}, nil, time.Hour)
require.NoError(t, err)
ca, err := tlsca.FromKeys(caCert, caKey)
require.NoError(t, err)
certPool := x509.NewCertPool()
certPool.AppendCertsFromPEM(caCert)
// Sign server certificate.
serverRSAKey, err := native.GenerateRSAPrivateKey()
require.NoError(t, err)
serverPEM, err := ca.GenerateCertificate(tlsca.CertificateRequest{
Subject: pkix.Name{CommonName: "localhost"},
PublicKey: serverRSAKey.Public(),
NotAfter: time.Now().Add(time.Hour),
DNSNames: []string{"127.0.0.1"},
})
require.NoError(t, err)
serverCert, err := tls.X509KeyPair(serverPEM, tlsca.MarshalPrivateKeyPEM(serverRSAKey))
require.NoError(t, err)
// Sign client certificate with database access identity.
clientRSAKey, err := rsa.GenerateKey(rand.Reader, constants.RSAKeySize)
require.NoError(t, err)
subject, err := (&tlsca.Identity{
Username: "alice",
Groups: []string{"admin"},
RouteToDatabase: tlsca.RouteToDatabase{
ServiceName: "postgres",
},
}).Subject()
require.NoError(t, err)
clientPEM, err := ca.GenerateCertificate(tlsca.CertificateRequest{
Subject: subject,
PublicKey: clientRSAKey.Public(),
NotAfter: time.Now().Add(time.Hour),
})
require.NoError(t, err)
clientCert, err := tls.X509KeyPair(clientPEM, tlsca.MarshalPrivateKeyPEM(clientRSAKey))
require.NoError(t, err)
webLis, err := NewWebListener(WebListenerConfig{
Listener: tls.NewListener(tlslistener, &tls.Config{
ClientCAs: certPool,
ClientAuth: tls.VerifyClientCertIfGiven,
Certificates: []tls.Certificate{serverCert},
}),
})
require.NoError(t, err)
go webLis.Serve()
defer webLis.Close()
go func() {
conn, err := webLis.Web().Accept()
require.NoError(t, err)
defer conn.Close()
conn.Write([]byte("web listener"))
}()
go func() {
conn, err := webLis.DB().Accept()
require.NoError(t, err)
defer conn.Close()
conn.Write([]byte("db listener"))
}()
webConn, err := tls.Dial("tcp", listener.Addr().String(), &tls.Config{
RootCAs: certPool,
})
require.NoError(t, err)
defer webConn.Close()
webBytes, err := io.ReadAll(webConn)
require.NoError(t, err)
require.Equal(t, "web listener", string(webBytes))
dbConn, err := tls.Dial("tcp", listener.Addr().String(), &tls.Config{
RootCAs: certPool,
Certificates: []tls.Certificate{clientCert},
})
require.NoError(t, err)
defer dbConn.Close()
dbBytes, err := io.ReadAll(dbConn)
require.NoError(t, err)
require.Equal(t, "db listener", string(dbBytes))
})
t.Run("SSHProxyHelloSignature", func(t *testing.T) {
// Ensures SSH connections fronted with the ProxyHelloSignature are
// detected as SSH connections.
t.Parallel()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
mux, err := New(Config{
Listener: listener,
EnableProxyProtocol: true,
})
require.NoError(t, err)
go mux.Serve()
defer mux.Close()
// Record the remote addr from the point of view of the ssh handler
// so we can confirm that the header is parsed properly.
calledWithRemoteAddr := ""
sshHandler := sshutils.NewChanHandlerFunc(func(_ context.Context, c *sshutils.ConnectionContext, nch ssh.NewChannel) {
calledWithRemoteAddr = c.ServerConn.RemoteAddr().String()
err := nch.Reject(ssh.Prohibited, "nothing to see here")
require.NoError(t, err)
})
srv, err := sshutils.NewServer(
"test",
utils.NetAddr{AddrNetwork: "tcp", Addr: "localhost:0"},
sshHandler,
[]ssh.Signer{signer},
sshutils.AuthMethods{Password: pass("abc123")},
)
require.NoError(t, err)
go srv.Serve(mux.SSH())
defer srv.Close()
// Manually create client conn so we can inject the ProxyHelloSignature
conn, err := net.DialTimeout("tcp", listener.Addr().String(), time.Second)
remoteAddr := "6.6.6.6:1337"
require.NoError(t, err)
hp := &apisshutils.HandshakePayload{
ClientAddr: remoteAddr,
}
payloadJSON, err := json.Marshal(hp)
require.NoError(t, err)
payload := fmt.Sprintf("%s%s\x00", apisshutils.ProxyHelloSignature, payloadJSON)
_, err = conn.Write([]byte(payload))
require.NoError(t, err)
c, chans, reqs, err := ssh.NewClientConn(conn, listener.Addr().String(), &ssh.ClientConfig{
Auth: []ssh.AuthMethod{ssh.Password("abc123")},
Timeout: time.Second,
HostKeyCallback: ssh.FixedHostKey(signer.PublicKey()),
})
require.NoError(t, err)
clt := ssh.NewClient(c, chans, reqs)
defer clt.Close()
// call new session to initiate opening new channel
_, err = clt.NewSession()
require.EqualError(t, err, "ssh: rejected: administratively prohibited (nothing to see here)")
// make sure the channel handler was called OK
require.Equal(t, remoteAddr, calledWithRemoteAddr)
// Close mux, new requests should fail
mux.Close()
mux.Wait()
})
}
func TestProtocolString(t *testing.T) {
for i := -1; i < len(protocolStrings)+1; i++ {
got := Protocol(i).String()
switch i {
case -1, len(protocolStrings) + 1:
require.Equal(t, "", got)
default:
require.Equal(t, protocolStrings[Protocol(i)], got)
}
}
}
// server is used to implement test.PingerServer
type server struct {
test.UnimplementedPingerServer
}
func (s *server) Ping(ctx context.Context, req *test.Request) (*test.Response, error) {
return &test.Response{Payload: "grpc backend"}, nil
}
// clientConfig returns tls client config from test http server
// set up to listen on TLS
func clientConfig(srv *httptest.Server) *tls.Config {
cert, err := x509.ParseCertificate(srv.TLS.Certificates[0].Certificate[0])
if err != nil {
panic(err)
}
certpool := x509.NewCertPool()
certpool.AddCert(cert)
return &tls.Config{
RootCAs: certpool,
ServerName: fmt.Sprintf("%v", cert.IPAddresses[0].String()),
}
}
// testClient is a test HTTP client set up for TLS
func testClient(srv *httptest.Server) *http.Client {
return &http.Client{
Transport: &http.Transport{
TLSClientConfig: clientConfig(srv),
},
}
}
func pass(need string) sshutils.PasswordFunc {
return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
if string(password) == need {
return nil, nil
}
return nil, fmt.Errorf("passwords don't match")
}
}
type noopListener struct {
addr net.Addr
}
func (noopListener) Accept() (net.Conn, error) {
return nil, errors.New("noop")
}
func (noopListener) Close() error {
return nil
}
func (l noopListener) Addr() net.Addr {
return l.addr
}
func TestIsHTTP(t *testing.T) {
t.Parallel()
for _, verb := range httpMethods {
t.Run(fmt.Sprintf("Accept %v", string(verb)), func(t *testing.T) {
data := fmt.Sprintf("%v /some/path HTTP/1.1", string(verb))
require.True(t, isHTTP([]byte(data)))
})
}
rejectedInputs := []string{
"some random junk",
"FAKE /some/path HTTP/1.1",
// This case checks for a bug where the arguments to bytes.HasPrefix are reversed.
"GE",
}
for _, input := range rejectedInputs {
t.Run(fmt.Sprintf("Reject %q", input), func(t *testing.T) {
require.False(t, isHTTP([]byte(input)))
})
}
}