Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions async/websocket_async.ml
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,9 @@ let src =

let server ?(name = "websocket.server")
?(check_request = fun _ -> Deferred.return true)
?(select_protocol = fun _ -> None) ~reader ~writer ~app_to_ws ~ws_to_app ()
?(select_protocol = fun _ -> None)
?max_frame_length
~reader ~writer ~app_to_ws ~ws_to_app ()
=
let handshake r w =
(Request.read r >>= function
Expand Down Expand Up @@ -276,7 +278,9 @@ let server ?(name = "websocket.server")
handshake reader writer)
|> Deferred.Or_error.bind ~f:(fun () ->
set_tcp_nodelay writer;
let read_frame = make_read_frame ~mode:Server reader writer in
let read_frame =
make_read_frame ?max_len:max_frame_length ~mode:Server reader writer
in
let rec loop () = read_frame () >>= Pipe.write ws_to_app >>= loop in
let transfer_end =
let buf = Buffer.create 128 in
Expand All @@ -303,7 +307,9 @@ let server ?(name = "websocket.server")
>>= Deferred.Or_error.return)

let upgrade_connection ?(select_protocol = fun _ -> None)
?(ping_interval = Time_ns.Span.of_int_sec 50) ~app_to_ws ~ws_to_app ~f
?(ping_interval = Time_ns.Span.of_int_sec 50)
?max_frame_length
~app_to_ws ~ws_to_app ~f
request =
let headers = Cohttp.Request.headers request in
let key =
Expand All @@ -328,7 +334,9 @@ let upgrade_connection ?(select_protocol = fun _ -> None)
()
in
let handler reader writer =
let read_frame = make_read_frame ~mode:Server reader writer in
let read_frame =
make_read_frame ?max_len:max_frame_length ~mode:Server reader writer
in
let rec loop () =
try_with read_frame >>= function
| Error _ -> Deferred.unit
Expand Down
9 changes: 6 additions & 3 deletions async/websocket_async.mli
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,14 @@ val server :
?name:string ->
?check_request:(Cohttp.Request.t -> bool Deferred.t) ->
?select_protocol:(string -> string option) ->
?max_frame_length:int ->
reader:Reader.t ->
writer:Writer.t ->
app_to_ws:Frame.t Pipe.Reader.t ->
ws_to_app:Frame.t Pipe.Writer.t ->
unit ->
unit Deferred.Or_error.t
(** [server ?request_cb reader writer app_to_ws
(** [server ?request_cb ?max_frame_length reader writer app_to_ws
ws_to_app ()] returns a thread that expects a websocket client
connected to [reader]/[writer] and, after performing the
handshake, will resp. read outgoing frames from [app_to_ws] and
Expand All @@ -71,17 +72,19 @@ val server :
reception of the client HTTP request, [request_cb] will be called
with the request as its argument. If [request_cb] returns true,
the connection will proceed, otherwise, the result is immediately
determined to [Error Exit]. *)
determined to [Error Exit]. If [max_frame_length] is specified and
the server receives a frame above this size, the connection is closed. *)

val upgrade_connection :
?select_protocol:(string -> string option) ->
?ping_interval:Core.Time_ns.Span.t ->
?max_frame_length:int ->
app_to_ws:Frame.t Pipe.Reader.t ->
ws_to_app:Frame.t Pipe.Writer.t ->
f:(unit -> unit Deferred.t) ->
Cohttp.Request.t ->
Cohttp.Response.t * (Reader.t -> Writer.t -> unit Deferred.t)
(** [upgrade_connection ?select_protocol ?ping_interval
(** [upgrade_connection ?select_protocol ?ping_interval ?max_frame_length
app_to_ws ws_to_app f request] returns a {!Cohttp_async.Server.response_action}.

Just wrap the return value of this function with [`Expert].
Expand Down
16 changes: 11 additions & 5 deletions core/websocket.ml
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ module type S = sig
type mode = Client of (int -> string) | Server

val make_read_frame :
?max_len:int ->
?buf:Buffer.t -> mode:mode -> IO.ic -> IO.oc -> unit -> Frame.t IO.t

val write_frame_to_buf : mode:mode -> Buffer.t -> Frame.t -> unit
Expand All @@ -177,6 +178,7 @@ module type S = sig
type t

val create :
?max_len:int ->
?read_buf:Buffer.t ->
?write_buf:Buffer.t ->
Cohttp.Request.t ->
Expand Down Expand Up @@ -265,7 +267,7 @@ module Make (IO : Cohttp.S.IO) = struct
write_frame_to_buf ~mode buf @@ Frame.close code;
write oc @@ Buffer.contents buf

let read_frame ic oc buf mode hdr =
let read_frame ?max_len ic oc buf mode hdr =
let hdr_part1 = EndianString.BigEndian.get_int8 hdr 0 in
let hdr_part2 = EndianString.BigEndian.get_int8 hdr 1 in
let final = is_bit_set 7 hdr_part1 in
Expand All @@ -291,6 +293,10 @@ module Make (IO : Cohttp.S.IO) = struct
else if extension <> 0 then
close_with_code mode buf oc 1002 >>= fun () ->
proto_error "unsupported extension"
else if (match max_len with Some max -> payload_len > max | None -> false)
then
close_with_code mode buf oc 1009 >>= fun () ->
proto_error "frame payload too big"
else if Frame.Opcode.is_ctrl opcode && payload_len > 125 then
close_with_code mode buf oc 1002 >>= fun () ->
proto_error "control frame too big"
Expand All @@ -315,11 +321,11 @@ module Make (IO : Cohttp.S.IO) = struct
let frame = Frame.of_bytes ~opcode ~extension ~final payload in
return frame)

let make_read_frame ?(buf = Buffer.create 128) ~mode ic oc () =
let make_read_frame ?max_len ?(buf = Buffer.create 128) ~mode ic oc () =
Buffer.clear buf;
read_exactly ic 2 buf >>= function
| None -> raise End_of_file
| Some hdr -> read_frame ic oc buf mode hdr
| Some hdr -> read_frame ?max_len ic oc buf mode hdr

module Request = Cohttp.Request.Make (IO)
module Response = Cohttp.Response.Make (IO)
Expand All @@ -337,9 +343,9 @@ module Make (IO : Cohttp.S.IO) = struct

let source { endp; _ } = endp

let create ?read_buf ?(write_buf = Buffer.create 128) http_request endp ic
let create ?max_len ?read_buf ?(write_buf = Buffer.create 128) http_request endp ic
oc =
let read_frame = make_read_frame ?buf:read_buf ~mode:Server ic oc in
let read_frame = make_read_frame ?max_len ?buf:read_buf ~mode:Server ic oc in
{
buffer = write_buf;
endp;
Expand Down
2 changes: 2 additions & 0 deletions core/websocket.mli
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ module type S = sig
type mode = Client of (int -> string) | Server

val make_read_frame :
?max_len:int ->
?buf:Buffer.t -> mode:mode -> IO.ic -> IO.oc -> unit -> Frame.t IO.t

val write_frame_to_buf : mode:mode -> Buffer.t -> Frame.t -> unit
Expand All @@ -106,6 +107,7 @@ module type S = sig
type t

val create :
?max_len:int ->
?read_buf:Buffer.t ->
?write_buf:Buffer.t ->
Cohttp.Request.t ->
Expand Down
8 changes: 4 additions & 4 deletions lwt/websocket_cohttp_lwt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ let send_frames stream oc =
in
Lwt_stream.iter_s send_frame stream

let read_frames ic oc handler_fn =
let read_frame = Lwt_IO.make_read_frame ~mode:Server ic oc in
let read_frames ?max_len ic oc handler_fn =
let read_frame = Lwt_IO.make_read_frame ?max_len ~mode:Server ic oc in
let rec inner () = read_frame () >>= Lwt.wrap1 handler_fn >>= inner in
inner ()

let upgrade_connection request incoming_handler =
let upgrade_connection ?max_frame_length request incoming_handler =
let headers = Cohttp.Request.headers request in
(match Cohttp.Header.get headers "sec-websocket-key" with
| None ->
Expand All @@ -61,7 +61,7 @@ let upgrade_connection request incoming_handler =
[
(* input: data from the client is read from the input channel
* of the tcp connection; pass it to handler function *)
read_frames ic oc incoming_handler;
read_frames ?max_len:max_frame_length ic oc incoming_handler;
(* output: data for the client is written to the output
* channel of the tcp connection *)
send_frames frames_out_stream oc;
Expand Down
1 change: 1 addition & 0 deletions lwt/websocket_cohttp_lwt.mli
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
open Websocket

val upgrade_connection :
?max_frame_length:int ->
Cohttp.Request.t ->
(Frame.t -> unit) ->
(Cohttp_lwt_unix.Server.response_action * (Frame.t option -> unit)) Lwt.t
Expand Down
18 changes: 11 additions & 7 deletions lwt/websocket_lwt_unix.ml
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,12 @@ let write { write_frame; _ } frame = write_frame frame
let close_transport { oc; _ } = Lwt_io.close oc

let connect ?(extra_headers = Cohttp.Header.init ())
?max_frame_length
?(random_string = Websocket.Rng.init ())
?(ctx = Lazy.force Conduit_lwt_unix.default_ctx) ?buf client url =
let nonce = Base64.encode_exn (random_string 16) in
connect ctx client url nonce extra_headers >|= fun (ic, oc) ->
let read_frame = make_read_frame ?buf ~mode:(Client random_string) ic oc in
let read_frame = make_read_frame ?max_len:max_frame_length ?buf ~mode:(Client random_string) ic oc in
let read_frame () =
Lwt.catch read_frame (fun exn ->
Lwt.async (fun () -> Lwt_io.close ic);
Expand Down Expand Up @@ -135,7 +136,7 @@ let write_failed_response oc =
let open Response in
write ~flush:true (fun writer -> write_body writer body) response oc

let server_fun ?read_buf ?write_buf check_request flow ic oc react =
let server_fun ?max_len ?read_buf ?write_buf check_request flow ic oc react =
let read = function
| `Ok r -> Lwt.return r
| `Eof ->
Expand Down Expand Up @@ -181,11 +182,11 @@ let server_fun ?read_buf ?write_buf check_request flow ic oc react =
in
Response.write (fun _writer -> Lwt.return_unit) response oc >>= fun () ->
let client =
Connected_client.create ?read_buf ?write_buf request flow ic oc
Connected_client.create ?max_len ?read_buf ?write_buf request flow ic oc
in
react client

let establish_server ?read_buf ?write_buf ?timeout ?stop
let establish_server ?max_frame_length ?read_buf ?write_buf ?timeout ?stop
?(on_exn = fun exn -> !Lwt.async_exception_hook exn)
?(check_request = check_origin_with_host)
?(ctx = Lazy.force Conduit_lwt_unix.default_ctx) ~mode react =
Expand All @@ -194,7 +195,8 @@ let establish_server ?read_buf ?write_buf ?timeout ?stop
set_tcp_nodelay flow;
Lwt.catch
(fun () ->
server_fun ?read_buf ?write_buf check_request
server_fun ?max_len:max_frame_length ?read_buf ?write_buf
check_request
(Conduit_lwt_unix.endp_of_flow flow)
ic oc react)
(function
Expand All @@ -211,9 +213,11 @@ let mk_frame_stream recv =
in
Lwt_stream.from f

let establish_standard_server ?read_buf ?write_buf ?timeout ?stop ?on_exn
let establish_standard_server ?max_frame_length ?read_buf ?write_buf ?timeout
?stop ?on_exn
?check_request ?(ctx = Lazy.force Conduit_lwt_unix.default_ctx) ~mode react
=
let f client = react (Connected_client.make_standard client) in
establish_server ?read_buf ?write_buf ?timeout ?stop ?on_exn ?check_request
establish_server ?max_frame_length ?read_buf ?write_buf ?timeout ?stop
?on_exn ?check_request
~ctx ~mode f
3 changes: 3 additions & 0 deletions lwt/websocket_lwt_unix.mli
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ val close_transport : conn -> unit Lwt.t

val connect :
?extra_headers:Cohttp.Header.t ->
?max_frame_length:int ->
?random_string:(int -> string) ->
?ctx:Conduit_lwt_unix.ctx ->
?buf:Buffer.t ->
Expand All @@ -47,6 +48,7 @@ val connect :
conn Lwt.t

val establish_server :
?max_frame_length:int ->
?read_buf:Buffer.t ->
?write_buf:Buffer.t ->
?timeout:int ->
Expand All @@ -71,6 +73,7 @@ val mk_frame_stream :
is received, the stream will be closed. *)

val establish_standard_server :
?max_frame_length:int ->
?read_buf:Buffer.t ->
?write_buf:Buffer.t ->
?timeout:int ->
Expand Down