{-# LANGUAGE Trustworthy, GADTs, ParallelListComp, Arrows, ImplicitParams, ScopedTypeVariables #-}

-- | Automatic deadlock prevention.
--
-- Automatic deadlock detection is inefficient, and computations cannot be rolled
-- back or aborted in general.
--
-- Instead, I prevent deadlocks before they happen.
module Control.CUtils.Deadlock (Res(Lift, Acq, Rel, Fork, Plus, Id), liftK, lft, acq, rel, fork, run) where

import Control.Category
import Control.Arrow
import Control.Monad
import Data.Map (Map)
import Data.List (inits, tails, elemIndex, deleteBy)
import Data.Array.IO (IOArray, newArray_, writeArray)
import Data.Array.Unsafe
import Data.Array
import Data.Maybe
import Data.Function (on)
import Data.BellmanFord
import qualified Data.Map as M
import System.IO.Unsafe
import Unsafe.Coerce
import Control.Concurrent
import Control.CUtils.Conc
import Control.CUtils.StrictArrow
import Prelude hiding (id, (.))

-- | The typical sequence that produces a deadlock is as follows:
--
-- (1) Thread 1 acquires lock A
--
-- (2) Thread 2 acquires lock B
--
-- (3) Thread 1 tries to acquire B
--
-- (4) Thread 2 tries to acquire A
--
-- Deadlock.
--
-- Standard deadlock detection intervenes after (4) has occurred.
-- I intervene in a lock acquisition that is followed
-- by an unsafe schedule (here at (2)). I suspend thread 2
-- until a safe schedule is guaranteed -- in this case until
-- thread 1 relinquishes lock A.
--
-- The Res arrow.
--
-- Computations are built with these constructors (and the arrow
-- interface). Pieces of the arrow that hold locks have to be finitely examinable,
-- Locks have to be used with the Acq and Rel constructors.
data Res t u where
	Lift :: Kleisli IO t v -> Res v u -> Res t u
	Acq :: MVar () -> Res t u -> Res t u -- acquire a lock
	Rel :: MVar () -> Res t u -> Res t u -- release a lock
	Fork :: Res t () -> Res t u -> Res t u -- fork a thread
	Plus :: Res t v -> Res u v -> Res (Either t u) v -- choice
	Id :: Res t t

liftK f = Lift (Kleisli f) Id

lft m = Lift (Kleisli (const m))

acq m = Acq m Id

rel m = Rel m Id

fork a = Fork a

instance Category Res where
	id = Id
	a . Lift k a2 = Lift k (a . a2)
	a . Acq m a2 = Acq m (a . a2)
	a . Rel m a2 = Rel m (a . a2)
	a . Fork a2 a3 = Fork a2 (a . a3)
	a . Plus a2 a3 = Plus (a . a2) (a . a3)
	a . Id = a

instance Arrow Res where
	arr f = Lift (arr f) Id
	first (Lift k a) = Lift (first k) (first a)
	first (Acq m a) = Acq m (first a)
	first (Rel m a) = Rel m (first a)
	first (Fork a a2) = Fork (a . arr fst) (first a2)
	first Id = Id

instance ArrowChoice Res where
	left a = Plus (arr Left . a) (arr Right)

unsafeFreeze' :: IOArray Int e -> IO (Array Int e)
unsafeFreeze' = unsafeFreeze

instance Concurrent Res where
	-- This is a cheesy reimplementation of the stuff in .Conc, so that we
	-- can use (and examine) the 'Fork' primitive.
	arr_concF_ mnds = proc (parm, n) -> do
		sem <- Lift (Kleisli $ const $ newQSem 0) Id -< ()
		recurse -< (sem, parm, n)
		Lift (Kleisli $ \(sem, n) -> sequence_ (replicate n (waitQSem sem))) Id -< (sem, n) where
		recurse = proc (sem, parm, n) -> if n <= 0 then
				returnA -< ()
			else
				(Fork (proc (sem, parm, n) -> do
					mnds -< (parm, n)
					Lift (Kleisli signalQSem) Id -< sem)
					recurse) -< (sem, parm, pred n)
	arr_concF mnds = proc (parm, n) -> do
		ar <- Lift (Kleisli newArray_) Id -< (0, n-1)
		arr_concF_ (proc ((ar, parm), n) -> do
			x <- mnds -< (parm, n)
			Lift (Kleisli $ \(n, x, ar) -> writeArray ar n x) Id -< (n, x, ar))
			-< ((ar, parm), n)
		Lift (Kleisli unsafeFreeze') Id -< ar
	arr_oneOfF mnds = let ?seq = False in arr_concF mnds >>> arr (! 0)

instance Strict Res where
	force = Lift (force id)

-- For each thread, we need to track what resources it currently
-- holds, and for each resource, the resources it may
-- potentially acquire while holding that resource.

resource :: MVar (Map ThreadId [(MVar (), [MVar ()])])
{-# NOINLINE resource #-}
resource = unsafePerformIO (newMVar M.empty)

-- A hazard is a HOLD-ACQUIRE cycle among threads.
-- We generate all sequences looking for a cycle.

instance Ord (MVar t) where
	x <= y = (unsafeCoerce x :: Int) <= (unsafeCoerce y :: Int)

instance Show (MVar t) where
	show x = show (unsafeCoerce x :: Int)

-- If there is a hazard, returns a /guard/ for the hazard, i.e.
-- a lock which avoids the hazard.
hazard mp m = cycles (M.fromList $ concatMap (\ls -> concatMap (\(x, ls2) -> map (\y -> ((x, y), ())) ls2) ls) $ M.elems mp) m

-- This is the static analysis bit.
acquired :: Res t u -> MVar () -> [MVar ()]
acquired (Lift _ a) m = acquired a m
acquired (Acq m' a) m = m' : acquired a m
acquired (Rel m' _) m | m' == m = []
acquired (Rel _ a) m = acquired a m
acquired (Fork a a2) m = acquired a m ++ acquired a2 m
acquired (Plus a a2) m = acquired a m ++ acquired a2 m
acquired Id _ = []

insert x y ((x1, _):xs) | x == x1 = (x, y) : xs
insert x y (pr:xs) = pr : insert x y xs
insert x y [] = [(x, y)]

-- | Use this to run computations built in the Res arrow.
run :: Res t u -> t -> IO u
run (Lift k a) x = runKleisli k x >>= run a
run (Acq m a) x = do
	-- Add this lock to held locks.
	mp <- takeMVar resource
	thd <- myThreadId
	let mp' = M.alter (Just . insert m (acquired a m) . maybe [] id) thd mp

	-- Have to see if acquiring this lock creates a hazard
	-- involving possibly acquired locks.
	let may = hazard mp' m
	maybe
		(do
			putMVar resource mp'
			takeMVar m
			run a x)
		-- Waits on the lock. This has the effect of denying service
		-- to this thread until the hazard has passed.
		(\m' -> do
			putMVar resource mp
			takeMVar m'
			putMVar m' ()
			run (Acq m a) x)
		may
run (Rel m a) x = do
	putMVar m ()
	thd <- myThreadId
	modifyMVar_ resource (return . M.adjust (deleteBy ((==) `on` fst) (m, [])) thd)
	run a x
run (Fork a a2) x = forkIO (run a x) >> run a2 x
run (Plus a a2) ei = either (run a) (run a2) ei
run Id x = return x

-- This implements the example above, using the primitives of this library.
test = do
	m1 <- newMVar ()
	m2 <- newMVar ()
	run (fork (lft (print "Thd1 done") Id . rel m1 . rel m2 . acq m1 . lft (threadDelay 1000000) Id . acq m2)
		(lft (print "Thd2 done") Id . rel m1 . rel m2 . acq m2 . lft (threadDelay 1000000) Id . acq m1))
		()

test2 = do
	m1 <- newMVar ()
	m2 <- newMVar ()
	let waltz = rel m1 >>> acq m1 >>> rel m2 >>> acq m2 >>> lft (threadDelay 500000) Id . waltz
	run (fork (acq m1 >>> acq m2 >>> waltz) (lft (threadDelay 1000000) Id >>> acq m2 >>> lft (print "Done") Id >>> rel m2))
		()