-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
bc904b9
commit 43e9bac
Showing
5 changed files
with
262 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import os | ||
from datasets import Dataset | ||
from datasets import load_dataset | ||
from langchain.chat_models import ChatOpenAI | ||
from ragas.llms import LangchainLLM | ||
from ragas.embeddings import RagasEmbeddings | ||
from ragas.embeddings import OpenAIEmbeddings | ||
from ragas.embeddings import HuggingfaceEmbeddings | ||
from ragas.metrics.base import Metric | ||
from ragas import evaluate | ||
from ragas.evaluation import Result | ||
|
||
DEFAULT_METRICS = [ | ||
"answer_relevancy", | ||
"context_precision", | ||
"faithfulness", | ||
"context_recall", | ||
"context_relevancy" | ||
] | ||
|
||
def wrap_langchain_llm( | ||
model: str, | ||
api_base: str | None, | ||
api_key: str | None | ||
) -> LangchainLLM: | ||
if api_base is None: | ||
print('api_base not provided, assuming OpenAI default') | ||
api_base = 'https://api.openai.com/v1' | ||
os.environ["OPENAI_API_KEY"] = api_key | ||
if api_key is None: | ||
raise ValueError("api_key must be provided") | ||
base = ChatOpenAI(model_name=model) | ||
else: | ||
os.environ["OPENAI_API_KEY"] = api_key | ||
os.environ["OPENAI_API_BASE"] = api_base | ||
base = ChatOpenAI( | ||
model_name=model, | ||
openai_api_key=api_key, | ||
openai_api_base=api_base | ||
) | ||
return LangchainLLM(llm=base) | ||
|
||
from ragas.metrics import ( | ||
context_precision, | ||
context_recall, | ||
context_relevancy, | ||
answer_relevancy, | ||
answer_correctness, | ||
answer_similarity, | ||
faithfulness | ||
) | ||
|
||
def set_metrics( | ||
metrics: list[str], | ||
llm: LangchainLLM | None, | ||
embeddings: RagasEmbeddings | None | ||
) -> list[Metric]: | ||
ms = [] | ||
if llm: | ||
context_precision.llm = llm | ||
context_recall.llm = llm | ||
context_relevancy.llm = llm | ||
answer_correctness.llm = llm | ||
answer_similarity.llm = llm | ||
faithfulness.llm = llm | ||
if embeddings: | ||
answer_relevancy.embeddings = embeddings | ||
answer_correctness.embeddings = embeddings | ||
if not metrics: | ||
metrics = DEFAULT_METRICS | ||
for m in metrics: | ||
if m == 'context_precision': | ||
ms.append(context_precision) | ||
elif m == 'context_recall': | ||
ms.append(context_recall) | ||
elif m == 'context_relevancy': | ||
ms.append(context_relevancy) | ||
elif m == 'answer_relevancy': | ||
ms.append(answer_relevancy) | ||
elif m == 'answer_correctness': | ||
ms.append(answer_correctness) | ||
elif m == 'answer_similarity': | ||
ms.append(answer_similarity) | ||
elif m == 'faithfulness': | ||
ms.append(faithfulness) | ||
return ms |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import os | ||
from datasets import Dataset | ||
from datasets import load_dataset | ||
from langchain.chat_models import ChatOpenAI | ||
from ragas.llms import LangchainLLM | ||
from ragas.embeddings import RagasEmbeddings | ||
from ragas.metrics.base import Metric | ||
from ragas.metrics import DEFAULT_METRICS | ||
from ragas import evaluate | ||
from ragas.evaluation import Result | ||
|
||
class JudgeLLM: | ||
def __init__(self, llm: LangchainLLM): | ||
self.llm = llm | ||
class ExamPaper: | ||
def __init__(self, qa: Dataset): | ||
self.qa = qa | ||
|
||
class EvalTask: | ||
def __init__(self, judge: JudgeLLM, exampaper: ExamPaper, metrics: list): | ||
self.judge = judge | ||
self.exampaper = exampaper | ||
self.metrics = metrics | ||
|
||
|
||
|
||
def _wrap_llm( | ||
model: str, | ||
api_base: str | None, | ||
api_key: str | None | ||
) -> LangchainLLM: | ||
if api_base is None: | ||
print('api_base not provided, assuming default') | ||
api_base = 'https://api.openai.com/v1' | ||
os.environ["OPENAI_API_KEY"] = api_key | ||
if api_key is None: | ||
raise ValueError("api_key must be provided") | ||
base = ChatOpenAI(model_name=model) | ||
else: | ||
os.environ["OPENAI_API_KEY"] = api_key | ||
os.environ["OPENAI_API_BASE"] = api_base | ||
base = ChatOpenAI( | ||
model_name=model, | ||
openai_api_key=api_key, | ||
openai_api_base=api_base | ||
) | ||
return LangchainLLM(llm=base) | ||
|
||
def _set_metric( | ||
metrics: list[str], | ||
llm: LangchainLLM | None, | ||
embeddings: RagasEmbeddings | None | ||
) -> list[Metric]: | ||
# init metric list ms | ||
print(metrics) | ||
ms = [] | ||
for m in metrics: | ||
if m == 'context_precision': | ||
from ragas.metrics import context_precision | ||
if llm is not None: | ||
context_precision.llm = llm | ||
ms.append(context_precision) | ||
elif m == 'context_recall': | ||
from ragas.metrics import context_recall | ||
if llm is not None: | ||
context_recall.llm = llm | ||
ms.append(context_recall) | ||
elif m == 'context_relevancy': | ||
from ragas.metrics import context_relevancy | ||
if llm is not None: | ||
context_relevancy.llm = llm | ||
ms.append(context_relevancy) | ||
elif m == 'answer_relevancy': | ||
from ragas.metrics import answer_relevancy | ||
if embeddings is not None: | ||
answer_relevancy.embeddings = embeddings | ||
ms.append(answer_relevancy) | ||
elif m == 'answer_correctness': | ||
from ragas.metrics import answer_correctness | ||
if llm is not None: | ||
answer_correctness.llm = llm | ||
if embeddings is not None: | ||
answer_correctness.embeddings = embeddings | ||
ms.append(answer_correctness) | ||
elif m == 'answer_similarity': | ||
from ragas.metrics import answer_similarity | ||
if llm is not None: | ||
answer_similarity.llm = llm | ||
ms.append(answer_similarity) | ||
elif m == 'faithfulness': | ||
from ragas.metrics import faithfulness | ||
if llm is not None: | ||
faithfulness.llm = llm | ||
ms.append(faithfulness) | ||
return ms | ||
|
||
|
||
def eval_data( | ||
test_set: Dataset | None, | ||
metrics: list[str] = {}, | ||
model: str = 'gpt-3.5-turbo', | ||
embeddings: RagasEmbeddings = None, | ||
api_base: str = None, | ||
api_key: str = None | ||
) -> Result: | ||
llm = _wrap_llm(model, api_base, api_key) | ||
|
||
default_ms = [ | ||
"answer_relevancy", | ||
"context_precision", | ||
"faithfulness", | ||
"context_recall", | ||
"context_relevancy" | ||
] | ||
|
||
if metrics: | ||
ms = _set_metric(metrics, llm, embeddings) | ||
else: | ||
print('metrics not provided, assuming default') | ||
ms = _set_metric(metrics=default_ms, llm=llm, embeddings=embeddings) | ||
print(f'using {len(ms)} metrics') | ||
|
||
if test_set is None: | ||
print('test_set not provided, assuming default') | ||
fiqa_eval = load_dataset("explodinggradients/fiqa", "ragas_eval") | ||
result = evaluate( | ||
fiqa_eval["baseline"].select(range(5)), | ||
metrics=ms | ||
) | ||
return result | ||
else: | ||
result = evaluate( | ||
test_set, | ||
metrics=ms | ||
) | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import argparse | ||
|
||
# main logic: | ||
# 1. get param from cli | ||
# (including judgellm, metrics, QA dataset path) | ||
# 2. validate param | ||
# 3. run evaluation | ||
|
||
def cli(): | ||
parser = argparse.ArgumentParser(description='RAGAS CLI') | ||
parser.add_argument("--judgellm", type=str, default="gpt-3.5-turbo") | ||
parser.add_argument("--metrics", type=str, default="default") | ||
parser.add_argument("--dataset", type=str, default="evaluation/qa_data/qa_data.csv") | ||
|
||
def eval_data(): | ||
llm = wrap_langchain_llm(model, api_base, api_key) |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from ragas.embeddings import HuggingfaceEmbeddings | ||
from ragas_pkg import eval_data | ||
|
||
judge_embedding = HuggingfaceEmbeddings(model_name="BAAI/bge-small-en") | ||
|
||
# from ragas.metrics.critique import SUPPORTED_ASPECTS | ||
# crit = SUPPORTED_ASPECTS | ||
|
||
# base = ChatOpenAI( | ||
# model="6b33213a-2692-4c70-b203-5ec6e0542cda", | ||
# openai_api_base="http://fastchat.172.40.20.125.nip.io/v1" | ||
# ) | ||
|
||
# gpt4 = ChatOpenAI(model_name="gpt-4") | ||
|
||
# judge_llm = LangchainLLM(llm=gpt4) | ||
|
||
result = eval_data( | ||
test_set = None, | ||
embeddings = judge_embedding, | ||
api_key = "" | ||
) | ||
|
||
df = result.to_pandas() |