Skip to content

Commit b4fdcd6

Browse files
committed
Formatteding client.py (black)
1 parent b894768 commit b4fdcd6

File tree

1 file changed

+150
-50
lines changed

1 file changed

+150
-50
lines changed

openai/client.py

Lines changed: 150 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ def __init__(self, credential=None):
2020
self._cached_token = None
2121

2222
def get_token(self) -> str:
23-
if self._cached_token is None or (self._cached_token.expires_on - time.time()) < 300:
23+
if (
24+
self._cached_token is None
25+
or (self._cached_token.expires_on - time.time()) < 300
26+
):
2427
self._cached_token = self._credential.get_token(
2528
"https://cognitiveservices.azure.com/.default"
2629
)
@@ -40,7 +43,6 @@ def get_token(self) -> str:
4043

4144

4245
class OpenAIClient:
43-
4446
def __init__(
4547
self,
4648
*,
@@ -81,8 +83,7 @@ def __init__(
8183
)
8284

8385
def _populate_args(self, kwargs: typing.Dict[str, typing.Any], **overrides) -> None:
84-
"""Populate default arguments based on the current client configuration/defaults
85-
"""
86+
"""Populate default arguments based on the current client configuration/defaults"""
8687
kwargs.setdefault("api_base", self.api_base or openai.api_base)
8788
kwargs.setdefault("api_key", self.auth.get_token())
8889
kwargs.setdefault("api_type", self.api_type)
@@ -98,19 +99,30 @@ def _populate_args(self, kwargs: typing.Dict[str, typing.Any], **overrides) -> N
9899

99100
def _normalize_model(self, kwargs: typing.Dict[str, typing.Any]):
100101
"""Normalize model/engine/deployment_id based on which backend the client is
101-
configured to target.
102+
configured to target.
102103
103-
Specifically, it will pass the provided `model` parameter as `deployment_id`
104-
unless `deployment_id` is explicitly passed in.
104+
Specifically, it will pass the provided `model` parameter as `deployment_id`
105+
unless `deployment_id` is explicitly passed in.
105106
"""
106-
if len([param for param in kwargs if param in ('deployment_id', 'model', 'engine')]) != 1:
107-
raise TypeError('You can only specify one of `deployment_id`, `model` and `engine`')
108-
109-
if self.backend == 'azure':
107+
if (
108+
len(
109+
[
110+
param
111+
for param in kwargs
112+
if param in ("deployment_id", "model", "engine")
113+
]
114+
)
115+
!= 1
116+
):
117+
raise TypeError(
118+
"You can only specify one of `deployment_id`, `model` and `engine`"
119+
)
120+
121+
if self.backend == "azure":
110122
try:
111123
# We'll try to "rename" the `model` keyword to fit azure's `deployment_id`
112124
# paradigm
113-
kwargs['deployment_id'] = kwargs.pop('model')
125+
kwargs["deployment_id"] = kwargs.pop("model")
114126
except KeyError:
115127
pass
116128

@@ -139,7 +151,8 @@ async def aiter_completion(
139151
self._populate_args(kwargs, prompt=prompt, stream=True)
140152
self._normalize_model(kwargs)
141153
return typing.cast(
142-
typing.AsyncIterable[openai.Completion], await openai.Completion.acreate(**kwargs)
154+
typing.AsyncIterable[openai.Completion],
155+
await openai.Completion.acreate(**kwargs),
143156
)
144157

145158
def chatcompletion(self, messages, **kwargs) -> openai.ChatCompletion:
@@ -186,50 +199,135 @@ async def aembeddings(self, input, **kwargs) -> openai.Embedding:
186199
self._normalize_model(kwargs)
187200
return typing.cast(openai.Embedding, await openai.Embedding.acreate(**kwargs))
188201

189-
def image(self, prompt: str, *, n: int = ..., size: str = ...,
190-
response_format: str = ..., user: str = ...,
191-
**kwargs):
192-
self._populate_args(kwargs, prompt = prompt, n = n, size = size,
193-
response_format = response_format, user = user)
202+
def image(
203+
self,
204+
prompt: str,
205+
*,
206+
n: int = ...,
207+
size: str = ...,
208+
response_format: str = ...,
209+
user: str = ...,
210+
**kwargs,
211+
):
212+
self._populate_args(
213+
kwargs,
214+
prompt=prompt,
215+
n=n,
216+
size=size,
217+
response_format=response_format,
218+
user=user,
219+
)
194220
return typing.cast(openai.Image, openai.Image.create(**kwargs))
195-
196-
async def aimage(self, prompt: str, *, n: int = ..., size: str = ...,
197-
response_format: str = ..., user: str = ...,
198-
**kwargs):
199-
self._populate_args(kwargs, prompt = prompt, n = n, size = size,
200-
response_format = response_format, user = user)
221+
222+
async def aimage(
223+
self,
224+
prompt: str,
225+
*,
226+
n: int = ...,
227+
size: str = ...,
228+
response_format: str = ...,
229+
user: str = ...,
230+
**kwargs,
231+
):
232+
self._populate_args(
233+
kwargs,
234+
prompt=prompt,
235+
n=n,
236+
size=size,
237+
response_format=response_format,
238+
user=user,
239+
)
201240
return typing.cast(openai.Image, await openai.Image.acreate(**kwargs))
202241

203-
def image_variation(self, image: bytes | typing.BinaryIO, *, n: int = ...,
204-
size: str = ..., response_format: str = ...,
205-
user: str = ..., **kwargs):
206-
self._populate_args(kwargs, image = image, n = n, size = size,
207-
response_format = response_format, user = user)
242+
def image_variation(
243+
self,
244+
image: bytes | typing.BinaryIO,
245+
*,
246+
n: int = ...,
247+
size: str = ...,
248+
response_format: str = ...,
249+
user: str = ...,
250+
**kwargs,
251+
):
252+
self._populate_args(
253+
kwargs,
254+
image=image,
255+
n=n,
256+
size=size,
257+
response_format=response_format,
258+
user=user,
259+
)
208260
return typing.cast(openai.Image, openai.Image.create_variation(**kwargs))
209261

210-
async def aimage_variation(self, image: bytes | typing.BinaryIO, *, n: int = ...,
211-
size: str = ..., response_format: str = ...,
212-
user: str = ..., **kwargs):
213-
self._populate_args(kwargs, image = image, n = n, size = size,
214-
response_format = response_format, user = user)
262+
async def aimage_variation(
263+
self,
264+
image: bytes | typing.BinaryIO,
265+
*,
266+
n: int = ...,
267+
size: str = ...,
268+
response_format: str = ...,
269+
user: str = ...,
270+
**kwargs,
271+
):
272+
self._populate_args(
273+
kwargs,
274+
image=image,
275+
n=n,
276+
size=size,
277+
response_format=response_format,
278+
user=user,
279+
)
215280
return typing.cast(openai.Image, await openai.Image.acreate_variation(**kwargs))
216281

217-
def image_edit(self, image: bytes | typing.BinaryIO, prompt: str, *, mask: str = ..., n: int = ...,
218-
size: str = ..., response_format: str = ...,
219-
user: str = ..., **kwargs):
220-
self._populate_args(kwargs, image = image, n = n, size = size,
221-
prompt = prompt, mask = mask,
222-
response_format = response_format, user = user)
282+
def image_edit(
283+
self,
284+
image: bytes | typing.BinaryIO,
285+
prompt: str,
286+
*,
287+
mask: str = ...,
288+
n: int = ...,
289+
size: str = ...,
290+
response_format: str = ...,
291+
user: str = ...,
292+
**kwargs,
293+
):
294+
self._populate_args(
295+
kwargs,
296+
image=image,
297+
n=n,
298+
size=size,
299+
prompt=prompt,
300+
mask=mask,
301+
response_format=response_format,
302+
user=user,
303+
)
223304
return typing.cast(openai.Image, openai.Image.create_edit(**kwargs))
224-
225-
async def aimage_edit(self, image: bytes | typing.BinaryIO, prompt: str, *, mask: str = ..., n: int = ...,
226-
size: str = ..., response_format: str = ...,
227-
user: str = ..., **kwargs):
228-
self._populate_args(kwargs, image = image, n = n, size = size,
229-
prompt = prompt, mask = mask,
230-
response_format = response_format, user = user)
305+
306+
async def aimage_edit(
307+
self,
308+
image: bytes | typing.BinaryIO,
309+
prompt: str,
310+
*,
311+
mask: str = ...,
312+
n: int = ...,
313+
size: str = ...,
314+
response_format: str = ...,
315+
user: str = ...,
316+
**kwargs,
317+
):
318+
self._populate_args(
319+
kwargs,
320+
image=image,
321+
n=n,
322+
size=size,
323+
prompt=prompt,
324+
mask=mask,
325+
response_format=response_format,
326+
user=user,
327+
)
231328
return typing.cast(openai.Image, await openai.Image.acreate_edit(**kwargs))
232329

330+
233331
if __name__ == "__main__":
234332
client = OpenAIClient(
235333
api_base="https://achand-openai-0.openai.azure.com/",
@@ -240,16 +338,18 @@ async def aimage_edit(self, image: bytes | typing.BinaryIO, prompt: str, *, mask
240338
# print(client.embeddings("What, or what is this?", model="arch")) # Doesn't work 'cause it is the wrong model...
241339

242340
import asyncio
341+
243342
async def stream_chat():
244-
respco = await client.aiter_completion("what is up, my friend?", model="chatgpt")
343+
respco = await client.aiter_completion(
344+
"what is up, my friend?", model="chatgpt"
345+
)
245346
async for rsp in respco:
246347
print(rsp)
247348

248349
asyncio.run(stream_chat())
249-
250350

251351
oaiclient = OpenAIClient()
252352
print(oaiclient.completion("what is up, my friend?", model="text-davinci-003"))
253353
print(oaiclient.embeddings("What are embeddings?", model="text-embedding-ada-002"))
254354
rsp = oaiclient.image("Happy cattle", response_format="b64_json")
255-
print(rsp)
355+
print(rsp)

0 commit comments

Comments
 (0)