Skip to content

Commit d316004

Browse files
committed
Add API for LISTEN; document get_settings()
1 parent 2b17b14 commit d316004

File tree

10 files changed

+198
-31
lines changed

10 files changed

+198
-31
lines changed

asyncpg/_testbase.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818

1919
from asyncpg import cluster as pg_cluster
20+
from asyncpg import pool as pg_pool
2021

2122

2223
@contextlib.contextmanager
@@ -127,6 +128,11 @@ def setUp(self):
127128
'log_connections': 'on'
128129
})
129130

131+
def create_pool(self, **kwargs):
132+
addr = self.cluster.get_connection_addr()
133+
return pg_pool.create_pool(host=addr[0], port=addr[1],
134+
loop=self.loop, **kwargs)
135+
130136

131137
class ConnectedTestCase(ClusterTestCase):
132138

asyncpg/connection.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class Connection:
2929
__slots__ = ('_protocol', '_transport', '_loop', '_types_stmt',
3030
'_type_by_name_stmt', '_top_xact', '_uid', '_aborted',
3131
'_stmt_cache_max_size', '_stmt_cache', '_stmts_to_close',
32-
'_addr', '_opts', '_command_timeout')
32+
'_addr', '_opts', '_command_timeout', '_listeners')
3333

3434
def __init__(self, protocol, transport, loop, addr, opts, *,
3535
statement_cache_size, command_timeout):
@@ -51,7 +51,44 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
5151

5252
self._command_timeout = command_timeout
5353

54+
self._listeners = {}
55+
56+
async def add_listener(self, channel, callback):
57+
"""Add a listener for Postgres notifications.
58+
59+
:param str channel: Channel to listen on.
60+
:param callable callback:
61+
A callable receiving the following arguments:
62+
**connection**: a Connection the callback is registered with;
63+
**pid**: PID of the Postgres server that sent the notification;
64+
**channel**: name of the channel the notification was sent to;
65+
**payload**: the payload.
66+
"""
67+
if channel not in self._listeners:
68+
await self.fetch('LISTEN {}'.format(channel))
69+
self._listeners[channel] = set()
70+
self._listeners[channel].add(callback)
71+
72+
async def remove_listener(self, channel, callback):
73+
"""Remove a listening callback on the specified channel."""
74+
if channel not in self._listeners:
75+
return
76+
if callback not in self._listeners[channel]:
77+
return
78+
self._listeners[channel].remove(callback)
79+
if not self._listeners[channel]:
80+
del self._listeners[channel]
81+
await self.fetch('UNLISTEN {}'.format(channel))
82+
83+
def get_server_pid(self):
84+
"""Return the PID of the Postgres server the connection is bound to."""
85+
return self._protocol.get_server_pid()
86+
5487
def get_settings(self):
88+
"""Return connection settings.
89+
90+
:return: :class:`~asyncpg.ConnectionSettings`.
91+
"""
5592
return self._protocol.get_settings()
5693

5794
def transaction(self, *, isolation='read_committed', readonly=False,
@@ -269,17 +306,20 @@ async def close(self):
269306
if self.is_closed():
270307
return
271308
self._close_stmts()
309+
self._listeners = {}
272310
self._aborted = True
273311
protocol = self._protocol
274312
await protocol.close()
275313

276314
def terminate(self):
277315
"""Terminate the connection without waiting for pending data."""
278316
self._close_stmts()
317+
self._listeners = {}
279318
self._aborted = True
280319
self._protocol.abort()
281320

282321
async def reset(self):
322+
self._listeners = {}
283323
await self.execute('''
284324
SET SESSION AUTHORIZATION DEFAULT;
285325
RESET ALL;
@@ -351,6 +391,20 @@ async def cancel():
351391

352392
self._loop.create_task(cancel())
353393

394+
def _notify(self, pid, channel, payload):
395+
if channel not in self._listeners:
396+
return
397+
398+
for cb in self._listeners[channel]:
399+
try:
400+
cb(self, pid, channel, payload)
401+
except Exception as ex:
402+
self._loop.call_exception_handler({
403+
'message': 'Unhandled exception in asyncpg notification '
404+
'listener callback {!r}'.format(cb),
405+
'exception': ex
406+
})
407+
354408

355409
async def connect(dsn=None, *,
356410
host=None, port=None,

asyncpg/protocol/coreproto.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ cdef class CoreProtocol:
101101

102102
cdef _parse_msg_authentication(self)
103103
cdef _parse_msg_parameter_status(self)
104+
cdef _parse_msg_notification(self)
104105
cdef _parse_msg_backend_key_data(self)
105106
cdef _parse_msg_ready_for_query(self)
106107
cdef _parse_data_msgs(self)
@@ -140,5 +141,6 @@ cdef class CoreProtocol:
140141
cdef _decode_row(self, const char* buf, int32_t buf_len)
141142

142143
cdef _on_result(self)
144+
cdef _on_notification(self, pid, channel, payload)
143145
cdef _set_server_parameter(self, name, val)
144146
cdef _on_connection_lost(self, exc)

asyncpg/protocol/coreproto.pyx

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ cdef class CoreProtocol:
4646
# ParameterStatus
4747
self._parse_msg_parameter_status()
4848
continue
49+
elif mtype == b'A':
50+
# NotificationResponse
51+
self._parse_msg_notification()
52+
continue
4953

5054
if state == PROTOCOL_AUTH:
5155
self._process__auth(mtype)
@@ -304,6 +308,12 @@ cdef class CoreProtocol:
304308

305309
self._set_server_parameter(name, val)
306310

311+
cdef _parse_msg_notification(self):
312+
pid = self.buffer.read_int32()
313+
channel = self.buffer.read_cstr().decode(self.encoding)
314+
payload = self.buffer.read_cstr().decode(self.encoding)
315+
self._on_notification(pid, channel, payload)
316+
307317
cdef _parse_msg_authentication(self):
308318
cdef:
309319
int32_t status
@@ -617,6 +627,9 @@ cdef class CoreProtocol:
617627
cdef _on_result(self):
618628
pass
619629

630+
cdef _on_notification(self, pid, channel, payload):
631+
pass
632+
620633
cdef _on_connection_lost(self, exc):
621634
pass
622635

asyncpg/protocol/protocol.pyx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ cdef class BaseProtocol(CoreProtocol):
113113
def set_connection(self, connection):
114114
self.connection = connection
115115

116+
def get_server_pid(self):
117+
return self.backend_pid
118+
116119
def get_settings(self):
117120
return self.settings
118121

@@ -445,6 +448,9 @@ cdef class BaseProtocol(CoreProtocol):
445448
self.last_query = None
446449
self.return_extra = False
447450

451+
cdef _on_notification(self, pid, channel, payload):
452+
self.connection._notify(pid, channel, payload)
453+
448454
cdef _on_connection_lost(self, exc):
449455
if self.closing:
450456
# The connection was lost because

asyncpg/protocol/settings.pyx

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ cdef class ConnectionSettings:
3232
cpdef inline register_data_types(self, types):
3333
self._data_codecs.add_types(types)
3434

35-
cpdef inline add_python_codec(self, typeoid, typename, typeschema, typekind,
36-
encoder, decoder, binary):
35+
cpdef inline add_python_codec(self, typeoid, typename, typeschema,
36+
typekind, encoder, decoder, binary):
3737
self._data_codecs.add_python_codec(typeoid, typename, typeschema,
3838
typekind, encoder, decoder, binary)
3939

40-
cpdef inline set_builtin_type_codec(self, typeoid, typename, typeschema, typekind,
41-
alias_to):
40+
cpdef inline set_builtin_type_codec(self, typeoid, typename, typeschema,
41+
typekind, alias_to):
4242
self._data_codecs.set_builtin_type_codec(typeoid, typename, typeschema,
4343
typekind, alias_to)
4444

docs/api/index.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,24 @@ of values either by a numeric index or by a field name:
301301

302302
Return an iterator over ``(field, value)`` pairs.
303303

304+
305+
.. class:: ConnectionSettings()
306+
307+
A read-only collection of Connection settings.
308+
309+
.. describe:: settings.setting_name
310+
311+
Return the value of the "setting_name" setting. Raises an
312+
``AttributeError`` if the setting is not defined.
313+
314+
Example:
315+
316+
.. code-block:: pycon
317+
318+
>>> connection.get_settings().client_encoding
319+
'UTF8'
320+
321+
304322
Introspection
305323
=============
306324

tests/test_connect.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ async def test_connect_1(self):
2222
await asyncpg.connect(user="__does_not_exist__", loop=self.loop)
2323

2424

25+
class TestSettings(tb.ConnectedTestCase):
26+
27+
async def test_get_settings_01(self):
28+
self.assertEqual(
29+
self.con.get_settings().client_encoding,
30+
'UTF8')
31+
32+
2533
class TestAuthentication(tb.ConnectedTestCase):
2634
def setUp(self):
2735
super().setUp()

tests/test_listeners.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (C) 2016-present the ayncpg authors and contributors
2+
# <see AUTHORS file>
3+
#
4+
# This module is part of asyncpg and is released under
5+
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
6+
7+
8+
import asyncio
9+
10+
from asyncpg import _testbase as tb
11+
12+
13+
class TestListeners(tb.ClusterTestCase):
14+
15+
async def test_listen_01(self):
16+
async with self.create_pool(database='postgres') as pool:
17+
async with pool.acquire() as con:
18+
19+
q1 = asyncio.Queue(loop=self.loop)
20+
q2 = asyncio.Queue(loop=self.loop)
21+
22+
def listener1(*args):
23+
q1.put_nowait(args)
24+
25+
def listener2(*args):
26+
q2.put_nowait(args)
27+
28+
await con.add_listener('test', listener1)
29+
await con.add_listener('test', listener2)
30+
31+
await con.execute("NOTIFY test, 'aaaa'")
32+
33+
self.assertEqual(
34+
await q1.get(),
35+
(con, con.get_server_pid(), 'test', 'aaaa'))
36+
self.assertEqual(
37+
await q2.get(),
38+
(con, con.get_server_pid(), 'test', 'aaaa'))
39+
40+
await con.remove_listener('test', listener2)
41+
42+
await con.execute("NOTIFY test, 'aaaa'")
43+
44+
self.assertEqual(
45+
await q1.get(),
46+
(con, con.get_server_pid(), 'test', 'aaaa'))
47+
with self.assertRaises(asyncio.TimeoutError):
48+
await asyncio.wait_for(q2.get(),
49+
timeout=0.05, loop=self.loop)
50+
51+
await con.reset()
52+
await con.remove_listener('test', listener1)
53+
await con.execute("NOTIFY test, 'aaaa'")
54+
55+
with self.assertRaises(asyncio.TimeoutError):
56+
await asyncio.wait_for(q1.get(),
57+
timeout=0.05, loop=self.loop)
58+
with self.assertRaises(asyncio.TimeoutError):
59+
await asyncio.wait_for(q2.get(),
60+
timeout=0.05, loop=self.loop)
61+
62+
async def test_listen_02(self):
63+
async with self.create_pool(database='postgres') as pool:
64+
async with pool.acquire() as con1, pool.acquire() as con2:
65+
66+
q1 = asyncio.Queue(loop=self.loop)
67+
68+
def listener1(*args):
69+
q1.put_nowait(args)
70+
71+
await con1.add_listener('ipc', listener1)
72+
await con2.execute("NOTIFY ipc, 'hello'")
73+
74+
self.assertEqual(
75+
await q1.get(),
76+
(con1, con2.get_server_pid(), 'ipc', 'hello'))

tests/test_pool.py

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77

88
import asyncio
9-
import asyncpg
109

1110
from asyncpg import _testbase as tb
1211

@@ -16,11 +15,8 @@ class TestPool(tb.ClusterTestCase):
1615
async def test_pool_01(self):
1716
for n in {1, 3, 5, 10, 20, 100}:
1817
with self.subTest(tasksnum=n):
19-
addr = self.cluster.get_connection_addr()
20-
pool = await asyncpg.create_pool(host=addr[0], port=addr[1],
21-
database='postgres',
22-
loop=self.loop, min_size=5,
23-
max_size=10)
18+
pool = await self.create_pool(database='postgres',
19+
min_size=5, max_size=10)
2420

2521
async def worker():
2622
con = await pool.acquire()
@@ -34,11 +30,8 @@ async def worker():
3430
async def test_pool_02(self):
3531
for n in {1, 3, 5, 10, 20, 100}:
3632
with self.subTest(tasksnum=n):
37-
addr = self.cluster.get_connection_addr()
38-
async with asyncpg.create_pool(host=addr[0], port=addr[1],
39-
database='postgres',
40-
loop=self.loop, min_size=5,
41-
max_size=5) as pool:
33+
async with self.create_pool(database='postgres',
34+
min_size=5, max_size=5) as pool:
4235

4336
async def worker():
4437
con = await pool.acquire(timeout=1)
@@ -49,11 +42,8 @@ async def worker():
4942
await asyncio.gather(*tasks, loop=self.loop)
5043

5144
async def test_pool_03(self):
52-
addr = self.cluster.get_connection_addr()
53-
pool = await asyncpg.create_pool(host=addr[0], port=addr[1],
54-
database='postgres',
55-
loop=self.loop, min_size=1,
56-
max_size=1)
45+
pool = await self.create_pool(database='postgres',
46+
min_size=1, max_size=1)
5747

5848
con = await pool.acquire(timeout=1)
5949
with self.assertRaises(asyncio.TimeoutError):
@@ -63,11 +53,8 @@ async def test_pool_03(self):
6353
del con
6454

6555
async def test_pool_04(self):
66-
addr = self.cluster.get_connection_addr()
67-
pool = await asyncpg.create_pool(host=addr[0], port=addr[1],
68-
database='postgres',
69-
loop=self.loop, min_size=1,
70-
max_size=1)
56+
pool = await self.create_pool(database='postgres',
57+
min_size=1, max_size=1)
7158

7259
con = await pool.acquire(timeout=0.1)
7360
con.terminate()
@@ -84,11 +71,8 @@ async def test_pool_04(self):
8471
async def test_pool_05(self):
8572
for n in {1, 3, 5, 10, 20, 100}:
8673
with self.subTest(tasksnum=n):
87-
addr = self.cluster.get_connection_addr()
88-
pool = await asyncpg.create_pool(host=addr[0], port=addr[1],
89-
database='postgres',
90-
loop=self.loop, min_size=5,
91-
max_size=10)
74+
pool = await self.create_pool(database='postgres',
75+
min_size=5, max_size=10)
9276

9377
async def worker():
9478
async with pool.acquire() as con:

0 commit comments

Comments
 (0)