From 020a1e8eaf31213086bf50e1c49c2a82d9129faa Mon Sep 17 00:00:00 2001 From: NI Date: Thu, 19 Sep 2019 14:54:54 +0800 Subject: [PATCH] Adding timeout detection to both SSH and Telnet command as well. This will prevent a dead remote connection to block backend request processing --- application/commands/ssh.go | 24 ++++++++++++++++-------- application/commands/telnet.go | 24 ++++++++++++++++-------- 2 files changed, 32 insertions(+), 16 deletions(-) diff --git a/application/commands/ssh.go b/application/commands/ssh.go index c5eb695..6451610 100644 --- a/application/commands/ssh.go +++ b/application/commands/ssh.go @@ -28,6 +28,7 @@ import ( "github.com/niruix/sshwifty/application/command" "github.com/niruix/sshwifty/application/log" + "github.com/niruix/sshwifty/application/network" "github.com/niruix/sshwifty/application/rw" ) @@ -109,6 +110,7 @@ var ( type sshRemoteConnWrapper struct { net.Conn + writerConn network.WriteTimeoutConn requestTimeoutRetry func(s *sshRemoteConnWrapper) bool } @@ -140,6 +142,10 @@ func (s *sshRemoteConnWrapper) Read(b []byte) (int, error) { } } +func (s *sshRemoteConnWrapper) Write(b []byte) (int, error) { + return s.writerConn.Write(b) +} + type sshRemoteConn struct { writer io.Writer closer func() error @@ -364,17 +370,18 @@ func (d *sshClient) disableRemoteReadTimeoutRetry() { } func (d *sshClient) dialRemote( - network, + networkName, addr string, config *ssh.ClientConfig) (*ssh.Client, func(), error) { - conn, err := d.cfg.Dial(network, addr, config.Timeout) + conn, err := d.cfg.Dial(networkName, addr, config.Timeout) if err != nil { return nil, nil, err } sshConn := &sshRemoteConnWrapper{ - Conn: conn, + Conn: conn, + writerConn: network.NewWriteTimeoutConn(conn, d.cfg.DialTimeout), requestTimeoutRetry: func(s *sshRemoteConnWrapper) bool { d.remoteReadTimeoutRetryLock.Lock() defer d.remoteReadTimeoutRetryLock.Unlock() @@ -393,6 +400,9 @@ func (d *sshClient) dialRemote( }, } + // Set timeout for writer, otherwise the Timeout writer will never + // be triggered + sshConn.SetWriteDeadline(time.Now().Add(d.cfg.DialTimeout)) sshConn.SetReadDeadline(time.Now().Add(config.Timeout)) c, chans, reqs, err := ssh.NewClientConn(sshConn, addr, config) @@ -527,11 +537,7 @@ func (d *sshClient) remote( d.remoteConnReceive <- sshRemoteConn{ writer: in, closer: func() error { - sErr := session.Close() - - if sErr != nil { - return sErr - } + session.Close() return conn.Close() }, @@ -626,6 +632,8 @@ func (d *sshClient) local( _, wErr := remote.writer.Write(rData) if wErr != nil { + remote.closer() + d.l.Debug("Failed to write data to remote: %s", wErr) } diff --git a/application/commands/telnet.go b/application/commands/telnet.go index 7e50ddf..275c006 100644 --- a/application/commands/telnet.go +++ b/application/commands/telnet.go @@ -19,11 +19,13 @@ package commands import ( "errors" - "io" + "net" "sync" + "time" "github.com/niruix/sshwifty/application/command" "github.com/niruix/sshwifty/application/log" + "github.com/niruix/sshwifty/application/network" "github.com/niruix/sshwifty/application/rw" ) @@ -49,8 +51,8 @@ type telnetClient struct { l log.Logger w command.StreamResponder cfg command.Configuration - remoteChan chan io.WriteCloser - remoteConn io.WriteCloser + remoteChan chan net.Conn + remoteConn net.Conn closeWait sync.WaitGroup } @@ -63,7 +65,7 @@ func newTelnet( l: l, w: w, cfg: cfg, - remoteChan: make(chan io.WriteCloser, 1), + remoteChan: make(chan net.Conn, 1), remoteConn: nil, closeWait: sync.WaitGroup{}, } @@ -79,8 +81,6 @@ func (d *telnetClient) Bootup( addrErr, TelnetRequestErrorBadRemoteAddress) } - // TODO: Test whether or not the address is allowed - d.closeWait.Add(1) go d.remote(addr.String()) @@ -118,7 +118,13 @@ func (d *telnetClient) remote(addr string) { return } - d.remoteChan <- clientConn + // Set timeout for writer, otherwise the Timeout writer will never + // be triggered + clientConn.SetWriteDeadline(time.Now().Add(d.cfg.DialTimeout)) + timeoutClientConn := network.NewWriteTimeoutConn( + clientConn, d.cfg.DialTimeout) + + d.remoteChan <- &timeoutClientConn for { rLen, rErr := clientConn.Read(buf[d.w.HeaderSize():]) @@ -136,7 +142,7 @@ func (d *telnetClient) remote(addr string) { } } -func (d *telnetClient) getRemote() (io.WriteCloser, error) { +func (d *telnetClient) getRemote() (net.Conn, error) { if d.remoteConn != nil { return d.remoteConn, nil } @@ -176,6 +182,8 @@ func (d *telnetClient) client( _, wErr := remoteConn.Write(rBuf) if wErr != nil { + remoteConn.Close() + d.l.Debug("Failed to write data to remote: %s", wErr) }