mirror of https://github.com/fatedier/frp.git
Merge pull request #746 from fatedier/mux
http port and https port can be same with frps bind_portpull/750/head
commit
178efd67f1
@ -0,0 +1,210 @@
|
|||||||
|
package mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sort"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fatedier/frp/utils/errors"
|
||||||
|
frpNet "github.com/fatedier/frp/utils/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultTimeout is the default length of time to wait for bytes we need.
|
||||||
|
DefaultTimeout = 10 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
type Mux struct {
|
||||||
|
ln net.Listener
|
||||||
|
|
||||||
|
defaultLn *listener
|
||||||
|
lns []*listener
|
||||||
|
maxNeedBytesNum uint32
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMux() (mux *Mux) {
|
||||||
|
mux = &Mux{
|
||||||
|
lns: make([]*listener, 0),
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mux *Mux) Listen(priority int, needBytesNum uint32, fn MatchFunc) net.Listener {
|
||||||
|
ln := &listener{
|
||||||
|
c: make(chan net.Conn),
|
||||||
|
mux: mux,
|
||||||
|
needBytesNum: needBytesNum,
|
||||||
|
matchFn: fn,
|
||||||
|
}
|
||||||
|
|
||||||
|
mux.mu.Lock()
|
||||||
|
defer mux.mu.Unlock()
|
||||||
|
if needBytesNum > mux.maxNeedBytesNum {
|
||||||
|
mux.maxNeedBytesNum = needBytesNum
|
||||||
|
}
|
||||||
|
|
||||||
|
newlns := append(mux.copyLns(), ln)
|
||||||
|
sort.Slice(newlns, func(i, j int) bool {
|
||||||
|
return newlns[i].needBytesNum < newlns[j].needBytesNum
|
||||||
|
})
|
||||||
|
mux.lns = newlns
|
||||||
|
return ln
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mux *Mux) ListenHttp(priority int) net.Listener {
|
||||||
|
return mux.Listen(priority, HttpNeedBytesNum, HttpMatchFunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mux *Mux) ListenHttps(priority int) net.Listener {
|
||||||
|
return mux.Listen(priority, HttpsNeedBytesNum, HttpsMatchFunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mux *Mux) DefaultListener() net.Listener {
|
||||||
|
mux.mu.Lock()
|
||||||
|
defer mux.mu.Unlock()
|
||||||
|
if mux.defaultLn == nil {
|
||||||
|
mux.defaultLn = &listener{
|
||||||
|
c: make(chan net.Conn),
|
||||||
|
mux: mux,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return mux.defaultLn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mux *Mux) release(ln *listener) bool {
|
||||||
|
result := false
|
||||||
|
mux.mu.Lock()
|
||||||
|
defer mux.mu.Unlock()
|
||||||
|
lns := mux.copyLns()
|
||||||
|
|
||||||
|
for i, l := range lns {
|
||||||
|
if l == ln {
|
||||||
|
lns = append(lns[:i], lns[i+1:]...)
|
||||||
|
result = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mux.lns = lns
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mux *Mux) copyLns() []*listener {
|
||||||
|
lns := make([]*listener, 0, len(mux.lns))
|
||||||
|
for _, l := range mux.lns {
|
||||||
|
lns = append(lns, l)
|
||||||
|
}
|
||||||
|
return lns
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serve handles connections from ln and multiplexes then across registered listeners.
|
||||||
|
func (mux *Mux) Serve(ln net.Listener) error {
|
||||||
|
mux.mu.Lock()
|
||||||
|
mux.ln = ln
|
||||||
|
mux.mu.Unlock()
|
||||||
|
for {
|
||||||
|
// Wait for the next connection.
|
||||||
|
// If it returns a temporary error then simply retry.
|
||||||
|
// If it returns any other error then exit immediately.
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err, ok := err.(interface {
|
||||||
|
Temporary() bool
|
||||||
|
}); ok && err.Temporary() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
go mux.handleConn(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mux *Mux) handleConn(conn net.Conn) {
|
||||||
|
mux.mu.RLock()
|
||||||
|
maxNeedBytesNum := mux.maxNeedBytesNum
|
||||||
|
lns := mux.lns
|
||||||
|
defaultLn := mux.defaultLn
|
||||||
|
mux.mu.RUnlock()
|
||||||
|
|
||||||
|
shareConn, rd := frpNet.NewShareConnSize(frpNet.WrapConn(conn), int(maxNeedBytesNum))
|
||||||
|
data := make([]byte, maxNeedBytesNum)
|
||||||
|
|
||||||
|
conn.SetReadDeadline(time.Now().Add(DefaultTimeout))
|
||||||
|
_, err := io.ReadFull(rd, data)
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn.SetReadDeadline(time.Time{})
|
||||||
|
|
||||||
|
for _, ln := range lns {
|
||||||
|
if match := ln.matchFn(data); match {
|
||||||
|
err = errors.PanicToError(func() {
|
||||||
|
ln.c <- shareConn
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// No match listeners
|
||||||
|
if defaultLn != nil {
|
||||||
|
err = errors.PanicToError(func() {
|
||||||
|
defaultLn.c <- shareConn
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// No listeners for this connection, close it.
|
||||||
|
conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type listener struct {
|
||||||
|
mux *Mux
|
||||||
|
|
||||||
|
needBytesNum uint32
|
||||||
|
matchFn MatchFunc
|
||||||
|
|
||||||
|
c chan net.Conn
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accept waits for and returns the next connection to the listener.
|
||||||
|
func (ln *listener) Accept() (net.Conn, error) {
|
||||||
|
conn, ok := <-ln.c
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("network connection closed")
|
||||||
|
}
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close removes this listener from the parent mux and closes the channel.
|
||||||
|
func (ln *listener) Close() error {
|
||||||
|
if ok := ln.mux.release(ln); ok {
|
||||||
|
// Close done to signal to any RLock holders to release their lock.
|
||||||
|
close(ln.c)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ln *listener) Addr() net.Addr {
|
||||||
|
if ln.mux == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
ln.mux.mu.RLock()
|
||||||
|
defer ln.mux.mu.RUnlock()
|
||||||
|
if ln.mux.ln == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return ln.mux.ln.Addr()
|
||||||
|
}
|
@ -0,0 +1,95 @@
|
|||||||
|
package mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"io/ioutil"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func runHttpSvr(ln net.Listener) *httptest.Server {
|
||||||
|
svr := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("http service"))
|
||||||
|
}))
|
||||||
|
svr.Listener = ln
|
||||||
|
svr.Start()
|
||||||
|
return svr
|
||||||
|
}
|
||||||
|
|
||||||
|
func runHttpsSvr(ln net.Listener) *httptest.Server {
|
||||||
|
svr := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("https service"))
|
||||||
|
}))
|
||||||
|
svr.Listener = ln
|
||||||
|
svr.StartTLS()
|
||||||
|
return svr
|
||||||
|
}
|
||||||
|
|
||||||
|
func runEchoSvr(ln net.Listener) {
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rd := bufio.NewReader(conn)
|
||||||
|
data, err := rd.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn.Write([]byte(data))
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMux(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:")
|
||||||
|
assert.NoError(err)
|
||||||
|
|
||||||
|
mux := NewMux()
|
||||||
|
httpLn := mux.ListenHttp(0)
|
||||||
|
httpsLn := mux.ListenHttps(0)
|
||||||
|
defaultLn := mux.DefaultListener()
|
||||||
|
go mux.Serve(ln)
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
httpSvr := runHttpSvr(httpLn)
|
||||||
|
defer httpSvr.Close()
|
||||||
|
httpsSvr := runHttpsSvr(httpsLn)
|
||||||
|
defer httpsSvr.Close()
|
||||||
|
runEchoSvr(defaultLn)
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
// test http service
|
||||||
|
resp, err := http.Get(httpSvr.URL)
|
||||||
|
assert.NoError(err)
|
||||||
|
data, err := ioutil.ReadAll(resp.Body)
|
||||||
|
assert.NoError(err)
|
||||||
|
assert.Equal("http service", string(data))
|
||||||
|
|
||||||
|
// test https service
|
||||||
|
client := httpsSvr.Client()
|
||||||
|
resp, err = client.Get(httpsSvr.URL)
|
||||||
|
assert.NoError(err)
|
||||||
|
data, err = ioutil.ReadAll(resp.Body)
|
||||||
|
assert.NoError(err)
|
||||||
|
assert.Equal("https service", string(data))
|
||||||
|
|
||||||
|
// test echo service
|
||||||
|
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||||
|
assert.NoError(err)
|
||||||
|
_, err = conn.Write([]byte("test echo\n"))
|
||||||
|
assert.NoError(err)
|
||||||
|
data = make([]byte, 1024)
|
||||||
|
n, err := conn.Read(data)
|
||||||
|
assert.NoError(err)
|
||||||
|
assert.Equal("test echo\n", string(data[:n]))
|
||||||
|
}
|
@ -0,0 +1,55 @@
|
|||||||
|
package mux
|
||||||
|
|
||||||
|
type MatchFunc func(data []byte) (match bool)
|
||||||
|
|
||||||
|
var (
|
||||||
|
HttpsNeedBytesNum uint32 = 1
|
||||||
|
HttpNeedBytesNum uint32 = 3
|
||||||
|
YamuxNeedBytesNum uint32 = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
var HttpsMatchFunc MatchFunc = func(data []byte) bool {
|
||||||
|
if len(data) < int(HttpsNeedBytesNum) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if data[0] == 0x16 {
|
||||||
|
return true
|
||||||
|
} else {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// From https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods
|
||||||
|
var httpHeadBytes = map[string]struct{}{
|
||||||
|
"GET": struct{}{},
|
||||||
|
"HEA": struct{}{},
|
||||||
|
"POS": struct{}{},
|
||||||
|
"PUT": struct{}{},
|
||||||
|
"DEL": struct{}{},
|
||||||
|
"CON": struct{}{},
|
||||||
|
"OPT": struct{}{},
|
||||||
|
"TRA": struct{}{},
|
||||||
|
"PAT": struct{}{},
|
||||||
|
}
|
||||||
|
|
||||||
|
var HttpMatchFunc MatchFunc = func(data []byte) bool {
|
||||||
|
if len(data) < int(HttpNeedBytesNum) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ok := httpHeadBytes[string(data[:3])]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// From https://github.com/hashicorp/yamux/blob/master/spec.md
|
||||||
|
var YamuxMatchFunc MatchFunc = func(data []byte) bool {
|
||||||
|
if len(data) < int(YamuxNeedBytesNum) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if data[0] == 0 && data[1] >= 0x0 && data[1] <= 0x3 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
Loading…
Reference in New Issue