{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE Rank2Types #-}
module Synthesizer.LLVM.Simple.SignalPrivate where

import qualified Synthesizer.LLVM.Storable.ChunkIterator as ChunkIt

import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.MaybeContinuation as MaybeCont
import qualified LLVM.Extra.Either as Either
import qualified LLVM.Extra.Maybe as Maybe
import qualified LLVM.Extra.Arithmetic as A
import LLVM.Extra.Control (ifThen, )
import LLVM.Extra.Class (MakeValueTuple, ValueTuple, Undefined, )

import qualified LLVM.Core as LLVM
import LLVM.Util.Loop (Phi, )
import LLVM.Core (CodeGenFunction, Value, valueOf, )

import Control.Monad (liftM2, )
import Control.Applicative (Applicative, pure, liftA2, (<*>), (<$>), )

import Foreign.Storable.Tuple ()
import Foreign.Storable (Storable, )
import Foreign.StablePtr (StablePtr, )
import Foreign.Ptr (Ptr, nullPtr, )

import Data.Monoid (Monoid, mempty, mappend, )
import Data.Semigroup (Semigroup, (<>), )
import Data.Word (Word32, )

import qualified Number.Ratio as Ratio
import qualified Algebra.Field as Field
import qualified Algebra.Ring as Ring
import qualified Algebra.Additive as Additive

import NumericPrelude.Base hiding (and, iterate, map, zip, zipWith, )

import qualified Prelude as P


{-
We need the forall quantification for 'CodeGenFunction's @r@ parameter.
This type parameter will be unified with the result type of the final function.
Since one piece of code can be used in multiple functions
we cannot yet fix the type @r@ here.


We might avoid code duplication with Causal.Process by defining

> newtype T a = Cons (Causal.T () a)


In earlier versions the createIOContext method created only an ioContext
that was directly used to construct code for 'start' and 'next'.
This had the advantage that we did not need to pass
something via the Memory.C interface to the function.
However, creating both an ioContext and a low-level parameter has those advantages:
We can design Causal.Process such that a process
can be applied to multiple signals without recompilation.
We can lift simple signals and processes to their parameterized counterparts.
-}
data T a =
   forall state local ioContext parameters.
      (Storable parameters,
       MakeValueTuple parameters,
       Memory.C (ValueTuple parameters),
       Memory.C state) =>
      Cons (forall r c.
            (Phi c) =>
            ValueTuple parameters -> local ->
            state -> MaybeCont.T r c (a, state))
               -- compute next value
           (forall r.
            CodeGenFunction r local)
               -- allocate temporary variables before a loop
           (forall r.
            ValueTuple parameters ->
            CodeGenFunction r state)
               -- initial state
           (IO (ioContext, parameters))
               {- initialization from IO monad
               This will be run within Unsafe.performIO,
               so no observable In/Out actions please!
               -}
           (ioContext -> IO ())
               -- finalization from IO monad, also run within Unsafe.performIO


data Core context initState exitState a =
   forall state.
      (Memory.C state) =>
      Core (forall r c.
            (Phi c) =>
            context ->
            state -> MaybeCont.T r c (a, state))
               -- compute next value
           (forall r.
            initState ->
            CodeGenFunction r state)
               -- initial state
           (state -> exitState)
               -- extract final state for cleanup


class Applicative signal => C signal where
   simple ::
      (Memory.C state) =>
      (forall r c. state -> MaybeCont.T r c (a, state)) ->
      (forall r. CodeGenFunction r state) ->
      signal a
   simple next start =
      simpleAlloca (\() state -> next state) (return ()) start

   simpleAlloca ::
      (Memory.C state) =>
      (forall r c. local -> state -> MaybeCont.T r c (a, state)) ->
      (forall r. CodeGenFunction r local) ->
      (forall r. CodeGenFunction r state) ->
      signal a

   alter ::
      (forall contextLocal initState exitState.
          Core contextLocal initState exitState a0 ->
          Core contextLocal initState exitState a1) ->
      signal a0 -> signal a1

instance C T where
   simpleAlloca next alloca0 start =
      Cons
         (\() local -> next local)
         alloca0
         (const start)
         (return ((),()))
         (const $ return ())

   alter f (Cons next0 alloca0 start0 create delete) =
      case f (Core (uncurry next0) start0 id) of
         Core next1 start1 _ ->
            Cons (curry next1) alloca0 start1 create delete


map ::
   (C signal) =>
   (forall r. a -> CodeGenFunction r b) -> signal a -> signal b
map f = alter (\(Core next start stop) ->
   Core
      (\ioContext sa0 -> do
         (a,sa1) <- next ioContext sa0
         b <- MaybeCont.lift $ f a
         return (b, sa1))
      start
      stop)

zipWith ::
   (C signal) =>
   (forall r. a -> b -> CodeGenFunction r c) ->
   signal a -> signal b -> signal c
zipWith f a b  =  map (uncurry f) $ liftA2 (,) a b


zipPair :: (a,b) -> (c,d) -> ((a,c),(b,d))
zipPair (a,b) (c,d) = ((a,c),(b,d))

zip :: T a -> T b -> T (a,b)
zip (Cons nextA allocaA startA createIOContextA deleteIOContextA)
    (Cons nextB allocaB startB createIOContextB deleteIOContextB) =
   Cons
      (\(paramA, paramB) (localA, localB) (sa0,sb0) ->
         liftM2 zipPair
            (nextA paramA localA sa0)
            (nextB paramB localB sb0))
      (liftM2 (,) allocaA allocaB)
      (combineStart startA startB)
      (combineCreate createIOContextA createIOContextB)
      (combineDelete deleteIOContextA deleteIOContextB)

combineStart ::
   Monad m =>
   (paramA -> m stateA) ->
   (paramB -> m stateB) ->
   (paramA, paramB) -> m (stateA, stateB)
combineStart startA startB (paramA, paramB) =
   liftM2 (,) (startA paramA) (startB paramB)

combineCreate ::
   Monad m =>
   m (ioContextA, contextA) ->
   m (ioContextB, contextB) ->
   m ((ioContextA, ioContextB), (contextA, contextB))
combineCreate createIOContextA createIOContextB =
   liftM2 zipPair createIOContextA createIOContextB

combineDelete :: (Monad m) => (ca -> m ()) -> (cb -> m ()) -> (ca, cb) -> m ()
combineDelete deleteIOContextA deleteIOContextB (ca,cb) =
   deleteIOContextA ca >> deleteIOContextB cb


instance Functor T where
   fmap f = map (return . f)

{- |
ZipList semantics
-}
instance Applicative T where
   pure x = simple (\() -> return (x, ())) (return ())
   f <*> a = fmap (uncurry ($)) $ zip f a

instance (A.Additive a) => Additive.C (T a) where
   zero = pure A.zero
   negate = map A.neg
   (+) = zipWith A.add
   (-) = zipWith A.sub

instance (A.PseudoRing a, A.IntegerConstant a) => Ring.C (T a) where
   one = pure A.one
   fromInteger n = pure (A.fromInteger' n)
   (*) = zipWith A.mul

instance (A.Field a, A.RationalConstant a) => Field.C (T a) where
   fromRational' x = pure (A.fromRational' $ Ratio.toRational98 x)
   (/) = zipWith A.fdiv


instance (A.PseudoRing a, A.Real a, A.IntegerConstant a) => P.Num (T a) where
   fromInteger n = pure (A.fromInteger' n)
   negate = map A.neg
   (+) = zipWith A.add
   (-) = zipWith A.sub
   (*) = zipWith A.mul
   abs = map A.abs
   signum = map A.signum

instance (A.Field a, A.Real a, A.RationalConstant a) => P.Fractional (T a) where
   fromRational x = pure (A.fromRational' x)
   (/) = zipWith A.fdiv



empty :: (C signal) => signal a
empty = simple (const $ MaybeCont.nothing) (return ())

{- |
Appending many signals is inefficient,
since in cascadingly appended signals the parts are counted in an unary way.
Concatenating infinitely many signals is impossible.
If you want to concatenate a lot of signals,
please render them to lazy storable vectors first.
-}
{-
We might save a little space by using a union
for the states of the first and the second signal generator.
If the concatenated generators allocate memory,
we could also save some memory by calling @startB@
only after the first generator finished.
However, for correct deallocation
we would need to track which of the @start@ blocks
have been executed so far.
This in turn might be difficult in connection with the garbage collector.
-}
append :: (Phi a, Undefined a) => T a -> T a -> T a
append
      (Cons nextA allocaA startA createIOContextA deleteIOContextA)
      (Cons nextB allocaB startB createIOContextB deleteIOContextB) =
   Cons
      (\(parameterA, parameterB) (localA, localB) es0 ->
            MaybeCont.fromMaybe $ do
         es1 <-
            Either.run es0
               (\sa0 ->
                  MaybeCont.resolve
                     (nextA parameterA localA sa0)
                     (fmap Either.right $ startB parameterB)
                     (\(a1,sa1) -> return (Either.left (a1, sa1))))
               (return . Either.right)

         Either.run es1
            (\(a1,s1) -> return (Maybe.just (a1, Either.left s1)))
            (\sb0 ->
               MaybeCont.toMaybe $
               fmap (\(b,sb1) -> (b, Either.right sb1)) $
               nextB parameterB localB sb0))
      (liftM2 (,) allocaA allocaB)
      (\(parameterA, _parameterB) -> Either.left <$> startA parameterA)
      (combineCreate createIOContextA createIOContextB)
      (combineDelete deleteIOContextA deleteIOContextB)

instance (Phi a, Undefined a) => Semigroup (T a) where
   (<>) = append

instance (Phi a, Undefined a) => Monoid (T a) where
   mempty = empty
   mappend = append



storableVectorNextChunk ::
   (Phi c, MakeValueTuple a, ValueTuple a ~ value,
    Memory.C value, Memory.Struct value ~ struct) =>
   String ->
   Value (StablePtr (ChunkIt.T a)) -> Value (Ptr Word32) -> () ->
   MaybeCont.T r c ((Value (Ptr struct), Value Word32), ())
storableVectorNextChunk callbackName stable lenPtr () =
   MaybeCont.fromBool $ do
      nextChunkFn <- LLVM.staticNamedFunction callbackName ChunkIt.nextCallBack
      (buffer,len) <-
         liftM2 (,)
            (LLVM.call nextChunkFn stable lenPtr)
            (LLVM.load lenPtr)
      valid <- A.cmp LLVM.CmpNE buffer (valueOf nullPtr)
      return (valid, ((buffer,len), ()))

flattenChunks ::
   (C signal, Memory.C value, Memory.Struct value ~ struct) =>
   signal (Value (Ptr struct), Value Word32) -> signal value
flattenChunks = alter $ \(Core next start stop) ->
   Core
      (\context ((buffer0,length0), state0) -> do
         ((buffer1,length1), state1) <- MaybeCont.fromBool $ do
            needNext <- A.cmp LLVM.CmpEQ length0 A.zero
            ifThen needNext
               (valueOf True, ((buffer0,length0), state0))
               (MaybeCont.toBool $ next context state0)
         MaybeCont.lift $ do
            x <- Memory.load buffer1
            buffer2 <- A.advanceArrayElementPtr buffer1
            length2 <- A.dec length1
            return (x, ((buffer2,length2), state1)))
      (\p -> (,) (valueOf nullPtr, A.zero) <$> start p)
      (stop . snd)

alloca :: (C signal, LLVM.IsSized a) => signal (LLVM.Value (Ptr a))
alloca =
   simpleAlloca
      (\ptr () -> return (ptr, ()))
      LLVM.alloca
      (return ())