Merge pull request 'staging' (#25) from staging into main
All checks were successful
Docker Build and Push / build-and-push (push) Successful in 4m21s
All checks were successful
Docker Build and Push / build-and-push (push) Successful in 4m21s
Reviewed-on: bagas/tunnl_please#25
This commit is contained in:
9
go.mod
9
go.mod
@ -3,13 +3,8 @@ module tunnel_pls
|
|||||||
go 1.24.4
|
go 1.24.4
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/a-h/templ v0.3.833
|
|
||||||
github.com/joho/godotenv v1.5.1
|
github.com/joho/godotenv v1.5.1
|
||||||
golang.org/x/crypto v0.32.0
|
golang.org/x/crypto v0.45.0
|
||||||
golang.org/x/net v0.33.0
|
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require golang.org/x/sys v0.38.0 // indirect
|
||||||
github.com/gorilla/websocket v1.5.3 // indirect
|
|
||||||
golang.org/x/sys v0.29.0 // indirect
|
|
||||||
)
|
|
||||||
|
|||||||
13
go.sum
13
go.sum
@ -1,16 +1,13 @@
|
|||||||
github.com/a-h/templ v0.3.833 h1:L/KOk/0VvVTBegtE0fp2RJQiBm7/52Zxv5fqlEHiQUU=
|
|
||||||
github.com/a-h/templ v0.3.833/go.mod h1:cAu4AiZhtJfBjMY0HASlyzvkrtjnHWPeEsyGK2YYmfk=
|
|
||||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
|
||||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
|
||||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
|
||||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
|
||||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||||
golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
|
golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
|
||||||
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
|
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
|
||||||
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
|
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||||
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||||
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
|
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
|
||||||
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
|
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||||
|
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg=
|
golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg=
|
||||||
golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek=
|
golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek=
|
||||||
|
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
|
||||||
|
|||||||
@ -11,16 +11,13 @@ import (
|
|||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"tunnel_pls/session"
|
"tunnel_pls/session"
|
||||||
|
"tunnel_pls/types"
|
||||||
"tunnel_pls/utils"
|
"tunnel_pls/utils"
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" +
|
type Interaction interface {
|
||||||
"Content-Length: 11\r\n" +
|
SendMessage(message string)
|
||||||
"Content-Type: text/plain\r\n\r\n" +
|
}
|
||||||
"Bad Gateway")
|
|
||||||
|
|
||||||
type CustomWriter struct {
|
type CustomWriter struct {
|
||||||
RemoteAddr net.Addr
|
RemoteAddr net.Addr
|
||||||
writer io.Writer
|
writer io.Writer
|
||||||
@ -29,12 +26,16 @@ type CustomWriter struct {
|
|||||||
buf []byte
|
buf []byte
|
||||||
respHeader *ResponseHeaderFactory
|
respHeader *ResponseHeaderFactory
|
||||||
reqHeader *RequestHeaderFactory
|
reqHeader *RequestHeaderFactory
|
||||||
interaction *session.Interaction
|
interaction Interaction
|
||||||
respMW []ResponseMiddleware
|
respMW []ResponseMiddleware
|
||||||
reqStartMW []RequestMiddleware
|
reqStartMW []RequestMiddleware
|
||||||
reqEndMW []RequestMiddleware
|
reqEndMW []RequestMiddleware
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cw *CustomWriter) SetInteraction(interaction Interaction) {
|
||||||
|
cw.interaction = interaction
|
||||||
|
}
|
||||||
|
|
||||||
func (cw *CustomWriter) Read(p []byte) (int, error) {
|
func (cw *CustomWriter) Read(p []byte) (int, error) {
|
||||||
tmp := make([]byte, len(p))
|
tmp := make([]byte, len(p))
|
||||||
read, err := cw.reader.Read(tmp)
|
read, err := cw.reader.Read(tmp)
|
||||||
@ -125,7 +126,7 @@ func isHTTPHeader(buf []byte) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (cw *CustomWriter) Write(p []byte) (int, error) {
|
func (cw *CustomWriter) Write(p []byte) (int, error) {
|
||||||
if len(p) == len(BadGatewayResponse) && bytes.Equal(p, BadGatewayResponse) {
|
if len(p) == len(types.BadGatewayResponse) && bytes.Equal(p, types.BadGatewayResponse) {
|
||||||
return cw.writer.Write(p)
|
return cw.writer.Write(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -177,7 +178,7 @@ func (cw *CustomWriter) Write(p []byte) (int, error) {
|
|||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cw *CustomWriter) AddInteraction(interaction *session.Interaction) {
|
func (cw *CustomWriter) AddInteraction(interaction Interaction) {
|
||||||
cw.interaction = interaction
|
cw.interaction = interaction
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -211,7 +212,7 @@ func NewHTTPServer() error {
|
|||||||
func Handler(conn net.Conn) {
|
func Handler(conn net.Conn) {
|
||||||
defer func() {
|
defer func() {
|
||||||
err := conn.Close()
|
err := conn.Close()
|
||||||
if err != nil {
|
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
log.Printf("Error closing connection: %v", err)
|
log.Printf("Error closing connection: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -287,32 +288,18 @@ func Handler(conn net.Conn) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr())
|
cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr())
|
||||||
|
cw.SetInteraction(sshSession.Interaction)
|
||||||
forwardRequest(cw, reqhf, sshSession)
|
forwardRequest(cw, reqhf, sshSession)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshSession *session.SSHSession) {
|
func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshSession *session.SSHSession) {
|
||||||
cw.AddInteraction(sshSession.Interaction)
|
payload := sshSession.Forwarder.CreateForwardedTCPIPPayload(cw.RemoteAddr)
|
||||||
originHost, originPort := ParseAddr(cw.RemoteAddr.String())
|
channel, reqs, err := sshSession.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
|
||||||
payload := createForwardedTCPIPPayload(originHost, uint16(originPort), sshSession.Forwarder.GetForwardedPort())
|
|
||||||
channel, reqs, err := sshSession.Conn.OpenChannel("forwarded-tcpip", payload)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
|
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
|
||||||
sendBadGatewayResponse(cw)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func(channel ssh.Channel) {
|
|
||||||
err := channel.Close()
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, io.EOF) {
|
|
||||||
sendBadGatewayResponse(cw)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Println("Failed to close connection:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}(channel)
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for req := range reqs {
|
for req := range reqs {
|
||||||
@ -346,14 +333,6 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sshSession.HandleForwardedConnection(cw, channel, cw.RemoteAddr)
|
sshSession.Forwarder.HandleConnection(cw, channel, cw.RemoteAddr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func sendBadGatewayResponse(writer io.Writer) {
|
|
||||||
_, err := writer.Write(BadGatewayResponse)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("failed to write Bad Gateway response: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@ -112,7 +112,7 @@ func HandlerTLS(conn net.Conn) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr())
|
cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr())
|
||||||
|
cw.SetInteraction(sshSession.Interaction)
|
||||||
forwardRequest(cw, reqhf, sshSession)
|
forwardRequest(cw, reqhf, sshSession)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@ -4,7 +4,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
"tunnel_pls/session"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type RequestMiddleware interface {
|
type RequestMiddleware interface {
|
||||||
@ -29,20 +28,22 @@ func (h *TunnelFingerprint) HandleResponse(header *ResponseHeaderFactory, body [
|
|||||||
}
|
}
|
||||||
|
|
||||||
type RequestLogger struct {
|
type RequestLogger struct {
|
||||||
interaction session.Interaction
|
interaction Interaction
|
||||||
remoteAddr net.Addr
|
remoteAddr net.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRequestLogger(interaction *session.Interaction, remoteAddr net.Addr) *RequestLogger {
|
func NewRequestLogger(interaction Interaction, remoteAddr net.Addr) *RequestLogger {
|
||||||
return &RequestLogger{
|
return &RequestLogger{
|
||||||
interaction: *interaction,
|
interaction: interaction,
|
||||||
remoteAddr: remoteAddr,
|
remoteAddr: remoteAddr,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rl *RequestLogger) HandleRequest(header *RequestHeaderFactory) error {
|
func (rl *RequestLogger) HandleRequest(header *RequestHeaderFactory) error {
|
||||||
rl.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", time.Now().UTC().Format(time.RFC3339), rl.remoteAddr.String(), header.Method, header.Path))
|
rl.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", time.Now().UTC().Format(time.RFC3339), rl.remoteAddr.String(), header.Method, header.Path))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rl *RequestLogger) HandleResponse(header *ResponseHeaderFactory, body []byte) error { return nil }
|
func (rl *RequestLogger) HandleResponse(header *ResponseHeaderFactory, body []byte) error { return nil }
|
||||||
|
|
||||||
//TODO: Implement caching atau enggak
|
//TODO: Implement caching atau enggak
|
||||||
|
|||||||
@ -1,13 +1,10 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
|
||||||
"tunnel_pls/utils"
|
"tunnel_pls/utils"
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
@ -58,41 +55,3 @@ func (s *Server) Start() {
|
|||||||
go s.handleConnection(conn)
|
go s.handleConnection(conn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
|
|
||||||
writeSSHString(&buf, "localhost")
|
|
||||||
err := binary.Write(&buf, binary.BigEndian, uint32(port))
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to write string to buffer: %v", err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
writeSSHString(&buf, host)
|
|
||||||
err = binary.Write(&buf, binary.BigEndian, uint32(originPort))
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to write string to buffer: %v", err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return buf.Bytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
func writeSSHString(buffer *bytes.Buffer, str string) {
|
|
||||||
err := binary.Write(buffer, binary.BigEndian, uint32(len(str)))
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to write string to buffer: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
buffer.WriteString(str)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ParseAddr(addr string) (string, uint32) {
|
|
||||||
host, portStr, err := net.SplitHostPort(addr)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr)
|
|
||||||
return "0.0.0.0", uint32(0)
|
|
||||||
}
|
|
||||||
port, _ := strconv.Atoi(portStr)
|
|
||||||
return host, uint32(port)
|
|
||||||
}
|
|
||||||
|
|||||||
@ -1,37 +0,0 @@
|
|||||||
package session
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Forwarder struct {
|
|
||||||
Listener net.Listener
|
|
||||||
TunnelType TunnelType
|
|
||||||
ForwardedPort uint16
|
|
||||||
|
|
||||||
getSlug func() string
|
|
||||||
setSlug func(string)
|
|
||||||
}
|
|
||||||
|
|
||||||
type ForwardingController interface {
|
|
||||||
HandleGlobalRequest(ch <-chan *ssh.Request)
|
|
||||||
HandleTCPIPForward(req *ssh.Request)
|
|
||||||
HandleHTTPForward(req *ssh.Request, port uint16)
|
|
||||||
HandleTCPForward(req *ssh.Request, addr string, port uint16)
|
|
||||||
AcceptTCPConnections()
|
|
||||||
}
|
|
||||||
|
|
||||||
type ForwarderInfo interface {
|
|
||||||
GetTunnelType() TunnelType
|
|
||||||
GetForwardedPort() uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Forwarder) GetTunnelType() TunnelType {
|
|
||||||
return f.TunnelType
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Forwarder) GetForwardedPort() uint16 {
|
|
||||||
return f.ForwardedPort
|
|
||||||
}
|
|
||||||
185
session/forwarder/forwarder.go
Normal file
185
session/forwarder/forwarder.go
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
package forwarder
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"tunnel_pls/session/slug"
|
||||||
|
"tunnel_pls/types"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Forwarder struct {
|
||||||
|
Listener net.Listener
|
||||||
|
TunnelType types.TunnelType
|
||||||
|
ForwardedPort uint16
|
||||||
|
SlugManager slug.Manager
|
||||||
|
Lifecycle Lifecycle
|
||||||
|
}
|
||||||
|
|
||||||
|
type Lifecycle interface {
|
||||||
|
GetConnection() ssh.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
type ForwardingController interface {
|
||||||
|
AcceptTCPConnections()
|
||||||
|
SetType(tunnelType types.TunnelType)
|
||||||
|
GetTunnelType() types.TunnelType
|
||||||
|
GetForwardedPort() uint16
|
||||||
|
SetForwardedPort(port uint16)
|
||||||
|
SetListener(listener net.Listener)
|
||||||
|
GetListener() net.Listener
|
||||||
|
Close() error
|
||||||
|
HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr)
|
||||||
|
SetLifecycle(lifecycle Lifecycle)
|
||||||
|
CreateForwardedTCPIPPayload(origin net.Addr) []byte
|
||||||
|
WriteBadGatewayResponse(dst io.Writer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) {
|
||||||
|
f.Lifecycle = lifecycle
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) AcceptTCPConnections() {
|
||||||
|
for {
|
||||||
|
conn, err := f.GetListener().Accept()
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, net.ErrClosed) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("Error accepting connection: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr())
|
||||||
|
channel, reqs, err := f.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for req := range reqs {
|
||||||
|
err := req.Reply(false, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to reply to request: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
go f.HandleConnection(conn, channel, conn.RemoteAddr())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) {
|
||||||
|
defer func(src ssh.Channel) {
|
||||||
|
_, err := io.Copy(io.Discard, src)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to discard connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = src.Close()
|
||||||
|
if err != nil && !errors.Is(err, io.EOF) {
|
||||||
|
log.Printf("Error closing connection: %v", err)
|
||||||
|
}
|
||||||
|
}(src)
|
||||||
|
log.Printf("Handling new forwarded connection from %s", remoteAddr)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_, err := io.Copy(src, dst)
|
||||||
|
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
|
||||||
|
log.Printf("Error copying from conn.Reader to channel: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err := io.Copy(dst, src)
|
||||||
|
|
||||||
|
if err != nil && !errors.Is(err, io.EOF) {
|
||||||
|
log.Printf("Error copying from channel to conn.Writer: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) SetType(tunnelType types.TunnelType) {
|
||||||
|
f.TunnelType = tunnelType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) GetTunnelType() types.TunnelType {
|
||||||
|
return f.TunnelType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) GetForwardedPort() uint16 {
|
||||||
|
return f.ForwardedPort
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) SetForwardedPort(port uint16) {
|
||||||
|
f.ForwardedPort = port
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) SetListener(listener net.Listener) {
|
||||||
|
f.Listener = listener
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) GetListener() net.Listener {
|
||||||
|
return f.Listener
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) WriteBadGatewayResponse(dst io.Writer) {
|
||||||
|
_, err := dst.Write(types.BadGatewayResponse)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to write Bad Gateway response: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) Close() error {
|
||||||
|
if f.GetTunnelType() != types.HTTP {
|
||||||
|
return f.Listener.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
|
||||||
|
host, originPort := parseAddr(origin.String())
|
||||||
|
|
||||||
|
writeSSHString(&buf, "localhost")
|
||||||
|
err := binary.Write(&buf, binary.BigEndian, uint32(f.GetForwardedPort()))
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to write string to buffer: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
writeSSHString(&buf, host)
|
||||||
|
err = binary.Write(&buf, binary.BigEndian, uint32(originPort))
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to write string to buffer: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAddr(addr string) (string, uint16) {
|
||||||
|
host, portStr, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr)
|
||||||
|
return "0.0.0.0", uint16(0)
|
||||||
|
}
|
||||||
|
port, _ := strconv.Atoi(portStr)
|
||||||
|
return host, uint16(port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeSSHString(buffer *bytes.Buffer, str string) {
|
||||||
|
err := binary.Write(buffer, binary.BigEndian, uint32(len(str)))
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to write string to buffer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
buffer.WriteString(str)
|
||||||
|
}
|
||||||
@ -3,98 +3,24 @@ package session
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
portUtil "tunnel_pls/internal/port"
|
portUtil "tunnel_pls/internal/port"
|
||||||
|
"tunnel_pls/types"
|
||||||
|
|
||||||
"tunnel_pls/utils"
|
"tunnel_pls/utils"
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Status string
|
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
|
||||||
|
|
||||||
var forbiddenSlug = []string{
|
|
||||||
"ping",
|
|
||||||
}
|
|
||||||
|
|
||||||
type UserConnection struct {
|
|
||||||
Reader io.Reader
|
|
||||||
Writer net.Conn
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
clientsMutex sync.RWMutex
|
|
||||||
Clients = make(map[string]*SSHSession)
|
|
||||||
)
|
|
||||||
|
|
||||||
func registerClient(slug string, session *SSHSession) bool {
|
|
||||||
clientsMutex.Lock()
|
|
||||||
defer clientsMutex.Unlock()
|
|
||||||
|
|
||||||
if _, exists := Clients[slug]; exists {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
Clients[slug] = session
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func unregisterClient(slug string) {
|
|
||||||
clientsMutex.Lock()
|
|
||||||
defer clientsMutex.Unlock()
|
|
||||||
|
|
||||||
delete(Clients, slug)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SSHSession) Close() error {
|
|
||||||
if s.Forwarder.Listener != nil {
|
|
||||||
err := s.Forwarder.Listener.Close()
|
|
||||||
if err != nil && !errors.Is(err, net.ErrClosed) {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.channel != nil {
|
|
||||||
err := s.channel.Close()
|
|
||||||
if err != nil && !errors.Is(err, io.EOF) {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.Conn != nil {
|
|
||||||
err := s.Conn.Close()
|
|
||||||
if err != nil && !errors.Is(err, net.ErrClosed) {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
slug := s.Forwarder.getSlug()
|
|
||||||
if slug != "" {
|
|
||||||
unregisterClient(slug)
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.Forwarder.TunnelType == TCP && s.Forwarder.Listener != nil {
|
|
||||||
err := portUtil.Manager.SetPortStatus(s.Forwarder.ForwardedPort, false)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
|
func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
|
||||||
for req := range GlobalRequest {
|
for req := range GlobalRequest {
|
||||||
switch req.Type {
|
switch req.Type {
|
||||||
case "tcpip-forward":
|
case "tcpip-forward":
|
||||||
s.handleTCPIPForward(req)
|
s.HandleTCPIPForward(req)
|
||||||
return
|
return
|
||||||
case "shell", "pty-req", "window-change":
|
case "shell", "pty-req", "window-change":
|
||||||
err := req.Reply(true, nil)
|
err := req.Reply(true, nil)
|
||||||
@ -113,7 +39,7 @@ func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SSHSession) handleTCPIPForward(req *ssh.Request) {
|
func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) {
|
||||||
log.Println("Port forwarding request detected")
|
log.Println("Port forwarding request detected")
|
||||||
|
|
||||||
reader := bytes.NewReader(req.Payload)
|
reader := bytes.NewReader(req.Payload)
|
||||||
@ -126,7 +52,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) {
|
|||||||
log.Println("Failed to reply to request:", err)
|
log.Println("Failed to reply to request:", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = s.Close()
|
err = s.Lifecycle.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("failed to close session: %v", err)
|
log.Printf("failed to close session: %v", err)
|
||||||
}
|
}
|
||||||
@ -142,7 +68,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) {
|
|||||||
log.Println("Failed to reply to request:", err)
|
log.Println("Failed to reply to request:", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = s.Close()
|
err = s.Lifecycle.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("failed to close session: %v", err)
|
log.Printf("failed to close session: %v", err)
|
||||||
}
|
}
|
||||||
@ -156,7 +82,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) {
|
|||||||
log.Println("Failed to reply to request:", err)
|
log.Println("Failed to reply to request:", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = s.Close()
|
err = s.Lifecycle.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("failed to close session: %v", err)
|
log.Printf("failed to close session: %v", err)
|
||||||
}
|
}
|
||||||
@ -172,7 +98,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) {
|
|||||||
log.Println("Failed to reply to request:", err)
|
log.Println("Failed to reply to request:", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = s.Close()
|
err = s.Lifecycle.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("failed to close session: %v", err)
|
log.Printf("failed to close session: %v", err)
|
||||||
}
|
}
|
||||||
@ -180,11 +106,11 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.Interaction.SendMessage("\033[H\033[2J")
|
s.Interaction.SendMessage("\033[H\033[2J")
|
||||||
s.Lifecycle.Status = RUNNING
|
s.Lifecycle.SetStatus(types.RUNNING)
|
||||||
go s.Interaction.HandleUserInput()
|
go s.Interaction.HandleUserInput()
|
||||||
|
|
||||||
if portToBind == 80 || portToBind == 443 {
|
if portToBind == 80 || portToBind == 443 {
|
||||||
s.handleHTTPForward(req, portToBind)
|
s.HandleHTTPForward(req, portToBind)
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
if portToBind == 0 {
|
if portToBind == 0 {
|
||||||
@ -197,7 +123,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) {
|
|||||||
log.Println("Failed to reply to request:", err)
|
log.Println("Failed to reply to request:", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = s.Close()
|
err = s.Lifecycle.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("failed to close session: %v", err)
|
log.Printf("failed to close session: %v", err)
|
||||||
}
|
}
|
||||||
@ -210,7 +136,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) {
|
|||||||
log.Println("Failed to reply to request:", err)
|
log.Println("Failed to reply to request:", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = s.Close()
|
err = s.Lifecycle.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("failed to close session: %v", err)
|
log.Printf("failed to close session: %v", err)
|
||||||
}
|
}
|
||||||
@ -222,29 +148,12 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
s.handleTCPForward(req, addr, portToBind)
|
s.HandleTCPForward(req, addr, portToBind)
|
||||||
}
|
}
|
||||||
|
|
||||||
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
|
func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) {
|
||||||
|
s.Forwarder.SetType(types.HTTP)
|
||||||
func isBlockedPort(port uint16) bool {
|
s.Forwarder.SetForwardedPort(portToBind)
|
||||||
if port == 80 || port == 443 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if port < 1024 && port != 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
for _, p := range blockedReservedPorts {
|
|
||||||
if p == port {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SSHSession) handleHTTPForward(req *ssh.Request, portToBind uint16) {
|
|
||||||
s.Forwarder.TunnelType = HTTP
|
|
||||||
s.Forwarder.ForwardedPort = portToBind
|
|
||||||
|
|
||||||
slug := generateUniqueSlug()
|
slug := generateUniqueSlug()
|
||||||
if slug == "" {
|
if slug == "" {
|
||||||
@ -256,7 +165,7 @@ func (s *SSHSession) handleHTTPForward(req *ssh.Request, portToBind uint16) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
s.Forwarder.setSlug(slug)
|
s.SlugManager.Set(slug)
|
||||||
registerClient(slug, s)
|
registerClient(slug, s)
|
||||||
|
|
||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
@ -282,8 +191,8 @@ func (s *SSHSession) handleHTTPForward(req *ssh.Request, portToBind uint16) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SSHSession) handleTCPForward(req *ssh.Request, addr string, portToBind uint16) {
|
func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) {
|
||||||
s.Forwarder.TunnelType = TCP
|
s.Forwarder.SetType(types.TCP)
|
||||||
log.Printf("Requested forwarding on %s:%d", addr, portToBind)
|
log.Printf("Requested forwarding on %s:%d", addr, portToBind)
|
||||||
|
|
||||||
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
|
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
|
||||||
@ -294,18 +203,18 @@ func (s *SSHSession) handleTCPForward(req *ssh.Request, addr string, portToBind
|
|||||||
log.Println("Failed to reply to request:", err)
|
log.Println("Failed to reply to request:", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = s.Close()
|
err = s.Lifecycle.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("failed to close session: %v", err)
|
log.Printf("failed to close session: %v", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.Forwarder.Listener = listener
|
s.Forwarder.SetListener(listener)
|
||||||
s.Forwarder.ForwardedPort = portToBind
|
s.Forwarder.SetForwardedPort(portToBind)
|
||||||
s.Interaction.ShowWelcomeMessage()
|
s.Interaction.ShowWelcomeMessage()
|
||||||
s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.Forwarder.TunnelType, utils.Getenv("domain"), s.Forwarder.ForwardedPort))
|
s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", utils.Getenv("domain"), s.Forwarder.GetForwardedPort()))
|
||||||
|
|
||||||
go s.acceptTCPConnections()
|
go s.Forwarder.AcceptTCPConnections()
|
||||||
|
|
||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
err = binary.Write(buf, binary.BigEndian, uint32(portToBind))
|
err = binary.Write(buf, binary.BigEndian, uint32(portToBind))
|
||||||
@ -321,37 +230,6 @@ func (s *SSHSession) handleTCPForward(req *ssh.Request, addr string, portToBind
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SSHSession) acceptTCPConnections() {
|
|
||||||
for {
|
|
||||||
conn, err := s.Forwarder.Listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, net.ErrClosed) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Printf("Error accepting connection: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
originHost, originPort := ParseAddr(conn.RemoteAddr().String())
|
|
||||||
payload := createForwardedTCPIPPayload(originHost, uint16(originPort), s.Forwarder.GetForwardedPort())
|
|
||||||
channel, reqs, err := s.Conn.OpenChannel("forwarded-tcpip", payload)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for req := range reqs {
|
|
||||||
err := req.Reply(false, nil)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to reply to request: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
go s.HandleForwardedConnection(conn, channel, conn.RemoteAddr())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateUniqueSlug() string {
|
func generateUniqueSlug() string {
|
||||||
maxAttempts := 5
|
maxAttempts := 5
|
||||||
|
|
||||||
@ -371,95 +249,6 @@ func generateUniqueSlug() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SSHSession) waitForRunningStatus() {
|
|
||||||
timeout := time.After(3 * time.Second)
|
|
||||||
ticker := time.NewTicker(150 * time.Millisecond)
|
|
||||||
defer ticker.Stop()
|
|
||||||
frames := []string{"-", "\\", "|", "/"}
|
|
||||||
i := 0
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ticker.C:
|
|
||||||
s.Interaction.SendMessage(fmt.Sprintf("\rLoading %s", frames[i]))
|
|
||||||
i = (i + 1) % len(frames)
|
|
||||||
if s.Lifecycle.Status == RUNNING {
|
|
||||||
s.Interaction.SendMessage("\r\033[K")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
case <-timeout:
|
|
||||||
s.Interaction.SendMessage("\r\033[K")
|
|
||||||
s.Interaction.SendMessage("TCP/IP request not received in time.\r\nCheck your internet connection and confirm the server responds within 3000ms.\r\nEnsure you ran the correct command. For more details, visit https://tunnl.live.\r\n\r\n")
|
|
||||||
err := s.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("failed to close session: %v", err)
|
|
||||||
}
|
|
||||||
log.Println("Timeout waiting for session to start running")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func isForbiddenSlug(slug string) bool {
|
|
||||||
for _, s := range forbiddenSlug {
|
|
||||||
if slug == s {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func isValidSlug(slug string) bool {
|
|
||||||
if len(slug) < 3 || len(slug) > 20 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if slug[0] == '-' || slug[len(slug)-1] == '-' {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, c := range slug {
|
|
||||||
if !((c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-') {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func waitForKeyPress(connection ssh.Channel) {
|
|
||||||
keyBuf := make([]byte, 1)
|
|
||||||
for {
|
|
||||||
_, err := connection.Read(keyBuf)
|
|
||||||
if err == nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SSHSession) HandleForwardedConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) {
|
|
||||||
defer func(src ssh.Channel) {
|
|
||||||
err := src.Close()
|
|
||||||
if err != nil && !errors.Is(err, io.EOF) {
|
|
||||||
log.Printf("Error closing connection: %v", err)
|
|
||||||
}
|
|
||||||
}(src)
|
|
||||||
log.Printf("Handling new forwarded connection from %s", remoteAddr)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
_, err := io.Copy(src, dst)
|
|
||||||
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
|
|
||||||
log.Printf("Error copying from conn.Reader to channel: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
_, err := io.Copy(dst, src)
|
|
||||||
|
|
||||||
if err != nil && !errors.Is(err, io.EOF) {
|
|
||||||
log.Printf("Error copying from channel to conn.Writer: %v", err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func readSSHString(reader *bytes.Reader) (string, error) {
|
func readSSHString(reader *bytes.Reader) (string, error) {
|
||||||
var length uint32
|
var length uint32
|
||||||
if err := binary.Read(reader, binary.BigEndian, &length); err != nil {
|
if err := binary.Read(reader, binary.BigEndian, &length); err != nil {
|
||||||
@ -472,40 +261,17 @@ func readSSHString(reader *bytes.Reader) (string, error) {
|
|||||||
return string(strBytes), nil
|
return string(strBytes), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte {
|
func isBlockedPort(port uint16) bool {
|
||||||
var buf bytes.Buffer
|
if port == 80 || port == 443 {
|
||||||
|
return false
|
||||||
writeSSHString(&buf, "localhost")
|
|
||||||
err := binary.Write(&buf, binary.BigEndian, uint32(port))
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to write string to buffer: %v", err)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
writeSSHString(&buf, host)
|
if port < 1024 && port != 0 {
|
||||||
err = binary.Write(&buf, binary.BigEndian, uint32(originPort))
|
return true
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to write string to buffer: %v", err)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
for _, p := range blockedReservedPorts {
|
||||||
return buf.Bytes()
|
if p == port {
|
||||||
}
|
return true
|
||||||
|
}
|
||||||
func writeSSHString(buffer *bytes.Buffer, str string) {
|
}
|
||||||
err := binary.Write(buffer, binary.BigEndian, uint32(len(str)))
|
return false
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to write string to buffer: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
buffer.WriteString(str)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ParseAddr(addr string) (string, uint32) {
|
|
||||||
host, portStr, err := net.SplitHostPort(addr)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr)
|
|
||||||
return "0.0.0.0", uint32(0)
|
|
||||||
}
|
|
||||||
port, _ := strconv.Atoi(portStr)
|
|
||||||
return host, uint32(port)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
package session
|
package interaction
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@ -7,35 +7,60 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
"tunnel_pls/session/slug"
|
||||||
|
"tunnel_pls/types"
|
||||||
"tunnel_pls/utils"
|
"tunnel_pls/utils"
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
type InteractionController interface {
|
var forbiddenSlug = []string{
|
||||||
|
"ping",
|
||||||
|
}
|
||||||
|
|
||||||
|
type Lifecycle interface {
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
type Controller interface {
|
||||||
SendMessage(message string)
|
SendMessage(message string)
|
||||||
HandleUserInput()
|
HandleUserInput()
|
||||||
HandleCommand(conn ssh.Channel, command string, inSlugEditMode *bool, editSlug *string, buf *bytes.Buffer)
|
HandleCommand(command string)
|
||||||
HandleSlugEditMode(conn ssh.Channel, inSlugEditMode *bool, editSlug *string, char byte, buf *bytes.Buffer)
|
HandleSlugEditMode(connection ssh.Channel, char byte)
|
||||||
HandleSlugSave(conn ssh.Channel, inSlugEditMode *bool, editSlug *string, buf *bytes.Buffer)
|
HandleSlugSave(conn ssh.Channel)
|
||||||
HandleSlugCancel(conn ssh.Channel, inSlugEditMode *bool, buf *bytes.Buffer)
|
HandleSlugCancel(connection ssh.Channel)
|
||||||
HandleSlugUpdateError()
|
HandleSlugUpdateError()
|
||||||
ShowWelcomeMessage()
|
ShowWelcomeMessage()
|
||||||
DisplaySlugEditor()
|
DisplaySlugEditor()
|
||||||
|
SetChannel(channel ssh.Channel)
|
||||||
|
SetLifecycle(lifecycle Lifecycle)
|
||||||
|
SetSlugModificator(func(oldSlug, newSlug string) bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Forwarder interface {
|
||||||
|
Close() error
|
||||||
|
GetTunnelType() types.TunnelType
|
||||||
|
GetForwardedPort() uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
type Interaction struct {
|
type Interaction struct {
|
||||||
CommandBuffer *bytes.Buffer
|
InputLength int
|
||||||
EditMode bool
|
CommandBuffer *bytes.Buffer
|
||||||
EditSlug string
|
EditMode bool
|
||||||
channel ssh.Channel
|
EditSlug string
|
||||||
|
channel ssh.Channel
|
||||||
|
SlugManager slug.Manager
|
||||||
|
Forwarder Forwarder
|
||||||
|
Lifecycle Lifecycle
|
||||||
|
updateClientSlug func(oldSlug, newSlug string) bool
|
||||||
|
}
|
||||||
|
|
||||||
getSlug func() string
|
func (i *Interaction) SetLifecycle(lifecycle Lifecycle) {
|
||||||
setSlug func(string)
|
i.Lifecycle = lifecycle
|
||||||
|
}
|
||||||
|
|
||||||
session SessionCloser
|
func (i *Interaction) SetChannel(channel ssh.Channel) {
|
||||||
|
i.channel = channel
|
||||||
forwarder ForwarderInfo
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *Interaction) SendMessage(message string) {
|
func (i *Interaction) SendMessage(message string) {
|
||||||
@ -49,7 +74,6 @@ func (i *Interaction) SendMessage(message string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (i *Interaction) HandleUserInput() {
|
func (i *Interaction) HandleUserInput() {
|
||||||
var commandBuffer bytes.Buffer
|
|
||||||
buf := make([]byte, 1)
|
buf := make([]byte, 1)
|
||||||
i.EditMode = false
|
i.EditMode = false
|
||||||
|
|
||||||
@ -66,42 +90,47 @@ func (i *Interaction) HandleUserInput() {
|
|||||||
char := buf[0]
|
char := buf[0]
|
||||||
|
|
||||||
if i.EditMode {
|
if i.EditMode {
|
||||||
i.HandleSlugEditMode(i.channel, char, &commandBuffer)
|
i.HandleSlugEditMode(i.channel, char)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
i.SendMessage(string(buf[:n]))
|
i.SendMessage(string(buf[:n]))
|
||||||
|
|
||||||
if char == 8 || char == 127 {
|
if char == 8 || char == 127 {
|
||||||
if commandBuffer.Len() > 0 {
|
if i.InputLength > 0 {
|
||||||
commandBuffer.Truncate(commandBuffer.Len() - 1)
|
|
||||||
i.SendMessage("\b \b")
|
i.SendMessage("\b \b")
|
||||||
}
|
}
|
||||||
|
if i.CommandBuffer.Len() > 0 {
|
||||||
|
i.CommandBuffer.Truncate(i.CommandBuffer.Len() - 1)
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
i.InputLength += n
|
||||||
|
|
||||||
if char == '/' {
|
if char == '/' {
|
||||||
commandBuffer.Reset()
|
i.CommandBuffer.Reset()
|
||||||
commandBuffer.WriteByte(char)
|
i.CommandBuffer.WriteByte(char)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if commandBuffer.Len() > 0 {
|
if i.CommandBuffer.Len() > 0 {
|
||||||
if char == 13 {
|
if char == 13 {
|
||||||
i.HandleCommand(commandBuffer.String(), &commandBuffer)
|
i.SendMessage("\033[K")
|
||||||
|
i.HandleCommand(i.CommandBuffer.String())
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
commandBuffer.WriteByte(char)
|
i.CommandBuffer.WriteByte(char)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *Interaction) HandleSlugEditMode(connection ssh.Channel, char byte, commandBuffer *bytes.Buffer) {
|
func (i *Interaction) HandleSlugEditMode(connection ssh.Channel, char byte) {
|
||||||
if char == 13 {
|
if char == 13 {
|
||||||
i.HandleSlugSave(connection)
|
i.HandleSlugSave(connection)
|
||||||
} else if char == 27 {
|
} else if char == 27 {
|
||||||
i.HandleSlugCancel(connection, commandBuffer)
|
i.HandleSlugCancel(connection)
|
||||||
} else if char == 8 || char == 127 {
|
} else if char == 8 || char == 127 {
|
||||||
if len(i.EditSlug) > 0 {
|
if len(i.EditSlug) > 0 {
|
||||||
i.EditSlug = (i.EditSlug)[:len(i.EditSlug)-1]
|
i.EditSlug = (i.EditSlug)[:len(i.EditSlug)-1]
|
||||||
@ -142,10 +171,10 @@ func (i *Interaction) HandleSlugSave(connection ssh.Channel) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if isValid {
|
if isValid {
|
||||||
oldSlug := i.getSlug()
|
oldSlug := i.SlugManager.Get()
|
||||||
newSlug := i.EditSlug
|
newSlug := i.EditSlug
|
||||||
|
|
||||||
if !updateClientSlug(oldSlug, newSlug) {
|
if !i.updateClientSlug(oldSlug, newSlug) {
|
||||||
i.HandleSlugUpdateError()
|
i.HandleSlugUpdateError()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -223,7 +252,7 @@ func (i *Interaction) HandleSlugSave(connection ssh.Channel) {
|
|||||||
if utils.Getenv("tls_enabled") == "true" {
|
if utils.Getenv("tls_enabled") == "true" {
|
||||||
protocol = "https"
|
protocol = "https"
|
||||||
}
|
}
|
||||||
_, err = connection.Write([]byte(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.getSlug(), domain)))
|
_, err = connection.Write([]byte(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.SlugManager.Get(), domain)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("failed to write to channel: %v", err)
|
log.Printf("failed to write to channel: %v", err)
|
||||||
return
|
return
|
||||||
@ -233,7 +262,7 @@ func (i *Interaction) HandleSlugSave(connection ssh.Channel) {
|
|||||||
i.CommandBuffer.Reset()
|
i.CommandBuffer.Reset()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *Interaction) HandleSlugCancel(connection ssh.Channel, commandBuffer *bytes.Buffer) {
|
func (i *Interaction) HandleSlugCancel(connection ssh.Channel) {
|
||||||
i.EditMode = false
|
i.EditMode = false
|
||||||
_, err := connection.Write([]byte("\033[H\033[2J"))
|
_, err := connection.Write([]byte("\033[H\033[2J"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -260,7 +289,7 @@ func (i *Interaction) HandleSlugCancel(connection ssh.Channel, commandBuffer *by
|
|||||||
}
|
}
|
||||||
i.ShowWelcomeMessage()
|
i.ShowWelcomeMessage()
|
||||||
|
|
||||||
commandBuffer.Reset()
|
i.CommandBuffer.Reset()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *Interaction) HandleSlugUpdateError() {
|
func (i *Interaction) HandleSlugUpdateError() {
|
||||||
@ -271,44 +300,44 @@ func (i *Interaction) HandleSlugUpdateError() {
|
|||||||
i.SendMessage(fmt.Sprintf("Disconnecting in %d...\r\n", iter))
|
i.SendMessage(fmt.Sprintf("Disconnecting in %d...\r\n", iter))
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(1 * time.Second)
|
||||||
}
|
}
|
||||||
err := i.session.Close()
|
err := i.Lifecycle.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("failed to close session: %v", err)
|
log.Printf("failed to close session: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *Interaction) HandleCommand(command string, commandBuffer *bytes.Buffer) {
|
func (i *Interaction) HandleCommand(command string) {
|
||||||
switch command {
|
switch command {
|
||||||
case "/bye":
|
case "/bye":
|
||||||
i.SendMessage("\r\nClosing connection...")
|
i.SendMessage("\r\nClosing connection...")
|
||||||
err := i.session.Close()
|
err := i.Lifecycle.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("failed to close session: %v", err)
|
log.Printf("failed to close session: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
case "/help":
|
case "/help":
|
||||||
i.SendMessage("\r\nAvailable commands: /bye, /help, /clear, /slug")
|
i.SendMessage("\r\nAvailable commands: /bye, /help, /clear, /slug\r\n")
|
||||||
case "/clear":
|
case "/clear":
|
||||||
i.SendMessage("\033[H\033[2J")
|
i.SendMessage("\033[H\033[2J")
|
||||||
i.ShowWelcomeMessage()
|
i.ShowWelcomeMessage()
|
||||||
domain := utils.Getenv("domain")
|
domain := utils.Getenv("domain")
|
||||||
if i.forwarder.GetTunnelType() == HTTP {
|
if i.Forwarder.GetTunnelType() == types.HTTP {
|
||||||
protocol := "http"
|
protocol := "http"
|
||||||
if utils.Getenv("tls_enabled") == "true" {
|
if utils.Getenv("tls_enabled") == "true" {
|
||||||
protocol = "https"
|
protocol = "https"
|
||||||
}
|
}
|
||||||
i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.getSlug(), domain))
|
i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.SlugManager.Get(), domain))
|
||||||
} else {
|
} else {
|
||||||
i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", i.forwarder.GetTunnelType(), domain, i.forwarder.GetForwardedPort()))
|
i.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", domain, i.Forwarder.GetForwardedPort()))
|
||||||
}
|
}
|
||||||
case "/slug":
|
case "/slug":
|
||||||
if i.forwarder.GetTunnelType() != HTTP {
|
if i.Forwarder.GetTunnelType() != types.HTTP {
|
||||||
i.SendMessage((fmt.Sprintf("\r\n%s tunnels cannot have custom subdomains", i.forwarder.GetTunnelType())))
|
i.SendMessage(fmt.Sprintf("\r\n%s tunnels cannot have custom subdomains", i.Forwarder.GetTunnelType()))
|
||||||
} else {
|
} else {
|
||||||
i.EditMode = true
|
i.EditMode = true
|
||||||
i.EditSlug = i.getSlug()
|
i.EditSlug = i.SlugManager.Get()
|
||||||
i.SendMessage("\033[H\033[2J")
|
i.SendMessage("\033[H\033[2J")
|
||||||
i.DisplaySlugEditor()
|
i.DisplaySlugEditor()
|
||||||
i.SendMessage("➤ " + i.EditSlug + "." + utils.Getenv("domain"))
|
i.SendMessage("➤ " + i.EditSlug + "." + utils.Getenv("domain"))
|
||||||
@ -317,7 +346,7 @@ func (i *Interaction) HandleCommand(command string, commandBuffer *bytes.Buffer)
|
|||||||
i.SendMessage("Unknown command")
|
i.SendMessage("Unknown command")
|
||||||
}
|
}
|
||||||
|
|
||||||
commandBuffer.Reset()
|
i.CommandBuffer.Reset()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *Interaction) ShowWelcomeMessage() {
|
func (i *Interaction) ShowWelcomeMessage() {
|
||||||
@ -347,7 +376,7 @@ func (i *Interaction) ShowWelcomeMessage() {
|
|||||||
|
|
||||||
func (i *Interaction) DisplaySlugEditor() {
|
func (i *Interaction) DisplaySlugEditor() {
|
||||||
domain := utils.Getenv("domain")
|
domain := utils.Getenv("domain")
|
||||||
fullDomain := i.getSlug() + "." + domain
|
fullDomain := i.SlugManager.Get() + "." + domain
|
||||||
|
|
||||||
const paddingRight = 4
|
const paddingRight = 4
|
||||||
|
|
||||||
@ -383,23 +412,8 @@ func (i *Interaction) DisplaySlugEditor() {
|
|||||||
i.SendMessage("\r\n\r\n")
|
i.SendMessage("\r\n\r\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateClientSlug(oldSlug, newSlug string) bool {
|
func (i *Interaction) SetSlugModificator(modificator func(oldSlug, newSlug string) bool) {
|
||||||
clientsMutex.Lock()
|
i.updateClientSlug = modificator
|
||||||
defer clientsMutex.Unlock()
|
|
||||||
|
|
||||||
if _, exists := Clients[newSlug]; exists && newSlug != oldSlug {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
client, ok := Clients[oldSlug]
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
delete(Clients, oldSlug)
|
|
||||||
client.Forwarder.setSlug(newSlug)
|
|
||||||
Clients[newSlug] = client
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func centerText(text string, width int) string {
|
func centerText(text string, width int) string {
|
||||||
@ -409,3 +423,40 @@ func centerText(text string, width int) string {
|
|||||||
}
|
}
|
||||||
return strings.Repeat(" ", padding) + text + strings.Repeat(" ", width-len(text)-padding)
|
return strings.Repeat(" ", padding) + text + strings.Repeat(" ", width-len(text)-padding)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isValidSlug(slug string) bool {
|
||||||
|
if len(slug) < 3 || len(slug) > 20 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if slug[0] == '-' || slug[len(slug)-1] == '-' {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range slug {
|
||||||
|
if !((c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-') {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForKeyPress(connection ssh.Channel) {
|
||||||
|
keyBuf := make([]byte, 1)
|
||||||
|
for {
|
||||||
|
_, err := connection.Read(keyBuf)
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isForbiddenSlug(slug string) bool {
|
||||||
|
for _, s := range forbiddenSlug {
|
||||||
|
if slug == s {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
126
session/lifecycle/lifecycle.go
Normal file
126
session/lifecycle/lifecycle.go
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
package lifecycle
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
portUtil "tunnel_pls/internal/port"
|
||||||
|
"tunnel_pls/session/slug"
|
||||||
|
"tunnel_pls/types"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Interaction interface {
|
||||||
|
SendMessage(string)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Forwarder interface {
|
||||||
|
Close() error
|
||||||
|
GetTunnelType() types.TunnelType
|
||||||
|
GetForwardedPort() uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
type Lifecycle struct {
|
||||||
|
Status types.Status
|
||||||
|
Conn ssh.Conn
|
||||||
|
Channel ssh.Channel
|
||||||
|
|
||||||
|
Interaction Interaction
|
||||||
|
Forwarder Forwarder
|
||||||
|
SlugManager slug.Manager
|
||||||
|
unregisterClient func(slug string)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Lifecycle) SetUnregisterClient(unregisterClient func(slug string)) {
|
||||||
|
l.unregisterClient = unregisterClient
|
||||||
|
}
|
||||||
|
|
||||||
|
type SessionLifecycle interface {
|
||||||
|
Close() error
|
||||||
|
WaitForRunningStatus()
|
||||||
|
SetStatus(status types.Status)
|
||||||
|
GetConnection() ssh.Conn
|
||||||
|
GetChannel() ssh.Channel
|
||||||
|
SetChannel(channel ssh.Channel)
|
||||||
|
SetUnregisterClient(unregisterClient func(slug string))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Lifecycle) GetChannel() ssh.Channel {
|
||||||
|
return l.Channel
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Lifecycle) SetChannel(channel ssh.Channel) {
|
||||||
|
l.Channel = channel
|
||||||
|
}
|
||||||
|
func (l *Lifecycle) GetConnection() ssh.Conn {
|
||||||
|
return l.Conn
|
||||||
|
}
|
||||||
|
func (l *Lifecycle) SetStatus(status types.Status) {
|
||||||
|
l.Status = status
|
||||||
|
}
|
||||||
|
func (l *Lifecycle) WaitForRunningStatus() {
|
||||||
|
timeout := time.After(3 * time.Second)
|
||||||
|
ticker := time.NewTicker(150 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
frames := []string{"-", "\\", "|", "/"}
|
||||||
|
i := 0
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
l.Interaction.SendMessage(fmt.Sprintf("\rLoading %s", frames[i]))
|
||||||
|
i = (i + 1) % len(frames)
|
||||||
|
if l.Status == types.RUNNING {
|
||||||
|
l.Interaction.SendMessage("\r\033[K")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-timeout:
|
||||||
|
l.Interaction.SendMessage("\r\033[K")
|
||||||
|
l.Interaction.SendMessage("TCP/IP request not received in time.\r\nCheck your internet connection and confirm the server responds within 3000ms.\r\nEnsure you ran the correct command. For more details, visit https://tunnl.live.\r\n\r\n")
|
||||||
|
err := l.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to close session: %v", err)
|
||||||
|
}
|
||||||
|
log.Println("Timeout waiting for session to start running")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Lifecycle) Close() error {
|
||||||
|
err := l.Forwarder.Close()
|
||||||
|
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if l.Channel != nil {
|
||||||
|
err := l.Channel.Close()
|
||||||
|
if err != nil && !errors.Is(err, io.EOF) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if l.Conn != nil {
|
||||||
|
err := l.Conn.Close()
|
||||||
|
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
clientSlug := l.SlugManager.Get()
|
||||||
|
if clientSlug != "" {
|
||||||
|
l.unregisterClient(clientSlug)
|
||||||
|
}
|
||||||
|
|
||||||
|
if l.Forwarder.GetTunnelType() == types.TCP {
|
||||||
|
err := portUtil.Manager.SetPortStatus(l.Forwarder.GetForwardedPort(), false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@ -4,102 +4,85 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"log"
|
"log"
|
||||||
"sync"
|
"sync"
|
||||||
|
"tunnel_pls/session/forwarder"
|
||||||
|
"tunnel_pls/session/interaction"
|
||||||
|
"tunnel_pls/session/lifecycle"
|
||||||
|
"tunnel_pls/session/slug"
|
||||||
|
"tunnel_pls/types"
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
var (
|
||||||
INITIALIZING Status = "INITIALIZING"
|
clientsMutex sync.RWMutex
|
||||||
RUNNING Status = "RUNNING"
|
Clients = make(map[string]*SSHSession)
|
||||||
SETUP Status = "SETUP"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type TunnelType string
|
|
||||||
|
|
||||||
const (
|
|
||||||
HTTP TunnelType = "http"
|
|
||||||
TCP TunnelType = "tcp"
|
|
||||||
)
|
|
||||||
|
|
||||||
type SessionLifecycle interface {
|
|
||||||
Close() error
|
|
||||||
WaitForRunningStatus()
|
|
||||||
}
|
|
||||||
|
|
||||||
type SessionCloser interface {
|
|
||||||
Close() error
|
|
||||||
}
|
|
||||||
|
|
||||||
type Session interface {
|
type Session interface {
|
||||||
SessionLifecycle
|
HandleGlobalRequest(ch <-chan *ssh.Request)
|
||||||
InteractionController
|
HandleTCPIPForward(req *ssh.Request)
|
||||||
ForwardingController
|
HandleHTTPForward(req *ssh.Request, port uint16)
|
||||||
}
|
HandleTCPForward(req *ssh.Request, addr string, port uint16)
|
||||||
|
|
||||||
type Lifecycle struct {
|
|
||||||
Status Status
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type SSHSession struct {
|
type SSHSession struct {
|
||||||
Lifecycle *Lifecycle
|
Lifecycle lifecycle.SessionLifecycle
|
||||||
Interaction *Interaction
|
Interaction interaction.Controller
|
||||||
Forwarder *Forwarder
|
Forwarder forwarder.ForwardingController
|
||||||
|
SlugManager slug.Manager
|
||||||
Conn *ssh.ServerConn
|
|
||||||
channel ssh.Channel
|
|
||||||
|
|
||||||
slug string
|
|
||||||
slugMu sync.RWMutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel) {
|
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel) {
|
||||||
session := SSHSession{
|
slugManager := slug.NewManager()
|
||||||
Lifecycle: &Lifecycle{
|
forwarderManager := &forwarder.Forwarder{
|
||||||
Status: INITIALIZING,
|
Listener: nil,
|
||||||
},
|
TunnelType: "",
|
||||||
Interaction: &Interaction{
|
ForwardedPort: 0,
|
||||||
CommandBuffer: new(bytes.Buffer),
|
SlugManager: slugManager,
|
||||||
EditMode: false,
|
}
|
||||||
EditSlug: "",
|
interactionManager := &interaction.Interaction{
|
||||||
channel: nil,
|
CommandBuffer: bytes.NewBuffer(make([]byte, 0, 20)),
|
||||||
getSlug: nil,
|
EditMode: false,
|
||||||
setSlug: nil,
|
EditSlug: "",
|
||||||
session: nil,
|
SlugManager: slugManager,
|
||||||
forwarder: nil,
|
Forwarder: forwarderManager,
|
||||||
},
|
Lifecycle: nil,
|
||||||
Forwarder: &Forwarder{
|
}
|
||||||
Listener: nil,
|
lifecycleManager := &lifecycle.Lifecycle{
|
||||||
TunnelType: "",
|
Status: "",
|
||||||
ForwardedPort: 0,
|
Conn: conn,
|
||||||
getSlug: nil,
|
Channel: nil,
|
||||||
setSlug: nil,
|
Interaction: interactionManager,
|
||||||
},
|
Forwarder: forwarderManager,
|
||||||
Conn: conn,
|
SlugManager: slugManager,
|
||||||
channel: nil,
|
|
||||||
slug: "",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
session.Forwarder.getSlug = session.GetSlug
|
interactionManager.SetLifecycle(lifecycleManager)
|
||||||
session.Forwarder.setSlug = session.SetSlug
|
interactionManager.SetSlugModificator(updateClientSlug)
|
||||||
session.Interaction.getSlug = session.GetSlug
|
forwarderManager.SetLifecycle(lifecycleManager)
|
||||||
session.Interaction.setSlug = session.SetSlug
|
lifecycleManager.SetUnregisterClient(unregisterClient)
|
||||||
session.Interaction.session = &session
|
|
||||||
session.Interaction.forwarder = session.Forwarder
|
session := &SSHSession{
|
||||||
|
Lifecycle: lifecycleManager,
|
||||||
|
Interaction: interactionManager,
|
||||||
|
Forwarder: forwarderManager,
|
||||||
|
SlugManager: slugManager,
|
||||||
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
go session.waitForRunningStatus()
|
go session.Lifecycle.WaitForRunningStatus()
|
||||||
|
|
||||||
for channel := range sshChan {
|
for channel := range sshChan {
|
||||||
ch, reqs, _ := channel.Accept()
|
ch, reqs, _ := channel.Accept()
|
||||||
if session.channel == nil {
|
if session.Lifecycle.GetChannel() == nil {
|
||||||
session.channel = ch
|
session.Lifecycle.SetChannel(ch)
|
||||||
session.Interaction.channel = ch
|
session.Interaction.SetChannel(ch)
|
||||||
session.Lifecycle.Status = SETUP
|
session.Lifecycle.SetStatus(types.SETUP)
|
||||||
go session.HandleGlobalRequest(forwardingReq)
|
go session.HandleGlobalRequest(forwardingReq)
|
||||||
}
|
}
|
||||||
go session.HandleGlobalRequest(reqs)
|
go session.HandleGlobalRequest(reqs)
|
||||||
}
|
}
|
||||||
err := session.Close()
|
err := session.Lifecycle.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("failed to close session: %v", err)
|
log.Printf("failed to close session: %v", err)
|
||||||
}
|
}
|
||||||
@ -107,14 +90,40 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SSHSession) GetSlug() string {
|
func updateClientSlug(oldSlug, newSlug string) bool {
|
||||||
s.slugMu.RLock()
|
clientsMutex.Lock()
|
||||||
defer s.slugMu.RUnlock()
|
defer clientsMutex.Unlock()
|
||||||
return s.slug
|
|
||||||
|
if _, exists := Clients[newSlug]; exists && newSlug != oldSlug {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
client, ok := Clients[oldSlug]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(Clients, oldSlug)
|
||||||
|
client.SlugManager.Set(newSlug)
|
||||||
|
Clients[newSlug] = client
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SSHSession) SetSlug(slug string) {
|
func registerClient(slug string, session *SSHSession) bool {
|
||||||
s.slugMu.Lock()
|
clientsMutex.Lock()
|
||||||
s.slug = slug
|
defer clientsMutex.Unlock()
|
||||||
s.slugMu.Unlock()
|
|
||||||
|
if _, exists := Clients[slug]; exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
Clients[slug] = session
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func unregisterClient(slug string) {
|
||||||
|
clientsMutex.Lock()
|
||||||
|
defer clientsMutex.Unlock()
|
||||||
|
|
||||||
|
delete(Clients, slug)
|
||||||
}
|
}
|
||||||
|
|||||||
32
session/slug/slug.go
Normal file
32
session/slug/slug.go
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
package slug
|
||||||
|
|
||||||
|
import "sync"
|
||||||
|
|
||||||
|
type Manager interface {
|
||||||
|
Get() string
|
||||||
|
Set(slug string)
|
||||||
|
}
|
||||||
|
|
||||||
|
type manager struct {
|
||||||
|
slug string
|
||||||
|
slugMu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager() Manager {
|
||||||
|
return &manager{
|
||||||
|
slug: "",
|
||||||
|
slugMu: sync.RWMutex{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *manager) Get() string {
|
||||||
|
s.slugMu.RLock()
|
||||||
|
defer s.slugMu.RUnlock()
|
||||||
|
return s.slug
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *manager) Set(slug string) {
|
||||||
|
s.slugMu.Lock()
|
||||||
|
s.slug = slug
|
||||||
|
s.slugMu.Unlock()
|
||||||
|
}
|
||||||
21
types/types.go
Normal file
21
types/types.go
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
type Status string
|
||||||
|
|
||||||
|
const (
|
||||||
|
INITIALIZING Status = "INITIALIZING"
|
||||||
|
RUNNING Status = "RUNNING"
|
||||||
|
SETUP Status = "SETUP"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TunnelType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
HTTP TunnelType = "HTTP"
|
||||||
|
TCP TunnelType = "TCP"
|
||||||
|
)
|
||||||
|
|
||||||
|
var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" +
|
||||||
|
"Content-Length: 11\r\n" +
|
||||||
|
"Content-Type: text/plain\r\n\r\n" +
|
||||||
|
"Bad Gateway")
|
||||||
Reference in New Issue
Block a user