Skip to content

Commit 380fbb9

Browse files
authored
Enable newer encrypted discovery protocol (python-kasa#1168)
1 parent 7fd8c14 commit 380fbb9

File tree

7 files changed

+258
-70
lines changed

7 files changed

+258
-70
lines changed

kasa/aestransport.py

Lines changed: 55 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from enum import Enum, auto
1515
from typing import TYPE_CHECKING, Any, Dict, cast
1616

17-
from cryptography.hazmat.primitives import padding, serialization
17+
from cryptography.hazmat.primitives import hashes, padding, serialization
1818
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
1919
from cryptography.hazmat.primitives.asymmetric import rsa
2020
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
@@ -108,7 +108,9 @@ def __init__(
108108
self._key_pair: KeyPair | None = None
109109
if config.aes_keys:
110110
aes_keys = config.aes_keys
111-
self._key_pair = KeyPair(aes_keys["private"], aes_keys["public"])
111+
self._key_pair = KeyPair.create_from_der_keys(
112+
aes_keys["private"], aes_keys["public"]
113+
)
112114
self._app_url = URL(f"http://{self._host}:{self._port}/app")
113115
self._token_url: URL | None = None
114116

@@ -277,14 +279,14 @@ async def _generate_key_pair_payload(self) -> AsyncGenerator:
277279
if not self._key_pair:
278280
kp = KeyPair.create_key_pair()
279281
self._config.aes_keys = {
280-
"private": kp.get_private_key(),
281-
"public": kp.get_public_key(),
282+
"private": kp.private_key_der_b64,
283+
"public": kp.public_key_der_b64,
282284
}
283285
self._key_pair = kp
284286

285287
pub_key = (
286288
"-----BEGIN PUBLIC KEY-----\n"
287-
+ self._key_pair.get_public_key() # type: ignore[union-attr]
289+
+ self._key_pair.public_key_der_b64 # type: ignore[union-attr]
288290
+ "\n-----END PUBLIC KEY-----\n"
289291
)
290292
handshake_params = {"key": pub_key}
@@ -392,18 +394,11 @@ class AesEncyptionSession:
392394
"""Class for an AES encryption session."""
393395

394396
@staticmethod
395-
def create_from_keypair(handshake_key: str, keypair):
397+
def create_from_keypair(handshake_key: str, keypair: KeyPair):
396398
"""Create the encryption session."""
397-
handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode("UTF-8"))
398-
private_key_data = base64.b64decode(keypair.get_private_key().encode("UTF-8"))
399+
handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode())
399400

400-
private_key = cast(
401-
rsa.RSAPrivateKey,
402-
serialization.load_der_private_key(private_key_data, None, None),
403-
)
404-
key_and_iv = private_key.decrypt(
405-
handshake_key_bytes, asymmetric_padding.PKCS1v15()
406-
)
401+
key_and_iv = keypair.decrypt_handshake_key(handshake_key_bytes)
407402
if key_and_iv is None:
408403
raise ValueError("Decryption failed!")
409404

@@ -438,30 +433,59 @@ def create_key_pair(key_size: int = 1024):
438433
"""Create a key pair."""
439434
private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
440435
public_key = private_key.public_key()
436+
return KeyPair(private_key, public_key)
437+
438+
@staticmethod
439+
def create_from_der_keys(private_key_der_b64: str, public_key_der_b64: str):
440+
"""Create a key pair."""
441+
key_bytes = base64.b64decode(private_key_der_b64.encode())
442+
private_key = cast(
443+
rsa.RSAPrivateKey, serialization.load_der_private_key(key_bytes, None)
444+
)
445+
key_bytes = base64.b64decode(public_key_der_b64.encode())
446+
public_key = cast(
447+
rsa.RSAPublicKey, serialization.load_der_public_key(key_bytes, None)
448+
)
441449

442-
private_key_bytes = private_key.private_bytes(
450+
return KeyPair(private_key, public_key)
451+
452+
def __init__(self, private_key: rsa.RSAPrivateKey, public_key: rsa.RSAPublicKey):
453+
self.private_key = private_key
454+
self.public_key = public_key
455+
self.private_key_der_bytes = self.private_key.private_bytes(
443456
encoding=serialization.Encoding.DER,
444457
format=serialization.PrivateFormat.PKCS8,
445458
encryption_algorithm=serialization.NoEncryption(),
446459
)
447-
public_key_bytes = public_key.public_bytes(
460+
self.public_key_der_bytes = self.public_key.public_bytes(
448461
encoding=serialization.Encoding.DER,
449462
format=serialization.PublicFormat.SubjectPublicKeyInfo,
450463
)
464+
self.private_key_der_b64 = base64.b64encode(self.private_key_der_bytes).decode()
465+
self.public_key_der_b64 = base64.b64encode(self.public_key_der_bytes).decode()
451466

452-
return KeyPair(
453-
private_key=base64.b64encode(private_key_bytes).decode("UTF-8"),
454-
public_key=base64.b64encode(public_key_bytes).decode("UTF-8"),
467+
def get_public_pem(self) -> bytes:
468+
"""Get public key in PEM encoding."""
469+
return self.public_key.public_bytes(
470+
encoding=serialization.Encoding.PEM,
471+
format=serialization.PublicFormat.SubjectPublicKeyInfo,
455472
)
456473

457-
def __init__(self, private_key: str, public_key: str):
458-
self.private_key = private_key
459-
self.public_key = public_key
460-
461-
def get_private_key(self) -> str:
462-
"""Get the private key."""
463-
return self.private_key
464-
465-
def get_public_key(self) -> str:
466-
"""Get the public key."""
467-
return self.public_key
474+
def decrypt_handshake_key(self, encrypted_key: bytes) -> bytes:
475+
"""Decrypt an aes handshake key."""
476+
decrypted = self.private_key.decrypt(
477+
encrypted_key, asymmetric_padding.PKCS1v15()
478+
)
479+
return decrypted
480+
481+
def decrypt_discovery_key(self, encrypted_key: bytes) -> bytes:
482+
"""Decrypt an aes discovery key."""
483+
decrypted = self.private_key.decrypt(
484+
encrypted_key,
485+
asymmetric_padding.OAEP(
486+
mgf=asymmetric_padding.MGF1(algorithm=hashes.SHA1()), # noqa: S303
487+
algorithm=hashes.SHA1(), # noqa: S303
488+
label=None,
489+
),
490+
)
491+
return decrypted

kasa/cli/discover.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import asyncio
6+
from pprint import pformat as pf
67

78
import asyncclick as click
89
from pydantic.v1 import ValidationError
@@ -28,6 +29,7 @@ async def discover(ctx):
2829
password = ctx.parent.params["password"]
2930
discovery_timeout = ctx.parent.params["discovery_timeout"]
3031
timeout = ctx.parent.params["timeout"]
32+
host = ctx.parent.params["host"]
3133
port = ctx.parent.params["port"]
3234

3335
credentials = Credentials(username, password) if username and password else None
@@ -49,8 +51,6 @@ async def print_unsupported(unsupported_exception: UnsupportedDeviceError):
4951
echo(f"\t{unsupported_exception}")
5052
echo()
5153

52-
echo(f"Discovering devices on {target} for {discovery_timeout} seconds")
53-
5454
from .device import state
5555

5656
async def print_discovered(dev: Device):
@@ -68,6 +68,18 @@ async def print_discovered(dev: Device):
6868
discovered[dev.host] = dev.internal_state
6969
echo()
7070

71+
if host:
72+
echo(f"Discovering device {host} for {discovery_timeout} seconds")
73+
return await Discover.discover_single(
74+
host,
75+
port=port,
76+
credentials=credentials,
77+
timeout=timeout,
78+
discovery_timeout=discovery_timeout,
79+
on_unsupported=print_unsupported,
80+
)
81+
82+
echo(f"Discovering devices on {target} for {discovery_timeout} seconds")
7183
discovered_devices = await Discover.discover(
7284
target=target,
7385
discovery_timeout=discovery_timeout,
@@ -113,21 +125,31 @@ def _echo_discovery_info(discovery_info):
113125
_echo_dictionary(discovery_info)
114126
return
115127

128+
def _conditional_echo(label, value):
129+
if value:
130+
ws = " " * (19 - len(label))
131+
echo(f"\t{label}:{ws}{value}")
132+
116133
echo("\t[bold]== Discovery Result ==[/bold]")
117-
echo(f"\tDevice Type: {dr.device_type}")
118-
echo(f"\tDevice Model: {dr.device_model}")
119-
echo(f"\tIP: {dr.ip}")
120-
echo(f"\tMAC: {dr.mac}")
121-
echo(f"\tDevice Id (hash): {dr.device_id}")
122-
echo(f"\tOwner (hash): {dr.owner}")
123-
echo(f"\tHW Ver: {dr.hw_ver}")
124-
echo(f"\tSupports IOT Cloud: {dr.is_support_iot_cloud}")
125-
echo(f"\tOBD Src: {dr.obd_src}")
126-
echo(f"\tFactory Default: {dr.factory_default}")
127-
echo(f"\tEncrypt Type: {dr.mgt_encrypt_schm.encrypt_type}")
128-
echo(f"\tSupports HTTPS: {dr.mgt_encrypt_schm.is_support_https}")
129-
echo(f"\tHTTP Port: {dr.mgt_encrypt_schm.http_port}")
130-
echo(f"\tLV (Login Level): {dr.mgt_encrypt_schm.lv}")
134+
_conditional_echo("Device Type", dr.device_type)
135+
_conditional_echo("Device Model", dr.device_model)
136+
_conditional_echo("Device Name", dr.device_name)
137+
_conditional_echo("IP", dr.ip)
138+
_conditional_echo("MAC", dr.mac)
139+
_conditional_echo("Device Id (hash)", dr.device_id)
140+
_conditional_echo("Owner (hash)", dr.owner)
141+
_conditional_echo("FW Ver", dr.firmware_version)
142+
_conditional_echo("HW Ver", dr.hw_ver)
143+
_conditional_echo("HW Ver", dr.hardware_version)
144+
_conditional_echo("Supports IOT Cloud", dr.is_support_iot_cloud)
145+
_conditional_echo("OBD Src", dr.owner)
146+
_conditional_echo("Factory Default", dr.factory_default)
147+
_conditional_echo("Encrypt Type", dr.mgt_encrypt_schm.encrypt_type)
148+
_conditional_echo("Encrypt Type", dr.encrypt_type)
149+
_conditional_echo("Supports HTTPS", dr.mgt_encrypt_schm.is_support_https)
150+
_conditional_echo("HTTP Port", dr.mgt_encrypt_schm.http_port)
151+
_conditional_echo("Encrypt info", pf(dr.encrypt_info) if dr.encrypt_info else None)
152+
_conditional_echo("Decrypted", pf(dr.decrypted_data) if dr.decrypted_data else None)
131153

132154

133155
async def find_host_from_alias(alias, target="255.255.255.255", timeout=1, attempts=3):

kasa/cli/main.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def _legacy_type_to_class(_type):
158158
type=click.Choice(ENCRYPT_TYPES, case_sensitive=False),
159159
)
160160
@click.option(
161+
"-df",
161162
"--device-family",
162163
envvar="KASA_DEVICE_FAMILY",
163164
default="SMART.TAPOPLUG",
@@ -182,7 +183,7 @@ def _legacy_type_to_class(_type):
182183
@click.option(
183184
"--discovery-timeout",
184185
envvar="KASA_DISCOVERY_TIMEOUT",
185-
default=5,
186+
default=10,
186187
required=False,
187188
show_default=True,
188189
help="Timeout for discovery.",
@@ -326,15 +327,11 @@ async def cli(
326327
dev = await Device.connect(config=config)
327328
device_updated = True
328329
else:
329-
from kasa.discover import Discover
330+
from .discover import discover
330331

331-
dev = await Discover.discover_single(
332-
host,
333-
port=port,
334-
credentials=credentials,
335-
timeout=timeout,
336-
discovery_timeout=discovery_timeout,
337-
)
332+
dev = await ctx.invoke(discover)
333+
if not dev:
334+
error(f"Unable to create device for {host}")
338335

339336
# Skip update on specific commands, or if device factory,
340337
# that performs an update was used for the device.

0 commit comments

Comments
 (0)