{-# OPTIONS_GHC -fglasgow-exts -fallow-overlapping-instances -fallow-undecidable-instances #-}

module Session
    (SType(..)
    ,SessionSpec(..)
    ,runSessionWithProcs
    ,Proc(..)
    ,send
    ,recv
    ,sendLog
    ,recvLog
    ,(>~>)
    ,(>~>=)
    ,returnS
    ,sliftIO
    ,sliftIO'
    ,SendT
    ,RecvT
    ,EndT
    ,SessT
    ,SessionState
    ,JustSendsRecvs
    ,ZeroOrMoreSteps
    ,mkLoopS
    ,mkLoop
    )
    where

import Control.Monad.Fix
import Control.Concurrent.MVar
import Control.Concurrent

data SType :: * -> * where
              IntS    :: SType Int
              BoolS   :: SType Bool
              CharS   :: SType Char
              StringS :: SType String
              FloatS  :: SType Float
              DoubleS :: SType Double
              AnyS    :: SType a
              LiftS   :: a -> SType a
              ListS   :: (SType a) -> SType [a]

instance Show (SType a) where
    show IntS      = "IntS"
    show BoolS     = "BoolS"
    show CharS     = "CharS"
    show StringS   = "StringS"
    show FloatS    = "FloatS"
    show DoubleS   = "DoubleS"
    show AnyS      = "AnyS"
    show (LiftS a) = "LiftS"
    show (ListS a) = "ListS of " ++ (show a)


newtype Mu f = In (f (Mu f))

out :: Mu f -> f (Mu f)
out (In x) = x

data LoopF :: * -> * -> * where
              Again :: (SessionSpec l) -> a -> LoopF (SessionSpec l) a

type Loop t = Mu (LoopF t)

instance Show (LoopF (SessionSpec l) a) where
    show (Again l _) = "Loop {" ++ (show l) ++ "}"

again :: SessionSpec l -> Loop (SessionSpec l) -> Loop (SessionSpec l)
again x xs = In (Again x xs)

mkLoopS :: (SessionSpec LoopEndT -> SessionSpec l) -> SessionSpec (Loop (SessionSpec l))
mkLoopS x = LoopS l
    where
      xl = x LoopEndS
      l = again xl l

data SendT a n
data RecvT a n
data EndT
data LoopEndT
data SessT a n

class ReplaceLoopEnd orig replacement result | orig replacement -> result where
    replaceLoopEnd :: orig -> replacement -> result
instance ReplaceLoopEnd (SessionSpec EndT) (SessionSpec n) (SessionSpec EndT) where
    replaceLoopEnd EndS r = EndS
instance ReplaceLoopEnd (SessionSpec LoopEndT) (SessionSpec n) (SessionSpec n) where
    replaceLoopEnd LoopEndS r = r
instance ReplaceLoopEnd (SessionSpec (Loop (SessionSpec l))) (SessionSpec replacement) (SessionSpec (Loop (SessionSpec l))) where
    replaceLoopEnd (LoopS l) r = (LoopS l)
instance (ReplaceLoopEnd (SessionSpec orig) (SessionSpec replacement) (SessionSpec result)) =>
    ReplaceLoopEnd (SessionSpec (SendT t orig)) (SessionSpec replacement) (SessionSpec (SendT t result)) where
        replaceLoopEnd (SendS t n) r = SendS t (replaceLoopEnd n r)
instance (ReplaceLoopEnd (SessionSpec orig) (SessionSpec replacement) (SessionSpec result)) =>
    ReplaceLoopEnd (SessionSpec (RecvT t orig)) (SessionSpec replacement) (SessionSpec (RecvT t result)) where
        replaceLoopEnd (RecvS t n) r = RecvS t (replaceLoopEnd n r)
instance (ReplaceLoopEnd (SessionSpec orig) (SessionSpec replacement) (SessionSpec result)) =>
    ReplaceLoopEnd (SessionSpec (SessT t orig)) (SessionSpec replacement) (SessionSpec (SessT t result)) where
        replaceLoopEnd = undefined

class UnrollLoop orig result | orig -> result where
    unroll :: orig -> result
instance (ReplaceLoopEnd (SessionSpec l) (SessionSpec (Loop (SessionSpec l))) (SessionSpec r)) =>
    UnrollLoop (SessionSpec (Loop (SessionSpec l))) (SessionSpec r) where
        unroll (LoopS l) = case out l of
                            (Again myLoop _) -> replaceLoopEnd myLoop (LoopS l)
instance UnrollLoop (SessionSpec (EndT)) (SessionSpec (EndT)) where
    unroll = id
instance UnrollLoop (SessionSpec (LoopEndT)) (SessionSpec (LoopEndT)) where
    unroll = id
instance UnrollLoop (SessionSpec (SendT t n)) (SessionSpec (SendT t n)) where
    unroll = id
instance UnrollLoop (SessionSpec (RecvT t n)) (SessionSpec (RecvT t n)) where
    unroll = id
instance UnrollLoop (SessionSpec (SessT t n)) (SessionSpec (SessT t n)) where
    unroll = undefined

data SessionSpec :: * -> * where
                    RecvS :: (SType t) -> (SessionSpec n) -> SessionSpec (RecvT t n)
                    SendS :: (SType t) -> (SessionSpec n) -> SessionSpec (SendT t n)
                    EndS :: SessionSpec EndT
                    LoopEndS :: SessionSpec LoopEndT
                    LoopS :: (Loop (SessionSpec l)) -> SessionSpec (Loop (SessionSpec l))

instance Show (SessionSpec n) where
    show (RecvS t next) = "RecvS " ++ (show t) ++ " . " ++ (show next)
    show (SendS t next) = "SendS " ++ (show t) ++ " . " ++ (show next)
    show (EndS) = "EndS"
    show (LoopEndS) = "LoopEndS"
    show (LoopS l) = case out l of
                      (Again myloop _) -> "LoopS {" ++ (show myloop) ++ "}"

class DualT a b | a -> b, b -> a where
    dual :: a -> b
instance DualT (SessionSpec EndT) (SessionSpec EndT) where
    dual EndS = EndS
instance DualT (SessionSpec LoopEndT) (SessionSpec LoopEndT) where
    dual LoopEndS = LoopEndS
instance (DualT (SessionSpec l) (SessionSpec r)) =>
    DualT (SessionSpec (RecvT t l)) (SessionSpec (SendT t r)) where
        dual (RecvS t n) = (SendS t (dual n))
instance (DualT (SessionSpec l) (SessionSpec r)) =>
    DualT (SessionSpec (SendT t l)) (SessionSpec (RecvT t r)) where
        dual (SendS t n) = (RecvS t (dual n))
instance (DualT (SessionSpec l) (SessionSpec r)) =>
    DualT (SessionSpec (Loop (SessionSpec l))) (SessionSpec (Loop (SessionSpec r))) where
        dual (LoopS l) = case out l of
                          (Again myLoop loopTail) ->
                              let myLoop' = dual myLoop
                                  l' = again myLoop' l'
                              in LoopS l'

class NextOp a b | a -> b where
    nextOp :: a -> b
instance NextOp (SessionSpec (RecvT t l)) (SessionSpec l) where
    nextOp (RecvS t n) = n
instance NextOp (SessionSpec (SendT t l)) (SessionSpec l) where
    nextOp (SendS t n) = n
instance NextOp (SessionSpec (SessT t l)) (SessionSpec l) where
    nextOp = undefined

class ZeroOrMoreSteps a b where
    stepN :: a -> b
instance ZeroOrMoreSteps (SessionSpec EndT) (SessionSpec EndT) where
    stepN x = x
instance ZeroOrMoreSteps (SessionSpec LoopEndT) (SessionSpec LoopEndT) where
    stepN x = x
instance ZeroOrMoreSteps (SessionSpec (RecvT t l)) (SessionSpec (RecvT t l)) where
    stepN x = x
instance ZeroOrMoreSteps (SessionSpec (SendT t l)) (SessionSpec (SendT t l)) where
    stepN x = x
instance ZeroOrMoreSteps (SessionSpec (Loop (SessionSpec a))) (SessionSpec (Loop (SessionSpec a))) where
    stepN x = x
instance ZeroOrMoreSteps (SessionSpec (RecvT t l)) (SessionSpec l) where
    stepN (RecvS t n) = n
instance ZeroOrMoreSteps (SessionSpec (SendT t l)) (SessionSpec l) where
    stepN (SendS t n) = n
instance ZeroOrMoreSteps (SessionSpec (SessT t l)) (SessionSpec l) where
    stepN = undefined
instance (ZeroOrMoreSteps (SessionSpec a) (SessionSpec b)) =>
    ZeroOrMoreSteps (SessionSpec (RecvT t a)) (SessionSpec b) where
        stepN (RecvS t n) = stepN n
instance (ZeroOrMoreSteps (SessionSpec a) (SessionSpec b)) =>
    ZeroOrMoreSteps (SessionSpec (SendT t a)) (SessionSpec b) where
        stepN (SendS t n) = stepN n
instance (ZeroOrMoreSteps (SessionSpec a) (SessionSpec b)) =>
    ZeroOrMoreSteps (SessionSpec (SessT t a)) (SessionSpec b) where
        stepN = undefined
instance (UnrollLoop (SessionSpec (Loop (SessionSpec a))) (SessionSpec b),
          ZeroOrMoreSteps (SessionSpec b) (SessionSpec c)) =>
    ZeroOrMoreSteps (SessionSpec (Loop (SessionSpec a))) (SessionSpec c) where
        stepN = stepN . unroll
instance (ZeroOrMoreSteps (SessionSpec a) (SessionSpec b)) =>
    ZeroOrMoreSteps (SessionSpec (Loop (SessionSpec a))) (SessionSpec (Loop (SessionSpec b))) where
        stepN = undefined

class CommTy a b | a -> b where
instance CommTy (SessionSpec (SendT t n)) t where
instance CommTy (SessionSpec (RecvT t n)) t where
instance CommTy (SessT t n) t where

fooDual :: (DualT l r) => l -> r -> Bool
fooDual _ _ = True

fooNext :: (NextOp l r) => l -> r -> Bool
fooNext _ _ = True

fooZeroPlus :: (ZeroOrMoreSteps a b) => a -> b -> Bool
fooZeroPlus _ _ = True

fooCommTy :: (CommTy a b) => a -> b -> Bool
fooCommTy _ _ = True

class JustSendsRecvs orig sends recvs | orig -> sends recvs, orig sends -> recvs, orig recvs -> sends where
instance JustSendsRecvs (SessionSpec EndT) (SessionSpec EndT) (SessionSpec EndT) where
instance JustSendsRecvs (SessionSpec LoopEndT) (SessionSpec LoopEndT) (SessionSpec LoopEndT) where
instance (JustSendsRecvs (SessionSpec n) (SessionSpec s) (SessionSpec r)) =>
    JustSendsRecvs (SessionSpec (SendT t n)) (SessionSpec (SessT t s)) (SessionSpec r) where
instance (JustSendsRecvs (SessionSpec n) (SessionSpec s) (SessionSpec r)) =>
    JustSendsRecvs (SessionSpec (RecvT t n)) (SessionSpec s) (SessionSpec (SessT t r)) where
instance (JustSendsRecvs (SessionSpec l) (SessionSpec s) (SessionSpec r)) =>
    JustSendsRecvs (SessionSpec (Loop (SessionSpec l))) (SessionSpec (Loop (SessionSpec s))) (SessionSpec (Loop (SessionSpec r))) where
instance JustSendsRecvs (SessionState s o i) (SessionSpec o) (SessionSpec i) where

data Cell :: * -> * where
             Cell :: t -> MVar (Cell ct) -> Cell (SessT t ct)
             Branch :: (UnrollLoop (SessionSpec l) (SessionSpec u)) => MVar (Cell u) -> Cell l

data SessionState :: * -> * -> * -> * where
                     SessionState :: (JustSendsRecvs spec (SessionSpec outgoing) (SessionSpec incoming)) =>
                            MVar (Cell outgoing) -> MVar (Cell incoming) ->
                            SessionState spec outgoing incoming

loop :: (UnrollLoop (SessionSpec (Loop (SessionSpec s))) (SessionSpec s'),
         UnrollLoop (SessionSpec (Loop (SessionSpec ol))) (SessionSpec o),
         UnrollLoop (SessionSpec (Loop (SessionSpec il))) (SessionSpec i),
         JustSendsRecvs (SessionSpec s') (SessionSpec o) (SessionSpec i)) =>
        SessionState (SessionSpec (Loop (SessionSpec s))) (Loop (SessionSpec ol)) (Loop (SessionSpec il)) ->
        IO ((), SessionState (SessionSpec s') o i)
loop (SessionState o i) = do { newEmptyO <- newEmptyMVar
                             ; didPutO <- tryPutMVar o (Branch newEmptyO)
                             ; newEmptyI <- newEmptyMVar
                             ; didPutI <- tryPutMVar i (Branch newEmptyI)
                             ; return ((), SessionState newEmptyO newEmptyI)
                             }


send :: (NextOp (SessionSpec (SendT t s)) (SessionSpec s),
         NextOp (SessionSpec (SessT t o)) (SessionSpec o),
         ZeroOrMoreSteps (SessionSpec (SendT t s)) (SessionSpec s),
         JustSendsRecvs (SessionSpec (SendT t s)) (SessionSpec (SessT t o)) (SessionSpec i),
         JustSendsRecvs (SessionSpec s) (SessionSpec o) (SessionSpec i)) =>
        t -> SessionState (SessionSpec (SendT t s)) (SessT t o) i ->
        IO ((), SessionState (SessionSpec s) o i)
send val (SessionState o i) = do { newEmpty <- newEmptyMVar
                                 ; putMVar o (Cell val newEmpty)
                                 ; return ((), SessionState newEmpty i)
                                 }

recv :: (NextOp (SessionSpec (RecvT t s)) (SessionSpec s),
         NextOp (SessionSpec (SessT t i)) (SessionSpec i),
         ZeroOrMoreSteps (SessionSpec (RecvT t s)) (SessionSpec s),
         JustSendsRecvs (SessionSpec (RecvT t s)) (SessionSpec o) (SessionSpec (SessT t i)),
         JustSendsRecvs (SessionSpec s) (SessionSpec o) (SessionSpec i)) =>
        SessionState (SessionSpec (RecvT t s)) o (SessT t i) ->
        IO (t, SessionState (SessionSpec s) o i)
recv (SessionState o i) = do { (Cell val next) <- takeMVar i
                             ; return (val, SessionState o next)
                             }

mkState :: (JustSendsRecvs spec (SessionSpec outgoing) (SessionSpec incoming),
            JustSendsRecvs spec' (SessionSpec incoming) (SessionSpec outgoing),
            DualT spec spec') => spec ->
           IO (SessionState spec outgoing incoming, SessionState spec' incoming outgoing)
mkState _ = do { outgoing <- newEmptyMVar
               ; incoming <- newEmptyMVar
               ; let s1 = SessionState outgoing incoming
               ; let s2 = SessionState incoming outgoing
               ; return (s1, s2)
               }

data Proc :: * -> * -> * -> * -> * -> * -> * where
             Proc :: (ZeroOrMoreSteps s s',
                      JustSendsRecvs s (SessionSpec o) (SessionSpec i),
                      JustSendsRecvs s' (SessionSpec o') (SessionSpec i')) =>
                     (SessionState s o i -> IO ((), SessionState s' o' i')) -> Proc s s' o o' i i'

runSessionWithProcs :: (JustSendsRecvs spec (SessionSpec outgoing) (SessionSpec incoming),
                        JustSendsRecvs specD (SessionSpec incoming) (SessionSpec outgoing),
                        DualT spec specD,
                        JustSendsRecvs spec' (SessionSpec outgoing') (SessionSpec incoming'),
                        JustSendsRecvs specD' (SessionSpec incoming') (SessionSpec outgoing'),
                        ZeroOrMoreSteps spec spec',
                        ZeroOrMoreSteps specD specD'
                       ) =>
                       spec
                    -> Proc spec spec' outgoing outgoing' incoming incoming'
                    -> Proc specD specD' incoming incoming' outgoing outgoing'
                    -> IO spec'
runSessionWithProcs spec (Proc p1) (Proc p2)
    = do { (s1, s2) <- mkState spec
         ; let f1 = (p1 s1) >>= return . fst
         ; let f2 = (p2 s2) >>= return . fst
         ; forkIO f1
         ; forkIO f2
         ; return (stepN spec)
         }

-- log through stdout and send
sendLog :: (NextOp (SessionSpec (SendT t s)) (SessionSpec s),
            ZeroOrMoreSteps (SessionSpec (SendT t s)) (SessionSpec s),
            JustSendsRecvs (SessionSpec (SendT t s)) (SessionSpec (SessT t o)) (SessionSpec i),
            JustSendsRecvs (SessionSpec s) (SessionSpec o) (SessionSpec i),
            Show t) =>
           t -> SessionState (SessionSpec (SendT t s)) (SessT t o) i
             -> IO ((), SessionState (SessionSpec s) o i)
sendLog v = sliftIO (putStrLn $ "Sending '" ++ (show v) ++ "'")
            >~> send v

-- receive and log through stdout
recvLog :: (NextOp (SessionSpec (RecvT t s)) (SessionSpec s),
            ZeroOrMoreSteps (SessionSpec (RecvT t s)) (SessionSpec s),
            JustSendsRecvs (SessionSpec s) (SessionSpec o) (SessionSpec i),
            JustSendsRecvs (SessionSpec (RecvT t s)) (SessionSpec o) (SessionSpec (SessT t i)),
            ZeroOrMoreSteps (SessionSpec s) (SessionSpec s),
            Show t) =>
           SessionState (SessionSpec (RecvT t s)) o (SessT t i)
               -> IO (t, SessionState (SessionSpec s) o i)
recvLog = recv
          >~>= (\r -> sliftIO (putStrLn $ "Received '" ++ (show r) ++ "'")
                >~> returnS r)

infixl 1 >~>
infixl 1 >~>=

-- (>>) :: (Monad m) => m b -> m c -> m c
-- (>~>) serves the same purpose as (>>). So it's basically
-- on the lines of Session-from-a-to-b-with-type-n ->
--                 Session-from-b-to-c-with-type-m ->
--                 Session-from-a-to-c-with-type-m
(>~>) :: (JustSendsRecvs (SessionSpec a) (SessionSpec o) (SessionSpec i),
          JustSendsRecvs (SessionSpec b) (SessionSpec o') (SessionSpec i'),
          JustSendsRecvs (SessionSpec c) (SessionSpec o'') (SessionSpec i''),
          ZeroOrMoreSteps (SessionSpec a) (SessionSpec b),
          ZeroOrMoreSteps (SessionSpec b) (SessionSpec c),
          ZeroOrMoreSteps (SessionSpec a) (SessionSpec c)) =>
         ((SessionState (SessionSpec a) o i) ->
          IO (r, (SessionState (SessionSpec b) o' i'))) ->
         ((SessionState (SessionSpec b) o' i') ->
          IO (r', (SessionState (SessionSpec c) o'' i''))) ->
         ((SessionState (SessionSpec a) o i) ->
          IO (r', (SessionState (SessionSpec c) o'' i'')))
f >~> g = \s -> f s >>= \(_, s') -> g s'

-- (>>=) :: (Monad m) => m b -> (b -> m c) -> m c
-- (>~>=) serves the same purpose as (>>=). So it's basically
-- on the lines of Session-from-a-to-b-with-type-n ->
--                 (n -> Session-from-b-to-c-with-type-m) ->
--                 Session-from-a-to-c-with-type-m
(>~>=) :: (JustSendsRecvs (SessionSpec a) (SessionSpec o) (SessionSpec i),
           JustSendsRecvs (SessionSpec b) (SessionSpec o') (SessionSpec i'),
           JustSendsRecvs (SessionSpec c) (SessionSpec o'') (SessionSpec i''),
           ZeroOrMoreSteps (SessionSpec a) (SessionSpec b),
           ZeroOrMoreSteps (SessionSpec b) (SessionSpec c),
           ZeroOrMoreSteps (SessionSpec a) (SessionSpec c)) =>
          ((SessionState (SessionSpec a) o i) ->
           IO (r, (SessionState (SessionSpec b) o' i'))) ->
          (r -> (SessionState (SessionSpec b) o' i') ->
           IO (r', (SessionState (SessionSpec c) o'' i''))) ->
          ((SessionState (SessionSpec a) o i) ->
           IO (r', (SessionState (SessionSpec c) o'' i'')))
f >~>= g = \s -> f s >>= \(r, s') -> g r s'

-- return :: (Monad m) => a -> m a
-- returnS is pretty much the some as return, lifting a value into a session
-- returnS :: r -> Session-from-a-to-a-with-type-r
returnS :: (JustSendsRecvs (SessionSpec a) (SessionSpec o) (SessionSpec i),
            ZeroOrMoreSteps (SessionSpec a) (SessionSpec a)) =>
           r -> (SessionState (SessionSpec a) o i) ->
           IO (r, (SessionState (SessionSpec a) o i))
returnS v s = return (v, s)

-- lifting just lift an IO straight in:
-- sliftIO :: IO r -> Session-from-a-to-a-with-type-r
sliftIO :: (JustSendsRecvs (SessionSpec a) (SessionSpec o) (SessionSpec i),
            ZeroOrMoreSteps (SessionSpec a) (SessionSpec a)) =>
           IO r -> ((SessionState (SessionSpec a) o i) ->
                    IO (r, (SessionState (SessionSpec a) o i)))
sliftIO f = \s -> f >>= \r -> returnS r s

-- for when the inner function wants to grab a value from outside...
-- sliftIO' :: (m -> IO n) -> (m -> Session-from-a-to-a-with-type n)
sliftIO' :: (JustSendsRecvs (SessionSpec a) (SessionSpec o) (SessionSpec i),
             ZeroOrMoreSteps (SessionSpec a) (SessionSpec a)) =>
            (r -> IO r') -> (r -> (SessionState (SessionSpec a) o i) ->
                             IO (r', (SessionState (SessionSpec a) o i)))
sliftIO' f = \r -> sliftIO (f r)


mkLoop :: (UnrollLoop (SessionSpec (Loop (SessionSpec s))) (SessionSpec s'),
           UnrollLoop (SessionSpec (Loop (SessionSpec ol))) (SessionSpec o),
           UnrollLoop (SessionSpec (Loop (SessionSpec il))) (SessionSpec i),
           JustSendsRecvs (SessionSpec s') (SessionSpec o) (SessionSpec i),
           ZeroOrMoreSteps (SessionSpec s') (SessionSpec (Loop (SessionSpec s)))) =>
         ((SessionState (SessionSpec s') o i) ->
           IO ((), (SessionState (SessionSpec (Loop (SessionSpec s))) (Loop (SessionSpec ol)) (Loop (SessionSpec il))))) ->
          (SessionState (SessionSpec (Loop (SessionSpec s))) (Loop (SessionSpec ol)) (Loop (SessionSpec il))) ->
          IO ((), SessionState (SessionSpec EndT) EndT EndT)
mkLoop f s = loop s >>= \((), s') -> f s' >>= \((), s'') -> mkLoop f s''
