Use context everywhere
This commit is contained in:
@@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"sync"
|
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"source.hodakov.me/hdkv/deconnect/internal/application"
|
"source.hodakov.me/hdkv/deconnect/internal/application"
|
||||||
@@ -32,28 +31,24 @@ func main() {
|
|||||||
app.RegisterDomain(domains.DomainNameDialer, dialer.New(app))
|
app.RegisterDomain(domains.DomainNameDialer, dialer.New(app))
|
||||||
app.RegisterDomain(domains.DomainNameDeconnector, deconnector.New(app))
|
app.RegisterDomain(domains.DomainNameDeconnector, deconnector.New(app))
|
||||||
|
|
||||||
err = app.ConnectDependencies()
|
err = app.ConnectDependencies(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
app.Logger().Fatal(err)
|
app.Logger().Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = app.StartDomains(ctx)
|
||||||
|
if err != nil {
|
||||||
|
app.Logger().Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
app.Logger().Info("Started deconnect")
|
||||||
|
|
||||||
// CTRL+C handler.
|
// CTRL+C handler.
|
||||||
interrupt := make(chan os.Signal, 1)
|
interrupt := make(chan os.Signal, 1)
|
||||||
signal.Notify(
|
signal.Notify(
|
||||||
interrupt, syscall.SIGINT, syscall.SIGTERM,
|
interrupt, syscall.SIGINT, syscall.SIGTERM,
|
||||||
)
|
)
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
app.RegisterGlobalWaitGroup(&wg)
|
|
||||||
|
|
||||||
err = app.StartDomains()
|
|
||||||
if err != nil {
|
|
||||||
app.Logger().Fatal(err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
app.Logger().Info("Started deconnect")
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
signalThing := <-interrupt
|
signalThing := <-interrupt
|
||||||
app.Logger().WithField("signal", signalThing.String()).
|
app.Logger().WithField("signal", signalThing.String()).
|
||||||
@@ -62,8 +57,7 @@ func main() {
|
|||||||
cancel()
|
cancel()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Wait for all domains to finish their cleanup
|
<-ctx.Done()
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
app.Logger().Info("deconnect shutdown complete")
|
app.Logger().Info("deconnect shutdown complete")
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
|
|||||||
@@ -13,10 +13,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type App struct {
|
type App struct {
|
||||||
ctx context.Context
|
|
||||||
logger *logrus.Entry
|
logger *logrus.Entry
|
||||||
config *configuration.Config
|
config *configuration.Config
|
||||||
wg *sync.WaitGroup
|
|
||||||
|
|
||||||
domains map[string]domains.Domain
|
domains map[string]domains.Domain
|
||||||
domainsMutex sync.RWMutex
|
domainsMutex sync.RWMutex
|
||||||
@@ -26,10 +24,6 @@ func (a *App) Config() *configuration.Config {
|
|||||||
return a.config
|
return a.config
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) Context() context.Context {
|
|
||||||
return a.ctx
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *App) Logger() *logrus.Entry {
|
func (a *App) Logger() *logrus.Entry {
|
||||||
return a.logger
|
return a.logger
|
||||||
}
|
}
|
||||||
@@ -54,8 +48,6 @@ func New(ctx context.Context) *App {
|
|||||||
"numgc": strconv.FormatUint(uint64(m.NumGC), 10),
|
"numgc": strconv.FormatUint(uint64(m.NumGC), 10),
|
||||||
})
|
})
|
||||||
|
|
||||||
app.ctx = ctx
|
|
||||||
|
|
||||||
app.domains = make(map[string]domains.Domain)
|
app.domains = make(map[string]domains.Domain)
|
||||||
|
|
||||||
return app
|
return app
|
||||||
@@ -92,12 +84,12 @@ func (a *App) RetrieveDomain(name string) any {
|
|||||||
return a.domains[name]
|
return a.domains[name]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) ConnectDependencies() error {
|
func (a *App) ConnectDependencies(ctx context.Context) error {
|
||||||
a.domainsMutex.RLock()
|
a.domainsMutex.RLock()
|
||||||
defer a.domainsMutex.RUnlock()
|
defer a.domainsMutex.RUnlock()
|
||||||
|
|
||||||
for _, domain := range a.domains {
|
for _, domain := range a.domains {
|
||||||
err := domain.ConnectDependencies()
|
err := domain.ConnectDependencies(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %w (%w)", ErrApplication, ErrConnectDependencies, err)
|
return fmt.Errorf("%w: %w (%w)", ErrApplication, ErrConnectDependencies, err)
|
||||||
}
|
}
|
||||||
@@ -106,12 +98,12 @@ func (a *App) ConnectDependencies() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) StartDomains() error {
|
func (a *App) StartDomains(ctx context.Context) error {
|
||||||
a.domainsMutex.RLock()
|
a.domainsMutex.RLock()
|
||||||
defer a.domainsMutex.RUnlock()
|
defer a.domainsMutex.RUnlock()
|
||||||
|
|
||||||
for _, domain := range a.domains {
|
for _, domain := range a.domains {
|
||||||
err := domain.Start()
|
err := domain.Start(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %w (%w)", ErrApplication, ErrDomainInit, err)
|
return fmt.Errorf("%w: %w (%w)", ErrApplication, ErrDomainInit, err)
|
||||||
}
|
}
|
||||||
@@ -119,11 +111,3 @@ func (a *App) StartDomains() error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) RegisterGlobalWaitGroup(wg *sync.WaitGroup) {
|
|
||||||
a.wg = wg
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *App) GetGlobalWaitGroup() *sync.WaitGroup {
|
|
||||||
return a.wg
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package deconnector
|
package deconnector
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"source.hodakov.me/hdkv/deconnect/internal/application"
|
"source.hodakov.me/hdkv/deconnect/internal/application"
|
||||||
@@ -24,7 +25,7 @@ func New(app *application.App) *Deconnector {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Deconnector) ConnectDependencies() error {
|
func (d *Deconnector) ConnectDependencies(_ context.Context) error {
|
||||||
dialer, ok := d.app.RetrieveDomain(domains.DomainNameDialer).(domains.Dialer)
|
dialer, ok := d.app.RetrieveDomain(domains.DomainNameDialer).(domains.Dialer)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf(
|
return fmt.Errorf(
|
||||||
@@ -38,6 +39,6 @@ func (d *Deconnector) ConnectDependencies() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Deconnector) Start() error {
|
func (d *Deconnector) Start(_ context.Context) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package dialer
|
package dialer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
@@ -25,7 +26,7 @@ func New(app *application.App) *Dialer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Dialer) ConnectDependencies() error {
|
func (d *Dialer) ConnectDependencies(ctx context.Context) error {
|
||||||
dialURL, err := d.UpstreamURL()
|
dialURL, err := d.UpstreamURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf(
|
return fmt.Errorf(
|
||||||
@@ -38,6 +39,6 @@ func (d *Dialer) ConnectDependencies() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Dialer) Start() error {
|
func (d *Dialer) Start(_ context.Context) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package domains
|
package domains
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
type Domain interface {
|
type Domain interface {
|
||||||
ConnectDependencies() error
|
ConnectDependencies(ctx context.Context) error
|
||||||
Start() error
|
Start(ctx context.Context) error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
package listener
|
package listener
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (l *Listener) Listen() error {
|
func (l *Listener) Listen(ctx context.Context) error {
|
||||||
ln, err := net.Listen(
|
ln, err := net.Listen(
|
||||||
"tcp",
|
"tcp",
|
||||||
l.app.Config().Deconnect.Host+":"+l.app.Config().Deconnect.Port,
|
l.app.Config().Deconnect.Host+":"+l.app.Config().Deconnect.Port,
|
||||||
@@ -19,7 +20,7 @@ func (l *Listener) Listen() error {
|
|||||||
Info("Listening for incoming connections")
|
Info("Listening for incoming connections")
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
<-l.app.Context().Done()
|
<-ctx.Done()
|
||||||
|
|
||||||
l.app.Logger().Info("Shutting down listener")
|
l.app.Logger().Info("Shutting down listener")
|
||||||
|
|
||||||
@@ -31,7 +32,7 @@ func (l *Listener) Listen() error {
|
|||||||
for {
|
for {
|
||||||
conn, err := ln.Accept()
|
conn, err := ln.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if l.app.Context().Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package listener
|
package listener
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"source.hodakov.me/hdkv/deconnect/internal/application"
|
"source.hodakov.me/hdkv/deconnect/internal/application"
|
||||||
@@ -24,7 +25,7 @@ func New(app *application.App) *Listener {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Listener) ConnectDependencies() error {
|
func (l *Listener) ConnectDependencies(ctx context.Context) error {
|
||||||
deconnector, ok := l.app.RetrieveDomain(domains.DomainNameDeconnector).(domains.Deconnector)
|
deconnector, ok := l.app.RetrieveDomain(domains.DomainNameDeconnector).(domains.Deconnector)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf(
|
return fmt.Errorf(
|
||||||
@@ -38,15 +39,14 @@ func (l *Listener) ConnectDependencies() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Listener) Start() error {
|
func (l *Listener) Start(ctx context.Context) error {
|
||||||
wg := l.app.GetGlobalWaitGroup()
|
go func() {
|
||||||
if wg == nil {
|
if err := l.Listen(ctx); err != nil {
|
||||||
return fmt.Errorf("%w: %w (%w)", ErrListener, ErrStart, ErrFailedToGetWaitGroup)
|
l.app.Logger().
|
||||||
|
WithError(err).
|
||||||
|
Error("Listener stopped with error")
|
||||||
}
|
}
|
||||||
|
}()
|
||||||
wg.Go(func() {
|
|
||||||
l.Listen()
|
|
||||||
})
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user