@@ -807,8 +807,10 @@ def sample(
807
807
grammar = grammar ,
808
808
)
809
809
810
+ ridx = idx - self .n_tokens if idx is not None else - 1
811
+
810
812
assert self .ctx is not None
811
- token = self ._sampler .sample (self ._ctx , - 1 )
813
+ token = self ._sampler .sample (self ._ctx , ridx )
812
814
if tmp_sampler :
813
815
self ._sampler = None
814
816
return token
@@ -928,7 +930,7 @@ def generate(
928
930
929
931
sample_idx += 1
930
932
if stopping_criteria is not None and stopping_criteria (
931
- self ._input_ids , self ._scores [- 1 , :]
933
+ self ._input_ids [: sample_idx ] , self ._scores [sample_idx - self . n_tokens , :]
932
934
):
933
935
return
934
936
tokens_or_none = yield token
@@ -1517,15 +1519,15 @@ def logit_bias_processor(
1517
1519
1518
1520
if stream :
1519
1521
remaining_tokens = completion_tokens [returned_tokens :]
1520
- all_text = self .detokenize (
1522
+ remaining_text = self .detokenize (
1521
1523
remaining_tokens ,
1522
1524
prev_tokens = prompt_tokens + completion_tokens [:returned_tokens ],
1523
1525
)
1524
- any_stop = [s for s in stop_sequences if s in all_text ]
1526
+ any_stop = [s for s in stop_sequences if s in remaining_text ]
1525
1527
if len (any_stop ) > 0 :
1526
- end = min (all_text .index (stop ) for stop in any_stop )
1528
+ end = min (remaining_text .index (stop ) for stop in any_stop )
1527
1529
else :
1528
- end = len (all_text )
1530
+ end = len (remaining_text )
1529
1531
1530
1532
token_end_position = 0
1531
1533
for token in remaining_tokens :
0 commit comments