fix: close connection on TCP/IP request timeout
Some checks failed
Docker Build and Push / build-and-push (push) Has been cancelled
Some checks failed
Docker Build and Push / build-and-push (push) Has been cancelled
This commit is contained in:
@ -1,10 +1,11 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
"log"
|
||||
"net"
|
||||
"tunnel_pls/session"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func (s *Server) handleConnection(conn net.Conn) {
|
||||
@ -21,16 +22,7 @@ func (s *Server) handleConnection(conn net.Conn) {
|
||||
|
||||
log.Println("SSH connection established:", sshConn.User())
|
||||
|
||||
newSession := session.New(sshConn, forwardingReqs)
|
||||
for ch := range chans {
|
||||
newSession.ChannelChan <- ch
|
||||
}
|
||||
session.New(sshConn, forwardingReqs, chans)
|
||||
|
||||
defer func(newSession *session.Session) {
|
||||
err := newSession.Close()
|
||||
if err != nil {
|
||||
log.Printf("failed to close session: %v", err)
|
||||
}
|
||||
}(newSession)
|
||||
return
|
||||
}
|
||||
|
||||
@ -23,12 +23,6 @@ import (
|
||||
|
||||
type SessionStatus string
|
||||
|
||||
const (
|
||||
INITIALIZING SessionStatus = "INITIALIZING"
|
||||
RUNNING SessionStatus = "RUNNING"
|
||||
SETUP SessionStatus = "SETUP"
|
||||
)
|
||||
|
||||
var forbiddenSlug = []string{
|
||||
"ping",
|
||||
}
|
||||
@ -81,18 +75,10 @@ func updateClientSlug(oldSlug, newSlug string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Session) safeClose() {
|
||||
s.once.Do(func() {
|
||||
close(s.ChannelChan)
|
||||
close(s.Done)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Session) Close() error {
|
||||
if s.Listener != nil {
|
||||
err := s.Listener.Close()
|
||||
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
fmt.Println("1")
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -100,7 +86,6 @@ func (s *Session) Close() error {
|
||||
if s.ConnChannel != nil {
|
||||
err := s.ConnChannel.Close()
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
fmt.Println("2")
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -108,8 +93,6 @@ func (s *Session) Close() error {
|
||||
if s.Connection != nil {
|
||||
err := s.Connection.Close()
|
||||
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
fmt.Println("3")
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -121,12 +104,10 @@ func (s *Session) Close() error {
|
||||
if s.TunnelType == TCP {
|
||||
err := portUtil.Manager.SetPortStatus(s.ForwardedPort, false)
|
||||
if err != nil {
|
||||
fmt.Println("4")
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
s.safeClose()
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -184,9 +165,8 @@ func (s *Session) handleTCPIPForward(req *ssh.Request) {
|
||||
}
|
||||
|
||||
s.sendMessage("\033[H\033[2J")
|
||||
|
||||
showWelcomeMessage(s.ConnChannel)
|
||||
s.Status = RUNNING
|
||||
showWelcomeMessage(s.ConnChannel)
|
||||
go s.handleUserInput()
|
||||
|
||||
if portToBind == 80 || portToBind == 443 {
|
||||
@ -210,7 +190,6 @@ func (s *Session) handleTCPIPForward(req *ssh.Request) {
|
||||
}
|
||||
portUtil.Manager.SetPortStatus(portToBind, true)
|
||||
}
|
||||
|
||||
s.handleTCPForward(req, addr, portToBind)
|
||||
}
|
||||
|
||||
@ -248,8 +227,6 @@ func (s *Session) handleHTTPForward(req *ssh.Request, portToBind uint16) {
|
||||
binary.Write(buf, binary.BigEndian, uint32(80))
|
||||
log.Printf("HTTP forwarding approved on port: %d", 80)
|
||||
|
||||
s.waitForRunningStatus()
|
||||
|
||||
domain := utils.Getenv("domain")
|
||||
protocol := "http"
|
||||
if utils.Getenv("tls_enabled") == "true" {
|
||||
@ -321,17 +298,23 @@ func (s *Session) generateUniqueSlug() string {
|
||||
}
|
||||
|
||||
func (s *Session) waitForRunningStatus() {
|
||||
timeout := time.After(10 * time.Second)
|
||||
ticker := time.NewTicker(500 * time.Millisecond)
|
||||
timeout := time.After(3 * time.Second)
|
||||
ticker := time.NewTicker(150 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
frames := []string{"-", "\\", "|", "/"}
|
||||
i := 0
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.sendMessage(fmt.Sprintf("\rLoading %s", frames[i]))
|
||||
i = (i + 1) % len(frames)
|
||||
if s.Status == RUNNING {
|
||||
return
|
||||
}
|
||||
case <-timeout:
|
||||
s.sendMessage("\r\033[K")
|
||||
s.sendMessage("TCP/IP request not received in time.\r\nCheck your internet connection and confirm the server responds within 3000ms.\r\nEnsure you ran the correct command. For more details, visit https://tunnl.live.\r\n\r\n")
|
||||
s.Close()
|
||||
log.Println("Timeout waiting for session to start running")
|
||||
return
|
||||
}
|
||||
|
||||
@ -1,9 +1,16 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
INITIALIZING SessionStatus = "INITIALIZING"
|
||||
RUNNING SessionStatus = "RUNNING"
|
||||
SETUP SessionStatus = "SETUP"
|
||||
)
|
||||
|
||||
type TunnelType string
|
||||
@ -11,7 +18,6 @@ type TunnelType string
|
||||
const (
|
||||
HTTP TunnelType = "http"
|
||||
TCP TunnelType = "tcp"
|
||||
UDP TunnelType = "udp"
|
||||
UNKNOWN TunnelType = "unknown"
|
||||
)
|
||||
|
||||
@ -23,33 +29,32 @@ type Session struct {
|
||||
ForwardedPort uint16
|
||||
Status SessionStatus
|
||||
Slug string
|
||||
ChannelChan chan ssh.NewChannel
|
||||
Done chan bool
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request) *Session {
|
||||
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel) {
|
||||
session := &Session{
|
||||
Status: SETUP,
|
||||
Status: INITIALIZING,
|
||||
Slug: "",
|
||||
ConnChannel: nil,
|
||||
Connection: conn,
|
||||
TunnelType: UNKNOWN,
|
||||
ChannelChan: make(chan ssh.NewChannel),
|
||||
Done: make(chan bool),
|
||||
}
|
||||
|
||||
go func() {
|
||||
for channel := range session.ChannelChan {
|
||||
go session.waitForRunningStatus()
|
||||
|
||||
for channel := range sshChan {
|
||||
ch, reqs, _ := channel.Accept()
|
||||
if session.ConnChannel == nil {
|
||||
session.ConnChannel = ch
|
||||
session.Status = RUNNING
|
||||
session.Status = SETUP
|
||||
go session.HandleGlobalRequest(forwardingReq)
|
||||
}
|
||||
go session.HandleGlobalRequest(reqs)
|
||||
}
|
||||
}()
|
||||
|
||||
return session
|
||||
err := session.Close()
|
||||
if err != nil {
|
||||
log.Printf("failed to close session: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user