Skip to content

Commit db3f352

Browse files
authored
Pass in keyword arguments in embedding utility functions (openai#405)
* Pass in kwargs in embedding util functions * Remove debug code
1 parent 794def3 commit db3f352

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

openai/embeddings_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,51 +15,51 @@
1515

1616

1717
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
18-
def get_embedding(text: str, engine="text-similarity-davinci-001") -> List[float]:
18+
def get_embedding(text: str, engine="text-similarity-davinci-001", **kwargs) -> List[float]:
1919

2020
# replace newlines, which can negatively affect performance.
2121
text = text.replace("\n", " ")
2222

23-
return openai.Embedding.create(input=[text], engine=engine)["data"][0]["embedding"]
23+
return openai.Embedding.create(input=[text], engine=engine, **kwargs)["data"][0]["embedding"]
2424

2525

2626
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
2727
async def aget_embedding(
28-
text: str, engine="text-similarity-davinci-001"
28+
text: str, engine="text-similarity-davinci-001", **kwargs
2929
) -> List[float]:
3030

3131
# replace newlines, which can negatively affect performance.
3232
text = text.replace("\n", " ")
3333

34-
return (await openai.Embedding.acreate(input=[text], engine=engine))["data"][0][
34+
return (await openai.Embedding.acreate(input=[text], engine=engine, **kwargs))["data"][0][
3535
"embedding"
3636
]
3737

3838

3939
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
4040
def get_embeddings(
41-
list_of_text: List[str], engine="text-similarity-babbage-001"
41+
list_of_text: List[str], engine="text-similarity-babbage-001", **kwargs
4242
) -> List[List[float]]:
4343
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
4444

4545
# replace newlines, which can negatively affect performance.
4646
list_of_text = [text.replace("\n", " ") for text in list_of_text]
4747

48-
data = openai.Embedding.create(input=list_of_text, engine=engine).data
48+
data = openai.Embedding.create(input=list_of_text, engine=engine, **kwargs).data
4949
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
5050
return [d["embedding"] for d in data]
5151

5252

5353
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
5454
async def aget_embeddings(
55-
list_of_text: List[str], engine="text-similarity-babbage-001"
55+
list_of_text: List[str], engine="text-similarity-babbage-001", **kwargs
5656
) -> List[List[float]]:
5757
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
5858

5959
# replace newlines, which can negatively affect performance.
6060
list_of_text = [text.replace("\n", " ") for text in list_of_text]
6161

62-
data = (await openai.Embedding.acreate(input=list_of_text, engine=engine)).data
62+
data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, **kwargs)).data
6363
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
6464
return [d["embedding"] for d in data]
6565

0 commit comments

Comments
 (0)