diff --git a/server/http.go b/server/http.go index f51889b..f416b67 100644 --- a/server/http.go +++ b/server/http.go @@ -59,9 +59,33 @@ func (w *connResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { } var redirectTLS = false +var allowedCors = make(map[string]bool) +var isAllowedAllCors = false + +func init() { + corsList := utils.Getenv("cors_list") + if corsList == "*" { + isAllowedAllCors = true + } else { + for _, allowedOrigin := range strings.Split(corsList, ",") { + allowedCors[allowedOrigin] = true + } + } +} func NewHTTPServer() error { - upgrader.CheckOrigin = func(r *http.Request) bool { return true } + upgrader.CheckOrigin = func(r *http.Request) bool { + if isAllowedAllCors { + return true + } else { + isAllowed, ok := allowedCors[r.Header.Get("Origin")] + if !ok || !isAllowed { + return false + } + return true + } + } + listener, err := net.Listen("tcp", ":80") if err != nil { return errors.New("Error listening: " + err.Error()) @@ -97,16 +121,10 @@ func Handler(conn net.Conn) { host := strings.Split(parseHostFromHeader(headers), ".") if len(host) < 1 { conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) - log.Println("Bad Request") conn.Close() return } - if len(host) < 1 { - conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) - conn.Close() - return - } slug := host[0] if redirectTLS { @@ -155,7 +173,11 @@ func Handler(conn net.Conn) { sshSession, ok := session.Clients[slug] if !ok { - conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) + conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" + + fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) + + "Content-Length: 0\r\n" + + "Connection: close\r\n" + + "\r\n")) conn.Close() return } diff --git a/server/https.go b/server/https.go index 043e74a..fbaf3f9 100644 --- a/server/https.go +++ b/server/https.go @@ -4,6 +4,7 @@ import ( "bufio" "crypto/tls" "errors" + "fmt" "log" "net" "net/http" @@ -58,11 +59,6 @@ func HandlerTLS(conn net.Conn) { return } - if len(host) < 1 { - conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) - conn.Close() - return - } slug := host[0] if slug == "ping" { @@ -101,7 +97,11 @@ func HandlerTLS(conn net.Conn) { sshSession, ok := session.Clients[slug] if !ok { - conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) + conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" + + fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) + + "Content-Length: 0\r\n" + + "Connection: close\r\n" + + "\r\n")) conn.Close() return } diff --git a/utils/utils.go b/utils/utils.go index a8c3d37..d5d05da 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -1,13 +1,14 @@ package utils import ( - "github.com/joho/godotenv" "log" "math/rand" "os" "strings" "sync" "time" + + "github.com/joho/godotenv" ) type Env struct {