diff --git a/cmd/deconnect/deconnect.go b/cmd/deconnect/deconnect.go index fe2acf5..d31ded1 100644 --- a/cmd/deconnect/deconnect.go +++ b/cmd/deconnect/deconnect.go @@ -4,7 +4,6 @@ import ( "context" "os" "os/signal" - "sync" "syscall" "source.hodakov.me/hdkv/deconnect/internal/application" @@ -32,28 +31,24 @@ func main() { app.RegisterDomain(domains.DomainNameDialer, dialer.New(app)) app.RegisterDomain(domains.DomainNameDeconnector, deconnector.New(app)) - err = app.ConnectDependencies() + err = app.ConnectDependencies(ctx) if err != nil { app.Logger().Fatal(err) } + err = app.StartDomains(ctx) + if err != nil { + app.Logger().Fatal(err) + } + + app.Logger().Info("Started deconnect") + // CTRL+C handler. interrupt := make(chan os.Signal, 1) signal.Notify( interrupt, syscall.SIGINT, syscall.SIGTERM, ) - var wg sync.WaitGroup - app.RegisterGlobalWaitGroup(&wg) - - err = app.StartDomains() - if err != nil { - app.Logger().Fatal(err) - os.Exit(1) - } - - app.Logger().Info("Started deconnect") - go func() { signalThing := <-interrupt app.Logger().WithField("signal", signalThing.String()). @@ -62,8 +57,7 @@ func main() { cancel() }() - // Wait for all domains to finish their cleanup - wg.Wait() + <-ctx.Done() app.Logger().Info("deconnect shutdown complete") os.Exit(0) diff --git a/internal/application/application.go b/internal/application/application.go index 503ca87..4ad3e4e 100644 --- a/internal/application/application.go +++ b/internal/application/application.go @@ -13,10 +13,8 @@ import ( ) type App struct { - ctx context.Context logger *logrus.Entry config *configuration.Config - wg *sync.WaitGroup domains map[string]domains.Domain domainsMutex sync.RWMutex @@ -26,10 +24,6 @@ func (a *App) Config() *configuration.Config { return a.config } -func (a *App) Context() context.Context { - return a.ctx -} - func (a *App) Logger() *logrus.Entry { return a.logger } @@ -54,8 +48,6 @@ func New(ctx context.Context) *App { "numgc": strconv.FormatUint(uint64(m.NumGC), 10), }) - app.ctx = ctx - app.domains = make(map[string]domains.Domain) return app @@ -92,12 +84,12 @@ func (a *App) RetrieveDomain(name string) any { return a.domains[name] } -func (a *App) ConnectDependencies() error { +func (a *App) ConnectDependencies(ctx context.Context) error { a.domainsMutex.RLock() defer a.domainsMutex.RUnlock() for _, domain := range a.domains { - err := domain.ConnectDependencies() + err := domain.ConnectDependencies(ctx) if err != nil { return fmt.Errorf("%w: %w (%w)", ErrApplication, ErrConnectDependencies, err) } @@ -106,12 +98,12 @@ func (a *App) ConnectDependencies() error { return nil } -func (a *App) StartDomains() error { +func (a *App) StartDomains(ctx context.Context) error { a.domainsMutex.RLock() defer a.domainsMutex.RUnlock() for _, domain := range a.domains { - err := domain.Start() + err := domain.Start(ctx) if err != nil { return fmt.Errorf("%w: %w (%w)", ErrApplication, ErrDomainInit, err) } @@ -119,11 +111,3 @@ func (a *App) StartDomains() error { return nil } - -func (a *App) RegisterGlobalWaitGroup(wg *sync.WaitGroup) { - a.wg = wg -} - -func (a *App) GetGlobalWaitGroup() *sync.WaitGroup { - return a.wg -} diff --git a/internal/domains/deconnector/deconnector.go b/internal/domains/deconnector/deconnector.go index 7effc42..59c1c0d 100644 --- a/internal/domains/deconnector/deconnector.go +++ b/internal/domains/deconnector/deconnector.go @@ -1,6 +1,7 @@ package deconnector import ( + "context" "fmt" "source.hodakov.me/hdkv/deconnect/internal/application" @@ -24,7 +25,7 @@ func New(app *application.App) *Deconnector { } } -func (d *Deconnector) ConnectDependencies() error { +func (d *Deconnector) ConnectDependencies(_ context.Context) error { dialer, ok := d.app.RetrieveDomain(domains.DomainNameDialer).(domains.Dialer) if !ok { return fmt.Errorf( @@ -38,6 +39,6 @@ func (d *Deconnector) ConnectDependencies() error { return nil } -func (d *Deconnector) Start() error { +func (d *Deconnector) Start(_ context.Context) error { return nil } diff --git a/internal/domains/dialer/dialer.go b/internal/domains/dialer/dialer.go index b1d9d71..792e1f1 100644 --- a/internal/domains/dialer/dialer.go +++ b/internal/domains/dialer/dialer.go @@ -1,6 +1,7 @@ package dialer import ( + "context" "fmt" "net/url" @@ -25,7 +26,7 @@ func New(app *application.App) *Dialer { } } -func (d *Dialer) ConnectDependencies() error { +func (d *Dialer) ConnectDependencies(ctx context.Context) error { dialURL, err := d.UpstreamURL() if err != nil { return fmt.Errorf( @@ -38,6 +39,6 @@ func (d *Dialer) ConnectDependencies() error { return nil } -func (d *Dialer) Start() error { +func (d *Dialer) Start(_ context.Context) error { return nil } diff --git a/internal/domains/domain.go b/internal/domains/domain.go index c75fbf1..30bf9fd 100644 --- a/internal/domains/domain.go +++ b/internal/domains/domain.go @@ -1,6 +1,8 @@ package domains +import "context" + type Domain interface { - ConnectDependencies() error - Start() error + ConnectDependencies(ctx context.Context) error + Start(ctx context.Context) error } diff --git a/internal/domains/listener/listen.go b/internal/domains/listener/listen.go index f51e939..5be15a8 100644 --- a/internal/domains/listener/listen.go +++ b/internal/domains/listener/listen.go @@ -1,11 +1,12 @@ package listener import ( + "context" "fmt" "net" ) -func (l *Listener) Listen() error { +func (l *Listener) Listen(ctx context.Context) error { ln, err := net.Listen( "tcp", l.app.Config().Deconnect.Host+":"+l.app.Config().Deconnect.Port, @@ -19,7 +20,7 @@ func (l *Listener) Listen() error { Info("Listening for incoming connections") go func() { - <-l.app.Context().Done() + <-ctx.Done() l.app.Logger().Info("Shutting down listener") @@ -31,7 +32,7 @@ func (l *Listener) Listen() error { for { conn, err := ln.Accept() if err != nil { - if l.app.Context().Err() != nil { + if ctx.Err() != nil { return nil } diff --git a/internal/domains/listener/listener.go b/internal/domains/listener/listener.go index 92090ea..2481d4f 100644 --- a/internal/domains/listener/listener.go +++ b/internal/domains/listener/listener.go @@ -1,6 +1,7 @@ package listener import ( + "context" "fmt" "source.hodakov.me/hdkv/deconnect/internal/application" @@ -24,7 +25,7 @@ func New(app *application.App) *Listener { } } -func (l *Listener) ConnectDependencies() error { +func (l *Listener) ConnectDependencies(ctx context.Context) error { deconnector, ok := l.app.RetrieveDomain(domains.DomainNameDeconnector).(domains.Deconnector) if !ok { return fmt.Errorf( @@ -38,15 +39,14 @@ func (l *Listener) ConnectDependencies() error { return nil } -func (l *Listener) Start() error { - wg := l.app.GetGlobalWaitGroup() - if wg == nil { - return fmt.Errorf("%w: %w (%w)", ErrListener, ErrStart, ErrFailedToGetWaitGroup) - } - - wg.Go(func() { - l.Listen() - }) +func (l *Listener) Start(ctx context.Context) error { + go func() { + if err := l.Listen(ctx); err != nil { + l.app.Logger(). + WithError(err). + Error("Listener stopped with error") + } + }() return nil }