Skip to content

Commit

Permalink
atlasexec: add version command (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
masseelch authored Sep 5, 2023
1 parent 5cd15f2 commit b09354b
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 19 deletions.
76 changes: 57 additions & 19 deletions atlasexec/atlas.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"io"
"os"
"os/exec"
"regexp"
"strconv"
"strings"
)
Expand Down Expand Up @@ -119,19 +120,19 @@ func NewClient(workingDir, execPath string) (*Client, error) {
}

// Apply runs the 'migrate apply' command.
// @deprecated use MigrateApply instead.
// Deprecated: use MigrateApply instead.
func (c *Client) Apply(ctx context.Context, params *MigrateApplyParams) (*MigrateApply, error) {
return c.MigrateApply(ctx, params)
}

// Lint runs the 'migrate lint' command.
// @deprecated use MigrateLint instead.
// Deprecated: use MigrateLint instead.
func (c *Client) Lint(ctx context.Context, params *MigrateLintParams) (*SummaryReport, error) {
return c.MigrateLint(ctx, params)
}

// Status runs the 'migrate status' command.
// @deprecated use MigrateStatus instead.
// Deprecated: use MigrateStatus instead.
func (c *Client) Status(ctx context.Context, params *MigrateStatusParams) (*MigrateStatus, error) {
return c.MigrateStatus(ctx, params)
}
Expand All @@ -141,7 +142,7 @@ func (c *Client) Login(ctx context.Context, params *LoginParams) error {
if params.Token == "" {
return errors.New("token cannot be empty")
}
_, err := c.runCommand(ctx, []string{"login", "--token", params.Token})
_, err := c.runCommand(ctx, []string{"login", "--token", params.Token}, validJSON)
return err
}

Expand Down Expand Up @@ -184,7 +185,7 @@ func (c *Client) MigratePush(ctx context.Context, params *MigratePushParams) (st
} else {
args = append(args, params.Name)
}
return stringVal(c.runCommand(ctx, args))
return stringVal(c.runCommand(ctx, args, validJSON))
}

// MigrateApply runs the 'migrate apply' command.
Expand Down Expand Up @@ -215,7 +216,7 @@ func (c *Client) MigrateApply(ctx context.Context, params *MigrateApplyParams) (
args = append(args, strconv.FormatUint(params.Amount, 10))
}
args = append(args, params.Vars.AsArgs()...)
return jsonDecode[MigrateApply](c.runCommand(ctx, args))
return jsonDecode[MigrateApply](c.runCommand(ctx, args, validJSON))
}

// SchemaApply runs the 'schema apply' command.
Expand Down Expand Up @@ -248,7 +249,7 @@ func (c *Client) SchemaApply(ctx context.Context, params *SchemaApplyParams) (*S
args = append(args, "--exclude", strings.Join(params.Exclude, ","))
}
args = append(args, params.Vars.AsArgs()...)
return jsonDecode[SchemaApply](c.runCommand(ctx, args))
return jsonDecode[SchemaApply](c.runCommand(ctx, args, validJSON))
}

// SchemaInspect runs the 'schema inspect' command.
Expand Down Expand Up @@ -298,7 +299,7 @@ func (c *Client) MigrateLint(ctx context.Context, params *MigrateLintParams) (*S
args = append(args, "--latest", strconv.FormatUint(params.Latest, 10))
}
args = append(args, params.Vars.AsArgs()...)
return jsonDecode[SummaryReport](c.runCommand(ctx, args))
return jsonDecode[SummaryReport](c.runCommand(ctx, args, validJSON))
}

// MigrateStatus runs the 'migrate status' command.
Expand All @@ -320,12 +321,38 @@ func (c *Client) MigrateStatus(ctx context.Context, params *MigrateStatusParams)
args = append(args, "--revisions-schema", params.RevisionsSchema)
}
args = append(args, params.Vars.AsArgs()...)
return jsonDecode[MigrateStatus](c.runCommand(ctx, args))
return jsonDecode[MigrateStatus](c.runCommand(ctx, args, validJSON))
}

var reVersion = regexp.MustCompile(`^atlas version v(\d+\.\d+.\d+)-?([a-z0-9]*)?`)

// Version runs the 'version' command.
func (c *Client) Version(ctx context.Context) (*Version, error) {
r, err := c.runCommand(ctx, []string{"version"})
if err != nil {
return nil, err
}
out, err := io.ReadAll(r)
if err != nil {
return nil, err
}
v := reVersion.FindSubmatch(out)
if v == nil {
return nil, errors.New("unexpected output format")
}
var sha string
if len(v) > 2 {
sha = string(v[2])
}
return &Version{
Version: string(v[1]),
SHA: sha,
Canary: strings.Contains(string(out), "canary"),
}, nil
}

// runCommand runs the given command and unmarshals the output into the given
// interface.
func (c *Client) runCommand(ctx context.Context, args []string) (io.Reader, error) {
// runCommand runs the given command and returns its output.
func (c *Client) runCommand(ctx context.Context, args []string, vs ...validator) (io.Reader, error) {
var stdout, stderr bytes.Buffer
cmd := exec.CommandContext(ctx, c.execPath, args...)
cmd.Dir = c.workingDir
Expand All @@ -343,13 +370,6 @@ func (c *Client) runCommand(ctx context.Context, args []string) (io.Reader, erro
summary: strings.TrimSpace(stderr.String()),
detail: strings.TrimSpace(stdout.String()),
}
case !json.Valid(stdout.Bytes()):
// When the output is not valid JSON, it means that
// the command failed.
return nil, &cliError{
summary: "Atlas CLI",
detail: strings.TrimSpace(stdout.String()),
}
case cmd.ProcessState.ExitCode() == 1:
// When the exit code is 1, it means that the command
// failed but the output is still valid JSON.
Expand All @@ -362,6 +382,12 @@ func (c *Client) runCommand(ctx context.Context, args []string) (io.Reader, erro
return nil, err
}
}
out := stdout.Bytes()
for _, v := range vs {
if err := v(out); err != nil {
return nil, err
}
}
return &stdout, nil
}

Expand Down Expand Up @@ -470,3 +496,15 @@ func jsonDecode[T any](r io.Reader, err error) (*T, error) {
}
return &dst, nil
}

type validator func([]byte) error

func validJSON(d []byte) error {
if !json.Valid(d) {
return &cliError{
summary: "Atlas CLI",
detail: strings.TrimSpace(string(d)),
}
}
return nil
}
6 changes: 6 additions & 0 deletions atlasexec/atlas_models.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ type (
ErrorStmt string `json:"ErrorStmt,omitempty"` // ErrorStmt is the statement that raised Error.
OperatorVersion string `json:"OperatorVersion"` // OperatorVersion that executed this migration.
}
// Version contains the result of an 'atlas version' run.
Version struct {
Version string `json:"Version"`
SHA string `json:"SHA,omitempty"`
Canary bool `json:"Canary,omitempty"`
}
)

type (
Expand Down
39 changes: 39 additions & 0 deletions atlasexec/atlas_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,45 @@ schema "main" {
`, s)
}

func TestVersion(t *testing.T) {
wd, err := os.Getwd()
require.NoError(t, err)
c, err := atlasexec.NewClient(t.TempDir(), filepath.Join(wd, "./mock-atlas.sh"))
require.NoError(t, err)

for _, tt := range []struct {
env string
expect *atlasexec.Version
}{
{
env: "",
expect: &atlasexec.Version{Version: "1.2.3"},
},
{
env: "v0.14.1-abcdef-canary",
expect: &atlasexec.Version{
Version: "0.14.1",
SHA: "abcdef",
Canary: true,
},
},
{
env: "v11.22.33-sha",
expect: &atlasexec.Version{
Version: "11.22.33",
SHA: "sha",
},
},
} {
t.Run(tt.env, func(t *testing.T) {
t.Setenv("TEST_ATLAS_VERSION", tt.env)
v, err := c.Version(context.Background())
require.NoError(t, err)
require.Equal(t, tt.expect, v)
})
}
}

func sqlitedb(t *testing.T) string {
td := t.TempDir()
dbpath := filepath.Join(td, "file.db")
Expand Down
5 changes: 5 additions & 0 deletions atlasexec/mock-atlas.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/bin/bash

TEST_ATLAS_VERSION="${TEST_ATLAS_VERSION:-v1.2.3}"

echo "atlas version $TEST_ATLAS_VERSION"

0 comments on commit b09354b

Please sign in to comment.