Use context everywhere

This commit is contained in:
2026-05-26 18:58:50 +03:00
parent 3ece07e11d
commit 11a357ebc4
7 changed files with 37 additions and 54 deletions

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"os" "os"
"os/signal" "os/signal"
"sync"
"syscall" "syscall"
"source.hodakov.me/hdkv/deconnect/internal/application" "source.hodakov.me/hdkv/deconnect/internal/application"
@@ -32,28 +31,24 @@ func main() {
app.RegisterDomain(domains.DomainNameDialer, dialer.New(app)) app.RegisterDomain(domains.DomainNameDialer, dialer.New(app))
app.RegisterDomain(domains.DomainNameDeconnector, deconnector.New(app)) app.RegisterDomain(domains.DomainNameDeconnector, deconnector.New(app))
err = app.ConnectDependencies() err = app.ConnectDependencies(ctx)
if err != nil { if err != nil {
app.Logger().Fatal(err) app.Logger().Fatal(err)
} }
err = app.StartDomains(ctx)
if err != nil {
app.Logger().Fatal(err)
}
app.Logger().Info("Started deconnect")
// CTRL+C handler. // CTRL+C handler.
interrupt := make(chan os.Signal, 1) interrupt := make(chan os.Signal, 1)
signal.Notify( signal.Notify(
interrupt, syscall.SIGINT, syscall.SIGTERM, interrupt, syscall.SIGINT, syscall.SIGTERM,
) )
var wg sync.WaitGroup
app.RegisterGlobalWaitGroup(&wg)
err = app.StartDomains()
if err != nil {
app.Logger().Fatal(err)
os.Exit(1)
}
app.Logger().Info("Started deconnect")
go func() { go func() {
signalThing := <-interrupt signalThing := <-interrupt
app.Logger().WithField("signal", signalThing.String()). app.Logger().WithField("signal", signalThing.String()).
@@ -62,8 +57,7 @@ func main() {
cancel() cancel()
}() }()
// Wait for all domains to finish their cleanup <-ctx.Done()
wg.Wait()
app.Logger().Info("deconnect shutdown complete") app.Logger().Info("deconnect shutdown complete")
os.Exit(0) os.Exit(0)

View File

@@ -13,10 +13,8 @@ import (
) )
type App struct { type App struct {
ctx context.Context
logger *logrus.Entry logger *logrus.Entry
config *configuration.Config config *configuration.Config
wg *sync.WaitGroup
domains map[string]domains.Domain domains map[string]domains.Domain
domainsMutex sync.RWMutex domainsMutex sync.RWMutex
@@ -26,10 +24,6 @@ func (a *App) Config() *configuration.Config {
return a.config return a.config
} }
func (a *App) Context() context.Context {
return a.ctx
}
func (a *App) Logger() *logrus.Entry { func (a *App) Logger() *logrus.Entry {
return a.logger return a.logger
} }
@@ -54,8 +48,6 @@ func New(ctx context.Context) *App {
"numgc": strconv.FormatUint(uint64(m.NumGC), 10), "numgc": strconv.FormatUint(uint64(m.NumGC), 10),
}) })
app.ctx = ctx
app.domains = make(map[string]domains.Domain) app.domains = make(map[string]domains.Domain)
return app return app
@@ -92,12 +84,12 @@ func (a *App) RetrieveDomain(name string) any {
return a.domains[name] return a.domains[name]
} }
func (a *App) ConnectDependencies() error { func (a *App) ConnectDependencies(ctx context.Context) error {
a.domainsMutex.RLock() a.domainsMutex.RLock()
defer a.domainsMutex.RUnlock() defer a.domainsMutex.RUnlock()
for _, domain := range a.domains { for _, domain := range a.domains {
err := domain.ConnectDependencies() err := domain.ConnectDependencies(ctx)
if err != nil { if err != nil {
return fmt.Errorf("%w: %w (%w)", ErrApplication, ErrConnectDependencies, err) return fmt.Errorf("%w: %w (%w)", ErrApplication, ErrConnectDependencies, err)
} }
@@ -106,12 +98,12 @@ func (a *App) ConnectDependencies() error {
return nil return nil
} }
func (a *App) StartDomains() error { func (a *App) StartDomains(ctx context.Context) error {
a.domainsMutex.RLock() a.domainsMutex.RLock()
defer a.domainsMutex.RUnlock() defer a.domainsMutex.RUnlock()
for _, domain := range a.domains { for _, domain := range a.domains {
err := domain.Start() err := domain.Start(ctx)
if err != nil { if err != nil {
return fmt.Errorf("%w: %w (%w)", ErrApplication, ErrDomainInit, err) return fmt.Errorf("%w: %w (%w)", ErrApplication, ErrDomainInit, err)
} }
@@ -119,11 +111,3 @@ func (a *App) StartDomains() error {
return nil return nil
} }
func (a *App) RegisterGlobalWaitGroup(wg *sync.WaitGroup) {
a.wg = wg
}
func (a *App) GetGlobalWaitGroup() *sync.WaitGroup {
return a.wg
}

View File

@@ -1,6 +1,7 @@
package deconnector package deconnector
import ( import (
"context"
"fmt" "fmt"
"source.hodakov.me/hdkv/deconnect/internal/application" "source.hodakov.me/hdkv/deconnect/internal/application"
@@ -24,7 +25,7 @@ func New(app *application.App) *Deconnector {
} }
} }
func (d *Deconnector) ConnectDependencies() error { func (d *Deconnector) ConnectDependencies(_ context.Context) error {
dialer, ok := d.app.RetrieveDomain(domains.DomainNameDialer).(domains.Dialer) dialer, ok := d.app.RetrieveDomain(domains.DomainNameDialer).(domains.Dialer)
if !ok { if !ok {
return fmt.Errorf( return fmt.Errorf(
@@ -38,6 +39,6 @@ func (d *Deconnector) ConnectDependencies() error {
return nil return nil
} }
func (d *Deconnector) Start() error { func (d *Deconnector) Start(_ context.Context) error {
return nil return nil
} }

View File

@@ -1,6 +1,7 @@
package dialer package dialer
import ( import (
"context"
"fmt" "fmt"
"net/url" "net/url"
@@ -25,7 +26,7 @@ func New(app *application.App) *Dialer {
} }
} }
func (d *Dialer) ConnectDependencies() error { func (d *Dialer) ConnectDependencies(ctx context.Context) error {
dialURL, err := d.UpstreamURL() dialURL, err := d.UpstreamURL()
if err != nil { if err != nil {
return fmt.Errorf( return fmt.Errorf(
@@ -38,6 +39,6 @@ func (d *Dialer) ConnectDependencies() error {
return nil return nil
} }
func (d *Dialer) Start() error { func (d *Dialer) Start(_ context.Context) error {
return nil return nil
} }

View File

@@ -1,6 +1,8 @@
package domains package domains
import "context"
type Domain interface { type Domain interface {
ConnectDependencies() error ConnectDependencies(ctx context.Context) error
Start() error Start(ctx context.Context) error
} }

View File

@@ -1,11 +1,12 @@
package listener package listener
import ( import (
"context"
"fmt" "fmt"
"net" "net"
) )
func (l *Listener) Listen() error { func (l *Listener) Listen(ctx context.Context) error {
ln, err := net.Listen( ln, err := net.Listen(
"tcp", "tcp",
l.app.Config().Deconnect.Host+":"+l.app.Config().Deconnect.Port, l.app.Config().Deconnect.Host+":"+l.app.Config().Deconnect.Port,
@@ -19,7 +20,7 @@ func (l *Listener) Listen() error {
Info("Listening for incoming connections") Info("Listening for incoming connections")
go func() { go func() {
<-l.app.Context().Done() <-ctx.Done()
l.app.Logger().Info("Shutting down listener") l.app.Logger().Info("Shutting down listener")
@@ -31,7 +32,7 @@ func (l *Listener) Listen() error {
for { for {
conn, err := ln.Accept() conn, err := ln.Accept()
if err != nil { if err != nil {
if l.app.Context().Err() != nil { if ctx.Err() != nil {
return nil return nil
} }

View File

@@ -1,6 +1,7 @@
package listener package listener
import ( import (
"context"
"fmt" "fmt"
"source.hodakov.me/hdkv/deconnect/internal/application" "source.hodakov.me/hdkv/deconnect/internal/application"
@@ -24,7 +25,7 @@ func New(app *application.App) *Listener {
} }
} }
func (l *Listener) ConnectDependencies() error { func (l *Listener) ConnectDependencies(ctx context.Context) error {
deconnector, ok := l.app.RetrieveDomain(domains.DomainNameDeconnector).(domains.Deconnector) deconnector, ok := l.app.RetrieveDomain(domains.DomainNameDeconnector).(domains.Deconnector)
if !ok { if !ok {
return fmt.Errorf( return fmt.Errorf(
@@ -38,15 +39,14 @@ func (l *Listener) ConnectDependencies() error {
return nil return nil
} }
func (l *Listener) Start() error { func (l *Listener) Start(ctx context.Context) error {
wg := l.app.GetGlobalWaitGroup() go func() {
if wg == nil { if err := l.Listen(ctx); err != nil {
return fmt.Errorf("%w: %w (%w)", ErrListener, ErrStart, ErrFailedToGetWaitGroup) l.app.Logger().
WithError(err).
Error("Listener stopped with error")
} }
}()
wg.Go(func() {
l.Listen()
})
return nil return nil
} }