Skip to content

Commit 6b68f96

Browse files
authored
Refactor to support push updates (#49)
- Fixes the assumption that address were mac addresses on MacOS they are UUIDs - Adds the ability to consume an update from the scanner that is always running
1 parent 85cfd3c commit 6b68f96

File tree

1 file changed

+109
-103
lines changed

1 file changed

+109
-103
lines changed

switchbot/__init__.py

Lines changed: 109 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
import asyncio
55
import binascii
66
import logging
7+
from dataclasses import dataclass
78
from typing import Any
89
from uuid import UUID
910

1011
import bleak
12+
from bleak.backends.device import BLEDevice
13+
from bleak.backends.scanner import AdvertisementData
1114

1215
DEFAULT_RETRY_COUNT = 3
1316
DEFAULT_RETRY_TIMEOUT = 1
@@ -100,62 +103,78 @@ def _process_wosensorth(data: bytes) -> dict[str, object]:
100103
return _wosensorth_data
101104

102105

106+
@dataclass
107+
class SwitchBotAdvertisement:
108+
"""Switchbot advertisement."""
109+
110+
address: str
111+
data: dict[str, Any]
112+
device: BLEDevice
113+
114+
115+
def parse_advertisement_data(
116+
device: BLEDevice, advertisement_data: AdvertisementData
117+
) -> SwitchBotAdvertisement | None:
118+
"""Parse advertisement data."""
119+
_services = list(advertisement_data.service_data.values())
120+
if not _services:
121+
return
122+
_service_data = _services[0]
123+
_model = chr(_service_data[0] & 0b01111111)
124+
125+
supported_types: dict[str, dict[str, Any]] = {
126+
"H": {"modelName": "WoHand", "func": _process_wohand},
127+
"c": {"modelName": "WoCurtain", "func": _process_wocurtain},
128+
"T": {"modelName": "WoSensorTH", "func": _process_wosensorth},
129+
}
130+
131+
data = {
132+
"address": device.address, # MacOS uses UUIDs
133+
"rawAdvData": list(advertisement_data.service_data.values())[0],
134+
"data": {
135+
"rssi": device.rssi,
136+
},
137+
}
138+
139+
if _model in supported_types:
140+
141+
data.update(
142+
{
143+
"isEncrypted": bool(_service_data[0] & 0b10000000),
144+
"model": _model,
145+
"modelName": supported_types[_model]["modelName"],
146+
"data": supported_types[_model]["func"](_service_data),
147+
}
148+
)
149+
150+
data["data"]["rssi"] = device.rssi
151+
152+
return SwitchBotAdvertisement(device.address, data, device)
153+
154+
103155
class GetSwitchbotDevices:
104156
"""Scan for all Switchbot devices and return by type."""
105157

106158
def __init__(self, interface: int = 0) -> None:
107159
"""Get switchbot devices class constructor."""
108160
self._interface = f"hci{interface}"
109-
self._adv_data: dict[str, Any] = {}
161+
self._adv_data: dict[str, SwitchBotAdvertisement] = {}
110162

111163
def detection_callback(
112164
self,
113-
device: bleak.backends.device.BLEDevice,
114-
advertisement_data: bleak.backends.scanner.AdvertisementData,
165+
device: BLEDevice,
166+
advertisement_data: AdvertisementData,
115167
) -> None:
116-
"""BTLE adv scan callback."""
117-
_services = list(advertisement_data.service_data.values())
118-
if not _services:
119-
return
120-
_service_data = _services[0]
121-
122-
_device = device.address.replace(":", "").lower()
123-
_model = chr(_service_data[0] & 0b01111111)
124-
125-
supported_types: dict[str, dict[str, Any]] = {
126-
"H": {"modelName": "WoHand", "func": _process_wohand},
127-
"c": {"modelName": "WoCurtain", "func": _process_wocurtain},
128-
"T": {"modelName": "WoSensorTH", "func": _process_wosensorth},
129-
}
130-
131-
self._adv_data[_device] = {
132-
"mac_address": device.address.lower(),
133-
"rawAdvData": list(advertisement_data.service_data.values())[0],
134-
"data": {
135-
"rssi": device.rssi,
136-
},
137-
}
138-
139-
if _model in supported_types:
140-
141-
self._adv_data[_device].update(
142-
{
143-
"isEncrypted": bool(_service_data[0] & 0b10000000),
144-
"model": _model,
145-
"modelName": supported_types[_model]["modelName"],
146-
"data": supported_types[_model]["func"](_service_data),
147-
}
148-
)
149-
150-
self._adv_data[_device]["data"]["rssi"] = device.rssi
168+
discovery = parse_advertisement_data(device, advertisement_data)
169+
if discovery:
170+
self._adv_data[discovery.address] = discovery
151171

152172
async def discover(
153173
self, retry: int = DEFAULT_RETRY_COUNT, scan_timeout: int = DEFAULT_SCAN_TIMEOUT
154174
) -> dict:
155175
"""Find switchbot devices and their advertisement data."""
156176

157177
devices = None
158-
159178
devices = bleak.BleakScanner(
160179
# TODO: Find new UUIDs to filter on. For example, see
161180
# https://github.com/OpenWonderLabs/SwitchBotAPI-BLE/blob/4ad138bb09f0fbbfa41b152ca327a78c1d0b6ba9/devicetypes/meter.md
@@ -184,54 +203,44 @@ async def discover(
184203

185204
return self._adv_data
186205

187-
async def get_curtains(self) -> dict:
188-
"""Return all WoCurtain/Curtains devices with services data."""
206+
async def _get_devices_by_model(
207+
self,
208+
model: str,
209+
) -> dict:
210+
"""Get switchbot devices by type."""
189211
if not self._adv_data:
190212
await self.discover()
191213

192-
_curtain_devices = {
193-
device: data
194-
for device, data in self._adv_data.items()
195-
if data.get("model") == "c"
214+
return {
215+
address: adv
216+
for address, adv in self._adv_data.items()
217+
if adv.data.get("model") == model
196218
}
197219

198-
return _curtain_devices
220+
async def get_curtains(self) -> dict[str, SwitchBotAdvertisement]:
221+
"""Return all WoCurtain/Curtains devices with services data."""
222+
return await self._get_devices_by_model("c")
199223

200-
async def get_bots(self) -> dict[str, Any] | None:
224+
async def get_bots(self) -> dict[str, SwitchBotAdvertisement]:
201225
"""Return all WoHand/Bot devices with services data."""
202-
if not self._adv_data:
203-
await self.discover()
204-
205-
_bot_devices = {
206-
device: data
207-
for device, data in self._adv_data.items()
208-
if data.get("model") == "H"
209-
}
210-
211-
return _bot_devices
226+
return await self._get_devices_by_model("H")
212227

213-
async def get_tempsensors(self) -> dict[str, Any] | None:
228+
async def get_tempsensors(self) -> dict[str, SwitchBotAdvertisement]:
214229
"""Return all WoSensorTH/Temp sensor devices with services data."""
215-
if not self._adv_data:
216-
await self.discover()
217-
218-
_bot_temp = {
219-
device: data
220-
for device, data in self._adv_data.items()
221-
if data.get("model") == "T"
222-
}
223-
224-
return _bot_temp
230+
return await self._get_devices_by_model("T")
225231

226-
async def get_device_data(self, mac: str) -> dict[str, Any] | None:
232+
async def get_device_data(
233+
self, address: str
234+
) -> dict[str, SwitchBotAdvertisement] | None:
227235
"""Return data for specific device."""
228236
if not self._adv_data:
229237
await self.discover()
230238

231239
_switchbot_data = {
232240
device: data
233241
for device, data in self._adv_data.items()
234-
if data.get("mac_address") == mac
242+
# MacOS uses UUIDs instead of MAC addresses
243+
if data.get("address") == address
235244
}
236245

237246
return _switchbot_data
@@ -242,15 +251,15 @@ class SwitchbotDevice:
242251

243252
def __init__(
244253
self,
245-
mac: str,
254+
device: BLEDevice,
246255
password: str | None = None,
247256
interface: int = 0,
248257
**kwargs: Any,
249258
) -> None:
250259
"""Switchbot base class constructor."""
251260
self._interface = f"hci{interface}"
252-
self._mac = mac.replace("-", ":").lower()
253-
self._sb_adv_data: dict[str, Any] = {}
261+
self._device = device
262+
self._sb_adv_data: SwitchBotAdvertisement | None = None
254263
self._scan_timeout: int = kwargs.pop("scan_timeout", DEFAULT_SCAN_TIMEOUT)
255264
self._retry_count: int = kwargs.pop("retry_count", DEFAULT_RETRY_COUNT)
256265
if password is None or password == "":
@@ -279,13 +288,11 @@ async def _sendcommand(self, key: str, retry: int) -> bytes:
279288
notify_msg = b""
280289
_LOGGER.debug("Sending command to switchbot %s", command)
281290

282-
if len(self._mac.split(":")) != 6:
283-
raise ValueError("Expected MAC address, got %s" % repr(self._mac))
284-
285291
async with CONNECT_LOCK:
286292
try:
287293
async with bleak.BleakClient(
288-
address_or_ble_device=self._mac, timeout=float(self._scan_timeout)
294+
address_or_ble_device=self._device,
295+
timeout=float(self._scan_timeout),
289296
) as client:
290297
_LOGGER.debug("Connnected to switchbot: %s", client.is_connected)
291298

@@ -334,15 +341,24 @@ async def _sendcommand(self, key: str, retry: int) -> bytes:
334341
await asyncio.sleep(DEFAULT_RETRY_TIMEOUT)
335342
return await self._sendcommand(key, retry - 1)
336343

337-
def get_mac(self) -> str:
338-
"""Return mac address of device."""
339-
return self._mac
344+
def get_address(self) -> str:
345+
"""Return address of device."""
346+
return self._device.address
340347

341-
def get_battery_percent(self) -> Any:
342-
"""Return device battery level in percent."""
348+
def _get_adv_value(self, key: str) -> Any:
349+
"""Return value from advertisement data."""
343350
if not self._sb_adv_data:
344351
return None
345-
return self._sb_adv_data["data"]["battery"]
352+
return self._sb_adv_data.data["data"][key]
353+
354+
def get_battery_percent(self) -> Any:
355+
"""Return device battery level in percent."""
356+
return self._get_adv_value("battery")
357+
358+
def update_from_advertisement(self, advertisement: SwitchBotAdvertisement) -> None:
359+
"""Update device data from advertisement."""
360+
self._sb_adv_data = advertisement
361+
self._device = advertisement.device
346362

347363
async def get_device_data(
348364
self, retry: int = DEFAULT_RETRY_COUNT, interface: int | None = None
@@ -353,14 +369,12 @@ async def get_device_data(
353369
else:
354370
_interface = int(self._interface.replace("hci", ""))
355371

356-
dev_id = self._mac.replace(":", "")
357-
358372
_data = await GetSwitchbotDevices(interface=_interface).discover(
359373
retry=retry, scan_timeout=self._scan_timeout
360374
)
361375

362-
if _data.get(dev_id):
363-
self._sb_adv_data = _data[dev_id]
376+
if self._device.address in _data:
377+
self._sb_adv_data = _data[self._device.address]
364378

365379
return self._sb_adv_data
366380

@@ -493,20 +507,18 @@ async def get_basic_info(self) -> dict[str, Any] | None:
493507
def switch_mode(self) -> Any:
494508
"""Return true or false from cache."""
495509
# To get actual position call update() first.
496-
if not self._sb_adv_data.get("data"):
497-
return None
498-
return self._sb_adv_data["data"].get("switchMode")
510+
return self._get_adv_value("switchMode")
499511

500512
def is_on(self) -> Any:
501513
"""Return switch state from cache."""
502514
# To get actual position call update() first.
503-
if not self._sb_adv_data.get("data"):
515+
value = self._get_adv_value("isOn")
516+
if value is None:
504517
return None
505518

506519
if self._inverse:
507-
return not self._sb_adv_data["data"].get("isOn")
508-
509-
return self._sb_adv_data["data"].get("isOn")
520+
return not value
521+
return value
510522

511523

512524
class SwitchbotCurtain(SwitchbotDevice):
@@ -570,9 +582,7 @@ async def update(self, interface: int | None = None) -> None:
570582
def get_position(self) -> Any:
571583
"""Return cached position (0-100) of Curtain."""
572584
# To get actual position call update() first.
573-
if not self._sb_adv_data.get("data"):
574-
return None
575-
return self._sb_adv_data["data"].get("position")
585+
return self._get_adv_value("position")
576586

577587
async def get_basic_info(self) -> dict[str, Any] | None:
578588
"""Get device basic settings."""
@@ -676,9 +686,7 @@ async def get_extended_info_adv(self) -> dict[str, Any] | None:
676686
def get_light_level(self) -> Any:
677687
"""Return cached light level."""
678688
# To get actual light level call update() first.
679-
if not self._sb_adv_data.get("data"):
680-
return None
681-
return self._sb_adv_data["data"].get("lightLevel")
689+
return self._get_adv_value("lightLevel")
682690

683691
def is_reversed(self) -> bool:
684692
"""Return True if curtain position is opposite from SB data."""
@@ -687,6 +695,4 @@ def is_reversed(self) -> bool:
687695
def is_calibrated(self) -> Any:
688696
"""Return True curtain is calibrated."""
689697
# To get actual light level call update() first.
690-
if not self._sb_adv_data.get("data"):
691-
return None
692-
return self._sb_adv_data["data"].get("calibration")
698+
return self._get_adv_value("calibration")

0 commit comments

Comments
 (0)