revert-54069ad305 #11

Closed
bagas wants to merge 217 commits from revert-54069ad305 into main
21 changed files with 460 additions and 277 deletions
Showing only changes of commit 961a905542 - Show all commits
+20
View File
@@ -0,0 +1,20 @@
on:
push:
pull_request:
types: [opened, synchronize, reopened]
name: SonarQube Scan
jobs:
sonarqube:
name: SonarQube Trigger
runs-on: ubuntu-latest
steps:
- name: Checking out
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: SonarQube Scan
uses: SonarSource/sonarqube-scan-action@v7.0.0
env:
SONAR_HOST_URL: ${{ secrets.SONARQUBE_HOST }}
SONAR_TOKEN: ${{ secrets.SONARQUBE_TOKEN }}
+53 -23
View File
@@ -1,33 +1,63 @@
package config
import (
"os"
"strconv"
import "tunnel_pls/types"
"github.com/joho/godotenv"
)
type Config interface {
Domain() string
SSHPort() string
func Load() error {
if _, err := os.Stat(".env"); err == nil {
return godotenv.Load(".env")
}
return nil
HTTPPort() string
HTTPSPort() string
TLSEnabled() bool
TLSRedirect() bool
ACMEEmail() string
CFAPIToken() string
ACMEStaging() bool
AllowedPortsStart() uint16
AllowedPortsEnd() uint16
BufferSize() int
PprofEnabled() bool
PprofPort() string
Mode() types.ServerMode
GRPCAddress() string
GRPCPort() string
NodeToken() string
}
func Getenv(key, defaultValue string) string {
val := os.Getenv(key)
if val == "" {
val = defaultValue
func MustLoad() (Config, error) {
if err := loadEnvFile(); err != nil {
return nil, err
}
return val
cfg, err := parse()
if err != nil {
return nil, err
}
return cfg, nil
}
func GetBufferSize() int {
sizeStr := Getenv("BUFFER_SIZE", "32768")
size, err := strconv.Atoi(sizeStr)
if err != nil || size < 4096 || size > 1048576 {
return 32768
}
return size
}
func (c *config) Domain() string { return c.domain }
func (c *config) SSHPort() string { return c.sshPort }
func (c *config) HTTPPort() string { return c.httpPort }
func (c *config) HTTPSPort() string { return c.httpsPort }
func (c *config) TLSEnabled() bool { return c.tlsEnabled }
func (c *config) TLSRedirect() bool { return c.tlsRedirect }
func (c *config) ACMEEmail() string { return c.acmeEmail }
func (c *config) CFAPIToken() string { return c.cfAPIToken }
func (c *config) ACMEStaging() bool { return c.acmeStaging }
func (c *config) AllowedPortsStart() uint16 { return c.allowedPortsStart }
func (c *config) AllowedPortsEnd() uint16 { return c.allowedPortsEnd }
func (c *config) BufferSize() int { return c.bufferSize }
func (c *config) PprofEnabled() bool { return c.pprofEnabled }
func (c *config) PprofPort() string { return c.pprofPort }
func (c *config) Mode() types.ServerMode { return c.mode }
func (c *config) GRPCAddress() string { return c.grpcAddress }
func (c *config) GRPCPort() string { return c.grpcPort }
func (c *config) NodeToken() string { return c.nodeToken }
+170
View File
@@ -0,0 +1,170 @@
package config
import (
"fmt"
"log"
"os"
"strconv"
"strings"
"tunnel_pls/types"
"github.com/joho/godotenv"
)
type config struct {
domain string
sshPort string
httpPort string
httpsPort string
tlsEnabled bool
tlsRedirect bool
acmeEmail string
cfAPIToken string
acmeStaging bool
allowedPortsStart uint16
allowedPortsEnd uint16
bufferSize int
pprofEnabled bool
pprofPort string
mode types.ServerMode
grpcAddress string
grpcPort string
nodeToken string
}
func parse() (*config, error) {
mode, err := parseMode()
if err != nil {
return nil, err
}
domain := getenv("DOMAIN", "localhost")
sshPort := getenv("PORT", "2200")
httpPort := getenv("HTTP_PORT", "8080")
httpsPort := getenv("HTTPS_PORT", "8443")
tlsEnabled := getenvBool("TLS_ENABLED", false)
tlsRedirect := tlsEnabled && getenvBool("TLS_REDIRECT", false)
acmeEmail := getenv("ACME_EMAIL", "admin@"+domain)
acmeStaging := getenvBool("ACME_STAGING", false)
cfToken := getenv("CF_API_TOKEN", "")
if tlsEnabled && cfToken == "" {
return nil, fmt.Errorf("CF_API_TOKEN is required when TLS is enabled")
}
start, end, err := parseAllowedPorts()
if err != nil {
return nil, err
}
bufferSize := parseBufferSize()
pprofEnabled := getenvBool("PPROF_ENABLED", false)
pprofPort := getenv("PPROF_PORT", "6060")
grpcHost := getenv("GRPC_ADDRESS", "localhost")
grpcPort := getenv("GRPC_PORT", "8080")
nodeToken := getenv("NODE_TOKEN", "")
if mode == types.ServerModeNODE && nodeToken == "" {
return nil, fmt.Errorf("NODE_TOKEN is required in node mode")
}
return &config{
domain: domain,
sshPort: sshPort,
httpPort: httpPort,
httpsPort: httpsPort,
tlsEnabled: tlsEnabled,
tlsRedirect: tlsRedirect,
acmeEmail: acmeEmail,
cfAPIToken: cfToken,
acmeStaging: acmeStaging,
allowedPortsStart: start,
allowedPortsEnd: end,
bufferSize: bufferSize,
pprofEnabled: pprofEnabled,
pprofPort: pprofPort,
mode: mode,
grpcAddress: grpcHost,
grpcPort: grpcPort,
nodeToken: nodeToken,
}, nil
}
func loadEnvFile() error {
if _, err := os.Stat(".env"); err == nil {
return godotenv.Load(".env")
}
return nil
}
func parseMode() (types.ServerMode, error) {
switch strings.ToLower(getenv("MODE", "standalone")) {
case "standalone":
return types.ServerModeSTANDALONE, nil
case "node":
return types.ServerModeNODE, nil
default:
return 0, fmt.Errorf("invalid MODE value")
}
}
func parseAllowedPorts() (uint16, uint16, error) {
raw := getenv("ALLOWED_PORTS", "")
if raw == "" {
return 0, 0, nil
}
parts := strings.Split(raw, "-")
if len(parts) != 2 {
return 0, 0, fmt.Errorf("invalid ALLOWED_PORTS format")
}
start, err := strconv.ParseUint(parts[0], 10, 16)
if err != nil {
return 0, 0, err
}
end, err := strconv.ParseUint(parts[1], 10, 16)
if err != nil {
return 0, 0, err
}
return uint16(start), uint16(end), nil
}
func parseBufferSize() int {
raw := getenv("BUFFER_SIZE", "32768")
size, err := strconv.Atoi(raw)
if err != nil || size < 4096 || size > 1048576 {
log.Println("Invalid BUFFER_SIZE, falling back to 4096")
return 4096
}
return size
}
func getenv(key, def string) string {
if v := os.Getenv(key); v != "" {
return v
}
return def
}
func getenvBool(key string, def bool) bool {
val := os.Getenv(key)
if val == "" {
return def
}
return val == "true"
}
+9 -7
View File
@@ -29,6 +29,7 @@ type Client interface {
CheckServerHealth(ctx context.Context) error
}
type client struct {
config config.Config
conn *grpc.ClientConn
address string
sessionRegistry registry.Registry
@@ -37,7 +38,7 @@ type client struct {
closing bool
}
func New(address string, sessionRegistry registry.Registry) (Client, error) {
func New(config config.Config, address string, sessionRegistry registry.Registry) (Client, error) {
var opts []grpc.DialOption
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
@@ -66,6 +67,7 @@ func New(address string, sessionRegistry registry.Registry) (Client, error) {
authorizeConnectionService := proto.NewUserServiceClient(conn)
return &client{
config: config,
conn: conn,
address: address,
sessionRegistry: sessionRegistry,
@@ -192,7 +194,7 @@ func (c *client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node,
oldSlug := slugEvent.GetOld()
newSlug := slugEvent.GetNew()
userSession, err := c.sessionRegistry.Get(types.SessionKey{Id: oldSlug, Type: types.HTTP})
userSession, err := c.sessionRegistry.Get(types.SessionKey{Id: oldSlug, Type: types.TunnelTypeHTTP})
if err != nil {
return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_SLUG_CHANGE_RESPONSE,
@@ -202,7 +204,7 @@ func (c *client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node,
}, "slug change failure response")
}
if err = c.sessionRegistry.Update(user, types.SessionKey{Id: oldSlug, Type: types.HTTP}, types.SessionKey{Id: newSlug, Type: types.HTTP}); err != nil {
if err = c.sessionRegistry.Update(user, types.SessionKey{Id: oldSlug, Type: types.TunnelTypeHTTP}, types.SessionKey{Id: newSlug, Type: types.TunnelTypeHTTP}); err != nil {
return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_SLUG_CHANGE_RESPONSE,
Payload: &proto.Node_SlugEventResponse{
@@ -227,7 +229,7 @@ func (c *client) handleGetSessions(subscribe grpc.BidiStreamingClient[proto.Node
for _, ses := range sessions {
detail := ses.Detail()
details = append(details, &proto.Detail{
Node: config.Getenv("DOMAIN", "localhost"),
Node: c.config.Domain(),
ForwardingType: detail.ForwardingType,
Slug: detail.Slug,
UserId: detail.UserID,
@@ -299,11 +301,11 @@ func (c *client) sendNode(subscribe grpc.BidiStreamingClient[proto.Node, proto.E
func (c *client) protoToTunnelType(t proto.TunnelType) (types.TunnelType, error) {
switch t {
case proto.TunnelType_HTTP:
return types.HTTP, nil
return types.TunnelTypeHTTP, nil
case proto.TunnelType_TCP:
return types.TCP, nil
return types.TunnelTypeTCP, nil
default:
return types.UNKNOWN, fmt.Errorf("unknown tunnel type received")
return types.TunnelTypeUNKNOWN, fmt.Errorf("unknown tunnel type received")
}
}
+3 -3
View File
@@ -17,9 +17,9 @@ type RequestHeader interface {
Set(key string, value string)
Remove(key string)
Finalize() []byte
GetMethod() string
GetPath() string
GetVersion() string
Method() string
Path() string
Version() string
}
type requestHeader struct {
method string
+3 -3
View File
@@ -32,15 +32,15 @@ func (req *requestHeader) Remove(key string) {
delete(req.headers, key)
}
func (req *requestHeader) GetMethod() string {
func (req *requestHeader) Method() string {
return req.method
}
func (req *requestHeader) GetPath() string {
func (req *requestHeader) Path() string {
return req.path
}
func (req *requestHeader) GetVersion() string {
func (req *requestHeader) Version() string {
return req.version
}
+19 -10
View File
@@ -34,6 +34,15 @@ type registry struct {
slugIndex map[Key]string
}
var (
ErrSessionNotFound = fmt.Errorf("session not found")
ErrSlugInUse = fmt.Errorf("slug already in use")
ErrInvalidSlug = fmt.Errorf("invalid slug")
ErrForbiddenSlug = fmt.Errorf("forbidden slug")
ErrSlugChangeNotAllowed = fmt.Errorf("slug change not allowed for this tunnel type")
ErrSlugUnchanged = fmt.Errorf("slug is unchanged")
)
func NewRegistry() Registry {
return &registry{
byUser: make(map[string]map[Key]Session),
@@ -47,12 +56,12 @@ func (r *registry) Get(key Key) (session Session, err error) {
userID, ok := r.slugIndex[key]
if !ok {
return nil, fmt.Errorf("Session not found")
return nil, ErrSessionNotFound
}
client, ok := r.byUser[userID][key]
if !ok {
return nil, fmt.Errorf("Session not found")
return nil, ErrSessionNotFound
}
return client, nil
}
@@ -63,37 +72,37 @@ func (r *registry) GetWithUser(user string, key Key) (session Session, err error
client, ok := r.byUser[user][key]
if !ok {
return nil, fmt.Errorf("Session not found")
return nil, ErrSessionNotFound
}
return client, nil
}
func (r *registry) Update(user string, oldKey, newKey Key) error {
if oldKey.Type != newKey.Type {
return fmt.Errorf("tunnel type cannot change")
return ErrSlugUnchanged
}
if newKey.Type != types.HTTP {
return fmt.Errorf("non http tunnel cannot change slug")
if newKey.Type != types.TunnelTypeHTTP {
return ErrSlugChangeNotAllowed
}
if isForbiddenSlug(newKey.Id) {
return fmt.Errorf("this subdomain is reserved. Please choose a different one")
return ErrForbiddenSlug
}
if !isValidSlug(newKey.Id) {
return fmt.Errorf("invalid subdomain. Follow the rules")
return ErrInvalidSlug
}
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey {
return fmt.Errorf("someone already uses this subdomain")
return ErrSlugInUse
}
client, ok := r.byUser[user][oldKey]
if !ok {
return fmt.Errorf("Session not found")
return ErrSessionNotFound
}
delete(r.byUser[user], oldKey)
+2 -2
View File
@@ -12,9 +12,9 @@ type httpServer struct {
port string
}
func NewHTTPServer(port string, sessionRegistry registry.Registry, redirectTLS bool) Transport {
func NewHTTPServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool) Transport {
return &httpServer{
handler: newHTTPHandler(sessionRegistry, redirectTLS),
handler: newHTTPHandler(domain, sessionRegistry, redirectTLS),
port: port,
}
}
+14 -10
View File
@@ -4,12 +4,12 @@ import (
"bufio"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"strings"
"time"
"tunnel_pls/internal/config"
"tunnel_pls/internal/http/header"
"tunnel_pls/internal/http/stream"
"tunnel_pls/internal/middleware"
@@ -20,12 +20,14 @@ import (
)
type httpHandler struct {
domain string
sessionRegistry registry.Registry
redirectTLS bool
}
func newHTTPHandler(sessionRegistry registry.Registry, redirectTLS bool) *httpHandler {
func newHTTPHandler(domain string, sessionRegistry registry.Registry, redirectTLS bool) *httpHandler {
return &httpHandler{
domain: domain,
sessionRegistry: sessionRegistry,
redirectTLS: redirectTLS,
}
@@ -67,7 +69,7 @@ func (hh *httpHandler) handler(conn net.Conn, isTLS bool) {
}
if hh.shouldRedirectToTLS(isTLS) {
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, config.Getenv("DOMAIN", "localhost")))
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, hh.domain))
return
}
@@ -133,7 +135,7 @@ func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool {
func (hh *httpHandler) getSession(slug string) (registry.Session, error) {
sshSession, err := hh.sessionRegistry.Get(types.SessionKey{
Id: slug,
Type: types.HTTP,
Type: types.TunnelTypeHTTP,
})
if err != nil {
return nil, err
@@ -143,17 +145,19 @@ func (hh *httpHandler) getSession(slug string) (registry.Session, error) {
func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) {
channel, err := hh.openForwardedChannel(hw, sshSession)
defer func() {
err = channel.Close()
if err != nil {
log.Printf("Error closing forwarded channel: %v", err)
}
}()
if err != nil {
log.Printf("Failed to establish channel: %v", err)
sshSession.Forwarder().WriteBadGatewayResponse(hw)
return
}
defer func() {
err = channel.Close()
if err != nil && !errors.Is(err, io.EOF) {
log.Printf("Error closing forwarded channel: %v", err)
}
}()
hh.setupMiddlewares(hw)
if err = hh.sendInitialRequest(hw, initialRequest, channel); err != nil {
+6 -9
View File
@@ -9,28 +9,25 @@ import (
)
type https struct {
tlsConfig *tls.Config
httpHandler *httpHandler
domain string
port string
}
func NewHTTPSServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool) Transport {
func NewHTTPSServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool, tlsConfig *tls.Config) Transport {
return &https{
httpHandler: newHTTPHandler(sessionRegistry, redirectTLS),
tlsConfig: tlsConfig,
httpHandler: newHTTPHandler(domain, sessionRegistry, redirectTLS),
domain: domain,
port: port,
}
}
func (ht *https) Listen() (net.Listener, error) {
tlsConfig, err := NewTLSConfig(ht.domain)
if err != nil {
return nil, err
}
return tls.Listen("tcp", ":"+ht.port, tlsConfig)
return tls.Listen("tcp", ":"+ht.port, ht.tlsConfig)
}
func (ht *https) Serve(listener net.Listener) error {
log.Printf("HTTPS server is starting on port %s", ht.port)
for {
+12 -33
View File
@@ -26,7 +26,8 @@ type TLSManager interface {
}
type tlsManager struct {
domain string
config config.Config
certPath string
keyPath string
storagePath string
@@ -42,7 +43,7 @@ type tlsManager struct {
var globalTLSManager TLSManager
var tlsManagerOnce sync.Once
func NewTLSConfig(domain string) (*tls.Config, error) {
func NewTLSConfig(config config.Config) (*tls.Config, error) {
var initErr error
tlsManagerOnce.Do(func() {
@@ -51,7 +52,7 @@ func NewTLSConfig(domain string) (*tls.Config, error) {
storagePath := "certs/tls/certmagic"
tm := &tlsManager{
domain: domain,
config: config,
certPath: certPath,
keyPath: keyPath,
storagePath: storagePath,
@@ -66,14 +67,7 @@ func NewTLSConfig(domain string) (*tls.Config, error) {
tm.useCertMagic = false
tm.startCertWatcher()
} else {
if !isACMEConfigComplete() {
log.Printf("User certificates missing or invalid, and ACME configuration is incomplete")
log.Printf("To enable automatic certificate generation, set CF_API_TOKEN environment variable")
initErr = fmt.Errorf("no valid certificates found and ACME configuration is incomplete (CF_API_TOKEN is required)")
return
}
log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", domain, domain)
log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", config.Domain(), config.Domain())
if err := tm.initCertMagic(); err != nil {
initErr = fmt.Errorf("failed to initialize CertMagic: %w", err)
return
@@ -91,11 +85,6 @@ func NewTLSConfig(domain string) (*tls.Config, error) {
return globalTLSManager.getTLSConfig(), nil
}
func isACMEConfigComplete() bool {
cfAPIToken := config.Getenv("CF_API_TOKEN", "")
return cfAPIToken != ""
}
func (tm *tlsManager) userCertsExistAndValid() bool {
if _, err := os.Stat(tm.certPath); os.IsNotExist(err) {
log.Printf("Certificate file not found: %s", tm.certPath)
@@ -106,7 +95,7 @@ func (tm *tlsManager) userCertsExistAndValid() bool {
return false
}
return ValidateCertDomains(tm.certPath, tm.domain)
return ValidateCertDomains(tm.certPath, tm.config.Domain())
}
func ValidateCertDomains(certPath, domain string) bool {
@@ -206,15 +195,9 @@ func (tm *tlsManager) startCertWatcher() {
if certInfo.ModTime().After(lastCertMod) || keyInfo.ModTime().After(lastKeyMod) {
log.Printf("Certificate files changed, reloading...")
if !ValidateCertDomains(tm.certPath, tm.domain) {
if !ValidateCertDomains(tm.certPath, tm.config.Domain()) {
log.Printf("New certificates don't cover required domains")
if !isACMEConfigComplete() {
log.Printf("Cannot switch to CertMagic: ACME configuration is incomplete (CF_API_TOKEN is required)")
continue
}
log.Printf("Switching to CertMagic for automatic certificate management")
if err := tm.initCertMagic(); err != nil {
log.Printf("Failed to initialize CertMagic: %v", err)
continue
@@ -241,16 +224,12 @@ func (tm *tlsManager) initCertMagic() error {
return fmt.Errorf("failed to create cert storage directory: %w", err)
}
acmeEmail := config.Getenv("ACME_EMAIL", "admin@"+tm.domain)
cfAPIToken := config.Getenv("CF_API_TOKEN", "")
acmeStaging := config.Getenv("ACME_STAGING", "false") == "true"
if cfAPIToken == "" {
if tm.config.CFAPIToken() == "" {
return fmt.Errorf("CF_API_TOKEN environment variable is required for automatic certificate generation")
}
cfProvider := &cloudflare.Provider{
APIToken: cfAPIToken,
APIToken: tm.config.CFAPIToken(),
}
storage := &certmagic.FileStorage{Path: tm.storagePath}
@@ -266,7 +245,7 @@ func (tm *tlsManager) initCertMagic() error {
})
acmeIssuer := certmagic.NewACMEIssuer(magic, certmagic.ACMEIssuer{
Email: acmeEmail,
Email: tm.config.ACMEEmail(),
Agreed: true,
DNS01Solver: &certmagic.DNS01Solver{
DNSManager: certmagic.DNSManager{
@@ -275,7 +254,7 @@ func (tm *tlsManager) initCertMagic() error {
},
})
if acmeStaging {
if tm.config.ACMEStaging() {
acmeIssuer.CA = certmagic.LetsEncryptStagingCA
log.Printf("Using Let's Encrypt staging server")
} else {
@@ -286,7 +265,7 @@ func (tm *tlsManager) initCertMagic() error {
magic.Issuers = []certmagic.Issuer{acmeIssuer}
tm.magic = magic
domains := []string{tm.domain, "*." + tm.domain}
domains := []string{tm.config.Domain(), "*." + tm.config.Domain()}
log.Printf("Requesting certificates for: %v", domains)
ctx := context.Background()
+33 -70
View File
@@ -9,8 +9,6 @@ import (
_ "net/http/pprof"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
"time"
"tunnel_pls/internal/config"
@@ -21,6 +19,7 @@ import (
"tunnel_pls/internal/transport"
"tunnel_pls/internal/version"
"tunnel_pls/server"
"tunnel_pls/types"
"golang.org/x/crypto/ssh"
)
@@ -36,27 +35,12 @@ func main() {
log.Printf("Starting %s", version.GetVersion())
err := config.Load()
conf, err := config.MustLoad()
if err != nil {
log.Fatalf("Failed to load configuration: %s", err)
return
}
mode := strings.ToLower(config.Getenv("MODE", "standalone"))
isNodeMode := mode == "node"
pprofEnabled := config.Getenv("PPROF_ENABLED", "false")
if pprofEnabled == "true" {
pprofPort := config.Getenv("PPROF_PORT", "6060")
go func() {
pprofAddr := fmt.Sprintf("localhost:%s", pprofPort)
log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr)
if err = http.ListenAndServe(pprofAddr, nil); err != nil {
log.Printf("pprof server error: %v", err)
}
}()
}
sshConfig := &ssh.ServerConfig{
NoClientAuth: true,
ServerVersion: fmt.Sprintf("SSH-2.0-TunnelPlease-%s", version.GetShortVersion()),
@@ -88,16 +72,11 @@ func main() {
signal.Notify(shutdownChan, os.Interrupt, syscall.SIGTERM)
var grpcClient client.Client
if isNodeMode {
grpcHost := config.Getenv("GRPC_ADDRESS", "localhost")
grpcPort := config.Getenv("GRPC_PORT", "8080")
grpcAddr := fmt.Sprintf("%s:%s", grpcHost, grpcPort)
nodeToken := config.Getenv("NODE_TOKEN", "")
if nodeToken == "" {
log.Fatalf("NODE_TOKEN is required in node mode")
}
grpcClient, err = client.New(grpcAddr, sessionRegistry)
if conf.Mode() == types.ServerModeNODE {
grpcAddr := fmt.Sprintf("%s:%s", conf.GRPCAddress(), conf.GRPCPort())
grpcClient, err = client.New(conf, grpcAddr, sessionRegistry)
if err != nil {
log.Fatalf("failed to create grpc client: %v", err)
}
@@ -110,46 +89,15 @@ func main() {
healthCancel()
go func() {
identity := config.Getenv("DOMAIN", "localhost")
if err = grpcClient.SubscribeEvents(ctx, identity, nodeToken); err != nil {
if err = grpcClient.SubscribeEvents(ctx, conf.Domain(), conf.NodeToken()); err != nil {
errChan <- fmt.Errorf("failed to subscribe to events: %w", err)
}
}()
}
portManager := port.New()
rawRange := config.Getenv("ALLOWED_PORTS", "")
if rawRange != "" {
splitRange := strings.Split(rawRange, "-")
if len(splitRange) == 2 {
var start, end uint64
start, err = strconv.ParseUint(splitRange[0], 10, 16)
if err != nil {
log.Fatalf("Failed to parse start port: %s", err)
}
end, err = strconv.ParseUint(splitRange[1], 10, 16)
if err != nil {
log.Fatalf("Failed to parse end port: %s", err)
}
if err = portManager.AddRange(uint16(start), uint16(end)); err != nil {
log.Fatalf("Failed to add port range: %s", err)
}
log.Printf("PortRegistry range configured: %d-%d", start, end)
} else {
log.Printf("Invalid ALLOWED_PORTS format, expected 'start-end', got: %s", rawRange)
}
}
tlsEnabled := config.Getenv("TLS_ENABLED", "false") == "true"
redirectTLS := config.Getenv("TLS_ENABLED", "false") == "true" && config.Getenv("TLS_REDIRECT", "false") == "true"
go func() {
httpPort := config.Getenv("HTTP_PORT", "8080")
var httpListener net.Listener
httpserver := transport.NewHTTPServer(httpPort, sessionRegistry, redirectTLS)
httpserver := transport.NewHTTPServer(conf.Domain(), conf.HTTPPort(), sessionRegistry, conf.TLSRedirect())
httpListener, err = httpserver.Listen()
if err != nil {
errChan <- fmt.Errorf("failed to start http server: %w", err)
@@ -162,19 +110,17 @@ func main() {
}
}()
if tlsEnabled {
if conf.TLSEnabled() {
go func() {
httpsPort := config.Getenv("HTTPS_PORT", "8443")
domain := config.Getenv("DOMAIN", "localhost")
var httpListener net.Listener
httpserver := transport.NewHTTPSServer(domain, httpsPort, sessionRegistry, redirectTLS)
httpListener, err = httpserver.Listen()
var httpsListener net.Listener
tlsConfig, _ := transport.NewTLSConfig(conf)
httpsServer := transport.NewHTTPSServer(conf.Domain(), conf.HTTPSPort(), sessionRegistry, conf.TLSRedirect(), tlsConfig)
httpsListener, err = httpsServer.Listen()
if err != nil {
errChan <- fmt.Errorf("failed to start http server: %w", err)
return
}
err = httpserver.Serve(httpListener)
err = httpsServer.Serve(httpsListener)
if err != nil {
errChan <- fmt.Errorf("error when serving http server: %w", err)
return
@@ -182,17 +128,34 @@ func main() {
}()
}
portManager := port.New()
err = portManager.AddRange(conf.AllowedPortsStart(), conf.AllowedPortsEnd())
if err != nil {
log.Fatalf("Failed to initialize port manager: %s", err)
return
}
var app server.Server
go func() {
sshPort := config.Getenv("PORT", "2200")
app, err = server.New(sshConfig, sessionRegistry, grpcClient, portManager, sshPort)
app, err = server.New(conf, sshConfig, sessionRegistry, grpcClient, portManager, conf.SSHPort())
if err != nil {
errChan <- fmt.Errorf("failed to start server: %s", err)
return
}
app.Start()
}()
if conf.PprofEnabled() {
go func() {
pprofAddr := fmt.Sprintf("localhost:%s", conf.PprofPort())
log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr)
if err = http.ListenAndServe(pprofAddr, nil); err != nil {
log.Printf("pprof server error: %v", err)
}
}()
}
select {
case err = <-errChan:
log.Printf("error happen : %s", err)
+8 -5
View File
@@ -7,6 +7,7 @@ import (
"log"
"net"
"time"
"tunnel_pls/internal/config"
"tunnel_pls/internal/grpc/client"
"tunnel_pls/internal/port"
"tunnel_pls/internal/registry"
@@ -20,24 +21,26 @@ type Server interface {
Close() error
}
type server struct {
config config.Config
sshPort string
sshListener net.Listener
config *ssh.ServerConfig
sshConfig *ssh.ServerConfig
grpcClient client.Client
sessionRegistry registry.Registry
portRegistry port.Port
}
func New(sshConfig *ssh.ServerConfig, sessionRegistry registry.Registry, grpcClient client.Client, portRegistry port.Port, sshPort string) (Server, error) {
func New(config config.Config, sshConfig *ssh.ServerConfig, sessionRegistry registry.Registry, grpcClient client.Client, portRegistry port.Port, sshPort string) (Server, error) {
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", sshPort))
if err != nil {
return nil, err
}
return &server{
config: config,
sshPort: sshPort,
sshListener: listener,
config: sshConfig,
sshConfig: sshConfig,
grpcClient: grpcClient,
sessionRegistry: sessionRegistry,
portRegistry: portRegistry,
@@ -66,7 +69,7 @@ func (s *server) Close() error {
}
func (s *server) handleConnection(conn net.Conn) {
sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.config)
sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.sshConfig)
if err != nil {
log.Printf("failed to establish SSH connection: %v", err)
err = conn.Close()
@@ -92,7 +95,7 @@ func (s *server) handleConnection(conn net.Conn) {
cancel()
}
log.Println("SSH connection established:", sshConn.User())
sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry, s.portRegistry, user)
sshSession := session.New(s.config, sshConn, forwardingReqs, chans, s.sessionRegistry, s.portRegistry, user)
err = sshSession.Start()
if err != nil {
log.Printf("SSH session ended with error: %v", err)
+31 -32
View File
@@ -18,37 +18,6 @@ import (
"golang.org/x/crypto/ssh"
)
var bufferPool = sync.Pool{
New: func() interface{} {
bufSize := config.GetBufferSize()
return make([]byte, bufSize)
},
}
func copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
buf := bufferPool.Get().([]byte)
defer bufferPool.Put(buf)
return io.CopyBuffer(dst, src, buf)
}
type forwarder struct {
listener net.Listener
tunnelType types.TunnelType
forwardedPort uint16
slug slug.Slug
conn ssh.Conn
}
func New(slug slug.Slug, conn ssh.Conn) Forwarder {
return &forwarder{
listener: nil,
tunnelType: types.UNKNOWN,
forwardedPort: 0,
slug: slug,
conn: conn,
}
}
type Forwarder interface {
SetType(tunnelType types.TunnelType)
SetForwardedPort(port uint16)
@@ -62,6 +31,36 @@ type Forwarder interface {
WriteBadGatewayResponse(dst io.Writer)
Close() error
}
type forwarder struct {
listener net.Listener
tunnelType types.TunnelType
forwardedPort uint16
slug slug.Slug
conn ssh.Conn
bufferPool sync.Pool
}
func New(config config.Config, slug slug.Slug, conn ssh.Conn) Forwarder {
return &forwarder{
listener: nil,
tunnelType: types.TunnelTypeUNKNOWN,
forwardedPort: 0,
slug: slug,
conn: conn,
bufferPool: sync.Pool{
New: func() interface{} {
bufSize := config.BufferSize()
return make([]byte, bufSize)
},
},
}
}
func (f *forwarder) copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
buf := f.bufferPool.Get().([]byte)
defer f.bufferPool.Put(buf)
return io.CopyBuffer(dst, src, buf)
}
func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
type channelResult struct {
@@ -107,7 +106,7 @@ func closeWriter(w io.Writer) error {
func (f *forwarder) copyAndClose(dst io.Writer, src io.Reader, direction string) error {
var errs []error
_, err := copyWithBuffer(dst, src)
_, err := f.copyWithBuffer(dst, src)
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
errs = append(errs, fmt.Errorf("copy error (%s): %w", direction, err))
}
+11 -10
View File
@@ -18,9 +18,9 @@ import (
)
type Interaction interface {
Mode() types.Mode
Mode() types.InteractiveMode
SetChannel(channel ssh.Channel)
SetMode(m types.Mode)
SetMode(m types.InteractiveMode)
SetWH(w, h int)
Start()
Redraw()
@@ -39,6 +39,7 @@ type Forwarder interface {
type CloseFunc func() error
type interaction struct {
config config.Config
channel ssh.Channel
slug slug.Slug
forwarder Forwarder
@@ -48,14 +49,14 @@ type interaction struct {
program *tea.Program
ctx context.Context
cancel context.CancelFunc
mode types.Mode
mode types.InteractiveMode
}
func (i *interaction) SetMode(m types.Mode) {
func (i *interaction) SetMode(m types.InteractiveMode) {
i.mode = m
}
func (i *interaction) Mode() types.Mode {
func (i *interaction) Mode() types.InteractiveMode {
return i.mode
}
@@ -75,9 +76,10 @@ func (i *interaction) SetWH(w, h int) {
}
}
func New(slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction {
func New(config config.Config, slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction {
ctx, cancel := context.WithCancel(context.Background())
return &interaction{
config: config,
channel: nil,
slug: slug,
forwarder: forwarder,
@@ -174,14 +176,13 @@ func (m *model) View() string {
}
func (i *interaction) Start() {
if i.mode == types.HEADLESS {
if i.mode == types.InteractiveModeHEADLESS {
return
}
lipgloss.SetColorProfile(termenv.TrueColor)
domain := config.Getenv("DOMAIN", "localhost")
protocol := "http"
if config.Getenv("TLS_ENABLED", "false") == "true" {
if i.config.TLSEnabled() {
protocol = "https"
}
@@ -209,7 +210,7 @@ func (i *interaction) Start() {
ti.Width = 50
m := &model{
domain: domain,
domain: i.config.Domain(),
protocol: protocol,
tunnelType: tunnelType,
port: port,
+1 -1
View File
@@ -41,7 +41,7 @@ type model struct {
}
func (m *model) getTunnelURL() string {
if m.tunnelType == types.HTTP {
if m.tunnelType == types.TunnelTypeHTTP {
return buildURL(m.protocol, m.interaction.slug.String(), m.domain)
}
return fmt.Sprintf("tcp://%s:%d", m.domain, m.port)
+4 -4
View File
@@ -15,7 +15,7 @@ import (
func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
if m.tunnelType != types.HTTP {
if m.tunnelType != types.TunnelTypeHTTP {
m.editingSlug = false
m.slugError = ""
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
@@ -30,10 +30,10 @@ func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
inputValue := m.slugInput.Value()
if err := m.interaction.sessionRegistry.Update(m.interaction.user, types.SessionKey{
Id: m.interaction.slug.String(),
Type: types.HTTP,
Type: types.TunnelTypeHTTP,
}, types.SessionKey{
Id: inputValue,
Type: types.HTTP,
Type: types.TunnelTypeHTTP,
}); err != nil {
m.slugError = err.Error()
return m, nil
@@ -130,7 +130,7 @@ func (m *model) slugView() string {
b.WriteString(titleStyle.Render(title))
b.WriteString("\n\n")
if m.tunnelType != types.HTTP {
if m.tunnelType != types.TunnelTypeHTTP {
warningBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
warningBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FFA500")).
+25 -29
View File
@@ -2,8 +2,6 @@ package lifecycle
import (
"errors"
"io"
"net"
"time"
portUtil "tunnel_pls/internal/port"
@@ -24,7 +22,7 @@ type SessionRegistry interface {
}
type lifecycle struct {
status types.Status
status types.SessionStatus
conn ssh.Conn
channel ssh.Channel
forwarder Forwarder
@@ -37,7 +35,7 @@ type lifecycle struct {
func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUtil.Port, sessionRegistry SessionRegistry, user string) Lifecycle {
return &lifecycle{
status: types.INITIALIZING,
status: types.SessionStatusINITIALIZING,
conn: conn,
channel: nil,
forwarder: forwarder,
@@ -54,7 +52,7 @@ type Lifecycle interface {
PortRegistry() portUtil.Port
User() string
SetChannel(channel ssh.Channel)
SetStatus(status types.Status)
SetStatus(status types.SessionStatus)
IsActive() bool
StartedAt() time.Time
Close() error
@@ -74,35 +72,30 @@ func (l *lifecycle) SetChannel(channel ssh.Channel) {
func (l *lifecycle) Connection() ssh.Conn {
return l.conn
}
func (l *lifecycle) SetStatus(status types.Status) {
func (l *lifecycle) SetStatus(status types.SessionStatus) {
l.status = status
if status == types.RUNNING && l.startedAt.IsZero() {
if status == types.SessionStatusRUNNING && l.startedAt.IsZero() {
l.startedAt = time.Now()
}
}
func closeIfNotNil(c interface{ Close() error }) error {
if c != nil {
return c.Close()
}
return nil
}
func (l *lifecycle) Close() error {
var firstErr error
var errs []error
tunnelType := l.forwarder.TunnelType()
if err := l.forwarder.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
firstErr = err
if err := closeIfNotNil(l.channel); err != nil {
errs = append(errs, err)
}
if l.channel != nil {
if err := l.channel.Close(); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
if firstErr == nil {
firstErr = err
}
}
}
if l.conn != nil {
if err := l.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
if firstErr == nil {
firstErr = err
}
}
if err := closeIfNotNil(l.conn); err != nil {
errs = append(errs, err)
}
clientSlug := l.slug.String()
@@ -112,17 +105,20 @@ func (l *lifecycle) Close() error {
}
l.sessionRegistry.Remove(key)
if tunnelType == types.TCP {
if err := l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil {
firstErr = err
if tunnelType == types.TunnelTypeTCP {
if err := l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false); err != nil {
errs = append(errs, err)
}
if err := l.forwarder.Close(); err != nil {
errs = append(errs, err)
}
}
return firstErr
return errors.Join(errs...)
}
func (l *lifecycle) IsActive() bool {
return l.status == types.RUNNING
return l.status == types.SessionStatusRUNNING
}
func (l *lifecycle) StartedAt() time.Time {
+19 -17
View File
@@ -37,6 +37,7 @@ type Session interface {
}
type session struct {
config config.Config
initialReq <-chan *ssh.Request
sshChan <-chan ssh.NewChannel
lifecycle lifecycle.Lifecycle
@@ -48,13 +49,14 @@ type session struct {
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry registry.Registry, portRegistry portUtil.Port, user string) Session {
func New(config config.Config, conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry registry.Registry, portRegistry portUtil.Port, user string) Session {
slugManager := slug.New()
forwarderManager := forwarder.New(slugManager, conn)
forwarderManager := forwarder.New(config, slugManager, conn)
lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, portRegistry, sessionRegistry, user)
interactionManager := interaction.New(slugManager, forwarderManager, sessionRegistry, user, lifecycleManager.Close)
interactionManager := interaction.New(config, slugManager, forwarderManager, sessionRegistry, user, lifecycleManager.Close)
return &session{
config: config,
initialReq: initialReq,
sshChan: sshChan,
lifecycle: lifecycleManager,
@@ -83,12 +85,12 @@ func (s *session) Slug() slug.Slug {
func (s *session) Detail() *types.Detail {
tunnelTypeMap := map[types.TunnelType]string{
types.HTTP: "HTTP",
types.TCP: "TCP",
types.TunnelTypeHTTP: "TunnelTypeHTTP",
types.TunnelTypeTCP: "TunnelTypeTCP",
}
tunnelType, ok := tunnelTypeMap[s.forwarder.TunnelType()]
if !ok {
tunnelType = "UNKNOWN"
tunnelType = "TunnelTypeUNKNOWN"
}
return &types.Detail{
@@ -131,7 +133,7 @@ func (s *session) setupSessionMode() error {
}
return s.setupInteractiveMode(channel)
case <-time.After(500 * time.Millisecond):
s.interaction.SetMode(types.HEADLESS)
s.interaction.SetMode(types.InteractiveModeHEADLESS)
return nil
}
}
@@ -152,13 +154,13 @@ func (s *session) setupInteractiveMode(channel ssh.NewChannel) error {
s.lifecycle.SetChannel(ch)
s.interaction.SetChannel(ch)
s.interaction.SetMode(types.INTERACTIVE)
s.interaction.SetMode(types.InteractiveModeINTERACTIVE)
return nil
}
func (s *session) handleMissingForwardRequest() error {
err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200")))
err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", s.config.Domain(), s.config.SSHPort()))
if err != nil {
return err
}
@@ -169,8 +171,8 @@ func (s *session) handleMissingForwardRequest() error {
}
func (s *session) shouldRejectUnauthorized() bool {
return s.interaction.Mode() == types.HEADLESS &&
config.Getenv("MODE", "standalone") == "standalone" &&
return s.interaction.Mode() == types.InteractiveModeHEADLESS &&
s.config.Mode() == types.ServerModeSTANDALONE &&
s.lifecycle.User() == "UNAUTHORIZED"
}
@@ -318,7 +320,7 @@ func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listen
s.forwarder.SetType(tunnelType)
s.forwarder.SetForwardedPort(portToBind)
s.slug.Set(slug)
s.lifecycle.SetStatus(types.RUNNING)
s.lifecycle.SetStatus(types.SessionStatusRUNNING)
if listener != nil {
s.forwarder.SetListener(listener)
@@ -348,12 +350,12 @@ func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) error {
if err != nil {
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to create slug: %s", err))
}
key := types.SessionKey{Id: randomString, Type: types.HTTP}
key := types.SessionKey{Id: randomString, Type: types.TunnelTypeHTTP}
if !s.registry.Register(key, s) {
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to register client with slug: %s", randomString))
}
err = s.finalizeForwarding(req, portToBind, nil, types.HTTP, key.Id)
err = s.finalizeForwarding(req, portToBind, nil, types.TunnelTypeHTTP, key.Id)
if err != nil {
return s.denyForwardingRequest(req, &key, nil, fmt.Sprintf("Failed to finalize forwarding: %s", err))
}
@@ -371,12 +373,12 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind))
}
key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TCP}
key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TunnelTypeTCP}
if !s.registry.Register(key, s) {
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Failed to register TCP client with id: %s", key.Id))
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Failed to register TunnelTypeTCP client with id: %s", key.Id))
}
err = s.finalizeForwarding(req, portToBind, listener, types.TCP, key.Id)
err = s.finalizeForwarding(req, portToBind, listener, types.TunnelTypeTCP, key.Id)
if err != nil {
return s.denyForwardingRequest(req, &key, listener, fmt.Sprintf("Failed to finalize forwarding: %s", err))
}
+1
View File
@@ -0,0 +1 @@
sonar.projectKey=tunnel-please
+16 -9
View File
@@ -2,26 +2,33 @@ package types
import "time"
type Status int
type SessionStatus int
const (
INITIALIZING Status = iota
RUNNING
SessionStatusINITIALIZING SessionStatus = iota
SessionStatusRUNNING
)
type Mode int
type InteractiveMode int
const (
INTERACTIVE Mode = iota
HEADLESS
InteractiveModeINTERACTIVE InteractiveMode = iota + 1
InteractiveModeHEADLESS
)
type TunnelType int
const (
UNKNOWN TunnelType = iota
HTTP
TCP
TunnelTypeUNKNOWN TunnelType = iota
TunnelTypeHTTP
TunnelTypeTCP
)
type ServerMode int
const (
ServerModeSTANDALONE = iota + 1
ServerModeNODE
)
type SessionKey struct {