{-# LANGUAGE GADTs, ParallelListComp #-}

-- | Automatic deadlock prevention.
--
-- Automatic deadlock detection is inefficient, and computations cannot be rolled
-- back or aborted in general.
--
-- Instead, we prevent deadlocks before they happen.
module Control.CUtils.Deadlock (Res(Lift, Acq, Rel, Fork, Plus, Id), run, lft) where

import Control.Category
import Control.Arrow
import Control.Monad
import Data.Map (Map)
import Data.List (inits, tails, elemIndex, deleteBy)
import Data.Maybe
import Data.Function (on)
import qualified Data.Map as M
import System.IO.Unsafe
import Control.Concurrent
import Prelude hiding (id, (.))

-- The typical sequence that produces a deadlock is as follows:
--
-- (1) Thread 1 acquires lock A
-- (2) Thread 2 acquires lock B
-- (3) Thread 1 tries to acquire B
-- (4) Thread 2 tries to acquire A
-- Deadlock.
--
-- Standard deadlock detection intervenes after (4) has occurred.
-- We should intervene in a lock acquisition that is followed
-- by an unsafe schedule (here at (2)). We suspend thread 2
-- until a safe schedule is guaranteed -- in this case until
-- thread 1 relinquishes lock A.
--
-- We need to do some kind of static analysis on the threads
-- to do this. Haskell arrows make possible a kind of JIT
-- static analysis. We leverage the fact that considerable
-- computation has been done to reach a certain point --
-- we only have to analyse the immediate continuation of
-- a thread.

-- | The Res arrow.
data Res t u where
	Lift :: Kleisli IO t v -> Res v u -> Res t u
	Acq :: MVar () -> Res t u -> Res t u -- acquire a lock
	Rel :: MVar () -> Res t u -> Res t u -- release a lock
	Fork :: Res t () -> Res t u -> Res t u -- fork a thread
	Plus :: Res t v -> Res u v -> Res (Either t u) v -- choice
	Id :: Res t t

instance Category Res where
	id = Id
	a . Lift k a2 = Lift k (a . a2)
	a . Acq m a2 = Acq m (a . a2)
	a . Rel m a2 = Rel m (a . a2)
	a . Fork a2 a3 = Fork a2 (a . a3)
	a . Plus a2 a3 = Plus (a . a2) (a . a3)
	a . Id = a

instance Arrow Res where
	arr f = Lift (arr f) Id
	first (Lift k a) = Lift (first k) (first a)
	first (Acq m a) = Acq m (first a)
	first (Rel m a) = Rel m (first a)
	first (Fork a a2) = Fork (a . arr fst) (first a2)
	first Id = Id

instance ArrowChoice Res where
	left a = Plus (arr Left . a) (arr Right)

-- For each thread, we need to track what resources it currently
-- holds, and for each resource, the resources it may
-- potentially acquire while holding that resource.

resource :: MVar (Map ThreadId [(MVar (), [MVar ()])])
{-# NOINLINE resource #-}
resource = unsafePerformIO (newMVar M.empty)

-- A hazard is an ACQUIRE-HOLD cycle among threads.
-- We generate all sequences looking for a cycle.

selects ls = [ (y, xs ++ ys) | xs <- inits ls | y:ys <- tails ls ]

generateSequences ls lock = if null ls then
		return []
	else do
		((t, m), xs) <- selects ls
		lock' <- maybe [] id $ lookup lock m
		liftM (lock':) $ generateSequences xs lock'

-- If there is a hazard, returns a /guard/ for the hazard, i.e.
-- a lock which avoids the hazard.
hazard mp m = msum $ map (\(x:xs) -> guard (m `elem` xs) >> return x) $ generateSequences (M.assocs mp) m

-- This is the static analysis bit.
acquired :: Res t u -> MVar () -> [MVar ()]
acquired (Lift _ a) m = acquired a m
acquired (Acq m' a) m = m' : acquired a m
acquired (Rel m' _) m | m' == m = []
acquired (Rel _ a) m = acquired a m
acquired (Fork a a2) m = acquired a m ++ acquired a2 m
acquired (Plus a a2) m = acquired a m ++ acquired a2 m
acquired Id _ = []

insert x y ((x1, _):xs) | x == x1 = (x, y) : xs
insert x y (pr:xs) = pr : insert x y xs
insert x y [] = [(x, y)]

-- | Use this to run computations built in the Res arrow.
--   Pieces of the arrow that hold locks must be finitely examinable,
--   otherwise it doesn't terminate.
run :: Res t u -> t -> IO u
run (Lift k a) x = runKleisli k x >>= run a
run (Acq m a) x = do
	-- Add this lock to held locks.
	mp <- takeMVar resource
	thd <- myThreadId
	let mp' = M.alter (Just . insert m (acquired a m) . maybe [] id) thd mp

	-- Have to see if acquiring this lock creates a hazard
	-- involving possibly acquired locks.
	let may = hazard mp' m
	maybe
		(do
			putMVar resource mp'
			takeMVar m
			run a x)
		-- Waits on the lock. This has the effect of denying service
		-- to this thread until the hazard has passed.
		(\m' -> do
			putMVar resource mp
			run (Acq m' $ Rel m' $ Acq m a) x)
		may
run (Rel m a) x = do
	putMVar m ()
	thd <- myThreadId
	modifyMVar_ resource (return . M.adjust (deleteBy ((==) `on` fst) (m, [])) thd)
	run a x
run (Fork a a2) x = forkIO (run a x) >> run a2 x
run (Plus a a2) ei = either (run a) (run a2) ei
run Id x = return x

lft m = Lift $ Kleisli $ \x -> m >> return x

-- This implements the example above, using the primitives of this library.
test = do
	m1 <- newMVar ()
	m2 <- newMVar ()
	run (Fork (lft (print "Thd1 done") Id . Rel m1 Id . Rel m2 Id . Acq m1 Id . lft (threadDelay 1000000) Id . Acq m2 Id)
		(lft (print "Thd2 done") Id . Rel m1 Id . Rel m2 Id . Acq m2 Id . lft (threadDelay 1000000) Id . Acq m1 Id))
		()