Skip to content

Commit 5ec4ea3

Browse files
committed
Add test script + linting
Add example script (`create_disk_cache.py`) that creates a `LlamaStaticDiskCache` from example snippets, then does some basic sanity testing to make sure that the cache matches the prompt tokens as expected. The generation of an example prompt is slightly complex, because need to generate the prompt with each context included, but not the user question.
1 parent d056228 commit 5ec4ea3

File tree

2 files changed

+228
-1
lines changed

2 files changed

+228
-1
lines changed
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
"""
2+
Creates static disk cache give a dataset of snippets to use as contexts for a RAG application.
3+
4+
Background: Want to embed the fixed system prompt + first context and then store on disk.
5+
That way, when a question is asked and the first context is provided, can look up the
6+
KV cache to find the prompt that matches the prompt tokens (including first context).
7+
8+
This should save on prompt ingestion time, decreasing time to first token.
9+
"""
10+
11+
import argparse
12+
import os
13+
import pathlib
14+
15+
import pandas as pd
16+
import tqdm
17+
18+
from llama_cpp.llama import Llama
19+
from llama_cpp.llama_cache import LlamaStaticDiskCache
20+
from llama_cpp.llama_chat_format import format_nekomata
21+
22+
# Add additional formatters here as desired so that can swap out models.
23+
CHAT_FORMATTER_MAP = {
24+
"rinna/nekomata-7b-instruction-gguf": format_nekomata,
25+
}
26+
27+
28+
def combine_question_ctx_nekomata(question, contexts):
29+
"""
30+
Formats question and contexts for nekomata-7b.
31+
"""
32+
output = ""
33+
for i, context in enumerate(contexts):
34+
output += f"- 資料{i+1}: '{context}'\n"
35+
36+
output += "\n"
37+
38+
output += question
39+
40+
return output
41+
42+
43+
# How to combine contexts + user question when creating a *full* prompt
44+
CONTEXT_QUESTION_FORMATTER_MAP = {
45+
"rinna/nekomata-7b-instruction-gguf": combine_question_ctx_nekomata,
46+
}
47+
48+
DEFAULT_SYSTEM_PROMPT = """
49+
You are a virtual assistant. You will be provided with contexts and a user
50+
question. Your job is to answer a user's question faithfully and concisely.
51+
52+
If the context provided does not contain enough information to answer the question,
53+
respond with "I don't know" - do not make up information. If you are helpful and provide
54+
accurate information, you will be provided with a $10,000 bonus. If you provide inaccurate
55+
information, unhelpful responses, or information not grounded in the context
56+
provided, you will be penalized $10,000 and fired - into the Sun.
57+
""".strip()
58+
59+
60+
def _create_nekomata_prompt_prefix(
61+
context: str, system_prompt=DEFAULT_SYSTEM_PROMPT
62+
) -> str:
63+
"""
64+
Override this if using a different model.
65+
66+
This provides a partially formatted prompt for the Nekomata model.
67+
It passes in the system prompt and the first context to the model,
68+
but not the question or prompt for assistant.
69+
"""
70+
71+
return """
72+
以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。
73+
74+
### 指示:
75+
{system_prompt}
76+
77+
### 入力:
78+
{input}""".format(
79+
system_prompt=system_prompt, input=f"- 資料1: '{context.strip()}'\n"
80+
).lstrip(
81+
"\n"
82+
)
83+
84+
85+
PARTIAL_PROMPT_MODEL_MAP = {
86+
"rinna/nekomata-7b-instruction-gguf": _create_nekomata_prompt_prefix,
87+
}
88+
89+
90+
def main(args: argparse.Namespace):
91+
dataset_path: pathlib.Path = args.dataset
92+
assert dataset_path.exists(), f"Dataset path {dataset_path} does not exist"
93+
94+
dataset = pd.read_csv(str(dataset_path))
95+
96+
snippets = dataset.loc[:, args.column_name].tolist()
97+
98+
model = Llama.from_pretrained(
99+
args.model,
100+
filename=args.model_filename,
101+
n_ctx=args.n_ctx,
102+
n_gpu_layers=-1,
103+
n_batch=1,
104+
n_threads=args.n_threads,
105+
n_threads_batch=args.n_threads,
106+
verbose=False,
107+
)
108+
109+
prompt_formatter_func = PARTIAL_PROMPT_MODEL_MAP[args.model]
110+
111+
# Have to format such that includes system prompt and the context
112+
snippets = [prompt_formatter_func(context) for context in snippets]
113+
114+
cache = LlamaStaticDiskCache.build_cache(args.output, tqdm.tqdm(snippets), model)
115+
snippet_tokens = model.tokenize(
116+
snippets[0].encode("utf-8"), add_bos=True, special=True
117+
)
118+
assert snippet_tokens in cache, "First snippet not in cache"
119+
120+
# pylint: disable=protected-access
121+
cache_prefix_tokens = cache._find_longest_prefix_key(snippet_tokens)
122+
123+
assert cache_prefix_tokens == tuple(
124+
snippet_tokens
125+
), "Expected all snippet tokens to be in cache"
126+
127+
128+
if __name__ == "__main__":
129+
parser = argparse.ArgumentParser()
130+
131+
parser.add_argument(
132+
"-d",
133+
"--dataset",
134+
type=pathlib.Path,
135+
required=True,
136+
help="Path to serialized dataframe with snippets to use",
137+
)
138+
139+
parser.add_argument(
140+
"-m",
141+
"--model",
142+
type=str,
143+
default="rinna/nekomata-7b-instruction-gguf",
144+
help="Hugging Face model name",
145+
)
146+
147+
parser.add_argument(
148+
"--model-filename",
149+
type=str,
150+
default="*Q4_K_M.gguf",
151+
help="Name of model weights file to load from repo - may contain wildcards (like '*Q4_K_M.gguf')",
152+
)
153+
154+
parser.add_argument(
155+
"--n-ctx",
156+
type=int,
157+
required=True,
158+
help="Context size (in tokens) - must be fixed for KV cache",
159+
)
160+
161+
parser.add_argument(
162+
"--n-threads",
163+
type=int,
164+
default=os.cpu_count(),
165+
help="Number of threads to use for inference + batch processing",
166+
)
167+
168+
parser.add_argument(
169+
"--column-name",
170+
type=str,
171+
default="snippets",
172+
help="Column name identifying snippets to use as contexts",
173+
)
174+
175+
parser.add_argument(
176+
"-o",
177+
"--output",
178+
type=str,
179+
default="static_cache",
180+
help="Output directory for static cache",
181+
)
182+
183+
args = parser.parse_args()
184+
185+
chat_formatter = CHAT_FORMATTER_MAP[args.model]
186+
question_ctx_combiner = CONTEXT_QUESTION_FORMATTER_MAP[args.model]
187+
188+
DUMMY_CONTEXTS = [
189+
"The air speed of an unladen swallow is 24 miles per hour.",
190+
"Red pandas are not actually pandas, but are more closely related to raccoons.",
191+
"Directly observing a quantum system can change its state.",
192+
"The mitochondria is the powerhouse of the cell.",
193+
"The least common multiple of 6 and 8 is 24.",
194+
]
195+
196+
# Just a quick-and-dirty test so that can verify that a full prompt will contain
197+
# the partial prompt (and so prefix matching should work)
198+
def _generate_full_prompt(user_question: str):
199+
user_msg = question_ctx_combiner(user_question, DUMMY_CONTEXTS)
200+
msgs = [
201+
{
202+
"role": "system",
203+
"content": DEFAULT_SYSTEM_PROMPT,
204+
},
205+
{
206+
"role": "user",
207+
"content": user_msg,
208+
},
209+
]
210+
211+
full_prompt = chat_formatter(msgs).prompt
212+
213+
return full_prompt
214+
215+
question = "What is the velocity of an unladen swallow?"
216+
217+
complete_prompt = _generate_full_prompt(question)
218+
partial_context = _create_nekomata_prompt_prefix(DUMMY_CONTEXTS[0])
219+
220+
if not partial_context in complete_prompt:
221+
print("Partial context:\n")
222+
print(partial_context + "\n")
223+
print("not found in complete prompt:\n")
224+
print(complete_prompt)
225+
assert False, "Sanity check failed"
226+
227+
main(args)

llama_cpp/llama_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def build_cache(
234234
print("LlamaStaticDiskCache.build_cache: eval", file=sys.stderr)
235235
model.eval(toks)
236236
state = model.save_state()
237-
cache._private_setitem(toks, state)
237+
cache._private_setitem(toks, state) # pylint: disable=protected-access
238238

239239
# Set up Trie for efficient prefix search
240240
for key in cache.cache.iterkeys():

0 commit comments

Comments
 (0)