fix: ensure proper buffer reuse with pointer handling in sync.Pool

This commit is contained in:
2026-01-26 19:50:34 +07:00
parent a3f6baa6ae
commit 7f44cc7bc0
2 changed files with 16 additions and 11 deletions
+4 -3
View File
@@ -46,16 +46,17 @@ func New(config config.Config, slug slug.Slug, conn ssh.Conn) Forwarder {
bufferPool: sync.Pool{
New: func() interface{} {
bufSize := config.BufferSize()
return make([]byte, bufSize)
buf := make([]byte, bufSize)
return &buf
},
},
}
}
func (f *forwarder) copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
buf := f.bufferPool.Get().([]byte)
buf := f.bufferPool.Get().(*[]byte)
defer f.bufferPool.Put(buf)
return io.CopyBuffer(dst, src, buf)
return io.CopyBuffer(dst, src, *buf)
}
func (f *forwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) {
+12 -8
View File
@@ -273,8 +273,8 @@ func TestNew(t *testing.T) {
forwarder := New(cfg, s, conn).(*forwarder)
buf := forwarder.bufferPool.Get().([]byte)
require.Len(t, buf, tt.wantBufLen)
buf := forwarder.bufferPool.Get().(*[]byte)
require.Len(t, *buf, tt.wantBufLen)
forwarder.bufferPool.Put(buf)
assert.Equal(t, types.TunnelTypeUNKNOWN, forwarder.TunnelType())
@@ -1082,8 +1082,9 @@ func TestCopyWithBufferReusesBuffer(t *testing.T) {
cfg.On("BufferSize").Return(16).Maybe()
forwarder := New(cfg, slug.New(), nil).(*forwarder)
buf1 := forwarder.bufferPool.Get().([]byte)
initialPtr := &buf1[0]
buf1 := forwarder.bufferPool.Get().(*[]byte)
initialPtr := buf1
forwarder.bufferPool.Put(buf1)
src := io.NopCloser(bytes.NewReader([]byte("test")))
@@ -1091,16 +1092,19 @@ func TestCopyWithBufferReusesBuffer(t *testing.T) {
_, err := forwarder.copyWithBuffer(dst, src)
require.NoError(t, err)
buf2 := forwarder.bufferPool.Get().([]byte)
secondPtr := &buf2[0]
buf2 := forwarder.bufferPool.Get().(*[]byte)
secondPtr := buf2
forwarder.bufferPool.Put(buf2)
assert.Equal(t, len(buf1), len(buf2))
assert.Equal(t, initialPtr, secondPtr, "Buffers should be the same pointer")
assert.Len(t, buf2, 16)
assert.Len(t, *buf2, 16)
assert.Len(t, *buf1, 16)
_ = initialPtr
_ = secondPtr
cfg.AssertExpectations(t)
}