diff --git a/apiserver/graph/generated/generated.go b/apiserver/graph/generated/generated.go index 288dd9591..7d75da14d 100644 --- a/apiserver/graph/generated/generated.go +++ b/apiserver/graph/generated/generated.go @@ -4441,6 +4441,7 @@ func (e *executableSchema) Exec(ctx context.Context) graphql.ResponseHandler { ec.unmarshalInputPersistentVolumeClaimSpecInput, ec.unmarshalInputRAGDatasetInput, ec.unmarshalInputRAGMetricInput, + ec.unmarshalInputRemoveDuplicateConfig, ec.unmarshalInputResourceInput, ec.unmarshalInputResourcesInput, ec.unmarshalInputSelectorInput, @@ -4984,6 +4985,7 @@ input FileItem { input DataProcessConfigItem { type: String! llm_config: LLMConfigItem + remove_duplicate_config: RemoveDuplicateConfig } # LLM for 数据处理配置条目 @@ -4998,6 +5000,14 @@ input LLMConfigItem { provider: String } +input RemoveDuplicateConfig { + embedding_name: String! + embedding_namespace: String! + embedding_model: String! + embedding_provider: String! + similarity: String! +} + input DeleteDataProcessInput { id: String! } @@ -33534,7 +33544,7 @@ func (ec *executionContext) unmarshalInputDataProcessConfigItem(ctx context.Cont asMap[k] = v } - fieldsInOrder := [...]string{"type", "llm_config"} + fieldsInOrder := [...]string{"type", "llm_config", "remove_duplicate_config"} for _, k := range fieldsInOrder { v, ok := asMap[k] if !ok { @@ -33559,6 +33569,15 @@ func (ec *executionContext) unmarshalInputDataProcessConfigItem(ctx context.Cont return it, err } it.LlmConfig = data + case "remove_duplicate_config": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("remove_duplicate_config")) + data, err := ec.unmarshalORemoveDuplicateConfig2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐRemoveDuplicateConfig(ctx, v) + if err != nil { + return it, err + } + it.RemoveDuplicateConfig = data } } @@ -35205,6 +35224,71 @@ func (ec *executionContext) unmarshalInputRAGMetricInput(ctx context.Context, ob return it, nil } +func (ec *executionContext) unmarshalInputRemoveDuplicateConfig(ctx context.Context, obj interface{}) (RemoveDuplicateConfig, error) { + var it RemoveDuplicateConfig + asMap := map[string]interface{}{} + for k, v := range obj.(map[string]interface{}) { + asMap[k] = v + } + + fieldsInOrder := [...]string{"embedding_name", "embedding_namespace", "embedding_model", "embedding_provider", "similarity"} + for _, k := range fieldsInOrder { + v, ok := asMap[k] + if !ok { + continue + } + switch k { + case "embedding_name": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("embedding_name")) + data, err := ec.unmarshalNString2string(ctx, v) + if err != nil { + return it, err + } + it.EmbeddingName = data + case "embedding_namespace": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("embedding_namespace")) + data, err := ec.unmarshalNString2string(ctx, v) + if err != nil { + return it, err + } + it.EmbeddingNamespace = data + case "embedding_model": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("embedding_model")) + data, err := ec.unmarshalNString2string(ctx, v) + if err != nil { + return it, err + } + it.EmbeddingModel = data + case "embedding_provider": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("embedding_provider")) + data, err := ec.unmarshalNString2string(ctx, v) + if err != nil { + return it, err + } + it.EmbeddingProvider = data + case "similarity": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("similarity")) + data, err := ec.unmarshalNString2string(ctx, v) + if err != nil { + return it, err + } + it.Similarity = data + } + } + + return it, nil +} + func (ec *executionContext) unmarshalInputResourceInput(ctx context.Context, obj interface{}) (ResourceInput, error) { var it ResourceInput asMap := map[string]interface{}{} @@ -45946,6 +46030,14 @@ func (ec *executionContext) marshalORayClusterQuery2ᚖgithubᚗcomᚋkubeagiᚋ return ec._RayClusterQuery(ctx, sel, v) } +func (ec *executionContext) unmarshalORemoveDuplicateConfig2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐRemoveDuplicateConfig(ctx context.Context, v interface{}) (*RemoveDuplicateConfig, error) { + if v == nil { + return nil, nil + } + res, err := ec.unmarshalInputRemoveDuplicateConfig(ctx, v) + return &res, graphql.ErrorOnPath(ctx, err) +} + func (ec *executionContext) marshalOResource2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐResource(ctx context.Context, sel ast.SelectionSet, v *Resource) graphql.Marshaler { if v == nil { return graphql.Null diff --git a/apiserver/graph/generated/models_gen.go b/apiserver/graph/generated/models_gen.go index 2d3e2a542..fc58b4e12 100644 --- a/apiserver/graph/generated/models_gen.go +++ b/apiserver/graph/generated/models_gen.go @@ -408,8 +408,9 @@ type DataProcessConfigChildren struct { } type DataProcessConfigItem struct { - Type string `json:"type"` - LlmConfig *LLMConfigItem `json:"llm_config,omitempty"` + Type string `json:"type"` + LlmConfig *LLMConfigItem `json:"llm_config,omitempty"` + RemoveDuplicateConfig *RemoveDuplicateConfig `json:"remove_duplicate_config,omitempty"` } type DataProcessConfigpreFileProgress struct { @@ -1385,6 +1386,14 @@ type RayClusterQuery struct { ListRayClusters PaginatedResult `json:"listRayClusters"` } +type RemoveDuplicateConfig struct { + EmbeddingName string `json:"embedding_name"` + EmbeddingNamespace string `json:"embedding_namespace"` + EmbeddingModel string `json:"embedding_model"` + EmbeddingProvider string `json:"embedding_provider"` + Similarity string `json:"similarity"` +} + type Resource struct { Limits map[string]interface{} `json:"limits,omitempty"` Requests map[string]interface{} `json:"requests,omitempty"` diff --git a/apiserver/graph/schema/dataprocessing.graphqls b/apiserver/graph/schema/dataprocessing.graphqls index 447d4a56b..0bd8c7197 100644 --- a/apiserver/graph/schema/dataprocessing.graphqls +++ b/apiserver/graph/schema/dataprocessing.graphqls @@ -63,6 +63,7 @@ input FileItem { input DataProcessConfigItem { type: String! llm_config: LLMConfigItem + remove_duplicate_config: RemoveDuplicateConfig } # LLM for 数据处理配置条目 @@ -77,6 +78,14 @@ input LLMConfigItem { provider: String } +input RemoveDuplicateConfig { + embedding_name: String! + embedding_namespace: String! + embedding_model: String! + embedding_provider: String! + similarity: String! +} + input DeleteDataProcessInput { id: String! } diff --git a/deploy/charts/arcadia/templates/pg-init-data-configmap.yaml b/deploy/charts/arcadia/templates/pg-init-data-configmap.yaml index 0a2c03d88..7be0d0193 100644 --- a/deploy/charts/arcadia/templates/pg-init-data-configmap.yaml +++ b/deploy/charts/arcadia/templates/pg-init-data-configmap.yaml @@ -233,6 +233,10 @@ data: file_name character varying(512) COLLATE pg_catalog."default", question text COLLATE pg_catalog."default", answer text COLLATE pg_catalog."default", + question_score character varying(32) COLLATE pg_catalog."default", + answer_score character varying(32) COLLATE pg_catalog."default", + duplicated_flag character varying(32) COLLATE pg_catalog."default", + compare_with_id character varying(32) COLLATE pg_catalog."default", create_datetime character varying(32) COLLATE pg_catalog."default", create_user character varying(32) COLLATE pg_catalog."default", create_program character varying(64) COLLATE pg_catalog."default", @@ -250,6 +254,10 @@ data: COMMENT ON COLUMN public.data_process_task_question_answer_clean.file_name IS '文件名称'; COMMENT ON COLUMN public.data_process_task_question_answer_clean.question IS '问题'; COMMENT ON COLUMN public.data_process_task_question_answer_clean.answer IS '答案'; + COMMENT ON COLUMN public.data_process_task_question_answer_clean.question_score IS 'question向量化后比对分数'; + COMMENT ON COLUMN public.data_process_task_question_answer_clean.answer_score IS 'answer向量化后比对分数'; + COMMENT ON COLUMN public.data_process_task_question_answer_clean.duplicated_flag IS '是否重复'; + COMMENT ON COLUMN public.data_process_task_question_answer_clean.compare_with_id IS '和那条数据进行的比较'; COMMENT ON COLUMN public.data_process_task_question_answer_clean.create_datetime IS '创建时间'; COMMENT ON COLUMN public.data_process_task_question_answer_clean.create_user IS '创建用户'; COMMENT ON COLUMN public.data_process_task_question_answer_clean.create_program IS '创建程序'; @@ -376,6 +384,8 @@ data: id varchar(32), task_id varchar(32), document_id varchar(32), + document_chunk_id varchar(32), + file_name varchar(64), question text, answer text, question_vector vector, diff --git a/pypi/data-processing/db-scripts/init-database-schema.sql b/pypi/data-processing/db-scripts/init-database-schema.sql index 84e802867..9b33ca07c 100644 --- a/pypi/data-processing/db-scripts/init-database-schema.sql +++ b/pypi/data-processing/db-scripts/init-database-schema.sql @@ -228,6 +228,10 @@ file_name character varying(512) COLLATE pg_catalog."default", question text COLLATE pg_catalog."default", answer text COLLATE pg_catalog."default", + question_score character varying(32) COLLATE pg_catalog."default", + answer_score character varying(32) COLLATE pg_catalog."default", + duplicated_flag character varying(32) COLLATE pg_catalog."default", + compare_with_id character varying(32) COLLATE pg_catalog."default", create_datetime character varying(32) COLLATE pg_catalog."default", create_user character varying(32) COLLATE pg_catalog."default", create_program character varying(64) COLLATE pg_catalog."default", @@ -245,6 +249,10 @@ COMMENT ON COLUMN public.data_process_task_question_answer_clean.file_name IS '文件名称'; COMMENT ON COLUMN public.data_process_task_question_answer_clean.question IS '问题'; COMMENT ON COLUMN public.data_process_task_question_answer_clean.answer IS '答案'; + COMMENT ON COLUMN public.data_process_task_question_answer_clean.question_score IS 'question向量化后比对分数'; + COMMENT ON COLUMN public.data_process_task_question_answer_clean.answer_score IS 'answer向量化后比对分数'; + COMMENT ON COLUMN public.data_process_task_question_answer_clean.duplicated_flag IS '是否重复'; + COMMENT ON COLUMN public.data_process_task_question_answer_clean.compare_with_id IS '和那条数据进行的比较'; COMMENT ON COLUMN public.data_process_task_question_answer_clean.create_datetime IS '创建时间'; COMMENT ON COLUMN public.data_process_task_question_answer_clean.create_user IS '创建用户'; COMMENT ON COLUMN public.data_process_task_question_answer_clean.create_program IS '创建程序'; @@ -371,6 +379,8 @@ id varchar(32), task_id varchar(32), document_id varchar(32), + document_chunk_id varchar(32), + file_name varchar(64), question text, answer text, question_vector vector, diff --git a/pypi/data-processing/src/common/log_tag_const.py b/pypi/data-processing/src/common/log_tag_const.py index fe1722c33..a5eed79da 100644 --- a/pypi/data-processing/src/common/log_tag_const.py +++ b/pypi/data-processing/src/common/log_tag_const.py @@ -42,3 +42,7 @@ CONFIG = "Config" WEB_CRAWLING = "Web Url Utils" + +PDF_LOADER = "PDF Loader" +DOCX_LOADER = "Docx Loader" +WEB_LOADER = "Web Loader" \ No newline at end of file diff --git a/pypi/data-processing/src/data_store_process/minio_store_process.py b/pypi/data-processing/src/data_store_process/minio_store_process.py index 69cbf9e14..772a2cb6b 100644 --- a/pypi/data-processing/src/data_store_process/minio_store_process.py +++ b/pypi/data-processing/src/data_store_process/minio_store_process.py @@ -27,7 +27,8 @@ data_process_document_db_operate, data_process_log_db_operate, data_process_stage_log_db_operate) -from file_handle import common_handle, pdf_handle, web_handle, word_handle +from file_handle import common_handle, web_handle, word_handle +from file_handle.pdf_handle import PDFHandle from kube import dataset_cr from utils import date_time_utils, file_utils, json_utils @@ -147,7 +148,7 @@ async def text_manipulate( file_extension = file_utils.get_file_extension(file_name) if file_extension in ["pdf"]: # 处理PDF文件 - result = pdf_handle.pdf_manipulate( + pdf_handle = PDFHandle( chunk_size=req_json.get("chunk_size"), chunk_overlap=req_json.get("chunk_overlap"), file_name=file_name, @@ -157,6 +158,7 @@ async def text_manipulate( task_id=id, create_user=req_json["creator"], ) + result = pdf_handle.handle() elif file_extension in ["docx"]: # 处理.docx文件 @@ -999,7 +1001,7 @@ def _text_manipulate_retry_for_document(document, task_info, log_id, pool, creat document_type = document.get("document_type") if document_type in ["pdf"]: # 处理PDF文件 - result = pdf_handle.pdf_manipulate( + pdf_handle = PDFHandle( file_name=file_name, document_id=document.get("id"), support_type=support_type, @@ -1007,6 +1009,7 @@ def _text_manipulate_retry_for_document(document, task_info, log_id, pool, creat task_id=task_id, create_user=creator, ) + result = pdf_handle.handle() elif document_type in ["docx"]: # 处理.docx文件 diff --git a/pypi/data-processing/src/database_operate/data_process_detail_db_operate.py b/pypi/data-processing/src/database_operate/data_process_detail_db_operate.py index 29c2c80ee..0f6dbd395 100644 --- a/pypi/data-processing/src/database_operate/data_process_detail_db_operate.py +++ b/pypi/data-processing/src/database_operate/data_process_detail_db_operate.py @@ -264,9 +264,10 @@ def top_n_list_qa_for_preview(req_json, pool): update_user, update_program from - public.data_process_task_question_answer + public.data_process_task_question_answer_clean where - task_id = %(task_id)s + task_id = %(task_id)s and + duplicated_flag = '1' order by random() limit 10 """.strip() @@ -356,6 +357,10 @@ def insert_question_answer_clean_info(req_json, pool): "file_name": req_json["file_name"], "question": req_json["question"], "answer": req_json["answer"], + "question_score": req_json["question_score"], + "answer_score": req_json["answer_score"], + "duplicated_flag": req_json["duplicated_flag"], + "compare_with_id": req_json["compare_with_id"], "create_datetime": now, "create_user": user, "create_program": program, @@ -373,6 +378,10 @@ def insert_question_answer_clean_info(req_json, pool): file_name, question, answer, + question_score, + answer_score, + duplicated_flag, + compare_with_id, create_datetime, create_user, create_program, @@ -388,6 +397,10 @@ def insert_question_answer_clean_info(req_json, pool): %(file_name)s, %(question)s, %(answer)s, + %(question_score)s, + %(answer_score)s, + %(duplicated_flag)s, + %(compare_with_id)s, %(create_datetime)s, %(create_program)s, %(create_user)s, @@ -414,22 +427,16 @@ def query_question_answer_list(document_id, pool): sql = """ select - dptqa.id, - dptqa.task_id, - dptqa.document_id, - dptqa.document_chunk_id, - dptqa.file_name, - dptqa.question, - dptqa.answer, - dptdc.content, - dptdc.meta_info, - dptdc.page_number - from public.data_process_task_question_answer dptqa - left join public.data_process_task_document_chunk dptdc - on - dptdc.id = dptqa.document_chunk_id + id, + task_id, + document_id, + document_chunk_id, + file_name, + question, + answer + from public.data_process_task_question_answer where - dptqa.document_id = %(document_id)s + document_id = %(document_id)s """.strip() res = postgresql_pool_client.execute_query(pool, sql, params) @@ -539,3 +546,38 @@ def delete_transform_by_document_chunk(req_json, pool): res = postgresql_pool_client.execute_update(pool, sql, params) return res + +def query_question_answer_clean_list(document_id, pool): + """List question answer with document id. + + req_json is a dictionary object. for example: + { + "document_id": "01HGWBE48DT3ADE9ZKA62SW4WS" + } + pool: databasec connection pool; + """ + params = {"document_id": document_id} + + sql = """ + select + dptqac.id, + dptqac.task_id, + dptqac.document_id, + dptqac.document_chunk_id, + dptqac.file_name, + dptqac.question, + dptqac.answer, + dptdc.content, + dptdc.meta_info, + dptdc.page_number + from public.data_process_task_question_answer_clean dptqac + left join public.data_process_task_document_chunk dptdc + on + dptdc.id = dptqac.document_chunk_id + where + dptqac.document_id = %(document_id)s and + dptqac.duplicated_flag = '1' + """.strip() + + res = postgresql_pool_client.execute_query(pool, sql, params) + return res diff --git a/pypi/data-processing/src/database_operate/dp_document_qa_remove_duplicate_db_operate.py b/pypi/data-processing/src/database_operate/dp_document_qa_remove_duplicate_db_operate.py index 3d783a7d9..f84c2928f 100644 --- a/pypi/data-processing/src/database_operate/dp_document_qa_remove_duplicate_db_operate.py +++ b/pypi/data-processing/src/database_operate/dp_document_qa_remove_duplicate_db_operate.py @@ -15,17 +15,20 @@ from database_clients import postgresql_pool_client + def add( params, pool ): """Add a new record""" - + sql = """ insert into public.data_process_task_question_answer_remove_duplicate_tmp ( id, task_id, document_id, + document_chunk_id, + file_name, question, question_vector, answer, @@ -36,6 +39,8 @@ def add( %(id)s, %(task_id)s, %(document_id)s, + %(document_chunk_id)s, + %(file_name)s, %(question)s, %(question_vector)s, %(answer)s, @@ -58,18 +63,20 @@ def filter_by_distance( id, task_id, document_id, + document_chunk_id, + file_name, question, answer, - (q1.question_vector <#> q2.question_vector) * -1 as question_distance, - (q1.answer_vector <#> q2.answer_vector) * -1 as answer_distance + 1 - (q1.question_vector <=> q2.question_vector) as question_distance, + 1 - (q1.answer_vector <=> q2.answer_vector) as answer_distance from - data_process_task_question_answer_remove_duplicate_tmp q1, - (select + data_process_task_question_answer_remove_duplicate_tmp q1, + (select question_vector, - answer_vector - from + answer_vector + from data_process_task_question_answer_remove_duplicate_tmp - where + where id = %(id)s limit 1) q2 where diff --git a/pypi/data-processing/src/document_chunks/base.py b/pypi/data-processing/src/document_chunks/base.py new file mode 100644 index 000000000..76fd6ecbc --- /dev/null +++ b/pypi/data-processing/src/document_chunks/base.py @@ -0,0 +1,26 @@ +# Copyright 2024 KubeAGI. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import List + +from langchain_core.documents import Document + + +class TextSplitter(ABC): + """Interface for splitting text into chunks.""" + + @abstractmethod + def split_documents(self, documents: List[Document]) -> List[Document]: + """Split document.""" \ No newline at end of file diff --git a/pypi/data-processing/src/document_chunks/spacy_text_splitter.py b/pypi/data-processing/src/document_chunks/spacy_text_splitter.py new file mode 100644 index 000000000..df05fe41b --- /dev/null +++ b/pypi/data-processing/src/document_chunks/spacy_text_splitter.py @@ -0,0 +1,50 @@ +# Copyright 2024 KubeAGI. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +from langchain_core.documents import Document + +from document_chunks.base import TextSplitter + + +class SpacyTextSplitter(TextSplitter): + def __init__( + self, + separator: str = "\n\n", + pipeline: str = "zh_core_web_sm", + chunk_size: int = 500, + chunk_overlap: int = 10, + ): + """Initialize the spacy text splitter.""" + if chunk_overlap > chunk_size: + raise ValueError( + f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " + f"({chunk_size}), should be smaller." + ) + self._chunk_size = chunk_size + self._chunk_overlap = chunk_overlap + self._separator = separator + self._pipeline = pipeline + + def split_documents(self, documents: List[Document]) -> List[Document]: + from langchain.text_splitter import SpacyTextSplitter + text_splitter = SpacyTextSplitter( + separator=self._separator, + pipeline=self._pipeline, + chunk_size=self._chunk_size, + chunk_overlap=self._chunk_overlap, + ) + documents = text_splitter.split_documents(documents) + return documents diff --git a/pypi/data-processing/src/document_loaders/async_playwright.py b/pypi/data-processing/src/document_loaders/async_playwright.py index f7e6ae6ea..78754cd15 100644 --- a/pypi/data-processing/src/document_loaders/async_playwright.py +++ b/pypi/data-processing/src/document_loaders/async_playwright.py @@ -17,7 +17,6 @@ import traceback from typing import List -import playwright from langchain_community.document_transformers import Html2TextTransformer from langchain_core.documents import Document from playwright.async_api import async_playwright @@ -46,9 +45,6 @@ def __init__( max_count (int): Maximum Number of Website URLs. max_depth (int): Website Crawling Depth. interval_time (int): Interval Time. - - Raises: - ImportError: If the required 'playwright' package is not installed. """ if max_count is None: max_count = 100 @@ -73,8 +69,8 @@ async def ascrape_playwright(self, url: str) -> str: str: The scraped HTML content or an error message if an exception occurs. """ - logger.info("Starting scraping...") + results = "" async with async_playwright() as p: browser = await p.chromium.launch(headless=True) @@ -97,11 +93,13 @@ async def load(self) -> List[Document]: containing the scraped content from each URL. """ + logger.info(f"{log_tag_const.WEB_LOADER} Async start to load Website data") + docs = [] all_url = await self.get_all_url() for url in all_url: html_content = await self.ascrape_playwright(url) - metadata = {"source": url} + metadata = {"source": url, "page": 0} docs.append(Document(page_content=html_content, metadata=metadata)) html2text = Html2TextTransformer() @@ -131,7 +129,7 @@ async def get_all_url(self): all_url = [self._url] sub_urls = [self._url] try: - for i in range(1, self._max_depth): + for _ in range(1, self._max_depth): for sub_url in sub_urls: children_urls = await self._get_children_url( url=sub_url, diff --git a/pypi/data-processing/src/document_loaders/docx.py b/pypi/data-processing/src/document_loaders/docx.py new file mode 100644 index 000000000..f2b91961e --- /dev/null +++ b/pypi/data-processing/src/document_loaders/docx.py @@ -0,0 +1,63 @@ +# Copyright 2024 KubeAGI. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import List + +import docx +from langchain_core.documents import Document + +from common import log_tag_const +from document_loaders.base import BaseLoader +from utils import file_utils + +logger = logging.getLogger(__name__) + +class DocxLoader(BaseLoader): + """Load docx files.""" + + def __init__( + self, + file_path: str, + ): + """ + Initialize the loader with a list of URL paths. + + Args: + file_path (str): File Path. + """ + self._file_path = file_path + + def load(self) -> List[Document]: + """ + Load and return all Documents from the docx file. + + Returns: + List[Document]: A list of Document objects. + + """ + logger.info(f"{log_tag_const.DOCX_LOADER} Start to load docx file") + + # Get file name + file_name = file_utils.get_file_name(self._file_path) + + docs = [] + doc = docx.Document(self._file_path) + for i in range(len(doc.paragraphs)): + para = doc.paragraphs[i] + content = para.text + metadata = {"source": file_name, "page": i} + docs.append(Document(page_content=content, metadata=metadata)) + + return docs diff --git a/pypi/data-processing/src/document_loaders/pdf.py b/pypi/data-processing/src/document_loaders/pdf.py new file mode 100644 index 000000000..6ac5301dc --- /dev/null +++ b/pypi/data-processing/src/document_loaders/pdf.py @@ -0,0 +1,60 @@ +# Copyright 2024 KubeAGI. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import List + +from langchain.document_loaders import PyPDFLoader +from langchain_core.documents import Document + +from common import log_tag_const +from document_loaders.base import BaseLoader +from utils import file_utils + +logger = logging.getLogger(__name__) + +class PDFLoader(BaseLoader): + """Load pdf file.""" + + def __init__( + self, + file_path: str, + ): + """ + Initialize the loader with a list of URL paths. + + Args: + file_path (str): File Path. + """ + self._file_path = file_path + + def load(self) -> List[Document]: + """ + Load and return all Documents from the docx file. + + Returns: + List[Document]: A list of Document objects. + + """ + logger.info(f"{log_tag_const.PDF_LOADER} Start to load pdf file") + + # Get file name + file_name = file_utils.get_file_name(self._file_path) + + pdf_loader = PyPDFLoader(self._file_path) + documents = pdf_loader.load() + for document in documents: + document.metadata["source"] = file_name + + return documents diff --git a/pypi/data-processing/src/embeddings/embeddings.py b/pypi/data-processing/src/embeddings/embeddings.py index 52a097ccf..a4e576b70 100644 --- a/pypi/data-processing/src/embeddings/embeddings.py +++ b/pypi/data-processing/src/embeddings/embeddings.py @@ -15,6 +15,7 @@ from abc import ABC, abstractmethod from typing import List + class Embeddings(ABC): """Interface for embedding models.""" diff --git a/pypi/data-processing/src/embeddings/openai_embeddings.py b/pypi/data-processing/src/embeddings/openai_embeddings.py index e3cd30960..258c56991 100644 --- a/pypi/data-processing/src/embeddings/openai_embeddings.py +++ b/pypi/data-processing/src/embeddings/openai_embeddings.py @@ -37,7 +37,7 @@ def __init__( Raises: ImportError: If the required 'openai' package is not installed. """ - + self.base_url = base_url self.api_key = api_key self.model = model @@ -53,7 +53,7 @@ def __init__( base_url=base_url, api_key=api_key ) - + def embed_documents(self, texts: List[str]) -> List[List[float]]: """Embed search texts.""" logger.debug(texts) diff --git a/pypi/data-processing/src/file_handle/base.py b/pypi/data-processing/src/file_handle/base.py new file mode 100644 index 000000000..81744aa0d --- /dev/null +++ b/pypi/data-processing/src/file_handle/base.py @@ -0,0 +1,26 @@ +# Copyright 2024 KubeAGI. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + + +class BaseHandle(ABC): + """Interface for Document Handle. + + The `handle` method will remain as is for backwards compatibility. + """ + + @abstractmethod + def handle(self): + """handle document.""" diff --git a/pypi/data-processing/src/file_handle/common_handle.py b/pypi/data-processing/src/file_handle/common_handle.py index 56aaa0533..ab7b96a24 100644 --- a/pypi/data-processing/src/file_handle/common_handle.py +++ b/pypi/data-processing/src/file_handle/common_handle.py @@ -24,10 +24,12 @@ from database_operate import (data_process_detail_db_operate, data_process_document_chunk_db_operate, data_process_document_db_operate) +from embeddings.openai_embeddings import OpenAIEmbeddings from kube import model_cr from llm_api_service.qa_provider_open_ai import QAProviderOpenAI from llm_api_service.qa_provider_zhi_pu_ai_online import \ QAProviderZhiPuAIOnline +from service.data_process_qa_remove_duplicate import QARemoveDuplicate from transform.text import clean_transform, privacy_transform from utils import csv_utils, date_time_utils, file_utils, json_utils @@ -119,48 +121,67 @@ def text_manipulate( id=document_id, status="success", conn_pool=conn_pool ) - # 通过documentId查询生成的所有QA数据 - qa_list = data_process_detail_db_operate.query_question_answer_list( - document_id=document_id, pool=conn_pool - ) + if support_type_map.get("qa_split"): + # 是否选择了QA拆分 + qa_list_dict = support_type_map.get("qa_split") + remove_duplicate_config = qa_list_dict.get("remove_duplicate_config") + if remove_duplicate_config: + # 进行了QA去重配置 + logger.debug(f"{log_tag_const.QA_SPLIT} Start to QA remove duplicate.") + remove_duplicate_response = _remove_duplicate( + document_id=document_id, + remove_duplicate_config=remove_duplicate_config, + conn_pool=conn_pool, + create_user=create_user + ) - qa_data_dict = [["q", "a", "file_name", "page_number", "chunk_content"]] - for item in qa_list.get("data"): - meta_info = item.get("meta_info") - if meta_info: - meta_json = json_utils.loads(meta_info) - meta_source = meta_json.get("source") - else: - meta_source = item.get("file_name") + if remove_duplicate_response.get("status") != 200: + return remove_duplicate_response - qa_data_dict.append( - [ - item.get("question"), - item.get("answer"), - meta_source, - item.get("page_number"), - item.get("content"), - ] + # 通过documentId查询生成的所有QA数据 + qa_list = data_process_detail_db_operate.query_question_answer_clean_list( + document_id=document_id, pool=conn_pool ) - # Save the csv file. - file_name_without_extension = file_utils.get_file_name_without_extension( - file_name - ) - file_name_csv = file_name_without_extension + ".csv" - csv_utils.save_csv( - file_name=file_name_csv, phase_value="final", data=qa_data_dict - ) + qa_data_dict = [["q", "a", "file_name", "page_number", "chunk_content"]] + for item in qa_list.get("data"): + meta_info = item.get("meta_info") + if meta_info: + meta_json = json_utils.loads(meta_info) + meta_source = meta_json.get("source") + else: + meta_source = item.get("file_name") + + qa_data_dict.append( + [ + item.get("question"), + item.get("answer"), + meta_source, + item.get("page_number"), + item.get("content"), + ] + ) - logger.debug(f"{log_tag_const.COMMON_HANDLE} Finish manipulating the text") - return { - "status": 200, - "message": "", - "data": { - "object_name": file_name_csv, - "object_count": len(qa_list.get("data")), - }, - } + # Save the csv file. + file_name_without_extension = file_utils.get_file_name_without_extension( + file_name + ) + file_name_csv = file_name_without_extension + ".csv" + csv_utils.save_csv( + file_name=file_name_csv, phase_value="final", data=qa_data_dict + ) + + logger.debug(f"{log_tag_const.COMMON_HANDLE} Finish manipulating the text") + return { + "status": 200, + "message": "", + "data": { + "object_name": file_name_csv, + "object_count": len(qa_list.get("data")), + }, + } + + return {"status": 200, "message": "", "data": ""} except Exception as ex: logger.error( "".join( @@ -883,10 +904,6 @@ def _qa_split( qa_insert_item, pool=conn_pool ) - data_process_detail_db_operate.insert_question_answer_clean_info( - qa_insert_item, pool=conn_pool - ) - # 更新data_process_task_document_chunk中的状态 _updata_document_chunk_status_and_end_time( id=document_chunk_id, @@ -1172,3 +1189,111 @@ def _updata_document_chunk_status_and_end_time(id, status, update_user, conn_poo ) ) return {"status": 1000, "message": str(ex), "data": traceback.format_exc()} + + +def _remove_duplicate(document_id, remove_duplicate_config, conn_pool, create_user): + # 通过documentId查询生成的所有QA数据 + qa_list = data_process_detail_db_operate.query_question_answer_list( + document_id=document_id, pool=conn_pool + ) + + remove_duplicate_res = _qa_remove_duplicate( + qa_list=qa_list.get("data"), + remove_duplicate_config=remove_duplicate_config, + conn_pool=conn_pool, + ) + if remove_duplicate_res.get("status") != 200: + # 更新data_process_task_document中的文件状态 + _updata_document_status_and_end_time( + id=document_id, status="fail", conn_pool=conn_pool + ) + return remove_duplicate_res + + # 将QA去重的数据存入question_answer_clean表中 + qa_data = remove_duplicate_res.get("data") + for _, item in enumerate(qa_data): + duplicated_flag = 1 + if item.get("duplicated_flag") is not None: + duplicated_flag = item.get("duplicated_flag") + qa_insert_item = { + "id": item.get("id"), + "task_id": item.get("task_id"), + "document_id": item.get("document_id"), + "document_chunk_id": item.get("document_chunk_id"), + "file_name": item.get("file_name"), + "question": item.get("question"), + "answer": item.get("answer"), + "question_score": item.get("question_distance"), + "answer_score": item.get("answer_distance"), + "duplicated_flag": duplicated_flag, + "compare_with_id": item.get("compare_with_id"), + "create_user": create_user, + } + data_process_detail_db_operate.insert_question_answer_clean_info( + qa_insert_item, pool=conn_pool + ) + return remove_duplicate_res + +def _qa_remove_duplicate(qa_list, remove_duplicate_config, conn_pool): + name = remove_duplicate_config.get("embedding_name") + namespace = remove_duplicate_config.get("embedding_namespace") + model = remove_duplicate_config.get("embedding_model") + provider = remove_duplicate_config.get("embedding_provider") + similarity = float(remove_duplicate_config.get("similarity")) + + # llms cr 中模型相关信息 + llm_spec_info = model_cr.get_spec_for_embedding_k8s_cr(name=name, namespace=namespace) + + if provider == "worker": + # get base url for configmap + base_url = model_cr.get_worker_base_url_k8s_configmap( + name=config.k8s_default_config, namespace=config.k8s_pod_namespace + ) + logger.debug( + "".join( + [ + f"worker embedding \n", + f"name: {name}\n", + f"namespace: {namespace}\n", + f"model: {model}\n", + f"base_url: {base_url}\n", + ] + ) + ) + + qa_embeddings = OpenAIEmbeddings( + api_key="fake", + base_url=base_url, + model=model, + ) + + remove_duplicate_loader = QARemoveDuplicate(embeddings=qa_embeddings, pool=conn_pool) + return remove_duplicate_loader.qa_remove_duplicate(qa_list, similarity) + else: + endpoint = llm_spec_info.get("data").get("provider").get("endpoint") + base_url = endpoint.get("url") + llm_type = llm_spec_info.get("data").get("type") + + logger.debug( + "".join( + [ + f"3rd_party embedding \n", + f"name: {name}\n", + f"namespace: {namespace}\n", + f"model: {model}\n", + f"llm_type: {llm_type}\n", + ] + ) + ) + + if llm_type == "openai": + qa_embeddings = OpenAIEmbeddings( + api_key="fake", + base_url=base_url, + model=model, + ) + + remove_duplicate_loader = QARemoveDuplicate(embeddings=qa_embeddings, pool=conn_pool) + return remove_duplicate_loader.qa_remove_duplicate(qa_list, similarity) + else: + return {"status": 1000, "message": f"暂时不支持{llm_type}类型的向量化模型模型", "data": ""} diff --git a/pypi/data-processing/src/file_handle/pdf_handle.py b/pypi/data-processing/src/file_handle/pdf_handle.py index 28f6e9cfc..85975d3e9 100644 --- a/pypi/data-processing/src/file_handle/pdf_handle.py +++ b/pypi/data-processing/src/file_handle/pdf_handle.py @@ -17,113 +17,126 @@ import traceback import ulid -from langchain.document_loaders import PyPDFLoader -from langchain.text_splitter import SpacyTextSplitter from common import log_tag_const from common.config import config from database_operate import data_process_document_chunk_db_operate +from document_chunks.spacy_text_splitter import SpacyTextSplitter +from document_loaders.pdf import PDFLoader from file_handle import common_handle +from file_handle.base import BaseHandle from utils import file_utils, json_utils logger = logging.getLogger(__name__) -def pdf_manipulate( - file_name, - document_id, - support_type, - conn_pool, - task_id, - create_user, - chunk_size=None, - chunk_overlap=None, -): - """Manipulate the text content from a pdf file. - - file_name: file name; - support_type: support type; - conn_pool: database connection pool; - task_id: data process task id; - chunk_size: chunk size; - chunk_overlap: chunk overlap; - """ - - logger.debug(f"{log_tag_const.PDF_HANDLE} Start to manipulate the text in pdf") - - try: - pdf_file_path = file_utils.get_temp_file_path() - file_path = pdf_file_path + "original/" + file_name - - # Text splitter - documents = _get_documents_by_langchain( - chunk_size=chunk_size, chunk_overlap=chunk_overlap, file_path=file_path - ) - - # step 2 - # save all chunk info to database - all_document_for_process = [] - for document in documents: - chunck_id = ulid.ulid() - page = document.metadata.get("page") + 1 - content = document.page_content.replace("\n", "") - meta_info = document.metadata - meta_info["source"] = file_name - chunk_insert_item = { - "id": chunck_id, - "document_id": document_id, - "task_id": task_id, - "status": "not_start", - "content": content, - "meta_info": json_utils.dumps(meta_info), - "page_number": page, - "creator": create_user, - } - all_document_for_process.append(chunk_insert_item) - - data_process_document_chunk_db_operate.add( - chunk_insert_item, pool=conn_pool +class PDFHandle(BaseHandle): + def __init__( + self, + file_name, + document_id, + support_type, + conn_pool, + task_id, + create_user, + chunk_size=None, + chunk_overlap=None, + ): + """ + Initialize the pdf handle. + + Args: + file_name: file name. + document_id: document id. + support_type: data processing support config type. + conn_pool: PostgreSQL connect pool. + task_id: data processing task id. + create_user: create user. + chunk_size: chunk size. + chunk_overlap: chunk overlap. + """ + if chunk_size is None: + chunk_size = config.knowledge_chunk_size + + if chunk_overlap is None: + chunk_overlap = config.knowledge_chunk_overlap + self._file_name = file_name + self._document_id = document_id + self._support_type = support_type + self._conn_pool = conn_pool + self._task_id = task_id + self._create_user = create_user + self._chunk_size = chunk_size + self._chunk_overlap = chunk_overlap + + def handle( + self, + ): + """handle the text content from a pdf file.""" + + logger.debug(f"{log_tag_const.PDF_HANDLE} Start to handle the text in pdf") + + try: + pdf_file_path = file_utils.get_temp_file_path() + file_path = pdf_file_path + "original/" + self._file_name + + # Text splitter + documents = self._get_documents(file_path=file_path) + + # step 2 + # save all chunk info to database + all_document_for_process = [] + for document in documents: + chunck_id = ulid.ulid() + content = document.page_content.replace("\n", "") + chunk_insert_item = { + "id": chunck_id, + "document_id": self._document_id, + "task_id": self._task_id, + "status": "not_start", + "content": content, + "meta_info": json_utils.dumps(document.metadata), + "page_number": document.metadata.get("page") + 1, + "creator": self._create_user, + } + all_document_for_process.append(chunk_insert_item) + + data_process_document_chunk_db_operate.add( + chunk_insert_item, pool=self._conn_pool + ) + + response = common_handle.text_manipulate( + file_name=self._file_name, + all_document_for_process=all_document_for_process, + support_type=self._support_type, + conn_pool=self._conn_pool, + create_user=self._create_user, ) - response = common_handle.text_manipulate( - file_name=file_name, - all_document_for_process=all_document_for_process, - support_type=support_type, - conn_pool=conn_pool, - create_user=create_user, - ) - - return response - except Exception as ex: - logger.error( - "".join( - [ - f"{log_tag_const.PDF_HANDLE} There is an error when manipulate ", - f"the text in pdf handler. \n{traceback.format_exc()}", - ] + return response + except Exception as ex: + logger.error( + "".join( + [ + f"{log_tag_const.PDF_HANDLE} There is an error when manipulate ", + f"the text in pdf handler. \n{traceback.format_exc()}", + ] + ) ) - ) - logger.debug(f"{log_tag_const.PDF_HANDLE} Finish manipulating the text in pdf") - return {"status": 400, "message": str(ex), "data": traceback.format_exc()} - + logger.debug(f"{log_tag_const.PDF_HANDLE} Finish manipulating the text in pdf") + return {"status": 400, "message": str(ex), "data": traceback.format_exc()} -def _get_documents_by_langchain(chunk_size, chunk_overlap, file_path): - # Split the text. - if chunk_size is None: - chunk_size = config.knowledge_chunk_size - if chunk_overlap is None: - chunk_overlap = config.knowledge_chunk_overlap + def _get_documents(self, file_path): + pdf_loader = PDFLoader(file_path) + docs = pdf_loader.load() - source_reader = PyPDFLoader(file_path) - pdf_pages = source_reader.load() - - text_splitter = SpacyTextSplitter( - separator="\n\n", - pipeline="zh_core_web_sm", - chunk_size=int(chunk_size), - chunk_overlap=int(chunk_overlap), - ) - documents = text_splitter.split_documents(pdf_pages) + text_splitter = SpacyTextSplitter( + separator="\n\n", + pipeline="zh_core_web_sm", + chunk_size=int(self._chunk_size), + chunk_overlap=int(self._chunk_overlap), + ) + documents = text_splitter.split_documents(docs) - return documents + return documents diff --git a/pypi/data-processing/src/file_handle/web_handle.py b/pypi/data-processing/src/file_handle/web_handle.py index 460fbd6a8..3c6fe5b34 100644 --- a/pypi/data-processing/src/file_handle/web_handle.py +++ b/pypi/data-processing/src/file_handle/web_handle.py @@ -17,11 +17,11 @@ import ujson import ulid -from langchain.text_splitter import SpacyTextSplitter from common import log_tag_const from common.config import config from database_operate import data_process_document_chunk_db_operate +from document_chunks.spacy_text_splitter import SpacyTextSplitter from document_loaders.async_playwright import AsyncPlaywrightLoader from file_handle import common_handle from utils import file_utils, json_utils @@ -54,7 +54,7 @@ async def web_manipulate( file_path = pdf_file_path + "original/" + file_name # Text splitter - documents = await _get_documents_by_langchain( + documents = await _get_documents( chunk_size=chunk_size, chunk_overlap=chunk_overlap, file_path=file_path ) @@ -71,7 +71,7 @@ async def web_manipulate( "status": "not_start", "content": content, "meta_info": json_utils.dumps(document.metadata), - "page_number": "1", + "page_number": document.metadata.get("page") + 1, "creator": create_user, } all_document_for_process.append(chunk_insert_item) @@ -102,7 +102,7 @@ async def web_manipulate( return {"status": 400, "message": str(ex), "data": traceback.format_exc()} -async def _get_documents_by_langchain(chunk_size, chunk_overlap, file_path): +async def _get_documents(chunk_size, chunk_overlap, file_path): # Split the text. if chunk_size is None: chunk_size = config.knowledge_chunk_size diff --git a/pypi/data-processing/src/file_handle/word_handle.py b/pypi/data-processing/src/file_handle/word_handle.py index 5251a1299..bd3e10722 100644 --- a/pypi/data-processing/src/file_handle/word_handle.py +++ b/pypi/data-processing/src/file_handle/word_handle.py @@ -17,13 +17,14 @@ import traceback import ulid -from langchain.text_splitter import SpacyTextSplitter from common import log_tag_const from common.config import config from database_operate import data_process_document_chunk_db_operate +from document_chunks.spacy_text_splitter import SpacyTextSplitter +from document_loaders.docx import DocxLoader from file_handle import common_handle -from utils import docx_utils, file_utils +from utils import file_utils, json_utils logger = logging.getLogger(__name__) @@ -55,7 +56,7 @@ def docx_manipulate( file_path = word_file_path + "original/" + file_name # Text splitter - documents = _get_documents_by_langchain( + documents = _get_documents( chunk_size=chunk_size, chunk_overlap=chunk_overlap, file_path=file_path ) @@ -64,15 +65,15 @@ def docx_manipulate( all_document_for_process = [] for document in documents: chunck_id = ulid.ulid() - content = document.replace("\n", "") + content = document.page_content.replace("\n", "") chunk_insert_item = { "id": chunck_id, "document_id": document_id, "task_id": task_id, "status": "not_start", "content": content, - "meta_info": "", - "page_number": "", + "meta_info": json_utils.dumps(document.metadata), + "page_number": document.metadata.get("page") + 1, "creator": create_user, } all_document_for_process.append(chunk_insert_item) @@ -105,7 +106,7 @@ def docx_manipulate( return {"status": 400, "message": str(ex), "data": traceback.format_exc()} -def _get_documents_by_langchain(chunk_size, chunk_overlap, file_path): +def _get_documents(chunk_size, chunk_overlap, file_path): # Split the text. if chunk_size is None: chunk_size = config.knowledge_chunk_size @@ -113,13 +114,14 @@ def _get_documents_by_langchain(chunk_size, chunk_overlap, file_path): if chunk_overlap is None: chunk_overlap = config.knowledge_chunk_overlap - content = docx_utils.get_content(file_path) + docx_loader = DocxLoader(file_path) + docs = docx_loader.load() text_splitter = SpacyTextSplitter( separator="\n\n", pipeline="zh_core_web_sm", chunk_size=int(chunk_size), chunk_overlap=int(chunk_overlap), ) - documents = text_splitter.split_text(content) + documents = text_splitter.split_documents(docs) return documents diff --git a/pypi/data-processing/src/kube/client.py b/pypi/data-processing/src/kube/client.py index cc488d993..8917aacc4 100644 --- a/pypi/data-processing/src/kube/client.py +++ b/pypi/data-processing/src/kube/client.py @@ -24,6 +24,7 @@ from .custom_resources import (arcadia_resource_datasets, arcadia_resource_datasources, + arcadia_resource_embedding, arcadia_resource_models, arcadia_resource_versioneddatasets) @@ -143,3 +144,12 @@ def get_datasource_object(self, namespace: str, name: str): plural=arcadia_resource_datasources.get_name(), name=name, ) + + def get_versionedembedding_status(self, namespace: str, name: str): + return CustomObjectsApi().get_namespaced_custom_object_status( + arcadia_resource_embedding.get_group(), + arcadia_resource_embedding.get_version(), + namespace, + arcadia_resource_embedding.get_name(), + name, + ) diff --git a/pypi/data-processing/src/kube/custom_resources.py b/pypi/data-processing/src/kube/custom_resources.py index 495d5a771..4e6ab5c09 100644 --- a/pypi/data-processing/src/kube/custom_resources.py +++ b/pypi/data-processing/src/kube/custom_resources.py @@ -44,3 +44,5 @@ def get_name(self): arcadia_resource_models = CustomResource(arcadia_group, "llms") # CRD Versioneddataset arcadia_resource_versioneddatasets = CustomResource(arcadia_group, "versioneddatasets") +# CRD Embedding +arcadia_resource_embedding = CustomResource(arcadia_group, "embedders") diff --git a/pypi/data-processing/src/kube/dataset_cr.py b/pypi/data-processing/src/kube/dataset_cr.py index a6c61b17f..70d060189 100644 --- a/pypi/data-processing/src/kube/dataset_cr.py +++ b/pypi/data-processing/src/kube/dataset_cr.py @@ -72,4 +72,3 @@ def update_dataset_k8s_cr(namespace, version_data_set_name, reason, message): except Exception as ex: logger.error(str(ex)) return {"status": 400, "message": "更新数据集状态失败", "data": ""} - diff --git a/pypi/data-processing/src/kube/model_cr.py b/pypi/data-processing/src/kube/model_cr.py index 573b45750..4170ebb8d 100644 --- a/pypi/data-processing/src/kube/model_cr.py +++ b/pypi/data-processing/src/kube/model_cr.py @@ -107,3 +107,21 @@ def get_llm_qa_retry_count_in_k8s_configmap(namespace, config_map_name): ) return None + +def get_spec_for_embedding_k8s_cr(name, namespace): + """get embedding. + + name: model name; + namespace: namespace; + """ + try: + kube = client.KubeEnv() + + one_cr_llm = kube.get_versionedembedding_status(namespace=namespace, name=name) + + provider = one_cr_llm["spec"] + + return {"status": 200, "message": "获取embedding中的provider成功", "data": provider} + except Exception as ex: + logger.error(str(ex)) + return {"status": 400, "message": "获取embedding中的provider失败", "data": ""} \ No newline at end of file diff --git a/pypi/data-processing/src/service/data_process_qa_remove_duplicate.py b/pypi/data-processing/src/service/data_process_qa_remove_duplicate.py index 5c0f14d39..ef290ad19 100644 --- a/pypi/data-processing/src/service/data_process_qa_remove_duplicate.py +++ b/pypi/data-processing/src/service/data_process_qa_remove_duplicate.py @@ -13,6 +13,8 @@ # limitations under the License. import logging +import traceback + from database_operate import dp_document_qa_remove_duplicate_db_operate from utils import date_time_utils @@ -44,28 +46,44 @@ def _import_qa_embedding_data( qa_pairs: QA datasets """ logger.debug(f"Starting to QA vectorize: {qa_pairs}") - texts = [] - for qa in qa_pairs: - texts.append(qa["question"]) - texts.append(qa["answer"]) - embeddings = self.embeddings.embed_documents(texts) - logger.debug(f"completed QA vectorize") - for index, qa_pair in enumerate(qa_pairs): - create_datetime = date_time_utils.now_str() - params = { - "id": qa_pair.get("id"), - "task_id": qa_pair.get("task_id"), - "document_id": qa_pair.get("document_id"), - "question": qa_pair.get("question"), - "question_vector": embeddings.data[index * 2].embedding, - "answer": qa_pair.get("answer"), - "answer_vector": embeddings.data[index * 2 + 1].embedding, - "create_datetime": create_datetime - } - dp_document_qa_remove_duplicate_db_operate.add( - params, - self.pool + + try: + texts = [] + for qa in qa_pairs: + texts.append(qa["question"]) + texts.append(qa["answer"]) + embeddings = self.embeddings.embed_documents(texts) + logger.debug(f"completed QA vectorize") + for index, qa_pair in enumerate(qa_pairs): + create_datetime = date_time_utils.now_str() + params = { + "id": qa_pair.get("id"), + "task_id": qa_pair.get("task_id"), + "document_id": qa_pair.get("document_id"), + "document_chunk_id": qa_pair.get("document_chunk_id"), + "file_name": qa_pair.get("file_name"), + "question": qa_pair.get("question"), + "question_vector": embeddings.data[index * 2].embedding, + "answer": qa_pair.get("answer"), + "answer_vector": embeddings.data[index * 2 + 1].embedding, + "create_datetime": create_datetime + } + dp_document_qa_remove_duplicate_db_operate.add( + params, + self.pool + ) + + return {"status": 200, "message": "", "data": ""} + except Exception as ex: + logger.error( + "".join( + [ + f"qa embedding fail\n", + f"The tracing error is: \n{traceback.format_exc()}\n", + ] + ) ) + return {"status": 1000, "message": "QA数据向量化失败,请检查向量化模型是否正常!", "data": ""} def _remove_qa_embedding_data_by_id( self, @@ -108,35 +126,47 @@ def _remove_duplicate_qa_data( qa_pairs (list): QA datasets distance (float): similarity threshold """ - qa_pairs_dict = {} - for qa in qa_pairs: - qa_pairs_dict[qa["id"]] = qa - for id, qa_pair in qa_pairs_dict.items(): - logger.debug(f"Querying similarity of QA item: {qa_pair}") - if qa_pair.get("duplicated_flag") is not None and qa_pair.get("duplicated_flag"): - logger.debug(f"QA Duplicate Skip") - continue - params = { - "task_id": qa_pair.get("task_id"), - "document_id": qa_pair.get("document_id"), - "id": qa_pair.get("id"), - } - res = dp_document_qa_remove_duplicate_db_operate.filter_by_distance( - params, - self.pool + try: + qa_pairs_dict = {} + for qa in qa_pairs: + qa_pairs_dict[qa["id"]] = qa + for id, qa_pair in qa_pairs_dict.items(): + logger.debug(f"Querying similarity of QA item: {qa_pair}") + if qa_pair.get("duplicated_flag") is not None and qa_pair.get("duplicated_flag"): + logger.debug(f"QA Duplicate Skip") + continue + params = { + "task_id": qa_pair.get("task_id"), + "document_id": qa_pair.get("document_id"), + "id": qa_pair.get("id"), + } + res = dp_document_qa_remove_duplicate_db_operate.filter_by_distance( + params, + self.pool + ) + self._remove_qa_embedding_data_by_id( + qa_pair["id"] + ) + logger.debug(f"Querying similarity of QA result: {res}") + for qa in res["data"]: + if qa["question_distance"] > distance and qa["answer_distance"] > distance: + qa["duplicated_flag"] = 0 + qa["compare_with_id"] = id + qa_pairs_dict[qa["id"]] = qa + self._remove_qa_embedding_data_by_id( + qa["id"] + ) + return {"status": 200, "message": "", "data": list(qa_pairs_dict.values())} + except Exception as ex: + logger.error( + "".join( + [ + f"qa remove duplicate fail\n", + f"The tracing error is: \n{traceback.format_exc()}\n", + ] + ) ) - self._remove_qa_embedding_data_by_id( - qa_pair["id"] - ) - logger.debug(f"Querying similarity of QA result: {res}") - for qa in res["data"]: - if qa["question_distance"] > distance and qa["answer_distance"] > distance: - qa["duplicated_flag"] = True - qa_pairs_dict[qa["id"]] = qa - self._remove_qa_embedding_data_by_id( - qa["id"] - ) - return list(qa_pairs_dict.values()) + return {"status": 1000, "message": "QA去重失败,未知原因,请联系管理员!", "data": ""} def qa_remove_duplicate( self, @@ -203,9 +233,13 @@ def qa_remove_duplicate( } ] """ - self._import_qa_embedding_data( + qa_embedding_res = self._import_qa_embedding_data( qa_pairs ) + + if qa_embedding_res.get("status") != 200: + return qa_embedding_res + return self._remove_duplicate_qa_data( qa_pairs, distance diff --git a/pypi/data-processing/src/utils/file_utils.py b/pypi/data-processing/src/utils/file_utils.py index fc38288ad..90e5c2fe3 100644 --- a/pypi/data-processing/src/utils/file_utils.py +++ b/pypi/data-processing/src/utils/file_utils.py @@ -54,3 +54,11 @@ def get_file_name_without_extension(file_name): file_name_without_extension = path.stem return file_name_without_extension + + +def get_file_name(file_path): + """Get file name""" + path = Path(file_path) + file_name = path.name + + return file_name diff --git a/pypi/data-processing/src/utils/json_utils.py b/pypi/data-processing/src/utils/json_utils.py index 962a0d410..590d3316c 100644 --- a/pypi/data-processing/src/utils/json_utils.py +++ b/pypi/data-processing/src/utils/json_utils.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import ujson @@ -46,4 +45,3 @@ def loads( return ujson.loads( data, ) -