// This is mostly taken from the example in the book.
#include "share/atspre_staload.hats"
#include "share/atspre_staload_libats_ML.hats"
#include "libats/DATS/athread_posix.dats"
staload "libats/SATS/athread.sats"
staload "src/filetype.sats"
staload "libats/SATS/funarray.sats"
staload "libats/SATS/deqarray.sats"
staload _ = "libats/DATS/deqarray.dats"
absvtype queue_vtype(a: vt@ype+, int) = ptr
vtypedef queue(a: vt0p, id: int) = queue_vtype(a, id)
vtypedef queue(a: vt0p) = [id:int] queue(a, id)
absprop ISNIL (id : int, b : bool)
extern
fun {a:vt0p} queue_is_nil {id:int} (!queue(a, id)) :
[b:bool] (ISNIL(id, b) | bool(b))
absprop ISFULL (id : int, b : bool)
extern
fun {a:vt0p} queue_is_full {id:int} (!queue(a, id)) :
[b:bool] (ISFULL(id, b) | bool(b))
extern
fun {a:vt0p} queue_insert {id:int} (ISFULL(id,false)
| xs : !queue(a, id) >> queue(a, id2), x : a) : #[id2:int] void
extern
fun {a:vt0p} queue_remove {id:int} (ISNIL(id,false)
| xs : !queue(a, id) >> queue(a, id2)) : #[id2:int] a
extern
fun {a:vt0p} queue_make (cap : intGt(0)) : queue(a)
extern
fun {a:t@ype} queue_free (que : queue(a)) : void
assume queue_vtype(a : vt0p, id : int) = deqarray(a)
assume ISNIL(id : int, b : bool) = unit_p
assume ISFULL(id : int, b : bool) = unit_p
absvtype channel_vtype(a: vt@ype+) = ptr
vtypedef channel(a: vt0p) = channel_vtype(a)
extern
fun {a:vt0p} channel_insert (!channel(a), a) : void
extern
fun {a:vt0p} channel_remove (chan : !channel(a)) : a
extern
fun {a:vt0p} channel_remove_helper (chan : !channel(a), !queue(a) >> _)
: a
extern
fun {a:vt0p} channel_insert_helper (!channel(a), !queue(a) >> _, a) :
void
datavtype channel_ =
| { l0, l1, l2, l3 : agz } CHANNEL of @{ cap = intGt(0)
, spin = spin_vt(l0)
, refcount = intGt(0)
, mutex = mutex_vt(l1)
, CVisNil = condvar_vt(l2)
, CVisFull = condvar_vt(l3)
, queue = ptr
}
extern
fun {a:vt0p} channel_make (cap : intGt(0)) : channel(a)
extern
fun {a:vt0p} channel_ref (ch : !channel(a)) : channel(a)
extern
fun {a:vt0p} channel_unref (ch : channel(a)) : Option_vt(queue(a))
extern
fun channel_refcount {a:vt0p} (ch : !channel(a)) : intGt(0)
assume channel_vtype(a : vt0p) = channel_
implement {a} queue_is_nil (xs) =
(unit_p() | deqarray_is_nil(xs))
implement {a} queue_is_full (xs) =
(unit_p() | deqarray_is_full(xs))
implement {a} queue_remove (prf | xs) =
let
prval () = __assert(prf) where
{ extern
praxi __assert {id:int} (p : ISNIL(id, false)) : [false] void }
in
deqarray_takeout_atbeg(xs)
end
implement {a} queue_insert (prf | xs, x) =
{
prval () = __assert(prf) where
{ extern
praxi __assert {id:int} (p : ISFULL(id, false)) : [false] void }
val () = deqarray_insert_atend(xs, x)
}
implement {a} queue_make (cap) =
deqarray_make_cap(i2sz(cap))
implement {a} queue_free (que) =
deqarray_free_nil($UN.castvwtp0{deqarray(a, 1, 0)}(que))
implement {a} channel_ref (chan) =
let
val @CHANNEL (ch) = chan
val spin = unsafe_spin_vt2t(ch.spin)
val (prf | ()) = spin_lock(spin)
val () = ch.refcount := ch.refcount + 1
val () = spin_unlock(prf | spin)
prval () = fold@(chan)
in
$UN.castvwtp1{channel(a)}(chan)
end
implement {a} channel_unref (chan) =
let
val @CHANNEL{l0,l1,l2,l3}(ch) = chan
val spin = unsafe_spin_vt2t(ch.spin)
val (prf | ()) = spin_lock(spin)
val () = spin_unlock(prf | spin)
val refcount = ch.refcount
in
if refcount <= 1 then
let
val que = $UN.castvwtp0{queue(a)}(ch.queue)
val () = spin_vt_destroy(ch.spin)
val () = mutex_vt_destroy(ch.mutex)
val () = condvar_vt_destroy(ch.CVisNil)
val () = condvar_vt_destroy(ch.CVisFull)
val () = free@{l0,l1,l2,l3}(chan)
in
Some_vt(que)
end
else
let
val () = ch.refcount := refcount - 1
prval () = fold@(chan)
prval () = $UN.cast2void(chan)
in
None_vt()
end
end
implement channel_refcount {a} (chan) =
let
val @CHANNEL{l0,l1,l2,l3}(ch) = chan
val refcount = ch.refcount
in
(fold@(chan) ; refcount)
end
implement {a} channel_make (cap) =
let
extern
praxi __assert() : [l:agz] void
prval [l0:addr]() = __assert()
prval [l1:addr]() = __assert()
prval [l2:addr]() = __assert()
prval [l3:addr]() = __assert()
val chan = CHANNEL{l0,l1,l2,l3}(_)
val+ CHANNEL (ch) = chan
val () = ch.cap := cap
val () = ch.refcount := 1
local
val x = spin_create_exn()
in
val () = ch.spin := unsafe_spin_t2vt(x)
end
local
val x = mutex_create_exn()
in
val () = ch.mutex := unsafe_mutex_t2vt(x)
end
local
val x = condvar_create_exn()
in
val () = ch.CVisNil := unsafe_condvar_t2vt(x)
end
local
val x = condvar_create_exn()
in
val () = ch.CVisFull := unsafe_condvar_t2vt(x)
end
val () = ch.queue := $UN.castvwtp0{ptr}(queue_make(cap))
in
(fold@(chan) ; chan)
end
implement {a} channel_insert (chan, x) =
let
val+ CHANNEL{l0,l1,l2,l3}(ch) = chan
val mutex = unsafe_mutex_vt2t(ch.mutex)
val (prf | ()) = mutex_lock(mutex)
val xs = $UN.castvwtp0{queue(a)}((prf | ch.queue))
val () = channel_insert_helper(chan, xs, x)
prval prf = $UN.castview0{locked_v(l1)}(xs)
val () = mutex_unlock(prf | mutex)
in end
implement {a} channel_remove (chan) =
x where
{ val+ CHANNEL{l0,l1,l2,l3}(ch) = chan
val mutex = unsafe_mutex_vt2t(ch.mutex)
val (prf | ()) = mutex_lock(mutex)
val xs = $UN.castvwtp0{queue(a)}((prf | ch.queue))
val x = channel_remove_helper(chan, xs)
prval prf = $UN.castview0{locked_v(l1)}(xs)
val () = mutex_unlock(prf | mutex) }
implement {a} channel_remove_helper (chan, xs) =
let
val+ CHANNEL{l0,l1,l2,l3}(ch) = chan
val (prf | is_nil) = queue_is_nil(xs)
in
if is_nil then
let
prval (pfmut, fpf) = __assert() where
{ extern
praxi __assert() : vtakeout0(locked_v(l1)) }
val mutex = unsafe_mutex_vt2t(ch.mutex)
val CVisNil = unsafe_condvar_vt2t(ch.CVisNil)
val () = condvar_wait(pfmut | CVisNil, mutex)
prval () = fpf(pfmut)
in
channel_remove_helper(chan, xs)
end
else
let
val is_full = queue_is_full(xs)
val x_out = queue_remove(prf | xs)
val () = if is_full.1 then
condvar_broadcast(unsafe_condvar_vt2t(ch.CVisFull))
in
x_out
end
end
implement {a} channel_insert_helper (chan, xs, x) =
let
val+ CHANNEL{l0,l1,l2,l3}(ch) = chan
val (prf | is_full) = queue_is_full(xs)
in
if is_full then
let
prval (pfmut, fpf) = __assert() where
{ extern
praxi __assert() : vtakeout0(locked_v(l1)) }
val mutex = unsafe_mutex_vt2t(ch.mutex)
val CVisFull = unsafe_condvar_vt2t(ch.CVisFull)
val () = condvar_wait(pfmut | CVisFull, mutex)
prval () = fpf(pfmut)
in
channel_insert_helper(chan, xs, x)
end
else
let
val is_nil = queue_is_nil(xs)
val () = queue_insert(prf | xs, x)
val () = if is_nil.1 then
condvar_broadcast(unsafe_condvar_vt2t(ch.CVisNil))
in end
end