Allowing Commands to reconfigure app Configuration, this is will allow default ports to be currently added to remote presets

This commit is contained in:
NI
2020-02-23 17:23:24 +08:00
parent 697ee0b2a1
commit e7ca6d95ed
14 changed files with 212 additions and 79 deletions

View File

@@ -26,6 +26,7 @@ import (
"sync" "sync"
"syscall" "syscall"
"github.com/niruix/sshwifty/application/command"
"github.com/niruix/sshwifty/application/configuration" "github.com/niruix/sshwifty/application/configuration"
"github.com/niruix/sshwifty/application/log" "github.com/niruix/sshwifty/application/log"
"github.com/niruix/sshwifty/application/server" "github.com/niruix/sshwifty/application/server"
@@ -66,11 +67,14 @@ func New(screen io.Writer, logger log.Logger) Application {
func (a Application) run( func (a Application) run(
cLoader configuration.Loader, cLoader configuration.Loader,
closeSigBuilder ProccessSignallerBuilder, closeSigBuilder ProccessSignallerBuilder,
handlerBuilder server.HandlerBuilder, commands command.Commands,
handlerBuilder server.HandlerBuilderBuilder,
) (bool, error) { ) (bool, error) {
var err error var err error
loaderName, c, cErr := cLoader(a.logger.Context("Configuration")) loaderName, c, cErr := cLoader(
a.logger.Context("Configuration"),
commands.Reconfigure)
if cErr != nil { if cErr != nil {
a.logger.Error("\"%s\" loader cannot load configuration: %s", a.logger.Error("\"%s\" loader cannot load configuration: %s",
@@ -117,7 +121,7 @@ func (a Application) run(
close(closeNotify) close(closeNotify)
closeNotify = nil closeNotify = nil
}, handlerBuilder) }, handlerBuilder(commands))
servers = append(servers, newServer) servers = append(servers, newServer)
} }
@@ -148,7 +152,8 @@ func (a Application) run(
func (a Application) Run( func (a Application) Run(
cLoader configuration.Loader, cLoader configuration.Loader,
closeSigBuilder ProccessSignallerBuilder, closeSigBuilder ProccessSignallerBuilder,
handlerBuilder server.HandlerBuilder, commands command.Commands,
handlerBuilder server.HandlerBuilderBuilder,
) error { ) error {
fmt.Fprintf(a.screen, banner, FullName, version, Author, URL) fmt.Fprintf(a.screen, banner, FullName, version, Author, URL)
@@ -159,7 +164,8 @@ func (a Application) Run(
defer a.logger.Info("Closed") defer a.logger.Info("Closed")
for { for {
restart, runErr := a.run(cLoader, closeSigBuilder, handlerBuilder) restart, runErr := a.run(
cLoader, closeSigBuilder, commands, handlerBuilder)
if runErr != nil { if runErr != nil {
a.logger.Error("Unable to start due to error: %s", runErr) a.logger.Error("Unable to start due to error: %s", runErr)

View File

@@ -21,6 +21,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/niruix/sshwifty/application/configuration"
"github.com/niruix/sshwifty/application/log" "github.com/niruix/sshwifty/application/log"
) )
@@ -42,20 +43,35 @@ type Command func(
cfg Configuration, cfg Configuration,
) FSMMachine ) FSMMachine
// Builder builds a command
type Builder struct {
command Command
configurator configuration.Reconfigurator
}
// Register builds a Builder for registration
func Register(c Command, p configuration.Reconfigurator) Builder {
return Builder{
command: c,
configurator: p,
}
}
// Commands contains data of all commands // Commands contains data of all commands
type Commands [MaxCommandID + 1]Command type Commands [MaxCommandID + 1]Builder
// Register registers a new command // Register registers a new command
func (c *Commands) Register(id byte, cb Command) { func (c *Commands) Register(
id byte, cb Command, ps configuration.Reconfigurator) {
if id > MaxCommandID { if id > MaxCommandID {
panic("Command ID must be not greater than MaxCommandID") panic("Command ID must be not greater than MaxCommandID")
} }
if (*c)[id] != nil { if (*c)[id].command != nil {
panic(fmt.Sprintf("Command %d already been registered", id)) panic(fmt.Sprintf("Command %d already been registered", id))
} }
(*c)[id] = cb (*c)[id] = Register(cb, ps)
} }
// Run creates command executer // Run creates command executer
@@ -63,16 +79,32 @@ func (c Commands) Run(
id byte, id byte,
l log.Logger, l log.Logger,
w StreamResponder, w StreamResponder,
cfg Configuration) (FSM, error) { cfg Configuration,
) (FSM, error) {
if id > MaxCommandID { if id > MaxCommandID {
return FSM{}, ErrCommandRunUndefinedCommand return FSM{}, ErrCommandRunUndefinedCommand
} }
cc := c[id] cc := c[id]
if cc == nil { if cc.command == nil {
return FSM{}, ErrCommandRunUndefinedCommand return FSM{}, ErrCommandRunUndefinedCommand
} }
return newFSM(cc(l, w, cfg)), nil return newFSM(cc.command(l, w, cfg)), nil
}
// Reconfigure lets commands reset configuration
func (c Commands) Reconfigure(
p configuration.Configuration,
) configuration.Configuration {
for i := range c {
if c[i].configurator == nil {
continue
}
p = c[i].configurator(p)
}
return p
} }

View File

@@ -171,7 +171,7 @@ func (d *dummyStreamCommand) Release() error {
func TestHandlerHandleStream(t *testing.T) { func TestHandlerHandleStream(t *testing.T) {
cmds := Commands{} cmds := Commands{}
cmds.Register(0, newDummyStreamCommand) cmds.Register(0, newDummyStreamCommand, nil)
readerDataInput := make(chan []byte) readerDataInput := make(chan []byte)

View File

@@ -24,7 +24,7 @@ import (
// New creates a new commands group // New creates a new commands group
func New() command.Commands { func New() command.Commands {
return command.Commands{ return command.Commands{
newTelnet, command.Register(newTelnet, parseTelnetConfig),
newSSH, command.Register(newSSH, parseSSHConfig),
} }
} }

View File

@@ -27,6 +27,7 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"github.com/niruix/sshwifty/application/command" "github.com/niruix/sshwifty/application/command"
"github.com/niruix/sshwifty/application/configuration"
"github.com/niruix/sshwifty/application/log" "github.com/niruix/sshwifty/application/log"
"github.com/niruix/sshwifty/application/network" "github.com/niruix/sshwifty/application/network"
"github.com/niruix/sshwifty/application/rw" "github.com/niruix/sshwifty/application/rw"
@@ -107,6 +108,10 @@ var (
sshEmptyTime = time.Time{} sshEmptyTime = time.Time{}
) )
const (
sshDefaultPortString = "22"
)
type sshRemoteConnWrapper struct { type sshRemoteConnWrapper struct {
net.Conn net.Conn
@@ -198,6 +203,30 @@ func newSSH(
} }
} }
func parseSSHConfig(p configuration.Configuration) configuration.Configuration {
for i := range p.Presets {
if p.Presets[i].Type != "SSH" {
continue
}
oldHost := p.Presets[i].Host
_, _, sErr := net.SplitHostPort(p.Presets[i].Host)
if sErr != nil {
p.Presets[i].Host = net.JoinHostPort(
p.Presets[i].Host,
sshDefaultPortString)
}
if len(p.Presets[i].Host) <= 0 {
p.Presets[i].Host = oldHost
}
}
return p
}
func (d *sshClient) Bootup( func (d *sshClient) Bootup(
r *rw.LimitedReader, r *rw.LimitedReader,
b []byte, b []byte,

View File

@@ -24,6 +24,7 @@ import (
"time" "time"
"github.com/niruix/sshwifty/application/command" "github.com/niruix/sshwifty/application/command"
"github.com/niruix/sshwifty/application/configuration"
"github.com/niruix/sshwifty/application/log" "github.com/niruix/sshwifty/application/log"
"github.com/niruix/sshwifty/application/network" "github.com/niruix/sshwifty/application/network"
"github.com/niruix/sshwifty/application/rw" "github.com/niruix/sshwifty/application/rw"
@@ -40,6 +41,10 @@ const (
TelnetRequestErrorBadRemoteAddress = command.StreamError(0x01) TelnetRequestErrorBadRemoteAddress = command.StreamError(0x01)
) )
const (
telnetDefaultPortString = "23"
)
// Server signal codes // Server signal codes
const ( const (
TelnetServerRemoteBand = 0x00 TelnetServerRemoteBand = 0x00
@@ -71,6 +76,32 @@ func newTelnet(
} }
} }
func parseTelnetConfig(
p configuration.Configuration,
) configuration.Configuration {
for i := range p.Presets {
if p.Presets[i].Type != "Telnet" {
continue
}
oldHost := p.Presets[i].Host
_, _, sErr := net.SplitHostPort(p.Presets[i].Host)
if sErr != nil {
p.Presets[i].Host = net.JoinHostPort(
p.Presets[i].Host,
telnetDefaultPortString)
}
if len(p.Presets[i].Host) <= 0 {
p.Presets[i].Host = oldHost
}
}
return p
}
func (d *telnetClient) Bootup( func (d *telnetClient) Bootup(
r *rw.LimitedReader, r *rw.LimitedReader,
b []byte) (command.FSMState, command.FSMError) { b []byte) (command.FSMState, command.FSMError) {

View File

@@ -132,8 +132,10 @@ type Preset struct {
type Configuration struct { type Configuration struct {
HostName string HostName string
SharedKey string SharedKey string
Dialer network.Dial
DialTimeout time.Duration DialTimeout time.Duration
Socks5 string
Socks5User string
Socks5Password string
Servers []Server Servers []Server
Presets []Preset Presets []Preset
OnlyAllowPresetRemotes bool OnlyAllowPresetRemotes bool
@@ -168,12 +170,50 @@ func (c Configuration) Verify() error {
return nil return nil
} }
// Dialer builds a Dialer
func (c Configuration) Dialer() network.Dial {
dialTimeout := c.DialTimeout
if dialTimeout < 3 {
dialTimeout = 3
}
dialer := network.TCPDial()
if len(c.Socks5) > 0 {
sDial, sDialErr := network.BuildSocks5Dial(
c.Socks5, c.Socks5User, c.Socks5Password)
if sDialErr != nil {
panic("Unable to build Socks5 Dialer: " + sDialErr.Error())
}
dialer = sDial
}
if c.OnlyAllowPresetRemotes {
accessList := make(network.AllowedHosts, len(c.Presets))
for _, k := range c.Presets {
if len(k.Host) <= 0 {
continue
}
accessList[k.Host] = struct{}{}
}
dialer = network.AccessControlDial(accessList, dialer)
}
return dialer
}
// Common returns common settings // Common returns common settings
func (c Configuration) Common() Common { func (c Configuration) Common() Common {
return Common{ return Common{
HostName: c.HostName, HostName: c.HostName,
SharedKey: c.SharedKey, SharedKey: c.SharedKey,
Dialer: c.Dialer, Dialer: c.Dialer(),
DialTimeout: c.DialTimeout, DialTimeout: c.DialTimeout,
Presets: c.Presets, Presets: c.Presets,
OnlyAllowPresetRemotes: c.OnlyAllowPresetRemotes, OnlyAllowPresetRemotes: c.OnlyAllowPresetRemotes,

View File

@@ -21,5 +21,11 @@ import (
"github.com/niruix/sshwifty/application/log" "github.com/niruix/sshwifty/application/log"
) )
// Reconfigurator reloads configuration
type Reconfigurator func(p Configuration) Configuration
// Loader Configuration loader // Loader Configuration loader
type Loader func(log log.Logger) (name string, cfg Configuration, err error) type Loader func(
log log.Logger,
r Reconfigurator,
) (name string, cfg Configuration, err error)

View File

@@ -28,7 +28,10 @@ const (
// Direct creates a loader that return raw configuration data directly. // Direct creates a loader that return raw configuration data directly.
// Good for integration. // Good for integration.
func Direct(cfg Configuration) Loader { func Direct(cfg Configuration) Loader {
return func(log log.Logger) (string, Configuration, error) { return func(
log log.Logger,
r Reconfigurator,
) (string, Configuration, error) {
return directTypeName, cfg, nil return directTypeName, cfg, nil
} }
} }

View File

@@ -44,13 +44,16 @@ func parseEviro(name string) string {
// Enviro creates an environment variable based configuration loader // Enviro creates an environment variable based configuration loader
func Enviro() Loader { func Enviro() Loader {
return func(log log.Logger) (string, Configuration, error) { return func(
log log.Logger,
r Reconfigurator,
) (string, Configuration, error) {
log.Info("Loading configuration from environment variables ...") log.Info("Loading configuration from environment variables ...")
dialTimeout, _ := strconv.ParseUint( dialTimeout, _ := strconv.ParseUint(
parseEviro("SSHWIFTY_DIALTIMEOUT"), 10, 32) parseEviro("SSHWIFTY_DIALTIMEOUT"), 10, 32)
cfg, dialer, cfgErr := fileCfgCommon{ cfg, cfgErr := fileCfgCommon{
HostName: parseEviro("SSHWIFTY_HOSTNAME"), HostName: parseEviro("SSHWIFTY_HOSTNAME"),
SharedKey: parseEviro("SSHWIFTY_SHAREDKEY"), SharedKey: parseEviro("SSHWIFTY_SHAREDKEY"),
DialTimeout: int(dialTimeout), DialTimeout: int(dialTimeout),
@@ -68,14 +71,15 @@ func Enviro() Loader {
"Failed to build the configuration: %s", cfgErr) "Failed to build the configuration: %s", cfgErr)
} }
listenPort, listenPortErr := strconv.ParseUint( listenIface := parseEviro("SSHWIFTY_LISTENINTERFACE")
parseEviro("SSHWIFTY_LISTENPORT"), 10, 16)
if listenPortErr != nil { if len(listenIface) <= 0 {
return enviroTypeName, Configuration{}, fmt.Errorf( listenIface = "127.0.0.1"
"Invalid \"SSHWIFTY_LISTENPORT\": %s", listenPortErr)
} }
listenPort, _ := strconv.ParseUint(
parseEviro("SSHWIFTY_LISTENPORT"), 10, 16)
initialTimeout, _ := strconv.ParseUint( initialTimeout, _ := strconv.ParseUint(
parseEviro("SSHWIFTY_INITIALTIMEOUT"), 10, 32) parseEviro("SSHWIFTY_INITIALTIMEOUT"), 10, 32)
@@ -95,7 +99,7 @@ func Enviro() Loader {
parseEviro("SSHWIFTY_WRITEELAY"), 10, 32) parseEviro("SSHWIFTY_WRITEELAY"), 10, 32)
cfgSer := fileCfgServer{ cfgSer := fileCfgServer{
ListenInterface: parseEviro("SSHWIFTY_LISTENINTERFACE"), ListenInterface: listenIface,
ListenPort: uint16(listenPort), ListenPort: uint16(listenPort),
InitialTimeout: int(initialTimeout), InitialTimeout: int(initialTimeout),
ReadTimeout: int(readTimeout), ReadTimeout: int(readTimeout),
@@ -122,8 +126,10 @@ func Enviro() Loader {
return enviroTypeName, Configuration{ return enviroTypeName, Configuration{
HostName: cfg.HostName, HostName: cfg.HostName,
SharedKey: cfg.SharedKey, SharedKey: cfg.SharedKey,
Dialer: dialer,
DialTimeout: time.Duration(cfg.DialTimeout) * time.Second, DialTimeout: time.Duration(cfg.DialTimeout) * time.Second,
Socks5: cfg.Socks5,
Socks5User: cfg.Socks5User,
Socks5Password: cfg.Socks5Password,
Servers: []Server{cfgSer.build()}, Servers: []Server{cfgSer.build()},
Presets: presets, Presets: presets,
OnlyAllowPresetRemotes: cfg.OnlyAllowPresetRemotes, OnlyAllowPresetRemotes: cfg.OnlyAllowPresetRemotes,

View File

@@ -28,7 +28,6 @@ import (
"time" "time"
"github.com/niruix/sshwifty/application/log" "github.com/niruix/sshwifty/application/log"
"github.com/niruix/sshwifty/application/network"
) )
const ( const (
@@ -122,56 +121,24 @@ type fileCfgCommon struct {
OnlyAllowPresetRemotes bool OnlyAllowPresetRemotes bool
} }
func (f fileCfgCommon) build() (fileCfgCommon, network.Dial, error) { func (f fileCfgCommon) build() (fileCfgCommon, error) {
dialTimeout := f.DialTimeout
if dialTimeout < 3 {
dialTimeout = 3
}
var dialer network.Dial
if len(f.Socks5) <= 0 {
dialer = network.TCPDial()
} else {
sDial, sDialErr := network.BuildSocks5Dial(
f.Socks5, f.Socks5User, f.Socks5Password)
if sDialErr != nil {
return fileCfgCommon{}, nil, sDialErr
}
dialer = sDial
}
if f.OnlyAllowPresetRemotes {
accessList := make(network.AllowedHosts, len(f.Presets))
for _, k := range f.Presets {
if len(k.Host) <= 0 {
continue
}
accessList[k.Host] = struct{}{}
}
dialer = network.AccessControlDial(accessList, dialer)
}
return fileCfgCommon{ return fileCfgCommon{
HostName: f.HostName, HostName: f.HostName,
SharedKey: f.SharedKey, SharedKey: f.SharedKey,
DialTimeout: dialTimeout, DialTimeout: f.DialTimeout,
Socks5: f.Socks5, Socks5: f.Socks5,
Socks5User: f.Socks5User, Socks5User: f.Socks5User,
Socks5Password: f.Socks5Password, Socks5Password: f.Socks5Password,
Servers: f.Servers, Servers: f.Servers,
Presets: f.Presets, Presets: f.Presets,
OnlyAllowPresetRemotes: f.OnlyAllowPresetRemotes, OnlyAllowPresetRemotes: f.OnlyAllowPresetRemotes,
}, dialer, nil }, nil
} }
func loadFile(filePath string) (string, Configuration, error) { func loadFile(
filePath string,
r Reconfigurator,
) (string, Configuration, error) {
f, fErr := os.Open(filePath) f, fErr := os.Open(filePath)
if fErr != nil { if fErr != nil {
@@ -189,7 +156,7 @@ func loadFile(filePath string) (string, Configuration, error) {
return fileTypeName, Configuration{}, jDecodeErr return fileTypeName, Configuration{}, jDecodeErr
} }
finalCfg, dialer, cfgErr := cfg.build() finalCfg, cfgErr := cfg.build()
if cfgErr != nil { if cfgErr != nil {
return fileTypeName, Configuration{}, cfgErr return fileTypeName, Configuration{}, cfgErr
@@ -207,25 +174,30 @@ func loadFile(filePath string) (string, Configuration, error) {
presets[i] = finalCfg.Presets[i].build() presets[i] = finalCfg.Presets[i].build()
} }
return fileTypeName, Configuration{ return fileTypeName, r(Configuration{
HostName: finalCfg.HostName, HostName: finalCfg.HostName,
SharedKey: finalCfg.SharedKey, SharedKey: finalCfg.SharedKey,
Dialer: dialer,
DialTimeout: time.Duration(finalCfg.DialTimeout) * DialTimeout: time.Duration(finalCfg.DialTimeout) *
time.Second, time.Second,
Socks5: cfg.Socks5,
Socks5User: cfg.Socks5User,
Socks5Password: cfg.Socks5Password,
Servers: servers, Servers: servers,
Presets: presets, Presets: presets,
OnlyAllowPresetRemotes: cfg.OnlyAllowPresetRemotes, OnlyAllowPresetRemotes: cfg.OnlyAllowPresetRemotes,
}, nil }), nil
} }
// File creates a configuration file loader // File creates a configuration file loader
func File(customPath string) Loader { func File(customPath string) Loader {
return func(log log.Logger) (string, Configuration, error) { return func(
log log.Logger,
r Reconfigurator,
) (string, Configuration, error) {
if len(customPath) > 0 { if len(customPath) > 0 {
log.Info("Loading configuration from: %s", customPath) log.Info("Loading configuration from: %s", customPath)
return loadFile(customPath) return loadFile(customPath, r)
} }
log.Info("Loading configuration from one of the default " + log.Info("Loading configuration from one of the default " +
@@ -267,7 +239,7 @@ func File(customPath string) Loader {
log.Info("Configuration file \"%s\" has been selected", log.Info("Configuration file \"%s\" has been selected",
fallbackFileSearchList[f]) fallbackFileSearchList[f])
return loadFile(fallbackFileSearchList[f]) return loadFile(fallbackFileSearchList[f], r)
} }
return fileTypeName, Configuration{}, fmt.Errorf( return fileTypeName, Configuration{}, fmt.Errorf(

View File

@@ -30,11 +30,14 @@ const (
// Redundant creates a group of loaders. They will be executed one by one until // Redundant creates a group of loaders. They will be executed one by one until
// one of it successfully returned a configuration // one of it successfully returned a configuration
func Redundant(loaders ...Loader) Loader { func Redundant(loaders ...Loader) Loader {
return func(log log.Logger) (string, Configuration, error) { return func(
log log.Logger,
r Reconfigurator,
) (string, Configuration, error) {
ll := log.Context("Redundant") ll := log.Context("Redundant")
for i := range loaders { for i := range loaders {
lLoaderName, lCfg, lErr := loaders[i](ll) lLoaderName, lCfg, lErr := loaders[i](ll, r)
if lErr != nil { if lErr != nil {
ll.Warning("Unable to load configuration from \"%s\": %s", ll.Warning("Unable to load configuration from \"%s\": %s",

View File

@@ -28,6 +28,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/niruix/sshwifty/application/command"
"github.com/niruix/sshwifty/application/configuration" "github.com/niruix/sshwifty/application/configuration"
"github.com/niruix/sshwifty/application/log" "github.com/niruix/sshwifty/application/log"
) )
@@ -50,6 +51,9 @@ type HandlerBuilder func(
cfg configuration.Server, cfg configuration.Server,
logger log.Logger) http.Handler logger log.Logger) http.Handler
// HandlerBuilderBuilder builds HandlerBuilder
type HandlerBuilderBuilder func(command.Commands) HandlerBuilder
// CloseCallback will be called when the server has closed // CloseCallback will be called when the server has closed
type CloseCallback func(error) type CloseCallback func(error)

View File

@@ -43,7 +43,8 @@ func main() {
len(os.Getenv("SSHWIFTY_DEBUG")) > 0, application.Name, os.Stderr)). len(os.Getenv("SSHWIFTY_DEBUG")) > 0, application.Name, os.Stderr)).
Run(configuration.Redundant(configLoaders...), Run(configuration.Redundant(configLoaders...),
application.DefaultProccessSignallerBuilder, application.DefaultProccessSignallerBuilder,
controller.Builder(commands.New())) commands.New(),
controller.Builder)
if e == nil { if e == nil {
return return