diff --git a/application/network/conn.go b/application/network/conn.go new file mode 100644 index 0000000..294152c --- /dev/null +++ b/application/network/conn.go @@ -0,0 +1,26 @@ +// Sshwifty - A Web SSH client +// +// Copyright (C) 2019 Rui NI +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package network + +import ( + "time" +) + +var ( + emptyTime = time.Time{} +) diff --git a/application/network/conn_timeout.go b/application/network/conn_timeout.go new file mode 100644 index 0000000..2542c60 --- /dev/null +++ b/application/network/conn_timeout.go @@ -0,0 +1,221 @@ +// Sshwifty - A Web SSH client +// +// Copyright (C) 2019 Rui NI +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package network + +import ( + "net" + "time" +) + +// TimeoutConn read write +type TimeoutConn struct { + net.Conn + + readTimeout time.Duration + disableNextReadTimeout bool + writeTimeout time.Duration + disableNextWriteTimeout bool +} + +// NewTimeoutConn creates a new TimeoutConn +func NewTimeoutConn( + c net.Conn, + rTimeout time.Duration, + wTimeout time.Duration, +) TimeoutConn { + return TimeoutConn{ + Conn: c, + readTimeout: rTimeout, + disableNextReadTimeout: false, + writeTimeout: wTimeout, + disableNextWriteTimeout: false, + } +} + +// SetReadTimeout sets read timeout +func (c *TimeoutConn) SetReadTimeout(t time.Duration) { + c.readTimeout = t +} + +// SetReadDeadline sets the next read deadline +func (c *TimeoutConn) SetReadDeadline(t time.Time) error { + c.disableNextReadTimeout = t.Before(time.Now()) + + if t.Equal(emptyTime) { + return c.Conn.SetReadDeadline(time.Now().Add(c.readTimeout)) + } + + return c.Conn.SetReadDeadline(t) +} + +// Read reads data +func (c *TimeoutConn) Read(b []byte) (int, error) { + defer func() { + c.disableNextReadTimeout = false + }() + + cLen, cErr := c.Conn.Read(b) + + if cErr == nil { + return cLen, nil + } + + netErr, isNetErr := cErr.(net.Error) + + if !isNetErr || + c.disableNextReadTimeout || + c.readTimeout <= 0 || + !netErr.Timeout() { + return cLen, cErr + } + + cErr = c.Conn.SetReadDeadline(time.Now().Add(c.readTimeout)) + + if cErr != nil { + return cLen, cErr + } + + tryCLen, cErr := c.Conn.Read(b[cLen:]) + + return tryCLen + cLen, cErr +} + +// SetWriteTimeout sets write timeout +func (c *TimeoutConn) SetWriteTimeout(t time.Duration) { + c.writeTimeout = t +} + +// SetWriteDeadline sets the next read deadline +func (c *TimeoutConn) SetWriteDeadline(t time.Time) error { + c.disableNextWriteTimeout = t.Before(time.Now()) + + if t.Equal(emptyTime) { + return c.Conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) + } + + return c.Conn.SetWriteDeadline(t) +} + +// Write writes data +func (c *TimeoutConn) Write(b []byte) (int, error) { + defer func() { + c.disableNextWriteTimeout = false + }() + + cLen, cErr := c.Conn.Write(b) + + if cErr == nil { + return cLen, nil + } + + netErr, isNetErr := cErr.(net.Error) + + if !isNetErr || + c.disableNextWriteTimeout || + c.writeTimeout <= 0 || + !netErr.Timeout() { + return cLen, cErr + } + + cErr = c.Conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) + + if cErr != nil { + return cLen, cErr + } + + tryCLen, cErr := c.Conn.Write(b[cLen:]) + + return tryCLen + cLen, cErr +} + +// SetDeadline sets read and write deadline +func (c *TimeoutConn) SetDeadline(t time.Time) error { + c.SetReadDeadline(t) + c.SetWriteDeadline(t) + + return nil +} + +// ReadTimeoutConn is a reader that will enforce a timeout rules +type ReadTimeoutConn struct { + net.Conn + + reader TimeoutConn +} + +// NewReadTimeoutConn creates a ReadTimeoutConn +func NewReadTimeoutConn(c net.Conn, timeout time.Duration) ReadTimeoutConn { + return ReadTimeoutConn{ + Conn: c, + reader: TimeoutConn{ + Conn: c, + readTimeout: timeout, + writeTimeout: 0, + }, + } +} + +// SetReadDeadline sets read deadline +func (c *ReadTimeoutConn) SetReadDeadline(t time.Time) error { + return c.reader.SetReadDeadline(t) +} + +// SetReadTimeout sets write timeout +func (c *ReadTimeoutConn) SetReadTimeout(t time.Duration) { + c.reader.SetReadTimeout(t) +} + +// Read writes data +func (c ReadTimeoutConn) Read(b []byte) (int, error) { + return c.reader.Read(b) +} + +// WriteTimeoutConn is a writer that will enforce a timeout rules onto a +// net.Conn +type WriteTimeoutConn struct { + net.Conn + + writer TimeoutConn +} + +// NewWriteTimeoutConn creates a WriteTimeoutConnWriter +func NewWriteTimeoutConn(c net.Conn, timeout time.Duration) WriteTimeoutConn { + return WriteTimeoutConn{ + Conn: c, + writer: TimeoutConn{ + Conn: c, + readTimeout: 0, + writeTimeout: timeout, + }, + } +} + +// SetWriteDeadline sets write deadline +func (c *WriteTimeoutConn) SetWriteDeadline(t time.Time) error { + return c.writer.SetWriteDeadline(t) +} + +// SetWriteTimeout sets write timeout +func (c *WriteTimeoutConn) SetWriteTimeout(t time.Duration) { + c.writer.SetWriteTimeout(t) +} + +// Write writes data +func (c WriteTimeoutConn) Write(b []byte) (int, error) { + return c.writer.Write(b) +} diff --git a/application/network/dial_socks5.go b/application/network/dial_socks5.go index 548490f..8868d07 100644 --- a/application/network/dial_socks5.go +++ b/application/network/dial_socks5.go @@ -25,10 +25,6 @@ import ( "golang.org/x/net/proxy" ) -var ( - emptyTime = time.Time{} -) - type socks5Dial struct { net.Dialer } diff --git a/application/server/conn.go b/application/server/conn.go new file mode 100644 index 0000000..6f0c250 --- /dev/null +++ b/application/server/conn.go @@ -0,0 +1,89 @@ +// Sshwifty - A Web SSH client +// +// Copyright (C) 2019 Rui NI +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package server + +import ( + "net" + "time" + + "github.com/niruix/sshwifty/application/network" +) + +var ( + emptyTime = time.Time{} +) + +type listener struct { + *net.TCPListener + + readTimeout time.Duration + writeTimeout time.Duration +} + +func (l listener) Accept() (net.Conn, error) { + acc, accErr := l.TCPListener.Accept() + + if accErr != nil { + return nil, accErr + } + + timeoutConn := network.NewTimeoutConn(acc, l.readTimeout, l.writeTimeout) + + return conn{ + TimeoutConn: &timeoutConn, + readTimeout: l.readTimeout, + writeTimeout: l.writeTimeout, + }, nil +} + +// conn is a net.Conn hack, we use it prevent the upper to alter some important +// configuration of the connection, mainly the timeouts. +type conn struct { + *network.TimeoutConn + + readTimeout time.Duration + writeTimeout time.Duration +} + +func (c conn) normalizeTimeout(t time.Time, m time.Duration) time.Time { + max := time.Now().Add(m) + + // You cannot set timeout that is longer than the given m + if t.After(max) { + return max + } + + return t +} + +func (c conn) SetDeadline(dl time.Time) error { + c.SetReadDeadline(dl) + c.SetWriteDeadline(dl) + + return nil +} + +func (c conn) SetReadDeadline(dl time.Time) error { + return c.TimeoutConn.SetReadDeadline( + c.normalizeTimeout(dl, c.readTimeout)) +} + +func (c conn) SetWriteDeadline(dl time.Time) error { + return c.TimeoutConn.SetWriteDeadline( + c.normalizeTimeout(dl, c.writeTimeout)) +} diff --git a/application/server/server.go b/application/server/server.go index fcea3f1..f01a38a 100644 --- a/application/server/server.go +++ b/application/server/server.go @@ -26,6 +26,7 @@ import ( "net/http" "strconv" "sync" + "time" "github.com/niruix/sshwifty/application/configuration" "github.com/niruix/sshwifty/application/log" @@ -111,11 +112,15 @@ func (s Server) Wait() { } func (s *Serving) buildListener( - ip string, port uint16) (*net.TCPListener, error) { + ip string, + port uint16, + readTimeout time.Duration, + writeTimeout time.Duration, +) (listener, error) { ipAddr := net.ParseIP(ip) if ipAddr == nil { - return nil, ErrInvalidIPAddress + return listener{}, ErrInvalidIPAddress } ipPort := net.JoinHostPort( @@ -124,10 +129,20 @@ func (s *Serving) buildListener( addr, addrErr := net.ResolveTCPAddr("tcp", ipPort) if addrErr != nil { - return nil, addrErr + return listener{}, addrErr } - return net.ListenTCP("tcp", addr) + ll, llErr := net.ListenTCP("tcp", addr) + + if llErr != nil { + return listener{}, llErr + } + + return listener{ + TCPListener: ll, + readTimeout: readTimeout, + writeTimeout: writeTimeout, + }, nil } // run starts the server @@ -137,7 +152,6 @@ func (s *Serving) run( closeCallback CloseCallback, ) error { var err error - var ls *net.TCPListener defer func() { if err == nil || err == http.ErrServerClosed { @@ -151,7 +165,12 @@ func (s *Serving) run( closeCallback(err) }() - ls, err = s.buildListener(cfg.ListenInterface, cfg.ListenPort) + ls, err := s.buildListener( + cfg.ListenInterface, + cfg.ListenPort, + cfg.ReadTimeout, + cfg.WriteTimeout, + ) if err != nil { return err