|
15 | 15 |
|
16 | 16 |
|
17 | 17 | @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
18 |
| -def get_embedding(text: str, engine="text-similarity-davinci-001") -> List[float]: |
| 18 | +def get_embedding(text: str, engine="text-similarity-davinci-001", **kwargs) -> List[float]: |
19 | 19 |
|
20 | 20 | # replace newlines, which can negatively affect performance.
|
21 | 21 | text = text.replace("\n", " ")
|
22 | 22 |
|
23 |
| - return openai.Embedding.create(input=[text], engine=engine)["data"][0]["embedding"] |
| 23 | + return openai.Embedding.create(input=[text], engine=engine, **kwargs)["data"][0]["embedding"] |
24 | 24 |
|
25 | 25 |
|
26 | 26 | @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
27 | 27 | async def aget_embedding(
|
28 |
| - text: str, engine="text-similarity-davinci-001" |
| 28 | + text: str, engine="text-similarity-davinci-001", **kwargs |
29 | 29 | ) -> List[float]:
|
30 | 30 |
|
31 | 31 | # replace newlines, which can negatively affect performance.
|
32 | 32 | text = text.replace("\n", " ")
|
33 | 33 |
|
34 |
| - return (await openai.Embedding.acreate(input=[text], engine=engine))["data"][0][ |
| 34 | + return (await openai.Embedding.acreate(input=[text], engine=engine, **kwargs))["data"][0][ |
35 | 35 | "embedding"
|
36 | 36 | ]
|
37 | 37 |
|
38 | 38 |
|
39 | 39 | @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
40 | 40 | def get_embeddings(
|
41 |
| - list_of_text: List[str], engine="text-similarity-babbage-001" |
| 41 | + list_of_text: List[str], engine="text-similarity-babbage-001", **kwargs |
42 | 42 | ) -> List[List[float]]:
|
43 | 43 | assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
|
44 | 44 |
|
45 | 45 | # replace newlines, which can negatively affect performance.
|
46 | 46 | list_of_text = [text.replace("\n", " ") for text in list_of_text]
|
47 | 47 |
|
48 |
| - data = openai.Embedding.create(input=list_of_text, engine=engine).data |
| 48 | + data = openai.Embedding.create(input=list_of_text, engine=engine, **kwargs).data |
49 | 49 | data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
|
50 | 50 | return [d["embedding"] for d in data]
|
51 | 51 |
|
52 | 52 |
|
53 | 53 | @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
54 | 54 | async def aget_embeddings(
|
55 |
| - list_of_text: List[str], engine="text-similarity-babbage-001" |
| 55 | + list_of_text: List[str], engine="text-similarity-babbage-001", **kwargs |
56 | 56 | ) -> List[List[float]]:
|
57 | 57 | assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
|
58 | 58 |
|
59 | 59 | # replace newlines, which can negatively affect performance.
|
60 | 60 | list_of_text = [text.replace("\n", " ") for text in list_of_text]
|
61 | 61 |
|
62 |
| - data = (await openai.Embedding.acreate(input=list_of_text, engine=engine)).data |
| 62 | + data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, **kwargs)).data |
63 | 63 | data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
|
64 | 64 | return [d["embedding"] for d in data]
|
65 | 65 |
|
|
0 commit comments