{-# LANGUAGE GADTs, EmptyDataDecls, GeneralizedNewtypeDeriving #-}
module Control.Concurrent.CHS (
    CHS,
    Chan,
    newChan,
    readChan,
    writeChan,
    synchronize,
    
    initChs,
    
    -- testing only
    testCHS, test1, test2, test3, step
) where
import Unsafe.Coerce(unsafeCoerce)        -- only used in chanEq; relies on channel ids being unique
import System.IO.Unsafe (unsafePerformIO) -- only used to allocate global variables

import Control.Monad
import Control.Applicative
import Control.Concurrent (MVar, newMVar, putMVar, takeMVar, forkIO)
import Control.Concurrent.STM hiding (orElse)
import Data.Unique
import Control.Monad.Prompt
import qualified Data.List.Zipper as Z

-- Interface

initChs :: IO () -- starts the administrative thread if it hasn't yet; idempotent.

newtype CHS a = CHS { runCHS :: Prompt CHSPrompt a }
   deriving (Functor, Applicative)
-- instance Monad CHS
-- instance Alternative CHS
-- instance MonadPlus CHS
synchronize :: CHS a -> IO a

data Chan a = Chan Unique deriving Eq
-- instance Show (Chan a)

newChan :: IO (Chan a)
readChan :: Chan a -> CHS a
writeChan :: Chan a -> a -> CHS ()

-----------------------------------------------
-- Internals
-----------------------------------------------


-- readChan / writeChan are simple prompts
readChan c = CHS (prompt $ Read c)
writeChan c a = CHS (prompt $ Write a c)

-- synchronize puts the computation on the list for
-- the administrative thread to pick up, then waits
-- for it to answer
synchronize computation = do
    v <- newTVarIO Nothing
    let genericComp = viewPrompt (CHSRes v <$> computation)
    atomically $ modifyTVar chsBlocked (genericComp :)
    atomically $ fromJustM $ readTVar v -- blocks until var is written

instance Monad CHS where
    return = CHS . return
    fail _ = CHS (prompt Fail)
    m >>= f = CHS $ runCHS m >>= runCHS . f

instance MonadPlus CHS where
    mzero = CHS (prompt Fail)
    mplus a b = CHS (prompt (Choice a b))

instance Alternative CHS where
    empty = mzero
    (<|>) = mplus

instance Show (Chan a) where
    show (Chan u) = "Chan " ++ show (hashUnique u)

-- represents the single-thread state of a "synchronize"
data CHSState a where
    Complete :: a -> CHSState a
    BlockedRead :: Chan a -> (a -> CHSState b) -> CHSState b
    BlockedWrite :: a -> Chan a -> CHSState b -> CHSState b
    OrElse :: CHSState a -> CHSState a -> CHSState a
    Failed :: CHSState a

-- prompting implementation
data CHSPrompt a where
    Fail :: CHSPrompt a
    Read :: Chan a -> CHSPrompt a
    Write :: a -> Chan a -> CHSPrompt ()
    Choice :: CHS a -> CHS a -> CHSPrompt a

viewPrompt :: CHS a -> CHSState a
viewPrompt = runPromptC ret prm . runCHS where
    ret = Complete

    prm :: CHSPrompt v -> (v -> CHSState a) -> CHSState a
    prm Fail         _ = Failed
    prm (Read c)     k = BlockedRead c k
    prm (Write a c)  k = BlockedWrite a c (k ())
    prm (Choice a b) k = OrElse (viewPrompt a `bindCHS` k) (viewPrompt b `bindCHS` k)

instance Show a => Show (CHSState a) where
    show (Complete a) = "Complete " ++ show a
    show (BlockedRead c _) = "BlockedRead " ++ show c
    show (BlockedWrite _ c _) = "BlockedWrite " ++ show c
    show Failed = "Failed"
    show (OrElse a b) = show a ++ " `OrElse` " ++ show b

bindCHS :: CHSState a -> (a -> CHSState b) -> CHSState b
bindCHS (Complete a)         f = f a
bindCHS (BlockedRead c k)    f = BlockedRead c $ \a -> (k a `bindCHS` f)
bindCHS (BlockedWrite a c k) f = BlockedWrite a c (k `bindCHS` f)
bindCHS Failed               _ = Failed
bindCHS (OrElse a b)         f = OrElse (a `bindCHS` f) (b `bindCHS` f)

select :: [a] -> [(a,Z.Zipper a)]
select xs = select' (Z.fromList xs) where
    select' z | Z.endp z    = []
              | otherwise   = (Z.cursor z, Z.delete z) : select' (Z.right z)

data TypeEq a b where Refl :: TypeEq a a


stepSynchronize :: [CHSState a] -> [[CHSState a]]
stepSynchronize [] = []
stepSynchronize (Failed : _) = []
stepSynchronize (Complete a : xs) = do
    xs' <- stepSynchronize xs
    return (Complete a : xs')
stepSynchronize (BlockedRead c k : xs) = mplus 
    (do (BlockedWrite a c2 k2, z) <- select xs
        Refl <- chanEq c c2
        return (k a : (Z.toList $ Z.insert k2 z))
    )
    (do xs' <- stepSynchronize xs
        return (BlockedRead c k : xs'))
stepSynchronize (BlockedWrite a c k : xs) = mplus
    (do (BlockedRead c2 k2, z) <- select xs
        Refl <- chanEq c c2
        return (k : (Z.toList $ Z.insert (k2 a) z))
    )
    (do xs' <- stepSynchronize xs
        return (BlockedWrite a c k : xs'))
stepSynchronize (OrElse a b : xs) = [a : xs, b: xs]

-- it's actually important that we put the later
-- subsets first; It means we will complete the
-- oldest set of computations that can successfully complete
-- with the current data
-- (although it probably means we waste work retrying lots of
-- combinations of computations that are guaranteed to fail;
-- an optimization would be to track these somehow and not
-- try them again except with new computations together)
splitSets :: [a] -> [([a], [a])]
splitSets [] = [([], [])]
splitSets (x:xs) = [ (l, x:r) | (l,r) <- splitSets xs ]
                ++ [ (x:l, r) | (l,r) <- splitSets xs ]

trySynchronize :: [CHSState a] -> [([a], [CHSState a])]
trySynchronize gang = do
    (g, r) <- splitSets gang
    guard (not $ null g)
    res <- runSynch g
    return (res, r)

-- depth-first search of the connection space
runSynch :: [CHSState a]-> [[a]]
runSynch gang | complete gang = return [ x | Complete x <- gang ]
              | otherwise     = stepSynchronize gang >>= runSynch
  where
    complete g = all isComplete g
    isComplete (Complete _)  = True
    isComplete _             = False

data CHSRes where
    CHSRes :: TVar (Maybe a) -> a -> CHSRes

writeResult :: CHSRes -> STM ()
writeResult (CHSRes v a) = writeTVar v (Just a)

chsThread :: IO ()
chsThread = forever $ atomically $ do
    gang <- readTVar chsBlocked
    case (trySynchronize gang) of
        [] -> retry
        ((results, gang') : _) -> do
            writeTVar chsBlocked gang'
            mapM_ writeResult results

initChs = do
    started <- takeMVar chsInited
    when (not started) $ do
        forkIO chsThread
        return ()
    putMVar chsInited True

fromJustM :: MonadPlus m => m (Maybe a) -> m a
fromJustM m = do
    x <- m
    case x of (Just a) -> return a
              Nothing  -> mzero


modifyTVar v f = do
    x <- readTVar v
    writeTVar v (f x)


---------------------------------------
-- "unsafe" operating kernel
--
-- all uses of unsafe operations are
-- confined to this section for easier
-- reasoning.
---------------------------------------

-- using uniques here in newChan justifies the use
-- of unsafeCoerce in chanEq
newChan = do
    u <- newUnique
    return (Chan u)

chanEq :: MonadPlus m => Chan a -> Chan b -> m (TypeEq a b)
chanEq (Chan a) (Chan b)
   | a == b    = return (unsafeCoerce Refl)
   | otherwise = mzero


chsBlocked :: TVar [CHSState CHSRes]
chsBlocked = unsafePerformIO $ newTVarIO []
{-# NOINLINE chsBlocked #-}

chsInited :: MVar Bool
chsInited = unsafePerformIO $ newMVar False
{-# NOINLINE chsInited #-}

testCh :: Chan Int
testCh = unsafePerformIO newChan
{-# NOINLINE testCh #-}


-----------
-- TESTS --
-----------

step m = do x <- m; stepSynchronize x

test1, test2, test3 :: Chan Int -> CHS Int

test1 c = do
    x <- readChan c
    if x == 0 then mzero else do
        writeChan c (100 `div` x)
        return x

test2 c = do
    writeChan c 0 `mplus` writeChan c 5
    readChan c

test3 _ = return 100

testGang :: [CHSState Int]
testGang = map (viewPrompt . ($ testCh)) [test3, test1, test2, test1, test1, test2]

testCHS :: IO ()
testCHS = do
    initChs
    c <- newChan
    forkIO $ synchronize (test1 c) >>= print
    synchronize (test2 c) >>= print
    return ()