diff --git a/server/http.go b/server/http.go index 4886f87..945cfc7 100644 --- a/server/http.go +++ b/server/http.go @@ -59,9 +59,34 @@ func (w *connResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { } var redirectTLS = false +var allowedCors = make(map[string]bool) +var isAllowedAllCors = false + +func init() { + corsList := utils.Getenv("cors_list") + if corsList == "*" { + isAllowedAllCors = true + } else { + for _, allowedOrigin := range strings.Split(corsList, ",") { + allowedCors[allowedOrigin] = true + } + } + fmt.Println(allowedCors) +} func NewHTTPServer() error { - upgrader.CheckOrigin = func(r *http.Request) bool { return true } + upgrader.CheckOrigin = func(r *http.Request) bool { + if isAllowedAllCors { + return true + } else { + isAllowed, ok := allowedCors[r.Host] + if !ok || !isAllowed { + return false + } + return true + } + } + listener, err := net.Listen("tcp", ":80") if err != nil { return errors.New("Error listening: " + err.Error()) @@ -97,16 +122,10 @@ func Handler(conn net.Conn) { host := strings.Split(parseHostFromHeader(headers), ".") if len(host) < 1 { conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) - log.Println("Bad Request") conn.Close() return } - if len(host) < 1 { - conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) - conn.Close() - return - } slug := host[0] if redirectTLS { diff --git a/server/https.go b/server/https.go index f7d9827..fbaf3f9 100644 --- a/server/https.go +++ b/server/https.go @@ -59,11 +59,6 @@ func HandlerTLS(conn net.Conn) { return } - if len(host) < 1 { - conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) - conn.Close() - return - } slug := host[0] if slug == "ping" {