diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go index b894536..03528c9 100644 --- a/session/forwarder/forwarder.go +++ b/session/forwarder/forwarder.go @@ -46,16 +46,17 @@ func New(config config.Config, slug slug.Slug, conn ssh.Conn) Forwarder { bufferPool: sync.Pool{ New: func() interface{} { bufSize := config.BufferSize() - return make([]byte, bufSize) + buf := make([]byte, bufSize) + return &buf }, }, } } func (f *forwarder) copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) { - buf := f.bufferPool.Get().([]byte) + buf := f.bufferPool.Get().(*[]byte) defer f.bufferPool.Put(buf) - return io.CopyBuffer(dst, src, buf) + return io.CopyBuffer(dst, src, *buf) } func (f *forwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) { diff --git a/session/forwarder/forwarder_test.go b/session/forwarder/forwarder_test.go index c3e1284..092d783 100644 --- a/session/forwarder/forwarder_test.go +++ b/session/forwarder/forwarder_test.go @@ -273,8 +273,8 @@ func TestNew(t *testing.T) { forwarder := New(cfg, s, conn).(*forwarder) - buf := forwarder.bufferPool.Get().([]byte) - require.Len(t, buf, tt.wantBufLen) + buf := forwarder.bufferPool.Get().(*[]byte) + require.Len(t, *buf, tt.wantBufLen) forwarder.bufferPool.Put(buf) assert.Equal(t, types.TunnelTypeUNKNOWN, forwarder.TunnelType()) @@ -1082,8 +1082,9 @@ func TestCopyWithBufferReusesBuffer(t *testing.T) { cfg.On("BufferSize").Return(16).Maybe() forwarder := New(cfg, slug.New(), nil).(*forwarder) - buf1 := forwarder.bufferPool.Get().([]byte) - initialPtr := &buf1[0] + buf1 := forwarder.bufferPool.Get().(*[]byte) + initialPtr := buf1 + forwarder.bufferPool.Put(buf1) src := io.NopCloser(bytes.NewReader([]byte("test"))) @@ -1091,16 +1092,19 @@ func TestCopyWithBufferReusesBuffer(t *testing.T) { _, err := forwarder.copyWithBuffer(dst, src) require.NoError(t, err) - buf2 := forwarder.bufferPool.Get().([]byte) - secondPtr := &buf2[0] + buf2 := forwarder.bufferPool.Get().(*[]byte) + secondPtr := buf2 + forwarder.bufferPool.Put(buf2) - assert.Equal(t, len(buf1), len(buf2)) + assert.Equal(t, initialPtr, secondPtr, "Buffers should be the same pointer") - assert.Len(t, buf2, 16) + assert.Len(t, *buf2, 16) + assert.Len(t, *buf1, 16) _ = initialPtr _ = secondPtr + cfg.AssertExpectations(t) }