2 Commits

4 changed files with 30 additions and 31 deletions

View File

@@ -12,16 +12,10 @@ import (
"strings" "strings"
"tunnel_pls/session" "tunnel_pls/session"
"tunnel_pls/session/interaction" "tunnel_pls/session/interaction"
"tunnel_pls/types"
"tunnel_pls/utils" "tunnel_pls/utils"
"golang.org/x/crypto/ssh"
) )
var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" +
"Content-Length: 11\r\n" +
"Content-Type: text/plain\r\n\r\n" +
"Bad Gateway")
type CustomWriter struct { type CustomWriter struct {
RemoteAddr net.Addr RemoteAddr net.Addr
writer io.Writer writer io.Writer
@@ -130,7 +124,7 @@ func isHTTPHeader(buf []byte) bool {
} }
func (cw *CustomWriter) Write(p []byte) (int, error) { func (cw *CustomWriter) Write(p []byte) (int, error) {
if len(p) == len(BadGatewayResponse) && bytes.Equal(p, BadGatewayResponse) { if len(p) == len(types.BadGatewayResponse) && bytes.Equal(p, types.BadGatewayResponse) {
return cw.writer.Write(p) return cw.writer.Write(p)
} }
@@ -216,7 +210,7 @@ func NewHTTPServer() error {
func Handler(conn net.Conn) { func Handler(conn net.Conn) {
defer func() { defer func() {
err := conn.Close() err := conn.Close()
if err != nil { if err != nil && !errors.Is(err, net.ErrClosed) {
log.Printf("Error closing connection: %v", err) log.Printf("Error closing connection: %v", err)
return return
} }
@@ -302,20 +296,8 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS
channel, reqs, err := sshSession.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) channel, reqs, err := sshSession.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
if err != nil { if err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", err) log.Printf("Failed to open forwarded-tcpip channel: %v", err)
sendBadGatewayResponse(cw)
return return
} }
defer func(channel ssh.Channel) {
err := channel.Close()
if err != nil {
if errors.Is(err, io.EOF) {
sendBadGatewayResponse(cw)
return
}
log.Println("Failed to close connection:", err)
return
}
}(channel)
go func() { go func() {
for req := range reqs { for req := range reqs {
@@ -352,11 +334,3 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS
sshSession.Forwarder.HandleConnection(cw, channel, cw.RemoteAddr) sshSession.Forwarder.HandleConnection(cw, channel, cw.RemoteAddr)
return return
} }
func sendBadGatewayResponse(writer io.Writer) {
_, err := writer.Write(BadGatewayResponse)
if err != nil {
log.Printf("failed to write Bad Gateway response: %v", err)
return
}
}

View File

@@ -38,6 +38,7 @@ type ForwardingController interface {
HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr)
SetLifecycle(lifecycle Lifecycle) SetLifecycle(lifecycle Lifecycle)
CreateForwardedTCPIPPayload(origin net.Addr) []byte CreateForwardedTCPIPPayload(origin net.Addr) []byte
WriteBadGatewayResponse(dst io.Writer)
} }
func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) { func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) {
@@ -76,7 +77,12 @@ func (f *Forwarder) AcceptTCPConnections() {
func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) { func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) {
defer func(src ssh.Channel) { defer func(src ssh.Channel) {
err := src.Close() _, err := io.Copy(io.Discard, src)
if err != nil {
log.Printf("Failed to discard connection: %v", err)
}
err = src.Close()
if err != nil && !errors.Is(err, io.EOF) { if err != nil && !errors.Is(err, io.EOF) {
log.Printf("Error closing connection: %v", err) log.Printf("Error closing connection: %v", err)
} }
@@ -122,6 +128,14 @@ func (f *Forwarder) GetListener() net.Listener {
return f.Listener return f.Listener
} }
func (f *Forwarder) WriteBadGatewayResponse(dst io.Writer) {
_, err := dst.Write(types.BadGatewayResponse)
if err != nil {
log.Printf("failed to write Bad Gateway response: %v", err)
return
}
}
func (f *Forwarder) Close() error { func (f *Forwarder) Close() error {
if f.GetTunnelType() != types.HTTP { if f.GetTunnelType() != types.HTTP {
return f.Listener.Close() return f.Listener.Close()

View File

@@ -44,6 +44,7 @@ type Forwarder interface {
} }
type Interaction struct { type Interaction struct {
InputLength int
CommandBuffer *bytes.Buffer CommandBuffer *bytes.Buffer
EditMode bool EditMode bool
EditSlug string EditSlug string
@@ -96,13 +97,17 @@ func (i *Interaction) HandleUserInput() {
i.SendMessage(string(buf[:n])) i.SendMessage(string(buf[:n]))
if char == 8 || char == 127 { if char == 8 || char == 127 {
if i.InputLength > 0 {
i.SendMessage("\b \b")
}
if i.CommandBuffer.Len() > 0 { if i.CommandBuffer.Len() > 0 {
i.CommandBuffer.Truncate(i.CommandBuffer.Len() - 1) i.CommandBuffer.Truncate(i.CommandBuffer.Len() - 1)
i.SendMessage("\b \b")
} }
continue continue
} }
i.InputLength += n
if char == '/' { if char == '/' {
i.CommandBuffer.Reset() i.CommandBuffer.Reset()
i.CommandBuffer.WriteByte(char) i.CommandBuffer.WriteByte(char)
@@ -111,6 +116,7 @@ func (i *Interaction) HandleUserInput() {
if i.CommandBuffer.Len() > 0 { if i.CommandBuffer.Len() > 0 {
if char == 13 { if char == 13 {
i.SendMessage("\033[K")
i.HandleCommand(i.CommandBuffer.String()) i.HandleCommand(i.CommandBuffer.String())
continue continue
} }

View File

@@ -14,3 +14,8 @@ const (
HTTP TunnelType = "HTTP" HTTP TunnelType = "HTTP"
TCP TunnelType = "TCP" TCP TunnelType = "TCP"
) )
var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" +
"Content-Length: 11\r\n" +
"Content-Type: text/plain\r\n\r\n" +
"Bad Gateway")