Skip to content

Commit

Permalink
Merge pull request #12 from agoncear-mwb/main
Browse files Browse the repository at this point in the history
Implement QueryRow and Exec methods of sql driver interface
  • Loading branch information
auxten authored Aug 26, 2024
2 parents f3d6a72 + d1572fe commit 671214b
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 7 deletions.
104 changes: 97 additions & 7 deletions chdb/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,60 @@ func init() {
sql.Register("chdb", Driver{})
}

// Row is the result of calling [DB.QueryRow] to select a single row.
type singleRow struct {
// One of these two will be non-nil:
err error // deferred error for easy chaining
rows driver.Rows
}

// Scan copies the columns from the matched row into the values
// pointed at by dest. See the documentation on [Rows.Scan] for details.
// If more than one row matches the query,
// Scan uses the first row and discards the rest. If no row matches
// the query, Scan returns [ErrNoRows].
func (r *singleRow) Scan(dest ...any) error {
if r.err != nil {
return r.err
}
vals := make([]driver.Value, 0)
for _, v := range dest {
vals = append(vals, v)
}
err := r.rows.Next(vals)
if err != nil {
return err
}
// Make sure the query can be processed to completion with no errors.
return r.rows.Close()
}

// Err provides a way for wrapping packages to check for
// query errors without calling [Row.Scan].
// Err returns the error, if any, that was encountered while running the query.
// If this error is not nil, this error will also be returned from [Row.Scan].
func (r *singleRow) Err() error {
return r.err
}

type execResult struct {
err error
}

func (e *execResult) LastInsertId() (int64, error) {
if e.err != nil {
return 0, e.err
}
return -1, fmt.Errorf("does not support LastInsertId")

}
func (e *execResult) RowsAffected() (int64, error) {
if e.err != nil {
return 0, e.err
}
return -1, fmt.Errorf("does not support RowsAffected")
}

type queryHandle func(string, ...string) (*chdbstable.LocalResult, error)

type connector struct {
Expand Down Expand Up @@ -192,6 +246,18 @@ type conn struct {
QueryFun queryHandle
}

func prepareValues(values []driver.Value) []driver.NamedValue {
namedValues := make([]driver.NamedValue, len(values))
for i, value := range values {
namedValues[i] = driver.NamedValue{
// nb: Name field is optional
Ordinal: i,
Value: value,
}
}
return namedValues
}

func (c *conn) Close() error {
return nil
}
Expand All @@ -204,15 +270,39 @@ func (c *conn) SetupQueryFun() {
}

func (c *conn) Query(query string, values []driver.Value) (driver.Rows, error) {
namedValues := make([]driver.NamedValue, len(values))
for i, value := range values {
namedValues[i] = driver.NamedValue{
// nb: Name field is optional
Ordinal: i,
Value: value,
return c.QueryContext(context.Background(), query, prepareValues(values))
}

func (c *conn) QueryRow(query string, values []driver.Value) *singleRow {
return c.QueryRowContext(context.Background(), query, values)
}

func (c *conn) Exec(query string, values []driver.Value) (sql.Result, error) {
return c.ExecContext(context.Background(), query, prepareValues(values))
}

func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
_, err := c.QueryContext(ctx, query, args)
if err != nil && err.Error() != "result is nil" {
return nil, err
}
return &execResult{
err: nil,
}, nil
}

func (c *conn) QueryRowContext(ctx context.Context, query string, values []driver.Value) *singleRow {

v, err := c.QueryContext(ctx, query, prepareValues(values))
if err != nil {
return &singleRow{
err: err,
rows: nil,
}
}
return c.QueryContext(context.Background(), query, namedValues)
return &singleRow{
rows: v,
}
}

func (c *conn) compileArguments(query string, args []driver.NamedValue) (string, error) {
Expand Down
90 changes: 90 additions & 0 deletions chdb/driver/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,93 @@ func TestDbWithSession(t *testing.T) {
count++
}
}

func TestQueryRow(t *testing.T) {
sessionDir, err := os.MkdirTemp("", "unittest-sessiondata")
if err != nil {
t.Fatalf("create temp directory fail, err: %s", err)
}
defer os.RemoveAll(sessionDir)
session, err := chdb.NewSession(sessionDir)
if err != nil {
t.Fatalf("new session fail, err: %s", err)
}
defer session.Cleanup()

session.Query("USE testdb; INSERT INTO testtable VALUES (1), (2), (3);")

ret, err := session.Query("SELECT * FROM testtable;")
if err != nil {
t.Fatalf("Query fail, err: %s", err)
}
if string(ret.Buf()) != "1\n2\n3\n" {
t.Errorf("Query result should be 1\n2\n3\n, got %s", string(ret.Buf()))
}
db, err := sql.Open("chdb", fmt.Sprintf("session=%s", sessionDir))
if err != nil {
t.Fatalf("open db fail, err: %s", err)
}
if db.Ping() != nil {
t.Fatalf("ping db fail, err: %s", err)
}
rows := db.QueryRow("select * from testtable;")

var bar = 0
var count = 1
err = rows.Scan(&bar)
if err != nil {
t.Fatalf("scan fail, err: %s", err)
}
if bar != count {
t.Fatalf("result is not match, want: %d actual: %d", count, bar)
}
err2 := rows.Scan(&bar)
if err2 == nil {
t.Fatalf("QueryRow method should return only one item")
}

}

func TestExec(t *testing.T) {
sessionDir, err := os.MkdirTemp("", "unittest-sessiondata")
if err != nil {
t.Fatalf("create temp directory fail, err: %s", err)
}
defer os.RemoveAll(sessionDir)
session, err := chdb.NewSession(sessionDir)
if err != nil {
t.Fatalf("new session fail, err: %s", err)
}
defer session.Cleanup()
session.Query("CREATE DATABASE IF NOT EXISTS testdb; " +
"CREATE TABLE IF NOT EXISTS testdb.testtable (id UInt32) ENGINE = MergeTree() ORDER BY id;")

db, err := sql.Open("chdb", fmt.Sprintf("session=%s", sessionDir))
if err != nil {
t.Fatalf("open db fail, err: %s", err)
}
if db.Ping() != nil {
t.Fatalf("ping db fail, err: %s", err)
}

_, err = db.Exec("INSERT INTO testdb.testtable VALUES (1), (2), (3);")
if err != nil {
t.Fatalf("exec failed, err: %s", err)
}
rows := db.QueryRow("select * from testdb.testtable;")

var bar = 0
var count = 1
err = rows.Scan(&bar)
if err != nil {
t.Fatalf("scan fail, err: %s", err)
}
if bar != count {
t.Fatalf("result is not match, want: %d actual: %d", count, bar)
}
err2 := rows.Scan(&bar)
if err2 == nil {
t.Fatalf("QueryRow method should return only one item")
}

}

0 comments on commit 671214b

Please sign in to comment.