SSH: Don't timeout during handshake if we're waiting for user input
This commit is contained in:
@@ -111,6 +111,7 @@ type sshRemoteConnWrapper struct {
|
||||
|
||||
readTimeout time.Duration
|
||||
enableTimeout bool
|
||||
retryTimeout func() bool
|
||||
}
|
||||
|
||||
func (s *sshRemoteConnWrapper) SetReadDeadline(t time.Time) error {
|
||||
@@ -150,7 +151,25 @@ func (s *sshRemoteConnWrapper) Read(b []byte) (int, error) {
|
||||
|
||||
s.Conn.SetReadDeadline(time.Now().Add(s.readTimeout))
|
||||
|
||||
return s.Conn.Read(b)
|
||||
for {
|
||||
rLen, rErr := s.Conn.Read(b)
|
||||
|
||||
if rErr == nil {
|
||||
return rLen, nil
|
||||
}
|
||||
|
||||
if !s.enableTimeout {
|
||||
return rLen, rErr
|
||||
}
|
||||
|
||||
netErr, isNetErr := rErr.(net.Error)
|
||||
|
||||
if !isNetErr || !netErr.Timeout() || !s.retryTimeout() {
|
||||
return rLen, rErr
|
||||
}
|
||||
|
||||
s.Conn.SetReadDeadline(time.Now().Add(s.readTimeout))
|
||||
}
|
||||
}
|
||||
|
||||
type sshRemoteConn struct {
|
||||
@@ -168,6 +187,8 @@ type sshClient struct {
|
||||
l log.Logger
|
||||
cfg command.Configuration
|
||||
remoteCloseWait sync.WaitGroup
|
||||
remoteReadTimeoutRetry bool
|
||||
remoteReadTimeoutRetryLock sync.Mutex
|
||||
credentialReceive chan []byte
|
||||
credentialProcessed bool
|
||||
credentialReceiveClosed bool
|
||||
@@ -188,6 +209,8 @@ func newSSH(
|
||||
l: l,
|
||||
cfg: cfg,
|
||||
remoteCloseWait: sync.WaitGroup{},
|
||||
remoteReadTimeoutRetry: false,
|
||||
remoteReadTimeoutRetryLock: sync.Mutex{},
|
||||
credentialReceive: make(chan []byte, 1),
|
||||
credentialProcessed: false,
|
||||
credentialReceiveClosed: false,
|
||||
@@ -261,6 +284,9 @@ func (d *sshClient) buildAuthMethod(
|
||||
return func(b []byte) []ssh.AuthMethod {
|
||||
return []ssh.AuthMethod{
|
||||
ssh.PasswordCallback(func() (string, error) {
|
||||
d.enableRemoteReadTimeoutRetry()
|
||||
defer d.disableRemoteReadTimeoutRetry()
|
||||
|
||||
wErr := d.w.SendManual(
|
||||
SSHServerConnectRequestCredential,
|
||||
b[d.w.HeaderSize():],
|
||||
@@ -285,6 +311,9 @@ func (d *sshClient) buildAuthMethod(
|
||||
return func(b []byte) []ssh.AuthMethod {
|
||||
return []ssh.AuthMethod{
|
||||
ssh.PublicKeysCallback(func() ([]ssh.Signer, error) {
|
||||
d.enableRemoteReadTimeoutRetry()
|
||||
defer d.disableRemoteReadTimeoutRetry()
|
||||
|
||||
wErr := d.w.SendManual(
|
||||
SSHServerConnectRequestCredential,
|
||||
b[d.w.HeaderSize():],
|
||||
@@ -321,6 +350,9 @@ func (d *sshClient) comfirmRemoteFingerprint(
|
||||
key ssh.PublicKey,
|
||||
buf []byte,
|
||||
) error {
|
||||
d.enableRemoteReadTimeoutRetry()
|
||||
defer d.disableRemoteReadTimeoutRetry()
|
||||
|
||||
fgp := ssh.FingerprintSHA256(key)
|
||||
fgpLen := copy(buf[d.w.HeaderSize():], fgp)
|
||||
|
||||
@@ -346,6 +378,20 @@ func (d *sshClient) comfirmRemoteFingerprint(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *sshClient) enableRemoteReadTimeoutRetry() {
|
||||
d.remoteReadTimeoutRetryLock.Lock()
|
||||
defer d.remoteReadTimeoutRetryLock.Unlock()
|
||||
|
||||
d.remoteReadTimeoutRetry = true
|
||||
}
|
||||
|
||||
func (d *sshClient) disableRemoteReadTimeoutRetry() {
|
||||
d.remoteReadTimeoutRetryLock.Lock()
|
||||
defer d.remoteReadTimeoutRetryLock.Unlock()
|
||||
|
||||
d.remoteReadTimeoutRetry = false
|
||||
}
|
||||
|
||||
func (d *sshClient) dialRemote(
|
||||
network,
|
||||
addr string,
|
||||
@@ -362,6 +408,12 @@ func (d *sshClient) dialRemote(
|
||||
Conn: conn,
|
||||
readTimeout: config.Timeout,
|
||||
enableTimeout: true,
|
||||
retryTimeout: func() bool {
|
||||
d.remoteReadTimeoutRetryLock.Lock()
|
||||
defer d.remoteReadTimeoutRetryLock.Unlock()
|
||||
|
||||
return d.remoteReadTimeoutRetry
|
||||
},
|
||||
}
|
||||
|
||||
c, chans, reqs, err := ssh.NewClientConn(&sshConn, addr, config)
|
||||
|
||||
Reference in New Issue
Block a user