Skip to content

Commit 5287d63

Browse files
committed
fixesd
2 parents 07e62b5 + 4341759 commit 5287d63

File tree

2 files changed

+38
-24
lines changed

2 files changed

+38
-24
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
*.local
2-
2+
test.ipynb
33
.python-version
44

55
.vscode/

llama_cpp/llama_chat_format.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2464,6 +2464,11 @@ def base_function_calling(
24642464
f"""root ::= functions | "</done>"\n"""
24652465
f"""functions ::= {function_names}\n"""
24662466
)
2467+
msg_gbnf_grammar = (
2468+
"""root ::= message | functions\n"""
2469+
f"""message ::= "message: " [^{end_token}]* "{end_token}"\n"""
2470+
f"""functions ::= {function_names}\n"""
2471+
)
24672472

24682473

24692474
prompt = template_renderer.render(
@@ -2499,10 +2504,9 @@ def base_function_calling(
24992504
completion: llama_types.CreateCompletionResponse = completion_or_chunks # type: ignore
25002505
text = completion["choices"][0]["text"]
25012506
print(text)
2502-
if "message" in text:
2503-
return _convert_completion_to_chat(
2504-
llama.create_completion(
2505-
prompt=prompt + "message:\n",
2507+
if "message:" in text:
2508+
message_output = llama.create_completion(
2509+
prompt=prompt,
25062510
temperature=temperature,
25072511
top_p=top_p,
25082512
top_k=top_k,
@@ -2521,11 +2525,15 @@ def base_function_calling(
25212525
mirostat_eta=mirostat_eta,
25222526
model=model,
25232527
logits_processor=logits_processor,
2524-
# grammar=llama_grammar.LlamaGrammar.from_string(
2525-
# follow_up_gbnf_tool_grammar, verbose=llama.verbose
2526-
# ),
2527-
),stream=stream
2528-
)
2528+
grammar=llama_grammar.LlamaGrammar.from_string(
2529+
msg_gbnf_grammar, verbose=llama.verbose
2530+
),
2531+
)
2532+
text: llama_types.CreateCompletionResponse = message_output["choices"][0]["text"] # type: ignore
2533+
# fallback
2534+
if not text.startswith("functions."):
2535+
message_output["choices"][0]["text"] = message_output["choices"][0]["text"].replace("message:", "")
2536+
return _convert_completion_to_chat(message_output,stream=stream)
25292537

25302538
# One or more function calls
25312539
tool_name = text[len("functions.") :].replace(":", "")
@@ -2802,15 +2810,17 @@ def vicuna_function_calling(
28022810
]:
28032811
function_calling_template = (
28042812
"{% for message in messages %}"
2813+
"{% if message.role != 'tool' %}"
28052814
"{{ message.role.upper() }}\n" # Vicuna uses upper case for roles
2815+
"{% endif %}"
28062816
# System message
28072817
"{% if message.role == 'system' %}"
28082818
"{{ message.content }}"
28092819
"{% if tool_calls %}"
28102820
"\n\nYou have access to the following functions:\n"
28112821
"{% for tool in tools %}"
28122822
"\nfunctions.{{ tool.function.name }}:\n"
2813-
"{{ tool.function.parameters | tojson }}"
2823+
"{{ tool.function.parameters }}"
28142824
"\n{% endfor %}"
28152825
"\n\nYou can respond to users messages with either a single message or multiple function calls, never both. If function calls are used, they must be the first part of the response."
28162826
"\n\nTo respond with one or more function calls begin the message with 'functions.<function_name>:', use the following format:"
@@ -2823,6 +2833,7 @@ def vicuna_function_calling(
28232833
"\n\nmessage:"
28242834
"\n<message> </s>"
28252835
"{% endif %}"
2836+
"\nAfter performing a function call, the function will send a response containing the return values of the function calls between <tool_output> tags. Present it to the user.\n"
28262837
"</s>\n"
28272838
"{% endif %}"
28282839
# User message
@@ -2844,7 +2855,7 @@ def vicuna_function_calling(
28442855
"{% if tool_calls %}"
28452856
"{% for tool_call in message.tool_calls %}"
28462857
"functions.{{ tool_call.function.name }}:\n"
2847-
"{{ (tool_call.arguments | default('{}') | tojson) }}"
2858+
"{{ (tool_call.function.parameters | default('{}') | tojson) }}"
28482859
"{% if not loop.last %};{% endif %}" # Semicolon separator if not the last function call
28492860
"{% endfor %}"
28502861
"</s>\n"
@@ -2853,7 +2864,7 @@ def vicuna_function_calling(
28532864
# Tool message (treated as Assistant response)
28542865
"{% if message.role == 'tool' %}"
28552866
"ASSISTANT:\n"
2856-
"Function response: {{ message.content | default('No response available') }}"
2867+
"<tool_output>: {{ message.content | default('No response available') }} </tool_output>"
28572868
"</s>\n"
28582869
"{% endif %}"
28592870
"{% endfor %}"
@@ -2910,29 +2921,32 @@ def llama3_function_calling(
29102921
"\nfunctions.{{ tool.function.name }}:\n"
29112922
"{{ tool.function.parameters | tojson }}"
29122923
"\n{% endfor %}"
2913-
"\nYou can respond to users messages with either a single message or one or more function calls. Never both. Prioritize function calls over messages."
2914-
"\nTo respond with a message begin the message with 'message:'"
2915-
'\n Example sending message: message: "Hello, how can I help you?"'
2916-
"\nTo respond with one or more function calls begin the message with 'functions.<function_name>:', use the following format:"
2924+
"\nYou can respond to user messages either by sending a single message or by making one or more function calls. You should never do both. Always prioritize function calls over messages."
2925+
"\nTo send a response message, start your message with 'message:'"
2926+
'\nExample of sending a message: message: "Hello, how can I help you?"'
2927+
"\nTo use one or more function calls, start your response with 'functions.<function_name>:', follow this format:"
29172928
"\nfunctions.<function_name>:"
29182929
'\n{ "arg1": "value1", "arg2": "value2" }'
29192930
"\nfunctions.<function_name>:"
29202931
'\n{ "arg1": "value1", "arg2": "value2" }'
2921-
"\nWhen you are done with the function calls, end the message with </done>."
2922-
'\nStart your output with either message: or functions. <|eot_id|>\n'
2932+
"\nWhen you have completed entering function calls, end your output with '</done>'."
2933+
'\nStart your output with either "message:" or "functions.". Do not mix the two.'
2934+
"\nAfter performing a function call, the function will send a response containing the return values of the function calls between <tool_output> tags. Present it to the user.\n"
2935+
#"Example: <tool_output> item: Cheeseburguer, price: 12 </tool_output> You should output: I found a Cheeseburguer that costs 12 dollars."
29232936
"{% endif %}"
2937+
"<|eot_id|>\n"
29242938
"{% for message in messages %}"
2925-
"{% if message.role == 'tool' %}"
2939+
"{% if message.role == 'tool'%}"
29262940
"<|start_header_id|>user<|end_header_id|>\n\n"
2927-
"here is the Function response, bring it to me in a nice way: {{ message.content | default('No response available') }}"
2941+
"<tool_output> {{ message.content | default('No response available') }} </tool_output>"
29282942
"<|eot_id|>\n"
29292943
"{% elif message.role == 'assistant' and message.function_call is defined%}"
2930-
"<|start_header_id|>{{ message.role }}<|end_header_id|>"
2944+
"<|start_header_id|>{{ message.role }}<|end_header_id|>\n\n"
29312945
"Function called: {{ message.function_call.name | default('No name') }}\n"
29322946
"Function argument: {{ message.function_call.arguments | default('No arguments') }}"
29332947
"<|eot_id|>\n"
2934-
"{% else %}"
2935-
"<|start_header_id|>{{ message.role }}<|end_header_id|>"
2948+
"{% elif message.role != 'system' %}"
2949+
"<|start_header_id|>{{ message.role }}<|end_header_id|>\n\n"
29362950
"{{ message.content }}"
29372951
"<|eot_id|>\n"
29382952
"{% endif %}"

0 commit comments

Comments
 (0)