14
14
import anyio
15
15
from anyio .streams .memory import MemoryObjectSendStream
16
16
from starlette .concurrency import run_in_threadpool , iterate_in_threadpool
17
- from fastapi import Depends , FastAPI , APIRouter , Request , Response
17
+ from fastapi import Depends , FastAPI , APIRouter , Request , Response , HTTPException , status
18
18
from fastapi .middleware import Middleware
19
19
from fastapi .middleware .cors import CORSMiddleware
20
20
from fastapi .responses import JSONResponse
21
21
from fastapi .routing import APIRoute
22
+ from fastapi .security import HTTPBearer
22
23
from pydantic import BaseModel , Field
23
24
from pydantic_settings import BaseSettings
24
25
from sse_starlette .sse import EventSourceResponse
@@ -161,6 +162,10 @@ class Settings(BaseSettings):
161
162
default = True ,
162
163
description = "Whether to interrupt requests when a new request is received." ,
163
164
)
165
+ api_key : Optional [str ] = Field (
166
+ default = None ,
167
+ description = "API key for authentication. If set all requests need to be authenticated."
168
+ )
164
169
165
170
166
171
class ErrorResponse (TypedDict ):
@@ -312,6 +317,9 @@ async def custom_route_handler(request: Request) -> Response:
312
317
elapsed_time_ms = int ((time .perf_counter () - start_sec ) * 1000 )
313
318
response .headers ["openai-processing-ms" ] = f"{ elapsed_time_ms } "
314
319
return response
320
+ except HTTPException as unauthorized :
321
+ # api key check failed
322
+ raise unauthorized
315
323
except Exception as exc :
316
324
json_body = await request .json ()
317
325
try :
@@ -656,6 +664,27 @@ def _logit_bias_tokens_to_input_ids(
656
664
return to_bias
657
665
658
666
667
+ # Setup Bearer authentication scheme
668
+ bearer_scheme = HTTPBearer (auto_error = False )
669
+
670
+
671
+ async def authenticate (settings : Settings = Depends (get_settings ), authorization : Optional [str ] = Depends (bearer_scheme )):
672
+ # Skip API key check if it's not set in settings
673
+ if settings .api_key is None :
674
+ return True
675
+
676
+ # check bearer credentials against the api_key
677
+ if authorization and authorization .credentials == settings .api_key :
678
+ # api key is valid
679
+ return authorization .credentials
680
+
681
+ # raise http error 401
682
+ raise HTTPException (
683
+ status_code = status .HTTP_401_UNAUTHORIZED ,
684
+ detail = "Invalid API key" ,
685
+ )
686
+
687
+
659
688
@router .post (
660
689
"/v1/completions" ,
661
690
summary = "Completion"
@@ -665,6 +694,7 @@ async def create_completion(
665
694
request : Request ,
666
695
body : CreateCompletionRequest ,
667
696
llama : llama_cpp .Llama = Depends (get_llama ),
697
+ authenticated : str = Depends (authenticate ),
668
698
) -> llama_cpp .Completion :
669
699
if isinstance (body .prompt , list ):
670
700
assert len (body .prompt ) <= 1
@@ -738,7 +768,9 @@ class CreateEmbeddingRequest(BaseModel):
738
768
summary = "Embedding"
739
769
)
740
770
async def create_embedding (
741
- request : CreateEmbeddingRequest , llama : llama_cpp .Llama = Depends (get_llama )
771
+ request : CreateEmbeddingRequest ,
772
+ llama : llama_cpp .Llama = Depends (get_llama ),
773
+ authenticated : str = Depends (authenticate ),
742
774
):
743
775
return await run_in_threadpool (
744
776
llama .create_embedding , ** request .model_dump (exclude = {"user" })
@@ -832,6 +864,7 @@ async def create_chat_completion(
832
864
body : CreateChatCompletionRequest ,
833
865
llama : llama_cpp .Llama = Depends (get_llama ),
834
866
settings : Settings = Depends (get_settings ),
867
+ authenticated : str = Depends (authenticate ),
835
868
) -> llama_cpp .ChatCompletion :
836
869
exclude = {
837
870
"n" ,
@@ -893,6 +926,7 @@ class ModelList(TypedDict):
893
926
@router .get ("/v1/models" , summary = "Models" )
894
927
async def get_models (
895
928
settings : Settings = Depends (get_settings ),
929
+ authenticated : str = Depends (authenticate ),
896
930
) -> ModelList :
897
931
assert llama is not None
898
932
return {
0 commit comments