diff --git a/application/commands/ssh.go b/application/commands/ssh.go index 5910aaf..756f0e4 100644 --- a/application/commands/ssh.go +++ b/application/commands/ssh.go @@ -111,6 +111,7 @@ type sshRemoteConnWrapper struct { readTimeout time.Duration enableTimeout bool + retryTimeout func() bool } func (s *sshRemoteConnWrapper) SetReadDeadline(t time.Time) error { @@ -150,7 +151,25 @@ func (s *sshRemoteConnWrapper) Read(b []byte) (int, error) { s.Conn.SetReadDeadline(time.Now().Add(s.readTimeout)) - return s.Conn.Read(b) + for { + rLen, rErr := s.Conn.Read(b) + + if rErr == nil { + return rLen, nil + } + + if !s.enableTimeout { + return rLen, rErr + } + + netErr, isNetErr := rErr.(net.Error) + + if !isNetErr || !netErr.Timeout() || !s.retryTimeout() { + return rLen, rErr + } + + s.Conn.SetReadDeadline(time.Now().Add(s.readTimeout)) + } } type sshRemoteConn struct { @@ -168,6 +187,8 @@ type sshClient struct { l log.Logger cfg command.Configuration remoteCloseWait sync.WaitGroup + remoteReadTimeoutRetry bool + remoteReadTimeoutRetryLock sync.Mutex credentialReceive chan []byte credentialProcessed bool credentialReceiveClosed bool @@ -188,6 +209,8 @@ func newSSH( l: l, cfg: cfg, remoteCloseWait: sync.WaitGroup{}, + remoteReadTimeoutRetry: false, + remoteReadTimeoutRetryLock: sync.Mutex{}, credentialReceive: make(chan []byte, 1), credentialProcessed: false, credentialReceiveClosed: false, @@ -261,6 +284,9 @@ func (d *sshClient) buildAuthMethod( return func(b []byte) []ssh.AuthMethod { return []ssh.AuthMethod{ ssh.PasswordCallback(func() (string, error) { + d.enableRemoteReadTimeoutRetry() + defer d.disableRemoteReadTimeoutRetry() + wErr := d.w.SendManual( SSHServerConnectRequestCredential, b[d.w.HeaderSize():], @@ -285,6 +311,9 @@ func (d *sshClient) buildAuthMethod( return func(b []byte) []ssh.AuthMethod { return []ssh.AuthMethod{ ssh.PublicKeysCallback(func() ([]ssh.Signer, error) { + d.enableRemoteReadTimeoutRetry() + defer d.disableRemoteReadTimeoutRetry() + wErr := d.w.SendManual( SSHServerConnectRequestCredential, b[d.w.HeaderSize():], @@ -321,6 +350,9 @@ func (d *sshClient) comfirmRemoteFingerprint( key ssh.PublicKey, buf []byte, ) error { + d.enableRemoteReadTimeoutRetry() + defer d.disableRemoteReadTimeoutRetry() + fgp := ssh.FingerprintSHA256(key) fgpLen := copy(buf[d.w.HeaderSize():], fgp) @@ -346,6 +378,20 @@ func (d *sshClient) comfirmRemoteFingerprint( return nil } +func (d *sshClient) enableRemoteReadTimeoutRetry() { + d.remoteReadTimeoutRetryLock.Lock() + defer d.remoteReadTimeoutRetryLock.Unlock() + + d.remoteReadTimeoutRetry = true +} + +func (d *sshClient) disableRemoteReadTimeoutRetry() { + d.remoteReadTimeoutRetryLock.Lock() + defer d.remoteReadTimeoutRetryLock.Unlock() + + d.remoteReadTimeoutRetry = false +} + func (d *sshClient) dialRemote( network, addr string, @@ -362,6 +408,12 @@ func (d *sshClient) dialRemote( Conn: conn, readTimeout: config.Timeout, enableTimeout: true, + retryTimeout: func() bool { + d.remoteReadTimeoutRetryLock.Lock() + defer d.remoteReadTimeoutRetryLock.Unlock() + + return d.remoteReadTimeoutRetry + }, } c, chans, reqs, err := ssh.NewClientConn(&sshConn, addr, config)