Fix Ctrl-C handling
This commit is contained in:
@@ -4,6 +4,8 @@ run:
|
|||||||
linters:
|
linters:
|
||||||
default: all
|
default: all
|
||||||
disable:
|
disable:
|
||||||
|
- gomodguard
|
||||||
|
- wsl
|
||||||
- revive
|
- revive
|
||||||
- noinlineerr
|
- noinlineerr
|
||||||
- mnd
|
- mnd
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package deconnector
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
@@ -16,7 +15,9 @@ func (d *Deconnector) handleDeconnect(clientConn net.Conn, connectReq *http.Requ
|
|||||||
// Read the real HTTP request the client sends through the tunnel
|
// Read the real HTTP request the client sends through the tunnel
|
||||||
innerReq, err := http.ReadRequest(bufio.NewReader(clientConn))
|
innerReq, err := http.ReadRequest(bufio.NewReader(clientConn))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("failed to read inner request after CONNECT:80: %v", err)
|
d.app.Logger().WithError(err).
|
||||||
|
Error("Failed to read inner request after CONNECT:80")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,5 +4,5 @@ import "errors"
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
ErrDeconnector = errors.New("deconnector")
|
ErrDeconnector = errors.New("deconnector")
|
||||||
ErrConnectDependencies = errors.New("failed to connect dependencies")
|
ErrConnectDependencies = errors.New("function ConnectDependencies()")
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -32,5 +32,6 @@ func (d *Deconnector) forwardHTTP(clientConn net.Conn, req *http.Request, upstre
|
|||||||
}
|
}
|
||||||
|
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
resp.Write(clientConn)
|
resp.Write(clientConn)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ func (d *Deconnector) Handle(clientConn net.Conn) {
|
|||||||
if req.URL.Host == "" {
|
if req.URL.Host == "" {
|
||||||
req.URL.Host = req.Host
|
req.URL.Host = req.Host
|
||||||
}
|
}
|
||||||
|
|
||||||
req.URL.Scheme = "http"
|
req.URL.Scheme = "http"
|
||||||
d.app.Logger().
|
d.app.Logger().
|
||||||
WithField("method", req.Method).
|
WithField("method", req.Method).
|
||||||
|
|||||||
@@ -37,7 +37,9 @@ func (d *Deconnector) handleTunnel(clientConn net.Conn, host string, upstreamURL
|
|||||||
fmt.Fprintf(clientConn, "HTTP/1.1 200 Connection established\r\n\r\n")
|
fmt.Fprintf(clientConn, "HTTP/1.1 200 Connection established\r\n\r\n")
|
||||||
|
|
||||||
done := make(chan struct{}, 2)
|
done := make(chan struct{}, 2)
|
||||||
|
|
||||||
go func() { io.Copy(upstreamConn, clientConn); done <- struct{}{} }()
|
go func() { io.Copy(upstreamConn, clientConn); done <- struct{}{} }()
|
||||||
go func() { io.Copy(clientConn, upstreamConn); done <- struct{}{} }()
|
go func() { io.Copy(clientConn, upstreamConn); done <- struct{}{} }()
|
||||||
|
|
||||||
<-done
|
<-done
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ func (d *Dialer) Dial() (net.Conn, error) {
|
|||||||
func (d *Dialer) UpstreamURL() (*url.URL, error) {
|
func (d *Dialer) UpstreamURL() (*url.URL, error) {
|
||||||
if d.app.Config().Upstream.URL == "" {
|
if d.app.Config().Upstream.URL == "" {
|
||||||
return nil, fmt.Errorf(
|
return nil, fmt.Errorf(
|
||||||
"%w: %w (%s)", ErrDialer, ErrParseURL,
|
"%w: %w (%s)", ErrDialer, ErrUpstreamURL,
|
||||||
"upstream URL is empty",
|
"upstream URL is empty",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -29,7 +29,7 @@ func (d *Dialer) UpstreamURL() (*url.URL, error) {
|
|||||||
u, err := url.Parse(d.app.Config().Upstream.URL)
|
u, err := url.Parse(d.app.Config().Upstream.URL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf(
|
return nil, fmt.Errorf(
|
||||||
"%w: %w (%w)", ErrDialer, ErrParseURL, err,
|
"%w: %w (%w)", ErrDialer, ErrUpstreamURL, err,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,6 @@ import "errors"
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
ErrDialer = errors.New("dialer")
|
ErrDialer = errors.New("dialer")
|
||||||
ErrConnectDependencies = errors.New("failed to connect dependencies")
|
ErrConnectDependencies = errors.New("function ConnectDependencies()")
|
||||||
ErrParseURL = errors.New("failed to parse URL")
|
ErrUpstreamURL = errors.New("function UpstreamURL()")
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,7 +4,9 @@ import "errors"
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
ErrListener = errors.New("listener")
|
ErrListener = errors.New("listener")
|
||||||
ErrConnectDependencies = errors.New("failed to connect dependencies")
|
ErrConnectDependencies = errors.New("function ConnectDependencies()")
|
||||||
|
ErrStart = errors.New("function Start()")
|
||||||
|
ErrListen = errors.New("function Listen()")
|
||||||
ErrFailedToGetWaitGroup = errors.New("failed to get global waitgroup")
|
ErrFailedToGetWaitGroup = errors.New("failed to get global waitgroup")
|
||||||
ErrFailedToListen = errors.New("failed to listen")
|
ErrFailedToListen = errors.New("failed to listen")
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -11,16 +11,30 @@ func (l *Listener) Listen() error {
|
|||||||
l.app.Config().Deconnect.Host+":"+l.app.Config().Deconnect.Port,
|
l.app.Config().Deconnect.Host+":"+l.app.Config().Deconnect.Port,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %w (%w)", ErrListener, ErrFailedToListen, err)
|
return fmt.Errorf("%w: %w (%w)", ErrListener, ErrListen, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
l.app.Logger().WithField("host", l.app.Config().Deconnect.Host).
|
l.app.Logger().WithField("host", l.app.Config().Deconnect.Host).
|
||||||
WithField("port", l.app.Config().Deconnect.Port).
|
WithField("port", l.app.Config().Deconnect.Port).
|
||||||
Info("Listening for incoming connections")
|
Info("Listening for incoming connections")
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-l.app.Context().Done()
|
||||||
|
|
||||||
|
l.app.Logger().Info("Shutting down listener")
|
||||||
|
|
||||||
|
if err := ln.Close(); err != nil {
|
||||||
|
l.app.Logger().WithError(err).Error("listener close error")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
conn, err := ln.Accept()
|
conn, err := ln.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if l.app.Context().Err() != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
l.app.Logger().WithError(err).Error("accept error")
|
l.app.Logger().WithError(err).Error("accept error")
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ func (l *Listener) ConnectDependencies() error {
|
|||||||
func (l *Listener) Start() error {
|
func (l *Listener) Start() error {
|
||||||
wg := l.app.GetGlobalWaitGroup()
|
wg := l.app.GetGlobalWaitGroup()
|
||||||
if wg == nil {
|
if wg == nil {
|
||||||
return fmt.Errorf("%w: %w (%s)", ErrListener, ErrFailedToGetWaitGroup, "got nil waitgroup")
|
return fmt.Errorf("%w: %w (%w)", ErrListener, ErrStart, ErrFailedToGetWaitGroup)
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Go(func() {
|
wg.Go(func() {
|
||||||
|
|||||||
Reference in New Issue
Block a user