-- |
-- Module:     Control.Wire.Types
-- Copyright:  (c) 2011 Ertugrul Soeylemez
-- License:    BSD3
-- Maintainer: Ertugrul Soeylemez <es@ertes.de>
--
-- Types used in the netwire library.

module Control.Wire.Types
    ( -- * The wire
      Wire(..),
      WireM,

      -- * Construction and destruction
      WireGen(..),
      WirePure(..),
      WireToGen(..),
      mkFixM,
      toGenM,

      -- * Inhibition
      LastException,
      inhibitException,
      inhibitMsg,

      -- * Utilities
      mapInputM
    )
    where

import qualified Control.Exception as Ex
import Control.Applicative
import Control.Arrow
import Control.Arrow.Operations
import Control.Arrow.Transformer
import Control.Category
import Control.Monad
import Control.Monad.Fix
import Control.Monad.Reader.Class
import Control.Monad.State.Class
import Control.Monad.Writer.Class
import Control.Wire.Classes
import Data.Monoid
import Prelude hiding ((.), id)


-- | Convenience type for wire exceptions.

type LastException = Last Ex.SomeException


-- | Signal networks.

data family Wire :: * -> (* -> * -> *) -> * -> * -> *

data instance Wire e (Kleisli m) a b where
    WmGen  :: (a -> m (Either e b, Wire e (Kleisli m) a b)) -> Wire e (Kleisli m) a b
    WmPure :: (a -> (Either e b, Wire e (Kleisli m) a b)) -> Wire e (Kleisli m) a b


-- | Choice at the functor level.

instance (Monad m, Monoid e) => Alternative (Wire e (Kleisli m) a) where
    empty = zeroArrow
    (<|>) = (<+>)


-- | Map a function signal over the output signal.

instance Monad m => Applicative (Wire e (Kleisli m) a) where
    pure = mkPureFix . const . Right

    WmPure ff <*> wx'@(WmPure fx) =
        WmPure $ \x' ->
            case ff x' of
              (Left ex, wf) -> (Left ex, wf <*> wx')
              (Right f, wf) ->
                  let (mx, wx) = fx x'
                  in (fmap f mx, wf <*> wx)

    WmPure ff <*> wx'@(WmGen fx) =
        WmGen $ \x' ->
            case ff x' of
              (Left ex, wf) -> return (Left ex, wf <*> wx')
              (Right f, wf) -> liftM (fmap f *** (wf <*>)) (fx x')

    WmGen ff <*> wx'@(WmPure fx) =
        WmGen $ \x' -> do
            (mf, wf) <- ff x'
            return $
                case mf of
                  Left ex -> (Left ex, wf <*> wx')
                  Right f ->
                      let (mx, wx) = fx x'
                      in (fmap f mx, wf <*> wx)

    WmGen ff <*> wx'@(WmGen fx) =
        WmGen $ \x' -> do
            (mf, wf) <- ff x'
            case mf of
              Left ex -> return (Left ex, wf <*> wx')
              Right f -> liftM (fmap f *** (wf <*>)) (fx x')


-- | Wire side channels.

instance Monad m => Arrow (Wire e (Kleisli m)) where
    arr f = mkPureFix $ Right . f

    first (WmGen c) =
        WmGen $ \(x', y) -> do
            (mx, w) <- c x'
            return (fmap (, y) mx, first w)
    first (WmPure f) =
        WmPure $ \(x', y) ->
            let (mx, w) = f x'
            in (fmap (, y) mx, first w)

    second (WmGen c) =
        WmGen $ \(x, y') -> do
            (my, w) <- c y'
            return (fmap (x,) my, second w)
    second (WmPure f) =
        WmPure $ \(x, y') ->
            let (my, w) = f y'
            in (fmap (x,) my, second w)

    -- (&&&) combinator.
    WmGen c1 &&& w2'@(WmGen c2) =
        WmGen $ \x' -> do
            (mx1, w1) <- c1 x'
            case mx1 of
              Left ex  -> return (Left ex, w1 &&& w2')
              Right x1 -> do
                  (mx2, w2) <- c2 x'
                  return (fmap (x1,) mx2, w1 &&& w2)

    WmGen c1 &&& w2'@(WmPure g) =
        WmGen $ \x' -> do
            (mx1, w1) <- c1 x'
            case mx1 of
              Left ex  -> return (Left ex, w1 &&& w2')
              Right x1 ->
                  let (mx2, w2) = g x' in
                  return (fmap (x1,) mx2, w1 &&& w2)

    WmPure f &&& w2'@(WmGen c2) =
        WmGen $ \x' ->
            let (mx1, w1) = f x' in
            case mx1 of
              Left ex  -> return (Left ex, w1 &&& w2')
              Right x1 -> do
                  (mx2, w2) <- c2 x'
                  return (fmap (x1,) mx2, w1 &&& w2)

    WmPure f &&& w2'@(WmPure g) =
        WmPure $ \x' ->
            let (mx1, w1) = f x'
                (mx2, w2) = g x' in
            case mx1 of
              Left ex  -> (Left ex, w1 &&& w2')
              Right x1 -> (fmap (x1,) mx2, w1 &&& w2)

    -- (***) combinator.
    WmGen c1 *** w2'@(WmGen c2) =
        WmGen $ \(x', y') -> do
            (mx, w1) <- c1 x'
            case mx of
              Left ex -> return (Left ex, w1 *** w2')
              Right x -> do
                  (my, w2) <- c2 y'
                  return (fmap (x,) my, w1 *** w2)

    WmGen c1 *** w2'@(WmPure g) =
        WmGen $ \(x', g -> (my, w2)) -> do
            (mx, w1) <- c1 x'
            return $
                case mx of
                  Left ex -> (Left ex, w1 *** w2')
                  Right x -> (fmap (x,) my, w1 *** w2)

    WmPure f *** w2'@(WmGen c2) =
        WmGen $ \(f -> (mx, w1), y') -> do
            case mx of
              Left ex -> return (Left ex, w1 *** w2')
              Right x -> do
                  (my, w2) <- c2 y'
                  return (fmap (x,) my, w1 *** w2)

    WmPure f *** w2'@(WmPure g) =
        WmPure $ \(f -> (mx, w1), g -> (my, w2)) ->
            case mx of
              Left ex -> (Left ex, w1 *** w2')
              Right x -> (fmap (x,) my, w1 *** w2)


-- | Support for choice (signal redirection).

instance Monad m => ArrowChoice (Wire e (Kleisli m)) where
    left w'@(WmPure f) =
        WmPure $ \mx' ->
            case mx' of
              Left x'  -> fmap Left *** left $ f x'
              Right x' -> (Right (Right x'), left w')

    left w'@(WmGen c) =
        WmGen $ \mx' ->
            case mx' of
              Left x'  -> liftM (fmap Left *** left) (c x')
              Right x' -> return (Right (Right x'), left w')

    right w'@(WmPure f) =
        WmPure $ \mx' ->
            case mx' of
              Right x' -> fmap Right *** right $ f x'
              Left x'  -> (Right (Left x'), right w')

    right w'@(WmGen c) =
        WmGen $ \mx' ->
            case mx' of
              Right x' -> liftM (fmap Right *** right) (c x')
              Left x'  -> return (Right (Left x'), right w')

    wl'@(WmPure f) +++ wr'@(WmPure g) =
        WmPure $ \mx' ->
            case mx' of
              Left x'  -> (fmap Left  *** (+++ wr')) . f $ x'
              Right x' -> (fmap Right *** (wl' +++)) . g $ x'

    wl' +++ wr' =
        WmGen $ \mx' ->
            case mx' of
              Left x'  -> liftM (fmap Left  *** (+++ wr')) (toGenM wl' x')
              Right x' -> liftM (fmap Right *** (wl' +++)) (toGenM wr' x')

    wl'@(WmPure f) ||| wr'@(WmPure g) =
        WmPure $ \mx' ->
            case mx' of
              Left x'  -> second (||| wr') . f $ x'
              Right x' -> second (wl' |||) . g $ x'

    wl' ||| wr' =
        WmGen $ \mx' ->
            case mx' of
              Left x'  -> liftM (second (||| wr')) (toGenM wl' x')
              Right x' -> liftM (second (wl' |||)) (toGenM wr' x')


-- | Support for one-instant delays.

instance (MonadFix m, Monoid e) => ArrowCircuit (Wire e (Kleisli m)) where
    delay x' = WmPure $ \x -> (Right x', delay x)


-- | Inhibition handling interface.  See also the
-- "Control.Wire.Trans.Exhibit" and "Control.Wire.Prefab.Event" modules.

instance Monad m => ArrowError e (Wire e (Kleisli m)) where
    raise = mkPureFix Left

    handle (WmPure f) wh'@(WmPure fh) =
        WmPure $ \x' ->
            let (mx, w) = f x' in
            case mx of
              Left ex ->
                  let (mxh, wh) = fh (x', ex)
                  in (mxh, handle w wh)
              Right _ -> (mx, handle w wh')

    handle w' wh' =
        WmGen $ \x' -> do
            (mx, w) <- toGenM w' x'
            case mx of
              Left ex -> do
                  (mxh, wh) <- toGenM wh' (x', ex)
                  return (mxh, handle w wh)
              Right _ -> return (mx, handle w wh')

    newError (WmPure f) = WmPure $ (Right *** newError) . f
    newError (WmGen c) = WmGen $ liftM (Right *** newError) . c

    tryInUnless (WmPure f) ws'@(WmPure fs) we'@(WmPure fe) =
        WmPure $ \x' ->
            let (mx, w) = f x' in
            case mx of
              Left ex ->
                  let (mxe, we) = fe (x', ex)
                  in (mxe, tryInUnless w ws' we)
              Right x ->
                  let (mxs, ws) = fs (x', x)
                  in (mxs, tryInUnless w ws we')

    tryInUnless w' ws' we' =
        WmGen $ \x' -> do
            (mx, w) <- toGenM w' x'
            case mx of
              Left ex -> do
                  (mxe, we) <- toGenM we' (x', ex)
                  return (mxe, tryInUnless w ws' we)
              Right x -> do
                  (mxs, ws) <- toGenM ws' (x', x)
                  return (mxs, tryInUnless w ws we')


-- | When the target arrow is an 'ArrowKleisli', then the wire arrow is
-- also an ArrowKleisli.

instance Monad m => ArrowKleisli m (Wire e (Kleisli m)) where
    arrM = mkFix (Right ^<< arrM)


-- | Value recursion in the wire arrows.  **NOTE**: Wires with feedback
-- must *never* inhibit.  There is an inherent, fundamental problem with
-- handling the inhibition case, which you will observe as a fatal
-- pattern match error.

instance (MonadFix m, Monoid e) => ArrowLoop (Wire e (Kleisli m)) where
    loop w' =
        WmGen $ \x' -> do
            rec (mx, w) <- toGenM w' (x', d)
                let d = either (error "Loop data dependency broken by inhibition") snd mx
            return (fmap fst mx, loop w)


-- | Combining possibly inhibiting wires.

instance (Monad m, Monoid e) => ArrowPlus (Wire e (Kleisli m)) where
    WmGen c1 <+> w2'@(WmGen c2) =
        WmGen $ \x' -> do
            (mx1, w1) <- c1 x'
            case mx1 of
              Right _ -> return (mx1, w1 <+> w2')
              Left ex1 -> do
                  (mx2, w2) <- c2 x'
                  return (mapLeft (mappend ex1) mx2, w1 <+> w2)

    WmGen c1 <+> w2'@(WmPure g) =
        WmGen $ \x' -> do
            (mx1, w1) <- c1 x'
            case mx1 of
              Right _ -> return (mx1, w1 <+> w2')
              Left ex1 ->
                  let (mx2, w2) = g x' in
                  return (mapLeft (mappend ex1) mx2, w1 <+> w2)

    WmPure f <+> w2'@(WmGen c2) =
        WmGen $ \x' ->
            let (mx1, w1) = f x' in
            case mx1 of
              Right _ -> return (mx1, w1 <+> w2')
              Left ex1 -> do
                  (mx2, w2) <- c2 x'
                  return (mapLeft (mappend ex1) mx2, w1 <+> w2)

    WmPure f <+> w2'@(WmPure g) =
        WmPure $ \x' ->
            let (mx1, w1) = f x'
                (mx2, w2) = g x' in
            case mx1 of
              Right _  -> (mx1, w1 <+> w2')
              Left ex1 -> (mapLeft (mappend ex1) mx2, w1 <+> w2)


-- | If the underlying arrow is a reader arrow, then the wire arrow is
-- also a reader arrow.

instance MonadReader r m => ArrowReader r (Wire e (Kleisli m)) where
    readState = mkFixM (const (liftM Right ask))

    newReader (WmPure f) = WmPure (second newReader . f . fst)
    newReader (WmGen c) =
        WmGen $ \(x', env) ->
            liftM (second newReader) (local (const env) (c x'))


-- | If the underlying arrow is a state arrow, then the wire arrow is
-- also a state arrow.

instance MonadState s m => ArrowState s (Wire e (Kleisli m)) where
    fetch = mkFixM (const (liftM Right get))
    store = mkFixM (liftM Right . put)


-- | Wire arrows are arrow transformers.

instance Monad m => ArrowTransformer (Wire e) (Kleisli m) where
    lift (Kleisli f) = mkFixM (liftM Right . f)


-- | If the underlying arrow is a writer arrow, then the wire arrow is
-- also a writer arrow.

instance MonadWriter w m => ArrowWriter w (Wire e (Kleisli m)) where
    write = mkFixM (liftM Right . tell)

    newWriter (WmPure f) = WmPure ((fmap (, mempty) *** newWriter) . f)
    newWriter (WmGen c) =
        WmGen $ \x' -> do
            ((mx, w), log) <- listen (c x')
            return (fmap (, log) mx, newWriter w)


-- | The always inhibiting wire.  The @zeroArrow@ is equivalent to
-- "Control.Wire.Prefab.Event.never".

instance (Monad m, Monoid e) => ArrowZero (Wire e (Kleisli m)) where
    zeroArrow = mkPureFix (const $ Left mempty)


-- | Sequencing of wires.

instance Monad m => Category (Wire e (Kleisli m)) where
    id = WmPure $ \x -> (Right x, id)

    w2'@(WmGen c2) . WmGen c1 =
        WmGen $ \x'' -> do
            (mx', w1) <- c1 x''
            case mx' of
              Left ex  -> return (Left ex, w2' . w1)
              Right x' -> do
                  (mx, w2) <- c2 x'
                  return (mx, w2 . w1)

    w2'@(WmGen c2) . WmPure g =
        WmGen $ \(g -> (mx', w1)) -> do
            case mx' of
              Left ex  -> return (Left ex, w2' . w1)
              Right x' -> do
                  (mx, w2) <- c2 x'
                  return (mx, w2 . w1)

    w2'@(WmPure f) . WmGen c1 =
        WmGen $ \x'' -> do
            (mx', w1) <- c1 x''
            return $
                case mx' of
                  Left ex               -> (Left ex, w2' . w1)
                  Right (f -> (mx, w2)) -> (mx, w2 . w1)

    w2'@(WmPure f) . WmPure g =
        WmPure $ \(g -> (mx', w1)) ->
            case mx' of
              Left ex               -> (Left ex, w2' . w1)
              Right (f -> (mx, w2)) -> (mx, w2 . w1)


-- | Map a function over the output signal.

instance Monad m => Functor (Wire e (Kleisli m) a) where
    fmap f (WmGen g)  = WmGen  (liftM (fmap f *** fmap f) . g)
    fmap f (WmPure g) = WmPure ((fmap f *** fmap f) . g)


-- | Create a wire from the given transformation computation.

class Arrow (>~) => WireGen (>~) where
    -- | Stateful variant.
    mkGen :: (a >~ (Either e b, Wire e (>~) a b)) -> Wire e (>~) a b

    -- | Stateless variant.
    mkFix :: Arrow (>~) => (a >~ Either e b) -> Wire e (>~) a b
    mkFix c = let w = mkGen (arr (, w) . c) in w

instance Monad m => WireGen (Kleisli m) where
    mkGen (Kleisli c) = WmGen c
    mkFix (Kleisli c) = let w = WmGen (liftM (, w) . c) in w


-- | Monad-based wires.

type WireM e m = Wire e (Kleisli m)


-- | Create a pure wire from the given transformation function.

class Arrow (>~) => WirePure (>~) where
    -- | Stateful variant.
    mkPure :: (a -> (Either e b, Wire e (>~) a b)) -> Wire e (>~) a b

    -- | Stateless variant.
    mkPureFix :: (a -> Either e b) -> Wire e (>~) a b
    mkPureFix f = let w = mkPure (\x -> (f x, w)) in w

instance Monad m => WirePure (Kleisli m) where
    mkPure = WmPure


-- | Convert the given wire to a generic arrow computation.

class WireToGen (>~) where
    toGen :: Wire e (>~) a b -> (a >~ (Either e b, Wire e (>~) a b))

instance Monad m => WireToGen (Kleisli m) where
    toGen = Kleisli . toGenM


-- | Turn an arbitrary exception to a wire exception.

inhibitException :: Ex.Exception e => e -> LastException
inhibitException = Last . Just . Ex.toException


-- | Turn a string into a 'userError' exception wrapped by
-- 'LastException'.

inhibitMsg :: String -> LastException
inhibitMsg = inhibitException . userError


-- | Map a function over the input.

mapInputM :: Monad m => (a' -> a) -> Wire e (Kleisli m) a b -> Wire e (Kleisli m) a' b
mapInputM f (WmPure g) = WmPure (second (mapInputM f) . g . f)
mapInputM f (WmGen g) = WmGen (liftM (second (mapInputM f)) . g . f)


-- | Map a function over the 'Left' value of an 'Either'.

mapLeft :: (e' -> e) -> Either e' b -> Either e b
mapLeft f = either (Left . f) Right


-- | Create a stateless wire from the given monadic computation.

mkFixM ::
    Monad m
    => (a -> m (Either e b))
    -> Wire e (Kleisli m) a b
mkFixM f = let w = WmGen (liftM (, w) . f) in w


-- | Convert the given wire to a generic monadic computation.

toGenM ::
    Monad m
    => Wire e (Kleisli m) a b  -- ^ Wire to convert.
    -> a                       -- ^ Input value.
    -> m (Either e b, Wire e (Kleisli m) a b)
toGenM (WmGen c)  = c
toGenM (WmPure f) = (return . f)