refactor: improve encapsulation
This commit is contained in:
@@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
func (s *Server) handleConnection(conn net.Conn) {
|
||||
sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.Config)
|
||||
sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.config)
|
||||
if err != nil {
|
||||
log.Printf("failed to establish SSH connection: %v", err)
|
||||
err := conn.Close()
|
||||
|
||||
@@ -14,21 +14,38 @@ type HeaderManager interface {
|
||||
Finalize() []byte
|
||||
}
|
||||
|
||||
type ResponseHeaderFactory struct {
|
||||
type ResponseHeaderManager interface {
|
||||
Get(key string) string
|
||||
Set(key string, value string)
|
||||
Remove(key string)
|
||||
Finalize() []byte
|
||||
}
|
||||
|
||||
type RequestHeaderManager interface {
|
||||
Get(key string) string
|
||||
Set(key string, value string)
|
||||
Remove(key string)
|
||||
Finalize() []byte
|
||||
GetMethod() string
|
||||
GetPath() string
|
||||
GetVersion() string
|
||||
}
|
||||
|
||||
type responseHeaderFactory struct {
|
||||
startLine []byte
|
||||
headers map[string]string
|
||||
}
|
||||
|
||||
type RequestHeaderFactory struct {
|
||||
Method string
|
||||
Path string
|
||||
Version string
|
||||
type requestHeaderFactory struct {
|
||||
method string
|
||||
path string
|
||||
version string
|
||||
startLine []byte
|
||||
headers map[string]string
|
||||
}
|
||||
|
||||
func NewRequestHeaderFactory(br *bufio.Reader) (*RequestHeaderFactory, error) {
|
||||
header := &RequestHeaderFactory{
|
||||
func NewRequestHeaderFactory(br *bufio.Reader) (RequestHeaderManager, error) {
|
||||
header := &requestHeaderFactory{
|
||||
headers: make(map[string]string),
|
||||
}
|
||||
|
||||
@@ -44,9 +61,9 @@ func NewRequestHeaderFactory(br *bufio.Reader) (*RequestHeaderFactory, error) {
|
||||
return nil, fmt.Errorf("invalid request line")
|
||||
}
|
||||
|
||||
header.Method = parts[0]
|
||||
header.Path = parts[1]
|
||||
header.Version = parts[2]
|
||||
header.method = parts[0]
|
||||
header.path = parts[1]
|
||||
header.version = parts[2]
|
||||
|
||||
for {
|
||||
line, err := br.ReadString('\n')
|
||||
@@ -69,8 +86,8 @@ func NewRequestHeaderFactory(br *bufio.Reader) (*RequestHeaderFactory, error) {
|
||||
return header, nil
|
||||
}
|
||||
|
||||
func NewResponseHeaderFactory(startLine []byte) *ResponseHeaderFactory {
|
||||
header := &ResponseHeaderFactory{
|
||||
func NewResponseHeaderFactory(startLine []byte) ResponseHeaderManager {
|
||||
header := &responseHeaderFactory{
|
||||
startLine: nil,
|
||||
headers: make(map[string]string),
|
||||
}
|
||||
@@ -96,19 +113,19 @@ func NewResponseHeaderFactory(startLine []byte) *ResponseHeaderFactory {
|
||||
return header
|
||||
}
|
||||
|
||||
func (resp *ResponseHeaderFactory) Get(key string) string {
|
||||
func (resp *responseHeaderFactory) Get(key string) string {
|
||||
return resp.headers[key]
|
||||
}
|
||||
|
||||
func (resp *ResponseHeaderFactory) Set(key string, value string) {
|
||||
func (resp *responseHeaderFactory) Set(key string, value string) {
|
||||
resp.headers[key] = value
|
||||
}
|
||||
|
||||
func (resp *ResponseHeaderFactory) Remove(key string) {
|
||||
func (resp *responseHeaderFactory) Remove(key string) {
|
||||
delete(resp.headers, key)
|
||||
}
|
||||
|
||||
func (resp *ResponseHeaderFactory) Finalize() []byte {
|
||||
func (resp *responseHeaderFactory) Finalize() []byte {
|
||||
var buf bytes.Buffer
|
||||
|
||||
buf.Write(resp.startLine)
|
||||
@@ -125,7 +142,7 @@ func (resp *ResponseHeaderFactory) Finalize() []byte {
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func (req *RequestHeaderFactory) Get(key string) string {
|
||||
func (req *requestHeaderFactory) Get(key string) string {
|
||||
val, ok := req.headers[key]
|
||||
if !ok {
|
||||
return ""
|
||||
@@ -133,15 +150,27 @@ func (req *RequestHeaderFactory) Get(key string) string {
|
||||
return val
|
||||
}
|
||||
|
||||
func (req *RequestHeaderFactory) Set(key string, value string) {
|
||||
func (req *requestHeaderFactory) Set(key string, value string) {
|
||||
req.headers[key] = value
|
||||
}
|
||||
|
||||
func (req *RequestHeaderFactory) Remove(key string) {
|
||||
func (req *requestHeaderFactory) Remove(key string) {
|
||||
delete(req.headers, key)
|
||||
}
|
||||
|
||||
func (req *RequestHeaderFactory) Finalize() []byte {
|
||||
func (req *requestHeaderFactory) GetMethod() string {
|
||||
return req.method
|
||||
}
|
||||
|
||||
func (req *requestHeaderFactory) GetPath() string {
|
||||
return req.path
|
||||
}
|
||||
|
||||
func (req *requestHeaderFactory) GetVersion() string {
|
||||
return req.version
|
||||
}
|
||||
|
||||
func (req *requestHeaderFactory) Finalize() []byte {
|
||||
var buf bytes.Buffer
|
||||
|
||||
buf.Write(req.startLine)
|
||||
|
||||
@@ -20,25 +20,63 @@ import (
|
||||
type Interaction interface {
|
||||
SendMessage(message string)
|
||||
}
|
||||
type CustomWriter struct {
|
||||
RemoteAddr net.Addr
|
||||
|
||||
type HTTPWriter interface {
|
||||
io.Reader
|
||||
io.Writer
|
||||
SetInteraction(interaction Interaction)
|
||||
AddInteraction(interaction Interaction)
|
||||
GetRemoteAddr() net.Addr
|
||||
GetWriter() io.Writer
|
||||
AddResponseMiddleware(mw ResponseMiddleware)
|
||||
AddRequestStartMiddleware(mw RequestMiddleware)
|
||||
SetRequestHeader(header RequestHeaderManager)
|
||||
GetRequestStartMiddleware() []RequestMiddleware
|
||||
}
|
||||
|
||||
type customWriter struct {
|
||||
remoteAddr net.Addr
|
||||
writer io.Writer
|
||||
reader io.Reader
|
||||
headerBuf []byte
|
||||
buf []byte
|
||||
respHeader *ResponseHeaderFactory
|
||||
reqHeader *RequestHeaderFactory
|
||||
respHeader ResponseHeaderManager
|
||||
reqHeader RequestHeaderManager
|
||||
interaction Interaction
|
||||
respMW []ResponseMiddleware
|
||||
reqStartMW []RequestMiddleware
|
||||
reqEndMW []RequestMiddleware
|
||||
}
|
||||
|
||||
func (cw *CustomWriter) SetInteraction(interaction Interaction) {
|
||||
func (cw *customWriter) SetInteraction(interaction Interaction) {
|
||||
cw.interaction = interaction
|
||||
}
|
||||
|
||||
func (cw *CustomWriter) Read(p []byte) (int, error) {
|
||||
func (cw *customWriter) GetRemoteAddr() net.Addr {
|
||||
return cw.remoteAddr
|
||||
}
|
||||
|
||||
func (cw *customWriter) GetWriter() io.Writer {
|
||||
return cw.writer
|
||||
}
|
||||
|
||||
func (cw *customWriter) AddResponseMiddleware(mw ResponseMiddleware) {
|
||||
cw.respMW = append(cw.respMW, mw)
|
||||
}
|
||||
|
||||
func (cw *customWriter) AddRequestStartMiddleware(mw RequestMiddleware) {
|
||||
cw.reqStartMW = append(cw.reqStartMW, mw)
|
||||
}
|
||||
|
||||
func (cw *customWriter) SetRequestHeader(header RequestHeaderManager) {
|
||||
cw.reqHeader = header
|
||||
}
|
||||
|
||||
func (cw *customWriter) GetRequestStartMiddleware() []RequestMiddleware {
|
||||
return cw.reqStartMW
|
||||
}
|
||||
|
||||
func (cw *customWriter) Read(p []byte) (int, error) {
|
||||
tmp := make([]byte, len(p))
|
||||
read, err := cw.reader.Read(tmp)
|
||||
if read == 0 && err != nil {
|
||||
@@ -95,9 +133,9 @@ func (cw *CustomWriter) Read(p []byte) (int, error) {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func NewCustomWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) *CustomWriter {
|
||||
return &CustomWriter{
|
||||
RemoteAddr: remoteAddr,
|
||||
func NewCustomWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) HTTPWriter {
|
||||
return &customWriter{
|
||||
remoteAddr: remoteAddr,
|
||||
writer: writer,
|
||||
reader: reader,
|
||||
buf: make([]byte, 0, 4096),
|
||||
@@ -129,7 +167,7 @@ func isHTTPHeader(buf []byte) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (cw *CustomWriter) Write(p []byte) (int, error) {
|
||||
func (cw *customWriter) Write(p []byte) (int, error) {
|
||||
if cw.respHeader != nil && len(cw.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/" {
|
||||
cw.respHeader = nil
|
||||
}
|
||||
@@ -186,7 +224,7 @@ func (cw *CustomWriter) Write(p []byte) (int, error) {
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (cw *CustomWriter) AddInteraction(interaction Interaction) {
|
||||
func (cw *customWriter) AddInteraction(interaction Interaction) {
|
||||
cw.interaction = interaction
|
||||
}
|
||||
|
||||
@@ -292,13 +330,13 @@ func Handler(conn net.Conn) {
|
||||
return
|
||||
}
|
||||
cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr())
|
||||
cw.SetInteraction(sshSession.Interaction)
|
||||
cw.SetInteraction(sshSession.GetInteraction())
|
||||
forwardRequest(cw, reqhf, sshSession)
|
||||
return
|
||||
}
|
||||
|
||||
func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshSession *session.SSHSession) {
|
||||
payload := sshSession.Forwarder.CreateForwardedTCPIPPayload(cw.RemoteAddr)
|
||||
func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSession *session.SSHSession) {
|
||||
payload := sshSession.GetForwarder().CreateForwardedTCPIPPayload(cw.GetRemoteAddr())
|
||||
|
||||
type channelResult struct {
|
||||
channel ssh.Channel
|
||||
@@ -308,7 +346,7 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS
|
||||
resultChan := make(chan channelResult, 1)
|
||||
|
||||
go func() {
|
||||
channel, reqs, err := sshSession.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
|
||||
channel, reqs, err := sshSession.GetLifecycle().GetConnection().OpenChannel("forwarded-tcpip", payload)
|
||||
resultChan <- channelResult{channel, reqs, err}
|
||||
}()
|
||||
|
||||
@@ -319,29 +357,28 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS
|
||||
case result := <-resultChan:
|
||||
if result.err != nil {
|
||||
log.Printf("Failed to open forwarded-tcpip channel: %v", result.err)
|
||||
sshSession.Forwarder.WriteBadGatewayResponse(cw.writer)
|
||||
sshSession.GetForwarder().WriteBadGatewayResponse(cw.GetWriter())
|
||||
return
|
||||
}
|
||||
channel = result.channel
|
||||
reqs = result.reqs
|
||||
case <-time.After(5 * time.Second):
|
||||
log.Printf("Timeout opening forwarded-tcpip channel")
|
||||
sshSession.Forwarder.WriteBadGatewayResponse(cw.writer)
|
||||
sshSession.GetForwarder().WriteBadGatewayResponse(cw.GetWriter())
|
||||
return
|
||||
}
|
||||
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
fingerprintMiddleware := NewTunnelFingerprint()
|
||||
forwardedForMiddleware := NewForwardedFor(cw.RemoteAddr)
|
||||
forwardedForMiddleware := NewForwardedFor(cw.GetRemoteAddr())
|
||||
|
||||
cw.respMW = append(cw.respMW, fingerprintMiddleware)
|
||||
cw.reqStartMW = append(cw.reqStartMW, forwardedForMiddleware)
|
||||
cw.reqEndMW = nil
|
||||
cw.reqHeader = initialRequest
|
||||
cw.AddResponseMiddleware(fingerprintMiddleware)
|
||||
cw.AddRequestStartMiddleware(forwardedForMiddleware)
|
||||
cw.SetRequestHeader(initialRequest)
|
||||
|
||||
for _, m := range cw.reqStartMW {
|
||||
if err := m.HandleRequest(cw.reqHeader); err != nil {
|
||||
for _, m := range cw.GetRequestStartMiddleware() {
|
||||
if err := m.HandleRequest(initialRequest); err != nil {
|
||||
log.Printf("Error handling request: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -353,6 +390,6 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS
|
||||
return
|
||||
}
|
||||
|
||||
sshSession.Forwarder.HandleConnection(cw, channel, cw.RemoteAddr)
|
||||
sshSession.GetForwarder().HandleConnection(cw, channel, cw.GetRemoteAddr())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -104,7 +104,7 @@ func HandlerTLS(conn net.Conn) {
|
||||
return
|
||||
}
|
||||
cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr())
|
||||
cw.SetInteraction(sshSession.Interaction)
|
||||
cw.SetInteraction(sshSession.GetInteraction())
|
||||
forwardRequest(cw, reqhf, sshSession)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -5,11 +5,11 @@ import (
|
||||
)
|
||||
|
||||
type RequestMiddleware interface {
|
||||
HandleRequest(header *RequestHeaderFactory) error
|
||||
HandleRequest(header RequestHeaderManager) error
|
||||
}
|
||||
|
||||
type ResponseMiddleware interface {
|
||||
HandleResponse(header *ResponseHeaderFactory, body []byte) error
|
||||
HandleResponse(header ResponseHeaderManager, body []byte) error
|
||||
}
|
||||
|
||||
type TunnelFingerprint struct{}
|
||||
@@ -18,16 +18,11 @@ func NewTunnelFingerprint() *TunnelFingerprint {
|
||||
return &TunnelFingerprint{}
|
||||
}
|
||||
|
||||
func (h *TunnelFingerprint) HandleResponse(header *ResponseHeaderFactory, body []byte) error {
|
||||
func (h *TunnelFingerprint) HandleResponse(header ResponseHeaderManager, body []byte) error {
|
||||
header.Set("Server", "Tunnel Please")
|
||||
return nil
|
||||
}
|
||||
|
||||
type RequestLogger struct {
|
||||
interaction Interaction
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
type ForwardedFor struct {
|
||||
addr net.Addr
|
||||
}
|
||||
@@ -36,7 +31,7 @@ func NewForwardedFor(addr net.Addr) *ForwardedFor {
|
||||
return &ForwardedFor{addr: addr}
|
||||
}
|
||||
|
||||
func (ff *ForwardedFor) HandleRequest(header *RequestHeaderFactory) error {
|
||||
func (ff *ForwardedFor) HandleRequest(header RequestHeaderManager) error {
|
||||
host, _, err := net.SplitHostPort(ff.addr.String())
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -11,9 +11,21 @@ import (
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
Conn *net.Listener
|
||||
Config *ssh.ServerConfig
|
||||
HttpServer *http.Server
|
||||
conn *net.Listener
|
||||
config *ssh.ServerConfig
|
||||
httpServer *http.Server
|
||||
}
|
||||
|
||||
func (s *Server) GetConn() *net.Listener {
|
||||
return s.conn
|
||||
}
|
||||
|
||||
func (s *Server) GetConfig() *ssh.ServerConfig {
|
||||
return s.config
|
||||
}
|
||||
|
||||
func (s *Server) GetHttpServer() *http.Server {
|
||||
return s.httpServer
|
||||
}
|
||||
|
||||
func NewServer(config *ssh.ServerConfig) *Server {
|
||||
@@ -33,15 +45,15 @@ func NewServer(config *ssh.ServerConfig) *Server {
|
||||
log.Fatalf("failed to start http server: %v", err)
|
||||
}
|
||||
return &Server{
|
||||
Conn: &listener,
|
||||
Config: config,
|
||||
conn: &listener,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) Start() {
|
||||
log.Println("SSH server is starting on port 2200...")
|
||||
for {
|
||||
conn, err := (*s.Conn).Accept()
|
||||
conn, err := (*s.conn).Accept()
|
||||
if err != nil {
|
||||
log.Printf("failed to accept connection: %v", err)
|
||||
continue
|
||||
|
||||
@@ -16,7 +16,16 @@ import (
|
||||
"github.com/libdns/cloudflare"
|
||||
)
|
||||
|
||||
type TLSManager struct {
|
||||
type TLSManager interface {
|
||||
userCertsExistAndValid() bool
|
||||
loadUserCerts() error
|
||||
startCertWatcher()
|
||||
initCertMagic() error
|
||||
getTLSConfig() *tls.Config
|
||||
getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error)
|
||||
}
|
||||
|
||||
type tlsManager struct {
|
||||
domain string
|
||||
certPath string
|
||||
keyPath string
|
||||
@@ -30,7 +39,7 @@ type TLSManager struct {
|
||||
useCertMagic bool
|
||||
}
|
||||
|
||||
var tlsManager *TLSManager
|
||||
var globalTLSManager TLSManager
|
||||
var tlsManagerOnce sync.Once
|
||||
|
||||
func NewTLSConfig(domain string) (*tls.Config, error) {
|
||||
@@ -41,7 +50,7 @@ func NewTLSConfig(domain string) (*tls.Config, error) {
|
||||
keyPath := "certs/tls/privkey.pem"
|
||||
storagePath := "certs/tls/certmagic"
|
||||
|
||||
tm := &TLSManager{
|
||||
tm := &tlsManager{
|
||||
domain: domain,
|
||||
certPath: certPath,
|
||||
keyPath: keyPath,
|
||||
@@ -72,14 +81,14 @@ func NewTLSConfig(domain string) (*tls.Config, error) {
|
||||
tm.useCertMagic = true
|
||||
}
|
||||
|
||||
tlsManager = tm
|
||||
globalTLSManager = tm
|
||||
})
|
||||
|
||||
if initErr != nil {
|
||||
return nil, initErr
|
||||
}
|
||||
|
||||
return tlsManager.getTLSConfig(), nil
|
||||
return globalTLSManager.getTLSConfig(), nil
|
||||
}
|
||||
|
||||
func isACMEConfigComplete() bool {
|
||||
@@ -87,7 +96,7 @@ func isACMEConfigComplete() bool {
|
||||
return cfAPIToken != ""
|
||||
}
|
||||
|
||||
func (tm *TLSManager) userCertsExistAndValid() bool {
|
||||
func (tm *tlsManager) userCertsExistAndValid() bool {
|
||||
if _, err := os.Stat(tm.certPath); os.IsNotExist(err) {
|
||||
log.Printf("Certificate file not found: %s", tm.certPath)
|
||||
return false
|
||||
@@ -158,7 +167,7 @@ func ValidateCertDomains(certPath, domain string) bool {
|
||||
return hasBase && hasWildcard
|
||||
}
|
||||
|
||||
func (tm *TLSManager) loadUserCerts() error {
|
||||
func (tm *tlsManager) loadUserCerts() error {
|
||||
cert, err := tls.LoadX509KeyPair(tm.certPath, tm.keyPath)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -172,7 +181,7 @@ func (tm *TLSManager) loadUserCerts() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tm *TLSManager) startCertWatcher() {
|
||||
func (tm *tlsManager) startCertWatcher() {
|
||||
go func() {
|
||||
var lastCertMod, lastKeyMod time.Time
|
||||
|
||||
@@ -227,7 +236,7 @@ func (tm *TLSManager) startCertWatcher() {
|
||||
}()
|
||||
}
|
||||
|
||||
func (tm *TLSManager) initCertMagic() error {
|
||||
func (tm *tlsManager) initCertMagic() error {
|
||||
if err := os.MkdirAll(tm.storagePath, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create cert storage directory: %w", err)
|
||||
}
|
||||
@@ -289,14 +298,14 @@ func (tm *TLSManager) initCertMagic() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tm *TLSManager) getTLSConfig() *tls.Config {
|
||||
func (tm *tlsManager) getTLSConfig() *tls.Config {
|
||||
return &tls.Config{
|
||||
GetCertificate: tm.getCertificate,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
}
|
||||
|
||||
func (tm *TLSManager) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
func (tm *tlsManager) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
if tm.useCertMagic {
|
||||
return tm.magic.GetCertificate(hello)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user