From ccea8a4ce53f7b0d1105c1c1e4875e443635d96a Mon Sep 17 00:00:00 2001 From: Irena Rindos Date: Fri, 13 Sep 2024 15:59:49 -0400 Subject: [PATCH] Client cache API paging (#5107) * internal/clientcache: add -force-reset-schema flag * clientcache: stream list pages directly to DB --------- Co-authored-by: Johan Brandhorst-Satzkorn --- api/accounts/account.gen.go | 14 +- api/accounts/option.gen.go | 8 + api/aliases/alias.gen.go | 14 +- api/aliases/option.gen.go | 8 + api/authmethods/authmethods.gen.go | 14 +- api/authmethods/option.gen.go | 8 + api/authtokens/authtokens.gen.go | 14 +- api/authtokens/option.gen.go | 8 + api/billing/option.gen.go | 8 + .../credential_library.gen.go | 14 +- api/credentiallibraries/option.gen.go | 8 + api/credentials/credential.gen.go | 14 +- api/credentials/option.gen.go | 8 + api/credentialstores/credential_store.gen.go | 14 +- api/credentialstores/option.gen.go | 8 + api/groups/group.gen.go | 14 +- api/groups/option.gen.go | 8 + api/hostcatalogs/host_catalog.gen.go | 14 +- api/hostcatalogs/option.gen.go | 8 + api/hosts/host.gen.go | 14 +- api/hosts/option.gen.go | 8 + api/hostsets/host_set.gen.go | 14 +- api/hostsets/option.gen.go | 8 + api/managedgroups/managedgroups.gen.go | 14 +- api/managedgroups/option.gen.go | 8 + api/policies/option.gen.go | 8 + api/policies/policy.gen.go | 14 +- api/roles/option.gen.go | 8 + api/roles/role.gen.go | 14 +- api/scopes/option.gen.go | 8 + api/scopes/scope.gen.go | 14 +- api/sessionrecordings/option.gen.go | 8 + .../session_recording.gen.go | 14 +- api/sessions/option.gen.go | 8 + api/sessions/session.gen.go | 14 +- api/storagebuckets/option.gen.go | 8 + api/storagebuckets/storage_bucket.gen.go | 14 +- api/targets/option.gen.go | 8 + api/targets/target.gen.go | 14 +- api/users/custom.go | 161 ++++----------- api/users/option.gen.go | 8 + api/users/user.gen.go | 14 +- api/workers/option.gen.go | 8 + api/workers/worker.gen.go | 7 +- internal/api/genapi/templates.go | 26 ++- .../internal/cache/options_test.go | 8 +- .../internal/cache/refresh_test.go | 184 ++++++++++-------- .../cache/repository_implicit_scopes_test.go | 2 +- .../cache/repository_resolvable_aliases.go | 116 ++++++----- .../repository_resolvable_aliases_test.go | 30 +-- .../internal/cache/repository_sessions.go | 124 +++++++----- .../cache/repository_sessions_test.go | 30 +-- .../clientcache/internal/cache/status_test.go | 8 +- .../clientcache/internal/daemon/testing.go | 26 ++- 54 files changed, 787 insertions(+), 391 deletions(-) diff --git a/api/accounts/account.gen.go b/api/accounts/account.gen.go index c55e8b1cb7..c80533ada1 100644 --- a/api/accounts/account.gen.go +++ b/api/accounts/account.gen.go @@ -321,7 +321,12 @@ func (c *Client) List(ctx context.Context, authMethodId string, opt ...Option) ( opts, apiOpts := getOpts(opt...) opts.queryMap["auth_method_id"] = authMethodId - req, err := c.client.NewRequest(ctx, "GET", "accounts", nil, apiOpts...) + requestPath := "accounts" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -470,7 +475,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *AccountListResul opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "accounts", nil, apiOpts...) + requestPath := "accounts" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/accounts/option.gen.go b/api/accounts/option.gen.go index 384121c547..304f7212c1 100644 --- a/api/accounts/option.gen.go +++ b/api/accounts/option.gen.go @@ -29,6 +29,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string } func getDefaultOptions() options { @@ -111,6 +112,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + func WithAttributes(inAttributes map[string]interface{}) Option { return func(o *options) { o.postMap["attributes"] = inAttributes diff --git a/api/aliases/alias.gen.go b/api/aliases/alias.gen.go index 54b3803cfe..233e4a2a6b 100644 --- a/api/aliases/alias.gen.go +++ b/api/aliases/alias.gen.go @@ -327,7 +327,12 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Alia opts, apiOpts := getOpts(opt...) opts.queryMap["scope_id"] = scopeId - req, err := c.client.NewRequest(ctx, "GET", "aliases", nil, apiOpts...) + requestPath := "aliases" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -483,7 +488,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *AliasListResult, opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "aliases", nil, apiOpts...) + requestPath := "aliases" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/aliases/option.gen.go b/api/aliases/option.gen.go index 9ec819435b..aa7d59a4ad 100644 --- a/api/aliases/option.gen.go +++ b/api/aliases/option.gen.go @@ -29,6 +29,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string withRecursive bool } @@ -115,6 +116,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/authmethods/authmethods.gen.go b/api/authmethods/authmethods.gen.go index afd51e5d61..7e48a76564 100644 --- a/api/authmethods/authmethods.gen.go +++ b/api/authmethods/authmethods.gen.go @@ -327,7 +327,12 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Auth opts, apiOpts := getOpts(opt...) opts.queryMap["scope_id"] = scopeId - req, err := c.client.NewRequest(ctx, "GET", "auth-methods", nil, apiOpts...) + requestPath := "auth-methods" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -483,7 +488,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *AuthMethodListRe opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "auth-methods", nil, apiOpts...) + requestPath := "auth-methods" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/authmethods/option.gen.go b/api/authmethods/option.gen.go index 079c35a855..4ad2f57919 100644 --- a/api/authmethods/option.gen.go +++ b/api/authmethods/option.gen.go @@ -29,6 +29,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string withRecursive bool } @@ -115,6 +116,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/authtokens/authtokens.gen.go b/api/authtokens/authtokens.gen.go index ced3e321d6..90fbbd69a0 100644 --- a/api/authtokens/authtokens.gen.go +++ b/api/authtokens/authtokens.gen.go @@ -212,7 +212,12 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Auth opts, apiOpts := getOpts(opt...) opts.queryMap["scope_id"] = scopeId - req, err := c.client.NewRequest(ctx, "GET", "auth-tokens", nil, apiOpts...) + requestPath := "auth-tokens" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -368,7 +373,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *AuthTokenListRes opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "auth-tokens", nil, apiOpts...) + requestPath := "auth-tokens" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/authtokens/option.gen.go b/api/authtokens/option.gen.go index 9fa43031a5..83b09a7dba 100644 --- a/api/authtokens/option.gen.go +++ b/api/authtokens/option.gen.go @@ -29,6 +29,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string withRecursive bool } @@ -105,6 +106,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/billing/option.gen.go b/api/billing/option.gen.go index 76babe8e89..04692ceb84 100644 --- a/api/billing/option.gen.go +++ b/api/billing/option.gen.go @@ -30,6 +30,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string } func getDefaultOptions() options { @@ -112,6 +113,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + func WithEndTime(inEndTime string) Option { return func(o *options) { o.queryMap["end_time"] = fmt.Sprintf("%v", inEndTime) diff --git a/api/credentiallibraries/credential_library.gen.go b/api/credentiallibraries/credential_library.gen.go index 7206eb892f..7a0d2998fd 100644 --- a/api/credentiallibraries/credential_library.gen.go +++ b/api/credentiallibraries/credential_library.gen.go @@ -327,7 +327,12 @@ func (c *Client) List(ctx context.Context, credentialStoreId string, opt ...Opti opts, apiOpts := getOpts(opt...) opts.queryMap["credential_store_id"] = credentialStoreId - req, err := c.client.NewRequest(ctx, "GET", "credential-libraries", nil, apiOpts...) + requestPath := "credential-libraries" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -476,7 +481,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *CredentialLibrar opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "credential-libraries", nil, apiOpts...) + requestPath := "credential-libraries" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/credentiallibraries/option.gen.go b/api/credentiallibraries/option.gen.go index a4dc80b1d4..ffbadffd77 100644 --- a/api/credentiallibraries/option.gen.go +++ b/api/credentiallibraries/option.gen.go @@ -29,6 +29,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string } func getDefaultOptions() options { @@ -111,6 +112,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + func WithVaultSSHCertificateCredentialLibraryAdditionalValidPrincipals(inAdditionalValidPrincipals []string) Option { return func(o *options) { raw, ok := o.postMap["attributes"] diff --git a/api/credentials/credential.gen.go b/api/credentials/credential.gen.go index d77e7d1283..bb723871f2 100644 --- a/api/credentials/credential.gen.go +++ b/api/credentials/credential.gen.go @@ -325,7 +325,12 @@ func (c *Client) List(ctx context.Context, credentialStoreId string, opt ...Opti opts, apiOpts := getOpts(opt...) opts.queryMap["credential_store_id"] = credentialStoreId - req, err := c.client.NewRequest(ctx, "GET", "credentials", nil, apiOpts...) + requestPath := "credentials" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -474,7 +479,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *CredentialListRe opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "credentials", nil, apiOpts...) + requestPath := "credentials" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/credentials/option.gen.go b/api/credentials/option.gen.go index 7cf0fb75e0..c6508253cf 100644 --- a/api/credentials/option.gen.go +++ b/api/credentials/option.gen.go @@ -29,6 +29,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string } func getDefaultOptions() options { @@ -111,6 +112,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + func WithAttributes(inAttributes map[string]interface{}) Option { return func(o *options) { o.postMap["attributes"] = inAttributes diff --git a/api/credentialstores/credential_store.gen.go b/api/credentialstores/credential_store.gen.go index 055af8150b..07b1bc1e33 100644 --- a/api/credentialstores/credential_store.gen.go +++ b/api/credentialstores/credential_store.gen.go @@ -326,7 +326,12 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Cred opts, apiOpts := getOpts(opt...) opts.queryMap["scope_id"] = scopeId - req, err := c.client.NewRequest(ctx, "GET", "credential-stores", nil, apiOpts...) + requestPath := "credential-stores" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -482,7 +487,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *CredentialStoreL opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "credential-stores", nil, apiOpts...) + requestPath := "credential-stores" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/credentialstores/option.gen.go b/api/credentialstores/option.gen.go index eb44bdecf4..86e670af0c 100644 --- a/api/credentialstores/option.gen.go +++ b/api/credentialstores/option.gen.go @@ -29,6 +29,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string withRecursive bool } @@ -115,6 +116,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/groups/group.gen.go b/api/groups/group.gen.go index ed0c31b65b..dc9b249412 100644 --- a/api/groups/group.gen.go +++ b/api/groups/group.gen.go @@ -320,7 +320,12 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Grou opts, apiOpts := getOpts(opt...) opts.queryMap["scope_id"] = scopeId - req, err := c.client.NewRequest(ctx, "GET", "groups", nil, apiOpts...) + requestPath := "groups" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -476,7 +481,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *GroupListResult, opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "groups", nil, apiOpts...) + requestPath := "groups" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/groups/option.gen.go b/api/groups/option.gen.go index d2e7b5a3d4..81553cd4a4 100644 --- a/api/groups/option.gen.go +++ b/api/groups/option.gen.go @@ -29,6 +29,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string withRecursive bool } @@ -115,6 +116,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/hostcatalogs/host_catalog.gen.go b/api/hostcatalogs/host_catalog.gen.go index 5e4feb45a8..43d907a11c 100644 --- a/api/hostcatalogs/host_catalog.gen.go +++ b/api/hostcatalogs/host_catalog.gen.go @@ -331,7 +331,12 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Host opts, apiOpts := getOpts(opt...) opts.queryMap["scope_id"] = scopeId - req, err := c.client.NewRequest(ctx, "GET", "host-catalogs", nil, apiOpts...) + requestPath := "host-catalogs" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -487,7 +492,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *HostCatalogListR opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "host-catalogs", nil, apiOpts...) + requestPath := "host-catalogs" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/hostcatalogs/option.gen.go b/api/hostcatalogs/option.gen.go index fe5d26fd1e..fdd1496b28 100644 --- a/api/hostcatalogs/option.gen.go +++ b/api/hostcatalogs/option.gen.go @@ -30,6 +30,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string withRecursive bool } @@ -116,6 +117,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/hosts/host.gen.go b/api/hosts/host.gen.go index 016869f447..015414e40d 100644 --- a/api/hosts/host.gen.go +++ b/api/hosts/host.gen.go @@ -327,7 +327,12 @@ func (c *Client) List(ctx context.Context, hostCatalogId string, opt ...Option) opts, apiOpts := getOpts(opt...) opts.queryMap["host_catalog_id"] = hostCatalogId - req, err := c.client.NewRequest(ctx, "GET", "hosts", nil, apiOpts...) + requestPath := "hosts" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -476,7 +481,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *HostListResult, opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "hosts", nil, apiOpts...) + requestPath := "hosts" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/hosts/option.gen.go b/api/hosts/option.gen.go index 32809d61c2..d15e5dd338 100644 --- a/api/hosts/option.gen.go +++ b/api/hosts/option.gen.go @@ -29,6 +29,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string } func getDefaultOptions() options { @@ -111,6 +112,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + func WithStaticHostAddress(inAddress string) Option { return func(o *options) { raw, ok := o.postMap["attributes"] diff --git a/api/hostsets/host_set.gen.go b/api/hostsets/host_set.gen.go index 7ce7e6640b..688f4fc6c1 100644 --- a/api/hostsets/host_set.gen.go +++ b/api/hostsets/host_set.gen.go @@ -325,7 +325,12 @@ func (c *Client) List(ctx context.Context, hostCatalogId string, opt ...Option) opts, apiOpts := getOpts(opt...) opts.queryMap["host_catalog_id"] = hostCatalogId - req, err := c.client.NewRequest(ctx, "GET", "host-sets", nil, apiOpts...) + requestPath := "host-sets" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -474,7 +479,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *HostSetListResul opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "host-sets", nil, apiOpts...) + requestPath := "host-sets" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/hostsets/option.gen.go b/api/hostsets/option.gen.go index fa81eddb7b..ec67230598 100644 --- a/api/hostsets/option.gen.go +++ b/api/hostsets/option.gen.go @@ -29,6 +29,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string } func getDefaultOptions() options { @@ -111,6 +112,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + func WithAttributes(inAttributes map[string]interface{}) Option { return func(o *options) { o.postMap["attributes"] = inAttributes diff --git a/api/managedgroups/managedgroups.gen.go b/api/managedgroups/managedgroups.gen.go index 02cc9d5be8..e0eeb668f1 100644 --- a/api/managedgroups/managedgroups.gen.go +++ b/api/managedgroups/managedgroups.gen.go @@ -321,7 +321,12 @@ func (c *Client) List(ctx context.Context, authMethodId string, opt ...Option) ( opts, apiOpts := getOpts(opt...) opts.queryMap["auth_method_id"] = authMethodId - req, err := c.client.NewRequest(ctx, "GET", "managed-groups", nil, apiOpts...) + requestPath := "managed-groups" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -470,7 +475,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *ManagedGroupList opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "managed-groups", nil, apiOpts...) + requestPath := "managed-groups" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/managedgroups/option.gen.go b/api/managedgroups/option.gen.go index 7da2f0803a..5c13b52fff 100644 --- a/api/managedgroups/option.gen.go +++ b/api/managedgroups/option.gen.go @@ -29,6 +29,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string } func getDefaultOptions() options { @@ -111,6 +112,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + func WithAttributes(inAttributes map[string]interface{}) Option { return func(o *options) { o.postMap["attributes"] = inAttributes diff --git a/api/policies/option.gen.go b/api/policies/option.gen.go index 376e19c66b..53dccd6597 100644 --- a/api/policies/option.gen.go +++ b/api/policies/option.gen.go @@ -29,6 +29,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string withRecursive bool } @@ -115,6 +116,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/policies/policy.gen.go b/api/policies/policy.gen.go index bf105a2baa..3fb60ce7bd 100644 --- a/api/policies/policy.gen.go +++ b/api/policies/policy.gen.go @@ -325,7 +325,12 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Poli opts, apiOpts := getOpts(opt...) opts.queryMap["scope_id"] = scopeId - req, err := c.client.NewRequest(ctx, "GET", "policies", nil, apiOpts...) + requestPath := "policies" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -481,7 +486,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *PolicyListResult opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "policies", nil, apiOpts...) + requestPath := "policies" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/roles/option.gen.go b/api/roles/option.gen.go index 06cda67ead..2671a82d36 100644 --- a/api/roles/option.gen.go +++ b/api/roles/option.gen.go @@ -29,6 +29,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string withRecursive bool } @@ -115,6 +116,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/roles/role.gen.go b/api/roles/role.gen.go index 7b6b550711..b9c052c2f2 100644 --- a/api/roles/role.gen.go +++ b/api/roles/role.gen.go @@ -323,7 +323,12 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Role opts, apiOpts := getOpts(opt...) opts.queryMap["scope_id"] = scopeId - req, err := c.client.NewRequest(ctx, "GET", "roles", nil, apiOpts...) + requestPath := "roles" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -479,7 +484,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *RoleListResult, opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "roles", nil, apiOpts...) + requestPath := "roles" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/scopes/option.gen.go b/api/scopes/option.gen.go index d20936ba0c..aef8586a37 100644 --- a/api/scopes/option.gen.go +++ b/api/scopes/option.gen.go @@ -30,6 +30,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string withRecursive bool } @@ -116,6 +117,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/scopes/scope.gen.go b/api/scopes/scope.gen.go index 93119cd5e3..9752443562 100644 --- a/api/scopes/scope.gen.go +++ b/api/scopes/scope.gen.go @@ -321,7 +321,12 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Scop opts, apiOpts := getOpts(opt...) opts.queryMap["scope_id"] = scopeId - req, err := c.client.NewRequest(ctx, "GET", "scopes", nil, apiOpts...) + requestPath := "scopes" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -477,7 +482,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *ScopeListResult, opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "scopes", nil, apiOpts...) + requestPath := "scopes" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/sessionrecordings/option.gen.go b/api/sessionrecordings/option.gen.go index beae7de2ab..d59df635a7 100644 --- a/api/sessionrecordings/option.gen.go +++ b/api/sessionrecordings/option.gen.go @@ -28,6 +28,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string withRecursive bool } @@ -95,6 +96,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/sessionrecordings/session_recording.gen.go b/api/sessionrecordings/session_recording.gen.go index 164880cf93..1c9bcbe770 100644 --- a/api/sessionrecordings/session_recording.gen.go +++ b/api/sessionrecordings/session_recording.gen.go @@ -219,7 +219,12 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Sess opts, apiOpts := getOpts(opt...) opts.queryMap["scope_id"] = scopeId - req, err := c.client.NewRequest(ctx, "GET", "session-recordings", nil, apiOpts...) + requestPath := "session-recordings" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -375,7 +380,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *SessionRecording opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "session-recordings", nil, apiOpts...) + requestPath := "session-recordings" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/sessions/option.gen.go b/api/sessions/option.gen.go index 3887657af2..b9799ff3ee 100644 --- a/api/sessions/option.gen.go +++ b/api/sessions/option.gen.go @@ -30,6 +30,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string withRecursive bool } @@ -116,6 +117,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/sessions/session.gen.go b/api/sessions/session.gen.go index 5ffaa9f6fe..63cdb047a4 100644 --- a/api/sessions/session.gen.go +++ b/api/sessions/session.gen.go @@ -179,7 +179,12 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Sess opts, apiOpts := getOpts(opt...) opts.queryMap["scope_id"] = scopeId - req, err := c.client.NewRequest(ctx, "GET", "sessions", nil, apiOpts...) + requestPath := "sessions" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -335,7 +340,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *SessionListResul opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "sessions", nil, apiOpts...) + requestPath := "sessions" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/storagebuckets/option.gen.go b/api/storagebuckets/option.gen.go index fc97a45869..9d8e30aba5 100644 --- a/api/storagebuckets/option.gen.go +++ b/api/storagebuckets/option.gen.go @@ -30,6 +30,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string withRecursive bool } @@ -116,6 +117,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/storagebuckets/storage_bucket.gen.go b/api/storagebuckets/storage_bucket.gen.go index b232f5a2db..834671e02f 100644 --- a/api/storagebuckets/storage_bucket.gen.go +++ b/api/storagebuckets/storage_bucket.gen.go @@ -329,7 +329,12 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Stor opts, apiOpts := getOpts(opt...) opts.queryMap["scope_id"] = scopeId - req, err := c.client.NewRequest(ctx, "GET", "storage-buckets", nil, apiOpts...) + requestPath := "storage-buckets" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -485,7 +490,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *StorageBucketLis opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "storage-buckets", nil, apiOpts...) + requestPath := "storage-buckets" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/targets/option.gen.go b/api/targets/option.gen.go index 4d71168be6..80351506f7 100644 --- a/api/targets/option.gen.go +++ b/api/targets/option.gen.go @@ -29,6 +29,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string withRecursive bool } @@ -115,6 +116,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/targets/target.gen.go b/api/targets/target.gen.go index 6d7ad43315..4743dbe644 100644 --- a/api/targets/target.gen.go +++ b/api/targets/target.gen.go @@ -339,7 +339,12 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Targ opts, apiOpts := getOpts(opt...) opts.queryMap["scope_id"] = scopeId - req, err := c.client.NewRequest(ctx, "GET", "targets", nil, apiOpts...) + requestPath := "targets" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -495,7 +500,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *TargetListResult opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "targets", nil, apiOpts...) + requestPath := "targets" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/users/custom.go b/api/users/custom.go index 3c33c3fd5e..bd211cec83 100644 --- a/api/users/custom.go +++ b/api/users/custom.go @@ -5,10 +5,8 @@ package users import ( "context" - "encoding/json" "fmt" "net/url" - "slices" "github.com/hashicorp/boundary/api/aliases" ) @@ -25,135 +23,44 @@ func (c *Client) ListResolvableAliases(ctx context.Context, userId string, opt . return nil, fmt.Errorf("nil client") } - opts, apiOpts := getOpts(opt...) - req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("users/%s:list-resolvable-aliases", url.PathEscape(userId)), nil, apiOpts...) - if err != nil { - return nil, fmt.Errorf("error creating List request: %w", err) - } - - if len(opts.queryMap) > 0 { - q := url.Values{} - for k, v := range opts.queryMap { - q.Add(k, v) - } - req.URL.RawQuery = q.Encode() - } - - resp, err := c.client.Do(req) - if err != nil { - return nil, fmt.Errorf("error performing client request during List call: %w", err) - } + opts, _ := getOpts(opt...) + apiClient := aliases.NewClient(c.client) + return apiClient.List(ctx, "global", + aliases.WithAutomaticVersioning(opts.withAutomaticVersioning), + aliases.WithSkipCurlOutput(opts.withSkipCurlOutput), + aliases.WithFilter(opts.withFilter), + aliases.WithListToken(opts.withListToken), + aliases.WithClientDirectedPagination(opts.withClientDirectedPagination), + aliases.WithPageSize(opts.withPageSize), + aliases.WithRecursive(opts.withRecursive), + aliases.WithResourcePathOverride(fmt.Sprintf("users/%s:list-resolvable-aliases", url.PathEscape(userId))), + ) +} - target := new(aliases.AliasListResult) - apiErr, err := resp.Decode(target) - if err != nil { - return nil, fmt.Errorf("error decoding List response: %w", err) +func (c *Client) ListResolvableAliasesNextPage(ctx context.Context, userId string, currentPage *aliases.AliasListResult, opt ...Option) (*aliases.AliasListResult, error) { + if currentPage == nil { + return nil, fmt.Errorf("empty currentPage value passed into ListResolvableAliasesNextPage request") } - if apiErr != nil { - return nil, apiErr + if userId == "" { + return nil, fmt.Errorf("empty userId value passed into ListResolvableAliasesNextPage request") } - target.Response = resp - if target.ResponseType == "complete" || target.ResponseType == "" { - return target, nil + if c.client == nil { + return nil, fmt.Errorf("nil client") } - // If there are more results, automatically fetch the rest of the results. - // idToIndex keeps a map from the ID of an item to its index in target.Items. - // This is used to update updated items in-place and remove deleted items - // from the result after pagination is done. - idToIndex := map[string]int{} - for i, item := range target.Items { - idToIndex[item.Id] = i + if currentPage.ResponseType == "complete" || currentPage.ResponseType == "" { + return nil, fmt.Errorf("no more pages available in ListResolvableAliasesNextPage request") } - for { - req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("users/%s:list-resolvable-aliases", url.PathEscape(userId)), nil, apiOpts...) - if err != nil { - return nil, fmt.Errorf("error creating List request: %w", err) - } - - opts.queryMap["list_token"] = target.ListToken - if len(opts.queryMap) > 0 { - q := url.Values{} - for k, v := range opts.queryMap { - q.Add(k, v) - } - req.URL.RawQuery = q.Encode() - } - resp, err := c.client.Do(req) - if err != nil { - return nil, fmt.Errorf("error performing client request during List call: %w", err) - } - - page := new(aliases.AliasListResult) - apiErr, err := resp.Decode(page) - if err != nil { - return nil, fmt.Errorf("error decoding List response: %w", err) - } - if apiErr != nil { - return nil, apiErr - } - for _, item := range page.Items { - if i, ok := idToIndex[item.Id]; ok { - // Item has already been seen at index i, update in-place - target.Items[i] = item - } else { - target.Items = append(target.Items, item) - idToIndex[item.Id] = len(target.Items) - 1 - } - } - // RemovedIds contain any Alias that were deleted since the last response. - target.RemovedIds = append(target.RemovedIds, page.RemovedIds...) - target.EstItemCount = page.EstItemCount - target.ListToken = page.ListToken - target.ResponseType = page.ResponseType - target.Response = resp - if target.ResponseType == "complete" { - break - } - } - // For now, removedIds will only be populated if this pagination cycle was the result of a - // "refresh" operation (i.e., the caller provided a list token option to this call). - // - // Sort to make response deterministic - slices.Sort(target.RemovedIds) - // Remove any duplicates - target.RemovedIds = slices.Compact(target.RemovedIds) - // Remove items that were deleted since the end of the last iteration. - // If an Alias has been updated and subsequently removed, we don't want - // it to appear both in the Items and RemovedIds, so we remove it from the Items. - for _, removedId := range target.RemovedIds { - if i, ok := idToIndex[removedId]; ok { - // Remove the item at index i without preserving order - // https://github.com/golang/go/wiki/SliceTricks#delete-without-preserving-order - target.Items[i] = target.Items[len(target.Items)-1] - target.Items = target.Items[:len(target.Items)-1] - // Update the index of the previously last element - idToIndex[target.Items[i].Id] = i - } - } - // Since we paginated to the end, we can avoid confusion - // for the user by setting the estimated item count to the - // length of the items slice. If we don't set this here, it - // will equal the value returned in the last response, which is - // often much smaller than the total number returned. - target.EstItemCount = uint(len(target.Items)) - // Sort the results again since in-place updates and deletes - // may have shuffled items. We sort by created time descending - // (most recently created first), same as the API. - slices.SortFunc(target.Items, func(i, j *aliases.Alias) int { - return j.CreatedTime.Compare(i.CreatedTime) - }) - // Finally, since we made at least 2 requests to the server to fulfill this - // function call, resp.Body and resp.Map will only contain the most recent response. - // Overwrite them with the true response. - target.GetResponse().Body.Reset() - if err := json.NewEncoder(target.GetResponse().Body).Encode(target); err != nil { - return nil, fmt.Errorf("error encoding final JSON list response: %w", err) - } - if err := json.Unmarshal(target.GetResponse().Body.Bytes(), &target.GetResponse().Map); err != nil { - return nil, fmt.Errorf("error encoding final map list response: %w", err) - } - // Note: the HTTP response body is consumed by resp.Decode in the loop, - // so it doesn't need to be updated (it will always be, and has always been, empty). - return target, nil + opts, _ := getOpts(opt...) + apiClient := aliases.NewClient(c.client) + return apiClient.ListNextPage(ctx, currentPage, + aliases.WithAutomaticVersioning(opts.withAutomaticVersioning), + aliases.WithSkipCurlOutput(opts.withSkipCurlOutput), + aliases.WithFilter(opts.withFilter), + aliases.WithListToken(opts.withListToken), + aliases.WithClientDirectedPagination(opts.withClientDirectedPagination), + aliases.WithPageSize(opts.withPageSize), + aliases.WithRecursive(opts.withRecursive), + aliases.WithResourcePathOverride(fmt.Sprintf("users/%s:list-resolvable-aliases", url.PathEscape(userId))), + ) } diff --git a/api/users/option.gen.go b/api/users/option.gen.go index a6349cde10..170754148a 100644 --- a/api/users/option.gen.go +++ b/api/users/option.gen.go @@ -29,6 +29,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string withRecursive bool } @@ -115,6 +116,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/users/user.gen.go b/api/users/user.gen.go index 51065721f9..0728043f6e 100644 --- a/api/users/user.gen.go +++ b/api/users/user.gen.go @@ -324,7 +324,12 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*User opts, apiOpts := getOpts(opt...) opts.queryMap["scope_id"] = scopeId - req, err := c.client.NewRequest(ctx, "GET", "users", nil, apiOpts...) + requestPath := "users" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -480,7 +485,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *UserListResult, opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "users", nil, apiOpts...) + requestPath := "users" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/workers/option.gen.go b/api/workers/option.gen.go index 1ca3ada826..d5af29ef43 100644 --- a/api/workers/option.gen.go +++ b/api/workers/option.gen.go @@ -29,6 +29,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string withRecursive bool } @@ -115,6 +116,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/workers/worker.gen.go b/api/workers/worker.gen.go index efe18d859e..c55636263c 100644 --- a/api/workers/worker.gen.go +++ b/api/workers/worker.gen.go @@ -377,7 +377,12 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Work opts, apiOpts := getOpts(opt...) opts.queryMap["scope_id"] = scopeId - req, err := c.client.NewRequest(ctx, "GET", "workers", nil, apiOpts...) + requestPath := "workers" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/internal/api/genapi/templates.go b/internal/api/genapi/templates.go index aa37311e5a..7899fb1c3e 100644 --- a/internal/api/genapi/templates.go +++ b/internal/api/genapi/templates.go @@ -155,7 +155,7 @@ func fillTemplates() { optionsMap[input.Package] = optionMap } // Override some defined options - if len(in.fieldOverrides) > 0 && optionsMap != nil { + if len(in.fieldOverrides) > 0 { for _, override := range in.fieldOverrides { inOpts := optionsMap[input.Package] if inOpts != nil { @@ -243,7 +243,12 @@ func (c *Client) List(ctx context.Context, {{ .CollectionFunctionArg }} string, opts, apiOpts := getOpts(opt...) opts.queryMap["{{ snakeCase .CollectionFunctionArg }}"] = {{ .CollectionFunctionArg }} - req, err := c.client.NewRequest(ctx, "GET", "{{ .CollectionPath }}", nil, apiOpts...) + requestPath := "{{ .CollectionPath }}" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -404,7 +409,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *{{ .Name }}ListR opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "{{ .CollectionPath }}", nil, apiOpts...) + requestPath := "{{ .CollectionPath }}" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -434,7 +444,7 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *{{ .Name }}ListR // Ensure values are carried forward to the next call nextPage.{{ .CollectionFunctionArg }} = currentPage.{{ .CollectionFunctionArg }} -{{ if .RecursiveListing }} +{{ if .RecursiveListing }} nextPage.recursive = currentPage.recursive {{ end }} nextPage.pageSize = currentPage.pageSize @@ -943,6 +953,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string {{ if .RecursiveListing }} withRecursive bool {{ end }} } @@ -1031,6 +1042,13 @@ func WithPageSize(with uint32) Option { o.withPageSize = with } } + +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} {{ if .RecursiveListing }} // WithRecursive tells the API to use recursion for listing operations on this // resource diff --git a/internal/clientcache/internal/cache/options_test.go b/internal/clientcache/internal/cache/options_test.go index 9f25faebe3..26810f8309 100644 --- a/internal/clientcache/internal/cache/options_test.go +++ b/internal/clientcache/internal/cache/options_test.go @@ -62,8 +62,8 @@ func Test_GetOpts(t *testing.T) { assert.Equal(t, opts, testOpts) }) t.Run("WithSessionRetrievalFunc", func(t *testing.T) { - var f SessionRetrievalFunc = func(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue) ([]*sessions.Session, []string, RefreshTokenValue, error) { - return nil, nil, "", nil + var f SessionRetrievalFunc = func(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue, inPage *sessions.SessionListResult, opt ...Option) (*sessions.SessionListResult, RefreshTokenValue, error) { + return nil, "", nil } opts, err := getOpts(WithSessionRetrievalFunc(f)) require.NoError(t, err) @@ -75,8 +75,8 @@ func Test_GetOpts(t *testing.T) { assert.Equal(t, opts, testOpts) }) t.Run("WithAliasRetrievalFunc", func(t *testing.T) { - var f ResolvableAliasRetrievalFunc = func(ctx context.Context, addr, authTok, userId string, refreshTok RefreshTokenValue) ([]*aliases.Alias, []string, RefreshTokenValue, error) { - return nil, nil, "", nil + var f ResolvableAliasRetrievalFunc = func(ctx context.Context, addr, authTok, userId string, refreshTok RefreshTokenValue, inPage *aliases.AliasListResult, opt ...Option) (*aliases.AliasListResult, RefreshTokenValue, error) { + return nil, "", nil } opts, err := getOpts(WithAliasRetrievalFunc(f)) require.NoError(t, err) diff --git a/internal/clientcache/internal/cache/refresh_test.go b/internal/clientcache/internal/cache/refresh_test.go index 0d45187279..3895376132 100644 --- a/internal/clientcache/internal/cache/refresh_test.go +++ b/internal/clientcache/internal/cache/refresh_test.go @@ -70,6 +70,38 @@ func testTargetStaticResourceRetrievalFunc(inFunc func(ctx context.Context, s1, } } +func testSessionStaticResourceRetrievalFunc(inFunc func(ctx context.Context, s1, s2 string, refToken RefreshTokenValue) ([]*sessions.Session, []string, RefreshTokenValue, error)) SessionRetrievalFunc { + return func(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue, inPage *sessions.SessionListResult, opt ...Option) (ret *sessions.SessionListResult, refreshToken RefreshTokenValue, err error) { + retSessions, removed, refreshToken, err := inFunc(ctx, addr, authTok, refreshTok) + if err != nil { + return nil, "", err + } + + ret = &sessions.SessionListResult{ + Items: retSessions, + RemovedIds: removed, + ResponseType: "complete", + } + return ret, refreshToken, nil + } +} + +func testResolvableAliasStaticResourceRetrievalFunc(inFunc func(ctx context.Context, s1, s2, s3 string, refToken RefreshTokenValue) ([]*aliases.Alias, []string, RefreshTokenValue, error)) ResolvableAliasRetrievalFunc { + return func(ctx context.Context, addr, authTok, userId string, refreshTok RefreshTokenValue, inPage *aliases.AliasListResult, opt ...Option) (ret *aliases.AliasListResult, refreshToken RefreshTokenValue, err error) { + retSessions, removed, refreshToken, err := inFunc(ctx, addr, authTok, userId, refreshTok) + if err != nil { + return nil, "", err + } + + ret = &aliases.AliasListResult{ + Items: retSessions, + RemovedIds: removed, + ResponseType: "complete", + } + return ret, refreshToken, nil + } +} + // testNoRefreshRetrievalFunc simulates a controller that doesn't support refresh // since it does not return any refresh token. func testNoRefreshRetrievalFunc[T any](t *testing.T) func(context.Context, string, string, RefreshTokenValue) ([]T, []string, RefreshTokenValue, error) { @@ -440,8 +472,8 @@ func TestRefreshForSearch(t *testing.T) { target("4"), } opts := []Option{ - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil))), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, [][]*targets.Target{ retTargets[:3], @@ -495,8 +527,8 @@ func TestRefreshForSearch(t *testing.T) { target("4"), } opts := []Option{ - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil))), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, [][]*targets.Target{ retTargets[:3], @@ -553,9 +585,9 @@ func TestRefreshForSearch(t *testing.T) { // Get the first set of resources, but no refresh tokens err = rs.Refresh(ctx, - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), - WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))) + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil))), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))), + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)))) assert.ErrorContains(t, err, ErrRefreshNotSupported.Error()) got, err := r.ListTargets(ctx, at.Id) @@ -569,15 +601,15 @@ func TestRefreshForSearch(t *testing.T) { // wont be refreshed any more, and we wont see the error when refreshing // any more. err = rs.Refresh(ctx, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) assert.Nil(t, err) err = rs.RecheckCachingSupport(ctx, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) assert.Nil(t, err) got, err = r.ListTargets(ctx, at.Id) @@ -590,9 +622,9 @@ func TestRefreshForSearch(t *testing.T) { // Now simulate the controller updating to support refresh tokens and // the resources starting to be cached. err = rs.RecheckCachingSupport(ctx, - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, [][]*targets.Target{retTargets}, [][]string{{}}))), - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil))), ) assert.Nil(t, err, err) @@ -617,9 +649,9 @@ func TestRefreshForSearch(t *testing.T) { session("4"), } opts := []Option{ - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil))), - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, [][]*sessions.Session{ retSess[:3], retSess[3:], @@ -628,7 +660,7 @@ func TestRefreshForSearch(t *testing.T) { nil, {retSess[0].Id, retSess[1].Id}, }, - )), + ))), } // First call doesn't sync anything because no sessions were already synced yet @@ -669,9 +701,9 @@ func TestRefreshForSearch(t *testing.T) { session("4"), } opts := []Option{ - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil))), - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, [][]*sessions.Session{ retSess[:3], retSess[3:], @@ -680,7 +712,7 @@ func TestRefreshForSearch(t *testing.T) { nil, {retSess[0].Id, retSess[1].Id}, }, - )), + ))), } // First call doesn't sync anything because no sessions were already synced yet @@ -726,9 +758,9 @@ func TestRefreshForSearch(t *testing.T) { alias("4"), } opts := []Option{ - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil))), - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, [][]*aliases.Alias{ retAl[:3], retAl[3:], @@ -737,7 +769,7 @@ func TestRefreshForSearch(t *testing.T) { nil, {retAl[0].Id, retAl[1].Id}, }, - )), + ))), } // First call doesn't sync anything because no aliases were already synced yet @@ -778,9 +810,9 @@ func TestRefreshForSearch(t *testing.T) { alias("4"), } opts := []Option{ - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil))), - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, [][]*aliases.Alias{ retAls[:3], retAls[3:], @@ -789,7 +821,7 @@ func TestRefreshForSearch(t *testing.T) { nil, {retAls[0].Id, retAls[1].Id}, }, - )), + ))), } // First call doesn't sync anything because no aliases were already synced yet @@ -854,8 +886,8 @@ func TestRefreshNonBlocking(t *testing.T) { target("4"), } opts := []Option{ - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil))), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, [][]*targets.Target{ retTargets[:3], @@ -915,9 +947,9 @@ func TestRefreshNonBlocking(t *testing.T) { session("4"), } opts := []Option{ - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil))), - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, [][]*sessions.Session{ retSess[:3], retSess[3:], @@ -926,7 +958,7 @@ func TestRefreshNonBlocking(t *testing.T) { nil, {retSess[0].Id, retSess[1].Id}, }, - )), + ))), } refreshWaitChs := &testRefreshWaitChs{ @@ -977,9 +1009,9 @@ func TestRefreshNonBlocking(t *testing.T) { alias("4"), } opts := []Option{ - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil))), - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, [][]*aliases.Alias{ retAl[:3], retAl[3:], @@ -988,7 +1020,7 @@ func TestRefreshNonBlocking(t *testing.T) { nil, {retAl[0].Id, retAl[1].Id}, }, - )), + ))), } refreshWaitChs := &testRefreshWaitChs{ @@ -1056,8 +1088,8 @@ func TestRefresh(t *testing.T) { target("4"), } opts := []Option{ - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil))), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, [][]*targets.Target{ retTargets[:3], @@ -1098,9 +1130,9 @@ func TestRefresh(t *testing.T) { session("4"), } opts := []Option{ - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil))), - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, [][]*sessions.Session{ retSess[:3], retSess[3:], @@ -1109,7 +1141,7 @@ func TestRefresh(t *testing.T) { nil, {retSess[0].Id, retSess[1].Id}, }, - )), + ))), } assert.NoError(t, rs.Refresh(ctx, opts...)) cachedSessions, err := r.ListSessions(ctx, at.Id) @@ -1139,9 +1171,9 @@ func TestRefresh(t *testing.T) { alias("4"), } opts := []Option{ - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil))), - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, [][]*aliases.Alias{ retAls[:3], retAls[3:], @@ -1150,7 +1182,7 @@ func TestRefresh(t *testing.T) { nil, {retAls[0].Id, retAls[1].Id}, }, - )), + ))), } assert.NoError(t, rs.Refresh(ctx, opts...)) cachedAliases, err := r.ListResolvableAliases(ctx, at.Id) @@ -1175,8 +1207,8 @@ func TestRefresh(t *testing.T) { innerErr := errors.New("test error") err = rs.Refresh(ctx, - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil))), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(func(ctx context.Context, addr, token string, refreshTok RefreshTokenValue) ([]*targets.Target, []string, RefreshTokenValue, error) { require.Equal(t, boundaryAddr, addr) require.Equal(t, at.Token, token) @@ -1184,13 +1216,13 @@ func TestRefresh(t *testing.T) { }))) assert.ErrorContains(t, err, innerErr.Error()) err = rs.Refresh(ctx, - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil))), - WithSessionRetrievalFunc(func(ctx context.Context, addr, token string, refreshTok RefreshTokenValue) ([]*sessions.Session, []string, RefreshTokenValue, error) { + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(func(ctx context.Context, addr, token string, refreshTok RefreshTokenValue) ([]*sessions.Session, []string, RefreshTokenValue, error) { require.Equal(t, boundaryAddr, addr) require.Equal(t, at.Token, token) return nil, nil, "", innerErr - })) + }))) assert.ErrorContains(t, err, innerErr.Error()) }) @@ -1217,8 +1249,8 @@ func TestRefresh(t *testing.T) { assert.Len(t, us, 1) require.NoError(t, rs.Refresh(ctx, - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil))), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil))))) ps, err = r.listTokens(ctx, u) @@ -1261,8 +1293,8 @@ func TestRecheckCachingSupport(t *testing.T) { // Since this user doesn't have any resources, the user's data will still // only get updated with a call to Refresh. assert.NoError(t, rs.RecheckCachingSupport(ctx, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))))) got, err := r.ListTargets(ctx, at.Id) @@ -1273,8 +1305,8 @@ func TestRecheckCachingSupport(t *testing.T) { assert.False(t, got.Incomplete) err = rs.Refresh(ctx, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)))) assert.ErrorIs(t, err, ErrRefreshNotSupported) @@ -1287,8 +1319,8 @@ func TestRecheckCachingSupport(t *testing.T) { // now a full fetch will work since the user has resources and no refresh token assert.NoError(t, rs.RecheckCachingSupport(ctx, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))))) }) @@ -1302,9 +1334,9 @@ func TestRecheckCachingSupport(t *testing.T) { require.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id})) assert.NoError(t, rs.RecheckCachingSupport(ctx, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))))) got, err := r.ListSessions(ctx, at.Id) require.NoError(t, err) @@ -1314,9 +1346,9 @@ func TestRecheckCachingSupport(t *testing.T) { assert.False(t, got.Incomplete) err = rs.Refresh(ctx, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) assert.ErrorIs(t, err, ErrRefreshNotSupported) got, err = r.ListSessions(ctx, at.Id) @@ -1327,9 +1359,9 @@ func TestRecheckCachingSupport(t *testing.T) { assert.False(t, got.Incomplete) assert.NoError(t, rs.RecheckCachingSupport(ctx, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))))) got, err = r.ListSessions(ctx, at.Id) require.NoError(t, err) assert.Empty(t, got.Targets) @@ -1348,9 +1380,9 @@ func TestRecheckCachingSupport(t *testing.T) { require.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id})) assert.NoError(t, rs.RecheckCachingSupport(ctx, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))))) got, err := r.ListResolvableAliases(ctx, at.Id) require.NoError(t, err) @@ -1360,9 +1392,9 @@ func TestRecheckCachingSupport(t *testing.T) { assert.False(t, got.Incomplete) err = rs.Refresh(ctx, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) assert.ErrorIs(t, err, ErrRefreshNotSupported) got, err = r.ListResolvableAliases(ctx, at.Id) @@ -1373,9 +1405,9 @@ func TestRecheckCachingSupport(t *testing.T) { assert.False(t, got.Incomplete) assert.NoError(t, rs.RecheckCachingSupport(ctx, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))))) got, err = r.ListResolvableAliases(ctx, at.Id) require.NoError(t, err) assert.Empty(t, got.Targets) @@ -1394,15 +1426,15 @@ func TestRecheckCachingSupport(t *testing.T) { require.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id})) err = rs.Refresh(ctx, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) assert.ErrorIs(t, err, ErrRefreshNotSupported) innerErr := errors.New("test error") err = rs.RecheckCachingSupport(ctx, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(func(ctx context.Context, addr, token string, refreshTok RefreshTokenValue) ([]*targets.Target, []string, RefreshTokenValue, error) { require.Equal(t, boundaryAddr, addr) require.Equal(t, at.Token, token) @@ -1411,8 +1443,8 @@ func TestRecheckCachingSupport(t *testing.T) { assert.ErrorContains(t, err, innerErr.Error()) err = rs.RecheckCachingSupport(ctx, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(func(ctx context.Context, addr, token string, refreshTok RefreshTokenValue) ([]*targets.Target, []string, RefreshTokenValue, error) { require.Equal(t, boundaryAddr, addr) require.Equal(t, at.Token, token) @@ -1431,9 +1463,9 @@ func TestRecheckCachingSupport(t *testing.T) { require.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id})) err = rs.Refresh(ctx, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) assert.ErrorIs(t, err, ErrRefreshNotSupported) // Remove the token from the keyring, see that we can still see the @@ -1450,8 +1482,8 @@ func TestRecheckCachingSupport(t *testing.T) { assert.Len(t, us, 1) err = rs.RecheckCachingSupport(ctx, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)))) assert.NoError(t, err) diff --git a/internal/clientcache/internal/cache/repository_implicit_scopes_test.go b/internal/clientcache/internal/cache/repository_implicit_scopes_test.go index ca5f5c0b64..6363f4f96f 100644 --- a/internal/clientcache/internal/cache/repository_implicit_scopes_test.go +++ b/internal/clientcache/internal/cache/repository_implicit_scopes_test.go @@ -100,7 +100,7 @@ func TestRepository_ImplicitScopes(t *testing.T) { }, } require.NoError(t, r.refreshSessions(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{ss}, [][]string{nil})))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{ss}, [][]string{nil}))))) expectedScopes = append(expectedScopes, &scopes.Scope{ Id: ss[0].ScopeId, diff --git a/internal/clientcache/internal/cache/repository_resolvable_aliases.go b/internal/clientcache/internal/cache/repository_resolvable_aliases.go index 39e3bf5f36..c0578e88e4 100644 --- a/internal/clientcache/internal/cache/repository_resolvable_aliases.go +++ b/internal/clientcache/internal/cache/repository_resolvable_aliases.go @@ -22,35 +22,45 @@ import ( // ResolvableAliasRetrievalFunc is a function that retrieves aliases // from the provided boundary addr using the provided token. -type ResolvableAliasRetrievalFunc func(ctx context.Context, addr, authTok, userId string, refreshTok RefreshTokenValue) (ret []*aliases.Alias, removedIds []string, refreshToken RefreshTokenValue, err error) +type ResolvableAliasRetrievalFunc func(ctx context.Context, addr, authTok, userId string, refreshTok RefreshTokenValue, inPage *aliases.AliasListResult, opt ...Option) (ret *aliases.AliasListResult, refreshToken RefreshTokenValue, err error) -func defaultResolvableAliasFunc(ctx context.Context, addr, authTok, userId string, refreshTok RefreshTokenValue) ([]*aliases.Alias, []string, RefreshTokenValue, error) { +func defaultResolvableAliasFunc(ctx context.Context, addr, authTok, userId string, refreshTok RefreshTokenValue, inPage *aliases.AliasListResult, opt ...Option) (*aliases.AliasListResult, RefreshTokenValue, error) { const op = "cache.defaultResolvableAliasFunc" conf, err := api.DefaultConfig() if err != nil { - return nil, nil, "", errors.Wrap(ctx, err, op) + return nil, "", errors.Wrap(ctx, err, op) + } + opts, err := getOpts(opt...) + if err != nil { + return nil, "", errors.Wrap(ctx, err, op) } conf.Addr = addr conf.Token = authTok client, err := api.NewClient(conf) if err != nil { - return nil, nil, "", errors.Wrap(ctx, err, op) + return nil, "", errors.Wrap(ctx, err, op) } aClient := users.NewClient(client) - l, err := aClient.ListResolvableAliases(ctx, userId, users.WithListToken(string(refreshTok))) + var l *aliases.AliasListResult + switch inPage { + case nil: + l, err = aClient.ListResolvableAliases(ctx, userId, users.WithRecursive(true), users.WithListToken(string(refreshTok)), users.WithClientDirectedPagination(!opts.withUseNonPagedListing)) + default: + l, err = aClient.ListResolvableAliasesNextPage(ctx, userId, inPage, users.WithListToken(string(refreshTok))) + } if err != nil { if api.ErrInvalidListToken.Is(err) { - return nil, nil, "", err + return nil, "", err } - return nil, nil, "", errors.Wrap(ctx, err, op) + return nil, "", errors.Wrap(ctx, err, op) } if l.ResponseType == "" { - return nil, nil, "", ErrRefreshNotSupported + return nil, "", ErrRefreshNotSupported } - return l.Items, l.RemovedIds, RefreshTokenValue(l.ListToken), nil + return l, RefreshTokenValue(l.ListToken), nil } -// refreshResolvableAliases attempts to refresh the resolvabl aliases for the +// refreshResolvableAliases attempts to refresh the resolvable aliases for the // provided user using the provided tokens. If available, it uses the refresh // tokens in storage to retrieve and apply only the delta. func (r *Repository) refreshResolvableAliases(ctx context.Context, u *user, tokens map[AuthToken]string, opt ...Option) error { @@ -83,13 +93,13 @@ func (r *Repository) refreshResolvableAliases(ctx context.Context, u *user, toke // Find and use a token for retrieving aliases var gotResponse bool - var resp []*aliases.Alias + var currentPage *aliases.AliasListResult var newRefreshToken RefreshTokenValue + var foundAuthToken string var unsupportedCacheRequest bool - var removedIds []string var retErr error for at, t := range tokens { - resp, removedIds, newRefreshToken, err = opts.withResolvableAliasRetrievalFunc(ctx, u.Address, t, u.Id, oldRefreshTokenVal) + currentPage, newRefreshToken, err = opts.withResolvableAliasRetrievalFunc(ctx, u.Address, t, u.Id, oldRefreshTokenVal, currentPage) if api.ErrInvalidListToken.Is(err) { event.WriteSysEvent(ctx, op, "old list token is no longer valid, starting new initial fetch", "user_id", u.Id) if err := r.deleteRefreshToken(ctx, u, resourceType); err != nil { @@ -97,7 +107,7 @@ func (r *Repository) refreshResolvableAliases(ctx context.Context, u *user, toke } // try again without the refresh token oldRefreshToken = nil - resp, removedIds, newRefreshToken, err = opts.withResolvableAliasRetrievalFunc(ctx, u.Address, t, u.Id, "") + currentPage, newRefreshToken, err = opts.withResolvableAliasRetrievalFunc(ctx, u.Address, t, u.Id, "", currentPage) } if err != nil { if err == ErrRefreshNotSupported { @@ -107,6 +117,7 @@ func (r *Repository) refreshResolvableAliases(ctx context.Context, u *user, toke continue } } + foundAuthToken = t gotResponse = true break } @@ -121,44 +132,57 @@ func (r *Repository) refreshResolvableAliases(ctx context.Context, u *user, toke } var numDeleted int - _, err = r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(_ db.Reader, w db.Writer) error { - var err error - switch { - case oldRefreshToken == nil: - if numDeleted, err = w.Exec(ctx, "delete from resolvable_alias where fk_user_id = @fk_user_id", - []any{sql.Named("fk_user_id", u.Id)}); err != nil { - return err - } - case len(removedIds) > 0: - if numDeleted, err = w.Exec(ctx, "delete from resolvable_alias where id in @ids", - []any{sql.Named("ids", removedIds)}); err != nil { - return err - } - } - switch { - case unsupportedCacheRequest: - if err := upsertRefreshToken(ctx, w, u, resourceType, sentinelNoRefreshToken); err != nil { - return err + var numUpserted int + var clearPerformed bool + for { + _, err = r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(_ db.Reader, w db.Writer) error { + var err error + if (oldRefreshToken == nil || unsupportedCacheRequest) && !clearPerformed { + if numDeleted, err = w.Exec(ctx, "delete from resolvable_alias where fk_user_id = @fk_user_id", + []any{sql.Named("fk_user_id", u.Id)}); err != nil { + return err + } + clearPerformed = true } - case newRefreshToken != "": - if err := upsertResolvableAliases(ctx, w, u, resp); err != nil { - return err + switch { + case unsupportedCacheRequest: + if err := upsertRefreshToken(ctx, w, u, resourceType, sentinelNoRefreshToken); err != nil { + return err + } + case newRefreshToken != "": + numUpserted += len(currentPage.Items) + if err := upsertResolvableAliases(ctx, w, u, currentPage.Items); err != nil { + return err + } + if err := upsertRefreshToken(ctx, w, u, resourceType, newRefreshToken); err != nil { + return err + } + default: + // controller supports caching, but doesn't have any resources } - if err := upsertRefreshToken(ctx, w, u, resourceType, newRefreshToken); err != nil { - return err + if !unsupportedCacheRequest && len(currentPage.RemovedIds) > 0 { + if numDeleted, err = w.Exec(ctx, "delete from resolvable_alias where id in @ids", + []any{sql.Named("ids", currentPage.RemovedIds)}); err != nil { + return err + } } - default: - // controller supports caching, but doesn't have any resources + return nil + }) + if unsupportedCacheRequest || currentPage.ResponseType == "" || currentPage.ResponseType == "complete" { + break } - return nil - }) + currentPage, newRefreshToken, err = opts.withResolvableAliasRetrievalFunc(ctx, u.Address, foundAuthToken, u.Id, newRefreshToken, currentPage) + if err != nil { + break + } + } if err != nil { return errors.Wrap(ctx, err, op) } if unsupportedCacheRequest { return ErrRefreshNotSupported } - event.WriteSysEvent(ctx, op, "resolvable-aliases updated", "deleted", numDeleted, "upserted", len(resp), "user_id", u.Id) + event.WriteSysEvent(ctx, op, "resolvable-aliases updated", "deleted", numDeleted, "upserted", numUpserted, "user_id", u.Id) return nil } @@ -187,12 +211,12 @@ func (r *Repository) checkCachingResolvableAliases(ctx context.Context, u *user, // Find and use a token for retrieving aliases var gotResponse bool - var resp []*aliases.Alias + var resp *aliases.AliasListResult var newRefreshToken RefreshTokenValue var unsupportedCacheRequest bool var retErr error for at, t := range tokens { - resp, _, newRefreshToken, err = opts.withResolvableAliasRetrievalFunc(ctx, u.Address, t, u.Id, "") + resp, newRefreshToken, err = opts.withResolvableAliasRetrievalFunc(ctx, u.Address, t, u.Id, "", nil, WithUseNonPagedListing(true)) if err != nil { if err == ErrRefreshNotSupported { unsupportedCacheRequest = true @@ -227,7 +251,7 @@ func (r *Repository) checkCachingResolvableAliases(ctx context.Context, u *user, []any{sql.Named("fk_user_id", u.Id)}); err != nil { return err } - if err := upsertResolvableAliases(ctx, w, u, resp); err != nil { + if err := upsertResolvableAliases(ctx, w, u, resp.Items); err != nil { return err } if err := upsertRefreshToken(ctx, w, u, resourceType, newRefreshToken); err != nil { @@ -248,7 +272,7 @@ func (r *Repository) checkCachingResolvableAliases(ctx context.Context, u *user, if unsupportedCacheRequest { return ErrRefreshNotSupported } - event.WriteSysEvent(ctx, op, "resolvable-aliases updated", "deleted", numDeleted, "upserted", len(resp), "user_id", u.Id) + event.WriteSysEvent(ctx, op, "resolvable-aliases updated", "deleted", numDeleted, "upserted", len(resp.Items), "user_id", u.Id) return nil } diff --git a/internal/clientcache/internal/cache/repository_resolvable_aliases_test.go b/internal/clientcache/internal/cache/repository_resolvable_aliases_test.go index 086bf17c4c..3096d1d5f7 100644 --- a/internal/clientcache/internal/cache/repository_resolvable_aliases_test.go +++ b/internal/clientcache/internal/cache/repository_resolvable_aliases_test.go @@ -139,7 +139,7 @@ func TestRepository_refreshAliases(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { err := r.refreshResolvableAliases(ctx, tc.u, map[AuthToken]string{{Id: "id"}: "something"}, - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId(t, [][]*aliases.Alias{tc.al}, [][]string{nil}))) + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId(t, [][]*aliases.Alias{tc.al}, [][]string{nil})))) if tc.errorContains == "" { assert.NoError(t, err) rw := db.New(s) @@ -218,7 +218,7 @@ func TestRepository_RefreshAliases_withRefreshTokens(t *testing.T) { } err = r.refreshResolvableAliases(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId(t, ss, [][]string{nil, nil}))) + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId(t, ss, [][]string{nil, nil})))) assert.NoError(t, err) got, err := r.ListResolvableAliases(ctx, at.Id) @@ -228,7 +228,7 @@ func TestRepository_RefreshAliases_withRefreshTokens(t *testing.T) { // Refreshing again uses the refresh token and get additional aliases, appending // them to the response err = r.refreshResolvableAliases(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId(t, ss, [][]string{nil, nil}))) + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId(t, ss, [][]string{nil, nil})))) assert.NoError(t, err) got, err = r.ListResolvableAliases(ctx, at.Id) @@ -238,7 +238,7 @@ func TestRepository_RefreshAliases_withRefreshTokens(t *testing.T) { // Refreshing again wont return any more resources, but also none should be // removed require.NoError(t, r.refreshResolvableAliases(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId(t, ss, [][]string{nil, nil})))) + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId(t, ss, [][]string{nil, nil}))))) assert.NoError(t, err) got, err = r.ListResolvableAliases(ctx, at.Id) @@ -247,7 +247,7 @@ func TestRepository_RefreshAliases_withRefreshTokens(t *testing.T) { // Refresh again with the refresh token being reported as invalid. require.NoError(t, r.refreshResolvableAliases(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, - WithAliasRetrievalFunc(testErroringForRefreshTokenRetrievalFuncForId(t, ss[0])))) + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testErroringForRefreshTokenRetrievalFuncForId(t, ss[0]))))) assert.NoError(t, err) got, err = r.ListResolvableAliases(ctx, at.Id) @@ -328,7 +328,7 @@ func TestRepository_ListAliases(t *testing.T) { }, } require.NoError(t, r.refreshResolvableAliases(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId(t, [][]*aliases.Alias{ss}, [][]string{nil})))) + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId(t, [][]*aliases.Alias{ss}, [][]string{nil}))))) t.Run("wrong user gets no aliases", func(t *testing.T) { l, err := r.ListResolvableAliases(ctx, kt2.AuthTokenId) @@ -372,7 +372,7 @@ func TestRepository_ListAliasesLimiting(t *testing.T) { ts = append(ts, alias("s"+strconv.Itoa(i))) } require.NoError(t, r.refreshResolvableAliases(ctx, u, map[AuthToken]string{{Id: "id"}: "something"}, - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId(t, [][]*aliases.Alias{ts}, [][]string{nil})))) + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId(t, [][]*aliases.Alias{ts}, [][]string{nil}))))) searchService, err := NewSearchService(ctx, r) require.NoError(t, err) @@ -498,7 +498,7 @@ func TestRepository_QueryAliases(t *testing.T) { }, } require.NoError(t, r.refreshResolvableAliases(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId(t, [][]*aliases.Alias{ss}, [][]string{nil})))) + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId(t, [][]*aliases.Alias{ss}, [][]string{nil}))))) t.Run("wrong token gets no aliases", func(t *testing.T) { l, err := r.QueryResolvableAliases(ctx, kt2.AuthTokenId, query) @@ -542,7 +542,7 @@ func TestRepository_QueryResolvableAliasesLimiting(t *testing.T) { ts = append(ts, alias("s"+strconv.Itoa(i))) } require.NoError(t, r.refreshResolvableAliases(ctx, u, map[AuthToken]string{{Id: "id"}: "something"}, - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId(t, [][]*aliases.Alias{ts}, [][]string{nil})))) + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId(t, [][]*aliases.Alias{ts}, [][]string{nil}))))) searchService, err := NewSearchService(ctx, r) require.NoError(t, err) @@ -593,15 +593,15 @@ func TestDefaultAliasRetrievalFunc(t *testing.T) { require.NoError(t, err) require.NotNil(t, tar1) - got, removed, refTok, err := defaultResolvableAliasFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token, tc.Token().UserId, "") + got, refTok, err := defaultResolvableAliasFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token, tc.Token().UserId, "", nil) assert.NoError(t, err) assert.NotEmpty(t, refTok) - assert.Empty(t, removed) - assert.Len(t, got, 1) + assert.Empty(t, got.RemovedIds) + assert.Len(t, got.Items, 1) - got2, removed2, refTok2, err := defaultResolvableAliasFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token, tc.Token().UserId, refTok) + got2, refTok2, err := defaultResolvableAliasFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token, tc.Token().UserId, refTok, nil) assert.NoError(t, err) assert.NotEmpty(t, refTok2) - assert.Empty(t, removed2) - assert.Empty(t, got2) + assert.Empty(t, got2.RemovedIds) + assert.Empty(t, got2.Items) } diff --git a/internal/clientcache/internal/cache/repository_sessions.go b/internal/clientcache/internal/cache/repository_sessions.go index 904202252a..9d53498316 100644 --- a/internal/clientcache/internal/cache/repository_sessions.go +++ b/internal/clientcache/internal/cache/repository_sessions.go @@ -21,32 +21,42 @@ import ( // SessionRetrievalFunc is a function that retrieves sessions // from the provided boundary addr using the provided token. -type SessionRetrievalFunc func(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue) (ret []*sessions.Session, removedIds []string, refreshToken RefreshTokenValue, err error) +type SessionRetrievalFunc func(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue, inPage *sessions.SessionListResult, opt ...Option) (ret *sessions.SessionListResult, refreshToken RefreshTokenValue, err error) -func defaultSessionFunc(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue) ([]*sessions.Session, []string, RefreshTokenValue, error) { +func defaultSessionFunc(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue, inPage *sessions.SessionListResult, opt ...Option) (ret *sessions.SessionListResult, refreshToken RefreshTokenValue, err error) { const op = "cache.defaultSessionFunc" conf, err := api.DefaultConfig() if err != nil { - return nil, nil, "", errors.Wrap(ctx, err, op) + return nil, "", errors.Wrap(ctx, err, op) + } + opts, err := getOpts(opt...) + if err != nil { + return nil, "", errors.Wrap(ctx, err, op) } conf.Addr = addr conf.Token = authTok client, err := api.NewClient(conf) if err != nil { - return nil, nil, "", errors.Wrap(ctx, err, op) + return nil, "", errors.Wrap(ctx, err, op) } sClient := sessions.NewClient(client) - l, err := sClient.List(ctx, "global", sessions.WithIncludeTerminated(true), sessions.WithRecursive(true), sessions.WithListToken(string(refreshTok))) + var l *sessions.SessionListResult + switch inPage { + case nil: + l, err = sClient.List(ctx, "global", sessions.WithIncludeTerminated(true), sessions.WithRecursive(true), sessions.WithListToken(string(refreshTok)), sessions.WithClientDirectedPagination(!opts.withUseNonPagedListing)) + default: + l, err = sClient.ListNextPage(ctx, inPage, sessions.WithListToken(string(refreshTok))) + } if err != nil { if api.ErrInvalidListToken.Is(err) { - return nil, nil, "", err + return nil, "", err } - return nil, nil, "", errors.Wrap(ctx, err, op) + return nil, "", errors.Wrap(ctx, err, op) } if l.ResponseType == "" { - return nil, nil, "", ErrRefreshNotSupported + return nil, "", ErrRefreshNotSupported } - return l.Items, l.RemovedIds, RefreshTokenValue(l.ListToken), nil + return l, RefreshTokenValue(l.ListToken), nil } // refreshSessions uses attempts to refresh the sessions for the provided user @@ -59,8 +69,6 @@ func (r *Repository) refreshSessions(ctx context.Context, u *user, tokens map[Au return errors.New(ctx, errors.InvalidParameter, op, "user is nil") case u.Id == "": return errors.New(ctx, errors.InvalidParameter, op, "user id is missing") - case u.Address == "": - return errors.New(ctx, errors.InvalidParameter, op, "user boundary address is missing") } const resourceType = sessionResourceType @@ -82,13 +90,13 @@ func (r *Repository) refreshSessions(ctx context.Context, u *user, tokens map[Au // Find and use a token for retrieving sessions var gotResponse bool - var resp []*sessions.Session + var currentPage *sessions.SessionListResult var newRefreshToken RefreshTokenValue + var foundAuthToken string var unsupportedCacheRequest bool - var removedIds []string var retErr error for at, t := range tokens { - resp, removedIds, newRefreshToken, err = opts.withSessionRetrievalFunc(ctx, u.Address, t, oldRefreshTokenVal) + currentPage, newRefreshToken, err = opts.withSessionRetrievalFunc(ctx, u.Address, t, oldRefreshTokenVal, currentPage) if api.ErrInvalidListToken.Is(err) { event.WriteSysEvent(ctx, op, "old list token is no longer valid, starting new initial fetch", "user_id", u.Id) if err := r.deleteRefreshToken(ctx, u, resourceType); err != nil { @@ -96,7 +104,7 @@ func (r *Repository) refreshSessions(ctx context.Context, u *user, tokens map[Au } // try again without the refresh token oldRefreshToken = nil - resp, removedIds, newRefreshToken, err = opts.withSessionRetrievalFunc(ctx, u.Address, t, "") + currentPage, newRefreshToken, err = opts.withSessionRetrievalFunc(ctx, u.Address, t, "", currentPage) } if err != nil { if err == ErrRefreshNotSupported { @@ -106,6 +114,7 @@ func (r *Repository) refreshSessions(ctx context.Context, u *user, tokens map[Au continue } } + foundAuthToken = t gotResponse = true break } @@ -120,44 +129,56 @@ func (r *Repository) refreshSessions(ctx context.Context, u *user, tokens map[Au } var numDeleted int - _, err = r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(_ db.Reader, w db.Writer) error { - var err error - switch { - case oldRefreshToken == nil: - if numDeleted, err = w.Exec(ctx, "delete from session where fk_user_id = @fk_user_id", - []any{sql.Named("fk_user_id", u.Id)}); err != nil { - return err - } - case len(removedIds) > 0: - if numDeleted, err = w.Exec(ctx, "delete from session where id in @ids", - []any{sql.Named("ids", removedIds)}); err != nil { - return err - } - } - switch { - case unsupportedCacheRequest: - if err := upsertRefreshToken(ctx, w, u, resourceType, sentinelNoRefreshToken); err != nil { - return err + var numUpserted int + var clearPerformed bool + for { + _, err = r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(_ db.Reader, w db.Writer) error { + var err error + if (oldRefreshToken == nil || unsupportedCacheRequest) && !clearPerformed { + if numDeleted, err = w.Exec(ctx, "delete from session where fk_user_id = @fk_user_id", + []any{sql.Named("fk_user_id", u.Id)}); err != nil { + return err + } } - case newRefreshToken != "": - if err := upsertSessions(ctx, w, u, resp); err != nil { - return err + switch { + case unsupportedCacheRequest: + if err := upsertRefreshToken(ctx, w, u, resourceType, sentinelNoRefreshToken); err != nil { + return err + } + case newRefreshToken != "": + numUpserted += len(currentPage.Items) + if err := upsertSessions(ctx, w, u, currentPage.Items); err != nil { + return err + } + if err := upsertRefreshToken(ctx, w, u, resourceType, newRefreshToken); err != nil { + return err + } + default: + // controller supports caching, but doesn't have any resources } - if err := upsertRefreshToken(ctx, w, u, resourceType, newRefreshToken); err != nil { - return err + if !unsupportedCacheRequest && len(currentPage.RemovedIds) > 0 { + if numDeleted, err = w.Exec(ctx, "delete from session where id in @ids", + []any{sql.Named("ids", currentPage.RemovedIds)}); err != nil { + return err + } } - default: - // controller supports caching, but doesn't have any resources + return nil + }) + if unsupportedCacheRequest || currentPage.ResponseType == "" || currentPage.ResponseType == "complete" { + break } - return nil - }) + currentPage, newRefreshToken, err = opts.withSessionRetrievalFunc(ctx, u.Address, foundAuthToken, newRefreshToken, currentPage) + if err != nil { + break + } + } if err != nil { return errors.Wrap(ctx, err, op) } if unsupportedCacheRequest { return ErrRefreshNotSupported } - event.WriteSysEvent(ctx, op, "sessions updated", "deleted", numDeleted, "upserted", len(resp), "user_id", u.Id) + event.WriteSysEvent(ctx, op, "sessions updated", "deleted", numDeleted, "upserted", numUpserted, "user id", u.Id) return nil } @@ -186,12 +207,12 @@ func (r *Repository) checkCachingSessions(ctx context.Context, u *user, tokens m // Find and use a token for retrieving sessions var gotResponse bool - var resp []*sessions.Session + var resp *sessions.SessionListResult var newRefreshToken RefreshTokenValue var unsupportedCacheRequest bool var retErr error for at, t := range tokens { - resp, _, newRefreshToken, err = opts.withSessionRetrievalFunc(ctx, u.Address, t, "") + resp, newRefreshToken, err = opts.withSessionRetrievalFunc(ctx, u.Address, t, "", nil, WithUseNonPagedListing(true)) if err != nil { if err == ErrRefreshNotSupported { unsupportedCacheRequest = true @@ -217,24 +238,29 @@ func (r *Repository) checkCachingSessions(ctx context.Context, u *user, tokens m _, err = r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(reader db.Reader, w db.Writer) error { switch { case unsupportedCacheRequest: + // Since we know the controller doesn't support caching, we mark the + // user as unable to cache the data. if err := upsertRefreshToken(ctx, w, u, resourceType, sentinelNoRefreshToken); err != nil { return err } case newRefreshToken != "": + // Now that there is a refresh token, the data can be cached, so + // cache it and store the refresh token for future refreshes. First + // remove any values, then add the new ones var err error if numDeleted, err = w.Exec(ctx, "delete from session where fk_user_id = @fk_user_id", []any{sql.Named("fk_user_id", u.Id)}); err != nil { return err } - if err := upsertSessions(ctx, w, u, resp); err != nil { + if err := upsertSessions(ctx, w, u, resp.Items); err != nil { return err } if err := upsertRefreshToken(ctx, w, u, resourceType, newRefreshToken); err != nil { return err } default: - // This is no longer flagged as not supported, but we dont have a - // refresh token so clear out any refresh token we have stored. + // We know the controller supports caching, but doesn't have a + // refresh token so clear out any refresh token we have for this resource. if err := deleteRefreshToken(ctx, w, u, resourceType); err != nil { return err } @@ -247,7 +273,7 @@ func (r *Repository) checkCachingSessions(ctx context.Context, u *user, tokens m if unsupportedCacheRequest { return ErrRefreshNotSupported } - event.WriteSysEvent(ctx, op, "sessions updated", "deleted", numDeleted, "upserted", len(resp), "user_id", u.Id) + event.WriteSysEvent(ctx, op, "sessions updated", "deleted", numDeleted, "upserted", len(resp.Items), "user_id", u.Id) return nil } diff --git a/internal/clientcache/internal/cache/repository_sessions_test.go b/internal/clientcache/internal/cache/repository_sessions_test.go index 4627f8cd88..5cc31dca0a 100644 --- a/internal/clientcache/internal/cache/repository_sessions_test.go +++ b/internal/clientcache/internal/cache/repository_sessions_test.go @@ -150,7 +150,7 @@ func TestRepository_refreshSessions(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { err := r.refreshSessions(ctx, tc.u, map[AuthToken]string{{Id: "id"}: "something"}, - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{tc.sess}, [][]string{nil}))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{tc.sess}, [][]string{nil})))) if tc.errorContains == "" { assert.NoError(t, err) rw := db.New(s) @@ -235,7 +235,7 @@ func TestRepository_RefreshSessions_withRefreshTokens(t *testing.T) { } err = r.refreshSessions(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, ss, [][]string{nil, nil}))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc(t, ss, [][]string{nil, nil})))) assert.NoError(t, err) got, err := r.ListSessions(ctx, at.Id) @@ -245,7 +245,7 @@ func TestRepository_RefreshSessions_withRefreshTokens(t *testing.T) { // Refreshing again uses the refresh token and get additional sessions, appending // them to the response err = r.refreshSessions(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, ss, [][]string{nil, nil}))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc(t, ss, [][]string{nil, nil})))) assert.NoError(t, err) got, err = r.ListSessions(ctx, at.Id) @@ -255,7 +255,7 @@ func TestRepository_RefreshSessions_withRefreshTokens(t *testing.T) { // Refreshing again wont return any more resources, but also none should be // removed require.NoError(t, r.refreshSessions(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, ss, [][]string{nil, nil})))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc(t, ss, [][]string{nil, nil}))))) assert.NoError(t, err) got, err = r.ListSessions(ctx, at.Id) @@ -264,7 +264,7 @@ func TestRepository_RefreshSessions_withRefreshTokens(t *testing.T) { // Refresh again with the refresh token being reported as invalid. require.NoError(t, r.refreshSessions(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, - WithSessionRetrievalFunc(testErroringForRefreshTokenRetrievalFunc(t, ss[0])))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testErroringForRefreshTokenRetrievalFunc(t, ss[0]))))) assert.NoError(t, err) got, err = r.ListSessions(ctx, at.Id) @@ -351,7 +351,7 @@ func TestRepository_ListSessions(t *testing.T) { }, } require.NoError(t, r.refreshSessions(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{ss}, [][]string{nil})))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{ss}, [][]string{nil}))))) t.Run("wrong user gets no sessions", func(t *testing.T) { l, err := r.ListSessions(ctx, kt2.AuthTokenId) @@ -395,7 +395,7 @@ func TestRepository_ListSessionsLimiting(t *testing.T) { ts = append(ts, session("s"+strconv.Itoa(i))) } require.NoError(t, r.refreshSessions(ctx, u, map[AuthToken]string{{Id: "id"}: "something"}, - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{ts}, [][]string{nil})))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{ts}, [][]string{nil}))))) searchService, err := NewSearchService(ctx, r) require.NoError(t, err) @@ -527,7 +527,7 @@ func TestRepository_QuerySessions(t *testing.T) { }, } require.NoError(t, r.refreshSessions(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{ss}, [][]string{nil})))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{ss}, [][]string{nil}))))) t.Run("wrong token gets no sessions", func(t *testing.T) { l, err := r.QuerySessions(ctx, kt2.AuthTokenId, query) @@ -571,7 +571,7 @@ func TestRepository_QuerySessionsLimiting(t *testing.T) { ts = append(ts, session("t"+strconv.Itoa(i))) } require.NoError(t, r.refreshSessions(ctx, u, map[AuthToken]string{{Id: "id"}: "something"}, - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{ts}, [][]string{nil})))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{ts}, [][]string{nil}))))) searchService, err := NewSearchService(ctx, r) require.NoError(t, err) @@ -637,15 +637,15 @@ func TestDefaultSessionRetrievalFunc(t *testing.T) { } } - got, removed, refTok, err := defaultSessionFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token, "") + got, refTok, err := defaultSessionFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token, "", nil) assert.NoError(t, err) assert.NotEmpty(t, refTok) - assert.Empty(t, removed) - assert.Len(t, got, 1) + assert.Empty(t, got.RemovedIds) + assert.Len(t, got.Items, 1) - got2, removed2, refTok2, err := defaultSessionFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token, refTok) + got2, refTok2, err := defaultSessionFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token, refTok, nil) assert.NoError(t, err) assert.NotEmpty(t, refTok2) - assert.Empty(t, removed2) - assert.Empty(t, got2) + assert.Empty(t, got2.RemovedIds) + assert.Empty(t, got2.Items) } diff --git a/internal/clientcache/internal/cache/status_test.go b/internal/clientcache/internal/cache/status_test.go index 67e34fc2ab..34b3d2fb2e 100644 --- a/internal/clientcache/internal/cache/status_test.go +++ b/internal/clientcache/internal/cache/status_test.go @@ -203,7 +203,7 @@ func TestStatus(t *testing.T) { session("3"), } err := r.refreshSessions(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{sess}, [][]string{nil}))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{sess}, [][]string{nil})))) require.NoError(t, err) als := []*aliases.Alias{ @@ -212,7 +212,7 @@ func TestStatus(t *testing.T) { alias("3"), } err = r.refreshResolvableAliases(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId(t, [][]*aliases.Alias{als}, [][]string{nil}))) + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId(t, [][]*aliases.Alias{als}, [][]string{nil})))) require.NoError(t, err) got, err := ss.Status(ctx) @@ -308,7 +308,7 @@ func TestStatus_unsupported(t *testing.T) { })) err = r.refreshResolvableAliases(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))) + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)))) require.ErrorIs(t, err, ErrRefreshNotSupported) err = r.refreshTargets(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, @@ -316,7 +316,7 @@ func TestStatus_unsupported(t *testing.T) { require.ErrorIs(t, err, ErrRefreshNotSupported) err = r.refreshSessions(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) require.ErrorIs(t, err, ErrRefreshNotSupported) got, err := ss.Status(ctx) diff --git a/internal/clientcache/internal/daemon/testing.go b/internal/clientcache/internal/daemon/testing.go index 0689d9cd4b..ffd24d7f59 100644 --- a/internal/clientcache/internal/daemon/testing.go +++ b/internal/clientcache/internal/daemon/testing.go @@ -91,11 +91,13 @@ func (s *TestServer) AddResources(t *testing.T, p *authtokens.AuthToken, alts [] r, err := cache.NewRepository(ctx, s.CacheServer.store.Load(), &sync.Map{}, s.cmd.ReadTokenFromKeyring, atReadFn) require.NoError(t, err) - altFn := func(ctx context.Context, _, tok, _ string, _ cache.RefreshTokenValue) ([]*aliases.Alias, []string, cache.RefreshTokenValue, error) { + altFn := func(ctx context.Context, _ string, tok, _ string, _ cache.RefreshTokenValue, inPage *aliases.AliasListResult, opt ...cache.Option) (*aliases.AliasListResult, cache.RefreshTokenValue, error) { if tok != p.Token { - return nil, nil, "", nil + return nil, "", nil } - return alts, nil, "addedaliases", nil + return &aliases.AliasListResult{ + Items: alts, + }, "addedaliases", nil } tarFn := func(ctx context.Context, _ string, tok string, _ cache.RefreshTokenValue, inPage *targets.TargetListResult, opt ...cache.Option) (*targets.TargetListResult, cache.RefreshTokenValue, error) { if tok != p.Token { @@ -105,11 +107,13 @@ func (s *TestServer) AddResources(t *testing.T, p *authtokens.AuthToken, alts [] Items: tars, }, "addedtargets", nil } - sessFn := func(ctx context.Context, _, tok string, _ cache.RefreshTokenValue) ([]*sessions.Session, []string, cache.RefreshTokenValue, error) { + sessFn := func(ctx context.Context, _, tok string, _ cache.RefreshTokenValue, inPage *sessions.SessionListResult, opt ...cache.Option) (*sessions.SessionListResult, cache.RefreshTokenValue, error) { if tok != p.Token { - return nil, nil, "", nil + return nil, "", nil } - return sess, nil, "addedsessions", nil + return &sessions.SessionListResult{ + Items: sess, + }, "addedsessions", nil } rs, err := cache.NewRefreshService(ctx, r, hclog.NewNullLogger(), 0, 0) require.NoError(t, err) @@ -136,11 +140,15 @@ func (s *TestServer) AddUnsupportedCachingData(t *testing.T, p *authtokens.AuthT }, }, "", cache.ErrRefreshNotSupported } - sessFn := func(ctx context.Context, _, tok string, _ cache.RefreshTokenValue) ([]*sessions.Session, []string, cache.RefreshTokenValue, error) { + sessFn := func(ctx context.Context, _, tok string, _ cache.RefreshTokenValue, inPage *sessions.SessionListResult, opt ...cache.Option) (*sessions.SessionListResult, cache.RefreshTokenValue, error) { if tok != p.Token { - return nil, nil, "", nil + return &sessions.SessionListResult{}, "", nil } - return []*sessions.Session{}, nil, "", cache.ErrRefreshNotSupported + return &sessions.SessionListResult{ + Items: []*sessions.Session{ + {Id: "s_unsupported"}, + }, + }, "", cache.ErrRefreshNotSupported } rs, err := cache.NewRefreshService(ctx, r, hclog.NewNullLogger(), 0, 0) require.NoError(t, err)