Skip to content

Commit 3d0a079

Browse files
authored
Merge branch 'abetlen:main' into main
2 parents e200577 + 4b11fa8 commit 3d0a079

File tree

10 files changed

+163
-24
lines changed

10 files changed

+163
-24
lines changed

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
## [0.2.29]
11+
12+
- feat: Update llama.cpp to ggerganov/llama.cpp@4483396751c79dea540808b9cb9238245d06da2b
13+
- feat: Add split_mode option by @abetlen in 84615adbc6855c8384807c42f0130f9a1763f99d
14+
- feat: Implement GGUF metadata KV overrides by @phiharri in #1011
15+
- fix: Avoid "LookupError: unknown encoding: ascii" when open() called in a destructor by @yieldthought in #1012
16+
- fix: Fix low_level_api_chat_cpp example to match current API by @aniljava in #1086
17+
- fix: Fix Pydantic model parsing by @DeNeutoy in #1087
18+
1019
## [0.2.28]
1120

1221
- feat: Update llama.cpp to ggerganov/llama.cpp@6efb8eb30e7025b168f3fda3ff83b9b386428ad6

llama_cpp/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .llama_cpp import *
22
from .llama import *
33

4-
__version__ = "0.2.28"
4+
__version__ = "0.2.29"

llama_cpp/_utils.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
import os
22
import sys
33

4+
import sys, traceback
5+
6+
# Avoid "LookupError: unknown encoding: ascii" when open() called in a destructor
7+
outnull_file = open(os.devnull, "w")
8+
errnull_file = open(os.devnull, "w")
49

510
class suppress_stdout_stderr(object):
611
# NOTE: these must be "saved" here to avoid exceptions when using
712
# this context manager inside of a __del__ method
8-
open = open
913
sys = sys
1014
os = os
1115

@@ -21,9 +25,6 @@ def __enter__(self):
2125
if not hasattr(self.sys.stdout, 'fileno') or not hasattr(self.sys.stderr, 'fileno'):
2226
return self # Return the instance without making changes
2327

24-
self.outnull_file = self.open(self.os.devnull, "w")
25-
self.errnull_file = self.open(self.os.devnull, "w")
26-
2728
self.old_stdout_fileno_undup = self.sys.stdout.fileno()
2829
self.old_stderr_fileno_undup = self.sys.stderr.fileno()
2930

@@ -33,11 +34,11 @@ def __enter__(self):
3334
self.old_stdout = self.sys.stdout
3435
self.old_stderr = self.sys.stderr
3536

36-
self.os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
37-
self.os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
37+
self.os.dup2(outnull_file.fileno(), self.old_stdout_fileno_undup)
38+
self.os.dup2(errnull_file.fileno(), self.old_stderr_fileno_undup)
3839

39-
self.sys.stdout = self.outnull_file
40-
self.sys.stderr = self.errnull_file
40+
self.sys.stdout = outnull_file
41+
self.sys.stderr = errnull_file
4142
return self
4243

4344
def __exit__(self, *_):
@@ -54,6 +55,3 @@ def __exit__(self, *_):
5455

5556
self.os.close(self.old_stdout_fileno)
5657
self.os.close(self.old_stderr_fileno)
57-
58-
self.outnull_file.close()
59-
self.errnull_file.close()

llama_cpp/llama.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -730,11 +730,13 @@ def __init__(
730730
*,
731731
# Model Params
732732
n_gpu_layers: int = 0,
733+
split_mode: int = llama_cpp.LLAMA_SPLIT_LAYER,
733734
main_gpu: int = 0,
734735
tensor_split: Optional[List[float]] = None,
735736
vocab_only: bool = False,
736737
use_mmap: bool = True,
737738
use_mlock: bool = False,
739+
kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None,
738740
# Context Params
739741
seed: int = llama_cpp.LLAMA_DEFAULT_SEED,
740742
n_ctx: int = 512,
@@ -798,11 +800,13 @@ def __init__(
798800
Args:
799801
model_path: Path to the model.
800802
n_gpu_layers: Number of layers to offload to GPU (-ngl). If -1, all layers are offloaded.
801-
main_gpu: The GPU that is used for scratch and small tensors.
803+
split_mode: How to split the model across GPUs. See llama_cpp.LLAMA_SPLIT_* for options.
804+
main_gpu: main_gpu interpretation depends on split_mode: LLAMA_SPLIT_NONE: the GPU that is used for the entire model. LLAMA_SPLIT_ROW: the GPU that is used for small tensors and intermediate results. LLAMA_SPLIT_LAYER: ignored
802805
tensor_split: How split tensors should be distributed across GPUs. If None, the model is not split.
803806
vocab_only: Only load the vocabulary no weights.
804807
use_mmap: Use mmap if possible.
805808
use_mlock: Force the system to keep the model in RAM.
809+
kv_overrides: Key-value overrides for the model.
806810
seed: RNG seed, -1 for random
807811
n_ctx: Text context, 0 = from model
808812
n_batch: Prompt processing maximum batch size
@@ -848,6 +852,7 @@ def __init__(
848852
self.model_params.n_gpu_layers = (
849853
0x7FFFFFFF if n_gpu_layers == -1 else n_gpu_layers
850854
) # 0x7FFFFFFF is INT32 max, will be auto set to all layers
855+
self.model_params.split_mode = split_mode
851856
self.model_params.main_gpu = main_gpu
852857
self.tensor_split = tensor_split
853858
self._c_tensor_split = None
@@ -866,6 +871,34 @@ def __init__(
866871
self.model_params.use_mmap = use_mmap if lora_path is None else False
867872
self.model_params.use_mlock = use_mlock
868873

874+
self.kv_overrides = kv_overrides
875+
if kv_overrides is not None:
876+
n_overrides = len(kv_overrides)
877+
self._kv_overrides_array = llama_cpp.llama_model_kv_override * (n_overrides + 1)
878+
self._kv_overrides_array_keys = []
879+
880+
for k, v in kv_overrides.items():
881+
key_buf = ctypes.create_string_buffer(k.encode("utf-8"))
882+
self._kv_overrides_array_keys.append(key_buf)
883+
self._kv_overrides_array[i].key = key_buf
884+
if isinstance(v, int):
885+
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_INT
886+
self._kv_overrides_array[i].value.int_value = v
887+
elif isinstance(v, float):
888+
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_FLOAT
889+
self._kv_overrides_array[i].value.float_value = v
890+
elif isinstance(v, bool):
891+
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_BOOL
892+
self._kv_overrides_array[i].value.bool_value = v
893+
else:
894+
raise ValueError(f"Unknown value type for {k}: {v}")
895+
896+
self._kv_overrides_array_sentinel_key = b'\0'
897+
898+
# null array sentinel
899+
self._kv_overrides_array[n_overrides].key = self._kv_overrides_array_sentinel_key
900+
self.model_params.kv_overrides = self._kv_overrides_array
901+
869902
self.n_batch = min(n_ctx, n_batch) # ???
870903
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
871904
self.n_threads_batch = n_threads_batch or max(
@@ -2143,11 +2176,13 @@ def __getstate__(self):
21432176
model_path=self.model_path,
21442177
# Model Params
21452178
n_gpu_layers=self.model_params.n_gpu_layers,
2179+
split_mode=self.model_params.split_mode,
21462180
main_gpu=self.model_params.main_gpu,
21472181
tensor_split=self.tensor_split,
21482182
vocab_only=self.model_params.vocab_only,
21492183
use_mmap=self.model_params.use_mmap,
21502184
use_mlock=self.model_params.use_mlock,
2185+
kv_overrides=self.kv_overrides,
21512186
# Context Params
21522187
seed=self.context_params.seed,
21532188
n_ctx=self.context_params.n_ctx,
@@ -2185,11 +2220,13 @@ def __setstate__(self, state):
21852220
model_path=state["model_path"],
21862221
# Model Params
21872222
n_gpu_layers=state["n_gpu_layers"],
2223+
split_mode=state["split_mode"],
21882224
main_gpu=state["main_gpu"],
21892225
tensor_split=state["tensor_split"],
21902226
vocab_only=state["vocab_only"],
21912227
use_mmap=state["use_mmap"],
21922228
use_mlock=state["use_mlock"],
2229+
kv_overrides=state["kv_overrides"],
21932230
# Context Params
21942231
seed=state["seed"],
21952232
n_ctx=state["n_ctx"],

llama_cpp/llama_cpp.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ def _load_shared_library(lib_base_name: str):
229229
LLAMA_SPLIT_LAYER = 1
230230
LLAMA_SPLIT_ROW = 2
231231

232+
232233
# typedef struct llama_token_data {
233234
# llama_token id; // token id
234235
# float logit; // log-odds of the token
@@ -395,6 +396,7 @@ class llama_model_kv_override(Structure):
395396
# // override key-value pairs of the model meta data
396397
# const struct llama_model_kv_override * kv_overrides;
397398

399+
398400
# // Keep the booleans together to avoid misalignment during copy-by-value.
399401
# bool vocab_only; // only load the vocabulary, no weights
400402
# bool use_mmap; // use mmap if possible
@@ -407,7 +409,7 @@ class llama_model_params(Structure):
407409
n_gpu_layers (int): number of layers to store in VRAM
408410
split_mode (int): how to split the model across multiple GPUs
409411
main_gpu (int): the GPU that is used for the entire model. main_gpu interpretation depends on split_mode: LLAMA_SPLIT_NONE: the GPU that is used for the entire model LLAMA_SPLIT_ROW: the GPU that is used for small tensors and intermediate results LLAMA_SPLIT_LAYER: ignored
410-
tensor_split (ctypes.Array[ctypes.c_float]): proportion of the model (layers or rows) to offload to each GPU, size: LLAMA_MAX_DEVICES
412+
tensor_split (ctypes.Array[ctypes.c_float]): proportion of the model (layers or rows) to offload to each GPU, size: LLAMA_MAX_DEVICES
411413
progress_callback (llama_progress_callback): called with a progress value between 0.0 and 1.0. Pass NULL to disable. If the provided progress_callback returns true, model loading continues. If it returns false, model loading is immediately aborted.
412414
progress_callback_user_data (ctypes.c_void_p): context pointer passed to the progress callback
413415
kv_overrides (ctypes.Array[llama_model_kv_override]): override key-value pairs of the model meta data
@@ -526,6 +528,7 @@ class llama_context_params(Structure):
526528
# bool quantize_output_tensor; // quantize output.weight
527529
# bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
528530
# bool pure; // disable k-quant mixtures and quantize all tensors to the same type
531+
# void * imatrix; // pointer to importance matrix data
529532
# } llama_model_quantize_params;
530533
class llama_model_quantize_params(Structure):
531534
"""Parameters for llama_model_quantize
@@ -537,6 +540,7 @@ class llama_model_quantize_params(Structure):
537540
quantize_output_tensor (bool): quantize output.weight
538541
only_copy (bool): only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
539542
pure (bool): disable k-quant mixtures and quantize all tensors to the same type
543+
imatrix (ctypes.c_void_p): pointer to importance matrix data
540544
"""
541545

542546
_fields_ = [
@@ -545,6 +549,8 @@ class llama_model_quantize_params(Structure):
545549
("allow_requantize", c_bool),
546550
("quantize_output_tensor", c_bool),
547551
("only_copy", c_bool),
552+
("pure", c_bool),
553+
("imatrix", c_void_p),
548554
]
549555

550556

@@ -1956,14 +1962,39 @@ def llama_sample_repetition_penalties(
19561962

19571963

19581964
# /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
1959-
# /// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.
1960-
# /// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
1961-
# /// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
1962-
# LLAMA_API void llama_sample_classifier_free_guidance(
1963-
# struct llama_context * ctx,
1965+
# /// @param logits Logits extracted from the original generation context.
1966+
# /// @param logits_guidance Logits extracted from a separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
1967+
# /// @param scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
1968+
# LLAMA_API void llama_sample_apply_guidance(
1969+
# struct llama_context * ctx,
1970+
# float * logits,
1971+
# float * logits_guidance,
1972+
# float scale);
1973+
def llama_sample_apply_guidance(
1974+
ctx: llama_context_p,
1975+
logits, # type: _Pointer[c_float]
1976+
logits_guidance, # type: _Pointer[c_float]
1977+
scale: Union[c_float, float],
1978+
):
1979+
"""Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806"""
1980+
return _lib.llama_sample_apply_guidance(ctx, logits, logits_guidance, scale)
1981+
1982+
1983+
_lib.llama_sample_apply_guidance.argtypes = [
1984+
llama_context_p,
1985+
c_float_p,
1986+
c_float_p,
1987+
c_float,
1988+
]
1989+
_lib.llama_sample_apply_guidance.restype = None
1990+
1991+
1992+
# LLAMA_API DEPRECATED(void llama_sample_classifier_free_guidance(
1993+
# struct llama_context * ctx,
19641994
# llama_token_data_array * candidates,
1965-
# struct llama_context * guidance_ctx,
1966-
# float scale);
1995+
# struct llama_context * guidance_ctx,
1996+
# float scale),
1997+
# "use llama_sample_apply_guidance() instead");
19671998
def llama_sample_classifier_free_guidance(
19681999
ctx: llama_context_p,
19692000
candidates, # type: _Pointer[llama_token_data_array]

llama_cpp/llama_grammar.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1433,7 +1433,6 @@ def _add_rule(self, name: str, rule: str):
14331433

14341434
def visit(self, schema: Dict[str, Any], name: str) -> str:
14351435
schema_type: Optional[str] = schema.get("type") # type: ignore
1436-
assert isinstance(schema_type, str), f"Unrecognized schema: {schema}"
14371436
rule_name = name or "root"
14381437

14391438
if "$defs" in schema:

llama_cpp/server/model.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Optional, Union, List
3+
from typing import Dict, Optional, Union, List
44

55
import llama_cpp
66

@@ -71,6 +71,23 @@ def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama:
7171
chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler(
7272
clip_model_path=settings.clip_model_path, verbose=settings.verbose
7373
)
74+
75+
kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None
76+
if settings.kv_overrides is not None:
77+
assert isinstance(settings.kv_overrides, list)
78+
kv_overrides = {}
79+
for kv in settings.kv_overrides:
80+
key, value = kv.split("=")
81+
if ":" in value:
82+
value_type, value = value.split(":")
83+
if value_type == "bool":
84+
kv_overrides[key] = value.lower() in ["true", "1"]
85+
elif value_type == "int":
86+
kv_overrides[key] = int(value)
87+
elif value_type == "float":
88+
kv_overrides[key] = float(value)
89+
else:
90+
raise ValueError(f"Unknown value type {value_type}")
7491

7592
_model = llama_cpp.Llama(
7693
model_path=settings.model,
@@ -81,6 +98,7 @@ def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama:
8198
vocab_only=settings.vocab_only,
8299
use_mmap=settings.use_mmap,
83100
use_mlock=settings.use_mlock,
101+
kv_overrides=kv_overrides,
84102
# Context Params
85103
seed=settings.seed,
86104
n_ctx=settings.n_ctx,

llama_cpp/server/settings.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ class ModelSettings(BaseSettings):
2828
ge=-1,
2929
description="The number of layers to put on the GPU. The rest will be on the CPU. Set -1 to move all to GPU.",
3030
)
31+
split_mode: int = Field(
32+
default=llama_cpp.LLAMA_SPLIT_LAYER,
33+
description="The split mode to use.",
34+
)
3135
main_gpu: int = Field(
3236
default=0,
3337
ge=0,
@@ -48,6 +52,10 @@ class ModelSettings(BaseSettings):
4852
default=llama_cpp.llama_mlock_supported(),
4953
description="Use mlock.",
5054
)
55+
kv_overrides: Optional[List[str]] = Field(
56+
default=None,
57+
description="List of model kv overrides in the format key=type:value where type is one of (bool, int, float). Valid true values are (true, TRUE, 1), otherwise false.",
58+
)
5159
# Context Params
5260
seed: int = Field(
5361
default=llama_cpp.LLAMA_DEFAULT_SEED, description="Random seed. -1 for random."

tests/test_grammar.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,52 @@
11
import llama_cpp
2+
import json
23

34
tree = """
45
leaf ::= "."
56
node ::= leaf | "(" node node ")"
67
root ::= node
78
"""
89

10+
911
def test_grammar_from_string():
1012
grammar = llama_cpp.LlamaGrammar.from_string(tree)
1113
assert grammar._n_rules == 3
1214
assert grammar._start_rule_index == 2
1315
assert grammar.grammar is not None
16+
17+
18+
def test_composed_pydantic_grammar():
19+
"""
20+
from pydantic import BaseModel
21+
22+
class A(BaseModel):
23+
a: int
24+
25+
class B(BaseModel):
26+
a: A
27+
b: int
28+
"""
29+
30+
# This schema corresponds to the grammar in the comment above.
31+
# We don't use the pydantic models directly to avoid the dependency.
32+
schema = {
33+
"$defs": {
34+
"A": {
35+
"properties": {"a": {"title": "A", "type": "integer"}},
36+
"required": ["a"],
37+
"title": "A",
38+
"type": "object",
39+
}
40+
},
41+
"properties": {
42+
"a": {"$ref": "#/$defs/A"},
43+
"b": {"title": "B", "type": "integer"},
44+
},
45+
"required": ["a", "b"],
46+
"title": "B",
47+
"type": "object",
48+
}
49+
50+
grammar = llama_cpp.LlamaGrammar.from_json_schema(json.dumps(schema))
51+
52+
assert grammar.grammar is not None

vendor/llama.cpp

0 commit comments

Comments
 (0)