Initial commit
This commit is contained in:
11
internal/domains/deconnector.go
Normal file
11
internal/domains/deconnector.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package domains
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
const DomainNameDeconnector = "deconnector"
|
||||
|
||||
type Deconnector interface {
|
||||
Handle(clientConn net.Conn)
|
||||
}
|
||||
33
internal/domains/deconnector/connect.go
Normal file
33
internal/domains/deconnector/connect.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package deconnector
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
func (d *Deconnector) handleDeconnect(clientConn net.Conn, connectReq *http.Request, upstreamURL *url.URL) {
|
||||
// Tell client the tunnel is open
|
||||
fmt.Fprintf(clientConn, "HTTP/1.1 200 Connection established\r\n\r\n")
|
||||
|
||||
// Read the real HTTP request the client sends through the tunnel
|
||||
innerReq, err := http.ReadRequest(bufio.NewReader(clientConn))
|
||||
if err != nil {
|
||||
log.Printf("failed to read inner request after CONNECT:80: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
innerReq.URL.Scheme = "http"
|
||||
innerReq.URL.Host = connectReq.Host
|
||||
innerReq.RequestURI = ""
|
||||
|
||||
d.app.Logger().
|
||||
WithField("method", innerReq.Method).
|
||||
WithField("url", innerReq.URL).
|
||||
Info("Handling de-CONNECT request")
|
||||
|
||||
d.forwardHTTP(clientConn, innerReq, upstreamURL)
|
||||
}
|
||||
43
internal/domains/deconnector/deconnector.go
Normal file
43
internal/domains/deconnector/deconnector.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package deconnector
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"source.hodakov.me/hdkv/deconnect/internal/application"
|
||||
"source.hodakov.me/hdkv/deconnect/internal/domains"
|
||||
)
|
||||
|
||||
var (
|
||||
_ domains.Deconnector = new(Deconnector)
|
||||
_ domains.Domain = new(Deconnector)
|
||||
)
|
||||
|
||||
type Deconnector struct {
|
||||
app *application.App
|
||||
|
||||
dialer domains.Dialer
|
||||
}
|
||||
|
||||
func New(app *application.App) *Deconnector {
|
||||
return &Deconnector{
|
||||
app: app,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Deconnector) ConnectDependencies() error {
|
||||
dialer, ok := d.app.RetrieveDomain(domains.DomainNameDialer).(domains.Dialer)
|
||||
if !ok {
|
||||
return fmt.Errorf(
|
||||
"%w: %w (%s)", ErrDeconnector, ErrConnectDependencies,
|
||||
"dialer domain interface conversion failed",
|
||||
)
|
||||
}
|
||||
|
||||
d.dialer = dialer
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Deconnector) Start() error {
|
||||
return nil
|
||||
}
|
||||
8
internal/domains/deconnector/errors.go
Normal file
8
internal/domains/deconnector/errors.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package deconnector
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrDeconnector = errors.New("deconnector")
|
||||
ErrConnectDependencies = errors.New("failed to connect dependencies")
|
||||
)
|
||||
36
internal/domains/deconnector/forward.go
Normal file
36
internal/domains/deconnector/forward.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package deconnector
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
func (d *Deconnector) forwardHTTP(clientConn net.Conn, req *http.Request, upstreamURL *url.URL) {
|
||||
upstreamConn, err := d.dialer.Dial()
|
||||
if err != nil {
|
||||
d.app.Logger().WithError(err).Error("upstream dial failed")
|
||||
fmt.Fprintf(clientConn, "HTTP/1.1 502 Bad Gateway\r\n\r\n")
|
||||
|
||||
return
|
||||
}
|
||||
defer upstreamConn.Close()
|
||||
|
||||
if err := req.WriteProxy(upstreamConn); err != nil {
|
||||
d.app.Logger().WithError(err).Error("failed to write request")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := http.ReadResponse(bufio.NewReader(upstreamConn), req)
|
||||
if err != nil {
|
||||
d.app.Logger().WithError(err).Error("failed to read response")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
resp.Write(clientConn)
|
||||
}
|
||||
46
internal/domains/deconnector/handle.go
Normal file
46
internal/domains/deconnector/handle.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package deconnector
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func (d *Deconnector) Handle(clientConn net.Conn) {
|
||||
upstreamURL, err := d.dialer.UpstreamURL()
|
||||
if err != nil {
|
||||
d.app.Logger().WithError(err).Error("failed to get upstream URL")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
defer clientConn.Close()
|
||||
|
||||
req, err := http.ReadRequest(bufio.NewReader(clientConn))
|
||||
if err != nil {
|
||||
d.app.Logger().WithError(err).Error("failed to read request")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if req.Method == http.MethodConnect {
|
||||
_, port, _ := net.SplitHostPort(req.Host)
|
||||
if port == "443" {
|
||||
d.handleTunnel(clientConn, req.Host, upstreamURL)
|
||||
} else {
|
||||
d.handleDeconnect(clientConn, req, upstreamURL)
|
||||
}
|
||||
} else {
|
||||
req.RequestURI = ""
|
||||
if req.URL.Host == "" {
|
||||
req.URL.Host = req.Host
|
||||
}
|
||||
req.URL.Scheme = "http"
|
||||
d.app.Logger().
|
||||
WithField("method", req.Method).
|
||||
WithField("url", req.URL).
|
||||
Info("Forwarding HTTP request")
|
||||
|
||||
d.forwardHTTP(clientConn, req, upstreamURL)
|
||||
}
|
||||
}
|
||||
43
internal/domains/deconnector/tunnel.go
Normal file
43
internal/domains/deconnector/tunnel.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package deconnector
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
func (d *Deconnector) handleTunnel(clientConn net.Conn, host string, upstreamURL *url.URL) {
|
||||
d.app.Logger().WithField("host", host).Info("Handling CONNECT tunnel")
|
||||
|
||||
upstreamConn, err := d.dialer.Dial()
|
||||
if err != nil {
|
||||
d.app.Logger().WithError(err).Error("upstream dial failed")
|
||||
fmt.Fprintf(clientConn, "HTTP/1.1 502 Bad Gateway\r\n\r\n")
|
||||
|
||||
return
|
||||
}
|
||||
defer upstreamConn.Close()
|
||||
|
||||
connectLine := fmt.Sprintf(
|
||||
"CONNECT %s HTTP/1.1\r\nHost: %s\r\n\r\n", host, host,
|
||||
)
|
||||
fmt.Fprint(upstreamConn, connectLine)
|
||||
|
||||
resp, err := http.ReadResponse(bufio.NewReader(upstreamConn), nil)
|
||||
if err != nil || resp.StatusCode > 499 {
|
||||
d.app.Logger().WithError(err).Error("upstream CONNECT failed")
|
||||
fmt.Fprintf(clientConn, "HTTP/1.1 502 Bad Gateway\r\n\r\n")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Fprintf(clientConn, "HTTP/1.1 200 Connection established\r\n\r\n")
|
||||
|
||||
done := make(chan struct{}, 2)
|
||||
go func() { io.Copy(upstreamConn, clientConn); done <- struct{}{} }()
|
||||
go func() { io.Copy(clientConn, upstreamConn); done <- struct{}{} }()
|
||||
<-done
|
||||
}
|
||||
13
internal/domains/dialer.go
Normal file
13
internal/domains/dialer.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package domains
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
const DomainNameDialer = "dialer"
|
||||
|
||||
type Dialer interface {
|
||||
Dial() (net.Conn, error)
|
||||
UpstreamURL() (*url.URL, error)
|
||||
}
|
||||
37
internal/domains/dialer/dial.go
Normal file
37
internal/domains/dialer/dial.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
func (d *Dialer) Dial() (net.Conn, error) {
|
||||
if d.dialURL.Scheme == "https" {
|
||||
return tls.Dial("tcp", d.dialURL.Host, &tls.Config{
|
||||
ServerName: d.dialURL.Hostname(),
|
||||
InsecureSkipVerify: d.app.Config().Upstream.InsecureTLS,
|
||||
})
|
||||
}
|
||||
|
||||
return net.Dial("tcp", d.dialURL.Host)
|
||||
}
|
||||
|
||||
func (d *Dialer) UpstreamURL() (*url.URL, error) {
|
||||
if d.app.Config().Upstream.URL == "" {
|
||||
return nil, fmt.Errorf(
|
||||
"%w: %w (%s)", ErrDialer, ErrParseURL,
|
||||
"upstream URL is empty",
|
||||
)
|
||||
}
|
||||
|
||||
u, err := url.Parse(d.app.Config().Upstream.URL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"%w: %w (%w)", ErrDialer, ErrParseURL, err,
|
||||
)
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
43
internal/domains/dialer/dialer.go
Normal file
43
internal/domains/dialer/dialer.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"source.hodakov.me/hdkv/deconnect/internal/application"
|
||||
"source.hodakov.me/hdkv/deconnect/internal/domains"
|
||||
)
|
||||
|
||||
var (
|
||||
_ domains.Dialer = new(Dialer)
|
||||
_ domains.Domain = new(Dialer)
|
||||
)
|
||||
|
||||
type Dialer struct {
|
||||
app *application.App
|
||||
|
||||
dialURL *url.URL
|
||||
}
|
||||
|
||||
func New(app *application.App) *Dialer {
|
||||
return &Dialer{
|
||||
app: app,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Dialer) ConnectDependencies() error {
|
||||
dialURL, err := d.UpstreamURL()
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"%w: %w (%w)", ErrDialer, ErrConnectDependencies, err,
|
||||
)
|
||||
}
|
||||
|
||||
d.dialURL = dialURL
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Dialer) Start() error {
|
||||
return nil
|
||||
}
|
||||
9
internal/domains/dialer/errors.go
Normal file
9
internal/domains/dialer/errors.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package dialer
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrDialer = errors.New("dialer")
|
||||
ErrConnectDependencies = errors.New("failed to connect dependencies")
|
||||
ErrParseURL = errors.New("failed to parse URL")
|
||||
)
|
||||
6
internal/domains/domain.go
Normal file
6
internal/domains/domain.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package domains
|
||||
|
||||
type Domain interface {
|
||||
ConnectDependencies() error
|
||||
Start() error
|
||||
}
|
||||
5
internal/domains/listener.go
Normal file
5
internal/domains/listener.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package domains
|
||||
|
||||
const DomainNameListener = "listener"
|
||||
|
||||
type Listener any
|
||||
10
internal/domains/listener/errors.go
Normal file
10
internal/domains/listener/errors.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package listener
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrListener = errors.New("listener")
|
||||
ErrConnectDependencies = errors.New("failed to connect dependencies")
|
||||
ErrFailedToGetWaitGroup = errors.New("failed to get global waitgroup")
|
||||
ErrFailedToListen = errors.New("failed to listen")
|
||||
)
|
||||
31
internal/domains/listener/listen.go
Normal file
31
internal/domains/listener/listen.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package listener
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
func (l *Listener) Listen() error {
|
||||
ln, err := net.Listen(
|
||||
"tcp",
|
||||
l.app.Config().Deconnect.Host+":"+l.app.Config().Deconnect.Port,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %w (%w)", ErrListener, ErrFailedToListen, err)
|
||||
}
|
||||
|
||||
l.app.Logger().WithField("host", l.app.Config().Deconnect.Host).
|
||||
WithField("port", l.app.Config().Deconnect.Port).
|
||||
Info("Listening for incoming connections")
|
||||
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
l.app.Logger().WithError(err).Error("accept error")
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
go l.deconnector.Handle(conn)
|
||||
}
|
||||
}
|
||||
52
internal/domains/listener/listener.go
Normal file
52
internal/domains/listener/listener.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package listener
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"source.hodakov.me/hdkv/deconnect/internal/application"
|
||||
"source.hodakov.me/hdkv/deconnect/internal/domains"
|
||||
)
|
||||
|
||||
var (
|
||||
_ domains.Listener = new(Listener)
|
||||
_ domains.Domain = new(Listener)
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
app *application.App
|
||||
|
||||
deconnector domains.Deconnector
|
||||
}
|
||||
|
||||
func New(app *application.App) *Listener {
|
||||
return &Listener{
|
||||
app: app,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) ConnectDependencies() error {
|
||||
deconnector, ok := l.app.RetrieveDomain(domains.DomainNameDeconnector).(domains.Deconnector)
|
||||
if !ok {
|
||||
return fmt.Errorf(
|
||||
"%w: %w (%s)", ErrListener, ErrConnectDependencies,
|
||||
"deconnector domain interface conversion failed",
|
||||
)
|
||||
}
|
||||
|
||||
l.deconnector = deconnector
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Listener) Start() error {
|
||||
wg := l.app.GetGlobalWaitGroup()
|
||||
if wg == nil {
|
||||
return fmt.Errorf("%w: %w (%s)", ErrListener, ErrFailedToGetWaitGroup, "got nil waitgroup")
|
||||
}
|
||||
|
||||
wg.Go(func() {
|
||||
l.Listen()
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user