|
14 | 14 | from enum import Enum, auto
|
15 | 15 | from typing import TYPE_CHECKING, Any, Dict, cast
|
16 | 16 |
|
17 |
| -from cryptography.hazmat.primitives import padding, serialization |
| 17 | +from cryptography.hazmat.primitives import hashes, padding, serialization |
18 | 18 | from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
19 | 19 | from cryptography.hazmat.primitives.asymmetric import rsa
|
20 | 20 | from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
@@ -108,7 +108,9 @@ def __init__(
|
108 | 108 | self._key_pair: KeyPair | None = None
|
109 | 109 | if config.aes_keys:
|
110 | 110 | aes_keys = config.aes_keys
|
111 |
| - self._key_pair = KeyPair(aes_keys["private"], aes_keys["public"]) |
| 111 | + self._key_pair = KeyPair.create_from_der_keys( |
| 112 | + aes_keys["private"], aes_keys["public"] |
| 113 | + ) |
112 | 114 | self._app_url = URL(f"http://{self._host}:{self._port}/app")
|
113 | 115 | self._token_url: URL | None = None
|
114 | 116 |
|
@@ -277,14 +279,14 @@ async def _generate_key_pair_payload(self) -> AsyncGenerator:
|
277 | 279 | if not self._key_pair:
|
278 | 280 | kp = KeyPair.create_key_pair()
|
279 | 281 | self._config.aes_keys = {
|
280 |
| - "private": kp.get_private_key(), |
281 |
| - "public": kp.get_public_key(), |
| 282 | + "private": kp.private_key_der_b64, |
| 283 | + "public": kp.public_key_der_b64, |
282 | 284 | }
|
283 | 285 | self._key_pair = kp
|
284 | 286 |
|
285 | 287 | pub_key = (
|
286 | 288 | "-----BEGIN PUBLIC KEY-----\n"
|
287 |
| - + self._key_pair.get_public_key() # type: ignore[union-attr] |
| 289 | + + self._key_pair.public_key_der_b64 # type: ignore[union-attr] |
288 | 290 | + "\n-----END PUBLIC KEY-----\n"
|
289 | 291 | )
|
290 | 292 | handshake_params = {"key": pub_key}
|
@@ -392,18 +394,11 @@ class AesEncyptionSession:
|
392 | 394 | """Class for an AES encryption session."""
|
393 | 395 |
|
394 | 396 | @staticmethod
|
395 |
| - def create_from_keypair(handshake_key: str, keypair): |
| 397 | + def create_from_keypair(handshake_key: str, keypair: KeyPair): |
396 | 398 | """Create the encryption session."""
|
397 |
| - handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode("UTF-8")) |
398 |
| - private_key_data = base64.b64decode(keypair.get_private_key().encode("UTF-8")) |
| 399 | + handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode()) |
399 | 400 |
|
400 |
| - private_key = cast( |
401 |
| - rsa.RSAPrivateKey, |
402 |
| - serialization.load_der_private_key(private_key_data, None, None), |
403 |
| - ) |
404 |
| - key_and_iv = private_key.decrypt( |
405 |
| - handshake_key_bytes, asymmetric_padding.PKCS1v15() |
406 |
| - ) |
| 401 | + key_and_iv = keypair.decrypt_handshake_key(handshake_key_bytes) |
407 | 402 | if key_and_iv is None:
|
408 | 403 | raise ValueError("Decryption failed!")
|
409 | 404 |
|
@@ -438,30 +433,59 @@ def create_key_pair(key_size: int = 1024):
|
438 | 433 | """Create a key pair."""
|
439 | 434 | private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
|
440 | 435 | public_key = private_key.public_key()
|
| 436 | + return KeyPair(private_key, public_key) |
| 437 | + |
| 438 | + @staticmethod |
| 439 | + def create_from_der_keys(private_key_der_b64: str, public_key_der_b64: str): |
| 440 | + """Create a key pair.""" |
| 441 | + key_bytes = base64.b64decode(private_key_der_b64.encode()) |
| 442 | + private_key = cast( |
| 443 | + rsa.RSAPrivateKey, serialization.load_der_private_key(key_bytes, None) |
| 444 | + ) |
| 445 | + key_bytes = base64.b64decode(public_key_der_b64.encode()) |
| 446 | + public_key = cast( |
| 447 | + rsa.RSAPublicKey, serialization.load_der_public_key(key_bytes, None) |
| 448 | + ) |
441 | 449 |
|
442 |
| - private_key_bytes = private_key.private_bytes( |
| 450 | + return KeyPair(private_key, public_key) |
| 451 | + |
| 452 | + def __init__(self, private_key: rsa.RSAPrivateKey, public_key: rsa.RSAPublicKey): |
| 453 | + self.private_key = private_key |
| 454 | + self.public_key = public_key |
| 455 | + self.private_key_der_bytes = self.private_key.private_bytes( |
443 | 456 | encoding=serialization.Encoding.DER,
|
444 | 457 | format=serialization.PrivateFormat.PKCS8,
|
445 | 458 | encryption_algorithm=serialization.NoEncryption(),
|
446 | 459 | )
|
447 |
| - public_key_bytes = public_key.public_bytes( |
| 460 | + self.public_key_der_bytes = self.public_key.public_bytes( |
448 | 461 | encoding=serialization.Encoding.DER,
|
449 | 462 | format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
450 | 463 | )
|
| 464 | + self.private_key_der_b64 = base64.b64encode(self.private_key_der_bytes).decode() |
| 465 | + self.public_key_der_b64 = base64.b64encode(self.public_key_der_bytes).decode() |
451 | 466 |
|
452 |
| - return KeyPair( |
453 |
| - private_key=base64.b64encode(private_key_bytes).decode("UTF-8"), |
454 |
| - public_key=base64.b64encode(public_key_bytes).decode("UTF-8"), |
| 467 | + def get_public_pem(self) -> bytes: |
| 468 | + """Get public key in PEM encoding.""" |
| 469 | + return self.public_key.public_bytes( |
| 470 | + encoding=serialization.Encoding.PEM, |
| 471 | + format=serialization.PublicFormat.SubjectPublicKeyInfo, |
455 | 472 | )
|
456 | 473 |
|
457 |
| - def __init__(self, private_key: str, public_key: str): |
458 |
| - self.private_key = private_key |
459 |
| - self.public_key = public_key |
460 |
| - |
461 |
| - def get_private_key(self) -> str: |
462 |
| - """Get the private key.""" |
463 |
| - return self.private_key |
464 |
| - |
465 |
| - def get_public_key(self) -> str: |
466 |
| - """Get the public key.""" |
467 |
| - return self.public_key |
| 474 | + def decrypt_handshake_key(self, encrypted_key: bytes) -> bytes: |
| 475 | + """Decrypt an aes handshake key.""" |
| 476 | + decrypted = self.private_key.decrypt( |
| 477 | + encrypted_key, asymmetric_padding.PKCS1v15() |
| 478 | + ) |
| 479 | + return decrypted |
| 480 | + |
| 481 | + def decrypt_discovery_key(self, encrypted_key: bytes) -> bytes: |
| 482 | + """Decrypt an aes discovery key.""" |
| 483 | + decrypted = self.private_key.decrypt( |
| 484 | + encrypted_key, |
| 485 | + asymmetric_padding.OAEP( |
| 486 | + mgf=asymmetric_padding.MGF1(algorithm=hashes.SHA1()), # noqa: S303 |
| 487 | + algorithm=hashes.SHA1(), # noqa: S303 |
| 488 | + label=None, |
| 489 | + ), |
| 490 | + ) |
| 491 | + return decrypted |
0 commit comments