Fix memory leak on Kubernetes port-forwarding (#24741)

This PR fixes a memory leak on Kubernetes access when using the SPDY protocol - used by `kubectl`.

The memory leak happens when a new connection is established using the SPDY's HTTPStreams. Each time a new connection is created locally, `kubectl` creates two streams for it - one for data and another for returning errors. When the multiplexed streams finish, they weren't properly cleanup from the SPDY connection and, although closed, their memory was kept alive and reachable for the duration of the SPDY long-lived connection. This ends up leaking memory and results in OOM events for the Proxy and Kubernetes services if a large number of connections are established within the same port-forwarding session.

Fixes #10966
This commit is contained in:
Tiago Silva 2023-04-18 21:09:53 +01:00 committed by GitHub
parent bebef91be8
commit a05165dd64
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 286 additions and 81 deletions

View file

@ -19,10 +19,9 @@ package proxy
import (
"context"
"fmt"
"io"
"net"
"net/http"
"strconv"
"strings"
"sync"
"time"
@ -33,6 +32,7 @@ import (
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/utils"
)
// portForwardRequest is a request that specifies port forwarding
@ -165,86 +165,68 @@ func (h *portForwardProxy) Close() error {
return nil
}
// forwardStreamPair creates a new data and error streams using the same requestID
// received from the client and copies the data between target's data and error and
// client's data and error streams. It blocks until all copy operations complete.
// It does not close the client's data and error streams as they are closed by
// the caller.
func (h *portForwardProxy) forwardStreamPair(p *httpStreamPair, remotePort int64) error {
// create error stream
headers := http.Header{}
port := fmt.Sprintf("%d", remotePort)
headers.Set(StreamType, StreamTypeError)
headers.Set(PortHeader, fmt.Sprintf("%d", remotePort))
headers.Set(PortHeader, port)
headers.Set(PortForwardRequestIDHeader, p.requestID)
// read and write from the error stream
targetErrorStream, err := h.targetConn.CreateStream(headers)
h.onPortForward(net.JoinHostPort(h.podName, port), err == nil /* success */)
if err != nil {
h.onPortForward(fmt.Sprintf("%v:%v", h.podName, remotePort), false)
return trace.ConnectionProblem(err, "error creating error stream for port %d", remotePort)
err := trace.ConnectionProblem(err, "error creating error stream for port %d", remotePort)
p.sendErr(err)
return err
}
h.onPortForward(fmt.Sprintf("%v:%v", h.podName, remotePort), true)
defer targetErrorStream.Close()
go func() {
_, err := io.Copy(targetErrorStream, p.errorStream)
if err != nil && err != io.EOF {
h.Debugf("Copy stream error: %v.", err)
}
defer func() {
// on stream close, remove the stream from the connection and close it.
h.targetConn.RemoveStreams(targetErrorStream)
targetErrorStream.Close()
}()
errClose := make(chan struct{})
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
defer close(errClose)
_, err := io.Copy(p.errorStream, targetErrorStream)
if err != nil && err != io.EOF {
h.Debugf("Copy stream error: %v.", err)
defer wg.Done()
if err := utils.ProxyConn(h.context, p.errorStream, targetErrorStream); err != nil {
h.WithError(err).Debugf("Unable to proxy portforward error-stream.")
}
}()
// create data stream
headers.Set(StreamType, StreamTypeData)
dataStream, err := h.targetConn.CreateStream(headers)
targetDataStream, err := h.targetConn.CreateStream(headers)
if err != nil {
return trace.ConnectionProblem(err, "error creating forwarding stream for port -> %d: %v", remotePort, err)
err := trace.ConnectionProblem(err, "error creating forwarding stream for port -> %d: %v", remotePort, err)
p.sendErr(err)
return err
}
defer dataStream.Close()
localError := make(chan struct{})
remoteDone := make(chan struct{})
go func() {
// inform the select below that the remote copy is done
defer close(remoteDone)
// Copy from the remote side to the local port.
if _, err := io.Copy(p.dataStream, dataStream); err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
log.Error(fmt.Errorf("error copying from remote stream to local connection: %v", err))
}
defer func() {
// on stream close, remove the stream from the connection and close it.
h.targetConn.RemoveStreams(targetDataStream)
targetDataStream.Close()
}()
wg.Add(1)
go func() {
// inform server we're not sending any more data after copy unblocks
defer dataStream.Close()
// Copy from the local port to the target side.
if _, err := io.Copy(dataStream, p.dataStream); err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
h.Warningf("Error copying from local connection to remote stream: %v.", err)
// break out of the select below without waiting for the other copy to finish
close(localError)
defer wg.Done()
if err := utils.ProxyConn(h.context, p.dataStream, targetDataStream); err != nil {
h.WithError(err).Debugf("Unable to proxy portforward data-stream.")
}
}()
h.Debugf("Streams have been created, Waiting for copy to complete.")
// wait for either a local->remote error or for copying from remote->local to finish
select {
case <-remoteDone:
case <-localError:
case <-h.context.Done():
h.Debugf("Context is closing, cleaning up.")
}
// always expect something on errorChan (it may be nil)
select {
case <-errClose:
case <-h.context.Done():
h.Debugf("Context is closing, cleaning up.")
}
// wait for the copies to complete before returning.
wg.Wait()
h.Debugf("Port forwarding pair completed.")
return nil
}
@ -272,9 +254,11 @@ func (h *portForwardProxy) getStreamPair(requestID string) (*httpStreamPair, boo
// monitorStreamPair waits for the pair to receive both its error and data
// streams, or for the timeout to expire (whichever happens first), and then
// removes the pair.
func (h *portForwardProxy) monitorStreamPair(p *httpStreamPair, timeout <-chan time.Time) {
func (h *portForwardProxy) monitorStreamPair(p *httpStreamPair) {
timeC := time.NewTimer(h.streamCreationTimeout)
defer timeC.Stop()
select {
case <-timeout:
case <-timeC.C:
h.Errorf("Request %s, timed out waiting for streams.", p.requestID)
case <-p.complete:
h.Debugf("Request %s, successfully received error and data streams.", p.requestID)
@ -286,7 +270,14 @@ func (h *portForwardProxy) monitorStreamPair(p *httpStreamPair, timeout <-chan t
func (h *portForwardProxy) removeStreamPair(requestID string) {
h.streamPairsLock.Lock()
defer h.streamPairsLock.Unlock()
pair, ok := h.streamPairs[requestID]
if !ok {
return
}
if h.sourceConn != nil {
// remove the streams from the connection and close them.
h.sourceConn.RemoveStreams(pair.dataStream, pair.errorStream)
}
delete(h.streamPairs, requestID)
}
@ -323,11 +314,11 @@ func (h *portForwardProxy) run() {
p, created := h.getStreamPair(requestID)
if created {
go h.monitorStreamPair(p, time.After(h.streamCreationTimeout))
go h.monitorStreamPair(p)
}
if complete, err := p.add(stream); err != nil {
msg := fmt.Sprintf("error processing stream for request %s: %v", requestID, err)
p.printError(msg)
err := trace.BadParameter("error processing stream for request %s: %v", requestID, err)
p.sendErr(err)
} else if complete {
go h.portForward(p)
}
@ -335,29 +326,27 @@ func (h *portForwardProxy) run() {
}
}
// portForward invokes the portForwardProxy's forwarder.PortForward
// function for the given stream pair.
// portForward handles the port-forwarding for the given stream pair.
// It closes the pair when it is done.
func (h *portForwardProxy) portForward(p *httpStreamPair) {
defer p.dataStream.Close()
defer p.errorStream.Close()
defer p.close()
portString := p.dataStream.Headers().Get(PortHeader)
port, _ := strconv.ParseInt(portString, 10, 32)
h.Debugf("Forwarding port %v -> %v.", p.requestID, portString)
err := h.forwardStreamPair(p, port)
h.Debugf("Completed forwarding port %v -> %v.", p.requestID, portString)
if err != nil {
msg := fmt.Errorf("error forwarding port %d to pod %s: %v", port, h.podName, err)
fmt.Fprint(p.errorStream, msg.Error())
if err := h.forwardStreamPair(p, port); err != nil {
h.WithError(err).Debugf("Error forwarding port %v -> %v.", p.requestID, portString)
return
}
h.Debugf("Completed forwarding port %v -> %v.", p.requestID, portString)
}
// httpStreamPair represents the error and data streams for a port
// forwarding request.
type httpStreamPair struct {
lock sync.RWMutex
lock sync.Mutex
requestID string
dataStream httpstream.Stream
errorStream httpstream.Stream
@ -400,11 +389,26 @@ func (p *httpStreamPair) add(stream httpstream.Stream) (bool, error) {
return complete, nil
}
// printError writes s to p.errorStream if p.errorStream has been set.
func (p *httpStreamPair) printError(s string) {
p.lock.RLock()
defer p.lock.RUnlock()
// sendErr writes s to p.errorStream if p.errorStream has been set.
func (p *httpStreamPair) sendErr(err error) {
if err == nil {
return
}
p.lock.Lock()
defer p.lock.Unlock()
if p.errorStream != nil {
fmt.Fprint(p.errorStream, s)
fmt.Fprint(p.errorStream, err.Error())
}
}
// close closes the data and error streams for this pair.
func (p *httpStreamPair) close() {
p.lock.Lock()
defer p.lock.Unlock()
if p.dataStream != nil {
p.dataStream.Close()
}
if p.errorStream != nil {
p.errorStream.Close()
}
}

View file

@ -19,14 +19,20 @@ package proxy
import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"sync"
"testing"
"time"
"github.com/gravitational/trace"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/portforward"
"k8s.io/client-go/transport/spdy"
@ -213,3 +219,190 @@ type portForwarder interface {
ForwardPorts() error
Close()
}
// TestPortForwardProxy_run_connsClosed tests that the port forward proxy cleans up the
// spdy stream when it is closed. This is important because the spdy connection
// holds a reference to the stream and if the stream is not removed from the
// connection, it will leak memory.
func TestPortForwardProxy_run_connsClosed(t *testing.T) {
t.Parallel()
logger := log.NewEntry(&log.Logger{Out: io.Discard})
const (
reqID = "reqID"
// portHeaderValue is the value of the port header in the stream.
// This value is not used to listen for requests, but it is used to identify the stream
// destination.
portHeaderValue = "8080"
)
sourceConn := newfakeSPDYConnection()
targetConn := newfakeSPDYConnection()
h := &portForwardProxy{
portForwardRequest: portForwardRequest{
context: context.Background(),
onPortForward: func(addr string, success bool) {},
},
Entry: logger,
sourceConn: sourceConn,
targetConn: targetConn,
streamChan: make(chan httpstream.Stream),
streamPairs: map[string]*httpStreamPair{},
streamCreationTimeout: 5 * time.Second,
}
go func() {
dataStream, err := sourceConn.CreateStream(http.Header{
PortForwardRequestIDHeader: []string{reqID},
PortHeader: []string{portHeaderValue},
StreamType: []string{StreamTypeError},
})
assert.NoError(t, err)
h.streamChan <- dataStream
errStream, err := sourceConn.CreateStream(http.Header{
PortForwardRequestIDHeader: []string{reqID},
PortHeader: []string{portHeaderValue},
StreamType: []string{StreamTypeData},
})
assert.NoError(t, err)
h.streamChan <- errStream
// Close the source after the streams are processed to unblock the call.
sourceConn.Close()
}()
// run the port forward proxy. it will read the h.streamChan and
// process the streams synchronously. Once the streams are processed,
// the sourceConn will be closed and the proxy will exit the run loop.
h.run()
// targetConn is closed once all streams are removed. It is an hack to
// unblock the targetConn.waitForClose() call otherwise it will block
// forever.
require.Eventually(t, func() bool {
select {
case <-targetConn.closed:
return true
default:
return false
}
}, 5*time.Second, 100*time.Millisecond, "streams werent properly removed from targetConn")
require.True(t, sourceConn.streamsClosed(), "sourceConn streams not closed")
require.True(t, targetConn.streamsClosed(), "targetConn streams not closed")
}
type fakeSPDYStream struct {
closed bool
headers http.Header
identifier uint32
mu sync.Mutex
}
func (f *fakeSPDYStream) Read(p []byte) (n int, err error) {
return 0, io.EOF
}
func (f *fakeSPDYStream) Write(p []byte) (n int, err error) {
return len(p), nil
}
func (f *fakeSPDYStream) Headers() http.Header {
return f.headers
}
func (f *fakeSPDYStream) Reset() error {
return nil
}
func (f *fakeSPDYStream) Identifier() uint32 {
return f.identifier
}
func (f *fakeSPDYStream) Close() error {
f.mu.Lock()
defer f.mu.Unlock()
f.closed = true
return nil
}
func (f *fakeSPDYStream) isClosed() bool {
f.mu.Lock()
defer f.mu.Unlock()
return f.closed
}
type fakeSPDYConnection struct {
count int
streams map[uint32]*fakeSPDYStream
streamsSlice []*fakeSPDYStream
closed chan bool
closedOnce sync.Once
mu sync.Mutex
}
func newfakeSPDYConnection() *fakeSPDYConnection {
return &fakeSPDYConnection{
streams: make(map[uint32]*fakeSPDYStream),
closed: make(chan bool),
}
}
// CreateStream creates a new Stream with the supplied headers.
func (f *fakeSPDYConnection) CreateStream(headers http.Header) (httpstream.Stream, error) {
f.mu.Lock()
defer f.mu.Unlock()
newHeader := http.Header{}
for k, v := range headers {
newHeader.Set(k, v[0])
}
f.count++
identifier := uint32(f.count)
stream := &fakeSPDYStream{identifier: identifier, headers: newHeader}
f.streamsSlice = append(f.streamsSlice, stream)
f.streams[identifier] = stream
return stream, nil
}
// Close resets all streams and closes the connection.
func (f *fakeSPDYConnection) Close() error {
f.closedOnce.Do(func() {
close(f.closed)
})
return nil
}
// CloseChan returns a channel that is closed when the underlying connection is closed.
func (f *fakeSPDYConnection) CloseChan() <-chan bool {
return f.closed
}
// SetIdleTimeout sets the amount of time the connection may remain idle before
// it is automatically closed.
func (f *fakeSPDYConnection) SetIdleTimeout(_ time.Duration) {}
// RemoveStreams can be used to remove a set of streams from the Connection.
func (f *fakeSPDYConnection) RemoveStreams(streams ...httpstream.Stream) {
f.mu.Lock()
defer f.mu.Unlock()
for _, stream := range streams {
if stream == nil {
continue
}
delete(f.streams, stream.Identifier())
}
// if there are no streams left, close the connection so the test can exit
if len(f.streams) == 0 {
f.Close()
}
}
func (f *fakeSPDYConnection) streamsClosed() bool {
f.mu.Lock()
defer f.mu.Unlock()
if len(f.streams) != 0 {
return false
}
for _, stream := range f.streamsSlice {
if !stream.isClosed() {
return false
}
}
return true
}

View file

@ -248,10 +248,14 @@ func (h *websocketPortforwardHandler) forwardStreamPair(p *websocketChannelPair)
h.onPortForward(fmt.Sprintf("%v:%v", h.podName, p.port), err == nil /* success */)
if err != nil {
p.sendErr(err)
p.close()
return
}
defer targetErrorStream.Close()
defer func() {
// on stream close, remove the stream from the connection and close it.
h.targetConn.RemoveStreams(targetErrorStream)
targetErrorStream.Close()
}()
wg := &sync.WaitGroup{}
wg.Add(1)
@ -264,19 +268,23 @@ func (h *websocketPortforwardHandler) forwardStreamPair(p *websocketChannelPair)
// create data stream
headers.Set(StreamType, StreamTypeData)
dataStream, err := h.targetConn.CreateStream(headers)
targetDataStream, err := h.targetConn.CreateStream(headers)
if err != nil {
p.sendErr(err)
p.close()
wg.Wait()
return
}
defer dataStream.Close()
defer func() {
// on stream close, remove the stream from the connection and close it.
h.targetConn.RemoveStreams(targetDataStream)
targetDataStream.Close()
}()
wg.Add(1)
go func() {
defer wg.Done()
if err := utils.ProxyConn(h.context, p.dataStream, dataStream); err != nil {
if err := utils.ProxyConn(h.context, p.dataStream, targetDataStream); err != nil {
h.WithError(err).Debugf("Unable to proxy portforward data-stream.")
}
}()