diff --git a/Dockerfile b/Dockerfile index 64bc4f4..fabf350 100644 --- a/Dockerfile +++ b/Dockerfile @@ -36,6 +36,7 @@ RUN set -ex && \ FROM alpine:latest ENV SSHWIFTY_HOSTNAME= \ SSHWIFTY_SHAREDKEY= \ + SSHWIFTY_DIALTIMEOUT=10 \ SSHWIFTY_SOCKS5= \ SSHWIFTY_SOCKS5_USER= \ SSHWIFTY_SOCKS5_PASSWORD= \ diff --git a/README.md b/README.md index b161492..f126780 100644 --- a/README.md +++ b/README.md @@ -114,6 +114,12 @@ Here is the options that can be used in a configuration file and what it for: // Web interface access password. Set to empty to allow public access "SharedKey": "WEB_ACCESS_PASSWORD", + // Remote dial timeout. This limits how long of time the backend can spend + // to connect to a remote host. The max timeout will be determined by + // server configuration (ReadTimeout). + // (In Second) + "DialTimeout": 10, + // Socks5 proxy. When set, we will try to connect remote through the given // proxy "Socks5": "localhost:1080", @@ -136,23 +142,29 @@ Here is the options that can be used in a configuration file and what it for: // Timeout of initial request. HTTP handshake must be finished within // this time + // (In Second) "InitialTimeout": 3, // How long the connection can be idle before the server disconnects the // client + // (In Second) "ReadTimeout": 60, // How long the server will wait until the client connect is ready to // recieve new data + // (In Second) "WriteTimeout": 60, // The interval between internal echo requests + // (In Second) "HeartbeatTimeout": 20, // Forced delay between each request + // (In Milisecond) "ReadDelay": 10, // Forced delay between each write + // (In Milisecond) "WriteDelay": 10, // Path to TLS certificate file. Set empty to use HTTP diff --git a/application/command/commander.go b/application/command/commander.go index adb8fcb..174f676 100644 --- a/application/command/commander.go +++ b/application/command/commander.go @@ -27,6 +27,12 @@ import ( "github.com/niruix/sshwifty/application/rw" ) +// CommandConfiguration contains configuration data needed to run command +type CommandConfiguration struct { + Dial network.Dial + DialTimeout time.Duration +} + // Commander command control type Commander struct { commands Commands @@ -41,7 +47,7 @@ func New(cs Commands) Commander { // New Adds a new client func (c Commander) New( - dialer network.Dial, + cfg CommandConfiguration, receiver rw.FetchReader, sender io.Writer, senderLock *sync.Mutex, @@ -50,7 +56,7 @@ func (c Commander) New( l log.Logger, ) (Handler, error) { return newHandler( - dialer, + cfg, &c.commands, receiver, sender, diff --git a/application/command/commands.go b/application/command/commands.go index 40c6099..8e20942 100644 --- a/application/command/commands.go +++ b/application/command/commands.go @@ -22,7 +22,6 @@ import ( "fmt" "github.com/niruix/sshwifty/application/log" - "github.com/niruix/sshwifty/application/network" ) // Consts @@ -37,7 +36,11 @@ var ( ) // Command represents a command handler machine builder -type Command func(l log.Logger, w StreamResponder, d network.Dial) FSMMachine +type Command func( + l log.Logger, + w StreamResponder, + cfg CommandConfiguration, +) FSMMachine // Commands contains data of all commands type Commands [MaxCommandID + 1]Command @@ -57,7 +60,10 @@ func (c *Commands) Register(id byte, cb Command) { // Run creates command executer func (c Commands) Run( - id byte, l log.Logger, w StreamResponder, dial network.Dial) (FSM, error) { + id byte, + l log.Logger, + w StreamResponder, + cfg CommandConfiguration) (FSM, error) { if id > MaxCommandID { return FSM{}, ErrCommandRunUndefinedCommand } @@ -68,5 +74,5 @@ func (c Commands) Run( return FSM{}, ErrCommandRunUndefinedCommand } - return newFSM(cc(l, w, dial)), nil + return newFSM(cc(l, w, cfg)), nil } diff --git a/application/command/handler.go b/application/command/handler.go index 63a1fef..7fddb7a 100644 --- a/application/command/handler.go +++ b/application/command/handler.go @@ -25,7 +25,6 @@ import ( "time" "github.com/niruix/sshwifty/application/log" - "github.com/niruix/sshwifty/application/network" "github.com/niruix/sshwifty/application/rw" ) @@ -105,7 +104,7 @@ func (h streamHandlerSender) Write(b []byte) (int, error) { // Handler client stream control type Handler struct { - dialer network.Dial + cfg CommandConfiguration commands *Commands receiver rw.FetchReader sender handlerSender @@ -118,7 +117,7 @@ type Handler struct { } func newHandler( - dialer network.Dial, + cfg CommandConfiguration, commands *Commands, receiver rw.FetchReader, sender io.Writer, @@ -128,7 +127,7 @@ func newHandler( l log.Logger, ) Handler { return Handler{ - dialer: dialer, + cfg: cfg, commands: commands, receiver: receiver, sender: handlerSender{writer: sender, lock: senderLock}, @@ -235,7 +234,7 @@ func (e *Handler) handleStream(h Header, d byte, l log.Logger) error { return st.reinit(h, &e.receiver, streamHandlerSender{ handlerSender: &e.sender, sendDelay: e.sendDelay, - }, l, e.commands, e.dialer, e.rBuf[:]) + }, l, e.commands, e.cfg, e.rBuf[:]) } func (e *Handler) handleClose(h Header, d byte, l log.Logger) error { diff --git a/application/command/handler_echo_test.go b/application/command/handler_echo_test.go index 53c31a9..e3a657f 100644 --- a/application/command/handler_echo_test.go +++ b/application/command/handler_echo_test.go @@ -78,7 +78,7 @@ func TestHandlerHandleEcho(t *testing.T) { } lock := sync.Mutex{} handler := newHandler( - nil, + CommandConfiguration{}, nil, rw.NewFetchReader(testDummyFetchGen(s)), &w, diff --git a/application/command/handler_stream_test.go b/application/command/handler_stream_test.go index 4d28ebf..406334c 100644 --- a/application/command/handler_stream_test.go +++ b/application/command/handler_stream_test.go @@ -25,7 +25,6 @@ import ( "testing" "github.com/niruix/sshwifty/application/log" - "github.com/niruix/sshwifty/application/network" "github.com/niruix/sshwifty/application/rw" ) @@ -58,7 +57,7 @@ func testDummyFetchChainGen(dd <-chan []byte) rw.FetchReaderFetcher { } type dummyStreamCommand struct { - lock sync.Mutex + lock sync.Mutex l log.Logger w StreamResponder downWait sync.WaitGroup @@ -67,9 +66,12 @@ type dummyStreamCommand struct { } func newDummyStreamCommand( - l log.Logger, w StreamResponder, d network.Dial) FSMMachine { + l log.Logger, + w StreamResponder, + cfg CommandConfiguration, +) FSMMachine { return &dummyStreamCommand{ - lock:sync.Mutex{}, + lock: sync.Mutex{}, l: l, w: w, downWait: sync.WaitGroup{}, @@ -84,7 +86,7 @@ func (d *dummyStreamCommand) Bootup( ) (FSMState, FSMError) { d.downWait.Add(1) - echoTrans:=d.echoTrans + echoTrans := d.echoTrans go func() { defer func() { @@ -178,7 +180,7 @@ func TestHandlerHandleStream(t *testing.T) { lock := sync.Mutex{} hhd := newHandler( - nil, + CommandConfiguration{}, &cmds, rw.NewFetchReader(readerSource), wBuffer, diff --git a/application/command/streams.go b/application/command/streams.go index 3d23e35..475cbfb 100644 --- a/application/command/streams.go +++ b/application/command/streams.go @@ -22,7 +22,6 @@ import ( "io" "github.com/niruix/sshwifty/application/log" - "github.com/niruix/sshwifty/application/network" "github.com/niruix/sshwifty/application/rw" ) @@ -342,7 +341,7 @@ func (c *stream) reinit( w streamHandlerSender, l log.Logger, cc *Commands, - dialer network.Dial, + cfg CommandConfiguration, b []byte, ) error { hd := streamInitialHeader{} @@ -355,7 +354,8 @@ func (c *stream) reinit( l = l.Context("Command (%d)", hd.command()) - ccc, cccErr := cc.Run(hd.command(), l, newStreamResponder(w, h), dialer) + ccc, cccErr := cc.Run( + hd.command(), l, newStreamResponder(w, h), cfg) if cccErr != nil { hd.set(0, uint16(StreamErrorCommandUndefined), false) diff --git a/application/commands/ssh.go b/application/commands/ssh.go index e65608b..0e834bf 100644 --- a/application/commands/ssh.go +++ b/application/commands/ssh.go @@ -22,13 +22,11 @@ import ( "io" "net" "sync" - "time" "golang.org/x/crypto/ssh" "github.com/niruix/sshwifty/application/command" "github.com/niruix/sshwifty/application/log" - "github.com/niruix/sshwifty/application/network" "github.com/niruix/sshwifty/application/rw" ) @@ -116,8 +114,7 @@ func (s sshRemoteConn) isValid() bool { type sshClient struct { w command.StreamResponder l log.Logger - dial network.Dial - dialTimeout time.Duration + cfg command.CommandConfiguration remoteCloseWait sync.WaitGroup credentialReceive chan []byte credentialProcessed bool @@ -132,13 +129,12 @@ type sshClient struct { func newSSH( l log.Logger, w command.StreamResponder, - dial network.Dial, + cfg command.CommandConfiguration, ) command.FSMMachine { return &sshClient{ w: w, l: l, - dial: dial, - dialTimeout: 10 * time.Second, + cfg: cfg, remoteCloseWait: sync.WaitGroup{}, credentialReceive: make(chan []byte, 1), credentialProcessed: false, @@ -300,7 +296,7 @@ func (d *sshClient) comfirmRemoteFingerprint( func (d *sshClient) dialRemote( network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { - conn, err := d.dial(network, addr, config.Timeout) + conn, err := d.cfg.Dial(network, addr, config.Timeout) if err != nil { return nil, err @@ -332,7 +328,7 @@ func (d *sshClient) remote( HostKeyCallback: func(h string, r net.Addr, k ssh.PublicKey) error { return d.comfirmRemoteFingerprint(h, r, k, buf[:]) }, - Timeout: d.dialTimeout, + Timeout: d.cfg.DialTimeout, }) if dErr != nil { diff --git a/application/commands/telnet.go b/application/commands/telnet.go index 3e1d3cc..08b3368 100644 --- a/application/commands/telnet.go +++ b/application/commands/telnet.go @@ -21,11 +21,9 @@ import ( "errors" "io" "sync" - "time" "github.com/niruix/sshwifty/application/command" "github.com/niruix/sshwifty/application/log" - "github.com/niruix/sshwifty/application/network" "github.com/niruix/sshwifty/application/rw" ) @@ -48,28 +46,26 @@ const ( ) type telnetClient struct { - l log.Logger - w command.StreamResponder - dial network.Dial - remoteChan chan io.WriteCloser - remoteConn io.WriteCloser - closeWait sync.WaitGroup - dialTimeout time.Duration + l log.Logger + w command.StreamResponder + cfg command.CommandConfiguration + remoteChan chan io.WriteCloser + remoteConn io.WriteCloser + closeWait sync.WaitGroup } func newTelnet( l log.Logger, w command.StreamResponder, - dial network.Dial, + cfg command.CommandConfiguration, ) command.FSMMachine { return &telnetClient{ - l: l, - w: w, - dial: dial, - remoteChan: make(chan io.WriteCloser, 1), - remoteConn: nil, - closeWait: sync.WaitGroup{}, - dialTimeout: 10 * time.Second, + l: l, + w: w, + cfg: cfg, + remoteChan: make(chan io.WriteCloser, 1), + remoteConn: nil, + closeWait: sync.WaitGroup{}, } } @@ -101,7 +97,7 @@ func (d *telnetClient) remote(addr string) { buf := [4096]byte{} - clientConn, clientConnErr := d.dial("tcp", addr, d.dialTimeout) + clientConn, clientConnErr := d.cfg.Dial("tcp", addr, d.cfg.DialTimeout) if clientConnErr != nil { errLen := copy( diff --git a/application/configuration/config.go b/application/configuration/config.go index abd378b..aeb4718 100644 --- a/application/configuration/config.go +++ b/application/configuration/config.go @@ -121,17 +121,19 @@ func (s Server) Verify() error { // Configuration contains configuration of the application type Configuration struct { - HostName string - SharedKey string - Dialer network.Dial - Servers []Server + HostName string + SharedKey string + Dialer network.Dial + DialTimeout time.Duration + Servers []Server } // Common settings shared by mulitple servers type Common struct { - HostName string - SharedKey string - Dialer network.Dial + HostName string + SharedKey string + Dialer network.Dial + DialTimeout time.Duration } // Verify verifies current setting @@ -155,24 +157,31 @@ func (c Configuration) Verify() error { // Common returns common settings func (c Configuration) Common() Common { - return Common{ - HostName: c.HostName, - SharedKey: c.SharedKey, - Dialer: c.Dialer, - } -} - -// WithDefault build the configuration and fill the blank with default values -func (c Common) WithDefault() Common { dialer := c.Dialer if dialer == nil { dialer = network.TCPDial() } + dialTimeout := c.DialTimeout + + if dialTimeout <= 1*time.Second { + dialTimeout = 1 * time.Second + } + return Common{ - HostName: c.HostName, - SharedKey: c.SharedKey, - Dialer: dialer, + HostName: c.HostName, + SharedKey: c.SharedKey, + Dialer: c.Dialer, + DialTimeout: c.DialTimeout, } } + +// DecideDialTimeout will return a reasonable timeout for dialing +func (c Common) DecideDialTimeout(max time.Duration) time.Duration { + if c.DialTimeout > max { + return max + } + + return c.DialTimeout +} diff --git a/application/configuration/loader_enviro.go b/application/configuration/loader_enviro.go index 0514516..99574a2 100644 --- a/application/configuration/loader_enviro.go +++ b/application/configuration/loader_enviro.go @@ -22,9 +22,9 @@ import ( "os" "strconv" "strings" + "time" "github.com/niruix/sshwifty/application/log" - "github.com/niruix/sshwifty/application/network" ) const ( @@ -46,12 +46,21 @@ func Enviro() Loader { return func(log log.Logger) (string, Configuration, error) { log.Info("Loading configuration from environment variables ...") - cfg := fileCfgCommon{ + dialTimeout, _ := strconv.ParseUint( + parseEviro("SSHWIFTY_DIALTIMEOUT"), 10, 32) + + cfg, dialer, cfgErr := fileCfgCommon{ HostName: parseEviro("SSHWIFTY_HOSTNAME"), SharedKey: parseEviro("SSHWIFTY_SHAREDKEY"), + DialTimeout: int(dialTimeout), Socks5: parseEviro("SSHWIFTY_SOCKS5"), Socks5User: parseEviro("SSHWIFTY_SOCKS5_USER"), Socks5Password: parseEviro("SSHWIFTY_SOCKS5_PASSWORD"), + }.build() + + if cfgErr != nil { + return enviroTypeName, Configuration{}, fmt.Errorf( + "Failed to build the configuration: %s", cfgErr) } listenPort, listenPortErr := strconv.ParseUint( @@ -93,26 +102,12 @@ func Enviro() Loader { TLSCertificateKeyFile: parseEviro("SSHWIFTY_TLSCERTIFICATEKEYFILE"), } - var dialer network.Dial - - if len(cfg.Socks5) <= 0 { - dialer = network.TCPDial() - } else { - sDial, sDialErr := network.BuildSocks5Dial( - cfg.Socks5, cfg.Socks5User, cfg.Socks5Password) - - if sDialErr != nil { - return enviroTypeName, Configuration{}, sDialErr - } - - dialer = sDial - } - return enviroTypeName, Configuration{ - HostName: cfg.HostName, - SharedKey: cfg.SharedKey, - Dialer: dialer, - Servers: []Server{cfgSer.build()}, + HostName: cfg.HostName, + SharedKey: cfg.SharedKey, + Dialer: dialer, + DialTimeout: time.Duration(cfg.DialTimeout) * time.Second, + Servers: []Server{cfgSer.build()}, }, nil } } diff --git a/application/configuration/loader_file.go b/application/configuration/loader_file.go index 77a9ef7..2c58031 100644 --- a/application/configuration/loader_file.go +++ b/application/configuration/loader_file.go @@ -49,7 +49,7 @@ type fileCfgServer struct { TLSCertificateKeyFile string // Location of TLS certificate key } -func (f fileCfgServer) minDur(current, min int) int { +func (f fileCfgServer) durationAtLeast(current, min int) int { if current > min { return current } @@ -62,17 +62,17 @@ func (f *fileCfgServer) build() Server { ListenInterface: f.ListenInterface, ListenPort: f.ListenPort, InitialTimeout: time.Duration( - f.minDur(f.InitialTimeout, 5)) * time.Second, + f.durationAtLeast(f.InitialTimeout, 5)) * time.Second, ReadTimeout: time.Duration( - f.minDur(f.ReadTimeout, 30)) * time.Second, + f.durationAtLeast(f.ReadTimeout, 30)) * time.Second, WriteTimeout: time.Duration( - f.minDur(f.WriteTimeout, 30)) * time.Second, + f.durationAtLeast(f.WriteTimeout, 30)) * time.Second, HeartbeatTimeout: time.Duration( - f.minDur(f.HeartbeatTimeout, 10)) * time.Second, + f.durationAtLeast(f.HeartbeatTimeout, 10)) * time.Second, ReadDelay: time.Duration( - f.minDur(f.ReadDelay, 0)) * time.Millisecond, + f.durationAtLeast(f.ReadDelay, 0)) * time.Millisecond, WriteDelay: time.Duration( - f.minDur(f.WriteDelay, 0)) * time.Millisecond, + f.durationAtLeast(f.WriteDelay, 0)) * time.Millisecond, TLSCertificateFile: f.TLSCertificateFile, TLSCertificateKeyFile: f.TLSCertificateKeyFile, } @@ -81,12 +81,46 @@ func (f *fileCfgServer) build() Server { type fileCfgCommon struct { HostName string // Host name SharedKey string // Shared key, empty to enable public access + DialTimeout int // DialTimeout, min 5s Socks5 string // Socks5 server address, optional Socks5User string // Login user for socks5 server, optional Socks5Password string // Login pass for socks5 server, optional Servers []*fileCfgServer // Servers } +func (f fileCfgCommon) build() (fileCfgCommon, network.Dial, error) { + dialTimeout := f.DialTimeout + + if dialTimeout < 3 { + dialTimeout = 3 + } + + var dialer network.Dial + + if len(f.Socks5) <= 0 { + dialer = network.TCPDial() + } else { + sDial, sDialErr := network.BuildSocks5Dial( + f.Socks5, f.Socks5User, f.Socks5Password) + + if sDialErr != nil { + return fileCfgCommon{}, nil, sDialErr + } + + dialer = sDial + } + + return fileCfgCommon{ + HostName: f.HostName, + SharedKey: f.SharedKey, + DialTimeout: dialTimeout, + Socks5: f.Socks5, + Socks5User: f.Socks5User, + Socks5Password: f.Socks5Password, + Servers: f.Servers, + }, dialer, nil +} + func loadFile(filePath string) (string, Configuration, error) { f, fErr := os.Open(filePath) @@ -105,32 +139,24 @@ func loadFile(filePath string) (string, Configuration, error) { return fileTypeName, Configuration{}, jDecodeErr } - servers := make([]Server, len(cfg.Servers)) + finalCfg, dialer, cfgErr := cfg.build() - for i := range servers { - servers[i] = cfg.Servers[i].build() + if cfgErr != nil { + return fileTypeName, Configuration{}, cfgErr } - var dialer network.Dial + servers := make([]Server, len(finalCfg.Servers)) - if len(cfg.Socks5) <= 0 { - dialer = network.TCPDial() - } else { - sDial, sDialErr := network.BuildSocks5Dial( - cfg.Socks5, cfg.Socks5User, cfg.Socks5Password) - - if sDialErr != nil { - return fileTypeName, Configuration{}, sDialErr - } - - dialer = sDial + for i := range servers { + servers[i] = finalCfg.Servers[i].build() } return fileTypeName, Configuration{ - HostName: cfg.HostName, - SharedKey: cfg.SharedKey, - Dialer: dialer, - Servers: servers, + HostName: finalCfg.HostName, + SharedKey: finalCfg.SharedKey, + Dialer: dialer, + DialTimeout: time.Duration(finalCfg.DialTimeout) * time.Second, + Servers: servers, }, nil } diff --git a/application/controller/socket.go b/application/controller/socket.go index 374ac1f..8188949 100644 --- a/application/controller/socket.go +++ b/application/controller/socket.go @@ -358,7 +358,11 @@ func (s socket) Get( senderLock := sync.Mutex{} cmdExec, cmdExecErr := s.commander.New( - s.commonCfg.Dialer, rw.NewFetchReader(func() ([]byte, error) { + command.CommandConfiguration{ + Dial: s.commonCfg.Dialer, + DialTimeout: s.commonCfg.DecideDialTimeout(s.serverCfg.ReadTimeout), + }, + rw.NewFetchReader(func() ([]byte, error) { defer s.increaseNonce(readNonce[:]) // Size is unencrypted @@ -402,7 +406,8 @@ func (s socket) Get( readNonce[:], cipherReadBuf[:packageSize], nil) - }), socketPackageWriter{ + }), + socketPackageWriter{ w: wsWriter, packager: func(w websocketWriter, b []byte) error { start := 0 diff --git a/application/network/dial_socks5.go b/application/network/dial_socks5.go index 8459b9c..8308207 100644 --- a/application/network/dial_socks5.go +++ b/application/network/dial_socks5.go @@ -42,7 +42,8 @@ func BuildSocks5Dial( timeout time.Duration, ) (net.Conn, error) { dial, dialErr := proxy.SOCKS5("tcp", socks5Address, auth, &net.Dialer{ - Timeout: timeout, + Timeout: timeout, + Deadline: time.Now().Add(timeout), }) if dialErr != nil { diff --git a/application/server/server.go b/application/server/server.go index 75bc8f5..fcea3f1 100644 --- a/application/server/server.go +++ b/application/server/server.go @@ -80,7 +80,6 @@ func (s Server) Serve( closeCallback CloseCallback, handlerBuilder HandlerBuilder, ) *Serving { - ccCfg := commonCfg.WithDefault() ssCfg := serverCfg.WithDefault() l := s.logger.Context( @@ -88,7 +87,7 @@ func (s Server) Serve( ss := &Serving{ server: http.Server{ - Handler: handlerBuilder(ccCfg, ssCfg, l), + Handler: handlerBuilder(commonCfg, ssCfg, l), ReadTimeout: ssCfg.ReadTimeout, ReadHeaderTimeout: ssCfg.InitialTimeout, WriteTimeout: ssCfg.WriteTimeout, diff --git a/sshwifty.conf.example.json b/sshwifty.conf.example.json index 479ac8b..b92f2ca 100644 --- a/sshwifty.conf.example.json +++ b/sshwifty.conf.example.json @@ -1,6 +1,7 @@ { "HostName": "", "SharedKey": "WEB_ACCESS_PASSWORD", + "DialTimeout": 5, "Socks5": "", "Socks5User": "", "Socks5Password": "",