diff --git a/application/commands/ssh.go b/application/commands/ssh.go index f5a8f06..c5eb695 100644 --- a/application/commands/ssh.go +++ b/application/commands/ssh.go @@ -109,27 +109,7 @@ var ( type sshRemoteConnWrapper struct { net.Conn - readTimeout time.Duration - enableTimeout bool - retryTimeout func() bool -} - -func (s *sshRemoteConnWrapper) SetReadDeadline(t time.Time) error { - s.enableTimeout = false - - return s.Conn.SetReadDeadline(t) -} - -func (s *sshRemoteConnWrapper) SetWriteDeadline(t time.Time) error { - s.enableTimeout = false - - return s.Conn.SetWriteDeadline(t) -} - -func (s *sshRemoteConnWrapper) SetDeadline(t time.Time) error { - s.enableTimeout = false - - return s.Conn.SetDeadline(t) + requestTimeoutRetry func(s *sshRemoteConnWrapper) bool } func (s *sshRemoteConnWrapper) Read(b []byte) (int, error) { @@ -139,18 +119,12 @@ func (s *sshRemoteConnWrapper) Read(b []byte) (int, error) { return rLen, nil } - if !s.enableTimeout { - return rLen, rErr - } - netErr, isNetErr := rErr.(net.Error) - if !isNetErr || !netErr.Timeout() { + if !isNetErr || !netErr.Timeout() || !s.requestTimeoutRetry(s) { return rLen, rErr } - s.Conn.SetReadDeadline(time.Now().Add(s.readTimeout)) - for { rLen, rErr := s.Conn.Read(b) @@ -158,17 +132,11 @@ func (s *sshRemoteConnWrapper) Read(b []byte) (int, error) { return rLen, nil } - if !s.enableTimeout { - return rLen, rErr - } - netErr, isNetErr := rErr.(net.Error) - if !isNetErr || !netErr.Timeout() || !s.retryTimeout() { + if !isNetErr || !netErr.Timeout() || !s.requestTimeoutRetry(s) { return rLen, rErr } - - s.Conn.SetReadDeadline(time.Now().Add(s.readTimeout)) } } @@ -188,6 +156,7 @@ type sshClient struct { cfg command.Configuration remoteCloseWait sync.WaitGroup remoteReadTimeoutRetry bool + remoteReadForceRetryNextTimeout bool remoteReadTimeoutRetryLock sync.Mutex credentialReceive chan []byte credentialProcessed bool @@ -210,6 +179,7 @@ func newSSH( cfg: cfg, remoteCloseWait: sync.WaitGroup{}, remoteReadTimeoutRetry: false, + remoteReadForceRetryNextTimeout: false, remoteReadTimeoutRetryLock: sync.Mutex{}, credentialReceive: make(chan []byte, 1), credentialProcessed: false, @@ -390,6 +360,7 @@ func (d *sshClient) disableRemoteReadTimeoutRetry() { defer d.remoteReadTimeoutRetryLock.Unlock() d.remoteReadTimeoutRetry = false + d.remoteReadForceRetryNextTimeout = true } func (d *sshClient) dialRemote( @@ -402,21 +373,29 @@ func (d *sshClient) dialRemote( return nil, nil, err } - conn.SetReadDeadline(time.Now().Add(config.Timeout)) - - sshConn := sshRemoteConnWrapper{ - Conn: conn, - readTimeout: config.Timeout, - enableTimeout: true, - retryTimeout: func() bool { + sshConn := &sshRemoteConnWrapper{ + Conn: conn, + requestTimeoutRetry: func(s *sshRemoteConnWrapper) bool { d.remoteReadTimeoutRetryLock.Lock() defer d.remoteReadTimeoutRetryLock.Unlock() - return d.remoteReadTimeoutRetry + if !d.remoteReadTimeoutRetry { + if !d.remoteReadForceRetryNextTimeout { + return false + } + + d.remoteReadForceRetryNextTimeout = false + } + + s.SetReadDeadline(time.Now().Add(config.Timeout)) + + return true }, } - c, chans, reqs, err := ssh.NewClientConn(&sshConn, addr, config) + sshConn.SetReadDeadline(time.Now().Add(config.Timeout)) + + c, chans, reqs, err := ssh.NewClientConn(sshConn, addr, config) if err != nil { sshConn.Close() @@ -425,11 +404,13 @@ func (d *sshClient) dialRemote( } return ssh.NewClient(c, chans, reqs), func() { - if sshConn.enableTimeout { - sshConn.SetReadDeadline(sshEmptyTime) - } + d.remoteReadTimeoutRetryLock.Lock() + defer d.remoteReadTimeoutRetryLock.Unlock() - sshConn.enableTimeout = false + d.remoteReadTimeoutRetry = false + d.remoteReadForceRetryNextTimeout = true + + sshConn.SetReadDeadline(sshEmptyTime) }, nil }