{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module LLVM.Extra.FastMath where

import qualified LLVM.Extra.Multi.Value.Private as MV
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Class as Class
import qualified LLVM.Core as LLVM
import LLVM.Util.Proxy (Proxy(Proxy))

import Foreign.Storable (Storable)

import qualified Control.Monad.HT as Monad
import Control.Applicative ((<$>))


data NoNaNs          = NoNaNs          deriving (Show, Eq)
data NoInfs          = NoInfs          deriving (Show, Eq)
data NoSignedZeros   = NoSignedZeros   deriving (Show, Eq)
data AllowReciprocal = AllowReciprocal deriving (Show, Eq)
data Fast            = Fast            deriving (Show, Eq)


class Flags flags where
   setFlags ::
      (LLVM.IsFloating a) =>
      Proxy flags -> Bool -> LLVM.Value a -> LLVM.CodeGenFunction r ()

instance Flags NoNaNs          where setFlags Proxy = LLVM.setHasNoNaNs
instance Flags NoInfs          where setFlags Proxy = LLVM.setHasNoInfs
instance Flags NoSignedZeros   where setFlags Proxy = LLVM.setHasNoSignedZeros
instance Flags AllowReciprocal where setFlags Proxy = LLVM.setHasAllowReciprocal
instance Flags Fast            where setFlags Proxy = LLVM.setFastMath

instance (Flags f0, Flags f1) => Flags (f0,f1) where
   setFlags p b v = setFlags (fst<$>p) b v >> setFlags (snd<$>p) b v

instance (Flags f0, Flags f1, Flags f2) => Flags (f0,f1,f2) where
   setFlags = setSplitFlags $ \(f0,f1,f2) -> (f0,(f1,f2))

instance (Flags f0, Flags f1, Flags f2, Flags f3) => Flags (f0,f1,f2,f3) where
   setFlags = setSplitFlags $ \(f0,f1,f2,f3) -> (f0,(f1,f2,f3))

instance
   (Flags f0, Flags f1, Flags f2, Flags f3, Flags f4) =>
      Flags (f0,f1,f2,f3,f4) where
   setFlags = setSplitFlags $ \(f0,f1,f2,f3,f4) -> (f0,(f1,f2,f3,f4))

setSplitFlags ::
   (Flags split, LLVM.IsFloating a) =>
   (flags -> split) ->
   Proxy flags -> Bool -> LLVM.Value a -> LLVM.CodeGenFunction r ()
setSplitFlags split p = setFlags (fmap split p)


newtype Number flags a = Number {deconsNumber :: a}
   deriving (Eq, Ord, Show, Num, Fractional, Floating, Storable)

getNumber :: flags -> Number flags a -> a
getNumber _ (Number a) = a

instance MultiValue a => MV.C (Number flags a) where
   type Repr f (Number flags a) = MV.Repr f a
   cons = mvNumber . MV.cons . deconsNumber
   undef = mvNumber MV.undef
   zero = mvNumber MV.zero
   phis bb = fmap mvNumber . MV.phis bb . mvDenumber
   addPhis bb a b = MV.addPhis bb (mvDenumber a) (mvDenumber b)

mvNumber :: MV.T a -> MV.T (Number flags a)
mvNumber (MV.Cons a) = MV.Cons a

mvDenumber :: MV.T (Number flags a) -> MV.T a
mvDenumber (MV.Cons a) = MV.Cons a


class MV.C a => MultiValue a where
   setMultiValueFlags ::
      (Flags flags) =>
      Proxy flags -> Bool -> MV.T (Number flags a) -> LLVM.CodeGenFunction r ()

instance MultiValue Float where
   setMultiValueFlags p b (MV.Cons a) = setFlags p b a

instance MultiValue Double where
   setMultiValueFlags p b (MV.Cons a) = setFlags p b a


type Id a = a -> a

attachMultiValueFlags ::
   (Flags flags, MultiValue a) =>
   Id (LLVM.CodeGenFunction r (MV.T (Number flags a)))
attachMultiValueFlags act = do
   mv <- act
   setMultiValueFlags Proxy True mv
   return mv

liftNumberM ::
   (m ~ LLVM.CodeGenFunction r, Flags flags, MultiValue b) =>
   (MV.T a -> m (MV.T b)) ->
   MV.T (Number flags a) -> m (MV.T (Number flags b))
liftNumberM f =
   attachMultiValueFlags . Monad.lift mvNumber . f . mvDenumber

liftNumberM2 ::
   (m ~ LLVM.CodeGenFunction r, Flags flags, MultiValue c) =>
   (MV.T a -> MV.T b -> m (MV.T c)) ->
   MV.T (Number flags a) -> MV.T (Number flags b) -> m (MV.T (Number flags c))
liftNumberM2 f a b =
   attachMultiValueFlags $ Monad.lift mvNumber $ f (mvDenumber a) (mvDenumber b)


instance (Flags flags, MV.Compose a) => MV.Compose (Number flags a) where
   type Composed (Number flags a) = Number flags (MV.Composed a)
   compose = mvNumber . MV.compose . deconsNumber

instance (Flags flags, MV.Decompose pa) => MV.Decompose (Number flags pa) where
   decompose (Number p) = Number . MV.decompose p . mvDenumber

type instance
   MV.Decomposed f (Number flags pa) = Number flags (MV.Decomposed f pa)
type instance
   MV.PatternTuple (Number flags pa) = Number flags (MV.PatternTuple pa)


instance
   (Flags flags, MultiValue a, MV.IntegerConstant a) =>
      MV.IntegerConstant (Number flags a) where
   fromInteger' = mvNumber . MV.fromInteger'

instance
   (Flags flags, MultiValue a, MV.RationalConstant a) =>
      MV.RationalConstant (Number flags a) where
   fromRational' = mvNumber . MV.fromRational'

instance
   (Flags flags, MultiValue a, MV.Additive a) =>
      MV.Additive (Number flags a) where
   add = liftNumberM2 MV.add
   sub = liftNumberM2 MV.sub
   neg = liftNumberM MV.neg

instance
   (Flags flags, MultiValue a, MV.PseudoRing a) =>
      MV.PseudoRing (Number flags a) where
   mul = liftNumberM2 MV.mul

instance
   (Flags flags, MultiValue a, MV.Field a) =>
      MV.Field (Number flags a) where
   fdiv = liftNumberM2 MV.fdiv

type instance MV.Scalar (Number flags a) = Number flags (MV.Scalar a)

instance
   (Flags flags, MultiValue a, a ~ MV.Scalar v,
    MultiValue v, MV.PseudoModule v) =>
      MV.PseudoModule (Number flags v) where
   scale = liftNumberM2 MV.scale

instance
   (Flags flags, MultiValue a, MV.Real a) =>
      MV.Real (Number flags a) where
   min = liftNumberM2 MV.min
   max = liftNumberM2 MV.max
   abs = liftNumberM MV.abs
   signum = liftNumberM MV.signum

instance
   (Flags flags, MultiValue a, MV.Fraction a) =>
      MV.Fraction (Number flags a) where
   truncate = liftNumberM MV.truncate
   fraction = liftNumberM MV.fraction

instance
   (Flags flags, MultiValue a, MV.Algebraic a) =>
      MV.Algebraic (Number flags a) where
   sqrt = liftNumberM MV.sqrt

instance
   (Flags flags, MultiValue a, MV.Transcendental a) =>
      MV.Transcendental (Number flags a) where
   pi = fmap mvNumber MV.pi
   sin = liftNumberM MV.sin
   cos = liftNumberM MV.cos
   exp = liftNumberM MV.exp
   log = liftNumberM MV.log
   pow = liftNumberM2 MV.pow

instance
   (Flags flags, MultiValue a, MV.Select a) =>
      MV.Select (Number flags a) where
   select = liftNumberM2 . MV.select

instance
   (Flags flags, MultiValue a, MV.Comparison a) =>
      MV.Comparison (Number flags a) where
   cmp p a b = MV.cmp p (mvDenumber a) (mvDenumber b)

instance
   (Flags flags, MultiValue a, MV.FloatingComparison a) =>
      MV.FloatingComparison (Number flags a) where
   fcmp p a b = MV.fcmp p (mvDenumber a) (mvDenumber b)



class Tuple a where
   setTupleFlags ::
      (Flags flags) => Proxy flags -> Bool -> a -> LLVM.CodeGenFunction r ()

instance (LLVM.IsFloating a) => Tuple (LLVM.Value a) where
   setTupleFlags = setFlags


newtype Context flags a = Context a

proxyFromContext :: Context flags a -> Proxy flags
proxyFromContext (Context _) = Proxy

instance
   (Flags flags, Class.Zero a, Tuple a) =>
      Class.Zero (Context flags a) where
   zeroTuple = Context Class.zeroTuple

instance
   (Flags flags, Tuple a, A.Additive a) =>
      A.Additive (Context flags a) where
   zero = Context A.zero
   add = liftContext2 A.add
   sub = liftContext2 A.sub
   neg = liftContext A.neg

instance
   (Flags flags, A.PseudoRing a, Tuple a) =>
      A.PseudoRing (Context flags a) where
   mul = liftContext2 A.mul

type instance A.Scalar (Context flags a) = Context flags (A.Scalar a)

instance
   (Flags flags, A.PseudoModule v, Tuple v, A.Scalar v ~ a, Tuple a) =>
      A.PseudoModule (Context flags v) where
   scale = liftContext2 A.scale

instance
   (Flags flags, Tuple a, A.IntegerConstant a) =>
      A.IntegerConstant (Context flags a) where
   fromInteger' = Context . A.fromInteger'

instance
   (Flags flags, Tuple v, A.Field v) =>
      A.Field (Context flags v) where
   fdiv = liftContext2 A.fdiv

instance
   (Flags flags, Tuple a, A.RationalConstant a) =>
      A.RationalConstant (Context flags a) where
   fromRational' = Context . A.fromRational'

instance (Flags flags, Tuple a, A.Real a) => A.Real (Context flags a) where
   min = liftContext2 A.min
   max = liftContext2 A.max
   abs = liftContext A.abs
   signum = liftContext A.signum

instance
   (Flags flags, Tuple a, A.Fraction a) =>
      A.Fraction (Context flags a) where
   truncate = liftContext A.truncate
   fraction = liftContext A.fraction

instance
   (Flags flags, Tuple a, A.Comparison a) =>
      A.Comparison (Context flags a) where
   type CmpResult (Context flags a) = A.CmpResult a
   cmp p (Context x) (Context y) = A.cmp p x y

instance
   (Flags flags, Tuple a, A.FloatingComparison a) =>
      A.FloatingComparison (Context flags a) where
   fcmp p (Context x) (Context y) = A.fcmp p x y

instance
   (Flags flags, Tuple a, A.Algebraic a) =>
      A.Algebraic (Context flags a) where
   sqrt = liftContext A.sqrt

instance
   (Flags flags, Tuple a, A.Transcendental a) =>
      A.Transcendental (Context flags a) where
   pi = attachTupleFlags A.pi
   sin = liftContext A.sin
   cos = liftContext A.cos
   exp = liftContext A.exp
   log = liftContext A.log
   pow = liftContext2 A.pow


attachTupleFlags ::
   (Flags flags, Tuple a) =>
   Id (LLVM.CodeGenFunction r (Context flags a))
attachTupleFlags act = do
   c@(Context x) <- act
   setTupleFlags (proxyFromContext c) True x
   return c

liftContext :: (Flags flags, Tuple b) =>
   (a -> LLVM.CodeGenFunction r b) ->
   Context flags a -> LLVM.CodeGenFunction r (Context flags b)
liftContext f (Context x) = attachTupleFlags (Context <$> f x)

liftContext2 :: (Flags flags, Tuple c) =>
   (a -> b -> LLVM.CodeGenFunction r c) ->
   Context flags a -> Context flags b ->
   LLVM.CodeGenFunction r (Context flags c)
liftContext2 f (Context x) = liftContext $ f x