mirror of
https://github.com/gravitational/teleport
synced 2024-10-19 08:43:58 +00:00
Tweak the PNG encoder (#9817)
Instead of using the default, allow for providing a custom png.Encoder. Set BestSpeed for compression and use a pool in order to reuse buffers. Also add some sample PNG frames and a benchmark for encoding them. This results in ~ 70% reduction in encoding time and > 99% reduction in memory allocations. name old time/op new time/op EncodePNG-8 107µs ± 0% 30µs ± 0% name old alloc/op new alloc/op EncodePNG-8 850kB ± 0% 1kB ± 0% name old allocs/op new allocs/op EncodePNG-8 42.0 ± 0% 13.0 ± 0% Updates #8742
This commit is contained in:
parent
46bf623c51
commit
74e580ab5f
|
@ -160,7 +160,7 @@ func New(ctx context.Context, cfg Config) (*Client, error) {
|
|||
|
||||
func (c *Client) readClientUsername() error {
|
||||
for {
|
||||
msg, err := c.cfg.InputMessage()
|
||||
msg, err := c.cfg.Conn.InputMessage()
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
@ -177,7 +177,7 @@ func (c *Client) readClientUsername() error {
|
|||
|
||||
func (c *Client) readClientSize() error {
|
||||
for {
|
||||
msg, err := c.cfg.InputMessage()
|
||||
msg, err := c.cfg.Conn.InputMessage()
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
@ -255,7 +255,7 @@ func (c *Client) start() {
|
|||
// Remember mouse coordinates to send them with all CGOPointer events.
|
||||
var mouseX, mouseY uint32
|
||||
for {
|
||||
msg, err := c.cfg.InputMessage()
|
||||
msg, err := c.cfg.Conn.InputMessage()
|
||||
if err != nil {
|
||||
c.cfg.Log.Warningf("Failed reading RDP input message: %v", err)
|
||||
return
|
||||
|
@ -383,7 +383,7 @@ func (c *Client) handleBitmap(cb C.CGOBitmap) C.CGOError {
|
|||
})
|
||||
copy(img.Pix, data)
|
||||
|
||||
if err := c.cfg.OutputMessage(tdp.PNGFrame{Img: img}); err != nil {
|
||||
if err := c.cfg.Conn.OutputMessage(tdp.NewPNG(img, c.cfg.Encoder)); err != nil {
|
||||
return C.CString(fmt.Sprintf("failed to send PNG frame %v: %v", img.Rect, err))
|
||||
}
|
||||
return nil
|
||||
|
|
|
@ -19,6 +19,7 @@ package rdpclient
|
|||
|
||||
import (
|
||||
"context"
|
||||
"image/png"
|
||||
|
||||
"github.com/gravitational/teleport/lib/srv/desktop/tdp"
|
||||
"github.com/gravitational/trace"
|
||||
|
@ -35,13 +36,12 @@ type Config struct {
|
|||
// AuthorizeFn is called to authorize a user connecting to a Windows desktop.
|
||||
AuthorizeFn func(login string) error
|
||||
|
||||
// TODO(zmb3): replace these callbacks with a tdp.Conn
|
||||
// Conn handles TDP messages between Windows Desktop Service
|
||||
// and a Teleport Proxy.
|
||||
Conn *tdp.Conn
|
||||
|
||||
// InputMessage is called to receive a message from the client for the RDP
|
||||
// server. This function should block until there is a message.
|
||||
InputMessage func() (tdp.Message, error)
|
||||
// OutputMessage is called to send a message from RDP server to the client.
|
||||
OutputMessage func(tdp.Message) error
|
||||
// Encoder is an optional override for PNG encoding.
|
||||
Encoder *png.Encoder
|
||||
|
||||
// Log is the logger for status messages.
|
||||
Log logrus.FieldLogger
|
||||
|
@ -58,15 +58,15 @@ func (c *Config) checkAndSetDefaults() error {
|
|||
if c.GenerateUserCert == nil {
|
||||
return trace.BadParameter("missing GenerateUserCert in rdpclient.Config")
|
||||
}
|
||||
if c.InputMessage == nil {
|
||||
return trace.BadParameter("missing InputMessage in rdpclient.Config")
|
||||
}
|
||||
if c.OutputMessage == nil {
|
||||
return trace.BadParameter("missing OutputMessage in rdpclient.Config")
|
||||
if c.Conn == nil {
|
||||
return trace.BadParameter("missing Conn in rdpclient.Config")
|
||||
}
|
||||
if c.AuthorizeFn == nil {
|
||||
return trace.BadParameter("missing AuthorizeFn in rdpclient.Config")
|
||||
}
|
||||
if c.Encoder == nil {
|
||||
c.Encoder = tdp.PNGEncoder()
|
||||
}
|
||||
if c.Log == nil {
|
||||
c.Log = logrus.New()
|
||||
}
|
||||
|
|
38
lib/srv/desktop/tdp/png.go
Normal file
38
lib/srv/desktop/tdp/png.go
Normal file
|
@ -0,0 +1,38 @@
|
|||
// Copyright 2022 Gravitational, Inc
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tdp
|
||||
|
||||
import "image/png"
|
||||
|
||||
// PNGEncoder returns the encoder used for PNG Frames.
|
||||
// It is not safe for concurrent use.
|
||||
func PNGEncoder() *png.Encoder {
|
||||
return &png.Encoder{
|
||||
CompressionLevel: png.BestSpeed,
|
||||
BufferPool: &pool{},
|
||||
}
|
||||
}
|
||||
|
||||
// pool implements png.EncoderBufferPool,
|
||||
// allowing us to reuse encoding resources
|
||||
type pool struct {
|
||||
b *png.EncoderBuffer
|
||||
}
|
||||
|
||||
// all encoding happens in a single thread, so we don't
|
||||
// need anything as sophisticated as a sync.Pool here
|
||||
|
||||
func (p *pool) Get() *png.EncoderBuffer { return p.b }
|
||||
func (p *pool) Put(eb *png.EncoderBuffer) { p.b = eb }
|
|
@ -104,6 +104,15 @@ func decode(in peekReader) (Message, error) {
|
|||
// https://github.com/gravitational/teleport/blob/master/rfd/0037-desktop-access-protocol.md#2---png-frame
|
||||
type PNGFrame struct {
|
||||
Img image.Image
|
||||
|
||||
enc *png.Encoder // optionally override the PNG encoder
|
||||
}
|
||||
|
||||
func NewPNG(img image.Image, enc *png.Encoder) PNGFrame {
|
||||
return PNGFrame{
|
||||
Img: img,
|
||||
enc: enc,
|
||||
}
|
||||
}
|
||||
|
||||
func (f PNGFrame) Encode() ([]byte, error) {
|
||||
|
@ -123,10 +132,11 @@ func (f PNGFrame) Encode() ([]byte, error) {
|
|||
}); err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
// Note: this uses the default png.Encoder parameters.
|
||||
// You can tweak compression level and reduce memory allocations by using a
|
||||
// custom png.Encoder, if this happens to be a bottleneck.
|
||||
if err := png.Encode(buf, f.Img); err != nil {
|
||||
encoder := f.enc
|
||||
if encoder == nil {
|
||||
encoder = &png.Encoder{}
|
||||
}
|
||||
if err := encoder.Encode(buf, f.Img); err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
|
|
|
@ -17,11 +17,16 @@ limitations under the License.
|
|||
package tdp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"image"
|
||||
"image/color"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
|
@ -51,7 +56,7 @@ func TestEncodeDecode(t *testing.T) {
|
|||
out, err := Decode(buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Empty(t, cmp.Diff(m, out))
|
||||
require.Empty(t, cmp.Diff(m, out, cmpopts.IgnoreUnexported(PNGFrame{})))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -60,3 +65,49 @@ func TestBadDecode(t *testing.T) {
|
|||
_, err := Decode([]byte{254})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
var encodedFrame []byte
|
||||
|
||||
func BenchmarkEncodePNG(b *testing.B) {
|
||||
b.StopTimer()
|
||||
frames := loadBitmaps(b)
|
||||
b.StartTimer()
|
||||
var err error
|
||||
for i := 0; i < b.N; i++ {
|
||||
fi := i % len(frames)
|
||||
encodedFrame, err = frames[fi].Encode()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func loadBitmaps(b *testing.B) []PNGFrame {
|
||||
b.Helper()
|
||||
|
||||
f, err := os.Open(filepath.Join("testdata", "png_frames.json"))
|
||||
require.NoError(b, err)
|
||||
defer f.Close()
|
||||
|
||||
enc := PNGEncoder()
|
||||
|
||||
var result []PNGFrame
|
||||
type record struct {
|
||||
Top, Left, Right, Bottom int
|
||||
Pix []byte
|
||||
}
|
||||
s := bufio.NewScanner(f)
|
||||
for s.Scan() {
|
||||
var r record
|
||||
require.NoError(b, json.Unmarshal(s.Bytes(), &r))
|
||||
|
||||
img := image.NewNRGBA(image.Rectangle{
|
||||
Min: image.Pt(r.Left, r.Top),
|
||||
Max: image.Pt(r.Right, r.Bottom),
|
||||
})
|
||||
copy(img.Pix, r.Pix)
|
||||
result = append(result, NewPNG(img, enc))
|
||||
}
|
||||
require.NoError(b, s.Err())
|
||||
return result
|
||||
}
|
||||
|
|
100
lib/srv/desktop/tdp/testdata/png_frames.json
vendored
Normal file
100
lib/srv/desktop/tdp/testdata/png_frames.json
vendored
Normal file
File diff suppressed because one or more lines are too long
|
@ -547,10 +547,9 @@ func (s *WindowsService) connectRDP(ctx context.Context, log logrus.FieldLogger,
|
|||
GenerateUserCert: func(ctx context.Context, username string) (certDER, keyDER []byte, err error) {
|
||||
return s.generateCredentials(ctx, username, desktop.GetDomain())
|
||||
},
|
||||
Addr: desktop.GetAddr(),
|
||||
InputMessage: tdpConn.InputMessage,
|
||||
OutputMessage: tdpConn.OutputMessage,
|
||||
AuthorizeFn: authorize,
|
||||
Addr: desktop.GetAddr(),
|
||||
Conn: tdpConn,
|
||||
AuthorizeFn: authorize,
|
||||
})
|
||||
if err != nil {
|
||||
s.onSessionStart(ctx, &identity, windowsUser, string(sessionID), desktop, err)
|
||||
|
|
Loading…
Reference in a new issue