feat/restructure #73
@@ -0,0 +1,20 @@
|
|||||||
|
on:
|
||||||
|
push:
|
||||||
|
pull_request:
|
||||||
|
types: [opened, synchronize, reopened]
|
||||||
|
|
||||||
|
name: SonarQube Scan
|
||||||
|
jobs:
|
||||||
|
sonarqube:
|
||||||
|
name: SonarQube Trigger
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checking out
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
- name: SonarQube Scan
|
||||||
|
uses: SonarSource/sonarqube-scan-action@v7.0.0
|
||||||
|
env:
|
||||||
|
SONAR_HOST_URL: ${{ secrets.SONARQUBE_HOST }}
|
||||||
|
SONAR_TOKEN: ${{ secrets.SONARQUBE_TOKEN }}
|
||||||
+53
-23
@@ -1,33 +1,63 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import "tunnel_pls/types"
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"github.com/joho/godotenv"
|
type Config interface {
|
||||||
)
|
Domain() string
|
||||||
|
SSHPort() string
|
||||||
|
|
||||||
func Load() error {
|
HTTPPort() string
|
||||||
if _, err := os.Stat(".env"); err == nil {
|
HTTPSPort() string
|
||||||
return godotenv.Load(".env")
|
|
||||||
}
|
TLSEnabled() bool
|
||||||
return nil
|
TLSRedirect() bool
|
||||||
|
|
||||||
|
ACMEEmail() string
|
||||||
|
CFAPIToken() string
|
||||||
|
ACMEStaging() bool
|
||||||
|
|
||||||
|
AllowedPortsStart() uint16
|
||||||
|
AllowedPortsEnd() uint16
|
||||||
|
|
||||||
|
BufferSize() int
|
||||||
|
|
||||||
|
PprofEnabled() bool
|
||||||
|
PprofPort() string
|
||||||
|
|
||||||
|
Mode() types.ServerMode
|
||||||
|
GRPCAddress() string
|
||||||
|
GRPCPort() string
|
||||||
|
NodeToken() string
|
||||||
}
|
}
|
||||||
|
|
||||||
func Getenv(key, defaultValue string) string {
|
func MustLoad() (Config, error) {
|
||||||
val := os.Getenv(key)
|
if err := loadEnvFile(); err != nil {
|
||||||
if val == "" {
|
return nil, err
|
||||||
val = defaultValue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return val
|
cfg, err := parse()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return cfg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetBufferSize() int {
|
func (c *config) Domain() string { return c.domain }
|
||||||
sizeStr := Getenv("BUFFER_SIZE", "32768")
|
func (c *config) SSHPort() string { return c.sshPort }
|
||||||
size, err := strconv.Atoi(sizeStr)
|
func (c *config) HTTPPort() string { return c.httpPort }
|
||||||
if err != nil || size < 4096 || size > 1048576 {
|
func (c *config) HTTPSPort() string { return c.httpsPort }
|
||||||
return 32768
|
func (c *config) TLSEnabled() bool { return c.tlsEnabled }
|
||||||
}
|
func (c *config) TLSRedirect() bool { return c.tlsRedirect }
|
||||||
return size
|
func (c *config) ACMEEmail() string { return c.acmeEmail }
|
||||||
}
|
func (c *config) CFAPIToken() string { return c.cfAPIToken }
|
||||||
|
func (c *config) ACMEStaging() bool { return c.acmeStaging }
|
||||||
|
func (c *config) AllowedPortsStart() uint16 { return c.allowedPortsStart }
|
||||||
|
func (c *config) AllowedPortsEnd() uint16 { return c.allowedPortsEnd }
|
||||||
|
func (c *config) BufferSize() int { return c.bufferSize }
|
||||||
|
func (c *config) PprofEnabled() bool { return c.pprofEnabled }
|
||||||
|
func (c *config) PprofPort() string { return c.pprofPort }
|
||||||
|
func (c *config) Mode() types.ServerMode { return c.mode }
|
||||||
|
func (c *config) GRPCAddress() string { return c.grpcAddress }
|
||||||
|
func (c *config) GRPCPort() string { return c.grpcPort }
|
||||||
|
func (c *config) NodeToken() string { return c.nodeToken }
|
||||||
|
|||||||
@@ -0,0 +1,170 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"tunnel_pls/types"
|
||||||
|
|
||||||
|
"github.com/joho/godotenv"
|
||||||
|
)
|
||||||
|
|
||||||
|
type config struct {
|
||||||
|
domain string
|
||||||
|
sshPort string
|
||||||
|
|
||||||
|
httpPort string
|
||||||
|
httpsPort string
|
||||||
|
|
||||||
|
tlsEnabled bool
|
||||||
|
tlsRedirect bool
|
||||||
|
|
||||||
|
acmeEmail string
|
||||||
|
cfAPIToken string
|
||||||
|
acmeStaging bool
|
||||||
|
|
||||||
|
allowedPortsStart uint16
|
||||||
|
allowedPortsEnd uint16
|
||||||
|
|
||||||
|
bufferSize int
|
||||||
|
|
||||||
|
pprofEnabled bool
|
||||||
|
pprofPort string
|
||||||
|
|
||||||
|
mode types.ServerMode
|
||||||
|
grpcAddress string
|
||||||
|
grpcPort string
|
||||||
|
nodeToken string
|
||||||
|
}
|
||||||
|
|
||||||
|
func parse() (*config, error) {
|
||||||
|
mode, err := parseMode()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
domain := getenv("DOMAIN", "localhost")
|
||||||
|
sshPort := getenv("PORT", "2200")
|
||||||
|
|
||||||
|
httpPort := getenv("HTTP_PORT", "8080")
|
||||||
|
httpsPort := getenv("HTTPS_PORT", "8443")
|
||||||
|
|
||||||
|
tlsEnabled := getenvBool("TLS_ENABLED", false)
|
||||||
|
tlsRedirect := tlsEnabled && getenvBool("TLS_REDIRECT", false)
|
||||||
|
|
||||||
|
acmeEmail := getenv("ACME_EMAIL", "admin@"+domain)
|
||||||
|
acmeStaging := getenvBool("ACME_STAGING", false)
|
||||||
|
|
||||||
|
cfToken := getenv("CF_API_TOKEN", "")
|
||||||
|
if tlsEnabled && cfToken == "" {
|
||||||
|
return nil, fmt.Errorf("CF_API_TOKEN is required when TLS is enabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
start, end, err := parseAllowedPorts()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
bufferSize := parseBufferSize()
|
||||||
|
|
||||||
|
pprofEnabled := getenvBool("PPROF_ENABLED", false)
|
||||||
|
pprofPort := getenv("PPROF_PORT", "6060")
|
||||||
|
|
||||||
|
grpcHost := getenv("GRPC_ADDRESS", "localhost")
|
||||||
|
grpcPort := getenv("GRPC_PORT", "8080")
|
||||||
|
|
||||||
|
nodeToken := getenv("NODE_TOKEN", "")
|
||||||
|
if mode == types.ServerModeNODE && nodeToken == "" {
|
||||||
|
return nil, fmt.Errorf("NODE_TOKEN is required in node mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &config{
|
||||||
|
domain: domain,
|
||||||
|
sshPort: sshPort,
|
||||||
|
httpPort: httpPort,
|
||||||
|
httpsPort: httpsPort,
|
||||||
|
tlsEnabled: tlsEnabled,
|
||||||
|
tlsRedirect: tlsRedirect,
|
||||||
|
acmeEmail: acmeEmail,
|
||||||
|
cfAPIToken: cfToken,
|
||||||
|
acmeStaging: acmeStaging,
|
||||||
|
allowedPortsStart: start,
|
||||||
|
allowedPortsEnd: end,
|
||||||
|
bufferSize: bufferSize,
|
||||||
|
pprofEnabled: pprofEnabled,
|
||||||
|
pprofPort: pprofPort,
|
||||||
|
mode: mode,
|
||||||
|
grpcAddress: grpcHost,
|
||||||
|
grpcPort: grpcPort,
|
||||||
|
nodeToken: nodeToken,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadEnvFile() error {
|
||||||
|
if _, err := os.Stat(".env"); err == nil {
|
||||||
|
return godotenv.Load(".env")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseMode() (types.ServerMode, error) {
|
||||||
|
switch strings.ToLower(getenv("MODE", "standalone")) {
|
||||||
|
case "standalone":
|
||||||
|
return types.ServerModeSTANDALONE, nil
|
||||||
|
case "node":
|
||||||
|
return types.ServerModeNODE, nil
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("invalid MODE value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAllowedPorts() (uint16, uint16, error) {
|
||||||
|
raw := getenv("ALLOWED_PORTS", "")
|
||||||
|
if raw == "" {
|
||||||
|
return 0, 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(raw, "-")
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return 0, 0, fmt.Errorf("invalid ALLOWED_PORTS format")
|
||||||
|
}
|
||||||
|
|
||||||
|
start, err := strconv.ParseUint(parts[0], 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
end, err := strconv.ParseUint(parts[1], 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return uint16(start), uint16(end), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseBufferSize() int {
|
||||||
|
raw := getenv("BUFFER_SIZE", "32768")
|
||||||
|
size, err := strconv.Atoi(raw)
|
||||||
|
if err != nil || size < 4096 || size > 1048576 {
|
||||||
|
log.Println("Invalid BUFFER_SIZE, falling back to 4096")
|
||||||
|
return 4096
|
||||||
|
}
|
||||||
|
return size
|
||||||
|
}
|
||||||
|
|
||||||
|
func getenv(key, def string) string {
|
||||||
|
if v := os.Getenv(key); v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
|
||||||
|
func getenvBool(key string, def bool) bool {
|
||||||
|
val := os.Getenv(key)
|
||||||
|
if val == "" {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
return val == "true"
|
||||||
|
}
|
||||||
@@ -29,6 +29,7 @@ type Client interface {
|
|||||||
CheckServerHealth(ctx context.Context) error
|
CheckServerHealth(ctx context.Context) error
|
||||||
}
|
}
|
||||||
type client struct {
|
type client struct {
|
||||||
|
config config.Config
|
||||||
conn *grpc.ClientConn
|
conn *grpc.ClientConn
|
||||||
address string
|
address string
|
||||||
sessionRegistry registry.Registry
|
sessionRegistry registry.Registry
|
||||||
@@ -37,7 +38,7 @@ type client struct {
|
|||||||
closing bool
|
closing bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(address string, sessionRegistry registry.Registry) (Client, error) {
|
func New(config config.Config, address string, sessionRegistry registry.Registry) (Client, error) {
|
||||||
var opts []grpc.DialOption
|
var opts []grpc.DialOption
|
||||||
|
|
||||||
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
@@ -66,6 +67,7 @@ func New(address string, sessionRegistry registry.Registry) (Client, error) {
|
|||||||
authorizeConnectionService := proto.NewUserServiceClient(conn)
|
authorizeConnectionService := proto.NewUserServiceClient(conn)
|
||||||
|
|
||||||
return &client{
|
return &client{
|
||||||
|
config: config,
|
||||||
conn: conn,
|
conn: conn,
|
||||||
address: address,
|
address: address,
|
||||||
sessionRegistry: sessionRegistry,
|
sessionRegistry: sessionRegistry,
|
||||||
@@ -192,7 +194,7 @@ func (c *client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node,
|
|||||||
oldSlug := slugEvent.GetOld()
|
oldSlug := slugEvent.GetOld()
|
||||||
newSlug := slugEvent.GetNew()
|
newSlug := slugEvent.GetNew()
|
||||||
|
|
||||||
userSession, err := c.sessionRegistry.Get(types.SessionKey{Id: oldSlug, Type: types.HTTP})
|
userSession, err := c.sessionRegistry.Get(types.SessionKey{Id: oldSlug, Type: types.TunnelTypeHTTP})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.sendNode(subscribe, &proto.Node{
|
return c.sendNode(subscribe, &proto.Node{
|
||||||
Type: proto.EventType_SLUG_CHANGE_RESPONSE,
|
Type: proto.EventType_SLUG_CHANGE_RESPONSE,
|
||||||
@@ -202,7 +204,7 @@ func (c *client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node,
|
|||||||
}, "slug change failure response")
|
}, "slug change failure response")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = c.sessionRegistry.Update(user, types.SessionKey{Id: oldSlug, Type: types.HTTP}, types.SessionKey{Id: newSlug, Type: types.HTTP}); err != nil {
|
if err = c.sessionRegistry.Update(user, types.SessionKey{Id: oldSlug, Type: types.TunnelTypeHTTP}, types.SessionKey{Id: newSlug, Type: types.TunnelTypeHTTP}); err != nil {
|
||||||
return c.sendNode(subscribe, &proto.Node{
|
return c.sendNode(subscribe, &proto.Node{
|
||||||
Type: proto.EventType_SLUG_CHANGE_RESPONSE,
|
Type: proto.EventType_SLUG_CHANGE_RESPONSE,
|
||||||
Payload: &proto.Node_SlugEventResponse{
|
Payload: &proto.Node_SlugEventResponse{
|
||||||
@@ -227,7 +229,7 @@ func (c *client) handleGetSessions(subscribe grpc.BidiStreamingClient[proto.Node
|
|||||||
for _, ses := range sessions {
|
for _, ses := range sessions {
|
||||||
detail := ses.Detail()
|
detail := ses.Detail()
|
||||||
details = append(details, &proto.Detail{
|
details = append(details, &proto.Detail{
|
||||||
Node: config.Getenv("DOMAIN", "localhost"),
|
Node: c.config.Domain(),
|
||||||
ForwardingType: detail.ForwardingType,
|
ForwardingType: detail.ForwardingType,
|
||||||
Slug: detail.Slug,
|
Slug: detail.Slug,
|
||||||
UserId: detail.UserID,
|
UserId: detail.UserID,
|
||||||
@@ -299,11 +301,11 @@ func (c *client) sendNode(subscribe grpc.BidiStreamingClient[proto.Node, proto.E
|
|||||||
func (c *client) protoToTunnelType(t proto.TunnelType) (types.TunnelType, error) {
|
func (c *client) protoToTunnelType(t proto.TunnelType) (types.TunnelType, error) {
|
||||||
switch t {
|
switch t {
|
||||||
case proto.TunnelType_HTTP:
|
case proto.TunnelType_HTTP:
|
||||||
return types.HTTP, nil
|
return types.TunnelTypeHTTP, nil
|
||||||
case proto.TunnelType_TCP:
|
case proto.TunnelType_TCP:
|
||||||
return types.TCP, nil
|
return types.TunnelTypeTCP, nil
|
||||||
default:
|
default:
|
||||||
return types.UNKNOWN, fmt.Errorf("unknown tunnel type received")
|
return types.TunnelTypeUNKNOWN, fmt.Errorf("unknown tunnel type received")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -17,9 +17,9 @@ type RequestHeader interface {
|
|||||||
Set(key string, value string)
|
Set(key string, value string)
|
||||||
Remove(key string)
|
Remove(key string)
|
||||||
Finalize() []byte
|
Finalize() []byte
|
||||||
GetMethod() string
|
Method() string
|
||||||
GetPath() string
|
Path() string
|
||||||
GetVersion() string
|
Version() string
|
||||||
}
|
}
|
||||||
type requestHeader struct {
|
type requestHeader struct {
|
||||||
method string
|
method string
|
||||||
|
|||||||
@@ -32,15 +32,15 @@ func (req *requestHeader) Remove(key string) {
|
|||||||
delete(req.headers, key)
|
delete(req.headers, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (req *requestHeader) GetMethod() string {
|
func (req *requestHeader) Method() string {
|
||||||
return req.method
|
return req.method
|
||||||
}
|
}
|
||||||
|
|
||||||
func (req *requestHeader) GetPath() string {
|
func (req *requestHeader) Path() string {
|
||||||
return req.path
|
return req.path
|
||||||
}
|
}
|
||||||
|
|
||||||
func (req *requestHeader) GetVersion() string {
|
func (req *requestHeader) Version() string {
|
||||||
return req.version
|
return req.version
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -34,6 +34,15 @@ type registry struct {
|
|||||||
slugIndex map[Key]string
|
slugIndex map[Key]string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrSessionNotFound = fmt.Errorf("session not found")
|
||||||
|
ErrSlugInUse = fmt.Errorf("slug already in use")
|
||||||
|
ErrInvalidSlug = fmt.Errorf("invalid slug")
|
||||||
|
ErrForbiddenSlug = fmt.Errorf("forbidden slug")
|
||||||
|
ErrSlugChangeNotAllowed = fmt.Errorf("slug change not allowed for this tunnel type")
|
||||||
|
ErrSlugUnchanged = fmt.Errorf("slug is unchanged")
|
||||||
|
)
|
||||||
|
|
||||||
func NewRegistry() Registry {
|
func NewRegistry() Registry {
|
||||||
return ®istry{
|
return ®istry{
|
||||||
byUser: make(map[string]map[Key]Session),
|
byUser: make(map[string]map[Key]Session),
|
||||||
@@ -47,12 +56,12 @@ func (r *registry) Get(key Key) (session Session, err error) {
|
|||||||
|
|
||||||
userID, ok := r.slugIndex[key]
|
userID, ok := r.slugIndex[key]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("Session not found")
|
return nil, ErrSessionNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
client, ok := r.byUser[userID][key]
|
client, ok := r.byUser[userID][key]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("Session not found")
|
return nil, ErrSessionNotFound
|
||||||
}
|
}
|
||||||
return client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
@@ -63,37 +72,37 @@ func (r *registry) GetWithUser(user string, key Key) (session Session, err error
|
|||||||
|
|
||||||
client, ok := r.byUser[user][key]
|
client, ok := r.byUser[user][key]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("Session not found")
|
return nil, ErrSessionNotFound
|
||||||
}
|
}
|
||||||
return client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *registry) Update(user string, oldKey, newKey Key) error {
|
func (r *registry) Update(user string, oldKey, newKey Key) error {
|
||||||
if oldKey.Type != newKey.Type {
|
if oldKey.Type != newKey.Type {
|
||||||
return fmt.Errorf("tunnel type cannot change")
|
return ErrSlugUnchanged
|
||||||
}
|
}
|
||||||
|
|
||||||
if newKey.Type != types.HTTP {
|
if newKey.Type != types.TunnelTypeHTTP {
|
||||||
return fmt.Errorf("non http tunnel cannot change slug")
|
return ErrSlugChangeNotAllowed
|
||||||
}
|
}
|
||||||
|
|
||||||
if isForbiddenSlug(newKey.Id) {
|
if isForbiddenSlug(newKey.Id) {
|
||||||
return fmt.Errorf("this subdomain is reserved. Please choose a different one")
|
return ErrForbiddenSlug
|
||||||
}
|
}
|
||||||
|
|
||||||
if !isValidSlug(newKey.Id) {
|
if !isValidSlug(newKey.Id) {
|
||||||
return fmt.Errorf("invalid subdomain. Follow the rules")
|
return ErrInvalidSlug
|
||||||
}
|
}
|
||||||
|
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey {
|
if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey {
|
||||||
return fmt.Errorf("someone already uses this subdomain")
|
return ErrSlugInUse
|
||||||
}
|
}
|
||||||
client, ok := r.byUser[user][oldKey]
|
client, ok := r.byUser[user][oldKey]
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("Session not found")
|
return ErrSessionNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(r.byUser[user], oldKey)
|
delete(r.byUser[user], oldKey)
|
||||||
|
|||||||
@@ -12,9 +12,9 @@ type httpServer struct {
|
|||||||
port string
|
port string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHTTPServer(port string, sessionRegistry registry.Registry, redirectTLS bool) Transport {
|
func NewHTTPServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool) Transport {
|
||||||
return &httpServer{
|
return &httpServer{
|
||||||
handler: newHTTPHandler(sessionRegistry, redirectTLS),
|
handler: newHTTPHandler(domain, sessionRegistry, redirectTLS),
|
||||||
port: port,
|
port: port,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,12 +4,12 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
"tunnel_pls/internal/config"
|
|
||||||
"tunnel_pls/internal/http/header"
|
"tunnel_pls/internal/http/header"
|
||||||
"tunnel_pls/internal/http/stream"
|
"tunnel_pls/internal/http/stream"
|
||||||
"tunnel_pls/internal/middleware"
|
"tunnel_pls/internal/middleware"
|
||||||
@@ -20,12 +20,14 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type httpHandler struct {
|
type httpHandler struct {
|
||||||
|
domain string
|
||||||
sessionRegistry registry.Registry
|
sessionRegistry registry.Registry
|
||||||
redirectTLS bool
|
redirectTLS bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHTTPHandler(sessionRegistry registry.Registry, redirectTLS bool) *httpHandler {
|
func newHTTPHandler(domain string, sessionRegistry registry.Registry, redirectTLS bool) *httpHandler {
|
||||||
return &httpHandler{
|
return &httpHandler{
|
||||||
|
domain: domain,
|
||||||
sessionRegistry: sessionRegistry,
|
sessionRegistry: sessionRegistry,
|
||||||
redirectTLS: redirectTLS,
|
redirectTLS: redirectTLS,
|
||||||
}
|
}
|
||||||
@@ -67,7 +69,7 @@ func (hh *httpHandler) handler(conn net.Conn, isTLS bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if hh.shouldRedirectToTLS(isTLS) {
|
if hh.shouldRedirectToTLS(isTLS) {
|
||||||
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, config.Getenv("DOMAIN", "localhost")))
|
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, hh.domain))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -133,7 +135,7 @@ func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool {
|
|||||||
func (hh *httpHandler) getSession(slug string) (registry.Session, error) {
|
func (hh *httpHandler) getSession(slug string) (registry.Session, error) {
|
||||||
sshSession, err := hh.sessionRegistry.Get(types.SessionKey{
|
sshSession, err := hh.sessionRegistry.Get(types.SessionKey{
|
||||||
Id: slug,
|
Id: slug,
|
||||||
Type: types.HTTP,
|
Type: types.TunnelTypeHTTP,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -143,17 +145,19 @@ func (hh *httpHandler) getSession(slug string) (registry.Session, error) {
|
|||||||
|
|
||||||
func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) {
|
func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) {
|
||||||
channel, err := hh.openForwardedChannel(hw, sshSession)
|
channel, err := hh.openForwardedChannel(hw, sshSession)
|
||||||
defer func() {
|
|
||||||
err = channel.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error closing forwarded channel: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to establish channel: %v", err)
|
log.Printf("Failed to establish channel: %v", err)
|
||||||
sshSession.Forwarder().WriteBadGatewayResponse(hw)
|
sshSession.Forwarder().WriteBadGatewayResponse(hw)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err = channel.Close()
|
||||||
|
if err != nil && !errors.Is(err, io.EOF) {
|
||||||
|
log.Printf("Error closing forwarded channel: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
hh.setupMiddlewares(hw)
|
hh.setupMiddlewares(hw)
|
||||||
|
|
||||||
if err = hh.sendInitialRequest(hw, initialRequest, channel); err != nil {
|
if err = hh.sendInitialRequest(hw, initialRequest, channel); err != nil {
|
||||||
|
|||||||
@@ -9,28 +9,25 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type https struct {
|
type https struct {
|
||||||
|
tlsConfig *tls.Config
|
||||||
httpHandler *httpHandler
|
httpHandler *httpHandler
|
||||||
domain string
|
domain string
|
||||||
port string
|
port string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHTTPSServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool) Transport {
|
func NewHTTPSServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool, tlsConfig *tls.Config) Transport {
|
||||||
return &https{
|
return &https{
|
||||||
httpHandler: newHTTPHandler(sessionRegistry, redirectTLS),
|
tlsConfig: tlsConfig,
|
||||||
|
httpHandler: newHTTPHandler(domain, sessionRegistry, redirectTLS),
|
||||||
domain: domain,
|
domain: domain,
|
||||||
port: port,
|
port: port,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ht *https) Listen() (net.Listener, error) {
|
func (ht *https) Listen() (net.Listener, error) {
|
||||||
tlsConfig, err := NewTLSConfig(ht.domain)
|
return tls.Listen("tcp", ":"+ht.port, ht.tlsConfig)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return tls.Listen("tcp", ":"+ht.port, tlsConfig)
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ht *https) Serve(listener net.Listener) error {
|
func (ht *https) Serve(listener net.Listener) error {
|
||||||
log.Printf("HTTPS server is starting on port %s", ht.port)
|
log.Printf("HTTPS server is starting on port %s", ht.port)
|
||||||
for {
|
for {
|
||||||
|
|||||||
+12
-33
@@ -26,7 +26,8 @@ type TLSManager interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type tlsManager struct {
|
type tlsManager struct {
|
||||||
domain string
|
config config.Config
|
||||||
|
|
||||||
certPath string
|
certPath string
|
||||||
keyPath string
|
keyPath string
|
||||||
storagePath string
|
storagePath string
|
||||||
@@ -42,7 +43,7 @@ type tlsManager struct {
|
|||||||
var globalTLSManager TLSManager
|
var globalTLSManager TLSManager
|
||||||
var tlsManagerOnce sync.Once
|
var tlsManagerOnce sync.Once
|
||||||
|
|
||||||
func NewTLSConfig(domain string) (*tls.Config, error) {
|
func NewTLSConfig(config config.Config) (*tls.Config, error) {
|
||||||
var initErr error
|
var initErr error
|
||||||
|
|
||||||
tlsManagerOnce.Do(func() {
|
tlsManagerOnce.Do(func() {
|
||||||
@@ -51,7 +52,7 @@ func NewTLSConfig(domain string) (*tls.Config, error) {
|
|||||||
storagePath := "certs/tls/certmagic"
|
storagePath := "certs/tls/certmagic"
|
||||||
|
|
||||||
tm := &tlsManager{
|
tm := &tlsManager{
|
||||||
domain: domain,
|
config: config,
|
||||||
certPath: certPath,
|
certPath: certPath,
|
||||||
keyPath: keyPath,
|
keyPath: keyPath,
|
||||||
storagePath: storagePath,
|
storagePath: storagePath,
|
||||||
@@ -66,14 +67,7 @@ func NewTLSConfig(domain string) (*tls.Config, error) {
|
|||||||
tm.useCertMagic = false
|
tm.useCertMagic = false
|
||||||
tm.startCertWatcher()
|
tm.startCertWatcher()
|
||||||
} else {
|
} else {
|
||||||
if !isACMEConfigComplete() {
|
log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", config.Domain(), config.Domain())
|
||||||
log.Printf("User certificates missing or invalid, and ACME configuration is incomplete")
|
|
||||||
log.Printf("To enable automatic certificate generation, set CF_API_TOKEN environment variable")
|
|
||||||
initErr = fmt.Errorf("no valid certificates found and ACME configuration is incomplete (CF_API_TOKEN is required)")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", domain, domain)
|
|
||||||
if err := tm.initCertMagic(); err != nil {
|
if err := tm.initCertMagic(); err != nil {
|
||||||
initErr = fmt.Errorf("failed to initialize CertMagic: %w", err)
|
initErr = fmt.Errorf("failed to initialize CertMagic: %w", err)
|
||||||
return
|
return
|
||||||
@@ -91,11 +85,6 @@ func NewTLSConfig(domain string) (*tls.Config, error) {
|
|||||||
return globalTLSManager.getTLSConfig(), nil
|
return globalTLSManager.getTLSConfig(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func isACMEConfigComplete() bool {
|
|
||||||
cfAPIToken := config.Getenv("CF_API_TOKEN", "")
|
|
||||||
return cfAPIToken != ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tm *tlsManager) userCertsExistAndValid() bool {
|
func (tm *tlsManager) userCertsExistAndValid() bool {
|
||||||
if _, err := os.Stat(tm.certPath); os.IsNotExist(err) {
|
if _, err := os.Stat(tm.certPath); os.IsNotExist(err) {
|
||||||
log.Printf("Certificate file not found: %s", tm.certPath)
|
log.Printf("Certificate file not found: %s", tm.certPath)
|
||||||
@@ -106,7 +95,7 @@ func (tm *tlsManager) userCertsExistAndValid() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return ValidateCertDomains(tm.certPath, tm.domain)
|
return ValidateCertDomains(tm.certPath, tm.config.Domain())
|
||||||
}
|
}
|
||||||
|
|
||||||
func ValidateCertDomains(certPath, domain string) bool {
|
func ValidateCertDomains(certPath, domain string) bool {
|
||||||
@@ -206,15 +195,9 @@ func (tm *tlsManager) startCertWatcher() {
|
|||||||
if certInfo.ModTime().After(lastCertMod) || keyInfo.ModTime().After(lastKeyMod) {
|
if certInfo.ModTime().After(lastCertMod) || keyInfo.ModTime().After(lastKeyMod) {
|
||||||
log.Printf("Certificate files changed, reloading...")
|
log.Printf("Certificate files changed, reloading...")
|
||||||
|
|
||||||
if !ValidateCertDomains(tm.certPath, tm.domain) {
|
if !ValidateCertDomains(tm.certPath, tm.config.Domain()) {
|
||||||
log.Printf("New certificates don't cover required domains")
|
log.Printf("New certificates don't cover required domains")
|
||||||
|
|
||||||
if !isACMEConfigComplete() {
|
|
||||||
log.Printf("Cannot switch to CertMagic: ACME configuration is incomplete (CF_API_TOKEN is required)")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("Switching to CertMagic for automatic certificate management")
|
|
||||||
if err := tm.initCertMagic(); err != nil {
|
if err := tm.initCertMagic(); err != nil {
|
||||||
log.Printf("Failed to initialize CertMagic: %v", err)
|
log.Printf("Failed to initialize CertMagic: %v", err)
|
||||||
continue
|
continue
|
||||||
@@ -241,16 +224,12 @@ func (tm *tlsManager) initCertMagic() error {
|
|||||||
return fmt.Errorf("failed to create cert storage directory: %w", err)
|
return fmt.Errorf("failed to create cert storage directory: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
acmeEmail := config.Getenv("ACME_EMAIL", "admin@"+tm.domain)
|
if tm.config.CFAPIToken() == "" {
|
||||||
cfAPIToken := config.Getenv("CF_API_TOKEN", "")
|
|
||||||
acmeStaging := config.Getenv("ACME_STAGING", "false") == "true"
|
|
||||||
|
|
||||||
if cfAPIToken == "" {
|
|
||||||
return fmt.Errorf("CF_API_TOKEN environment variable is required for automatic certificate generation")
|
return fmt.Errorf("CF_API_TOKEN environment variable is required for automatic certificate generation")
|
||||||
}
|
}
|
||||||
|
|
||||||
cfProvider := &cloudflare.Provider{
|
cfProvider := &cloudflare.Provider{
|
||||||
APIToken: cfAPIToken,
|
APIToken: tm.config.CFAPIToken(),
|
||||||
}
|
}
|
||||||
|
|
||||||
storage := &certmagic.FileStorage{Path: tm.storagePath}
|
storage := &certmagic.FileStorage{Path: tm.storagePath}
|
||||||
@@ -266,7 +245,7 @@ func (tm *tlsManager) initCertMagic() error {
|
|||||||
})
|
})
|
||||||
|
|
||||||
acmeIssuer := certmagic.NewACMEIssuer(magic, certmagic.ACMEIssuer{
|
acmeIssuer := certmagic.NewACMEIssuer(magic, certmagic.ACMEIssuer{
|
||||||
Email: acmeEmail,
|
Email: tm.config.ACMEEmail(),
|
||||||
Agreed: true,
|
Agreed: true,
|
||||||
DNS01Solver: &certmagic.DNS01Solver{
|
DNS01Solver: &certmagic.DNS01Solver{
|
||||||
DNSManager: certmagic.DNSManager{
|
DNSManager: certmagic.DNSManager{
|
||||||
@@ -275,7 +254,7 @@ func (tm *tlsManager) initCertMagic() error {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
if acmeStaging {
|
if tm.config.ACMEStaging() {
|
||||||
acmeIssuer.CA = certmagic.LetsEncryptStagingCA
|
acmeIssuer.CA = certmagic.LetsEncryptStagingCA
|
||||||
log.Printf("Using Let's Encrypt staging server")
|
log.Printf("Using Let's Encrypt staging server")
|
||||||
} else {
|
} else {
|
||||||
@@ -286,7 +265,7 @@ func (tm *tlsManager) initCertMagic() error {
|
|||||||
magic.Issuers = []certmagic.Issuer{acmeIssuer}
|
magic.Issuers = []certmagic.Issuer{acmeIssuer}
|
||||||
tm.magic = magic
|
tm.magic = magic
|
||||||
|
|
||||||
domains := []string{tm.domain, "*." + tm.domain}
|
domains := []string{tm.config.Domain(), "*." + tm.config.Domain()}
|
||||||
log.Printf("Requesting certificates for: %v", domains)
|
log.Printf("Requesting certificates for: %v", domains)
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|||||||
@@ -9,8 +9,6 @@ import (
|
|||||||
_ "net/http/pprof"
|
_ "net/http/pprof"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
"tunnel_pls/internal/config"
|
"tunnel_pls/internal/config"
|
||||||
@@ -21,6 +19,7 @@ import (
|
|||||||
"tunnel_pls/internal/transport"
|
"tunnel_pls/internal/transport"
|
||||||
"tunnel_pls/internal/version"
|
"tunnel_pls/internal/version"
|
||||||
"tunnel_pls/server"
|
"tunnel_pls/server"
|
||||||
|
"tunnel_pls/types"
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
@@ -36,27 +35,12 @@ func main() {
|
|||||||
|
|
||||||
log.Printf("Starting %s", version.GetVersion())
|
log.Printf("Starting %s", version.GetVersion())
|
||||||
|
|
||||||
err := config.Load()
|
conf, err := config.MustLoad()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to load configuration: %s", err)
|
log.Fatalf("Failed to load configuration: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
mode := strings.ToLower(config.Getenv("MODE", "standalone"))
|
|
||||||
isNodeMode := mode == "node"
|
|
||||||
|
|
||||||
pprofEnabled := config.Getenv("PPROF_ENABLED", "false")
|
|
||||||
if pprofEnabled == "true" {
|
|
||||||
pprofPort := config.Getenv("PPROF_PORT", "6060")
|
|
||||||
go func() {
|
|
||||||
pprofAddr := fmt.Sprintf("localhost:%s", pprofPort)
|
|
||||||
log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr)
|
|
||||||
if err = http.ListenAndServe(pprofAddr, nil); err != nil {
|
|
||||||
log.Printf("pprof server error: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
sshConfig := &ssh.ServerConfig{
|
sshConfig := &ssh.ServerConfig{
|
||||||
NoClientAuth: true,
|
NoClientAuth: true,
|
||||||
ServerVersion: fmt.Sprintf("SSH-2.0-TunnelPlease-%s", version.GetShortVersion()),
|
ServerVersion: fmt.Sprintf("SSH-2.0-TunnelPlease-%s", version.GetShortVersion()),
|
||||||
@@ -88,16 +72,11 @@ func main() {
|
|||||||
signal.Notify(shutdownChan, os.Interrupt, syscall.SIGTERM)
|
signal.Notify(shutdownChan, os.Interrupt, syscall.SIGTERM)
|
||||||
|
|
||||||
var grpcClient client.Client
|
var grpcClient client.Client
|
||||||
if isNodeMode {
|
|
||||||
grpcHost := config.Getenv("GRPC_ADDRESS", "localhost")
|
|
||||||
grpcPort := config.Getenv("GRPC_PORT", "8080")
|
|
||||||
grpcAddr := fmt.Sprintf("%s:%s", grpcHost, grpcPort)
|
|
||||||
nodeToken := config.Getenv("NODE_TOKEN", "")
|
|
||||||
if nodeToken == "" {
|
|
||||||
log.Fatalf("NODE_TOKEN is required in node mode")
|
|
||||||
}
|
|
||||||
|
|
||||||
grpcClient, err = client.New(grpcAddr, sessionRegistry)
|
if conf.Mode() == types.ServerModeNODE {
|
||||||
|
grpcAddr := fmt.Sprintf("%s:%s", conf.GRPCAddress(), conf.GRPCPort())
|
||||||
|
|
||||||
|
grpcClient, err = client.New(conf, grpcAddr, sessionRegistry)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create grpc client: %v", err)
|
log.Fatalf("failed to create grpc client: %v", err)
|
||||||
}
|
}
|
||||||
@@ -110,46 +89,15 @@ func main() {
|
|||||||
healthCancel()
|
healthCancel()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
identity := config.Getenv("DOMAIN", "localhost")
|
if err = grpcClient.SubscribeEvents(ctx, conf.Domain(), conf.NodeToken()); err != nil {
|
||||||
if err = grpcClient.SubscribeEvents(ctx, identity, nodeToken); err != nil {
|
|
||||||
errChan <- fmt.Errorf("failed to subscribe to events: %w", err)
|
errChan <- fmt.Errorf("failed to subscribe to events: %w", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
portManager := port.New()
|
|
||||||
rawRange := config.Getenv("ALLOWED_PORTS", "")
|
|
||||||
if rawRange != "" {
|
|
||||||
splitRange := strings.Split(rawRange, "-")
|
|
||||||
if len(splitRange) == 2 {
|
|
||||||
var start, end uint64
|
|
||||||
start, err = strconv.ParseUint(splitRange[0], 10, 16)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Failed to parse start port: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
end, err = strconv.ParseUint(splitRange[1], 10, 16)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Failed to parse end port: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = portManager.AddRange(uint16(start), uint16(end)); err != nil {
|
|
||||||
log.Fatalf("Failed to add port range: %s", err)
|
|
||||||
}
|
|
||||||
log.Printf("PortRegistry range configured: %d-%d", start, end)
|
|
||||||
} else {
|
|
||||||
log.Printf("Invalid ALLOWED_PORTS format, expected 'start-end', got: %s", rawRange)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
tlsEnabled := config.Getenv("TLS_ENABLED", "false") == "true"
|
|
||||||
redirectTLS := config.Getenv("TLS_ENABLED", "false") == "true" && config.Getenv("TLS_REDIRECT", "false") == "true"
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
httpPort := config.Getenv("HTTP_PORT", "8080")
|
|
||||||
|
|
||||||
var httpListener net.Listener
|
var httpListener net.Listener
|
||||||
httpserver := transport.NewHTTPServer(httpPort, sessionRegistry, redirectTLS)
|
httpserver := transport.NewHTTPServer(conf.Domain(), conf.HTTPPort(), sessionRegistry, conf.TLSRedirect())
|
||||||
httpListener, err = httpserver.Listen()
|
httpListener, err = httpserver.Listen()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- fmt.Errorf("failed to start http server: %w", err)
|
errChan <- fmt.Errorf("failed to start http server: %w", err)
|
||||||
@@ -162,19 +110,17 @@ func main() {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if tlsEnabled {
|
if conf.TLSEnabled() {
|
||||||
go func() {
|
go func() {
|
||||||
httpsPort := config.Getenv("HTTPS_PORT", "8443")
|
var httpsListener net.Listener
|
||||||
domain := config.Getenv("DOMAIN", "localhost")
|
tlsConfig, _ := transport.NewTLSConfig(conf)
|
||||||
|
httpsServer := transport.NewHTTPSServer(conf.Domain(), conf.HTTPSPort(), sessionRegistry, conf.TLSRedirect(), tlsConfig)
|
||||||
var httpListener net.Listener
|
httpsListener, err = httpsServer.Listen()
|
||||||
httpserver := transport.NewHTTPSServer(domain, httpsPort, sessionRegistry, redirectTLS)
|
|
||||||
httpListener, err = httpserver.Listen()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- fmt.Errorf("failed to start http server: %w", err)
|
errChan <- fmt.Errorf("failed to start http server: %w", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = httpserver.Serve(httpListener)
|
err = httpsServer.Serve(httpsListener)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- fmt.Errorf("error when serving http server: %w", err)
|
errChan <- fmt.Errorf("error when serving http server: %w", err)
|
||||||
return
|
return
|
||||||
@@ -182,17 +128,34 @@ func main() {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
portManager := port.New()
|
||||||
|
err = portManager.AddRange(conf.AllowedPortsStart(), conf.AllowedPortsEnd())
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to initialize port manager: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var app server.Server
|
var app server.Server
|
||||||
go func() {
|
go func() {
|
||||||
sshPort := config.Getenv("PORT", "2200")
|
app, err = server.New(conf, sshConfig, sessionRegistry, grpcClient, portManager, conf.SSHPort())
|
||||||
app, err = server.New(sshConfig, sessionRegistry, grpcClient, portManager, sshPort)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- fmt.Errorf("failed to start server: %s", err)
|
errChan <- fmt.Errorf("failed to start server: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
app.Start()
|
app.Start()
|
||||||
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
if conf.PprofEnabled() {
|
||||||
|
go func() {
|
||||||
|
pprofAddr := fmt.Sprintf("localhost:%s", conf.PprofPort())
|
||||||
|
log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr)
|
||||||
|
if err = http.ListenAndServe(pprofAddr, nil); err != nil {
|
||||||
|
log.Printf("pprof server error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case err = <-errChan:
|
case err = <-errChan:
|
||||||
log.Printf("error happen : %s", err)
|
log.Printf("error happen : %s", err)
|
||||||
|
|||||||
+8
-5
@@ -7,6 +7,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
"tunnel_pls/internal/config"
|
||||||
"tunnel_pls/internal/grpc/client"
|
"tunnel_pls/internal/grpc/client"
|
||||||
"tunnel_pls/internal/port"
|
"tunnel_pls/internal/port"
|
||||||
"tunnel_pls/internal/registry"
|
"tunnel_pls/internal/registry"
|
||||||
@@ -20,24 +21,26 @@ type Server interface {
|
|||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
type server struct {
|
type server struct {
|
||||||
|
config config.Config
|
||||||
sshPort string
|
sshPort string
|
||||||
sshListener net.Listener
|
sshListener net.Listener
|
||||||
config *ssh.ServerConfig
|
sshConfig *ssh.ServerConfig
|
||||||
grpcClient client.Client
|
grpcClient client.Client
|
||||||
sessionRegistry registry.Registry
|
sessionRegistry registry.Registry
|
||||||
portRegistry port.Port
|
portRegistry port.Port
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(sshConfig *ssh.ServerConfig, sessionRegistry registry.Registry, grpcClient client.Client, portRegistry port.Port, sshPort string) (Server, error) {
|
func New(config config.Config, sshConfig *ssh.ServerConfig, sessionRegistry registry.Registry, grpcClient client.Client, portRegistry port.Port, sshPort string) (Server, error) {
|
||||||
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", sshPort))
|
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", sshPort))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &server{
|
return &server{
|
||||||
|
config: config,
|
||||||
sshPort: sshPort,
|
sshPort: sshPort,
|
||||||
sshListener: listener,
|
sshListener: listener,
|
||||||
config: sshConfig,
|
sshConfig: sshConfig,
|
||||||
grpcClient: grpcClient,
|
grpcClient: grpcClient,
|
||||||
sessionRegistry: sessionRegistry,
|
sessionRegistry: sessionRegistry,
|
||||||
portRegistry: portRegistry,
|
portRegistry: portRegistry,
|
||||||
@@ -66,7 +69,7 @@ func (s *server) Close() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *server) handleConnection(conn net.Conn) {
|
func (s *server) handleConnection(conn net.Conn) {
|
||||||
sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.config)
|
sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.sshConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("failed to establish SSH connection: %v", err)
|
log.Printf("failed to establish SSH connection: %v", err)
|
||||||
err = conn.Close()
|
err = conn.Close()
|
||||||
@@ -92,7 +95,7 @@ func (s *server) handleConnection(conn net.Conn) {
|
|||||||
cancel()
|
cancel()
|
||||||
}
|
}
|
||||||
log.Println("SSH connection established:", sshConn.User())
|
log.Println("SSH connection established:", sshConn.User())
|
||||||
sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry, s.portRegistry, user)
|
sshSession := session.New(s.config, sshConn, forwardingReqs, chans, s.sessionRegistry, s.portRegistry, user)
|
||||||
err = sshSession.Start()
|
err = sshSession.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("SSH session ended with error: %v", err)
|
log.Printf("SSH session ended with error: %v", err)
|
||||||
|
|||||||
@@ -18,37 +18,6 @@ import (
|
|||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
var bufferPool = sync.Pool{
|
|
||||||
New: func() interface{} {
|
|
||||||
bufSize := config.GetBufferSize()
|
|
||||||
return make([]byte, bufSize)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
func copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
|
|
||||||
buf := bufferPool.Get().([]byte)
|
|
||||||
defer bufferPool.Put(buf)
|
|
||||||
return io.CopyBuffer(dst, src, buf)
|
|
||||||
}
|
|
||||||
|
|
||||||
type forwarder struct {
|
|
||||||
listener net.Listener
|
|
||||||
tunnelType types.TunnelType
|
|
||||||
forwardedPort uint16
|
|
||||||
slug slug.Slug
|
|
||||||
conn ssh.Conn
|
|
||||||
}
|
|
||||||
|
|
||||||
func New(slug slug.Slug, conn ssh.Conn) Forwarder {
|
|
||||||
return &forwarder{
|
|
||||||
listener: nil,
|
|
||||||
tunnelType: types.UNKNOWN,
|
|
||||||
forwardedPort: 0,
|
|
||||||
slug: slug,
|
|
||||||
conn: conn,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type Forwarder interface {
|
type Forwarder interface {
|
||||||
SetType(tunnelType types.TunnelType)
|
SetType(tunnelType types.TunnelType)
|
||||||
SetForwardedPort(port uint16)
|
SetForwardedPort(port uint16)
|
||||||
@@ -62,6 +31,36 @@ type Forwarder interface {
|
|||||||
WriteBadGatewayResponse(dst io.Writer)
|
WriteBadGatewayResponse(dst io.Writer)
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
type forwarder struct {
|
||||||
|
listener net.Listener
|
||||||
|
tunnelType types.TunnelType
|
||||||
|
forwardedPort uint16
|
||||||
|
slug slug.Slug
|
||||||
|
conn ssh.Conn
|
||||||
|
bufferPool sync.Pool
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(config config.Config, slug slug.Slug, conn ssh.Conn) Forwarder {
|
||||||
|
return &forwarder{
|
||||||
|
listener: nil,
|
||||||
|
tunnelType: types.TunnelTypeUNKNOWN,
|
||||||
|
forwardedPort: 0,
|
||||||
|
slug: slug,
|
||||||
|
conn: conn,
|
||||||
|
bufferPool: sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
bufSize := config.BufferSize()
|
||||||
|
return make([]byte, bufSize)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *forwarder) copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
|
||||||
|
buf := f.bufferPool.Get().([]byte)
|
||||||
|
defer f.bufferPool.Put(buf)
|
||||||
|
return io.CopyBuffer(dst, src, buf)
|
||||||
|
}
|
||||||
|
|
||||||
func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
|
func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
|
||||||
type channelResult struct {
|
type channelResult struct {
|
||||||
@@ -107,7 +106,7 @@ func closeWriter(w io.Writer) error {
|
|||||||
|
|
||||||
func (f *forwarder) copyAndClose(dst io.Writer, src io.Reader, direction string) error {
|
func (f *forwarder) copyAndClose(dst io.Writer, src io.Reader, direction string) error {
|
||||||
var errs []error
|
var errs []error
|
||||||
_, err := copyWithBuffer(dst, src)
|
_, err := f.copyWithBuffer(dst, src)
|
||||||
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
|
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
|
||||||
errs = append(errs, fmt.Errorf("copy error (%s): %w", direction, err))
|
errs = append(errs, fmt.Errorf("copy error (%s): %w", direction, err))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,9 +18,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Interaction interface {
|
type Interaction interface {
|
||||||
Mode() types.Mode
|
Mode() types.InteractiveMode
|
||||||
SetChannel(channel ssh.Channel)
|
SetChannel(channel ssh.Channel)
|
||||||
SetMode(m types.Mode)
|
SetMode(m types.InteractiveMode)
|
||||||
SetWH(w, h int)
|
SetWH(w, h int)
|
||||||
Start()
|
Start()
|
||||||
Redraw()
|
Redraw()
|
||||||
@@ -39,6 +39,7 @@ type Forwarder interface {
|
|||||||
|
|
||||||
type CloseFunc func() error
|
type CloseFunc func() error
|
||||||
type interaction struct {
|
type interaction struct {
|
||||||
|
config config.Config
|
||||||
channel ssh.Channel
|
channel ssh.Channel
|
||||||
slug slug.Slug
|
slug slug.Slug
|
||||||
forwarder Forwarder
|
forwarder Forwarder
|
||||||
@@ -48,14 +49,14 @@ type interaction struct {
|
|||||||
program *tea.Program
|
program *tea.Program
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
mode types.Mode
|
mode types.InteractiveMode
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *interaction) SetMode(m types.Mode) {
|
func (i *interaction) SetMode(m types.InteractiveMode) {
|
||||||
i.mode = m
|
i.mode = m
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *interaction) Mode() types.Mode {
|
func (i *interaction) Mode() types.InteractiveMode {
|
||||||
return i.mode
|
return i.mode
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -75,9 +76,10 @@ func (i *interaction) SetWH(w, h int) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction {
|
func New(config config.Config, slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
return &interaction{
|
return &interaction{
|
||||||
|
config: config,
|
||||||
channel: nil,
|
channel: nil,
|
||||||
slug: slug,
|
slug: slug,
|
||||||
forwarder: forwarder,
|
forwarder: forwarder,
|
||||||
@@ -174,14 +176,13 @@ func (m *model) View() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (i *interaction) Start() {
|
func (i *interaction) Start() {
|
||||||
if i.mode == types.HEADLESS {
|
if i.mode == types.InteractiveModeHEADLESS {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
lipgloss.SetColorProfile(termenv.TrueColor)
|
lipgloss.SetColorProfile(termenv.TrueColor)
|
||||||
|
|
||||||
domain := config.Getenv("DOMAIN", "localhost")
|
|
||||||
protocol := "http"
|
protocol := "http"
|
||||||
if config.Getenv("TLS_ENABLED", "false") == "true" {
|
if i.config.TLSEnabled() {
|
||||||
protocol = "https"
|
protocol = "https"
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -209,7 +210,7 @@ func (i *interaction) Start() {
|
|||||||
ti.Width = 50
|
ti.Width = 50
|
||||||
|
|
||||||
m := &model{
|
m := &model{
|
||||||
domain: domain,
|
domain: i.config.Domain(),
|
||||||
protocol: protocol,
|
protocol: protocol,
|
||||||
tunnelType: tunnelType,
|
tunnelType: tunnelType,
|
||||||
port: port,
|
port: port,
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ type model struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *model) getTunnelURL() string {
|
func (m *model) getTunnelURL() string {
|
||||||
if m.tunnelType == types.HTTP {
|
if m.tunnelType == types.TunnelTypeHTTP {
|
||||||
return buildURL(m.protocol, m.interaction.slug.String(), m.domain)
|
return buildURL(m.protocol, m.interaction.slug.String(), m.domain)
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("tcp://%s:%d", m.domain, m.port)
|
return fmt.Sprintf("tcp://%s:%d", m.domain, m.port)
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||||
var cmd tea.Cmd
|
var cmd tea.Cmd
|
||||||
|
|
||||||
if m.tunnelType != types.HTTP {
|
if m.tunnelType != types.TunnelTypeHTTP {
|
||||||
m.editingSlug = false
|
m.editingSlug = false
|
||||||
m.slugError = ""
|
m.slugError = ""
|
||||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||||
@@ -30,10 +30,10 @@ func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
|||||||
inputValue := m.slugInput.Value()
|
inputValue := m.slugInput.Value()
|
||||||
if err := m.interaction.sessionRegistry.Update(m.interaction.user, types.SessionKey{
|
if err := m.interaction.sessionRegistry.Update(m.interaction.user, types.SessionKey{
|
||||||
Id: m.interaction.slug.String(),
|
Id: m.interaction.slug.String(),
|
||||||
Type: types.HTTP,
|
Type: types.TunnelTypeHTTP,
|
||||||
}, types.SessionKey{
|
}, types.SessionKey{
|
||||||
Id: inputValue,
|
Id: inputValue,
|
||||||
Type: types.HTTP,
|
Type: types.TunnelTypeHTTP,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
m.slugError = err.Error()
|
m.slugError = err.Error()
|
||||||
return m, nil
|
return m, nil
|
||||||
@@ -130,7 +130,7 @@ func (m *model) slugView() string {
|
|||||||
b.WriteString(titleStyle.Render(title))
|
b.WriteString(titleStyle.Render(title))
|
||||||
b.WriteString("\n\n")
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
if m.tunnelType != types.HTTP {
|
if m.tunnelType != types.TunnelTypeHTTP {
|
||||||
warningBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
|
warningBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
|
||||||
warningBoxStyle := lipgloss.NewStyle().
|
warningBoxStyle := lipgloss.NewStyle().
|
||||||
Foreground(lipgloss.Color("#FFA500")).
|
Foreground(lipgloss.Color("#FFA500")).
|
||||||
|
|||||||
@@ -2,8 +2,6 @@ package lifecycle
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
portUtil "tunnel_pls/internal/port"
|
portUtil "tunnel_pls/internal/port"
|
||||||
@@ -24,7 +22,7 @@ type SessionRegistry interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type lifecycle struct {
|
type lifecycle struct {
|
||||||
status types.Status
|
status types.SessionStatus
|
||||||
conn ssh.Conn
|
conn ssh.Conn
|
||||||
channel ssh.Channel
|
channel ssh.Channel
|
||||||
forwarder Forwarder
|
forwarder Forwarder
|
||||||
@@ -37,7 +35,7 @@ type lifecycle struct {
|
|||||||
|
|
||||||
func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUtil.Port, sessionRegistry SessionRegistry, user string) Lifecycle {
|
func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUtil.Port, sessionRegistry SessionRegistry, user string) Lifecycle {
|
||||||
return &lifecycle{
|
return &lifecycle{
|
||||||
status: types.INITIALIZING,
|
status: types.SessionStatusINITIALIZING,
|
||||||
conn: conn,
|
conn: conn,
|
||||||
channel: nil,
|
channel: nil,
|
||||||
forwarder: forwarder,
|
forwarder: forwarder,
|
||||||
@@ -54,7 +52,7 @@ type Lifecycle interface {
|
|||||||
PortRegistry() portUtil.Port
|
PortRegistry() portUtil.Port
|
||||||
User() string
|
User() string
|
||||||
SetChannel(channel ssh.Channel)
|
SetChannel(channel ssh.Channel)
|
||||||
SetStatus(status types.Status)
|
SetStatus(status types.SessionStatus)
|
||||||
IsActive() bool
|
IsActive() bool
|
||||||
StartedAt() time.Time
|
StartedAt() time.Time
|
||||||
Close() error
|
Close() error
|
||||||
@@ -74,35 +72,30 @@ func (l *lifecycle) SetChannel(channel ssh.Channel) {
|
|||||||
func (l *lifecycle) Connection() ssh.Conn {
|
func (l *lifecycle) Connection() ssh.Conn {
|
||||||
return l.conn
|
return l.conn
|
||||||
}
|
}
|
||||||
func (l *lifecycle) SetStatus(status types.Status) {
|
func (l *lifecycle) SetStatus(status types.SessionStatus) {
|
||||||
l.status = status
|
l.status = status
|
||||||
if status == types.RUNNING && l.startedAt.IsZero() {
|
if status == types.SessionStatusRUNNING && l.startedAt.IsZero() {
|
||||||
l.startedAt = time.Now()
|
l.startedAt = time.Now()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func closeIfNotNil(c interface{ Close() error }) error {
|
||||||
|
if c != nil {
|
||||||
|
return c.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (l *lifecycle) Close() error {
|
func (l *lifecycle) Close() error {
|
||||||
var firstErr error
|
var errs []error
|
||||||
tunnelType := l.forwarder.TunnelType()
|
tunnelType := l.forwarder.TunnelType()
|
||||||
|
|
||||||
if err := l.forwarder.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
if err := closeIfNotNil(l.channel); err != nil {
|
||||||
firstErr = err
|
errs = append(errs, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if l.channel != nil {
|
if err := closeIfNotNil(l.conn); err != nil {
|
||||||
if err := l.channel.Close(); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
|
errs = append(errs, err)
|
||||||
if firstErr == nil {
|
|
||||||
firstErr = err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if l.conn != nil {
|
|
||||||
if err := l.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
|
||||||
if firstErr == nil {
|
|
||||||
firstErr = err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
clientSlug := l.slug.String()
|
clientSlug := l.slug.String()
|
||||||
@@ -112,17 +105,20 @@ func (l *lifecycle) Close() error {
|
|||||||
}
|
}
|
||||||
l.sessionRegistry.Remove(key)
|
l.sessionRegistry.Remove(key)
|
||||||
|
|
||||||
if tunnelType == types.TCP {
|
if tunnelType == types.TunnelTypeTCP {
|
||||||
if err := l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil {
|
if err := l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false); err != nil {
|
||||||
firstErr = err
|
errs = append(errs, err)
|
||||||
|
}
|
||||||
|
if err := l.forwarder.Close(); err != nil {
|
||||||
|
errs = append(errs, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return firstErr
|
return errors.Join(errs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *lifecycle) IsActive() bool {
|
func (l *lifecycle) IsActive() bool {
|
||||||
return l.status == types.RUNNING
|
return l.status == types.SessionStatusRUNNING
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *lifecycle) StartedAt() time.Time {
|
func (l *lifecycle) StartedAt() time.Time {
|
||||||
|
|||||||
+19
-17
@@ -37,6 +37,7 @@ type Session interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type session struct {
|
type session struct {
|
||||||
|
config config.Config
|
||||||
initialReq <-chan *ssh.Request
|
initialReq <-chan *ssh.Request
|
||||||
sshChan <-chan ssh.NewChannel
|
sshChan <-chan ssh.NewChannel
|
||||||
lifecycle lifecycle.Lifecycle
|
lifecycle lifecycle.Lifecycle
|
||||||
@@ -48,13 +49,14 @@ type session struct {
|
|||||||
|
|
||||||
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
|
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
|
||||||
|
|
||||||
func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry registry.Registry, portRegistry portUtil.Port, user string) Session {
|
func New(config config.Config, conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry registry.Registry, portRegistry portUtil.Port, user string) Session {
|
||||||
slugManager := slug.New()
|
slugManager := slug.New()
|
||||||
forwarderManager := forwarder.New(slugManager, conn)
|
forwarderManager := forwarder.New(config, slugManager, conn)
|
||||||
lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, portRegistry, sessionRegistry, user)
|
lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, portRegistry, sessionRegistry, user)
|
||||||
interactionManager := interaction.New(slugManager, forwarderManager, sessionRegistry, user, lifecycleManager.Close)
|
interactionManager := interaction.New(config, slugManager, forwarderManager, sessionRegistry, user, lifecycleManager.Close)
|
||||||
|
|
||||||
return &session{
|
return &session{
|
||||||
|
config: config,
|
||||||
initialReq: initialReq,
|
initialReq: initialReq,
|
||||||
sshChan: sshChan,
|
sshChan: sshChan,
|
||||||
lifecycle: lifecycleManager,
|
lifecycle: lifecycleManager,
|
||||||
@@ -83,12 +85,12 @@ func (s *session) Slug() slug.Slug {
|
|||||||
|
|
||||||
func (s *session) Detail() *types.Detail {
|
func (s *session) Detail() *types.Detail {
|
||||||
tunnelTypeMap := map[types.TunnelType]string{
|
tunnelTypeMap := map[types.TunnelType]string{
|
||||||
types.HTTP: "HTTP",
|
types.TunnelTypeHTTP: "TunnelTypeHTTP",
|
||||||
types.TCP: "TCP",
|
types.TunnelTypeTCP: "TunnelTypeTCP",
|
||||||
}
|
}
|
||||||
tunnelType, ok := tunnelTypeMap[s.forwarder.TunnelType()]
|
tunnelType, ok := tunnelTypeMap[s.forwarder.TunnelType()]
|
||||||
if !ok {
|
if !ok {
|
||||||
tunnelType = "UNKNOWN"
|
tunnelType = "TunnelTypeUNKNOWN"
|
||||||
}
|
}
|
||||||
|
|
||||||
return &types.Detail{
|
return &types.Detail{
|
||||||
@@ -131,7 +133,7 @@ func (s *session) setupSessionMode() error {
|
|||||||
}
|
}
|
||||||
return s.setupInteractiveMode(channel)
|
return s.setupInteractiveMode(channel)
|
||||||
case <-time.After(500 * time.Millisecond):
|
case <-time.After(500 * time.Millisecond):
|
||||||
s.interaction.SetMode(types.HEADLESS)
|
s.interaction.SetMode(types.InteractiveModeHEADLESS)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -152,13 +154,13 @@ func (s *session) setupInteractiveMode(channel ssh.NewChannel) error {
|
|||||||
|
|
||||||
s.lifecycle.SetChannel(ch)
|
s.lifecycle.SetChannel(ch)
|
||||||
s.interaction.SetChannel(ch)
|
s.interaction.SetChannel(ch)
|
||||||
s.interaction.SetMode(types.INTERACTIVE)
|
s.interaction.SetMode(types.InteractiveModeINTERACTIVE)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) handleMissingForwardRequest() error {
|
func (s *session) handleMissingForwardRequest() error {
|
||||||
err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200")))
|
err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", s.config.Domain(), s.config.SSHPort()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -169,8 +171,8 @@ func (s *session) handleMissingForwardRequest() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) shouldRejectUnauthorized() bool {
|
func (s *session) shouldRejectUnauthorized() bool {
|
||||||
return s.interaction.Mode() == types.HEADLESS &&
|
return s.interaction.Mode() == types.InteractiveModeHEADLESS &&
|
||||||
config.Getenv("MODE", "standalone") == "standalone" &&
|
s.config.Mode() == types.ServerModeSTANDALONE &&
|
||||||
s.lifecycle.User() == "UNAUTHORIZED"
|
s.lifecycle.User() == "UNAUTHORIZED"
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -318,7 +320,7 @@ func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listen
|
|||||||
s.forwarder.SetType(tunnelType)
|
s.forwarder.SetType(tunnelType)
|
||||||
s.forwarder.SetForwardedPort(portToBind)
|
s.forwarder.SetForwardedPort(portToBind)
|
||||||
s.slug.Set(slug)
|
s.slug.Set(slug)
|
||||||
s.lifecycle.SetStatus(types.RUNNING)
|
s.lifecycle.SetStatus(types.SessionStatusRUNNING)
|
||||||
|
|
||||||
if listener != nil {
|
if listener != nil {
|
||||||
s.forwarder.SetListener(listener)
|
s.forwarder.SetListener(listener)
|
||||||
@@ -348,12 +350,12 @@ func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to create slug: %s", err))
|
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to create slug: %s", err))
|
||||||
}
|
}
|
||||||
key := types.SessionKey{Id: randomString, Type: types.HTTP}
|
key := types.SessionKey{Id: randomString, Type: types.TunnelTypeHTTP}
|
||||||
if !s.registry.Register(key, s) {
|
if !s.registry.Register(key, s) {
|
||||||
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to register client with slug: %s", randomString))
|
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to register client with slug: %s", randomString))
|
||||||
}
|
}
|
||||||
|
|
||||||
err = s.finalizeForwarding(req, portToBind, nil, types.HTTP, key.Id)
|
err = s.finalizeForwarding(req, portToBind, nil, types.TunnelTypeHTTP, key.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return s.denyForwardingRequest(req, &key, nil, fmt.Sprintf("Failed to finalize forwarding: %s", err))
|
return s.denyForwardingRequest(req, &key, nil, fmt.Sprintf("Failed to finalize forwarding: %s", err))
|
||||||
}
|
}
|
||||||
@@ -371,12 +373,12 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin
|
|||||||
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind))
|
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind))
|
||||||
}
|
}
|
||||||
|
|
||||||
key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TCP}
|
key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TunnelTypeTCP}
|
||||||
if !s.registry.Register(key, s) {
|
if !s.registry.Register(key, s) {
|
||||||
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Failed to register TCP client with id: %s", key.Id))
|
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Failed to register TunnelTypeTCP client with id: %s", key.Id))
|
||||||
}
|
}
|
||||||
|
|
||||||
err = s.finalizeForwarding(req, portToBind, listener, types.TCP, key.Id)
|
err = s.finalizeForwarding(req, portToBind, listener, types.TunnelTypeTCP, key.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return s.denyForwardingRequest(req, &key, listener, fmt.Sprintf("Failed to finalize forwarding: %s", err))
|
return s.denyForwardingRequest(req, &key, listener, fmt.Sprintf("Failed to finalize forwarding: %s", err))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
sonar.projectKey=tunnel-please
|
||||||
+16
-9
@@ -2,26 +2,33 @@ package types
|
|||||||
|
|
||||||
import "time"
|
import "time"
|
||||||
|
|
||||||
type Status int
|
type SessionStatus int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
INITIALIZING Status = iota
|
SessionStatusINITIALIZING SessionStatus = iota
|
||||||
RUNNING
|
SessionStatusRUNNING
|
||||||
)
|
)
|
||||||
|
|
||||||
type Mode int
|
type InteractiveMode int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
INTERACTIVE Mode = iota
|
InteractiveModeINTERACTIVE InteractiveMode = iota + 1
|
||||||
HEADLESS
|
InteractiveModeHEADLESS
|
||||||
)
|
)
|
||||||
|
|
||||||
type TunnelType int
|
type TunnelType int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
UNKNOWN TunnelType = iota
|
TunnelTypeUNKNOWN TunnelType = iota
|
||||||
HTTP
|
TunnelTypeHTTP
|
||||||
TCP
|
TunnelTypeTCP
|
||||||
|
)
|
||||||
|
|
||||||
|
type ServerMode int
|
||||||
|
|
||||||
|
const (
|
||||||
|
ServerModeSTANDALONE = iota + 1
|
||||||
|
ServerModeNODE
|
||||||
)
|
)
|
||||||
|
|
||||||
type SessionKey struct {
|
type SessionKey struct {
|
||||||
|
|||||||
Reference in New Issue
Block a user