{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} module LLVM.Extra.FastMath where import qualified LLVM.Extra.Multi.Vector as MultiVector import qualified LLVM.Extra.Multi.Value.Private as MV import qualified LLVM.Extra.Arithmetic as A import qualified LLVM.Extra.Class as Class import qualified LLVM.Core as LLVM import LLVM.Util.Proxy (Proxy(Proxy)) import Foreign.Storable (Storable) import qualified Control.Monad.HT as Monad import Control.Applicative ((<$>)) data NoNaNs = NoNaNs deriving (Show, Eq) data NoInfs = NoInfs deriving (Show, Eq) data NoSignedZeros = NoSignedZeros deriving (Show, Eq) data AllowReciprocal = AllowReciprocal deriving (Show, Eq) data Fast = Fast deriving (Show, Eq) class Flags flags where setFlags :: (LLVM.IsFloating a) => Proxy flags -> Bool -> LLVM.Value a -> LLVM.CodeGenFunction r () instance Flags NoNaNs where setFlags Proxy = LLVM.setHasNoNaNs instance Flags NoInfs where setFlags Proxy = LLVM.setHasNoInfs instance Flags NoSignedZeros where setFlags Proxy = LLVM.setHasNoSignedZeros instance Flags AllowReciprocal where setFlags Proxy = LLVM.setHasAllowReciprocal instance Flags Fast where setFlags Proxy = LLVM.setFastMath instance (Flags f0, Flags f1) => Flags (f0,f1) where setFlags p b v = setFlags (fst<$>p) b v >> setFlags (snd<$>p) b v instance (Flags f0, Flags f1, Flags f2) => Flags (f0,f1,f2) where setFlags = setSplitFlags $ \(f0,f1,f2) -> (f0,(f1,f2)) instance (Flags f0, Flags f1, Flags f2, Flags f3) => Flags (f0,f1,f2,f3) where setFlags = setSplitFlags $ \(f0,f1,f2,f3) -> (f0,(f1,f2,f3)) instance (Flags f0, Flags f1, Flags f2, Flags f3, Flags f4) => Flags (f0,f1,f2,f3,f4) where setFlags = setSplitFlags $ \(f0,f1,f2,f3,f4) -> (f0,(f1,f2,f3,f4)) setSplitFlags :: (Flags split, LLVM.IsFloating a) => (flags -> split) -> Proxy flags -> Bool -> LLVM.Value a -> LLVM.CodeGenFunction r () setSplitFlags split p = setFlags (fmap split p) newtype Number flags a = Number {deconsNumber :: a} deriving (Eq, Ord, Show, Num, Fractional, Floating, Storable) getNumber :: flags -> Number flags a -> a getNumber _ (Number a) = a instance MultiValue a => MV.C (Number flags a) where type Repr f (Number flags a) = MV.Repr f a cons = mvNumber . MV.cons . deconsNumber undef = mvNumber MV.undef zero = mvNumber MV.zero phis bb = fmap mvNumber . MV.phis bb . mvDenumber addPhis bb a b = MV.addPhis bb (mvDenumber a) (mvDenumber b) mvNumber :: MV.T a -> MV.T (Number flags a) mvNumber (MV.Cons a) = MV.Cons a mvDenumber :: MV.T (Number flags a) -> MV.T a mvDenumber (MV.Cons a) = MV.Cons a class MV.C a => MultiValue a where setMultiValueFlags :: (Flags flags) => Proxy flags -> Bool -> MV.T (Number flags a) -> LLVM.CodeGenFunction r () instance MultiValue Float where setMultiValueFlags p b (MV.Cons a) = setFlags p b a instance MultiValue Double where setMultiValueFlags p b (MV.Cons a) = setFlags p b a type Id a = a -> a attachMultiValueFlags :: (Flags flags, MultiValue a) => Id (LLVM.CodeGenFunction r (MV.T (Number flags a))) attachMultiValueFlags act = do mv <- act setMultiValueFlags Proxy True mv return mv liftNumberM :: (m ~ LLVM.CodeGenFunction r, Flags flags, MultiValue b) => (MV.T a -> m (MV.T b)) -> MV.T (Number flags a) -> m (MV.T (Number flags b)) liftNumberM f = attachMultiValueFlags . Monad.lift mvNumber . f . mvDenumber liftNumberM2 :: (m ~ LLVM.CodeGenFunction r, Flags flags, MultiValue c) => (MV.T a -> MV.T b -> m (MV.T c)) -> MV.T (Number flags a) -> MV.T (Number flags b) -> m (MV.T (Number flags c)) liftNumberM2 f a b = attachMultiValueFlags $ Monad.lift mvNumber $ f (mvDenumber a) (mvDenumber b) instance (Flags flags, MV.Compose a) => MV.Compose (Number flags a) where type Composed (Number flags a) = Number flags (MV.Composed a) compose = mvNumber . MV.compose . deconsNumber instance (Flags flags, MV.Decompose pa) => MV.Decompose (Number flags pa) where decompose (Number p) = Number . MV.decompose p . mvDenumber type instance MV.Decomposed f (Number flags pa) = Number flags (MV.Decomposed f pa) type instance MV.PatternTuple (Number flags pa) = Number flags (MV.PatternTuple pa) instance (Flags flags, MultiValue a, MV.IntegerConstant a) => MV.IntegerConstant (Number flags a) where fromInteger' = mvNumber . MV.fromInteger' instance (Flags flags, MultiValue a, MV.RationalConstant a) => MV.RationalConstant (Number flags a) where fromRational' = mvNumber . MV.fromRational' instance (Flags flags, MultiValue a, MV.Additive a) => MV.Additive (Number flags a) where add = liftNumberM2 MV.add sub = liftNumberM2 MV.sub neg = liftNumberM MV.neg instance (Flags flags, MultiValue a, MV.PseudoRing a) => MV.PseudoRing (Number flags a) where mul = liftNumberM2 MV.mul instance (Flags flags, MultiValue a, MV.Field a) => MV.Field (Number flags a) where fdiv = liftNumberM2 MV.fdiv type instance MV.Scalar (Number flags a) = Number flags (MV.Scalar a) instance (Flags flags, MultiValue a, a ~ MV.Scalar v, MultiValue v, MV.PseudoModule v) => MV.PseudoModule (Number flags v) where scale = liftNumberM2 MV.scale instance (Flags flags, MultiValue a, MV.Real a) => MV.Real (Number flags a) where min = liftNumberM2 MV.min max = liftNumberM2 MV.max abs = liftNumberM MV.abs signum = liftNumberM MV.signum instance (Flags flags, MultiValue a, MV.Fraction a) => MV.Fraction (Number flags a) where truncate = liftNumberM MV.truncate fraction = liftNumberM MV.fraction instance (Flags flags, MultiValue a, MV.Algebraic a) => MV.Algebraic (Number flags a) where sqrt = liftNumberM MV.sqrt instance (Flags flags, MultiValue a, MV.Transcendental a) => MV.Transcendental (Number flags a) where pi = fmap mvNumber MV.pi sin = liftNumberM MV.sin cos = liftNumberM MV.cos exp = liftNumberM MV.exp log = liftNumberM MV.log pow = liftNumberM2 MV.pow instance (Flags flags, MultiValue a, MV.Select a) => MV.Select (Number flags a) where select = liftNumberM2 . MV.select instance (Flags flags, MultiValue a, MV.Comparison a) => MV.Comparison (Number flags a) where cmp p a b = MV.cmp p (mvDenumber a) (mvDenumber b) instance (Flags flags, MultiValue a, MV.FloatingComparison a) => MV.FloatingComparison (Number flags a) where fcmp p a b = MV.fcmp p (mvDenumber a) (mvDenumber b) mvecNumber :: MultiVector.T n a -> MultiVector.T n (Number flags a) mvecNumber (MultiVector.Cons v) = MultiVector.Cons v mvecDenumber :: MultiVector.T n (Number flags a) -> MultiVector.T n a mvecDenumber (MultiVector.Cons v) = MultiVector.Cons v class (MultiValue a, MultiVector.C a) => MultiVector a where setMultiVectorFlags :: (Flags flags, LLVM.Positive n) => Proxy flags -> Bool -> MultiVector.T n (Number flags a) -> LLVM.CodeGenFunction r () instance MultiVector Float where setMultiVectorFlags p b = setFlags p b . MultiVector.deconsPrim . mvecDenumber instance MultiVector Double where setMultiVectorFlags p b = setFlags p b . MultiVector.deconsPrim . mvecDenumber attachMultiVectorFlags :: (LLVM.Positive n, Flags flags, MultiVector a) => Id (LLVM.CodeGenFunction r (MultiVector.T n (Number flags a))) attachMultiVectorFlags act = do mv <- act setMultiVectorFlags Proxy True mv return mv liftMultiVectorM :: (m ~ LLVM.CodeGenFunction r, LLVM.Positive n, Flags flags, MultiVector b) => (MultiVector.T n a -> m (MultiVector.T n b)) -> MultiVector.T n (Number flags a) -> m (MultiVector.T n (Number flags b)) liftMultiVectorM f = attachMultiVectorFlags . Monad.lift mvecNumber . f . mvecDenumber liftMultiVectorM2 :: (m ~ LLVM.CodeGenFunction r, LLVM.Positive n, Flags flags, MultiVector c) => (MultiVector.T n a -> MultiVector.T n b -> m (MultiVector.T n c)) -> MultiVector.T n (Number flags a) -> MultiVector.T n (Number flags b) -> m (MultiVector.T n (Number flags c)) liftMultiVectorM2 f a b = attachMultiVectorFlags $ Monad.lift mvecNumber $ f (mvecDenumber a) (mvecDenumber b) instance (Flags flags, MultiVector a) => MultiVector.C (Number flags a) where cons = mvecNumber . MultiVector.cons . fmap deconsNumber undef = mvecNumber MultiVector.undef zero = mvecNumber MultiVector.zero phis bb = fmap mvecNumber . MultiVector.phis bb . mvecDenumber addPhis bb a b = MultiVector.addPhis bb (mvecDenumber a) (mvecDenumber b) shuffle ks a b = fmap mvecNumber $ MultiVector.shuffle ks (mvecDenumber a) (mvecDenumber b) extract k = fmap mvNumber . MultiVector.extract k . mvecDenumber insert k x = fmap mvecNumber . MultiVector.insert k (mvDenumber x) . mvecDenumber instance (Flags flags, MultiVector a, MV.IntegerConstant a, MultiVector.IntegerConstant a) => MultiVector.IntegerConstant (Number flags a) where fromInteger' = mvecNumber . MultiVector.fromInteger' instance (Flags flags, MultiVector a, MV.RationalConstant a, MultiVector.RationalConstant a) => MultiVector.RationalConstant (Number flags a) where fromRational' = mvecNumber . MultiVector.fromRational' instance (Flags flags, MultiVector a, MultiVector.Additive a) => MultiVector.Additive (Number flags a) where add = liftMultiVectorM2 MultiVector.add sub = liftMultiVectorM2 MultiVector.sub neg = liftMultiVectorM MultiVector.neg instance (Flags flags, MultiVector a, MultiVector.PseudoRing a) => MultiVector.PseudoRing (Number flags a) where mul = liftMultiVectorM2 MultiVector.mul instance (Flags flags, MultiVector a, MultiVector.Field a) => MultiVector.Field (Number flags a) where fdiv = liftMultiVectorM2 MultiVector.fdiv {- type instance MultiValue.Scalar (Number flags a) = Number flags (MultiValue.Scalar a) instance (Flags flags, MultiVector a, MultiVector.PseudoModule a) => MultiVector.PseudoModule (Number flags a) where scale = liftMultiVectorM2 MultiVector.mul -} instance (Flags flags, MultiVector a, MultiVector.Real a) => MultiVector.Real (Number flags a) where min = liftMultiVectorM2 MultiVector.min max = liftMultiVectorM2 MultiVector.max abs = liftMultiVectorM MultiVector.abs signum = liftMultiVectorM MultiVector.signum instance (Flags flags, MultiVector a, MultiVector.Fraction a) => MultiVector.Fraction (Number flags a) where truncate = liftMultiVectorM MultiVector.truncate fraction = liftMultiVectorM MultiVector.fraction instance (Flags flags, MultiVector a, MultiVector.Algebraic a) => MultiVector.Algebraic (Number flags a) where sqrt = liftMultiVectorM MultiVector.sqrt instance (Flags flags, MultiVector a, MultiVector.Transcendental a) => MultiVector.Transcendental (Number flags a) where pi = fmap mvecNumber MultiVector.pi sin = liftMultiVectorM MultiVector.sin cos = liftMultiVectorM MultiVector.cos exp = liftMultiVectorM MultiVector.exp log = liftMultiVectorM MultiVector.log pow = liftMultiVectorM2 MultiVector.pow instance (Flags flags, MultiVector a, MultiVector.Select a) => MultiVector.Select (Number flags a) where select = liftMultiVectorM2 . MultiVector.select instance (Flags flags, MultiVector a, MultiVector.Comparison a) => MultiVector.Comparison (Number flags a) where cmp p a b = MultiVector.cmp p (mvecDenumber a) (mvecDenumber b) instance (Flags flags, MultiVector a, MultiVector.FloatingComparison a) => MultiVector.FloatingComparison (Number flags a) where fcmp p a b = MultiVector.fcmp p (mvecDenumber a) (mvecDenumber b) class Tuple a where setTupleFlags :: (Flags flags) => Proxy flags -> Bool -> a -> LLVM.CodeGenFunction r () instance (LLVM.IsFloating a) => Tuple (LLVM.Value a) where setTupleFlags = setFlags newtype Context flags a = Context a proxyFromContext :: Context flags a -> Proxy flags proxyFromContext (Context _) = Proxy instance (Flags flags, Class.Zero a, Tuple a) => Class.Zero (Context flags a) where zeroTuple = Context Class.zeroTuple instance (Flags flags, Tuple a, A.Additive a) => A.Additive (Context flags a) where zero = Context A.zero add = liftContext2 A.add sub = liftContext2 A.sub neg = liftContext A.neg instance (Flags flags, A.PseudoRing a, Tuple a) => A.PseudoRing (Context flags a) where mul = liftContext2 A.mul type instance A.Scalar (Context flags a) = Context flags (A.Scalar a) instance (Flags flags, A.PseudoModule v, Tuple v, A.Scalar v ~ a, Tuple a) => A.PseudoModule (Context flags v) where scale = liftContext2 A.scale instance (Flags flags, Tuple a, A.IntegerConstant a) => A.IntegerConstant (Context flags a) where fromInteger' = Context . A.fromInteger' instance (Flags flags, Tuple v, A.Field v) => A.Field (Context flags v) where fdiv = liftContext2 A.fdiv instance (Flags flags, Tuple a, A.RationalConstant a) => A.RationalConstant (Context flags a) where fromRational' = Context . A.fromRational' instance (Flags flags, Tuple a, A.Real a) => A.Real (Context flags a) where min = liftContext2 A.min max = liftContext2 A.max abs = liftContext A.abs signum = liftContext A.signum instance (Flags flags, Tuple a, A.Fraction a) => A.Fraction (Context flags a) where truncate = liftContext A.truncate fraction = liftContext A.fraction instance (Flags flags, Tuple a, A.Comparison a) => A.Comparison (Context flags a) where type CmpResult (Context flags a) = A.CmpResult a cmp p (Context x) (Context y) = A.cmp p x y instance (Flags flags, Tuple a, A.FloatingComparison a) => A.FloatingComparison (Context flags a) where fcmp p (Context x) (Context y) = A.fcmp p x y instance (Flags flags, Tuple a, A.Algebraic a) => A.Algebraic (Context flags a) where sqrt = liftContext A.sqrt instance (Flags flags, Tuple a, A.Transcendental a) => A.Transcendental (Context flags a) where pi = attachTupleFlags A.pi sin = liftContext A.sin cos = liftContext A.cos exp = liftContext A.exp log = liftContext A.log pow = liftContext2 A.pow attachTupleFlags :: (Flags flags, Tuple a) => Id (LLVM.CodeGenFunction r (Context flags a)) attachTupleFlags act = do c@(Context x) <- act setTupleFlags (proxyFromContext c) True x return c liftContext :: (Flags flags, Tuple b) => (a -> LLVM.CodeGenFunction r b) -> Context flags a -> LLVM.CodeGenFunction r (Context flags b) liftContext f (Context x) = attachTupleFlags (Context <$> f x) liftContext2 :: (Flags flags, Tuple c) => (a -> b -> LLVM.CodeGenFunction r c) -> Context flags a -> Context flags b -> LLVM.CodeGenFunction r (Context flags c) liftContext2 f (Context x) = liftContext $ f x