Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Riot-friendly mutex #75

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions riot/lib/lib.ml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ module Hashmap = Hashmap
module IO = Lib_io
module Logger = Logger_app
module Message = Message
module Mutex = Mutex
module Net = Net
module Pid = Pid
module Process = Process
Expand Down
160 changes: 160 additions & 0 deletions riot/lib/mutex.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
open Global
open Util
open Process.Messages

type 'a t = { mutable inner : 'a; process : Pid.t }

type state = { status : status; queue : Pid.t Lf_queue.t }
and status = Locked of Pid.t | Unlocked

let pp fmt inner_pp mutex =
Format.fprintf fmt "Mutex<inner: %a>" inner_pp mutex.inner

type error = [ `multiple_unlocks | `locked | `not_owner | `process_died ]

let pp_err fmt error =
let reason =
match error with
| `multiple_unlocks -> "Mutex received multiple unlock messages"
| `locked -> "Mutex is locked"
| `not_owner -> "Process does not own mutex"
| `process_died -> "Mutex process died"
in
Format.fprintf fmt "Mutex error: %s" reason

type Message.t +=
| Lock of Pid.t
| Unlock of Pid.t
| Try_lock of Pid.t
| Lock_accepted
| Unlock_accepted
| Failed of error

let rec loop ({ status; queue } as state) =
match receive_any () with
| (Lock owner | Try_lock owner) when status = Unlocked ->
monitor owner;
send owner Lock_accepted;
loop { state with status = Locked owner }
| Lock requesting ->
Lf_queue.push queue requesting;
loop state
| Try_lock requesting -> send requesting @@ Failed `locked
| Unlock pid when status = Locked pid ->
send pid Unlock_accepted;
demonitor pid;
check_queue { state with status = Unlocked }
| Unlock not_owner when status = Unlocked ->
Logger.error (fun f ->
f "Mutex (PID: %a) received unlock message while unlocked" Pid.pp
(self ()));
send not_owner @@ Failed `multiple_unlocks;
loop state
| Unlock not_owner ->
Logger.error (fun f ->
f "Mutex (PID: %a) received unlock message from non-owner process"
Pid.pp (self ()));
send not_owner @@ Failed `not_owner;
loop state
| Monitor (Process_down fell_pid) when status = Locked fell_pid ->
Logger.error (fun f -> f "Mutex owner crashed: %a" Pid.pp fell_pid);
check_queue { state with status = Unlocked }
| _ ->
Logger.debug (fun f ->
f "Mutex (PID: %a) received unexpected message" Pid.pp (self ()));
loop state

and check_queue ({ queue; _ } as state) =
match Lf_queue.pop queue with
| Some owner ->
send owner Lock_accepted;
monitor owner;
loop { state with status = Locked owner }
| None -> loop state

let selector = function
| (Lock_accepted | Unlock_accepted | Failed _ | Monitor (Process_down _)) as m
->
`select m
| _ -> `skip

(* Monitor mutex process to catch crashes *)
let wait_lock mutex : (unit, [> error ]) result =
monitor mutex.process;
send mutex.process @@ Lock (self ());
match[@warning "-8"] receive ~selector () with
| Monitor (Process_down _) -> Error `process_died
| Failed reason -> Error reason
| Lock_accepted -> Ok ()

let try_wait_lock mutex =
monitor mutex.process;
send mutex.process @@ Try_lock (self ());
match[@warning "-8"] receive ~selector () with
| Lock_accepted -> Ok ()
| Failed reason -> Error reason
| Monitor (Process_down _) -> Error `process_died

let wait_unlock mutex =
send mutex.process @@ Unlock (self ());
match[@warning "-8"] receive ~selector () with
| Unlock_accepted ->
demonitor mutex.process;
Ok ()
| Failed reason -> Error reason
| Monitor (Process_down _) -> Error `process_died

(* NOTE: (@faycarsons) Assuming that we do want functions like `get` to return
a copy of the wrapped value to prevent mutation once the mutex has been
unlocked: I'm not sure how we want to go about that copying. There are maybe
cheaper but less safe solutions using `Obj`, but if the serialization cost
is OK this seems to work fine *)
let clone (inner : 'a) : 'a =
let open Marshal in
let ser = to_bytes inner [ Closures; No_sharing ] in
from_bytes ser 0

(* Exposed API *)

let create inner =
let state = { status = Unlocked; queue = Lf_queue.create () } in
let process = spawn_link @@ fun () -> loop state in
{ inner; process }

let drop mutex = exit mutex.process Process.Normal

let lock mutex fn =
let* _ = wait_lock mutex in
mutex.inner <- fn mutex.inner;
wait_unlock mutex

let try_lock mutex fn =
let* _ = try_wait_lock mutex in
mutex.inner <- fn mutex.inner;
wait_unlock mutex

let iter mutex fn =
let* _ = wait_lock mutex in
fn mutex.inner;
wait_unlock mutex

let try_iter mutex fn =
let* _ = try_wait_lock mutex in
fn mutex.inner;
wait_unlock mutex

let get mutex =
let* _ = wait_lock mutex in
let inner = clone mutex.inner in
let* _ = wait_unlock mutex in
Ok inner

let try_get mutex =
let* _ = try_wait_lock mutex in
let inner = clone mutex.inner in
let* _ = wait_unlock mutex in
Ok inner

(* NOTE: (@faycarsons) not sure if we want this? *)
let unsafe_get mutex = mutex.inner
let unsafe_set mutex inner = mutex.inner <- inner
18 changes: 18 additions & 0 deletions riot/riot.mli
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,24 @@ module Hashmap : sig
module Make (B : Base) : Intf with type key = B.key
end

module Mutex : sig
type 'a t
type error

val pp : Format.formatter -> (Format.formatter -> 'a -> unit) -> 'a t -> unit
val pp_err : Format.formatter -> error -> unit
val create : 'a -> 'a t
val drop : 'a t -> unit
val lock : 'a t -> ('a -> 'a) -> (unit, error) result
val try_lock : 'a t -> ('a -> 'a) -> (unit, error) result
val iter : 'a t -> ('a -> unit) -> (unit, error) result
val try_iter : 'a t -> ('a -> unit) -> (unit, error) result
val get : 'a t -> ('a, error) result
val try_get : 'a t -> ('a, error) result
val unsafe_get : 'a t -> 'a
val unsafe_set : 'a t -> 'a -> unit
end

module Stream : sig
type 'v t = 'v Seq.t

Expand Down