{-# LANGUAGE DataKinds, TypeOperators, GADTs, MultiParamTypeClasses, LambdaCase
           , FlexibleInstances, UndecidableInstances, DeriveFunctor, StandaloneDeriving
           , TypeFamilies #-}
-- | This module exposes the base @Signal@ and @SignalT@ types, but you probably want to use the
--   "Control.Monad.Signal.Class" module because it also re-exports these types.
module Control.Monad.Signal (Signal(..), signal, liftSignal, SignalT(..)) where

import Data.Union
import Control.Monad.Fix
import Control.Monad.Trans.Class
import Control.Monad.IO.Class
import Control.Monad.State.Class
import Control.Monad.Writer.Class
import Control.Monad.Reader.Class
import Control.Monad.Cont.Class
import GHC.Base

-- | A @Signal sigs a@ is either a @Union sigs@ or a value of type @a@.
--   It's basically the same as @Union (a ': sigs)@ but the @a@ parameter is special because
--   it's used in the Functor\/Applicative\/Monad instances.k
data Signal sigs a = Value a | Signal (Union sigs)

deriving instance (Eq a, Eq (Union sigs)) => Eq (Signal sigs a)
deriving instance (Ord a, Ord (Union sigs)) => Ord (Signal sigs a)
deriving instance (Read a, Read (Union sigs)) => Read (Signal sigs a)
instance (ShowSignal sigs, Show a) => Show (Signal sigs a) where
    show (Value a) = "Value (" ++ show a ++ ")"
    show (Signal u) = "Signal (" ++ showSignal u ++ ")"

-- | Show only the existing value of a Union.
class ShowSignal sigs where
    showSignal :: Union sigs -> String

instance (Show a, ShowSignal as) => ShowSignal (a ': as) where
    showSignal (Union (Right a)) = show a
    showSignal (Union (Left u))  = showSignal u

instance ShowSignal '[] where
    showSignal _ = "absurd"

-- | Send a single signal.
signal :: Elem sigs a => a -> Signal sigs b
signal s = Signal (liftSingle s)

-- | Re-send all the signals into a \"looser\" @Signal@
liftSignal :: Subset as bs => Signal as a -> Signal bs a
liftSignal (Value a) = Value a
liftSignal (Signal u) = Signal (liftUnion u)

instance Functor (Signal sigs) where
    fmap _ (Signal u) = Signal u
    fmap f (Value a)  = Value (f a)

instance Applicative (Signal sigs) where
    pure = Value
    Signal u <*> _        = Signal u
    Value f  <*> Value x  = Value (f x)
    Value _  <*> Signal u = Signal u

instance Monad (Signal sigs) where
    return = pure
    Signal u >>= _ = Signal u
    Value a  >>= f = f a

instance Foldable (Signal sigs) where
    foldMap f (Value a) = f a
    foldMap _ (Signal _) = mempty

instance Traversable (Signal sigs) where
    sequenceA (Value fa) = fmap pure fa
    sequenceA (Signal u) = pure (Signal u)

-- | A transformer version of @Signal@.
newtype SignalT sigs m a = SignalT { runSignalT :: m (Signal sigs a) } deriving Functor

deriving instance Eq (m (Signal sigs a)) => Eq (SignalT sigs m a)
deriving instance Ord (m (Signal sigs a)) => Ord (SignalT sigs m a)
deriving instance Read (m (Signal sigs a)) => Read (SignalT sigs m a)
deriving instance Show (m (Signal sigs a)) => Show (SignalT sigs m a)

instance Applicative m => Applicative (SignalT sigs m) where
    pure x = SignalT (pure (Value x))
    SignalT mcab <*> SignalT mca = SignalT mcb
        where mcacb = fmap (<*>) mcab
              mcb = mcacb <*> mca

instance Monad m => Monad (SignalT sigs m) where
    return = pure
    SignalT mca >>= f = SignalT $ mca >>= \case
        Value a  -> runSignalT (f a)
        Signal u -> return (Signal u)

instance MonadTrans (SignalT sigs) where
    lift = SignalT . fmap Value

-- | Not sure if correct
instance MonadFix m => MonadFix (SignalT sigs m) where
    mfix f = SignalT (mfix f')
        where f' (Value a)  = runSignalT (f a)
              f' (Signal u) = return (Signal u)

instance Foldable m => Foldable (SignalT sigs m) where
    foldMap f (SignalT mca) = foldMap f' mca
        where f' (Value x)  = f x
              f' (Signal _) = mempty

instance Traversable m => Traversable (SignalT sigs m) where
    traverse afb (SignalT mca) = SignalT <$> traverse (traverse afb) mca

instance MonadIO m => MonadIO (SignalT sigs m) where
    liftIO = lift . liftIO

instance MonadState s m => MonadState s (SignalT sigs m) where
    state = lift . state

instance MonadWriter w m => MonadWriter w (SignalT sigs m) where
    tell = lift . tell
    listen (SignalT mca) = SignalT $ fmap (\(ca, w) -> fmap (\a -> (a, w)) ca) (listen mca)
    pass (SignalT mcaww) = SignalT $ pass (fmap f mcaww)
        where f (Value (a, ww)) = (Value a, ww)
              f (Signal u)      = (Signal u, id)

instance MonadReader r m => MonadReader r (SignalT sigs m) where
    reader = lift . reader
    local rr (SignalT mca) = SignalT (local rr mca)

instance MonadCont m => MonadCont (SignalT sigs m) where
    callCC f = SignalT $ callCC $ \c -> runSignalT (f (SignalT . c . Value))