Skip to content

Commit 14c2189

Browse files
author
DigiDecode@Youtube
committed
dirty patch server to get codestral working with continue.dev vscode extension
1 parent 165b4dc commit 14c2189

File tree

4 files changed

+201
-0
lines changed

4 files changed

+201
-0
lines changed

continue.dev/config.json

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
{
2+
"models": [
3+
{
4+
"title": "Codestral",
5+
"model": "codestral-latest",
6+
"apiBase": "http://127.0.0.1:4321/v1/",
7+
"provider": "openai",
8+
"apiKey": "key",
9+
10+
"completionOptions": {
11+
"maxTokens": 8000
12+
}
13+
}
14+
],
15+
"tabAutocompleteModel": {
16+
"title": "Codestral",
17+
"model": "codestral-latest",
18+
"apiBase": "http://127.0.0.1:4321/v1/",
19+
"provider": "openai",
20+
"apiKey": "key",
21+
22+
"completionOptions": {
23+
"maxTokens": 200
24+
}
25+
},
26+
"tabAutocompleteOptions": {
27+
"useCache": true,
28+
"disable": false
29+
},
30+
"slashCommands": [
31+
{
32+
"name": "edit",
33+
"description": "Edit selected code"
34+
},
35+
{
36+
"name": "comment",
37+
"description": "Write comments for the selected code"
38+
},
39+
{
40+
"name": "share",
41+
"description": "Export this session as markdown"
42+
},
43+
{
44+
"name": "cmd",
45+
"description": "Generate a shell command"
46+
}
47+
],
48+
"customCommands": [
49+
{
50+
"name": "test",
51+
"prompt": "Write a comprehensive set of unit tests for the selected code. It should setup, run tests that check for correctness including important edge cases, and teardown. Ensure that the tests are complete and sophisticated. Give the tests just as chat output, don't edit any file.",
52+
"description": "Write unit tests for highlighted code"
53+
}
54+
],
55+
"contextProviders": [
56+
{
57+
"name": "diff",
58+
"params": {}
59+
},
60+
{
61+
"name": "open",
62+
"params": {}
63+
},
64+
{
65+
"name": "terminal",
66+
"params": {}
67+
},
68+
{
69+
"name": "problems",
70+
"params": {}
71+
},
72+
{
73+
"name": "codebase",
74+
"params": {}
75+
},
76+
{
77+
"name": "code",
78+
"params": {}
79+
},
80+
{
81+
"name": "docs",
82+
"params": {}
83+
}
84+
],
85+
"allowAnonymousTelemetry": false
86+
}

continue.dev/fim.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
def add
2+
return a+b

llama_cpp/server/app.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
DetokenizeInputResponse,
4242
)
4343
from llama_cpp.server.errors import RouteErrorHandler
44+
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
45+
from mistral_common.tokens.instruct.request import FIMRequest
4446

4547

4648
router = APIRouter(route_class=RouteErrorHandler)
@@ -264,6 +266,38 @@ async def create_completion(
264266
assert len(body.prompt) <= 1
265267
body.prompt = body.prompt[0] if len(body.prompt) > 0 else ""
266268

269+
prompt = body.prompt
270+
prompt = prompt.replace('[SUFFIX]','')
271+
272+
print(f'##### PROMPT#:\n{json.dumps(body.prompt)}\n\n')
273+
print(f'##### SUFFIX#:\n{json.dumps(body.suffix)}\n\n')
274+
print(f'##### STOP REPR#:\n')
275+
if body.stop is not None:
276+
for s in body.stop: print(repr(s))
277+
print(f'##### STOP#:\n')
278+
for s in body.stop: print(s)
279+
print(f'\n\n')
280+
281+
body.stop = ['[PREFIX]','[/PREFIX]', '</s>', '[SUFFIX]', '[MIDDLE]']
282+
283+
tokenizer = MistralTokenizer.v3()
284+
285+
if '[PREFIX]' in prompt:
286+
prompt_parts = prompt.split('[PREFIX]')
287+
prefix = prompt_parts[1] if len(prompt_parts) > 0 else ''
288+
postfix = prompt_parts[0]
289+
else:
290+
prefix = prompt
291+
postfix = ''
292+
293+
fim_request = FIMRequest(prompt=prefix, suffix=postfix)
294+
fim_tokens = tokenizer.encode_fim(fim_request)
295+
body.prompt = fim_tokens.text
296+
297+
print(f'##### prefix#:\n{repr(prefix)}\n\n')
298+
print(f'##### postfix#:\n{repr(postfix)}\n\n')
299+
print(f'##### fim_tokens.tex#:\n{repr(fim_tokens.text)}\n\n')
300+
267301
llama = llama_proxy(
268302
body.model
269303
if request.url.path != "/v1/engines/copilot-codex/completions"

run-codestral.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from fastapi import Request
2+
from fastapi.responses import JSONResponse
3+
from llama_cpp.server.settings import ServerSettings, ModelSettings
4+
from llama_cpp.server.app import create_app
5+
import uvicorn
6+
import llama_cpp
7+
import asyncio
8+
9+
def dprint(*args, **kwargs):
10+
print(*args, **kwargs)
11+
12+
model_path = '/home/zbuntu/projects/Codestral-22B-v0.1-Q5_K_M.gguf'
13+
14+
server_settings = ServerSettings()
15+
server_settings.host = "127.0.0.1"
16+
server_settings.port = 4321
17+
18+
model_setting: dict = {
19+
"model": model_path,
20+
"model_alias": "codegemma-7b8bit",
21+
"n_gpu_layers": -1,
22+
"rope_scaling_type": llama_cpp.LLAMA_ROPE_SCALING_TYPE_LINEAR,
23+
"rope_freq_base": 1000000.0,
24+
"rope_freq_scale": 1,
25+
"n_ctx": 16000,
26+
"n_threads":1
27+
}
28+
29+
app = create_app(
30+
server_settings=server_settings,
31+
model_settings=[ModelSettings(**model_setting)],
32+
)
33+
34+
lock = asyncio.Lock()
35+
queue = asyncio.Queue(maxsize=1) # Set the queue size to 1
36+
37+
async def wait_for_disconnect(request: Request, queue):
38+
dprint('starting to wait for disconnect')
39+
while True:
40+
if await request.is_disconnected():
41+
break
42+
await asyncio.sleep(0.5)
43+
dprint('client disconnected')
44+
await queue.get()
45+
dprint('request removed from queue')
46+
47+
48+
@app.middleware("http")
49+
async def synchronize_requests(request: Request, call_next):
50+
try:
51+
dprint('adding request to queue')
52+
await queue.put(request) # Add the request to the queue
53+
54+
dprint('request added to queue')
55+
async with lock:
56+
dprint('acquired lock serving request')
57+
try:
58+
# Process the request
59+
response = await call_next(request)
60+
dprint('response received')
61+
asyncio.create_task(wait_for_disconnect(request, queue))
62+
dprint('wait for disconnect thread created')
63+
# print(response)
64+
return response
65+
finally:
66+
dprint('request processed')
67+
except asyncio.QueueFull:
68+
dprint('exceptoin encountered')
69+
# If the queue is full, return a 503 error response
70+
return JSONResponse(status_code=503, content={"error": "Service unavailable, try again later"})
71+
72+
73+
uvicorn.run(
74+
app,
75+
host=server_settings.host,
76+
port=server_settings.port,
77+
ssl_keyfile=server_settings.ssl_keyfile,
78+
ssl_certfile=server_settings.ssl_certfile,
79+
)

0 commit comments

Comments
 (0)