Skip to content

Commit 7fedf16

Browse files
committed
Add support for chat completion
1 parent 3dec778 commit 7fedf16

File tree

3 files changed

+211
-4
lines changed

3 files changed

+211
-4
lines changed

examples/fastapi_server.py

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,18 @@
11
"""Example FastAPI server for llama.cpp.
2+
3+
To run this example:
4+
5+
```bash
6+
pip install fastapi uvicorn sse-starlette
7+
export MODEL=../models/7B/...
8+
uvicorn fastapi_server_chat:app --reload
9+
```
10+
11+
Then visit http://localhost:8000/docs to see the interactive API docs.
12+
213
"""
314
import json
4-
from typing import List, Optional, Iterator
15+
from typing import List, Optional, Literal, Union, Iterator
516

617
import llama_cpp
718

@@ -95,4 +106,67 @@ class Config:
95106
response_model=CreateEmbeddingResponse,
96107
)
97108
def create_embedding(request: CreateEmbeddingRequest):
98-
return llama.create_embedding(request.input)
109+
return llama.create_embedding(**request.dict(exclude={"model", "user"}))
110+
111+
112+
class ChatCompletionRequestMessage(BaseModel):
113+
role: Union[Literal["system"], Literal["user"], Literal["assistant"]]
114+
content: str
115+
user: Optional[str] = None
116+
117+
118+
class CreateChatCompletionRequest(BaseModel):
119+
model: Optional[str]
120+
messages: List[ChatCompletionRequestMessage]
121+
temperature: float = 0.8
122+
top_p: float = 0.95
123+
stream: bool = False
124+
stop: List[str] = []
125+
max_tokens: int = 128
126+
repeat_penalty: float = 1.1
127+
128+
class Config:
129+
schema_extra = {
130+
"example": {
131+
"messages": [
132+
ChatCompletionRequestMessage(
133+
role="system", content="You are a helpful assistant."
134+
),
135+
ChatCompletionRequestMessage(
136+
role="user", content="What is the capital of France?"
137+
),
138+
]
139+
}
140+
}
141+
142+
143+
CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion)
144+
145+
146+
@app.post(
147+
"/v1/chat/completions",
148+
response_model=CreateChatCompletionResponse,
149+
)
150+
async def create_chat_completion(
151+
request: CreateChatCompletionRequest,
152+
) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]:
153+
completion_or_chunks = llama.create_chat_completion(
154+
**request.dict(exclude={"model"}),
155+
)
156+
157+
if request.stream:
158+
159+
async def server_sent_events(
160+
chat_chunks: Iterator[llama_cpp.ChatCompletionChunk],
161+
):
162+
for chat_chunk in chat_chunks:
163+
yield dict(data=json.dumps(chat_chunk))
164+
yield dict(data="[DONE]")
165+
166+
chunks: Iterator[llama_cpp.ChatCompletionChunk] = completion_or_chunks # type: ignore
167+
168+
return EventSourceResponse(
169+
server_sent_events(chunks),
170+
)
171+
completion: llama_cpp.ChatCompletion = completion_or_chunks # type: ignore
172+
return completion

llama_cpp/llama.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,99 @@ def __call__(
517517
stream=stream,
518518
)
519519

520+
def _convert_text_completion_to_chat(
521+
self, completion: Completion
522+
) -> ChatCompletion:
523+
return {
524+
"id": "chat" + completion["id"],
525+
"object": "chat.completion",
526+
"created": completion["created"],
527+
"model": completion["model"],
528+
"choices": [
529+
{
530+
"index": 0,
531+
"message": {
532+
"role": "assistant",
533+
"content": completion["choices"][0]["text"],
534+
},
535+
"finish_reason": completion["choices"][0]["finish_reason"],
536+
}
537+
],
538+
"usage": completion["usage"],
539+
}
540+
541+
def _convert_text_completion_chunks_to_chat(
542+
self,
543+
chunks: Iterator[CompletionChunk],
544+
) -> Iterator[ChatCompletionChunk]:
545+
for i, chunk in enumerate(chunks):
546+
if i == 0:
547+
yield {
548+
"id": "chat" + chunk["id"],
549+
"model": chunk["model"],
550+
"created": chunk["created"],
551+
"object": "chat.completion.chunk",
552+
"choices": [
553+
{
554+
"index": 0,
555+
"delta": {
556+
"role": "assistant",
557+
},
558+
"finish_reason": None,
559+
}
560+
],
561+
}
562+
yield {
563+
"id": "chat" + chunk["id"],
564+
"model": chunk["model"],
565+
"created": chunk["created"],
566+
"object": "chat.completion.chunk",
567+
"choices": [
568+
{
569+
"index": 0,
570+
"delta": {
571+
"content": chunk["choices"][0]["text"],
572+
},
573+
"finish_reason": chunk["choices"][0]["finish_reason"],
574+
}
575+
],
576+
}
577+
578+
def create_chat_completion(
579+
self,
580+
messages: List[ChatCompletionMessage],
581+
temperature: float = 0.8,
582+
top_p: float = 0.95,
583+
top_k: int = 40,
584+
stream: bool = False,
585+
stop: List[str] = [],
586+
max_tokens: int = 128,
587+
repeat_penalty: float = 1.1,
588+
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
589+
instructions = """Complete the following chat conversation between the user and the assistant. System messages should be strictly followed as additional instructions."""
590+
chat_history = "\n".join(
591+
f'{message["role"]} {message.get("user", "")}: {message["content"]}'
592+
for message in messages
593+
)
594+
PROMPT = f" \n\n### Instructions:{instructions}\n\n### Inputs:{chat_history}\n\n### Response:\nassistant: "
595+
PROMPT_STOP = ["###", "\nuser: ", "\nassistant: ", "\nsystem: "]
596+
completion_or_chunks = self(
597+
prompt=PROMPT,
598+
stop=PROMPT_STOP + stop,
599+
temperature=temperature,
600+
top_p=top_p,
601+
top_k=top_k,
602+
stream=stream,
603+
max_tokens=max_tokens,
604+
repeat_penalty=repeat_penalty,
605+
)
606+
if stream:
607+
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
608+
return self._convert_text_completion_chunks_to_chat(chunks)
609+
else:
610+
completion: Completion = completion_or_chunks # type: ignore
611+
return self._convert_text_completion_to_chat(completion)
612+
520613
def __del__(self):
521614
if self.ctx is not None:
522615
llama_cpp.llama_free(self.ctx)

llama_cpp/llama_types.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from typing import List, Optional, Dict, Literal
2-
from typing_extensions import TypedDict
1+
from typing import List, Optional, Dict, Literal, Union
2+
from typing_extensions import TypedDict, NotRequired
33

44

55
class EmbeddingUsage(TypedDict):
@@ -55,3 +55,43 @@ class Completion(TypedDict):
5555
model: str
5656
choices: List[CompletionChoice]
5757
usage: CompletionUsage
58+
59+
60+
class ChatCompletionMessage(TypedDict):
61+
role: Union[Literal["assistant"], Literal["user"], Literal["system"]]
62+
content: str
63+
user: NotRequired[str]
64+
65+
66+
class ChatCompletionChoice(TypedDict):
67+
index: int
68+
message: ChatCompletionMessage
69+
finish_reason: Optional[str]
70+
71+
72+
class ChatCompletion(TypedDict):
73+
id: str
74+
object: Literal["chat.completion"]
75+
created: int
76+
model: str
77+
choices: List[ChatCompletionChoice]
78+
usage: CompletionUsage
79+
80+
81+
class ChatCompletionChunkDelta(TypedDict):
82+
role: NotRequired[Literal["assistant"]]
83+
content: NotRequired[str]
84+
85+
86+
class ChatCompletionChunkChoice(TypedDict):
87+
index: int
88+
delta: ChatCompletionChunkDelta
89+
finish_reason: Optional[str]
90+
91+
92+
class ChatCompletionChunk(TypedDict):
93+
id: str
94+
model: str
95+
object: Literal["chat.completion.chunk"]
96+
created: int
97+
choices: List[ChatCompletionChunkChoice]

0 commit comments

Comments
 (0)