fix: close connection on TCP/IP request timeout
Some checks failed
Docker Build and Push / build-and-push (push) Has been cancelled
Some checks failed
Docker Build and Push / build-and-push (push) Has been cancelled
This commit is contained in:
@ -1,10 +1,11 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"tunnel_pls/session"
|
"tunnel_pls/session"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) handleConnection(conn net.Conn) {
|
func (s *Server) handleConnection(conn net.Conn) {
|
||||||
@ -21,16 +22,7 @@ func (s *Server) handleConnection(conn net.Conn) {
|
|||||||
|
|
||||||
log.Println("SSH connection established:", sshConn.User())
|
log.Println("SSH connection established:", sshConn.User())
|
||||||
|
|
||||||
newSession := session.New(sshConn, forwardingReqs)
|
session.New(sshConn, forwardingReqs, chans)
|
||||||
for ch := range chans {
|
|
||||||
newSession.ChannelChan <- ch
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func(newSession *session.Session) {
|
|
||||||
err := newSession.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("failed to close session: %v", err)
|
|
||||||
}
|
|
||||||
}(newSession)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@ -23,12 +23,6 @@ import (
|
|||||||
|
|
||||||
type SessionStatus string
|
type SessionStatus string
|
||||||
|
|
||||||
const (
|
|
||||||
INITIALIZING SessionStatus = "INITIALIZING"
|
|
||||||
RUNNING SessionStatus = "RUNNING"
|
|
||||||
SETUP SessionStatus = "SETUP"
|
|
||||||
)
|
|
||||||
|
|
||||||
var forbiddenSlug = []string{
|
var forbiddenSlug = []string{
|
||||||
"ping",
|
"ping",
|
||||||
}
|
}
|
||||||
@ -81,18 +75,10 @@ func updateClientSlug(oldSlug, newSlug string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) safeClose() {
|
|
||||||
s.once.Do(func() {
|
|
||||||
close(s.ChannelChan)
|
|
||||||
close(s.Done)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) Close() error {
|
func (s *Session) Close() error {
|
||||||
if s.Listener != nil {
|
if s.Listener != nil {
|
||||||
err := s.Listener.Close()
|
err := s.Listener.Close()
|
||||||
if err != nil && !errors.Is(err, net.ErrClosed) {
|
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
fmt.Println("1")
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -100,7 +86,6 @@ func (s *Session) Close() error {
|
|||||||
if s.ConnChannel != nil {
|
if s.ConnChannel != nil {
|
||||||
err := s.ConnChannel.Close()
|
err := s.ConnChannel.Close()
|
||||||
if err != nil && !errors.Is(err, io.EOF) {
|
if err != nil && !errors.Is(err, io.EOF) {
|
||||||
fmt.Println("2")
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -108,8 +93,6 @@ func (s *Session) Close() error {
|
|||||||
if s.Connection != nil {
|
if s.Connection != nil {
|
||||||
err := s.Connection.Close()
|
err := s.Connection.Close()
|
||||||
if err != nil && !errors.Is(err, net.ErrClosed) {
|
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
fmt.Println("3")
|
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -121,12 +104,10 @@ func (s *Session) Close() error {
|
|||||||
if s.TunnelType == TCP {
|
if s.TunnelType == TCP {
|
||||||
err := portUtil.Manager.SetPortStatus(s.ForwardedPort, false)
|
err := portUtil.Manager.SetPortStatus(s.ForwardedPort, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("4")
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.safeClose()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -184,9 +165,8 @@ func (s *Session) handleTCPIPForward(req *ssh.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.sendMessage("\033[H\033[2J")
|
s.sendMessage("\033[H\033[2J")
|
||||||
|
|
||||||
showWelcomeMessage(s.ConnChannel)
|
|
||||||
s.Status = RUNNING
|
s.Status = RUNNING
|
||||||
|
showWelcomeMessage(s.ConnChannel)
|
||||||
go s.handleUserInput()
|
go s.handleUserInput()
|
||||||
|
|
||||||
if portToBind == 80 || portToBind == 443 {
|
if portToBind == 80 || portToBind == 443 {
|
||||||
@ -210,7 +190,6 @@ func (s *Session) handleTCPIPForward(req *ssh.Request) {
|
|||||||
}
|
}
|
||||||
portUtil.Manager.SetPortStatus(portToBind, true)
|
portUtil.Manager.SetPortStatus(portToBind, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.handleTCPForward(req, addr, portToBind)
|
s.handleTCPForward(req, addr, portToBind)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -248,8 +227,6 @@ func (s *Session) handleHTTPForward(req *ssh.Request, portToBind uint16) {
|
|||||||
binary.Write(buf, binary.BigEndian, uint32(80))
|
binary.Write(buf, binary.BigEndian, uint32(80))
|
||||||
log.Printf("HTTP forwarding approved on port: %d", 80)
|
log.Printf("HTTP forwarding approved on port: %d", 80)
|
||||||
|
|
||||||
s.waitForRunningStatus()
|
|
||||||
|
|
||||||
domain := utils.Getenv("domain")
|
domain := utils.Getenv("domain")
|
||||||
protocol := "http"
|
protocol := "http"
|
||||||
if utils.Getenv("tls_enabled") == "true" {
|
if utils.Getenv("tls_enabled") == "true" {
|
||||||
@ -321,17 +298,23 @@ func (s *Session) generateUniqueSlug() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) waitForRunningStatus() {
|
func (s *Session) waitForRunningStatus() {
|
||||||
timeout := time.After(10 * time.Second)
|
timeout := time.After(3 * time.Second)
|
||||||
ticker := time.NewTicker(500 * time.Millisecond)
|
ticker := time.NewTicker(150 * time.Millisecond)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
frames := []string{"-", "\\", "|", "/"}
|
||||||
|
i := 0
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
|
s.sendMessage(fmt.Sprintf("\rLoading %s", frames[i]))
|
||||||
|
i = (i + 1) % len(frames)
|
||||||
if s.Status == RUNNING {
|
if s.Status == RUNNING {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case <-timeout:
|
case <-timeout:
|
||||||
|
s.sendMessage("\r\033[K")
|
||||||
|
s.sendMessage("TCP/IP request not received in time.\r\nCheck your internet connection and confirm the server responds within 3000ms.\r\nEnsure you ran the correct command. For more details, visit https://tunnl.live.\r\n\r\n")
|
||||||
|
s.Close()
|
||||||
log.Println("Timeout waiting for session to start running")
|
log.Println("Timeout waiting for session to start running")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,9 +1,16 @@
|
|||||||
package session
|
package session
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"golang.org/x/crypto/ssh"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
INITIALIZING SessionStatus = "INITIALIZING"
|
||||||
|
RUNNING SessionStatus = "RUNNING"
|
||||||
|
SETUP SessionStatus = "SETUP"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TunnelType string
|
type TunnelType string
|
||||||
@ -11,7 +18,6 @@ type TunnelType string
|
|||||||
const (
|
const (
|
||||||
HTTP TunnelType = "http"
|
HTTP TunnelType = "http"
|
||||||
TCP TunnelType = "tcp"
|
TCP TunnelType = "tcp"
|
||||||
UDP TunnelType = "udp"
|
|
||||||
UNKNOWN TunnelType = "unknown"
|
UNKNOWN TunnelType = "unknown"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -23,33 +29,32 @@ type Session struct {
|
|||||||
ForwardedPort uint16
|
ForwardedPort uint16
|
||||||
Status SessionStatus
|
Status SessionStatus
|
||||||
Slug string
|
Slug string
|
||||||
ChannelChan chan ssh.NewChannel
|
|
||||||
Done chan bool
|
|
||||||
once sync.Once
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request) *Session {
|
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel) {
|
||||||
session := &Session{
|
session := &Session{
|
||||||
Status: SETUP,
|
Status: INITIALIZING,
|
||||||
Slug: "",
|
Slug: "",
|
||||||
ConnChannel: nil,
|
ConnChannel: nil,
|
||||||
Connection: conn,
|
Connection: conn,
|
||||||
TunnelType: UNKNOWN,
|
TunnelType: UNKNOWN,
|
||||||
ChannelChan: make(chan ssh.NewChannel),
|
|
||||||
Done: make(chan bool),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for channel := range session.ChannelChan {
|
go session.waitForRunningStatus()
|
||||||
|
|
||||||
|
for channel := range sshChan {
|
||||||
ch, reqs, _ := channel.Accept()
|
ch, reqs, _ := channel.Accept()
|
||||||
if session.ConnChannel == nil {
|
if session.ConnChannel == nil {
|
||||||
session.ConnChannel = ch
|
session.ConnChannel = ch
|
||||||
session.Status = RUNNING
|
session.Status = SETUP
|
||||||
go session.HandleGlobalRequest(forwardingReq)
|
go session.HandleGlobalRequest(forwardingReq)
|
||||||
}
|
}
|
||||||
go session.HandleGlobalRequest(reqs)
|
go session.HandleGlobalRequest(reqs)
|
||||||
}
|
}
|
||||||
|
err := session.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to close session: %v", err)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return session
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user