{-# OPTIONS_GHC -fglasgow-exts -fallow-overlapping-instances -fallow-undecidable-instances #-}

module Session where

import Control.Concurrent.MVar
import Control.Concurrent

data SType :: * -> * where
              IntS    :: SType Int
              BoolS   :: SType Bool
              CharS   :: SType Char
              StringS :: SType String
              FloatS  :: SType Float
              DoubleS :: SType Double
              AnyS    :: SType a
              LiftS   :: a -> SType a
              ListS   :: (SType a) -> SType [a]

instance Show (SType a) where
    show IntS      = "IntS"
    show BoolS     = "BoolS"
    show CharS     = "CharS"
    show StringS   = "StringS"
    show FloatS    = "FloatS"
    show DoubleS   = "DoubleS"
    show AnyS      = "AnyS"
    show (LiftS a) = "LiftS"
    show (ListS a) = "ListS of " ++ (show a)

type Loop a = [a]

mkLoopS :: (SessionSpec LoopEndT -> SessionSpec l) -> SessionSpec (Loop (SessionSpec l))
mkLoopS x = LoopS (repeat xl)
    where
      xl = x LoopEndS

data SendT a n
data RecvT a n
data EndT
data LoopEndT
data SessT a n

data SessionSpec :: * -> * where
                    RecvS :: (SType t) -> (SessionSpec n) -> SessionSpec (RecvT t n)
                    SendS :: (SType t) -> (SessionSpec n) -> SessionSpec (SendT t n)
                    EndS :: SessionSpec EndT
                    LoopEndS :: SessionSpec LoopEndT
                    LoopS :: (Loop (SessionSpec l)) -> SessionSpec (Loop (SessionSpec l))

instance Show (SessionSpec n) where
    show (RecvS t next) = "RecvS " ++ (show t) ++ " . " ++ (show next)
    show (SendS t next) = "SendS " ++ (show t) ++ " . " ++ (show next)
    show (EndS) = "EndS"
    show (LoopEndS) = "LoopEndS"
    show (LoopS l) = case l of
                      (myloop:_) -> "LoopS {" ++ (show myloop) ++ "}"


class ReplaceLoopEnd orig replacement result | orig replacement -> result where
    replaceLoopEnd :: orig -> replacement -> result
instance ReplaceLoopEnd (SessionSpec EndT) (SessionSpec n) (SessionSpec EndT) where
    replaceLoopEnd EndS r = EndS
instance ReplaceLoopEnd (SessionSpec LoopEndT) (SessionSpec n) (SessionSpec n) where
    replaceLoopEnd LoopEndS r = r
instance ReplaceLoopEnd (SessionSpec (Loop (SessionSpec l))) (SessionSpec replacement) (SessionSpec (Loop (SessionSpec l))) where
    replaceLoopEnd (LoopS l) r = (LoopS l)
instance (ReplaceLoopEnd (SessionSpec orig) (SessionSpec replacement) (SessionSpec result)) =>
    ReplaceLoopEnd (SessionSpec (SendT t orig)) (SessionSpec replacement) (SessionSpec (SendT t result)) where
        replaceLoopEnd (SendS t n) r = SendS t (replaceLoopEnd n r)
instance (ReplaceLoopEnd (SessionSpec orig) (SessionSpec replacement) (SessionSpec result)) =>
    ReplaceLoopEnd (SessionSpec (RecvT t orig)) (SessionSpec replacement) (SessionSpec (RecvT t result)) where
        replaceLoopEnd (RecvS t n) r = RecvS t (replaceLoopEnd n r)
instance (ReplaceLoopEnd (SessionSpec orig) (SessionSpec replacement) (SessionSpec result)) =>
    ReplaceLoopEnd (SessionSpec (SessT t orig)) (SessionSpec replacement) (SessionSpec (SessT t result)) where
        replaceLoopEnd = undefined

class UnrollLoop orig result | orig -> result where
    unroll :: orig -> result
instance (ReplaceLoopEnd (SessionSpec l) (SessionSpec (Loop (SessionSpec l))) (SessionSpec r)) =>
    UnrollLoop (SessionSpec (Loop (SessionSpec l))) (SessionSpec r) where
        unroll (LoopS l) = case l of
                            (myLoop:_) -> replaceLoopEnd myLoop (LoopS l)
instance UnrollLoop (SessionSpec (EndT)) (SessionSpec (EndT)) where
    unroll = id
instance UnrollLoop (SessionSpec (LoopEndT)) (SessionSpec (LoopEndT)) where
    unroll = id
instance UnrollLoop (SessionSpec (SendT t n)) (SessionSpec (SendT t n)) where
    unroll = id
instance UnrollLoop (SessionSpec (RecvT t n)) (SessionSpec (RecvT t n)) where
    unroll = id
instance UnrollLoop (SessionSpec (SessT t n)) (SessionSpec (SessT t n)) where
    unroll = undefined

class ZeroOrMoreSteps a b where
    stepN :: a -> b
instance ZeroOrMoreSteps (SessionSpec EndT) (SessionSpec EndT) where
    stepN x = x
instance ZeroOrMoreSteps (SessionSpec LoopEndT) (SessionSpec LoopEndT) where
    stepN x = x
instance ZeroOrMoreSteps (SessionSpec (RecvT t l)) (SessionSpec (RecvT t l)) where
    stepN x = x
instance ZeroOrMoreSteps (SessionSpec (SendT t l)) (SessionSpec (SendT t l)) where
    stepN x = x
instance ZeroOrMoreSteps (SessionSpec (Loop (SessionSpec a))) (SessionSpec (Loop (SessionSpec a))) where
    stepN x = x
instance ZeroOrMoreSteps (SessionSpec (RecvT t l)) (SessionSpec l) where
    stepN (RecvS t n) = n
instance ZeroOrMoreSteps (SessionSpec (SendT t l)) (SessionSpec l) where
    stepN (SendS t n) = n
instance ZeroOrMoreSteps (SessionSpec (SessT t l)) (SessionSpec l) where
    stepN = undefined
instance (ZeroOrMoreSteps (SessionSpec a) (SessionSpec b)) =>
    ZeroOrMoreSteps (SessionSpec (RecvT t a)) (SessionSpec b) where
        stepN (RecvS t n) = stepN n
instance (ZeroOrMoreSteps (SessionSpec a) (SessionSpec b)) =>
    ZeroOrMoreSteps (SessionSpec (SendT t a)) (SessionSpec b) where
        stepN (SendS t n) = stepN n
instance (ZeroOrMoreSteps (SessionSpec a) (SessionSpec b)) =>
    ZeroOrMoreSteps (SessionSpec (SessT t a)) (SessionSpec b) where
        stepN = undefined
instance (UnrollLoop (SessionSpec (Loop (SessionSpec a))) (SessionSpec b),
          ZeroOrMoreSteps (SessionSpec b) (SessionSpec c)) =>
    ZeroOrMoreSteps (SessionSpec (Loop (SessionSpec a))) (SessionSpec c) where
        stepN = stepN . unroll

class JustSendsRecvs orig sends recvs | orig -> sends recvs where
instance JustSendsRecvs (SessionSpec EndT) (SessionSpec EndT) (SessionSpec EndT) where
instance JustSendsRecvs (SessionSpec LoopEndT) (SessionSpec LoopEndT) (SessionSpec LoopEndT) where
instance (JustSendsRecvs (SessionSpec n) (SessionSpec s) (SessionSpec r)) =>
    JustSendsRecvs (SessionSpec (SendT t n)) (SessionSpec (SessT t s)) (SessionSpec r) where
instance (JustSendsRecvs (SessionSpec n) (SessionSpec s) (SessionSpec r)) =>
    JustSendsRecvs (SessionSpec (RecvT t n)) (SessionSpec s) (SessionSpec (SessT t r)) where
instance (JustSendsRecvs (SessionSpec l) (SessionSpec s) (SessionSpec r)) =>
    JustSendsRecvs (SessionSpec (Loop (SessionSpec l))) (SessionSpec (Loop (SessionSpec s))) (SessionSpec (Loop (SessionSpec r))) where
instance JustSendsRecvs (SessionState s o i) (SessionSpec o) (SessionSpec i) where

data Cell :: * -> * where
             Cell :: t -> MVar (Cell ct) -> Cell (SessT t ct)
             Branch :: (UnrollLoop (SessionSpec l) (SessionSpec u)) => MVar (Cell u) -> Cell l

data SessionState :: * -> * -> * -> * where
                     SessionState :: (JustSendsRecvs spec (SessionSpec outgoing) (SessionSpec incoming)) =>
                            MVar (Cell outgoing) -> MVar (Cell incoming) ->
                            SessionState spec outgoing incoming

loop :: (UnrollLoop (SessionSpec (Loop (SessionSpec s))) (SessionSpec s'),
         UnrollLoop (SessionSpec (Loop (SessionSpec ol))) (SessionSpec o),
         UnrollLoop (SessionSpec (Loop (SessionSpec il))) (SessionSpec i),
         JustSendsRecvs (SessionSpec s') (SessionSpec o) (SessionSpec i)) =>
        SessionState (SessionSpec (Loop (SessionSpec s))) (Loop (SessionSpec ol)) (Loop (SessionSpec il)) ->
        IO ((), SessionState (SessionSpec s') o i)
-- TODO: make the code below do something that will work!
loop (SessionState o i) = do { newEmptyO <- newEmptyMVar
                             ; didPutO <- tryPutMVar o (Branch newEmptyO)
                             ; newEmptyI <- newEmptyMVar
                             ; didPutI <- tryPutMVar i (Branch newEmptyI)
                             ; return ((), SessionState newEmptyO newEmptyI)
                             }

data Proc :: * -> * -> * -> * -> * -> * -> * where
             Proc :: (ZeroOrMoreSteps s s',
                      JustSendsRecvs s (SessionSpec o) (SessionSpec i),
                      JustSendsRecvs s' (SessionSpec o') (SessionSpec i')) =>
                     (SessionState s o i -> IO ((), SessionState s' o' i')) -> Proc s s' o o' i i'

-- return :: (Monad m) => a -> m a
-- returnS is pretty much the some as return, lifting a value into a session
-- returnS :: r -> Session-from-a-to-a-with-type-r
returnS :: (JustSendsRecvs (SessionSpec a) (SessionSpec o) (SessionSpec i),
            ZeroOrMoreSteps (SessionSpec a) (SessionSpec a)) =>
           r -> (SessionState (SessionSpec a) o i) ->
           IO (r, (SessionState (SessionSpec a) o i))
returnS v s = return (v, s)

mkLoop :: (UnrollLoop (SessionSpec (Loop (SessionSpec s))) (SessionSpec s'),
           UnrollLoop (SessionSpec (Loop (SessionSpec ol))) (SessionSpec o),
           UnrollLoop (SessionSpec (Loop (SessionSpec il))) (SessionSpec i),
           JustSendsRecvs (SessionSpec s') (SessionSpec o) (SessionSpec i),
           ZeroOrMoreSteps (SessionSpec s') (SessionSpec (Loop (SessionSpec s)))) =>
         ((SessionState (SessionSpec s') o i) ->
           IO ((), (SessionState (SessionSpec (Loop (SessionSpec s))) (Loop (SessionSpec ol)) (Loop (SessionSpec il))))) ->
          (SessionState (SessionSpec (Loop (SessionSpec s))) (Loop (SessionSpec ol)) (Loop (SessionSpec il))) ->
          IO ((), SessionState (SessionSpec EndT) EndT EndT)
mkLoop f s = loop s >>= \((), s') -> f s' >>= \((), s'') -> mkLoop f s''
