Skip to content

Commit 673e4fa

Browse files
committed
Implement openai api compatible authentication
1 parent cbce061 commit 673e4fa

File tree

1 file changed

+36
-2
lines changed

1 file changed

+36
-2
lines changed

llama_cpp/server/app.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414
import anyio
1515
from anyio.streams.memory import MemoryObjectSendStream
1616
from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
17-
from fastapi import Depends, FastAPI, APIRouter, Request, Response
17+
from fastapi import Depends, FastAPI, APIRouter, Request, Response, HTTPException, status
1818
from fastapi.middleware import Middleware
1919
from fastapi.middleware.cors import CORSMiddleware
2020
from fastapi.responses import JSONResponse
2121
from fastapi.routing import APIRoute
22+
from fastapi.security import HTTPBearer
2223
from pydantic import BaseModel, Field
2324
from pydantic_settings import BaseSettings
2425
from sse_starlette.sse import EventSourceResponse
@@ -161,6 +162,10 @@ class Settings(BaseSettings):
161162
default=True,
162163
description="Whether to interrupt requests when a new request is received.",
163164
)
165+
api_key: Optional[str] = Field(
166+
default=None,
167+
description="API key for authentication. If set all requests need to be authenticated."
168+
)
164169

165170

166171
class ErrorResponse(TypedDict):
@@ -312,6 +317,9 @@ async def custom_route_handler(request: Request) -> Response:
312317
elapsed_time_ms = int((time.perf_counter() - start_sec) * 1000)
313318
response.headers["openai-processing-ms"] = f"{elapsed_time_ms}"
314319
return response
320+
except HTTPException as unauthorized:
321+
# api key check failed
322+
raise unauthorized
315323
except Exception as exc:
316324
json_body = await request.json()
317325
try:
@@ -656,6 +664,27 @@ def _logit_bias_tokens_to_input_ids(
656664
return to_bias
657665

658666

667+
# Setup Bearer authentication scheme
668+
bearer_scheme = HTTPBearer(auto_error=False)
669+
670+
671+
async def authenticate(settings: Settings = Depends(get_settings), authorization: Optional[str] = Depends(bearer_scheme)):
672+
# Skip API key check if it's not set in settings
673+
if settings.api_key is None:
674+
return True
675+
676+
# check bearer credentials against the api_key
677+
if authorization and authorization.credentials == settings.api_key:
678+
# api key is valid
679+
return authorization.credentials
680+
681+
# raise http error 401
682+
raise HTTPException(
683+
status_code=status.HTTP_401_UNAUTHORIZED,
684+
detail="Invalid API key",
685+
)
686+
687+
659688
@router.post(
660689
"/v1/completions",
661690
summary="Completion"
@@ -665,6 +694,7 @@ async def create_completion(
665694
request: Request,
666695
body: CreateCompletionRequest,
667696
llama: llama_cpp.Llama = Depends(get_llama),
697+
authenticated: str = Depends(authenticate),
668698
) -> llama_cpp.Completion:
669699
if isinstance(body.prompt, list):
670700
assert len(body.prompt) <= 1
@@ -738,7 +768,9 @@ class CreateEmbeddingRequest(BaseModel):
738768
summary="Embedding"
739769
)
740770
async def create_embedding(
741-
request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama)
771+
request: CreateEmbeddingRequest,
772+
llama: llama_cpp.Llama = Depends(get_llama),
773+
authenticated: str = Depends(authenticate),
742774
):
743775
return await run_in_threadpool(
744776
llama.create_embedding, **request.model_dump(exclude={"user"})
@@ -832,6 +864,7 @@ async def create_chat_completion(
832864
body: CreateChatCompletionRequest,
833865
llama: llama_cpp.Llama = Depends(get_llama),
834866
settings: Settings = Depends(get_settings),
867+
authenticated: str = Depends(authenticate),
835868
) -> llama_cpp.ChatCompletion:
836869
exclude = {
837870
"n",
@@ -893,6 +926,7 @@ class ModelList(TypedDict):
893926
@router.get("/v1/models", summary="Models")
894927
async def get_models(
895928
settings: Settings = Depends(get_settings),
929+
authenticated: str = Depends(authenticate),
896930
) -> ModelList:
897931
assert llama is not None
898932
return {

0 commit comments

Comments
 (0)