Replaced weg sockets with HTTP POST/GET chunks

This commit is contained in:
Ev Kontsevoy 2016-05-05 23:51:56 -07:00
parent a08ea32b5e
commit f27e207afc
11 changed files with 185 additions and 131 deletions

View file

@ -14,8 +14,9 @@ export
$(eval BUILDFLAGS := $(ADDFLAGS) -ldflags "-w $(shell go install $(PKGPATH)/vendor/github.com/gravitational/version/cmd/linkflags && linkflags -pkg=$(GOPATH)/src/$(PKGPATH) -verpkg=$(PKGPATH)/vendor/github.com/gravitational/version)")
ev:
TELEPORT_DEBUG_TESTS=1 go test -v ./integration/... -check.f "IntSuite.TestAudit"
#ev:
# TELEPORT_DEBUG_TESTS=1 go test -v ./integration/... -check.f "IntSuite.TestAudit" 2>&1
# TELEPORT_DEBUG_TESTS=1 go test -v ./integration/... -check.f "IntSuite.TestAudit" 2>&1 | grep -e "---"
#
# Default target: builds all 3 executables and plaaces them in a current directory

View file

@ -20,7 +20,6 @@ import (
"bytes"
"fmt"
"io"
"io/ioutil"
"os"
"os/user"
"strconv"
@ -131,7 +130,8 @@ func (s *IntSuite) TestAudit(c *check.C) {
c.Assert(err, check.IsNil)
cl.Output = &myTerm
endC <- cl.SSH([]string{}, false, &myTerm)
err = cl.SSH([]string{}, false, &myTerm)
endC <- err
}()
// wait until there's a session in there:
@ -149,31 +149,34 @@ func (s *IntSuite) TestAudit(c *check.C) {
// make sure it's us who joined! :)
c.Assert(session.Parties[0].User, check.Equals, s.me.Username)
// lets add something to the session streaam:
// write 1MB chunk
bigChunk := make([]byte, 1024*1024)
err = site.PostSessionChunk(session.ID, bytes.NewReader(bigChunk))
c.Assert(err, check.Equals, nil)
// then add small prefix:
err = site.PostSessionChunk(session.ID, bytes.NewBufferString("\nsuffix"))
c.Assert(err, check.Equals, nil)
// lets type "echo hi" followed by "enter" and then "exit" + "enter":
myTerm.Type("\aecho hi\n\r\aexit\n\r\a")
// wait for session to end:
<-endC
// using 'session writer' lets add something to the session streaam:
w, err := site.GetSessionWriter(session.ID)
c.Assert(err, check.IsNil)
// write 32Kb chunk
bigChunk := make([]byte, 1024*32)
n, err := w.Write(bigChunk)
c.Assert(err, check.Equals, nil)
c.Assert(n, check.Equals, len(bigChunk))
// then add small prefix:
w.Write([]byte("\nsuffix"))
w.Close()
// read back the entire session:
r, err := site.GetSessionReader(session.ID, 0)
c.Assert(err, check.IsNil)
sessionStream, err := ioutil.ReadAll(r)
c.Assert(err, check.IsNil)
c.Assert(len(sessionStream) > len(bigChunk), check.Equals, true)
r.Close()
// read back the entire session (we have to try several times until we get back
// everything because the session is closing)
expectedLen := 1048600
var sessionStream []byte
for i := 0; len(sessionStream) < expectedLen; i++ {
sessionStream, err = site.GetSessionChunk(session.ID, 0, events.MaxChunkBytes)
c.Assert(err, check.IsNil)
time.Sleep(time.Millisecond * 250)
if i > 10 {
// session stream keeps coming back short
c.Fatalf("stream is too short: <%d", expectedLen)
}
}
// see what we got. It looks different based on bash settings, but here it is
// on Ev's machine (hostname is 'edsger'):
@ -182,18 +185,15 @@ func (s *IntSuite) TestAudit(c *check.C) {
// hi
// edsger ~: exit
// logout
// <5MB of zeros here>
// <1MB of zeros here>
// suffix
//
c.Assert(strings.Contains(string(sessionStream), "echo hi"), check.Equals, true)
c.Assert(strings.HasSuffix(string(sessionStream), "\nsuffix"), check.Equals, true)
c.Assert(strings.Contains(string(sessionStream), "\nsuffix"), check.Equals, true)
// now lets look at session events:
history, err := site.GetSessionEvents(session.ID, 0)
c.Assert(err, check.IsNil)
first := history[0]
beforeLast := history[len(history)-2]
last := history[len(history)-1]
getChunk := func(e events.EventFields) string {
offset := e.GetInt("offset")
@ -201,32 +201,41 @@ func (s *IntSuite) TestAudit(c *check.C) {
if length == 0 {
return ""
}
c.Assert(offset+length <= len(sessionStream), check.Equals, true)
return string(sessionStream[offset : offset+length])
}
// last two are manually-typed (32Kb chunk and "suffix"):
c.Assert(last.GetString(events.EventType), check.Equals, "print")
c.Assert(beforeLast.GetString(events.EventType), check.Equals, "print")
c.Assert(last.GetInt("bytes"), check.Equals, len("\nsuffix"))
c.Assert(beforeLast.GetInt("bytes"), check.Equals, len(bigChunk))
findByType := func(et string) events.EventFields {
for _, e := range history {
if e.GetType() == et {
return e
}
}
return nil
}
// 10th chunk should be printed "hi":
c.Assert(strings.HasPrefix(getChunk(history[10]), "hi"), check.Equals, true)
// there should alwys be 'session.start' event (and it must be first)
first := history[0]
start := findByType(events.SessionStartEvent)
c.Assert(start, check.DeepEquals, first)
c.Assert(start.GetInt("bytes"), check.Equals, 0)
c.Assert(start.GetString(events.SessionEventID) != "", check.Equals, true)
c.Assert(start.GetString(events.TerminalSize) != "", check.Equals, true)
// 1st should be "session.start"
c.Assert(first.GetString(events.EventType), check.Equals, events.SessionStartEvent)
// 3rd event is always "print suffix"
c.Assert(history[2].GetType(), check.Equals, events.SessionPrintEvent)
c.Assert(getChunk(history[2]), check.Equals, "\nsuffix")
// last-3 should be "session.end", and the one before - "session.leave"
endEvent := history[len(history)-3]
leaveEvent := history[len(history)-4]
c.Assert(endEvent.GetString(events.EventType), check.Equals, events.SessionEndEvent)
c.Assert(leaveEvent.GetString(events.EventType), check.Equals, events.SessionLeaveEvent)
// there should alwys be 'session.end' event
end := findByType(events.SessionEndEvent)
c.Assert(end, check.NotNil)
c.Assert(end.GetInt("bytes"), check.Equals, 0)
c.Assert(end.GetString(events.SessionEventID) != "", check.Equals, true)
// session events should have session ID assigned
c.Assert(first.GetString(events.SessionEventID) != "", check.Equals, true)
c.Assert(endEvent.GetString(events.SessionEventID) != "", check.Equals, true)
c.Assert(leaveEvent.GetString(events.SessionEventID) != "", check.Equals, true)
// there should alwys be 'session.leave' event
leave := findByType(events.SessionLeaveEvent)
c.Assert(leave, check.NotNil)
c.Assert(leave.GetInt("bytes"), check.Equals, 0)
c.Assert(leave.GetString(events.SessionEventID) != "", check.Equals, true)
// all of them should have a proper time:
for _, e := range history {

View file

@ -19,7 +19,6 @@ package auth
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
@ -36,8 +35,6 @@ import (
log "github.com/Sirupsen/logrus"
"github.com/gravitational/trace"
"github.com/julienschmidt/httprouter"
"golang.org/x/net/websocket"
)
// Config is APIServer config
@ -113,7 +110,7 @@ func NewAPIServer(a *AuthWithRoles) *APIServer {
srv.GET("/v1/sessions", httplib.MakeHandler(srv.getSessions))
srv.GET("/v1/sessions/:id", httplib.MakeHandler(srv.getSession))
srv.POST("/v1/sessions/:id/stream", httplib.MakeHandler(srv.postSessionChunk))
srv.GET("/v1/sessions/:id/reader", httplib.MakeHandler(srv.getSessionReader))
srv.GET("/v1/sessions/:id/stream", srv.getSessionChunk)
srv.GET("/v1/sessions/:id/events", httplib.MakeHandler(srv.getSessionEvents))
// OIDC stuff
@ -802,40 +799,46 @@ func (s *APIServer) postSessionChunk(w http.ResponseWriter, r *http.Request, p h
if err != nil {
return nil, trace.Wrap(err)
}
if err = s.a.PostSessionChunk(*sid, r.Body); err != nil {
return nil, trace.Wrap(err)
}
return message("ok"), nil
}
// HTTP GET /v1/sessions/:id/reader
func (s *APIServer) getSessionReader(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) {
// HTTP GET /v1/sessions/:id/stream?offset=x&bytes=y
// Query parameters:
// "offset" : bytes from the beginning
// "bytes" : number of bytes to read (it won't return more than 512Kb)
func (s *APIServer) getSessionChunk(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
sid, err := session.ParseID(p.ByName("id"))
if err != nil {
return nil, trace.Wrap(err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// parse "offset bytes"
offsetBytes, err := strconv.Atoi(r.URL.Query().Get("from"))
// "offset bytes" query param
offsetBytes, err := strconv.Atoi(r.URL.Query().Get("offset"))
if err != nil || offsetBytes < 0 {
offsetBytes = 0
}
log.Infof("[AUTH] api.getSessionReader(%v, %v)", *sid, offsetBytes)
reader, err := s.a.GetSessionReader(*sid, offsetBytes)
// "max bytes" query param
max, err := strconv.Atoi(r.URL.Query().Get("bytes"))
if err != nil || offsetBytes < 0 {
offsetBytes = 0
}
log.Infof("----> apiserver.GetSessionChunk(%v, offset=%d)", *sid, offsetBytes)
w.Header().Set("Content-Type", "text/plain")
buffer, err := s.a.GetSessionChunk(*sid, offsetBytes, max)
if err != nil {
return nil, trace.Wrap(err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer reader.Close()
ws := websocket.Server{
Handler: func(conn *websocket.Conn) {
log.Info("[AUTH] session streaming websocket open")
// set websocket to 64K read/writes
buffer := make([]byte, 1024*64)
read, _ := io.CopyBuffer(conn, reader, buffer)
log.Infof("[AUTH] session streaming websocket closed: %v bytes streamed", read)
},
if _, err = w.Write(buffer); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
ws.ServeHTTP(w, r)
return nil, nil
w.Header().Set("Content-Type", "application/octet-stream")
}
// HTTP GET /v1/sessions/:id/events?maxage=n

View file

@ -21,6 +21,7 @@ import (
"net/url"
"time"
"github.com/Sirupsen/logrus"
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/services"
@ -368,11 +369,12 @@ func (a *AuthWithRoles) PostSessionChunk(sid session.ID, reader io.Reader) error
return a.alog.PostSessionChunk(sid, reader)
}
func (a *AuthWithRoles) GetSessionReader(sid session.ID, offsetBytes int) (io.ReadCloser, error) {
func (a *AuthWithRoles) GetSessionChunk(sid session.ID, offsetBytes, maxBytes int) ([]byte, error) {
logrus.Infof("----> authWithRoles.GetSessionChunk(%v, offset=%d)", sid, offsetBytes)
if err := a.permChecker.HasPermission(a.role, ActionViewSession); err != nil {
return nil, trace.Wrap(err)
}
return a.alog.GetSessionReader(sid, offsetBytes)
return a.alog.GetSessionChunk(sid, offsetBytes, maxBytes)
}
func (a *AuthWithRoles) GetSessionEvents(sid session.ID, afterN int) ([]events.EventFields, error) {

View file

@ -738,24 +738,35 @@ func (c *Client) EmitAuditEvent(eventType string, fields events.EventFields) err
// The data is POSTed to HTTP server as a simple binary body (no encodings of any
// kind are needed)
func (c *Client) PostSessionChunk(sid session.ID, reader io.Reader) error {
logrus.Infof("authClient.PostSessionChunk(%v)", sid)
logrus.Infof("----> authClient.PostSessionChunk(%v)", sid)
request, err := http.NewRequest("POST",
c.Endpoint("sessions", string(sid), "stream"),
reader)
request.Header.Set("Content-Type", "application/octet-stream")
_, err = c.Client.HTTPClient().Do(request)
resp, err := c.Client.HTTPClient().Do(request)
if err != nil {
return trace.Wrap(err)
}
defer resp.Body.Close()
return nil
}
// GetSessionReader allows clients to recewive a live stream of an active session
func (c *Client) GetSessionReader(sid session.ID, offsetBytes int) (io.ReadCloser, error) {
return c.openWebsocket(c.Endpoint("sessions", string(sid), "reader") +
fmt.Sprintf("?from=%d", offsetBytes))
// GetSessionChunk allows clients to receive a byte array (chunk) from a recorded
// session stream, starting from 'offset', up to 'max' in length. The upper bound
// of 'max' is set to events.MaxChunkBytes
func (c *Client) GetSessionChunk(sid session.ID, offsetBytes, maxBytes int) ([]byte, error) {
logrus.Infof("----> authClient.GetSessionChunk(from=%d, max=%d)", offsetBytes, maxBytes)
response, err := c.Get(c.Endpoint("sessions", string(sid), "stream"), url.Values{
"offset": []string{strconv.Itoa(offsetBytes)},
"bytes": []string{strconv.Itoa(maxBytes)},
})
if err != nil {
logrus.Error(err)
return nil, trace.Wrap(err)
}
return response.Bytes(), nil
}
// Returns events that happen during a session sorted by time

View file

@ -549,6 +549,8 @@ func NewTunClient(purpose string,
for _, o := range opts {
o(tc)
}
log.Infof("newTunClient(%s)", purpose)
clt, err := NewClient("http://stub:0", tc.Dial)
if err != nil {
return nil, trace.Wrap(err)
@ -621,7 +623,9 @@ func (c *TunClient) Dial(network, address string) (net.Conn, error) {
}
// dialed & authenticated? lets start synchronizing the
// list of auth servers:
go c.authServersSyncLoop()
if c.refreshTicker == nil {
go c.authServersSyncLoop()
}
return &tunConn{client: client, Conn: conn}, nil
}
@ -647,7 +651,6 @@ func (c *TunClient) fetchAndSync() error {
// authServersSyncLoop continuously refreshes the list of available auth servers
// for this client
func (c *TunClient) authServersSyncLoop() {
log.Infof("TunClient[%s]: authServersSyncLoop() started", c.purpose)
alreadyRunning := func() bool {
c.Lock()
defer c.Unlock()
@ -661,6 +664,7 @@ func (c *TunClient) authServersSyncLoop() {
if alreadyRunning() {
return
}
log.Infof("TunClient[%s]: authServersSyncLoop() started", c.purpose)
defer c.refreshTicker.Stop()
// initial fetch for quick start-ups

View file

@ -37,6 +37,7 @@ import (
"github.com/gravitational/teleport/lib/auth/native"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/session"
"github.com/gravitational/teleport/lib/utils"
@ -382,26 +383,6 @@ func (tc *TeleportClient) Join(sessionID session.ID, input io.Reader) (err error
return tc.runShell(nc, session.ID, input)
}
// readAll is similarl to ioutil.ReadAll, except it doesn't use ever-increasing
// internal buffer, instead asking for the exact buffer size.
//
// we need this for websockets: they can't deal with huge Reads
// (set bufsize to 32K)
func readAll(r io.Reader, bufsize int) (out []byte, err error) {
buff := make([]byte, bufsize)
n := 0
for err == nil {
n, err = r.Read(buff)
if n > 0 {
out = append(out, buff[:n]...)
}
}
if err == io.EOF {
err = nil
}
return out, err
}
// SCP securely copies file(s) from one SSH server to another
func (tc *TeleportClient) Play(sessionId string) (err error) {
sid, err := session.ParseID(sessionId)
@ -424,14 +405,17 @@ func (tc *TeleportClient) Play(sessionId string) (err error) {
}
// read the stream into a buffer:
reader, err := site.GetSessionReader(*sid, 0)
if err != nil {
return trace.Wrap(err)
}
defer reader.Close()
stream, err := readAll(reader, 1024*32) //ioutil.ReadAll(reader)
if err != nil {
return trace.Wrap(err)
var stream []byte
for err == nil {
tmp, err := site.GetSessionChunk(*sid, len(stream), events.MaxChunkBytes)
if err != nil {
return trace.Wrap(err)
}
if len(tmp) == 0 {
err = io.EOF
break
}
stream = append(stream, tmp...)
}
// configure terminal for direct unbuffered echo-less input:

View file

@ -91,6 +91,12 @@ const (
TerminalSize = "size" // expressed as 'W:H'
)
const (
// MaxChunkBytes defines the maximum size of a session stream chunk that
// can be requested via AuditLog.GetSessionChunk(). Set to 5MB
MaxChunkBytes = 1024 * 1024 * 5
)
// AuditLogI is the primary (and the only external-facing) interface for AUditLogger.
// If you wish to implement a different kind of logger (not filesystem-based), you
// have to implement this interface
@ -101,10 +107,12 @@ type AuditLogI interface {
// their live sessions into the session log
PostSessionChunk(sid session.ID, reader io.Reader) error
// GetSessionReader returns a reader which can be used to read a byte stream
// GetSessionChunk returns a reader which can be used to read a byte stream
// of a recorded session starting from 'offsetBytes' (pass 0 to start from the
// beginning)
GetSessionReader(sid session.ID, offsetBytes int) (io.ReadCloser, error)
// beginning) up to maxBytes bytes.
//
// If maxBytes > MaxChunkBytes, it gets rounded down to MaxChunkBytes
GetSessionChunk(sid session.ID, offsetBytes, maxBytes int) ([]byte, error)
// Returns all events that happen during a session sorted by time
// (oldest first).
@ -138,6 +146,11 @@ func (f EventFields) AsString() string {
f.GetInt(SessionPrintEventBytes))
}
// GetType returns the type (string) of the event
func (f EventFields) GetType() string {
return f.GetString(EventType)
}
// GetString returns a string representation of a logged field
func (f EventFields) GetString(key string) string {
val, found := f[key]

View file

@ -50,6 +50,7 @@ package events
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
@ -63,6 +64,7 @@ import (
"github.com/Sirupsen/logrus"
"github.com/gravitational/teleport/lib/session"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/trace"
)
@ -242,12 +244,16 @@ func NewAuditLog(dataDir string, testMode bool) (*AuditLog, error) {
// PostSessionChunk writes a new chunk of session stream into the audit log
func (l *AuditLog) PostSessionChunk(sid session.ID, reader io.Reader) error {
//buffer, err := utils.ReadAll(reader, 1024*32)
//logrus.Infof("----> auditLog.OnPostSessionChunk() got %d bytes,err=%v", len(buffer), err)
sl := l.LoggerFor(sid)
if sl == nil {
logrus.Warnf("audit.log: no session writer for %s", sid)
return nil
}
written, err := io.CopyBuffer(sl, reader, make([]byte, 8*1024))
tmp, err := utils.ReadAll(reader, 8*1024)
written, err := sl.Write(tmp)
if err != nil {
logrus.Error(err)
return trace.Wrap(err)
@ -256,18 +262,28 @@ func (l *AuditLog) PostSessionChunk(sid session.ID, reader io.Reader) error {
return nil
}
// GetSessionReader returns a reader which console and web clients request
// to receive a live stream of a given session
func (l *AuditLog) GetSessionReader(sid session.ID, offsetBytes int) (io.ReadCloser, error) {
// GetSessionChunk returns a reader which console and web clients request
// to receive a live stream of a given session. The reader allows access to a
// session stream range from offsetBytes to offsetBytes+maxBytes
//
func (l *AuditLog) GetSessionChunk(sid session.ID, offsetBytes, maxBytes int) ([]byte, error) {
logrus.Infof("audit.log: getSessionReader(%v)", sid)
fstream, err := os.OpenFile(l.sessionStreamFn(sid), os.O_RDONLY, 0640)
if err != nil {
logrus.Warning(err)
return nil, trace.Wrap(err)
}
defer fstream.Close()
// seek to 'offset'
const fromBeginning = 0
fstream.Seek(int64(offsetBytes), fromBeginning)
return fstream, nil
// copy up to maxBytes from the offset position:
var buff bytes.Buffer
io.Copy(&buff, io.LimitReader(fstream, int64(maxBytes)))
return buff.Bytes(), nil
}
// Returns all events that happen during a session sorted by time

View file

@ -1,6 +1,9 @@
package utils
import "os"
import (
"io"
"os"
)
// IsDir is a helper function to quickly check if a given path is a valid directory
func IsDir(dirPath string) bool {
@ -10,3 +13,22 @@ func IsDir(dirPath string) bool {
}
return false
}
// ReadAll is similarl to ioutil.ReadAll, except it doesn't use ever-increasing
// internal buffer, instead asking for the exact buffer size.
//
// We need this for websockets: they can't deal with huge Reads > 32K
func ReadAll(r io.Reader, bufsize int) (out []byte, err error) {
buff := make([]byte, bufsize)
n := 0
for err == nil {
n, err = r.Read(buff)
if n > 0 {
out = append(out, buff[:n]...)
}
}
if err == io.EOF {
err = nil
}
return out, err
}

View file

@ -19,12 +19,10 @@ limitations under the License.
package web
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"html/template"
"io"
"io/ioutil"
"net/http"
"net/url"
@ -999,22 +997,13 @@ func (m *Handler) siteSessionStreamGet(w http.ResponseWriter, r *http.Request, p
return nil, trace.BadParameter("bytes", "bytes=%d, cannot exceed %d", max, maxStreamBytes)
}
// read file:
reader, err := clt.GetSessionReader(*sid, offset)
if err != nil {
// return empty buffer if no file found
return siteSessionStreamGetResponse{Bytes: []byte{}}, nil
}
defer reader.Close()
var buff bytes.Buffer
written, err := io.CopyN(&buff, reader, int64(max))
bytes, err := clt.GetSessionChunk(*sid, offset, max)
if err != nil {
log.Error(err)
return nil, trace.Wrap(err)
}
log.Infof("[web] siteSessionStreamGet() returned %d/%d bytes", len(buff.Bytes()), written)
return siteSessionStreamGetResponse{Bytes: buff.Bytes()}, nil
log.Infof("----> [web] siteSessionStreamGet() returned %d bytes", len(bytes))
return siteSessionStreamGetResponse{Bytes: bytes}, nil
}
type eventsListGetResponse struct {