23
23
24
24
from . import llama_cpp
25
25
from .llama_types import *
26
+ from .llama_grammar import LlamaGrammar
26
27
27
28
import numpy as np
28
29
import numpy .typing as npt
29
30
31
+ from .utils import suppress_stdout_stderr
30
32
31
33
class BaseLlamaCache (ABC ):
32
34
"""Base cache class for a llama.cpp model."""
@@ -231,7 +233,8 @@ def __init__(
231
233
rope_freq_base : float = 10000.0 ,
232
234
rope_freq_scale : float = 1.0 ,
233
235
n_gqa : Optional [int ] = None , # (TEMPORARY) must be 8 for llama2 70b
234
- rms_norm_eps : Optional [float ] = None , # (TEMPORARY)
236
+ rms_norm_eps : Optional [float ] = None , # (TEMPORARY)
237
+ mul_mat_q : Optional [bool ] = None , # (TEMPORARY)
235
238
verbose : bool = True ,
236
239
):
237
240
"""Load a llama.cpp model from `model_path`.
@@ -241,6 +244,7 @@ def __init__(
241
244
n_ctx: Maximum context size.
242
245
n_parts: Number of parts to split the model into. If -1, the number of parts is automatically determined.
243
246
seed: Random seed. -1 for random.
247
+ n_gpu_layers: Number of layers to offload to GPU (-ngl). If -1, all layers are offloaded.
244
248
f16_kv: Use half-precision for key/value cache.
245
249
logits_all: Return logits for all tokens, not just the last token.
246
250
vocab_only: Only load the vocabulary no weights.
@@ -269,7 +273,7 @@ def __init__(
269
273
270
274
self .params = llama_cpp .llama_context_default_params ()
271
275
self .params .n_ctx = n_ctx
272
- self .params .n_gpu_layers = n_gpu_layers
276
+ self .params .n_gpu_layers = 0x7FFFFFFF if n_gpu_layers == - 1 else n_gpu_layers # 0x7FFFFFFF is INT32 max, will be auto set to all layers
273
277
self .params .seed = seed
274
278
self .params .f16_kv = f16_kv
275
279
self .params .logits_all = logits_all
@@ -280,7 +284,7 @@ def __init__(
280
284
self .params .low_vram = low_vram
281
285
282
286
self .tensor_split = tensor_split
283
- self ._c_tensor_split = None
287
+ self ._p_tensor_split = None
284
288
285
289
if self .tensor_split is not None :
286
290
# Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
@@ -299,6 +303,9 @@ def __init__(
299
303
if rms_norm_eps is not None :
300
304
self .params .rms_norm_eps = rms_norm_eps
301
305
306
+ if mul_mat_q is not None :
307
+ self .params .mul_mat_q = mul_mat_q
308
+
302
309
self .last_n_tokens_size = last_n_tokens_size
303
310
self .n_batch = min (n_ctx , n_batch )
304
311
@@ -316,12 +323,25 @@ def __init__(
316
323
if not os .path .exists (model_path ):
317
324
raise ValueError (f"Model path does not exist: { model_path } " )
318
325
319
- self .model = llama_cpp .llama_load_model_from_file (
320
- self .model_path .encode ("utf-8" ), self .params
321
- )
326
+ if verbose :
327
+ self .model = llama_cpp .llama_load_model_from_file (
328
+ self .model_path .encode ("utf-8" ), self .params
329
+ )
330
+ else :
331
+ with suppress_stdout_stderr ():
332
+ self .model = llama_cpp .llama_load_model_from_file (
333
+ self .model_path .encode ("utf-8" ), self .params
334
+ )
322
335
assert self .model is not None
323
336
324
- self .ctx = llama_cpp .llama_new_context_with_model (self .model , self .params )
337
+ if verbose :
338
+ self .ctx = llama_cpp .llama_new_context_with_model (self .model , self .params )
339
+ else :
340
+ with suppress_stdout_stderr ():
341
+ print ("here" )
342
+ self .ctx = llama_cpp .llama_new_context_with_model (
343
+ self .model , self .params
344
+ )
325
345
326
346
assert self .ctx is not None
327
347
@@ -358,8 +378,8 @@ def __init__(
358
378
sorted = sorted ,
359
379
)
360
380
self ._candidates = candidates
361
- self ._token_nl = Llama .token_nl ()
362
- self ._token_eos = Llama .token_eos ()
381
+ self ._token_nl = self .token_nl ()
382
+ self ._token_eos = self .token_eos ()
363
383
self ._candidates_data_id = np .arange (self ._n_vocab , dtype = np .intc ) # type: ignore
364
384
self ._candidates_data_p = np .zeros (self ._n_vocab , dtype = np .single )
365
385
@@ -437,10 +457,14 @@ def detokenize(self, tokens: List[int]) -> bytes:
437
457
"""
438
458
assert self .ctx is not None
439
459
output = b""
460
+ buffer_size = 32
461
+ buffer = (ctypes .c_char * buffer_size )()
440
462
for token in tokens :
441
- output + = llama_cpp .llama_token_to_str (
442
- self .ctx , llama_cpp .llama_token (token )
463
+ n = llama_cpp .llama_token_to_str (
464
+ self .ctx , llama_cpp .llama_token (token ), buffer , buffer_size
443
465
)
466
+ assert n <= buffer_size
467
+ output += bytes (buffer [:n ])
444
468
return output
445
469
446
470
def set_cache (self , cache : Optional [BaseLlamaCache ]):
@@ -506,6 +530,7 @@ def _sample(
506
530
mirostat_eta : llama_cpp .c_float ,
507
531
penalize_nl : bool = True ,
508
532
logits_processor : Optional [LogitsProcessorList ] = None ,
533
+ grammar : Optional [LlamaGrammar ] = None ,
509
534
):
510
535
assert self .ctx is not None
511
536
assert self .n_tokens > 0
@@ -548,8 +573,16 @@ def _sample(
548
573
)
549
574
if not penalize_nl :
550
575
candidates .data [self ._token_nl ].logit = llama_cpp .c_float (nl_logit )
576
+
577
+ if grammar is not None :
578
+ llama_cpp .llama_sample_grammar (
579
+ ctx = self .ctx ,
580
+ candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
581
+ grammar = grammar .grammar ,
582
+ )
583
+
551
584
if temp .value == 0.0 :
552
- return llama_cpp .llama_sample_token_greedy (
585
+ id = llama_cpp .llama_sample_token_greedy (
553
586
ctx = self .ctx ,
554
587
candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
555
588
)
@@ -561,7 +594,7 @@ def _sample(
561
594
candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
562
595
temp = temp ,
563
596
)
564
- return llama_cpp .llama_sample_token_mirostat (
597
+ id = llama_cpp .llama_sample_token_mirostat (
565
598
ctx = self .ctx ,
566
599
candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
567
600
tau = mirostat_tau ,
@@ -576,7 +609,7 @@ def _sample(
576
609
candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
577
610
temp = temp ,
578
611
)
579
- return llama_cpp .llama_sample_token_mirostat_v2 (
612
+ id = llama_cpp .llama_sample_token_mirostat_v2 (
580
613
ctx = self .ctx ,
581
614
candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
582
615
tau = mirostat_tau ,
@@ -613,10 +646,17 @@ def _sample(
613
646
candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
614
647
temp = temp ,
615
648
)
616
- return llama_cpp .llama_sample_token (
649
+ id = llama_cpp .llama_sample_token (
617
650
ctx = self .ctx ,
618
651
candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
619
652
)
653
+ if grammar is not None :
654
+ llama_cpp .llama_grammar_accept_token (
655
+ ctx = self .ctx ,
656
+ grammar = grammar .grammar ,
657
+ token = llama_cpp .ctypes .c_int (id ),
658
+ )
659
+ return id
620
660
621
661
def sample (
622
662
self ,
@@ -632,6 +672,7 @@ def sample(
632
672
mirostat_tau : float = 5.0 ,
633
673
penalize_nl : bool = True ,
634
674
logits_processor : Optional [LogitsProcessorList ] = None ,
675
+ grammar : Optional [LlamaGrammar ] = None ,
635
676
):
636
677
"""Sample a token from the model.
637
678
@@ -665,6 +706,7 @@ def sample(
665
706
mirostat_eta = llama_cpp .c_float (mirostat_eta ),
666
707
penalize_nl = penalize_nl ,
667
708
logits_processor = logits_processor ,
709
+ grammar = grammar ,
668
710
)
669
711
670
712
def generate (
@@ -683,6 +725,7 @@ def generate(
683
725
mirostat_eta : float = 0.1 ,
684
726
logits_processor : Optional [LogitsProcessorList ] = None ,
685
727
stopping_criteria : Optional [StoppingCriteriaList ] = None ,
728
+ grammar : Optional [LlamaGrammar ] = None ,
686
729
) -> Generator [int , Optional [Sequence [int ]], None ]:
687
730
"""Create a generator of tokens from a prompt.
688
731
@@ -704,7 +747,6 @@ def generate(
704
747
The generated tokens.
705
748
"""
706
749
assert self .ctx is not None
707
-
708
750
if reset and len (self ._input_ids ) > 0 :
709
751
longest_prefix = 0
710
752
for a , b in zip (self ._input_ids , tokens [:- 1 ]):
@@ -722,6 +764,9 @@ def generate(
722
764
if reset :
723
765
self .reset ()
724
766
767
+ if grammar is not None :
768
+ grammar .reset ()
769
+
725
770
while True :
726
771
self .eval (tokens )
727
772
token = self .sample (
@@ -736,6 +781,7 @@ def generate(
736
781
mirostat_tau = mirostat_tau ,
737
782
mirostat_eta = mirostat_eta ,
738
783
logits_processor = logits_processor ,
784
+ grammar = grammar ,
739
785
)
740
786
if stopping_criteria is not None and stopping_criteria (
741
787
self ._input_ids , self ._scores [- 1 , :]
@@ -838,6 +884,7 @@ def _create_completion(
838
884
model : Optional [str ] = None ,
839
885
stopping_criteria : Optional [StoppingCriteriaList ] = None ,
840
886
logits_processor : Optional [LogitsProcessorList ] = None ,
887
+ grammar : Optional [LlamaGrammar ] = None ,
841
888
) -> Union [Iterator [Completion ], Iterator [CompletionChunk ]]:
842
889
assert self .ctx is not None
843
890
@@ -915,6 +962,7 @@ def _create_completion(
915
962
repeat_penalty = repeat_penalty ,
916
963
stopping_criteria = stopping_criteria ,
917
964
logits_processor = logits_processor ,
965
+ grammar = grammar ,
918
966
):
919
967
if token == self ._token_eos :
920
968
text = self .detokenize (completion_tokens )
@@ -965,9 +1013,7 @@ def _create_completion(
965
1013
for token in remaining_tokens :
966
1014
token_end_position += len (self .detokenize ([token ]))
967
1015
# Check if stop sequence is in the token
968
- if token_end_position >= (
969
- remaining_length - first_stop_position
970
- ):
1016
+ if token_end_position >= (remaining_length - first_stop_position ):
971
1017
break
972
1018
logprobs_or_none : Optional [CompletionLogprobs ] = None
973
1019
if logprobs is not None :
@@ -1261,6 +1307,7 @@ def create_completion(
1261
1307
model : Optional [str ] = None ,
1262
1308
stopping_criteria : Optional [StoppingCriteriaList ] = None ,
1263
1309
logits_processor : Optional [LogitsProcessorList ] = None ,
1310
+ grammar : Optional [LlamaGrammar ] = None ,
1264
1311
) -> Union [Completion , Iterator [CompletionChunk ]]:
1265
1312
"""Generate text from a prompt.
1266
1313
@@ -1305,6 +1352,7 @@ def create_completion(
1305
1352
model = model ,
1306
1353
stopping_criteria = stopping_criteria ,
1307
1354
logits_processor = logits_processor ,
1355
+ grammar = grammar
1308
1356
)
1309
1357
if stream :
1310
1358
chunks : Iterator [CompletionChunk ] = completion_or_chunks
@@ -1334,6 +1382,7 @@ def __call__(
1334
1382
model : Optional [str ] = None ,
1335
1383
stopping_criteria : Optional [StoppingCriteriaList ] = None ,
1336
1384
logits_processor : Optional [LogitsProcessorList ] = None ,
1385
+ grammar : Optional [LlamaGrammar ] = None ,
1337
1386
) -> Union [Completion , Iterator [CompletionChunk ]]:
1338
1387
"""Generate text from a prompt.
1339
1388
@@ -1378,6 +1427,7 @@ def __call__(
1378
1427
model = model ,
1379
1428
stopping_criteria = stopping_criteria ,
1380
1429
logits_processor = logits_processor ,
1430
+ grammar = grammar ,
1381
1431
)
1382
1432
1383
1433
def _convert_text_completion_to_chat (
@@ -1460,6 +1510,7 @@ def create_chat_completion(
1460
1510
mirostat_eta : float = 0.1 ,
1461
1511
model : Optional [str ] = None ,
1462
1512
logits_processor : Optional [LogitsProcessorList ] = None ,
1513
+ grammar : Optional [LlamaGrammar ] = None ,
1463
1514
) -> Union [ChatCompletion , Iterator [ChatCompletionChunk ]]:
1464
1515
"""Generate a chat completion from a list of messages.
1465
1516
@@ -1502,6 +1553,7 @@ def create_chat_completion(
1502
1553
mirostat_eta = mirostat_eta ,
1503
1554
model = model ,
1504
1555
logits_processor = logits_processor ,
1556
+ grammar = grammar ,
1505
1557
)
1506
1558
if stream :
1507
1559
chunks : Iterator [CompletionChunk ] = completion_or_chunks # type: ignore
@@ -1511,10 +1563,10 @@ def create_chat_completion(
1511
1563
return self ._convert_text_completion_to_chat (completion )
1512
1564
1513
1565
def __del__ (self ):
1514
- if self .model is not None :
1566
+ if hasattr ( self , "model" ) and self .model is not None :
1515
1567
llama_cpp .llama_free_model (self .model )
1516
1568
self .model = None
1517
- if self .ctx is not None :
1569
+ if hasattr ( self , "ctx" ) and self .ctx is not None :
1518
1570
llama_cpp .llama_free (self .ctx )
1519
1571
self .ctx = None
1520
1572
@@ -1638,20 +1690,20 @@ def tokenizer(self) -> "LlamaTokenizer":
1638
1690
assert self .ctx is not None
1639
1691
return LlamaTokenizer (self )
1640
1692
1641
- @staticmethod
1642
- def token_eos () -> int :
1693
+ def token_eos (self ) -> int :
1643
1694
"""Return the end-of-sequence token."""
1644
- return llama_cpp .llama_token_eos ()
1695
+ assert self .ctx is not None
1696
+ return llama_cpp .llama_token_eos (self .ctx )
1645
1697
1646
- @staticmethod
1647
- def token_bos () -> int :
1698
+ def token_bos (self ) -> int :
1648
1699
"""Return the beginning-of-sequence token."""
1649
- return llama_cpp .llama_token_bos ()
1700
+ assert self .ctx is not None
1701
+ return llama_cpp .llama_token_bos (self .ctx )
1650
1702
1651
- @staticmethod
1652
- def token_nl () -> int :
1703
+ def token_nl (self ) -> int :
1653
1704
"""Return the newline token."""
1654
- return llama_cpp .llama_token_nl ()
1705
+ assert self .ctx is not None
1706
+ return llama_cpp .llama_token_nl (self .ctx )
1655
1707
1656
1708
@staticmethod
1657
1709
def logits_to_logprobs (logits : List [float ]) -> List [float ]:
0 commit comments