fix(deps): update module github.com/caddyserver/certmagic to v0.25.1 - autoclosed #63

Closed
Renovate-Clanker wants to merge 113 commits from renovate/github.com-caddyserver-certmagic-0.x into main
38 changed files with 1763 additions and 1397 deletions
Showing only changes of commit d91eecb2a0 - Show all commits
+20
View File
@@ -0,0 +1,20 @@
on:
push:
pull_request:
types: [opened, synchronize, reopened]
name: SonarQube Scan
jobs:
sonarqube:
name: SonarQube Trigger
runs-on: ubuntu-latest
steps:
- name: Checking out
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: SonarQube Scan
uses: SonarSource/sonarqube-scan-action@v7.0.0
env:
SONAR_HOST_URL: ${{ secrets.SONARQUBE_HOST }}
SONAR_TOKEN: ${{ secrets.SONARQUBE_TOKEN }}
+53 -23
View File
@@ -1,33 +1,63 @@
package config package config
import ( import "tunnel_pls/types"
"os"
"strconv"
"github.com/joho/godotenv" type Config interface {
) Domain() string
SSHPort() string
func Load() error { HTTPPort() string
if _, err := os.Stat(".env"); err == nil { HTTPSPort() string
return godotenv.Load(".env")
} TLSEnabled() bool
return nil TLSRedirect() bool
ACMEEmail() string
CFAPIToken() string
ACMEStaging() bool
AllowedPortsStart() uint16
AllowedPortsEnd() uint16
BufferSize() int
PprofEnabled() bool
PprofPort() string
Mode() types.ServerMode
GRPCAddress() string
GRPCPort() string
NodeToken() string
} }
func Getenv(key, defaultValue string) string { func MustLoad() (Config, error) {
val := os.Getenv(key) if err := loadEnvFile(); err != nil {
if val == "" { return nil, err
val = defaultValue
} }
return val cfg, err := parse()
if err != nil {
return nil, err
}
return cfg, nil
} }
func GetBufferSize() int { func (c *config) Domain() string { return c.domain }
sizeStr := Getenv("BUFFER_SIZE", "32768") func (c *config) SSHPort() string { return c.sshPort }
size, err := strconv.Atoi(sizeStr) func (c *config) HTTPPort() string { return c.httpPort }
if err != nil || size < 4096 || size > 1048576 { func (c *config) HTTPSPort() string { return c.httpsPort }
return 32768 func (c *config) TLSEnabled() bool { return c.tlsEnabled }
} func (c *config) TLSRedirect() bool { return c.tlsRedirect }
return size func (c *config) ACMEEmail() string { return c.acmeEmail }
} func (c *config) CFAPIToken() string { return c.cfAPIToken }
func (c *config) ACMEStaging() bool { return c.acmeStaging }
func (c *config) AllowedPortsStart() uint16 { return c.allowedPortsStart }
func (c *config) AllowedPortsEnd() uint16 { return c.allowedPortsEnd }
func (c *config) BufferSize() int { return c.bufferSize }
func (c *config) PprofEnabled() bool { return c.pprofEnabled }
func (c *config) PprofPort() string { return c.pprofPort }
func (c *config) Mode() types.ServerMode { return c.mode }
func (c *config) GRPCAddress() string { return c.grpcAddress }
func (c *config) GRPCPort() string { return c.grpcPort }
func (c *config) NodeToken() string { return c.nodeToken }
+170
View File
@@ -0,0 +1,170 @@
package config
import (
"fmt"
"log"
"os"
"strconv"
"strings"
"tunnel_pls/types"
"github.com/joho/godotenv"
)
type config struct {
domain string
sshPort string
httpPort string
httpsPort string
tlsEnabled bool
tlsRedirect bool
acmeEmail string
cfAPIToken string
acmeStaging bool
allowedPortsStart uint16
allowedPortsEnd uint16
bufferSize int
pprofEnabled bool
pprofPort string
mode types.ServerMode
grpcAddress string
grpcPort string
nodeToken string
}
func parse() (*config, error) {
mode, err := parseMode()
if err != nil {
return nil, err
}
domain := getenv("DOMAIN", "localhost")
sshPort := getenv("PORT", "2200")
httpPort := getenv("HTTP_PORT", "8080")
httpsPort := getenv("HTTPS_PORT", "8443")
tlsEnabled := getenvBool("TLS_ENABLED", false)
tlsRedirect := tlsEnabled && getenvBool("TLS_REDIRECT", false)
acmeEmail := getenv("ACME_EMAIL", "admin@"+domain)
acmeStaging := getenvBool("ACME_STAGING", false)
cfToken := getenv("CF_API_TOKEN", "")
if tlsEnabled && cfToken == "" {
return nil, fmt.Errorf("CF_API_TOKEN is required when TLS is enabled")
}
start, end, err := parseAllowedPorts()
if err != nil {
return nil, err
}
bufferSize := parseBufferSize()
pprofEnabled := getenvBool("PPROF_ENABLED", false)
pprofPort := getenv("PPROF_PORT", "6060")
grpcHost := getenv("GRPC_ADDRESS", "localhost")
grpcPort := getenv("GRPC_PORT", "8080")
nodeToken := getenv("NODE_TOKEN", "")
if mode == types.ServerModeNODE && nodeToken == "" {
return nil, fmt.Errorf("NODE_TOKEN is required in node mode")
}
return &config{
domain: domain,
sshPort: sshPort,
httpPort: httpPort,
httpsPort: httpsPort,
tlsEnabled: tlsEnabled,
tlsRedirect: tlsRedirect,
acmeEmail: acmeEmail,
cfAPIToken: cfToken,
acmeStaging: acmeStaging,
allowedPortsStart: start,
allowedPortsEnd: end,
bufferSize: bufferSize,
pprofEnabled: pprofEnabled,
pprofPort: pprofPort,
mode: mode,
grpcAddress: grpcHost,
grpcPort: grpcPort,
nodeToken: nodeToken,
}, nil
}
func loadEnvFile() error {
if _, err := os.Stat(".env"); err == nil {
return godotenv.Load(".env")
}
return nil
}
func parseMode() (types.ServerMode, error) {
switch strings.ToLower(getenv("MODE", "standalone")) {
case "standalone":
return types.ServerModeSTANDALONE, nil
case "node":
return types.ServerModeNODE, nil
default:
return 0, fmt.Errorf("invalid MODE value")
}
}
func parseAllowedPorts() (uint16, uint16, error) {
raw := getenv("ALLOWED_PORTS", "")
if raw == "" {
return 0, 0, nil
}
parts := strings.Split(raw, "-")
if len(parts) != 2 {
return 0, 0, fmt.Errorf("invalid ALLOWED_PORTS format")
}
start, err := strconv.ParseUint(parts[0], 10, 16)
if err != nil {
return 0, 0, err
}
end, err := strconv.ParseUint(parts[1], 10, 16)
if err != nil {
return 0, 0, err
}
return uint16(start), uint16(end), nil
}
func parseBufferSize() int {
raw := getenv("BUFFER_SIZE", "32768")
size, err := strconv.Atoi(raw)
if err != nil || size < 4096 || size > 1048576 {
log.Println("Invalid BUFFER_SIZE, falling back to 4096")
return 4096
}
return size
}
func getenv(key, def string) string {
if v := os.Getenv(key); v != "" {
return v
}
return def
}
func getenvBool(key string, def bool) bool {
val := os.Getenv(key)
if val == "" {
return def
}
return val == "true"
}
+11 -10
View File
@@ -8,10 +8,9 @@ import (
"log" "log"
"time" "time"
"tunnel_pls/internal/config" "tunnel_pls/internal/config"
"tunnel_pls/internal/registry"
"tunnel_pls/types" "tunnel_pls/types"
"tunnel_pls/session"
proto "git.fossy.my.id/bagas/tunnel-please-grpc/gen" proto "git.fossy.my.id/bagas/tunnel-please-grpc/gen"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
@@ -30,15 +29,16 @@ type Client interface {
CheckServerHealth(ctx context.Context) error CheckServerHealth(ctx context.Context) error
} }
type client struct { type client struct {
config config.Config
conn *grpc.ClientConn conn *grpc.ClientConn
address string address string
sessionRegistry session.Registry sessionRegistry registry.Registry
eventService proto.EventServiceClient eventService proto.EventServiceClient
authorizeConnectionService proto.UserServiceClient authorizeConnectionService proto.UserServiceClient
closing bool closing bool
} }
func New(address string, sessionRegistry session.Registry) (Client, error) { func New(config config.Config, address string, sessionRegistry registry.Registry) (Client, error) {
var opts []grpc.DialOption var opts []grpc.DialOption
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
@@ -67,6 +67,7 @@ func New(address string, sessionRegistry session.Registry) (Client, error) {
authorizeConnectionService := proto.NewUserServiceClient(conn) authorizeConnectionService := proto.NewUserServiceClient(conn)
return &client{ return &client{
config: config,
conn: conn, conn: conn,
address: address, address: address,
sessionRegistry: sessionRegistry, sessionRegistry: sessionRegistry,
@@ -193,7 +194,7 @@ func (c *client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node,
oldSlug := slugEvent.GetOld() oldSlug := slugEvent.GetOld()
newSlug := slugEvent.GetNew() newSlug := slugEvent.GetNew()
userSession, err := c.sessionRegistry.Get(types.SessionKey{Id: oldSlug, Type: types.HTTP}) userSession, err := c.sessionRegistry.Get(types.SessionKey{Id: oldSlug, Type: types.TunnelTypeHTTP})
if err != nil { if err != nil {
return c.sendNode(subscribe, &proto.Node{ return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_SLUG_CHANGE_RESPONSE, Type: proto.EventType_SLUG_CHANGE_RESPONSE,
@@ -203,7 +204,7 @@ func (c *client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node,
}, "slug change failure response") }, "slug change failure response")
} }
if err = c.sessionRegistry.Update(user, types.SessionKey{Id: oldSlug, Type: types.HTTP}, types.SessionKey{Id: newSlug, Type: types.HTTP}); err != nil { if err = c.sessionRegistry.Update(user, types.SessionKey{Id: oldSlug, Type: types.TunnelTypeHTTP}, types.SessionKey{Id: newSlug, Type: types.TunnelTypeHTTP}); err != nil {
return c.sendNode(subscribe, &proto.Node{ return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_SLUG_CHANGE_RESPONSE, Type: proto.EventType_SLUG_CHANGE_RESPONSE,
Payload: &proto.Node_SlugEventResponse{ Payload: &proto.Node_SlugEventResponse{
@@ -228,7 +229,7 @@ func (c *client) handleGetSessions(subscribe grpc.BidiStreamingClient[proto.Node
for _, ses := range sessions { for _, ses := range sessions {
detail := ses.Detail() detail := ses.Detail()
details = append(details, &proto.Detail{ details = append(details, &proto.Detail{
Node: config.Getenv("DOMAIN", "localhost"), Node: c.config.Domain(),
ForwardingType: detail.ForwardingType, ForwardingType: detail.ForwardingType,
Slug: detail.Slug, Slug: detail.Slug,
UserId: detail.UserID, UserId: detail.UserID,
@@ -300,11 +301,11 @@ func (c *client) sendNode(subscribe grpc.BidiStreamingClient[proto.Node, proto.E
func (c *client) protoToTunnelType(t proto.TunnelType) (types.TunnelType, error) { func (c *client) protoToTunnelType(t proto.TunnelType) (types.TunnelType, error) {
switch t { switch t {
case proto.TunnelType_HTTP: case proto.TunnelType_HTTP:
return types.HTTP, nil return types.TunnelTypeHTTP, nil
case proto.TunnelType_TCP: case proto.TunnelType_TCP:
return types.TCP, nil return types.TunnelTypeTCP, nil
default: default:
return types.UNKNOWN, fmt.Errorf("unknown tunnel type received") return types.TunnelTypeUNKNOWN, fmt.Errorf("unknown tunnel type received")
} }
} }
+30
View File
@@ -0,0 +1,30 @@
package header
type ResponseHeader interface {
Value(key string) string
Set(key string, value string)
Remove(key string)
Finalize() []byte
}
type responseHeader struct {
startLine []byte
headers map[string]string
}
type RequestHeader interface {
Value(key string) string
Set(key string, value string)
Remove(key string)
Finalize() []byte
Method() string
Path() string
Version() string
}
type requestHeader struct {
method string
path string
version string
startLine []byte
headers map[string]string
}
+148
View File
@@ -0,0 +1,148 @@
package header
import (
"bufio"
"bytes"
"fmt"
)
func setRemainingHeaders(remaining []byte, header interface {
Set(key string, value string)
}) {
for len(remaining) > 0 {
lineEnd := bytes.Index(remaining, []byte("\r\n"))
if lineEnd == -1 {
lineEnd = len(remaining)
}
line := remaining[:lineEnd]
if len(line) == 0 {
break
}
colonIdx := bytes.IndexByte(line, ':')
if colonIdx != -1 {
key := bytes.TrimSpace(line[:colonIdx])
value := bytes.TrimSpace(line[colonIdx+1:])
header.Set(string(key), string(value))
}
if lineEnd == len(remaining) {
break
}
remaining = remaining[lineEnd+2:]
}
}
func parseHeadersFromBytes(headerData []byte) (RequestHeader, error) {
header := &requestHeader{
headers: make(map[string]string, 16),
}
lineEnd := bytes.Index(headerData, []byte("\r\n"))
if lineEnd == -1 {
return nil, fmt.Errorf("invalid request: no CRLF found in start line")
}
startLine := headerData[:lineEnd]
header.startLine = startLine
var err error
header.method, header.path, header.version, err = parseStartLine(startLine)
if err != nil {
return nil, err
}
remaining := headerData[lineEnd+2:]
setRemainingHeaders(remaining, header)
return header, nil
}
func parseStartLine(startLine []byte) (method, path, version string, err error) {
firstSpace := bytes.IndexByte(startLine, ' ')
if firstSpace == -1 {
return "", "", "", fmt.Errorf("invalid start line: missing method")
}
secondSpace := bytes.IndexByte(startLine[firstSpace+1:], ' ')
if secondSpace == -1 {
return "", "", "", fmt.Errorf("invalid start line: missing version")
}
secondSpace += firstSpace + 1
method = string(startLine[:firstSpace])
path = string(startLine[firstSpace+1 : secondSpace])
version = string(startLine[secondSpace+1:])
return method, path, version, nil
}
func parseHeadersFromReader(br *bufio.Reader) (RequestHeader, error) {
header := &requestHeader{
headers: make(map[string]string, 16),
}
startLineBytes, err := br.ReadSlice('\n')
if err != nil {
return nil, err
}
startLineBytes = bytes.TrimRight(startLineBytes, "\r\n")
header.startLine = make([]byte, len(startLineBytes))
copy(header.startLine, startLineBytes)
header.method, header.path, header.version, err = parseStartLine(header.startLine)
if err != nil {
return nil, err
}
for {
lineBytes, err := br.ReadSlice('\n')
if err != nil {
return nil, err
}
lineBytes = bytes.TrimRight(lineBytes, "\r\n")
if len(lineBytes) == 0 {
break
}
colonIdx := bytes.IndexByte(lineBytes, ':')
if colonIdx == -1 {
continue
}
key := bytes.TrimSpace(lineBytes[:colonIdx])
value := bytes.TrimSpace(lineBytes[colonIdx+1:])
header.headers[string(key)] = string(value)
}
return header, nil
}
func finalize(startLine []byte, headers map[string]string) []byte {
size := len(startLine) + 2
for key, val := range headers {
size += len(key) + 2 + len(val) + 2
}
size += 2
buf := make([]byte, 0, size)
buf = append(buf, startLine...)
buf = append(buf, '\r', '\n')
for key, val := range headers {
buf = append(buf, key...)
buf = append(buf, ':', ' ')
buf = append(buf, val...)
buf = append(buf, '\r', '\n')
}
buf = append(buf, '\r', '\n')
return buf
}
+49
View File
@@ -0,0 +1,49 @@
package header
import (
"bufio"
"fmt"
)
func NewRequest(r interface{}) (RequestHeader, error) {
switch v := r.(type) {
case []byte:
return parseHeadersFromBytes(v)
case *bufio.Reader:
return parseHeadersFromReader(v)
default:
return nil, fmt.Errorf("unsupported type: %T", r)
}
}
func (req *requestHeader) Value(key string) string {
val, ok := req.headers[key]
if !ok {
return ""
}
return val
}
func (req *requestHeader) Set(key string, value string) {
req.headers[key] = value
}
func (req *requestHeader) Remove(key string) {
delete(req.headers, key)
}
func (req *requestHeader) Method() string {
return req.method
}
func (req *requestHeader) Path() string {
return req.path
}
func (req *requestHeader) Version() string {
return req.version
}
func (req *requestHeader) Finalize() []byte {
return finalize(req.startLine, req.headers)
}
+40
View File
@@ -0,0 +1,40 @@
package header
import (
"bytes"
"fmt"
)
func NewResponse(headerData []byte) (ResponseHeader, error) {
header := &responseHeader{
startLine: nil,
headers: make(map[string]string, 16),
}
lineEnd := bytes.Index(headerData, []byte("\r\n"))
if lineEnd == -1 {
return nil, fmt.Errorf("invalid response: no CRLF found in start line")
}
header.startLine = headerData[:lineEnd]
remaining := headerData[lineEnd+2:]
setRemainingHeaders(remaining, header)
return header, nil
}
func (resp *responseHeader) Value(key string) string {
return resp.headers[key]
}
func (resp *responseHeader) Set(key string, value string) {
resp.headers[key] = value
}
func (resp *responseHeader) Remove(key string) {
delete(resp.headers, key)
}
func (resp *responseHeader) Finalize() []byte {
return finalize(resp.startLine, resp.headers)
}
+29
View File
@@ -0,0 +1,29 @@
package stream
import "bytes"
func splitHeaderAndBody(data []byte, delimiterIdx int) ([]byte, []byte) {
headerByte := data[:delimiterIdx+len(DELIMITER)]
body := data[delimiterIdx+len(DELIMITER):]
return headerByte, body
}
func isHTTPHeader(buf []byte) bool {
lines := bytes.Split(buf, []byte("\r\n"))
startLine := string(lines[0])
if !requestLine.MatchString(startLine) && !responseLine.MatchString(startLine) {
return false
}
for _, line := range lines[1:] {
if len(line) == 0 {
break
}
colonIdx := bytes.IndexByte(line, ':')
if colonIdx <= 0 {
return false
}
}
return true
}
+50
View File
@@ -0,0 +1,50 @@
package stream
import (
"bytes"
"tunnel_pls/internal/http/header"
)
func (hs *http) Read(p []byte) (int, error) {
tmp := make([]byte, len(p))
read, err := hs.reader.Read(tmp)
if read == 0 && err != nil {
return 0, err
}
tmp = tmp[:read]
headerEndIdx := bytes.Index(tmp, DELIMITER)
if headerEndIdx == -1 {
return handleNoDelimiter(p, tmp, err)
}
headerByte, bodyByte := splitHeaderAndBody(tmp, headerEndIdx)
if !isHTTPHeader(headerByte) {
copy(p, tmp)
return read, nil
}
return hs.processHTTPRequest(p, headerByte, bodyByte)
}
func (hs *http) processHTTPRequest(p, headerByte, bodyByte []byte) (int, error) {
reqhf, err := header.NewRequest(headerByte)
if err != nil {
return 0, err
}
if err = hs.ApplyRequestMiddlewares(reqhf); err != nil {
return 0, err
}
hs.reqHeader = reqhf
combined := append(reqhf.Finalize(), bodyByte...)
return copy(p, combined), nil
}
func handleNoDelimiter(p, tmp []byte, err error) (int, error) {
copy(p, tmp)
return len(tmp), err
}
+103
View File
@@ -0,0 +1,103 @@
package stream
import (
"io"
"log"
"net"
"regexp"
"tunnel_pls/internal/http/header"
"tunnel_pls/internal/middleware"
)
var DELIMITER = []byte{0x0D, 0x0A, 0x0D, 0x0A}
var requestLine = regexp.MustCompile(`^(GET|POST|PUT|DELETE|HEAD|OPTIONS|PATCH|TRACE|CONNECT) \S+ HTTP/\d\.\d$`)
var responseLine = regexp.MustCompile(`^HTTP/\d\.\d \d{3} .+`)
type HTTP interface {
io.ReadWriteCloser
CloseWrite() error
RemoteAddr() net.Addr
UseResponseMiddleware(mw middleware.ResponseMiddleware)
UseRequestMiddleware(mw middleware.RequestMiddleware)
SetRequestHeader(header header.RequestHeader)
RequestMiddlewares() []middleware.RequestMiddleware
ResponseMiddlewares() []middleware.ResponseMiddleware
ApplyResponseMiddlewares(resphf header.ResponseHeader, body []byte) error
ApplyRequestMiddlewares(reqhf header.RequestHeader) error
}
type http struct {
remoteAddr net.Addr
writer io.Writer
reader io.Reader
headerBuf []byte
buf []byte
respHeader header.ResponseHeader
reqHeader header.RequestHeader
respMW []middleware.ResponseMiddleware
reqMW []middleware.RequestMiddleware
}
func New(writer io.Writer, reader io.Reader, remoteAddr net.Addr) HTTP {
return &http{
remoteAddr: remoteAddr,
writer: writer,
reader: reader,
buf: make([]byte, 0, 4096),
}
}
func (hs *http) RemoteAddr() net.Addr {
return hs.remoteAddr
}
func (hs *http) UseResponseMiddleware(mw middleware.ResponseMiddleware) {
hs.respMW = append(hs.respMW, mw)
}
func (hs *http) UseRequestMiddleware(mw middleware.RequestMiddleware) {
hs.reqMW = append(hs.reqMW, mw)
}
func (hs *http) SetRequestHeader(header header.RequestHeader) {
hs.reqHeader = header
}
func (hs *http) RequestMiddlewares() []middleware.RequestMiddleware {
return hs.reqMW
}
func (hs *http) ResponseMiddlewares() []middleware.ResponseMiddleware {
return hs.respMW
}
func (hs *http) Close() error {
return hs.writer.(io.Closer).Close()
}
func (hs *http) CloseWrite() error {
if closer, ok := hs.writer.(interface{ CloseWrite() error }); ok {
return closer.CloseWrite()
}
return hs.Close()
}
func (hs *http) ApplyRequestMiddlewares(reqhf header.RequestHeader) error {
for _, m := range hs.RequestMiddlewares() {
if err := m.HandleRequest(reqhf); err != nil {
log.Printf("Error when applying request middleware: %v", err)
return err
}
}
return nil
}
func (hs *http) ApplyResponseMiddlewares(resphf header.ResponseHeader, bodyByte []byte) error {
for _, m := range hs.ResponseMiddlewares() {
if err := m.HandleResponse(resphf, bodyByte); err != nil {
log.Printf("Cannot apply middleware: %s\n", err)
return err
}
}
return nil
}
+88
View File
@@ -0,0 +1,88 @@
package stream
import (
"bytes"
"tunnel_pls/internal/http/header"
)
func (hs *http) Write(p []byte) (int, error) {
if hs.shouldBypassBuffering(p) {
hs.respHeader = nil
}
if hs.respHeader != nil {
return hs.writer.Write(p)
}
hs.buf = append(hs.buf, p...)
headerEndIdx := bytes.Index(hs.buf, DELIMITER)
if headerEndIdx == -1 {
return len(p), nil
}
return hs.processBufferedResponse(p, headerEndIdx)
}
func (hs *http) shouldBypassBuffering(p []byte) bool {
return hs.respHeader != nil && len(hs.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/"
}
func (hs *http) processBufferedResponse(p []byte, delimiterIdx int) (int, error) {
headerByte, bodyByte := splitHeaderAndBody(hs.buf, delimiterIdx)
if !isHTTPHeader(headerByte) {
return hs.writeRawBuffer()
}
if err := hs.processHTTPResponse(headerByte, bodyByte); err != nil {
return 0, err
}
hs.buf = nil
return len(p), nil
}
func (hs *http) writeRawBuffer() (int, error) {
_, err := hs.writer.Write(hs.buf)
length := len(hs.buf)
hs.buf = nil
if err != nil {
return 0, err
}
return length, nil
}
func (hs *http) processHTTPResponse(headerByte, bodyByte []byte) error {
resphf, err := header.NewResponse(headerByte)
if err != nil {
return err
}
if err = hs.ApplyResponseMiddlewares(resphf, bodyByte); err != nil {
return err
}
hs.respHeader = resphf
finalHeader := resphf.Finalize()
if err = hs.writeHeaderAndBody(finalHeader, bodyByte); err != nil {
return err
}
return nil
}
func (hs *http) writeHeaderAndBody(header, bodyByte []byte) error {
if _, err := hs.writer.Write(header); err != nil {
return err
}
if len(bodyByte) > 0 {
if _, err := hs.writer.Write(bodyByte); err != nil {
return err
}
}
return nil
}
+23
View File
@@ -0,0 +1,23 @@
package middleware
import (
"net"
"tunnel_pls/internal/http/header"
)
type ForwardedFor struct {
addr net.Addr
}
func NewForwardedFor(addr net.Addr) *ForwardedFor {
return &ForwardedFor{addr: addr}
}
func (ff *ForwardedFor) HandleRequest(header header.RequestHeader) error {
host, _, err := net.SplitHostPort(ff.addr.String())
if err != nil {
return err
}
header.Set("X-Forwarded-For", host)
return nil
}
+13
View File
@@ -0,0 +1,13 @@
package middleware
import (
"tunnel_pls/internal/http/header"
)
type RequestMiddleware interface {
HandleRequest(header header.RequestHeader) error
}
type ResponseMiddleware interface {
HandleResponse(header header.ResponseHeader, body []byte) error
}
+16
View File
@@ -0,0 +1,16 @@
package middleware
import (
"tunnel_pls/internal/http/header"
)
type TunnelFingerprint struct{}
func NewTunnelFingerprint() *TunnelFingerprint {
return &TunnelFingerprint{}
}
func (h *TunnelFingerprint) HandleResponse(header header.ResponseHeader, body []byte) error {
header.Set("Server", "Tunnel Please")
return nil
}
+19 -19
View File
@@ -6,37 +6,37 @@ import (
"sync" "sync"
) )
type Registry interface { type Port interface {
AddPortRange(startPort, endPort uint16) error AddRange(startPort, endPort uint16) error
GetUnassignedPort() (uint16, bool) Unassigned() (uint16, bool)
SetPortStatus(port uint16, assigned bool) error SetStatus(port uint16, assigned bool) error
ClaimPort(port uint16) (claimed bool) Claim(port uint16) (claimed bool)
} }
type registry struct { type port struct {
mu sync.RWMutex mu sync.RWMutex
ports map[uint16]bool ports map[uint16]bool
sortedPorts []uint16 sortedPorts []uint16
} }
func New() Registry { func New() Port {
return &registry{ return &port{
ports: make(map[uint16]bool), ports: make(map[uint16]bool),
sortedPorts: []uint16{}, sortedPorts: []uint16{},
} }
} }
func (pm *registry) AddPortRange(startPort, endPort uint16) error { func (pm *port) AddRange(startPort, endPort uint16) error {
pm.mu.Lock() pm.mu.Lock()
defer pm.mu.Unlock() defer pm.mu.Unlock()
if startPort > endPort { if startPort > endPort {
return fmt.Errorf("start port cannot be greater than end port") return fmt.Errorf("start port cannot be greater than end port")
} }
for port := startPort; port <= endPort; port++ { for index := startPort; index <= endPort; index++ {
if _, exists := pm.ports[port]; !exists { if _, exists := pm.ports[index]; !exists {
pm.ports[port] = false pm.ports[index] = false
pm.sortedPorts = append(pm.sortedPorts, port) pm.sortedPorts = append(pm.sortedPorts, index)
} }
} }
sort.Slice(pm.sortedPorts, func(i, j int) bool { sort.Slice(pm.sortedPorts, func(i, j int) bool {
@@ -45,19 +45,19 @@ func (pm *registry) AddPortRange(startPort, endPort uint16) error {
return nil return nil
} }
func (pm *registry) GetUnassignedPort() (uint16, bool) { func (pm *port) Unassigned() (uint16, bool) {
pm.mu.Lock() pm.mu.Lock()
defer pm.mu.Unlock() defer pm.mu.Unlock()
for _, port := range pm.sortedPorts { for _, index := range pm.sortedPorts {
if !pm.ports[port] { if !pm.ports[index] {
return port, true return index, true
} }
} }
return 0, false return 0, false
} }
func (pm *registry) SetPortStatus(port uint16, assigned bool) error { func (pm *port) SetStatus(port uint16, assigned bool) error {
pm.mu.Lock() pm.mu.Lock()
defer pm.mu.Unlock() defer pm.mu.Unlock()
@@ -65,7 +65,7 @@ func (pm *registry) SetPortStatus(port uint16, assigned bool) error {
return nil return nil
} }
func (pm *registry) ClaimPort(port uint16) (claimed bool) { func (pm *port) Claim(port uint16) (claimed bool) {
pm.mu.Lock() pm.mu.Lock()
defer pm.mu.Unlock() defer pm.mu.Unlock()
@@ -1,13 +1,25 @@
package session package registry
import ( import (
"fmt" "fmt"
"sync" "sync"
"tunnel_pls/session/forwarder"
"tunnel_pls/session/interaction"
"tunnel_pls/session/lifecycle"
"tunnel_pls/session/slug"
"tunnel_pls/types" "tunnel_pls/types"
) )
type Key = types.SessionKey type Key = types.SessionKey
type Session interface {
Lifecycle() lifecycle.Lifecycle
Interaction() interaction.Interaction
Forwarder() forwarder.Forwarder
Slug() slug.Slug
Detail() *types.Detail
}
type Registry interface { type Registry interface {
Get(key Key) (session Session, err error) Get(key Key) (session Session, err error)
GetWithUser(user string, key Key) (session Session, err error) GetWithUser(user string, key Key) (session Session, err error)
@@ -22,6 +34,15 @@ type registry struct {
slugIndex map[Key]string slugIndex map[Key]string
} }
var (
ErrSessionNotFound = fmt.Errorf("session not found")
ErrSlugInUse = fmt.Errorf("slug already in use")
ErrInvalidSlug = fmt.Errorf("invalid slug")
ErrForbiddenSlug = fmt.Errorf("forbidden slug")
ErrSlugChangeNotAllowed = fmt.Errorf("slug change not allowed for this tunnel type")
ErrSlugUnchanged = fmt.Errorf("slug is unchanged")
)
func NewRegistry() Registry { func NewRegistry() Registry {
return &registry{ return &registry{
byUser: make(map[string]map[Key]Session), byUser: make(map[string]map[Key]Session),
@@ -35,12 +56,12 @@ func (r *registry) Get(key Key) (session Session, err error) {
userID, ok := r.slugIndex[key] userID, ok := r.slugIndex[key]
if !ok { if !ok {
return nil, fmt.Errorf("session not found") return nil, ErrSessionNotFound
} }
client, ok := r.byUser[userID][key] client, ok := r.byUser[userID][key]
if !ok { if !ok {
return nil, fmt.Errorf("session not found") return nil, ErrSessionNotFound
} }
return client, nil return client, nil
} }
@@ -51,37 +72,37 @@ func (r *registry) GetWithUser(user string, key Key) (session Session, err error
client, ok := r.byUser[user][key] client, ok := r.byUser[user][key]
if !ok { if !ok {
return nil, fmt.Errorf("session not found") return nil, ErrSessionNotFound
} }
return client, nil return client, nil
} }
func (r *registry) Update(user string, oldKey, newKey Key) error { func (r *registry) Update(user string, oldKey, newKey Key) error {
if oldKey.Type != newKey.Type { if oldKey.Type != newKey.Type {
return fmt.Errorf("tunnel type cannot change") return ErrSlugUnchanged
} }
if newKey.Type != types.HTTP { if newKey.Type != types.TunnelTypeHTTP {
return fmt.Errorf("non http tunnel cannot change slug") return ErrSlugChangeNotAllowed
} }
if isForbiddenSlug(newKey.Id) { if isForbiddenSlug(newKey.Id) {
return fmt.Errorf("this subdomain is reserved. Please choose a different one") return ErrForbiddenSlug
} }
if !isValidSlug(newKey.Id) { if !isValidSlug(newKey.Id) {
return fmt.Errorf("invalid subdomain. Follow the rules") return ErrInvalidSlug
} }
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey { if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey {
return fmt.Errorf("someone already uses this subdomain") return ErrSlugInUse
} }
client, ok := r.byUser[user][oldKey] client, ok := r.byUser[user][oldKey]
if !ok { if !ok {
return fmt.Errorf("session not found") return ErrSessionNotFound
} }
delete(r.byUser[user], oldKey) delete(r.byUser[user], oldKey)
@@ -97,7 +118,7 @@ func (r *registry) Update(user string, oldKey, newKey Key) error {
return nil return nil
} }
func (r *registry) Register(key Key, session Session) (success bool) { func (r *registry) Register(key Key, userSession Session) (success bool) {
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
@@ -105,12 +126,12 @@ func (r *registry) Register(key Key, session Session) (success bool) {
return false return false
} }
userID := session.Lifecycle().User() userID := userSession.Lifecycle().User()
if r.byUser[userID] == nil { if r.byUser[userID] == nil {
r.byUser[userID] = make(map[Key]Session) r.byUser[userID] = make(map[Key]Session)
} }
r.byUser[userID][key] = session r.byUser[userID][key] = userSession
r.slugIndex[key] = userID r.slugIndex[key] = userID
return true return true
} }
+40
View File
@@ -0,0 +1,40 @@
package transport
import (
"errors"
"log"
"net"
"tunnel_pls/internal/registry"
)
type httpServer struct {
handler *httpHandler
port string
}
func NewHTTPServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool) Transport {
return &httpServer{
handler: newHTTPHandler(domain, sessionRegistry, redirectTLS),
port: port,
}
}
func (ht *httpServer) Listen() (net.Listener, error) {
return net.Listen("tcp", ":"+ht.port)
}
func (ht *httpServer) Serve(listener net.Listener) error {
log.Printf("HTTP server is starting on port %s", ht.port)
for {
conn, err := listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return err
}
log.Printf("Error accepting connection: %v", err)
continue
}
go ht.handler.handler(conn, false)
}
}
+231
View File
@@ -0,0 +1,231 @@
package transport
import (
"bufio"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"strings"
"time"
"tunnel_pls/internal/http/header"
"tunnel_pls/internal/http/stream"
"tunnel_pls/internal/middleware"
"tunnel_pls/internal/registry"
"tunnel_pls/types"
"golang.org/x/crypto/ssh"
)
type httpHandler struct {
domain string
sessionRegistry registry.Registry
redirectTLS bool
}
func newHTTPHandler(domain string, sessionRegistry registry.Registry, redirectTLS bool) *httpHandler {
return &httpHandler{
domain: domain,
sessionRegistry: sessionRegistry,
redirectTLS: redirectTLS,
}
}
func (hh *httpHandler) redirect(conn net.Conn, status int, location string) error {
_, err := conn.Write([]byte(fmt.Sprintf("HTTP/1.1 %d Moved Permanently\r\n", status) +
fmt.Sprintf("Location: %s", location) +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"\r\n"))
if err != nil {
return err
}
return nil
}
func (hh *httpHandler) badRequest(conn net.Conn) error {
if _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")); err != nil {
return err
}
return nil
}
func (hh *httpHandler) handler(conn net.Conn, isTLS bool) {
defer hh.closeConnection(conn)
dstReader := bufio.NewReader(conn)
reqhf, err := header.NewRequest(dstReader)
if err != nil {
log.Printf("Error creating request header: %v", err)
return
}
slug, err := hh.extractSlug(reqhf)
if err != nil {
_ = hh.badRequest(conn)
return
}
if hh.shouldRedirectToTLS(isTLS) {
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, hh.domain))
return
}
if hh.handlePingRequest(slug, conn) {
return
}
sshSession, err := hh.getSession(slug)
if err != nil {
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug))
return
}
hw := stream.New(conn, dstReader, conn.RemoteAddr())
defer func(hw stream.HTTP) {
err = hw.Close()
if err != nil {
log.Printf("Error closing HTTP stream: %v", err)
}
}(hw)
hh.forwardRequest(hw, reqhf, sshSession)
}
func (hh *httpHandler) closeConnection(conn net.Conn) {
err := conn.Close()
if err != nil && !errors.Is(err, net.ErrClosed) {
log.Printf("Error closing connection: %v", err)
}
}
func (hh *httpHandler) extractSlug(reqhf header.RequestHeader) (string, error) {
host := strings.Split(reqhf.Value("Host"), ".")
if len(host) < 1 {
return "", errors.New("invalid host")
}
return host[0], nil
}
func (hh *httpHandler) shouldRedirectToTLS(isTLS bool) bool {
return !isTLS && hh.redirectTLS
}
func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool {
if slug != "ping" {
return false
}
_, err := conn.Write([]byte(
"HTTP/1.1 200 OK\r\n" +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"Access-Control-Allow-Origin: *\r\n" +
"Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" +
"Access-Control-Allow-Headers: *\r\n" +
"\r\n",
))
if err != nil {
log.Println("Failed to write 200 OK:", err)
}
return true
}
func (hh *httpHandler) getSession(slug string) (registry.Session, error) {
sshSession, err := hh.sessionRegistry.Get(types.SessionKey{
Id: slug,
Type: types.TunnelTypeHTTP,
})
if err != nil {
return nil, err
}
return sshSession, nil
}
func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) {
channel, err := hh.openForwardedChannel(hw, sshSession)
if err != nil {
log.Printf("Failed to establish channel: %v", err)
sshSession.Forwarder().WriteBadGatewayResponse(hw)
return
}
defer func() {
err = channel.Close()
if err != nil && !errors.Is(err, io.EOF) {
log.Printf("Error closing forwarded channel: %v", err)
}
}()
hh.setupMiddlewares(hw)
if err = hh.sendInitialRequest(hw, initialRequest, channel); err != nil {
log.Printf("Failed to forward initial request: %v", err)
return
}
sshSession.Forwarder().HandleConnection(hw, channel)
}
func (hh *httpHandler) openForwardedChannel(hw stream.HTTP, sshSession registry.Session) (ssh.Channel, error) {
payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(hw.RemoteAddr())
type channelResult struct {
channel ssh.Channel
reqs <-chan *ssh.Request
err error
}
resultChan := make(chan channelResult, 1)
go func() {
channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload)
select {
case resultChan <- channelResult{channel, reqs, err}:
default:
hh.cleanupUnusedChannel(channel, reqs)
}
}()
select {
case result := <-resultChan:
if result.err != nil {
return nil, result.err
}
go ssh.DiscardRequests(result.reqs)
return result.channel, nil
case <-time.After(5 * time.Second):
return nil, errors.New("timeout opening forwarded-tcpip channel")
}
}
func (hh *httpHandler) cleanupUnusedChannel(channel ssh.Channel, reqs <-chan *ssh.Request) {
if channel != nil {
if err := channel.Close(); err != nil {
log.Printf("Failed to close unused channel: %v", err)
}
go ssh.DiscardRequests(reqs)
}
}
func (hh *httpHandler) setupMiddlewares(hw stream.HTTP) {
fingerprintMiddleware := middleware.NewTunnelFingerprint()
forwardedForMiddleware := middleware.NewForwardedFor(hw.RemoteAddr())
hw.UseResponseMiddleware(fingerprintMiddleware)
hw.UseRequestMiddleware(forwardedForMiddleware)
}
func (hh *httpHandler) sendInitialRequest(hw stream.HTTP, initialRequest header.RequestHeader, channel ssh.Channel) error {
hw.SetRequestHeader(initialRequest)
if err := hw.ApplyRequestMiddlewares(initialRequest); err != nil {
return fmt.Errorf("error applying request middlewares: %w", err)
}
if _, err := channel.Write(initialRequest.Finalize()); err != nil {
return fmt.Errorf("error writing to channel: %w", err)
}
return nil
}
+45
View File
@@ -0,0 +1,45 @@
package transport
import (
"crypto/tls"
"errors"
"log"
"net"
"tunnel_pls/internal/registry"
)
type https struct {
tlsConfig *tls.Config
httpHandler *httpHandler
domain string
port string
}
func NewHTTPSServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool, tlsConfig *tls.Config) Transport {
return &https{
tlsConfig: tlsConfig,
httpHandler: newHTTPHandler(domain, sessionRegistry, redirectTLS),
domain: domain,
port: port,
}
}
func (ht *https) Listen() (net.Listener, error) {
return tls.Listen("tcp", ":"+ht.port, ht.tlsConfig)
}
func (ht *https) Serve(listener net.Listener) error {
log.Printf("HTTPS server is starting on port %s", ht.port)
for {
conn, err := listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return err
}
log.Printf("Error accepting connection: %v", err)
continue
}
go ht.httpHandler.handler(conn, true)
}
}
+66
View File
@@ -0,0 +1,66 @@
package transport
import (
"errors"
"fmt"
"io"
"log"
"net"
"golang.org/x/crypto/ssh"
)
type tcp struct {
port uint16
forwarder forwarder
}
type forwarder interface {
CreateForwardedTCPIPPayload(origin net.Addr) []byte
OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error)
HandleConnection(dst io.ReadWriter, src ssh.Channel)
}
func NewTCPServer(port uint16, forwarder forwarder) Transport {
return &tcp{
port: port,
forwarder: forwarder,
}
}
func (tt *tcp) Listen() (net.Listener, error) {
return net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", tt.port))
}
func (tt *tcp) Serve(listener net.Listener) error {
for {
conn, err := listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return nil
}
log.Printf("Error accepting connection: %v", err)
continue
}
go tt.handleTcp(conn)
}
}
func (tt *tcp) handleTcp(conn net.Conn) {
defer func() {
err := conn.Close()
if err != nil {
log.Printf("Failed to close connection: %v", err)
}
}()
payload := tt.forwarder.CreateForwardedTCPIPPayload(conn.RemoteAddr())
channel, reqs, err := tt.forwarder.OpenForwardedChannel(payload)
if err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
return
}
go ssh.DiscardRequests(reqs)
tt.forwarder.HandleConnection(conn, channel)
}
+13 -34
View File
@@ -1,4 +1,4 @@
package server package transport
import ( import (
"context" "context"
@@ -26,7 +26,8 @@ type TLSManager interface {
} }
type tlsManager struct { type tlsManager struct {
domain string config config.Config
certPath string certPath string
keyPath string keyPath string
storagePath string storagePath string
@@ -42,7 +43,7 @@ type tlsManager struct {
var globalTLSManager TLSManager var globalTLSManager TLSManager
var tlsManagerOnce sync.Once var tlsManagerOnce sync.Once
func NewTLSConfig(domain string) (*tls.Config, error) { func NewTLSConfig(config config.Config) (*tls.Config, error) {
var initErr error var initErr error
tlsManagerOnce.Do(func() { tlsManagerOnce.Do(func() {
@@ -51,7 +52,7 @@ func NewTLSConfig(domain string) (*tls.Config, error) {
storagePath := "certs/tls/certmagic" storagePath := "certs/tls/certmagic"
tm := &tlsManager{ tm := &tlsManager{
domain: domain, config: config,
certPath: certPath, certPath: certPath,
keyPath: keyPath, keyPath: keyPath,
storagePath: storagePath, storagePath: storagePath,
@@ -66,14 +67,7 @@ func NewTLSConfig(domain string) (*tls.Config, error) {
tm.useCertMagic = false tm.useCertMagic = false
tm.startCertWatcher() tm.startCertWatcher()
} else { } else {
if !isACMEConfigComplete() { log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", config.Domain(), config.Domain())
log.Printf("User certificates missing or invalid, and ACME configuration is incomplete")
log.Printf("To enable automatic certificate generation, set CF_API_TOKEN environment variable")
initErr = fmt.Errorf("no valid certificates found and ACME configuration is incomplete (CF_API_TOKEN is required)")
return
}
log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", domain, domain)
if err := tm.initCertMagic(); err != nil { if err := tm.initCertMagic(); err != nil {
initErr = fmt.Errorf("failed to initialize CertMagic: %w", err) initErr = fmt.Errorf("failed to initialize CertMagic: %w", err)
return return
@@ -91,11 +85,6 @@ func NewTLSConfig(domain string) (*tls.Config, error) {
return globalTLSManager.getTLSConfig(), nil return globalTLSManager.getTLSConfig(), nil
} }
func isACMEConfigComplete() bool {
cfAPIToken := config.Getenv("CF_API_TOKEN", "")
return cfAPIToken != ""
}
func (tm *tlsManager) userCertsExistAndValid() bool { func (tm *tlsManager) userCertsExistAndValid() bool {
if _, err := os.Stat(tm.certPath); os.IsNotExist(err) { if _, err := os.Stat(tm.certPath); os.IsNotExist(err) {
log.Printf("Certificate file not found: %s", tm.certPath) log.Printf("Certificate file not found: %s", tm.certPath)
@@ -106,7 +95,7 @@ func (tm *tlsManager) userCertsExistAndValid() bool {
return false return false
} }
return ValidateCertDomains(tm.certPath, tm.domain) return ValidateCertDomains(tm.certPath, tm.config.Domain())
} }
func ValidateCertDomains(certPath, domain string) bool { func ValidateCertDomains(certPath, domain string) bool {
@@ -206,15 +195,9 @@ func (tm *tlsManager) startCertWatcher() {
if certInfo.ModTime().After(lastCertMod) || keyInfo.ModTime().After(lastKeyMod) { if certInfo.ModTime().After(lastCertMod) || keyInfo.ModTime().After(lastKeyMod) {
log.Printf("Certificate files changed, reloading...") log.Printf("Certificate files changed, reloading...")
if !ValidateCertDomains(tm.certPath, tm.domain) { if !ValidateCertDomains(tm.certPath, tm.config.Domain()) {
log.Printf("New certificates don't cover required domains") log.Printf("New certificates don't cover required domains")
if !isACMEConfigComplete() {
log.Printf("Cannot switch to CertMagic: ACME configuration is incomplete (CF_API_TOKEN is required)")
continue
}
log.Printf("Switching to CertMagic for automatic certificate management")
if err := tm.initCertMagic(); err != nil { if err := tm.initCertMagic(); err != nil {
log.Printf("Failed to initialize CertMagic: %v", err) log.Printf("Failed to initialize CertMagic: %v", err)
continue continue
@@ -241,16 +224,12 @@ func (tm *tlsManager) initCertMagic() error {
return fmt.Errorf("failed to create cert storage directory: %w", err) return fmt.Errorf("failed to create cert storage directory: %w", err)
} }
acmeEmail := config.Getenv("ACME_EMAIL", "admin@"+tm.domain) if tm.config.CFAPIToken() == "" {
cfAPIToken := config.Getenv("CF_API_TOKEN", "")
acmeStaging := config.Getenv("ACME_STAGING", "false") == "true"
if cfAPIToken == "" {
return fmt.Errorf("CF_API_TOKEN environment variable is required for automatic certificate generation") return fmt.Errorf("CF_API_TOKEN environment variable is required for automatic certificate generation")
} }
cfProvider := &cloudflare.Provider{ cfProvider := &cloudflare.Provider{
APIToken: cfAPIToken, APIToken: tm.config.CFAPIToken(),
} }
storage := &certmagic.FileStorage{Path: tm.storagePath} storage := &certmagic.FileStorage{Path: tm.storagePath}
@@ -266,7 +245,7 @@ func (tm *tlsManager) initCertMagic() error {
}) })
acmeIssuer := certmagic.NewACMEIssuer(magic, certmagic.ACMEIssuer{ acmeIssuer := certmagic.NewACMEIssuer(magic, certmagic.ACMEIssuer{
Email: acmeEmail, Email: tm.config.ACMEEmail(),
Agreed: true, Agreed: true,
DNS01Solver: &certmagic.DNS01Solver{ DNS01Solver: &certmagic.DNS01Solver{
DNSManager: certmagic.DNSManager{ DNSManager: certmagic.DNSManager{
@@ -275,7 +254,7 @@ func (tm *tlsManager) initCertMagic() error {
}, },
}) })
if acmeStaging { if tm.config.ACMEStaging() {
acmeIssuer.CA = certmagic.LetsEncryptStagingCA acmeIssuer.CA = certmagic.LetsEncryptStagingCA
log.Printf("Using Let's Encrypt staging server") log.Printf("Using Let's Encrypt staging server")
} else { } else {
@@ -286,7 +265,7 @@ func (tm *tlsManager) initCertMagic() error {
magic.Issuers = []certmagic.Issuer{acmeIssuer} magic.Issuers = []certmagic.Issuer{acmeIssuer}
tm.magic = magic tm.magic = magic
domains := []string{tm.domain, "*." + tm.domain} domains := []string{tm.config.Domain(), "*." + tm.config.Domain()}
log.Printf("Requesting certificates for: %v", domains) log.Printf("Requesting certificates for: %v", domains)
ctx := context.Background() ctx := context.Background()
+10
View File
@@ -0,0 +1,10 @@
package transport
import (
"net"
)
type Transport interface {
Listen() (net.Listener, error)
Serve(listener net.Listener) error
}
+62 -55
View File
@@ -4,21 +4,22 @@ import (
"context" "context"
"fmt" "fmt"
"log" "log"
"net"
"net/http" "net/http"
_ "net/http/pprof" _ "net/http/pprof"
"os" "os"
"os/signal" "os/signal"
"strconv"
"strings"
"syscall" "syscall"
"time" "time"
"tunnel_pls/internal/config" "tunnel_pls/internal/config"
"tunnel_pls/internal/grpc/client" "tunnel_pls/internal/grpc/client"
"tunnel_pls/internal/key" "tunnel_pls/internal/key"
"tunnel_pls/internal/port" "tunnel_pls/internal/port"
"tunnel_pls/internal/registry"
"tunnel_pls/internal/transport"
"tunnel_pls/internal/version"
"tunnel_pls/server" "tunnel_pls/server"
"tunnel_pls/session" "tunnel_pls/types"
"tunnel_pls/version"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@@ -34,27 +35,12 @@ func main() {
log.Printf("Starting %s", version.GetVersion()) log.Printf("Starting %s", version.GetVersion())
err := config.Load() conf, err := config.MustLoad()
if err != nil { if err != nil {
log.Fatalf("Failed to load configuration: %s", err) log.Fatalf("Failed to load configuration: %s", err)
return return
} }
mode := strings.ToLower(config.Getenv("MODE", "standalone"))
isNodeMode := mode == "node"
pprofEnabled := config.Getenv("PPROF_ENABLED", "false")
if pprofEnabled == "true" {
pprofPort := config.Getenv("PPROF_PORT", "6060")
go func() {
pprofAddr := fmt.Sprintf("localhost:%s", pprofPort)
log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr)
if err = http.ListenAndServe(pprofAddr, nil); err != nil {
log.Printf("pprof server error: %v", err)
}
}()
}
sshConfig := &ssh.ServerConfig{ sshConfig := &ssh.ServerConfig{
NoClientAuth: true, NoClientAuth: true,
ServerVersion: fmt.Sprintf("SSH-2.0-TunnelPlease-%s", version.GetShortVersion()), ServerVersion: fmt.Sprintf("SSH-2.0-TunnelPlease-%s", version.GetShortVersion()),
@@ -76,7 +62,7 @@ func main() {
} }
sshConfig.AddHostKey(private) sshConfig.AddHostKey(private)
sessionRegistry := session.NewRegistry() sessionRegistry := registry.NewRegistry()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@@ -86,16 +72,11 @@ func main() {
signal.Notify(shutdownChan, os.Interrupt, syscall.SIGTERM) signal.Notify(shutdownChan, os.Interrupt, syscall.SIGTERM)
var grpcClient client.Client var grpcClient client.Client
if isNodeMode {
grpcHost := config.Getenv("GRPC_ADDRESS", "localhost")
grpcPort := config.Getenv("GRPC_PORT", "8080")
grpcAddr := fmt.Sprintf("%s:%s", grpcHost, grpcPort)
nodeToken := config.Getenv("NODE_TOKEN", "")
if nodeToken == "" {
log.Fatalf("NODE_TOKEN is required in node mode")
}
grpcClient, err = client.New(grpcAddr, sessionRegistry) if conf.Mode() == types.ServerModeNODE {
grpcAddr := fmt.Sprintf("%s:%s", conf.GRPCAddress(), conf.GRPCPort())
grpcClient, err = client.New(conf, grpcAddr, sessionRegistry)
if err != nil { if err != nil {
log.Fatalf("failed to create grpc client: %v", err) log.Fatalf("failed to create grpc client: %v", err)
} }
@@ -108,47 +89,73 @@ func main() {
healthCancel() healthCancel()
go func() { go func() {
identity := config.Getenv("DOMAIN", "localhost") if err = grpcClient.SubscribeEvents(ctx, conf.Domain(), conf.NodeToken()); err != nil {
if err = grpcClient.SubscribeEvents(ctx, identity, nodeToken); err != nil {
errChan <- fmt.Errorf("failed to subscribe to events: %w", err) errChan <- fmt.Errorf("failed to subscribe to events: %w", err)
} }
}() }()
} }
portManager := port.New() go func() {
rawRange := config.Getenv("ALLOWED_PORTS", "") var httpListener net.Listener
if rawRange != "" { httpserver := transport.NewHTTPServer(conf.Domain(), conf.HTTPPort(), sessionRegistry, conf.TLSRedirect())
splitRange := strings.Split(rawRange, "-") httpListener, err = httpserver.Listen()
if len(splitRange) == 2 { if err != nil {
var start, end uint64 errChan <- fmt.Errorf("failed to start http server: %w", err)
start, err = strconv.ParseUint(splitRange[0], 10, 16) return
if err != nil {
log.Fatalf("Failed to parse start port: %s", err)
}
end, err = strconv.ParseUint(splitRange[1], 10, 16)
if err != nil {
log.Fatalf("Failed to parse end port: %s", err)
}
if err = portManager.AddPortRange(uint16(start), uint16(end)); err != nil {
log.Fatalf("Failed to add port range: %s", err)
}
log.Printf("PortRegistry range configured: %d-%d", start, end)
} else {
log.Printf("Invalid ALLOWED_PORTS format, expected 'start-end', got: %s", rawRange)
} }
err = httpserver.Serve(httpListener)
if err != nil {
errChan <- fmt.Errorf("error when serving http server: %w", err)
return
}
}()
if conf.TLSEnabled() {
go func() {
var httpsListener net.Listener
tlsConfig, _ := transport.NewTLSConfig(conf)
httpsServer := transport.NewHTTPSServer(conf.Domain(), conf.HTTPSPort(), sessionRegistry, conf.TLSRedirect(), tlsConfig)
httpsListener, err = httpsServer.Listen()
if err != nil {
errChan <- fmt.Errorf("failed to start http server: %w", err)
return
}
err = httpsServer.Serve(httpsListener)
if err != nil {
errChan <- fmt.Errorf("error when serving http server: %w", err)
return
}
}()
} }
portManager := port.New()
err = portManager.AddRange(conf.AllowedPortsStart(), conf.AllowedPortsEnd())
if err != nil {
log.Fatalf("Failed to initialize port manager: %s", err)
return
}
var app server.Server var app server.Server
go func() { go func() {
app, err = server.New(sshConfig, sessionRegistry, grpcClient, portManager) app, err = server.New(conf, sshConfig, sessionRegistry, grpcClient, portManager, conf.SSHPort())
if err != nil { if err != nil {
errChan <- fmt.Errorf("failed to start server: %s", err) errChan <- fmt.Errorf("failed to start server: %s", err)
return return
} }
app.Start() app.Start()
}() }()
if conf.PprofEnabled() {
go func() {
pprofAddr := fmt.Sprintf("localhost:%s", conf.PprofPort())
log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr)
if err = http.ListenAndServe(pprofAddr, nil); err != nil {
log.Printf("pprof server error: %v", err)
}
}()
}
select { select {
case err = <-errChan: case err = <-errChan:
log.Printf("error happen : %s", err) log.Printf("error happen : %s", err)
-276
View File
@@ -1,276 +0,0 @@
package server
import (
"bufio"
"bytes"
"fmt"
)
type HeaderManager interface {
Get(key string) []byte
Set(key string, value []byte)
Remove(key string)
Finalize() []byte
}
type ResponseHeaderManager interface {
Get(key string) string
Set(key string, value string)
Remove(key string)
Finalize() []byte
}
type RequestHeaderManager interface {
Get(key string) string
Set(key string, value string)
Remove(key string)
Finalize() []byte
GetMethod() string
GetPath() string
GetVersion() string
}
type responseHeaderFactory struct {
startLine []byte
headers map[string]string
}
type requestHeaderFactory struct {
method string
path string
version string
startLine []byte
headers map[string]string
}
func NewRequestHeaderFactory(r interface{}) (RequestHeaderManager, error) {
switch v := r.(type) {
case []byte:
return parseHeadersFromBytes(v)
case *bufio.Reader:
return parseHeadersFromReader(v)
default:
return nil, fmt.Errorf("unsupported type: %T", r)
}
}
func parseHeadersFromBytes(headerData []byte) (RequestHeaderManager, error) {
header := &requestHeaderFactory{
headers: make(map[string]string, 16),
}
lineEnd := bytes.IndexByte(headerData, '\n')
if lineEnd == -1 {
return nil, fmt.Errorf("invalid request: no newline found")
}
startLine := bytes.TrimRight(headerData[:lineEnd], "\r\n")
header.startLine = make([]byte, len(startLine))
copy(header.startLine, startLine)
parts := bytes.Split(startLine, []byte{' '})
if len(parts) < 3 {
return nil, fmt.Errorf("invalid request line")
}
header.method = string(parts[0])
header.path = string(parts[1])
header.version = string(parts[2])
remaining := headerData[lineEnd+1:]
for len(remaining) > 0 {
lineEnd = bytes.IndexByte(remaining, '\n')
if lineEnd == -1 {
lineEnd = len(remaining)
}
line := bytes.TrimRight(remaining[:lineEnd], "\r\n")
if len(line) == 0 {
break
}
colonIdx := bytes.IndexByte(line, ':')
if colonIdx != -1 {
key := bytes.TrimSpace(line[:colonIdx])
value := bytes.TrimSpace(line[colonIdx+1:])
header.headers[string(key)] = string(value)
}
if lineEnd == len(remaining) {
break
}
remaining = remaining[lineEnd+1:]
}
return header, nil
}
func parseHeadersFromReader(br *bufio.Reader) (RequestHeaderManager, error) {
header := &requestHeaderFactory{
headers: make(map[string]string, 16),
}
startLineBytes, err := br.ReadSlice('\n')
if err != nil {
if err == bufio.ErrBufferFull {
var startLine string
startLine, err = br.ReadString('\n')
if err != nil {
return nil, err
}
startLineBytes = []byte(startLine)
} else {
return nil, err
}
}
startLineBytes = bytes.TrimRight(startLineBytes, "\r\n")
header.startLine = make([]byte, len(startLineBytes))
copy(header.startLine, startLineBytes)
parts := bytes.Split(startLineBytes, []byte{' '})
if len(parts) < 3 {
return nil, fmt.Errorf("invalid request line")
}
header.method = string(parts[0])
header.path = string(parts[1])
header.version = string(parts[2])
for {
lineBytes, err := br.ReadSlice('\n')
if err != nil {
if err == bufio.ErrBufferFull {
var line string
line, err = br.ReadString('\n')
if err != nil {
return nil, err
}
lineBytes = []byte(line)
} else {
return nil, err
}
}
lineBytes = bytes.TrimRight(lineBytes, "\r\n")
if len(lineBytes) == 0 {
break
}
colonIdx := bytes.IndexByte(lineBytes, ':')
if colonIdx == -1 {
continue
}
key := bytes.TrimSpace(lineBytes[:colonIdx])
value := bytes.TrimSpace(lineBytes[colonIdx+1:])
header.headers[string(key)] = string(value)
}
return header, nil
}
func NewResponseHeaderFactory(startLine []byte) ResponseHeaderManager {
header := &responseHeaderFactory{
startLine: nil,
headers: make(map[string]string),
}
lines := bytes.Split(startLine, []byte("\r\n"))
if len(lines) == 0 {
return header
}
header.startLine = lines[0]
for _, h := range lines[1:] {
if len(h) == 0 {
continue
}
parts := bytes.SplitN(h, []byte(":"), 2)
if len(parts) < 2 {
continue
}
key := parts[0]
val := bytes.TrimSpace(parts[1])
header.headers[string(key)] = string(val)
}
return header
}
func (resp *responseHeaderFactory) Get(key string) string {
return resp.headers[key]
}
func (resp *responseHeaderFactory) Set(key string, value string) {
resp.headers[key] = value
}
func (resp *responseHeaderFactory) Remove(key string) {
delete(resp.headers, key)
}
func (resp *responseHeaderFactory) Finalize() []byte {
var buf bytes.Buffer
buf.Write(resp.startLine)
buf.WriteString("\r\n")
for key, val := range resp.headers {
buf.WriteString(key)
buf.WriteString(": ")
buf.WriteString(val)
buf.WriteString("\r\n")
}
buf.WriteString("\r\n")
return buf.Bytes()
}
func (req *requestHeaderFactory) Get(key string) string {
val, ok := req.headers[key]
if !ok {
return ""
}
return val
}
func (req *requestHeaderFactory) Set(key string, value string) {
req.headers[key] = value
}
func (req *requestHeaderFactory) Remove(key string) {
delete(req.headers, key)
}
func (req *requestHeaderFactory) GetMethod() string {
return req.method
}
func (req *requestHeaderFactory) GetPath() string {
return req.path
}
func (req *requestHeaderFactory) GetVersion() string {
return req.version
}
func (req *requestHeaderFactory) Finalize() []byte {
var buf bytes.Buffer
buf.Write(req.startLine)
buf.WriteString("\r\n")
for key, val := range req.headers {
buf.WriteString(key)
buf.WriteString(": ")
buf.WriteString(val)
buf.WriteString("\r\n")
}
buf.WriteString("\r\n")
return buf.Bytes()
}
-395
View File
@@ -1,395 +0,0 @@
package server
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"log"
"net"
"regexp"
"strings"
"time"
"tunnel_pls/internal/config"
"tunnel_pls/session"
"tunnel_pls/types"
"golang.org/x/crypto/ssh"
)
type HTTPWriter interface {
io.Reader
io.Writer
GetRemoteAddr() net.Addr
GetWriter() io.Writer
AddResponseMiddleware(mw ResponseMiddleware)
AddRequestStartMiddleware(mw RequestMiddleware)
SetRequestHeader(header RequestHeaderManager)
GetRequestStartMiddleware() []RequestMiddleware
}
type customWriter struct {
remoteAddr net.Addr
writer io.Writer
reader io.Reader
headerBuf []byte
buf []byte
respHeader ResponseHeaderManager
reqHeader RequestHeaderManager
respMW []ResponseMiddleware
reqStartMW []RequestMiddleware
reqEndMW []RequestMiddleware
}
func (cw *customWriter) GetRemoteAddr() net.Addr {
return cw.remoteAddr
}
func (cw *customWriter) GetWriter() io.Writer {
return cw.writer
}
func (cw *customWriter) AddResponseMiddleware(mw ResponseMiddleware) {
cw.respMW = append(cw.respMW, mw)
}
func (cw *customWriter) AddRequestStartMiddleware(mw RequestMiddleware) {
cw.reqStartMW = append(cw.reqStartMW, mw)
}
func (cw *customWriter) SetRequestHeader(header RequestHeaderManager) {
cw.reqHeader = header
}
func (cw *customWriter) GetRequestStartMiddleware() []RequestMiddleware {
return cw.reqStartMW
}
func (cw *customWriter) Read(p []byte) (int, error) {
tmp := make([]byte, len(p))
read, err := cw.reader.Read(tmp)
if read == 0 && err != nil {
return 0, err
}
tmp = tmp[:read]
idx := bytes.Index(tmp, DELIMITER)
if idx == -1 {
copy(p, tmp)
if err != nil {
return read, err
}
return read, nil
}
header := tmp[:idx+len(DELIMITER)]
body := tmp[idx+len(DELIMITER):]
if !isHTTPHeader(header) {
copy(p, tmp)
return read, nil
}
for _, m := range cw.reqEndMW {
err = m.HandleRequest(cw.reqHeader)
if err != nil {
log.Printf("Error when applying request middleware: %v", err)
return 0, err
}
}
reqhf, err := NewRequestHeaderFactory(header)
if err != nil {
return 0, err
}
for _, m := range cw.reqStartMW {
if mwErr := m.HandleRequest(reqhf); mwErr != nil {
log.Printf("Error when applying request middleware: %v", mwErr)
return 0, mwErr
}
}
cw.reqHeader = reqhf
finalHeader := reqhf.Finalize()
combined := append(finalHeader, body...)
n := copy(p, combined)
return n, nil
}
func NewCustomWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) HTTPWriter {
return &customWriter{
remoteAddr: remoteAddr,
writer: writer,
reader: reader,
buf: make([]byte, 0, 4096),
}
}
var DELIMITER = []byte{0x0D, 0x0A, 0x0D, 0x0A}
var requestLine = regexp.MustCompile(`^(GET|POST|PUT|DELETE|HEAD|OPTIONS|PATCH|TRACE|CONNECT) \S+ HTTP/\d\.\d$`)
var responseLine = regexp.MustCompile(`^HTTP/\d\.\d \d{3} .+`)
func isHTTPHeader(buf []byte) bool {
lines := bytes.Split(buf, []byte("\r\n"))
startLine := string(lines[0])
if !requestLine.MatchString(startLine) && !responseLine.MatchString(startLine) {
return false
}
for _, line := range lines[1:] {
if len(line) == 0 {
break
}
colonIdx := bytes.IndexByte(line, ':')
if colonIdx <= 0 {
return false
}
}
return true
}
func (cw *customWriter) Write(p []byte) (int, error) {
if cw.respHeader != nil && len(cw.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/" {
cw.respHeader = nil
}
if cw.respHeader != nil {
n, err := cw.writer.Write(p)
if err != nil {
return n, err
}
return n, nil
}
cw.buf = append(cw.buf, p...)
idx := bytes.Index(cw.buf, DELIMITER)
if idx == -1 {
return len(p), nil
}
header := cw.buf[:idx+len(DELIMITER)]
body := cw.buf[idx+len(DELIMITER):]
if !isHTTPHeader(header) {
_, err := cw.writer.Write(cw.buf)
cw.buf = nil
if err != nil {
return 0, err
}
return len(p), nil
}
resphf := NewResponseHeaderFactory(header)
for _, m := range cw.respMW {
err := m.HandleResponse(resphf, body)
if err != nil {
log.Printf("Cannot apply middleware: %s\n", err)
return 0, err
}
}
header = resphf.Finalize()
cw.respHeader = resphf
_, err := cw.writer.Write(header)
if err != nil {
return 0, err
}
if len(body) > 0 {
_, err = cw.writer.Write(body)
if err != nil {
return 0, err
}
}
cw.buf = nil
return len(p), nil
}
var redirectTLS = false
type HTTPServer interface {
ListenAndServe() error
ListenAndServeTLS() error
handler(conn net.Conn)
handlerTLS(conn net.Conn)
}
type httpServer struct {
sessionRegistry session.Registry
}
func NewHTTPServer(sessionRegistry session.Registry) HTTPServer {
return &httpServer{sessionRegistry: sessionRegistry}
}
func (hs *httpServer) ListenAndServe() error {
httpPort := config.Getenv("HTTP_PORT", "8080")
listener, err := net.Listen("tcp", ":"+httpPort)
if err != nil {
return errors.New("Error listening: " + err.Error())
}
if config.Getenv("TLS_ENABLED", "false") == "true" && config.Getenv("TLS_REDIRECT", "false") == "true" {
redirectTLS = true
}
go func() {
for {
var conn net.Conn
conn, err = listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
}
log.Printf("Error accepting connection: %v", err)
continue
}
go hs.handler(conn)
}
}()
return nil
}
func (hs *httpServer) handler(conn net.Conn) {
defer func() {
err := conn.Close()
if err != nil && !errors.Is(err, net.ErrClosed) {
log.Printf("Error closing connection: %v", err)
return
}
return
}()
dstReader := bufio.NewReader(conn)
reqhf, err := NewRequestHeaderFactory(dstReader)
if err != nil {
log.Printf("Error creating request header: %v", err)
return
}
host := strings.Split(reqhf.Get("Host"), ".")
if len(host) < 1 {
_, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n"))
if err != nil {
log.Println("Failed to write 400 Bad Request:", err)
return
}
return
}
slug := host[0]
if redirectTLS {
_, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" +
fmt.Sprintf("Location: https://%s.%s/\r\n", slug, config.Getenv("DOMAIN", "localhost")) +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"\r\n"))
if err != nil {
log.Println("Failed to write 301 Moved Permanently:", err)
return
}
return
}
if slug == "ping" {
_, err = conn.Write([]byte(
"HTTP/1.1 200 OK\r\n" +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"Access-Control-Allow-Origin: *\r\n" +
"Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" +
"Access-Control-Allow-Headers: *\r\n" +
"\r\n",
))
if err != nil {
log.Println("Failed to write 200 OK:", err)
return
}
return
}
sshSession, err := hs.sessionRegistry.Get(types.SessionKey{
Id: slug,
Type: types.HTTP,
})
if err != nil {
_, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" +
fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"\r\n"))
if err != nil {
log.Println("Failed to write 301 Moved Permanently:", err)
return
}
return
}
cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr())
forwardRequest(cw, reqhf, sshSession)
return
}
func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSession session.Session) {
payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(cw.GetRemoteAddr())
type channelResult struct {
channel ssh.Channel
reqs <-chan *ssh.Request
err error
}
resultChan := make(chan channelResult, 1)
go func() {
channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload)
resultChan <- channelResult{channel, reqs, err}
}()
var channel ssh.Channel
var reqs <-chan *ssh.Request
select {
case result := <-resultChan:
if result.err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", result.err)
sshSession.Forwarder().WriteBadGatewayResponse(cw.GetWriter())
return
}
channel = result.channel
reqs = result.reqs
case <-time.After(5 * time.Second):
log.Printf("Timeout opening forwarded-tcpip channel")
sshSession.Forwarder().WriteBadGatewayResponse(cw.GetWriter())
return
}
go ssh.DiscardRequests(reqs)
fingerprintMiddleware := NewTunnelFingerprint()
forwardedForMiddleware := NewForwardedFor(cw.GetRemoteAddr())
cw.AddResponseMiddleware(fingerprintMiddleware)
cw.AddRequestStartMiddleware(forwardedForMiddleware)
cw.SetRequestHeader(initialRequest)
for _, m := range cw.GetRequestStartMiddleware() {
if err := m.HandleRequest(initialRequest); err != nil {
log.Printf("Error handling request: %v", err)
return
}
}
_, err := channel.Write(initialRequest.Finalize())
if err != nil {
log.Printf("Failed to forward request: %v", err)
return
}
sshSession.Forwarder().HandleConnection(cw, channel, cw.GetRemoteAddr())
return
}
-112
View File
@@ -1,112 +0,0 @@
package server
import (
"bufio"
"crypto/tls"
"errors"
"fmt"
"log"
"net"
"strings"
"tunnel_pls/internal/config"
"tunnel_pls/types"
)
func (hs *httpServer) ListenAndServeTLS() error {
domain := config.Getenv("DOMAIN", "localhost")
httpsPort := config.Getenv("HTTPS_PORT", "8443")
tlsConfig, err := NewTLSConfig(domain)
if err != nil {
return fmt.Errorf("failed to initialize TLS config: %w", err)
}
ln, err := tls.Listen("tcp", ":"+httpsPort, tlsConfig)
if err != nil {
return err
}
go func() {
for {
var conn net.Conn
conn, err = ln.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
log.Println("https server closed")
}
log.Printf("Error accepting connection: %v", err)
continue
}
go hs.handlerTLS(conn)
}
}()
return nil
}
func (hs *httpServer) handlerTLS(conn net.Conn) {
defer func() {
err := conn.Close()
if err != nil {
log.Printf("Error closing connection: %v", err)
return
}
return
}()
dstReader := bufio.NewReader(conn)
reqhf, err := NewRequestHeaderFactory(dstReader)
if err != nil {
log.Printf("Error creating request header: %v", err)
return
}
host := strings.Split(reqhf.Get("Host"), ".")
if len(host) < 1 {
_, err = conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n"))
if err != nil {
log.Println("Failed to write 400 Bad Request:", err)
return
}
return
}
slug := host[0]
if slug == "ping" {
_, err = conn.Write([]byte(
"HTTP/1.1 200 OK\r\n" +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"Access-Control-Allow-Origin: *\r\n" +
"Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" +
"Access-Control-Allow-Headers: *\r\n" +
"\r\n",
))
if err != nil {
log.Println("Failed to write 200 OK:", err)
return
}
return
}
sshSession, err := hs.sessionRegistry.Get(types.SessionKey{
Id: slug,
Type: types.HTTP,
})
if err != nil {
_, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" +
fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"\r\n"))
if err != nil {
log.Println("Failed to write 301 Moved Permanently:", err)
return
}
return
}
cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr())
forwardRequest(cw, reqhf, sshSession)
return
}
-41
View File
@@ -1,41 +0,0 @@
package server
import (
"net"
)
type RequestMiddleware interface {
HandleRequest(header RequestHeaderManager) error
}
type ResponseMiddleware interface {
HandleResponse(header ResponseHeaderManager, body []byte) error
}
type TunnelFingerprint struct{}
func NewTunnelFingerprint() *TunnelFingerprint {
return &TunnelFingerprint{}
}
func (h *TunnelFingerprint) HandleResponse(header ResponseHeaderManager, body []byte) error {
header.Set("Server", "Tunnel Please")
return nil
}
type ForwardedFor struct {
addr net.Addr
}
func NewForwardedFor(addr net.Addr) *ForwardedFor {
return &ForwardedFor{addr: addr}
}
func (ff *ForwardedFor) HandleRequest(header RequestHeaderManager) error {
host, _, err := net.SplitHostPort(ff.addr.String())
if err != nil {
return err
}
header.Set("X-Forwarded-For", host)
return nil
}
+18 -29
View File
@@ -10,6 +10,7 @@ import (
"tunnel_pls/internal/config" "tunnel_pls/internal/config"
"tunnel_pls/internal/grpc/client" "tunnel_pls/internal/grpc/client"
"tunnel_pls/internal/port" "tunnel_pls/internal/port"
"tunnel_pls/internal/registry"
"tunnel_pls/session" "tunnel_pls/session"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
@@ -20,38 +21,26 @@ type Server interface {
Close() error Close() error
} }
type server struct { type server struct {
listener net.Listener config config.Config
config *ssh.ServerConfig sshPort string
sshListener net.Listener
sshConfig *ssh.ServerConfig
grpcClient client.Client grpcClient client.Client
sessionRegistry session.Registry sessionRegistry registry.Registry
portRegistry port.Registry portRegistry port.Port
} }
func New(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry, grpcClient client.Client, portRegistry port.Registry) (Server, error) { func New(config config.Config, sshConfig *ssh.ServerConfig, sessionRegistry registry.Registry, grpcClient client.Client, portRegistry port.Port, sshPort string) (Server, error) {
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", config.Getenv("PORT", "2200"))) listener, err := net.Listen("tcp", fmt.Sprintf(":%s", sshPort))
if err != nil { if err != nil {
log.Fatalf("failed to listen on port 2200: %v", err)
return nil, err return nil, err
} }
HttpServer := NewHTTPServer(sessionRegistry)
err = HttpServer.ListenAndServe()
if err != nil {
log.Fatalf("failed to start http server: %v", err)
return nil, err
}
if config.Getenv("TLS_ENABLED", "false") == "true" {
err = HttpServer.ListenAndServeTLS()
if err != nil {
log.Fatalf("failed to start https server: %v", err)
return nil, err
}
}
return &server{ return &server{
listener: listener, config: config,
config: sshConfig, sshPort: sshPort,
sshListener: listener,
sshConfig: sshConfig,
grpcClient: grpcClient, grpcClient: grpcClient,
sessionRegistry: sessionRegistry, sessionRegistry: sessionRegistry,
portRegistry: portRegistry, portRegistry: portRegistry,
@@ -59,9 +48,9 @@ func New(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry, grpcClie
} }
func (s *server) Start() { func (s *server) Start() {
log.Println("SSH server is starting on port 2200...") log.Printf("SSH server is starting on port %s", s.sshPort)
for { for {
conn, err := s.listener.Accept() conn, err := s.sshListener.Accept()
if err != nil { if err != nil {
if errors.Is(err, net.ErrClosed) { if errors.Is(err, net.ErrClosed) {
log.Println("listener closed, stopping server") log.Println("listener closed, stopping server")
@@ -76,11 +65,11 @@ func (s *server) Start() {
} }
func (s *server) Close() error { func (s *server) Close() error {
return s.listener.Close() return s.sshListener.Close()
} }
func (s *server) handleConnection(conn net.Conn) { func (s *server) handleConnection(conn net.Conn) {
sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.config) sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.sshConfig)
if err != nil { if err != nil {
log.Printf("failed to establish SSH connection: %v", err) log.Printf("failed to establish SSH connection: %v", err)
err = conn.Close() err = conn.Close()
@@ -106,7 +95,7 @@ func (s *server) handleConnection(conn net.Conn) {
cancel() cancel()
} }
log.Println("SSH connection established:", sshConn.User()) log.Println("SSH connection established:", sshConn.User())
sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry, s.portRegistry, user) sshSession := session.New(s.config, sshConn, forwardingReqs, chans, s.sessionRegistry, s.portRegistry, user)
err = sshSession.Start() err = sshSession.Start()
if err != nil { if err != nil {
log.Printf("SSH session ended with error: %v", err) log.Printf("SSH session ended with error: %v", err)
+94 -110
View File
@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt"
"io" "io"
"log" "log"
"net" "net"
@@ -17,37 +18,6 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
var bufferPool = sync.Pool{
New: func() interface{} {
bufSize := config.GetBufferSize()
return make([]byte, bufSize)
},
}
func copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
buf := bufferPool.Get().([]byte)
defer bufferPool.Put(buf)
return io.CopyBuffer(dst, src, buf)
}
type forwarder struct {
listener net.Listener
tunnelType types.TunnelType
forwardedPort uint16
slug slug.Slug
conn ssh.Conn
}
func New(slug slug.Slug, conn ssh.Conn) Forwarder {
return &forwarder{
listener: nil,
tunnelType: types.UNKNOWN,
forwardedPort: 0,
slug: slug,
conn: conn,
}
}
type Forwarder interface { type Forwarder interface {
SetType(tunnelType types.TunnelType) SetType(tunnelType types.TunnelType)
SetForwardedPort(port uint16) SetForwardedPort(port uint16)
@@ -55,110 +25,124 @@ type Forwarder interface {
Listener() net.Listener Listener() net.Listener
TunnelType() types.TunnelType TunnelType() types.TunnelType
ForwardedPort() uint16 ForwardedPort() uint16
HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) HandleConnection(dst io.ReadWriter, src ssh.Channel)
CreateForwardedTCPIPPayload(origin net.Addr) []byte CreateForwardedTCPIPPayload(origin net.Addr) []byte
OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error)
WriteBadGatewayResponse(dst io.Writer) WriteBadGatewayResponse(dst io.Writer)
AcceptTCPConnections()
Close() error Close() error
} }
type forwarder struct {
listener net.Listener
tunnelType types.TunnelType
forwardedPort uint16
slug slug.Slug
conn ssh.Conn
bufferPool sync.Pool
}
func (f *forwarder) AcceptTCPConnections() { func New(config config.Config, slug slug.Slug, conn ssh.Conn) Forwarder {
for { return &forwarder{
conn, err := f.Listener().Accept() listener: nil,
if err != nil { tunnelType: types.TunnelTypeUNKNOWN,
if errors.Is(err, net.ErrClosed) { forwardedPort: 0,
return slug: slug,
} conn: conn,
log.Printf("Error accepting connection: %v", err) bufferPool: sync.Pool{
continue New: func() interface{} {
} bufSize := config.BufferSize()
return make([]byte, bufSize)
if err = conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil { },
log.Printf("Failed to set connection deadline: %v", err) },
if closeErr := conn.Close(); closeErr != nil {
log.Printf("Failed to close connection: %v", closeErr)
}
continue
}
payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr())
type channelResult struct {
channel ssh.Channel
reqs <-chan *ssh.Request
err error
}
resultChan := make(chan channelResult, 1)
go func() {
channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload)
resultChan <- channelResult{channel, reqs, err}
}()
select {
case result := <-resultChan:
if result.err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", result.err)
if closeErr := conn.Close(); closeErr != nil {
log.Printf("Failed to close connection: %v", closeErr)
}
continue
}
if err = conn.SetDeadline(time.Time{}); err != nil {
log.Printf("Failed to clear connection deadline: %v", err)
}
go ssh.DiscardRequests(result.reqs)
go f.HandleConnection(conn, result.channel, conn.RemoteAddr())
case <-time.After(5 * time.Second):
log.Printf("Timeout opening forwarded-tcpip channel")
if closeErr := conn.Close(); closeErr != nil {
log.Printf("Failed to close connection: %v", closeErr)
}
}
} }
} }
func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) { func (f *forwarder) copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
buf := f.bufferPool.Get().([]byte)
defer f.bufferPool.Put(buf)
return io.CopyBuffer(dst, src, buf)
}
func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
type channelResult struct {
channel ssh.Channel
reqs <-chan *ssh.Request
err error
}
resultChan := make(chan channelResult, 1)
go func() {
channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload)
select {
case resultChan <- channelResult{channel, reqs, err}:
default:
if channel != nil {
err = channel.Close()
if err != nil {
log.Printf("Failed to close unused channel: %v", err)
return
}
go ssh.DiscardRequests(reqs)
}
}
}()
select {
case result := <-resultChan:
return result.channel, result.reqs, result.err
case <-time.After(5 * time.Second):
return nil, nil, errors.New("timeout opening forwarded-tcpip channel")
}
}
func closeWriter(w io.Writer) error {
if cw, ok := w.(interface{ CloseWrite() error }); ok {
return cw.CloseWrite()
}
if closer, ok := w.(io.Closer); ok {
return closer.Close()
}
return nil
}
func (f *forwarder) copyAndClose(dst io.Writer, src io.Reader, direction string) error {
var errs []error
_, err := f.copyWithBuffer(dst, src)
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
errs = append(errs, fmt.Errorf("copy error (%s): %w", direction, err))
}
if err = closeWriter(dst); err != nil && !errors.Is(err, io.EOF) {
errs = append(errs, fmt.Errorf("close stream error (%s): %w", direction, err))
}
return errors.Join(errs...)
}
func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
defer func() { defer func() {
_, err := io.Copy(io.Discard, src) _, err := io.Copy(io.Discard, src)
if err != nil { if err != nil {
log.Printf("Failed to discard connection: %v", err) log.Printf("Failed to discard connection: %v", err)
} }
err = src.Close()
if err != nil && !errors.Is(err, io.EOF) {
log.Printf("Error closing source channel: %v", err)
}
if closer, ok := dst.(io.Closer); ok {
err = closer.Close()
if err != nil && !errors.Is(err, io.EOF) {
log.Printf("Error closing destination connection: %v", err)
}
}
}() }()
log.Printf("Handling new forwarded connection from %s", remoteAddr)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(2) wg.Add(2)
go func() { go func() {
defer wg.Done() defer wg.Done()
_, err := copyWithBuffer(dst, src) err := f.copyAndClose(dst, src, "src to dst")
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { if err != nil {
log.Printf("Error copying src→dst: %v", err) log.Println("Error during copy: ", err)
return
} }
}() }()
go func() { go func() {
defer wg.Done() defer wg.Done()
_, err := copyWithBuffer(src, dst) err := f.copyAndClose(src, dst, "dst to src")
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { if err != nil {
log.Printf("Error copying dst→src: %v", err) log.Println("Error during copy: ", err)
return
} }
}() }()
+11 -10
View File
@@ -18,9 +18,9 @@ import (
) )
type Interaction interface { type Interaction interface {
Mode() types.Mode Mode() types.InteractiveMode
SetChannel(channel ssh.Channel) SetChannel(channel ssh.Channel)
SetMode(m types.Mode) SetMode(m types.InteractiveMode)
SetWH(w, h int) SetWH(w, h int)
Start() Start()
Redraw() Redraw()
@@ -39,6 +39,7 @@ type Forwarder interface {
type CloseFunc func() error type CloseFunc func() error
type interaction struct { type interaction struct {
config config.Config
channel ssh.Channel channel ssh.Channel
slug slug.Slug slug slug.Slug
forwarder Forwarder forwarder Forwarder
@@ -48,14 +49,14 @@ type interaction struct {
program *tea.Program program *tea.Program
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
mode types.Mode mode types.InteractiveMode
} }
func (i *interaction) SetMode(m types.Mode) { func (i *interaction) SetMode(m types.InteractiveMode) {
i.mode = m i.mode = m
} }
func (i *interaction) Mode() types.Mode { func (i *interaction) Mode() types.InteractiveMode {
return i.mode return i.mode
} }
@@ -75,9 +76,10 @@ func (i *interaction) SetWH(w, h int) {
} }
} }
func New(slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction { func New(config config.Config, slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &interaction{ return &interaction{
config: config,
channel: nil, channel: nil,
slug: slug, slug: slug,
forwarder: forwarder, forwarder: forwarder,
@@ -174,14 +176,13 @@ func (m *model) View() string {
} }
func (i *interaction) Start() { func (i *interaction) Start() {
if i.mode == types.HEADLESS { if i.mode == types.InteractiveModeHEADLESS {
return return
} }
lipgloss.SetColorProfile(termenv.TrueColor) lipgloss.SetColorProfile(termenv.TrueColor)
domain := config.Getenv("DOMAIN", "localhost")
protocol := "http" protocol := "http"
if config.Getenv("TLS_ENABLED", "false") == "true" { if i.config.TLSEnabled() {
protocol = "https" protocol = "https"
} }
@@ -209,7 +210,7 @@ func (i *interaction) Start() {
ti.Width = 50 ti.Width = 50
m := &model{ m := &model{
domain: domain, domain: i.config.Domain(),
protocol: protocol, protocol: protocol,
tunnelType: tunnelType, tunnelType: tunnelType,
port: port, port: port,
+1 -1
View File
@@ -41,7 +41,7 @@ type model struct {
} }
func (m *model) getTunnelURL() string { func (m *model) getTunnelURL() string {
if m.tunnelType == types.HTTP { if m.tunnelType == types.TunnelTypeHTTP {
return buildURL(m.protocol, m.interaction.slug.String(), m.domain) return buildURL(m.protocol, m.interaction.slug.String(), m.domain)
} }
return fmt.Sprintf("tcp://%s:%d", m.domain, m.port) return fmt.Sprintf("tcp://%s:%d", m.domain, m.port)
+4 -4
View File
@@ -15,7 +15,7 @@ import (
func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) { func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd var cmd tea.Cmd
if m.tunnelType != types.HTTP { if m.tunnelType != types.TunnelTypeHTTP {
m.editingSlug = false m.editingSlug = false
m.slugError = "" m.slugError = ""
return m, tea.Batch(tea.ClearScreen, textinput.Blink) return m, tea.Batch(tea.ClearScreen, textinput.Blink)
@@ -30,10 +30,10 @@ func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
inputValue := m.slugInput.Value() inputValue := m.slugInput.Value()
if err := m.interaction.sessionRegistry.Update(m.interaction.user, types.SessionKey{ if err := m.interaction.sessionRegistry.Update(m.interaction.user, types.SessionKey{
Id: m.interaction.slug.String(), Id: m.interaction.slug.String(),
Type: types.HTTP, Type: types.TunnelTypeHTTP,
}, types.SessionKey{ }, types.SessionKey{
Id: inputValue, Id: inputValue,
Type: types.HTTP, Type: types.TunnelTypeHTTP,
}); err != nil { }); err != nil {
m.slugError = err.Error() m.slugError = err.Error()
return m, nil return m, nil
@@ -130,7 +130,7 @@ func (m *model) slugView() string {
b.WriteString(titleStyle.Render(title)) b.WriteString(titleStyle.Render(title))
b.WriteString("\n\n") b.WriteString("\n\n")
if m.tunnelType != types.HTTP { if m.tunnelType != types.TunnelTypeHTTP {
warningBoxWidth := getResponsiveWidth(m.width, 10, 30, 60) warningBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
warningBoxStyle := lipgloss.NewStyle(). warningBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FFA500")). Foreground(lipgloss.Color("#FFA500")).
+29 -33
View File
@@ -2,8 +2,6 @@ package lifecycle
import ( import (
"errors" "errors"
"io"
"net"
"time" "time"
portUtil "tunnel_pls/internal/port" portUtil "tunnel_pls/internal/port"
@@ -24,20 +22,20 @@ type SessionRegistry interface {
} }
type lifecycle struct { type lifecycle struct {
status types.Status status types.SessionStatus
conn ssh.Conn conn ssh.Conn
channel ssh.Channel channel ssh.Channel
forwarder Forwarder forwarder Forwarder
slug slug.Slug slug slug.Slug
startedAt time.Time startedAt time.Time
sessionRegistry SessionRegistry sessionRegistry SessionRegistry
portRegistry portUtil.Registry portRegistry portUtil.Port
user string user string
} }
func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUtil.Registry, sessionRegistry SessionRegistry, user string) Lifecycle { func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUtil.Port, sessionRegistry SessionRegistry, user string) Lifecycle {
return &lifecycle{ return &lifecycle{
status: types.INITIALIZING, status: types.SessionStatusINITIALIZING,
conn: conn, conn: conn,
channel: nil, channel: nil,
forwarder: forwarder, forwarder: forwarder,
@@ -51,16 +49,16 @@ func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUti
type Lifecycle interface { type Lifecycle interface {
Connection() ssh.Conn Connection() ssh.Conn
PortRegistry() portUtil.Registry PortRegistry() portUtil.Port
User() string User() string
SetChannel(channel ssh.Channel) SetChannel(channel ssh.Channel)
SetStatus(status types.Status) SetStatus(status types.SessionStatus)
IsActive() bool IsActive() bool
StartedAt() time.Time StartedAt() time.Time
Close() error Close() error
} }
func (l *lifecycle) PortRegistry() portUtil.Registry { func (l *lifecycle) PortRegistry() portUtil.Port {
return l.portRegistry return l.portRegistry
} }
@@ -74,35 +72,30 @@ func (l *lifecycle) SetChannel(channel ssh.Channel) {
func (l *lifecycle) Connection() ssh.Conn { func (l *lifecycle) Connection() ssh.Conn {
return l.conn return l.conn
} }
func (l *lifecycle) SetStatus(status types.Status) { func (l *lifecycle) SetStatus(status types.SessionStatus) {
l.status = status l.status = status
if status == types.RUNNING && l.startedAt.IsZero() { if status == types.SessionStatusRUNNING && l.startedAt.IsZero() {
l.startedAt = time.Now() l.startedAt = time.Now()
} }
} }
func closeIfNotNil(c interface{ Close() error }) error {
if c != nil {
return c.Close()
}
return nil
}
func (l *lifecycle) Close() error { func (l *lifecycle) Close() error {
var firstErr error var errs []error
tunnelType := l.forwarder.TunnelType() tunnelType := l.forwarder.TunnelType()
if err := l.forwarder.Close(); err != nil && !errors.Is(err, net.ErrClosed) { if err := closeIfNotNil(l.channel); err != nil {
firstErr = err errs = append(errs, err)
} }
if l.channel != nil { if err := closeIfNotNil(l.conn); err != nil {
if err := l.channel.Close(); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { errs = append(errs, err)
if firstErr == nil {
firstErr = err
}
}
}
if l.conn != nil {
if err := l.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
if firstErr == nil {
firstErr = err
}
}
} }
clientSlug := l.slug.String() clientSlug := l.slug.String()
@@ -112,17 +105,20 @@ func (l *lifecycle) Close() error {
} }
l.sessionRegistry.Remove(key) l.sessionRegistry.Remove(key)
if tunnelType == types.TCP { if tunnelType == types.TunnelTypeTCP {
if err := l.PortRegistry().SetPortStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil { if err := l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false); err != nil {
firstErr = err errs = append(errs, err)
}
if err := l.forwarder.Close(); err != nil {
errs = append(errs, err)
} }
} }
return firstErr return errors.Join(errs...)
} }
func (l *lifecycle) IsActive() bool { func (l *lifecycle) IsActive() bool {
return l.status == types.RUNNING return l.status == types.SessionStatusRUNNING
} }
func (l *lifecycle) StartedAt() time.Time { func (l *lifecycle) StartedAt() time.Time {
+215 -222
View File
@@ -12,6 +12,8 @@ import (
"tunnel_pls/internal/config" "tunnel_pls/internal/config"
portUtil "tunnel_pls/internal/port" portUtil "tunnel_pls/internal/port"
"tunnel_pls/internal/random" "tunnel_pls/internal/random"
"tunnel_pls/internal/registry"
"tunnel_pls/internal/transport"
"tunnel_pls/session/forwarder" "tunnel_pls/session/forwarder"
"tunnel_pls/session/interaction" "tunnel_pls/session/interaction"
"tunnel_pls/session/lifecycle" "tunnel_pls/session/lifecycle"
@@ -21,46 +23,40 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
type Detail struct {
ForwardingType string `json:"forwarding_type,omitempty"`
Slug string `json:"slug,omitempty"`
UserID string `json:"user_id,omitempty"`
Active bool `json:"active,omitempty"`
StartedAt time.Time `json:"started_at,omitempty"`
}
type Session interface { type Session interface {
HandleGlobalRequest(ch <-chan *ssh.Request) HandleGlobalRequest(ch <-chan *ssh.Request) error
HandleTCPIPForward(req *ssh.Request) HandleTCPIPForward(req *ssh.Request) error
HandleHTTPForward(req *ssh.Request, port uint16) HandleHTTPForward(req *ssh.Request, port uint16) error
HandleTCPForward(req *ssh.Request, addr string, port uint16) HandleTCPForward(req *ssh.Request, addr string, port uint16) error
Lifecycle() lifecycle.Lifecycle Lifecycle() lifecycle.Lifecycle
Interaction() interaction.Interaction Interaction() interaction.Interaction
Forwarder() forwarder.Forwarder Forwarder() forwarder.Forwarder
Slug() slug.Slug Slug() slug.Slug
Detail() *Detail Detail() *types.Detail
Start() error Start() error
} }
type session struct { type session struct {
config config.Config
initialReq <-chan *ssh.Request initialReq <-chan *ssh.Request
sshChan <-chan ssh.NewChannel sshChan <-chan ssh.NewChannel
lifecycle lifecycle.Lifecycle lifecycle lifecycle.Lifecycle
interaction interaction.Interaction interaction interaction.Interaction
forwarder forwarder.Forwarder forwarder forwarder.Forwarder
slug slug.Slug slug slug.Slug
registry Registry registry registry.Registry
} }
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017} var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, portRegistry portUtil.Registry, user string) Session { func New(config config.Config, conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry registry.Registry, portRegistry portUtil.Port, user string) Session {
slugManager := slug.New() slugManager := slug.New()
forwarderManager := forwarder.New(slugManager, conn) forwarderManager := forwarder.New(config, slugManager, conn)
lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, portRegistry, sessionRegistry, user) lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, portRegistry, sessionRegistry, user)
interactionManager := interaction.New(slugManager, forwarderManager, sessionRegistry, user, lifecycleManager.Close) interactionManager := interaction.New(config, slugManager, forwarderManager, sessionRegistry, user, lifecycleManager.Close)
return &session{ return &session{
config: config,
initialReq: initialReq, initialReq: initialReq,
sshChan: sshChan, sshChan: sshChan,
lifecycle: lifecycleManager, lifecycle: lifecycleManager,
@@ -87,16 +83,17 @@ func (s *session) Slug() slug.Slug {
return s.slug return s.slug
} }
func (s *session) Detail() *Detail { func (s *session) Detail() *types.Detail {
var tunnelType string tunnelTypeMap := map[types.TunnelType]string{
if s.forwarder.TunnelType() == types.HTTP { types.TunnelTypeHTTP: "TunnelTypeHTTP",
tunnelType = "HTTP" types.TunnelTypeTCP: "TunnelTypeTCP",
} else if s.forwarder.TunnelType() == types.TCP {
tunnelType = "TCP"
} else {
tunnelType = "UNKNOWN"
} }
return &Detail{ tunnelType, ok := tunnelTypeMap[s.forwarder.TunnelType()]
if !ok {
tunnelType = "TunnelTypeUNKNOWN"
}
return &types.Detail{
ForwardingType: tunnelType, ForwardingType: tunnelType,
Slug: s.slug.String(), Slug: s.slug.String(),
UserID: s.lifecycle.User(), UserID: s.lifecycle.User(),
@@ -106,55 +103,80 @@ func (s *session) Detail() *Detail {
} }
func (s *session) Start() error { func (s *session) Start() error {
var channel ssh.NewChannel if err := s.setupSessionMode(); err != nil {
var ok bool return err
select {
case channel, ok = <-s.sshChan:
if !ok {
log.Println("Forwarding request channel closed")
return nil
}
ch, reqs, err := channel.Accept()
if err != nil {
log.Printf("failed to accept channel: %v", err)
return err
}
go s.HandleGlobalRequest(reqs)
s.lifecycle.SetChannel(ch)
s.interaction.SetChannel(ch)
s.interaction.SetMode(types.INTERACTIVE)
case <-time.After(500 * time.Millisecond):
s.interaction.SetMode(types.HEADLESS)
} }
tcpipReq := s.waitForTCPIPForward() tcpipReq := s.waitForTCPIPForward()
if tcpipReq == nil { if tcpipReq == nil {
err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200"))) return s.handleMissingForwardRequest()
if err != nil {
return err
}
if err = s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
}
return fmt.Errorf("no forwarding Request")
} }
if (s.interaction.Mode() == types.HEADLESS && config.Getenv("MODE", "standalone") == "standalone") && s.lifecycle.User() == "UNAUTHORIZED" { if s.shouldRejectUnauthorized() {
if err := tcpipReq.Reply(false, nil); err != nil { return s.denyForwardingRequest(tcpipReq, nil, nil, fmt.Sprintf("headless forwarding only allowed on node mode"))
log.Printf("cannot reply to tcpip req: %s\n", err)
return err
}
if err := s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
return err
}
return nil
} }
s.HandleTCPIPForward(tcpipReq) if err := s.HandleTCPIPForward(tcpipReq); err != nil {
return err
}
s.interaction.Start() s.interaction.Start()
return s.waitForSessionEnd()
}
func (s *session) setupSessionMode() error {
select {
case channel, ok := <-s.sshChan:
if !ok {
log.Println("Forwarding request channel closed")
return nil
}
return s.setupInteractiveMode(channel)
case <-time.After(500 * time.Millisecond):
s.interaction.SetMode(types.InteractiveModeHEADLESS)
return nil
}
}
func (s *session) setupInteractiveMode(channel ssh.NewChannel) error {
ch, reqs, err := channel.Accept()
if err != nil {
log.Printf("failed to accept channel: %v", err)
return err
}
go func() {
err = s.HandleGlobalRequest(reqs)
if err != nil {
log.Printf("global request handler error: %v", err)
}
}()
s.lifecycle.SetChannel(ch)
s.interaction.SetChannel(ch)
s.interaction.SetMode(types.InteractiveModeINTERACTIVE)
return nil
}
func (s *session) handleMissingForwardRequest() error {
err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", s.config.Domain(), s.config.SSHPort()))
if err != nil {
return err
}
if err = s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
}
return fmt.Errorf("no forwarding Request")
}
func (s *session) shouldRejectUnauthorized() bool {
return s.interaction.Mode() == types.InteractiveModeHEADLESS &&
s.config.Mode() == types.ServerModeSTANDALONE &&
s.lifecycle.User() == "UNAUTHORIZED"
}
func (s *session) waitForSessionEnd() error {
if err := s.lifecycle.Connection().Wait(); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { if err := s.lifecycle.Connection().Wait(); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
log.Printf("ssh connection closed with error: %v", err) log.Printf("ssh connection closed with error: %v", err)
} }
@@ -187,220 +209,191 @@ func (s *session) waitForTCPIPForward() *ssh.Request {
} }
} }
func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) { func (s *session) handleWindowChange(req *ssh.Request) error {
p := req.Payload
if len(p) < 16 {
log.Println("invalid window-change payload")
return req.Reply(false, nil)
}
cols := binary.BigEndian.Uint32(p[0:4])
rows := binary.BigEndian.Uint32(p[4:8])
s.interaction.SetWH(int(cols), int(rows))
return req.Reply(true, nil)
}
func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) error {
for req := range GlobalRequest { for req := range GlobalRequest {
switch req.Type { switch req.Type {
case "shell", "pty-req": case "shell", "pty-req":
err := req.Reply(true, nil) err := req.Reply(true, nil)
if err != nil { if err != nil {
log.Println("Failed to reply to request:", err) return err
return
} }
case "window-change": case "window-change":
p := req.Payload if err := s.handleWindowChange(req); err != nil {
if len(p) < 16 { return err
log.Println("invalid window-change payload")
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
return
}
cols := binary.BigEndian.Uint32(p[0:4])
rows := binary.BigEndian.Uint32(p[4:8])
s.interaction.SetWH(int(cols), int(rows))
err := req.Reply(true, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
} }
default: default:
log.Println("Unknown request type:", req.Type) log.Println("Unknown request type:", req.Type)
err := req.Reply(false, nil) err := req.Reply(false, nil)
if err != nil { if err != nil {
log.Println("Failed to reply to request:", err) return err
return
} }
} }
} }
return nil
} }
func (s *session) HandleTCPIPForward(req *ssh.Request) { func (s *session) parseForwardPayload(payloadReader io.Reader) (address string, port uint16, err error) {
log.Println("PortRegistry forwarding request detected") address, err = readSSHString(payloadReader)
fail := func(msg string) {
log.Println(msg)
if err := req.Reply(false, nil); err != nil {
log.Println("Failed to reply to request:", err)
return
}
if err := s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
}
}
reader := bytes.NewReader(req.Payload)
addr, err := readSSHString(reader)
if err != nil { if err != nil {
fail(fmt.Sprintf("Failed to read address from payload: %v", err)) return "", 0, err
return
} }
var rawPortToBind uint32 var rawPortToBind uint32
if err = binary.Read(reader, binary.BigEndian, &rawPortToBind); err != nil { if err = binary.Read(payloadReader, binary.BigEndian, &rawPortToBind); err != nil {
fail(fmt.Sprintf("Failed to read port from payload: %v", err)) return "", 0, err
return
} }
if rawPortToBind > 65535 { if rawPortToBind > 65535 {
fail(fmt.Sprintf("PortRegistry %d is larger than allowed port of 65535", rawPortToBind)) return "", 0, fmt.Errorf("port is larger than allowed port of 65535")
return
} }
portToBind := uint16(rawPortToBind) port = uint16(rawPortToBind)
if isBlockedPort(portToBind) { if isBlockedPort(port) {
fail(fmt.Sprintf("PortRegistry %d is blocked or restricted", portToBind)) return "", 0, fmt.Errorf("port is block")
return
} }
switch portToBind { if port == 0 {
unassigned, ok := s.lifecycle.PortRegistry().Unassigned()
if !ok {
return "", 0, fmt.Errorf("no available port")
}
return address, unassigned, err
}
return address, port, err
}
func (s *session) denyForwardingRequest(req *ssh.Request, key *types.SessionKey, listener io.Closer, msg string) error {
var errs []error
if key != nil {
s.registry.Remove(*key)
}
if listener != nil {
if err := listener.Close(); err != nil {
errs = append(errs, fmt.Errorf("close listener: %w", err))
}
}
if err := req.Reply(false, nil); err != nil {
errs = append(errs, fmt.Errorf("reply request: %w", err))
}
if err := s.lifecycle.Close(); err != nil {
errs = append(errs, fmt.Errorf("close session: %w", err))
}
errs = append(errs, fmt.Errorf("deny forwarding request: %s", msg))
return errors.Join(errs...)
}
func (s *session) approveForwardingRequest(req *ssh.Request, port uint16) (err error) {
buf := new(bytes.Buffer)
err = binary.Write(buf, binary.BigEndian, uint32(port))
if err != nil {
return err
}
err = req.Reply(true, buf.Bytes())
if err != nil {
return err
}
return nil
}
func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listener net.Listener, tunnelType types.TunnelType, slug string) error {
err := s.approveForwardingRequest(req, portToBind)
if err != nil {
return err
}
s.forwarder.SetType(tunnelType)
s.forwarder.SetForwardedPort(portToBind)
s.slug.Set(slug)
s.lifecycle.SetStatus(types.SessionStatusRUNNING)
if listener != nil {
s.forwarder.SetListener(listener)
}
return nil
}
func (s *session) HandleTCPIPForward(req *ssh.Request) error {
reader := bytes.NewReader(req.Payload)
address, port, err := s.parseForwardPayload(reader)
if err != nil {
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("cannot parse forwarded payload: %s", err.Error()))
}
switch port {
case 80, 443: case 80, 443:
s.HandleHTTPForward(req, portToBind) return s.HandleHTTPForward(req, port)
default: default:
s.HandleTCPForward(req, addr, portToBind) return s.HandleTCPForward(req, address, port)
} }
} }
func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) { func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) error {
fail := func(msg string, key *types.SessionKey) {
log.Println(msg)
if key != nil {
s.registry.Remove(*key)
}
if err := req.Reply(false, nil); err != nil {
log.Println("Failed to reply to request:", err)
}
}
randomString, err := random.GenerateRandomString(20) randomString, err := random.GenerateRandomString(20)
if err != nil { if err != nil {
fail(fmt.Sprintf("Failed to create slug: %s", err), nil) return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to create slug: %s", err))
return
} }
key := types.SessionKey{Id: randomString, Type: types.HTTP} key := types.SessionKey{Id: randomString, Type: types.TunnelTypeHTTP}
if !s.registry.Register(key, s) { if !s.registry.Register(key, s) {
fail(fmt.Sprintf("Failed to register client with slug: %s", randomString), nil) return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to register client with slug: %s", randomString))
return
} }
buf := new(bytes.Buffer) err = s.finalizeForwarding(req, portToBind, nil, types.TunnelTypeHTTP, key.Id)
err = binary.Write(buf, binary.BigEndian, uint32(portToBind))
if err != nil { if err != nil {
fail(fmt.Sprintf("Failed to write port to buffer: %v", err), &key) return s.denyForwardingRequest(req, &key, nil, fmt.Sprintf("Failed to finalize forwarding: %s", err))
return
} }
log.Printf("HTTP forwarding approved on port: %d", portToBind) return nil
err = req.Reply(true, buf.Bytes())
if err != nil {
fail(fmt.Sprintf("Failed to reply to request: %v", err), &key)
return
}
s.forwarder.SetType(types.HTTP)
s.forwarder.SetForwardedPort(portToBind)
s.slug.Set(randomString)
s.lifecycle.SetStatus(types.RUNNING)
} }
func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) { func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) error {
fail := func(msg string) { if claimed := s.lifecycle.PortRegistry().Claim(portToBind); !claimed {
log.Println(msg) return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind))
if err := req.Reply(false, nil); err != nil {
log.Println("Failed to reply to request:", err)
return
}
if err := s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
}
} }
cleanup := func(msg string, port uint16, listener net.Listener, key *types.SessionKey) { tcpServer := transport.NewTCPServer(portToBind, s.forwarder)
log.Println(msg) listener, err := tcpServer.Listen()
if key != nil {
s.registry.Remove(*key)
}
if port != 0 {
if setErr := s.lifecycle.PortRegistry().SetPortStatus(port, false); setErr != nil {
log.Printf("Failed to reset port status: %v", setErr)
}
}
if listener != nil {
if closeErr := listener.Close(); closeErr != nil {
log.Printf("Failed to close listener: %v", closeErr)
}
}
if err := req.Reply(false, nil); err != nil {
log.Println("Failed to reply to request:", err)
}
_ = s.lifecycle.Close()
}
if portToBind == 0 {
unassigned, ok := s.lifecycle.PortRegistry().GetUnassignedPort()
if !ok {
fail("No available port")
return
}
portToBind = unassigned
}
if claimed := s.lifecycle.PortRegistry().ClaimPort(portToBind); !claimed {
fail(fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind))
return
}
log.Printf("Requested forwarding on %s:%d", addr, portToBind)
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
if err != nil { if err != nil {
cleanup(fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind), portToBind, nil, nil) return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind))
return
} }
key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TCP} key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TunnelTypeTCP}
if !s.registry.Register(key, s) { if !s.registry.Register(key, s) {
cleanup(fmt.Sprintf("Failed to register TCP client with id: %s", key.Id), portToBind, listener, nil) return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Failed to register TunnelTypeTCP client with id: %s", key.Id))
return
} }
buf := new(bytes.Buffer) err = s.finalizeForwarding(req, portToBind, listener, types.TunnelTypeTCP, key.Id)
err = binary.Write(buf, binary.BigEndian, uint32(portToBind))
if err != nil { if err != nil {
cleanup(fmt.Sprintf("Failed to write port to buffer: %v", err), portToBind, listener, &key) return s.denyForwardingRequest(req, &key, listener, fmt.Sprintf("Failed to finalize forwarding: %s", err))
return
} }
log.Printf("TCP forwarding approved on port: %d", portToBind) go func() {
err = req.Reply(true, buf.Bytes()) err = tcpServer.Serve(listener)
if err != nil { if err != nil {
cleanup(fmt.Sprintf("Failed to reply to request: %v", err), portToBind, listener, &key) log.Printf("Failed serving tcp server: %s\n", err)
return }
} }()
s.forwarder.SetType(types.TCP) return nil
s.forwarder.SetListener(listener)
s.forwarder.SetForwardedPort(portToBind)
s.slug.Set(key.Id)
s.lifecycle.SetStatus(types.RUNNING)
go s.forwarder.AcceptTCPConnections()
} }
func readSSHString(reader *bytes.Reader) (string, error) { func readSSHString(reader io.Reader) (string, error) {
var length uint32 var length uint32
if err := binary.Read(reader, binary.BigEndian, &length); err != nil { if err := binary.Read(reader, binary.BigEndian, &length); err != nil {
return "", err return "", err
+1
View File
@@ -0,0 +1 @@
sonar.projectKey=tunnel-please
+26 -9
View File
@@ -1,25 +1,34 @@
package types package types
type Status int import "time"
type SessionStatus int
const ( const (
INITIALIZING Status = iota SessionStatusINITIALIZING SessionStatus = iota
RUNNING SessionStatusRUNNING
) )
type Mode int type InteractiveMode int
const ( const (
INTERACTIVE Mode = iota InteractiveModeINTERACTIVE InteractiveMode = iota + 1
HEADLESS InteractiveModeHEADLESS
) )
type TunnelType int type TunnelType int
const ( const (
UNKNOWN TunnelType = iota TunnelTypeUNKNOWN TunnelType = iota
HTTP TunnelTypeHTTP
TCP TunnelTypeTCP
)
type ServerMode int
const (
ServerModeSTANDALONE = iota + 1
ServerModeNODE
) )
type SessionKey struct { type SessionKey struct {
@@ -27,6 +36,14 @@ type SessionKey struct {
Type TunnelType Type TunnelType
} }
type Detail struct {
ForwardingType string `json:"forwarding_type,omitempty"`
Slug string `json:"slug,omitempty"`
UserID string `json:"user_id,omitempty"`
Active bool `json:"active,omitempty"`
StartedAt time.Time `json:"started_at,omitempty"`
}
var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" + var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" +
"Content-Length: 11\r\n" + "Content-Length: 11\r\n" +
"Content-Type: text/plain\r\n\r\n" + "Content-Type: text/plain\r\n\r\n" +