{-# OPTIONS_GHC -fno-warn-orphans #-} {-# LANGUAGE CPP #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE TypeFamilies #-} module LLVM.Util.Arithmetic( TValue, (%==), (%/=), (%<), (%<=), (%>), (%>=), (%&&), (%||), (?), (??), retrn, set, ArithFunction, arithFunction, Return, ToArithFunction, toArithFunction, recursiveFunction, CallIntrinsic, ) where import qualified LLVM.Util.Intrinsic as Intrinsic import qualified LLVM.Core as LLVM import LLVM.Util.Loop (mapVector, mapVector2) import LLVM.Core.CodeGen (UnValue, CodeValue, CodeResult) import LLVM.Core import qualified Type.Data.Num.Decimal.Number as Dec import Control.Monad (liftM2) -- |Synonym for @CodeGenFunction r (Value a)@. type TValue r a = CodeGenFunction r (Value a) infix 4 %==, %/=, %<, %<=, %>=, %> -- |Comparison functions. (%==), (%/=), (%<), (%<=), (%>), (%>=) :: (CmpRet a) => TValue r a -> TValue r a -> TValue r (CmpResult a) (%==) = binop $ LLVM.cmp CmpEQ (%/=) = binop $ LLVM.cmp CmpNE (%>) = binop $ LLVM.cmp CmpGT (%>=) = binop $ LLVM.cmp CmpGE (%<) = binop $ LLVM.cmp CmpLT (%<=) = binop $ LLVM.cmp CmpLE infixr 3 %&& infixr 2 %|| -- |Lazy and. (%&&) :: TValue r Bool -> TValue r Bool -> TValue r Bool a %&& b = a ? (b, return (valueOf False)) -- |Lazy or. (%||) :: TValue r Bool -> TValue r Bool -> TValue r Bool a %|| b = a ? (return (valueOf True), b) infix 0 ? -- |Conditional, returns first element of the pair when condition is true, otherwise second. (?) :: (IsFirstClass a) => TValue r Bool -> (TValue r a, TValue r a) -> TValue r a c ? (t, f) = do lt <- newBasicBlock lf <- newBasicBlock lj <- newBasicBlock c' <- c condBr c' lt lf defineBasicBlock lt rt <- t lt' <- getCurrentBasicBlock br lj defineBasicBlock lf rf <- f lf' <- getCurrentBasicBlock br lj defineBasicBlock lj phi [(rt, lt'), (rf, lf')] infix 0 ?? (??) :: (IsFirstClass a, CmpRet a) => TValue r (CmpResult a) -> (TValue r a, TValue r a) -> TValue r a c ?? (t, f) = do c' <- c t' <- t f' <- f select c' t' f' -- | Return a value from an 'arithFunction'. retrn :: TValue a a -> CodeGenFunction a () retrn x = x >>= ret -- | Use @x <- set $ ...@ to make a binding. set :: TValue r a -> CodeGenFunction r (TValue r a) set x = do x' <- x; return (return x') instance Eq (CodeGenFunction r av) where (==) = error "CodeGenFunction Value: (==)" instance Ord (CodeGenFunction r av) where compare = error "CodeGenFunction Value: compare" instance (IsArithmetic a, CmpRet a, Num a, IsConst a, Value a ~ av) => Num (CodeGenFunction r av) where (+) = binop add (-) = binop sub (*) = binop mul negate = (>>= neg) abs x = x %< 0 ?? (-x, x) signum x = x %< 0 ?? (-1, x %> 0 ?? (1, 0)) fromInteger = return . valueOf . fromInteger instance (IsArithmetic a, CmpRet a, Num a, IsConst a, Value a ~ av) => Enum (CodeGenFunction r av) where succ x = x + 1 pred x = x - 1 fromEnum _ = error "CodeGenFunction Value: fromEnum" toEnum = fromIntegral instance (IsArithmetic a, CmpRet a, Num a, IsConst a, Value a ~ av) => Real (CodeGenFunction r av) where toRational _ = error "CodeGenFunction Value: toRational" instance (CmpRet a, Num a, IsConst a, IsInteger a, Value a ~ av) => Integral (CodeGenFunction r av) where quot = binop idiv rem = binop irem quotRem x y = (quot x y, rem x y) toInteger _ = error "CodeGenFunction Value: toInteger" instance (CmpRet a, Fractional a, IsConst a, IsFloating a, Value a ~ av) => Fractional (CodeGenFunction r av) where (/) = binop fdiv fromRational = return . valueOf . fromRational instance (CmpRet a, Fractional a, IsConst a, IsFloating a, Value a ~ av) => RealFrac (CodeGenFunction r av) where properFraction _ = error "CodeGenFunction Value: properFraction" instance (CmpRet a, CallIntrinsic a, Floating a, IsConst a, IsFloating a, Value a ~ av) => Floating (CodeGenFunction r av) where pi = return $ valueOf pi sqrt = callIntrinsic1 "sqrt" sin = callIntrinsic1 "sin" cos = callIntrinsic1 "cos" (**) = callIntrinsic2 "pow" exp = callIntrinsic1 "exp" log = callIntrinsic1 "log" asin _ = error "LLVM missing intrinsic: asin" acos _ = error "LLVM missing intrinsic: acos" atan _ = error "LLVM missing intrinsic: atan" sinh x = (exp x - exp (-x)) / 2 cosh x = (exp x + exp (-x)) / 2 asinh x = log (x + sqrt (x*x + 1)) acosh x = log (x + sqrt (x*x - 1)) atanh x = (log (1 + x) - log (1 - x)) / 2 instance (CmpRet a, CallIntrinsic a, RealFloat a, IsConst a, IsFloating a, Value a ~ av) => RealFloat (CodeGenFunction r av) where floatRadix _ = floatRadix (undefined :: a) floatDigits _ = floatDigits (undefined :: a) floatRange _ = floatRange (undefined :: a) decodeFloat _ = error "CodeGenFunction Value: decodeFloat" encodeFloat _ _ = error "CodeGenFunction Value: encodeFloat" exponent _ = 0 scaleFloat 0 x = x scaleFloat _ _ = error "CodeGenFunction Value: scaleFloat" isNaN _ = error "CodeGenFunction Value: isNaN" isInfinite _ = error "CodeGenFunction Value: isInfinite" isDenormalized _ = error "CodeGenFunction Value: isDenormalized" isNegativeZero _ = error "CodeGenFunction Value: isNegativeZero" isIEEE _ = isIEEE (undefined :: a) binop :: (Value a -> Value b -> TValue r c) -> TValue r a -> TValue r b -> TValue r c binop op x y = do x' <- x y' <- y op x' y' ------------------------------------------- {- | Turn @(a -> b -> CodeGenFunction r c)@ into @(a -> b -> CodeGenFunction r ())@ for @r ~ Result c@ -} class (RetB a ~ b, CodeValue a ~ z, RetA z b ~ a) => Return z a b where type RetA z b type RetB a addRet :: a -> b instance (Ret z, Result z ~ r, r ~ ra, r ~ rb, z ~ a, unit ~ ()) => Return z (CodeGenFunction ra a) (CodeGenFunction rb unit) where type RetA z (CodeGenFunction rb unit) = CodeGenFunction (Result z) z type RetB (CodeGenFunction ra a) = CodeGenFunction ra () addRet code = ret =<< code instance (Return z b0 b1, a0 ~ a1) => Return z (a0 -> b0) (a1 -> b1) where type RetA z (a1 -> b1) = a1 -> RetA z b1 type RetB (a0 -> b0) = a0 -> RetB b0 addRet f = addRet . f class (FunA r b ~ a, FunB a ~ b, CodeResult a ~ r) => ArithFunction r a b where type FunA r b type FunB a arithFunction' :: a -> b instance (r ~ ra, r ~ rb, a ~ b) => ArithFunction r (CodeGenFunction ra a) (CodeGenFunction rb b) where type FunA r (CodeGenFunction rb b) = CodeGenFunction r b type FunB (CodeGenFunction ra a) = CodeGenFunction ra a arithFunction' x = x instance (ArithFunction r b0 b1, a0 ~ CodeGenFunction r a1) => ArithFunction r (a0 -> b0) (a1 -> b1) where type FunA r (a1 -> b1) = CodeGenFunction r a1 -> FunA r b1 type FunB (a0 -> b0) = CodeValue a0 -> FunB b0 arithFunction' f = arithFunction' . f . return -- |Unlift a function with @TValue@ to have @Value@ arguments. arithFunction :: (ArithFunction r a b, r ~ Result z, Return z b c) => a -> c arithFunction = addRet . arithFunction' class (TFunB r a ~ b, TFunA b ~ a, CodeResult b ~ r) => ToArithFunction r a b where type TFunA b type TFunB r a toArithFunction' :: CodeGenFunction r (Call a) -> b instance (Value a ~ b) => ToArithFunction r (IO a) (CodeGenFunction r b) where type TFunA (CodeGenFunction r b) = IO (UnValue b) type TFunB r (IO a) = CodeGenFunction r (Value a) toArithFunction' cl = runCall =<< cl instance (ToArithFunction r b0 b1, CodeGenFunction r (Value a0) ~ a1) => ToArithFunction r (a0 -> b0) (a1 -> b1) where type TFunA (a1 -> b1) = UnValue (CodeValue a1) -> TFunA b1 type TFunB r (a0 -> b0) = CodeGenFunction r (Value a0) -> TFunB r b0 toArithFunction' cl x = toArithFunction' (liftM2 applyCall cl x) _toArithFunction2 :: Function (a -> b -> IO c) -> TValue r a -> TValue r b -> TValue r c _toArithFunction2 f tx ty = do x <- tx y <- ty runCall $ callFromFunction f `applyCall` x `applyCall` y -- |Lift a function from having @Value@ arguments to having @TValue@ arguments. toArithFunction :: (ToArithFunction r f g) => Function f -> g toArithFunction = toArithFunction' . return . callFromFunction ------------------------------------------- -- |Define a recursive 'arithFunction', gets passed itself as the first argument. recursiveFunction :: (IsFunction f, FunctionArgs f, code ~ FunctionCodeGen f, ArithFunction r arith open, r ~ Result z, Return z open code, ToArithFunction r f g) => (g -> arith) -> CodeGenModule (Function f) recursiveFunction af = do f <- newFunction ExternalLinkage defineFunction f $ arithFunction $ af $ toArithFunction f return f ------------------------------------------- class CallIntrinsic a where callIntrinsic1' :: String -> Value a -> TValue r a callIntrinsic2' :: String -> Value a -> Value a -> TValue r a instance CallIntrinsic Float where callIntrinsic1' = Intrinsic.call1 callIntrinsic2' = Intrinsic.call2 instance CallIntrinsic Double where callIntrinsic1' = Intrinsic.call1 callIntrinsic2' = Intrinsic.call2 {- I think such a special case for certain systems would be better handled as in LLVM.Extra.Extension. (lemming) -} macOS :: Bool #if defined(__MACOS__) macOS = True #else macOS = False #endif instance (Dec.Positive n, IsPrimitive a, CallIntrinsic a) => CallIntrinsic (Vector n a) where callIntrinsic1' s x = if macOS && Dec.integerFromSingleton (Dec.singleton :: Dec.Singleton n) == 4 && elem s ["sqrt", "log", "exp", "sin", "cos", "tan"] then do op <- externFunction ("v" ++ s ++ "f") call op x else mapVector (callIntrinsic1' s) x callIntrinsic2' s = mapVector2 (callIntrinsic2' s) callIntrinsic1 :: (CallIntrinsic a) => String -> TValue r a -> TValue r a callIntrinsic1 s x = do x' <- x; callIntrinsic1' s x' callIntrinsic2 :: (CallIntrinsic a) => String -> TValue r a -> TValue r a -> TValue r a callIntrinsic2 s = binop (callIntrinsic2' s)