refactor: replace Get/Set patterns with idiomatic Go interfaces
- rename constructors to New - remove Get/Set-style accessors - replace string-based enums with iota-backed types
This commit is contained in:
@@ -30,50 +30,50 @@ func copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
|
||||
return io.CopyBuffer(dst, src, buf)
|
||||
}
|
||||
|
||||
type Forwarder struct {
|
||||
type forwarder struct {
|
||||
listener net.Listener
|
||||
tunnelType types.TunnelType
|
||||
forwardedPort uint16
|
||||
slugManager slug.Manager
|
||||
slug slug.Slug
|
||||
lifecycle Lifecycle
|
||||
}
|
||||
|
||||
func NewForwarder(slugManager slug.Manager) *Forwarder {
|
||||
return &Forwarder{
|
||||
func New(slug slug.Slug) Forwarder {
|
||||
return &forwarder{
|
||||
listener: nil,
|
||||
tunnelType: "",
|
||||
tunnelType: types.UNKNOWN,
|
||||
forwardedPort: 0,
|
||||
slugManager: slugManager,
|
||||
slug: slug,
|
||||
lifecycle: nil,
|
||||
}
|
||||
}
|
||||
|
||||
type Lifecycle interface {
|
||||
GetConnection() ssh.Conn
|
||||
Connection() ssh.Conn
|
||||
}
|
||||
|
||||
type ForwardingController interface {
|
||||
AcceptTCPConnections()
|
||||
type Forwarder interface {
|
||||
SetType(tunnelType types.TunnelType)
|
||||
GetTunnelType() types.TunnelType
|
||||
GetForwardedPort() uint16
|
||||
SetLifecycle(lifecycle Lifecycle)
|
||||
SetForwardedPort(port uint16)
|
||||
SetListener(listener net.Listener)
|
||||
GetListener() net.Listener
|
||||
Close() error
|
||||
Listener() net.Listener
|
||||
TunnelType() types.TunnelType
|
||||
ForwardedPort() uint16
|
||||
HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr)
|
||||
SetLifecycle(lifecycle Lifecycle)
|
||||
CreateForwardedTCPIPPayload(origin net.Addr) []byte
|
||||
WriteBadGatewayResponse(dst io.Writer)
|
||||
AcceptTCPConnections()
|
||||
Close() error
|
||||
}
|
||||
|
||||
func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) {
|
||||
func (f *forwarder) SetLifecycle(lifecycle Lifecycle) {
|
||||
f.lifecycle = lifecycle
|
||||
}
|
||||
|
||||
func (f *Forwarder) AcceptTCPConnections() {
|
||||
func (f *forwarder) AcceptTCPConnections() {
|
||||
for {
|
||||
conn, err := f.GetListener().Accept()
|
||||
conn, err := f.Listener().Accept()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
@@ -100,7 +100,7 @@ func (f *Forwarder) AcceptTCPConnections() {
|
||||
resultChan := make(chan channelResult, 1)
|
||||
|
||||
go func() {
|
||||
channel, reqs, err := f.lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
|
||||
channel, reqs, err := f.lifecycle.Connection().OpenChannel("forwarded-tcpip", payload)
|
||||
resultChan <- channelResult{channel, reqs, err}
|
||||
}()
|
||||
|
||||
@@ -130,7 +130,7 @@ func (f *Forwarder) AcceptTCPConnections() {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) {
|
||||
func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) {
|
||||
defer func() {
|
||||
_, err := io.Copy(io.Discard, src)
|
||||
if err != nil {
|
||||
@@ -174,31 +174,31 @@ func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (f *Forwarder) SetType(tunnelType types.TunnelType) {
|
||||
func (f *forwarder) SetType(tunnelType types.TunnelType) {
|
||||
f.tunnelType = tunnelType
|
||||
}
|
||||
|
||||
func (f *Forwarder) GetTunnelType() types.TunnelType {
|
||||
func (f *forwarder) TunnelType() types.TunnelType {
|
||||
return f.tunnelType
|
||||
}
|
||||
|
||||
func (f *Forwarder) GetForwardedPort() uint16 {
|
||||
func (f *forwarder) ForwardedPort() uint16 {
|
||||
return f.forwardedPort
|
||||
}
|
||||
|
||||
func (f *Forwarder) SetForwardedPort(port uint16) {
|
||||
func (f *forwarder) SetForwardedPort(port uint16) {
|
||||
f.forwardedPort = port
|
||||
}
|
||||
|
||||
func (f *Forwarder) SetListener(listener net.Listener) {
|
||||
func (f *forwarder) SetListener(listener net.Listener) {
|
||||
f.listener = listener
|
||||
}
|
||||
|
||||
func (f *Forwarder) GetListener() net.Listener {
|
||||
func (f *forwarder) Listener() net.Listener {
|
||||
return f.listener
|
||||
}
|
||||
|
||||
func (f *Forwarder) WriteBadGatewayResponse(dst io.Writer) {
|
||||
func (f *forwarder) WriteBadGatewayResponse(dst io.Writer) {
|
||||
_, err := dst.Write(types.BadGatewayResponse)
|
||||
if err != nil {
|
||||
log.Printf("failed to write Bad Gateway response: %v", err)
|
||||
@@ -206,20 +206,20 @@ func (f *Forwarder) WriteBadGatewayResponse(dst io.Writer) {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Forwarder) Close() error {
|
||||
if f.GetListener() != nil {
|
||||
func (f *forwarder) Close() error {
|
||||
if f.Listener() != nil {
|
||||
return f.listener.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *Forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
|
||||
func (f *forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
|
||||
var buf bytes.Buffer
|
||||
|
||||
host, originPort := parseAddr(origin.String())
|
||||
|
||||
writeSSHString(&buf, "localhost")
|
||||
err := binary.Write(&buf, binary.BigEndian, uint32(f.GetForwardedPort()))
|
||||
err := binary.Write(&buf, binary.BigEndian, uint32(f.ForwardedPort()))
|
||||
if err != nil {
|
||||
log.Printf("Failed to write string to buffer: %v", err)
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user