refactor: restructure session initialization to avoid circular references
This commit is contained in:
@ -11,6 +11,7 @@ import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"tunnel_pls/session"
|
||||
"tunnel_pls/session/interaction"
|
||||
"tunnel_pls/utils"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
@ -29,12 +30,16 @@ type CustomWriter struct {
|
||||
buf []byte
|
||||
respHeader *ResponseHeaderFactory
|
||||
reqHeader *RequestHeaderFactory
|
||||
interaction *session.Interaction
|
||||
interaction interaction.InteractionController
|
||||
respMW []ResponseMiddleware
|
||||
reqStartMW []RequestMiddleware
|
||||
reqEndMW []RequestMiddleware
|
||||
}
|
||||
|
||||
func (cw *CustomWriter) SetInteraction(interaction interaction.InteractionController) {
|
||||
cw.interaction = interaction
|
||||
}
|
||||
|
||||
func (cw *CustomWriter) Read(p []byte) (int, error) {
|
||||
tmp := make([]byte, len(p))
|
||||
read, err := cw.reader.Read(tmp)
|
||||
@ -177,7 +182,7 @@ func (cw *CustomWriter) Write(p []byte) (int, error) {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (cw *CustomWriter) AddInteraction(interaction *session.Interaction) {
|
||||
func (cw *CustomWriter) AddInteraction(interaction *interaction.Interaction) {
|
||||
cw.interaction = interaction
|
||||
}
|
||||
|
||||
@ -287,16 +292,15 @@ func Handler(conn net.Conn) {
|
||||
return
|
||||
}
|
||||
cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr())
|
||||
|
||||
cw.SetInteraction(sshSession.Interaction)
|
||||
forwardRequest(cw, reqhf, sshSession)
|
||||
return
|
||||
}
|
||||
|
||||
func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshSession *session.SSHSession) {
|
||||
cw.AddInteraction(sshSession.Interaction)
|
||||
originHost, originPort := ParseAddr(cw.RemoteAddr.String())
|
||||
payload := createForwardedTCPIPPayload(originHost, uint16(originPort), sshSession.Forwarder.GetForwardedPort())
|
||||
channel, reqs, err := sshSession.Conn.OpenChannel("forwarded-tcpip", payload)
|
||||
channel, reqs, err := sshSession.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
|
||||
if err != nil {
|
||||
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
|
||||
sendBadGatewayResponse(cw)
|
||||
|
||||
Reference in New Issue
Block a user