Allowing Commands to reconfigure app Configuration, this is will allow default ports to be currently added to remote presets
This commit is contained in:
@@ -26,6 +26,7 @@ import (
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/niruix/sshwifty/application/command"
|
||||
"github.com/niruix/sshwifty/application/configuration"
|
||||
"github.com/niruix/sshwifty/application/log"
|
||||
"github.com/niruix/sshwifty/application/server"
|
||||
@@ -66,11 +67,14 @@ func New(screen io.Writer, logger log.Logger) Application {
|
||||
func (a Application) run(
|
||||
cLoader configuration.Loader,
|
||||
closeSigBuilder ProccessSignallerBuilder,
|
||||
handlerBuilder server.HandlerBuilder,
|
||||
commands command.Commands,
|
||||
handlerBuilder server.HandlerBuilderBuilder,
|
||||
) (bool, error) {
|
||||
var err error
|
||||
|
||||
loaderName, c, cErr := cLoader(a.logger.Context("Configuration"))
|
||||
loaderName, c, cErr := cLoader(
|
||||
a.logger.Context("Configuration"),
|
||||
commands.Reconfigure)
|
||||
|
||||
if cErr != nil {
|
||||
a.logger.Error("\"%s\" loader cannot load configuration: %s",
|
||||
@@ -117,7 +121,7 @@ func (a Application) run(
|
||||
|
||||
close(closeNotify)
|
||||
closeNotify = nil
|
||||
}, handlerBuilder)
|
||||
}, handlerBuilder(commands))
|
||||
|
||||
servers = append(servers, newServer)
|
||||
}
|
||||
@@ -148,7 +152,8 @@ func (a Application) run(
|
||||
func (a Application) Run(
|
||||
cLoader configuration.Loader,
|
||||
closeSigBuilder ProccessSignallerBuilder,
|
||||
handlerBuilder server.HandlerBuilder,
|
||||
commands command.Commands,
|
||||
handlerBuilder server.HandlerBuilderBuilder,
|
||||
) error {
|
||||
fmt.Fprintf(a.screen, banner, FullName, version, Author, URL)
|
||||
|
||||
@@ -159,7 +164,8 @@ func (a Application) Run(
|
||||
defer a.logger.Info("Closed")
|
||||
|
||||
for {
|
||||
restart, runErr := a.run(cLoader, closeSigBuilder, handlerBuilder)
|
||||
restart, runErr := a.run(
|
||||
cLoader, closeSigBuilder, commands, handlerBuilder)
|
||||
|
||||
if runErr != nil {
|
||||
a.logger.Error("Unable to start due to error: %s", runErr)
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/niruix/sshwifty/application/configuration"
|
||||
"github.com/niruix/sshwifty/application/log"
|
||||
)
|
||||
|
||||
@@ -42,20 +43,35 @@ type Command func(
|
||||
cfg Configuration,
|
||||
) FSMMachine
|
||||
|
||||
// Builder builds a command
|
||||
type Builder struct {
|
||||
command Command
|
||||
configurator configuration.Reconfigurator
|
||||
}
|
||||
|
||||
// Register builds a Builder for registration
|
||||
func Register(c Command, p configuration.Reconfigurator) Builder {
|
||||
return Builder{
|
||||
command: c,
|
||||
configurator: p,
|
||||
}
|
||||
}
|
||||
|
||||
// Commands contains data of all commands
|
||||
type Commands [MaxCommandID + 1]Command
|
||||
type Commands [MaxCommandID + 1]Builder
|
||||
|
||||
// Register registers a new command
|
||||
func (c *Commands) Register(id byte, cb Command) {
|
||||
func (c *Commands) Register(
|
||||
id byte, cb Command, ps configuration.Reconfigurator) {
|
||||
if id > MaxCommandID {
|
||||
panic("Command ID must be not greater than MaxCommandID")
|
||||
}
|
||||
|
||||
if (*c)[id] != nil {
|
||||
if (*c)[id].command != nil {
|
||||
panic(fmt.Sprintf("Command %d already been registered", id))
|
||||
}
|
||||
|
||||
(*c)[id] = cb
|
||||
(*c)[id] = Register(cb, ps)
|
||||
}
|
||||
|
||||
// Run creates command executer
|
||||
@@ -63,16 +79,32 @@ func (c Commands) Run(
|
||||
id byte,
|
||||
l log.Logger,
|
||||
w StreamResponder,
|
||||
cfg Configuration) (FSM, error) {
|
||||
cfg Configuration,
|
||||
) (FSM, error) {
|
||||
if id > MaxCommandID {
|
||||
return FSM{}, ErrCommandRunUndefinedCommand
|
||||
}
|
||||
|
||||
cc := c[id]
|
||||
|
||||
if cc == nil {
|
||||
if cc.command == nil {
|
||||
return FSM{}, ErrCommandRunUndefinedCommand
|
||||
}
|
||||
|
||||
return newFSM(cc(l, w, cfg)), nil
|
||||
return newFSM(cc.command(l, w, cfg)), nil
|
||||
}
|
||||
|
||||
// Reconfigure lets commands reset configuration
|
||||
func (c Commands) Reconfigure(
|
||||
p configuration.Configuration,
|
||||
) configuration.Configuration {
|
||||
for i := range c {
|
||||
if c[i].configurator == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
p = c[i].configurator(p)
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
@@ -171,7 +171,7 @@ func (d *dummyStreamCommand) Release() error {
|
||||
|
||||
func TestHandlerHandleStream(t *testing.T) {
|
||||
cmds := Commands{}
|
||||
cmds.Register(0, newDummyStreamCommand)
|
||||
cmds.Register(0, newDummyStreamCommand, nil)
|
||||
|
||||
readerDataInput := make(chan []byte)
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ import (
|
||||
// New creates a new commands group
|
||||
func New() command.Commands {
|
||||
return command.Commands{
|
||||
newTelnet,
|
||||
newSSH,
|
||||
command.Register(newTelnet, parseTelnetConfig),
|
||||
command.Register(newSSH, parseSSHConfig),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/niruix/sshwifty/application/command"
|
||||
"github.com/niruix/sshwifty/application/configuration"
|
||||
"github.com/niruix/sshwifty/application/log"
|
||||
"github.com/niruix/sshwifty/application/network"
|
||||
"github.com/niruix/sshwifty/application/rw"
|
||||
@@ -107,6 +108,10 @@ var (
|
||||
sshEmptyTime = time.Time{}
|
||||
)
|
||||
|
||||
const (
|
||||
sshDefaultPortString = "22"
|
||||
)
|
||||
|
||||
type sshRemoteConnWrapper struct {
|
||||
net.Conn
|
||||
|
||||
@@ -198,6 +203,30 @@ func newSSH(
|
||||
}
|
||||
}
|
||||
|
||||
func parseSSHConfig(p configuration.Configuration) configuration.Configuration {
|
||||
for i := range p.Presets {
|
||||
if p.Presets[i].Type != "SSH" {
|
||||
continue
|
||||
}
|
||||
|
||||
oldHost := p.Presets[i].Host
|
||||
|
||||
_, _, sErr := net.SplitHostPort(p.Presets[i].Host)
|
||||
|
||||
if sErr != nil {
|
||||
p.Presets[i].Host = net.JoinHostPort(
|
||||
p.Presets[i].Host,
|
||||
sshDefaultPortString)
|
||||
}
|
||||
|
||||
if len(p.Presets[i].Host) <= 0 {
|
||||
p.Presets[i].Host = oldHost
|
||||
}
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func (d *sshClient) Bootup(
|
||||
r *rw.LimitedReader,
|
||||
b []byte,
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/niruix/sshwifty/application/command"
|
||||
"github.com/niruix/sshwifty/application/configuration"
|
||||
"github.com/niruix/sshwifty/application/log"
|
||||
"github.com/niruix/sshwifty/application/network"
|
||||
"github.com/niruix/sshwifty/application/rw"
|
||||
@@ -40,6 +41,10 @@ const (
|
||||
TelnetRequestErrorBadRemoteAddress = command.StreamError(0x01)
|
||||
)
|
||||
|
||||
const (
|
||||
telnetDefaultPortString = "23"
|
||||
)
|
||||
|
||||
// Server signal codes
|
||||
const (
|
||||
TelnetServerRemoteBand = 0x00
|
||||
@@ -71,6 +76,32 @@ func newTelnet(
|
||||
}
|
||||
}
|
||||
|
||||
func parseTelnetConfig(
|
||||
p configuration.Configuration,
|
||||
) configuration.Configuration {
|
||||
for i := range p.Presets {
|
||||
if p.Presets[i].Type != "Telnet" {
|
||||
continue
|
||||
}
|
||||
|
||||
oldHost := p.Presets[i].Host
|
||||
|
||||
_, _, sErr := net.SplitHostPort(p.Presets[i].Host)
|
||||
|
||||
if sErr != nil {
|
||||
p.Presets[i].Host = net.JoinHostPort(
|
||||
p.Presets[i].Host,
|
||||
telnetDefaultPortString)
|
||||
}
|
||||
|
||||
if len(p.Presets[i].Host) <= 0 {
|
||||
p.Presets[i].Host = oldHost
|
||||
}
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func (d *telnetClient) Bootup(
|
||||
r *rw.LimitedReader,
|
||||
b []byte) (command.FSMState, command.FSMError) {
|
||||
|
||||
@@ -132,8 +132,10 @@ type Preset struct {
|
||||
type Configuration struct {
|
||||
HostName string
|
||||
SharedKey string
|
||||
Dialer network.Dial
|
||||
DialTimeout time.Duration
|
||||
Socks5 string
|
||||
Socks5User string
|
||||
Socks5Password string
|
||||
Servers []Server
|
||||
Presets []Preset
|
||||
OnlyAllowPresetRemotes bool
|
||||
@@ -168,12 +170,50 @@ func (c Configuration) Verify() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Dialer builds a Dialer
|
||||
func (c Configuration) Dialer() network.Dial {
|
||||
dialTimeout := c.DialTimeout
|
||||
|
||||
if dialTimeout < 3 {
|
||||
dialTimeout = 3
|
||||
}
|
||||
|
||||
dialer := network.TCPDial()
|
||||
|
||||
if len(c.Socks5) > 0 {
|
||||
sDial, sDialErr := network.BuildSocks5Dial(
|
||||
c.Socks5, c.Socks5User, c.Socks5Password)
|
||||
|
||||
if sDialErr != nil {
|
||||
panic("Unable to build Socks5 Dialer: " + sDialErr.Error())
|
||||
}
|
||||
|
||||
dialer = sDial
|
||||
}
|
||||
|
||||
if c.OnlyAllowPresetRemotes {
|
||||
accessList := make(network.AllowedHosts, len(c.Presets))
|
||||
|
||||
for _, k := range c.Presets {
|
||||
if len(k.Host) <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
accessList[k.Host] = struct{}{}
|
||||
}
|
||||
|
||||
dialer = network.AccessControlDial(accessList, dialer)
|
||||
}
|
||||
|
||||
return dialer
|
||||
}
|
||||
|
||||
// Common returns common settings
|
||||
func (c Configuration) Common() Common {
|
||||
return Common{
|
||||
HostName: c.HostName,
|
||||
SharedKey: c.SharedKey,
|
||||
Dialer: c.Dialer,
|
||||
Dialer: c.Dialer(),
|
||||
DialTimeout: c.DialTimeout,
|
||||
Presets: c.Presets,
|
||||
OnlyAllowPresetRemotes: c.OnlyAllowPresetRemotes,
|
||||
|
||||
@@ -21,5 +21,11 @@ import (
|
||||
"github.com/niruix/sshwifty/application/log"
|
||||
)
|
||||
|
||||
// Reconfigurator reloads configuration
|
||||
type Reconfigurator func(p Configuration) Configuration
|
||||
|
||||
// Loader Configuration loader
|
||||
type Loader func(log log.Logger) (name string, cfg Configuration, err error)
|
||||
type Loader func(
|
||||
log log.Logger,
|
||||
r Reconfigurator,
|
||||
) (name string, cfg Configuration, err error)
|
||||
|
||||
@@ -28,7 +28,10 @@ const (
|
||||
// Direct creates a loader that return raw configuration data directly.
|
||||
// Good for integration.
|
||||
func Direct(cfg Configuration) Loader {
|
||||
return func(log log.Logger) (string, Configuration, error) {
|
||||
return func(
|
||||
log log.Logger,
|
||||
r Reconfigurator,
|
||||
) (string, Configuration, error) {
|
||||
return directTypeName, cfg, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,13 +44,16 @@ func parseEviro(name string) string {
|
||||
|
||||
// Enviro creates an environment variable based configuration loader
|
||||
func Enviro() Loader {
|
||||
return func(log log.Logger) (string, Configuration, error) {
|
||||
return func(
|
||||
log log.Logger,
|
||||
r Reconfigurator,
|
||||
) (string, Configuration, error) {
|
||||
log.Info("Loading configuration from environment variables ...")
|
||||
|
||||
dialTimeout, _ := strconv.ParseUint(
|
||||
parseEviro("SSHWIFTY_DIALTIMEOUT"), 10, 32)
|
||||
|
||||
cfg, dialer, cfgErr := fileCfgCommon{
|
||||
cfg, cfgErr := fileCfgCommon{
|
||||
HostName: parseEviro("SSHWIFTY_HOSTNAME"),
|
||||
SharedKey: parseEviro("SSHWIFTY_SHAREDKEY"),
|
||||
DialTimeout: int(dialTimeout),
|
||||
@@ -68,14 +71,15 @@ func Enviro() Loader {
|
||||
"Failed to build the configuration: %s", cfgErr)
|
||||
}
|
||||
|
||||
listenPort, listenPortErr := strconv.ParseUint(
|
||||
parseEviro("SSHWIFTY_LISTENPORT"), 10, 16)
|
||||
listenIface := parseEviro("SSHWIFTY_LISTENINTERFACE")
|
||||
|
||||
if listenPortErr != nil {
|
||||
return enviroTypeName, Configuration{}, fmt.Errorf(
|
||||
"Invalid \"SSHWIFTY_LISTENPORT\": %s", listenPortErr)
|
||||
if len(listenIface) <= 0 {
|
||||
listenIface = "127.0.0.1"
|
||||
}
|
||||
|
||||
listenPort, _ := strconv.ParseUint(
|
||||
parseEviro("SSHWIFTY_LISTENPORT"), 10, 16)
|
||||
|
||||
initialTimeout, _ := strconv.ParseUint(
|
||||
parseEviro("SSHWIFTY_INITIALTIMEOUT"), 10, 32)
|
||||
|
||||
@@ -95,7 +99,7 @@ func Enviro() Loader {
|
||||
parseEviro("SSHWIFTY_WRITEELAY"), 10, 32)
|
||||
|
||||
cfgSer := fileCfgServer{
|
||||
ListenInterface: parseEviro("SSHWIFTY_LISTENINTERFACE"),
|
||||
ListenInterface: listenIface,
|
||||
ListenPort: uint16(listenPort),
|
||||
InitialTimeout: int(initialTimeout),
|
||||
ReadTimeout: int(readTimeout),
|
||||
@@ -122,8 +126,10 @@ func Enviro() Loader {
|
||||
return enviroTypeName, Configuration{
|
||||
HostName: cfg.HostName,
|
||||
SharedKey: cfg.SharedKey,
|
||||
Dialer: dialer,
|
||||
DialTimeout: time.Duration(cfg.DialTimeout) * time.Second,
|
||||
Socks5: cfg.Socks5,
|
||||
Socks5User: cfg.Socks5User,
|
||||
Socks5Password: cfg.Socks5Password,
|
||||
Servers: []Server{cfgSer.build()},
|
||||
Presets: presets,
|
||||
OnlyAllowPresetRemotes: cfg.OnlyAllowPresetRemotes,
|
||||
|
||||
@@ -28,7 +28,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/niruix/sshwifty/application/log"
|
||||
"github.com/niruix/sshwifty/application/network"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -122,56 +121,24 @@ type fileCfgCommon struct {
|
||||
OnlyAllowPresetRemotes bool
|
||||
}
|
||||
|
||||
func (f fileCfgCommon) build() (fileCfgCommon, network.Dial, error) {
|
||||
dialTimeout := f.DialTimeout
|
||||
|
||||
if dialTimeout < 3 {
|
||||
dialTimeout = 3
|
||||
}
|
||||
|
||||
var dialer network.Dial
|
||||
|
||||
if len(f.Socks5) <= 0 {
|
||||
dialer = network.TCPDial()
|
||||
} else {
|
||||
sDial, sDialErr := network.BuildSocks5Dial(
|
||||
f.Socks5, f.Socks5User, f.Socks5Password)
|
||||
|
||||
if sDialErr != nil {
|
||||
return fileCfgCommon{}, nil, sDialErr
|
||||
}
|
||||
|
||||
dialer = sDial
|
||||
}
|
||||
|
||||
if f.OnlyAllowPresetRemotes {
|
||||
accessList := make(network.AllowedHosts, len(f.Presets))
|
||||
|
||||
for _, k := range f.Presets {
|
||||
if len(k.Host) <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
accessList[k.Host] = struct{}{}
|
||||
}
|
||||
|
||||
dialer = network.AccessControlDial(accessList, dialer)
|
||||
}
|
||||
|
||||
func (f fileCfgCommon) build() (fileCfgCommon, error) {
|
||||
return fileCfgCommon{
|
||||
HostName: f.HostName,
|
||||
SharedKey: f.SharedKey,
|
||||
DialTimeout: dialTimeout,
|
||||
DialTimeout: f.DialTimeout,
|
||||
Socks5: f.Socks5,
|
||||
Socks5User: f.Socks5User,
|
||||
Socks5Password: f.Socks5Password,
|
||||
Servers: f.Servers,
|
||||
Presets: f.Presets,
|
||||
OnlyAllowPresetRemotes: f.OnlyAllowPresetRemotes,
|
||||
}, dialer, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
func loadFile(filePath string) (string, Configuration, error) {
|
||||
func loadFile(
|
||||
filePath string,
|
||||
r Reconfigurator,
|
||||
) (string, Configuration, error) {
|
||||
f, fErr := os.Open(filePath)
|
||||
|
||||
if fErr != nil {
|
||||
@@ -189,7 +156,7 @@ func loadFile(filePath string) (string, Configuration, error) {
|
||||
return fileTypeName, Configuration{}, jDecodeErr
|
||||
}
|
||||
|
||||
finalCfg, dialer, cfgErr := cfg.build()
|
||||
finalCfg, cfgErr := cfg.build()
|
||||
|
||||
if cfgErr != nil {
|
||||
return fileTypeName, Configuration{}, cfgErr
|
||||
@@ -207,25 +174,30 @@ func loadFile(filePath string) (string, Configuration, error) {
|
||||
presets[i] = finalCfg.Presets[i].build()
|
||||
}
|
||||
|
||||
return fileTypeName, Configuration{
|
||||
return fileTypeName, r(Configuration{
|
||||
HostName: finalCfg.HostName,
|
||||
SharedKey: finalCfg.SharedKey,
|
||||
Dialer: dialer,
|
||||
DialTimeout: time.Duration(finalCfg.DialTimeout) *
|
||||
time.Second,
|
||||
Socks5: cfg.Socks5,
|
||||
Socks5User: cfg.Socks5User,
|
||||
Socks5Password: cfg.Socks5Password,
|
||||
Servers: servers,
|
||||
Presets: presets,
|
||||
OnlyAllowPresetRemotes: cfg.OnlyAllowPresetRemotes,
|
||||
}, nil
|
||||
}), nil
|
||||
}
|
||||
|
||||
// File creates a configuration file loader
|
||||
func File(customPath string) Loader {
|
||||
return func(log log.Logger) (string, Configuration, error) {
|
||||
return func(
|
||||
log log.Logger,
|
||||
r Reconfigurator,
|
||||
) (string, Configuration, error) {
|
||||
if len(customPath) > 0 {
|
||||
log.Info("Loading configuration from: %s", customPath)
|
||||
|
||||
return loadFile(customPath)
|
||||
return loadFile(customPath, r)
|
||||
}
|
||||
|
||||
log.Info("Loading configuration from one of the default " +
|
||||
@@ -267,7 +239,7 @@ func File(customPath string) Loader {
|
||||
log.Info("Configuration file \"%s\" has been selected",
|
||||
fallbackFileSearchList[f])
|
||||
|
||||
return loadFile(fallbackFileSearchList[f])
|
||||
return loadFile(fallbackFileSearchList[f], r)
|
||||
}
|
||||
|
||||
return fileTypeName, Configuration{}, fmt.Errorf(
|
||||
|
||||
@@ -30,11 +30,14 @@ const (
|
||||
// Redundant creates a group of loaders. They will be executed one by one until
|
||||
// one of it successfully returned a configuration
|
||||
func Redundant(loaders ...Loader) Loader {
|
||||
return func(log log.Logger) (string, Configuration, error) {
|
||||
return func(
|
||||
log log.Logger,
|
||||
r Reconfigurator,
|
||||
) (string, Configuration, error) {
|
||||
ll := log.Context("Redundant")
|
||||
|
||||
for i := range loaders {
|
||||
lLoaderName, lCfg, lErr := loaders[i](ll)
|
||||
lLoaderName, lCfg, lErr := loaders[i](ll, r)
|
||||
|
||||
if lErr != nil {
|
||||
ll.Warning("Unable to load configuration from \"%s\": %s",
|
||||
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/niruix/sshwifty/application/command"
|
||||
"github.com/niruix/sshwifty/application/configuration"
|
||||
"github.com/niruix/sshwifty/application/log"
|
||||
)
|
||||
@@ -50,6 +51,9 @@ type HandlerBuilder func(
|
||||
cfg configuration.Server,
|
||||
logger log.Logger) http.Handler
|
||||
|
||||
// HandlerBuilderBuilder builds HandlerBuilder
|
||||
type HandlerBuilderBuilder func(command.Commands) HandlerBuilder
|
||||
|
||||
// CloseCallback will be called when the server has closed
|
||||
type CloseCallback func(error)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user