Make DialTimeout configurable.
This commit is contained in:
@@ -121,17 +121,19 @@ func (s Server) Verify() error {
|
||||
|
||||
// Configuration contains configuration of the application
|
||||
type Configuration struct {
|
||||
HostName string
|
||||
SharedKey string
|
||||
Dialer network.Dial
|
||||
Servers []Server
|
||||
HostName string
|
||||
SharedKey string
|
||||
Dialer network.Dial
|
||||
DialTimeout time.Duration
|
||||
Servers []Server
|
||||
}
|
||||
|
||||
// Common settings shared by mulitple servers
|
||||
type Common struct {
|
||||
HostName string
|
||||
SharedKey string
|
||||
Dialer network.Dial
|
||||
HostName string
|
||||
SharedKey string
|
||||
Dialer network.Dial
|
||||
DialTimeout time.Duration
|
||||
}
|
||||
|
||||
// Verify verifies current setting
|
||||
@@ -155,24 +157,31 @@ func (c Configuration) Verify() error {
|
||||
|
||||
// Common returns common settings
|
||||
func (c Configuration) Common() Common {
|
||||
return Common{
|
||||
HostName: c.HostName,
|
||||
SharedKey: c.SharedKey,
|
||||
Dialer: c.Dialer,
|
||||
}
|
||||
}
|
||||
|
||||
// WithDefault build the configuration and fill the blank with default values
|
||||
func (c Common) WithDefault() Common {
|
||||
dialer := c.Dialer
|
||||
|
||||
if dialer == nil {
|
||||
dialer = network.TCPDial()
|
||||
}
|
||||
|
||||
dialTimeout := c.DialTimeout
|
||||
|
||||
if dialTimeout <= 1*time.Second {
|
||||
dialTimeout = 1 * time.Second
|
||||
}
|
||||
|
||||
return Common{
|
||||
HostName: c.HostName,
|
||||
SharedKey: c.SharedKey,
|
||||
Dialer: dialer,
|
||||
HostName: c.HostName,
|
||||
SharedKey: c.SharedKey,
|
||||
Dialer: c.Dialer,
|
||||
DialTimeout: c.DialTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// DecideDialTimeout will return a reasonable timeout for dialing
|
||||
func (c Common) DecideDialTimeout(max time.Duration) time.Duration {
|
||||
if c.DialTimeout > max {
|
||||
return max
|
||||
}
|
||||
|
||||
return c.DialTimeout
|
||||
}
|
||||
|
||||
@@ -22,9 +22,9 @@ import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/niruix/sshwifty/application/log"
|
||||
"github.com/niruix/sshwifty/application/network"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -46,12 +46,21 @@ func Enviro() Loader {
|
||||
return func(log log.Logger) (string, Configuration, error) {
|
||||
log.Info("Loading configuration from environment variables ...")
|
||||
|
||||
cfg := fileCfgCommon{
|
||||
dialTimeout, _ := strconv.ParseUint(
|
||||
parseEviro("SSHWIFTY_DIALTIMEOUT"), 10, 32)
|
||||
|
||||
cfg, dialer, cfgErr := fileCfgCommon{
|
||||
HostName: parseEviro("SSHWIFTY_HOSTNAME"),
|
||||
SharedKey: parseEviro("SSHWIFTY_SHAREDKEY"),
|
||||
DialTimeout: int(dialTimeout),
|
||||
Socks5: parseEviro("SSHWIFTY_SOCKS5"),
|
||||
Socks5User: parseEviro("SSHWIFTY_SOCKS5_USER"),
|
||||
Socks5Password: parseEviro("SSHWIFTY_SOCKS5_PASSWORD"),
|
||||
}.build()
|
||||
|
||||
if cfgErr != nil {
|
||||
return enviroTypeName, Configuration{}, fmt.Errorf(
|
||||
"Failed to build the configuration: %s", cfgErr)
|
||||
}
|
||||
|
||||
listenPort, listenPortErr := strconv.ParseUint(
|
||||
@@ -93,26 +102,12 @@ func Enviro() Loader {
|
||||
TLSCertificateKeyFile: parseEviro("SSHWIFTY_TLSCERTIFICATEKEYFILE"),
|
||||
}
|
||||
|
||||
var dialer network.Dial
|
||||
|
||||
if len(cfg.Socks5) <= 0 {
|
||||
dialer = network.TCPDial()
|
||||
} else {
|
||||
sDial, sDialErr := network.BuildSocks5Dial(
|
||||
cfg.Socks5, cfg.Socks5User, cfg.Socks5Password)
|
||||
|
||||
if sDialErr != nil {
|
||||
return enviroTypeName, Configuration{}, sDialErr
|
||||
}
|
||||
|
||||
dialer = sDial
|
||||
}
|
||||
|
||||
return enviroTypeName, Configuration{
|
||||
HostName: cfg.HostName,
|
||||
SharedKey: cfg.SharedKey,
|
||||
Dialer: dialer,
|
||||
Servers: []Server{cfgSer.build()},
|
||||
HostName: cfg.HostName,
|
||||
SharedKey: cfg.SharedKey,
|
||||
Dialer: dialer,
|
||||
DialTimeout: time.Duration(cfg.DialTimeout) * time.Second,
|
||||
Servers: []Server{cfgSer.build()},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ type fileCfgServer struct {
|
||||
TLSCertificateKeyFile string // Location of TLS certificate key
|
||||
}
|
||||
|
||||
func (f fileCfgServer) minDur(current, min int) int {
|
||||
func (f fileCfgServer) durationAtLeast(current, min int) int {
|
||||
if current > min {
|
||||
return current
|
||||
}
|
||||
@@ -62,17 +62,17 @@ func (f *fileCfgServer) build() Server {
|
||||
ListenInterface: f.ListenInterface,
|
||||
ListenPort: f.ListenPort,
|
||||
InitialTimeout: time.Duration(
|
||||
f.minDur(f.InitialTimeout, 5)) * time.Second,
|
||||
f.durationAtLeast(f.InitialTimeout, 5)) * time.Second,
|
||||
ReadTimeout: time.Duration(
|
||||
f.minDur(f.ReadTimeout, 30)) * time.Second,
|
||||
f.durationAtLeast(f.ReadTimeout, 30)) * time.Second,
|
||||
WriteTimeout: time.Duration(
|
||||
f.minDur(f.WriteTimeout, 30)) * time.Second,
|
||||
f.durationAtLeast(f.WriteTimeout, 30)) * time.Second,
|
||||
HeartbeatTimeout: time.Duration(
|
||||
f.minDur(f.HeartbeatTimeout, 10)) * time.Second,
|
||||
f.durationAtLeast(f.HeartbeatTimeout, 10)) * time.Second,
|
||||
ReadDelay: time.Duration(
|
||||
f.minDur(f.ReadDelay, 0)) * time.Millisecond,
|
||||
f.durationAtLeast(f.ReadDelay, 0)) * time.Millisecond,
|
||||
WriteDelay: time.Duration(
|
||||
f.minDur(f.WriteDelay, 0)) * time.Millisecond,
|
||||
f.durationAtLeast(f.WriteDelay, 0)) * time.Millisecond,
|
||||
TLSCertificateFile: f.TLSCertificateFile,
|
||||
TLSCertificateKeyFile: f.TLSCertificateKeyFile,
|
||||
}
|
||||
@@ -81,12 +81,46 @@ func (f *fileCfgServer) build() Server {
|
||||
type fileCfgCommon struct {
|
||||
HostName string // Host name
|
||||
SharedKey string // Shared key, empty to enable public access
|
||||
DialTimeout int // DialTimeout, min 5s
|
||||
Socks5 string // Socks5 server address, optional
|
||||
Socks5User string // Login user for socks5 server, optional
|
||||
Socks5Password string // Login pass for socks5 server, optional
|
||||
Servers []*fileCfgServer // Servers
|
||||
}
|
||||
|
||||
func (f fileCfgCommon) build() (fileCfgCommon, network.Dial, error) {
|
||||
dialTimeout := f.DialTimeout
|
||||
|
||||
if dialTimeout < 3 {
|
||||
dialTimeout = 3
|
||||
}
|
||||
|
||||
var dialer network.Dial
|
||||
|
||||
if len(f.Socks5) <= 0 {
|
||||
dialer = network.TCPDial()
|
||||
} else {
|
||||
sDial, sDialErr := network.BuildSocks5Dial(
|
||||
f.Socks5, f.Socks5User, f.Socks5Password)
|
||||
|
||||
if sDialErr != nil {
|
||||
return fileCfgCommon{}, nil, sDialErr
|
||||
}
|
||||
|
||||
dialer = sDial
|
||||
}
|
||||
|
||||
return fileCfgCommon{
|
||||
HostName: f.HostName,
|
||||
SharedKey: f.SharedKey,
|
||||
DialTimeout: dialTimeout,
|
||||
Socks5: f.Socks5,
|
||||
Socks5User: f.Socks5User,
|
||||
Socks5Password: f.Socks5Password,
|
||||
Servers: f.Servers,
|
||||
}, dialer, nil
|
||||
}
|
||||
|
||||
func loadFile(filePath string) (string, Configuration, error) {
|
||||
f, fErr := os.Open(filePath)
|
||||
|
||||
@@ -105,32 +139,24 @@ func loadFile(filePath string) (string, Configuration, error) {
|
||||
return fileTypeName, Configuration{}, jDecodeErr
|
||||
}
|
||||
|
||||
servers := make([]Server, len(cfg.Servers))
|
||||
finalCfg, dialer, cfgErr := cfg.build()
|
||||
|
||||
for i := range servers {
|
||||
servers[i] = cfg.Servers[i].build()
|
||||
if cfgErr != nil {
|
||||
return fileTypeName, Configuration{}, cfgErr
|
||||
}
|
||||
|
||||
var dialer network.Dial
|
||||
servers := make([]Server, len(finalCfg.Servers))
|
||||
|
||||
if len(cfg.Socks5) <= 0 {
|
||||
dialer = network.TCPDial()
|
||||
} else {
|
||||
sDial, sDialErr := network.BuildSocks5Dial(
|
||||
cfg.Socks5, cfg.Socks5User, cfg.Socks5Password)
|
||||
|
||||
if sDialErr != nil {
|
||||
return fileTypeName, Configuration{}, sDialErr
|
||||
}
|
||||
|
||||
dialer = sDial
|
||||
for i := range servers {
|
||||
servers[i] = finalCfg.Servers[i].build()
|
||||
}
|
||||
|
||||
return fileTypeName, Configuration{
|
||||
HostName: cfg.HostName,
|
||||
SharedKey: cfg.SharedKey,
|
||||
Dialer: dialer,
|
||||
Servers: servers,
|
||||
HostName: finalCfg.HostName,
|
||||
SharedKey: finalCfg.SharedKey,
|
||||
Dialer: dialer,
|
||||
DialTimeout: time.Duration(finalCfg.DialTimeout) * time.Second,
|
||||
Servers: servers,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user