Add timeout for SSH handshake

This commit is contained in:
NI
2019-08-19 21:29:47 +08:00
parent 33076628a4
commit 1070c2bcf2

View File

@@ -22,6 +22,7 @@ import (
"io"
"net"
"sync"
"time"
"golang.org/x/crypto/ssh"
@@ -101,6 +102,57 @@ var (
"Unknown client signal")
)
var (
sshEmptyTime = time.Time{}
)
type sshRemoteConnWrapper struct {
net.Conn
readTimeout time.Duration
enableTimeout bool
}
func (s *sshRemoteConnWrapper) SetReadDeadline(t time.Time) error {
s.enableTimeout = false
return s.Conn.SetReadDeadline(t)
}
func (s *sshRemoteConnWrapper) SetWriteDeadline(t time.Time) error {
s.enableTimeout = false
return s.Conn.SetWriteDeadline(t)
}
func (s *sshRemoteConnWrapper) SetDeadline(t time.Time) error {
s.enableTimeout = false
return s.Conn.SetDeadline(t)
}
func (s *sshRemoteConnWrapper) Read(b []byte) (int, error) {
rLen, rErr := s.Conn.Read(b)
if rErr == nil {
return rLen, nil
}
if !s.enableTimeout {
return rLen, rErr
}
netErr, isNetErr := rErr.(net.Error)
if !isNetErr || !netErr.Timeout() {
return rLen, rErr
}
s.Conn.SetReadDeadline(time.Now().Add(s.readTimeout))
return s.Conn.Read(b)
}
type sshRemoteConn struct {
writer io.Writer
closer func() error
@@ -295,20 +347,36 @@ func (d *sshClient) comfirmRemoteFingerprint(
}
func (d *sshClient) dialRemote(
network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
network,
addr string,
config *ssh.ClientConfig) (*ssh.Client, func(), error) {
conn, err := d.cfg.Dial(network, addr, config.Timeout)
if err != nil {
return nil, err
return nil, nil, err
}
c, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
conn.SetReadDeadline(time.Now().Add(config.Timeout))
sshConn := sshRemoteConnWrapper{
Conn: conn,
readTimeout: config.Timeout,
enableTimeout: true,
}
c, chans, reqs, err := ssh.NewClientConn(&sshConn, addr, config)
if err != nil {
return nil, err
return nil, nil, err
}
return ssh.NewClient(c, chans, reqs), nil
return ssh.NewClient(c, chans, reqs), func() {
if sshConn.enableTimeout {
sshConn.SetReadDeadline(sshEmptyTime)
}
sshConn.enableTimeout = false
}, nil
}
func (d *sshClient) remote(
@@ -322,14 +390,15 @@ func (d *sshClient) remote(
buf := [4096]byte{}
conn, dErr := d.dialRemote("tcp", address, &ssh.ClientConfig{
User: user,
Auth: authMethodBuilder(buf[:]),
HostKeyCallback: func(h string, r net.Addr, k ssh.PublicKey) error {
return d.comfirmRemoteFingerprint(h, r, k, buf[:])
},
Timeout: d.cfg.DialTimeout,
})
conn, clearConnInitialDeadline, dErr :=
d.dialRemote("tcp", address, &ssh.ClientConfig{
User: user,
Auth: authMethodBuilder(buf[:]),
HostKeyCallback: func(h string, r net.Addr, k ssh.PublicKey) error {
return d.comfirmRemoteFingerprint(h, r, k, buf[:])
},
Timeout: d.cfg.DialTimeout,
})
if dErr != nil {
errLen := copy(buf[d.w.HeaderSize():], dErr.Error()) + d.w.HeaderSize()
@@ -418,6 +487,8 @@ func (d *sshClient) remote(
defer session.Wait()
clearConnInitialDeadline()
d.remoteConnReceive <- sshRemoteConn{
writer: in,
closer: func() error {