From 534b4027175de8abc4670e5291a4d4cfd1817425 Mon Sep 17 00:00:00 2001 From: dayuy <973860441@qq.com> Date: Tue, 30 Jan 2024 15:54:01 +0800 Subject: [PATCH] feat: graphql: add api to update rag --- apiserver/graph/generated/generated.go | 4 +- apiserver/graph/generated/models_gen.go | 2 +- apiserver/graph/impl/rag.resolvers.go | 7 +- apiserver/graph/schema/rag.graphqls | 2 +- apiserver/pkg/rag/rag.go | 99 +++++++++++++++++++++++++ 5 files changed, 108 insertions(+), 6 deletions(-) diff --git a/apiserver/graph/generated/generated.go b/apiserver/graph/generated/generated.go index 421c80fbf..a30aeadfb 100644 --- a/apiserver/graph/generated/generated.go +++ b/apiserver/graph/generated/generated.go @@ -6808,7 +6808,7 @@ input UpdateRAGInput { annotations: Map displayName: String description: String - application: TypedObjectReferenceInput! + application: TypedObjectReferenceInput datasets: [RAGDatasetInput!] judgeLLM: TypedObjectReferenceInput metrics: [RAGMetricInput!] @@ -36432,7 +36432,7 @@ func (ec *executionContext) unmarshalInputUpdateRAGInput(ctx context.Context, ob var err error ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("application")) - data, err := ec.unmarshalNTypedObjectReferenceInput2githubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐTypedObjectReferenceInput(ctx, v) + data, err := ec.unmarshalOTypedObjectReferenceInput2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐTypedObjectReferenceInput(ctx, v) if err != nil { return it, err } diff --git a/apiserver/graph/generated/models_gen.go b/apiserver/graph/generated/models_gen.go index 34483e2d8..a8b410ccb 100644 --- a/apiserver/graph/generated/models_gen.go +++ b/apiserver/graph/generated/models_gen.go @@ -1679,7 +1679,7 @@ type UpdateRAGInput struct { Annotations map[string]interface{} `json:"annotations,omitempty"` DisplayName *string `json:"displayName,omitempty"` Description *string `json:"description,omitempty"` - Application TypedObjectReferenceInput `json:"application"` + Application *TypedObjectReferenceInput `json:"application,omitempty"` Datasets []*RAGDatasetInput `json:"datasets,omitempty"` JudgeLlm *TypedObjectReferenceInput `json:"judgeLLM,omitempty"` Metrics []*RAGMetricInput `json:"metrics,omitempty"` diff --git a/apiserver/graph/impl/rag.resolvers.go b/apiserver/graph/impl/rag.resolvers.go index e2cda9e54..598827d49 100644 --- a/apiserver/graph/impl/rag.resolvers.go +++ b/apiserver/graph/impl/rag.resolvers.go @@ -6,7 +6,6 @@ package impl import ( "context" - "fmt" "github.com/kubeagi/arcadia/apiserver/graph/generated" "github.com/kubeagi/arcadia/apiserver/pkg/application" @@ -87,7 +86,11 @@ func (r *rAGMutationResolver) CreateRag(ctx context.Context, obj *generated.RAGM // UpdateRag is the resolver for the updateRAG field. func (r *rAGMutationResolver) UpdateRag(ctx context.Context, obj *generated.RAGMutation, input generated.UpdateRAGInput) (*generated.Rag, error) { - panic(fmt.Errorf("not implemented: UpdateRag - updateRAG")) + c, err := getClientFromCtx(ctx) + if err != nil { + return nil, err + } + return rag.UpdateRAG(ctx, c, &input) } // DeleteRag is the resolver for the deleteRAG field. diff --git a/apiserver/graph/schema/rag.graphqls b/apiserver/graph/schema/rag.graphqls index 08faa4089..0ec408ce9 100644 --- a/apiserver/graph/schema/rag.graphqls +++ b/apiserver/graph/schema/rag.graphqls @@ -143,7 +143,7 @@ input UpdateRAGInput { annotations: Map displayName: String description: String - application: TypedObjectReferenceInput! + application: TypedObjectReferenceInput datasets: [RAGDatasetInput!] judgeLLM: TypedObjectReferenceInput metrics: [RAGMetricInput!] diff --git a/apiserver/pkg/rag/rag.go b/apiserver/pkg/rag/rag.go index 4ce5a7699..2c4055da8 100644 --- a/apiserver/pkg/rag/rag.go +++ b/apiserver/pkg/rag/rag.go @@ -33,6 +33,7 @@ import ( evav1alpha1 "github.com/kubeagi/arcadia/api/evaluation/v1alpha1" "github.com/kubeagi/arcadia/apiserver/graph/generated" "github.com/kubeagi/arcadia/apiserver/pkg/common" + graphqlutils "github.com/kubeagi/arcadia/apiserver/pkg/utils" "github.com/kubeagi/arcadia/pkg/utils" ) @@ -334,6 +335,104 @@ func CreateRAG(ctx context.Context, kubeClient dynamic.Interface, input *generat return rag2model(u1) } +func UpdateRAG(ctx context.Context, kubeClient dynamic.Interface, input *generated.UpdateRAGInput) (*generated.Rag, error) { + obj, err := kubeClient.Resource(common.SchemaOf(&common.ArcadiaAPIGroup, "RAG")).Namespace(input.Namespace).Get(ctx, input.Name, v1.GetOptions{}) + if err != nil { + return nil, err + } + rag := &evav1alpha1.RAG{} + if err := utils.UnstructuredToStructured(obj, rag); err != nil { + return nil, err + } + + if input.Labels != nil { + rag.SetLabels(graphqlutils.MapAny2Str(input.Labels)) + } + if input.Annotations != nil { + rag.SetAnnotations(graphqlutils.MapAny2Str(input.Annotations)) + } + if input.DisplayName != nil { + rag.Spec.DisplayName = *input.DisplayName + } + if input.Description != nil { + rag.Spec.Description = *input.Description + } + if input.Application != nil { + rag.Spec.Application = &v1alpha1.TypedObjectReference{ + APIGroup: input.Application.APIGroup, + Kind: input.Application.Kind, + Name: input.Application.Name, + Namespace: input.Application.Namespace, + } + } + if input.Datasets != nil { + rag.Spec.Datasets = make([]evav1alpha1.Dataset, 0) + for i, dataset := range input.Datasets { + ds := evav1alpha1.Dataset{ + Source: &v1alpha1.TypedObjectReference{ + APIGroup: input.Datasets[i].Source.APIGroup, + Kind: input.Datasets[i].Source.Kind, + Name: input.Datasets[i].Source.Name, + Namespace: input.Datasets[i].Source.Namespace, + }, + Files: dataset.Files, + } + rag.Spec.Datasets = append(rag.Spec.Datasets, ds) + } + } + if input.JudgeLlm != nil { + rag.Spec.JudgeLLM = &v1alpha1.TypedObjectReference{ + APIGroup: input.JudgeLlm.APIGroup, + Kind: input.JudgeLlm.Kind, + Name: input.JudgeLlm.Name, + Namespace: input.JudgeLlm.Namespace, + } + } + if input.Metrics != nil { + rag.Spec.Metrics = make([]evav1alpha1.Metric, 0) + for _, m := range input.Metrics { + mm := evav1alpha1.Metric{ + Parameters: make([]evav1alpha1.Parameter, 0), + } + if m.MetricKind != nil { + mm.Kind = evav1alpha1.MetricsKind(*m.MetricKind) + } + if m.ToleranceThreshbold != nil { + mm.ToleranceThreshbold = *m.ToleranceThreshbold + } + for _, p := range m.Parameters { + mm.Parameters = append(mm.Parameters, evav1alpha1.Parameter{ + Key: *p.Key, + Value: *p.Value, + }) + } + rag.Spec.Metrics = append(rag.Spec.Metrics, mm) + } + } + + if input.Storage != nil { + rag.Spec.Storage = gen2storage(*input.Storage) + } + if input.Suspend != nil { + rag.Spec.Suspend = *input.Suspend + } + + unstructuredRag, err := runtime.DefaultUnstructuredConverter.ToUnstructured(&rag) + if err != nil { + return nil, err + } + updatedRag, err := common.ResouceUpdate(ctx, kubeClient, generated.TypedObjectReferenceInput{ + APIGroup: &common.ArcadiaAPIGroup, + Kind: "RAG", + Namespace: &rag.Namespace, + Name: rag.Name, + }, unstructuredRag, v1.UpdateOptions{}) + if err != nil { + return nil, err + } + return rag2model(updatedRag) +} + func ListRAG(ctx context.Context, kubeClient dynamic.Interface, input *generated.ListRAGInput) (*generated.PaginatedResult, error) { listOptions := v1.ListOptions{ LabelSelector: fmt.Sprintf("%s=%s", evav1alpha1.EvaluationApplicationLabel, input.AppName),