Skip to content

Commit 971f235

Browse files
committed
Add LlamaChatHandler for more complex chat use cases
1 parent fd0ad07 commit 971f235

File tree

3 files changed

+405
-93
lines changed

3 files changed

+405
-93
lines changed

llama_cpp/llama.py

Lines changed: 5 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from . import llama_cpp
2525
from .llama_types import *
2626
from .llama_grammar import LlamaGrammar
27-
from . import llama_chat_format
27+
import llama_cpp.llama_chat_format as llama_chat_format
2828

2929
import numpy as np
3030
import numpy.typing as npt
@@ -392,7 +392,7 @@ def __init__(
392392

393393
if self.verbose:
394394
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
395-
395+
396396
self.chat_format = chat_format
397397

398398
self._n_vocab = self.n_vocab()
@@ -1512,78 +1512,6 @@ def __call__(
15121512
grammar=grammar,
15131513
)
15141514

1515-
def _convert_text_completion_to_chat(
1516-
self, completion: Completion
1517-
) -> ChatCompletion:
1518-
return {
1519-
"id": "chat" + completion["id"],
1520-
"object": "chat.completion",
1521-
"created": completion["created"],
1522-
"model": completion["model"],
1523-
"choices": [
1524-
{
1525-
"index": 0,
1526-
"message": {
1527-
"role": "assistant",
1528-
"content": completion["choices"][0]["text"],
1529-
},
1530-
"finish_reason": completion["choices"][0]["finish_reason"],
1531-
}
1532-
],
1533-
"usage": completion["usage"],
1534-
}
1535-
1536-
def _convert_text_completion_chunks_to_chat(
1537-
self,
1538-
chunks: Iterator[CompletionChunk],
1539-
) -> Iterator[ChatCompletionChunk]:
1540-
for i, chunk in enumerate(chunks):
1541-
if i == 0:
1542-
yield {
1543-
"id": "chat" + chunk["id"],
1544-
"model": chunk["model"],
1545-
"created": chunk["created"],
1546-
"object": "chat.completion.chunk",
1547-
"choices": [
1548-
{
1549-
"index": 0,
1550-
"delta": {
1551-
"role": "assistant",
1552-
},
1553-
"finish_reason": None,
1554-
}
1555-
],
1556-
}
1557-
yield {
1558-
"id": "chat" + chunk["id"],
1559-
"model": chunk["model"],
1560-
"created": chunk["created"],
1561-
"object": "chat.completion.chunk",
1562-
"choices": [
1563-
{
1564-
"index": 0,
1565-
"delta": {
1566-
"content": chunk["choices"][0]["text"],
1567-
}
1568-
if chunk["choices"][0]["finish_reason"] is None
1569-
else {},
1570-
"finish_reason": chunk["choices"][0]["finish_reason"],
1571-
}
1572-
],
1573-
}
1574-
1575-
def _convert_completion_to_chat(
1576-
self,
1577-
completion_or_chunks: Union[Completion, Iterator[CompletionChunk]],
1578-
stream: bool = False,
1579-
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
1580-
if stream:
1581-
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
1582-
return self._convert_text_completion_chunks_to_chat(chunks)
1583-
else:
1584-
completion: Completion = completion_or_chunks # type: ignore
1585-
return self._convert_text_completion_to_chat(completion)
1586-
15871515
def create_chat_completion(
15881516
self,
15891517
messages: List[ChatCompletionRequestMessage],
@@ -1621,21 +1549,12 @@ def create_chat_completion(
16211549
Returns:
16221550
Generated chat completion or a stream of chat completion chunks.
16231551
"""
1624-
1625-
format = llama_chat_format.get_chat_format(self.chat_format)
1626-
result = format(
1552+
handler = llama_chat_format.get_chat_completion_handler(self.chat_format)
1553+
return handler(
1554+
self,
16271555
messages=messages,
16281556
functions=functions,
16291557
function_call=function_call,
1630-
)
1631-
prompt = result.prompt
1632-
if result.stop is not None:
1633-
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
1634-
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
1635-
stop = stop + rstop
1636-
1637-
completion_or_chunks = self.create_completion(
1638-
prompt=prompt,
16391558
temperature=temperature,
16401559
top_p=top_p,
16411560
top_k=top_k,
@@ -1653,7 +1572,6 @@ def create_chat_completion(
16531572
logits_processor=logits_processor,
16541573
grammar=grammar,
16551574
)
1656-
return self._convert_completion_to_chat(completion_or_chunks, stream=stream) # type: ignore
16571575

16581576
def __del__(self):
16591577
if hasattr(self, "model") and self.model is not None:

0 commit comments

Comments
 (0)