Emit data transfer events.

Created *utils.TrackingConn that wraps the server side net.Conn and is
used to track how much data is transmitted and received over the
net.Conn. At the close of a connection (close of a *srv.ServerContext)
the total data transmitted and received is emitted to the Audit Log.
This commit is contained in:
Russell Jones 2019-02-15 02:15:45 +00:00
parent d6253f25e7
commit ac9af87dfb
9 changed files with 239 additions and 19 deletions

View file

@ -17,9 +17,11 @@ limitations under the License.
package integration
import (
"bufio"
"bytes"
"context"
"crypto/x509"
"encoding/json"
"fmt"
"io"
"io/ioutil"
@ -3151,6 +3153,83 @@ func (s *IntSuite) TestMultipleSignup(c *check.C) {
}
}
// TestDataTransfer makes sure that a "session.data" event is emitted at the
// end of a session that matches the amount of data that was transferred.
func (s *IntSuite) TestDataTransfer(c *check.C) {
KB := 1024
MB := 1048576
// Create a Teleport cluster.
main := s.newTeleport(c, nil, true)
defer main.Stop(true)
// Create a client to the above Teleport cluster.
clientConfig := ClientConfig{
Login: s.me.Username,
Cluster: Site,
Host: Host,
Port: main.GetPortSSHInt(),
}
// Write 1 MB to stdout.
command := []string{"dd", "if=/dev/zero", "bs=1024", "count=1024"}
output, err := runCommand(main, command, clientConfig, 1)
c.Assert(err, check.IsNil)
// Make sure exactly 1 MB was written to output.
c.Assert(len(output) == MB, check.Equals, true)
// Make sure the session.data event was emitted to the audit log.
eventFields, err := findEventInLog(main, events.SessionDataEvent)
c.Assert(err, check.IsNil)
// Make sure the audit event shows that 1 MB was written to the output.
c.Assert(eventFields.GetInt(events.DataReceived) > MB, check.Equals, true)
c.Assert(eventFields.GetInt(events.DataTransmitted) > KB, check.Equals, true)
}
// findEventInLog tries to find an event in the audit log file 10 times.
func findEventInLog(t *TeleInstance, eventName string) (events.EventFields, error) {
for i := 0; i < 10; i++ {
eventFields, err := eventInLog(t.Config.DataDir+"/log/events.log", eventName)
if err != nil {
time.Sleep(1 * time.Second)
continue
}
return eventFields, nil
}
return nil, trace.NotFound("event not found")
}
// eventInLog finds event in audit log file.
func eventInLog(path string, eventName string) (events.EventFields, error) {
file, err := os.Open(path)
if err != nil {
return nil, trace.Wrap(err)
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
var fields events.EventFields
err = json.Unmarshal(scanner.Bytes(), &fields)
if err != nil {
return nil, trace.Wrap(err)
}
eventType, ok := fields[events.EventType]
if !ok {
return nil, trace.BadParameter("not found")
}
if eventType == eventName {
return fields, nil
}
}
return nil, trace.NotFound("event not found")
}
// runCommand is a shortcut for running SSH command, it creates a client
// connected to proxy of the passed in instance, runs the command, and returns
// the result. If multiple attempts are requested, a 250 millisecond delay is

View file

@ -88,6 +88,11 @@ const (
// SessionLeaveEvent indicates that someone left a session
SessionLeaveEvent = "session.leave"
// Data transfer events.
SessionDataEvent = "session.data"
DataTransmitted = "tx"
DataReceived = "rx"
// ClientDisconnectEvent is emitted when client is disconnected
// by the server due to inactivity or any other reason
ClientDisconnectEvent = "client.disconnect"

View file

@ -12,7 +12,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package reversetunnel
@ -484,7 +483,7 @@ func (s *server) Shutdown(ctx context.Context) error {
}
func (s *server) HandleNewChan(conn net.Conn, sconn *ssh.ServerConn, nch ssh.NewChannel) {
// apply read/write timeouts to the server connection
// Apply read/write timeouts to the server connection.
conn = utils.ObeyIdleTimeout(conn,
defaults.ReverseTunnelAgentHeartbeatPeriod*10,
"reverse tunnel server")

View file

@ -98,7 +98,7 @@ func (s *ServiceTestSuite) TestMonitor(c *check.C) {
// Broadcast a degraded event and make sure Teleport reports it's in a
// degraded state.
process.BroadcastEvent(Event{Name: TeleportDegradedEvent, Payload: nil})
err = waitForStatus(endpoint, http.StatusServiceUnavailable)
err = waitForStatus(endpoint, http.StatusServiceUnavailable, http.StatusBadRequest)
c.Assert(err, check.IsNil)
// Broadcast a OK event, this should put Teleport into a recovering state.
@ -120,7 +120,7 @@ func (s *ServiceTestSuite) TestMonitor(c *check.C) {
c.Assert(err, check.IsNil)
}
func waitForStatus(diagAddr string, statusCode int) error {
func waitForStatus(diagAddr string, statusCodes ...int) error {
tickCh := time.Tick(250 * time.Millisecond)
timeoutCh := time.After(10 * time.Second)
for {
@ -130,11 +130,13 @@ func waitForStatus(diagAddr string, statusCode int) error {
if err != nil {
return trace.Wrap(err)
}
if resp.StatusCode == statusCode {
return nil
for _, statusCode := range statusCodes {
if resp.StatusCode == statusCode {
return nil
}
}
case <-timeoutCh:
return trace.BadParameter("timeout waiting for status %v", statusCode)
return trace.BadParameter("timeout waiting for status %v", statusCodes)
}
}
}

View file

@ -20,6 +20,7 @@ import (
"context"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"time"
@ -27,6 +28,7 @@ import (
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
@ -38,11 +40,32 @@ import (
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
)
var ctxID int32
var (
serverTX = prometheus.NewCounter(
prometheus.CounterOpts{
Name: "tx",
Help: "Number of bytes transmitted.",
},
)
serverRX = prometheus.NewCounter(
prometheus.CounterOpts{
Name: "rx",
Help: "Number of bytes received.",
},
)
)
func init() {
prometheus.MustRegister(serverTX)
prometheus.MustRegister(serverRX)
}
// Server is regular or forwarding SSH server.
type Server interface {
// ID is the unique ID of the server.
@ -164,6 +187,9 @@ type ServerContext struct {
// Conn is the underlying *ssh.ServerConn.
Conn *ssh.ServerConn
// Connection is the underlying net.Conn for the connection.
Connection net.Conn
// Identity holds the identity of the user that is currently logged in on
// the Conn.
Identity IdentityContext
@ -478,9 +504,63 @@ func (c *ServerContext) takeClosers() []io.Closer {
return closers
}
// When the ServerContext (connection) is closed, emit "session.data" event
// containing how much data was transmitted and received over the net.Conn.
func (c *ServerContext) reportStats(conn utils.Stater) {
// Never emit session data events for the proxy or from a Teleport node if
// sessions are being recorded at the proxy (this would result in double
// events).
if c.GetServer().Component() == teleport.ComponentProxy {
return
}
if c.ClusterConfig.GetSessionRecording() == services.RecordAtProxy &&
c.GetServer().Component() == teleport.ComponentNode {
return
}
// Get the TX and RX bytes.
txBytes, rxBytes := conn.Stat()
// Build and emit session data event. Note that TX and RX are reversed
// below, that is because the connection is held from the perspective of
// the server not the client, but the logs are from the perspective of the
// client.
eventFields := events.EventFields{
events.DataTransmitted: rxBytes,
events.DataReceived: txBytes,
events.SessionServerID: c.GetServer().ID(),
events.EventLogin: c.Identity.Login,
events.EventUser: c.Identity.TeleportUser,
events.LocalAddr: c.Conn.LocalAddr().String(),
events.RemoteAddr: c.Conn.RemoteAddr().String(),
}
if c.session != nil {
eventFields[events.SessionEventID] = c.session.id
}
c.GetServer().GetAuditLog().EmitAuditEvent(events.SessionDataEvent, eventFields)
// Emit TX and RX bytes to their respective Prometheus counters.
serverTX.Add(float64(txBytes))
serverRX.Add(float64(rxBytes))
}
func (c *ServerContext) Close() error {
// If the underlying connection is holding tracking information, report that
// to the audit log at close.
if stats, ok := c.Connection.(*utils.TrackingConn); ok {
defer c.reportStats(stats)
}
// Unblock any goroutines waiting until session is closed.
c.cancel()
return closeAll(c.takeClosers()...)
// Close and release all resources.
err := closeAll(c.takeClosers()...)
if err != nil {
return trace.Wrap(err)
}
return nil
}
// SendExecResult sends the result of execution of the "exec" command over the

View file

@ -214,7 +214,7 @@ func New(c ServerConfig) (*Server, error) {
}),
id: uuid.New(),
targetConn: c.TargetConn,
serverConn: serverConn,
serverConn: utils.NewTrackingConn(serverConn),
clientConn: clientConn,
userAgent: c.UserAgent,
hostCertificate: c.HostCertificate,
@ -594,6 +594,7 @@ func (s *Server) handleDirectTCPIPRequest(ch ssh.Channel, req *sshutils.DirectTC
ch.Stderr().Write([]byte("Unable to create connection context."))
return
}
ctx.Connection = s.serverConn
ctx.RemoteClient = s.remoteClient
defer ctx.Close()
@ -657,6 +658,7 @@ func (s *Server) handleSessionRequests(ch ssh.Channel, in <-chan *ssh.Request) {
ch.Stderr().Write([]byte("Unable to create connection context."))
return
}
ctx.Connection = s.serverConn
ctx.RemoteClient = s.remoteClient
ctx.AddCloser(ch)
defer ctx.Close()

View file

@ -741,7 +741,7 @@ func (s *Server) HandleRequest(r *ssh.Request) {
}
// HandleNewChan is called when new channel is opened
func (s *Server) HandleNewChan(nc net.Conn, sconn *ssh.ServerConn, nch ssh.NewChannel) {
func (s *Server) HandleNewChan(wconn net.Conn, sconn *ssh.ServerConn, nch ssh.NewChannel) {
identityContext, err := s.authHandlers.CreateIdentityContext(sconn)
if err != nil {
nch.Reject(ssh.Prohibited, fmt.Sprintf("Unable to create identity from connection: %v", err))
@ -760,7 +760,7 @@ func (s *Server) HandleNewChan(nc net.Conn, sconn *ssh.ServerConn, nch ssh.NewCh
nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err))
return
}
go s.handleSessionRequests(sconn, identityContext, ch, requests)
go s.handleSessionRequests(wconn, sconn, identityContext, ch, requests)
} else {
nch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %v", channelType))
}
@ -777,7 +777,7 @@ func (s *Server) HandleNewChan(nc net.Conn, sconn *ssh.ServerConn, nch ssh.NewCh
nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err))
return
}
go s.handleSessionRequests(sconn, identityContext, ch, requests)
go s.handleSessionRequests(wconn, sconn, identityContext, ch, requests)
// Channels of type "direct-tcpip" handles request for port forwarding.
case "direct-tcpip":
req, err := sshutils.ParseDirectTCPIPReq(nch.ExtraData())
@ -792,14 +792,14 @@ func (s *Server) HandleNewChan(nc net.Conn, sconn *ssh.ServerConn, nch ssh.NewCh
nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err))
return
}
go s.handleDirectTCPIPRequest(sconn, identityContext, ch, req)
go s.handleDirectTCPIPRequest(wconn, sconn, identityContext, ch, req)
default:
nch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %v", channelType))
}
}
// handleDirectTCPIPRequest handles port forwarding requests.
func (s *Server) handleDirectTCPIPRequest(sconn *ssh.ServerConn, identityContext srv.IdentityContext, ch ssh.Channel, req *sshutils.DirectTCPIPReq) {
func (s *Server) handleDirectTCPIPRequest(wconn net.Conn, sconn *ssh.ServerConn, identityContext srv.IdentityContext, ch ssh.Channel, req *sshutils.DirectTCPIPReq) {
// Create context for this channel. This context will be closed when
// forwarding is complete.
ctx, err := srv.NewServerContext(s, sconn, identityContext)
@ -808,7 +808,7 @@ func (s *Server) handleDirectTCPIPRequest(sconn *ssh.ServerConn, identityContext
ch.Stderr().Write([]byte("Unable to create connection context."))
return
}
ctx.Connection = wconn
ctx.IsTestStub = s.isTestStub
ctx.AddCloser(ch)
defer ctx.Debugf("direct-tcp closed")
@ -892,7 +892,7 @@ func (s *Server) handleDirectTCPIPRequest(sconn *ssh.ServerConn, identityContext
// handleSessionRequests handles out of band session requests once the session
// channel has been created this function's loop handles all the "exec",
// "subsystem" and "shell" requests.
func (s *Server) handleSessionRequests(sconn *ssh.ServerConn, identityContext srv.IdentityContext, ch ssh.Channel, in <-chan *ssh.Request) {
func (s *Server) handleSessionRequests(conn net.Conn, sconn *ssh.ServerConn, identityContext srv.IdentityContext, ch ssh.Channel, in <-chan *ssh.Request) {
// Create context for this channel. This context will be closed when the
// session request is complete.
ctx, err := srv.NewServerContext(s, sconn, identityContext)
@ -901,12 +901,13 @@ func (s *Server) handleSessionRequests(sconn *ssh.ServerConn, identityContext sr
ch.Stderr().Write([]byte("Unable to create connection context."))
return
}
ctx.Connection = conn
ctx.IsTestStub = s.isTestStub
ctx.AddCloser(ch)
defer ctx.Close()
// Create a close context used to signal between the server and the
// keep-alive loop when to close th connection (from either side).
// keep-alive loop when to close the connection (from either side).
closeContext, closeCancel := context.WithCancel(context.Background())
defer closeCancel()

View file

@ -364,10 +364,14 @@ func (s *Server) handleConnection(conn net.Conn) {
defaults.DefaultIdleConnectionDuration,
s.component)
// Wrap connection with a tracker used to monitor how much data was
// transmitted and received over the connection.
wconn := utils.NewTrackingConn(conn)
// create a new SSH server which handles the handshake (and pass the custom
// payload structure which will be populated only when/if this connection
// comes from another Teleport proxy):
sconn, chans, reqs, err := ssh.NewServerConn(wrapConnection(conn), &s.cfg)
sconn, chans, reqs, err := ssh.NewServerConn(wrapConnection(wconn), &s.cfg)
if err != nil {
conn.SetDeadline(time.Time{})
return
@ -413,7 +417,7 @@ func (s *Server) handleConnection(conn net.Conn) {
connClosed()
return
}
go s.newChanHandler.HandleNewChan(conn, sconn, nch)
go s.newChanHandler.HandleNewChan(wconn, sconn, nch)
// send keepalive pings to the clients
case <-keepAliveTick.C:
const wantReply = true

View file

@ -23,6 +23,7 @@ import (
"io/ioutil"
"net"
"net/http"
"sync/atomic"
"github.com/gravitational/trace"
)
@ -88,3 +89,50 @@ func RoundtripWithConn(conn net.Conn) (string, error) {
}
return string(out), nil
}
// Stater is extension interface of the net.Conn for implementations that
// track connection statistics.
type Stater interface {
// Stat returns TX, RX data.
Stat() (uint64, uint64)
}
// TrackingConn is a net.Conn that keeps track of how much data was transmitted
// (TX) and received (RX) over the net.Conn. A maximum of about 18446
// petabytes can be kept track of for TX and RX before it rolls over.
// See https://golang.org/ref/spec#Numeric_types for more details.
type TrackingConn struct {
// net.Conn is the underlying net.Conn.
net.Conn
// txBytes keeps track of how many bytes were transmitted.
txBytes uint64
// rxBytes keeps track of how many bytes were received.
rxBytes uint64
}
// NewTrackingConn returns a net.Conn that can keep track of how much data was
// transmitted over it.
func NewTrackingConn(conn net.Conn) *TrackingConn {
return &TrackingConn{
Conn: conn,
}
}
// Stat returns the transmitted (TX) and received (RX) bytes over the net.Conn.
func (s *TrackingConn) Stat() (uint64, uint64) {
return atomic.LoadUint64(&s.txBytes), atomic.LoadUint64(&s.rxBytes)
}
func (s *TrackingConn) Read(b []byte) (n int, err error) {
n, err = s.Conn.Read(b)
atomic.AddUint64(&s.rxBytes, uint64(n))
return n, trace.Wrap(err)
}
func (s *TrackingConn) Write(b []byte) (n int, err error) {
n, err = s.Conn.Write(b)
atomic.AddUint64(&s.txBytes, uint64(n))
return n, trace.Wrap(err)
}