{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module LLVM.Extra.FastMath ( 
   NoNaNs(NoNaNs),
   NoInfs(NoInfs),
   NoSignedZeros(NoSignedZeros),
   AllowReciprocal(AllowReciprocal),
   Fast(Fast),
   Flags(setFlags),

   Number(Number, deconsNumber),
   getNumber,
   nvNumber,
   nvDenumber,
   mvNumber,
   mvDenumber,

   NiceValue(setMultiValueFlags, setNiceValueFlags),
   attachNiceValueFlags,
   attachMultiValueFlags,
   liftNumberM,
   liftNumberM2,
   nvecNumber,
   nvecDenumber,
   mvecNumber,
   mvecDenumber,

   NiceVector(setMultiVectorFlags, setNiceVectorFlags),
   attachNiceVectorFlags,
   liftNiceVectorM,
   liftNiceVectorM2,
   attachMultiVectorFlags,
   liftMultiVectorM,
   liftMultiVectorM2,

   Tuple(setTupleFlags),
   Context(Context),
   attachTupleFlags,
   liftContext,
   liftContext2,
   ) where

import qualified LLVM.Extra.Nice.Vector as NiceVector
import qualified LLVM.Extra.Nice.Value.Private as Nice
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Tuple as Tuple
import qualified LLVM.Core as LLVM
import LLVM.Util.Proxy (Proxy(Proxy))

import Foreign.Storable (Storable)

import qualified Control.Monad.HT as Monad
import Control.Applicative ((<$>))


data NoNaNs          = NoNaNs          deriving (Show, Eq)
data NoInfs          = NoInfs          deriving (Show, Eq)
data NoSignedZeros   = NoSignedZeros   deriving (Show, Eq)
data AllowReciprocal = AllowReciprocal deriving (Show, Eq)
data Fast            = Fast            deriving (Show, Eq)


class Flags flags where
   setFlags ::
      (LLVM.IsFloating a) =>
      Proxy flags -> Bool -> LLVM.Value a -> LLVM.CodeGenFunction r ()

instance Flags NoNaNs          where setFlags Proxy = LLVM.setHasNoNaNs
instance Flags NoInfs          where setFlags Proxy = LLVM.setHasNoInfs
instance Flags NoSignedZeros   where setFlags Proxy = LLVM.setHasNoSignedZeros
instance Flags AllowReciprocal where setFlags Proxy = LLVM.setHasAllowReciprocal
instance Flags Fast            where setFlags Proxy = LLVM.setFastMath

instance (Flags f0, Flags f1) => Flags (f0,f1) where
   setFlags p b v = setFlags (fst<$>p) b v >> setFlags (snd<$>p) b v

instance (Flags f0, Flags f1, Flags f2) => Flags (f0,f1,f2) where
   setFlags = setSplitFlags $ \(f0,f1,f2) -> (f0,(f1,f2))

instance (Flags f0, Flags f1, Flags f2, Flags f3) => Flags (f0,f1,f2,f3) where
   setFlags = setSplitFlags $ \(f0,f1,f2,f3) -> (f0,(f1,f2,f3))

instance
   (Flags f0, Flags f1, Flags f2, Flags f3, Flags f4) =>
      Flags (f0,f1,f2,f3,f4) where
   setFlags = setSplitFlags $ \(f0,f1,f2,f3,f4) -> (f0,(f1,f2,f3,f4))

setSplitFlags ::
   (Flags split, LLVM.IsFloating a) =>
   (flags -> split) ->
   Proxy flags -> Bool -> LLVM.Value a -> LLVM.CodeGenFunction r ()
setSplitFlags split p = setFlags (fmap split p)


newtype Number flags a = Number {deconsNumber :: a}
   deriving (Eq, Ord, Show, Num, Fractional, Floating, Storable)

getNumber :: flags -> Number flags a -> a
getNumber _ (Number a) = a

instance NiceValue a => Nice.C (Number flags a) where
   type Repr (Number flags a) = Nice.Repr a
   cons = nvNumber . Nice.cons . deconsNumber
   undef = nvNumber Nice.undef
   zero = nvNumber Nice.zero
   phi bb = fmap nvNumber . Nice.phi bb . nvDenumber
   addPhi bb a b = Nice.addPhi bb (nvDenumber a) (nvDenumber b)

nvNumber :: Nice.T a -> Nice.T (Number flags a)
nvNumber (Nice.Cons a) = Nice.Cons a

nvDenumber :: Nice.T (Number flags a) -> Nice.T a
nvDenumber (Nice.Cons a) = Nice.Cons a

{-# DEPRECATED mvNumber "Use nvNumber instead" #-}
mvNumber :: Nice.T a -> Nice.T (Number flags a)
mvNumber (Nice.Cons a) = Nice.Cons a

{-# DEPRECATED mvDenumber "Use nvDenumber instead" #-}
mvDenumber :: Nice.T (Number flags a) -> Nice.T a
mvDenumber (Nice.Cons a) = Nice.Cons a


{-# DEPRECATED setMultiValueFlags "use setNiceValueFlags instead" #-}
class Nice.C a => NiceValue a where
   {-# MINIMAL setNiceValueFlags | setMultiValueFlags #-}
   setNiceValueFlags, setMultiValueFlags ::
      (Flags flags) =>
      Proxy flags -> Bool -> Nice.T (Number flags a) ->
      LLVM.CodeGenFunction r ()
   setNiceValueFlags = setMultiValueFlags
   setMultiValueFlags = setNiceValueFlags

instance NiceValue Float where
   setNiceValueFlags p b (Nice.Cons a) = setFlags p b a

instance NiceValue Double where
   setNiceValueFlags p b (Nice.Cons a) = setFlags p b a


type Id a = a -> a

{-# DEPRECATED attachMultiValueFlags "Use attachNiceValueFlags instead." #-}
attachMultiValueFlags, attachNiceValueFlags ::
   (Flags flags, NiceValue a) =>
   Id (LLVM.CodeGenFunction r (Nice.T (Number flags a)))
attachMultiValueFlags = attachNiceValueFlags
attachNiceValueFlags act = do
   mv <- act
   setMultiValueFlags Proxy True mv
   return mv

liftNumberM ::
   (m ~ LLVM.CodeGenFunction r, Flags flags, NiceValue b) =>
   (Nice.T a -> m (Nice.T b)) ->
   Nice.T (Number flags a) -> m (Nice.T (Number flags b))
liftNumberM f =
   attachMultiValueFlags . Monad.lift nvNumber . f . nvDenumber

liftNumberM2 ::
   (m ~ LLVM.CodeGenFunction r, Flags flags, NiceValue c) =>
   (Nice.T a -> Nice.T b -> m (Nice.T c)) ->
   Nice.T (Number flags a) -> Nice.T (Number flags b) ->
   m (Nice.T (Number flags c))
liftNumberM2 f a b =
   attachMultiValueFlags $ Monad.lift nvNumber $ f (nvDenumber a) (nvDenumber b)


instance (Flags flags, Nice.Compose a) => Nice.Compose (Number flags a) where
   type Composed (Number flags a) = Number flags (Nice.Composed a)
   compose = nvNumber . Nice.compose . deconsNumber

instance
      (Flags flags, Nice.Decompose pa) => Nice.Decompose (Number flags pa) where
   decompose (Number p) = Number . Nice.decompose p . nvDenumber

type instance
   Nice.Decomposed f (Number flags pa) = Number flags (Nice.Decomposed f pa)
type instance
   Nice.PatternTuple (Number flags pa) = Number flags (Nice.PatternTuple pa)


instance
   (Flags flags, NiceValue a, Nice.IntegerConstant a) =>
      Nice.IntegerConstant (Number flags a) where
   fromInteger' = nvNumber . Nice.fromInteger'

instance
   (Flags flags, NiceValue a, Nice.RationalConstant a) =>
      Nice.RationalConstant (Number flags a) where
   fromRational' = nvNumber . Nice.fromRational'

instance
   (Flags flags, NiceValue a, Nice.Additive a) =>
      Nice.Additive (Number flags a) where
   add = liftNumberM2 Nice.add
   sub = liftNumberM2 Nice.sub
   neg = liftNumberM Nice.neg

instance
   (Flags flags, NiceValue a, Nice.PseudoRing a) =>
      Nice.PseudoRing (Number flags a) where
   mul = liftNumberM2 Nice.mul

instance
   (Flags flags, NiceValue a, Nice.Field a) =>
      Nice.Field (Number flags a) where
   fdiv = liftNumberM2 Nice.fdiv

type instance Nice.Scalar (Number flags a) = Number flags (Nice.Scalar a)

instance
   (Flags flags, NiceValue a, a ~ Nice.Scalar v,
    NiceValue v, Nice.PseudoModule v) =>
      Nice.PseudoModule (Number flags v) where
   scale = liftNumberM2 Nice.scale

instance
   (Flags flags, NiceValue a, Nice.Real a) =>
      Nice.Real (Number flags a) where
   min = liftNumberM2 Nice.min
   max = liftNumberM2 Nice.max
   abs = liftNumberM Nice.abs
   signum = liftNumberM Nice.signum

instance
   (Flags flags, NiceValue a, Nice.Fraction a) =>
      Nice.Fraction (Number flags a) where
   truncate = liftNumberM Nice.truncate
   fraction = liftNumberM Nice.fraction

instance
   (Flags flags, NiceValue a, Nice.Algebraic a) =>
      Nice.Algebraic (Number flags a) where
   sqrt = liftNumberM Nice.sqrt

instance
   (Flags flags, NiceValue a, Nice.Transcendental a) =>
      Nice.Transcendental (Number flags a) where
   pi = fmap nvNumber Nice.pi
   sin = liftNumberM Nice.sin
   cos = liftNumberM Nice.cos
   exp = liftNumberM Nice.exp
   log = liftNumberM Nice.log
   pow = liftNumberM2 Nice.pow

instance
   (Flags flags, NiceValue a, Nice.Select a) =>
      Nice.Select (Number flags a) where
   select = liftNumberM2 . Nice.select

instance
   (Flags flags, NiceValue a, Nice.Comparison a) =>
      Nice.Comparison (Number flags a) where
   cmp p a b = Nice.cmp p (nvDenumber a) (nvDenumber b)

instance
   (Flags flags, NiceValue a, Nice.FloatingComparison a) =>
      Nice.FloatingComparison (Number flags a) where
   fcmp p a b = Nice.fcmp p (nvDenumber a) (nvDenumber b)



nvecNumber :: NiceVector.T n a -> NiceVector.T n (Number flags a)
nvecNumber (NiceVector.Cons v) = NiceVector.Cons v

nvecDenumber :: NiceVector.T n (Number flags a) -> NiceVector.T n a
nvecDenumber (NiceVector.Cons v) = NiceVector.Cons v

{-# DEPRECATED mvecNumber "Use nvecNumber instead" #-}
mvecNumber :: NiceVector.T n a -> NiceVector.T n (Number flags a)
mvecNumber (NiceVector.Cons v) = NiceVector.Cons v

{-# DEPRECATED mvecDenumber "Use nvecDenumber instead" #-}
mvecDenumber :: NiceVector.T n (Number flags a) -> NiceVector.T n a
mvecDenumber (NiceVector.Cons v) = NiceVector.Cons v

{-# DEPRECATED setMultiVectorFlags "use setNiceVectorFlags instead" #-}
class (NiceValue a, NiceVector.C a) => NiceVector a where
   {-# MINIMAL setNiceVectorFlags | setMultiVectorFlags #-}
   setNiceVectorFlags, setMultiVectorFlags ::
      (Flags flags, LLVM.Positive n) =>
      Proxy flags -> Bool ->
      NiceVector.T n (Number flags a) -> LLVM.CodeGenFunction r ()
   setNiceVectorFlags = setMultiVectorFlags
   setMultiVectorFlags = setNiceVectorFlags

instance NiceVector Float where
   setMultiVectorFlags p b =
      setFlags p b . NiceVector.deconsPrim . nvecDenumber

instance NiceVector Double where
   setMultiVectorFlags p b =
      setFlags p b . NiceVector.deconsPrim . nvecDenumber

{-# DEPRECATED attachMultiVectorFlags "Use attachNiceVectorFlags instead." #-}
attachNiceVectorFlags, attachMultiVectorFlags ::
   (LLVM.Positive n, Flags flags, NiceVector a) =>
   Id (LLVM.CodeGenFunction r (NiceVector.T n (Number flags a)))
attachMultiVectorFlags = attachNiceVectorFlags
attachNiceVectorFlags act = do
   mv <- act
   setMultiVectorFlags Proxy True mv
   return mv

{-# DEPRECATED liftMultiVectorM "Use liftNiceVectorM instead." #-}
liftNiceVectorM, liftMultiVectorM ::
   (m ~ LLVM.CodeGenFunction r, LLVM.Positive n, Flags flags, NiceVector b) =>
   (NiceVector.T n a -> m (NiceVector.T n b)) ->
   NiceVector.T n (Number flags a) -> m (NiceVector.T n (Number flags b))
liftMultiVectorM = liftNiceVectorM
liftNiceVectorM f =
   attachMultiVectorFlags . Monad.lift nvecNumber . f . nvecDenumber

{-# DEPRECATED liftMultiVectorM2 "Use liftNiceVectorM2 instead." #-}
liftNiceVectorM2, liftMultiVectorM2 ::
   (m ~ LLVM.CodeGenFunction r, LLVM.Positive n, Flags flags, NiceVector c) =>
   (NiceVector.T n a -> NiceVector.T n b -> m (NiceVector.T n c)) ->
   NiceVector.T n (Number flags a) -> NiceVector.T n (Number flags b) ->
   m (NiceVector.T n (Number flags c))
liftMultiVectorM2 = liftNiceVectorM2
liftNiceVectorM2 f a b =
   attachMultiVectorFlags $
      Monad.lift nvecNumber $ f (nvecDenumber a) (nvecDenumber b)

instance (Flags flags, NiceVector a) => NiceVector.C (Number flags a) where
   type Repr n (Number flags a) = NiceVector.Repr n a
   cons = nvecNumber . NiceVector.cons . fmap deconsNumber
   undef = nvecNumber NiceVector.undef
   zero = nvecNumber NiceVector.zero
   phi bb = fmap nvecNumber . NiceVector.phi bb . nvecDenumber
   addPhi bb a b = NiceVector.addPhi bb (nvecDenumber a) (nvecDenumber b)
   shuffle ks a b =
      fmap nvecNumber $ NiceVector.shuffle ks (nvecDenumber a) (nvecDenumber b)
   extract k = fmap nvNumber . NiceVector.extract k . nvecDenumber
   insert k x =
      fmap nvecNumber . NiceVector.insert k (nvDenumber x) . nvecDenumber

instance
   (Flags flags, NiceVector a, NiceVector.IntegerConstant a) =>
      NiceVector.IntegerConstant (Number flags a) where
   fromInteger' = nvecNumber . NiceVector.fromInteger'

instance
   (Flags flags, NiceVector a, NiceVector.RationalConstant a) =>
      NiceVector.RationalConstant (Number flags a) where
   fromRational' = nvecNumber . NiceVector.fromRational'

instance
   (Flags flags, NiceVector a, NiceVector.Additive a) =>
      NiceVector.Additive (Number flags a) where
   add = liftNiceVectorM2 NiceVector.add
   sub = liftNiceVectorM2 NiceVector.sub
   neg = liftNiceVectorM NiceVector.neg

instance
   (Flags flags, NiceVector a, NiceVector.PseudoRing a) =>
      NiceVector.PseudoRing (Number flags a) where
   mul = liftNiceVectorM2 NiceVector.mul

instance
   (Flags flags, NiceVector a, NiceVector.Field a) =>
      NiceVector.Field (Number flags a) where
   fdiv = liftNiceVectorM2 NiceVector.fdiv


{-
type instance NiceValue.Scalar (Number flags a) =
      Number flags (NiceValue.Scalar a)
instance
   (Flags flags, NiceVector a, NiceVector.PseudoModule a) =>
      NiceVector.PseudoModule (Number flags a) where
   scale = liftNiceVectorM2 NiceVector.mul
-}

instance
   (Flags flags, NiceVector a, NiceVector.Real a) =>
      NiceVector.Real (Number flags a) where
   min = liftNiceVectorM2 NiceVector.min
   max = liftNiceVectorM2 NiceVector.max
   abs = liftNiceVectorM NiceVector.abs
   signum = liftNiceVectorM NiceVector.signum

instance
   (Flags flags, NiceVector a, NiceVector.Fraction a) =>
      NiceVector.Fraction (Number flags a) where
   truncate = liftNiceVectorM NiceVector.truncate
   fraction = liftNiceVectorM NiceVector.fraction

instance
   (Flags flags, NiceVector a, NiceVector.Algebraic a) =>
      NiceVector.Algebraic (Number flags a) where
   sqrt = liftNiceVectorM NiceVector.sqrt

instance
   (Flags flags, NiceVector a, NiceVector.Transcendental a) =>
      NiceVector.Transcendental (Number flags a) where
   pi = fmap nvecNumber NiceVector.pi
   sin = liftNiceVectorM NiceVector.sin
   cos = liftNiceVectorM NiceVector.cos
   exp = liftNiceVectorM NiceVector.exp
   log = liftNiceVectorM NiceVector.log
   pow = liftNiceVectorM2 NiceVector.pow

instance
   (Flags flags, NiceVector a, NiceVector.Select a) =>
      NiceVector.Select (Number flags a) where
   select = liftNiceVectorM2 . NiceVector.select

instance
   (Flags flags, NiceVector a, NiceVector.Comparison a) =>
      NiceVector.Comparison (Number flags a) where
   cmp p a b = NiceVector.cmp p (nvecDenumber a) (nvecDenumber b)

instance
   (Flags flags, NiceVector a, NiceVector.FloatingComparison a) =>
      NiceVector.FloatingComparison (Number flags a) where
   fcmp p a b = NiceVector.fcmp p (nvecDenumber a) (nvecDenumber b)



class Tuple a where
   setTupleFlags ::
      (Flags flags) => Proxy flags -> Bool -> a -> LLVM.CodeGenFunction r ()

instance (LLVM.IsFloating a) => Tuple (LLVM.Value a) where
   setTupleFlags = setFlags


newtype Context flags a = Context a

proxyFromContext :: Context flags a -> Proxy flags
proxyFromContext (Context _) = Proxy

instance
   (Flags flags, Tuple.Zero a, Tuple a) =>
      Tuple.Zero (Context flags a) where
   zero = Context Tuple.zero

instance
   (Flags flags, Tuple a, A.Additive a) =>
      A.Additive (Context flags a) where
   zero = Context A.zero
   add = liftContext2 A.add
   sub = liftContext2 A.sub
   neg = liftContext A.neg

instance
   (Flags flags, A.PseudoRing a, Tuple a) =>
      A.PseudoRing (Context flags a) where
   mul = liftContext2 A.mul

type instance A.Scalar (Context flags a) = Context flags (A.Scalar a)

instance
   (Flags flags, A.PseudoModule v, Tuple v, A.Scalar v ~ a, Tuple a) =>
      A.PseudoModule (Context flags v) where
   scale = liftContext2 A.scale

instance
   (Flags flags, Tuple a, A.IntegerConstant a) =>
      A.IntegerConstant (Context flags a) where
   fromInteger' = Context . A.fromInteger'

instance
   (Flags flags, Tuple v, A.Field v) =>
      A.Field (Context flags v) where
   fdiv = liftContext2 A.fdiv

instance
   (Flags flags, Tuple a, A.RationalConstant a) =>
      A.RationalConstant (Context flags a) where
   fromRational' = Context . A.fromRational'

instance (Flags flags, Tuple a, A.Real a) => A.Real (Context flags a) where
   min = liftContext2 A.min
   max = liftContext2 A.max
   abs = liftContext A.abs
   signum = liftContext A.signum

instance
   (Flags flags, Tuple a, A.Fraction a) =>
      A.Fraction (Context flags a) where
   truncate = liftContext A.truncate
   fraction = liftContext A.fraction

instance
   (Flags flags, Tuple a, A.Comparison a) =>
      A.Comparison (Context flags a) where
   type CmpResult (Context flags a) = A.CmpResult a
   cmp p (Context x) (Context y) = A.cmp p x y

instance
   (Flags flags, Tuple a, A.FloatingComparison a) =>
      A.FloatingComparison (Context flags a) where
   fcmp p (Context x) (Context y) = A.fcmp p x y

instance
   (Flags flags, Tuple a, A.Algebraic a) =>
      A.Algebraic (Context flags a) where
   sqrt = liftContext A.sqrt

instance
   (Flags flags, Tuple a, A.Transcendental a) =>
      A.Transcendental (Context flags a) where
   pi = attachTupleFlags A.pi
   sin = liftContext A.sin
   cos = liftContext A.cos
   exp = liftContext A.exp
   log = liftContext A.log
   pow = liftContext2 A.pow


attachTupleFlags ::
   (Flags flags, Tuple a) =>
   Id (LLVM.CodeGenFunction r (Context flags a))
attachTupleFlags act = do
   c@(Context x) <- act
   setTupleFlags (proxyFromContext c) True x
   return c

liftContext :: (Flags flags, Tuple b) =>
   (a -> LLVM.CodeGenFunction r b) ->
   Context flags a -> LLVM.CodeGenFunction r (Context flags b)
liftContext f (Context x) = attachTupleFlags (Context <$> f x)

liftContext2 :: (Flags flags, Tuple c) =>
   (a -> b -> LLVM.CodeGenFunction r c) ->
   Context flags a -> Context flags b ->
   LLVM.CodeGenFunction r (Context flags c)
liftContext2 f (Context x) = liftContext $ f x
