mirror of https://github.com/fatedier/frp.git
sshTunnelGateway refactor (#3784)
parent
8b432e179d
commit
d5b41f1e14
@ -0,0 +1,223 @@
|
||||
// Copyright 2023 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
libdial "github.com/fatedier/golib/net/dial"
|
||||
fmux "github.com/hashicorp/yamux"
|
||||
quic "github.com/quic-go/quic-go"
|
||||
"github.com/samber/lo"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/transport"
|
||||
utilnet "github.com/fatedier/frp/pkg/util/net"
|
||||
"github.com/fatedier/frp/pkg/util/xlog"
|
||||
)
|
||||
|
||||
// Connector is a interface for establishing connections to the server.
|
||||
type Connector interface {
|
||||
Open() error
|
||||
Connect() (net.Conn, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
// defaultConnectorImpl is the default implementation of Connector for normal frpc.
|
||||
type defaultConnectorImpl struct {
|
||||
ctx context.Context
|
||||
cfg *v1.ClientCommonConfig
|
||||
|
||||
muxSession *fmux.Session
|
||||
quicConn quic.Connection
|
||||
}
|
||||
|
||||
func NewConnector(ctx context.Context, cfg *v1.ClientCommonConfig) Connector {
|
||||
return &defaultConnectorImpl{
|
||||
ctx: ctx,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// Open opens a underlying connection to the server.
|
||||
// The underlying connection is either a TCP connection or a QUIC connection.
|
||||
// After the underlying connection is established, you can call Connect() to get a stream.
|
||||
// If TCPMux isn't enabled, the underlying connection is nil, you will get a new real TCP connection every time you call Connect().
|
||||
func (c *defaultConnectorImpl) Open() error {
|
||||
xl := xlog.FromContextSafe(c.ctx)
|
||||
|
||||
// special for quic
|
||||
if strings.EqualFold(c.cfg.Transport.Protocol, "quic") {
|
||||
var tlsConfig *tls.Config
|
||||
var err error
|
||||
sn := c.cfg.Transport.TLS.ServerName
|
||||
if sn == "" {
|
||||
sn = c.cfg.ServerAddr
|
||||
}
|
||||
if lo.FromPtr(c.cfg.Transport.TLS.Enable) {
|
||||
tlsConfig, err = transport.NewClientTLSConfig(
|
||||
c.cfg.Transport.TLS.CertFile,
|
||||
c.cfg.Transport.TLS.KeyFile,
|
||||
c.cfg.Transport.TLS.TrustedCaFile,
|
||||
sn)
|
||||
} else {
|
||||
tlsConfig, err = transport.NewClientTLSConfig("", "", "", sn)
|
||||
}
|
||||
if err != nil {
|
||||
xl.Warn("fail to build tls configuration, err: %v", err)
|
||||
return err
|
||||
}
|
||||
tlsConfig.NextProtos = []string{"frp"}
|
||||
|
||||
conn, err := quic.DialAddr(
|
||||
c.ctx,
|
||||
net.JoinHostPort(c.cfg.ServerAddr, strconv.Itoa(c.cfg.ServerPort)),
|
||||
tlsConfig, &quic.Config{
|
||||
MaxIdleTimeout: time.Duration(c.cfg.Transport.QUIC.MaxIdleTimeout) * time.Second,
|
||||
MaxIncomingStreams: int64(c.cfg.Transport.QUIC.MaxIncomingStreams),
|
||||
KeepAlivePeriod: time.Duration(c.cfg.Transport.QUIC.KeepalivePeriod) * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.quicConn = conn
|
||||
return nil
|
||||
}
|
||||
|
||||
if !lo.FromPtr(c.cfg.Transport.TCPMux) {
|
||||
return nil
|
||||
}
|
||||
|
||||
conn, err := c.realConnect()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmuxCfg := fmux.DefaultConfig()
|
||||
fmuxCfg.KeepAliveInterval = time.Duration(c.cfg.Transport.TCPMuxKeepaliveInterval) * time.Second
|
||||
fmuxCfg.LogOutput = io.Discard
|
||||
fmuxCfg.MaxStreamWindowSize = 6 * 1024 * 1024
|
||||
session, err := fmux.Client(conn, fmuxCfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.muxSession = session
|
||||
return nil
|
||||
}
|
||||
|
||||
// Connect returns a stream from the underlying connection, or a new TCP connection if TCPMux isn't enabled.
|
||||
func (c *defaultConnectorImpl) Connect() (net.Conn, error) {
|
||||
if c.quicConn != nil {
|
||||
stream, err := c.quicConn.OpenStreamSync(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return utilnet.QuicStreamToNetConn(stream, c.quicConn), nil
|
||||
} else if c.muxSession != nil {
|
||||
stream, err := c.muxSession.OpenStream()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
return c.realConnect()
|
||||
}
|
||||
|
||||
func (c *defaultConnectorImpl) realConnect() (net.Conn, error) {
|
||||
xl := xlog.FromContextSafe(c.ctx)
|
||||
var tlsConfig *tls.Config
|
||||
var err error
|
||||
tlsEnable := lo.FromPtr(c.cfg.Transport.TLS.Enable)
|
||||
if c.cfg.Transport.Protocol == "wss" {
|
||||
tlsEnable = true
|
||||
}
|
||||
if tlsEnable {
|
||||
sn := c.cfg.Transport.TLS.ServerName
|
||||
if sn == "" {
|
||||
sn = c.cfg.ServerAddr
|
||||
}
|
||||
|
||||
tlsConfig, err = transport.NewClientTLSConfig(
|
||||
c.cfg.Transport.TLS.CertFile,
|
||||
c.cfg.Transport.TLS.KeyFile,
|
||||
c.cfg.Transport.TLS.TrustedCaFile,
|
||||
sn)
|
||||
if err != nil {
|
||||
xl.Warn("fail to build tls configuration, err: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
proxyType, addr, auth, err := libdial.ParseProxyURL(c.cfg.Transport.ProxyURL)
|
||||
if err != nil {
|
||||
xl.Error("fail to parse proxy url")
|
||||
return nil, err
|
||||
}
|
||||
dialOptions := []libdial.DialOption{}
|
||||
protocol := c.cfg.Transport.Protocol
|
||||
switch protocol {
|
||||
case "websocket":
|
||||
protocol = "tcp"
|
||||
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: utilnet.DialHookWebsocket(protocol, "")}))
|
||||
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{
|
||||
Hook: utilnet.DialHookCustomTLSHeadByte(tlsConfig != nil, lo.FromPtr(c.cfg.Transport.TLS.DisableCustomTLSFirstByte)),
|
||||
}))
|
||||
dialOptions = append(dialOptions, libdial.WithTLSConfig(tlsConfig))
|
||||
case "wss":
|
||||
protocol = "tcp"
|
||||
dialOptions = append(dialOptions, libdial.WithTLSConfigAndPriority(100, tlsConfig))
|
||||
// Make sure that if it is wss, the websocket hook is executed after the tls hook.
|
||||
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: utilnet.DialHookWebsocket(protocol, tlsConfig.ServerName), Priority: 110}))
|
||||
default:
|
||||
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{
|
||||
Hook: utilnet.DialHookCustomTLSHeadByte(tlsConfig != nil, lo.FromPtr(c.cfg.Transport.TLS.DisableCustomTLSFirstByte)),
|
||||
}))
|
||||
dialOptions = append(dialOptions, libdial.WithTLSConfig(tlsConfig))
|
||||
}
|
||||
|
||||
if c.cfg.Transport.ConnectServerLocalIP != "" {
|
||||
dialOptions = append(dialOptions, libdial.WithLocalAddr(c.cfg.Transport.ConnectServerLocalIP))
|
||||
}
|
||||
dialOptions = append(dialOptions,
|
||||
libdial.WithProtocol(protocol),
|
||||
libdial.WithTimeout(time.Duration(c.cfg.Transport.DialServerTimeout)*time.Second),
|
||||
libdial.WithKeepAlive(time.Duration(c.cfg.Transport.DialServerKeepAlive)*time.Second),
|
||||
libdial.WithProxy(proxyType, addr),
|
||||
libdial.WithProxyAuth(auth),
|
||||
)
|
||||
conn, err := libdial.DialContext(
|
||||
c.ctx,
|
||||
net.JoinHostPort(c.cfg.ServerAddr, strconv.Itoa(c.cfg.ServerPort)),
|
||||
dialOptions...,
|
||||
)
|
||||
return conn, err
|
||||
}
|
||||
|
||||
func (c *defaultConnectorImpl) Close() error {
|
||||
if c.quicConn != nil {
|
||||
_ = c.quicConn.CloseWithError(0, "")
|
||||
}
|
||||
if c.muxSession != nil {
|
||||
_ = c.muxSession.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
@ -1,110 +0,0 @@
|
||||
// Copyright 2023 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/fatedier/frp/pkg/config/types"
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
)
|
||||
|
||||
type PortsRangeSliceFlag struct {
|
||||
V *[]types.PortsRange
|
||||
}
|
||||
|
||||
func (f *PortsRangeSliceFlag) String() string {
|
||||
if f.V == nil {
|
||||
return ""
|
||||
}
|
||||
return types.PortsRangeSlice(*f.V).String()
|
||||
}
|
||||
|
||||
func (f *PortsRangeSliceFlag) Set(s string) error {
|
||||
slice, err := types.NewPortsRangeSliceFromString(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*f.V = slice
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *PortsRangeSliceFlag) Type() string {
|
||||
return "string"
|
||||
}
|
||||
|
||||
type BoolFuncFlag struct {
|
||||
TrueFunc func()
|
||||
FalseFunc func()
|
||||
|
||||
v bool
|
||||
}
|
||||
|
||||
func (f *BoolFuncFlag) String() string {
|
||||
return strconv.FormatBool(f.v)
|
||||
}
|
||||
|
||||
func (f *BoolFuncFlag) Set(s string) error {
|
||||
f.v = strconv.FormatBool(f.v) == "true"
|
||||
|
||||
if !f.v {
|
||||
if f.FalseFunc != nil {
|
||||
f.FalseFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if f.TrueFunc != nil {
|
||||
f.TrueFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *BoolFuncFlag) Type() string {
|
||||
return "bool"
|
||||
}
|
||||
|
||||
func RegisterServerConfigFlags(cmd *cobra.Command, c *v1.ServerConfig) {
|
||||
cmd.PersistentFlags().StringVarP(&c.BindAddr, "bind_addr", "", "0.0.0.0", "bind address")
|
||||
cmd.PersistentFlags().IntVarP(&c.BindPort, "bind_port", "p", 7000, "bind port")
|
||||
cmd.PersistentFlags().IntVarP(&c.KCPBindPort, "kcp_bind_port", "", 0, "kcp bind udp port")
|
||||
cmd.PersistentFlags().StringVarP(&c.ProxyBindAddr, "proxy_bind_addr", "", "0.0.0.0", "proxy bind address")
|
||||
cmd.PersistentFlags().IntVarP(&c.VhostHTTPPort, "vhost_http_port", "", 0, "vhost http port")
|
||||
cmd.PersistentFlags().IntVarP(&c.VhostHTTPSPort, "vhost_https_port", "", 0, "vhost https port")
|
||||
cmd.PersistentFlags().Int64VarP(&c.VhostHTTPTimeout, "vhost_http_timeout", "", 60, "vhost http response header timeout")
|
||||
cmd.PersistentFlags().StringVarP(&c.WebServer.Addr, "dashboard_addr", "", "0.0.0.0", "dashboard address")
|
||||
cmd.PersistentFlags().IntVarP(&c.WebServer.Port, "dashboard_port", "", 0, "dashboard port")
|
||||
cmd.PersistentFlags().StringVarP(&c.WebServer.User, "dashboard_user", "", "admin", "dashboard user")
|
||||
cmd.PersistentFlags().StringVarP(&c.WebServer.Password, "dashboard_pwd", "", "admin", "dashboard password")
|
||||
cmd.PersistentFlags().BoolVarP(&c.EnablePrometheus, "enable_prometheus", "", false, "enable prometheus dashboard")
|
||||
cmd.PersistentFlags().StringVarP(&c.Log.To, "log_file", "", "console", "log file")
|
||||
cmd.PersistentFlags().StringVarP(&c.Log.Level, "log_level", "", "info", "log level")
|
||||
cmd.PersistentFlags().Int64VarP(&c.Log.MaxDays, "log_max_days", "", 3, "log max days")
|
||||
cmd.PersistentFlags().BoolVarP(&c.Log.DisablePrintColor, "disable_log_color", "", false, "disable log color in console")
|
||||
cmd.PersistentFlags().StringVarP(&c.Auth.Token, "token", "t", "", "auth token")
|
||||
cmd.PersistentFlags().StringVarP(&c.SubDomainHost, "subdomain_host", "", "", "subdomain host")
|
||||
cmd.PersistentFlags().VarP(&PortsRangeSliceFlag{V: &c.AllowPorts}, "allow_ports", "", "allow ports")
|
||||
cmd.PersistentFlags().Int64VarP(&c.MaxPortsPerClient, "max_ports_per_client", "", 0, "max ports per client")
|
||||
cmd.PersistentFlags().BoolVarP(&c.Transport.TLS.Force, "tls_only", "", false, "frps tls only")
|
||||
|
||||
webServerTLS := v1.TLSConfig{}
|
||||
cmd.PersistentFlags().StringVarP(&webServerTLS.CertFile, "dashboard_tls_cert_file", "", "", "dashboard tls cert file")
|
||||
cmd.PersistentFlags().StringVarP(&webServerTLS.KeyFile, "dashboard_tls_key_file", "", "", "dashboard tls key file")
|
||||
cmd.PersistentFlags().VarP(&BoolFuncFlag{
|
||||
TrueFunc: func() { c.WebServer.TLS = &webServerTLS },
|
||||
}, "dashboard_tls_mode", "", "if enable dashboard tls mode")
|
||||
}
|
@ -1,72 +0,0 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
// custom define
|
||||
SSHClientLoginUserPrefix = "_frpc_ssh_client_"
|
||||
)
|
||||
|
||||
// encodePrivateKeyToPEM encodes Private Key from RSA to PEM format
|
||||
func GeneratePrivateKey() ([]byte, error) {
|
||||
privateKey, err := generatePrivateKey()
|
||||
if err != nil {
|
||||
return nil, errors.New("gen private key error")
|
||||
}
|
||||
|
||||
privBlock := pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Headers: nil,
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
|
||||
}
|
||||
|
||||
return pem.EncodeToMemory(&privBlock), nil
|
||||
}
|
||||
|
||||
// generatePrivateKey creates a RSA Private Key of specified byte size
|
||||
func generatePrivateKey() (*rsa.PrivateKey, error) {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 4096)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = privateKey.Validate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return privateKey, nil
|
||||
}
|
||||
|
||||
func LoadSSHPublicKeyFilesInDir(dirPath string) (map[string]ssh.PublicKey, error) {
|
||||
fileMap := make(map[string]ssh.PublicKey)
|
||||
files, err := os.ReadDir(dirPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
filePath := filepath.Join(dirPath, file.Name())
|
||||
content, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parsedAuthorizedKey, _, _, _, err := ssh.ParseAuthorizedKey(content)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
fileMap[ssh.FingerprintSHA256(parsedAuthorizedKey)] = parsedAuthorizedKey
|
||||
}
|
||||
|
||||
return fileMap, nil
|
||||
}
|
@ -0,0 +1,149 @@
|
||||
// Copyright 2023 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/transport"
|
||||
"github.com/fatedier/frp/pkg/util/log"
|
||||
utilnet "github.com/fatedier/frp/pkg/util/net"
|
||||
)
|
||||
|
||||
type Gateway struct {
|
||||
bindPort int
|
||||
ln net.Listener
|
||||
|
||||
serverPeerListener *utilnet.InternalListener
|
||||
|
||||
sshConfig *ssh.ServerConfig
|
||||
}
|
||||
|
||||
func NewGateway(
|
||||
cfg v1.SSHTunnelGateway, bindAddr string,
|
||||
serverPeerListener *utilnet.InternalListener,
|
||||
) (*Gateway, error) {
|
||||
sshConfig := &ssh.ServerConfig{}
|
||||
|
||||
// privateKey
|
||||
var (
|
||||
privateKeyBytes []byte
|
||||
err error
|
||||
)
|
||||
if cfg.PrivateKeyFile != "" {
|
||||
privateKeyBytes, err = os.ReadFile(cfg.PrivateKeyFile)
|
||||
} else {
|
||||
if cfg.AutoGenPrivateKeyPath != "" {
|
||||
privateKeyBytes, _ = os.ReadFile(cfg.AutoGenPrivateKeyPath)
|
||||
}
|
||||
if len(privateKeyBytes) == 0 {
|
||||
privateKeyBytes, err = transport.NewRandomPrivateKey()
|
||||
if err == nil && cfg.AutoGenPrivateKeyPath != "" {
|
||||
err = os.WriteFile(cfg.AutoGenPrivateKeyPath, privateKeyBytes, 0o600)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
privateKey, err := ssh.ParsePrivateKey(privateKeyBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sshConfig.AddHostKey(privateKey)
|
||||
|
||||
sshConfig.PublicKeyCallback = func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
||||
if cfg.AuthorizedKeysFile == "" {
|
||||
return &ssh.Permissions{
|
||||
Extensions: map[string]string{
|
||||
"user": "",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
authorizedKeysMap, err := loadAuthorizedKeysFromFile(cfg.AuthorizedKeysFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("internal error")
|
||||
}
|
||||
|
||||
user, ok := authorizedKeysMap[string(key.Marshal())]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown public key for remoteAddr %q", conn.RemoteAddr())
|
||||
}
|
||||
return &ssh.Permissions{
|
||||
Extensions: map[string]string{
|
||||
"user": user,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", net.JoinHostPort(bindAddr, strconv.Itoa(cfg.BindPort)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Gateway{
|
||||
bindPort: cfg.BindPort,
|
||||
ln: ln,
|
||||
serverPeerListener: serverPeerListener,
|
||||
sshConfig: sshConfig,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (g *Gateway) Run() {
|
||||
for {
|
||||
conn, err := g.ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go g.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Gateway) handleConn(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
ts, err := NewTunnelServer(conn, g.sshConfig, g.serverPeerListener)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if err := ts.Run(); err != nil {
|
||||
log.Error("ssh tunnel server run error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func loadAuthorizedKeysFromFile(path string) (map[string]string, error) {
|
||||
authorizedKeysMap := make(map[string]string) // value is username
|
||||
authorizedKeysBytes, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for len(authorizedKeysBytes) > 0 {
|
||||
pubKey, comment, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
authorizedKeysMap[string(pubKey.Marshal())] = strings.TrimSpace(comment)
|
||||
authorizedKeysBytes = rest
|
||||
}
|
||||
return authorizedKeysMap, nil
|
||||
}
|
@ -0,0 +1,279 @@
|
||||
// Copyright 2023 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
libio "github.com/fatedier/golib/io"
|
||||
"github.com/samber/lo"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/fatedier/frp/pkg/config"
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
utilnet "github.com/fatedier/frp/pkg/util/net"
|
||||
"github.com/fatedier/frp/pkg/util/util"
|
||||
"github.com/fatedier/frp/pkg/util/xlog"
|
||||
"github.com/fatedier/frp/pkg/virtual"
|
||||
)
|
||||
|
||||
const (
|
||||
// https://datatracker.ietf.org/doc/html/rfc4254#page-16
|
||||
ChannelTypeServerOpenChannel = "forwarded-tcpip"
|
||||
RequestTypeForward = "tcpip-forward"
|
||||
)
|
||||
|
||||
type tcpipForward struct {
|
||||
Host string
|
||||
Port uint32
|
||||
}
|
||||
|
||||
// https://datatracker.ietf.org/doc/html/rfc4254#page-16
|
||||
type forwardedTCPPayload struct {
|
||||
Addr string
|
||||
Port uint32
|
||||
|
||||
// can be default empty value but do not delete it
|
||||
// because ssh protocol shoule be reserved
|
||||
OriginAddr string
|
||||
OriginPort uint32
|
||||
}
|
||||
|
||||
type TunnelServer struct {
|
||||
underlyingConn net.Conn
|
||||
sshConn *ssh.ServerConn
|
||||
sc *ssh.ServerConfig
|
||||
|
||||
vc *virtual.Client
|
||||
serverPeerListener *utilnet.InternalListener
|
||||
doneCh chan struct{}
|
||||
}
|
||||
|
||||
func NewTunnelServer(conn net.Conn, sc *ssh.ServerConfig, serverPeerListener *utilnet.InternalListener) (*TunnelServer, error) {
|
||||
s := &TunnelServer{
|
||||
underlyingConn: conn,
|
||||
sc: sc,
|
||||
serverPeerListener: serverPeerListener,
|
||||
doneCh: make(chan struct{}),
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *TunnelServer) Run() error {
|
||||
sshConn, channels, requests, err := ssh.NewServerConn(s.underlyingConn, s.sc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.sshConn = sshConn
|
||||
|
||||
addr, extraPayload, err := s.waitForwardAddrAndExtraPayload(channels, requests, 3*time.Second)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
clientCfg, pc, err := s.parseClientAndProxyConfigurer(addr, extraPayload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
clientCfg.User = util.EmptyOr(sshConn.Permissions.Extensions["user"], clientCfg.User)
|
||||
pc.Complete(clientCfg.User)
|
||||
|
||||
s.vc = virtual.NewClient(clientCfg)
|
||||
// join workConn and ssh channel
|
||||
s.vc.SetInWorkConnCallback(func(base *v1.ProxyBaseConfig, workConn net.Conn, m *msg.StartWorkConn) bool {
|
||||
c, err := s.openConn(addr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
libio.Join(c, workConn)
|
||||
return false
|
||||
})
|
||||
// transfer connection from virtual client to server peer listener
|
||||
go func() {
|
||||
l := s.vc.PeerListener()
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = s.serverPeerListener.PutConn(conn)
|
||||
}
|
||||
}()
|
||||
xl := xlog.New().AddPrefix(xlog.LogPrefix{Name: "sshVirtualClient", Value: "sshVirtualClient", Priority: 100})
|
||||
ctx := xlog.NewContext(context.Background(), xl)
|
||||
go func() {
|
||||
_ = s.vc.Run(ctx)
|
||||
}()
|
||||
|
||||
s.vc.UpdateProxyConfigurer([]v1.ProxyConfigurer{pc})
|
||||
|
||||
_ = sshConn.Wait()
|
||||
_ = sshConn.Close()
|
||||
s.vc.Close()
|
||||
close(s.doneCh)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *TunnelServer) waitForwardAddrAndExtraPayload(
|
||||
channels <-chan ssh.NewChannel,
|
||||
requests <-chan *ssh.Request,
|
||||
timeout time.Duration,
|
||||
) (*tcpipForward, string, error) {
|
||||
addrCh := make(chan *tcpipForward, 1)
|
||||
extraPayloadCh := make(chan string, 1)
|
||||
|
||||
// get forward address
|
||||
go func() {
|
||||
addrGot := false
|
||||
for req := range requests {
|
||||
switch req.Type {
|
||||
case RequestTypeForward:
|
||||
if !addrGot {
|
||||
payload := tcpipForward{}
|
||||
if err := ssh.Unmarshal(req.Payload, &payload); err != nil {
|
||||
return
|
||||
}
|
||||
addrGot = true
|
||||
addrCh <- &payload
|
||||
}
|
||||
default:
|
||||
if req.WantReply {
|
||||
_ = req.Reply(true, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// get extra payload
|
||||
go func() {
|
||||
for newChannel := range channels {
|
||||
// extraPayload will send to extraPayloadCh
|
||||
go s.handleNewChannel(newChannel, extraPayloadCh)
|
||||
}
|
||||
}()
|
||||
|
||||
var (
|
||||
addr *tcpipForward
|
||||
extraPayload string
|
||||
)
|
||||
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
for {
|
||||
select {
|
||||
case v := <-addrCh:
|
||||
addr = v
|
||||
case extra := <-extraPayloadCh:
|
||||
extraPayload = extra
|
||||
case <-timer.C:
|
||||
return nil, "", fmt.Errorf("get addr and extra payload timeout")
|
||||
}
|
||||
if addr != nil && extraPayload != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
return addr, extraPayload, nil
|
||||
}
|
||||
|
||||
func (s *TunnelServer) parseClientAndProxyConfigurer(_ *tcpipForward, extraPayload string) (*v1.ClientCommonConfig, v1.ProxyConfigurer, error) {
|
||||
cmd := &cobra.Command{}
|
||||
args := strings.Split(extraPayload, " ")
|
||||
if len(args) < 1 {
|
||||
return nil, nil, fmt.Errorf("invalid extra payload")
|
||||
}
|
||||
proxyType := strings.TrimSpace(args[0])
|
||||
supportTypes := []string{"tcp", "http", "https", "tcpmux", "stcp"}
|
||||
if !lo.Contains(supportTypes, proxyType) {
|
||||
return nil, nil, fmt.Errorf("invalid proxy type: %s, support types: %v", proxyType, supportTypes)
|
||||
}
|
||||
pc := v1.NewProxyConfigurerByType(v1.ProxyType(proxyType))
|
||||
if pc == nil {
|
||||
return nil, nil, fmt.Errorf("new proxy configurer error")
|
||||
}
|
||||
config.RegisterProxyFlags(cmd, pc)
|
||||
|
||||
clientCfg := v1.ClientCommonConfig{}
|
||||
config.RegisterClientCommonConfigFlags(cmd, &clientCfg)
|
||||
|
||||
if err := cmd.ParseFlags(args); err != nil {
|
||||
return nil, nil, fmt.Errorf("parse flags from ssh client error: %v", err)
|
||||
}
|
||||
return &clientCfg, pc, nil
|
||||
}
|
||||
|
||||
func (s *TunnelServer) handleNewChannel(channel ssh.NewChannel, extraPayloadCh chan string) {
|
||||
ch, reqs, err := channel.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go s.keepAlive(ch)
|
||||
|
||||
for req := range reqs {
|
||||
if req.Type != "exec" {
|
||||
continue
|
||||
}
|
||||
if len(req.Payload) <= 4 {
|
||||
continue
|
||||
}
|
||||
end := 4 + binary.BigEndian.Uint32(req.Payload[:4])
|
||||
if len(req.Payload) < int(end) {
|
||||
continue
|
||||
}
|
||||
extraPayload := string(req.Payload[4:end])
|
||||
select {
|
||||
case extraPayloadCh <- extraPayload:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TunnelServer) keepAlive(ch ssh.Channel) {
|
||||
tk := time.NewTicker(time.Second * 30)
|
||||
defer tk.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-tk.C:
|
||||
_, err := ch.SendRequest("heartbeat", false, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
case <-s.doneCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TunnelServer) openConn(addr *tcpipForward) (net.Conn, error) {
|
||||
payload := forwardedTCPPayload{
|
||||
Addr: addr.Host,
|
||||
Port: addr.Port,
|
||||
}
|
||||
channel, reqs, err := s.sshConn.OpenChannel(ChannelTypeServerOpenChannel, ssh.Marshal(&payload))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open ssh channel error: %v", err)
|
||||
}
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
conn := utilnet.WrapReadWriteCloserToConn(channel, s.underlyingConn)
|
||||
return conn, nil
|
||||
}
|
@ -1,497 +0,0 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
gerror "github.com/fatedier/golib/errors"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/util/log"
|
||||
)
|
||||
|
||||
const (
|
||||
// ssh protocol define
|
||||
// https://datatracker.ietf.org/doc/html/rfc4254#page-16
|
||||
ChannelTypeServerOpenChannel = "forwarded-tcpip"
|
||||
RequestTypeForward = "tcpip-forward"
|
||||
|
||||
// golang ssh package define.
|
||||
// https://pkg.go.dev/golang.org/x/crypto/ssh
|
||||
RequestTypeHeartbeat = "keepalive@openssh.com"
|
||||
)
|
||||
|
||||
// 当 proxy 失败会返回该错误
|
||||
type VProxyError struct{}
|
||||
|
||||
// ssh protocol define
|
||||
// https://datatracker.ietf.org/doc/html/rfc4254#page-16
|
||||
// parse ssh client cmds input
|
||||
type forwardedTCPPayload struct {
|
||||
Addr string
|
||||
Port uint32
|
||||
|
||||
// can be default empty value but do not delete it
|
||||
// because ssh protocol shoule be reserved
|
||||
OriginAddr string
|
||||
OriginPort uint32
|
||||
}
|
||||
|
||||
// custom define
|
||||
// parse ssh client cmds input
|
||||
type CmdPayload struct {
|
||||
Address string
|
||||
Port uint32
|
||||
}
|
||||
|
||||
// custom define
|
||||
// with frp control cmds
|
||||
type ExtraPayload struct {
|
||||
Type string
|
||||
|
||||
// TODO port can be set by extra message and priority to ssh raw cmd
|
||||
Address string
|
||||
Port uint32
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
tcpConn net.Conn
|
||||
cfg *ssh.ServerConfig
|
||||
|
||||
sshConn *ssh.ServerConn
|
||||
gChannel <-chan ssh.NewChannel
|
||||
gReq <-chan *ssh.Request
|
||||
|
||||
addrPayloadCh chan CmdPayload
|
||||
extraPayloadCh chan ExtraPayload
|
||||
|
||||
proxyPayloadCh chan v1.ProxyConfigurer
|
||||
replyCh chan interface{}
|
||||
|
||||
closeCh chan struct{}
|
||||
exit int32
|
||||
}
|
||||
|
||||
func NewSSHService(
|
||||
tcpConn net.Conn,
|
||||
cfg *ssh.ServerConfig,
|
||||
proxyPayloadCh chan v1.ProxyConfigurer,
|
||||
replyCh chan interface{},
|
||||
) (ss *Service, err error) {
|
||||
ss = &Service{
|
||||
tcpConn: tcpConn,
|
||||
cfg: cfg,
|
||||
|
||||
addrPayloadCh: make(chan CmdPayload),
|
||||
extraPayloadCh: make(chan ExtraPayload),
|
||||
|
||||
proxyPayloadCh: proxyPayloadCh,
|
||||
replyCh: replyCh,
|
||||
|
||||
closeCh: make(chan struct{}),
|
||||
exit: 0,
|
||||
}
|
||||
|
||||
ss.sshConn, ss.gChannel, ss.gReq, err = ssh.NewServerConn(tcpConn, cfg)
|
||||
if err != nil {
|
||||
log.Error("ssh handshake error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Info("ssh connection success")
|
||||
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
func (ss *Service) Run() {
|
||||
go ss.loopGenerateProxy()
|
||||
go ss.loopParseCmdPayload()
|
||||
go ss.loopParseExtraPayload()
|
||||
go ss.loopReply()
|
||||
}
|
||||
|
||||
func (ss *Service) Exit() <-chan struct{} {
|
||||
return ss.closeCh
|
||||
}
|
||||
|
||||
func (ss *Service) Close() {
|
||||
if atomic.LoadInt32(&ss.exit) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ss.closeCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
close(ss.closeCh)
|
||||
close(ss.addrPayloadCh)
|
||||
close(ss.extraPayloadCh)
|
||||
|
||||
_ = ss.sshConn.Wait()
|
||||
|
||||
ss.sshConn.Close()
|
||||
ss.tcpConn.Close()
|
||||
|
||||
atomic.StoreInt32(&ss.exit, 1)
|
||||
|
||||
log.Info("ssh service close")
|
||||
}
|
||||
|
||||
func (ss *Service) loopParseCmdPayload() {
|
||||
for {
|
||||
select {
|
||||
case req, ok := <-ss.gReq:
|
||||
if !ok {
|
||||
log.Info("global request is close")
|
||||
ss.Close()
|
||||
return
|
||||
}
|
||||
|
||||
switch req.Type {
|
||||
case RequestTypeForward:
|
||||
var addrPayload CmdPayload
|
||||
if err := ssh.Unmarshal(req.Payload, &addrPayload); err != nil {
|
||||
log.Error("ssh unmarshal error: %v", err)
|
||||
return
|
||||
}
|
||||
_ = gerror.PanicToError(func() {
|
||||
ss.addrPayloadCh <- addrPayload
|
||||
})
|
||||
default:
|
||||
if req.Type == RequestTypeHeartbeat {
|
||||
log.Debug("ssh heartbeat data")
|
||||
} else {
|
||||
log.Info("default req, data: %v", req)
|
||||
}
|
||||
}
|
||||
if req.WantReply {
|
||||
err := req.Reply(true, nil)
|
||||
if err != nil {
|
||||
log.Error("reply to ssh client error: %v", err)
|
||||
}
|
||||
}
|
||||
case <-ss.closeCh:
|
||||
log.Info("loop parse cmd payload close")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ss *Service) loopSendHeartbeat(ch ssh.Channel) {
|
||||
tk := time.NewTicker(time.Second * 60)
|
||||
defer tk.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-tk.C:
|
||||
ok, err := ch.SendRequest("heartbeat", false, nil)
|
||||
if err != nil {
|
||||
log.Error("channel send req error: %v", err)
|
||||
if err == io.EOF {
|
||||
ss.Close()
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
log.Debug("heartbeat send success, ok: %v", ok)
|
||||
case <-ss.closeCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ss *Service) loopParseExtraPayload() {
|
||||
log.Info("loop parse extra payload start")
|
||||
|
||||
for newChannel := range ss.gChannel {
|
||||
ch, req, err := newChannel.Accept()
|
||||
if err != nil {
|
||||
log.Error("channel accept error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
go ss.loopSendHeartbeat(ch)
|
||||
|
||||
go func(req <-chan *ssh.Request) {
|
||||
for r := range req {
|
||||
if len(r.Payload) <= 4 {
|
||||
log.Info("r.payload is less than 4")
|
||||
continue
|
||||
}
|
||||
if !strings.Contains(string(r.Payload), "tcp") && !strings.Contains(string(r.Payload), "http") {
|
||||
log.Info("ssh protocol exchange data")
|
||||
continue
|
||||
}
|
||||
|
||||
// [4byte data_len|data]
|
||||
end := 4 + binary.BigEndian.Uint32(r.Payload[:4])
|
||||
if end > uint32(len(r.Payload)) {
|
||||
end = uint32(len(r.Payload))
|
||||
}
|
||||
p := string(r.Payload[4:end])
|
||||
|
||||
msg, err := parseSSHExtraMessage(p)
|
||||
if err != nil {
|
||||
log.Error("parse ssh extra message error: %v, payload: %v", err, r.Payload)
|
||||
continue
|
||||
}
|
||||
_ = gerror.PanicToError(func() {
|
||||
ss.extraPayloadCh <- msg
|
||||
})
|
||||
return
|
||||
}
|
||||
}(req)
|
||||
}
|
||||
}
|
||||
|
||||
func (ss *Service) SSHConn() *ssh.ServerConn {
|
||||
return ss.sshConn
|
||||
}
|
||||
|
||||
func (ss *Service) TCPConn() net.Conn {
|
||||
return ss.tcpConn
|
||||
}
|
||||
|
||||
func (ss *Service) loopReply() {
|
||||
for {
|
||||
select {
|
||||
case <-ss.closeCh:
|
||||
log.Info("loop reply close")
|
||||
return
|
||||
case req := <-ss.replyCh:
|
||||
switch req.(type) {
|
||||
case *VProxyError:
|
||||
log.Error("run frp proxy error, close ssh service")
|
||||
ss.Close()
|
||||
default:
|
||||
// TODO
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ss *Service) loopGenerateProxy() {
|
||||
log.Info("loop generate proxy start")
|
||||
|
||||
for {
|
||||
if atomic.LoadInt32(&ss.exit) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
wg := new(sync.WaitGroup)
|
||||
wg.Add(2)
|
||||
|
||||
var p1 CmdPayload
|
||||
var p2 ExtraPayload
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-ss.closeCh:
|
||||
return
|
||||
case p1 = <-ss.addrPayloadCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-ss.closeCh:
|
||||
return
|
||||
case p2 = <-ss.extraPayloadCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if atomic.LoadInt32(&ss.exit) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
switch p2.Type {
|
||||
case "http":
|
||||
case "tcp":
|
||||
ss.proxyPayloadCh <- &v1.TCPProxyConfig{
|
||||
ProxyBaseConfig: v1.ProxyBaseConfig{
|
||||
Name: fmt.Sprintf("ssh-proxy-%v-%v", ss.tcpConn.RemoteAddr().String(), time.Now().UnixNano()),
|
||||
Type: p2.Type,
|
||||
|
||||
ProxyBackend: v1.ProxyBackend{
|
||||
LocalIP: p1.Address,
|
||||
},
|
||||
},
|
||||
RemotePort: int(p1.Port),
|
||||
}
|
||||
default:
|
||||
log.Warn("invalid frp proxy type: %v", p2.Type)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseSSHExtraMessage(s string) (p ExtraPayload, err error) {
|
||||
sn := len(s)
|
||||
|
||||
log.Info("parse ssh extra message: %v", s)
|
||||
|
||||
ss := strings.Fields(s)
|
||||
if len(ss) == 0 {
|
||||
if sn != 0 {
|
||||
ss = append(ss, s)
|
||||
} else {
|
||||
return p, fmt.Errorf("invalid ssh input, args: %v", ss)
|
||||
}
|
||||
}
|
||||
|
||||
for i, v := range ss {
|
||||
ss[i] = strings.TrimSpace(v)
|
||||
}
|
||||
|
||||
if ss[0] != "tcp" && ss[0] != "http" {
|
||||
return p, fmt.Errorf("only support tcp/http now")
|
||||
}
|
||||
|
||||
switch ss[0] {
|
||||
case "tcp":
|
||||
tcpCmd, err := ParseTCPCommand(ss)
|
||||
if err != nil {
|
||||
return ExtraPayload{}, fmt.Errorf("invalid ssh input: %v", err)
|
||||
}
|
||||
|
||||
port, _ := strconv.Atoi(tcpCmd.Port)
|
||||
|
||||
p = ExtraPayload{
|
||||
Type: "tcp",
|
||||
Address: tcpCmd.Address,
|
||||
Port: uint32(port),
|
||||
}
|
||||
case "http":
|
||||
httpCmd, err := ParseHTTPCommand(ss)
|
||||
if err != nil {
|
||||
return ExtraPayload{}, fmt.Errorf("invalid ssh input: %v", err)
|
||||
}
|
||||
|
||||
_ = httpCmd
|
||||
|
||||
p = ExtraPayload{
|
||||
Type: "http",
|
||||
}
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
type HTTPCommand struct {
|
||||
Domain string
|
||||
BasicAuthUser string
|
||||
BasicAuthPass string
|
||||
}
|
||||
|
||||
func ParseHTTPCommand(params []string) (*HTTPCommand, error) {
|
||||
if len(params) < 2 {
|
||||
return nil, errors.New("invalid HTTP command")
|
||||
}
|
||||
|
||||
var (
|
||||
basicAuth string
|
||||
domainURL string
|
||||
basicAuthUser string
|
||||
basicAuthPass string
|
||||
)
|
||||
|
||||
fs := flag.NewFlagSet("http", flag.ContinueOnError)
|
||||
fs.StringVar(&basicAuth, "basic-auth", "", "")
|
||||
fs.StringVar(&domainURL, "domain", "", "")
|
||||
|
||||
fs.SetOutput(&nullWriter{}) // Disables usage output
|
||||
|
||||
err := fs.Parse(params[2:])
|
||||
if err != nil {
|
||||
if !errors.Is(err, flag.ErrHelp) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if basicAuth != "" {
|
||||
authParts := strings.SplitN(basicAuth, ":", 2)
|
||||
basicAuthUser = authParts[0]
|
||||
if len(authParts) > 1 {
|
||||
basicAuthPass = authParts[1]
|
||||
}
|
||||
}
|
||||
|
||||
httpCmd := &HTTPCommand{
|
||||
Domain: domainURL,
|
||||
BasicAuthUser: basicAuthUser,
|
||||
BasicAuthPass: basicAuthPass,
|
||||
}
|
||||
return httpCmd, nil
|
||||
}
|
||||
|
||||
type TCPCommand struct {
|
||||
Address string
|
||||
Port string
|
||||
}
|
||||
|
||||
func ParseTCPCommand(params []string) (*TCPCommand, error) {
|
||||
if len(params) == 0 || params[0] != "tcp" {
|
||||
return nil, errors.New("invalid TCP command")
|
||||
}
|
||||
|
||||
if len(params) == 1 {
|
||||
return &TCPCommand{}, nil
|
||||
}
|
||||
|
||||
var (
|
||||
address string
|
||||
port string
|
||||
)
|
||||
|
||||
fs := flag.NewFlagSet("tcp", flag.ContinueOnError)
|
||||
fs.StringVar(&address, "address", "", "The IP address to listen on")
|
||||
fs.StringVar(&port, "port", "", "The port to listen on")
|
||||
fs.SetOutput(&nullWriter{}) // Disables usage output
|
||||
|
||||
args := params[1:]
|
||||
err := fs.Parse(args)
|
||||
if err != nil {
|
||||
if !errors.Is(err, flag.ErrHelp) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
parsedAddr, err := net.ResolveIPAddr("ip", address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := net.LookupPort("tcp", port); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tcpCmd := &TCPCommand{
|
||||
Address: parsedAddr.String(),
|
||||
Port: port,
|
||||
}
|
||||
return tcpCmd, nil
|
||||
}
|
||||
|
||||
type nullWriter struct{}
|
||||
|
||||
func (w *nullWriter) Write(p []byte) (n int, err error) { return len(p), nil }
|
@ -1,185 +0,0 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/fatedier/frp/pkg/config"
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
plugin "github.com/fatedier/frp/pkg/plugin/server"
|
||||
"github.com/fatedier/frp/pkg/util/log"
|
||||
frp_net "github.com/fatedier/frp/pkg/util/net"
|
||||
"github.com/fatedier/frp/pkg/util/util"
|
||||
"github.com/fatedier/frp/pkg/util/xlog"
|
||||
"github.com/fatedier/frp/server/controller"
|
||||
"github.com/fatedier/frp/server/proxy"
|
||||
)
|
||||
|
||||
// VirtualService is a client VirtualService run in frps
|
||||
type VirtualService struct {
|
||||
clientCfg v1.ClientCommonConfig
|
||||
pxyCfg v1.ProxyConfigurer
|
||||
serverCfg v1.ServerConfig
|
||||
|
||||
sshSvc *Service
|
||||
|
||||
// uniq id got from frps, attach it in loginMsg
|
||||
runID string
|
||||
loginMsg *msg.Login
|
||||
|
||||
// All resource managers and controllers
|
||||
rc *controller.ResourceController
|
||||
|
||||
exit uint32 // 0 means not exit
|
||||
// SSHService context
|
||||
ctx context.Context
|
||||
// call cancel to stop SSHService
|
||||
cancel context.CancelFunc
|
||||
|
||||
replyCh chan interface{}
|
||||
pxy proxy.Proxy
|
||||
}
|
||||
|
||||
func NewVirtualService(
|
||||
ctx context.Context,
|
||||
clientCfg v1.ClientCommonConfig,
|
||||
serverCfg v1.ServerConfig,
|
||||
logMsg msg.Login,
|
||||
rc *controller.ResourceController,
|
||||
pxyCfg v1.ProxyConfigurer,
|
||||
sshSvc *Service,
|
||||
replyCh chan interface{},
|
||||
) (svr *VirtualService, err error) {
|
||||
svr = &VirtualService{
|
||||
clientCfg: clientCfg,
|
||||
serverCfg: serverCfg,
|
||||
rc: rc,
|
||||
|
||||
loginMsg: &logMsg,
|
||||
|
||||
sshSvc: sshSvc,
|
||||
pxyCfg: pxyCfg,
|
||||
|
||||
ctx: ctx,
|
||||
exit: 0,
|
||||
|
||||
replyCh: replyCh,
|
||||
}
|
||||
|
||||
svr.runID, err = util.RandID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go svr.loopCheck()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (svr *VirtualService) Run(ctx context.Context) (err error) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
svr.ctx = xlog.NewContext(ctx, xlog.New())
|
||||
svr.cancel = cancel
|
||||
|
||||
remoteAddr, err := svr.RegisterProxy(&msg.NewProxy{
|
||||
ProxyName: svr.pxyCfg.(*v1.TCPProxyConfig).Name,
|
||||
ProxyType: svr.pxyCfg.(*v1.TCPProxyConfig).Type,
|
||||
RemotePort: svr.pxyCfg.(*v1.TCPProxyConfig).RemotePort,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info("run a reverse proxy on port: %v", remoteAddr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svr *VirtualService) Close() {
|
||||
svr.GracefulClose(time.Duration(0))
|
||||
}
|
||||
|
||||
func (svr *VirtualService) GracefulClose(d time.Duration) {
|
||||
atomic.StoreUint32(&svr.exit, 1)
|
||||
svr.pxy.Close()
|
||||
|
||||
if svr.cancel != nil {
|
||||
svr.cancel()
|
||||
}
|
||||
|
||||
svr.replyCh <- &VProxyError{}
|
||||
}
|
||||
|
||||
func (svr *VirtualService) loopCheck() {
|
||||
<-svr.sshSvc.Exit()
|
||||
svr.pxy.Close()
|
||||
log.Info("virtual client service close")
|
||||
}
|
||||
|
||||
func (svr *VirtualService) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err error) {
|
||||
var pxyConf v1.ProxyConfigurer
|
||||
pxyConf, err = config.NewProxyConfigurerFromMsg(pxyMsg, &svr.serverCfg)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// User info
|
||||
userInfo := plugin.UserInfo{
|
||||
User: svr.loginMsg.User,
|
||||
Metas: svr.loginMsg.Metas,
|
||||
RunID: svr.runID,
|
||||
}
|
||||
|
||||
svr.pxy, err = proxy.NewProxy(svr.ctx, &proxy.Options{
|
||||
LoginMsg: svr.loginMsg,
|
||||
UserInfo: userInfo,
|
||||
Configurer: pxyConf,
|
||||
ResourceController: svr.rc,
|
||||
|
||||
GetWorkConnFn: svr.GetWorkConn,
|
||||
PoolCount: 10,
|
||||
|
||||
ServerCfg: &svr.serverCfg,
|
||||
})
|
||||
if err != nil {
|
||||
return remoteAddr, err
|
||||
}
|
||||
|
||||
remoteAddr, err = svr.pxy.Run()
|
||||
if err != nil {
|
||||
log.Warn("proxy run error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
log.Warn("proxy close")
|
||||
svr.pxy.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (svr *VirtualService) GetWorkConn() (workConn net.Conn, err error) {
|
||||
// tell ssh client open a new stream for work
|
||||
payload := forwardedTCPPayload{
|
||||
Addr: svr.serverCfg.BindAddr, // TODO refine
|
||||
Port: uint32(svr.pxyCfg.(*v1.TCPProxyConfig).RemotePort),
|
||||
}
|
||||
|
||||
channel, reqs, err := svr.sshSvc.SSHConn().OpenChannel(ChannelTypeServerOpenChannel, ssh.Marshal(payload))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open ssh channel error: %v", err)
|
||||
}
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
workConn = frp_net.WrapReadWriteCloserToConn(channel, svr.sshSvc.tcpConn)
|
||||
return workConn, nil
|
||||
}
|
@ -0,0 +1,92 @@
|
||||
// Copyright 2023 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package virtual
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/fatedier/frp/client"
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
utilnet "github.com/fatedier/frp/pkg/util/net"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
l *utilnet.InternalListener
|
||||
svr *client.Service
|
||||
}
|
||||
|
||||
func NewClient(cfg *v1.ClientCommonConfig) *Client {
|
||||
cfg.Complete()
|
||||
|
||||
ln := utilnet.NewInternalListener()
|
||||
|
||||
svr := client.NewService(cfg, nil, nil, "")
|
||||
svr.SetConnectorCreator(func(context.Context, *v1.ClientCommonConfig) client.Connector {
|
||||
return &pipeConnector{
|
||||
peerListener: ln,
|
||||
}
|
||||
})
|
||||
|
||||
return &Client{
|
||||
l: ln,
|
||||
svr: svr,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) PeerListener() net.Listener {
|
||||
return c.l
|
||||
}
|
||||
|
||||
func (c *Client) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) {
|
||||
c.svr.SetInWorkConnCallback(cb)
|
||||
}
|
||||
|
||||
func (c *Client) UpdateProxyConfigurer(proxyCfgs []v1.ProxyConfigurer) {
|
||||
_ = c.svr.ReloadConf(proxyCfgs, nil)
|
||||
}
|
||||
|
||||
func (c *Client) Run(ctx context.Context) error {
|
||||
return c.svr.Run(ctx)
|
||||
}
|
||||
|
||||
func (c *Client) Close() {
|
||||
c.l.Close()
|
||||
c.svr.Close()
|
||||
}
|
||||
|
||||
type pipeConnector struct {
|
||||
peerListener *utilnet.InternalListener
|
||||
}
|
||||
|
||||
func (pc *pipeConnector) Open() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pc *pipeConnector) Connect() (net.Conn, error) {
|
||||
c1, c2 := net.Pipe()
|
||||
if err := pc.peerListener.PutConn(c1); err != nil {
|
||||
c1.Close()
|
||||
c2.Close()
|
||||
return nil, err
|
||||
}
|
||||
return c2, nil
|
||||
}
|
||||
|
||||
func (pc *pipeConnector) Close() error {
|
||||
pc.peerListener.Close()
|
||||
return nil
|
||||
}
|
Loading…
Reference in New Issue