{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE Rank2Types #-}
module Synthesizer.LLVM.CausalParameterized.Functional (
   T,
   lift, fromSignal,
   ($&), (&|&),
   compile,
   compileSignal,
   withArgs, MakeArguments, Arguments, makeArgs,
   AnyArg(..),

   Ground(Ground),
   withGroundArgs, MakeGroundArguments, GroundArguments,
   makeGroundArgs,

   Atom(..), atom,
   withGuidedArgs, MakeGuidedArguments, GuidedArguments, PatternArguments,
   makeGuidedArgs,

   PrepareArguments(PrepareArguments), withPreparedArgs, withPreparedArgs2,
   atomArg, stereoArgs, pairArgs, tripleArgs,
   ) where

import qualified Synthesizer.LLVM.CausalParameterized.ProcessPrivate as CausalP
import qualified Synthesizer.LLVM.Causal.ProcessPrivate as Causal
import qualified Synthesizer.LLVM.Parameterized.Signal as Signal
import qualified Synthesizer.LLVM.Frame.SerialVector as Serial
import qualified Synthesizer.LLVM.Frame.Stereo as Stereo

import qualified LLVM.Extra.MaybeContinuation as Maybe
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.Arithmetic as A

import LLVM.Extra.Class (MakeValueTuple, ValueTuple, )
import LLVM.Util.Loop (Phi, )
import LLVM.Core (CodeGenFunction, )
import qualified LLVM.Core as LLVM

import qualified Number.Ratio as Ratio
import qualified Algebra.Transcendental as Trans
import qualified Algebra.Algebraic as Algebraic
import qualified Algebra.Field as Field
import qualified Algebra.Ring as Ring
import qualified Algebra.Additive as Additive

import qualified Control.Monad.Trans.State as State
import qualified Control.Monad.Trans.Class as Trans
import Control.Monad.Trans.State (StateT, )

import qualified Data.Vault.Lazy as Vault
import Data.Vault.Lazy (Vault, )
import qualified Control.Category as Cat
import Control.Arrow (Arrow, (>>^), (&&&), arr, first, )
import Control.Category (Category, (.), )
import Control.Applicative (Applicative, (<*>), pure, liftA2, )

import Foreign.Storable (Storable, )

import Data.Tuple.HT (fst3, snd3, thd3, )

import qualified System.Unsafe as Unsafe

import Prelude hiding ((.), )


newtype T p inp out = Cons (Code p inp out)


-- | similar to @CausalP.T p a b@
data Code p a b =
   forall context local state ioContext parameters.
      (Storable parameters,
       MakeValueTuple parameters,
       Memory.C (ValueTuple parameters),
       Memory.C context,
       Memory.C state) =>
   Code
      (forall r c.
       (Phi c) =>
       context -> local -> a -> state ->
       StateT Vault (Maybe.T r c) (b, state))
          -- compute next value
      (forall r.
       CodeGenFunction r local)
          -- allocate temporary variables before a loop
      (forall r.
       ValueTuple parameters ->
       CodeGenFunction r (context, state))
          -- initial state
      (forall r.
       context -> state ->
       CodeGenFunction r ())
          -- cleanup
      (p -> IO (ioContext, parameters))
          {- initialization from IO monad
          This will be run within Unsafe.performIO,
          so no observable In/Out actions please!
          -}
      (ioContext -> IO ())
          -- finalization from IO monad, also run within Unsafe.performIO



instance Category (Code p) where
   id = arr id
   Code nextB allocaB startB stopB createIOContextB deleteIOContextB .
      Code nextA allocaA startA stopA createIOContextA deleteIOContextA = Code
         (CausalP.composeNext
             (State.mapStateT . Maybe.onFail)
             stopA stopB nextA nextB)
         (liftA2 (,) allocaA allocaB)
         (CausalP.composeStart startA startB)
         (CausalP.composeStop stopA stopB)
         (CausalP.composeCreate createIOContextA createIOContextB)
         (CausalP.composeDelete deleteIOContextA deleteIOContextB)


instance Arrow (Code p) where
   arr f = Code
      (\ _p () a () -> return (f a, ()))
      (return ())
      (\() -> return ((),()))
      (\() () -> return ())
      (const $ return ((),()))
      (const $ return ())
   first (Code next alloca start stop create delete) = Code
      (curry $ Causal.firstNext $ uncurry next) alloca start stop
      create delete


{-
We must not define Category and Arrow instances
because in osci***osci the result of osci would be shared,
although it depends on the particular input.

instance Category (T p) where
   id = tagUnique Cat.id
   Cons a . Cons b = tagUnique (a . b)

instance Arrow (T p) where
   arr f = tagUnique $ arr f
   first (Cons a) = tagUnique $ first a
-}

instance Functor (T p inp) where
   fmap f (Cons x) =
      tagUnique $ x >>^ f

instance Applicative (T p inp) where
   pure a = tagUnique $ arr (const a)
   f <*> x = fmap (uncurry ($))  $  f &|& x


lift0 :: (forall r. CodeGenFunction r out) -> T p inp out
lift0 f = lift (CausalP.mapSimple (const f))

lift1 :: (forall r. a -> CodeGenFunction r out) -> T p inp a -> T p inp out
lift1 f x = CausalP.mapSimple f $& x

lift2 :: (forall r. a -> b -> CodeGenFunction r out) -> T p inp a -> T p inp b -> T p inp out
lift2 f x y = CausalP.zipWithSimple f $& x&|&y


instance (A.PseudoRing b, A.Real b, A.IntegerConstant b) => Num (T p a b) where
   fromInteger n = pure (A.fromInteger' n)
   (+) = lift2 A.add
   (-) = lift2 A.sub
   (*) = lift2 A.mul
   abs = lift1 A.abs
   signum = lift1 A.signum

instance (A.Field b, A.Real b, A.RationalConstant b) => Fractional (T p a b) where
   fromRational x = pure (A.fromRational' x)
   (/) = lift2 A.fdiv


instance (A.Additive b) => Additive.C (T p a b) where
   zero = pure A.zero
   (+) = lift2 A.add
   (-) = lift2 A.sub
   negate = lift1 A.neg

instance (A.PseudoRing b, A.IntegerConstant b) => Ring.C (T p a b) where
   one = pure A.one
   fromInteger n = pure (A.fromInteger' n)
   (*) = lift2 A.mul

instance (A.Field b, A.RationalConstant b) => Field.C (T p a b) where
   fromRational' x = pure (A.fromRational' $ Ratio.toRational98 x)
   (/) = lift2 A.fdiv

instance (A.Transcendental b, A.RationalConstant b) => Algebraic.C (T p a b) where
   sqrt = lift1 A.sqrt
   root n x = lift2 A.pow x (Field.recip $ Ring.fromInteger n)
   x^/r = lift2 A.pow x (Field.fromRational' r)

instance (A.Transcendental b, A.RationalConstant b) => Trans.C (T p a b) where
   pi = lift0 A.pi
   sin = lift1 A.sin
   cos = lift1 A.cos
   (**) = lift2 A.pow
   exp = lift1 A.exp
   log = lift1 A.log

   asin _ = error "LLVM missing intrinsic: asin"
   acos _ = error "LLVM missing intrinsic: acos"
   atan _ = error "LLVM missing intrinsic: atan"


infixr 0 $&

($&) :: CausalP.T p b c -> T p a b -> T p a c
f $& (Cons b) =
   tagUnique $  liftCode f . b


infixr 3 &|&

(&|&) :: T p a b -> T p a c -> T p a (b,c)
Cons b &|& Cons c =
   tagUnique $  b &&& c


liftCode :: CausalP.T p inp out -> Code p inp out
liftCode (CausalP.Cons next alloca start stop create delete) =
   Code
      (\p l a state -> Trans.lift (next p l a state))
      alloca start stop create delete

lift :: CausalP.T p inp out -> T p inp out
lift = tagUnique . liftCode

fromSignal :: Signal.T p out -> T p inp out
fromSignal = lift . CausalP.fromSignal

tag :: Vault.Key out -> Code p inp out -> T p inp out
tag key (Code next alloca start stop create delete) =
   Cons $
   Code
      (\p l a s0 -> do
         mb <- State.gets (Vault.lookup key)
         case mb of
            Just b -> return (b,s0)
            Nothing -> do
               bs@(b,_) <- next p l a s0
               State.modify (Vault.insert key b)
               return bs)
      alloca start stop create delete

-- dummy for debugging
_tag :: Vault.Key out -> Code p inp out -> T p inp out
_tag _ = Cons

tagUnique :: Code p inp out -> T p inp out
tagUnique code =
   Unsafe.performIO $
   fmap (flip tag code) Vault.newKey

initialize :: Code p inp out -> CausalP.T p inp out
initialize (Code next alloca start stop create delete) =
   CausalP.Cons
      (\p l a state -> State.evalStateT (next p l a state) Vault.empty)
      alloca start stop create delete

compile :: T p inp out -> CausalP.T p inp out
compile (Cons code) = initialize code

compileSignal :: T p () out -> Signal.T p out
compileSignal f = CausalP.toSignal $ compile f


{- |
Using 'withArgs' you can simplify

> let x = F.lift (arr fst)
>     y = F.lift (arr (fst.snd))
>     z = F.lift (arr (snd.snd))
> in  F.compile (f x y z)

to

> withArgs $ \(x,(y,z)) -> f x y z
-}
withArgs ::
   (MakeArguments inp) =>
   (Arguments (T p inp) inp -> T p inp out) -> CausalP.T p inp out
withArgs f = withId $ f . makeArgs

withId :: (T p inp inp -> T p inp out) -> CausalP.T p inp out
withId f = compile $ f $ lift Cat.id


type family Arguments (f :: * -> *) (arg :: *)

class MakeArguments arg where
   makeArgs :: Functor f => f arg -> Arguments f arg


{-
I have thought about an Arg type, that marks where to stop descending.
This way we can throw away all of these FlexibleContext instances
and the user can freely choose the granularity of arguments.
However this does not work so easily,
because we would need a functional depedency from, say,
@(Arg a, Arg b)@ to @(a,b)@.
This is the opposite direction to the dependency we use currently.
The 'AnyArg' type provides a solution in this spirit.
-}
type instance Arguments f (LLVM.Value a) = f (LLVM.Value a)
instance MakeArguments (LLVM.Value a) where
   makeArgs = id

{- |
Consistent with pair instance.
You may use 'AnyArg' or 'withGuidedArgs'
to stop descending into the stereo channels.
-}
type instance Arguments f (Stereo.T a) = Stereo.T (Arguments f a)
instance (MakeArguments a) => MakeArguments (Stereo.T a) where
   makeArgs = fmap makeArgs . Stereo.sequence

type instance Arguments f (Serial.T v) = f (Serial.T v)
instance MakeArguments (Serial.T v) where
   makeArgs = id

type instance Arguments f () = f ()
instance MakeArguments () where
   makeArgs = id

type instance Arguments f (a,b) = (Arguments f a, Arguments f b)
instance (MakeArguments a, MakeArguments b) =>
      MakeArguments (a,b) where
   makeArgs f = (makeArgs $ fmap fst f, makeArgs $ fmap snd f)

type instance Arguments f (a,b,c) = (Arguments f a, Arguments f b, Arguments f c)
instance (MakeArguments a, MakeArguments b, MakeArguments c) =>
      MakeArguments (a,b,c) where
   makeArgs f = (makeArgs $ fmap fst3 f, makeArgs $ fmap snd3 f, makeArgs $ fmap thd3 f)


{- |
You can use this to explicitly stop breaking of composed data types.
It might be more comfortable to do this using 'withGuidedArgs'.
-}
newtype AnyArg a = AnyArg {getAnyArg :: a}

type instance Arguments f (AnyArg a) = f a
instance MakeArguments (AnyArg a) where
   makeArgs = fmap getAnyArg



{- |
This is similar to 'withArgs'
but it requires to specify the decomposition depth
using constructors in the arguments.
-}
withGroundArgs ::
   (MakeGroundArguments (T p inp) args,
    GroundArguments args ~ inp) =>
   (args -> T p inp out) -> CausalP.T p inp out
withGroundArgs f = withId $ f . makeGroundArgs


data Ground f a = Ground (f a)


type family GroundArguments args

class (Functor f) => MakeGroundArguments f args where
   makeGroundArgs :: f (GroundArguments args) -> args


type instance GroundArguments (Ground f a) = a
instance (Functor f, f ~ g) => MakeGroundArguments f (Ground g a) where
   makeGroundArgs = Ground

type instance GroundArguments (Stereo.T a) = Stereo.T (GroundArguments a)
instance MakeGroundArguments f a => MakeGroundArguments f (Stereo.T a) where
   makeGroundArgs f =
      Stereo.cons
         (makeGroundArgs $ fmap Stereo.left f)
         (makeGroundArgs $ fmap Stereo.right f)

type instance GroundArguments () = ()
instance (Functor f) => MakeGroundArguments f () where
   makeGroundArgs _ = ()


type instance
   GroundArguments (a,b) =
      (GroundArguments a, GroundArguments b)
instance
   (MakeGroundArguments f a, MakeGroundArguments f b) =>
      MakeGroundArguments f (a,b) where
   makeGroundArgs f =
      (makeGroundArgs $ fmap fst f,
       makeGroundArgs $ fmap snd f)

type instance
   GroundArguments (a,b,c) =
      (GroundArguments a, GroundArguments b, GroundArguments c)
instance
   (MakeGroundArguments f a, MakeGroundArguments f b, MakeGroundArguments f c) =>
      MakeGroundArguments f (a,b,c) where
   makeGroundArgs f =
      (makeGroundArgs $ fmap fst3 f,
       makeGroundArgs $ fmap snd3 f,
       makeGroundArgs $ fmap thd3 f)



{- |
This is similar to 'withArgs'
but it allows to specify the decomposition depth using a pattern.
-}
withGuidedArgs ::
   (MakeGuidedArguments pat, PatternArguments pat ~ inp) =>
   pat ->
   (GuidedArguments (T p inp) pat -> T p inp out) -> CausalP.T p inp out
withGuidedArgs p f = withId $ f . makeGuidedArgs p


data Atom a = Atom

atom :: Atom a
atom = Atom


type family GuidedArguments (f :: * -> *) pat
type family PatternArguments pat

class MakeGuidedArguments pat where
   makeGuidedArgs ::
      Functor f =>
      pat -> f (PatternArguments pat) -> GuidedArguments f pat


type instance GuidedArguments f (Atom a) = f a
type instance PatternArguments (Atom a) = a
instance MakeGuidedArguments (Atom a) where
   makeGuidedArgs Atom = id

type instance GuidedArguments f (Stereo.T a) = Stereo.T (GuidedArguments f a)
type instance PatternArguments (Stereo.T a) = Stereo.T (PatternArguments a)
instance MakeGuidedArguments a => MakeGuidedArguments (Stereo.T a) where
   makeGuidedArgs pat f =
      Stereo.cons
         (makeGuidedArgs (Stereo.left  pat) $ fmap Stereo.left f)
         (makeGuidedArgs (Stereo.right pat) $ fmap Stereo.right f)

type instance GuidedArguments f () = f ()
type instance PatternArguments () = ()
instance MakeGuidedArguments () where
   makeGuidedArgs () = id

type instance
   GuidedArguments f (a,b) =
      (GuidedArguments f a, GuidedArguments f b)
type instance
   PatternArguments (a,b) =
      (PatternArguments a, PatternArguments b)
instance (MakeGuidedArguments a, MakeGuidedArguments b) =>
      MakeGuidedArguments (a,b) where
   makeGuidedArgs (pa,pb) f =
      (makeGuidedArgs pa $ fmap fst f,
       makeGuidedArgs pb $ fmap snd f)

type instance
   GuidedArguments f (a,b,c) =
      (GuidedArguments f a, GuidedArguments f b, GuidedArguments f c)
type instance
   PatternArguments (a,b,c) =
      (PatternArguments a, PatternArguments b, PatternArguments c)
instance
   (MakeGuidedArguments a, MakeGuidedArguments b, MakeGuidedArguments c) =>
      MakeGuidedArguments (a,b,c) where
   makeGuidedArgs (pa,pb,pc) f =
      (makeGuidedArgs pa $ fmap fst3 f,
       makeGuidedArgs pb $ fmap snd3 f,
       makeGuidedArgs pc $ fmap thd3 f)



{- |
Alternative to withGuidedArgs.
This way of pattern construction is even Haskell 98.
-}
withPreparedArgs ::
   PrepareArguments (T p inp) inp a ->
   (a -> T p inp out) -> CausalP.T p inp out
withPreparedArgs (PrepareArguments prepare) f = withId $ f . prepare

withPreparedArgs2 ::
   PrepareArguments (T p (inp0, inp1)) inp0 a ->
   PrepareArguments (T p (inp0, inp1)) inp1 b ->
   (a -> b -> T p (inp0, inp1) out) ->
   CausalP.T p (inp0, inp1) out
withPreparedArgs2 prepareA prepareB f =
   withPreparedArgs (pairArgs prepareA prepareB) (uncurry f)

newtype PrepareArguments f merged separated =
   PrepareArguments (f merged -> separated)

atomArg :: PrepareArguments f a (f a)
atomArg = PrepareArguments id

stereoArgs ::
   (Functor f) =>
   PrepareArguments f a b ->
   PrepareArguments f (Stereo.T a) (Stereo.T b)
stereoArgs (PrepareArguments p) =
   PrepareArguments $ fmap p . Stereo.sequence

pairArgs ::
   (Functor f) =>
   PrepareArguments f a0 b0 ->
   PrepareArguments f a1 b1 ->
   PrepareArguments f (a0,a1) (b0,b1)
pairArgs (PrepareArguments p0) (PrepareArguments p1) =
   PrepareArguments $ \f -> (p0 $ fmap fst f, p1 $ fmap snd f)

tripleArgs ::
   (Functor f) =>
   PrepareArguments f a0 b0 ->
   PrepareArguments f a1 b1 ->
   PrepareArguments f a2 b2 ->
   PrepareArguments f (a0,a1,a2) (b0,b1,b2)
tripleArgs (PrepareArguments p0) (PrepareArguments p1) (PrepareArguments p2) =
   PrepareArguments $ \f ->
      (p0 $ fmap fst3 f, p1 $ fmap snd3 f, p2 $ fmap thd3 f)