{-# LANGUAGE CPP
           , GADTs
           , Rank2Types
           , DataKinds
           , TypeFamilies
           , FlexibleContexts
           , UndecidableInstances
           , LambdaCase
           , MultiParamTypeClasses
           , OverloadedStrings
           #-}

{-# OPTIONS_GHC -Wall -fwarn-tabs -fsimpl-tick-factor=1000 #-}
module Language.Hakaru.Runtime.LogFloatPrelude where

#if __GLASGOW_HASKELL__ < 710
import           Data.Functor                    ((<$>))
import           Control.Applicative             (Applicative(..))
#endif
import           Data.Foldable                   as F
import qualified System.Random.MWC               as MWC
import qualified System.Random.MWC.Distributions as MWCD
import           Data.Number.Natural
import           Data.Number.LogFloat            hiding (sum, product)
import qualified Data.Number.LogFloat            as LF
import           Data.STRef
import qualified Data.Vector                     as V
import qualified Data.Vector.Unboxed             as U
import qualified Data.Vector.Generic             as G
import qualified Data.Vector.Generic.Mutable     as M
import           Control.Monad
import           Control.Monad.ST
import           Prelude                         hiding (init, sum, product, exp, log, (**), pi)
import qualified Prelude                         as P

type family MinBoxVec (v1 :: * -> *) (v2 :: * -> *) :: * -> *
type instance MinBoxVec V.Vector v        = V.Vector
type instance MinBoxVec v        V.Vector = V.Vector
type instance MinBoxVec U.Vector U.Vector = U.Vector

type family MayBoxVec a :: * -> *
type instance MayBoxVec ()           = U.Vector
type instance MayBoxVec Int          = U.Vector
type instance MayBoxVec Double       = U.Vector
type instance MayBoxVec LogFloat     = U.Vector
type instance MayBoxVec Bool         = U.Vector
type instance MayBoxVec (U.Vector a) = V.Vector
type instance MayBoxVec (V.Vector a) = V.Vector
type instance MayBoxVec (a,b)        = MinBoxVec (MayBoxVec a) (MayBoxVec b)

newtype instance U.MVector s LogFloat = MV_LogFloat (U.MVector s Double)
newtype instance U.Vector    LogFloat = V_LogFloat  (U.Vector    Double)

instance U.Unbox LogFloat

instance M.MVector U.MVector LogFloat where
  {-# INLINE basicLength #-}
  {-# INLINE basicUnsafeSlice #-}
  {-# INLINE basicOverlaps #-}
  {-# INLINE basicUnsafeNew #-}
#if __GLASGOW_HASKELL__ > 710
  {-# INLINE basicInitialize #-}
#endif
  {-# INLINE basicUnsafeReplicate #-}
  {-# INLINE basicUnsafeRead #-}
  {-# INLINE basicUnsafeWrite #-}
  {-# INLINE basicClear #-}
  {-# INLINE basicSet #-}
  {-# INLINE basicUnsafeCopy #-}
  {-# INLINE basicUnsafeGrow #-}
  basicLength (MV_LogFloat v) = M.basicLength v
  basicUnsafeSlice i n (MV_LogFloat v) = MV_LogFloat $ M.basicUnsafeSlice i n v
  basicOverlaps (MV_LogFloat v1) (MV_LogFloat v2) = M.basicOverlaps v1 v2
  basicUnsafeNew n = MV_LogFloat `liftM` M.basicUnsafeNew n
#if __GLASGOW_HASKELL__ > 710
  basicInitialize (MV_LogFloat v) = M.basicInitialize v
#endif
  basicUnsafeReplicate n x = MV_LogFloat `liftM` M.basicUnsafeReplicate n (logFromLogFloat x)
  basicUnsafeRead (MV_LogFloat v) i = logToLogFloat `liftM` M.basicUnsafeRead v i
  basicUnsafeWrite (MV_LogFloat v) i x = M.basicUnsafeWrite v i (logFromLogFloat x)
  basicClear (MV_LogFloat v) = M.basicClear v
  basicSet (MV_LogFloat v) x = M.basicSet v (logFromLogFloat x)
  basicUnsafeCopy (MV_LogFloat v1) (MV_LogFloat v2) = M.basicUnsafeCopy v1 v2
  basicUnsafeMove (MV_LogFloat v1) (MV_LogFloat v2) = M.basicUnsafeMove v1 v2
  basicUnsafeGrow (MV_LogFloat v) n = MV_LogFloat `liftM` M.basicUnsafeGrow v n

instance G.Vector U.Vector LogFloat where
  {-# INLINE basicUnsafeFreeze #-}
  {-# INLINE basicUnsafeThaw #-}
  {-# INLINE basicLength #-}
  {-# INLINE basicUnsafeSlice #-}
  {-# INLINE basicUnsafeIndexM #-}
  {-# INLINE elemseq #-}
  basicUnsafeFreeze (MV_LogFloat v) = V_LogFloat `liftM` G.basicUnsafeFreeze v
  basicUnsafeThaw (V_LogFloat v) = MV_LogFloat `liftM` G.basicUnsafeThaw v
  basicLength (V_LogFloat v) = G.basicLength v
  basicUnsafeSlice i n (V_LogFloat v) = V_LogFloat $ G.basicUnsafeSlice i n v
  basicUnsafeIndexM (V_LogFloat v) i
                = logToLogFloat `liftM` G.basicUnsafeIndexM v i
  basicUnsafeCopy (MV_LogFloat mv) (V_LogFloat v)
                = G.basicUnsafeCopy mv v
  elemseq _ x z = G.elemseq (undefined :: U.Vector a) (logFromLogFloat x) z


lam :: (a -> b) -> a -> b
lam = id
{-# INLINE lam #-}

app :: (a -> b) -> a -> b
app f x = f x
{-# INLINE app #-}

let_ :: a -> (a -> b) -> b
let_ x f = let x1 = x in f x1
{-# INLINE let_ #-}

ann_ :: a -> b -> b
ann_ _ a = a
{-# INLINE ann_ #-}

exp :: Double -> LogFloat
exp = logToLogFloat
{-# INLINE exp #-}

log :: LogFloat -> Double
log = logFromLogFloat
{-# INLINE log #-}

newtype Measure a = Measure { unMeasure :: MWC.GenIO -> IO (Maybe a) }

instance Functor Measure where
    fmap  = liftM
    {-# INLINE fmap #-}

instance Applicative Measure where
    pure x = Measure $ \_ -> return (Just x)
    {-# INLINE pure #-}
    (<*>)  = ap
    {-# INLINE (<*>) #-}

instance Monad Measure where
    return  = pure
    {-# INLINE return #-}
    m >>= f = Measure $ \g -> do
                          Just x <- unMeasure m g
                          unMeasure (f x) g
    {-# INLINE (>>=) #-}

makeMeasure :: (MWC.GenIO -> IO a) -> Measure a
makeMeasure f = Measure $ \g -> Just <$> f g
{-# INLINE makeMeasure #-}

uniform :: Double -> Double -> Measure Double
uniform lo hi = makeMeasure $ MWC.uniformR (lo, hi)
{-# INLINE uniform #-}

normal :: Double -> LogFloat -> Measure Double
normal mu sd = makeMeasure $ MWCD.normal mu (fromLogFloat sd)
{-# INLINE normal #-}

beta :: LogFloat -> LogFloat -> Measure LogFloat
beta a b = makeMeasure $ \g ->
  logFloat <$> MWCD.beta (fromLogFloat a) (fromLogFloat b) g
{-# INLINE beta #-}

gamma :: LogFloat -> LogFloat -> Measure LogFloat
gamma a b = makeMeasure $ \g ->
  logFloat <$> MWCD.gamma (fromLogFloat a) (fromLogFloat b) g
{-# INLINE gamma #-}

categorical :: MayBoxVec LogFloat LogFloat -> Measure Int
categorical a = makeMeasure $ \g ->
  fromIntegral <$> MWCD.categorical (U.map prep a) g
  where prep p = fromLogFloat (p / m)
        m      = G.maximum a
{-# INLINE categorical #-}

plate :: (G.Vector (MayBoxVec a) a) =>
         Int -> (Int -> Measure a) -> Measure (MayBoxVec a a)
plate n f = G.generateM (fromIntegral n) $ \x ->
             f (fromIntegral x)
{-# INLINE plate #-}

bucket :: Int -> Int -> (forall s. Reducer () s a) -> a
bucket b e r = runST
             $ case r of Reducer{init=initR,accum=accumR,done=doneR} -> do
                          s' <- initR ()
                          F.mapM_ (\i -> accumR () i s') [b .. e - 1]
                          doneR s'
{-# INLINE bucket #-}

data Reducer xs s a =
    forall cell.
    Reducer { init  :: xs -> ST s cell
            , accum :: xs -> Int -> cell -> ST s ()
            , done  :: cell -> ST s a
            }

r_fanout :: Reducer xs s a
         -> Reducer xs s b
         -> Reducer xs s (a,b)
r_fanout Reducer{init=initA,accum=accumA,done=doneA}
         Reducer{init=initB,accum=accumB,done=doneB} = Reducer
   { init  = \xs       -> liftM2 (,) (initA xs) (initB xs)
   , accum = \bs i (s1, s2) ->
             accumA bs i s1 >> accumB bs i s2
   , done  = \(s1, s2) -> liftM2 (,) (doneA s1) (doneB s2)
   }
{-# INLINE r_fanout #-}

r_index :: (G.Vector (MayBoxVec a) a)
        => (xs -> Int)
        -> ((Int, xs) -> Int)
        -> Reducer (Int, xs) s a
        -> Reducer xs s (MayBoxVec a a)
r_index n f Reducer{init=initR,accum=accumR,done=doneR} = Reducer
   { init  = \xs -> V.generateM (n xs) (\b -> initR (b, xs))
   , accum = \bs i v ->
             let ov = f (i, bs) in
             accumR (ov,bs) i (v V.! ov)
   , done  = \v -> fmap G.convert (V.mapM doneR v)
   }
{-# INLINE r_index #-}

r_split :: ((Int, xs) -> Bool)
        -> Reducer xs s a
        -> Reducer xs s b
        -> Reducer xs s (a,b)
r_split b Reducer{init=initA,accum=accumA,done=doneA}
          Reducer{init=initB,accum=accumB,done=doneB} = Reducer
   { init  = \xs -> liftM2 (,) (initA xs) (initB xs)
   , accum = \bs i (s1, s2) ->
             if (b (i,bs)) then accumA bs i s1 else accumB bs i s2
   , done  = \(s1, s2) -> liftM2 (,) (doneA s1) (doneB s2)
   }
{-# INLINE r_split #-}

r_add :: Num a => ((Int, xs) -> a) -> Reducer xs s a
r_add e = Reducer
   { init  = \_ -> newSTRef 0
   , accum = \bs i s ->
             modifySTRef' s (+ (e (i,bs)))
   , done  = readSTRef
   }
{-# INLINE r_add #-}

r_nop :: Reducer xs s ()
r_nop = Reducer
   { init  = \_ -> return ()
   , accum = \_ _ _ -> return ()
   , done  = \_ -> return ()
   }
{-# INLINE r_nop #-}

pair :: a -> b -> (a, b)
pair = (,)
{-# INLINE pair #-}

true, false :: Bool
true  = True
false = False

nothing :: Maybe a
nothing = Nothing

just :: a -> Maybe a
just = Just

unit :: ()
unit = ()

data Pattern = PVar | PWild
newtype Branch a b =
    Branch { extract :: a -> Maybe b }

ptrue, pfalse :: a -> Branch Bool a
ptrue  b = Branch { extract = extractBool True  b }
pfalse b = Branch { extract = extractBool False b }
{-# INLINE ptrue  #-}
{-# INLINE pfalse #-}

extractBool :: Bool -> a -> Bool -> Maybe a
extractBool b a p | p == b     = Just a
                  | otherwise  = Nothing
{-# INLINE extractBool #-}

pnothing :: b -> Branch (Maybe a) b
pnothing b = Branch { extract = \ma -> case ma of
                                         Nothing -> Just b
                                         Just _  -> Nothing }

pjust :: Pattern -> (a -> b) -> Branch (Maybe a) b
pjust PVar c = Branch { extract = \ma -> case ma of
                                           Nothing -> Nothing
                                           Just x  -> Just (c x) }
pjust _ _ = error "Runtime.Prelude pjust"


ppair :: Pattern -> Pattern -> (x -> y -> b) -> Branch (x,y) b
ppair PVar  PVar c = Branch { extract = (\(x,y) -> Just (c x y)) }
ppair _     _    _ = error "ppair: TODO"

uncase_ :: Maybe a -> a
uncase_ (Just a) = a
uncase_ Nothing  = error "case_: unable to match any branches"
{-# INLINE uncase_ #-}

case_ :: a -> [Branch a b] -> b
case_ e [c1]     = uncase_ (extract c1 e)
case_ e [c1, c2] = uncase_ (extract c1 e `mplus` extract c2 e)
case_ e bs_      = go bs_
  where go []     = error "case_: unable to match any branches"
        go (b:bs) = case extract b e of
                      Just b' -> b'
                      Nothing -> go bs
{-# INLINE case_ #-}

branch :: (c -> Branch a b) -> c -> Branch a b
branch pat body = pat body
{-# INLINE branch #-}

dirac :: a -> Measure a
dirac = return
{-# INLINE dirac #-}

pose :: LogFloat -> Measure a -> Measure a
pose _ a = a
{-# INLINE pose #-}

superpose :: [(LogFloat, Measure a)]
          -> Measure a
superpose pms = do
  i <- categorical (G.fromList $ map fst pms)
  snd (pms !! i)
{-# INLINE superpose #-}

reject :: Measure a
reject = Measure $ \_ -> return Nothing

nat_ :: Int -> Int
nat_ = id

int_ :: Int -> Int
int_ = id

unsafeNat :: Int -> Int
unsafeNat = id

nat2prob :: Int -> LogFloat
nat2prob = fromIntegral

fromInt  :: Int -> Double
fromInt  = fromIntegral

nat2int  :: Int -> Int
nat2int  = id

nat2real :: Int -> Double
nat2real = fromIntegral

fromProb :: LogFloat -> Double
fromProb = fromLogFloat

unsafeProb :: Double -> LogFloat
unsafeProb = logFloat

real_ :: Rational -> Double
real_ = fromRational

prob_ :: NonNegativeRational -> LogFloat
prob_ = fromRational . fromNonNegativeRational

infinity :: Double
infinity = 1/0

abs_ :: Num a => a -> a
abs_ = abs

(**) :: LogFloat -> Double -> LogFloat
(**) = pow
{-# INLINE (**) #-}

pi :: LogFloat
pi = logFloat P.pi
{-# INLINE pi #-}

thRootOf :: Int -> LogFloat -> LogFloat
thRootOf a b = b `pow` (recip $ fromIntegral a)
{-# INLINE thRootOf #-}

array
    :: (G.Vector (MayBoxVec a) a)
    => Int
    -> (Int -> a)
    -> MayBoxVec a a
array n f = G.generate (fromIntegral n) (f . fromIntegral)
{-# INLINE array #-}

arrayLit :: (G.Vector (MayBoxVec a) a) => [a] -> MayBoxVec a a
arrayLit = G.fromList
{-# INLINE arrayLit #-}

(!) :: (G.Vector (MayBoxVec a) a) => MayBoxVec a a -> Int -> a
a ! b = a G.! (fromIntegral b)
{-# INLINE (!) #-}

size :: (G.Vector (MayBoxVec a) a) => MayBoxVec a a -> Int
size v = fromIntegral (G.length v)
{-# INLINE size #-}

class Num a => Num' a where
    product :: Int -> Int -> (Int -> a) -> a
    product a b f = F.foldl' (\x y -> x * f y) 1 [a .. b-1]
    {-# INLINE product #-}
    summate :: Int -> Int -> (Int -> a) -> a
    summate a b f = F.foldl' (\x y -> x + f y) 0 [a .. b-1]
    {-# INLINE summate #-}

instance Num' Int
instance Num' Double
instance Num' LogFloat where
    product a b f = LF.product (map f [a .. b-1])
    {-# INLINE product #-}
    summate a b f = LF.sum     (map f [a .. b-1])
    {-# INLINE summate #-}

run :: Show a
    => MWC.GenIO
    -> Measure a
    -> IO ()
run g k = unMeasure k g >>= \case
           Just a  -> print a
           Nothing -> return ()

iterateM_
    :: Monad m
    => (a -> m a)
    -> a
    -> m b
iterateM_ f = g
    where g x = f x >>= g

withPrint :: Show a => (a -> IO b) -> a -> IO b
withPrint f x = print x >> f x