From e38fb03387b01802fa7eb567a4dcfe52ae44b661 Mon Sep 17 00:00:00 2001 From: Kailash Bisht Date: Sun, 15 Oct 2023 22:28:02 +0530 Subject: [PATCH] feat: add project analytics api --- .../dashboard-backend/internal/api/api.go | 12 ++ .../internal/api/constants.go | 3 +- .../internal/api/projects.go | 27 +++- .../internal/api/projects_test.go | 124 +++++++++++++++++- .../dashboard-backend/internal/dto/dto.go | 5 + .../internal/repositories/application.go | 1 + .../internal/repositories/project.go | 72 ++++++++++ .../internal/repositories/project_test.go | 80 ++++++++++- shared/go/db/project.sql.go | 67 ++++++++++ sql/queries/project.sql | 21 +++ 10 files changed, 406 insertions(+), 6 deletions(-) diff --git a/services/dashboard-backend/internal/api/api.go b/services/dashboard-backend/internal/api/api.go index e2bf3d72..0911c3df 100644 --- a/services/dashboard-backend/internal/api/api.go +++ b/services/dashboard-backend/internal/api/api.go @@ -39,6 +39,18 @@ func RegisterHandlers(mux *chi.Mux) { subRouter.Patch("/", HandleUpdateProject) }) + router.Route(ProjectAnalyticsEndpoint, func(subRouter chi.Router) { + subRouter.Use(middleware.PathParameterMiddleware("projectId")) + subRouter.Use( + middleware.AuthorizationMiddleware( + middleware.MethodPermissionMap{ + http.MethodGet: allPermissions, + }, + ), + ) + subRouter.Get("/", HandleRetrieveProjectAnalytics) + }) + router.Route(ProjectUserListEndpoint, func(subRouter chi.Router) { subRouter.Use(middleware.PathParameterMiddleware("projectId")) subRouter.Use( diff --git a/services/dashboard-backend/internal/api/constants.go b/services/dashboard-backend/internal/api/constants.go index 5f226af5..925a4b7a 100644 --- a/services/dashboard-backend/internal/api/constants.go +++ b/services/dashboard-backend/internal/api/constants.go @@ -2,11 +2,12 @@ package api const ( ProjectsListEndpoint = "/projects" + ProjectAnalyticsEndpoint = "/projects/{projectId}/analytics" ProjectDetailEndpoint = "/projects/{projectId}" ProjectUserListEndpoint = "/projects/{projectId}/users" ProjectUserDetailEndpoint = "/projects/{projectId}/users/{userId}" - ApplicationAnalyticsEndpoint = "/projects/{projectId}/applications/{applicationId}/analytics" ApplicationsListEndpoint = "/projects/{projectId}/applications" + ApplicationAnalyticsEndpoint = "/projects/{projectId}/applications/{applicationId}/analytics" ApplicationDetailEndpoint = "/projects/{projectId}/applications/{applicationId}" ApplicationTokensListEndpoint = "/projects/{projectId}/applications/{applicationId}/tokens" //nolint: gosec ApplicationTokenDetailEndpoint = "/projects/{projectId}/applications/{applicationId}/tokens/{tokenId}" //nolint: gosec diff --git a/services/dashboard-backend/internal/api/projects.go b/services/dashboard-backend/internal/api/projects.go index d3b5078c..4c1c0ce1 100644 --- a/services/dashboard-backend/internal/api/projects.go +++ b/services/dashboard-backend/internal/api/projects.go @@ -1,15 +1,18 @@ package api import ( + "net/http" + "time" + "github.com/basemind-ai/monorepo/services/dashboard-backend/internal/dto" "github.com/basemind-ai/monorepo/services/dashboard-backend/internal/middleware" "github.com/basemind-ai/monorepo/services/dashboard-backend/internal/repositories" "github.com/basemind-ai/monorepo/shared/go/apierror" "github.com/basemind-ai/monorepo/shared/go/db" "github.com/basemind-ai/monorepo/shared/go/serialization" + "github.com/basemind-ai/monorepo/shared/go/timeutils" "github.com/jackc/pgx/v5/pgtype" "github.com/rs/zerolog/log" - "net/http" ) // HandleCreateProject - creates a new project and sets the user as an ADMIN. @@ -148,3 +151,25 @@ func HandleDeleteProject(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) } + +func HandleRetrieveProjectAnalytics(w http.ResponseWriter, r *http.Request) { + projectID := r.Context().Value(middleware.ProjectIDContextKey).(pgtype.UUID) + + toDate := timeutils.ParseDate(r.URL.Query().Get("toDate"), time.Now()) + fromDate := timeutils.ParseDate(r.URL.Query().Get("fromDate"), timeutils.GetFirstDayOfMonth()) + + projectAnalytics, err := repositories.GetProjectAnalyticsByDateRange( + r.Context(), + projectID, + fromDate, + toDate, + ) + if err != nil { + log.Error().Err(err).Msg("failed to retrieve project analytics") + apierror.InternalServerError().Render(w, r) + return + } + + w.WriteHeader(http.StatusOK) + serialization.RenderJSONResponse(w, http.StatusOK, projectAnalytics) +} diff --git a/services/dashboard-backend/internal/api/projects_test.go b/services/dashboard-backend/internal/api/projects_test.go index fada445a..e693f69f 100644 --- a/services/dashboard-backend/internal/api/projects_test.go +++ b/services/dashboard-backend/internal/api/projects_test.go @@ -3,15 +3,18 @@ package api_test import ( "context" "fmt" + "net/http" + "strings" + "testing" + "time" + "github.com/basemind-ai/monorepo/e2e/factories" "github.com/basemind-ai/monorepo/services/dashboard-backend/internal/api" "github.com/basemind-ai/monorepo/services/dashboard-backend/internal/dto" + "github.com/basemind-ai/monorepo/services/dashboard-backend/internal/repositories" "github.com/basemind-ai/monorepo/shared/go/db" "github.com/basemind-ai/monorepo/shared/go/serialization" "github.com/stretchr/testify/assert" - "net/http" - "strings" - "testing" ) func TestProjectsAPI(t *testing.T) { @@ -281,4 +284,119 @@ func TestProjectsAPI(t *testing.T) { }, ) }) + + t.Run(fmt.Sprintf("GET: %s", api.ProjectAnalyticsEndpoint), func(t *testing.T) { + invalidUUID := "invalid" + projectID := createProject(t) + createUserProject(t, userAccount.FirebaseID, projectID, db.AccessPermissionTypeADMIN) + + applicationID := createApplication(t, projectID) + createPromptRequestRecord(t, applicationID) + + fromDate := time.Now().AddDate(0, 0, -1) + toDate := fromDate.AddDate(0, 0, 2) + + t.Run("retrieves project analytics", func(t *testing.T) { + response, requestErr := testClient.Get( + context.TODO(), + fmt.Sprintf( + "/v1%s", + strings.ReplaceAll( + api.ProjectAnalyticsEndpoint, + "{projectId}", + projectID, + ), + ), + ) + assert.NoError(t, requestErr) + assert.Equal(t, http.StatusOK, response.StatusCode) + + projectUUID, _ := db.StringToUUID(projectID) + promptReqAnalytics, _ := repositories.GetProjectAnalyticsByDateRange( + context.TODO(), + *projectUUID, + fromDate, + toDate, + ) + + responseAnalytics := dto.ProjectAnalyticsDTO{} + deserializationErr := serialization.DeserializeJSON( + response.Body, + &responseAnalytics, + ) + + assert.NoError(t, deserializationErr) + assert.Equal(t, promptReqAnalytics.TotalAPICalls, responseAnalytics.TotalAPICalls) + assert.Equal(t, promptReqAnalytics.ModelsCost, responseAnalytics.ModelsCost) + }) + + for _, permission := range []db.AccessPermissionType{ + db.AccessPermissionTypeMEMBER, db.AccessPermissionTypeADMIN, + } { + t.Run( + fmt.Sprintf( + "responds with status 200 OK if the user has %s permission", + permission, + ), + func(t *testing.T) { + newUserAccount, _ := factories.CreateUserAccount(context.TODO()) + newProjectID := createProject(t) + createUserProject(t, newUserAccount.FirebaseID, newProjectID, permission) + + newTestClient := createTestClient(t, newUserAccount) + + response, requestErr := newTestClient.Get( + context.TODO(), + fmt.Sprintf( + "/v1%s", + strings.ReplaceAll( + api.ProjectAnalyticsEndpoint, + "{projectId}", + newProjectID, + ), + ), + ) + assert.NoError(t, requestErr) + assert.Equal(t, http.StatusOK, response.StatusCode) + }, + ) + } + + t.Run( + "responds with status 403 FORBIDDEN if the user does not have projects access", + func(t *testing.T) { + newProjectID := createProject(t) + + response, requestErr := testClient.Get( + context.TODO(), + fmt.Sprintf( + "/v1%s", + strings.ReplaceAll( + api.ProjectAnalyticsEndpoint, + "{projectId}", + newProjectID, + ), + ), + ) + assert.NoError(t, requestErr) + assert.Equal(t, http.StatusForbidden, response.StatusCode) + }, + ) + + t.Run("responds with status 400 BAD REQUEST if projectID is invalid", func(t *testing.T) { + response, requestErr := testClient.Get( + context.TODO(), + fmt.Sprintf( + "/v1%s", + strings.ReplaceAll( + api.ProjectAnalyticsEndpoint, + "{projectId}", + invalidUUID, + ), + ), + ) + assert.NoError(t, requestErr) + assert.Equal(t, http.StatusBadRequest, response.StatusCode) + }) + }) } diff --git a/services/dashboard-backend/internal/dto/dto.go b/services/dashboard-backend/internal/dto/dto.go index 52897650..41dc33fa 100644 --- a/services/dashboard-backend/internal/dto/dto.go +++ b/services/dashboard-backend/internal/dto/dto.go @@ -74,3 +74,8 @@ type ApplicationAnalyticsDTO struct { TotalRequests int64 `json:"totalRequests"` ProjectedCost float64 `json:"projectedCost"` } + +type ProjectAnalyticsDTO struct { + TotalAPICalls int64 `json:"totalAPICalls"` + ModelsCost float64 `json:"modelsCost"` +} diff --git a/services/dashboard-backend/internal/repositories/application.go b/services/dashboard-backend/internal/repositories/application.go index e2c5e31e..5765bebb 100644 --- a/services/dashboard-backend/internal/repositories/application.go +++ b/services/dashboard-backend/internal/repositories/application.go @@ -3,6 +3,7 @@ package repositories import ( "context" "fmt" + "github.com/basemind-ai/monorepo/shared/go/db" "github.com/basemind-ai/monorepo/shared/go/rediscache" "github.com/jackc/pgx/v5/pgtype" diff --git a/services/dashboard-backend/internal/repositories/project.go b/services/dashboard-backend/internal/repositories/project.go index 82c9ccf5..011d654d 100644 --- a/services/dashboard-backend/internal/repositories/project.go +++ b/services/dashboard-backend/internal/repositories/project.go @@ -3,8 +3,11 @@ package repositories import ( "context" "fmt" + "time" + "github.com/basemind-ai/monorepo/services/dashboard-backend/internal/dto" "github.com/basemind-ai/monorepo/shared/go/db" + "github.com/basemind-ai/monorepo/shared/go/tokenutils" "github.com/jackc/pgx/v5/pgtype" "github.com/rs/zerolog/log" ) @@ -114,3 +117,72 @@ func DeleteProject(ctx context.Context, projectID pgtype.UUID) error { return nil } + +func GetTotalAPICountByDateRange( + ctx context.Context, + projectID pgtype.UUID, + fromDate, toDate time.Time, +) (int64, error) { + reqParam := db.RetrieveTotalPromptAPICallsParams{ + ProjectID: projectID, + FromDate: pgtype.Timestamptz{Time: fromDate, Valid: true}, + ToDate: pgtype.Timestamptz{Time: toDate, Valid: true}, + } + + totalAPICalls, dbErr := db.GetQueries().RetrieveTotalPromptAPICalls(ctx, reqParam) + if dbErr != nil { + return -1, dbErr + } + + return totalAPICalls, nil +} + +func GetTokenConsumedByProjectByDateRange( + ctx context.Context, + projectID pgtype.UUID, + fromDate, toDate time.Time, +) (map[db.ModelType]int64, error) { + reqParam := db.RetrieveTotalTokensConsumedParams{ + ProjectID: projectID, + FromDate: pgtype.Timestamptz{Time: fromDate, Valid: true}, + ToDate: pgtype.Timestamptz{Time: toDate, Valid: true}, + } + + tokensConsumed, dbErr := db.GetQueries().RetrieveTotalTokensConsumed(ctx, reqParam) + if dbErr != nil { + return nil, dbErr + } + + projectTokenCntMap := make(map[db.ModelType]int64) + for _, record := range tokensConsumed { + projectTokenCntMap[record.ModelType] += record.TotalTokens + } + + return projectTokenCntMap, nil +} + +func GetProjectAnalyticsByDateRange( + ctx context.Context, + projectID pgtype.UUID, + fromDate, toDate time.Time, +) (dto.ProjectAnalyticsDTO, error) { + totalApiCalls, dbErr := GetTotalAPICountByDateRange(ctx, projectID, fromDate, toDate) + if dbErr != nil { + return dto.ProjectAnalyticsDTO{}, dbErr + } + + projectTokenCntMap, dbErr := GetTokenConsumedByProjectByDateRange(ctx, projectID, fromDate, toDate) + if dbErr != nil { + return dto.ProjectAnalyticsDTO{}, dbErr + } + + var modelCost float64 + for model, tokenCnt := range projectTokenCntMap { + modelCost += tokenutils.GetCostByModelType(tokenCnt, model) + } + + return dto.ProjectAnalyticsDTO{ + TotalAPICalls: totalApiCalls, + ModelsCost: modelCost, + }, nil +} diff --git a/services/dashboard-backend/internal/repositories/project_test.go b/services/dashboard-backend/internal/repositories/project_test.go index 9c4ab6dd..16466731 100644 --- a/services/dashboard-backend/internal/repositories/project_test.go +++ b/services/dashboard-backend/internal/repositories/project_test.go @@ -2,11 +2,15 @@ package repositories_test import ( "context" + "testing" + "time" + "github.com/basemind-ai/monorepo/e2e/factories" "github.com/basemind-ai/monorepo/services/dashboard-backend/internal/repositories" "github.com/basemind-ai/monorepo/shared/go/db" + "github.com/basemind-ai/monorepo/shared/go/tokenutils" + "github.com/jackc/pgx/v5/pgtype" "github.com/stretchr/testify/assert" - "testing" ) func TestProjectRepository(t *testing.T) { @@ -121,4 +125,78 @@ func TestProjectRepository(t *testing.T) { assert.Error(t, err) }) }) + + t.Run("Project Analytics", func(t *testing.T) { + project, _ := factories.CreateProject(context.TODO()) + application, _ := factories.CreateApplication(context.TODO(), project.ID) + factories.CreatePromptRequestRecord(context.TODO(), application.ID) + + fromDate := time.Now().AddDate(0, 0, -1) + toDate := fromDate.AddDate(0, 0, 2) + totalTokensUsed := int64(20) + + t.Run("GetTotalAPICountByDateRange", func(t *testing.T) { + t.Run("get total api count by date range", func(t *testing.T) { + totalRequests, dbErr := repositories.GetTotalAPICountByDateRange( + context.TODO(), + project.ID, + time.Now().AddDate(0, 0, -1), + time.Now().AddDate(0, 0, 1), + ) + assert.NoError(t, dbErr) + assert.Equal(t, int64(1), totalRequests) + }) + t.Run("fails to get total api count for invalid project id", func(t *testing.T) { + invalidProjectId := pgtype.UUID{Bytes: [16]byte{}, Valid: false} + totalRequests, _ := repositories.GetTotalAPICountByDateRange( + context.TODO(), + invalidProjectId, + time.Now().AddDate(0, 0, -1), + time.Now().AddDate(0, 0, 1), + ) + assert.Equal(t, int64(0), totalRequests) + }) + }) + + t.Run("GetTokenConsumedByProjectByDateRange", func(t *testing.T) { + t.Run("get total api count by date range", func(t *testing.T) { + projectTokenCntMap, dbErr := repositories.GetTokenConsumedByProjectByDateRange( + context.TODO(), + project.ID, + time.Now().AddDate(0, 0, -1), + time.Now().AddDate(0, 0, 1), + ) + assert.NoError(t, dbErr) + assert.Equal(t, int64(20), projectTokenCntMap[db.ModelTypeGpt35Turbo]) + }) + t.Run("fails to get total api count for invalid project id", func(t *testing.T) { + invalidProjectId := pgtype.UUID{Bytes: [16]byte{}, Valid: false} + projectTokenCntMap, _ := repositories.GetTokenConsumedByProjectByDateRange( + context.TODO(), + invalidProjectId, + time.Now().AddDate(0, 0, -1), + time.Now().AddDate(0, 0, 1), + ) + assert.Equal(t, int64(0), projectTokenCntMap[db.ModelTypeGpt35Turbo]) + }) + }) + + t.Run("GetProjectAnalyticsByDateRange", func(t *testing.T) { + t.Run("get token usage for each model types by date range", func(t *testing.T) { + projectAnalytics, dbErr := repositories.GetProjectAnalyticsByDateRange( + context.TODO(), + project.ID, + fromDate, + toDate, + ) + assert.NoError(t, dbErr) + assert.Equal(t, int64(1), projectAnalytics.TotalAPICalls) + assert.Equal( + t, + tokenutils.GetCostByModelType(totalTokensUsed, db.ModelTypeGpt35Turbo), + projectAnalytics.ModelsCost, + ) + }) + }) + }) } diff --git a/shared/go/db/project.sql.go b/shared/go/db/project.sql.go index ecae437d..6b9442fb 100644 --- a/shared/go/db/project.sql.go +++ b/shared/go/db/project.sql.go @@ -142,6 +142,73 @@ func (q *Queries) RetrieveProjects(ctx context.Context, firebaseID string) ([]Re return items, nil } +const retrieveTotalPromptAPICalls = `-- name: RetrieveTotalPromptAPICalls :one +SELECT COUNT(prr.id) AS total_requests +FROM prompt_request_record AS prr +INNER JOIN prompt_config AS pc ON prr.prompt_config_id = pc.id +INNER JOIN application AS app ON pc.application_id = app.id +WHERE + app.project_id = $1 + AND prr.created_at BETWEEN $2 AND $3 +` + +type RetrieveTotalPromptAPICallsParams struct { + ProjectID pgtype.UUID `json:"projectId"` + FromDate pgtype.Timestamptz `json:"fromDate"` + ToDate pgtype.Timestamptz `json:"toDate"` +} + +func (q *Queries) RetrieveTotalPromptAPICalls(ctx context.Context, arg RetrieveTotalPromptAPICallsParams) (int64, error) { + row := q.db.QueryRow(ctx, retrieveTotalPromptAPICalls, arg.ProjectID, arg.FromDate, arg.ToDate) + var total_requests int64 + err := row.Scan(&total_requests) + return total_requests, err +} + +const retrieveTotalTokensConsumed = `-- name: RetrieveTotalTokensConsumed :many +SELECT + pc.model_type, + SUM(prr.request_tokens + prr.response_tokens) AS total_tokens +FROM prompt_request_record AS prr +INNER JOIN prompt_config AS pc ON prr.prompt_config_id = pc.id +INNER JOIN application AS app ON pc.application_id = app.id +WHERE + app.project_id = $1 + AND prr.created_at BETWEEN $2 AND $3 +GROUP BY pc.model_type +` + +type RetrieveTotalTokensConsumedParams struct { + ProjectID pgtype.UUID `json:"projectId"` + FromDate pgtype.Timestamptz `json:"fromDate"` + ToDate pgtype.Timestamptz `json:"toDate"` +} + +type RetrieveTotalTokensConsumedRow struct { + ModelType ModelType `json:"modelType"` + TotalTokens int64 `json:"totalTokens"` +} + +func (q *Queries) RetrieveTotalTokensConsumed(ctx context.Context, arg RetrieveTotalTokensConsumedParams) ([]RetrieveTotalTokensConsumedRow, error) { + rows, err := q.db.Query(ctx, retrieveTotalTokensConsumed, arg.ProjectID, arg.FromDate, arg.ToDate) + if err != nil { + return nil, err + } + defer rows.Close() + var items []RetrieveTotalTokensConsumedRow + for rows.Next() { + var i RetrieveTotalTokensConsumedRow + if err := rows.Scan(&i.ModelType, &i.TotalTokens); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const updateProject = `-- name: UpdateProject :one UPDATE project SET diff --git a/sql/queries/project.sql b/sql/queries/project.sql index 5b7cb70c..46960ffa 100644 --- a/sql/queries/project.sql +++ b/sql/queries/project.sql @@ -47,3 +47,24 @@ LEFT JOIN project AS p ON up.project_id = p.id LEFT JOIN user_account AS ua ON up.user_id = ua.id WHERE ua.firebase_id = $1 AND p.deleted_at IS NULL; + +-- name: RetrieveTotalPromptAPICalls :one +SELECT COUNT(prr.id) AS total_requests +FROM prompt_request_record AS prr +INNER JOIN prompt_config AS pc ON prr.prompt_config_id = pc.id +INNER JOIN application AS app ON pc.application_id = app.id +WHERE + app.project_id = $1 + AND prr.created_at BETWEEN $2 AND $3; + +-- name: RetrieveTotalTokensConsumed :many +SELECT + pc.model_type, + SUM(prr.request_tokens + prr.response_tokens) AS total_tokens +FROM prompt_request_record AS prr +INNER JOIN prompt_config AS pc ON prr.prompt_config_id = pc.id +INNER JOIN application AS app ON pc.application_id = app.id +WHERE + app.project_id = $1 + AND prr.created_at BETWEEN $2 AND $3 +GROUP BY pc.model_type;