@@ -130,11 +130,6 @@ def llama_sampling_sample(ctx_sampling: LlamaSamplingContext,
130
130
for key , value in params .logit_bias .items ():
131
131
logits [key ] += value
132
132
133
- # for token_id, logit in enumerate(logits):
134
- # # baii fix logit_bias is None
135
- # if params.logit_bias:
136
- # logit += params.logit_bias.get(token_id, 0.0)
137
-
138
133
# 性能优化换用 numpy 数组实现
139
134
cur .clear ()
140
135
# cur = None
@@ -145,14 +140,6 @@ def llama_sampling_sample(ctx_sampling: LlamaSamplingContext,
145
140
cur [token_id ].id = token_id
146
141
cur [token_id ].logit = logits [token_id ]
147
142
cur [token_id ].p = 0.0
148
- #
149
- # cur_p = ctypes.byref(llama_cpp.llama_token_data_array(cur, len(cur), False))
150
-
151
- # for token_id in range(n_vocab):
152
- # cur.append(
153
- # llama_cpp.llama_token_data(id=token_id, logit=logits[token_id], p=0.0)
154
- # # {'id': token_id, 'logit': logits[token_id], 'p': 0.0}
155
- # )
156
143
157
144
cur_p = ctypes .byref (llama_cpp .llama_token_data_array (cur , n_vocab , False ))
158
145
@@ -294,7 +281,7 @@ def _get_batch_view(batch: _LlamaBatch, n_tokens: int, offset: int) -> llama_cpp
294
281
295
282
def _move_pointer_offset (ptr , c_types , offset : int ):
296
283
"""
297
- 移动指针(指针算数 )
284
+ Move the pointer (pointer counts )
298
285
:param ptr: 要移动的指针
299
286
:param c_types: 指针指向内存的类型
300
287
:param offset: 移动的偏移量
@@ -368,7 +355,7 @@ class Client(object):
368
355
input : str = ""
369
356
prompt : str = ""
370
357
response : str = ""
371
- decode_err : bytes = b''
358
+ decode_err_buffer : bytes = b''
372
359
373
360
def __del__ (self ):
374
361
if self .ctx_sampling :
@@ -638,10 +625,10 @@ class GptParams():
638
625
token_str = llama .detokenize ([token_id ])
639
626
# simple decode support zh-cn
640
627
try :
641
- client .response += (client .decode_err + token_str ).decode ('utf8' )
642
- client .decode_err = b''
628
+ client .response += (client .decode_err_buffer + token_str ).decode ('utf8' )
629
+ client .decode_err_buffer = b''
643
630
except UnicodeDecodeError :
644
- client .decode_err += token_str
631
+ client .decode_err_buffer += token_str
645
632
# print(f'{id=} {token_str} 解码失败')
646
633
# client.response += token_str.decode('utf8', 'replace')
647
634
# print(f"\033[31mClient {client.id}, seq {client.seq_id}, response {client.response}, \033[0m")
0 commit comments