test(transport): add unit tests for transport behavior using Testify
SonarQube Scan / SonarQube Trigger (push) Successful in 1m51s
SonarQube Scan / SonarQube Trigger (push) Successful in 1m51s
This commit is contained in:
@@ -19,6 +19,8 @@ import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var openChannelTimeout = 5 * time.Second
|
||||
|
||||
type httpHandler struct {
|
||||
domain string
|
||||
sessionRegistry registry.Registry
|
||||
@@ -52,7 +54,7 @@ func (hh *httpHandler) badRequest(conn net.Conn) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hh *httpHandler) handler(conn net.Conn, isTLS bool) {
|
||||
func (hh *httpHandler) Handler(conn net.Conn, isTLS bool) {
|
||||
defer hh.closeConnection(conn)
|
||||
|
||||
dstReader := bufio.NewReader(conn)
|
||||
@@ -69,7 +71,7 @@ func (hh *httpHandler) handler(conn net.Conn, isTLS bool) {
|
||||
}
|
||||
|
||||
if hh.shouldRedirectToTLS(isTLS) {
|
||||
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, hh.domain))
|
||||
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://%s.%s/\r\n", slug, hh.domain))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -77,7 +79,10 @@ func (hh *httpHandler) handler(conn net.Conn, isTLS bool) {
|
||||
return
|
||||
}
|
||||
|
||||
sshSession, err := hh.getSession(slug)
|
||||
sshSession, err := hh.sessionRegistry.Get(types.SessionKey{
|
||||
Id: slug,
|
||||
Type: types.TunnelTypeHTTP,
|
||||
})
|
||||
if err != nil {
|
||||
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug))
|
||||
return
|
||||
@@ -102,7 +107,7 @@ func (hh *httpHandler) closeConnection(conn net.Conn) {
|
||||
|
||||
func (hh *httpHandler) extractSlug(reqhf header.RequestHeader) (string, error) {
|
||||
host := strings.Split(reqhf.Value("Host"), ".")
|
||||
if len(host) < 1 {
|
||||
if len(host) <= 1 {
|
||||
return "", errors.New("invalid host")
|
||||
}
|
||||
return host[0], nil
|
||||
@@ -128,21 +133,11 @@ func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool {
|
||||
))
|
||||
if err != nil {
|
||||
log.Println("Failed to write 200 OK:", err)
|
||||
return true
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (hh *httpHandler) getSession(slug string) (registry.Session, error) {
|
||||
sshSession, err := hh.sessionRegistry.Get(types.SessionKey{
|
||||
Id: slug,
|
||||
Type: types.TunnelTypeHTTP,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sshSession, nil
|
||||
}
|
||||
|
||||
func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) {
|
||||
channel, err := hh.openForwardedChannel(hw, sshSession)
|
||||
if err != nil {
|
||||
@@ -180,11 +175,7 @@ func (hh *httpHandler) openForwardedChannel(hw stream.HTTP, sshSession registry.
|
||||
|
||||
go func() {
|
||||
channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload)
|
||||
select {
|
||||
case resultChan <- channelResult{channel, reqs, err}:
|
||||
default:
|
||||
hh.cleanupUnusedChannel(channel, reqs)
|
||||
}
|
||||
resultChan <- channelResult{channel, reqs, err}
|
||||
}()
|
||||
|
||||
select {
|
||||
@@ -194,7 +185,11 @@ func (hh *httpHandler) openForwardedChannel(hw stream.HTTP, sshSession registry.
|
||||
}
|
||||
go ssh.DiscardRequests(result.reqs)
|
||||
return result.channel, nil
|
||||
case <-time.After(5 * time.Second):
|
||||
case <-time.After(openChannelTimeout):
|
||||
go func() {
|
||||
result := <-resultChan
|
||||
hh.cleanupUnusedChannel(result.channel, result.reqs)
|
||||
}()
|
||||
return nil, errors.New("timeout opening forwarded-tcpip channel")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user