Skip to content

Commit 3d6c479

Browse files
committed
Merge branch 'main' into D4ve-R/main
2 parents 950f721 + 4a85442 commit 3d6c479

File tree

3 files changed

+21
-7
lines changed

3 files changed

+21
-7
lines changed

llama_cpp/llama.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1551,7 +1551,9 @@ def logit_bias_processor(
15511551
"utf-8", errors="ignore"
15521552
)
15531553
text_offset = len(prompt) + len(
1554-
self.detokenize(completion_tokens[:returned_tokens])
1554+
self.detokenize(completion_tokens[:returned_tokens]).decode(
1555+
"utf-8", errors="ignore"
1556+
)
15551557
)
15561558
token_offset = len(prompt_tokens) + returned_tokens
15571559
logits = self._scores[token_offset - 1, :]
@@ -1789,13 +1791,19 @@ def logit_bias_processor(
17891791
]
17901792
all_logprobs = Llama.logits_to_logprobs(self._scores)[token_offset:]
17911793
# TODO: may be able to change this loop to use np.take_along_dim
1792-
for token, token_str, logprobs_token in zip(
1793-
all_tokens, all_token_strs, all_logprobs
1794+
for idx, (token, token_str, logprobs_token) in enumerate(
1795+
zip(all_tokens, all_token_strs, all_logprobs)
17941796
):
17951797
if token == self.token_bos():
17961798
continue
1797-
text_offsets.append(text_offset)
1798-
text_offset += len(token_str)
1799+
text_offsets.append(
1800+
text_offset
1801+
+ len(
1802+
self.detokenize(all_tokens[:idx]).decode(
1803+
"utf-8", errors="ignore"
1804+
)
1805+
)
1806+
)
17991807
tokens.append(token_str)
18001808
sorted_logprobs = list(
18011809
sorted(

llama_cpp/llama_cpp.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -733,8 +733,14 @@ def llama_n_ctx(ctx: llama_context_p) -> int:
733733

734734

735735
_lib.llama_n_ctx.argtypes = [llama_context_p]
736-
_lib.llama_n_ctx.restype = c_int
736+
_lib.llama_n_ctx.restype = c_uint32
737737

738+
# LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
739+
def llama_n_batch(ctx: llama_context_p) -> int:
740+
return _lib.llama_n_batch(ctx)
741+
742+
_lib.llama_n_batch.argtypes = [llama_context_p]
743+
_lib.llama_n_batch.restype = c_uint32
738744

739745
# LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
740746
def llama_vocab_type(model: llama_model_p) -> int:

vendor/llama.cpp

0 commit comments

Comments
 (0)