{-
    Runtime.hs
        Copyright 2008 Matthew Sackman <matthew@wellquite.org>

    This file is part of Session Types for Haskell.

    Session Types for Haskell is free software: you can redistribute it
    and/or modify it under the terms of the GNU General Public License
    as published by the Free Software Foundation, either version 3 of
    the License, or (at your option) any later version.

    Session Types for Haskell is distributed in the hope that it will
    be useful, but WITHOUT ANY WARRANTY; without even the implied
    warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
    See the GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with Session Types for Haskell.
    If not, see <http://www.gnu.org/licenses/>.
-}

{-# LANGUAGE KindSignatures, ExistentialQuantification, ScopedTypeVariables, GADTs, MultiParamTypeClasses, FunctionalDependencies, UndecidableInstances, FlexibleInstances, FlexibleContexts #-}

-- | Having actually described a session type, you'll now want to
-- implement it! Use the methods of 'SMonad' to chain functions
-- together.

module Control.Concurrent.Session.Runtime
    ( OfferImpls (..)
    , SessionState ()
    , SessionChain (..)
    , sjump
    , ssend
    , srecv
    , soffer
    , sselect
    , run
    ) where

import Control.Concurrent.Session.List
import Control.Concurrent.Session.SessionType
import Control.Concurrent.Session.Number
import Control.Concurrent.Session.SMonad
import Control.Concurrent

-- | Use OfferImpls to construct the implementations of the branches
-- of an offer. Really, it's just a slightly fancy list.
data OfferImpls :: * -> * -> * -> * -> * -> * -> * where
                   OfferImplsNil :: OfferImpls Nil prog progOut progIn finalState finalResult
                   (:~:) :: (SessionChain prog progOut progIn (Cons (Jump l) Nil, Cons (Jump l) Nil) finalState finalResult) ->
                            OfferImpls jumps prog progOut progIn finalState finalResult ->
                            OfferImpls (Cons (Cons (Jump l) Nil) jumps) prog progOut progIn finalState finalResult
infixr 5 :~:

class WalkOfferImpls prog progOut progIn finalState finalResult where
    walkOfferImpls :: Int -> OfferImpls jumps prog progOut progIn finalState finalResult -> SessionChain prog progOut progIn from finalState finalResult
instance forall prog progOut progIn finalState finalResult .
    WalkOfferImpls prog progOut progIn finalState finalResult where
        walkOfferImpls 0 (chain :~: _) = SessionChain f
            where
              f :: forall from . SessionState prog progOut progIn from ->
                   IO (finalResult, SessionState prog progOut progIn finalState)
              f (SessionState prog outgoingProg incomingProg _ _)
                  = runSessionChain chain (SessionState prog outgoingProg incomingProg undefined undefined)
        walkOfferImpls n (_ :~: rest) = walkOfferImpls (n - 1) rest
        walkOfferImpls _ _ = error "The Truly Impossible Happened."

data Cell :: * -> * where
             Cell :: val -> MVar (Cell nxt) -> Cell (Cons val nxt)
             SelectCell :: Int -> Cell (Cons (Choice jumps) Nil)

data ProgramCell :: * -> * where
                    ProgramCell :: MVar a -> MVar (ProgramCell a) -> ProgramCell a

class ProgramToMVarsOutgoing prog mvars | prog -> mvars where
    programToMVarsOutgoing :: prog -> IO mvars
instance ProgramToMVarsOutgoing Nil Nil where
    programToMVarsOutgoing Nil = return Nil
instance (ProgramToMVarsOutgoing nxt nxt', OnlyOutgoing val val') =>
    ProgramToMVarsOutgoing (Cons val nxt) (Cons (MVar (ProgramCell (Cell val'))) nxt') where
        programToMVarsOutgoing (Cons _ nxt)
            = do { hole <- newEmptyMVar
                 ; rest <- programToMVarsOutgoing nxt
                 ; return $ Cons hole rest
                 }

data SessionState :: * -> * -> * -> * -> * where
                     SessionState :: prog -> progOut -> progIn ->
                                     (MVar (Cell currentOutgoing)) ->
                                     (MVar (Cell currentIncoming)) ->
                                     SessionState prog progOut progIn
                                                  (currentOutgoing, currentIncoming)

-- | The representation of a computation that performs work using
-- session types. Again, really quite similar to a more-parameterized
-- State monad.
newtype SessionChain prog progOut progIn from to res
    = SessionChain { runSessionChain :: (SessionState prog progOut progIn from) ->
                                        IO (res, SessionState prog progOut progIn to)
                   }

instance ( Dual prog prog'
         , ProgramToMVarsOutgoing prog progOut
         , ProgramToMVarsOutgoing prog' progIn
         ) =>
    SMonad (SessionChain prog progOut progIn) where
        f ~>> g   = SessionChain $ \x -> do { (_, y) <- runSessionChain f x
                                            ; runSessionChain g y
                                            }
        f ~>>= g  = SessionChain $ \x -> do { (a, y) <- runSessionChain f x
                                            ; runSessionChain (g a) y
                                            }
        sreturn a = SessionChain $ \x -> return (a, x)

instance ( Dual prog prog'
         , ProgramToMVarsOutgoing prog progOut
         , ProgramToMVarsOutgoing prog' progIn
         ) =>
    SMonadIO (SessionChain prog progOut progIn) where
    sliftIO f = SessionChain $ \x -> do { a <- f
                                        ; return (a, x)
                                        }

carefullySwapToNextCell :: MVar (ProgramCell a) -> IO (ProgramCell a)
carefullySwapToNextCell programCellMVar
    = do { maybeProgramCell <- tryTakeMVar programCellMVar
         ; case maybeProgramCell of
                        -- if it's already full then no one else will grab it but us, so safe
                        -- if it's empty, then must be careful, as could fill up in mean time
             Nothing -> do { emptyProgramCell <- newEmptyMVar
                           ; emptyProgramCellMVar <- newEmptyMVar
                           ; let cell = (ProgramCell emptyProgramCell emptyProgramCellMVar)
                           ; didPut <- tryPutMVar programCellMVar  cell
                           ; if didPut
                             then return cell
                             else takeMVar programCellMVar
                           }
             (Just cell) -> return cell
         }

-- | Perform a jump. Now you may think that you should indicate where
-- you want to jump to. But of coures, that's actually specified by
-- the session type so you don't have to specify it at all in the
-- implementation.
sjump :: forall l prog prog' progOut progIn outgoing incoming .
         ( Dual prog prog'
         , ProgramToMVarsOutgoing prog progOut
         , ProgramToMVarsOutgoing prog' progIn
         , SWellFormedConfig l (D0 E) prog
         , SWellFormedConfig l (D0 E) prog'
         , Elem progOut l (MVar (ProgramCell (Cell outgoing)))
         , Elem progIn l (MVar (ProgramCell (Cell incoming)))
         ) =>
        (SessionChain prog progOut progIn) ((Cons (Jump l) Nil), (Cons (Jump l) Nil)) (outgoing, incoming) ()
sjump = SessionChain f
    where
      f :: SessionState prog progOut progIn ((Cons (Jump l) Nil), (Cons (Jump l) Nil)) ->
           IO ((), SessionState prog progOut progIn (outgoing, incoming))
      f (SessionState prog outgoingProg incomingProg _ _)
          = do { (ProgramCell outgoing outProgCellMVar') <- carefullySwapToNextCell outProgCellMVar
               ; (ProgramCell incoming inProgCellMVar') <- carefullySwapToNextCell inProgCellMVar
               ; let outgoingProg' = tyListUpdate outgoingProg (undefined :: l) outProgCellMVar'
               ; let incomingProg' = tyListUpdate incomingProg (undefined :: l) inProgCellMVar'
               ; return ((), (SessionState prog outgoingProg' incomingProg' outgoing incoming))
               }
          where
            outProgCellMVar = tyListElem outgoingProg (undefined :: l)
            inProgCellMVar = tyListElem incomingProg (undefined :: l)

-- | Send a value to the other party. Of course, the value must be of
-- the correct type indicated in the session type.
ssend :: forall t prog prog' progOut progIn nxt incoming .
         ( Dual prog prog'
         , ProgramToMVarsOutgoing prog progOut
         , ProgramToMVarsOutgoing prog' progIn
         ) =>
         t -> (SessionChain prog progOut progIn) ((Cons t nxt), incoming) (nxt, incoming) ()
ssend t = SessionChain f
    where
      f :: SessionState prog progOut progIn ((Cons t nxt), incoming) ->
           IO ((), SessionState prog progOut progIn (nxt, incoming))
      f (SessionState prog outgoingProg incomingProg outMVar inMVar)
          = do { hole <- newEmptyMVar
               ; putMVar outMVar (Cell t hole)
               ; return ((), (SessionState prog outgoingProg incomingProg hole inMVar))
               }

-- | Recieve a value from the other party. This will block as
-- necessary. The type of the value received is specified by the
-- session type. No magic coercion needed.
srecv :: forall t prog prog' progOut progIn nxt outgoing .
         ( Dual prog prog'
         , ProgramToMVarsOutgoing prog progOut
         , ProgramToMVarsOutgoing prog' progIn
         ) =>
        (SessionChain prog progOut progIn) (outgoing, (Cons t nxt)) (outgoing, nxt) t
srecv = SessionChain f
    where
      f :: SessionState prog progOut progIn (outgoing, (Cons t nxt)) ->
           IO (t, SessionState prog progOut progIn (outgoing, nxt))
      f (SessionState prog outgoingProg incomingProg outMVar inMVar)
          = do { (Cell t nxt) <- takeMVar inMVar
               ; return (t, (SessionState prog outgoingProg incomingProg outMVar nxt))
               }

-- | Offer a number of branches. This is basically an external choice
-- - the other party uses 'sselect' to decide which branch to take.
-- Use OfferImpls in order to construct the list of implementations of
-- branches. Note that every implementation must result in the same
-- final state and emit the same value.
soffer :: forall finalState finalResult prog prog' progOut progIn jumps .
          ( Dual prog prog'
          , ProgramToMVarsOutgoing prog progOut
          , ProgramToMVarsOutgoing prog' progIn
          ) =>
          OfferImpls jumps prog progOut progIn finalState finalResult
          -> (SessionChain prog progOut progIn) (Cons (Choice jumps) Nil, Cons (Choice jumps) Nil) finalState finalResult
soffer implementations = SessionChain f
    where
      f :: SessionState prog progOut progIn (Cons (Choice jumps) Nil, Cons (Choice jumps) Nil) ->
           IO (finalResult, SessionState prog progOut progIn finalState)
      f (SessionState prog outgoingProg incomingProg _ inMVar)
          = do { (SelectCell n) <- takeMVar inMVar
               ; runSessionChain (walkOfferImpls n implementations) (SessionState prog outgoingProg incomingProg undefined undefined)
               }

-- | Select which branch we're taking at a branch point. Use a type
-- number ("Control.Concurrent.Session.Number") to indicate the branch
-- to take.
sselect :: forall prog prog' progOut progIn label jumps outgoing incoming len jumpTarget .
           ( Dual prog prog'
           , ProgramToMVarsOutgoing prog progOut
           , ProgramToMVarsOutgoing prog' progIn
           , ListLength jumps len
           , SmallerThan label len
           , TypeNumberToInt label
           , Elem jumps label (Cons (Jump jumpTarget) Nil)
           , SWellFormedConfig jumpTarget (D0 E) prog
           , SWellFormedConfig jumpTarget (D0 E) prog'
           , Elem progOut jumpTarget (MVar (ProgramCell (Cell outgoing)))
           , Elem progIn jumpTarget (MVar (ProgramCell (Cell incoming)))
           ) =>
           label -> (SessionChain prog progOut progIn) (Cons (Choice jumps) Nil, Cons (Choice jumps) Nil) (outgoing, incoming) ()
sselect label = SessionChain f
    where
      f :: SessionState prog progOut progIn ((Cons (Choice jumps) Nil, Cons (Choice jumps) Nil)) ->
           IO ((), SessionState prog progOut progIn (outgoing, incoming))
      f (SessionState prog outgoingProg incomingProg outMVar _)
          = do { putMVar outMVar (SelectCell (tyNumToInt label))
               ; (ProgramCell outgoing outProgCellMVar') <- carefullySwapToNextCell outProgCellMVar
               ; (ProgramCell incoming inProgCellMVar') <- carefullySwapToNextCell inProgCellMVar
               ; let outgoingProg' = tyListUpdate outgoingProg (undefined :: jumpTarget) outProgCellMVar'
               ; let incomingProg' = tyListUpdate incomingProg (undefined :: jumpTarget) inProgCellMVar'
               ; return ((), (SessionState prog outgoingProg' incomingProg' outgoing incoming))
               }
          where
            outProgCellMVar = tyListElem outgoingProg (undefined :: jumpTarget)
            inProgCellMVar = tyListElem incomingProg (undefined :: jumpTarget)

-- | Run! Provide a program and a start point within that program
-- (which also then means that all implementations must start with
-- 'sjump'), the two implementations which must be duals of each
-- other, run them, have them communicate, wait until they both finish
-- and die and then return the results from both of them.
run :: ( Dual prog prog'
       , ProgramToMVarsOutgoing prog progOut
       , ProgramToMVarsOutgoing prog' progIn
       ) => prog -> init ->
       SessionChain prog progOut progIn ((Cons (Jump init) Nil), (Cons (Jump init) Nil)) (toO, toI) res ->
       SessionChain prog' progIn progOut ((Cons (Jump init) Nil), (Cons (Jump init) Nil)) (toO', toI') res' ->
       IO (res, res')
run prog _ chain1 chain2
    = do { mvarsOut <- programToMVarsOutgoing prog
         ; mvarsIn <- programToMVarsOutgoing (dual prog)
         ; a <- newEmptyMVar
         ; b <- newEmptyMVar
         ; aDone <- newEmptyMVar
         ; bDone <- newEmptyMVar
         ; forkIO $ runSessionChain chain1 (SessionState prog mvarsOut mvarsIn a b) >>= putMVar aDone . fst
         ; forkIO $ runSessionChain chain2 (SessionState (dual prog) mvarsIn mvarsOut b a) >>= putMVar bDone . fst
         ; aRes <- takeMVar aDone
         ; bRes <- takeMVar bDone
         ; return (aRes, bRes)
         }