1
- use std:: { env, net:: SocketAddr , sync:: Arc , time:: Duration } ;
1
+ use std:: { env, io :: Cursor , net:: SocketAddr , path :: PathBuf , sync:: Arc , time:: Duration } ;
2
2
3
+ use anyhow:: * ;
4
+ use bytes:: Bytes ;
3
5
use futures_util:: { SinkExt , StreamExt } ;
6
+ use serde:: { de:: DeserializeOwned , Serialize } ;
4
7
use serde_json:: json;
5
- use tokio:: sync:: Mutex ;
6
- use tokio_tungstenite:: { connect_async, tungstenite:: protocol:: Message } ;
7
- use uuid:: Uuid ;
8
+ use tokio:: { net:: UnixStream , sync:: Mutex } ;
9
+ use tokio_util:: codec:: { Framed , LengthDelimitedCodec } ;
8
10
use warp:: Filter ;
9
11
10
12
const PING_INTERVAL : Duration = Duration :: from_secs ( 1 ) ;
@@ -18,20 +20,20 @@ async fn main() {
18
20
}
19
21
20
22
// Get manager connection details from env vars
21
- let manager_ip = env :: var ( "RIVET_MANAGER_IP" ) . expect ( "RIVET_MANAGER_IP not set" ) ;
22
- let manager_port = env:: var ( "RIVET_MANAGER_PORT " ) . expect ( "RIVET_MANAGER_PORT not set" ) ;
23
- let manager_addr = format ! ( "ws://{}:{}" , manager_ip , manager_port ) ;
23
+ let manager_socket_path = PathBuf :: from (
24
+ env:: var ( "RIVET_MANAGER_SOCKET_PATH " ) . expect ( "RIVET_MANAGER_SOCKET_PATH not set" ) ,
25
+ ) ;
24
26
25
27
// Get HTTP server port from env var or use default
26
28
let http_port = env:: var ( "PORT_MAIN" )
27
29
. expect ( "PORT_MAIN not set" )
28
30
. parse :: < u16 > ( )
29
31
. expect ( "bad PORT_MAIN" ) ;
30
32
31
- // Spawn the WebSocket client
33
+ // Spawn the unix socket client
32
34
tokio:: spawn ( async move {
33
- if let Err ( e) = run_websocket_client ( & manager_addr ) . await {
34
- eprintln ! ( "WebSocket client error: {}" , e) ;
35
+ if let Err ( e) = run_socket_client ( manager_socket_path ) . await {
36
+ eprintln ! ( "Socket client error: {}" , e) ;
35
37
}
36
38
} ) ;
37
39
@@ -53,25 +55,28 @@ async fn main() {
53
55
warp:: serve ( echo) . run ( http_addr) . await ;
54
56
}
55
57
56
- async fn run_websocket_client ( url : & str ) -> Result < ( ) , Box < dyn std :: error :: Error > > {
57
- println ! ( "Connecting to WebSocket at {}" , url ) ;
58
+ async fn run_socket_client ( socket_path : PathBuf ) -> Result < ( ) > {
59
+ println ! ( "Connecting to socket at {}" , socket_path . display ( ) ) ;
58
60
59
- // Connect to the WebSocket server
60
- let ( ws_stream , _ ) = connect_async ( url ) . await ?;
61
- println ! ( "WebSocket connection established" ) ;
61
+ // Connect to the socket server
62
+ let stream = UnixStream :: connect ( socket_path ) . await ?;
63
+ println ! ( "Socket connection established" ) ;
62
64
63
- // Split the stream
64
- let ( mut write, mut read) = ws_stream. split ( ) ;
65
+ let codec = LengthDelimitedCodec :: builder ( )
66
+ . length_field_type :: < u32 > ( )
67
+ . length_field_length ( 4 )
68
+ // No offset
69
+ . length_field_offset ( 0 )
70
+ // Header length is not included in the length calculation
71
+ . length_adjustment ( 4 )
72
+ // header is included in the returned bytes
73
+ . num_skip ( 0 )
74
+ . new_codec ( ) ;
65
75
66
- let payload = json ! ( {
67
- "init" : {
68
- "access_token" : env:: var( "RIVET_ACCESS_TOKEN" ) . expect( "RIVET_ACCESS_TOKEN not set" ) ,
69
- } ,
70
- } ) ;
76
+ let framed = Framed :: new ( stream, codec) ;
71
77
72
- let data = serde_json:: to_vec ( & payload) ?;
73
- write. send ( Message :: Binary ( data) ) . await ?;
74
- println ! ( "Sent init message" ) ;
78
+ // Split the stream
79
+ let ( write, mut read) = framed. split ( ) ;
75
80
76
81
// Ping thread
77
82
let write = Arc :: new ( Mutex :: new ( write) ) ;
@@ -80,10 +85,14 @@ async fn run_websocket_client(url: &str) -> Result<(), Box<dyn std::error::Error
80
85
loop {
81
86
tokio:: time:: sleep ( PING_INTERVAL ) . await ;
82
87
88
+ let payload = json ! ( {
89
+ "ping" : { }
90
+ } ) ;
91
+
83
92
if write2
84
93
. lock ( )
85
94
. await
86
- . send ( Message :: Ping ( Vec :: new ( ) ) )
95
+ . send ( encode_frame ( & payload ) . unwrap ( ) )
87
96
. await
88
97
. is_err ( )
89
98
{
@@ -93,53 +102,61 @@ async fn run_websocket_client(url: &str) -> Result<(), Box<dyn std::error::Error
93
102
} ) ;
94
103
95
104
// Process incoming messages
96
- while let Some ( message) = read. next ( ) . await {
97
- match message {
98
- Ok ( msg) => match msg {
99
- Message :: Pong ( _) => { }
100
- Message :: Binary ( buf) => {
101
- let packet = serde_json:: from_slice :: < serde_json:: Value > ( & buf) ?;
102
- println ! ( "Received packet: {packet:?}" ) ;
103
-
104
- if let Some ( packet) = packet. get ( "start_actor" ) {
105
- let payload = json ! ( {
106
- "actor_state_update" : {
107
- "actor_id" : packet[ "actor_id" ] ,
108
- "generation" : packet[ "generation" ] ,
109
- "state" : {
110
- "running" : null,
111
- } ,
112
- } ,
113
- } ) ;
114
-
115
- let data = serde_json:: to_vec ( & payload) ?;
116
- write. lock ( ) . await . send ( Message :: Binary ( data) ) . await ?;
117
- } else if let Some ( packet) = packet. get ( "signal_actor" ) {
118
- let payload = json ! ( {
119
- "actor_state_update" : {
120
- "actor_id" : packet[ "actor_id" ] ,
121
- "generation" : packet[ "generation" ] ,
122
- "state" : {
123
- "exited" : {
124
- "exit_code" : null,
125
- } ,
126
- } ,
127
- } ,
128
- } ) ;
129
-
130
- let data = serde_json:: to_vec ( & payload) ?;
131
- write. lock ( ) . await . send ( Message :: Binary ( data) ) . await ?;
132
- }
133
- }
134
- msg => eprintln ! ( "Unexpected message: {msg:?}" ) ,
135
- } ,
136
- Err ( e) => {
137
- eprintln ! ( "Error reading message: {}" , e) ;
138
- break ;
139
- }
105
+ while let Some ( frame) = read. next ( ) . await . transpose ( ) ? {
106
+ let ( _, packet) = decode_frame :: < serde_json:: Value > ( & frame. freeze ( ) ) ?;
107
+ println ! ( "Received packet: {packet:?}" ) ;
108
+
109
+ if let Some ( packet) = packet. get ( "start_actor" ) {
110
+ let payload = json ! ( {
111
+ "actor_state_update" : {
112
+ "actor_id" : packet[ "actor_id" ] ,
113
+ "generation" : packet[ "generation" ] ,
114
+ "state" : {
115
+ "running" : null,
116
+ } ,
117
+ } ,
118
+ } ) ;
119
+
120
+ write. lock ( ) . await . send ( encode_frame ( & payload) ?) . await ?;
121
+ } else if let Some ( packet) = packet. get ( "signal_actor" ) {
122
+ let payload = json ! ( {
123
+ "actor_state_update" : {
124
+ "actor_id" : packet[ "actor_id" ] ,
125
+ "generation" : packet[ "generation" ] ,
126
+ "state" : {
127
+ "exited" : {
128
+ "exit_code" : null,
129
+ } ,
130
+ } ,
131
+ } ,
132
+ } ) ;
133
+
134
+ write. lock ( ) . await . send ( encode_frame ( & payload) ?) . await ?;
140
135
}
141
136
}
142
137
143
- println ! ( "WebSocket connection closed" ) ;
138
+ println ! ( "Socket connection closed" ) ;
144
139
Ok ( ( ) )
145
140
}
141
+
142
+ fn decode_frame < T : DeserializeOwned > ( frame : & Bytes ) -> Result < ( [ u8 ; 4 ] , T ) > {
143
+ ensure ! ( frame. len( ) >= 4 , "Frame too short" ) ;
144
+
145
+ // Extract the header (first 4 bytes)
146
+ let header = [ frame[ 0 ] , frame[ 1 ] , frame[ 2 ] , frame[ 3 ] ] ;
147
+
148
+ // Deserialize the rest of the frame (payload after the header)
149
+ let payload = serde_json:: from_slice ( & frame[ 4 ..] ) ?;
150
+
151
+ Ok ( ( header, payload) )
152
+ }
153
+
154
+ fn encode_frame < T : Serialize > ( payload : & T ) -> Result < Bytes > {
155
+ let mut buf = Vec :: with_capacity ( 4 ) ;
156
+ buf. extend_from_slice ( & [ 0u8 ; 4 ] ) ; // header (currently unused)
157
+
158
+ let mut cursor = Cursor :: new ( & mut buf) ;
159
+ serde_json:: to_writer ( & mut cursor, payload) ?;
160
+
161
+ Ok ( buf. into ( ) )
162
+ }
0 commit comments