Fix linter issues
This commit is contained in:
@@ -29,9 +29,9 @@ func (a *App) Logger() *logrus.Entry {
|
||||
}
|
||||
|
||||
func New(ctx context.Context) *App {
|
||||
var m runtime.MemStats
|
||||
var memStats runtime.MemStats
|
||||
|
||||
runtime.ReadMemStats(&m)
|
||||
runtime.ReadMemStats(&memStats)
|
||||
|
||||
app := new(App)
|
||||
|
||||
@@ -43,9 +43,9 @@ func New(ctx context.Context) *App {
|
||||
})
|
||||
|
||||
app.logger = logger.WithContext(ctx).WithFields(logrus.Fields{
|
||||
"memalloc": fmt.Sprintf("%dMB", m.Alloc/1024/1024),
|
||||
"memsys": fmt.Sprintf("%dMB", m.Sys/1024/1024),
|
||||
"numgc": strconv.FormatUint(uint64(m.NumGC), 10),
|
||||
"memalloc": fmt.Sprintf("%dMB", memStats.Alloc/1024/1024),
|
||||
"memsys": fmt.Sprintf("%dMB", memStats.Sys/1024/1024),
|
||||
"numgc": strconv.FormatUint(uint64(memStats.NumGC), 10),
|
||||
})
|
||||
|
||||
app.domains = make(map[string]domains.Domain)
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
package domains
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
const DomainNameDeconnector = "deconnector"
|
||||
|
||||
type Deconnector interface {
|
||||
Handle(clientConn net.Conn)
|
||||
Handle(ctx context.Context, clientConn net.Conn)
|
||||
}
|
||||
|
||||
@@ -2,15 +2,15 @@ package deconnector
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
func (d *Deconnector) handleDeconnect(clientConn net.Conn, connectReq *http.Request, upstreamURL *url.URL) {
|
||||
func (d *Deconnector) handleDeconnect(ctx context.Context, clientConn net.Conn, connectReq *http.Request) {
|
||||
// Tell client the tunnel is open
|
||||
fmt.Fprintf(clientConn, "HTTP/1.1 200 Connection established\r\n\r\n")
|
||||
_, _ = fmt.Fprintf(clientConn, "HTTP/1.1 200 Connection established\r\n\r\n")
|
||||
|
||||
// Read the real HTTP request the client sends through the tunnel
|
||||
innerReq, err := http.ReadRequest(bufio.NewReader(clientConn))
|
||||
@@ -30,5 +30,5 @@ func (d *Deconnector) handleDeconnect(clientConn net.Conn, connectReq *http.Requ
|
||||
WithField("url", innerReq.URL).
|
||||
Info("Handling de-CONNECT request")
|
||||
|
||||
d.forwardHTTP(clientConn, innerReq, upstreamURL)
|
||||
d.forwardHTTP(ctx, clientConn, innerReq)
|
||||
}
|
||||
|
||||
@@ -2,14 +2,14 @@ package deconnector
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
func (d *Deconnector) forwardHTTP(clientConn net.Conn, req *http.Request, upstreamURL *url.URL) {
|
||||
upstreamConn, err := d.dialer.Dial()
|
||||
func (d *Deconnector) forwardHTTP(ctx context.Context, clientConn net.Conn, req *http.Request) {
|
||||
upstreamConn, err := d.dialer.Dial(ctx)
|
||||
if err != nil {
|
||||
d.app.Logger().WithError(err).Error("upstream dial failed")
|
||||
fmt.Fprintf(clientConn, "HTTP/1.1 502 Bad Gateway\r\n\r\n")
|
||||
@@ -18,6 +18,10 @@ func (d *Deconnector) forwardHTTP(clientConn net.Conn, req *http.Request, upstre
|
||||
}
|
||||
defer upstreamConn.Close()
|
||||
|
||||
if authHeader, ok := d.dialer.Auth(); ok {
|
||||
req.Header.Set("Proxy-Authorization", "Basic "+authHeader)
|
||||
}
|
||||
|
||||
if err := req.WriteProxy(upstreamConn); err != nil {
|
||||
d.app.Logger().WithError(err).Error("failed to write request")
|
||||
|
||||
@@ -33,5 +37,8 @@ func (d *Deconnector) forwardHTTP(clientConn net.Conn, req *http.Request, upstre
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
resp.Write(clientConn)
|
||||
err = resp.Write(clientConn)
|
||||
if err != nil {
|
||||
d.app.Logger().WithError(err).Error("failed to write response")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,18 +2,12 @@ package deconnector
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func (d *Deconnector) Handle(clientConn net.Conn) {
|
||||
upstreamURL, err := d.dialer.UpstreamURL()
|
||||
if err != nil {
|
||||
d.app.Logger().WithError(err).Error("failed to get upstream URL")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Deconnector) Handle(ctx context.Context, clientConn net.Conn) {
|
||||
defer clientConn.Close()
|
||||
|
||||
req, err := http.ReadRequest(bufio.NewReader(clientConn))
|
||||
@@ -26,9 +20,9 @@ func (d *Deconnector) Handle(clientConn net.Conn) {
|
||||
if req.Method == http.MethodConnect {
|
||||
_, port, _ := net.SplitHostPort(req.Host)
|
||||
if port == "443" {
|
||||
d.handleTunnel(clientConn, req.Host, upstreamURL)
|
||||
d.handleTunnel(ctx, clientConn, req.Host)
|
||||
} else {
|
||||
d.handleDeconnect(clientConn, req, upstreamURL)
|
||||
d.handleDeconnect(ctx, clientConn, req)
|
||||
}
|
||||
} else {
|
||||
req.RequestURI = ""
|
||||
@@ -42,6 +36,6 @@ func (d *Deconnector) Handle(clientConn net.Conn) {
|
||||
WithField("url", req.URL).
|
||||
Info("Forwarding HTTP request")
|
||||
|
||||
d.forwardHTTP(clientConn, req, upstreamURL)
|
||||
d.forwardHTTP(ctx, clientConn, req)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,17 +2,18 @@ package deconnector
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
func (d *Deconnector) handleTunnel(clientConn net.Conn, host string, upstreamURL *url.URL) {
|
||||
//nolint:funlen
|
||||
func (d *Deconnector) handleTunnel(ctx context.Context, clientConn net.Conn, host string) {
|
||||
d.app.Logger().WithField("host", host).Info("Handling CONNECT tunnel")
|
||||
|
||||
upstreamConn, err := d.dialer.Dial()
|
||||
upstreamConn, err := d.dialer.Dial(ctx)
|
||||
if err != nil {
|
||||
d.app.Logger().WithError(err).Error("upstream dial failed")
|
||||
fmt.Fprintf(clientConn, "HTTP/1.1 502 Bad Gateway\r\n\r\n")
|
||||
@@ -24,6 +25,10 @@ func (d *Deconnector) handleTunnel(clientConn net.Conn, host string, upstreamURL
|
||||
connectLine := fmt.Sprintf(
|
||||
"CONNECT %s HTTP/1.1\r\nHost: %s\r\n\r\n", host, host,
|
||||
)
|
||||
if authHeader, ok := d.dialer.Auth(); ok {
|
||||
connectLine += fmt.Sprintf("Proxy-Authorization: Basic %s\r\n", authHeader)
|
||||
}
|
||||
|
||||
fmt.Fprint(upstreamConn, connectLine)
|
||||
|
||||
resp, err := http.ReadResponse(bufio.NewReader(upstreamConn), nil)
|
||||
@@ -33,13 +38,40 @@ func (d *Deconnector) handleTunnel(clientConn net.Conn, host string, upstreamURL
|
||||
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
fmt.Fprintf(clientConn, "HTTP/1.1 200 Connection established\r\n\r\n")
|
||||
|
||||
done := make(chan struct{}, 2)
|
||||
|
||||
go func() { io.Copy(upstreamConn, clientConn); done <- struct{}{} }()
|
||||
go func() { io.Copy(clientConn, upstreamConn); done <- struct{}{} }()
|
||||
go func() {
|
||||
defer func() {
|
||||
_ = upstreamConn.Close()
|
||||
_ = clientConn.Close()
|
||||
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
if _, err := io.Copy(upstreamConn, clientConn); err != nil {
|
||||
d.app.Logger().
|
||||
WithError(err).
|
||||
Debug("client -> upstream copy stopped")
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer func() {
|
||||
_ = upstreamConn.Close()
|
||||
_ = clientConn.Close()
|
||||
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
if _, err := io.Copy(clientConn, upstreamConn); err != nil {
|
||||
d.app.Logger().
|
||||
WithError(err).
|
||||
Debug("upstream -> client copy stopped")
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package domains
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/url"
|
||||
)
|
||||
@@ -8,6 +9,7 @@ import (
|
||||
const DomainNameDialer = "dialer"
|
||||
|
||||
type Dialer interface {
|
||||
Dial() (net.Conn, error)
|
||||
Auth() (string, bool)
|
||||
Dial(ctx context.Context) (net.Conn, error)
|
||||
UpstreamURL() (*url.URL, error)
|
||||
}
|
||||
|
||||
@@ -1,21 +1,52 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
func (d *Dialer) Dial() (net.Conn, error) {
|
||||
if d.dialURL.Scheme == "https" {
|
||||
return tls.Dial("tcp", d.dialURL.Host, &tls.Config{
|
||||
ServerName: d.dialURL.Hostname(),
|
||||
InsecureSkipVerify: d.app.Config().Upstream.InsecureTLS,
|
||||
})
|
||||
func (d *Dialer) Auth() (string, bool) {
|
||||
url, _ := d.UpstreamURL()
|
||||
if url.User == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return net.Dial("tcp", d.dialURL.Host)
|
||||
username := url.User.Username()
|
||||
password, _ := url.User.Password()
|
||||
|
||||
return base64.StdEncoding.EncodeToString(
|
||||
[]byte(username + ":" + password),
|
||||
), true
|
||||
}
|
||||
|
||||
func (d *Dialer) Dial(ctx context.Context) (net.Conn, error) {
|
||||
dialer := &net.Dialer{}
|
||||
|
||||
conn, err := dialer.DialContext(ctx, "tcp", d.dialURL.Host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %w (%w)", ErrDialer, ErrDial, err)
|
||||
}
|
||||
|
||||
if d.dialURL.Scheme != "https" {
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
tlsConn := tls.Client(conn, &tls.Config{
|
||||
ServerName: d.dialURL.Hostname(),
|
||||
//nolint:gosec
|
||||
InsecureSkipVerify: d.app.Config().Upstream.InsecureTLS,
|
||||
})
|
||||
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||
_ = conn.Close()
|
||||
|
||||
return nil, fmt.Errorf("%w: %w (%w)", ErrDialer, ErrDial, err)
|
||||
}
|
||||
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
func (d *Dialer) UpstreamURL() (*url.URL, error) {
|
||||
@@ -26,12 +57,12 @@ func (d *Dialer) UpstreamURL() (*url.URL, error) {
|
||||
)
|
||||
}
|
||||
|
||||
u, err := url.Parse(d.app.Config().Upstream.URL)
|
||||
upstreamURL, err := url.Parse(d.app.Config().Upstream.URL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"%w: %w (%w)", ErrDialer, ErrUpstreamURL, err,
|
||||
)
|
||||
}
|
||||
|
||||
return u, nil
|
||||
return upstreamURL, nil
|
||||
}
|
||||
|
||||
@@ -5,5 +5,6 @@ import "errors"
|
||||
var (
|
||||
ErrDialer = errors.New("dialer")
|
||||
ErrConnectDependencies = errors.New("function ConnectDependencies()")
|
||||
ErrDial = errors.New("function Dial()")
|
||||
ErrUpstreamURL = errors.New("function UpstreamURL()")
|
||||
)
|
||||
|
||||
@@ -7,9 +7,15 @@ import (
|
||||
)
|
||||
|
||||
func (l *Listener) Listen(ctx context.Context) error {
|
||||
ln, err := net.Listen(
|
||||
listenerConfig := new(net.ListenConfig)
|
||||
|
||||
listener, err := listenerConfig.Listen(
|
||||
ctx,
|
||||
"tcp",
|
||||
l.app.Config().Deconnect.Host+":"+l.app.Config().Deconnect.Port,
|
||||
net.JoinHostPort(
|
||||
l.app.Config().Deconnect.Host,
|
||||
l.app.Config().Deconnect.Port,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %w (%w)", ErrListener, ErrListen, err)
|
||||
@@ -24,16 +30,16 @@ func (l *Listener) Listen(ctx context.Context) error {
|
||||
|
||||
l.app.Logger().Info("Shutting down listener")
|
||||
|
||||
if err := ln.Close(); err != nil {
|
||||
if err := listener.Close(); err != nil {
|
||||
l.app.Logger().WithError(err).Error("listener close error")
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return nil
|
||||
return fmt.Errorf("%w: %w (%w)", ErrListener, ErrListen, err)
|
||||
}
|
||||
|
||||
l.app.Logger().WithError(err).Error("accept error")
|
||||
@@ -41,6 +47,6 @@ func (l *Listener) Listen(ctx context.Context) error {
|
||||
continue
|
||||
}
|
||||
|
||||
go l.deconnector.Handle(conn)
|
||||
go l.deconnector.Handle(ctx, conn)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user