From 9be24db41049549cc029819b661ce61c7c66b840 Mon Sep 17 00:00:00 2001
From: fatedier <fatedier@gmail.com>
Date: Fri, 15 Mar 2019 16:22:41 +0800
Subject: [PATCH] support multilevel subdomain, fix #1132

---
 tests/ci/auto_test_frpc.ini |   6 +
 tests/ci/normal_test.go     |  15 +++
 tests/mock/http_server.go   |   6 +-
 utils/vhost/http.go         | 235 ------------------------------------
 utils/vhost/newhttp.go      |  22 +++-
 utils/vhost/resource.go     |  14 +++
 utils/vhost/vhost.go        |  23 ++--
 7 files changed, 71 insertions(+), 250 deletions(-)
 delete mode 100644 utils/vhost/http.go

diff --git a/tests/ci/auto_test_frpc.ini b/tests/ci/auto_test_frpc.ini
index 407d679e..28ea5fd5 100644
--- a/tests/ci/auto_test_frpc.ini
+++ b/tests/ci/auto_test_frpc.ini
@@ -127,6 +127,12 @@ custom_domains = test6.frp.com
 host_header_rewrite = test6.frp.com
 header_X-From-Where = frp
 
+[wildcard_http]
+type = http
+local_ip = 127.0.0.1
+local_port = 10704
+custom_domains = *.frp1.com
+
 [subhost01]
 type = http
 local_ip = 127.0.0.1
diff --git a/tests/ci/normal_test.go b/tests/ci/normal_test.go
index 24f5795a..4f976c81 100644
--- a/tests/ci/normal_test.go
+++ b/tests/ci/normal_test.go
@@ -182,6 +182,21 @@ func TestHttp(t *testing.T) {
 		assert.Equal("true", header.Get("X-Header-Set"))
 	}
 
+	// wildcard_http
+	// test.frp1.com match *.frp1.com
+	code, body, _, err = util.SendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", consts.TEST_HTTP_FRP_PORT), "test.frp1.com", nil, "")
+	if assert.NoError(err) {
+		assert.Equal(200, code)
+		assert.Equal(consts.TEST_HTTP_NORMAL_STR, body)
+	}
+
+	// new.test.frp1.com also match *.frp1.com
+	code, body, _, err = util.SendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", consts.TEST_HTTP_FRP_PORT), "new.test.frp1.com", nil, "")
+	if assert.NoError(err) {
+		assert.Equal(200, code)
+		assert.Equal(consts.TEST_HTTP_NORMAL_STR, body)
+	}
+
 	// subhost01
 	code, body, _, err = util.SendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", consts.TEST_HTTP_FRP_PORT), "test01.sub.com", nil, "")
 	if assert.NoError(err) {
diff --git a/tests/mock/http_server.go b/tests/mock/http_server.go
index 37b2b1e6..92c51a6c 100644
--- a/tests/mock/http_server.go
+++ b/tests/mock/http_server.go
@@ -88,8 +88,10 @@ func handleHttp(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	if strings.Contains(r.Host, "127.0.0.1") || strings.Contains(r.Host, "test2.frp.com") ||
-		strings.Contains(r.Host, "test5.frp.com") || strings.Contains(r.Host, "test6.frp.com") {
+	if strings.HasPrefix(r.Host, "127.0.0.1") || strings.HasPrefix(r.Host, "test2.frp.com") ||
+		strings.HasPrefix(r.Host, "test5.frp.com") || strings.HasPrefix(r.Host, "test6.frp.com") ||
+		strings.HasPrefix(r.Host, "test.frp1.com") || strings.HasPrefix(r.Host, "new.test.frp1.com") {
+
 		w.WriteHeader(200)
 		w.Write([]byte(consts.TEST_HTTP_NORMAL_STR))
 	} else if strings.Contains(r.Host, "test3.frp.com") {
diff --git a/utils/vhost/http.go b/utils/vhost/http.go
deleted file mode 100644
index 9fc05bdb..00000000
--- a/utils/vhost/http.go
+++ /dev/null
@@ -1,235 +0,0 @@
-// Copyright 2016 fatedier, fatedier@gmail.com
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package vhost
-
-import (
-	"bufio"
-	"bytes"
-	"encoding/base64"
-	"fmt"
-	"io"
-	"net/http"
-	"net/url"
-	"strings"
-	"time"
-
-	frpNet "github.com/fatedier/frp/utils/net"
-
-	gnet "github.com/fatedier/golib/net"
-	"github.com/fatedier/golib/pool"
-)
-
-type HttpMuxer struct {
-	*VhostMuxer
-}
-
-func GetHttpRequestInfo(c frpNet.Conn) (_ frpNet.Conn, _ map[string]string, err error) {
-	reqInfoMap := make(map[string]string, 0)
-	sc, rd := gnet.NewSharedConn(c)
-
-	request, err := http.ReadRequest(bufio.NewReader(rd))
-	if err != nil {
-		return nil, reqInfoMap, err
-	}
-	// hostName
-	tmpArr := strings.Split(request.Host, ":")
-	reqInfoMap["Host"] = tmpArr[0]
-	reqInfoMap["Path"] = request.URL.Path
-	reqInfoMap["Scheme"] = request.URL.Scheme
-
-	// Authorization
-	authStr := request.Header.Get("Authorization")
-	if authStr != "" {
-		reqInfoMap["Authorization"] = authStr
-	}
-	request.Body.Close()
-	return frpNet.WrapConn(sc), reqInfoMap, nil
-}
-
-func NewHttpMuxer(listener frpNet.Listener, timeout time.Duration) (*HttpMuxer, error) {
-	mux, err := NewVhostMuxer(listener, GetHttpRequestInfo, HttpAuthFunc, ModifyHttpRequest, timeout)
-	return &HttpMuxer{mux}, err
-}
-
-func ModifyHttpRequest(c frpNet.Conn, rewriteHost string) (_ frpNet.Conn, err error) {
-	sc, rd := gnet.NewSharedConn(c)
-	var buff []byte
-	remoteIP := strings.Split(c.RemoteAddr().String(), ":")[0]
-	if buff, err = hostNameRewrite(rd, rewriteHost, remoteIP); err != nil {
-		return nil, err
-	}
-	err = sc.ResetBuf(buff)
-	return frpNet.WrapConn(sc), err
-}
-
-func hostNameRewrite(request io.Reader, rewriteHost string, remoteIP string) (_ []byte, err error) {
-	buf := pool.GetBuf(1024)
-	defer pool.PutBuf(buf)
-
-	var n int
-	n, err = request.Read(buf)
-	if err != nil {
-		return
-	}
-	retBuffer, err := parseRequest(buf[:n], rewriteHost, remoteIP)
-	return retBuffer, err
-}
-
-func parseRequest(org []byte, rewriteHost string, remoteIP string) (ret []byte, err error) {
-	tp := bytes.NewBuffer(org)
-	// First line: GET /index.html HTTP/1.0
-	var b []byte
-	if b, err = tp.ReadBytes('\n'); err != nil {
-		return nil, err
-	}
-	req := new(http.Request)
-	// we invoked ReadRequest in GetHttpHostname before, so we ignore error
-	req.Method, req.RequestURI, req.Proto, _ = parseRequestLine(string(b))
-	rawurl := req.RequestURI
-	// CONNECT www.google.com:443 HTTP/1.1
-	justAuthority := req.Method == "CONNECT" && !strings.HasPrefix(rawurl, "/")
-	if justAuthority {
-		rawurl = "http://" + rawurl
-	}
-	req.URL, _ = url.ParseRequestURI(rawurl)
-	if justAuthority {
-		// Strip the bogus "http://" back off.
-		req.URL.Scheme = ""
-	}
-
-	//  RFC2616: first case
-	//  GET /index.html HTTP/1.1
-	//  Host: www.google.com
-	if req.URL.Host == "" {
-		var changedBuf []byte
-		if rewriteHost != "" {
-			changedBuf, err = changeHostName(tp, rewriteHost)
-		}
-		buf := new(bytes.Buffer)
-		buf.Write(b)
-		buf.WriteString(fmt.Sprintf("X-Forwarded-For: %s\r\n", remoteIP))
-		buf.WriteString(fmt.Sprintf("X-Real-IP: %s\r\n", remoteIP))
-		if len(changedBuf) == 0 {
-			tp.WriteTo(buf)
-		} else {
-			buf.Write(changedBuf)
-		}
-		return buf.Bytes(), err
-	}
-
-	// RFC2616: second case
-	// GET http://www.google.com/index.html HTTP/1.1
-	// Host: doesntmatter
-	// In this case, any Host line is ignored.
-	if rewriteHost != "" {
-		hostPort := strings.Split(req.URL.Host, ":")
-		if len(hostPort) == 1 {
-			req.URL.Host = rewriteHost
-		} else if len(hostPort) == 2 {
-			req.URL.Host = fmt.Sprintf("%s:%s", rewriteHost, hostPort[1])
-		}
-	}
-	firstLine := req.Method + " " + req.URL.String() + " " + req.Proto
-	buf := new(bytes.Buffer)
-	buf.WriteString(firstLine)
-	buf.WriteString(fmt.Sprintf("X-Forwarded-For: %s\r\n", remoteIP))
-	buf.WriteString(fmt.Sprintf("X-Real-IP: %s\r\n", remoteIP))
-	tp.WriteTo(buf)
-	return buf.Bytes(), err
-}
-
-// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts.
-func parseRequestLine(line string) (method, requestURI, proto string, ok bool) {
-	s1 := strings.Index(line, " ")
-	s2 := strings.Index(line[s1+1:], " ")
-	if s1 < 0 || s2 < 0 {
-		return
-	}
-	s2 += s1 + 1
-	return line[:s1], line[s1+1 : s2], line[s2+1:], true
-}
-
-func changeHostName(buff *bytes.Buffer, rewriteHost string) (_ []byte, err error) {
-	retBuf := new(bytes.Buffer)
-
-	peek := buff.Bytes()
-	for len(peek) > 0 {
-		i := bytes.IndexByte(peek, '\n')
-		if i < 3 {
-			// Not present (-1) or found within the next few bytes,
-			// implying we're at the end ("\r\n\r\n" or "\n\n")
-			return nil, err
-		}
-		kv := peek[:i]
-		j := bytes.IndexByte(kv, ':')
-		if j < 0 {
-			return nil, fmt.Errorf("malformed MIME header line: " + string(kv))
-		}
-		if strings.Contains(strings.ToLower(string(kv[:j])), "host") {
-			var hostHeader string
-			portPos := bytes.IndexByte(kv[j+1:], ':')
-			if portPos == -1 {
-				hostHeader = fmt.Sprintf("Host: %s\r\n", rewriteHost)
-			} else {
-				hostHeader = fmt.Sprintf("Host: %s:%s\r\n", rewriteHost, kv[j+portPos+2:])
-			}
-			retBuf.WriteString(hostHeader)
-			peek = peek[i+1:]
-			break
-		} else {
-			retBuf.Write(peek[:i])
-			retBuf.WriteByte('\n')
-		}
-
-		peek = peek[i+1:]
-	}
-	retBuf.Write(peek)
-	return retBuf.Bytes(), err
-}
-
-func HttpAuthFunc(c frpNet.Conn, userName, passWord, authorization string) (bAccess bool, err error) {
-	s := strings.SplitN(authorization, " ", 2)
-	if len(s) != 2 {
-		res := noAuthResponse()
-		res.Write(c)
-		return
-	}
-	b, err := base64.StdEncoding.DecodeString(s[1])
-	if err != nil {
-		return
-	}
-	pair := strings.SplitN(string(b), ":", 2)
-	if len(pair) != 2 {
-		return
-	}
-	if pair[0] != userName || pair[1] != passWord {
-		return
-	}
-	return true, nil
-}
-
-func noAuthResponse() *http.Response {
-	header := make(map[string][]string)
-	header["WWW-Authenticate"] = []string{`Basic realm="Restricted"`}
-	res := &http.Response{
-		Status:     "401 Not authorized",
-		StatusCode: 401,
-		Proto:      "HTTP/1.1",
-		ProtoMajor: 1,
-		ProtoMinor: 1,
-		Header:     header,
-	}
-	return res
-}
diff --git a/utils/vhost/newhttp.go b/utils/vhost/newhttp.go
index fef991fa..59a4a0e6 100644
--- a/utils/vhost/newhttp.go
+++ b/utils/vhost/newhttp.go
@@ -18,6 +18,7 @@ import (
 	"bytes"
 	"context"
 	"errors"
+	"fmt"
 	"log"
 	"net"
 	"net/http"
@@ -145,7 +146,7 @@ func (rp *HttpReverseProxy) CreateConnection(domain string, location string) (ne
 			return fn()
 		}
 	}
-	return nil, ErrNoDomain
+	return nil, fmt.Errorf("%v: %s %s", ErrNoDomain, domain, location)
 }
 
 func (rp *HttpReverseProxy) CheckAuth(domain, location, user, passwd string) bool {
@@ -173,11 +174,22 @@ func (rp *HttpReverseProxy) getVhost(domain string, location string) (vr *VhostR
 
 	domainSplit := strings.Split(domain, ".")
 	if len(domainSplit) < 3 {
-		return vr, false
+		return nil, false
+	}
+
+	for {
+		if len(domainSplit) < 3 {
+			return nil, false
+		}
+
+		domainSplit[0] = "*"
+		domain = strings.Join(domainSplit, ".")
+		vr, ok = rp.vhostRouter.Get(domain, location)
+		if ok {
+			return vr, true
+		}
+		domainSplit = domainSplit[1:]
 	}
-	domainSplit[0] = "*"
-	domain = strings.Join(domainSplit, ".")
-	vr, ok = rp.vhostRouter.Get(domain, location)
 	return
 }
 
diff --git a/utils/vhost/resource.go b/utils/vhost/resource.go
index ed8149c6..40cb9523 100644
--- a/utils/vhost/resource.go
+++ b/utils/vhost/resource.go
@@ -61,3 +61,17 @@ func notFoundResponse() *http.Response {
 	}
 	return res
 }
+
+func noAuthResponse() *http.Response {
+	header := make(map[string][]string)
+	header["WWW-Authenticate"] = []string{`Basic realm="Restricted"`}
+	res := &http.Response{
+		Status:     "401 Not authorized",
+		StatusCode: 401,
+		Proto:      "HTTP/1.1",
+		ProtoMajor: 1,
+		ProtoMinor: 1,
+		Header:     header,
+	}
+	return res
+}
diff --git a/utils/vhost/vhost.go b/utils/vhost/vhost.go
index 84e57e95..2e386524 100644
--- a/utils/vhost/vhost.go
+++ b/utils/vhost/vhost.go
@@ -102,17 +102,24 @@ func (v *VhostMuxer) getListener(name, path string) (l *Listener, exist bool) {
 
 	domainSplit := strings.Split(name, ".")
 	if len(domainSplit) < 3 {
-		return l, false
-	}
-	domainSplit[0] = "*"
-	name = strings.Join(domainSplit, ".")
-
-	vr, found = v.registryRouter.Get(name, path)
-	if !found {
 		return
 	}
 
-	return vr.payload.(*Listener), true
+	for {
+		if len(domainSplit) < 3 {
+			return
+		}
+
+		domainSplit[0] = "*"
+		name = strings.Join(domainSplit, ".")
+
+		vr, found = v.registryRouter.Get(name, path)
+		if found {
+			return vr.payload.(*Listener), true
+		}
+		domainSplit = domainSplit[1:]
+	}
+	return
 }
 
 func (v *VhostMuxer) run() {