|
6 | 6 | from typing import (
|
7 | 7 | Dict,
|
8 | 8 | List,
|
| 9 | + Tuple, |
9 | 10 | Optional,
|
10 | 11 | Sequence,
|
11 | 12 | )
|
@@ -707,3 +708,136 @@ def accept(self, ctx_main: LlamaContext, id: int, apply_grammar: bool):
|
707 | 708 | ctx_main.grammar_accept_token(self.grammar, id)
|
708 | 709 | self.prev.append(id)
|
709 | 710 |
|
| 711 | + |
| 712 | +from typing import List, Callable, Optional, Union |
| 713 | +import ctypes |
| 714 | +import llama_cpp |
| 715 | + |
| 716 | +class CustomSampler: |
| 717 | + def __init__(self, apply_func: typing.Callable[[llama_cpp.llama_token_data_array], None]): |
| 718 | + self.apply_func = apply_func |
| 719 | + |
| 720 | + def apply_wrapper(sampler: llama_cpp.llama_sampler_p, cur_p: llama_cpp.llama_token_data_array_p): |
| 721 | + self.apply_func(cur_p) |
| 722 | + |
| 723 | + def free_wrapper(sampler: llama_cpp.llama_sampler_p): |
| 724 | + pass |
| 725 | + |
| 726 | + sampler_i = llama_cpp.llama_sampler_i() |
| 727 | + sampler_i.apply = llama_cpp.llama_sampler_i_apply(apply_wrapper) |
| 728 | + self._apply_wrapper_ref = apply_wrapper |
| 729 | + |
| 730 | + sampler_i.name = llama_cpp.llama_sampler_i_name(0) |
| 731 | + sampler_i.accept = llama_cpp.llama_sampler_i_accept(0) |
| 732 | + sampler_i.reset = llama_cpp.llama_sampler_i_reset(0) |
| 733 | + sampler_i.clone = llama_cpp.llama_sampler_i_clone(0) |
| 734 | + sampler_i.free = llama_cpp.llama_sampler_i_free(0) |
| 735 | + |
| 736 | + self.sampler = llama_cpp.llama_sampler() |
| 737 | + self.sampler.iface = ctypes.pointer(sampler_i) |
| 738 | + self.sampler.ctx = None |
| 739 | + |
| 740 | + def get_sampler(self) -> llama_cpp.llama_sampler_p: |
| 741 | + return ctypes.pointer(self.sampler) |
| 742 | + |
| 743 | +class LlamaSampler: |
| 744 | + def __init__(self): |
| 745 | + params = llama_cpp.llama_sampler_chain_params() |
| 746 | + self.sampler = llama_cpp.llama_sampler_chain_init(params) |
| 747 | + self.samplers: List[llama_cpp.llama_sampler_p] = [] |
| 748 | + self.custom_samplers: List[Tuple[int, CustomSampler]] = [] |
| 749 | + |
| 750 | + def add_greedy(self): |
| 751 | + sampler = llama_cpp.llama_sampler_init_greedy() |
| 752 | + self._add_sampler(sampler) |
| 753 | + |
| 754 | + def add_dist(self, seed: int): |
| 755 | + sampler = llama_cpp.llama_sampler_init_dist(seed) |
| 756 | + self._add_sampler(sampler) |
| 757 | + |
| 758 | + def add_softmax(self): |
| 759 | + sampler = llama_cpp.llama_sampler_init_softmax() |
| 760 | + self._add_sampler(sampler) |
| 761 | + |
| 762 | + def add_top_k(self, k: int): |
| 763 | + sampler = llama_cpp.llama_sampler_init_top_k(k) |
| 764 | + self._add_sampler(sampler) |
| 765 | + |
| 766 | + def add_top_p(self, p: float, min_keep: int): |
| 767 | + sampler = llama_cpp.llama_sampler_init_top_p(p, min_keep) |
| 768 | + self._add_sampler(sampler) |
| 769 | + |
| 770 | + def add_min_p(self, p: float, min_keep: int): |
| 771 | + sampler = llama_cpp.llama_sampler_init_min_p(p, min_keep) |
| 772 | + self._add_sampler(sampler) |
| 773 | + |
| 774 | + def add_tail_free(self, z: float, min_keep: int): |
| 775 | + sampler = llama_cpp.llama_sampler_init_tail_free(z, min_keep) |
| 776 | + self._add_sampler(sampler) |
| 777 | + |
| 778 | + def add_typical(self, p: float, min_keep: int): |
| 779 | + sampler = llama_cpp.llama_sampler_init_typical(p, min_keep) |
| 780 | + self._add_sampler(sampler) |
| 781 | + |
| 782 | + def add_temp(self, temp: float): |
| 783 | + sampler = llama_cpp.llama_sampler_init_temp(temp) |
| 784 | + self._add_sampler(sampler) |
| 785 | + |
| 786 | + def add_temp_ext(self, t: float, delta: float, exponent: float): |
| 787 | + sampler = llama_cpp.llama_sampler_init_temp_ext(t, delta, exponent) |
| 788 | + self._add_sampler(sampler) |
| 789 | + |
| 790 | + def add_mirostat(self, n_vocab: int, seed: int, tau: float, eta: float, m: int): |
| 791 | + sampler = llama_cpp.llama_sampler_init_mirostat( |
| 792 | + n_vocab, seed, tau, eta, m |
| 793 | + ) |
| 794 | + self._add_sampler(sampler) |
| 795 | + |
| 796 | + def add_mirostat_v2(self, seed: int, tau: float, eta: float): |
| 797 | + sampler = llama_cpp.llama_sampler_init_mirostat_v2(seed, tau, eta) |
| 798 | + self._add_sampler(sampler) |
| 799 | + |
| 800 | + def add_grammar(self, model: LlamaModel, grammar: LlamaGrammar): |
| 801 | + sampler = llama_cpp.llama_sampler_init_grammar(model.model, grammar._grammar.encode("utf-8"), grammar._root.encode("utf-8")) |
| 802 | + self._add_sampler(sampler) |
| 803 | + |
| 804 | + def add_penalties(self, n_vocab: int, special_eos_id: int, linefeed_id: int, penalty_last_n: int, penalty_repeat: float, penalty_freq: float, penalty_present: float, penalize_nl: bool, ignore_eos: bool): |
| 805 | + sampler = llama_cpp.llama_sampler_init_penalties(n_vocab, special_eos_id, linefeed_id, penalty_last_n, penalty_repeat, penalty_freq, penalty_present, penalize_nl, ignore_eos) |
| 806 | + self._add_sampler(sampler) |
| 807 | + |
| 808 | + def init_logit_bias(self, n_vocab: int, n_logit_bias, logit_bias: llama_cpp.llama_logit_bias_p): |
| 809 | + sampler = llama_cpp.llama_sampler_init_logit_bias(n_vocab, n_logit_bias, logit_bias) |
| 810 | + self._add_sampler(sampler) |
| 811 | + |
| 812 | + def add_custom(self, apply_func: Callable[[llama_cpp.llama_token_data_array], None]): |
| 813 | + custom_sampler = CustomSampler(apply_func) |
| 814 | + sampler = custom_sampler.get_sampler() |
| 815 | + self._add_sampler(sampler) |
| 816 | + # NOTE: Must remove custom samplers before free or llama.cpp will try to free them |
| 817 | + self.custom_samplers.append((llama_cpp.llama_sampler_chain_n(self.sampler) - 1, custom_sampler)) |
| 818 | + |
| 819 | + def _add_sampler(self, sampler: llama_cpp.llama_sampler_p): |
| 820 | + assert self.sampler is not None |
| 821 | + llama_cpp.llama_sampler_chain_add(self.sampler, sampler) |
| 822 | + self.samplers.append(sampler) |
| 823 | + |
| 824 | + def get_seed(self) -> int: |
| 825 | + assert self.sampler is not None |
| 826 | + return llama_cpp.llama_sampler_get_seed(self.sampler) |
| 827 | + |
| 828 | + def sample(self, ctx: LlamaContext, idx: int) -> int: |
| 829 | + assert self.sampler is not None |
| 830 | + return llama_cpp.llama_sampler_sample(self.sampler, ctx.ctx, idx) |
| 831 | + |
| 832 | + def close(self): |
| 833 | + if self.sampler: |
| 834 | + # NOTE: Must remove custom samplers before free or llama.cpp will try to free them |
| 835 | + for i, _ in reversed(self.custom_samplers): |
| 836 | + llama_cpp.llama_sampler_chain_remove(self.sampler, i) |
| 837 | + llama_cpp.llama_sampler_free(self.sampler) |
| 838 | + self.sampler = None |
| 839 | + self.samplers.clear() |
| 840 | + self.custom_samplers.clear() |
| 841 | + |
| 842 | + def __del__(self): |
| 843 | + self.close() |
0 commit comments