diff --git a/application/application.go b/application/application.go index 3d4555b..0d75e4f 100644 --- a/application/application.go +++ b/application/application.go @@ -26,6 +26,7 @@ import ( "sync" "syscall" + "github.com/niruix/sshwifty/application/command" "github.com/niruix/sshwifty/application/configuration" "github.com/niruix/sshwifty/application/log" "github.com/niruix/sshwifty/application/server" @@ -66,11 +67,14 @@ func New(screen io.Writer, logger log.Logger) Application { func (a Application) run( cLoader configuration.Loader, closeSigBuilder ProccessSignallerBuilder, - handlerBuilder server.HandlerBuilder, + commands command.Commands, + handlerBuilder server.HandlerBuilderBuilder, ) (bool, error) { var err error - loaderName, c, cErr := cLoader(a.logger.Context("Configuration")) + loaderName, c, cErr := cLoader( + a.logger.Context("Configuration"), + commands.Reconfigure) if cErr != nil { a.logger.Error("\"%s\" loader cannot load configuration: %s", @@ -117,7 +121,7 @@ func (a Application) run( close(closeNotify) closeNotify = nil - }, handlerBuilder) + }, handlerBuilder(commands)) servers = append(servers, newServer) } @@ -148,7 +152,8 @@ func (a Application) run( func (a Application) Run( cLoader configuration.Loader, closeSigBuilder ProccessSignallerBuilder, - handlerBuilder server.HandlerBuilder, + commands command.Commands, + handlerBuilder server.HandlerBuilderBuilder, ) error { fmt.Fprintf(a.screen, banner, FullName, version, Author, URL) @@ -159,7 +164,8 @@ func (a Application) Run( defer a.logger.Info("Closed") for { - restart, runErr := a.run(cLoader, closeSigBuilder, handlerBuilder) + restart, runErr := a.run( + cLoader, closeSigBuilder, commands, handlerBuilder) if runErr != nil { a.logger.Error("Unable to start due to error: %s", runErr) diff --git a/application/command/commands.go b/application/command/commands.go index 6b29700..48ac390 100644 --- a/application/command/commands.go +++ b/application/command/commands.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" + "github.com/niruix/sshwifty/application/configuration" "github.com/niruix/sshwifty/application/log" ) @@ -42,20 +43,35 @@ type Command func( cfg Configuration, ) FSMMachine +// Builder builds a command +type Builder struct { + command Command + configurator configuration.Reconfigurator +} + +// Register builds a Builder for registration +func Register(c Command, p configuration.Reconfigurator) Builder { + return Builder{ + command: c, + configurator: p, + } +} + // Commands contains data of all commands -type Commands [MaxCommandID + 1]Command +type Commands [MaxCommandID + 1]Builder // Register registers a new command -func (c *Commands) Register(id byte, cb Command) { +func (c *Commands) Register( + id byte, cb Command, ps configuration.Reconfigurator) { if id > MaxCommandID { panic("Command ID must be not greater than MaxCommandID") } - if (*c)[id] != nil { + if (*c)[id].command != nil { panic(fmt.Sprintf("Command %d already been registered", id)) } - (*c)[id] = cb + (*c)[id] = Register(cb, ps) } // Run creates command executer @@ -63,16 +79,32 @@ func (c Commands) Run( id byte, l log.Logger, w StreamResponder, - cfg Configuration) (FSM, error) { + cfg Configuration, +) (FSM, error) { if id > MaxCommandID { return FSM{}, ErrCommandRunUndefinedCommand } cc := c[id] - if cc == nil { + if cc.command == nil { return FSM{}, ErrCommandRunUndefinedCommand } - return newFSM(cc(l, w, cfg)), nil + return newFSM(cc.command(l, w, cfg)), nil +} + +// Reconfigure lets commands reset configuration +func (c Commands) Reconfigure( + p configuration.Configuration, +) configuration.Configuration { + for i := range c { + if c[i].configurator == nil { + continue + } + + p = c[i].configurator(p) + } + + return p } diff --git a/application/command/handler_stream_test.go b/application/command/handler_stream_test.go index 0bf3587..154afd8 100644 --- a/application/command/handler_stream_test.go +++ b/application/command/handler_stream_test.go @@ -171,7 +171,7 @@ func (d *dummyStreamCommand) Release() error { func TestHandlerHandleStream(t *testing.T) { cmds := Commands{} - cmds.Register(0, newDummyStreamCommand) + cmds.Register(0, newDummyStreamCommand, nil) readerDataInput := make(chan []byte) diff --git a/application/commands/commands.go b/application/commands/commands.go index 474b03e..b2157f8 100644 --- a/application/commands/commands.go +++ b/application/commands/commands.go @@ -24,7 +24,7 @@ import ( // New creates a new commands group func New() command.Commands { return command.Commands{ - newTelnet, - newSSH, + command.Register(newTelnet, parseTelnetConfig), + command.Register(newSSH, parseSSHConfig), } } diff --git a/application/commands/ssh.go b/application/commands/ssh.go index 15fef7a..af3b0ca 100644 --- a/application/commands/ssh.go +++ b/application/commands/ssh.go @@ -27,6 +27,7 @@ import ( "golang.org/x/crypto/ssh" "github.com/niruix/sshwifty/application/command" + "github.com/niruix/sshwifty/application/configuration" "github.com/niruix/sshwifty/application/log" "github.com/niruix/sshwifty/application/network" "github.com/niruix/sshwifty/application/rw" @@ -107,6 +108,10 @@ var ( sshEmptyTime = time.Time{} ) +const ( + sshDefaultPortString = "22" +) + type sshRemoteConnWrapper struct { net.Conn @@ -198,6 +203,30 @@ func newSSH( } } +func parseSSHConfig(p configuration.Configuration) configuration.Configuration { + for i := range p.Presets { + if p.Presets[i].Type != "SSH" { + continue + } + + oldHost := p.Presets[i].Host + + _, _, sErr := net.SplitHostPort(p.Presets[i].Host) + + if sErr != nil { + p.Presets[i].Host = net.JoinHostPort( + p.Presets[i].Host, + sshDefaultPortString) + } + + if len(p.Presets[i].Host) <= 0 { + p.Presets[i].Host = oldHost + } + } + + return p +} + func (d *sshClient) Bootup( r *rw.LimitedReader, b []byte, diff --git a/application/commands/telnet.go b/application/commands/telnet.go index 42a5fb3..d295b2d 100644 --- a/application/commands/telnet.go +++ b/application/commands/telnet.go @@ -24,6 +24,7 @@ import ( "time" "github.com/niruix/sshwifty/application/command" + "github.com/niruix/sshwifty/application/configuration" "github.com/niruix/sshwifty/application/log" "github.com/niruix/sshwifty/application/network" "github.com/niruix/sshwifty/application/rw" @@ -40,6 +41,10 @@ const ( TelnetRequestErrorBadRemoteAddress = command.StreamError(0x01) ) +const ( + telnetDefaultPortString = "23" +) + // Server signal codes const ( TelnetServerRemoteBand = 0x00 @@ -71,6 +76,32 @@ func newTelnet( } } +func parseTelnetConfig( + p configuration.Configuration, +) configuration.Configuration { + for i := range p.Presets { + if p.Presets[i].Type != "Telnet" { + continue + } + + oldHost := p.Presets[i].Host + + _, _, sErr := net.SplitHostPort(p.Presets[i].Host) + + if sErr != nil { + p.Presets[i].Host = net.JoinHostPort( + p.Presets[i].Host, + telnetDefaultPortString) + } + + if len(p.Presets[i].Host) <= 0 { + p.Presets[i].Host = oldHost + } + } + + return p +} + func (d *telnetClient) Bootup( r *rw.LimitedReader, b []byte) (command.FSMState, command.FSMError) { diff --git a/application/configuration/config.go b/application/configuration/config.go index 18425f7..be6c402 100644 --- a/application/configuration/config.go +++ b/application/configuration/config.go @@ -132,8 +132,10 @@ type Preset struct { type Configuration struct { HostName string SharedKey string - Dialer network.Dial DialTimeout time.Duration + Socks5 string + Socks5User string + Socks5Password string Servers []Server Presets []Preset OnlyAllowPresetRemotes bool @@ -168,12 +170,50 @@ func (c Configuration) Verify() error { return nil } +// Dialer builds a Dialer +func (c Configuration) Dialer() network.Dial { + dialTimeout := c.DialTimeout + + if dialTimeout < 3 { + dialTimeout = 3 + } + + dialer := network.TCPDial() + + if len(c.Socks5) > 0 { + sDial, sDialErr := network.BuildSocks5Dial( + c.Socks5, c.Socks5User, c.Socks5Password) + + if sDialErr != nil { + panic("Unable to build Socks5 Dialer: " + sDialErr.Error()) + } + + dialer = sDial + } + + if c.OnlyAllowPresetRemotes { + accessList := make(network.AllowedHosts, len(c.Presets)) + + for _, k := range c.Presets { + if len(k.Host) <= 0 { + continue + } + + accessList[k.Host] = struct{}{} + } + + dialer = network.AccessControlDial(accessList, dialer) + } + + return dialer +} + // Common returns common settings func (c Configuration) Common() Common { return Common{ HostName: c.HostName, SharedKey: c.SharedKey, - Dialer: c.Dialer, + Dialer: c.Dialer(), DialTimeout: c.DialTimeout, Presets: c.Presets, OnlyAllowPresetRemotes: c.OnlyAllowPresetRemotes, diff --git a/application/configuration/loader.go b/application/configuration/loader.go index d6cc2c4..3b2a773 100644 --- a/application/configuration/loader.go +++ b/application/configuration/loader.go @@ -21,5 +21,11 @@ import ( "github.com/niruix/sshwifty/application/log" ) +// Reconfigurator reloads configuration +type Reconfigurator func(p Configuration) Configuration + // Loader Configuration loader -type Loader func(log log.Logger) (name string, cfg Configuration, err error) +type Loader func( + log log.Logger, + r Reconfigurator, +) (name string, cfg Configuration, err error) diff --git a/application/configuration/loader_direct.go b/application/configuration/loader_direct.go index 3274460..948df76 100644 --- a/application/configuration/loader_direct.go +++ b/application/configuration/loader_direct.go @@ -28,7 +28,10 @@ const ( // Direct creates a loader that return raw configuration data directly. // Good for integration. func Direct(cfg Configuration) Loader { - return func(log log.Logger) (string, Configuration, error) { + return func( + log log.Logger, + r Reconfigurator, + ) (string, Configuration, error) { return directTypeName, cfg, nil } -} \ No newline at end of file +} diff --git a/application/configuration/loader_enviro.go b/application/configuration/loader_enviro.go index 64fe5bb..3370128 100644 --- a/application/configuration/loader_enviro.go +++ b/application/configuration/loader_enviro.go @@ -44,13 +44,16 @@ func parseEviro(name string) string { // Enviro creates an environment variable based configuration loader func Enviro() Loader { - return func(log log.Logger) (string, Configuration, error) { + return func( + log log.Logger, + r Reconfigurator, + ) (string, Configuration, error) { log.Info("Loading configuration from environment variables ...") dialTimeout, _ := strconv.ParseUint( parseEviro("SSHWIFTY_DIALTIMEOUT"), 10, 32) - cfg, dialer, cfgErr := fileCfgCommon{ + cfg, cfgErr := fileCfgCommon{ HostName: parseEviro("SSHWIFTY_HOSTNAME"), SharedKey: parseEviro("SSHWIFTY_SHAREDKEY"), DialTimeout: int(dialTimeout), @@ -68,14 +71,15 @@ func Enviro() Loader { "Failed to build the configuration: %s", cfgErr) } - listenPort, listenPortErr := strconv.ParseUint( - parseEviro("SSHWIFTY_LISTENPORT"), 10, 16) + listenIface := parseEviro("SSHWIFTY_LISTENINTERFACE") - if listenPortErr != nil { - return enviroTypeName, Configuration{}, fmt.Errorf( - "Invalid \"SSHWIFTY_LISTENPORT\": %s", listenPortErr) + if len(listenIface) <= 0 { + listenIface = "127.0.0.1" } + listenPort, _ := strconv.ParseUint( + parseEviro("SSHWIFTY_LISTENPORT"), 10, 16) + initialTimeout, _ := strconv.ParseUint( parseEviro("SSHWIFTY_INITIALTIMEOUT"), 10, 32) @@ -95,7 +99,7 @@ func Enviro() Loader { parseEviro("SSHWIFTY_WRITEELAY"), 10, 32) cfgSer := fileCfgServer{ - ListenInterface: parseEviro("SSHWIFTY_LISTENINTERFACE"), + ListenInterface: listenIface, ListenPort: uint16(listenPort), InitialTimeout: int(initialTimeout), ReadTimeout: int(readTimeout), @@ -122,8 +126,10 @@ func Enviro() Loader { return enviroTypeName, Configuration{ HostName: cfg.HostName, SharedKey: cfg.SharedKey, - Dialer: dialer, DialTimeout: time.Duration(cfg.DialTimeout) * time.Second, + Socks5: cfg.Socks5, + Socks5User: cfg.Socks5User, + Socks5Password: cfg.Socks5Password, Servers: []Server{cfgSer.build()}, Presets: presets, OnlyAllowPresetRemotes: cfg.OnlyAllowPresetRemotes, diff --git a/application/configuration/loader_file.go b/application/configuration/loader_file.go index 21384d6..6fec632 100644 --- a/application/configuration/loader_file.go +++ b/application/configuration/loader_file.go @@ -28,7 +28,6 @@ import ( "time" "github.com/niruix/sshwifty/application/log" - "github.com/niruix/sshwifty/application/network" ) const ( @@ -122,56 +121,24 @@ type fileCfgCommon struct { OnlyAllowPresetRemotes bool } -func (f fileCfgCommon) build() (fileCfgCommon, network.Dial, error) { - dialTimeout := f.DialTimeout - - if dialTimeout < 3 { - dialTimeout = 3 - } - - var dialer network.Dial - - if len(f.Socks5) <= 0 { - dialer = network.TCPDial() - } else { - sDial, sDialErr := network.BuildSocks5Dial( - f.Socks5, f.Socks5User, f.Socks5Password) - - if sDialErr != nil { - return fileCfgCommon{}, nil, sDialErr - } - - dialer = sDial - } - - if f.OnlyAllowPresetRemotes { - accessList := make(network.AllowedHosts, len(f.Presets)) - - for _, k := range f.Presets { - if len(k.Host) <= 0 { - continue - } - - accessList[k.Host] = struct{}{} - } - - dialer = network.AccessControlDial(accessList, dialer) - } - +func (f fileCfgCommon) build() (fileCfgCommon, error) { return fileCfgCommon{ HostName: f.HostName, SharedKey: f.SharedKey, - DialTimeout: dialTimeout, + DialTimeout: f.DialTimeout, Socks5: f.Socks5, Socks5User: f.Socks5User, Socks5Password: f.Socks5Password, Servers: f.Servers, Presets: f.Presets, OnlyAllowPresetRemotes: f.OnlyAllowPresetRemotes, - }, dialer, nil + }, nil } -func loadFile(filePath string) (string, Configuration, error) { +func loadFile( + filePath string, + r Reconfigurator, +) (string, Configuration, error) { f, fErr := os.Open(filePath) if fErr != nil { @@ -189,7 +156,7 @@ func loadFile(filePath string) (string, Configuration, error) { return fileTypeName, Configuration{}, jDecodeErr } - finalCfg, dialer, cfgErr := cfg.build() + finalCfg, cfgErr := cfg.build() if cfgErr != nil { return fileTypeName, Configuration{}, cfgErr @@ -207,25 +174,30 @@ func loadFile(filePath string) (string, Configuration, error) { presets[i] = finalCfg.Presets[i].build() } - return fileTypeName, Configuration{ + return fileTypeName, r(Configuration{ HostName: finalCfg.HostName, SharedKey: finalCfg.SharedKey, - Dialer: dialer, DialTimeout: time.Duration(finalCfg.DialTimeout) * time.Second, + Socks5: cfg.Socks5, + Socks5User: cfg.Socks5User, + Socks5Password: cfg.Socks5Password, Servers: servers, Presets: presets, OnlyAllowPresetRemotes: cfg.OnlyAllowPresetRemotes, - }, nil + }), nil } // File creates a configuration file loader func File(customPath string) Loader { - return func(log log.Logger) (string, Configuration, error) { + return func( + log log.Logger, + r Reconfigurator, + ) (string, Configuration, error) { if len(customPath) > 0 { log.Info("Loading configuration from: %s", customPath) - return loadFile(customPath) + return loadFile(customPath, r) } log.Info("Loading configuration from one of the default " + @@ -267,7 +239,7 @@ func File(customPath string) Loader { log.Info("Configuration file \"%s\" has been selected", fallbackFileSearchList[f]) - return loadFile(fallbackFileSearchList[f]) + return loadFile(fallbackFileSearchList[f], r) } return fileTypeName, Configuration{}, fmt.Errorf( diff --git a/application/configuration/loader_redundant.go b/application/configuration/loader_redundant.go index b85c99c..44605ca 100644 --- a/application/configuration/loader_redundant.go +++ b/application/configuration/loader_redundant.go @@ -30,11 +30,14 @@ const ( // Redundant creates a group of loaders. They will be executed one by one until // one of it successfully returned a configuration func Redundant(loaders ...Loader) Loader { - return func(log log.Logger) (string, Configuration, error) { + return func( + log log.Logger, + r Reconfigurator, + ) (string, Configuration, error) { ll := log.Context("Redundant") for i := range loaders { - lLoaderName, lCfg, lErr := loaders[i](ll) + lLoaderName, lCfg, lErr := loaders[i](ll, r) if lErr != nil { ll.Warning("Unable to load configuration from \"%s\": %s", diff --git a/application/server/server.go b/application/server/server.go index 1da13f5..b7da41c 100644 --- a/application/server/server.go +++ b/application/server/server.go @@ -28,6 +28,7 @@ import ( "sync" "time" + "github.com/niruix/sshwifty/application/command" "github.com/niruix/sshwifty/application/configuration" "github.com/niruix/sshwifty/application/log" ) @@ -50,6 +51,9 @@ type HandlerBuilder func( cfg configuration.Server, logger log.Logger) http.Handler +// HandlerBuilderBuilder builds HandlerBuilder +type HandlerBuilderBuilder func(command.Commands) HandlerBuilder + // CloseCallback will be called when the server has closed type CloseCallback func(error) diff --git a/sshwifty.go b/sshwifty.go index 8691f3f..00379c7 100644 --- a/sshwifty.go +++ b/sshwifty.go @@ -43,7 +43,8 @@ func main() { len(os.Getenv("SSHWIFTY_DEBUG")) > 0, application.Name, os.Stderr)). Run(configuration.Redundant(configLoaders...), application.DefaultProccessSignallerBuilder, - controller.Builder(commands.New())) + commands.New(), + controller.Builder) if e == nil { return