SSH: Don't timeout during handshake if we're waiting for user input
This commit is contained in:
@@ -111,6 +111,7 @@ type sshRemoteConnWrapper struct {
|
|||||||
|
|
||||||
readTimeout time.Duration
|
readTimeout time.Duration
|
||||||
enableTimeout bool
|
enableTimeout bool
|
||||||
|
retryTimeout func() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sshRemoteConnWrapper) SetReadDeadline(t time.Time) error {
|
func (s *sshRemoteConnWrapper) SetReadDeadline(t time.Time) error {
|
||||||
@@ -150,7 +151,25 @@ func (s *sshRemoteConnWrapper) Read(b []byte) (int, error) {
|
|||||||
|
|
||||||
s.Conn.SetReadDeadline(time.Now().Add(s.readTimeout))
|
s.Conn.SetReadDeadline(time.Now().Add(s.readTimeout))
|
||||||
|
|
||||||
return s.Conn.Read(b)
|
for {
|
||||||
|
rLen, rErr := s.Conn.Read(b)
|
||||||
|
|
||||||
|
if rErr == nil {
|
||||||
|
return rLen, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !s.enableTimeout {
|
||||||
|
return rLen, rErr
|
||||||
|
}
|
||||||
|
|
||||||
|
netErr, isNetErr := rErr.(net.Error)
|
||||||
|
|
||||||
|
if !isNetErr || !netErr.Timeout() || !s.retryTimeout() {
|
||||||
|
return rLen, rErr
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Conn.SetReadDeadline(time.Now().Add(s.readTimeout))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type sshRemoteConn struct {
|
type sshRemoteConn struct {
|
||||||
@@ -168,6 +187,8 @@ type sshClient struct {
|
|||||||
l log.Logger
|
l log.Logger
|
||||||
cfg command.Configuration
|
cfg command.Configuration
|
||||||
remoteCloseWait sync.WaitGroup
|
remoteCloseWait sync.WaitGroup
|
||||||
|
remoteReadTimeoutRetry bool
|
||||||
|
remoteReadTimeoutRetryLock sync.Mutex
|
||||||
credentialReceive chan []byte
|
credentialReceive chan []byte
|
||||||
credentialProcessed bool
|
credentialProcessed bool
|
||||||
credentialReceiveClosed bool
|
credentialReceiveClosed bool
|
||||||
@@ -188,6 +209,8 @@ func newSSH(
|
|||||||
l: l,
|
l: l,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
remoteCloseWait: sync.WaitGroup{},
|
remoteCloseWait: sync.WaitGroup{},
|
||||||
|
remoteReadTimeoutRetry: false,
|
||||||
|
remoteReadTimeoutRetryLock: sync.Mutex{},
|
||||||
credentialReceive: make(chan []byte, 1),
|
credentialReceive: make(chan []byte, 1),
|
||||||
credentialProcessed: false,
|
credentialProcessed: false,
|
||||||
credentialReceiveClosed: false,
|
credentialReceiveClosed: false,
|
||||||
@@ -261,6 +284,9 @@ func (d *sshClient) buildAuthMethod(
|
|||||||
return func(b []byte) []ssh.AuthMethod {
|
return func(b []byte) []ssh.AuthMethod {
|
||||||
return []ssh.AuthMethod{
|
return []ssh.AuthMethod{
|
||||||
ssh.PasswordCallback(func() (string, error) {
|
ssh.PasswordCallback(func() (string, error) {
|
||||||
|
d.enableRemoteReadTimeoutRetry()
|
||||||
|
defer d.disableRemoteReadTimeoutRetry()
|
||||||
|
|
||||||
wErr := d.w.SendManual(
|
wErr := d.w.SendManual(
|
||||||
SSHServerConnectRequestCredential,
|
SSHServerConnectRequestCredential,
|
||||||
b[d.w.HeaderSize():],
|
b[d.w.HeaderSize():],
|
||||||
@@ -285,6 +311,9 @@ func (d *sshClient) buildAuthMethod(
|
|||||||
return func(b []byte) []ssh.AuthMethod {
|
return func(b []byte) []ssh.AuthMethod {
|
||||||
return []ssh.AuthMethod{
|
return []ssh.AuthMethod{
|
||||||
ssh.PublicKeysCallback(func() ([]ssh.Signer, error) {
|
ssh.PublicKeysCallback(func() ([]ssh.Signer, error) {
|
||||||
|
d.enableRemoteReadTimeoutRetry()
|
||||||
|
defer d.disableRemoteReadTimeoutRetry()
|
||||||
|
|
||||||
wErr := d.w.SendManual(
|
wErr := d.w.SendManual(
|
||||||
SSHServerConnectRequestCredential,
|
SSHServerConnectRequestCredential,
|
||||||
b[d.w.HeaderSize():],
|
b[d.w.HeaderSize():],
|
||||||
@@ -321,6 +350,9 @@ func (d *sshClient) comfirmRemoteFingerprint(
|
|||||||
key ssh.PublicKey,
|
key ssh.PublicKey,
|
||||||
buf []byte,
|
buf []byte,
|
||||||
) error {
|
) error {
|
||||||
|
d.enableRemoteReadTimeoutRetry()
|
||||||
|
defer d.disableRemoteReadTimeoutRetry()
|
||||||
|
|
||||||
fgp := ssh.FingerprintSHA256(key)
|
fgp := ssh.FingerprintSHA256(key)
|
||||||
fgpLen := copy(buf[d.w.HeaderSize():], fgp)
|
fgpLen := copy(buf[d.w.HeaderSize():], fgp)
|
||||||
|
|
||||||
@@ -346,6 +378,20 @@ func (d *sshClient) comfirmRemoteFingerprint(
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *sshClient) enableRemoteReadTimeoutRetry() {
|
||||||
|
d.remoteReadTimeoutRetryLock.Lock()
|
||||||
|
defer d.remoteReadTimeoutRetryLock.Unlock()
|
||||||
|
|
||||||
|
d.remoteReadTimeoutRetry = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *sshClient) disableRemoteReadTimeoutRetry() {
|
||||||
|
d.remoteReadTimeoutRetryLock.Lock()
|
||||||
|
defer d.remoteReadTimeoutRetryLock.Unlock()
|
||||||
|
|
||||||
|
d.remoteReadTimeoutRetry = false
|
||||||
|
}
|
||||||
|
|
||||||
func (d *sshClient) dialRemote(
|
func (d *sshClient) dialRemote(
|
||||||
network,
|
network,
|
||||||
addr string,
|
addr string,
|
||||||
@@ -362,6 +408,12 @@ func (d *sshClient) dialRemote(
|
|||||||
Conn: conn,
|
Conn: conn,
|
||||||
readTimeout: config.Timeout,
|
readTimeout: config.Timeout,
|
||||||
enableTimeout: true,
|
enableTimeout: true,
|
||||||
|
retryTimeout: func() bool {
|
||||||
|
d.remoteReadTimeoutRetryLock.Lock()
|
||||||
|
defer d.remoteReadTimeoutRetryLock.Unlock()
|
||||||
|
|
||||||
|
return d.remoteReadTimeoutRetry
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
c, chans, reqs, err := ssh.NewClientConn(&sshConn, addr, config)
|
c, chans, reqs, err := ssh.NewClientConn(&sshConn, addr, config)
|
||||||
|
|||||||
Reference in New Issue
Block a user