Skip to content

Commit

Permalink
feat: graphql: add api to update rag
Browse files Browse the repository at this point in the history
  • Loading branch information
dayuy committed Jan 30, 2024
1 parent d74f39a commit 534b402
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 6 deletions.
4 changes: 2 additions & 2 deletions apiserver/graph/generated/generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion apiserver/graph/generated/models_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 5 additions & 2 deletions apiserver/graph/impl/rag.resolvers.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion apiserver/graph/schema/rag.graphqls
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ input UpdateRAGInput {
annotations: Map
displayName: String
description: String
application: TypedObjectReferenceInput!
application: TypedObjectReferenceInput
datasets: [RAGDatasetInput!]
judgeLLM: TypedObjectReferenceInput
metrics: [RAGMetricInput!]
Expand Down
99 changes: 99 additions & 0 deletions apiserver/pkg/rag/rag.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
evav1alpha1 "github.com/kubeagi/arcadia/api/evaluation/v1alpha1"
"github.com/kubeagi/arcadia/apiserver/graph/generated"
"github.com/kubeagi/arcadia/apiserver/pkg/common"
graphqlutils "github.com/kubeagi/arcadia/apiserver/pkg/utils"
"github.com/kubeagi/arcadia/pkg/utils"
)

Expand Down Expand Up @@ -334,6 +335,104 @@ func CreateRAG(ctx context.Context, kubeClient dynamic.Interface, input *generat
return rag2model(u1)
}

func UpdateRAG(ctx context.Context, kubeClient dynamic.Interface, input *generated.UpdateRAGInput) (*generated.Rag, error) {
obj, err := kubeClient.Resource(common.SchemaOf(&common.ArcadiaAPIGroup, "RAG")).Namespace(input.Namespace).Get(ctx, input.Name, v1.GetOptions{})
if err != nil {
return nil, err
}
rag := &evav1alpha1.RAG{}
if err := utils.UnstructuredToStructured(obj, rag); err != nil {
return nil, err
}

if input.Labels != nil {
rag.SetLabels(graphqlutils.MapAny2Str(input.Labels))
}
if input.Annotations != nil {
rag.SetAnnotations(graphqlutils.MapAny2Str(input.Annotations))
}
if input.DisplayName != nil {
rag.Spec.DisplayName = *input.DisplayName
}
if input.Description != nil {
rag.Spec.Description = *input.Description
}
if input.Application != nil {
rag.Spec.Application = &v1alpha1.TypedObjectReference{
APIGroup: input.Application.APIGroup,
Kind: input.Application.Kind,
Name: input.Application.Name,
Namespace: input.Application.Namespace,
}
}
if input.Datasets != nil {
rag.Spec.Datasets = make([]evav1alpha1.Dataset, 0)
for i, dataset := range input.Datasets {
ds := evav1alpha1.Dataset{
Source: &v1alpha1.TypedObjectReference{
APIGroup: input.Datasets[i].Source.APIGroup,
Kind: input.Datasets[i].Source.Kind,
Name: input.Datasets[i].Source.Name,
Namespace: input.Datasets[i].Source.Namespace,
},
Files: dataset.Files,
}
rag.Spec.Datasets = append(rag.Spec.Datasets, ds)
}
}
if input.JudgeLlm != nil {
rag.Spec.JudgeLLM = &v1alpha1.TypedObjectReference{
APIGroup: input.JudgeLlm.APIGroup,
Kind: input.JudgeLlm.Kind,
Name: input.JudgeLlm.Name,
Namespace: input.JudgeLlm.Namespace,
}
}
if input.Metrics != nil {
rag.Spec.Metrics = make([]evav1alpha1.Metric, 0)
for _, m := range input.Metrics {
mm := evav1alpha1.Metric{
Parameters: make([]evav1alpha1.Parameter, 0),
}
if m.MetricKind != nil {
mm.Kind = evav1alpha1.MetricsKind(*m.MetricKind)
}
if m.ToleranceThreshbold != nil {
mm.ToleranceThreshbold = *m.ToleranceThreshbold
}
for _, p := range m.Parameters {
mm.Parameters = append(mm.Parameters, evav1alpha1.Parameter{
Key: *p.Key,
Value: *p.Value,
})
}
rag.Spec.Metrics = append(rag.Spec.Metrics, mm)
}
}

if input.Storage != nil {
rag.Spec.Storage = gen2storage(*input.Storage)
}
if input.Suspend != nil {
rag.Spec.Suspend = *input.Suspend
}

unstructuredRag, err := runtime.DefaultUnstructuredConverter.ToUnstructured(&rag)
if err != nil {
return nil, err
}
updatedRag, err := common.ResouceUpdate(ctx, kubeClient, generated.TypedObjectReferenceInput{
APIGroup: &common.ArcadiaAPIGroup,
Kind: "RAG",
Namespace: &rag.Namespace,
Name: rag.Name,
}, unstructuredRag, v1.UpdateOptions{})
if err != nil {
return nil, err
}
return rag2model(updatedRag)
}

func ListRAG(ctx context.Context, kubeClient dynamic.Interface, input *generated.ListRAGInput) (*generated.PaginatedResult, error) {
listOptions := v1.ListOptions{
LabelSelector: fmt.Sprintf("%s=%s", evav1alpha1.EvaluationApplicationLabel, input.AppName),
Expand Down

0 comments on commit 534b402

Please sign in to comment.