Skip to content

Commit

Permalink
Vector search respect db idx (#1582)
Browse files Browse the repository at this point in the history
* propagate schema name (iDb) for vector index to support working not only with main DB

* add basic test

* sometimes zDbSName can be null and this is fine

* avoid test from writing files to disk

* build bundles
  • Loading branch information
sivukhin authored Jul 22, 2024
1 parent a9639c3 commit 5e6afb3
Show file tree
Hide file tree
Showing 11 changed files with 508 additions and 259 deletions.
235 changes: 156 additions & 79 deletions libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c

Large diffs are not rendered by default.

235 changes: 156 additions & 79 deletions libsql-ffi/bundled/src/sqlite3.c

Large diffs are not rendered by default.

18 changes: 13 additions & 5 deletions libsql-sqlite3/src/build.c
Original file line number Diff line number Diff line change
Expand Up @@ -3323,6 +3323,7 @@ static void destroyTable(Parse *pParse, Table *pTab){
Pgno iTab = pTab->tnum;
Pgno iDestroyed = 0;
Index *pIdx;
int iDb;

#ifndef SQLITE_OMIT_VECTOR
/*
Expand All @@ -3336,9 +3337,12 @@ static void destroyTable(Parse *pParse, Table *pTab){
* 3. Delete index during the parsing stage (implemented variant) - it's hacky
* and bit dirty but seems to me as pretty safe and easy way to delete index
*/
iDb = sqlite3SchemaToIndex(pParse->db, pTab->pSchema);

for(pIdx=pTab->pIndex; pIdx; pIdx=pIdx->pNext){
if( IsVectorIndex(pIdx) ){
vectorIndexDrop(pParse->db, pIdx->zName);
assert( 0 <= iDb && iDb < pParse->db->nDb );
vectorIndexDrop(pParse->db, pParse->db->aDb[iDb].zDbSName, pIdx->zName);
}
}
#endif
Expand Down Expand Up @@ -4305,7 +4309,7 @@ void sqlite3CreateIndex(


#ifndef SQLITE_OMIT_VECTOR
if( vectorIndexCreate(pParse, pIndex, pUsing) != SQLITE_OK ) {
if( vectorIndexCreate(pParse, pIndex, db->aDb[iDb].zDbSName, pUsing) != SQLITE_OK ) {
goto exit_create_index;
}
idxType = pIndex->idxType; // vectorIndexCreate can update idxType to 4 (VECTOR INDEX)
Expand Down Expand Up @@ -4662,6 +4666,7 @@ void sqlite3DropIndex(Parse *pParse, SrcList *pName, int ifExists){
"or PRIMARY KEY constraint cannot be dropped", 0);
goto exit_drop_index;
}
iDb = sqlite3SchemaToIndex(db, pIndex->pSchema);
#ifndef SQLITE_OMIT_VECTOR
/*
* There are several places to delete vector index:
Expand All @@ -4675,10 +4680,9 @@ void sqlite3DropIndex(Parse *pParse, SrcList *pName, int ifExists){
* and bit dirty but seems to me as pretty safe and easy way to delete index
*/
if( IsVectorIndex(pIndex) ){
vectorIndexDrop(pParse->db, pIndex->zName);
vectorIndexDrop(pParse->db, pParse->db->aDb[iDb].zDbSName, pIndex->zName);
}
#endif
iDb = sqlite3SchemaToIndex(db, pIndex->pSchema);
#ifndef SQLITE_OMIT_AUTHORIZATION
{
int code = SQLITE_DROP_INDEX;
Expand Down Expand Up @@ -5620,7 +5624,7 @@ void sqlite3Reindex(Parse *pParse, Token *pName1, Token *pName2){
** when it has finished using it.
*/
KeyInfo *sqlite3KeyInfoOfIndex(Parse *pParse, Index *pIdx){
int i;
int i, iDb;
int nCol = pIdx->nColumn;
int nKey = pIdx->nKeyCol;
KeyInfo *pKey;
Expand All @@ -5631,8 +5635,12 @@ KeyInfo *sqlite3KeyInfoOfIndex(Parse *pParse, Index *pIdx){
pKey = sqlite3KeyInfoAlloc(pParse->db, nCol, 0);
}
if( pKey ){
iDb = sqlite3SchemaToIndex(pParse->db, pIdx->pSchema);
assert( sqlite3KeyInfoIsWriteable(pKey) );
pKey->zIndexName = sqlite3DbStrDup(pParse->db, pIdx->zName);
if( 0 <= iDb && iDb < pParse->db->nDb ){
pKey->zDbSName = sqlite3DbStrDup(pParse->db, pParse->db->aDb[iDb].zDbSName);
}
for(i=0; i<nCol; i++){
const char *zColl = pIdx->azColl[i];
pKey->aColl[i] = zColl==sqlite3StrBINARY ? 0 :
Expand Down
4 changes: 4 additions & 0 deletions libsql-sqlite3/src/select.c
Original file line number Diff line number Diff line change
Expand Up @@ -1513,6 +1513,7 @@ KeyInfo *sqlite3KeyInfoAlloc(sqlite3 *db, int N, int X){
p->db = db;
p->nRef = 1;
p->zIndexName = NULL;
p->zDbSName = NULL;
memset(&p[1], 0, nExtra);
}else{
return (KeyInfo*)sqlite3OomFault(db);
Expand All @@ -1532,6 +1533,9 @@ void sqlite3KeyInfoUnref(KeyInfo *p){
if( p->zIndexName != NULL ){
sqlite3DbFree(p->db, p->zIndexName);
}
if( p->zDbSName != NULL ){
sqlite3DbFree(p->db, p->zDbSName);
}
sqlite3DbNNFreeNN(p->db, p);
}
}
Expand Down
1 change: 1 addition & 0 deletions libsql-sqlite3/src/sqliteInt.h
Original file line number Diff line number Diff line change
Expand Up @@ -2641,6 +2641,7 @@ struct KeyInfo {
* vector indices as they operate with names rather than with page numbers
*/
char *zIndexName; /* Name of the index (might be NULL) */
char *zDbSName; /* Name of the database schema (might be NULL) */
u32 nRef; /* Number of references to this KeyInfo object */
u8 enc; /* Text encoding - one of the SQLITE_UTF* values */
u16 nKeyField; /* Number of key columns in the index */
Expand Down
5 changes: 3 additions & 2 deletions libsql-sqlite3/src/vdbe.c
Original file line number Diff line number Diff line change
Expand Up @@ -4229,13 +4229,14 @@ case OP_OpenVectorIdx: {
}else if( pOp->p4type==P4_INT32 ){
nField = pOp->p4.i;
}
assert( pKeyInfo->zDbSName != NULL );
if( pOp->p5 == OPFLAG_FORDELETE ){
rc = vectorIndexClear(db, pKeyInfo->zIndexName);
rc = vectorIndexClear(db, pKeyInfo->zDbSName, pKeyInfo->zIndexName);
if( rc ){
goto abort_due_to_error;
}
}
rc = vectorIndexCursorInit(db, &cursor, pKeyInfo->zIndexName);
rc = vectorIndexCursorInit(db, pKeyInfo->zDbSName, pKeyInfo->zIndexName, &cursor);
if( rc ) {
goto abort_due_to_error;
}
Expand Down
127 changes: 80 additions & 47 deletions libsql-sqlite3/src/vectorIndex.c
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ int vectorInRowAlloc(sqlite3 *db, const UnpackedRecord *pRecord, VectorInRow *pV

if( pVectorInRow->nKeys <= 0 ){
rc = SQLITE_ERROR;
goto out;
goto out;
}

if( sqlite3_value_type(pVectorValue)==SQLITE_NULL ){
Expand All @@ -233,7 +233,7 @@ int vectorInRowAlloc(sqlite3 *db, const UnpackedRecord *pRecord, VectorInRow *pV

if( sqlite3_value_type(pVectorValue) == SQLITE_BLOB ){
vectorInitFromBlob(pVectorInRow->pVector, sqlite3_value_blob(pVectorValue), sqlite3_value_bytes(pVectorValue));
} else if( sqlite3_value_type(pVectorValue) == SQLITE_TEXT ){
} else if( sqlite3_value_type(pVectorValue) == SQLITE_TEXT ){
// users can put strings (e.g. '[1,2,3]') in the table and we should process them correctly
if( vectorParse(pVectorValue, pVectorInRow->pVector, pzErrMsg) != 0 ){
rc = SQLITE_ERROR;
Expand Down Expand Up @@ -321,10 +321,10 @@ void vectorOutRowsGet(sqlite3_context *context, const VectorOutRows *pRows, int

void vectorOutRowsFree(sqlite3 *db, VectorOutRows *pRows) {
int i;

// both aIntValues and ppValues can be null if processing failing in the middle and we didn't created VectorOutRows
assert( pRows->aIntValues == NULL || pRows->ppValues == NULL );

if( pRows->aIntValues != NULL ){
sqlite3DbFree(db, pRows->aIntValues);
}else if( pRows->ppValues != NULL ){
Expand All @@ -337,8 +337,8 @@ void vectorOutRowsFree(sqlite3 *db, VectorOutRows *pRows) {
}
}

/*
* Internal type to represent VECTOR_COLUMN_TYPES array
/*
* Internal type to represent VECTOR_COLUMN_TYPES array
* We support both FLOATNN and FNN_BLOB type names for the following reasons:
* 1. FLOATNN is easy to type for humans and generally OK to use for column type names
* 2. FNN_BLOB is aligned with SQLite affinity rules and can be used in cases where compatibility with type affinity rules is important
Expand All @@ -349,15 +349,15 @@ struct VectorColumnType {
int nBits;
};

static struct VectorColumnType VECTOR_COLUMN_TYPES[] = {
{ "FLOAT32", 32 },
{ "FLOAT64", 64 },
{ "F32_BLOB", 32 },
{ "F64_BLOB", 64 }
static struct VectorColumnType VECTOR_COLUMN_TYPES[] = {
{ "FLOAT32", 32 },
{ "FLOAT64", 64 },
{ "F32_BLOB", 32 },
{ "F64_BLOB", 64 }
};

/*
* Internal type to represent VECTOR_PARAM_NAMES array with recognized parameters for index creation
* Internal type to represent VECTOR_PARAM_NAMES array with recognized parameters for index creation
* For example, libsql_vector_idx(embedding, 'type=diskann', 'metric=cosine')
*/
struct VectorParamName {
Expand All @@ -368,7 +368,7 @@ struct VectorParamName {
u64 value;
};

static struct VectorParamName VECTOR_PARAM_NAMES[] = {
static struct VectorParamName VECTOR_PARAM_NAMES[] = {
{ "type", VECTOR_INDEX_TYPE_PARAM_ID, 0, "diskann", VECTOR_INDEX_TYPE_DISKANN },
{ "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "cosine", VECTOR_METRIC_TYPE_COS },
{ "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "l2", VECTOR_METRIC_TYPE_L2 },
Expand Down Expand Up @@ -550,15 +550,34 @@ int vectorIdxParseColumnType(const char *zType, int *pType, int *pDims, const ch
return -1;
}

int initVectorIndexMetaTable(sqlite3* db) {
static const char *zSql = "CREATE TABLE IF NOT EXISTS " VECTOR_INDEX_GLOBAL_META_TABLE " ( name TEXT PRIMARY KEY, metadata BLOB ) WITHOUT ROWID;";
return sqlite3_exec(db, zSql, 0, 0, 0);
int initVectorIndexMetaTable(sqlite3* db, const char *zDbSName) {
int rc;
static const char *zSqlTemplate = "CREATE TABLE IF NOT EXISTS \"%w\"." VECTOR_INDEX_GLOBAL_META_TABLE " ( name TEXT PRIMARY KEY, metadata BLOB ) WITHOUT ROWID;";
char* zSql;

assert( zDbSName != NULL );

zSql = sqlite3_mprintf(zSqlTemplate, zDbSName);
if( zSql == NULL ){
return SQLITE_NOMEM_BKPT;
}
rc = sqlite3_exec(db, zSql, 0, 0, 0);
sqlite3_free(zSql);
return rc;
}

int insertIndexParameters(sqlite3* db, const char *zName, const VectorIdxParams *pParameters) {
static const char *zSql = "INSERT INTO " VECTOR_INDEX_GLOBAL_META_TABLE " VALUES (?, ?)";
sqlite3_stmt* pStatement = NULL;
int insertIndexParameters(sqlite3* db, const char *zDbSName, const char *zName, const VectorIdxParams *pParameters) {
int rc = SQLITE_ERROR;
static const char *zSqlTemplate = "INSERT INTO \"%w\"." VECTOR_INDEX_GLOBAL_META_TABLE " VALUES (?, ?)";
sqlite3_stmt* pStatement = NULL;
char *zSql;

assert( zDbSName != NULL );

zSql = sqlite3_mprintf(zSqlTemplate, zDbSName);
if( zSql == NULL ){
return SQLITE_NOMEM_BKPT;
}

rc = sqlite3_prepare_v2(db, zSql, -1, &pStatement, 0);
if( rc != SQLITE_OK ){
Expand All @@ -579,6 +598,9 @@ int insertIndexParameters(sqlite3* db, const char *zName, const VectorIdxParams
rc = SQLITE_OK;
}
clear_and_exit:
if( zSql != NULL ){
sqlite3_free(zSql);
}
if( pStatement != NULL ){
sqlite3_finalize(pStatement);
}
Expand Down Expand Up @@ -672,24 +694,31 @@ int vectorIndexGetParameters(
}


int vectorIndexDrop(sqlite3 *db, const char *zIdxName) {
int vectorIndexDrop(sqlite3 *db, const char *zDbSName, const char *zIdxName) {
// we want to try delete all traces of index on every attempt
// this is done to prevent unrecoverable situations where index were dropped but index parameters deletion failed and second attempt will fail on first step
int rcIdx = diskAnnDropIndex(db, zIdxName);
int rcParams = removeIndexParameters(db, zIdxName);
int rcIdx, rcParams;

assert( zDbSName != NULL );

rcIdx = diskAnnDropIndex(db, zDbSName, zIdxName);
rcParams = removeIndexParameters(db, zIdxName);
return rcIdx != SQLITE_OK ? rcIdx : rcParams;
}

int vectorIndexClear(sqlite3 *db, const char *zIdxName) {
return diskAnnClearIndex(db, zIdxName);
int vectorIndexClear(sqlite3 *db, const char *zDbSName, const char *zIdxName) {
assert( zDbSName != NULL );
return diskAnnClearIndex(db, zDbSName, zIdxName);
}

int vectorIndexCreate(Parse *pParse, Index *pIdx, const IdList *pUsing) {
int vectorIndexCreate(Parse *pParse, Index *pIdx, const char *zDbSName, const IdList *pUsing) {
int i, rc = SQLITE_OK;
int dims, type;
int hasLibsqlVectorIdxFn = 0, hasCollation = 0;
const char *pzErrMsg;

assert( zDbSName != NULL );

sqlite3 *db = pParse->db;
Table *pTable = pIdx->pTable;
struct ExprList_item *pListItem;
Expand Down Expand Up @@ -776,34 +805,33 @@ int vectorIndexCreate(Parse *pParse, Index *pIdx, const IdList *pUsing) {
return SQLITE_ERROR;
}

if( vectorIdxKeyGet(pTable, &idxKey, &pzErrMsg) != 0 ){
sqlite3ErrorMsg(pParse, "failed to detect underlying table key: %s", pzErrMsg);
return SQLITE_ERROR;
}
if( idxKey.nKeyColumns != 1 ){
sqlite3ErrorMsg(pParse, "vector index for tables without ROWID and composite primary key are not supported");
return SQLITE_ERROR;
}

// schema is locked while db is initializing and we need to just proceed here
if( db->init.busy == 1 ){
goto succeed;
}

rc = initVectorIndexMetaTable(db);
rc = initVectorIndexMetaTable(db, zDbSName);
if( rc != SQLITE_OK ){
return rc;
}
rc = parseVectorIdxParams(pParse, &idxParams, type, dims, pListItem + 1, pArgsList->nExpr - 1);
if( rc != SQLITE_OK ){
return rc;
}
rc = diskAnnCreateIndex(db, pIdx->zName, &idxKey, &idxParams);
if( vectorIdxKeyGet(pTable, &idxKey, &pzErrMsg) != 0 ){
sqlite3ErrorMsg(pParse, "failed to detect underlying table key: %s", pzErrMsg);
return SQLITE_ERROR;
}
if( idxKey.nKeyColumns != 1 ){
sqlite3ErrorMsg(pParse, "vector index for tables without ROWID and composite primary key are not supported");
return SQLITE_ERROR;
}
rc = diskAnnCreateIndex(db, zDbSName, pIdx->zName, &idxKey, &idxParams);
if( rc != SQLITE_OK ){
sqlite3ErrorMsg(pParse, "unable to initialize diskann vector index");
return rc;
}
rc = insertIndexParameters(db, pIdx->zName, &idxParams);
rc = insertIndexParameters(db, zDbSName, pIdx->zName, &idxParams);
if( rc != SQLITE_OK ){
sqlite3ErrorMsg(pParse, "unable to update global metadata table");
return rc;
Expand All @@ -815,7 +843,7 @@ int vectorIndexCreate(Parse *pParse, Index *pIdx, const IdList *pUsing) {
return SQLITE_OK;
}

int vectorIndexSearch(sqlite3 *db, int argc, sqlite3_value **argv, VectorOutRows *pRows, char **pzErrMsg) {
int vectorIndexSearch(sqlite3 *db, const char* zDbSName, int argc, sqlite3_value **argv, VectorOutRows *pRows, char **pzErrMsg) {
int type, dims, k, rc;
const char *zIdxName;
const char *zErrMsg;
Expand All @@ -826,6 +854,8 @@ int vectorIndexSearch(sqlite3 *db, int argc, sqlite3_value **argv, VectorOutRows
VectorIdxParams idxParams;
vectorIdxParamsInit(&idxParams, NULL, 0);

assert( zDbSName != NULL );

if( argc != 3 ){
*pzErrMsg = sqlite3_mprintf("vector search must have exactly 3 parameters");
rc = SQLITE_ERROR;
Expand Down Expand Up @@ -871,22 +901,22 @@ int vectorIndexSearch(sqlite3 *db, int argc, sqlite3_value **argv, VectorOutRows
rc = SQLITE_ERROR;
goto out;
}
pIndex = sqlite3FindIndex(db, zIdxName, db->aDb[0].zDbSName);
pIndex = sqlite3FindIndex(db, zIdxName, zDbSName);
if( pIndex == NULL ){
*pzErrMsg = sqlite3_mprintf("vector index not found");
rc = SQLITE_ERROR;
goto out;
}
rc = diskAnnOpenIndex(db, zDbSName, zIdxName, &idxParams, &pDiskAnn);
if( rc != SQLITE_OK ){
*pzErrMsg = sqlite3_mprintf("failed to open diskann index");
goto out;
}
if( vectorIdxKeyGet(pIndex->pTable, &pKey, &zErrMsg) != 0 ){
*pzErrMsg = sqlite3_mprintf("failed to extract table key: %s", zErrMsg);
rc = SQLITE_ERROR;
goto out;
}
rc = diskAnnOpenIndex(db, zIdxName, &idxParams, &pDiskAnn);
if( rc != SQLITE_OK ){
*pzErrMsg = sqlite3_mprintf("failed to open diskann index");
goto out;
}
rc = diskAnnSearch(pDiskAnn, pVector, k, &pKey, pRows, pzErrMsg);
out:
if( pDiskAnn != NULL ){
Expand Down Expand Up @@ -932,22 +962,25 @@ int vectorIndexDelete(

int vectorIndexCursorInit(
sqlite3 *db,
VectorIdxCursor **ppCursor,
const char *zIndexName
const char *zDbSName,
const char *zIndexName,
VectorIdxCursor **ppCursor
){
int rc;
VectorIdxCursor* pCursor;
VectorIdxParams params;
vectorIdxParamsInit(&params, NULL, 0);

assert( zDbSName != NULL );

if( vectorIndexGetParameters(db, zIndexName, &params) != 0 ){
return SQLITE_ERROR;
}
pCursor = sqlite3DbMallocZero(db, sizeof(VectorIdxCursor));
if( pCursor == 0 ){
return SQLITE_NOMEM_BKPT;
}
rc = diskAnnOpenIndex(db, zIndexName, &params, &pCursor->pIndex);
rc = diskAnnOpenIndex(db, zDbSName, zIndexName, &params, &pCursor->pIndex);
if( rc != SQLITE_OK ){
sqlite3DbFree(db, pCursor);
return rc;
Expand Down
Loading

0 comments on commit 5e6afb3

Please sign in to comment.