diff --git a/.golangci.yml b/.golangci.yml index e87e08d..a458a28 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -4,6 +4,8 @@ run: linters: default: all disable: + - gomodguard + - wsl - revive - noinlineerr - mnd diff --git a/internal/domains/deconnector/connect.go b/internal/domains/deconnector/connect.go index 74613ee..509f9a3 100644 --- a/internal/domains/deconnector/connect.go +++ b/internal/domains/deconnector/connect.go @@ -3,7 +3,6 @@ package deconnector import ( "bufio" "fmt" - "log" "net" "net/http" "net/url" @@ -16,7 +15,9 @@ func (d *Deconnector) handleDeconnect(clientConn net.Conn, connectReq *http.Requ // Read the real HTTP request the client sends through the tunnel innerReq, err := http.ReadRequest(bufio.NewReader(clientConn)) if err != nil { - log.Printf("failed to read inner request after CONNECT:80: %v", err) + d.app.Logger().WithError(err). + Error("Failed to read inner request after CONNECT:80") + return } diff --git a/internal/domains/deconnector/errors.go b/internal/domains/deconnector/errors.go index ebdb47e..36a4bd1 100644 --- a/internal/domains/deconnector/errors.go +++ b/internal/domains/deconnector/errors.go @@ -4,5 +4,5 @@ import "errors" var ( ErrDeconnector = errors.New("deconnector") - ErrConnectDependencies = errors.New("failed to connect dependencies") + ErrConnectDependencies = errors.New("function ConnectDependencies()") ) diff --git a/internal/domains/deconnector/forward.go b/internal/domains/deconnector/forward.go index d92aca9..28ca792 100644 --- a/internal/domains/deconnector/forward.go +++ b/internal/domains/deconnector/forward.go @@ -32,5 +32,6 @@ func (d *Deconnector) forwardHTTP(clientConn net.Conn, req *http.Request, upstre } defer resp.Body.Close() + resp.Write(clientConn) } diff --git a/internal/domains/deconnector/handle.go b/internal/domains/deconnector/handle.go index b138026..7aff2b0 100644 --- a/internal/domains/deconnector/handle.go +++ b/internal/domains/deconnector/handle.go @@ -35,6 +35,7 @@ func (d *Deconnector) Handle(clientConn net.Conn) { if req.URL.Host == "" { req.URL.Host = req.Host } + req.URL.Scheme = "http" d.app.Logger(). WithField("method", req.Method). diff --git a/internal/domains/deconnector/tunnel.go b/internal/domains/deconnector/tunnel.go index 5f09bc7..97f83ba 100644 --- a/internal/domains/deconnector/tunnel.go +++ b/internal/domains/deconnector/tunnel.go @@ -37,7 +37,9 @@ func (d *Deconnector) handleTunnel(clientConn net.Conn, host string, upstreamURL fmt.Fprintf(clientConn, "HTTP/1.1 200 Connection established\r\n\r\n") done := make(chan struct{}, 2) + go func() { io.Copy(upstreamConn, clientConn); done <- struct{}{} }() go func() { io.Copy(clientConn, upstreamConn); done <- struct{}{} }() + <-done } diff --git a/internal/domains/dialer/dial.go b/internal/domains/dialer/dial.go index aeb09e8..64e0213 100644 --- a/internal/domains/dialer/dial.go +++ b/internal/domains/dialer/dial.go @@ -21,7 +21,7 @@ func (d *Dialer) Dial() (net.Conn, error) { func (d *Dialer) UpstreamURL() (*url.URL, error) { if d.app.Config().Upstream.URL == "" { return nil, fmt.Errorf( - "%w: %w (%s)", ErrDialer, ErrParseURL, + "%w: %w (%s)", ErrDialer, ErrUpstreamURL, "upstream URL is empty", ) } @@ -29,7 +29,7 @@ func (d *Dialer) UpstreamURL() (*url.URL, error) { u, err := url.Parse(d.app.Config().Upstream.URL) if err != nil { return nil, fmt.Errorf( - "%w: %w (%w)", ErrDialer, ErrParseURL, err, + "%w: %w (%w)", ErrDialer, ErrUpstreamURL, err, ) } diff --git a/internal/domains/dialer/errors.go b/internal/domains/dialer/errors.go index 4b2d261..0a25dd8 100644 --- a/internal/domains/dialer/errors.go +++ b/internal/domains/dialer/errors.go @@ -4,6 +4,6 @@ import "errors" var ( ErrDialer = errors.New("dialer") - ErrConnectDependencies = errors.New("failed to connect dependencies") - ErrParseURL = errors.New("failed to parse URL") + ErrConnectDependencies = errors.New("function ConnectDependencies()") + ErrUpstreamURL = errors.New("function UpstreamURL()") ) diff --git a/internal/domains/listener/errors.go b/internal/domains/listener/errors.go index b0a365c..c2137a9 100644 --- a/internal/domains/listener/errors.go +++ b/internal/domains/listener/errors.go @@ -4,7 +4,9 @@ import "errors" var ( ErrListener = errors.New("listener") - ErrConnectDependencies = errors.New("failed to connect dependencies") + ErrConnectDependencies = errors.New("function ConnectDependencies()") + ErrStart = errors.New("function Start()") + ErrListen = errors.New("function Listen()") ErrFailedToGetWaitGroup = errors.New("failed to get global waitgroup") ErrFailedToListen = errors.New("failed to listen") ) diff --git a/internal/domains/listener/listen.go b/internal/domains/listener/listen.go index 75fbd4a..f51e939 100644 --- a/internal/domains/listener/listen.go +++ b/internal/domains/listener/listen.go @@ -11,16 +11,30 @@ func (l *Listener) Listen() error { l.app.Config().Deconnect.Host+":"+l.app.Config().Deconnect.Port, ) if err != nil { - return fmt.Errorf("%w: %w (%w)", ErrListener, ErrFailedToListen, err) + return fmt.Errorf("%w: %w (%w)", ErrListener, ErrListen, err) } l.app.Logger().WithField("host", l.app.Config().Deconnect.Host). WithField("port", l.app.Config().Deconnect.Port). Info("Listening for incoming connections") + go func() { + <-l.app.Context().Done() + + l.app.Logger().Info("Shutting down listener") + + if err := ln.Close(); err != nil { + l.app.Logger().WithError(err).Error("listener close error") + } + }() + for { conn, err := ln.Accept() if err != nil { + if l.app.Context().Err() != nil { + return nil + } + l.app.Logger().WithError(err).Error("accept error") continue diff --git a/internal/domains/listener/listener.go b/internal/domains/listener/listener.go index ced5b13..92090ea 100644 --- a/internal/domains/listener/listener.go +++ b/internal/domains/listener/listener.go @@ -41,7 +41,7 @@ func (l *Listener) ConnectDependencies() error { func (l *Listener) Start() error { wg := l.app.GetGlobalWaitGroup() if wg == nil { - return fmt.Errorf("%w: %w (%s)", ErrListener, ErrFailedToGetWaitGroup, "got nil waitgroup") + return fmt.Errorf("%w: %w (%w)", ErrListener, ErrStart, ErrFailedToGetWaitGroup) } wg.Go(func() {