fix: ensure proper buffer reuse with pointer handling in sync.Pool
This commit is contained in:
@@ -46,16 +46,17 @@ func New(config config.Config, slug slug.Slug, conn ssh.Conn) Forwarder {
|
||||
bufferPool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
bufSize := config.BufferSize()
|
||||
return make([]byte, bufSize)
|
||||
buf := make([]byte, bufSize)
|
||||
return &buf
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (f *forwarder) copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
|
||||
buf := f.bufferPool.Get().([]byte)
|
||||
buf := f.bufferPool.Get().(*[]byte)
|
||||
defer f.bufferPool.Put(buf)
|
||||
return io.CopyBuffer(dst, src, buf)
|
||||
return io.CopyBuffer(dst, src, *buf)
|
||||
}
|
||||
|
||||
func (f *forwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) {
|
||||
|
||||
@@ -273,8 +273,8 @@ func TestNew(t *testing.T) {
|
||||
|
||||
forwarder := New(cfg, s, conn).(*forwarder)
|
||||
|
||||
buf := forwarder.bufferPool.Get().([]byte)
|
||||
require.Len(t, buf, tt.wantBufLen)
|
||||
buf := forwarder.bufferPool.Get().(*[]byte)
|
||||
require.Len(t, *buf, tt.wantBufLen)
|
||||
forwarder.bufferPool.Put(buf)
|
||||
|
||||
assert.Equal(t, types.TunnelTypeUNKNOWN, forwarder.TunnelType())
|
||||
@@ -1082,8 +1082,9 @@ func TestCopyWithBufferReusesBuffer(t *testing.T) {
|
||||
cfg.On("BufferSize").Return(16).Maybe()
|
||||
forwarder := New(cfg, slug.New(), nil).(*forwarder)
|
||||
|
||||
buf1 := forwarder.bufferPool.Get().([]byte)
|
||||
initialPtr := &buf1[0]
|
||||
buf1 := forwarder.bufferPool.Get().(*[]byte)
|
||||
initialPtr := buf1
|
||||
|
||||
forwarder.bufferPool.Put(buf1)
|
||||
|
||||
src := io.NopCloser(bytes.NewReader([]byte("test")))
|
||||
@@ -1091,16 +1092,19 @@ func TestCopyWithBufferReusesBuffer(t *testing.T) {
|
||||
_, err := forwarder.copyWithBuffer(dst, src)
|
||||
require.NoError(t, err)
|
||||
|
||||
buf2 := forwarder.bufferPool.Get().([]byte)
|
||||
secondPtr := &buf2[0]
|
||||
buf2 := forwarder.bufferPool.Get().(*[]byte)
|
||||
secondPtr := buf2
|
||||
|
||||
forwarder.bufferPool.Put(buf2)
|
||||
|
||||
assert.Equal(t, len(buf1), len(buf2))
|
||||
assert.Equal(t, initialPtr, secondPtr, "Buffers should be the same pointer")
|
||||
|
||||
assert.Len(t, buf2, 16)
|
||||
assert.Len(t, *buf2, 16)
|
||||
assert.Len(t, *buf1, 16)
|
||||
|
||||
_ = initialPtr
|
||||
_ = secondPtr
|
||||
|
||||
cfg.AssertExpectations(t)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user