1
+ {-# LANGUAGE LambdaCase #-}
2
+ {-# LANGUAGE ScopedTypeVariables #-}
3
+
4
+ module Network.Transport.QUIC.Internal (
5
+ createTransport ,
6
+ QUICAddr (.. ),
7
+ encodeQUICAddr ,
8
+ decodeQUICAddr ,
9
+ ) where
10
+
11
+ import Control.Concurrent (ThreadId , forkIO , killThread , myThreadId )
12
+ import Control.Concurrent.STM (atomically )
13
+ import Control.Concurrent.STM.TQueue (
14
+ TQueue ,
15
+ newTQueueIO ,
16
+ readTQueue ,
17
+ writeTQueue ,
18
+ )
19
+ import Control.Exception (bracket , try )
20
+ import Control.Monad (void )
21
+ import Data.Bifunctor (first )
22
+ import Data.ByteString (StrictByteString )
23
+ import Data.ByteString qualified as BS
24
+ import Data.Foldable (traverse_ )
25
+ import Data.Functor (($>) , (<&>) )
26
+ import Data.IORef (IORef , newIORef , readIORef , writeIORef )
27
+ import Data.Set (Set )
28
+ import Data.Set qualified as Set
29
+ import GHC.IORef (atomicModifyIORef'_ )
30
+ import Network.QUIC (Stream )
31
+ import Network.QUIC qualified as QUIC
32
+ import Network.QUIC.Client (defaultClientConfig )
33
+ import Network.QUIC.Client qualified as QUIC.Client
34
+ import Network.QUIC.Server (defaultServerConfig )
35
+ import Network.QUIC.Server qualified as QUIC.Server
36
+ import Network.TLS (Credentials (Credentials ))
37
+ import Network.Transport (ConnectErrorCode (ConnectNotFound ), ConnectHints , Connection (.. ), ConnectionId , EndPoint (.. ), EndPointAddress , Event (.. ), NewEndPointErrorCode (NewEndPointFailed ), NewMulticastGroupErrorCode (NewMulticastGroupUnsupported ), Reliability , ResolveMulticastGroupErrorCode (ResolveMulticastGroupUnsupported ), SendErrorCode (.. ), Transport (.. ), TransportError (.. ))
38
+ import Network.Transport.QUIC.Internal.QUICAddr (QUICAddr (.. ), decodeQUICAddr , encodeQUICAddr )
39
+ import Network.Transport.QUIC.Internal.TLS qualified as TLS
40
+ import Network.Transport.QUIC.Internal.TransportState (TransportState , newTransportState , registerEndpoint , traverseTransportState )
41
+
42
+ -- | Create a new Transport.
43
+ --
44
+ -- Only a single transport should be created per Haskell process
45
+ -- (threads can, and should, create their own endpoints though).
46
+ createTransport ::
47
+ QUICAddr ->
48
+ -- | Path to certificate
49
+ FilePath ->
50
+ -- | Path to key
51
+ FilePath ->
52
+ IO Transport
53
+ createTransport quicAddr certFile keyFile = do
54
+ transportState <- newTransportState
55
+ pure $
56
+ Transport
57
+ (newEndpoint transportState quicAddr certFile keyFile)
58
+ (closeQUICTransport transportState)
59
+
60
+ newEndpoint ::
61
+ TransportState ->
62
+ QUICAddr ->
63
+ -- | Path to certificate
64
+ FilePath ->
65
+ -- | Path to key
66
+ FilePath ->
67
+ IO (Either (TransportError NewEndPointErrorCode ) EndPoint )
68
+ newEndpoint transportState quicAddr@ (QUICAddr host port) certFile keyFile = do
69
+ eventQueue <- newTQueueIO
70
+
71
+ state <- EndpointState <$> newIORef mempty
72
+ tlsSessionManager <- TLS. sessionManager
73
+ TLS. credentialLoadX509 certFile keyFile >>= \ case
74
+ Left errmsg -> pure . Left $ TransportError NewEndPointFailed errmsg
75
+ Right creds -> do
76
+ serverThread <-
77
+ forkIO $
78
+ QUIC.Server. run
79
+ ( defaultServerConfig
80
+ { QUIC.Server. scAddresses = [(read host, read port)]
81
+ , QUIC.Server. scSessionManager = tlsSessionManager
82
+ , QUIC.Server. scCredentials = Credentials [creds]
83
+ }
84
+ )
85
+ ( withQUICStream $
86
+ -- TODO: create a bidirectional stream
87
+ -- which can be re-used for sending
88
+ \ stream ->
89
+ -- We register which threads are actively receiving or sending
90
+ -- data such that we can cleanly stop
91
+ withThreadRegistered state $ do
92
+ -- TODO: how to ensure positivity of ConnectionId? QUIC StreamID should be a 62 bit integer,
93
+ -- so there's room to make it a positive 64 bit integer (ConnectionId ~ Word64)
94
+ let connId = fromIntegral (QUIC. streamId stream)
95
+ receiveLoop connId stream eventQueue
96
+ )
97
+
98
+ let endpoint =
99
+ EndPoint
100
+ (atomically (readTQueue eventQueue))
101
+ (encodeQUICAddr quicAddr)
102
+ connectQUIC
103
+ (pure . Left $ TransportError NewMulticastGroupUnsupported " Multicast not supported" )
104
+ (pure . Left . const (TransportError ResolveMulticastGroupUnsupported " Multicast not supported" ))
105
+ (stopAllThreads state >> killThread serverThread)
106
+ void $ transportState `registerEndpoint` endpoint
107
+ pure $ Right endpoint
108
+ where
109
+ receiveLoop ::
110
+ ConnectionId ->
111
+ QUIC. Stream ->
112
+ TQueue Event ->
113
+ IO ()
114
+ receiveLoop connId stream eventQueue = do
115
+ incoming <- QUIC. recvStream stream 1024 -- TODO: variable length?
116
+ -- TODO: check some state whether we should stop all connections
117
+ if BS. null incoming
118
+ then do
119
+ atomically (writeTQueue eventQueue (ConnectionClosed connId))
120
+ else do
121
+ atomically (writeTQueue eventQueue (Received connId [incoming]))
122
+ receiveLoop connId stream eventQueue
123
+
124
+ withQUICStream :: (QUIC. Stream -> IO a ) -> QUIC. Connection -> IO a
125
+ withQUICStream f conn =
126
+ bracket
127
+ (QUIC. waitEstablished conn >> QUIC. acceptStream conn)
128
+ (\ stream -> QUIC. closeStream stream >> QUIC.Server. stop conn)
129
+ f
130
+
131
+ connectQUIC ::
132
+ EndPointAddress ->
133
+ Reliability ->
134
+ ConnectHints ->
135
+ IO (Either (TransportError ConnectErrorCode ) Connection )
136
+ connectQUIC endpointAddress _reliability _connectHints =
137
+ case decodeQUICAddr endpointAddress of
138
+ Left errmsg -> pure $ Left $ TransportError ConnectNotFound (" Could not decode QUIC address: " <> errmsg)
139
+ Right (QUICAddr hostname port) ->
140
+ try $ do
141
+ let clientConfig =
142
+ defaultClientConfig
143
+ { QUIC.Client. ccServerName = hostname
144
+ , QUIC.Client. ccPortName = port
145
+ }
146
+
147
+ -- TODO: why is the TLS handshake failing?
148
+ QUIC.Client. run clientConfig $ \ conn -> do
149
+ QUIC. waitEstablished conn
150
+ stream <- QUIC. stream conn
151
+
152
+ pure $
153
+ Connection
154
+ (sendQUIC stream)
155
+ (QUIC. closeStream stream)
156
+ where
157
+ sendQUIC :: Stream -> [StrictByteString ] -> IO (Either (TransportError SendErrorCode ) () )
158
+ sendQUIC stream payloads =
159
+ try (QUIC. sendStreamMany stream payloads)
160
+ <&> first
161
+ ( \ case
162
+ QUIC. StreamIsClosed -> TransportError SendClosed " QUIC stream is closed"
163
+ QUIC. ConnectionIsClosed reason -> TransportError SendClosed (show reason)
164
+ other -> TransportError SendFailed (show other)
165
+ )
166
+
167
+ closeQUICTransport :: TransportState -> IO ()
168
+ closeQUICTransport = flip traverseTransportState (\ _ endpoint -> closeEndPoint endpoint)
169
+
170
+ {- | We keep track of all threads actively listening on QUIC streams
171
+ so that we can cleanly stop these threads when closing the endpoint.
172
+
173
+ See 'withThreadRegistered' for a combinator which automatically keeps
174
+ track of these threads
175
+ -}
176
+ newtype EndpointState = EndpointState
177
+ { threads :: IORef (Set ThreadId )
178
+ }
179
+
180
+ withThreadRegistered :: EndpointState -> IO a -> IO a
181
+ withThreadRegistered state f =
182
+ bracket
183
+ registerThread
184
+ unregisterThread
185
+ (const f)
186
+ where
187
+ registerThread =
188
+ myThreadId
189
+ >>= \ tid ->
190
+ atomicModifyIORef'_ (threads state) (Set. insert tid)
191
+ $> tid
192
+
193
+ unregisterThread tid =
194
+ atomicModifyIORef'_ (threads state) (Set. insert tid)
195
+
196
+ stopAllThreads :: EndpointState -> IO ()
197
+ stopAllThreads (EndpointState tds) = do
198
+ readIORef tds >>= traverse_ killThread
199
+ writeIORef tds mempty -- so that we can call `closeQUICTransport` even after the endpoint has been closed
0 commit comments