Fix the problematic SSH handshake timeout

This commit is contained in:
NI
2019-09-02 22:21:23 +08:00
parent c0ad0addad
commit bf68b88919

View File

@@ -109,27 +109,7 @@ var (
type sshRemoteConnWrapper struct { type sshRemoteConnWrapper struct {
net.Conn net.Conn
readTimeout time.Duration requestTimeoutRetry func(s *sshRemoteConnWrapper) bool
enableTimeout bool
retryTimeout func() bool
}
func (s *sshRemoteConnWrapper) SetReadDeadline(t time.Time) error {
s.enableTimeout = false
return s.Conn.SetReadDeadline(t)
}
func (s *sshRemoteConnWrapper) SetWriteDeadline(t time.Time) error {
s.enableTimeout = false
return s.Conn.SetWriteDeadline(t)
}
func (s *sshRemoteConnWrapper) SetDeadline(t time.Time) error {
s.enableTimeout = false
return s.Conn.SetDeadline(t)
} }
func (s *sshRemoteConnWrapper) Read(b []byte) (int, error) { func (s *sshRemoteConnWrapper) Read(b []byte) (int, error) {
@@ -139,18 +119,12 @@ func (s *sshRemoteConnWrapper) Read(b []byte) (int, error) {
return rLen, nil return rLen, nil
} }
if !s.enableTimeout {
return rLen, rErr
}
netErr, isNetErr := rErr.(net.Error) netErr, isNetErr := rErr.(net.Error)
if !isNetErr || !netErr.Timeout() { if !isNetErr || !netErr.Timeout() || !s.requestTimeoutRetry(s) {
return rLen, rErr return rLen, rErr
} }
s.Conn.SetReadDeadline(time.Now().Add(s.readTimeout))
for { for {
rLen, rErr := s.Conn.Read(b) rLen, rErr := s.Conn.Read(b)
@@ -158,17 +132,11 @@ func (s *sshRemoteConnWrapper) Read(b []byte) (int, error) {
return rLen, nil return rLen, nil
} }
if !s.enableTimeout {
return rLen, rErr
}
netErr, isNetErr := rErr.(net.Error) netErr, isNetErr := rErr.(net.Error)
if !isNetErr || !netErr.Timeout() || !s.retryTimeout() { if !isNetErr || !netErr.Timeout() || !s.requestTimeoutRetry(s) {
return rLen, rErr return rLen, rErr
} }
s.Conn.SetReadDeadline(time.Now().Add(s.readTimeout))
} }
} }
@@ -188,6 +156,7 @@ type sshClient struct {
cfg command.Configuration cfg command.Configuration
remoteCloseWait sync.WaitGroup remoteCloseWait sync.WaitGroup
remoteReadTimeoutRetry bool remoteReadTimeoutRetry bool
remoteReadForceRetryNextTimeout bool
remoteReadTimeoutRetryLock sync.Mutex remoteReadTimeoutRetryLock sync.Mutex
credentialReceive chan []byte credentialReceive chan []byte
credentialProcessed bool credentialProcessed bool
@@ -210,6 +179,7 @@ func newSSH(
cfg: cfg, cfg: cfg,
remoteCloseWait: sync.WaitGroup{}, remoteCloseWait: sync.WaitGroup{},
remoteReadTimeoutRetry: false, remoteReadTimeoutRetry: false,
remoteReadForceRetryNextTimeout: false,
remoteReadTimeoutRetryLock: sync.Mutex{}, remoteReadTimeoutRetryLock: sync.Mutex{},
credentialReceive: make(chan []byte, 1), credentialReceive: make(chan []byte, 1),
credentialProcessed: false, credentialProcessed: false,
@@ -390,6 +360,7 @@ func (d *sshClient) disableRemoteReadTimeoutRetry() {
defer d.remoteReadTimeoutRetryLock.Unlock() defer d.remoteReadTimeoutRetryLock.Unlock()
d.remoteReadTimeoutRetry = false d.remoteReadTimeoutRetry = false
d.remoteReadForceRetryNextTimeout = true
} }
func (d *sshClient) dialRemote( func (d *sshClient) dialRemote(
@@ -402,21 +373,29 @@ func (d *sshClient) dialRemote(
return nil, nil, err return nil, nil, err
} }
conn.SetReadDeadline(time.Now().Add(config.Timeout)) sshConn := &sshRemoteConnWrapper{
Conn: conn,
sshConn := sshRemoteConnWrapper{ requestTimeoutRetry: func(s *sshRemoteConnWrapper) bool {
Conn: conn,
readTimeout: config.Timeout,
enableTimeout: true,
retryTimeout: func() bool {
d.remoteReadTimeoutRetryLock.Lock() d.remoteReadTimeoutRetryLock.Lock()
defer d.remoteReadTimeoutRetryLock.Unlock() defer d.remoteReadTimeoutRetryLock.Unlock()
return d.remoteReadTimeoutRetry if !d.remoteReadTimeoutRetry {
if !d.remoteReadForceRetryNextTimeout {
return false
}
d.remoteReadForceRetryNextTimeout = false
}
s.SetReadDeadline(time.Now().Add(config.Timeout))
return true
}, },
} }
c, chans, reqs, err := ssh.NewClientConn(&sshConn, addr, config) sshConn.SetReadDeadline(time.Now().Add(config.Timeout))
c, chans, reqs, err := ssh.NewClientConn(sshConn, addr, config)
if err != nil { if err != nil {
sshConn.Close() sshConn.Close()
@@ -425,11 +404,13 @@ func (d *sshClient) dialRemote(
} }
return ssh.NewClient(c, chans, reqs), func() { return ssh.NewClient(c, chans, reqs), func() {
if sshConn.enableTimeout { d.remoteReadTimeoutRetryLock.Lock()
sshConn.SetReadDeadline(sshEmptyTime) defer d.remoteReadTimeoutRetryLock.Unlock()
}
sshConn.enableTimeout = false d.remoteReadTimeoutRetry = false
d.remoteReadForceRetryNextTimeout = true
sshConn.SetReadDeadline(sshEmptyTime)
}, nil }, nil
} }