-----------------------------------------------------------
-- |
-- Module      : Control.Imperative.Internal
-- Copyright   : (C) 2015, Yu Fukuzawa
-- License     : BSD3
-- Maintainer  : minpou.primer@email.com
-- Stability   : experimental
-- Portability : portable
--
-----------------------------------------------------------

{-# LANGUAGE BangPatterns          #-}
{-# LANGUAGE CPP                   #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies          #-}

module Control.Imperative.Internal where
import           Control.Monad
import           Control.Monad.Base
import qualified Control.Monad.ST                  as Strict
import qualified Control.Monad.ST.Lazy             as Lazy
import           Control.Monad.Trans.Cont          (ContT)
import           Control.Monad.Trans.Identity      (IdentityT)
import           Control.Monad.Trans.List          (ListT)
import           Control.Monad.Trans.Loop          (LoopT)
import           Control.Monad.Trans.Maybe         (MaybeT)
import           Control.Monad.Trans.Reader        (ReaderT)
import qualified Control.Monad.Trans.RWS.Lazy      as Lazy
import qualified Control.Monad.Trans.RWS.Strict    as Strict
import qualified Control.Monad.Trans.State.Lazy    as Lazy
import qualified Control.Monad.Trans.State.Strict  as Strict
import qualified Control.Monad.Trans.Writer.Lazy   as Lazy
import qualified Control.Monad.Trans.Writer.Strict as Strict
import           Data.Functor.Identity
import           Data.Monoid
import           GHC.Exts

#ifndef MIN_VERSION_transformers
#define MIN_VERSION_transformers(x,y,z) 1
#endif

#if MIN_VERSION_transformers(0,4,0)
import           Control.Monad.Trans.Except        (ExceptT)
#else
import           Control.Monad.Trans.Error         (Error, ErrorT)
#endif

type family BaseEff (m :: * -> *) :: * -> *
type instance BaseEff [] = []
type instance BaseEff IO = IO
type instance BaseEff Maybe = Maybe
type instance BaseEff Identity = Identity
type instance BaseEff (ListT m)	= BaseEff m
type instance BaseEff (MaybeT m) = BaseEff m
type instance BaseEff (IdentityT m)	= BaseEff m
#if MIN_VERSION_transformers(0,4,0)
type instance BaseEff (ExceptT e m)	= BaseEff m
#else
type instance BaseEff (ErrorT e m)	= BaseEff m
#endif
type instance BaseEff (Lazy.WriterT w m)	= BaseEff m
type instance BaseEff (Strict.WriterT w m)	= BaseEff m
type instance BaseEff (ContT r m)	  = BaseEff m
type instance BaseEff (Lazy.StateT s m)	= BaseEff m
type instance BaseEff (Strict.StateT s m)	= BaseEff m
type instance BaseEff (ReaderT r m)	= BaseEff m
type instance BaseEff (Lazy.RWST r w s m) = BaseEff m
type instance BaseEff (Strict.RWST r w s m)	= BaseEff m
type instance BaseEff (Either e) = Either e
type instance BaseEff (Lazy.ST s) = Lazy.ST s
type instance BaseEff (Strict.ST s) = Strict.ST s
type instance BaseEff (LoopT c e m) = BaseEff m

-- | A reference type in the specified monad.
data Ref m a = Ref
  { get :: m a
  , set :: a -> m ()
  }

-- | Get a stored value from the 'Ref'.
ref :: (MonadBase (BaseEff m) m) => Ref (BaseEff m) a -> m a
ref r = liftBase $ get r
{-# INLINE ref #-}

-- | Assign a value to the 'Ref'.
assign :: MonadBase (BaseEff m) m => Ref (BaseEff m) a -> a -> m ()
assign r !x = liftBase $ set r x
{-# INLINE assign #-}

-- | Apply a function to a stored value without rewriting original one.
liftOp :: Monad m => (a -> b) -> Ref m a -> Ref m b
liftOp f r = expr $ liftM f $ get r
{-# INLINE liftOp #-}

-- | Apply a binary function to two 'Ref's.
liftOp2 :: Monad m => (a -> b -> c) -> Ref m a -> Ref m b -> Ref m c
liftOp2 f r s = expr $ liftM2 f (get r) (get s)
{-# INLINE liftOp2 #-}

-- | Wrap a value inside an immutable 'Ref'.
val :: Monad m => a -> Ref m a
val x = Ref
  { get = return x
  , set = const $ return ()
  }
{-# INLINE val #-}

expr :: Monad m => m a -> Ref m a
expr m = Ref
  { get = m
  , set = const $ return ()
  }
{-# INLINE expr #-}

instance (Num a, Monad m) => Num (Ref m a) where
  (+) = liftOp2 (+)
  (-) = liftOp2 (-)
  (*) = liftOp2 (*)
  negate = liftOp negate
  abs = liftOp abs
  signum = liftOp signum
  fromInteger = val . fromInteger

instance (Fractional a, Monad m) => Fractional (Ref m a) where
  (/) = liftOp2 (/)
  recip = liftOp recip
  fromRational = val . fromRational

instance (Floating a, Monad m) => Floating (Ref m a) where
  pi = val pi
  exp = liftOp exp
  sqrt = liftOp sqrt
  log= liftOp log
  (**) = liftOp2 (**)
  logBase = liftOp2 logBase
  sin = liftOp sin
  tan = liftOp tan
  cos = liftOp cos
  asin = liftOp asin
  atan = liftOp atan
  acos = liftOp acos
  sinh = liftOp sinh
  cosh = liftOp cosh
  tanh = liftOp tanh
  asinh = liftOp asinh
  acosh = liftOp acosh
  atanh = liftOp atanh

instance (Monoid w, Monad m) => Monoid (Ref m w) where
  mempty = val mempty
  mappend = liftOp2 mappend

instance (IsString a, Monad m) => IsString (Ref m a) where
  fromString = val . fromString

-- | Indexing for array-like.
class Indexable v where
  type Element v
  type IndexType v
  (!) :: v -> IndexType v -> Element v

infixl 9 !