Skip to content

Commit 9652372

Browse files
committed
performant reproduction with tuning
1 parent e83c965 commit 9652372

File tree

4 files changed

+31
-7
lines changed

4 files changed

+31
-7
lines changed

examples/high_level_api/high_level_api_inference.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55

66
parser = argparse.ArgumentParser()
77
parser.add_argument("-m", "--model", type=str, default="../models/7B/ggml-models.bin")
8+
parser.add_argument("-i", "--path_idx", type=str)
9+
parser.add_argument("-ngl", "--n_gpu_layers", type=int)
10+
811
args = parser.parse_args()
912

10-
llm = Llama(model_path=args.model)
13+
llm = Llama(model_path=args.model, n_gpu_layers=args.n_gpu_layers, path_idx=args.path_idx, n_ctx=128, n_batch=1)
1114

1215
output = llm(
1316
"Question: What are the names of the planets in the solar system? Answer: ",
14-
max_tokens=512,
17+
max_tokens=128,
1518
stop=["Q:", "\n"],
1619
echo=True,
1720
)

llama_cpp/llama.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def __init__(
221221
path_model: str,
222222
params: llama_cpp.llama_model_params,
223223
verbose: bool = True,
224+
path_idx: Optional[str] = None,
224225
):
225226
self.path_model = path_model
226227
self.params = params
@@ -235,6 +236,11 @@ def __init__(
235236
self.model = llama_cpp.llama_load_model_from_file(
236237
self.path_model.encode("utf-8"), self.params
237238
)
239+
if path_idx:
240+
llama_cpp.llama_model_apply_mlp_from_file(
241+
self.model, path_idx.encode("utf-8"), True
242+
)
243+
llama_cpp.llama_model_apply_augmentation(self.model)
238244

239245
def __del__(self):
240246
with suppress_stdout_stderr(disable=self.verbose):
@@ -761,6 +767,8 @@ def __init__(
761767
chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None,
762768
# Misc
763769
verbose: bool = True,
770+
# GPU index
771+
path_idx: Optional[str] = None,
764772
# Extra Params
765773
**kwargs, # type: ignore
766774
):
@@ -887,7 +895,8 @@ def __init__(
887895
raise ValueError(f"Model path does not exist: {model_path}")
888896

889897
self._model = _LlamaModel(
890-
path_model=self.model_path, params=self.model_params, verbose=self.verbose
898+
path_model=self.model_path, params=self.model_params, verbose=self.verbose,
899+
path_idx=path_idx,
891900
)
892901

893902
self._ctx = _LlamaContext(

llama_cpp/llama_cpp.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,15 +305,12 @@ class llama_model_params(Structure):
305305
_fields_ = [
306306
("n_gpu_layers", c_int32),
307307
("main_gpu", c_int32),
308-
("vram_budget_gb", c_float),
309308
("tensor_split", c_float_p),
310309
("progress_callback", llama_progress_callback),
311310
("progress_callback_user_data", c_void_p),
312311
("vocab_only", c_bool),
313312
("use_mmap", c_bool),
314313
("use_mlock", c_bool),
315-
("reset_gpu_index", c_bool),
316-
("disable_gpu_index", c_bool),
317314
]
318315

319316

@@ -555,6 +552,21 @@ def llama_new_context_with_model(
555552
_lib.llama_new_context_with_model.restype = llama_context_p
556553

557554

555+
def llama_model_apply_mlp_from_file(
556+
model: llama_model_p, path: bytes, use_mmap: Union[c_bool, bool]
557+
):
558+
_lib.llama_model_apply_mlp_from_file(model, path, use_mmap)
559+
560+
_lib.llama_model_apply_mlp_from_file.argtypes = [llama_model_p, c_char_p, c_bool]
561+
_lib.llama_model_apply_mlp_from_file.restype = None
562+
563+
def llama_model_apply_augmentation(model: llama_model_p):
564+
_lib.llama_model_apply_augmentation(model)
565+
566+
_lib.llama_model_apply_augmentation.argtypes = [llama_model_p]
567+
_lib.llama_model_apply_augmentation.restype = None
568+
569+
558570
# // Frees all allocated memory
559571
# LLAMA_API void llama_free(struct llama_context * ctx);
560572
def llama_free(ctx: llama_context_p):

0 commit comments

Comments
 (0)