Skip to content

Commit 80184a2

Browse files
committed
Update llama.cpp
1 parent 755f9fa commit 80184a2

File tree

2 files changed

+215
-15
lines changed

2 files changed

+215
-15
lines changed

llama_cpp/llama_cpp.py

Lines changed: 214 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ def _load_shared_library(lib_base_name):
6767
_lib = _load_shared_library(_lib_base_name)
6868

6969
# C types
70+
LLAMA_FILE_VERSION = ctypes.c_int(1)
71+
LLAMA_FILE_MAGIC = b"ggjt"
72+
LLAMA_FILE_MAGIC_UNVERSIONED = b"ggml"
73+
LLAMA_SESSION_MAGIC = b"ggsn"
74+
LLAMA_SESSION_VERSION = ctypes.c_int(0)
75+
7076
llama_context_p = c_void_p
7177

7278

@@ -77,13 +83,24 @@ def _load_shared_library(lib_base_name):
7783
class llama_token_data(Structure):
7884
_fields_ = [
7985
("id", llama_token), # token id
86+
("logit", c_float), # log-odds of the token
8087
("p", c_float), # probability of the token
81-
("plog", c_float), # log probability of the token
8288
]
8389

8490

8591
llama_token_data_p = POINTER(llama_token_data)
8692

93+
94+
class llama_token_data_array(Structure):
95+
_fields_ = [
96+
("data", llama_token_data_p),
97+
("size", c_size_t),
98+
("sorted", c_bool),
99+
]
100+
101+
102+
llama_token_data_array_p = POINTER(llama_token_data_array)
103+
87104
llama_progress_callback = ctypes.CFUNCTYPE(None, c_float, c_void_p)
88105

89106

@@ -118,7 +135,7 @@ class llama_context_params(Structure):
118135
4
119136
) # tok_embeddings.weight and output.weight are F16
120137
LLAMA_FTYPE_MOSTLY_Q4_2 = ctypes.c_int(5) # except 1d tensors
121-
LLAMA_FTYPE_MOSTYL_Q4_3 = ctypes.c_int(6) # except 1d tensors
138+
# LLAMA_FTYPE_MOSTYL_Q4_3 = ctypes.c_int(6) # except 1d tensors
122139
LLAMA_FTYPE_MOSTYL_Q8_0 = ctypes.c_int(7) # except 1d tensors
123140
LLAMA_FTYPE_MOSTYL_Q5_0 = ctypes.c_int(8) # except 1d tensors
124141
LLAMA_FTYPE_MOSTYL_Q5_1 = ctypes.c_int(9) # except 1d tensors
@@ -401,31 +418,214 @@ def llama_token_eos() -> llama_token:
401418
_lib.llama_token_eos.restype = llama_token
402419

403420

404-
# TODO: improve the last_n_tokens interface ?
405-
def llama_sample_top_p_top_k(
421+
def llama_token_nl() -> llama_token:
422+
return _lib.llama_token_nl()
423+
424+
425+
_lib.llama_token_nl.argtypes = []
426+
_lib.llama_token_nl.restype = llama_token
427+
428+
429+
# Sampling functions
430+
def llama_sample_repetition_penalty(
431+
ctx: llama_context_p,
432+
candidates,
433+
last_tokens_data,
434+
last_tokens_size: c_int,
435+
penalty: c_float,
436+
) -> llama_token:
437+
return _lib.llama_sample_repetition_penalty(
438+
ctx, candidates, last_tokens_data, last_tokens_size, penalty
439+
)
440+
441+
442+
_lib.llama_sample_repetition_penalty.argtypes = [
443+
llama_context_p,
444+
llama_token_data_array_p,
445+
llama_token_p,
446+
c_int,
447+
c_float,
448+
]
449+
_lib.llama_sample_repetition_penalty.restype = llama_token
450+
451+
452+
# LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
453+
def llama_sample_frequency_and_presence_penalties(
406454
ctx: llama_context_p,
407-
last_n_tokens_data, # type: Array[llama_token]
408-
last_n_tokens_size: c_int,
409-
top_k: c_int,
410-
top_p: c_float,
411-
temp: c_float,
412-
repeat_penalty: c_float,
455+
candidates,
456+
last_tokens_data,
457+
last_tokens_size: c_int,
458+
alpha_frequency: c_float,
459+
alpha_presence: c_float,
413460
) -> llama_token:
414-
return _lib.llama_sample_top_p_top_k(
415-
ctx, last_n_tokens_data, last_n_tokens_size, top_k, top_p, temp, repeat_penalty
461+
return _lib.llama_sample_frequency_and_presence_penalties(
462+
ctx,
463+
candidates,
464+
last_tokens_data,
465+
last_tokens_size,
466+
alpha_frequency,
467+
alpha_presence,
416468
)
417469

418470

419-
_lib.llama_sample_top_p_top_k.argtypes = [
471+
_lib.llama_sample_frequency_and_presence_penalties.argtypes = [
420472
llama_context_p,
473+
llama_token_data_array_p,
421474
llama_token_p,
422475
c_int,
476+
c_float,
477+
c_float,
478+
]
479+
_lib.llama_sample_frequency_and_presence_penalties.restype = llama_token
480+
481+
482+
# LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
483+
def llama_sample_softmax(ctx: llama_context_p, candidates) -> llama_token:
484+
return _lib.llama_sample_softmax(ctx, candidates)
485+
486+
487+
_lib.llama_sample_softmax.argtypes = [
488+
llama_context_p,
489+
llama_token_data_array_p,
490+
]
491+
_lib.llama_sample_softmax.restype = llama_token
492+
493+
494+
# LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep = 1);
495+
def llama_sample_top_k(
496+
ctx: llama_context_p, candidates, k: c_int, min_keep: c_int
497+
) -> llama_token:
498+
return _lib.llama_sample_top_k(ctx, candidates, k, min_keep)
499+
500+
501+
_lib.llama_sample_top_k.argtypes = [
502+
llama_context_p,
503+
llama_token_data_array_p,
504+
c_int,
505+
c_int,
506+
]
507+
508+
509+
# LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1);
510+
def llama_sample_top_p(
511+
ctx: llama_context_p, candidates, p: c_float, min_keep: c_int
512+
) -> llama_token:
513+
return _lib.llama_sample_top_p(ctx, candidates, p, min_keep)
514+
515+
516+
_lib.llama_sample_top_p.argtypes = [
517+
llama_context_p,
518+
llama_token_data_array_p,
519+
c_float,
520+
c_int,
521+
]
522+
_lib.llama_sample_top_p.restype = llama_token
523+
524+
525+
# LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep = 1);
526+
def llama_sample_tail_free(
527+
ctx: llama_context_p, candidates, z: c_float, min_keep: c_int
528+
) -> llama_token:
529+
return _lib.llama_sample_tail_free(ctx, candidates, z, min_keep)
530+
531+
532+
_lib.llama_sample_tail_free.argtypes = [
533+
llama_context_p,
534+
llama_token_data_array_p,
535+
c_float,
536+
c_int,
537+
]
538+
_lib.llama_sample_tail_free.restype = llama_token
539+
540+
541+
# LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1);
542+
def llama_sample_typical(
543+
ctx: llama_context_p, candidates, p: c_float, min_keep: c_int
544+
) -> llama_token:
545+
return _lib.llama_sample_typical(ctx, candidates, p, min_keep)
546+
547+
548+
_lib.llama_sample_typical.argtypes = [
549+
llama_context_p,
550+
llama_token_data_array_p,
551+
c_float,
423552
c_int,
553+
]
554+
_lib.llama_sample_typical.restype = llama_token
555+
556+
557+
# LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
558+
def llama_sample_temperature(
559+
ctx: llama_context_p, candidates, temp: c_float
560+
) -> llama_token:
561+
return _lib.llama_sample_temperature(ctx, candidates, temp)
562+
563+
564+
_lib.llama_sample_temperature.argtypes = [
565+
llama_context_p,
566+
llama_token_data_array_p,
424567
c_float,
568+
]
569+
_lib.llama_sample_temperature.restype = llama_token
570+
571+
572+
# LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu);
573+
def llama_sample_token_mirostat(
574+
ctx: llama_context_p, candidates, tau: c_float, eta: c_float, m: c_int, mu
575+
) -> llama_token:
576+
return _lib.llama_sample_token_mirostat(ctx, candidates, tau, eta, m, mu)
577+
578+
579+
_lib.llama_sample_token_mirostat.argtypes = [
580+
llama_context_p,
581+
llama_token_data_array_p,
582+
c_float,
583+
c_float,
584+
c_int,
585+
POINTER(c_float),
586+
]
587+
_lib.llama_sample_token_mirostat.restype = llama_token
588+
589+
590+
# LLAMA_API llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu);
591+
def llama_sample_token_mirostat_v2(
592+
ctx: llama_context_p, candidates, tau: c_float, eta: c_float, mu
593+
) -> llama_token:
594+
return _lib.llama_sample_token_mirostat_v2(ctx, candidates, tau, eta, mu)
595+
596+
597+
_lib.llama_sample_token_mirostat_v2.argtypes = [
598+
llama_context_p,
599+
llama_token_data_array_p,
425600
c_float,
426601
c_float,
602+
POINTER(c_float),
603+
]
604+
_lib.llama_sample_token_mirostat_v2.restype = llama_token
605+
606+
607+
# LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates);
608+
def llama_sample_token_greedy(ctx: llama_context_p, candidates) -> llama_token:
609+
return _lib.llama_sample_token_greedy(ctx, candidates)
610+
611+
612+
_lib.llama_sample_token_greedy.argtypes = [
613+
llama_context_p,
614+
llama_token_data_array_p,
615+
]
616+
_lib.llama_sample_token_greedy.restype = llama_token
617+
618+
619+
# LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
620+
def llama_sample_token(ctx: llama_context_p, candidates) -> llama_token:
621+
return _lib.llama_sample_token(ctx, candidates)
622+
623+
624+
_lib.llama_sample_token.argtypes = [
625+
llama_context_p,
626+
llama_token_data_array_p,
427627
]
428-
_lib.llama_sample_top_p_top_k.restype = llama_token
628+
_lib.llama_sample_token.restype = llama_token
429629

430630

431631
# Performance information

vendor/llama.cpp

0 commit comments

Comments
 (0)