refactor: restructure session initialization to avoid circular references
This commit is contained in:
@ -3,12 +3,9 @@ package session
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"strconv"
|
||||
portUtil "tunnel_pls/internal/port"
|
||||
"tunnel_pls/types"
|
||||
|
||||
@ -17,10 +14,7 @@ import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type UserConnection struct {
|
||||
Reader io.Reader
|
||||
Writer net.Conn
|
||||
}
|
||||
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
|
||||
|
||||
func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
|
||||
for req := range GlobalRequest {
|
||||
@ -157,23 +151,6 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) {
|
||||
s.HandleTCPForward(req, addr, portToBind)
|
||||
}
|
||||
|
||||
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
|
||||
|
||||
func isBlockedPort(port uint16) bool {
|
||||
if port == 80 || port == 443 {
|
||||
return false
|
||||
}
|
||||
if port < 1024 && port != 0 {
|
||||
return true
|
||||
}
|
||||
for _, p := range blockedReservedPorts {
|
||||
if p == port {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) {
|
||||
s.Forwarder.SetType(types.HTTP)
|
||||
s.Forwarder.SetForwardedPort(portToBind)
|
||||
@ -237,7 +214,7 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind
|
||||
s.Interaction.ShowWelcomeMessage()
|
||||
s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.Forwarder.GetTunnelType(), utils.Getenv("domain"), s.Forwarder.GetForwardedPort()))
|
||||
|
||||
go s.acceptTCPConnections()
|
||||
go s.Forwarder.AcceptTCPConnections()
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
err = binary.Write(buf, binary.BigEndian, uint32(portToBind))
|
||||
@ -253,37 +230,6 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SSHSession) acceptTCPConnections() {
|
||||
for {
|
||||
conn, err := s.Forwarder.GetListener().Accept()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
log.Printf("Error accepting connection: %v", err)
|
||||
continue
|
||||
}
|
||||
originHost, originPort := ParseAddr(conn.RemoteAddr().String())
|
||||
payload := createForwardedTCPIPPayload(originHost, uint16(originPort), s.Forwarder.GetForwardedPort())
|
||||
channel, reqs, err := s.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
|
||||
if err != nil {
|
||||
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
for req := range reqs {
|
||||
err := req.Reply(false, nil)
|
||||
if err != nil {
|
||||
log.Printf("Failed to reply to request: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
go s.HandleForwardedConnection(conn, channel, conn.RemoteAddr())
|
||||
}
|
||||
}
|
||||
|
||||
func generateUniqueSlug() string {
|
||||
maxAttempts := 5
|
||||
|
||||
@ -303,30 +249,6 @@ func generateUniqueSlug() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *SSHSession) HandleForwardedConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) {
|
||||
defer func(src ssh.Channel) {
|
||||
err := src.Close()
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
log.Printf("Error closing connection: %v", err)
|
||||
}
|
||||
}(src)
|
||||
log.Printf("Handling new forwarded connection from %s", remoteAddr)
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(src, dst)
|
||||
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
|
||||
log.Printf("Error copying from conn.Reader to channel: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
_, err := io.Copy(dst, src)
|
||||
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
log.Printf("Error copying from channel to conn.Writer: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func readSSHString(reader *bytes.Reader) (string, error) {
|
||||
var length uint32
|
||||
if err := binary.Read(reader, binary.BigEndian, &length); err != nil {
|
||||
@ -339,40 +261,17 @@ func readSSHString(reader *bytes.Reader) (string, error) {
|
||||
return string(strBytes), nil
|
||||
}
|
||||
|
||||
func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte {
|
||||
var buf bytes.Buffer
|
||||
|
||||
writeSSHString(&buf, "localhost")
|
||||
err := binary.Write(&buf, binary.BigEndian, uint32(port))
|
||||
if err != nil {
|
||||
log.Printf("Failed to write string to buffer: %v", err)
|
||||
return nil
|
||||
func isBlockedPort(port uint16) bool {
|
||||
if port == 80 || port == 443 {
|
||||
return false
|
||||
}
|
||||
writeSSHString(&buf, host)
|
||||
err = binary.Write(&buf, binary.BigEndian, uint32(originPort))
|
||||
if err != nil {
|
||||
log.Printf("Failed to write string to buffer: %v", err)
|
||||
return nil
|
||||
if port < 1024 && port != 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func writeSSHString(buffer *bytes.Buffer, str string) {
|
||||
err := binary.Write(buffer, binary.BigEndian, uint32(len(str)))
|
||||
if err != nil {
|
||||
log.Printf("Failed to write string to buffer: %v", err)
|
||||
return
|
||||
}
|
||||
buffer.WriteString(str)
|
||||
}
|
||||
|
||||
func ParseAddr(addr string) (string, uint32) {
|
||||
host, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr)
|
||||
return "0.0.0.0", uint32(0)
|
||||
}
|
||||
port, _ := strconv.Atoi(portStr)
|
||||
return host, uint32(port)
|
||||
for _, p := range blockedReservedPorts {
|
||||
if p == port {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user