Skip to content

Commit fa7f1cd

Browse files
committed
truncate to n_batch, not n_ctx
1 parent ee84ca1 commit fa7f1cd

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

llama_cpp/llama.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -762,7 +762,7 @@ def embed(
762762
"""
763763
assert self._ctx.ctx is not None
764764
n_embd = self.n_embd()
765-
n_ctx = self.n_ctx()
765+
n_batch = self.n_batch
766766

767767
if self.context_params.embedding == False:
768768
raise RuntimeError(
@@ -807,19 +807,19 @@ def decode_batch(n_seq: int):
807807
for text in inputs:
808808
tokens = self.tokenize(text.encode("utf-8"))
809809
if truncate:
810-
tokens = tokens[:n_ctx]
810+
tokens = tokens[:n_batch]
811811

812812
n_tokens = len(tokens)
813813
total_tokens += n_tokens
814814

815815
# check for overrun
816-
if n_tokens > n_ctx:
816+
if n_tokens > n_batch:
817817
raise ValueError(
818-
f"Requested tokens ({n_tokens}) exceed context window of {n_ctx}"
818+
f"Requested tokens ({n_tokens}) exceed batch size of {n_batch}"
819819
)
820820

821821
# time to eval batch
822-
if t_batch + n_tokens > self._n_ctx:
822+
if t_batch + n_tokens > n_batch:
823823
decode_batch(p_batch)
824824
t_batch = 0
825825
p_batch = 0

0 commit comments

Comments
 (0)