Improve concurrency and resource management #2
@ -8,7 +8,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) handleConnection(conn net.Conn) {
|
func (s *Server) handleConnection(conn net.Conn) {
|
||||||
sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.Config)
|
sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.Config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("failed to establish SSH connection: %v", err)
|
log.Printf("failed to establish SSH connection: %v", err)
|
||||||
conn.Close()
|
conn.Close()
|
||||||
@ -17,5 +17,9 @@ func (s *Server) handleConnection(conn net.Conn) {
|
|||||||
|
|
||||||
log.Println("SSH connection established:", sshConn.User())
|
log.Println("SSH connection established:", sshConn.User())
|
||||||
|
|
||||||
session.New(sshConn, chans, reqs)
|
newSession := session.New(sshConn, forwardingReqs)
|
||||||
|
for ch := range chans {
|
||||||
|
newSession.ChannelChan <- ch
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@ -40,17 +40,6 @@ var (
|
|||||||
Clients = make(map[string]*Session)
|
Clients = make(map[string]*Session)
|
||||||
)
|
)
|
||||||
|
|
||||||
type Session struct {
|
|
||||||
Connection *ssh.ServerConn
|
|
||||||
ConnChannel ssh.Channel
|
|
||||||
Listener net.Listener
|
|
||||||
TunnelType TunnelType
|
|
||||||
ForwardedPort uint16
|
|
||||||
Status SessionStatus
|
|
||||||
Slug string
|
|
||||||
Done chan bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func registerClient(slug string, session *Session) bool {
|
func registerClient(slug string, session *Session) bool {
|
||||||
clientsMutex.Lock()
|
clientsMutex.Lock()
|
||||||
defer clientsMutex.Unlock()
|
defer clientsMutex.Unlock()
|
||||||
@ -113,29 +102,18 @@ func (s *Session) Close() {
|
|||||||
close(s.Done)
|
close(s.Done)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) handleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
|
func (s *Session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
|
||||||
ticker := time.NewTicker(1 * time.Second)
|
for req := range GlobalRequest {
|
||||||
for {
|
switch req.Type {
|
||||||
select {
|
case "tcpip-forward":
|
||||||
case req := <-GlobalRequest:
|
|
||||||
ticker.Stop()
|
|
||||||
if req == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if req.Type == "tcpip-forward" {
|
|
||||||
s.handleTCPIPForward(req)
|
s.handleTCPIPForward(req)
|
||||||
} else if req.Type == "shell" || req.Type == "pty-req" || req.Type == "window-change" {
|
return
|
||||||
|
case "shell", "pty-req", "window-change":
|
||||||
req.Reply(true, nil)
|
req.Reply(true, nil)
|
||||||
} else {
|
default:
|
||||||
|
log.Println("Unknown request type:", req.Type)
|
||||||
req.Reply(false, nil)
|
req.Reply(false, nil)
|
||||||
}
|
}
|
||||||
case <-s.Done:
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
s.sendMessage(fmt.Sprintf("Please specify the forwarding tunnel. For example: 'ssh %s -p %s -R 443:localhost:8080' \r\n\n\n", utils.Getenv("domain"), utils.Getenv("port")))
|
|
||||||
s.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -181,6 +159,7 @@ func (s *Session) handleTCPIPForward(req *ssh.Request) {
|
|||||||
|
|
||||||
showWelcomeMessage(s.ConnChannel)
|
showWelcomeMessage(s.ConnChannel)
|
||||||
s.Status = RUNNING
|
s.Status = RUNNING
|
||||||
|
go s.handleUserInput()
|
||||||
|
|
||||||
if portToBind == 80 || portToBind == 443 {
|
if portToBind == 80 || portToBind == 443 {
|
||||||
s.handleHTTPForward(req, portToBind)
|
s.handleHTTPForward(req, portToBind)
|
||||||
@ -338,29 +317,14 @@ func (s *Session) sendMessage(message string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) HandleSessionChannel(newChannel ssh.NewChannel, initialRequest <-chan *ssh.Request) {
|
func (s *Session) handleUserInput() {
|
||||||
connection, requests, err := newChannel.Accept()
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Could not accept channel: %s", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
s.ConnChannel = connection
|
|
||||||
s.Status = RUNNING
|
|
||||||
|
|
||||||
go s.handleGlobalRequest(initialRequest)
|
|
||||||
go s.handleGlobalRequest(requests)
|
|
||||||
go s.handleUserInput(connection)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) handleUserInput(connection ssh.Channel) {
|
|
||||||
var commandBuffer bytes.Buffer
|
var commandBuffer bytes.Buffer
|
||||||
buf := make([]byte, 1)
|
buf := make([]byte, 1)
|
||||||
inSlugEditMode := false
|
inSlugEditMode := false
|
||||||
editSlug := s.Slug
|
editSlug := s.Slug
|
||||||
|
|
||||||
for {
|
for {
|
||||||
n, err := connection.Read(buf)
|
n, err := s.ConnChannel.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err != io.EOF {
|
if err != io.EOF {
|
||||||
log.Printf("Error reading from client: %s", err)
|
log.Printf("Error reading from client: %s", err)
|
||||||
@ -372,16 +336,16 @@ func (s *Session) handleUserInput(connection ssh.Channel) {
|
|||||||
char := buf[0]
|
char := buf[0]
|
||||||
|
|
||||||
if inSlugEditMode {
|
if inSlugEditMode {
|
||||||
s.handleSlugEditMode(connection, &inSlugEditMode, &editSlug, char, &commandBuffer)
|
s.handleSlugEditMode(s.ConnChannel, &inSlugEditMode, &editSlug, char, &commandBuffer)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
connection.Write(buf[:n])
|
s.ConnChannel.Write(buf[:n])
|
||||||
|
|
||||||
if char == 8 || char == 127 {
|
if char == 8 || char == 127 {
|
||||||
if commandBuffer.Len() > 0 {
|
if commandBuffer.Len() > 0 {
|
||||||
commandBuffer.Truncate(commandBuffer.Len() - 1)
|
commandBuffer.Truncate(commandBuffer.Len() - 1)
|
||||||
connection.Write([]byte("\b \b"))
|
s.ConnChannel.Write([]byte("\b \b"))
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -394,7 +358,7 @@ func (s *Session) handleUserInput(connection ssh.Channel) {
|
|||||||
|
|
||||||
if commandBuffer.Len() > 0 {
|
if commandBuffer.Len() > 0 {
|
||||||
if char == 13 {
|
if char == 13 {
|
||||||
s.handleCommand(connection, commandBuffer.String(), &inSlugEditMode, &editSlug, &commandBuffer)
|
s.handleCommand(s.ConnChannel, commandBuffer.String(), &inSlugEditMode, &editSlug, &commandBuffer)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
commandBuffer.WriteByte(char)
|
commandBuffer.WriteByte(char)
|
||||||
|
|||||||
@ -2,6 +2,7 @@ package session
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TunnelType string
|
type TunnelType string
|
||||||
@ -13,19 +14,38 @@ const (
|
|||||||
UNKNOWN TunnelType = "unknown"
|
UNKNOWN TunnelType = "unknown"
|
||||||
)
|
)
|
||||||
|
|
||||||
func New(conn *ssh.ServerConn, sshChannel <-chan ssh.NewChannel, req <-chan *ssh.Request) *Session {
|
type Session struct {
|
||||||
|
Connection *ssh.ServerConn
|
||||||
|
ConnChannel ssh.Channel
|
||||||
|
Listener net.Listener
|
||||||
|
TunnelType TunnelType
|
||||||
|
ForwardedPort uint16
|
||||||
|
Status SessionStatus
|
||||||
|
Slug string
|
||||||
|
ChannelChan chan ssh.NewChannel
|
||||||
|
Done chan bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request) *Session {
|
||||||
session := &Session{
|
session := &Session{
|
||||||
Status: SETUP,
|
Status: SETUP,
|
||||||
Slug: "",
|
Slug: "",
|
||||||
ConnChannel: nil,
|
ConnChannel: nil,
|
||||||
Connection: conn,
|
Connection: conn,
|
||||||
TunnelType: UNKNOWN,
|
TunnelType: UNKNOWN,
|
||||||
|
ChannelChan: make(chan ssh.NewChannel),
|
||||||
Done: make(chan bool),
|
Done: make(chan bool),
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for newChannel := range sshChannel {
|
for channel := range session.ChannelChan {
|
||||||
go session.HandleSessionChannel(newChannel, req)
|
ch, reqs, _ := channel.Accept()
|
||||||
|
if session.ConnChannel == nil {
|
||||||
|
session.ConnChannel = ch
|
||||||
|
session.Status = RUNNING
|
||||||
|
go session.HandleGlobalRequest(forwardingReq)
|
||||||
|
}
|
||||||
|
go session.HandleGlobalRequest(reqs)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user