Add timeout for SSH handshake

This commit is contained in:
NI
2019-08-19 21:29:47 +08:00
parent 33076628a4
commit 1070c2bcf2

View File

@@ -22,6 +22,7 @@ import (
"io" "io"
"net" "net"
"sync" "sync"
"time"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
@@ -101,6 +102,57 @@ var (
"Unknown client signal") "Unknown client signal")
) )
var (
sshEmptyTime = time.Time{}
)
type sshRemoteConnWrapper struct {
net.Conn
readTimeout time.Duration
enableTimeout bool
}
func (s *sshRemoteConnWrapper) SetReadDeadline(t time.Time) error {
s.enableTimeout = false
return s.Conn.SetReadDeadline(t)
}
func (s *sshRemoteConnWrapper) SetWriteDeadline(t time.Time) error {
s.enableTimeout = false
return s.Conn.SetWriteDeadline(t)
}
func (s *sshRemoteConnWrapper) SetDeadline(t time.Time) error {
s.enableTimeout = false
return s.Conn.SetDeadline(t)
}
func (s *sshRemoteConnWrapper) Read(b []byte) (int, error) {
rLen, rErr := s.Conn.Read(b)
if rErr == nil {
return rLen, nil
}
if !s.enableTimeout {
return rLen, rErr
}
netErr, isNetErr := rErr.(net.Error)
if !isNetErr || !netErr.Timeout() {
return rLen, rErr
}
s.Conn.SetReadDeadline(time.Now().Add(s.readTimeout))
return s.Conn.Read(b)
}
type sshRemoteConn struct { type sshRemoteConn struct {
writer io.Writer writer io.Writer
closer func() error closer func() error
@@ -295,20 +347,36 @@ func (d *sshClient) comfirmRemoteFingerprint(
} }
func (d *sshClient) dialRemote( func (d *sshClient) dialRemote(
network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { network,
addr string,
config *ssh.ClientConfig) (*ssh.Client, func(), error) {
conn, err := d.cfg.Dial(network, addr, config.Timeout) conn, err := d.cfg.Dial(network, addr, config.Timeout)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
c, chans, reqs, err := ssh.NewClientConn(conn, addr, config) conn.SetReadDeadline(time.Now().Add(config.Timeout))
sshConn := sshRemoteConnWrapper{
Conn: conn,
readTimeout: config.Timeout,
enableTimeout: true,
}
c, chans, reqs, err := ssh.NewClientConn(&sshConn, addr, config)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
return ssh.NewClient(c, chans, reqs), nil return ssh.NewClient(c, chans, reqs), func() {
if sshConn.enableTimeout {
sshConn.SetReadDeadline(sshEmptyTime)
}
sshConn.enableTimeout = false
}, nil
} }
func (d *sshClient) remote( func (d *sshClient) remote(
@@ -322,7 +390,8 @@ func (d *sshClient) remote(
buf := [4096]byte{} buf := [4096]byte{}
conn, dErr := d.dialRemote("tcp", address, &ssh.ClientConfig{ conn, clearConnInitialDeadline, dErr :=
d.dialRemote("tcp", address, &ssh.ClientConfig{
User: user, User: user,
Auth: authMethodBuilder(buf[:]), Auth: authMethodBuilder(buf[:]),
HostKeyCallback: func(h string, r net.Addr, k ssh.PublicKey) error { HostKeyCallback: func(h string, r net.Addr, k ssh.PublicKey) error {
@@ -418,6 +487,8 @@ func (d *sshClient) remote(
defer session.Wait() defer session.Wait()
clearConnInitialDeadline()
d.remoteConnReceive <- sshRemoteConn{ d.remoteConnReceive <- sshRemoteConn{
writer: in, writer: in,
closer: func() error { closer: func() error {