diff --git a/dag.go b/dag.go index 27a3905..6036d7e 100644 --- a/dag.go +++ b/dag.go @@ -23,15 +23,17 @@ type EdgeService struct { // newDAG creates a new DAG (Directed Acyclic Graph) with initialized dependencies and dependents maps. func newDAG() *DAG { return &DAG{ - dependencies: new(sync.Map), - dependents: new(sync.Map), + mu: sync.RWMutex{}, + dependencies: map[EdgeService]map[EdgeService]struct{}{}, + dependents: map[EdgeService]map[EdgeService]struct{}{}, } } // DAG represents a Directed Acyclic Graph of services, tracking dependencies and dependents. type DAG struct { - dependencies *sync.Map - dependents *sync.Map + mu sync.RWMutex + dependencies map[EdgeService]map[EdgeService]struct{} + dependents map[EdgeService]map[EdgeService]struct{} } // addDependency adds a dependency relationship from one service to another in the DAG. @@ -39,43 +41,64 @@ func (d *DAG) addDependency(fromScopeID, fromScopeName, fromServiceName, toScope from := newEdgeService(fromScopeID, fromScopeName, fromServiceName) to := newEdgeService(toScopeID, toScopeName, toServiceName) - d.addToMap(d.dependencies, from, to) - d.addToMap(d.dependents, to, from) + d.mu.Lock() + defer d.mu.Unlock() + + // from -> to + if _, ok := d.dependencies[from]; !ok { + d.dependencies[from] = map[EdgeService]struct{}{} + } + d.dependencies[from][to] = struct{}{} + + // from <- to + if _, ok := d.dependents[to]; !ok { + d.dependents[to] = map[EdgeService]struct{}{} + } + d.dependents[to][from] = struct{}{} } -// addToMap is a helper function to add a key-value pair to a sync.Map, creating a new sync.Map for the value if necessary. -func (d *DAG) addToMap(dependencyMap *sync.Map, key, value interface{}) { - valueMap := new(sync.Map) - valueMap.Store(value, struct{}{}) +// removeService removes a dependency relationship between services in the DAG. +func (d *DAG) removeService(scopeID, scopeName, serviceName string) { + edge := newEdgeService(scopeID, scopeName, serviceName) + + d.mu.Lock() + defer d.mu.Unlock() - if actual, loaded := dependencyMap.LoadOrStore(key, valueMap); loaded { - actual.(*sync.Map).Store(value, struct{}{}) + dependencies, dependents := d.explainServiceImplem(edge) + + for _, dependency := range dependencies { + delete(d.dependents[dependency], edge) } + + // should be empty, because we remove dependencies in the inverse invocation order + for _, dependent := range dependents { + delete(d.dependencies[dependent], edge) + } + + delete(d.dependencies, edge) + delete(d.dependents, edge) } // explainService provides information about a service's dependencies and dependents in the DAG. func (d *DAG) explainService(scopeID, scopeName, serviceName string) (dependencies, dependents []EdgeService) { edge := newEdgeService(scopeID, scopeName, serviceName) - dependencies = d.getServicesFromMap(d.dependencies, edge) - dependents = d.getServicesFromMap(d.dependents, edge) + d.mu.RLock() + defer d.mu.RUnlock() - return dependencies, dependents + return d.explainServiceImplem(edge) } -// getServicesFromMap is a helper function to retrieve services related to a specific key from a sync.Map. -func (d *DAG) getServicesFromMap(serviceMap *sync.Map, edge EdgeService) []EdgeService { - var services []EdgeService - - if kv, ok := serviceMap.Load(edge); ok { - kv.(*sync.Map).Range(func(key, value interface{}) bool { - edgeService, ok := key.(EdgeService) - if ok { - services = append(services, edgeService) - } - return ok - }) +func (d *DAG) explainServiceImplem(edge EdgeService) (dependencies, dependents []EdgeService) { + dependencies, dependents = []EdgeService{}, []EdgeService{} + + if kv, ok := d.dependencies[edge]; ok { + dependencies = keys(kv) } - return services + if kv, ok := d.dependents[edge]; ok { + dependents = keys(kv) + } + + return dependencies, dependents } diff --git a/dag_test.go b/dag_test.go index 8d5d966..a5c8271 100644 --- a/dag_test.go +++ b/dag_test.go @@ -1,7 +1,6 @@ package do import ( - "sync" "testing" "github.com/stretchr/testify/assert" @@ -24,11 +23,11 @@ func TestNewDAG(t *testing.T) { is := assert.New(t) dag := newDAG() - expectedDependencies := unSyncMap(new(sync.Map)) - expectedDependents := unSyncMap(new(sync.Map)) + expectedDependencies := map[EdgeService]map[EdgeService]struct{}{} + expectedDependents := map[EdgeService]map[EdgeService]struct{}{} - is.Equal(expectedDependencies, unSyncMap(dag.dependencies)) - is.Equal(expectedDependents, unSyncMap(dag.dependents)) + is.Equal(expectedDependencies, dag.dependencies) + is.Equal(expectedDependents, dag.dependents) } // TestDAG_addDependency checks the addition of dependencies to the DAG. @@ -44,19 +43,42 @@ func TestDAG_addDependency(t *testing.T) { dag.addDependency("scope1", "scope1", "service1", "scope2", "scope2", "service2") - expectedDependencies := map[interface{}]interface{}{edge1: map[interface{}]interface{}{edge2: struct{}{}}} - expectedDependents := map[interface{}]interface{}{edge2: map[interface{}]interface{}{edge1: struct{}{}}} + expectedDependencies := map[EdgeService]map[EdgeService]struct{}{edge1: {edge2: {}}} + expectedDependents := map[EdgeService]map[EdgeService]struct{}{edge2: {edge1: {}}} - is.Equal(expectedDependencies, unSyncMap(dag.dependencies)) - is.Equal(expectedDependents, unSyncMap(dag.dependents)) + is.Equal(expectedDependencies, dag.dependencies) + is.Equal(expectedDependents, dag.dependents) dag.addDependency("scope3", "scope3", "service3", "scope2", "scope2", "service2") - expectedDependencies[edge3] = map[interface{}]interface{}{edge2: struct{}{}} - expectedDependents[edge2] = map[interface{}]interface{}{edge1: struct{}{}, edge3: struct{}{}} + expectedDependencies = map[EdgeService]map[EdgeService]struct{}{edge1: {edge2: {}}, edge3: {edge2: {}}} + expectedDependents = map[EdgeService]map[EdgeService]struct{}{edge2: {edge1: {}, edge3: {}}} - is.Equal(expectedDependencies, unSyncMap(dag.dependencies)) - is.Equal(expectedDependents, unSyncMap(dag.dependents)) + is.Equal(expectedDependencies, dag.dependencies) + is.Equal(expectedDependents, dag.dependents) +} + +// TestDAG_removeService checks the removal of dependencies to the DAG. +func TestDAG_removeService(t *testing.T) { + t.Parallel() + is := assert.New(t) + + edge1 := newEdgeService("scope1", "scope1", "service1") + // edge2 := newEdgeService("scope2", "scope2", "service2") + edge3 := newEdgeService("scope3", "scope3", "service3") + + dag := newDAG() + + dag.addDependency("scope1", "scope1", "service1", "scope2", "scope2", "service2") + dag.addDependency("scope3", "scope3", "service3", "scope2", "scope2", "service2") + + dag.removeService("scope2", "scope2", "service2") + + expectedDependencies := map[EdgeService]map[EdgeService]struct{}{edge1: {}, edge3: {}} + expectedDependents := map[EdgeService]map[EdgeService]struct{}{} + + is.Equal(expectedDependencies, dag.dependencies) + is.Equal(expectedDependents, dag.dependents) } // TestDAG_explainService checks the explanation of dependencies for a service in the DAG. @@ -92,19 +114,3 @@ func TestDAG_explainService(t *testing.T) { is.ElementsMatch([]EdgeService{}, a) is.ElementsMatch([]EdgeService{}, b) } - -func unSyncMap(syncMap *sync.Map) map[interface{}]interface{} { - result := make(map[interface{}]interface{}) - - syncMap.Range(func(key, value interface{}) bool { - if vSyncMap, ok := value.(*sync.Map); ok { - result[key] = unSyncMap(vSyncMap) - } else { - result[key] = value - } - - return true - }) - - return result -} diff --git a/docs/docs/service-lifecycle/shutdowner.md b/docs/docs/service-lifecycle/shutdowner.md index fe20338..cf56f7b 100644 --- a/docs/docs/service-lifecycle/shutdowner.md +++ b/docs/docs/service-lifecycle/shutdowner.md @@ -18,12 +18,12 @@ A shutdown can be triggered on a root scope: ```go // on demand -injector.Shutdown() map[string]error -injector.ShutdownWithContext(context.Context) map[string]error +injector.Shutdown() error +injector.ShutdownWithContext(context.Context) error // on signal -injector.ShutdownOnSignals(...os.Signal) (os.Signal, map[string]error) -injector.ShutdownOnSignalsWithContext(context.Context, ...os.Signal) (os.Signal, map[string]error) +injector.ShutdownOnSignals(...os.Signal) (os.Signal, error) +injector.ShutdownOnSignalsWithContext(context.Context, ...os.Signal) (os.Signal, error) ``` ...on a single service: @@ -90,9 +90,7 @@ Invoke(i, ...) ctx := context.WithTimeout(10 * time.Second) errors := i.ShutdownWithContext(ctx) -for _, err := range errors { - if err != nil { - log.Println("shutdown error:", err) - } +if err != nil { + log.Println("shutdown error:", err) } ``` diff --git a/errors.go b/errors.go index d4a08a2..828ff7c 100644 --- a/errors.go +++ b/errors.go @@ -1,7 +1,65 @@ package do -import "errors" +import ( + "errors" + "fmt" + "strings" +) var ErrServiceNotFound = errors.New("DI: could not find service") var ErrCircularDependency = errors.New("DI: circular dependency detected") var ErrHealthCheckTimeout = errors.New("DI: health check timeout") + +func newShutdownErrors() *ShutdownErrors { + return &ShutdownErrors{} +} + +type ShutdownErrors map[EdgeService]error + +func (e *ShutdownErrors) Add(scopeID string, scopeName string, serviceName string, err error) { + if err != nil { + (*e)[newEdgeService(scopeID, scopeName, serviceName)] = err + } +} + +func (e ShutdownErrors) Len() int { + out := 0 + for _, v := range e { + if v != nil { + out++ + } + } + return out +} + +func (e ShutdownErrors) Error() string { + lines := []string{} + for k, v := range e { + if v != nil { + lines = append(lines, fmt.Sprintf(" - %s > %s: %s", k.ScopeName, k.Service, v.Error())) + } + } + + if len(lines) == 0 { + return "DI: no shutdown errors" + } + + return "DI: shutdown errors:\n" + strings.Join(lines, "\n") +} + +func mergeShutdownErrors(ins ...*ShutdownErrors) *ShutdownErrors { + out := newShutdownErrors() + + for _, in := range ins { + if in != nil { + se := &ShutdownErrors{} + if ok := errors.As(in, &se); ok { + for k, v := range *se { + (*out)[k] = v + } + } + } + } + + return out +} diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 0000000..88ead04 --- /dev/null +++ b/errors_test.go @@ -0,0 +1,67 @@ +package do + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestShutdownErrors_Add(t *testing.T) { + is := assert.New(t) + + se := newShutdownErrors() + is.Equal(0, len(*se)) + is.Equal(0, se.Len()) + + se.Add("scope-1", "scope-a", "service-a", nil) + is.Equal(0, len(*se)) + is.Equal(0, se.Len()) + is.EqualValues(&ShutdownErrors{}, se) + + se.Add("scope-2", "scope-b", "service-b", assert.AnError) + is.Equal(1, len(*se)) + is.Equal(1, se.Len()) + is.EqualValues(&ShutdownErrors{ + {ScopeID: "scope-2", ScopeName: "scope-b", Service: "service-b"}: assert.AnError, + }, se) +} + +func TestShutdownErrors_Error(t *testing.T) { + is := assert.New(t) + + se := newShutdownErrors() + is.Equal(0, len(*se)) + is.Equal(0, se.Len()) + is.EqualValues("DI: no shutdown errors", se.Error()) + + se.Add("scope-1", "scope-a", "service-a", nil) + is.Equal(0, len(*se)) + is.Equal(0, se.Len()) + is.EqualValues("DI: no shutdown errors", se.Error()) + + se.Add("scope-2", "scope-b", "service-b", assert.AnError) + is.Equal(1, len(*se)) + is.Equal(1, se.Len()) + is.EqualValues("DI: shutdown errors:\n - scope-b > service-b: assert.AnError general error for testing", se.Error()) +} + +func TestMergeShutdownErrors(t *testing.T) { + is := assert.New(t) + + se1 := newShutdownErrors() + se2 := newShutdownErrors() + se3 := newShutdownErrors() + + se1.Add("scope-1", "scope-a", "service-a", assert.AnError) + se2.Add("scope-2", "scope-b", "service-b", assert.AnError) + + result := mergeShutdownErrors(se1, se2, se3, nil) + is.Equal(2, result.Len()) + is.EqualValues( + &ShutdownErrors{ + {ScopeID: "scope-1", ScopeName: "scope-a", Service: "service-a"}: assert.AnError, + {ScopeID: "scope-2", ScopeName: "scope-b", Service: "service-b"}: assert.AnError, + }, + result, + ) +} diff --git a/examples/http/std/main.go b/examples/http/std/main.go index 97a9c72..72bf9c0 100644 --- a/examples/http/std/main.go +++ b/examples/http/std/main.go @@ -3,7 +3,7 @@ package main import ( "net/http" - "github.com/samber/do/http/std" + "github.com/samber/do/http/std/v2" ) func main() { diff --git a/examples/shutdownable/example.go b/examples/shutdownable/example.go index a4bdc40..8224f97 100644 --- a/examples/shutdownable/example.go +++ b/examples/shutdownable/example.go @@ -2,7 +2,7 @@ package main import ( "context" - "log" + "fmt" "github.com/samber/do/v2" ) @@ -45,7 +45,7 @@ type Car struct { func (c *Car) Shutdown() error { println("car stopped") - return nil + return fmt.Errorf("💥 BOOOOM!") } func (c *Car) Start() { @@ -95,9 +95,7 @@ func main() { car.Start() _, err := injector.ShutdownOnSignals() - for _, e := range err { - if e != nil { - log.Fatal(e.Error()) - } + if err != nil { + fmt.Println(err.Error()) } } diff --git a/injector.go b/injector.go index 2ea2f9a..871c6cb 100644 --- a/injector.go +++ b/injector.go @@ -17,8 +17,8 @@ type Injector interface { ListInvokedServices() []EdgeService HealthCheck() map[string]error HealthCheckWithContext(context.Context) map[string]error - Shutdown() map[string]error - ShutdownWithContext(context.Context) map[string]error + Shutdown() *ShutdownErrors + ShutdownWithContext(context.Context) *ShutdownErrors clone(*RootScope, *Scope) *Scope // service lifecycle diff --git a/root_scope.go b/root_scope.go index 1a62a3d..3f4f55e 100644 --- a/root_scope.go +++ b/root_scope.go @@ -66,8 +66,8 @@ func (s *RootScope) HealthCheck() map[string]error { return s.self.Heal func (s *RootScope) HealthCheckWithContext(ctx context.Context) map[string]error { return s.self.HealthCheckWithContext(ctx) } -func (s *RootScope) Shutdown() map[string]error { return s.ShutdownWithContext(context.Background()) } -func (s *RootScope) ShutdownWithContext(ctx context.Context) map[string]error { +func (s *RootScope) Shutdown() *ShutdownErrors { return s.ShutdownWithContext(context.Background()) } +func (s *RootScope) ShutdownWithContext(ctx context.Context) *ShutdownErrors { defer func() { if s.healthCheckPool != nil { s.healthCheckPool.stop() @@ -142,14 +142,14 @@ func (s *RootScope) CloneWithOpts(opts *InjectorOpts) *RootScope { // ShutdownOnSignals listens for signals defined in signals parameter in order to graceful stop service. // It will block until receiving any of these signal. // If no signal is provided in signals parameter, syscall.SIGTERM and os.Interrupt will be added as default signal. -func (s *RootScope) ShutdownOnSignals(signals ...os.Signal) (os.Signal, map[string]error) { +func (s *RootScope) ShutdownOnSignals(signals ...os.Signal) (os.Signal, *ShutdownErrors) { return s.ShutdownOnSignalsWithContext(context.Background(), signals...) } // ShutdownOnSignalsWithContext listens for signals defined in signals parameter in order to graceful stop service. // It will block until receiving any of these signal. // If no signal is provided in signals parameter, syscall.SIGTERM and os.Interrupt will be added as default signal. -func (s *RootScope) ShutdownOnSignalsWithContext(ctx context.Context, signals ...os.Signal) (os.Signal, map[string]error) { +func (s *RootScope) ShutdownOnSignalsWithContext(ctx context.Context, signals ...os.Signal) (os.Signal, *ShutdownErrors) { // Make sure there is at least syscall.SIGTERM and os.Interrupt as a signal if len(signals) < 1 { signals = append(signals, syscall.SIGTERM, os.Interrupt) diff --git a/scope.go b/scope.go index c947aef..916f62e 100644 --- a/scope.go +++ b/scope.go @@ -212,42 +212,94 @@ func (s *Scope) asyncHealthCheckWithContext(ctx context.Context) map[string]<-ch } // Shutdown shutdowns the scope and all its children. -func (s *Scope) Shutdown() map[string]error { +func (s *Scope) Shutdown() *ShutdownErrors { return s.ShutdownWithContext(context.Background()) } // ShutdownWithContext shutdowns the scope and all its children. -func (s *Scope) ShutdownWithContext(ctx context.Context) map[string]error { +func (s *Scope) ShutdownWithContext(ctx context.Context) *ShutdownErrors { + s.logf("requested shutdown") + err1 := s.shutdownChildrenInParallel(ctx) + err2 := s.shutdownServicesInParallel(ctx) + s.logf("shutdowned services") + + err := mergeShutdownErrors(err1, err2) + if err.Len() > 0 { + return err + } + + return nil +} + +// shutdownChildrenInParallel runs a parallel shutdown of children scopes. +func (s *Scope) shutdownChildrenInParallel(ctx context.Context) *ShutdownErrors { s.mu.RLock() children := s.childScopes - invocations := invertMap(s.orderedInvocation) s.mu.RUnlock() - s.logf("requested shutdown") + errors := make([]*ShutdownErrors, len(children)) + + var wg sync.WaitGroup + for index, scope := range values(children) { + wg.Add(1) + + go func(s *Scope, i int) { + errors[i] = s.ShutdownWithContext(ctx) + wg.Done() + }(scope, index) + } + wg.Wait() - err := map[string]error{} + s.mu.Lock() + defer s.mu.Unlock() + + s.childScopes = make(map[string]*Scope) // scopes are removed from DI container + return mergeShutdownErrors(errors...) +} - // first shutdown children - for k, child := range children { - err = mergeMaps(err, child.Shutdown()) +// shutdownServicesInParallel runs a parallel shutdown of scope services. +// +// We look for services having no dependents. Then we shutdown them. +// And repeat, until every scope services have been shutdown. +func (s *Scope) shutdownServicesInParallel(ctx context.Context) *ShutdownErrors { + err := newShutdownErrors() + addError := func(name string, e error) { s.mu.Lock() - delete(s.childScopes, k) // scope is removed from DI container + err.Add(s.id, s.name, name, e) s.mu.Unlock() } - // then shutdown scope services - for index := s.orderedInvocationIndex; index >= 0; index-- { - name, ok := invocations[index] - if !ok { - continue + listServices := func() []string { + s.mu.RLock() + defer s.mu.RUnlock() + return keys(s.services) + } + + var wg sync.WaitGroup + + for len(listServices()) > 0 { + // loop over the service that have not been shutdown already + for _, name := range listServices() { + // Check the service has no dependents (dependencies allowed here). + // Services having dependents must be shutdown first. + // The next iteration will shutdown current service. + _, dependents := s.rootScope.dag.explainService(s.id, s.name, name) + if len(dependents) > 0 { + continue + } + + wg.Add(1) + go func(n string) { + e := s.serviceShutdown(ctx, n) + addError(n, e) + wg.Done() + }(name) } - err[name] = s.serviceShutdown(ctx, name) + wg.Wait() } - s.logf("shutdowned services") - return err } @@ -393,23 +445,21 @@ func (s *Scope) serviceHealthCheck(ctx context.Context, name string) error { func (s *Scope) serviceShutdown(ctx context.Context, name string) error { s.mu.RLock() - serviceAny, ok := s.services[name] + s.mu.RUnlock() + if !ok { - s.mu.RUnlock() return serviceNotFound(s, []string{name}) } - s.mu.RUnlock() + var err error service, ok := serviceAny.(serviceShutdown) if ok { s.logf("requested shutdown for service %s", name) - err := service.shutdown(ctx) - if err != nil { - return err - } + err = service.shutdown(ctx) + s.onServiceShutdown(name) } else { panic(fmt.Errorf("DI: service `%s` is not shutdowner", name)) } @@ -417,11 +467,10 @@ func (s *Scope) serviceShutdown(ctx context.Context, name string) error { s.mu.Lock() delete(s.services, name) // service is removed from DI container delete(s.orderedInvocation, name) + s.RootScope().dag.removeService(s.id, s.name, name) s.mu.Unlock() - s.onServiceShutdown(name) - - return nil + return err } /********************************** diff --git a/scope_test.go b/scope_test.go index e645ea8..b960f62 100644 --- a/scope_test.go +++ b/scope_test.go @@ -420,7 +420,7 @@ func TestScope_Shutdown(t *testing.T) { _, _ = InvokeNamed[*lazyTestShutdownerOK](i, "lazy-ok") _, _ = InvokeNamed[*lazyTestShutdownerKO](i, "lazy-ko") - is.EqualValues(map[string]error{"lazy-ok": nil, "lazy-ko": assert.AnError}, i.Shutdown()) + is.EqualValues(&ShutdownErrors{EdgeService{ScopeID: i.self.id, ScopeName: i.self.name, Service: "lazy-ko"}: assert.AnError}, i.Shutdown()) } // @TODO: missing tests for context @@ -454,7 +454,7 @@ func TestScope_ShutdownWithContext(t *testing.T) { _, _ = invokeByName[*lazyTestShutdownerKO](child2a, "child2a-b") _, _ = invokeByName[*lazyTestShutdownerKO](child2b, "child2b-a") - // from rootScope POV + // // from rootScope POV is.Equal(assert.AnError, rootScope.serviceShutdown(ctx, "root-a")) is.ErrorContains(rootScope.serviceShutdown(ctx, "child1-a"), "could not find service") is.ErrorContains(rootScope.serviceShutdown(ctx, "child2a-a"), "could not find service") @@ -540,7 +540,7 @@ func TestScope_serviceHealthCheck(t *testing.T) { _, _ = invokeByName[int](child3, "child3-a") is.ElementsMatch([]EdgeService{newEdgeService(child3.id, child3.name, "child3-a"), newEdgeService(child2a.id, child2a.name, "child2a-a"), newEdgeService(child2a.id, child2a.name, "child2a-b"), newEdgeService(child1.id, child1.name, "child1-a")}, child3.ListInvokedServices()) - is.EqualValues(map[string]error{"child1-a": nil, "child2a-a": nil, "child2a-b": nil, "child2b-a": nil, "child3-a": nil}, child1.Shutdown()) + is.Nil(child1.Shutdown()) is.ElementsMatch([]EdgeService{}, child3.ListInvokedServices()) } diff --git a/virtual_scope.go b/virtual_scope.go index d260934..6e41f6e 100644 --- a/virtual_scope.go +++ b/virtual_scope.go @@ -30,8 +30,8 @@ func (s *virtualScope) HealthCheck() map[string]error { return s.self.H func (s *virtualScope) HealthCheckWithContext(ctx context.Context) map[string]error { return s.self.HealthCheckWithContext(ctx) } -func (s *virtualScope) Shutdown() map[string]error { return s.self.Shutdown() } -func (s *virtualScope) ShutdownWithContext(ctx context.Context) map[string]error { +func (s *virtualScope) Shutdown() *ShutdownErrors { return s.self.Shutdown() } +func (s *virtualScope) ShutdownWithContext(ctx context.Context) *ShutdownErrors { return s.self.ShutdownWithContext(ctx) } func (s *virtualScope) clone(r *RootScope, p *Scope) *Scope { return s.self.clone(r, p) }