diff --git a/server/http.go b/server/http.go index 3d9ac0f..18bbf38 100644 --- a/server/http.go +++ b/server/http.go @@ -12,16 +12,10 @@ import ( "strings" "tunnel_pls/session" "tunnel_pls/session/interaction" + "tunnel_pls/types" "tunnel_pls/utils" - - "golang.org/x/crypto/ssh" ) -var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" + - "Content-Length: 11\r\n" + - "Content-Type: text/plain\r\n\r\n" + - "Bad Gateway") - type CustomWriter struct { RemoteAddr net.Addr writer io.Writer @@ -130,7 +124,7 @@ func isHTTPHeader(buf []byte) bool { } func (cw *CustomWriter) Write(p []byte) (int, error) { - if len(p) == len(BadGatewayResponse) && bytes.Equal(p, BadGatewayResponse) { + if len(p) == len(types.BadGatewayResponse) && bytes.Equal(p, types.BadGatewayResponse) { return cw.writer.Write(p) } @@ -216,7 +210,7 @@ func NewHTTPServer() error { func Handler(conn net.Conn) { defer func() { err := conn.Close() - if err != nil { + if err != nil && !errors.Is(err, net.ErrClosed) { log.Printf("Error closing connection: %v", err) return } @@ -302,20 +296,8 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS channel, reqs, err := sshSession.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) if err != nil { log.Printf("Failed to open forwarded-tcpip channel: %v", err) - sendBadGatewayResponse(cw) return } - defer func(channel ssh.Channel) { - err := channel.Close() - if err != nil { - if errors.Is(err, io.EOF) { - sendBadGatewayResponse(cw) - return - } - log.Println("Failed to close connection:", err) - return - } - }(channel) go func() { for req := range reqs { @@ -352,11 +334,3 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS sshSession.Forwarder.HandleConnection(cw, channel, cw.RemoteAddr) return } - -func sendBadGatewayResponse(writer io.Writer) { - _, err := writer.Write(BadGatewayResponse) - if err != nil { - log.Printf("failed to write Bad Gateway response: %v", err) - return - } -} diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go index 450184d..3d846e6 100644 --- a/session/forwarder/forwarder.go +++ b/session/forwarder/forwarder.go @@ -38,6 +38,7 @@ type ForwardingController interface { HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) SetLifecycle(lifecycle Lifecycle) CreateForwardedTCPIPPayload(origin net.Addr) []byte + WriteBadGatewayResponse(dst io.Writer) } func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) { @@ -76,7 +77,12 @@ func (f *Forwarder) AcceptTCPConnections() { func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) { defer func(src ssh.Channel) { - err := src.Close() + _, err := io.Copy(io.Discard, src) + if err != nil { + log.Printf("Failed to discard connection: %v", err) + } + + err = src.Close() if err != nil && !errors.Is(err, io.EOF) { log.Printf("Error closing connection: %v", err) } @@ -122,6 +128,14 @@ func (f *Forwarder) GetListener() net.Listener { return f.Listener } +func (f *Forwarder) WriteBadGatewayResponse(dst io.Writer) { + _, err := dst.Write(types.BadGatewayResponse) + if err != nil { + log.Printf("failed to write Bad Gateway response: %v", err) + return + } +} + func (f *Forwarder) Close() error { if f.GetTunnelType() != types.HTTP { return f.Listener.Close() diff --git a/session/interaction/interaction.go b/session/interaction/interaction.go index 181b3a4..0c998c4 100644 --- a/session/interaction/interaction.go +++ b/session/interaction/interaction.go @@ -98,7 +98,6 @@ func (i *Interaction) HandleUserInput() { if char == 8 || char == 127 { if i.InputLength > 0 { - //i.CommandBuffer.Truncate(i.CommandBuffer.Len() - 1) i.SendMessage("\b \b") } if i.CommandBuffer.Len() > 0 { diff --git a/types/types.go b/types/types.go index c007661..f909da5 100644 --- a/types/types.go +++ b/types/types.go @@ -14,3 +14,8 @@ const ( HTTP TunnelType = "HTTP" TCP TunnelType = "TCP" ) + +var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" + + "Content-Length: 11\r\n" + + "Content-Type: text/plain\r\n\r\n" + + "Bad Gateway")