1
+ {-# LANGUAGE LambdaCase #-}
2
+ {-# LANGUAGE OverloadedStrings #-}
3
+ {-# LANGUAGE ScopedTypeVariables #-}
4
+
5
+ module Network.Transport.QUIC.Internal (
6
+ createTransport ,
7
+ QUICAddr (.. ),
8
+ encodeQUICAddr ,
9
+ decodeQUICAddr ,
10
+
11
+ -- * Re-export to generate credentials
12
+ Credential ,
13
+ credentialLoadX509 ,
14
+ ) where
15
+
16
+ import Control.Concurrent (ThreadId , forkFinally , killThread , myThreadId )
17
+ import Control.Concurrent.STM (atomically )
18
+ import Control.Concurrent.STM.TQueue (
19
+ TQueue ,
20
+ newTQueueIO ,
21
+ readTQueue ,
22
+ writeTQueue ,
23
+ )
24
+ import Control.Exception (bracket , try )
25
+ import Control.Monad (void )
26
+ import Data.Bifunctor (first )
27
+ import Data.ByteString (StrictByteString )
28
+ import Data.ByteString qualified as BS
29
+ import Data.Foldable (traverse_ )
30
+ import Data.Functor (($>) , (<&>) )
31
+ import Data.IORef (IORef , newIORef , readIORef , writeIORef )
32
+ import Data.List.NonEmpty (NonEmpty )
33
+ import Data.Set (Set )
34
+ import Data.Set qualified as Set
35
+ import GHC.IORef (atomicModifyIORef'_ )
36
+ import Network.QUIC (Stream )
37
+ import Network.QUIC qualified as QUIC
38
+ import Network.QUIC.Client qualified as QUIC.Client
39
+ import Network.QUIC.Server qualified as QUIC.Server
40
+ import Network.TLS (Credential )
41
+ import Network.Transport (
42
+ ConnectErrorCode (ConnectNotFound ),
43
+ ConnectHints ,
44
+ Connection (.. ),
45
+ ConnectionId ,
46
+ EndPoint (.. ),
47
+ EndPointAddress ,
48
+ Event (.. ),
49
+ EventErrorCode (EventEndPointFailed ),
50
+ NewEndPointErrorCode ,
51
+ NewMulticastGroupErrorCode (NewMulticastGroupUnsupported ),
52
+ Reliability ,
53
+ ResolveMulticastGroupErrorCode (ResolveMulticastGroupUnsupported ),
54
+ SendErrorCode (.. ),
55
+ Transport (.. ),
56
+ TransportError (.. ),
57
+ )
58
+ import Network.Transport.QUIC.Internal.Configuration (credentialLoadX509 , mkClientConfig , mkServerConfig )
59
+ import Network.Transport.QUIC.Internal.QUICAddr (QUICAddr (.. ), decodeQUICAddr , encodeQUICAddr )
60
+ import Network.Transport.QUIC.Internal.TransportState (TransportState , newTransportState , registerEndpoint , traverseTransportState )
61
+
62
+ {- | Create a new Transport.
63
+
64
+ Only a single transport should be created per Haskell process
65
+ (threads can, and should, create their own endpoints though).
66
+ -}
67
+ createTransport ::
68
+ QUICAddr ->
69
+ NonEmpty Credential ->
70
+ IO Transport
71
+ createTransport quicAddr creds = do
72
+ transportState <- newTransportState
73
+ pure $
74
+ Transport
75
+ (newEndpoint transportState quicAddr creds)
76
+ (closeQUICTransport transportState)
77
+
78
+ newEndpoint ::
79
+ TransportState ->
80
+ QUICAddr ->
81
+ NonEmpty Credential ->
82
+ IO (Either (TransportError NewEndPointErrorCode ) EndPoint )
83
+ newEndpoint transportState quicAddr@ (QUICAddr host port) creds = do
84
+ eventQueue <- newTQueueIO
85
+
86
+ state <- EndpointState <$> newIORef mempty
87
+
88
+ serverConfig <- mkServerConfig host port creds
89
+ serverThread <-
90
+ forkFinally
91
+ ( QUIC.Server. run
92
+ serverConfig
93
+ ( withQUICStream $
94
+ -- TODO: create a bidirectional stream
95
+ -- which can be re-used for sending
96
+ \ stream ->
97
+ -- We register which threads are actively receiving or sending
98
+ -- data such that we can cleanly stop
99
+ withThreadRegistered state $ do
100
+ -- TODO: how to ensure positivity of ConnectionId? QUIC StreamID should be a 62 bit integer,
101
+ -- so there's room to make it a positive 64 bit integer (ConnectionId ~ Word64)
102
+ let connId = fromIntegral (QUIC. streamId stream)
103
+ receiveLoop connId stream eventQueue
104
+ )
105
+ )
106
+ ( \ case
107
+ Left exc -> atomically (writeTQueue eventQueue (ErrorEvent (TransportError EventEndPointFailed (show exc))))
108
+ Right _ -> pure ()
109
+ )
110
+
111
+ let endpoint =
112
+ EndPoint
113
+ (atomically (readTQueue eventQueue))
114
+ (encodeQUICAddr quicAddr)
115
+ (connectQUIC creds)
116
+ (pure . Left $ TransportError NewMulticastGroupUnsupported " Multicast not supported" )
117
+ (pure . Left . const (TransportError ResolveMulticastGroupUnsupported " Multicast not supported" ))
118
+ (stopAllThreads state >> killThread serverThread >> atomically (writeTQueue eventQueue EndPointClosed ))
119
+ void $ transportState `registerEndpoint` endpoint
120
+ pure $ Right endpoint
121
+ where
122
+ receiveLoop ::
123
+ ConnectionId ->
124
+ QUIC. Stream ->
125
+ TQueue Event ->
126
+ IO ()
127
+ receiveLoop connId stream eventQueue = do
128
+ incoming <- QUIC. recvStream stream 1024 -- TODO: variable length?
129
+ -- TODO: check some state whether we should stop all connections
130
+ if BS. null incoming
131
+ then do
132
+ atomically (writeTQueue eventQueue (ConnectionClosed connId))
133
+ else do
134
+ atomically (writeTQueue eventQueue (Received connId [incoming]))
135
+ receiveLoop connId stream eventQueue
136
+
137
+ withQUICStream :: (QUIC. Stream -> IO a ) -> QUIC. Connection -> IO a
138
+ withQUICStream f conn =
139
+ bracket
140
+ (QUIC. waitEstablished conn >> QUIC. acceptStream conn)
141
+ (\ stream -> QUIC. closeStream stream >> QUIC.Server. stop conn)
142
+ f
143
+
144
+ connectQUIC ::
145
+ NonEmpty Credential ->
146
+ EndPointAddress ->
147
+ Reliability ->
148
+ ConnectHints ->
149
+ IO (Either (TransportError ConnectErrorCode ) Connection )
150
+ connectQUIC creds endpointAddress _reliability _connectHints =
151
+ case decodeQUICAddr endpointAddress of
152
+ Left errmsg -> pure $ Left $ TransportError ConnectNotFound (" Could not decode QUIC address: " <> errmsg)
153
+ Right (QUICAddr hostname port) ->
154
+ try $ do
155
+ clientConfig <- mkClientConfig hostname port creds
156
+
157
+ QUIC.Client. run clientConfig $ \ conn -> do
158
+ QUIC. waitEstablished conn
159
+ stream <- QUIC. stream conn
160
+
161
+ pure $
162
+ Connection
163
+ (sendQUIC stream)
164
+ (QUIC. closeStream stream)
165
+ where
166
+ sendQUIC :: Stream -> [StrictByteString ] -> IO (Either (TransportError SendErrorCode ) () )
167
+ sendQUIC stream payloads =
168
+ try (QUIC. sendStreamMany stream payloads)
169
+ <&> first
170
+ ( \ case
171
+ QUIC. StreamIsClosed -> TransportError SendClosed " QUIC stream is closed"
172
+ QUIC. ConnectionIsClosed reason -> TransportError SendClosed (show reason)
173
+ other -> TransportError SendFailed (show other)
174
+ )
175
+
176
+ closeQUICTransport :: TransportState -> IO ()
177
+ closeQUICTransport = flip traverseTransportState (\ _ endpoint -> closeEndPoint endpoint)
178
+
179
+ {- | We keep track of all threads actively listening on QUIC streams
180
+ so that we can cleanly stop these threads when closing the endpoint.
181
+
182
+ See 'withThreadRegistered' for a combinator which automatically keeps
183
+ track of these threads
184
+ -}
185
+ newtype EndpointState = EndpointState
186
+ { threads :: IORef (Set ThreadId )
187
+ }
188
+
189
+ withThreadRegistered :: EndpointState -> IO a -> IO a
190
+ withThreadRegistered state f =
191
+ bracket
192
+ registerThread
193
+ unregisterThread
194
+ (const f)
195
+ where
196
+ registerThread =
197
+ myThreadId
198
+ >>= \ tid ->
199
+ atomicModifyIORef'_ (threads state) (Set. insert tid)
200
+ $> tid
201
+
202
+ unregisterThread tid =
203
+ atomicModifyIORef'_ (threads state) (Set. insert tid)
204
+
205
+ stopAllThreads :: EndpointState -> IO ()
206
+ stopAllThreads (EndpointState tds) = do
207
+ readIORef tds >>= traverse_ killThread
208
+ writeIORef tds mempty -- so that we can call `closeQUICTransport` even after the endpoint has been closed
0 commit comments