Tweak the PNG encoder (#9817)

Instead of using the default, allow for providing a custom
png.Encoder. Set BestSpeed for compression and use a pool
in order to reuse buffers.

Also add some sample PNG frames and a benchmark for encoding them.
This results in ~ 70% reduction in encoding time and > 99% reduction
in memory allocations.

    name         old time/op    new time/op
    EncodePNG-8     107µs ± 0%      30µs ± 0%

    name         old alloc/op   new alloc/op
    EncodePNG-8     850kB ± 0%       1kB ± 0%

    name         old allocs/op  new allocs/op
    EncodePNG-8      42.0 ± 0%      13.0 ± 0%

Updates #8742
This commit is contained in:
Zac Bergquist 2022-01-19 14:11:55 -07:00 committed by GitHub
parent 46bf623c51
commit 74e580ab5f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 222 additions and 24 deletions

View file

@ -160,7 +160,7 @@ func New(ctx context.Context, cfg Config) (*Client, error) {
func (c *Client) readClientUsername() error {
for {
msg, err := c.cfg.InputMessage()
msg, err := c.cfg.Conn.InputMessage()
if err != nil {
return trace.Wrap(err)
}
@ -177,7 +177,7 @@ func (c *Client) readClientUsername() error {
func (c *Client) readClientSize() error {
for {
msg, err := c.cfg.InputMessage()
msg, err := c.cfg.Conn.InputMessage()
if err != nil {
return trace.Wrap(err)
}
@ -255,7 +255,7 @@ func (c *Client) start() {
// Remember mouse coordinates to send them with all CGOPointer events.
var mouseX, mouseY uint32
for {
msg, err := c.cfg.InputMessage()
msg, err := c.cfg.Conn.InputMessage()
if err != nil {
c.cfg.Log.Warningf("Failed reading RDP input message: %v", err)
return
@ -383,7 +383,7 @@ func (c *Client) handleBitmap(cb C.CGOBitmap) C.CGOError {
})
copy(img.Pix, data)
if err := c.cfg.OutputMessage(tdp.PNGFrame{Img: img}); err != nil {
if err := c.cfg.Conn.OutputMessage(tdp.NewPNG(img, c.cfg.Encoder)); err != nil {
return C.CString(fmt.Sprintf("failed to send PNG frame %v: %v", img.Rect, err))
}
return nil

View file

@ -19,6 +19,7 @@ package rdpclient
import (
"context"
"image/png"
"github.com/gravitational/teleport/lib/srv/desktop/tdp"
"github.com/gravitational/trace"
@ -35,13 +36,12 @@ type Config struct {
// AuthorizeFn is called to authorize a user connecting to a Windows desktop.
AuthorizeFn func(login string) error
// TODO(zmb3): replace these callbacks with a tdp.Conn
// Conn handles TDP messages between Windows Desktop Service
// and a Teleport Proxy.
Conn *tdp.Conn
// InputMessage is called to receive a message from the client for the RDP
// server. This function should block until there is a message.
InputMessage func() (tdp.Message, error)
// OutputMessage is called to send a message from RDP server to the client.
OutputMessage func(tdp.Message) error
// Encoder is an optional override for PNG encoding.
Encoder *png.Encoder
// Log is the logger for status messages.
Log logrus.FieldLogger
@ -58,15 +58,15 @@ func (c *Config) checkAndSetDefaults() error {
if c.GenerateUserCert == nil {
return trace.BadParameter("missing GenerateUserCert in rdpclient.Config")
}
if c.InputMessage == nil {
return trace.BadParameter("missing InputMessage in rdpclient.Config")
}
if c.OutputMessage == nil {
return trace.BadParameter("missing OutputMessage in rdpclient.Config")
if c.Conn == nil {
return trace.BadParameter("missing Conn in rdpclient.Config")
}
if c.AuthorizeFn == nil {
return trace.BadParameter("missing AuthorizeFn in rdpclient.Config")
}
if c.Encoder == nil {
c.Encoder = tdp.PNGEncoder()
}
if c.Log == nil {
c.Log = logrus.New()
}

View file

@ -0,0 +1,38 @@
// Copyright 2022 Gravitational, Inc
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tdp
import "image/png"
// PNGEncoder returns the encoder used for PNG Frames.
// It is not safe for concurrent use.
func PNGEncoder() *png.Encoder {
return &png.Encoder{
CompressionLevel: png.BestSpeed,
BufferPool: &pool{},
}
}
// pool implements png.EncoderBufferPool,
// allowing us to reuse encoding resources
type pool struct {
b *png.EncoderBuffer
}
// all encoding happens in a single thread, so we don't
// need anything as sophisticated as a sync.Pool here
func (p *pool) Get() *png.EncoderBuffer { return p.b }
func (p *pool) Put(eb *png.EncoderBuffer) { p.b = eb }

View file

@ -104,6 +104,15 @@ func decode(in peekReader) (Message, error) {
// https://github.com/gravitational/teleport/blob/master/rfd/0037-desktop-access-protocol.md#2---png-frame
type PNGFrame struct {
Img image.Image
enc *png.Encoder // optionally override the PNG encoder
}
func NewPNG(img image.Image, enc *png.Encoder) PNGFrame {
return PNGFrame{
Img: img,
enc: enc,
}
}
func (f PNGFrame) Encode() ([]byte, error) {
@ -123,10 +132,11 @@ func (f PNGFrame) Encode() ([]byte, error) {
}); err != nil {
return nil, trace.Wrap(err)
}
// Note: this uses the default png.Encoder parameters.
// You can tweak compression level and reduce memory allocations by using a
// custom png.Encoder, if this happens to be a bottleneck.
if err := png.Encode(buf, f.Img); err != nil {
encoder := f.enc
if encoder == nil {
encoder = &png.Encoder{}
}
if err := encoder.Encode(buf, f.Img); err != nil {
return nil, trace.Wrap(err)
}
return buf.Bytes(), nil

View file

@ -17,11 +17,16 @@ limitations under the License.
package tdp
import (
"bufio"
"encoding/json"
"image"
"image/color"
"os"
"path/filepath"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/stretchr/testify/require"
)
@ -51,7 +56,7 @@ func TestEncodeDecode(t *testing.T) {
out, err := Decode(buf)
require.NoError(t, err)
require.Empty(t, cmp.Diff(m, out))
require.Empty(t, cmp.Diff(m, out, cmpopts.IgnoreUnexported(PNGFrame{})))
}
}
@ -60,3 +65,49 @@ func TestBadDecode(t *testing.T) {
_, err := Decode([]byte{254})
require.Error(t, err)
}
var encodedFrame []byte
func BenchmarkEncodePNG(b *testing.B) {
b.StopTimer()
frames := loadBitmaps(b)
b.StartTimer()
var err error
for i := 0; i < b.N; i++ {
fi := i % len(frames)
encodedFrame, err = frames[fi].Encode()
if err != nil {
b.Fatal(err)
}
}
}
func loadBitmaps(b *testing.B) []PNGFrame {
b.Helper()
f, err := os.Open(filepath.Join("testdata", "png_frames.json"))
require.NoError(b, err)
defer f.Close()
enc := PNGEncoder()
var result []PNGFrame
type record struct {
Top, Left, Right, Bottom int
Pix []byte
}
s := bufio.NewScanner(f)
for s.Scan() {
var r record
require.NoError(b, json.Unmarshal(s.Bytes(), &r))
img := image.NewNRGBA(image.Rectangle{
Min: image.Pt(r.Left, r.Top),
Max: image.Pt(r.Right, r.Bottom),
})
copy(img.Pix, r.Pix)
result = append(result, NewPNG(img, enc))
}
require.NoError(b, s.Err())
return result
}

File diff suppressed because one or more lines are too long

View file

@ -547,10 +547,9 @@ func (s *WindowsService) connectRDP(ctx context.Context, log logrus.FieldLogger,
GenerateUserCert: func(ctx context.Context, username string) (certDER, keyDER []byte, err error) {
return s.generateCredentials(ctx, username, desktop.GetDomain())
},
Addr: desktop.GetAddr(),
InputMessage: tdpConn.InputMessage,
OutputMessage: tdpConn.OutputMessage,
AuthorizeFn: authorize,
Addr: desktop.GetAddr(),
Conn: tdpConn,
AuthorizeFn: authorize,
})
if err != nil {
s.onSessionStart(ctx, &identity, windowsUser, string(sessionID), desktop, err)