SSH: Don't timeout during handshake if we're waiting for user input

This commit is contained in:
NI
2019-08-29 21:34:23 +08:00
parent 6d349f133e
commit 3749460d14

View File

@@ -111,6 +111,7 @@ type sshRemoteConnWrapper struct {
readTimeout time.Duration readTimeout time.Duration
enableTimeout bool enableTimeout bool
retryTimeout func() bool
} }
func (s *sshRemoteConnWrapper) SetReadDeadline(t time.Time) error { func (s *sshRemoteConnWrapper) SetReadDeadline(t time.Time) error {
@@ -150,7 +151,25 @@ func (s *sshRemoteConnWrapper) Read(b []byte) (int, error) {
s.Conn.SetReadDeadline(time.Now().Add(s.readTimeout)) s.Conn.SetReadDeadline(time.Now().Add(s.readTimeout))
return s.Conn.Read(b) for {
rLen, rErr := s.Conn.Read(b)
if rErr == nil {
return rLen, nil
}
if !s.enableTimeout {
return rLen, rErr
}
netErr, isNetErr := rErr.(net.Error)
if !isNetErr || !netErr.Timeout() || !s.retryTimeout() {
return rLen, rErr
}
s.Conn.SetReadDeadline(time.Now().Add(s.readTimeout))
}
} }
type sshRemoteConn struct { type sshRemoteConn struct {
@@ -168,6 +187,8 @@ type sshClient struct {
l log.Logger l log.Logger
cfg command.Configuration cfg command.Configuration
remoteCloseWait sync.WaitGroup remoteCloseWait sync.WaitGroup
remoteReadTimeoutRetry bool
remoteReadTimeoutRetryLock sync.Mutex
credentialReceive chan []byte credentialReceive chan []byte
credentialProcessed bool credentialProcessed bool
credentialReceiveClosed bool credentialReceiveClosed bool
@@ -188,6 +209,8 @@ func newSSH(
l: l, l: l,
cfg: cfg, cfg: cfg,
remoteCloseWait: sync.WaitGroup{}, remoteCloseWait: sync.WaitGroup{},
remoteReadTimeoutRetry: false,
remoteReadTimeoutRetryLock: sync.Mutex{},
credentialReceive: make(chan []byte, 1), credentialReceive: make(chan []byte, 1),
credentialProcessed: false, credentialProcessed: false,
credentialReceiveClosed: false, credentialReceiveClosed: false,
@@ -261,6 +284,9 @@ func (d *sshClient) buildAuthMethod(
return func(b []byte) []ssh.AuthMethod { return func(b []byte) []ssh.AuthMethod {
return []ssh.AuthMethod{ return []ssh.AuthMethod{
ssh.PasswordCallback(func() (string, error) { ssh.PasswordCallback(func() (string, error) {
d.enableRemoteReadTimeoutRetry()
defer d.disableRemoteReadTimeoutRetry()
wErr := d.w.SendManual( wErr := d.w.SendManual(
SSHServerConnectRequestCredential, SSHServerConnectRequestCredential,
b[d.w.HeaderSize():], b[d.w.HeaderSize():],
@@ -285,6 +311,9 @@ func (d *sshClient) buildAuthMethod(
return func(b []byte) []ssh.AuthMethod { return func(b []byte) []ssh.AuthMethod {
return []ssh.AuthMethod{ return []ssh.AuthMethod{
ssh.PublicKeysCallback(func() ([]ssh.Signer, error) { ssh.PublicKeysCallback(func() ([]ssh.Signer, error) {
d.enableRemoteReadTimeoutRetry()
defer d.disableRemoteReadTimeoutRetry()
wErr := d.w.SendManual( wErr := d.w.SendManual(
SSHServerConnectRequestCredential, SSHServerConnectRequestCredential,
b[d.w.HeaderSize():], b[d.w.HeaderSize():],
@@ -321,6 +350,9 @@ func (d *sshClient) comfirmRemoteFingerprint(
key ssh.PublicKey, key ssh.PublicKey,
buf []byte, buf []byte,
) error { ) error {
d.enableRemoteReadTimeoutRetry()
defer d.disableRemoteReadTimeoutRetry()
fgp := ssh.FingerprintSHA256(key) fgp := ssh.FingerprintSHA256(key)
fgpLen := copy(buf[d.w.HeaderSize():], fgp) fgpLen := copy(buf[d.w.HeaderSize():], fgp)
@@ -346,6 +378,20 @@ func (d *sshClient) comfirmRemoteFingerprint(
return nil return nil
} }
func (d *sshClient) enableRemoteReadTimeoutRetry() {
d.remoteReadTimeoutRetryLock.Lock()
defer d.remoteReadTimeoutRetryLock.Unlock()
d.remoteReadTimeoutRetry = true
}
func (d *sshClient) disableRemoteReadTimeoutRetry() {
d.remoteReadTimeoutRetryLock.Lock()
defer d.remoteReadTimeoutRetryLock.Unlock()
d.remoteReadTimeoutRetry = false
}
func (d *sshClient) dialRemote( func (d *sshClient) dialRemote(
network, network,
addr string, addr string,
@@ -362,6 +408,12 @@ func (d *sshClient) dialRemote(
Conn: conn, Conn: conn,
readTimeout: config.Timeout, readTimeout: config.Timeout,
enableTimeout: true, enableTimeout: true,
retryTimeout: func() bool {
d.remoteReadTimeoutRetryLock.Lock()
defer d.remoteReadTimeoutRetryLock.Unlock()
return d.remoteReadTimeoutRetry
},
} }
c, chans, reqs, err := ssh.NewClientConn(&sshConn, addr, config) c, chans, reqs, err := ssh.NewClientConn(&sshConn, addr, config)