Skip to content

Commit eb16072

Browse files
authored
Merge branch 'main' into expose-libggml
2 parents 2c260b9 + dca0c9a commit eb16072

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

llama_cpp/llama.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -807,8 +807,10 @@ def sample(
807807
grammar=grammar,
808808
)
809809

810+
ridx = idx - self.n_tokens if idx is not None else -1
811+
810812
assert self.ctx is not None
811-
token = self._sampler.sample(self._ctx, -1)
813+
token = self._sampler.sample(self._ctx, ridx)
812814
if tmp_sampler:
813815
self._sampler = None
814816
return token
@@ -928,7 +930,7 @@ def generate(
928930

929931
sample_idx += 1
930932
if stopping_criteria is not None and stopping_criteria(
931-
self._input_ids, self._scores[-1, :]
933+
self._input_ids[: sample_idx], self._scores[sample_idx - self.n_tokens, :]
932934
):
933935
return
934936
tokens_or_none = yield token
@@ -1517,15 +1519,15 @@ def logit_bias_processor(
15171519

15181520
if stream:
15191521
remaining_tokens = completion_tokens[returned_tokens:]
1520-
all_text = self.detokenize(
1522+
remaining_text = self.detokenize(
15211523
remaining_tokens,
15221524
prev_tokens=prompt_tokens + completion_tokens[:returned_tokens],
15231525
)
1524-
any_stop = [s for s in stop_sequences if s in all_text]
1526+
any_stop = [s for s in stop_sequences if s in remaining_text]
15251527
if len(any_stop) > 0:
1526-
end = min(all_text.index(stop) for stop in any_stop)
1528+
end = min(remaining_text.index(stop) for stop in any_stop)
15271529
else:
1528-
end = len(all_text)
1530+
end = len(remaining_text)
15291531

15301532
token_end_position = 0
15311533
for token in remaining_tokens:

vendor/llama.cpp

0 commit comments

Comments
 (0)