diff --git a/cmd/nginx-server/main.go b/cmd/nginx-server/main.go index 8aed329..b53d6e2 100644 --- a/cmd/nginx-server/main.go +++ b/cmd/nginx-server/main.go @@ -7,12 +7,15 @@ import ( "os" "path/filepath" "syscall" + "time" // Packages server "github.com/mutablelogic/go-server" ctx "github.com/mutablelogic/go-server/pkg/context" + auth "github.com/mutablelogic/go-server/pkg/handler/auth" nginx "github.com/mutablelogic/go-server/pkg/handler/nginx" router "github.com/mutablelogic/go-server/pkg/handler/router" + tokenjar "github.com/mutablelogic/go-server/pkg/handler/tokenjar" httpserver "github.com/mutablelogic/go-server/pkg/httpserver" logger "github.com/mutablelogic/go-server/pkg/middleware/logger" provider "github.com/mutablelogic/go-server/pkg/provider" @@ -46,6 +49,24 @@ func main() { log.Fatal(err) } + // Token Jar + jar, err := tokenjar.Config{ + DataPath: n.(nginx.Nginx).Config(), + WriteInterval: 30 * time.Second, + }.New() + if err != nil { + log.Fatal(err) + } + + // Auth handler + auth, err := auth.Config{ + TokenJar: jar.(auth.TokenJar), + TokenBytes: 8, + }.New() + if err != nil { + log.Fatal(err) + } + // Location of the FCGI unix socket socket := filepath.Join(n.(nginx.Nginx).Config(), "run/go-server.sock") @@ -58,6 +79,12 @@ func main() { logger.(server.Middleware), }, }, + "auth": { // /api/auth/... + Service: auth.(server.ServiceEndpoints), + Middleware: []server.Middleware{ + logger.(server.Middleware), + }, + }, }, }.New() if err != nil { @@ -75,7 +102,7 @@ func main() { } // Run until we receive an interrupt - provider := provider.NewProvider(logger, n, router, httpserver) + provider := provider.NewProvider(logger, n, jar, auth, router, httpserver) provider.Print(ctx, "Press CTRL+C to exit") if err := provider.Run(ctx); err != nil { log.Fatal(err) diff --git a/pkg/handler/auth/config.go b/pkg/handler/auth/config.go new file mode 100644 index 0000000..d75671e --- /dev/null +++ b/pkg/handler/auth/config.go @@ -0,0 +1,44 @@ +package auth + +import ( + // Packages + server "github.com/mutablelogic/go-server" +) + +//////////////////////////////////////////////////////////////////////////// +// TYPES + +type Config struct { + TokenJar TokenJar `hcl:"token_jar" description:"Persistent storage for tokens"` + TokenBytes int `hcl:"token_bytes" description:"Number of bytes in a token"` +} + +// Check interfaces are satisfied +var _ server.Plugin = Config{} + +//////////////////////////////////////////////////////////////////////////// +// GLOBALS + +const ( + defaultName = "auth-handler" + defaultTokenBytes = 16 + defaultRootNme = "root" +) + +/////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +// Name returns the name of the service +func (Config) Name() string { + return defaultName +} + +// Description returns the description of the service +func (Config) Description() string { + return "token and group management for authentication and authorisation" +} + +// Create a new task from the configuration +func (c Config) New() (server.Task, error) { + return New(c) +} diff --git a/pkg/handler/auth/endpoints.go b/pkg/handler/auth/endpoints.go new file mode 100644 index 0000000..21d8dfb --- /dev/null +++ b/pkg/handler/auth/endpoints.go @@ -0,0 +1,155 @@ +package auth + +import ( + "context" + "net/http" + "regexp" + "strings" + "time" + + // Packages + server "github.com/mutablelogic/go-server" + router "github.com/mutablelogic/go-server/pkg/handler/router" + httprequest "github.com/mutablelogic/go-server/pkg/httprequest" + httpresponse "github.com/mutablelogic/go-server/pkg/httpresponse" +) + +/////////////////////////////////////////////////////////////////////////////// +// GLOBALS + +const ( + jsonIndent = 2 + + // Token should be at least eight bytes (16 chars) + reTokenString = `[a-zA-Z0-9]{16}[a-zA-Z0-9]*` +) + +var ( + reRoot = regexp.MustCompile(`^/?$`) + reToken = regexp.MustCompile(`^/(` + reTokenString + `)/?$`) +) + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS - ENDPOINTS + +// Add endpoints to the router +func (service *auth) AddEndpoints(ctx context.Context, router server.Router) { + // Path: / + // Methods: GET + // Scopes: read // TODO: Add scopes + // Description: Get current set of tokens and groups + router.AddHandlerFuncRe(ctx, reRoot, service.ListTokens, http.MethodGet) + + // Path: / + // Methods: POST + // Scopes: write // TODO: Add scopes + // Description: Create a new token + router.AddHandlerFuncRe(ctx, reRoot, service.CreateToken, http.MethodPost) + + // Path: / + // Methods: GET + // Scopes: read // TODO: Add scopes + // Description: Get a token + router.AddHandlerFuncRe(ctx, reToken, service.GetToken, http.MethodGet) + + // Path: / + // Methods: DELETE, PATCH + // Scopes: write // TODO: Add scopes + // Description: Delete or update a token + router.AddHandlerFuncRe(ctx, reToken, service.UpdateToken, http.MethodDelete, http.MethodPatch) +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Get all tokens +func (service *auth) ListTokens(w http.ResponseWriter, r *http.Request) { + tokens := service.jar.Tokens() + result := make([]*Token, 0, len(tokens)) + for _, token := range tokens { + token.Value = "" + result = append(result, &token) + } + httpresponse.JSON(w, result, http.StatusOK, jsonIndent) +} + +// Get a token +func (service *auth) GetToken(w http.ResponseWriter, r *http.Request) { + urlParameters := router.Params(r.Context()) + token := service.jar.GetWithValue(strings.ToLower(urlParameters[0])) + if token.IsZero() { + httpresponse.Error(w, http.StatusNotFound) + return + } + + // Remove the token value before returning + token.Value = "" + + // Return the token + httpresponse.JSON(w, token, http.StatusOK, jsonIndent) +} + +// Create a token +func (service *auth) CreateToken(w http.ResponseWriter, r *http.Request) { + var req TokenCreate + + // Get the request + if err := httprequest.Read(r, &req); err != nil { + httpresponse.Error(w, http.StatusBadRequest, err.Error()) + return + } + + // Check for a valid name + req.Name = strings.TrimSpace(req.Name) + if req.Name == "" { + httpresponse.Error(w, http.StatusBadRequest, "missing 'name'") + } else if token := service.jar.GetWithName(req.Name); token.IsValid() { + httpresponse.Error(w, http.StatusConflict, "duplicate 'name'") + } + + // Create the token + token := NewToken(req.Name, service.tokenBytes, req.Duration.Duration, req.Scope...) + if !token.IsValid() { + httpresponse.Error(w, http.StatusInternalServerError) + return + } + + // Add the token + if err := service.jar.Create(token); err != nil { + httpresponse.Error(w, http.StatusInternalServerError, err.Error()) + return + } + + // Remove the access_time which doesn't make sense when + // creating a token + token.Time = time.Time{} + + // Return the token + httpresponse.JSON(w, token, http.StatusCreated, jsonIndent) +} + +// Update an existing token +func (service *auth) UpdateToken(w http.ResponseWriter, r *http.Request) { + urlParameters := router.Params(r.Context()) + token := service.jar.GetWithValue(strings.ToLower(urlParameters[0])) + if token.IsZero() { + httpresponse.Error(w, http.StatusNotFound) + return + } + + switch r.Method { + case http.MethodDelete: + if err := service.jar.Delete(token.Value); err != nil { + httpresponse.Error(w, http.StatusInternalServerError, err.Error()) + return + } + default: + // TODO: PATCH + // Patch can be with name, expire_time, scopes + httpresponse.Error(w, http.StatusMethodNotAllowed) + return + } + + // Respond with no content + httpresponse.Empty(w, http.StatusOK) +} diff --git a/pkg/handler/auth/interface.go b/pkg/handler/auth/interface.go new file mode 100644 index 0000000..bd81698 --- /dev/null +++ b/pkg/handler/auth/interface.go @@ -0,0 +1,32 @@ +package auth + +import ( + "context" +) + +type TokenJar interface { + // Run the token jar until cancelled + Run(context.Context) error + + // Return all tokens + Tokens() []Token + + // Return a token from the jar by value, or an invalid token + // if the token is not found. The method should update the access + // time of the token. + GetWithValue(string) Token + + // Return a token from the jar by name, or nil if the token + // is not found. The method should not update the access time + // of the token. + GetWithName(string) Token + + // Put a token into the jar, assuming it does not yet exist. + Create(Token) error + + // Update an existing token in the jar, assuming it already exists. + Update(Token) error + + // Remove a token from the jar, based on key. + Delete(string) error +} diff --git a/pkg/handler/auth/scope.go b/pkg/handler/auth/scope.go new file mode 100644 index 0000000..5657ef8 --- /dev/null +++ b/pkg/handler/auth/scope.go @@ -0,0 +1,8 @@ +package auth + +import "github.com/mutablelogic/go-server/pkg/version" + +var ( + // Root scope allows ANY operation + ScopeRoot = version.GitSource + "scope/root" +) diff --git a/pkg/handler/auth/task.go b/pkg/handler/auth/task.go new file mode 100644 index 0000000..4811772 --- /dev/null +++ b/pkg/handler/auth/task.go @@ -0,0 +1,81 @@ +package auth + +import ( + "context" + + // Packages + server "github.com/mutablelogic/go-server" + "github.com/mutablelogic/go-server/pkg/provider" + + // Namespace imports + . "github.com/djthorpe/go-errors" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type auth struct { + jar TokenJar + tokenBytes int +} + +// Check interfaces are satisfied +var _ server.Task = (*auth)(nil) +var _ server.ServiceEndpoints = (*auth)(nil) + +/////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +// Create a new auth task from the configuration +func New(c Config) (*auth, error) { + task := new(auth) + + // Set token jar + if c.TokenJar == nil { + return nil, ErrInternalAppError.With("missing 'tokenjar'") + } else { + task.jar = c.TokenJar + } + + // Set token bytes + if c.TokenBytes <= 0 { + task.tokenBytes = defaultTokenBytes + } else { + task.tokenBytes = c.TokenBytes + } + + // Return success + return task, nil +} + +///////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Return the label +func (task *auth) Label() string { + // TODO + return defaultName +} + +// Run the task until the context is cancelled +func (task *auth) Run(ctx context.Context) error { + var result error + + // Logger + logger := provider.Logger(ctx) + + // If there are no tokens, then create a "root" token + if tokens := task.jar.Tokens(); len(tokens) == 0 { + token := NewToken(defaultRootNme, task.tokenBytes, 0, ScopeRoot) + logger.Printf(ctx, "Creating root token %q for scope %q", token.Value, ScopeRoot) + if err := task.jar.Create(token); err != nil { + return err + } + } + + // Run the task until cancelled + <-ctx.Done() + + // Return any errors + return result +} diff --git a/pkg/handler/auth/task_test.go b/pkg/handler/auth/task_test.go new file mode 100644 index 0000000..fbb29c6 --- /dev/null +++ b/pkg/handler/auth/task_test.go @@ -0,0 +1,39 @@ +package auth_test + +import ( + "context" + "testing" + "time" + + // Packages + + "github.com/mutablelogic/go-server/pkg/handler/auth" + "github.com/mutablelogic/go-server/pkg/handler/tokenjar" + "github.com/mutablelogic/go-server/pkg/provider" + "github.com/stretchr/testify/assert" +) + +func Test_auth_001(t *testing.T) { + assert := assert.New(t) + + // Create a token jar and auth object + jar, err := tokenjar.New(tokenjar.Config{ + DataPath: t.TempDir(), + }) + assert.NoError(err) + tokens, err := auth.New(auth.Config{ + TokenJar: jar, + }) + assert.NoError(err) + + // Create a provider + provider := provider.NewProvider(jar, tokens) + assert.NotNil(provider) + + // Run the provider + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + // Run the provider + assert.NoError(provider.Run(ctx)) +} diff --git a/pkg/middleware/tokenauth/token.go b/pkg/handler/auth/token.go similarity index 58% rename from pkg/middleware/tokenauth/token.go rename to pkg/handler/auth/token.go index 5df6eaf..ef2d071 100644 --- a/pkg/middleware/tokenauth/token.go +++ b/pkg/handler/auth/token.go @@ -1,4 +1,4 @@ -package tokenauth +package auth import ( "bytes" @@ -6,12 +6,9 @@ import ( "encoding/hex" "encoding/json" "fmt" + "slices" "strconv" "time" - - // Package imports - - slices "golang.org/x/exp/slices" ) ///////////////////////////////////////////////////////////////////// @@ -26,19 +23,27 @@ type Token struct { } type TokenCreate struct { - Duration time.Duration `json:"duration,omitempty"` // Duration of the token, or zero for no expiration - Scope []string `json:"scopes,omitempty"` // Authentication scopes + Name string `json:"name,omitempty"` // Name of the token + Duration duration `json:"duration,omitempty"` // Duration of the token, or zero for no expiration + Scope []string `json:"scopes,omitempty"` // Authentication scopes +} + +type duration struct { + time.Duration } ///////////////////////////////////////////////////////////////////// // LIFECYCLE -func NewToken(length int, duration time.Duration, scope ...string) *Token { +// Create a token of the specified number of bytes, with the specified duration and scope. +// If the duration is zero, the token will not expire. +func NewToken(name string, length int, duration time.Duration, scope ...string) Token { var expire time.Time if duration != 0 { expire = time.Now().Add(duration) } - return &Token{ + return Token{ + Name: name, Value: generateToken(length), Time: time.Now(), Scope: scope, @@ -49,37 +54,50 @@ func NewToken(length int, duration time.Duration, scope ...string) *Token { ///////////////////////////////////////////////////////////////////// // STRINGIFY -func (t *Token) String() string { - str := " 0 { - str += fmt.Sprintf(" scopes=%q", t.Scope) - } - if t.IsValid() { - str += " valid" - } - return str + ">" +func (t Token) String() string { + data, _ := json.MarshalIndent(t, "", " ") + return string(data) } ///////////////////////////////////////////////////////////////////// // PUBLIC METHODS +// Compares token name, value, expiry and scopes +func (t Token) Equals(other Token) bool { + if t.Name != other.Name || t.Value != other.Value || t.Expire != other.Expire { + return false + } + for _, scope := range other.Scope { + if !slices.Contains(t.Scope, scope) { + return false + } + } + for _, scope := range t.Scope { + if !slices.Contains(other.Scope, scope) { + return false + } + } + return true +} + // Return true if the token is valid (not expired) -func (t *Token) IsValid() bool { +func (t Token) IsValid() bool { if t.Expire.IsZero() || t.Expire.After(time.Now()) { return true } return false } +// Return true if the token is a zero token +func (t Token) IsZero() bool { + if t.Name == "" && t.Value == "" && t.Expire.IsZero() && t.Time.IsZero() && len(t.Scope) == 0 { + return true + } + return false +} + // Return true if the token has the specified scope, and is valid -func (t *Token) IsScope(scopes ...string) bool { +func (t Token) IsScope(scopes ...string) bool { if !t.IsValid() { return false } @@ -94,11 +112,9 @@ func (t *Token) IsScope(scopes ...string) bool { ///////////////////////////////////////////////////////////////////// // JSON MARSHAL -func (t *Token) MarshalJSON() ([]byte, error) { +func (t Token) MarshalJSON() ([]byte, error) { var buf bytes.Buffer - if t == nil { - return []byte("null"), nil - } + buf.WriteRune('{') // Write the fields @@ -141,6 +157,31 @@ func (t *Token) MarshalJSON() ([]byte, error) { return buf.Bytes(), nil } +func (d duration) MarshalJSON() ([]byte, error) { + return json.Marshal(d.String()) +} + +func (d *duration) UnmarshalJSON(b []byte) error { + var v any + if err := json.Unmarshal(b, &v); err != nil { + return err + } + switch value := v.(type) { + case float64: + d.Duration = time.Duration(value) * time.Second + return nil + case string: + var err error + d.Duration, err = time.ParseDuration(value) + if err != nil { + return err + } + return nil + default: + return fmt.Errorf("invalid duration of type %T", v) + } +} + ///////////////////////////////////////////////////////////////////// // PRIVATE METHODS diff --git a/pkg/middleware/tokenauth/token_test.go b/pkg/handler/auth/token_test.go similarity index 60% rename from pkg/middleware/tokenauth/token_test.go rename to pkg/handler/auth/token_test.go index d90bee7..6c79c1d 100644 --- a/pkg/middleware/tokenauth/token_test.go +++ b/pkg/handler/auth/token_test.go @@ -1,10 +1,12 @@ -package tokenauth_test +package auth_test import ( + "encoding/json" "testing" "time" - "github.com/mutablelogic/go-server/pkg/middleware/tokenauth" + // Packages + "github.com/mutablelogic/go-server/pkg/handler/auth" "github.com/stretchr/testify/assert" ) @@ -16,8 +18,7 @@ const ( func Test_token_001(t *testing.T) { assert := assert.New(t) // Create a token with 100 bytes and no expiry - token := tokenauth.NewToken(100, 0) - assert.NotNil(token) + token := auth.NewToken("test", 100, 0) assert.Equal(200, len(token.Value)) assert.True(token.IsValid()) t.Log(token) @@ -26,8 +27,7 @@ func Test_token_001(t *testing.T) { func Test_token_002(t *testing.T) { assert := assert.New(t) // Create a token with 100 bytes and expiry in the past - token := tokenauth.NewToken(100, -time.Second) - assert.NotNil(token) + token := auth.NewToken("test", 100, -time.Second) assert.False(token.IsValid()) assert.Equal(200, len(token.Value)) t.Log(token) @@ -36,8 +36,7 @@ func Test_token_002(t *testing.T) { func Test_token_003(t *testing.T) { assert := assert.New(t) // Create a token with 100 bytes and one scope - token := tokenauth.NewToken(100, 0, ScopeRead) - assert.NotNil(token) + token := auth.NewToken("test", 100, 0, ScopeRead) assert.True(token.IsScope(ScopeRead)) assert.False(token.IsScope(ScopeWrite)) t.Log(token) @@ -46,8 +45,18 @@ func Test_token_003(t *testing.T) { func Test_token_004(t *testing.T) { assert := assert.New(t) // Create a token with 100 bytes and one scope - token := tokenauth.NewToken(100, -1, ScopeRead) - assert.NotNil(token) + token := auth.NewToken("test", 100, -1, ScopeRead) assert.False(token.IsScope(ScopeRead)) t.Log(token) } + +func Test_token_005(t *testing.T) { + assert := assert.New(t) + + // Create a token with 100 bytes and one scope + token := auth.NewToken("test", 100, -1, ScopeRead) + bytes, err := json.MarshalIndent(token, "", " ") + assert.NoError(err) + assert.NotNil(bytes) + t.Log(string(bytes)) +} diff --git a/pkg/handler/tokenjar/config.go b/pkg/handler/tokenjar/config.go new file mode 100644 index 0000000..d2216e6 --- /dev/null +++ b/pkg/handler/tokenjar/config.go @@ -0,0 +1,44 @@ +package tokenjar + +import ( + "time" + + // Packages + server "github.com/mutablelogic/go-server" +) + +//////////////////////////////////////////////////////////////////////////// +// TYPES + +type Config struct { + DataPath string `hcl:"datapath" description:"Path to persistent data"` + WriteInterval time.Duration `hcl:"write-interval" description:"Interval to write data to disk"` +} + +// Check interfaces are satisfied +var _ server.Plugin = Config{} + +//////////////////////////////////////////////////////////////////////////// +// GLOBALS + +const ( + defaultName = "tokenjar-handler" +) + +// ///////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +// Name returns the name of the service +func (Config) Name() string { + return defaultName +} + +// Description returns the description of the service +func (Config) Description() string { + return "on-disk token persistence" +} + +// Create a new task from the configuration +func (c Config) New() (server.Task, error) { + return New(c) +} diff --git a/pkg/handler/tokenjar/task.go b/pkg/handler/tokenjar/task.go new file mode 100644 index 0000000..6447324 --- /dev/null +++ b/pkg/handler/tokenjar/task.go @@ -0,0 +1,48 @@ +package tokenjar + +import ( + "context" + "time" + + "github.com/mutablelogic/go-server/pkg/provider" +) + +//////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +func (jar *tokenjar) Label() string { + // TODO + return defaultName +} + +func (jar *tokenjar) Run(ctx context.Context) error { + // Get logger + logger := provider.Logger(ctx) + + // Ticker for writing to disk + ticker := time.NewTicker(jar.writeInterval) + defer ticker.Stop() + + // Loop until cancelled + for { + select { + case <-ticker.C: + if err := jar.write(); err != nil { + logger.Print(ctx, err) + } + case <-ctx.Done(): + return jar.write() + } + } +} + +//////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS + +func (jar *tokenjar) write() error { + if !jar.Modified() { + return nil + } else { + return jar.Write() + } +} diff --git a/pkg/handler/tokenjar/tokenjar.go b/pkg/handler/tokenjar/tokenjar.go new file mode 100644 index 0000000..89b2569 --- /dev/null +++ b/pkg/handler/tokenjar/tokenjar.go @@ -0,0 +1,286 @@ +/* +implements a token jar that stores tokens into memory, and potentially a file +on the file system +*/ +package tokenjar + +import ( + "encoding/json" + "os" + "path/filepath" + "sync" + "time" + + // Package imports + server "github.com/mutablelogic/go-server" + auth "github.com/mutablelogic/go-server/pkg/handler/auth" + + // Namespace imports + . "github.com/djthorpe/go-errors" +) + +//////////////////////////////////////////////////////////////////////////////// +// TYPES + +type tokenjar struct { + sync.RWMutex + + // Set the write interval to persist the tokens to disk + writeInterval time.Duration + + // The filename to persist the tokens to + filename string + + // Tokens keyed by the token value + jar map[string]*auth.Token + + // Modified flag when the persistenr storage is updated + modified bool +} + +var _ auth.TokenJar = (*tokenjar)(nil) +var _ server.Task = (*tokenjar)(nil) + +//////////////////////////////////////////////////////////////////////////////// +// GLOBALS + +const ( + defaultCap = 20 + defaultFilename = "tokenauth.json" + defaultWriteInterval = time.Minute +) + +//////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +// Create a new tokenjar, with the specified path. If the path is empty, +// the tokenjar will be in-memory only. +func New(c Config) (*tokenjar, error) { + j := new(tokenjar) + + // Set filepath for persistent storage + if c.DataPath != "" { + if stat, err := os.Stat(c.DataPath); err != nil { + return nil, err + } else if !stat.IsDir() { + return nil, ErrBadParameter.Withf("not a directory: %v", c.DataPath) + } else { + j.filename = filepath.Join(c.DataPath, defaultFilename) + } + } + + // Set write interval + if c.WriteInterval == 0 { + j.writeInterval = defaultWriteInterval + } else { + j.writeInterval = c.WriteInterval + } + + // Read the tokens from persistent storage + var tokens []*auth.Token + if _, err := os.Stat(j.filename); os.IsNotExist(err) { + // Do nothing + } else if err != nil { + return nil, err + } else if tokens_, err := j.Read(); err != nil { + return nil, err + } else { + tokens = tokens_ + } + + // Create the token jar + j.jar = make(map[string]*auth.Token, len(tokens)+defaultCap) + + // Read persistent tokens, bail if there is an inconsistent file + for _, token := range tokens { + if _, exists := j.jar[token.Value]; exists { + return nil, ErrDuplicateEntry.With(token.Value) + } + if token.IsValid() { + j.jar[token.Value] = token + } else { + j.modified = true + } + } + + // Return success + return j, nil +} + +//////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Return true if the jar has been modified +func (jar *tokenjar) Modified() bool { + jar.RLock() + defer jar.RUnlock() + + return jar.modified +} + +// Return all tokens +func (jar *tokenjar) Tokens() []auth.Token { + var result []auth.Token + + // Lock the jar for read + jar.RLock() + defer jar.RUnlock() + + // Copy the tokens + for _, token := range jar.jar { + result = append(result, *token) + } + + // Return the result + return result +} + +// Return a token from the jar. +// The method should update the access time of the token. +// If token is not found, return an empty token. +func (jar *tokenjar) GetWithValue(key string) auth.Token { + jar.Lock() + defer jar.Unlock() + + if token, ok := jar.jar[key]; ok { + token.Time = time.Now() + jar.modified = true + + // Make a copy of the token before returning + return *token + } else { + // Return an empty token - not found + return auth.Token{} + } +} + +// Return a token from the jar by name. +// The method does not update the access time of the token. +// If token is not found, return an empty token. +func (jar *tokenjar) GetWithName(name string) auth.Token { + jar.RLock() + defer jar.RUnlock() + + for _, token := range jar.jar { + if token.Name == name { + return *token + } + } + + return auth.Token{} +} + +// Put a token into the jar, assuming it does not yet exist. +func (jar *tokenjar) Create(token auth.Token) error { + jar.Lock() + defer jar.Unlock() + + // Check if the token already exists + if token.Value == "" { + return ErrBadParameter + } + if _, ok := jar.jar[token.Value]; ok { + return ErrDuplicateEntry + } + + // Update the token + token.Time = time.Now() + jar.jar[token.Value] = &token + jar.modified = true + + // Return success + return nil +} + +// Update an existing token in the jar, assuming it already exists. +func (jar *tokenjar) Update(token auth.Token) error { + jar.Lock() + defer jar.Unlock() + + // Check if the token already exists + if token.Value == "" { + return ErrBadParameter + } + dest, ok := jar.jar[token.Value] + if !ok { + return ErrNotFound + } + + // Update the token + dest.Name = token.Name + dest.Time = time.Now() + dest.Expire = token.Expire + dest.Scope = append([]string{}, token.Scope...) + jar.modified = true + + // Return success + return nil +} + +// Remove a token from the jar +func (jar *tokenjar) Delete(key string) error { + jar.Lock() + defer jar.Unlock() + + // Check if the token already exists + if _, ok := jar.jar[key]; !ok { + return ErrNotFound + } else { + delete(jar.jar, key) + jar.modified = true + } + + // Return success + return nil +} + +// Write the tokens to persistent storage +func (jar *tokenjar) Write() error { + jar.Lock() + defer jar.Unlock() + + // NOP if there is no filename + if jar.filename == "" { + return nil + } + + // Open the file for writing + w, err := os.Create(jar.filename) + if err != nil { + return err + } + defer w.Close() + + // Write the tokens, unset modified flag + jar.modified = false + if err := json.NewEncoder(w).Encode(jar.jar); err != nil { + return err + } + + // Return success + return nil +} + +// Run the token jar +func (jar *tokenjar) Read() ([]*auth.Token, error) { + // NOP if there is no filename + if jar.filename == "" { + return nil, nil + } + + // Open the file for reading + r, err := os.Open(jar.filename) + if err != nil { + return nil, err + } + defer r.Close() + + // Read the tokens + var tokens []*auth.Token + if err := json.NewDecoder(r).Decode(&tokens); err != nil { + return nil, err + } + + // Return success + return tokens, nil +} diff --git a/pkg/handler/tokenjar/tokenjar_test.go b/pkg/handler/tokenjar/tokenjar_test.go new file mode 100644 index 0000000..77edbcb --- /dev/null +++ b/pkg/handler/tokenjar/tokenjar_test.go @@ -0,0 +1,76 @@ +package tokenjar_test + +import ( + "context" + "sync" + "testing" + + // Packages + "github.com/mutablelogic/go-server/pkg/handler/auth" + "github.com/mutablelogic/go-server/pkg/handler/tokenjar" + "github.com/stretchr/testify/assert" +) + +const ( + ScopeRead = "read" + ScopeWrite = "write" +) + +func Test_tokenjar_001(t *testing.T) { + assert := assert.New(t) + + // Create a persistent token jar + tokens, err := tokenjar.New(tokenjar.Config{ + DataPath: t.TempDir(), + }) + assert.NoError(err) + assert.NotNil(tokens) +} + +func Test_tokenjar_002(t *testing.T) { + assert := assert.New(t) + + // Create a persistent token jar + tokens, err := tokenjar.New(tokenjar.Config{ + DataPath: t.TempDir(), + }) + assert.NoError(err) + assert.NotNil(tokens) + + // Run the token jar + var wg sync.WaitGroup + ctx, cancel := context.WithCancel(context.Background()) + + wg.Add(1) + go func() { + defer wg.Done() + assert.NoError(tokens.Run(ctx)) + }() + + // Add a token + token := auth.NewToken("test", 100, 0) + assert.NoError(tokens.Create(token)) + + // Get a token by value + token2 := tokens.GetWithValue(token.Value) + assert.True(token.Equals(token2)) + + // Get a token by name + token9 := tokens.GetWithName(token.Name) + assert.True(token.Equals(token9)) + + // Update a token + token2.Scope = []string{ScopeRead} + assert.NoError(tokens.Update(token2)) + + token3 := tokens.GetWithValue(token.Value) + assert.NotNil(token3) + assert.True(token3.Equals(token2)) + + // Remove a token + assert.NoError(tokens.Delete(token.Value)) + + // Cancel the context and wait + cancel() + wg.Wait() +}