Prevent SetDeadline, SetReadDeadline and SetWriteDeadline from setting up unlimited deadline, and also impose our own timeout retry mechanism.

This commit is contained in:
NI
2019-09-18 18:51:17 +08:00
parent aba5993ffb
commit c7df33f14f
5 changed files with 361 additions and 10 deletions

View File

@@ -0,0 +1,89 @@
// Sshwifty - A Web SSH client
//
// Copyright (C) 2019 Rui NI <nirui@gmx.com>
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as
// published by the Free Software Foundation, either version 3 of the
// License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package server
import (
"net"
"time"
"github.com/niruix/sshwifty/application/network"
)
var (
emptyTime = time.Time{}
)
type listener struct {
*net.TCPListener
readTimeout time.Duration
writeTimeout time.Duration
}
func (l listener) Accept() (net.Conn, error) {
acc, accErr := l.TCPListener.Accept()
if accErr != nil {
return nil, accErr
}
timeoutConn := network.NewTimeoutConn(acc, l.readTimeout, l.writeTimeout)
return conn{
TimeoutConn: &timeoutConn,
readTimeout: l.readTimeout,
writeTimeout: l.writeTimeout,
}, nil
}
// conn is a net.Conn hack, we use it prevent the upper to alter some important
// configuration of the connection, mainly the timeouts.
type conn struct {
*network.TimeoutConn
readTimeout time.Duration
writeTimeout time.Duration
}
func (c conn) normalizeTimeout(t time.Time, m time.Duration) time.Time {
max := time.Now().Add(m)
// You cannot set timeout that is longer than the given m
if t.After(max) {
return max
}
return t
}
func (c conn) SetDeadline(dl time.Time) error {
c.SetReadDeadline(dl)
c.SetWriteDeadline(dl)
return nil
}
func (c conn) SetReadDeadline(dl time.Time) error {
return c.TimeoutConn.SetReadDeadline(
c.normalizeTimeout(dl, c.readTimeout))
}
func (c conn) SetWriteDeadline(dl time.Time) error {
return c.TimeoutConn.SetWriteDeadline(
c.normalizeTimeout(dl, c.writeTimeout))
}

View File

@@ -26,6 +26,7 @@ import (
"net/http"
"strconv"
"sync"
"time"
"github.com/niruix/sshwifty/application/configuration"
"github.com/niruix/sshwifty/application/log"
@@ -111,11 +112,15 @@ func (s Server) Wait() {
}
func (s *Serving) buildListener(
ip string, port uint16) (*net.TCPListener, error) {
ip string,
port uint16,
readTimeout time.Duration,
writeTimeout time.Duration,
) (listener, error) {
ipAddr := net.ParseIP(ip)
if ipAddr == nil {
return nil, ErrInvalidIPAddress
return listener{}, ErrInvalidIPAddress
}
ipPort := net.JoinHostPort(
@@ -124,10 +129,20 @@ func (s *Serving) buildListener(
addr, addrErr := net.ResolveTCPAddr("tcp", ipPort)
if addrErr != nil {
return nil, addrErr
return listener{}, addrErr
}
return net.ListenTCP("tcp", addr)
ll, llErr := net.ListenTCP("tcp", addr)
if llErr != nil {
return listener{}, llErr
}
return listener{
TCPListener: ll,
readTimeout: readTimeout,
writeTimeout: writeTimeout,
}, nil
}
// run starts the server
@@ -137,7 +152,6 @@ func (s *Serving) run(
closeCallback CloseCallback,
) error {
var err error
var ls *net.TCPListener
defer func() {
if err == nil || err == http.ErrServerClosed {
@@ -151,7 +165,12 @@ func (s *Serving) run(
closeCallback(err)
}()
ls, err = s.buildListener(cfg.ListenInterface, cfg.ListenPort)
ls, err := s.buildListener(
cfg.ListenInterface,
cfg.ListenPort,
cfg.ReadTimeout,
cfg.WriteTimeout,
)
if err != nil {
return err