chore(restructure): reorganize project layout
Docker Build and Push / build-and-push-branches (push) Has been skipped
Docker Build and Push / build-and-push-tags (push) Successful in 13m1s

- Reorganize internal packages and overall project structure
- Update imports and wiring to match the new layout
- Separate HTTP parsing and streaming from the server package
- Separate middleware from the server package
- Separate session registry from the session package
- Move HTTP, HTTPS, and TCP servers to the transport package
- Session package no longer starts the TCP server directly
- Server package no longer starts HTTP/HTTPS servers on initialization
- Forwarder no longer handles accepting TCP requests
- Move session details to the types package
- HTTP/HTTPS initialization is now the responsibility of main
This commit is contained in:
2026-01-21 14:06:46 +07:00
parent 9a4539cc02
commit 1e12373359
30 changed files with 862 additions and 704 deletions
+3 -4
View File
@@ -8,10 +8,9 @@ import (
"log" "log"
"time" "time"
"tunnel_pls/internal/config" "tunnel_pls/internal/config"
"tunnel_pls/internal/registry"
"tunnel_pls/types" "tunnel_pls/types"
"tunnel_pls/session"
proto "git.fossy.my.id/bagas/tunnel-please-grpc/gen" proto "git.fossy.my.id/bagas/tunnel-please-grpc/gen"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
@@ -32,13 +31,13 @@ type Client interface {
type client struct { type client struct {
conn *grpc.ClientConn conn *grpc.ClientConn
address string address string
sessionRegistry session.Registry sessionRegistry registry.Registry
eventService proto.EventServiceClient eventService proto.EventServiceClient
authorizeConnectionService proto.UserServiceClient authorizeConnectionService proto.UserServiceClient
closing bool closing bool
} }
func New(address string, sessionRegistry session.Registry) (Client, error) { func New(address string, sessionRegistry registry.Registry) (Client, error) {
var opts []grpc.DialOption var opts []grpc.DialOption
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
@@ -1,4 +1,4 @@
package httpheader package header
type ResponseHeader interface { type ResponseHeader interface {
Value(key string) string Value(key string) string
@@ -1,4 +1,4 @@
package httpheader package header
import ( import (
"bufio" "bufio"
@@ -1,11 +1,11 @@
package httpheader package header
import ( import (
"bufio" "bufio"
"fmt" "fmt"
) )
func NewRequestHeader(r interface{}) (RequestHeader, error) { func NewRequest(r interface{}) (RequestHeader, error) {
switch v := r.(type) { switch v := r.(type) {
case []byte: case []byte:
return parseHeadersFromBytes(v) return parseHeadersFromBytes(v)
@@ -1,11 +1,11 @@
package httpheader package header
import ( import (
"bytes" "bytes"
"fmt" "fmt"
) )
func NewResponseHeader(headerData []byte) (ResponseHeader, error) { func NewResponse(headerData []byte) (ResponseHeader, error) {
header := &responseHeader{ header := &responseHeader{
startLine: nil, startLine: nil,
headers: make(map[string]string, 16), headers: make(map[string]string, 16),
+29
View File
@@ -0,0 +1,29 @@
package stream
import "bytes"
func splitHeaderAndBody(data []byte, delimiterIdx int) ([]byte, []byte) {
headerByte := data[:delimiterIdx+len(DELIMITER)]
body := data[delimiterIdx+len(DELIMITER):]
return headerByte, body
}
func isHTTPHeader(buf []byte) bool {
lines := bytes.Split(buf, []byte("\r\n"))
startLine := string(lines[0])
if !requestLine.MatchString(startLine) && !responseLine.MatchString(startLine) {
return false
}
for _, line := range lines[1:] {
if len(line) == 0 {
break
}
colonIdx := bytes.IndexByte(line, ':')
if colonIdx <= 0 {
return false
}
}
return true
}
+50
View File
@@ -0,0 +1,50 @@
package stream
import (
"bytes"
"tunnel_pls/internal/http/header"
)
func (hs *http) Read(p []byte) (int, error) {
tmp := make([]byte, len(p))
read, err := hs.reader.Read(tmp)
if read == 0 && err != nil {
return 0, err
}
tmp = tmp[:read]
headerEndIdx := bytes.Index(tmp, DELIMITER)
if headerEndIdx == -1 {
return handleNoDelimiter(p, tmp, err)
}
headerByte, bodyByte := splitHeaderAndBody(tmp, headerEndIdx)
if !isHTTPHeader(headerByte) {
copy(p, tmp)
return read, nil
}
return hs.processHTTPRequest(p, headerByte, bodyByte)
}
func (hs *http) processHTTPRequest(p, headerByte, bodyByte []byte) (int, error) {
reqhf, err := header.NewRequest(headerByte)
if err != nil {
return 0, err
}
if err = hs.ApplyRequestMiddlewares(reqhf); err != nil {
return 0, err
}
hs.reqHeader = reqhf
combined := append(reqhf.Finalize(), bodyByte...)
return copy(p, combined), nil
}
func handleNoDelimiter(p, tmp []byte, err error) (int, error) {
copy(p, tmp)
return len(tmp), err
}
+103
View File
@@ -0,0 +1,103 @@
package stream
import (
"io"
"log"
"net"
"regexp"
"tunnel_pls/internal/http/header"
"tunnel_pls/internal/middleware"
)
var DELIMITER = []byte{0x0D, 0x0A, 0x0D, 0x0A}
var requestLine = regexp.MustCompile(`^(GET|POST|PUT|DELETE|HEAD|OPTIONS|PATCH|TRACE|CONNECT) \S+ HTTP/\d\.\d$`)
var responseLine = regexp.MustCompile(`^HTTP/\d\.\d \d{3} .+`)
type HTTP interface {
io.ReadWriteCloser
CloseWrite() error
RemoteAddr() net.Addr
UseResponseMiddleware(mw middleware.ResponseMiddleware)
UseRequestMiddleware(mw middleware.RequestMiddleware)
SetRequestHeader(header header.RequestHeader)
RequestMiddlewares() []middleware.RequestMiddleware
ResponseMiddlewares() []middleware.ResponseMiddleware
ApplyResponseMiddlewares(resphf header.ResponseHeader, body []byte) error
ApplyRequestMiddlewares(reqhf header.RequestHeader) error
}
type http struct {
remoteAddr net.Addr
writer io.Writer
reader io.Reader
headerBuf []byte
buf []byte
respHeader header.ResponseHeader
reqHeader header.RequestHeader
respMW []middleware.ResponseMiddleware
reqMW []middleware.RequestMiddleware
}
func New(writer io.Writer, reader io.Reader, remoteAddr net.Addr) HTTP {
return &http{
remoteAddr: remoteAddr,
writer: writer,
reader: reader,
buf: make([]byte, 0, 4096),
}
}
func (hs *http) RemoteAddr() net.Addr {
return hs.remoteAddr
}
func (hs *http) UseResponseMiddleware(mw middleware.ResponseMiddleware) {
hs.respMW = append(hs.respMW, mw)
}
func (hs *http) UseRequestMiddleware(mw middleware.RequestMiddleware) {
hs.reqMW = append(hs.reqMW, mw)
}
func (hs *http) SetRequestHeader(header header.RequestHeader) {
hs.reqHeader = header
}
func (hs *http) RequestMiddlewares() []middleware.RequestMiddleware {
return hs.reqMW
}
func (hs *http) ResponseMiddlewares() []middleware.ResponseMiddleware {
return hs.respMW
}
func (hs *http) Close() error {
return hs.writer.(io.Closer).Close()
}
func (hs *http) CloseWrite() error {
if closer, ok := hs.writer.(interface{ CloseWrite() error }); ok {
return closer.CloseWrite()
}
return hs.Close()
}
func (hs *http) ApplyRequestMiddlewares(reqhf header.RequestHeader) error {
for _, m := range hs.RequestMiddlewares() {
if err := m.HandleRequest(reqhf); err != nil {
log.Printf("Error when applying request middleware: %v", err)
return err
}
}
return nil
}
func (hs *http) ApplyResponseMiddlewares(resphf header.ResponseHeader, bodyByte []byte) error {
for _, m := range hs.ResponseMiddlewares() {
if err := m.HandleResponse(resphf, bodyByte); err != nil {
log.Printf("Cannot apply middleware: %s\n", err)
return err
}
}
return nil
}
+88
View File
@@ -0,0 +1,88 @@
package stream
import (
"bytes"
"tunnel_pls/internal/http/header"
)
func (hs *http) Write(p []byte) (int, error) {
if hs.shouldBypassBuffering(p) {
hs.respHeader = nil
}
if hs.respHeader != nil {
return hs.writer.Write(p)
}
hs.buf = append(hs.buf, p...)
headerEndIdx := bytes.Index(hs.buf, DELIMITER)
if headerEndIdx == -1 {
return len(p), nil
}
return hs.processBufferedResponse(p, headerEndIdx)
}
func (hs *http) shouldBypassBuffering(p []byte) bool {
return hs.respHeader != nil && len(hs.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/"
}
func (hs *http) processBufferedResponse(p []byte, delimiterIdx int) (int, error) {
headerByte, bodyByte := splitHeaderAndBody(hs.buf, delimiterIdx)
if !isHTTPHeader(headerByte) {
return hs.writeRawBuffer()
}
if err := hs.processHTTPResponse(headerByte, bodyByte); err != nil {
return 0, err
}
hs.buf = nil
return len(p), nil
}
func (hs *http) writeRawBuffer() (int, error) {
_, err := hs.writer.Write(hs.buf)
length := len(hs.buf)
hs.buf = nil
if err != nil {
return 0, err
}
return length, nil
}
func (hs *http) processHTTPResponse(headerByte, bodyByte []byte) error {
resphf, err := header.NewResponse(headerByte)
if err != nil {
return err
}
if err = hs.ApplyResponseMiddlewares(resphf, bodyByte); err != nil {
return err
}
hs.respHeader = resphf
finalHeader := resphf.Finalize()
if err = hs.writeHeaderAndBody(finalHeader, bodyByte); err != nil {
return err
}
return nil
}
func (hs *http) writeHeaderAndBody(header, bodyByte []byte) error {
if _, err := hs.writer.Write(header); err != nil {
return err
}
if len(bodyByte) > 0 {
if _, err := hs.writer.Write(bodyByte); err != nil {
return err
}
}
return nil
}
+23
View File
@@ -0,0 +1,23 @@
package middleware
import (
"net"
"tunnel_pls/internal/http/header"
)
type ForwardedFor struct {
addr net.Addr
}
func NewForwardedFor(addr net.Addr) *ForwardedFor {
return &ForwardedFor{addr: addr}
}
func (ff *ForwardedFor) HandleRequest(header header.RequestHeader) error {
host, _, err := net.SplitHostPort(ff.addr.String())
if err != nil {
return err
}
header.Set("X-Forwarded-For", host)
return nil
}
+13
View File
@@ -0,0 +1,13 @@
package middleware
import (
"tunnel_pls/internal/http/header"
)
type RequestMiddleware interface {
HandleRequest(header header.RequestHeader) error
}
type ResponseMiddleware interface {
HandleResponse(header header.ResponseHeader, body []byte) error
}
+16
View File
@@ -0,0 +1,16 @@
package middleware
import (
"tunnel_pls/internal/http/header"
)
type TunnelFingerprint struct{}
func NewTunnelFingerprint() *TunnelFingerprint {
return &TunnelFingerprint{}
}
func (h *TunnelFingerprint) HandleResponse(header header.ResponseHeader, body []byte) error {
header.Set("Server", "Tunnel Please")
return nil
}
+19 -19
View File
@@ -6,37 +6,37 @@ import (
"sync" "sync"
) )
type Registry interface { type Port interface {
AddPortRange(startPort, endPort uint16) error AddRange(startPort, endPort uint16) error
GetUnassignedPort() (uint16, bool) Unassigned() (uint16, bool)
SetPortStatus(port uint16, assigned bool) error SetStatus(port uint16, assigned bool) error
ClaimPort(port uint16) (claimed bool) Claim(port uint16) (claimed bool)
} }
type registry struct { type port struct {
mu sync.RWMutex mu sync.RWMutex
ports map[uint16]bool ports map[uint16]bool
sortedPorts []uint16 sortedPorts []uint16
} }
func New() Registry { func New() Port {
return &registry{ return &port{
ports: make(map[uint16]bool), ports: make(map[uint16]bool),
sortedPorts: []uint16{}, sortedPorts: []uint16{},
} }
} }
func (pm *registry) AddPortRange(startPort, endPort uint16) error { func (pm *port) AddRange(startPort, endPort uint16) error {
pm.mu.Lock() pm.mu.Lock()
defer pm.mu.Unlock() defer pm.mu.Unlock()
if startPort > endPort { if startPort > endPort {
return fmt.Errorf("start port cannot be greater than end port") return fmt.Errorf("start port cannot be greater than end port")
} }
for port := startPort; port <= endPort; port++ { for index := startPort; index <= endPort; index++ {
if _, exists := pm.ports[port]; !exists { if _, exists := pm.ports[index]; !exists {
pm.ports[port] = false pm.ports[index] = false
pm.sortedPorts = append(pm.sortedPorts, port) pm.sortedPorts = append(pm.sortedPorts, index)
} }
} }
sort.Slice(pm.sortedPorts, func(i, j int) bool { sort.Slice(pm.sortedPorts, func(i, j int) bool {
@@ -45,19 +45,19 @@ func (pm *registry) AddPortRange(startPort, endPort uint16) error {
return nil return nil
} }
func (pm *registry) GetUnassignedPort() (uint16, bool) { func (pm *port) Unassigned() (uint16, bool) {
pm.mu.Lock() pm.mu.Lock()
defer pm.mu.Unlock() defer pm.mu.Unlock()
for _, port := range pm.sortedPorts { for _, index := range pm.sortedPorts {
if !pm.ports[port] { if !pm.ports[index] {
return port, true return index, true
} }
} }
return 0, false return 0, false
} }
func (pm *registry) SetPortStatus(port uint16, assigned bool) error { func (pm *port) SetStatus(port uint16, assigned bool) error {
pm.mu.Lock() pm.mu.Lock()
defer pm.mu.Unlock() defer pm.mu.Unlock()
@@ -65,7 +65,7 @@ func (pm *registry) SetPortStatus(port uint16, assigned bool) error {
return nil return nil
} }
func (pm *registry) ClaimPort(port uint16) (claimed bool) { func (pm *port) Claim(port uint16) (claimed bool) {
pm.mu.Lock() pm.mu.Lock()
defer pm.mu.Unlock() defer pm.mu.Unlock()
@@ -1,13 +1,25 @@
package session package registry
import ( import (
"fmt" "fmt"
"sync" "sync"
"tunnel_pls/session/forwarder"
"tunnel_pls/session/interaction"
"tunnel_pls/session/lifecycle"
"tunnel_pls/session/slug"
"tunnel_pls/types" "tunnel_pls/types"
) )
type Key = types.SessionKey type Key = types.SessionKey
type Session interface {
Lifecycle() lifecycle.Lifecycle
Interaction() interaction.Interaction
Forwarder() forwarder.Forwarder
Slug() slug.Slug
Detail() *types.Detail
}
type Registry interface { type Registry interface {
Get(key Key) (session Session, err error) Get(key Key) (session Session, err error)
GetWithUser(user string, key Key) (session Session, err error) GetWithUser(user string, key Key) (session Session, err error)
@@ -35,12 +47,12 @@ func (r *registry) Get(key Key) (session Session, err error) {
userID, ok := r.slugIndex[key] userID, ok := r.slugIndex[key]
if !ok { if !ok {
return nil, fmt.Errorf("session not found") return nil, fmt.Errorf("Session not found")
} }
client, ok := r.byUser[userID][key] client, ok := r.byUser[userID][key]
if !ok { if !ok {
return nil, fmt.Errorf("session not found") return nil, fmt.Errorf("Session not found")
} }
return client, nil return client, nil
} }
@@ -51,7 +63,7 @@ func (r *registry) GetWithUser(user string, key Key) (session Session, err error
client, ok := r.byUser[user][key] client, ok := r.byUser[user][key]
if !ok { if !ok {
return nil, fmt.Errorf("session not found") return nil, fmt.Errorf("Session not found")
} }
return client, nil return client, nil
} }
@@ -81,7 +93,7 @@ func (r *registry) Update(user string, oldKey, newKey Key) error {
} }
client, ok := r.byUser[user][oldKey] client, ok := r.byUser[user][oldKey]
if !ok { if !ok {
return fmt.Errorf("session not found") return fmt.Errorf("Session not found")
} }
delete(r.byUser[user], oldKey) delete(r.byUser[user], oldKey)
@@ -97,7 +109,7 @@ func (r *registry) Update(user string, oldKey, newKey Key) error {
return nil return nil
} }
func (r *registry) Register(key Key, session Session) (success bool) { func (r *registry) Register(key Key, userSession Session) (success bool) {
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
@@ -105,12 +117,12 @@ func (r *registry) Register(key Key, session Session) (success bool) {
return false return false
} }
userID := session.Lifecycle().User() userID := userSession.Lifecycle().User()
if r.byUser[userID] == nil { if r.byUser[userID] == nil {
r.byUser[userID] = make(map[Key]Session) r.byUser[userID] = make(map[Key]Session)
} }
r.byUser[userID][key] = session r.byUser[userID][key] = userSession
r.slugIndex[key] = userID r.slugIndex[key] = userID
return true return true
} }
+40
View File
@@ -0,0 +1,40 @@
package transport
import (
"errors"
"log"
"net"
"tunnel_pls/internal/registry"
)
type httpServer struct {
handler *httpHandler
port string
}
func NewHTTPServer(port string, sessionRegistry registry.Registry, redirectTLS bool) Transport {
return &httpServer{
handler: newHTTPHandler(sessionRegistry, redirectTLS),
port: port,
}
}
func (ht *httpServer) Listen() (net.Listener, error) {
return net.Listen("tcp", ":"+ht.port)
}
func (ht *httpServer) Serve(listener net.Listener) error {
log.Printf("HTTP server is starting on port %s", ht.port)
for {
conn, err := listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return err
}
log.Printf("Error accepting connection: %v", err)
continue
}
go ht.handler.handler(conn, false)
}
}
+227
View File
@@ -0,0 +1,227 @@
package transport
import (
"bufio"
"errors"
"fmt"
"log"
"net"
"net/http"
"strings"
"time"
"tunnel_pls/internal/config"
"tunnel_pls/internal/http/header"
"tunnel_pls/internal/http/stream"
"tunnel_pls/internal/middleware"
"tunnel_pls/internal/registry"
"tunnel_pls/types"
"golang.org/x/crypto/ssh"
)
type httpHandler struct {
sessionRegistry registry.Registry
redirectTLS bool
}
func newHTTPHandler(sessionRegistry registry.Registry, redirectTLS bool) *httpHandler {
return &httpHandler{
sessionRegistry: sessionRegistry,
redirectTLS: redirectTLS,
}
}
func (hh *httpHandler) redirect(conn net.Conn, status int, location string) error {
_, err := conn.Write([]byte(fmt.Sprintf("HTTP/1.1 %d Moved Permanently\r\n", status) +
fmt.Sprintf("Location: %s", location) +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"\r\n"))
if err != nil {
return err
}
return nil
}
func (hh *httpHandler) badRequest(conn net.Conn) error {
if _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")); err != nil {
return err
}
return nil
}
func (hh *httpHandler) handler(conn net.Conn, isTLS bool) {
defer hh.closeConnection(conn)
dstReader := bufio.NewReader(conn)
reqhf, err := header.NewRequest(dstReader)
if err != nil {
log.Printf("Error creating request header: %v", err)
return
}
slug, err := hh.extractSlug(reqhf)
if err != nil {
_ = hh.badRequest(conn)
return
}
if hh.shouldRedirectToTLS(isTLS) {
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, config.Getenv("DOMAIN", "localhost")))
return
}
if hh.handlePingRequest(slug, conn) {
return
}
sshSession, err := hh.getSession(slug)
if err != nil {
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug))
return
}
hw := stream.New(conn, dstReader, conn.RemoteAddr())
defer func(hw stream.HTTP) {
err = hw.Close()
if err != nil {
log.Printf("Error closing HTTP stream: %v", err)
}
}(hw)
hh.forwardRequest(hw, reqhf, sshSession)
}
func (hh *httpHandler) closeConnection(conn net.Conn) {
err := conn.Close()
if err != nil && !errors.Is(err, net.ErrClosed) {
log.Printf("Error closing connection: %v", err)
}
}
func (hh *httpHandler) extractSlug(reqhf header.RequestHeader) (string, error) {
host := strings.Split(reqhf.Value("Host"), ".")
if len(host) < 1 {
return "", errors.New("invalid host")
}
return host[0], nil
}
func (hh *httpHandler) shouldRedirectToTLS(isTLS bool) bool {
return !isTLS && hh.redirectTLS
}
func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool {
if slug != "ping" {
return false
}
_, err := conn.Write([]byte(
"HTTP/1.1 200 OK\r\n" +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"Access-Control-Allow-Origin: *\r\n" +
"Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" +
"Access-Control-Allow-Headers: *\r\n" +
"\r\n",
))
if err != nil {
log.Println("Failed to write 200 OK:", err)
}
return true
}
func (hh *httpHandler) getSession(slug string) (registry.Session, error) {
sshSession, err := hh.sessionRegistry.Get(types.SessionKey{
Id: slug,
Type: types.HTTP,
})
if err != nil {
return nil, err
}
return sshSession, nil
}
func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) {
channel, err := hh.openForwardedChannel(hw, sshSession)
defer func() {
err = channel.Close()
if err != nil {
log.Printf("Error closing forwarded channel: %v", err)
}
}()
if err != nil {
log.Printf("Failed to establish channel: %v", err)
sshSession.Forwarder().WriteBadGatewayResponse(hw)
return
}
hh.setupMiddlewares(hw)
if err = hh.sendInitialRequest(hw, initialRequest, channel); err != nil {
log.Printf("Failed to forward initial request: %v", err)
return
}
sshSession.Forwarder().HandleConnection(hw, channel)
}
func (hh *httpHandler) openForwardedChannel(hw stream.HTTP, sshSession registry.Session) (ssh.Channel, error) {
payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(hw.RemoteAddr())
type channelResult struct {
channel ssh.Channel
reqs <-chan *ssh.Request
err error
}
resultChan := make(chan channelResult, 1)
go func() {
channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload)
select {
case resultChan <- channelResult{channel, reqs, err}:
default:
hh.cleanupUnusedChannel(channel, reqs)
}
}()
select {
case result := <-resultChan:
if result.err != nil {
return nil, result.err
}
go ssh.DiscardRequests(result.reqs)
return result.channel, nil
case <-time.After(5 * time.Second):
return nil, errors.New("timeout opening forwarded-tcpip channel")
}
}
func (hh *httpHandler) cleanupUnusedChannel(channel ssh.Channel, reqs <-chan *ssh.Request) {
if channel != nil {
if err := channel.Close(); err != nil {
log.Printf("Failed to close unused channel: %v", err)
}
go ssh.DiscardRequests(reqs)
}
}
func (hh *httpHandler) setupMiddlewares(hw stream.HTTP) {
fingerprintMiddleware := middleware.NewTunnelFingerprint()
forwardedForMiddleware := middleware.NewForwardedFor(hw.RemoteAddr())
hw.UseResponseMiddleware(fingerprintMiddleware)
hw.UseRequestMiddleware(forwardedForMiddleware)
}
func (hh *httpHandler) sendInitialRequest(hw stream.HTTP, initialRequest header.RequestHeader, channel ssh.Channel) error {
hw.SetRequestHeader(initialRequest)
if err := hw.ApplyRequestMiddlewares(initialRequest); err != nil {
return fmt.Errorf("error applying request middlewares: %w", err)
}
if _, err := channel.Write(initialRequest.Finalize()); err != nil {
return fmt.Errorf("error writing to channel: %w", err)
}
return nil
}
+48
View File
@@ -0,0 +1,48 @@
package transport
import (
"crypto/tls"
"errors"
"log"
"net"
"tunnel_pls/internal/registry"
)
type https struct {
httpHandler *httpHandler
domain string
port string
}
func NewHTTPSServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool) Transport {
return &https{
httpHandler: newHTTPHandler(sessionRegistry, redirectTLS),
domain: domain,
port: port,
}
}
func (ht *https) Listen() (net.Listener, error) {
tlsConfig, err := NewTLSConfig(ht.domain)
if err != nil {
return nil, err
}
return tls.Listen("tcp", ":"+ht.port, tlsConfig)
}
func (ht *https) Serve(listener net.Listener) error {
log.Printf("HTTPS server is starting on port %s", ht.port)
for {
conn, err := listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return err
}
log.Printf("Error accepting connection: %v", err)
continue
}
go ht.httpHandler.handler(conn, true)
}
}
+66
View File
@@ -0,0 +1,66 @@
package transport
import (
"errors"
"fmt"
"io"
"log"
"net"
"golang.org/x/crypto/ssh"
)
type tcp struct {
port uint16
forwarder forwarder
}
type forwarder interface {
CreateForwardedTCPIPPayload(origin net.Addr) []byte
OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error)
HandleConnection(dst io.ReadWriter, src ssh.Channel)
}
func NewTCPServer(port uint16, forwarder forwarder) Transport {
return &tcp{
port: port,
forwarder: forwarder,
}
}
func (tt *tcp) Listen() (net.Listener, error) {
return net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", tt.port))
}
func (tt *tcp) Serve(listener net.Listener) error {
for {
conn, err := listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return nil
}
log.Printf("Error accepting connection: %v", err)
continue
}
go tt.handleTcp(conn)
}
}
func (tt *tcp) handleTcp(conn net.Conn) {
defer func() {
err := conn.Close()
if err != nil {
log.Printf("Failed to close connection: %v", err)
}
}()
payload := tt.forwarder.CreateForwardedTCPIPPayload(conn.RemoteAddr())
channel, reqs, err := tt.forwarder.OpenForwardedChannel(payload)
if err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
return
}
go ssh.DiscardRequests(reqs)
tt.forwarder.HandleConnection(conn, channel)
}
+1 -1
View File
@@ -1,4 +1,4 @@
package server package transport
import ( import (
"context" "context"
+10
View File
@@ -0,0 +1,10 @@
package transport
import (
"net"
)
type Transport interface {
Listen() (net.Listener, error)
Serve(listener net.Listener) error
}
+49 -5
View File
@@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"log" "log"
"net"
"net/http" "net/http"
_ "net/http/pprof" _ "net/http/pprof"
"os" "os"
@@ -16,9 +17,10 @@ import (
"tunnel_pls/internal/grpc/client" "tunnel_pls/internal/grpc/client"
"tunnel_pls/internal/key" "tunnel_pls/internal/key"
"tunnel_pls/internal/port" "tunnel_pls/internal/port"
"tunnel_pls/internal/registry"
"tunnel_pls/internal/transport"
"tunnel_pls/internal/version"
"tunnel_pls/server" "tunnel_pls/server"
"tunnel_pls/session"
"tunnel_pls/version"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@@ -76,7 +78,7 @@ func main() {
} }
sshConfig.AddHostKey(private) sshConfig.AddHostKey(private)
sessionRegistry := session.NewRegistry() sessionRegistry := registry.NewRegistry()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@@ -131,7 +133,7 @@ func main() {
log.Fatalf("Failed to parse end port: %s", err) log.Fatalf("Failed to parse end port: %s", err)
} }
if err = portManager.AddPortRange(uint16(start), uint16(end)); err != nil { if err = portManager.AddRange(uint16(start), uint16(end)); err != nil {
log.Fatalf("Failed to add port range: %s", err) log.Fatalf("Failed to add port range: %s", err)
} }
log.Printf("PortRegistry range configured: %d-%d", start, end) log.Printf("PortRegistry range configured: %d-%d", start, end)
@@ -139,9 +141,51 @@ func main() {
log.Printf("Invalid ALLOWED_PORTS format, expected 'start-end', got: %s", rawRange) log.Printf("Invalid ALLOWED_PORTS format, expected 'start-end', got: %s", rawRange)
} }
} }
tlsEnabled := config.Getenv("TLS_ENABLED", "false") == "true"
redirectTLS := config.Getenv("TLS_ENABLED", "false") == "true" && config.Getenv("TLS_REDIRECT", "false") == "true"
go func() {
httpPort := config.Getenv("HTTP_PORT", "8080")
var httpListener net.Listener
httpserver := transport.NewHTTPServer(httpPort, sessionRegistry, redirectTLS)
httpListener, err = httpserver.Listen()
if err != nil {
errChan <- fmt.Errorf("failed to start http server: %w", err)
return
}
err = httpserver.Serve(httpListener)
if err != nil {
errChan <- fmt.Errorf("error when serving http server: %w", err)
return
}
}()
if tlsEnabled {
go func() {
httpsPort := config.Getenv("HTTPS_PORT", "8443")
domain := config.Getenv("DOMAIN", "localhost")
var httpListener net.Listener
httpserver := transport.NewHTTPSServer(domain, httpsPort, sessionRegistry, redirectTLS)
httpListener, err = httpserver.Listen()
if err != nil {
errChan <- fmt.Errorf("failed to start http server: %w", err)
return
}
err = httpserver.Serve(httpListener)
if err != nil {
errChan <- fmt.Errorf("error when serving http server: %w", err)
return
}
}()
}
var app server.Server var app server.Server
go func() { go func() {
app, err = server.New(sshConfig, sessionRegistry, grpcClient, portManager) sshPort := config.Getenv("PORT", "2200")
app, err = server.New(sshConfig, sessionRegistry, grpcClient, portManager, sshPort)
if err != nil { if err != nil {
errChan <- fmt.Errorf("failed to start server: %s", err) errChan <- fmt.Errorf("failed to start server: %s", err)
return return
-276
View File
@@ -1,276 +0,0 @@
package server
import (
"bufio"
"crypto/tls"
"errors"
"fmt"
"log"
"net"
"net/http"
"strings"
"time"
"tunnel_pls/internal/config"
"tunnel_pls/internal/httpheader"
"tunnel_pls/session"
"tunnel_pls/types"
"golang.org/x/crypto/ssh"
)
type HTTPServer interface {
ListenAndServe() error
ListenAndServeTLS() error
}
type httpServer struct {
sessionRegistry session.Registry
redirectTLS bool
}
func NewHTTPServer(sessionRegistry session.Registry, redirectTLS bool) HTTPServer {
return &httpServer{
sessionRegistry: sessionRegistry,
redirectTLS: redirectTLS,
}
}
func (hs *httpServer) ListenAndServe() error {
httpPort := config.Getenv("HTTP_PORT", "8080")
listener, err := net.Listen("tcp", ":"+httpPort)
if err != nil {
return errors.New("Error listening: " + err.Error())
}
go func() {
for {
var conn net.Conn
conn, err = listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
}
log.Printf("Error accepting connection: %v", err)
continue
}
go hs.handler(conn, false)
}
}()
return nil
}
func (hs *httpServer) ListenAndServeTLS() error {
domain := config.Getenv("DOMAIN", "localhost")
httpsPort := config.Getenv("HTTPS_PORT", "8443")
tlsConfig, err := NewTLSConfig(domain)
if err != nil {
return fmt.Errorf("failed to initialize TLS config: %w", err)
}
ln, err := tls.Listen("tcp", ":"+httpsPort, tlsConfig)
if err != nil {
return err
}
go func() {
for {
var conn net.Conn
conn, err = ln.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
log.Println("https server closed")
}
log.Printf("Error accepting connection: %v", err)
continue
}
go hs.handler(conn, true)
}
}()
return nil
}
func (hs *httpServer) redirect(conn net.Conn, status int, location string) error {
_, err := conn.Write([]byte(fmt.Sprintf("HTTP/1.1 %d Moved Permanently\r\n", status) +
fmt.Sprintf("Location: %s", location) +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"\r\n"))
if err != nil {
return err
}
return nil
}
func (hs *httpServer) badRequest(conn net.Conn) error {
if _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")); err != nil {
return err
}
return nil
}
func (hs *httpServer) handler(conn net.Conn, isTLS bool) {
defer hs.closeConnection(conn)
dstReader := bufio.NewReader(conn)
reqhf, err := httpheader.NewRequestHeader(dstReader)
if err != nil {
log.Printf("Error creating request header: %v", err)
return
}
slug, err := hs.extractSlug(reqhf)
if err != nil {
_ = hs.badRequest(conn)
return
}
if hs.shouldRedirectToTLS(isTLS) {
_ = hs.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, config.Getenv("DOMAIN", "localhost")))
return
}
if hs.handlePingRequest(slug, conn) {
return
}
sshSession, err := hs.getSession(slug)
if err != nil {
_ = hs.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug))
return
}
hw := NewHTTPWriter(conn, dstReader, conn.RemoteAddr())
hs.forwardRequest(hw, reqhf, sshSession)
}
func (hs *httpServer) closeConnection(conn net.Conn) {
err := conn.Close()
if err != nil && !errors.Is(err, net.ErrClosed) {
log.Printf("Error closing connection: %v", err)
}
}
func (hs *httpServer) extractSlug(reqhf httpheader.RequestHeader) (string, error) {
host := strings.Split(reqhf.Value("Host"), ".")
if len(host) < 1 {
return "", errors.New("invalid host")
}
return host[0], nil
}
func (hs *httpServer) shouldRedirectToTLS(isTLS bool) bool {
return !isTLS && hs.redirectTLS
}
func (hs *httpServer) handlePingRequest(slug string, conn net.Conn) bool {
if slug != "ping" {
return false
}
_, err := conn.Write([]byte(
"HTTP/1.1 200 OK\r\n" +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"Access-Control-Allow-Origin: *\r\n" +
"Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" +
"Access-Control-Allow-Headers: *\r\n" +
"\r\n",
))
if err != nil {
log.Println("Failed to write 200 OK:", err)
}
return true
}
func (hs *httpServer) getSession(slug string) (session.Session, error) {
sshSession, err := hs.sessionRegistry.Get(types.SessionKey{
Id: slug,
Type: types.HTTP,
})
if err != nil {
return nil, err
}
return sshSession, nil
}
func (hs *httpServer) forwardRequest(hw HTTPWriter, initialRequest httpheader.RequestHeader, sshSession session.Session) {
channel, err := hs.openForwardedChannel(hw, sshSession)
if err != nil {
log.Printf("Failed to establish channel: %v", err)
sshSession.Forwarder().WriteBadGatewayResponse(hw)
return
}
hs.setupMiddlewares(hw)
if err := hs.sendInitialRequest(hw, initialRequest, channel); err != nil {
log.Printf("Failed to forward initial request: %v", err)
return
}
sshSession.Forwarder().HandleConnection(hw, channel, hw.RemoteAddr())
}
func (hs *httpServer) openForwardedChannel(hw HTTPWriter, sshSession session.Session) (ssh.Channel, error) {
payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(hw.RemoteAddr())
type channelResult struct {
channel ssh.Channel
reqs <-chan *ssh.Request
err error
}
resultChan := make(chan channelResult, 1)
go func() {
channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload)
select {
case resultChan <- channelResult{channel, reqs, err}:
default:
hs.cleanupUnusedChannel(channel, reqs)
}
}()
select {
case result := <-resultChan:
if result.err != nil {
return nil, result.err
}
go ssh.DiscardRequests(result.reqs)
return result.channel, nil
case <-time.After(5 * time.Second):
return nil, errors.New("timeout opening forwarded-tcpip channel")
}
}
func (hs *httpServer) cleanupUnusedChannel(channel ssh.Channel, reqs <-chan *ssh.Request) {
if channel != nil {
if err := channel.Close(); err != nil {
log.Printf("Failed to close unused channel: %v", err)
}
go ssh.DiscardRequests(reqs)
}
}
func (hs *httpServer) setupMiddlewares(hw HTTPWriter) {
fingerprintMiddleware := NewTunnelFingerprint()
forwardedForMiddleware := NewForwardedFor(hw.RemoteAddr())
hw.UseResponseMiddleware(fingerprintMiddleware)
hw.UseRequestMiddleware(forwardedForMiddleware)
}
func (hs *httpServer) sendInitialRequest(hw HTTPWriter, initialRequest httpheader.RequestHeader, channel ssh.Channel) error {
hw.SetRequestHeader(initialRequest)
if err := hw.ApplyRequestMiddlewares(initialRequest); err != nil {
return fmt.Errorf("error applying request middlewares: %w", err)
}
if _, err := channel.Write(initialRequest.Finalize()); err != nil {
return fmt.Errorf("error writing to channel: %w", err)
}
return nil
}
-254
View File
@@ -1,254 +0,0 @@
package server
import (
"bytes"
"io"
"log"
"net"
"regexp"
"tunnel_pls/internal/httpheader"
)
type HTTPWriter interface {
io.ReadWriteCloser
CloseWrite() error
RemoteAddr() net.Addr
UseResponseMiddleware(mw ResponseMiddleware)
UseRequestMiddleware(mw RequestMiddleware)
SetRequestHeader(header httpheader.RequestHeader)
RequestMiddlewares() []RequestMiddleware
ResponseMiddlewares() []ResponseMiddleware
ApplyResponseMiddlewares(resphf httpheader.ResponseHeader, body []byte) error
ApplyRequestMiddlewares(reqhf httpheader.RequestHeader) error
}
type httpWriter struct {
remoteAddr net.Addr
writer io.Writer
reader io.Reader
headerBuf []byte
buf []byte
respHeader httpheader.ResponseHeader
reqHeader httpheader.RequestHeader
respMW []ResponseMiddleware
reqMW []RequestMiddleware
}
var DELIMITER = []byte{0x0D, 0x0A, 0x0D, 0x0A}
var requestLine = regexp.MustCompile(`^(GET|POST|PUT|DELETE|HEAD|OPTIONS|PATCH|TRACE|CONNECT) \S+ HTTP/\d\.\d$`)
var responseLine = regexp.MustCompile(`^HTTP/\d\.\d \d{3} .+`)
func (hw *httpWriter) RemoteAddr() net.Addr {
return hw.remoteAddr
}
func (hw *httpWriter) UseResponseMiddleware(mw ResponseMiddleware) {
hw.respMW = append(hw.respMW, mw)
}
func (hw *httpWriter) UseRequestMiddleware(mw RequestMiddleware) {
hw.reqMW = append(hw.reqMW, mw)
}
func (hw *httpWriter) SetRequestHeader(header httpheader.RequestHeader) {
hw.reqHeader = header
}
func (hw *httpWriter) RequestMiddlewares() []RequestMiddleware {
return hw.reqMW
}
func (hw *httpWriter) ResponseMiddlewares() []ResponseMiddleware {
return hw.respMW
}
func (hw *httpWriter) Close() error {
return hw.writer.(io.Closer).Close()
}
func (hw *httpWriter) CloseWrite() error {
if closer, ok := hw.writer.(interface{ CloseWrite() error }); ok {
return closer.CloseWrite()
}
return hw.Close()
}
func (hw *httpWriter) Read(p []byte) (int, error) {
tmp := make([]byte, len(p))
read, err := hw.reader.Read(tmp)
if read == 0 && err != nil {
return 0, err
}
tmp = tmp[:read]
headerEndIdx := bytes.Index(tmp, DELIMITER)
if headerEndIdx == -1 {
return hw.handleNoDelimiter(p, tmp, err)
}
header, body := hw.splitHeaderAndBody(tmp, headerEndIdx)
if !isHTTPHeader(header) {
copy(p, tmp)
return read, nil
}
return hw.processHTTPRequest(p, header, body)
}
func (hw *httpWriter) handleNoDelimiter(p, tmp []byte, err error) (int, error) {
copy(p, tmp)
return len(tmp), err
}
func (hw *httpWriter) splitHeaderAndBody(data []byte, delimiterIdx int) ([]byte, []byte) {
header := data[:delimiterIdx+len(DELIMITER)]
body := data[delimiterIdx+len(DELIMITER):]
return header, body
}
func (hw *httpWriter) processHTTPRequest(p, header, body []byte) (int, error) {
reqhf, err := httpheader.NewRequestHeader(header)
if err != nil {
return 0, err
}
if err = hw.ApplyRequestMiddlewares(reqhf); err != nil {
return 0, err
}
hw.reqHeader = reqhf
combined := append(reqhf.Finalize(), body...)
return copy(p, combined), nil
}
func (hw *httpWriter) ApplyRequestMiddlewares(reqhf httpheader.RequestHeader) error {
for _, m := range hw.RequestMiddlewares() {
if err := m.HandleRequest(reqhf); err != nil {
log.Printf("Error when applying request middleware: %v", err)
return err
}
}
return nil
}
func (hw *httpWriter) Write(p []byte) (int, error) {
if hw.shouldBypassBuffering(p) {
hw.respHeader = nil
}
if hw.respHeader != nil {
return hw.writer.Write(p)
}
hw.buf = append(hw.buf, p...)
headerEndIdx := bytes.Index(hw.buf, DELIMITER)
if headerEndIdx == -1 {
return len(p), nil
}
return hw.processBufferedResponse(p, headerEndIdx)
}
func (hw *httpWriter) shouldBypassBuffering(p []byte) bool {
return hw.respHeader != nil && len(hw.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/"
}
func (hw *httpWriter) processBufferedResponse(p []byte, delimiterIdx int) (int, error) {
header, body := hw.splitHeaderAndBody(hw.buf, delimiterIdx)
if !isHTTPHeader(header) {
return hw.writeRawBuffer()
}
if err := hw.processHTTPResponse(header, body); err != nil {
return 0, err
}
hw.buf = nil
return len(p), nil
}
func (hw *httpWriter) writeRawBuffer() (int, error) {
_, err := hw.writer.Write(hw.buf)
length := len(hw.buf)
hw.buf = nil
if err != nil {
return 0, err
}
return length, nil
}
func (hw *httpWriter) processHTTPResponse(header, body []byte) error {
resphf, err := httpheader.NewResponseHeader(header)
if err != nil {
return err
}
if err = hw.ApplyResponseMiddlewares(resphf, body); err != nil {
return err
}
hw.respHeader = resphf
finalHeader := resphf.Finalize()
if err = hw.writeHeaderAndBody(finalHeader, body); err != nil {
return err
}
return nil
}
func (hw *httpWriter) ApplyResponseMiddlewares(resphf httpheader.ResponseHeader, body []byte) error {
for _, m := range hw.ResponseMiddlewares() {
if err := m.HandleResponse(resphf, body); err != nil {
log.Printf("Cannot apply middleware: %s\n", err)
return err
}
}
return nil
}
func (hw *httpWriter) writeHeaderAndBody(header, body []byte) error {
if _, err := hw.writer.Write(header); err != nil {
return err
}
if len(body) > 0 {
if _, err := hw.writer.Write(body); err != nil {
return err
}
}
return nil
}
func NewHTTPWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) HTTPWriter {
return &httpWriter{
remoteAddr: remoteAddr,
writer: writer,
reader: reader,
buf: make([]byte, 0, 4096),
}
}
func isHTTPHeader(buf []byte) bool {
lines := bytes.Split(buf, []byte("\r\n"))
startLine := string(lines[0])
if !requestLine.MatchString(startLine) && !responseLine.MatchString(startLine) {
return false
}
for _, line := range lines[1:] {
if len(line) == 0 {
break
}
colonIdx := bytes.IndexByte(line, ':')
if colonIdx <= 0 {
return false
}
}
return true
}
-42
View File
@@ -1,42 +0,0 @@
package server
import (
"net"
"tunnel_pls/internal/httpheader"
)
type RequestMiddleware interface {
HandleRequest(header httpheader.RequestHeader) error
}
type ResponseMiddleware interface {
HandleResponse(header httpheader.ResponseHeader, body []byte) error
}
type TunnelFingerprint struct{}
func NewTunnelFingerprint() *TunnelFingerprint {
return &TunnelFingerprint{}
}
func (h *TunnelFingerprint) HandleResponse(header httpheader.ResponseHeader, body []byte) error {
header.Set("Server", "Tunnel Please")
return nil
}
type ForwardedFor struct {
addr net.Addr
}
func NewForwardedFor(addr net.Addr) *ForwardedFor {
return &ForwardedFor{addr: addr}
}
func (ff *ForwardedFor) HandleRequest(header httpheader.RequestHeader) error {
host, _, err := net.SplitHostPort(ff.addr.String())
if err != nil {
return err
}
header.Set("X-Forwarded-For", host)
return nil
}
+12 -27
View File
@@ -7,9 +7,9 @@ import (
"log" "log"
"net" "net"
"time" "time"
"tunnel_pls/internal/config"
"tunnel_pls/internal/grpc/client" "tunnel_pls/internal/grpc/client"
"tunnel_pls/internal/port" "tunnel_pls/internal/port"
"tunnel_pls/internal/registry"
"tunnel_pls/session" "tunnel_pls/session"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
@@ -20,38 +20,23 @@ type Server interface {
Close() error Close() error
} }
type server struct { type server struct {
listener net.Listener sshPort string
sshListener net.Listener
config *ssh.ServerConfig config *ssh.ServerConfig
grpcClient client.Client grpcClient client.Client
sessionRegistry session.Registry sessionRegistry registry.Registry
portRegistry port.Registry portRegistry port.Port
} }
func New(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry, grpcClient client.Client, portRegistry port.Registry) (Server, error) { func New(sshConfig *ssh.ServerConfig, sessionRegistry registry.Registry, grpcClient client.Client, portRegistry port.Port, sshPort string) (Server, error) {
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", config.Getenv("PORT", "2200"))) listener, err := net.Listen("tcp", fmt.Sprintf(":%s", sshPort))
if err != nil { if err != nil {
log.Fatalf("failed to listen on port 2200: %v", err)
return nil, err return nil, err
} }
redirectTLS := config.Getenv("TLS_ENABLED", "false") == "true" && config.Getenv("TLS_REDIRECT", "false") == "true"
HttpServer := NewHTTPServer(sessionRegistry, redirectTLS)
err = HttpServer.ListenAndServe()
if err != nil {
log.Fatalf("failed to start http server: %v", err)
return nil, err
}
if config.Getenv("TLS_ENABLED", "false") == "true" {
err = HttpServer.ListenAndServeTLS()
if err != nil {
log.Fatalf("failed to start https server: %v", err)
return nil, err
}
}
return &server{ return &server{
listener: listener, sshPort: sshPort,
sshListener: listener,
config: sshConfig, config: sshConfig,
grpcClient: grpcClient, grpcClient: grpcClient,
sessionRegistry: sessionRegistry, sessionRegistry: sessionRegistry,
@@ -60,9 +45,9 @@ func New(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry, grpcClie
} }
func (s *server) Start() { func (s *server) Start() {
log.Println("SSH server is starting on port 2200...") log.Printf("SSH server is starting on port %s", s.sshPort)
for { for {
conn, err := s.listener.Accept() conn, err := s.sshListener.Accept()
if err != nil { if err != nil {
if errors.Is(err, net.ErrClosed) { if errors.Is(err, net.ErrClosed) {
log.Println("listener closed, stopping server") log.Println("listener closed, stopping server")
@@ -77,7 +62,7 @@ func (s *server) Start() {
} }
func (s *server) Close() error { func (s *server) Close() error {
return s.listener.Close() return s.sshListener.Close()
} }
func (s *server) handleConnection(conn net.Conn) { func (s *server) handleConnection(conn net.Conn) {
+5 -39
View File
@@ -56,14 +56,14 @@ type Forwarder interface {
Listener() net.Listener Listener() net.Listener
TunnelType() types.TunnelType TunnelType() types.TunnelType
ForwardedPort() uint16 ForwardedPort() uint16
HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) HandleConnection(dst io.ReadWriter, src ssh.Channel)
CreateForwardedTCPIPPayload(origin net.Addr) []byte CreateForwardedTCPIPPayload(origin net.Addr) []byte
OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error)
WriteBadGatewayResponse(dst io.Writer) WriteBadGatewayResponse(dst io.Writer)
AcceptTCPConnections()
Close() error Close() error
} }
func (f *forwarder) openForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) { func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
type channelResult struct { type channelResult struct {
channel ssh.Channel channel ssh.Channel
reqs <-chan *ssh.Request reqs <-chan *ssh.Request
@@ -95,38 +95,6 @@ func (f *forwarder) openForwardedChannel(payload []byte) (ssh.Channel, <-chan *s
} }
} }
func (f *forwarder) handleIncomingConnection(conn net.Conn) {
payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr())
channel, reqs, err := f.openForwardedChannel(payload)
if err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
err = conn.Close()
if err != nil {
log.Printf("Failed to close connection: %v", err)
}
return
}
go ssh.DiscardRequests(reqs)
go f.HandleConnection(conn, channel, conn.RemoteAddr())
}
func (f *forwarder) AcceptTCPConnections() {
for {
conn, err := f.Listener().Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
}
log.Printf("Error accepting connection: %v", err)
continue
}
go f.handleIncomingConnection(conn)
}
}
func closeWriter(w io.Writer) error { func closeWriter(w io.Writer) error {
if cw, ok := w.(interface{ CloseWrite() error }); ok { if cw, ok := w.(interface{ CloseWrite() error }); ok {
return cw.CloseWrite() return cw.CloseWrite()
@@ -145,12 +113,12 @@ func (f *forwarder) copyAndClose(dst io.Writer, src io.Reader, direction string)
} }
if err = closeWriter(dst); err != nil && !errors.Is(err, io.EOF) { if err = closeWriter(dst); err != nil && !errors.Is(err, io.EOF) {
errs = append(errs, fmt.Errorf("close writer error (%s): %w", direction, err)) errs = append(errs, fmt.Errorf("close stream error (%s): %w", direction, err))
} }
return errors.Join(errs...) return errors.Join(errs...)
} }
func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) { func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
defer func() { defer func() {
_, err := io.Copy(io.Discard, src) _, err := io.Copy(io.Discard, src)
if err != nil { if err != nil {
@@ -158,8 +126,6 @@ func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA
} }
}() }()
log.Printf("Handling new forwarded connection from %s", remoteAddr)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(2) wg.Add(2)
+5 -5
View File
@@ -31,11 +31,11 @@ type lifecycle struct {
slug slug.Slug slug slug.Slug
startedAt time.Time startedAt time.Time
sessionRegistry SessionRegistry sessionRegistry SessionRegistry
portRegistry portUtil.Registry portRegistry portUtil.Port
user string user string
} }
func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUtil.Registry, sessionRegistry SessionRegistry, user string) Lifecycle { func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUtil.Port, sessionRegistry SessionRegistry, user string) Lifecycle {
return &lifecycle{ return &lifecycle{
status: types.INITIALIZING, status: types.INITIALIZING,
conn: conn, conn: conn,
@@ -51,7 +51,7 @@ func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUti
type Lifecycle interface { type Lifecycle interface {
Connection() ssh.Conn Connection() ssh.Conn
PortRegistry() portUtil.Registry PortRegistry() portUtil.Port
User() string User() string
SetChannel(channel ssh.Channel) SetChannel(channel ssh.Channel)
SetStatus(status types.Status) SetStatus(status types.Status)
@@ -60,7 +60,7 @@ type Lifecycle interface {
Close() error Close() error
} }
func (l *lifecycle) PortRegistry() portUtil.Registry { func (l *lifecycle) PortRegistry() portUtil.Port {
return l.portRegistry return l.portRegistry
} }
@@ -113,7 +113,7 @@ func (l *lifecycle) Close() error {
l.sessionRegistry.Remove(key) l.sessionRegistry.Remove(key)
if tunnelType == types.TCP { if tunnelType == types.TCP {
if err := l.PortRegistry().SetPortStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil { if err := l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil {
firstErr = err firstErr = err
} }
} }
+19 -18
View File
@@ -12,6 +12,8 @@ import (
"tunnel_pls/internal/config" "tunnel_pls/internal/config"
portUtil "tunnel_pls/internal/port" portUtil "tunnel_pls/internal/port"
"tunnel_pls/internal/random" "tunnel_pls/internal/random"
"tunnel_pls/internal/registry"
"tunnel_pls/internal/transport"
"tunnel_pls/session/forwarder" "tunnel_pls/session/forwarder"
"tunnel_pls/session/interaction" "tunnel_pls/session/interaction"
"tunnel_pls/session/lifecycle" "tunnel_pls/session/lifecycle"
@@ -21,14 +23,6 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
type Detail struct {
ForwardingType string `json:"forwarding_type,omitempty"`
Slug string `json:"slug,omitempty"`
UserID string `json:"user_id,omitempty"`
Active bool `json:"active,omitempty"`
StartedAt time.Time `json:"started_at,omitempty"`
}
type Session interface { type Session interface {
HandleGlobalRequest(ch <-chan *ssh.Request) error HandleGlobalRequest(ch <-chan *ssh.Request) error
HandleTCPIPForward(req *ssh.Request) error HandleTCPIPForward(req *ssh.Request) error
@@ -38,7 +32,7 @@ type Session interface {
Interaction() interaction.Interaction Interaction() interaction.Interaction
Forwarder() forwarder.Forwarder Forwarder() forwarder.Forwarder
Slug() slug.Slug Slug() slug.Slug
Detail() *Detail Detail() *types.Detail
Start() error Start() error
} }
@@ -49,12 +43,12 @@ type session struct {
interaction interaction.Interaction interaction interaction.Interaction
forwarder forwarder.Forwarder forwarder forwarder.Forwarder
slug slug.Slug slug slug.Slug
registry Registry registry registry.Registry
} }
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017} var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, portRegistry portUtil.Registry, user string) Session { func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry registry.Registry, portRegistry portUtil.Port, user string) Session {
slugManager := slug.New() slugManager := slug.New()
forwarderManager := forwarder.New(slugManager, conn) forwarderManager := forwarder.New(slugManager, conn)
lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, portRegistry, sessionRegistry, user) lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, portRegistry, sessionRegistry, user)
@@ -87,7 +81,7 @@ func (s *session) Slug() slug.Slug {
return s.slug return s.slug
} }
func (s *session) Detail() *Detail { func (s *session) Detail() *types.Detail {
tunnelTypeMap := map[types.TunnelType]string{ tunnelTypeMap := map[types.TunnelType]string{
types.HTTP: "HTTP", types.HTTP: "HTTP",
types.TCP: "TCP", types.TCP: "TCP",
@@ -97,7 +91,7 @@ func (s *session) Detail() *Detail {
tunnelType = "UNKNOWN" tunnelType = "UNKNOWN"
} }
return &Detail{ return &types.Detail{
ForwardingType: tunnelType, ForwardingType: tunnelType,
Slug: s.slug.String(), Slug: s.slug.String(),
UserID: s.lifecycle.User(), UserID: s.lifecycle.User(),
@@ -271,7 +265,7 @@ func (s *session) parseForwardPayload(payloadReader io.Reader) (address string,
} }
if port == 0 { if port == 0 {
unassigned, ok := s.lifecycle.PortRegistry().GetUnassignedPort() unassigned, ok := s.lifecycle.PortRegistry().Unassigned()
if !ok { if !ok {
return "", 0, fmt.Errorf("no available port") return "", 0, fmt.Errorf("no available port")
} }
@@ -328,7 +322,6 @@ func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listen
if listener != nil { if listener != nil {
s.forwarder.SetListener(listener) s.forwarder.SetListener(listener)
go s.forwarder.AcceptTCPConnections()
} }
return nil return nil
@@ -346,7 +339,6 @@ func (s *session) HandleTCPIPForward(req *ssh.Request) error {
case 80, 443: case 80, 443:
return s.HandleHTTPForward(req, port) return s.HandleHTTPForward(req, port)
default: default:
return s.HandleTCPForward(req, address, port) return s.HandleTCPForward(req, address, port)
} }
} }
@@ -369,11 +361,12 @@ func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) error {
} }
func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) error { func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) error {
if claimed := s.lifecycle.PortRegistry().ClaimPort(portToBind); !claimed { if claimed := s.lifecycle.PortRegistry().Claim(portToBind); !claimed {
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind)) return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind))
} }
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) tcpServer := transport.NewTCPServer(portToBind, s.forwarder)
listener, err := tcpServer.Listen()
if err != nil { if err != nil {
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind)) return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind))
} }
@@ -387,6 +380,14 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin
if err != nil { if err != nil {
return s.denyForwardingRequest(req, &key, listener, fmt.Sprintf("Failed to finalize forwarding: %s", err)) return s.denyForwardingRequest(req, &key, listener, fmt.Sprintf("Failed to finalize forwarding: %s", err))
} }
go func() {
err = tcpServer.Serve(listener)
if err != nil {
log.Printf("Failed serving tcp server: %s\n", err)
}
}()
return nil return nil
} }
+10
View File
@@ -1,5 +1,7 @@
package types package types
import "time"
type Status int type Status int
const ( const (
@@ -27,6 +29,14 @@ type SessionKey struct {
Type TunnelType Type TunnelType
} }
type Detail struct {
ForwardingType string `json:"forwarding_type,omitempty"`
Slug string `json:"slug,omitempty"`
UserID string `json:"user_id,omitempty"`
Active bool `json:"active,omitempty"`
StartedAt time.Time `json:"started_at,omitempty"`
}
var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" + var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" +
"Content-Length: 11\r\n" + "Content-Length: 11\r\n" +
"Content-Type: text/plain\r\n\r\n" + "Content-Type: text/plain\r\n\r\n" +