{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE Rank2Types #-}
module Synthesizer.LLVM.Causal.ProcessPrivate where

import qualified Synthesizer.LLVM.Simple.SignalPrivate as Sig
import qualified Synthesizer.Causal.Class as CausalClass
import qualified Synthesizer.Causal.Utility as ArrowUtil

import qualified LLVM.Extra.Tuple as Tuple
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.MaybeContinuation as MaybeCont
import qualified LLVM.Extra.Marshal as Marshal
import qualified LLVM.Extra.Memory as Memory

import LLVM.Core (CodeGenFunction, Value)

import System.Random (Random, RandomGen, randomR)

import qualified Control.Arrow    as Arr
import qualified Control.Category as Cat
import qualified Control.Monad.Trans.State as MS
import Control.Arrow (Arrow, arr, (<<<), (>>>), (&&&))
import Control.Monad (liftM2, replicateM)
import Control.Applicative (Applicative, pure, (<*>))

import qualified Number.Ratio as Ratio
import qualified Algebra.Field as Field
import qualified Algebra.Ring as Ring
import qualified Algebra.Additive as Additive

import NumericPrelude.Numeric
import NumericPrelude.Base hiding (and, map, zip, zipWith, init)

import qualified Prelude as P


data Core context initState exitState a b =
   forall state.
      (Memory.C state) =>
      Core (forall r c.
            (Tuple.Phi c) =>
            context ->
            a -> state -> MaybeCont.T r c (b, state))
               -- compute next value
           (forall r.
            initState ->
            CodeGenFunction r state)
               -- initial state
           (state -> exitState)
               -- extract final state for cleanup


class
   (CausalClass.C process, Sig.C (CausalClass.SignalOf process)) =>
      C process where
   simple ::
      (Memory.C state) =>
      (forall r c.
       (Tuple.Phi c) =>
       a -> state -> MaybeCont.T r c (b, state)) ->
      (forall r. CodeGenFunction r state) ->
      process a b

   alter ::
      (forall contextLocal initState exitState.
          Core contextLocal initState exitState a0 b0 ->
          Core contextLocal initState exitState a1 b1) ->
      process a0 b0 -> process a1 b1

   replicateControlled ::
      (Tuple.Undefined x, Tuple.Phi x) =>
      Int -> process (c,x) x -> process (c,x) x


alterSignal ::
   (C process, CausalClass.SignalOf process ~ signal) =>
   (forall contextLocal initState exitState.
       Sig.Core contextLocal initState exitState a0 ->
       Core contextLocal initState exitState a1 b1) ->
   signal a0 -> process a1 b1
alterSignal f =
   alter (\(Core next start stop) -> f (Sig.Core (\c -> next c ()) start stop))
   .
   CausalClass.fromSignal



data T a b =
   forall state local ioContext parameters.
      (Marshal.C parameters, Memory.C state) =>
      Cons (forall r c.
            (Tuple.Phi c) =>
            Tuple.ValueOf parameters -> local ->
            a -> state -> MaybeCont.T r c (b, state))
               -- compute next value
           (forall r.
            CodeGenFunction r local)
               -- allocate temporary variables before a loop
           (forall r.
            Tuple.ValueOf parameters ->
            CodeGenFunction r state)
               -- initial state
           (IO (ioContext, parameters))
               -- initialization from IO monad
           (ioContext -> IO ())
               -- finalization from IO monad


type instance CausalClass.ProcessOf Sig.T = T

instance CausalClass.C T where
   type SignalOf T = Sig.T
   toSignal = toSignal
   fromSignal = fromSignal

instance C T where
   simple next start =
      Cons
         (const $ \ () -> next)
         (return ())
         (const start)
         (return ((),()))
         (const $ return ())

   alter f (Cons next0 alloca start0 create delete) =
      case f (Core (uncurry next0) start0 id) of
         Core next1 start1 _ ->
            Cons (curry next1) alloca start1 create delete

   {-
   Could be implemented with a machine code loop like in CausalParameterized.
   But to this end we would need a 'stop' function.
   -}
   replicateControlled = CausalClass.replicateControlled


toSignal :: T () a -> Sig.T a
toSignal (Cons next alloca start createIOContext deleteIOContext) = Sig.Cons
   (\ioContext local -> next ioContext local ())
   alloca
   start
   createIOContext deleteIOContext

fromSignal :: Sig.T b -> T a b
fromSignal (Sig.Cons next alloca start createIOContext deleteIOContext) = Cons
   (\ioContext local _ -> next ioContext local)
   alloca
   start
   createIOContext deleteIOContext


map ::
   (C process) =>
   (forall r. a -> CodeGenFunction r b) ->
   process a b
map f =
   mapAccum (\a s -> fmap (flip (,) s) $ f a) (return ())

mapAccum ::
   (C process, Memory.C state) =>
   (forall r.
    a -> state -> CodeGenFunction r (b, state)) ->
   (forall r. CodeGenFunction r state) ->
   process a b
mapAccum next =
   simple (\a s -> MaybeCont.lift $ next a s)

zipWith ::
   (C process) =>
   (forall r. a -> b -> CodeGenFunction r c) ->
   process (a,b) c
zipWith f = map (uncurry f)


mapProc ::
   (C process) =>
   (forall r. b -> CodeGenFunction r c) ->
   process a b ->
   process a c
mapProc f x = map f <<< x

zipProcWith ::
   (C process) =>
   (forall r. b -> c -> CodeGenFunction r d) ->
   process a b ->
   process a c ->
   process a d
zipProcWith f x y = zipWith f <<< x&&&y


takeWhile ::
   (C process) =>
   (forall r. a -> CodeGenFunction r (Value Bool)) ->
   process a a
takeWhile p =
   simple
      (\a () -> do
         MaybeCont.guard =<< MaybeCont.lift (p a)
         return (a,()))
      (return ())


compose :: T a b -> T b c -> T a c
compose
      (Cons nextA allocaA startA createIOContextA deleteIOContextA)
      (Cons nextB allocaB startB createIOContextB deleteIOContextB) = Cons
   (\(paramA, paramB) (localA, localB) a (sa0,sb0) -> do
      (b,sa1) <- nextA paramA localA a sa0
      (c,sb1) <- nextB paramB localB b sb0
      return (c, (sa1,sb1)))
   (liftM2 (,) allocaA allocaB)
   (Sig.combineStart startA startB)
   (Sig.combineCreate createIOContextA createIOContextB)
   (Sig.combineDelete deleteIOContextA deleteIOContextB)


first :: (C process) => process b c -> process (b, d) (c, d)
first = alter (\(Core next start stop) -> Core (firstNext next) start stop)


instance Cat.Category T where
   id = map return
   (.) = flip compose

instance Arr.Arrow T where
   arr f = map (return . f)
   first = first



instance Functor (T a) where
   fmap = ArrowUtil.map

instance Applicative (T a) where
   pure = ArrowUtil.pure
   (<*>) = ArrowUtil.apply


instance (A.Additive b) => Additive.C (T a b) where
   zero = pure A.zero
   negate = mapProc A.neg
   (+) = zipProcWith A.add
   (-) = zipProcWith A.sub

instance (A.PseudoRing b, A.IntegerConstant b) => Ring.C (T a b) where
   one = pure A.one
   fromInteger n = pure (A.fromInteger' n)
   (*) = zipProcWith A.mul

instance (A.Field b, A.RationalConstant b) => Field.C (T a b) where
   fromRational' x = pure (A.fromRational' $ Ratio.toRational98 x)
   (/) = zipProcWith A.fdiv


instance (A.PseudoRing b, A.Real b, A.IntegerConstant b) => P.Num (T a b) where
   fromInteger n = pure (A.fromInteger' n)
   negate = mapProc A.neg
   (+) = zipProcWith A.add
   (-) = zipProcWith A.sub
   (*) = zipProcWith A.mul
   abs = mapProc A.abs
   signum = mapProc A.signum

instance (A.Field b, A.Real b, A.RationalConstant b) => P.Fractional (T a b) where
   fromRational x = pure (A.fromRational' x)
   (/) = zipProcWith A.fdiv



firstNext ::
   (Functor m) =>
   (context -> a -> s -> m (b, s)) ->
   context -> (a, c) -> s -> m ((b, c), s)
firstNext next context (b,d) s0 =
   fmap
      (\(c,s1) -> ((c,d), s1))
      (next context b s0)

loopNext ::
   (Monad m) =>
   (context -> (a,c) -> state -> m ((b,c), state)) ->
   context -> a -> (c, state) -> m (b, (c, state))
loopNext next ctx a0 (c0,s0) = do
   ((b1,c1), s1) <- next ctx (a0,c0) s0
   return (b1,(c1,s1))

feedbackControlledAux ::
   Arrow arrow =>
   arrow ((ctrl,a),c) b ->
   arrow (ctrl,b) c ->
   arrow ((ctrl,a),c) (b,c)
feedbackControlledAux forth back =
   arr (fst.fst) &&& forth  >>>  arr snd &&& back


reverbParams ::
   (RandomGen g, Random a) =>
   g -> Int -> (a, a) -> (Int, Int) -> [(a, Int)]
reverbParams rnd num gainRange timeRange =
   flip MS.evalState rnd $
   replicateM num $
   liftM2 (,)
      (MS.state (randomR gainRange))
      (MS.state (randomR timeRange))