1+ module Network.Transport.QUIC (
2+ createTransport ,
3+ QUICAddr (.. ),
4+ ) where
5+
6+ import Control.Concurrent.STM (atomically )
7+ import Control.Concurrent.STM.TQueue (
8+ TQueue ,
9+ newTQueueIO ,
10+ readTQueue ,
11+ writeTQueue ,
12+ )
13+ import Data.ByteString qualified as BS
14+ import Data.ByteString.Char8 qualified as BS8
15+ import Data.Set (Set )
16+ import Data.Set qualified as Set
17+ import Network.QUIC qualified as QUIC
18+ import Network.QUIC.Server (defaultServerConfig )
19+ import Network.QUIC.Server qualified as QUIC.Server
20+ import Network.Transport (ConnectionId , EndPoint (.. ), EndPointAddress (EndPointAddress ), Event (.. ), NewEndPointErrorCode , Transport (.. ), TransportError (.. ))
21+
22+ import Control.Concurrent (ThreadId , killThread , myThreadId )
23+ import Control.Exception (bracket )
24+ import Data.Foldable (traverse_ )
25+ import Data.Functor (($>) )
26+ import Data.IORef (IORef , newIORef , readIORef )
27+ import GHC.IORef (atomicModifyIORef'_ )
28+ import Network.Socket (HostName , ServiceName )
29+
30+ {- | Create a new Transport.
31+
32+ Only a single transport should be created per Haskell process
33+ (threads can, and should, create their own endpoints though).
34+ -}
35+ createTransport :: QUICAddr -> IO Transport
36+ createTransport quicAddr = do
37+ pure $ Transport (newEndpoint quicAddr) closeQUICTransport
38+
39+ data QUICAddr = QUICAddr
40+ { quicBindHost :: ! HostName
41+ , quicBindPort :: ! ServiceName
42+ }
43+
44+ newEndpoint :: QUICAddr -> IO (Either (TransportError NewEndPointErrorCode ) EndPoint )
45+ newEndpoint quicAddr = do
46+ eventQueue <- newTQueueIO
47+
48+ state <- EndpointState <$> newIORef mempty
49+
50+ QUIC.Server. run
51+ defaultServerConfig
52+ ( withQUICStream $
53+ -- TODO: create a bidirectional stream
54+ -- which can be re-used for sending
55+ \ stream ->
56+ -- We register which threads are actively receiving or sending
57+ -- data such that we can cleanly stop
58+ withThreadRegistered state $ do
59+ -- TODO: how to ensure positivity of ConnectionId? QUIC StreamID should be a 62 bit integer,
60+ -- so there's room to make it a positive 64 bit integer (ConnectionId ~ Word64)
61+ let connId = fromIntegral (QUIC. streamId stream)
62+ receiveLoop connId stream eventQueue
63+ )
64+
65+ pure . Right $
66+ EndPoint
67+ (atomically (readTQueue eventQueue))
68+ (encodeQUICAddr quicAddr)
69+ _
70+ _
71+ _
72+ (stopAllThreads state)
73+ where
74+ receiveLoop ::
75+ ConnectionId ->
76+ QUIC. Stream ->
77+ TQueue Event ->
78+ IO ()
79+ receiveLoop connId stream eventQueue = do
80+ incoming <- QUIC. recvStream stream 1024 -- TODO: variable length?
81+ -- TODO: check some state whether we should stop all connections
82+ if BS. null incoming
83+ then do
84+ atomically (writeTQueue eventQueue (ConnectionClosed connId))
85+ else do
86+ atomically (writeTQueue eventQueue (Received connId [incoming]))
87+ receiveLoop connId stream eventQueue
88+
89+ withQUICStream :: (QUIC. Stream -> IO a ) -> QUIC. Connection -> IO a
90+ withQUICStream f conn =
91+ bracket
92+ (QUIC. acceptStream conn)
93+ (\ stream -> QUIC. closeStream stream >> QUIC.Server. stop conn)
94+ f
95+
96+ encodeQUICAddr :: QUICAddr -> EndPointAddress
97+ encodeQUICAddr (QUICAddr host port) = EndPointAddress (BS8. pack $ host <> " :" <> port)
98+
99+ closeQUICTransport :: IO ()
100+ closeQUICTransport = pure ()
101+
102+ {- | We keep track of all threads actively listening on QUIC streams
103+ so that we can cleanly stop these threads when closing the endpoint.
104+
105+ See 'withThreadRegistered' for a combinator which automatically keeps
106+ track of these threads
107+ -}
108+ newtype EndpointState = EndpointState
109+ { threads :: IORef (Set ThreadId )
110+ }
111+
112+ withThreadRegistered :: EndpointState -> IO a -> IO a
113+ withThreadRegistered state f =
114+ bracket
115+ registerThread
116+ unregisterThread
117+ (const f)
118+ where
119+ registerThread =
120+ myThreadId
121+ >>= \ tid ->
122+ atomicModifyIORef'_ (threads state) (Set. insert tid)
123+ $> tid
124+
125+ unregisterThread tid =
126+ atomicModifyIORef'_ (threads state) (Set. insert tid)
127+
128+ stopAllThreads :: EndpointState -> IO ()
129+ stopAllThreads (EndpointState tds) =
130+ readIORef tds >>= traverse_ killThread
0 commit comments