{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
module Data.Array.Knead.Expression where

import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Control as C
import qualified LLVM.Extra.Monad as LMonad
import qualified LLVM.Core as LLVM
import LLVM.Extra.Multi.Value (PatternTuple, Decomposed, Atom, atom, )

import qualified Control.Monad as Monad
import qualified Data.Tuple.HT as Tuple

import Prelude hiding (fst, snd, min, max, zip, unzip, zip3, unzip3, )


newtype Exp a = Exp {unExp :: forall r. LLVM.CodeGenFunction r (MultiValue.T a)}


class Value val where
   lift0 :: MultiValue.T a -> val a
   lift1 ::
      (MultiValue.T a -> MultiValue.T b) ->
      val a -> val b
   lift2 ::
      (MultiValue.T a -> MultiValue.T b -> MultiValue.T c) ->
      val a -> val b -> val c
   lift3 ::
      (MultiValue.T a -> MultiValue.T b -> MultiValue.T c -> MultiValue.T d) ->
      val a -> val b -> val c -> val d
   lift4 ::
      (MultiValue.T a -> MultiValue.T b -> MultiValue.T c -> MultiValue.T d -> MultiValue.T e) ->
      val a -> val b -> val c -> val d -> val e

instance Value MultiValue.T where
   lift0 = id
   lift1 = id
   lift2 = id
   lift3 = id
   lift4 = id

instance Value Exp where
   lift0 a = Exp (return a)
   lift1 f (Exp a) = Exp (Monad.liftM f a)
   lift2 f (Exp a) (Exp b) = Exp (Monad.liftM2 f a b)
   lift3 f (Exp a) (Exp b) (Exp c) = Exp (Monad.liftM3 f a b c)
   lift4 f (Exp a) (Exp b) (Exp c) (Exp d) = Exp (Monad.liftM4 f a b c d)


liftM ::
   (forall r.
    MultiValue.T a ->
    LLVM.CodeGenFunction r (MultiValue.T b)) ->
   (Exp a -> Exp b)
liftM f (Exp a) = Exp (f =<< a)

liftM2 ::
   (forall r.
    MultiValue.T a -> MultiValue.T b ->
    LLVM.CodeGenFunction r (MultiValue.T c)) ->
   (Exp a -> Exp b -> Exp c)
liftM2 f (Exp a) (Exp b) = Exp (LMonad.liftR2 f a b)

liftM3 ::
   (forall r.
    MultiValue.T a -> MultiValue.T b -> MultiValue.T c ->
    LLVM.CodeGenFunction r (MultiValue.T d)) ->
   (Exp a -> Exp b -> Exp c -> Exp d)
liftM3 f (Exp a) (Exp b) (Exp c) = Exp (LMonad.liftR3 f a b c)


unliftM1 ::
   (Exp a -> Exp b) ->
   MultiValue.T a -> LLVM.CodeGenFunction r (MultiValue.T b)
unliftM1 f ix = unExp (f (lift0 ix))

unliftM2 ::
   (Exp a -> Exp b -> Exp c) ->
   MultiValue.T a -> MultiValue.T b ->
   LLVM.CodeGenFunction r (MultiValue.T c)
unliftM2 f ix jx = unExp (f (lift0 ix) (lift0 jx))

unliftM3 ::
   (Exp a -> Exp b -> Exp c -> Exp d) ->
   MultiValue.T a -> MultiValue.T b -> MultiValue.T c ->
   LLVM.CodeGenFunction r (MultiValue.T d)
unliftM3 f ix jx kx = unExp (f (lift0 ix) (lift0 jx) (lift0 kx))



min :: (MultiValue.Real a) => Exp a -> Exp a -> Exp a
min = liftM2 A.min

max :: (MultiValue.Real a) => Exp a -> Exp a -> Exp a
max = liftM2 A.max


zip :: (Value val) => val a -> val b -> val (a, b)
zip = lift2 MultiValue.zip

zip3 :: (Value val) => val a -> val b -> val c -> val (a, b, c)
zip3 = lift3 MultiValue.zip3

zip4 :: (Value val) => val a -> val b -> val c -> val d -> val (a, b, c, d)
zip4 = lift4 MultiValue.zip4

unzip :: (Value val) => val (a, b) -> (val a, val b)
unzip ab =
   (lift1 MultiValue.fst ab, lift1 MultiValue.snd ab)

unzip3 :: (Value val) => val (a, b, c) -> (val a, val b, val c)
unzip3 abc =
   (lift1 MultiValue.fst3 abc,
    lift1 MultiValue.snd3 abc,
    lift1 MultiValue.thd3 abc)

unzip4 :: (Value val) => val (a, b, c, d) -> (val a, val b, val c, val d)
unzip4 abcd =
   (lift1 (\(MultiValue.Cons (a,_,_,_)) -> MultiValue.Cons a) abcd,
    lift1 (\(MultiValue.Cons (_,b,_,_)) -> MultiValue.Cons b) abcd,
    lift1 (\(MultiValue.Cons (_,_,c,_)) -> MultiValue.Cons c) abcd,
    lift1 (\(MultiValue.Cons (_,_,_,d)) -> MultiValue.Cons d) abcd)

fst :: (Value val) => val (a, b) -> val a
fst = lift1 MultiValue.fst

snd :: (Value val) => val (a, b) -> val b
snd = lift1 MultiValue.snd

mapFst :: (Exp a -> Exp b) -> Exp (a, c) -> Exp (b, c)
mapFst f = modify (atom, atom) $ \(a,c) -> (f a, c)

mapSnd :: (Exp b -> Exp c) -> Exp (a, b) -> Exp (a, c)
mapSnd f = modify (atom, atom) $ \(a,b) -> (a, f b)

swap :: (Value val) => val (a, b) -> val (b, a)
swap = lift1 MultiValue.swap


modifyMultiValue ::
   (Value val,
    MultiValue.Compose a,
    MultiValue.Decompose pattern,
    MultiValue.PatternTuple pattern ~ tuple) =>
   pattern ->
   (Decomposed MultiValue.T pattern -> a) ->
   val tuple -> val (MultiValue.Composed a)
modifyMultiValue p f = lift1 $ MultiValue.modify p f

modifyMultiValue2 ::
   (Value val,
    MultiValue.Compose a,
    MultiValue.Decompose patternA,
    MultiValue.Decompose patternB,
    MultiValue.PatternTuple patternA ~ tupleA,
    MultiValue.PatternTuple patternB ~ tupleB) =>
   patternA ->
   patternB ->
   (Decomposed MultiValue.T patternA ->
    Decomposed MultiValue.T patternB -> a) ->
   val tupleA -> val tupleB -> val (MultiValue.Composed a)
modifyMultiValue2 pa pb f = lift2 $ MultiValue.modify2 pa pb f

modifyMultiValueM ::
   (MultiValue.Compose a,
    MultiValue.Decompose pattern,
    MultiValue.PatternTuple pattern ~ tuple) =>
   pattern ->
   (forall r.
    Decomposed MultiValue.T pattern ->
    LLVM.CodeGenFunction r a) ->
   Exp tuple -> Exp (MultiValue.Composed a)
modifyMultiValueM p f = liftM (MultiValue.modifyF p f)

modifyMultiValueM2 ::
   (MultiValue.Compose a,
    MultiValue.Decompose patternA,
    MultiValue.Decompose patternB,
    MultiValue.PatternTuple patternA ~ tupleA,
    MultiValue.PatternTuple patternB ~ tupleB) =>
   patternA ->
   patternB ->
   (forall r.
    Decomposed MultiValue.T patternA ->
    Decomposed MultiValue.T patternB ->
    LLVM.CodeGenFunction r a) ->
   Exp tupleA -> Exp tupleB -> Exp (MultiValue.Composed a)
modifyMultiValueM2 pa pb f = liftM2 (MultiValue.modifyF2 pa pb f)


class Compose multituple where
   type Composed multituple
   {- |
   A nested 'zip'.
   -}
   compose :: multituple -> Exp (Composed multituple)

class
   (Composed (Decomposed Exp pattern) ~ PatternTuple pattern) =>
      Decompose pattern where
   {- |
   Analogous to 'MultiValue.decompose'.
   -}
   decompose :: pattern -> Exp (PatternTuple pattern) -> Decomposed Exp pattern


{- |
Analogus to 'MultiValue.modifyMultiValue'.
-}
modify ::
   (Compose a, Decompose pattern) =>
   pattern ->
   (Decomposed Exp pattern -> a) ->
   Exp (PatternTuple pattern) -> Exp (Composed a)
modify p f = compose . f . decompose p

modify2 ::
   (Compose a, Decompose patternA, Decompose patternB) =>
   patternA ->
   patternB ->
   (Decomposed Exp patternA -> Decomposed Exp patternB -> a) ->
   Exp (PatternTuple patternA) -> Exp (PatternTuple patternB) -> Exp (Composed a)
modify2 pa pb f a b = compose $ f (decompose pa a) (decompose pb b)



instance Compose (Exp a) where
   type Composed (Exp a) = a
   compose = id

instance Decompose (Atom a) where
   decompose _ = id



instance Compose () where
   type Composed () = ()
   compose = lift0 . MultiValue.cons

instance Decompose () where
   decompose _ _ = ()


instance (Compose a, Compose b) => Compose (a,b) where
   type Composed (a,b) = (Composed a, Composed b)
   compose = uncurry zip . Tuple.mapPair (compose, compose)

instance (Decompose pa, Decompose pb) => Decompose (pa,pb) where
   decompose (pa,pb) =
      Tuple.mapPair (decompose pa, decompose pb) . unzip


instance (Compose a, Compose b, Compose c) => Compose (a,b,c) where
   type Composed (a,b,c) = (Composed a, Composed b, Composed c)
   compose = Tuple.uncurry3 zip3 . Tuple.mapTriple (compose, compose, compose)

instance
   (Decompose pa, Decompose pb, Decompose pc) =>
      Decompose (pa,pb,pc) where
   decompose (pa,pb,pc) =
      Tuple.mapTriple (decompose pa, decompose pb, decompose pc) . unzip3


instance (Compose a, Compose b, Compose c, Compose d) => Compose (a,b,c,d) where
   type Composed (a,b,c,d) = (Composed a, Composed b, Composed c, Composed d)
   compose (a,b,c,d) = zip4 (compose a) (compose b) (compose c) (compose d)

instance
   (Decompose pa, Decompose pb, Decompose pc, Decompose pd) =>
      Decompose (pa,pb,pc,pd) where
   decompose (pa,pb,pc,pd) x =
      case unzip4 x of
         (a,b,c,d) ->
            (decompose pa a, decompose pb b, decompose pc c, decompose pd d)


unit :: Exp ()
unit = lift0 $ MultiValue.cons ()

zero :: (MultiValue.C a) => Exp a
zero = lift0 MultiValue.zero

add :: (MultiValue.Additive a) => Exp a -> Exp a -> Exp a
add = liftM2 MultiValue.add

sub :: (MultiValue.Additive a) => Exp a -> Exp a -> Exp a
sub = liftM2 MultiValue.sub

mul :: (MultiValue.PseudoRing a) => Exp a -> Exp a -> Exp a
mul = liftM2 MultiValue.mul

sqr :: (MultiValue.PseudoRing a) => Exp a -> Exp a
sqr = liftM $ \x -> MultiValue.mul x x

sqrt :: (MultiValue.Algebraic a) => Exp a -> Exp a
sqrt = liftM MultiValue.sqrt

idiv :: (MultiValue.Integral a) => Exp a -> Exp a -> Exp a
idiv = liftM2 MultiValue.idiv

irem :: (MultiValue.Integral a) => Exp a -> Exp a -> Exp a
irem = liftM2 MultiValue.irem

fromInteger' :: (MultiValue.IntegerConstant a) => Integer -> Exp a
fromInteger' = lift0 . MultiValue.fromInteger'

fromRational' :: (MultiValue.RationalConstant a) => Rational -> Exp a
fromRational' = lift0 . MultiValue.fromRational'


cmp ::
   (MultiValue.Comparison a) =>
   LLVM.CmpPredicate -> Exp a -> Exp a -> Exp Bool
cmp ord = liftM2 $ MultiValue.cmp ord

infix 4 ==*, /=*, <*, <=*, >*, >=*

(==*), (/=*), (<*), (>=*), (>*), (<=*) ::
   (MultiValue.Comparison a) => Exp a -> Exp a -> Exp Bool
(==*) = cmp LLVM.CmpEQ
(/=*) = cmp LLVM.CmpNE
(<*)  = cmp LLVM.CmpLT
(>=*) = cmp LLVM.CmpGE
(>*)  = cmp LLVM.CmpGT
(<=*) = cmp LLVM.CmpLE


true, false :: Exp Bool
true = lift0 $ MultiValue.cons True
false = lift0 $ MultiValue.cons False

infixr 3 &&*
(&&*) :: Exp Bool -> Exp Bool -> Exp Bool
(&&*) = liftM2 MultiValue.and

infixr 2 ||*
(||*) :: Exp Bool -> Exp Bool -> Exp Bool
(||*) = liftM2 MultiValue.or

not :: Exp Bool -> Exp Bool
not = liftM MultiValue.inv

{- |
Like 'ifThenElse' but computes both alternative expressions
and then uses LLVM's efficient @select@ instruction.
-}
select :: (MultiValue.Select a) => Exp Bool -> Exp a -> Exp a -> Exp a
select = liftM3 MultiValue.select

ifThenElse :: (MultiValue.C a) => Exp Bool -> Exp a -> Exp a -> Exp a
ifThenElse ec ex ey =
   Exp (do
      MultiValue.Cons c <- unExp ec
      C.ifThenElse c (unExp ex) (unExp ey))


instance
   (MultiValue.PseudoRing a, MultiValue.Real a, MultiValue.IntegerConstant a) =>
      Num (Exp a) where
   fromInteger n = lift0 (MultiValue.fromInteger' n)
   (+) = liftM2 MultiValue.add
   (-) = liftM2 MultiValue.sub
   negate = liftM MultiValue.neg
   (*) = liftM2 MultiValue.mul
   abs = liftM MultiValue.abs
   signum = liftM MultiValue.signum

instance (MultiValue.Field a, MultiValue.Real a, MultiValue.RationalConstant a) =>
      Fractional (Exp a) where
   fromRational n = lift0 (MultiValue.fromRational' n)
   (/) = liftM2 MultiValue.fdiv