SSH: Don't timeout during handshake if we're waiting for user input

This commit is contained in:
NI
2019-08-29 21:34:23 +08:00
parent 6d349f133e
commit 3749460d14

View File

@@ -111,6 +111,7 @@ type sshRemoteConnWrapper struct {
readTimeout time.Duration
enableTimeout bool
retryTimeout func() bool
}
func (s *sshRemoteConnWrapper) SetReadDeadline(t time.Time) error {
@@ -150,7 +151,25 @@ func (s *sshRemoteConnWrapper) Read(b []byte) (int, error) {
s.Conn.SetReadDeadline(time.Now().Add(s.readTimeout))
return s.Conn.Read(b)
for {
rLen, rErr := s.Conn.Read(b)
if rErr == nil {
return rLen, nil
}
if !s.enableTimeout {
return rLen, rErr
}
netErr, isNetErr := rErr.(net.Error)
if !isNetErr || !netErr.Timeout() || !s.retryTimeout() {
return rLen, rErr
}
s.Conn.SetReadDeadline(time.Now().Add(s.readTimeout))
}
}
type sshRemoteConn struct {
@@ -168,6 +187,8 @@ type sshClient struct {
l log.Logger
cfg command.Configuration
remoteCloseWait sync.WaitGroup
remoteReadTimeoutRetry bool
remoteReadTimeoutRetryLock sync.Mutex
credentialReceive chan []byte
credentialProcessed bool
credentialReceiveClosed bool
@@ -188,6 +209,8 @@ func newSSH(
l: l,
cfg: cfg,
remoteCloseWait: sync.WaitGroup{},
remoteReadTimeoutRetry: false,
remoteReadTimeoutRetryLock: sync.Mutex{},
credentialReceive: make(chan []byte, 1),
credentialProcessed: false,
credentialReceiveClosed: false,
@@ -261,6 +284,9 @@ func (d *sshClient) buildAuthMethod(
return func(b []byte) []ssh.AuthMethod {
return []ssh.AuthMethod{
ssh.PasswordCallback(func() (string, error) {
d.enableRemoteReadTimeoutRetry()
defer d.disableRemoteReadTimeoutRetry()
wErr := d.w.SendManual(
SSHServerConnectRequestCredential,
b[d.w.HeaderSize():],
@@ -285,6 +311,9 @@ func (d *sshClient) buildAuthMethod(
return func(b []byte) []ssh.AuthMethod {
return []ssh.AuthMethod{
ssh.PublicKeysCallback(func() ([]ssh.Signer, error) {
d.enableRemoteReadTimeoutRetry()
defer d.disableRemoteReadTimeoutRetry()
wErr := d.w.SendManual(
SSHServerConnectRequestCredential,
b[d.w.HeaderSize():],
@@ -321,6 +350,9 @@ func (d *sshClient) comfirmRemoteFingerprint(
key ssh.PublicKey,
buf []byte,
) error {
d.enableRemoteReadTimeoutRetry()
defer d.disableRemoteReadTimeoutRetry()
fgp := ssh.FingerprintSHA256(key)
fgpLen := copy(buf[d.w.HeaderSize():], fgp)
@@ -346,6 +378,20 @@ func (d *sshClient) comfirmRemoteFingerprint(
return nil
}
func (d *sshClient) enableRemoteReadTimeoutRetry() {
d.remoteReadTimeoutRetryLock.Lock()
defer d.remoteReadTimeoutRetryLock.Unlock()
d.remoteReadTimeoutRetry = true
}
func (d *sshClient) disableRemoteReadTimeoutRetry() {
d.remoteReadTimeoutRetryLock.Lock()
defer d.remoteReadTimeoutRetryLock.Unlock()
d.remoteReadTimeoutRetry = false
}
func (d *sshClient) dialRemote(
network,
addr string,
@@ -362,6 +408,12 @@ func (d *sshClient) dialRemote(
Conn: conn,
readTimeout: config.Timeout,
enableTimeout: true,
retryTimeout: func() bool {
d.remoteReadTimeoutRetryLock.Lock()
defer d.remoteReadTimeoutRetryLock.Unlock()
return d.remoteReadTimeoutRetry
},
}
c, chans, reqs, err := ssh.NewClientConn(&sshConn, addr, config)