Skip to content

Use GitHub U data for fork #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/evaluation.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ pip install -r requirements-dev.txt

## Generate ground truth data

Modify the prompt in `evals/generate.txt` to match your database table and RAG scenario.
Modify the prompt in `evals/generate_prompt.txt` to match your database table and RAG scenario.

Generate ground truth data by running the following command:

```bash
python evals/generate_ground_truth_data.py
python evals/generate_ground_truth.py --numquestions=50 --persource=50
```

Review the generated data after running that script, removing any question/answer pairs that don't seem like realistic user input.
Expand Down
2 changes: 1 addition & 1 deletion evals/eval_config.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"testdata_path": "ground_truth.jsonl",
"results_dir": "results/experiment<TIMESTAMP>",
"requested_metrics": ["gpt_groundedness", "gpt_relevance", "answer_length", "latency", "citations_matched"],
"requested_metrics": ["gpt_groundedness", "gpt_relevance", "f1_score", "answer_length", "latency", "citations_matched"],
"target_url": "http://127.0.0.1:8000/chat",
"target_parameters": {
"overrides": {
Expand Down
14 changes: 10 additions & 4 deletions evals/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,13 @@ def citations_overlap(*, response, ground_truth, **kwargs):
if response is None:
logger.warning("Received response of None, can't compute citation_match metric. Setting to -1.")
return {cls.METRIC_NAME: -1}
truth_citations = set(re.findall(r"\[(\d+)\]", ground_truth))
response_citations = set(re.findall(r"\[(\d+)\]", response))
# Count the percentage of citations that are present in the response
citation_pattern = r"\[(\d+)\]"
truth_citations = set(re.findall(citation_pattern, ground_truth))
response_citations = set(re.findall(citation_pattern, response))
# Return the percentage of citations that are present in the response
if len(truth_citations) == 0:
logger.warning("No citations found in ground truth, setting metric to 1.0.")
return {cls.METRIC_NAME: 1.0}
num_citations = len(truth_citations)
num_matched_citations = len(truth_citations.intersection(response_citations))
return {cls.METRIC_NAME: num_matched_citations / num_citations}
Expand Down Expand Up @@ -74,8 +78,10 @@ def get_openai_config() -> dict:

if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO, format="%(message)s", datefmt="[%X]", handlers=[RichHandler(rich_tracebacks=True)]
level=logging.WARNING, format="%(message)s", datefmt="[%X]", handlers=[RichHandler(rich_tracebacks=True)]
)
logging.getLogger("evaltools").setLevel(logging.INFO)
logger.setLevel(logging.INFO)
load_dotenv(".env", override=True)

parser = argparse.ArgumentParser(description="Run evaluation with OpenAI configuration.")
Expand Down
63 changes: 37 additions & 26 deletions evals/generate_ground_truth.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
import json
import logging
import os
Expand All @@ -8,14 +9,18 @@
from dotenv_azd import load_azd_env
from openai import AzureOpenAI, OpenAI
from openai.types.chat import ChatCompletionToolParam
from sqlalchemy import create_engine, select
from sqlalchemy import create_engine, select, func
from sqlalchemy.orm import Session
from dotenv import load_dotenv
from jinja2 import Environment, FileSystemLoader
from rich.logging import RichHandler

from fastapi_app.postgres_models import Item

logger = logging.getLogger("ragapp")



def qa_pairs_tool(num_questions: int = 1) -> ChatCompletionToolParam:
return {
"type": "function",
Expand Down Expand Up @@ -47,26 +52,19 @@ def qa_pairs_tool(num_questions: int = 1) -> ChatCompletionToolParam:


def source_retriever() -> Generator[str, None, None]:
# Connect to the database
# Connect to the local database
DBHOST = os.environ["POSTGRES_HOST"]
DBUSER = os.environ["POSTGRES_USERNAME"]
DBPASS = os.environ["POSTGRES_PASSWORD"]
DBNAME = os.environ["POSTGRES_DATABASE"]
DATABASE_URI = f"postgresql://{DBUSER}:{DBPASS}@{DBHOST}/{DBNAME}"
engine = create_engine(DATABASE_URI, echo=False)
with Session(engine) as session:
# Fetch all products for a particular type
item_types = session.scalars(select(Item.type).distinct())
for item_type in item_types:
records = list(session.scalars(select(Item).filter(Item.type == item_type).order_by(Item.id)))
logger.info(f"Processing database records for type: {item_type}")
yield "\n\n".join([f"## Product ID: [{record.id}]\n" + record.to_str_for_rag() for record in records])
# Fetch each item individually
# records = list(session.scalars(select(Item).order_by(Item.id)))
# for record in records:
# logger.info(f"Processing database record: {record.name}")
# yield f"## Product ID: [{record.id}]\n" + record.to_str_for_rag()
# await self.openai_chat_client.chat.completions.create(
while True:
# Fetch all the rows from the database
random_rows = list(session.scalars(select(Item).order_by(func.random())))
logger.info("Fetched %d random rows", len(random_rows))
yield "\n\n".join([f"## Row ID: [{row.id}]\n" + row.to_str_for_rag() for row in random_rows])


def source_to_text(source) -> str:
Expand Down Expand Up @@ -108,31 +106,36 @@ def get_openai_client() -> tuple[AzureOpenAI | OpenAI, str]:
return openai_client, model


def generate_ground_truth_data(num_questions_total: int, num_questions_per_source: int = 5):
def generate_ground_truth_data(num_questions_total: int, num_questions_per_source):
logger.info("Generating %d questions total", num_questions_total)
openai_client, model = get_openai_client()
current_dir = Path(__file__).parent
generate_prompt = open(current_dir / "generate_prompt.txt").read()

# Load the template from the file system
jinja_file_loader = FileSystemLoader(current_dir)
jinja_env = Environment(loader=jinja_file_loader)
prompt_template = jinja_env.get_template('generate_prompt.jinja2')

output_file = Path(__file__).parent / "ground_truth.jsonl"

qa: list[dict] = []
for source in source_retriever():
if len(qa) > num_questions_total:
logger.info("Generated enough questions already, stopping")
break
while len(qa) < num_questions_total:
sources = next(source_retriever())
previous_questions = [qa_pair["question"] for qa_pair in qa]
result = openai_client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": generate_prompt},
{"role": "user", "content": json.dumps(source)},
{"role": "system", "content": prompt_template.render(num_questions=num_questions_per_source, previous_questions=previous_questions)},
{"role": "user", "content": json.dumps(sources)},
],
tools=[qa_pairs_tool(num_questions=2)],
tools=[qa_pairs_tool(num_questions=num_questions_per_source)],
)
if not result.choices[0].message.tool_calls:
logger.warning("No tool calls found in response, skipping")
continue
qa_pairs = json.loads(result.choices[0].message.tool_calls[0].function.arguments)["qa_list"]
qa_pairs = [{"question": qa_pair["question"], "truth": qa_pair["answer"]} for qa_pair in qa_pairs]
logger.info("Received %d suggested questions", len(qa_pairs))
qa.extend(qa_pairs)

logger.info("Writing %d questions to %s", num_questions_total, output_file)
Expand All @@ -145,8 +148,16 @@ def generate_ground_truth_data(num_questions_total: int, num_questions_per_sourc


if __name__ == "__main__":
logging.basicConfig(level=logging.WARNING)
logging.basicConfig(
level=logging.WARNING, format="%(message)s", datefmt="[%X]", handlers=[RichHandler(rich_tracebacks=True)]
)
logger.setLevel(logging.INFO)
load_azd_env()
load_dotenv(".env", override=True)

parser = argparse.ArgumentParser(description="Run evaluation with OpenAI configuration.")
parser.add_argument("--numquestions", type=int, help="Specify the number of questions.", default=50)
parser.add_argument("--persource", type=int, help="Specify the number of questions per retrieved sources.", default=5)

args = parser.parse_args()

generate_ground_truth_data(num_questions_total=10)
generate_ground_truth_data(num_questions_total=args.numquestions, num_questions_per_source=args.persource)
22 changes: 22 additions & 0 deletions evals/generate_prompt.jinja2
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
Your job is to generate {{ num_questions }} example questions that a customer might ask about sessions at the GitHub Universe conference.
The conference has *not* yet happened.

You should come up with the {{ num_questions }} questions and answers based on the provided data.
Each answer should include the row ID in square brackets.
For example,
'Are there any sessions featuring Python?'
with answer:
'Yes, there is a session on Python at 10:00 AM on the first day, about how to use Python to automate your workflow. [12]
There is an additional session at 2:00 PM on the second day about how to use Python to build a web application. [5]
Finally, there is a session at 4:00 PM on the second day about how to use Python to analyze data. [3]'
'
Your answer should typically be a paragraph or two.

Your questions should NOT be about specific session titles, but instead be more general questions
that a conference attendee might ask when planning their schedule.
Your answers should reference specific session titles, however, to help the user pick sessions.

{% if previous_questions %}
You should NOT suggest any of these questions that have already been asked:
{{ previous_questions }}
{% endif %}
9 changes: 0 additions & 9 deletions evals/generate_prompt.txt

This file was deleted.

Loading
Loading