fix: discard unused buffers in the ssh channel before disconnecting
This commit is contained in:
@ -12,16 +12,10 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"tunnel_pls/session"
|
"tunnel_pls/session"
|
||||||
"tunnel_pls/session/interaction"
|
"tunnel_pls/session/interaction"
|
||||||
|
"tunnel_pls/types"
|
||||||
"tunnel_pls/utils"
|
"tunnel_pls/utils"
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" +
|
|
||||||
"Content-Length: 11\r\n" +
|
|
||||||
"Content-Type: text/plain\r\n\r\n" +
|
|
||||||
"Bad Gateway")
|
|
||||||
|
|
||||||
type CustomWriter struct {
|
type CustomWriter struct {
|
||||||
RemoteAddr net.Addr
|
RemoteAddr net.Addr
|
||||||
writer io.Writer
|
writer io.Writer
|
||||||
@ -130,7 +124,7 @@ func isHTTPHeader(buf []byte) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (cw *CustomWriter) Write(p []byte) (int, error) {
|
func (cw *CustomWriter) Write(p []byte) (int, error) {
|
||||||
if len(p) == len(BadGatewayResponse) && bytes.Equal(p, BadGatewayResponse) {
|
if len(p) == len(types.BadGatewayResponse) && bytes.Equal(p, types.BadGatewayResponse) {
|
||||||
return cw.writer.Write(p)
|
return cw.writer.Write(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -216,7 +210,7 @@ func NewHTTPServer() error {
|
|||||||
func Handler(conn net.Conn) {
|
func Handler(conn net.Conn) {
|
||||||
defer func() {
|
defer func() {
|
||||||
err := conn.Close()
|
err := conn.Close()
|
||||||
if err != nil {
|
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
log.Printf("Error closing connection: %v", err)
|
log.Printf("Error closing connection: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -302,20 +296,8 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS
|
|||||||
channel, reqs, err := sshSession.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
|
channel, reqs, err := sshSession.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
|
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
|
||||||
sendBadGatewayResponse(cw)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func(channel ssh.Channel) {
|
|
||||||
err := channel.Close()
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, io.EOF) {
|
|
||||||
sendBadGatewayResponse(cw)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Println("Failed to close connection:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}(channel)
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for req := range reqs {
|
for req := range reqs {
|
||||||
@ -352,11 +334,3 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS
|
|||||||
sshSession.Forwarder.HandleConnection(cw, channel, cw.RemoteAddr)
|
sshSession.Forwarder.HandleConnection(cw, channel, cw.RemoteAddr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func sendBadGatewayResponse(writer io.Writer) {
|
|
||||||
_, err := writer.Write(BadGatewayResponse)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("failed to write Bad Gateway response: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@ -38,6 +38,7 @@ type ForwardingController interface {
|
|||||||
HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr)
|
HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr)
|
||||||
SetLifecycle(lifecycle Lifecycle)
|
SetLifecycle(lifecycle Lifecycle)
|
||||||
CreateForwardedTCPIPPayload(origin net.Addr) []byte
|
CreateForwardedTCPIPPayload(origin net.Addr) []byte
|
||||||
|
WriteBadGatewayResponse(dst io.Writer)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) {
|
func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) {
|
||||||
@ -76,7 +77,12 @@ func (f *Forwarder) AcceptTCPConnections() {
|
|||||||
|
|
||||||
func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) {
|
func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) {
|
||||||
defer func(src ssh.Channel) {
|
defer func(src ssh.Channel) {
|
||||||
err := src.Close()
|
_, err := io.Copy(io.Discard, src)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to discard connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = src.Close()
|
||||||
if err != nil && !errors.Is(err, io.EOF) {
|
if err != nil && !errors.Is(err, io.EOF) {
|
||||||
log.Printf("Error closing connection: %v", err)
|
log.Printf("Error closing connection: %v", err)
|
||||||
}
|
}
|
||||||
@ -122,6 +128,14 @@ func (f *Forwarder) GetListener() net.Listener {
|
|||||||
return f.Listener
|
return f.Listener
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) WriteBadGatewayResponse(dst io.Writer) {
|
||||||
|
_, err := dst.Write(types.BadGatewayResponse)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to write Bad Gateway response: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (f *Forwarder) Close() error {
|
func (f *Forwarder) Close() error {
|
||||||
if f.GetTunnelType() != types.HTTP {
|
if f.GetTunnelType() != types.HTTP {
|
||||||
return f.Listener.Close()
|
return f.Listener.Close()
|
||||||
|
|||||||
@ -98,7 +98,6 @@ func (i *Interaction) HandleUserInput() {
|
|||||||
|
|
||||||
if char == 8 || char == 127 {
|
if char == 8 || char == 127 {
|
||||||
if i.InputLength > 0 {
|
if i.InputLength > 0 {
|
||||||
//i.CommandBuffer.Truncate(i.CommandBuffer.Len() - 1)
|
|
||||||
i.SendMessage("\b \b")
|
i.SendMessage("\b \b")
|
||||||
}
|
}
|
||||||
if i.CommandBuffer.Len() > 0 {
|
if i.CommandBuffer.Len() > 0 {
|
||||||
|
|||||||
@ -14,3 +14,8 @@ const (
|
|||||||
HTTP TunnelType = "HTTP"
|
HTTP TunnelType = "HTTP"
|
||||||
TCP TunnelType = "TCP"
|
TCP TunnelType = "TCP"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" +
|
||||||
|
"Content-Length: 11\r\n" +
|
||||||
|
"Content-Type: text/plain\r\n\r\n" +
|
||||||
|
"Bad Gateway")
|
||||||
|
|||||||
Reference in New Issue
Block a user