diff --git a/session/handler.go b/session/handler.go index 1efad69..b30c7ff 100644 --- a/session/handler.go +++ b/session/handler.go @@ -11,7 +11,6 @@ import ( "log" "net" "strconv" - "strings" "sync" "time" portUtil "tunnel_pls/internal/port" @@ -37,71 +36,6 @@ var ( Clients = make(map[string]*SSHSession) ) -type HeaderModifier struct { - r io.Reader - headerBuf []byte - headerDone bool - state int -} - -func (hm *HeaderModifier) Read(p []byte) (int, error) { - n, err := hm.r.Read(p) - if n > 0 && !hm.headerDone { - for i := 0; i < n; i++ { - b := p[i] - hm.headerBuf = append(hm.headerBuf, b) - - switch hm.state { - case 0: - if b == '\r' { - hm.state = 1 - } - case 1: - if b == '\n' { - hm.state = 2 - } else { - hm.state = 0 - } - case 2: - if b == '\r' { - hm.state = 3 - } else { - hm.state = 0 - } - case 3: - if b == '\n' { - hm.headerDone = true - modifiedHeader := hm.modifyHeader(hm.headerBuf) - copy(p, modifiedHeader) - return len(modifiedHeader), nil - } else { - hm.state = 0 - } - } - } - } - - return n, err -} -func (hm *HeaderModifier) modifyHeader(header []byte) []byte { - lines := strings.Split(string(header), "\r\n") - found := false - - for i, line := range lines { - if strings.HasPrefix(strings.ToLower(line), "server:") { - lines[i] = "Server: tunnel_please" - found = true - } - } - - if !found { - lines = append(lines[:len(lines)-2], "Server: tunnel_please", "", "") - } - - modified := strings.Join(lines, "\r\n") - return []byte(modified) -} - func registerClient(slug string, session *SSHSession) bool { clientsMutex.Lock() defer clientsMutex.Unlock() @@ -586,12 +520,7 @@ func (s *SSHSession) HandleForwardedConnection(conn UserConnection, sshConn *ssh s.interaction.SendMessage(fmt.Sprintf("\033[32m%s -> [%s] TUNNEL ADDRESS -- \"%s\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.forwarder.TunnelType, timestamp)) - if s.forwarder.GetTunnelType() == HTTP { - ir := &HeaderModifier{r: reader} - _, err = io.Copy(conn.Writer, ir) - } else { - _, err = io.Copy(conn.Writer, reader) - } + _, err = io.Copy(conn.Writer, reader) if err != nil && !errors.Is(err, io.EOF) { log.Printf("Error copying from channel to conn.Writer: %v", err)