@@ -20,7 +20,10 @@ def __init__(self, credential=None):
20
20
self ._cached_token = None
21
21
22
22
def get_token (self ) -> str :
23
- if self ._cached_token is None or (self ._cached_token .expires_on - time .time ()) < 300 :
23
+ if (
24
+ self ._cached_token is None
25
+ or (self ._cached_token .expires_on - time .time ()) < 300
26
+ ):
24
27
self ._cached_token = self ._credential .get_token (
25
28
"https://cognitiveservices.azure.com/.default"
26
29
)
@@ -40,7 +43,6 @@ def get_token(self) -> str:
40
43
41
44
42
45
class OpenAIClient :
43
-
44
46
def __init__ (
45
47
self ,
46
48
* ,
@@ -81,8 +83,7 @@ def __init__(
81
83
)
82
84
83
85
def _populate_args (self , kwargs : typing .Dict [str , typing .Any ], ** overrides ) -> None :
84
- """Populate default arguments based on the current client configuration/defaults
85
- """
86
+ """Populate default arguments based on the current client configuration/defaults"""
86
87
kwargs .setdefault ("api_base" , self .api_base or openai .api_base )
87
88
kwargs .setdefault ("api_key" , self .auth .get_token ())
88
89
kwargs .setdefault ("api_type" , self .api_type )
@@ -98,19 +99,30 @@ def _populate_args(self, kwargs: typing.Dict[str, typing.Any], **overrides) -> N
98
99
99
100
def _normalize_model (self , kwargs : typing .Dict [str , typing .Any ]):
100
101
"""Normalize model/engine/deployment_id based on which backend the client is
101
- configured to target.
102
+ configured to target.
102
103
103
- Specifically, it will pass the provided `model` parameter as `deployment_id`
104
- unless `deployment_id` is explicitly passed in.
104
+ Specifically, it will pass the provided `model` parameter as `deployment_id`
105
+ unless `deployment_id` is explicitly passed in.
105
106
"""
106
- if len ([param for param in kwargs if param in ('deployment_id' , 'model' , 'engine' )]) != 1 :
107
- raise TypeError ('You can only specify one of `deployment_id`, `model` and `engine`' )
108
-
109
- if self .backend == 'azure' :
107
+ if (
108
+ len (
109
+ [
110
+ param
111
+ for param in kwargs
112
+ if param in ("deployment_id" , "model" , "engine" )
113
+ ]
114
+ )
115
+ != 1
116
+ ):
117
+ raise TypeError (
118
+ "You can only specify one of `deployment_id`, `model` and `engine`"
119
+ )
120
+
121
+ if self .backend == "azure" :
110
122
try :
111
123
# We'll try to "rename" the `model` keyword to fit azure's `deployment_id`
112
124
# paradigm
113
- kwargs [' deployment_id' ] = kwargs .pop (' model' )
125
+ kwargs [" deployment_id" ] = kwargs .pop (" model" )
114
126
except KeyError :
115
127
pass
116
128
@@ -139,7 +151,8 @@ async def aiter_completion(
139
151
self ._populate_args (kwargs , prompt = prompt , stream = True )
140
152
self ._normalize_model (kwargs )
141
153
return typing .cast (
142
- typing .AsyncIterable [openai .Completion ], await openai .Completion .acreate (** kwargs )
154
+ typing .AsyncIterable [openai .Completion ],
155
+ await openai .Completion .acreate (** kwargs ),
143
156
)
144
157
145
158
def chatcompletion (self , messages , ** kwargs ) -> openai .ChatCompletion :
@@ -186,50 +199,135 @@ async def aembeddings(self, input, **kwargs) -> openai.Embedding:
186
199
self ._normalize_model (kwargs )
187
200
return typing .cast (openai .Embedding , await openai .Embedding .acreate (** kwargs ))
188
201
189
- def image (self , prompt : str , * , n : int = ..., size : str = ...,
190
- response_format : str = ..., user : str = ...,
191
- ** kwargs ):
192
- self ._populate_args (kwargs , prompt = prompt , n = n , size = size ,
193
- response_format = response_format , user = user )
202
+ def image (
203
+ self ,
204
+ prompt : str ,
205
+ * ,
206
+ n : int = ...,
207
+ size : str = ...,
208
+ response_format : str = ...,
209
+ user : str = ...,
210
+ ** kwargs ,
211
+ ):
212
+ self ._populate_args (
213
+ kwargs ,
214
+ prompt = prompt ,
215
+ n = n ,
216
+ size = size ,
217
+ response_format = response_format ,
218
+ user = user ,
219
+ )
194
220
return typing .cast (openai .Image , openai .Image .create (** kwargs ))
195
-
196
- async def aimage (self , prompt : str , * , n : int = ..., size : str = ...,
197
- response_format : str = ..., user : str = ...,
198
- ** kwargs ):
199
- self ._populate_args (kwargs , prompt = prompt , n = n , size = size ,
200
- response_format = response_format , user = user )
221
+
222
+ async def aimage (
223
+ self ,
224
+ prompt : str ,
225
+ * ,
226
+ n : int = ...,
227
+ size : str = ...,
228
+ response_format : str = ...,
229
+ user : str = ...,
230
+ ** kwargs ,
231
+ ):
232
+ self ._populate_args (
233
+ kwargs ,
234
+ prompt = prompt ,
235
+ n = n ,
236
+ size = size ,
237
+ response_format = response_format ,
238
+ user = user ,
239
+ )
201
240
return typing .cast (openai .Image , await openai .Image .acreate (** kwargs ))
202
241
203
- def image_variation (self , image : bytes | typing .BinaryIO , * , n : int = ...,
204
- size : str = ..., response_format : str = ...,
205
- user : str = ..., ** kwargs ):
206
- self ._populate_args (kwargs , image = image , n = n , size = size ,
207
- response_format = response_format , user = user )
242
+ def image_variation (
243
+ self ,
244
+ image : bytes | typing .BinaryIO ,
245
+ * ,
246
+ n : int = ...,
247
+ size : str = ...,
248
+ response_format : str = ...,
249
+ user : str = ...,
250
+ ** kwargs ,
251
+ ):
252
+ self ._populate_args (
253
+ kwargs ,
254
+ image = image ,
255
+ n = n ,
256
+ size = size ,
257
+ response_format = response_format ,
258
+ user = user ,
259
+ )
208
260
return typing .cast (openai .Image , openai .Image .create_variation (** kwargs ))
209
261
210
- async def aimage_variation (self , image : bytes | typing .BinaryIO , * , n : int = ...,
211
- size : str = ..., response_format : str = ...,
212
- user : str = ..., ** kwargs ):
213
- self ._populate_args (kwargs , image = image , n = n , size = size ,
214
- response_format = response_format , user = user )
262
+ async def aimage_variation (
263
+ self ,
264
+ image : bytes | typing .BinaryIO ,
265
+ * ,
266
+ n : int = ...,
267
+ size : str = ...,
268
+ response_format : str = ...,
269
+ user : str = ...,
270
+ ** kwargs ,
271
+ ):
272
+ self ._populate_args (
273
+ kwargs ,
274
+ image = image ,
275
+ n = n ,
276
+ size = size ,
277
+ response_format = response_format ,
278
+ user = user ,
279
+ )
215
280
return typing .cast (openai .Image , await openai .Image .acreate_variation (** kwargs ))
216
281
217
- def image_edit (self , image : bytes | typing .BinaryIO , prompt : str , * , mask : str = ..., n : int = ...,
218
- size : str = ..., response_format : str = ...,
219
- user : str = ..., ** kwargs ):
220
- self ._populate_args (kwargs , image = image , n = n , size = size ,
221
- prompt = prompt , mask = mask ,
222
- response_format = response_format , user = user )
282
+ def image_edit (
283
+ self ,
284
+ image : bytes | typing .BinaryIO ,
285
+ prompt : str ,
286
+ * ,
287
+ mask : str = ...,
288
+ n : int = ...,
289
+ size : str = ...,
290
+ response_format : str = ...,
291
+ user : str = ...,
292
+ ** kwargs ,
293
+ ):
294
+ self ._populate_args (
295
+ kwargs ,
296
+ image = image ,
297
+ n = n ,
298
+ size = size ,
299
+ prompt = prompt ,
300
+ mask = mask ,
301
+ response_format = response_format ,
302
+ user = user ,
303
+ )
223
304
return typing .cast (openai .Image , openai .Image .create_edit (** kwargs ))
224
-
225
- async def aimage_edit (self , image : bytes | typing .BinaryIO , prompt : str , * , mask : str = ..., n : int = ...,
226
- size : str = ..., response_format : str = ...,
227
- user : str = ..., ** kwargs ):
228
- self ._populate_args (kwargs , image = image , n = n , size = size ,
229
- prompt = prompt , mask = mask ,
230
- response_format = response_format , user = user )
305
+
306
+ async def aimage_edit (
307
+ self ,
308
+ image : bytes | typing .BinaryIO ,
309
+ prompt : str ,
310
+ * ,
311
+ mask : str = ...,
312
+ n : int = ...,
313
+ size : str = ...,
314
+ response_format : str = ...,
315
+ user : str = ...,
316
+ ** kwargs ,
317
+ ):
318
+ self ._populate_args (
319
+ kwargs ,
320
+ image = image ,
321
+ n = n ,
322
+ size = size ,
323
+ prompt = prompt ,
324
+ mask = mask ,
325
+ response_format = response_format ,
326
+ user = user ,
327
+ )
231
328
return typing .cast (openai .Image , await openai .Image .acreate_edit (** kwargs ))
232
329
330
+
233
331
if __name__ == "__main__" :
234
332
client = OpenAIClient (
235
333
api_base = "https://achand-openai-0.openai.azure.com/" ,
@@ -240,16 +338,18 @@ async def aimage_edit(self, image: bytes | typing.BinaryIO, prompt: str, *, mask
240
338
# print(client.embeddings("What, or what is this?", model="arch")) # Doesn't work 'cause it is the wrong model...
241
339
242
340
import asyncio
341
+
243
342
async def stream_chat ():
244
- respco = await client .aiter_completion ("what is up, my friend?" , model = "chatgpt" )
343
+ respco = await client .aiter_completion (
344
+ "what is up, my friend?" , model = "chatgpt"
345
+ )
245
346
async for rsp in respco :
246
347
print (rsp )
247
348
248
349
asyncio .run (stream_chat ())
249
-
250
350
251
351
oaiclient = OpenAIClient ()
252
352
print (oaiclient .completion ("what is up, my friend?" , model = "text-davinci-003" ))
253
353
print (oaiclient .embeddings ("What are embeddings?" , model = "text-embedding-ada-002" ))
254
354
rsp = oaiclient .image ("Happy cattle" , response_format = "b64_json" )
255
- print (rsp )
355
+ print (rsp )
0 commit comments