mirror of
https://github.com/gravitational/teleport
synced 2024-10-20 17:23:22 +00:00
Emit data transfer events.
Created *utils.TrackingConn that wraps the server side net.Conn and is used to track how much data is transmitted and received over the net.Conn. At the close of a connection (close of a *srv.ServerContext) the total data transmitted and received is emitted to the Audit Log.
This commit is contained in:
parent
d6253f25e7
commit
ac9af87dfb
|
@ -17,9 +17,11 @@ limitations under the License.
|
|||
package integration
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
|
@ -3151,6 +3153,83 @@ func (s *IntSuite) TestMultipleSignup(c *check.C) {
|
|||
}
|
||||
}
|
||||
|
||||
// TestDataTransfer makes sure that a "session.data" event is emitted at the
|
||||
// end of a session that matches the amount of data that was transferred.
|
||||
func (s *IntSuite) TestDataTransfer(c *check.C) {
|
||||
KB := 1024
|
||||
MB := 1048576
|
||||
|
||||
// Create a Teleport cluster.
|
||||
main := s.newTeleport(c, nil, true)
|
||||
defer main.Stop(true)
|
||||
|
||||
// Create a client to the above Teleport cluster.
|
||||
clientConfig := ClientConfig{
|
||||
Login: s.me.Username,
|
||||
Cluster: Site,
|
||||
Host: Host,
|
||||
Port: main.GetPortSSHInt(),
|
||||
}
|
||||
|
||||
// Write 1 MB to stdout.
|
||||
command := []string{"dd", "if=/dev/zero", "bs=1024", "count=1024"}
|
||||
output, err := runCommand(main, command, clientConfig, 1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
// Make sure exactly 1 MB was written to output.
|
||||
c.Assert(len(output) == MB, check.Equals, true)
|
||||
|
||||
// Make sure the session.data event was emitted to the audit log.
|
||||
eventFields, err := findEventInLog(main, events.SessionDataEvent)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
// Make sure the audit event shows that 1 MB was written to the output.
|
||||
c.Assert(eventFields.GetInt(events.DataReceived) > MB, check.Equals, true)
|
||||
c.Assert(eventFields.GetInt(events.DataTransmitted) > KB, check.Equals, true)
|
||||
}
|
||||
|
||||
// findEventInLog tries to find an event in the audit log file 10 times.
|
||||
func findEventInLog(t *TeleInstance, eventName string) (events.EventFields, error) {
|
||||
for i := 0; i < 10; i++ {
|
||||
eventFields, err := eventInLog(t.Config.DataDir+"/log/events.log", eventName)
|
||||
if err != nil {
|
||||
time.Sleep(1 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
return eventFields, nil
|
||||
}
|
||||
return nil, trace.NotFound("event not found")
|
||||
}
|
||||
|
||||
// eventInLog finds event in audit log file.
|
||||
func eventInLog(path string, eventName string) (events.EventFields, error) {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
var fields events.EventFields
|
||||
err = json.Unmarshal(scanner.Bytes(), &fields)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
eventType, ok := fields[events.EventType]
|
||||
if !ok {
|
||||
return nil, trace.BadParameter("not found")
|
||||
}
|
||||
if eventType == eventName {
|
||||
return fields, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, trace.NotFound("event not found")
|
||||
}
|
||||
|
||||
// runCommand is a shortcut for running SSH command, it creates a client
|
||||
// connected to proxy of the passed in instance, runs the command, and returns
|
||||
// the result. If multiple attempts are requested, a 250 millisecond delay is
|
||||
|
|
|
@ -88,6 +88,11 @@ const (
|
|||
// SessionLeaveEvent indicates that someone left a session
|
||||
SessionLeaveEvent = "session.leave"
|
||||
|
||||
// Data transfer events.
|
||||
SessionDataEvent = "session.data"
|
||||
DataTransmitted = "tx"
|
||||
DataReceived = "rx"
|
||||
|
||||
// ClientDisconnectEvent is emitted when client is disconnected
|
||||
// by the server due to inactivity or any other reason
|
||||
ClientDisconnectEvent = "client.disconnect"
|
||||
|
|
|
@ -12,7 +12,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
|
|||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
*/
|
||||
|
||||
package reversetunnel
|
||||
|
@ -484,7 +483,7 @@ func (s *server) Shutdown(ctx context.Context) error {
|
|||
}
|
||||
|
||||
func (s *server) HandleNewChan(conn net.Conn, sconn *ssh.ServerConn, nch ssh.NewChannel) {
|
||||
// apply read/write timeouts to the server connection
|
||||
// Apply read/write timeouts to the server connection.
|
||||
conn = utils.ObeyIdleTimeout(conn,
|
||||
defaults.ReverseTunnelAgentHeartbeatPeriod*10,
|
||||
"reverse tunnel server")
|
||||
|
|
|
@ -98,7 +98,7 @@ func (s *ServiceTestSuite) TestMonitor(c *check.C) {
|
|||
// Broadcast a degraded event and make sure Teleport reports it's in a
|
||||
// degraded state.
|
||||
process.BroadcastEvent(Event{Name: TeleportDegradedEvent, Payload: nil})
|
||||
err = waitForStatus(endpoint, http.StatusServiceUnavailable)
|
||||
err = waitForStatus(endpoint, http.StatusServiceUnavailable, http.StatusBadRequest)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
// Broadcast a OK event, this should put Teleport into a recovering state.
|
||||
|
@ -120,7 +120,7 @@ func (s *ServiceTestSuite) TestMonitor(c *check.C) {
|
|||
c.Assert(err, check.IsNil)
|
||||
}
|
||||
|
||||
func waitForStatus(diagAddr string, statusCode int) error {
|
||||
func waitForStatus(diagAddr string, statusCodes ...int) error {
|
||||
tickCh := time.Tick(250 * time.Millisecond)
|
||||
timeoutCh := time.After(10 * time.Second)
|
||||
for {
|
||||
|
@ -130,11 +130,13 @@ func waitForStatus(diagAddr string, statusCode int) error {
|
|||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
if resp.StatusCode == statusCode {
|
||||
return nil
|
||||
for _, statusCode := range statusCodes {
|
||||
if resp.StatusCode == statusCode {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
case <-timeoutCh:
|
||||
return trace.BadParameter("timeout waiting for status %v", statusCode)
|
||||
return trace.BadParameter("timeout waiting for status %v", statusCodes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
@ -27,6 +28,7 @@ import (
|
|||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/agent"
|
||||
|
||||
"github.com/gravitational/teleport"
|
||||
"github.com/gravitational/teleport/lib/auth"
|
||||
"github.com/gravitational/teleport/lib/defaults"
|
||||
"github.com/gravitational/teleport/lib/events"
|
||||
|
@ -38,11 +40,32 @@ import (
|
|||
|
||||
"github.com/gravitational/trace"
|
||||
"github.com/jonboulle/clockwork"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var ctxID int32
|
||||
|
||||
var (
|
||||
serverTX = prometheus.NewCounter(
|
||||
prometheus.CounterOpts{
|
||||
Name: "tx",
|
||||
Help: "Number of bytes transmitted.",
|
||||
},
|
||||
)
|
||||
serverRX = prometheus.NewCounter(
|
||||
prometheus.CounterOpts{
|
||||
Name: "rx",
|
||||
Help: "Number of bytes received.",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
func init() {
|
||||
prometheus.MustRegister(serverTX)
|
||||
prometheus.MustRegister(serverRX)
|
||||
}
|
||||
|
||||
// Server is regular or forwarding SSH server.
|
||||
type Server interface {
|
||||
// ID is the unique ID of the server.
|
||||
|
@ -164,6 +187,9 @@ type ServerContext struct {
|
|||
// Conn is the underlying *ssh.ServerConn.
|
||||
Conn *ssh.ServerConn
|
||||
|
||||
// Connection is the underlying net.Conn for the connection.
|
||||
Connection net.Conn
|
||||
|
||||
// Identity holds the identity of the user that is currently logged in on
|
||||
// the Conn.
|
||||
Identity IdentityContext
|
||||
|
@ -478,9 +504,63 @@ func (c *ServerContext) takeClosers() []io.Closer {
|
|||
return closers
|
||||
}
|
||||
|
||||
// When the ServerContext (connection) is closed, emit "session.data" event
|
||||
// containing how much data was transmitted and received over the net.Conn.
|
||||
func (c *ServerContext) reportStats(conn utils.Stater) {
|
||||
// Never emit session data events for the proxy or from a Teleport node if
|
||||
// sessions are being recorded at the proxy (this would result in double
|
||||
// events).
|
||||
if c.GetServer().Component() == teleport.ComponentProxy {
|
||||
return
|
||||
}
|
||||
if c.ClusterConfig.GetSessionRecording() == services.RecordAtProxy &&
|
||||
c.GetServer().Component() == teleport.ComponentNode {
|
||||
return
|
||||
}
|
||||
|
||||
// Get the TX and RX bytes.
|
||||
txBytes, rxBytes := conn.Stat()
|
||||
|
||||
// Build and emit session data event. Note that TX and RX are reversed
|
||||
// below, that is because the connection is held from the perspective of
|
||||
// the server not the client, but the logs are from the perspective of the
|
||||
// client.
|
||||
eventFields := events.EventFields{
|
||||
events.DataTransmitted: rxBytes,
|
||||
events.DataReceived: txBytes,
|
||||
events.SessionServerID: c.GetServer().ID(),
|
||||
events.EventLogin: c.Identity.Login,
|
||||
events.EventUser: c.Identity.TeleportUser,
|
||||
events.LocalAddr: c.Conn.LocalAddr().String(),
|
||||
events.RemoteAddr: c.Conn.RemoteAddr().String(),
|
||||
}
|
||||
if c.session != nil {
|
||||
eventFields[events.SessionEventID] = c.session.id
|
||||
}
|
||||
c.GetServer().GetAuditLog().EmitAuditEvent(events.SessionDataEvent, eventFields)
|
||||
|
||||
// Emit TX and RX bytes to their respective Prometheus counters.
|
||||
serverTX.Add(float64(txBytes))
|
||||
serverRX.Add(float64(rxBytes))
|
||||
}
|
||||
|
||||
func (c *ServerContext) Close() error {
|
||||
// If the underlying connection is holding tracking information, report that
|
||||
// to the audit log at close.
|
||||
if stats, ok := c.Connection.(*utils.TrackingConn); ok {
|
||||
defer c.reportStats(stats)
|
||||
}
|
||||
|
||||
// Unblock any goroutines waiting until session is closed.
|
||||
c.cancel()
|
||||
return closeAll(c.takeClosers()...)
|
||||
|
||||
// Close and release all resources.
|
||||
err := closeAll(c.takeClosers()...)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendExecResult sends the result of execution of the "exec" command over the
|
||||
|
|
|
@ -214,7 +214,7 @@ func New(c ServerConfig) (*Server, error) {
|
|||
}),
|
||||
id: uuid.New(),
|
||||
targetConn: c.TargetConn,
|
||||
serverConn: serverConn,
|
||||
serverConn: utils.NewTrackingConn(serverConn),
|
||||
clientConn: clientConn,
|
||||
userAgent: c.UserAgent,
|
||||
hostCertificate: c.HostCertificate,
|
||||
|
@ -594,6 +594,7 @@ func (s *Server) handleDirectTCPIPRequest(ch ssh.Channel, req *sshutils.DirectTC
|
|||
ch.Stderr().Write([]byte("Unable to create connection context."))
|
||||
return
|
||||
}
|
||||
ctx.Connection = s.serverConn
|
||||
ctx.RemoteClient = s.remoteClient
|
||||
defer ctx.Close()
|
||||
|
||||
|
@ -657,6 +658,7 @@ func (s *Server) handleSessionRequests(ch ssh.Channel, in <-chan *ssh.Request) {
|
|||
ch.Stderr().Write([]byte("Unable to create connection context."))
|
||||
return
|
||||
}
|
||||
ctx.Connection = s.serverConn
|
||||
ctx.RemoteClient = s.remoteClient
|
||||
ctx.AddCloser(ch)
|
||||
defer ctx.Close()
|
||||
|
|
|
@ -741,7 +741,7 @@ func (s *Server) HandleRequest(r *ssh.Request) {
|
|||
}
|
||||
|
||||
// HandleNewChan is called when new channel is opened
|
||||
func (s *Server) HandleNewChan(nc net.Conn, sconn *ssh.ServerConn, nch ssh.NewChannel) {
|
||||
func (s *Server) HandleNewChan(wconn net.Conn, sconn *ssh.ServerConn, nch ssh.NewChannel) {
|
||||
identityContext, err := s.authHandlers.CreateIdentityContext(sconn)
|
||||
if err != nil {
|
||||
nch.Reject(ssh.Prohibited, fmt.Sprintf("Unable to create identity from connection: %v", err))
|
||||
|
@ -760,7 +760,7 @@ func (s *Server) HandleNewChan(nc net.Conn, sconn *ssh.ServerConn, nch ssh.NewCh
|
|||
nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err))
|
||||
return
|
||||
}
|
||||
go s.handleSessionRequests(sconn, identityContext, ch, requests)
|
||||
go s.handleSessionRequests(wconn, sconn, identityContext, ch, requests)
|
||||
} else {
|
||||
nch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %v", channelType))
|
||||
}
|
||||
|
@ -777,7 +777,7 @@ func (s *Server) HandleNewChan(nc net.Conn, sconn *ssh.ServerConn, nch ssh.NewCh
|
|||
nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err))
|
||||
return
|
||||
}
|
||||
go s.handleSessionRequests(sconn, identityContext, ch, requests)
|
||||
go s.handleSessionRequests(wconn, sconn, identityContext, ch, requests)
|
||||
// Channels of type "direct-tcpip" handles request for port forwarding.
|
||||
case "direct-tcpip":
|
||||
req, err := sshutils.ParseDirectTCPIPReq(nch.ExtraData())
|
||||
|
@ -792,14 +792,14 @@ func (s *Server) HandleNewChan(nc net.Conn, sconn *ssh.ServerConn, nch ssh.NewCh
|
|||
nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err))
|
||||
return
|
||||
}
|
||||
go s.handleDirectTCPIPRequest(sconn, identityContext, ch, req)
|
||||
go s.handleDirectTCPIPRequest(wconn, sconn, identityContext, ch, req)
|
||||
default:
|
||||
nch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %v", channelType))
|
||||
}
|
||||
}
|
||||
|
||||
// handleDirectTCPIPRequest handles port forwarding requests.
|
||||
func (s *Server) handleDirectTCPIPRequest(sconn *ssh.ServerConn, identityContext srv.IdentityContext, ch ssh.Channel, req *sshutils.DirectTCPIPReq) {
|
||||
func (s *Server) handleDirectTCPIPRequest(wconn net.Conn, sconn *ssh.ServerConn, identityContext srv.IdentityContext, ch ssh.Channel, req *sshutils.DirectTCPIPReq) {
|
||||
// Create context for this channel. This context will be closed when
|
||||
// forwarding is complete.
|
||||
ctx, err := srv.NewServerContext(s, sconn, identityContext)
|
||||
|
@ -808,7 +808,7 @@ func (s *Server) handleDirectTCPIPRequest(sconn *ssh.ServerConn, identityContext
|
|||
ch.Stderr().Write([]byte("Unable to create connection context."))
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Connection = wconn
|
||||
ctx.IsTestStub = s.isTestStub
|
||||
ctx.AddCloser(ch)
|
||||
defer ctx.Debugf("direct-tcp closed")
|
||||
|
@ -892,7 +892,7 @@ func (s *Server) handleDirectTCPIPRequest(sconn *ssh.ServerConn, identityContext
|
|||
// handleSessionRequests handles out of band session requests once the session
|
||||
// channel has been created this function's loop handles all the "exec",
|
||||
// "subsystem" and "shell" requests.
|
||||
func (s *Server) handleSessionRequests(sconn *ssh.ServerConn, identityContext srv.IdentityContext, ch ssh.Channel, in <-chan *ssh.Request) {
|
||||
func (s *Server) handleSessionRequests(conn net.Conn, sconn *ssh.ServerConn, identityContext srv.IdentityContext, ch ssh.Channel, in <-chan *ssh.Request) {
|
||||
// Create context for this channel. This context will be closed when the
|
||||
// session request is complete.
|
||||
ctx, err := srv.NewServerContext(s, sconn, identityContext)
|
||||
|
@ -901,12 +901,13 @@ func (s *Server) handleSessionRequests(sconn *ssh.ServerConn, identityContext sr
|
|||
ch.Stderr().Write([]byte("Unable to create connection context."))
|
||||
return
|
||||
}
|
||||
ctx.Connection = conn
|
||||
ctx.IsTestStub = s.isTestStub
|
||||
ctx.AddCloser(ch)
|
||||
defer ctx.Close()
|
||||
|
||||
// Create a close context used to signal between the server and the
|
||||
// keep-alive loop when to close th connection (from either side).
|
||||
// keep-alive loop when to close the connection (from either side).
|
||||
closeContext, closeCancel := context.WithCancel(context.Background())
|
||||
defer closeCancel()
|
||||
|
||||
|
|
|
@ -364,10 +364,14 @@ func (s *Server) handleConnection(conn net.Conn) {
|
|||
defaults.DefaultIdleConnectionDuration,
|
||||
s.component)
|
||||
|
||||
// Wrap connection with a tracker used to monitor how much data was
|
||||
// transmitted and received over the connection.
|
||||
wconn := utils.NewTrackingConn(conn)
|
||||
|
||||
// create a new SSH server which handles the handshake (and pass the custom
|
||||
// payload structure which will be populated only when/if this connection
|
||||
// comes from another Teleport proxy):
|
||||
sconn, chans, reqs, err := ssh.NewServerConn(wrapConnection(conn), &s.cfg)
|
||||
sconn, chans, reqs, err := ssh.NewServerConn(wrapConnection(wconn), &s.cfg)
|
||||
if err != nil {
|
||||
conn.SetDeadline(time.Time{})
|
||||
return
|
||||
|
@ -413,7 +417,7 @@ func (s *Server) handleConnection(conn net.Conn) {
|
|||
connClosed()
|
||||
return
|
||||
}
|
||||
go s.newChanHandler.HandleNewChan(conn, sconn, nch)
|
||||
go s.newChanHandler.HandleNewChan(wconn, sconn, nch)
|
||||
// send keepalive pings to the clients
|
||||
case <-keepAliveTick.C:
|
||||
const wantReply = true
|
||||
|
|
|
@ -23,6 +23,7 @@ import (
|
|||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/gravitational/trace"
|
||||
)
|
||||
|
@ -88,3 +89,50 @@ func RoundtripWithConn(conn net.Conn) (string, error) {
|
|||
}
|
||||
return string(out), nil
|
||||
}
|
||||
|
||||
// Stater is extension interface of the net.Conn for implementations that
|
||||
// track connection statistics.
|
||||
type Stater interface {
|
||||
// Stat returns TX, RX data.
|
||||
Stat() (uint64, uint64)
|
||||
}
|
||||
|
||||
// TrackingConn is a net.Conn that keeps track of how much data was transmitted
|
||||
// (TX) and received (RX) over the net.Conn. A maximum of about 18446
|
||||
// petabytes can be kept track of for TX and RX before it rolls over.
|
||||
// See https://golang.org/ref/spec#Numeric_types for more details.
|
||||
type TrackingConn struct {
|
||||
// net.Conn is the underlying net.Conn.
|
||||
net.Conn
|
||||
|
||||
// txBytes keeps track of how many bytes were transmitted.
|
||||
txBytes uint64
|
||||
|
||||
// rxBytes keeps track of how many bytes were received.
|
||||
rxBytes uint64
|
||||
}
|
||||
|
||||
// NewTrackingConn returns a net.Conn that can keep track of how much data was
|
||||
// transmitted over it.
|
||||
func NewTrackingConn(conn net.Conn) *TrackingConn {
|
||||
return &TrackingConn{
|
||||
Conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
// Stat returns the transmitted (TX) and received (RX) bytes over the net.Conn.
|
||||
func (s *TrackingConn) Stat() (uint64, uint64) {
|
||||
return atomic.LoadUint64(&s.txBytes), atomic.LoadUint64(&s.rxBytes)
|
||||
}
|
||||
|
||||
func (s *TrackingConn) Read(b []byte) (n int, err error) {
|
||||
n, err = s.Conn.Read(b)
|
||||
atomic.AddUint64(&s.rxBytes, uint64(n))
|
||||
return n, trace.Wrap(err)
|
||||
}
|
||||
|
||||
func (s *TrackingConn) Write(b []byte) (n int, err error) {
|
||||
n, err = s.Conn.Write(b)
|
||||
atomic.AddUint64(&s.txBytes, uint64(n))
|
||||
return n, trace.Wrap(err)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue