@@ -6,7 +6,7 @@ let newLine = 0x0A
66let headerPreamble = " codervpn "
77
88/// A message that has the `rpc` property for recording participation in a unary RPC.
9- protocol RPCMessage {
9+ protocol RPCMessage : Sendable {
1010 var rpc : Vpn_RPC { get set }
1111 /// Returns true if `rpc` has been explicitly set.
1212 var hasRpc : Bool { get }
@@ -49,8 +49,8 @@ struct ProtoVersion: CustomStringConvertible, Equatable, Codable {
4949 }
5050}
5151
52- /// An abstract base class for implementations that need to communicate using the VPN protocol.
53- class Speaker < SendMsg: RPCMessage & Message , RecvMsg: RPCMessage & Message > {
52+ /// An actor that communicates using the VPN protocol
53+ actor Speaker < SendMsg: RPCMessage & Message , RecvMsg: RPCMessage & Message > {
5454 private let logger = Logger ( subsystem: " com.coder.Coder-Desktop " , category: " proto " )
5555 private let writeFD : FileHandle
5656 private let readFD : FileHandle
@@ -59,6 +59,8 @@ class Speaker<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Message> {
5959 private let sender : Sender < SendMsg >
6060 private let receiver : Receiver < RecvMsg >
6161 private let secretary = RPCSecretary < RecvMsg > ( )
62+ private var messageBuffer : MessageBuffer = . init( )
63+ private var readLoopTask : Task < Void , any Error > ?
6264 let role : ProtoRole
6365
6466 /// Creates an instance that communicates over the provided file handles.
@@ -93,41 +95,45 @@ class Speaker<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Message> {
9395 try _ = await hndsh. handshake ( )
9496 }
9597
96- /// Reads and handles protocol messages.
97- func readLoop( ) async throws {
98- for try await msg in try await receiver. messages ( ) {
99- guard msg. hasRpc else {
100- handleMessage ( msg)
101- continue
102- }
103- guard msg. rpc. msgID == 0 else {
104- let req = RPCRequest < SendMsg , RecvMsg > ( req: msg, sender: sender)
105- handleRPC ( req)
106- continue
107- }
108- guard msg. rpc. responseTo == 0 else {
109- logger. debug ( " got RPC reply for msgID \( msg. rpc. responseTo) " )
110- do throws ( RPCError) {
111- try await self . secretary. route ( reply: msg)
112- } catch {
113- logger. error (
114- " couldn't route RPC reply for \( msg. rpc. responseTo) : \( error) " )
98+ public func start( ) {
99+ guard readLoopTask == nil else {
100+ logger. error ( " speaker is already running " )
101+ return
102+ }
103+ readLoopTask = Task {
104+ do throws ( ReceiveError) {
105+ for try await msg in try await self . receiver. messages ( ) {
106+ guard msg. hasRpc else {
107+ await messageBuffer. push ( . message( msg) )
108+ continue
109+ }
110+ guard msg. rpc. msgID == 0 else {
111+ let req = RPCRequest < SendMsg , RecvMsg > ( req: msg, sender: self . sender)
112+ await messageBuffer. push ( . RPC( req) )
113+ continue
114+ }
115+ guard msg. rpc. responseTo == 0 else {
116+ self . logger. debug ( " got RPC reply for msgID \( msg. rpc. responseTo) " )
117+ do throws ( RPCError) {
118+ try await self . secretary. route ( reply: msg)
119+ } catch {
120+ self . logger. error (
121+ " couldn't route RPC reply for \( msg. rpc. responseTo) : \( error) " )
122+ }
123+ continue
124+ }
115125 }
116- continue
126+ } catch {
127+ self . logger. error ( " failed to receive messages: \( error) " )
117128 }
118129 }
119130 }
120131
121- /// Handles a single non-RPC message. It is expected that subclasses override this method with their own handlers.
122- func handleMessage( _ msg: RecvMsg) {
123- // just log
124- logger. debug ( " got non-RPC message \( msg. textFormatString ( ) ) " )
125- }
126-
127- /// Handle a single RPC request. It is expected that subclasses override this method with their own handlers.
128- func handleRPC( _ req: RPCRequest < SendMsg , RecvMsg > ) {
129- // just log
130- logger. debug ( " got RPC message \( req. msg. textFormatString ( ) ) " )
132+ func wait( ) async throws {
133+ guard let task = readLoopTask else {
134+ return
135+ }
136+ try await task. value
131137 }
132138
133139 /// Send a unary RPC message and handle the response
@@ -166,10 +172,51 @@ class Speaker<SendMsg: RPCMessage & Message, RecvMsg: RPCMessage & Message> {
166172 logger. error ( " failed to close read file handle: \( error) " )
167173 }
168174 }
175+
176+ enum IncomingMessage {
177+ case message( RecvMsg )
178+ case RPC( RPCRequest < SendMsg , RecvMsg > )
179+ }
180+
181+ private actor MessageBuffer {
182+ private var messages : [ IncomingMessage ] = [ ]
183+ private var continuations : [ CheckedContinuation < IncomingMessage ? , Never > ] = [ ]
184+
185+ func push( _ message: IncomingMessage ? ) {
186+ if let continuation = continuations. first {
187+ continuations. removeFirst ( )
188+ continuation. resume ( returning: message)
189+ } else if let message = message {
190+ messages. append ( message)
191+ }
192+ }
193+
194+ func next( ) async -> IncomingMessage ? {
195+ if let message = messages. first {
196+ messages. removeFirst ( )
197+ return message
198+ }
199+ return await withCheckedContinuation { continuation in
200+ continuations. append ( continuation)
201+ }
202+ }
203+ }
169204}
170205
171- /// A class that performs the initial VPN protocol handshake and version negotiation.
172- class Handshaker {
206+ extension Speaker : AsyncSequence , AsyncIteratorProtocol {
207+ typealias Element = IncomingMessage
208+
209+ public nonisolated func makeAsyncIterator( ) -> Speaker < SendMsg , RecvMsg > {
210+ self
211+ }
212+
213+ func next( ) async throws -> IncomingMessage ? {
214+ return await messageBuffer. next ( )
215+ }
216+ }
217+
218+ /// An actor performs the initial VPN protocol handshake and version negotiation.
219+ actor Handshaker {
173220 private let writeFD : FileHandle
174221 private let dispatch : DispatchIO
175222 private var theirData : Data = . init( )
@@ -193,17 +240,19 @@ class Handshaker {
193240 func handshake( ) async throws -> ProtoVersion {
194241 // kick off the read async before we try to write, synchronously, so we don't deadlock, both
195242 // waiting to write with nobody reading.
196- async let theirs = try withCheckedThrowingContinuation { cont in
197- continuation = cont
198- // send in a nil read to kick us off
199- handleRead ( false , nil , 0 )
243+ let readTask = Task {
244+ try await withCheckedThrowingContinuation { cont in
245+ self . continuation = cont
246+ // send in a nil read to kick us off
247+ self . handleRead ( false , nil , 0 )
248+ }
200249 }
201250
202251 let vStr = versions. map { $0. description } . joined ( separator: " , " )
203252 let ours = String ( format: " \( headerPreamble) \( role) \( vStr) \n " )
204253 try writeFD. write ( contentsOf: ours. data ( using: . utf8) !)
205254
206- let theirData = try await theirs
255+ let theirData = try await readTask . value
207256 guard let theirsString = String ( bytes: theirData, encoding: . utf8) else {
208257 throw HandshakeError . invalidHeader ( " <unparsable: \( theirData) " )
209258 }
@@ -219,7 +268,7 @@ class Handshaker {
219268 private func handleRead( _: Bool , _ data: DispatchData ? , _ error: Int32 ) {
220269 guard error == 0 else {
221270 let errStrPtr = strerror ( error)
222- let errStr = String ( validatingUTF8 : errStrPtr!) !
271+ let errStr = String ( validatingCString : errStrPtr!) !
223272 continuation? . resume ( throwing: HandshakeError . readError ( errStr) )
224273 return
225274 }
0 commit comments