Skip to content

Commit 24cc9d3

Browse files
committed
add a simple python implementation of parallel.cpp
1 parent 3664f6e commit 24cc9d3

File tree

1 file changed

+5
-18
lines changed

1 file changed

+5
-18
lines changed

examples/low_level_api/parallel.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,6 @@ def llama_sampling_sample(ctx_sampling: LlamaSamplingContext,
130130
for key, value in params.logit_bias.items():
131131
logits[key] += value
132132

133-
# for token_id, logit in enumerate(logits):
134-
# # baii fix logit_bias is None
135-
# if params.logit_bias:
136-
# logit += params.logit_bias.get(token_id, 0.0)
137-
138133
# 性能优化换用 numpy 数组实现
139134
cur.clear()
140135
# cur = None
@@ -145,14 +140,6 @@ def llama_sampling_sample(ctx_sampling: LlamaSamplingContext,
145140
cur[token_id].id = token_id
146141
cur[token_id].logit = logits[token_id]
147142
cur[token_id].p = 0.0
148-
#
149-
# cur_p = ctypes.byref(llama_cpp.llama_token_data_array(cur, len(cur), False))
150-
151-
# for token_id in range(n_vocab):
152-
# cur.append(
153-
# llama_cpp.llama_token_data(id=token_id, logit=logits[token_id], p=0.0)
154-
# # {'id': token_id, 'logit': logits[token_id], 'p': 0.0}
155-
# )
156143

157144
cur_p = ctypes.byref(llama_cpp.llama_token_data_array(cur, n_vocab, False))
158145

@@ -294,7 +281,7 @@ def _get_batch_view(batch: _LlamaBatch, n_tokens: int, offset: int) -> llama_cpp
294281

295282
def _move_pointer_offset(ptr, c_types, offset: int):
296283
"""
297-
移动指针(指针算数)
284+
Move the pointer (pointer counts)
298285
:param ptr: 要移动的指针
299286
:param c_types: 指针指向内存的类型
300287
:param offset: 移动的偏移量
@@ -368,7 +355,7 @@ class Client(object):
368355
input: str = ""
369356
prompt: str = ""
370357
response: str = ""
371-
decode_err: bytes = b''
358+
decode_err_buffer: bytes = b''
372359

373360
def __del__(self):
374361
if self.ctx_sampling:
@@ -638,10 +625,10 @@ class GptParams():
638625
token_str = llama.detokenize([token_id])
639626
# simple decode support zh-cn
640627
try:
641-
client.response += (client.decode_err + token_str).decode('utf8')
642-
client.decode_err = b''
628+
client.response += (client.decode_err_buffer + token_str).decode('utf8')
629+
client.decode_err_buffer = b''
643630
except UnicodeDecodeError:
644-
client.decode_err += token_str
631+
client.decode_err_buffer += token_str
645632
# print(f'{id=} {token_str} 解码失败')
646633
# client.response += token_str.decode('utf8', 'replace')
647634
# print(f"\033[31mClient {client.id}, seq {client.seq_id}, response {client.response}, \033[0m")

0 commit comments

Comments
 (0)