Skip to content

Commit ba1f1b2

Browse files
committed
Add sampling chain
1 parent a757e2b commit ba1f1b2

File tree

3 files changed

+278
-54
lines changed

3 files changed

+278
-54
lines changed

llama_cpp/_internals.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import (
77
Dict,
88
List,
9+
Tuple,
910
Optional,
1011
Sequence,
1112
)
@@ -707,3 +708,136 @@ def accept(self, ctx_main: LlamaContext, id: int, apply_grammar: bool):
707708
ctx_main.grammar_accept_token(self.grammar, id)
708709
self.prev.append(id)
709710

711+
712+
from typing import List, Callable, Optional, Union
713+
import ctypes
714+
import llama_cpp
715+
716+
class CustomSampler:
717+
def __init__(self, apply_func: typing.Callable[[llama_cpp.llama_token_data_array], None]):
718+
self.apply_func = apply_func
719+
720+
def apply_wrapper(sampler: llama_cpp.llama_sampler_p, cur_p: llama_cpp.llama_token_data_array_p):
721+
self.apply_func(cur_p)
722+
723+
def free_wrapper(sampler: llama_cpp.llama_sampler_p):
724+
pass
725+
726+
sampler_i = llama_cpp.llama_sampler_i()
727+
sampler_i.apply = llama_cpp.llama_sampler_i_apply(apply_wrapper)
728+
self._apply_wrapper_ref = apply_wrapper
729+
730+
sampler_i.name = llama_cpp.llama_sampler_i_name(0)
731+
sampler_i.accept = llama_cpp.llama_sampler_i_accept(0)
732+
sampler_i.reset = llama_cpp.llama_sampler_i_reset(0)
733+
sampler_i.clone = llama_cpp.llama_sampler_i_clone(0)
734+
sampler_i.free = llama_cpp.llama_sampler_i_free(0)
735+
736+
self.sampler = llama_cpp.llama_sampler()
737+
self.sampler.iface = ctypes.pointer(sampler_i)
738+
self.sampler.ctx = None
739+
740+
def get_sampler(self) -> llama_cpp.llama_sampler_p:
741+
return ctypes.pointer(self.sampler)
742+
743+
class LlamaSampler:
744+
def __init__(self):
745+
params = llama_cpp.llama_sampler_chain_params()
746+
self.sampler = llama_cpp.llama_sampler_chain_init(params)
747+
self.samplers: List[llama_cpp.llama_sampler_p] = []
748+
self.custom_samplers: List[Tuple[int, CustomSampler]] = []
749+
750+
def add_greedy(self):
751+
sampler = llama_cpp.llama_sampler_init_greedy()
752+
self._add_sampler(sampler)
753+
754+
def add_dist(self, seed: int):
755+
sampler = llama_cpp.llama_sampler_init_dist(seed)
756+
self._add_sampler(sampler)
757+
758+
def add_softmax(self):
759+
sampler = llama_cpp.llama_sampler_init_softmax()
760+
self._add_sampler(sampler)
761+
762+
def add_top_k(self, k: int):
763+
sampler = llama_cpp.llama_sampler_init_top_k(k)
764+
self._add_sampler(sampler)
765+
766+
def add_top_p(self, p: float, min_keep: int):
767+
sampler = llama_cpp.llama_sampler_init_top_p(p, min_keep)
768+
self._add_sampler(sampler)
769+
770+
def add_min_p(self, p: float, min_keep: int):
771+
sampler = llama_cpp.llama_sampler_init_min_p(p, min_keep)
772+
self._add_sampler(sampler)
773+
774+
def add_tail_free(self, z: float, min_keep: int):
775+
sampler = llama_cpp.llama_sampler_init_tail_free(z, min_keep)
776+
self._add_sampler(sampler)
777+
778+
def add_typical(self, p: float, min_keep: int):
779+
sampler = llama_cpp.llama_sampler_init_typical(p, min_keep)
780+
self._add_sampler(sampler)
781+
782+
def add_temp(self, temp: float):
783+
sampler = llama_cpp.llama_sampler_init_temp(temp)
784+
self._add_sampler(sampler)
785+
786+
def add_temp_ext(self, t: float, delta: float, exponent: float):
787+
sampler = llama_cpp.llama_sampler_init_temp_ext(t, delta, exponent)
788+
self._add_sampler(sampler)
789+
790+
def add_mirostat(self, n_vocab: int, seed: int, tau: float, eta: float, m: int):
791+
sampler = llama_cpp.llama_sampler_init_mirostat(
792+
n_vocab, seed, tau, eta, m
793+
)
794+
self._add_sampler(sampler)
795+
796+
def add_mirostat_v2(self, seed: int, tau: float, eta: float):
797+
sampler = llama_cpp.llama_sampler_init_mirostat_v2(seed, tau, eta)
798+
self._add_sampler(sampler)
799+
800+
def add_grammar(self, model: LlamaModel, grammar: LlamaGrammar):
801+
sampler = llama_cpp.llama_sampler_init_grammar(model.model, grammar._grammar.encode("utf-8"), grammar._root.encode("utf-8"))
802+
self._add_sampler(sampler)
803+
804+
def add_penalties(self, n_vocab: int, special_eos_id: int, linefeed_id: int, penalty_last_n: int, penalty_repeat: float, penalty_freq: float, penalty_present: float, penalize_nl: bool, ignore_eos: bool):
805+
sampler = llama_cpp.llama_sampler_init_penalties(n_vocab, special_eos_id, linefeed_id, penalty_last_n, penalty_repeat, penalty_freq, penalty_present, penalize_nl, ignore_eos)
806+
self._add_sampler(sampler)
807+
808+
def init_logit_bias(self, n_vocab: int, n_logit_bias, logit_bias: llama_cpp.llama_logit_bias_p):
809+
sampler = llama_cpp.llama_sampler_init_logit_bias(n_vocab, n_logit_bias, logit_bias)
810+
self._add_sampler(sampler)
811+
812+
def add_custom(self, apply_func: Callable[[llama_cpp.llama_token_data_array], None]):
813+
custom_sampler = CustomSampler(apply_func)
814+
sampler = custom_sampler.get_sampler()
815+
self._add_sampler(sampler)
816+
# NOTE: Must remove custom samplers before free or llama.cpp will try to free them
817+
self.custom_samplers.append((llama_cpp.llama_sampler_chain_n(self.sampler) - 1, custom_sampler))
818+
819+
def _add_sampler(self, sampler: llama_cpp.llama_sampler_p):
820+
assert self.sampler is not None
821+
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
822+
self.samplers.append(sampler)
823+
824+
def get_seed(self) -> int:
825+
assert self.sampler is not None
826+
return llama_cpp.llama_sampler_get_seed(self.sampler)
827+
828+
def sample(self, ctx: LlamaContext, idx: int) -> int:
829+
assert self.sampler is not None
830+
return llama_cpp.llama_sampler_sample(self.sampler, ctx.ctx, idx)
831+
832+
def close(self):
833+
if self.sampler:
834+
# NOTE: Must remove custom samplers before free or llama.cpp will try to free them
835+
for i, _ in reversed(self.custom_samplers):
836+
llama_cpp.llama_sampler_chain_remove(self.sampler, i)
837+
llama_cpp.llama_sampler_free(self.sampler)
838+
self.sampler = None
839+
self.samplers.clear()
840+
self.custom_samplers.clear()
841+
842+
def __del__(self):
843+
self.close()

0 commit comments

Comments
 (0)