Skip to content

Commit

Permalink
d
Browse files Browse the repository at this point in the history
  • Loading branch information
Lanture1064 committed Jan 12, 2024
1 parent 4f04286 commit caac671
Show file tree
Hide file tree
Showing 6 changed files with 262 additions and 53 deletions.
86 changes: 86 additions & 0 deletions evaluation/pkg/pkg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import os
from datasets import Dataset
from datasets import load_dataset
from langchain.chat_models import ChatOpenAI
from ragas.llms import LangchainLLM
from ragas.embeddings import RagasEmbeddings
from ragas.embeddings import OpenAIEmbeddings
from ragas.embeddings import HuggingfaceEmbeddings
from ragas.metrics.base import Metric
from ragas import evaluate
from ragas.evaluation import Result

DEFAULT_METRICS = [
"answer_relevancy",
"context_precision",
"faithfulness",
"context_recall",
"context_relevancy"
]

def wrap_langchain_llm(
model: str,
api_base: str | None,
api_key: str | None
) -> LangchainLLM:
if api_base is None:
print('api_base not provided, assuming OpenAI default')
api_base = 'https://api.openai.com/v1'
os.environ["OPENAI_API_KEY"] = api_key
if api_key is None:
raise ValueError("api_key must be provided")
base = ChatOpenAI(model_name=model)
else:
os.environ["OPENAI_API_KEY"] = api_key
os.environ["OPENAI_API_BASE"] = api_base
base = ChatOpenAI(
model_name=model,
openai_api_key=api_key,
openai_api_base=api_base
)
return LangchainLLM(llm=base)

from ragas.metrics import (
context_precision,
context_recall,
context_relevancy,
answer_relevancy,
answer_correctness,
answer_similarity,
faithfulness
)

def set_metrics(
metrics: list[str],
llm: LangchainLLM | None,
embeddings: RagasEmbeddings | None
) -> list[Metric]:
ms = []
if llm:
context_precision.llm = llm
context_recall.llm = llm
context_relevancy.llm = llm
answer_correctness.llm = llm
answer_similarity.llm = llm
faithfulness.llm = llm
if embeddings:
answer_relevancy.embeddings = embeddings
answer_correctness.embeddings = embeddings
if not metrics:
metrics = DEFAULT_METRICS
for m in metrics:
if m == 'context_precision':
ms.append(context_precision)
elif m == 'context_recall':
ms.append(context_recall)
elif m == 'context_relevancy':
ms.append(context_relevancy)
elif m == 'answer_relevancy':
ms.append(answer_relevancy)
elif m == 'answer_correctness':
ms.append(answer_correctness)
elif m == 'answer_similarity':
ms.append(answer_similarity)
elif m == 'faithfulness':
ms.append(faithfulness)
return ms
136 changes: 136 additions & 0 deletions evaluation/pkg/ragas_pkg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import os
from datasets import Dataset
from datasets import load_dataset
from langchain.chat_models import ChatOpenAI
from ragas.llms import LangchainLLM
from ragas.embeddings import RagasEmbeddings
from ragas.metrics.base import Metric
from ragas.metrics import DEFAULT_METRICS
from ragas import evaluate
from ragas.evaluation import Result

class JudgeLLM:
def __init__(self, llm: LangchainLLM):
self.llm = llm
class ExamPaper:
def __init__(self, qa: Dataset):
self.qa = qa

class EvalTask:
def __init__(self, judge: JudgeLLM, exampaper: ExamPaper, metrics: list):
self.judge = judge
self.exampaper = exampaper
self.metrics = metrics



def _wrap_llm(
model: str,
api_base: str | None,
api_key: str | None
) -> LangchainLLM:
if api_base is None:
print('api_base not provided, assuming default')
api_base = 'https://api.openai.com/v1'
os.environ["OPENAI_API_KEY"] = api_key
if api_key is None:
raise ValueError("api_key must be provided")
base = ChatOpenAI(model_name=model)
else:
os.environ["OPENAI_API_KEY"] = api_key
os.environ["OPENAI_API_BASE"] = api_base
base = ChatOpenAI(
model_name=model,
openai_api_key=api_key,
openai_api_base=api_base
)
return LangchainLLM(llm=base)

def _set_metric(
metrics: list[str],
llm: LangchainLLM | None,
embeddings: RagasEmbeddings | None
) -> list[Metric]:
# init metric list ms
print(metrics)
ms = []
for m in metrics:
if m == 'context_precision':
from ragas.metrics import context_precision
if llm is not None:
context_precision.llm = llm
ms.append(context_precision)
elif m == 'context_recall':
from ragas.metrics import context_recall
if llm is not None:
context_recall.llm = llm
ms.append(context_recall)
elif m == 'context_relevancy':
from ragas.metrics import context_relevancy
if llm is not None:
context_relevancy.llm = llm
ms.append(context_relevancy)
elif m == 'answer_relevancy':
from ragas.metrics import answer_relevancy
if embeddings is not None:
answer_relevancy.embeddings = embeddings
ms.append(answer_relevancy)
elif m == 'answer_correctness':
from ragas.metrics import answer_correctness
if llm is not None:
answer_correctness.llm = llm
if embeddings is not None:
answer_correctness.embeddings = embeddings
ms.append(answer_correctness)
elif m == 'answer_similarity':
from ragas.metrics import answer_similarity
if llm is not None:
answer_similarity.llm = llm
ms.append(answer_similarity)
elif m == 'faithfulness':
from ragas.metrics import faithfulness
if llm is not None:
faithfulness.llm = llm
ms.append(faithfulness)
return ms


def eval_data(
test_set: Dataset | None,
metrics: list[str] = {},
model: str = 'gpt-3.5-turbo',
embeddings: RagasEmbeddings = None,
api_base: str = None,
api_key: str = None
) -> Result:
llm = _wrap_llm(model, api_base, api_key)

default_ms = [
"answer_relevancy",
"context_precision",
"faithfulness",
"context_recall",
"context_relevancy"
]

if metrics:
ms = _set_metric(metrics, llm, embeddings)
else:
print('metrics not provided, assuming default')
ms = _set_metric(metrics=default_ms, llm=llm, embeddings=embeddings)
print(f'using {len(ms)} metrics')

if test_set is None:
print('test_set not provided, assuming default')
fiqa_eval = load_dataset("explodinggradients/fiqa", "ragas_eval")
result = evaluate(
fiqa_eval["baseline"].select(range(5)),
metrics=ms
)
return result
else:
result = evaluate(
test_set,
metrics=ms
)
return result
53 changes: 0 additions & 53 deletions evaluation/ragas-fiqa.py

This file was deleted.

16 changes: 16 additions & 0 deletions evaluation/ragas_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import argparse

# main logic:
# 1. get param from cli
# (including judgellm, metrics, QA dataset path)
# 2. validate param
# 3. run evaluation

def cli():
parser = argparse.ArgumentParser(description='RAGAS CLI')
parser.add_argument("--judgellm", type=str, default="gpt-3.5-turbo")
parser.add_argument("--metrics", type=str, default="default")
parser.add_argument("--dataset", type=str, default="evaluation/qa_data/qa_data.csv")

def eval_data():
llm = wrap_langchain_llm(model, api_base, api_key)
File renamed without changes.
24 changes: 24 additions & 0 deletions evaluation/ragas_test_fiqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from ragas.embeddings import HuggingfaceEmbeddings
from ragas_pkg import eval_data

judge_embedding = HuggingfaceEmbeddings(model_name="BAAI/bge-small-en")

# from ragas.metrics.critique import SUPPORTED_ASPECTS
# crit = SUPPORTED_ASPECTS

# base = ChatOpenAI(
# model="6b33213a-2692-4c70-b203-5ec6e0542cda",
# openai_api_base="http://fastchat.172.40.20.125.nip.io/v1"
# )

# gpt4 = ChatOpenAI(model_name="gpt-4")

# judge_llm = LangchainLLM(llm=gpt4)

result = eval_data(
test_set = None,
embeddings = judge_embedding,
api_key = "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"
)

df = result.to_pandas()

0 comments on commit caac671

Please sign in to comment.