Skip to content

Commit 5cbce91

Browse files
committed
Support multiple devices
1 parent 8a6dea5 commit 5cbce91

File tree

3 files changed

+93
-33
lines changed

3 files changed

+93
-33
lines changed

pslab/sciencelab.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,13 @@ class ScienceLab(SerialHandler):
3232
nrf : pslab.peripherals.NRF24L01
3333
"""
3434

35-
def __init__(self):
36-
super().__init__()
35+
def __init__(
36+
self,
37+
port: str = None,
38+
baudrate: int = 1000000,
39+
timeout: float = 1.0,
40+
):
41+
super().__init__(port, baudrate, timeout)
3742
self.logic_analyzer = LogicAnalyzer(device=self)
3843
self.oscilloscope = Oscilloscope(device=self)
3944
self.waveform_generator = WaveformGenerator(device=self)
@@ -210,10 +215,10 @@ def _read_program_address(self, address: int):
210215
return data
211216

212217
def _device_id(self):
213-
a = self.read_program_address(0x800FF8)
214-
b = self.read_program_address(0x800FFA)
215-
c = self.read_program_address(0x800FFC)
216-
d = self.read_program_address(0x800FFE)
218+
a = self._read_program_address(0x800FF8)
219+
b = self._read_program_address(0x800FFA)
220+
c = self._read_program_address(0x800FFC)
221+
d = self._read_program_address(0x800FFE)
217222
val = d | (c << 16) | (b << 32) | (a << 48)
218223
return val
219224

pslab/serial_handler.py

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,40 @@
2323
logger = logging.getLogger(__name__)
2424

2525

26+
def detect():
27+
"""Detect connected PSLab devices.
28+
29+
Returns
30+
-------
31+
devices : dict of str: str
32+
Dictionary containing port name as keys and device version on that
33+
port as values.
34+
"""
35+
regex = []
36+
37+
for vid, pid in zip(SerialHandler._USB_VID, SerialHandler._USB_PID):
38+
regex.append(f"{vid:04x}:{pid:04x}")
39+
40+
regex = "(" + "|".join(regex) + ")"
41+
port_info_generator = list_ports.grep(regex)
42+
pslab_devices = {}
43+
44+
for port_info in port_info_generator:
45+
version = _get_version(port_info.device)
46+
if any(expected in version for expected in ["PSLab", "CSpark"]):
47+
pslab_devices[port_info.device] = version
48+
49+
return pslab_devices
50+
51+
52+
def _get_version(port: str) -> str:
53+
interface = serial.Serial(port=port, baudrate=1e6, timeout=1)
54+
interface.write(CP.COMMON)
55+
interface.write(CP.GET_VERSION)
56+
version = interface.readline()
57+
return version.decode("utf-8")
58+
59+
2660
class SerialHandler:
2761
"""Provides methods for communicating with the PSLab hardware.
2862
@@ -98,9 +132,11 @@ def connect(
98132
Parameters
99133
----------
100134
port : str, optional
101-
The name of the port to which the PSLab is connected as a string. On
102-
Posix this is a path, e.g. "/dev/ttyACM0". On Windows, it's a numbered
103-
COM port, e.g. "COM5". Will be autodetected if not specified.
135+
The name of the port to which the PSLab is connected as a string.
136+
On Posix this is a path, e.g. "/dev/ttyACM0". On Windows, it's a
137+
numbered COM port, e.g. "COM5". Will be autodetected if not
138+
specified. If multiple PSLab devices are connected, port must be
139+
specified.
104140
baudrate : int, optional
105141
Symbol rate in bit/s. The default value is 1000000.
106142
timeout : float, optional
@@ -111,6 +147,8 @@ def connect(
111147
------
112148
SerialException
113149
If connection could not be established.
150+
RuntimeError
151+
If ultiple devices are connected and no port was specified.
114152
"""
115153
# serial.Serial opens automatically if port is not None.
116154
self.interface = serial.Serial(
@@ -119,28 +157,31 @@ def connect(
119157
timeout=timeout,
120158
write_timeout=timeout,
121159
)
160+
pslab_devices = detect()
122161

123162
if self.interface.is_open:
124163
# User specified a port.
125164
version = self.get_version()
126165
else:
127-
regex = []
128-
for vid, pid in zip(self._USB_VID, self._USB_PID):
129-
regex.append(f"{vid:04x}:{pid:04x}")
130-
131-
regex = "(" + "|".join(regex) + ")"
132-
port_info_generator = list_ports.grep(regex)
133-
134-
for port_info in port_info_generator:
135-
self.interface.port = port_info.device
166+
if len(pslab_devices) == 1:
167+
self.interface.port = list(pslab_devices.keys())[0]
136168
self.interface.open()
137169
version = self.get_version()
138-
if any(expected in version for expected in ["PSLab", "CSpark"]):
139-
break
170+
elif len(pslab_devices) > 1:
171+
found = ""
172+
173+
for port, version in pslab_devices.items():
174+
found += f"{port}: {version}"
175+
176+
raise RuntimeError(
177+
"Multiple PSLab devices found:\n"
178+
f"{found}"
179+
"Please choose a device by specifying a port."
180+
)
140181
else:
141182
version = ""
142183

143-
if any(expected in version for expected in ["PSLab", "CSpark"]):
184+
if self.interface.port in pslab_devices:
144185
self.version = version
145186
logger.info(f"Connected to {self.version} on {self.interface.port}.")
146187
else:
@@ -174,13 +215,11 @@ def reconnect(
174215
port = self.interface.port if port is None else port
175216
timeout = self.interface.timeout if timeout is None else timeout
176217

177-
self.interface = serial.Serial(
218+
self.connect(
178219
port=port,
179220
baudrate=baudrate,
180221
timeout=timeout,
181-
write_timeout=timeout,
182222
)
183-
self.connect()
184223

185224
def get_version(self) -> str:
186225
"""Query PSLab for its version and return it as a decoded string.

tests/test_serial_handler.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,19 @@
33
from serial.tools.list_ports_common import ListPortInfo
44

55
import pslab.protocol as CP
6-
from pslab.serial_handler import SerialHandler
6+
from pslab.serial_handler import detect, SerialHandler
77

88
VERSION = "PSLab vMOCK\n"
99
PORT = "mock_port"
10+
PORT2 = "mock_port_2"
1011

1112

12-
def mock_ListPortInfo(found=True):
13+
def mock_ListPortInfo(found=True, multiple=False):
1314
if found:
14-
yield ListPortInfo(device=PORT)
15+
if multiple:
16+
yield from [ListPortInfo(device=PORT), ListPortInfo(device=PORT2)]
17+
else:
18+
yield ListPortInfo(device=PORT)
1519
else:
1620
return
1721

@@ -20,12 +24,14 @@ def mock_ListPortInfo(found=True):
2024
def mock_serial(mocker):
2125
serial_patch = mocker.patch("pslab.serial_handler.serial.Serial")
2226
serial_patch().readline.return_value = VERSION.encode()
27+
serial_patch().is_open = False
2328
return serial_patch
2429

2530

2631
@pytest.fixture
27-
def mock_handler(mocker, mock_serial):
32+
def mock_handler(mocker, mock_serial, mock_list_ports):
2833
mocker.patch("pslab.serial_handler.SerialHandler._check_udev")
34+
mock_list_ports.grep.return_value = mock_ListPortInfo()
2935
return SerialHandler()
3036

3137

@@ -34,28 +40,39 @@ def mock_list_ports(mocker):
3440
return mocker.patch("pslab.serial_handler.list_ports")
3541

3642

43+
def test_detect(mocker, mock_serial, mock_list_ports):
44+
mock_list_ports.grep.return_value = mock_ListPortInfo(multiple=True)
45+
assert len(detect()) == 2
46+
47+
3748
def test_connect_scan_port(mocker, mock_serial, mock_list_ports):
38-
mock_serial().is_open = False
3949
mock_list_ports.grep.return_value = mock_ListPortInfo()
4050
mocker.patch("pslab.serial_handler.SerialHandler._check_udev")
4151
SerialHandler()
4252
mock_serial().open.assert_called()
4353

4454

4555
def test_connect_scan_failure(mocker, mock_serial, mock_list_ports):
46-
mock_serial().is_open = False
4756
mock_list_ports.grep.return_value = mock_ListPortInfo(found=False)
4857
mocker.patch("pslab.serial_handler.SerialHandler._check_udev")
4958
with pytest.raises(SerialException):
5059
SerialHandler()
5160

5261

62+
def test_connect_multiple_connected(mocker, mock_serial, mock_list_ports):
63+
mock_list_ports.grep.return_value = mock_ListPortInfo(multiple=True)
64+
mocker.patch("pslab.serial_handler.SerialHandler._check_udev")
65+
with pytest.raises(RuntimeError):
66+
SerialHandler()
67+
68+
5369
def test_disconnect(mock_serial, mock_handler):
5470
mock_handler.disconnect()
5571
mock_serial().close.assert_called()
5672

5773

58-
def test_reconnect(mock_serial, mock_handler):
74+
def test_reconnect(mock_serial, mock_handler, mock_list_ports):
75+
mock_list_ports.grep.return_value = mock_ListPortInfo()
5976
mock_handler.reconnect()
6077
mock_serial().close.assert_called()
6178

@@ -67,10 +84,9 @@ def test_get_version(mock_serial, mock_handler):
6784

6885

6986
def test_get_ack_success(mock_serial, mock_handler):
70-
H = SerialHandler()
7187
success = 1
7288
mock_serial().read.return_value = CP.Byte.pack(success)
73-
assert H.get_ack() == success
89+
assert mock_handler.get_ack() == success
7490

7591

7692
def test_get_ack_failure(mock_serial, mock_handler):

0 commit comments

Comments
 (0)