{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE TypeFamilies #-}
module LLVM.Util.Arithmetic(
    TValue,
    (%==), (%/=), (%<), (%<=), (%>), (%>=),
    (%&&), (%||),
    (?), (??),
    retrn, set,
    ArithFunction, arithFunction,
    ToArithFunction, toArithFunction, recursiveFunction,
    CallIntrinsic,
    ) where

import qualified LLVM.Core as LLVM
import LLVM.Util.Loop (mapVector, mapVector2)
import LLVM.Util.Proxy (Proxy(Proxy))
import LLVM.Core

import qualified Type.Data.Num.Decimal.Number as Dec

import Control.Monad (liftM2)

-- |Synonym for @CodeGenFunction r (Value a)@.
type TValue r a = CodeGenFunction r (Value a)


infix  4  %==, %/=, %<, %<=, %>=, %>
-- |Comparison functions.
(%==), (%/=), (%<), (%<=), (%>), (%>=) :: (CmpRet a) => TValue r a -> TValue r a -> TValue r (CmpResult a)
(%==) = binop $ LLVM.cmp CmpEQ
(%/=) = binop $ LLVM.cmp CmpNE
(%>)  = binop $ LLVM.cmp CmpGT
(%>=) = binop $ LLVM.cmp CmpGE
(%<)  = binop $ LLVM.cmp CmpLT
(%<=) = binop $ LLVM.cmp CmpLE

infixr 3  %&&
infixr 2  %||
-- |Lazy and.
(%&&) :: TValue r Bool -> TValue r Bool -> TValue r Bool
a %&& b = a ? (b, return (valueOf False))
-- |Lazy or.
(%||) :: TValue r Bool -> TValue r Bool -> TValue r Bool
a %|| b = a ? (return (valueOf True), b)

infix  0 ?
-- |Conditional, returns first element of the pair when condition is true, otherwise second.
(?) :: (IsFirstClass a) => TValue r Bool -> (TValue r a, TValue r a) -> TValue r a
c ? (t, f) = do
    lt <- newBasicBlock
    lf <- newBasicBlock
    lj <- newBasicBlock
    c' <- c
    condBr c' lt lf
    defineBasicBlock lt
    rt <- t
    lt' <- getCurrentBasicBlock
    br lj
    defineBasicBlock lf
    rf <- f
    lf' <- getCurrentBasicBlock
    br lj
    defineBasicBlock lj
    phi [(rt, lt'), (rf, lf')]

infix 0 ??
(??) :: (IsFirstClass a, CmpRet a) => TValue r (CmpResult a) -> (TValue r a, TValue r a) -> TValue r a
c ?? (t, f) = do
    c' <- c
    t' <- t
    f' <- f
    select c' t' f'

-- | Return a value from an 'arithFunction'.
retrn :: (Ret (Value a) r) => TValue r a -> CodeGenFunction r ()
retrn x = x >>= ret

-- | Use @x <- set $ ...@ to make a binding.
set :: TValue r a -> (CodeGenFunction r (TValue r a))
set x = do x' <- x; return (return x')

instance Eq (TValue r a)
instance Ord (TValue r a)

instance (IsArithmetic a, CmpRet a, Num a, IsConst a) => Num (TValue r a) where
    (+) = binop add
    (-) = binop sub
    (*) = binop mul
    negate = (>>= neg)
    abs x = x %< 0 ?? (-x, x)
    signum x = x %< 0 ?? (-1, x %> 0 ?? (1, 0))
    fromInteger = return . valueOf . fromInteger

instance (IsArithmetic a, CmpRet a, Num a, IsConst a) => Enum (TValue r a) where
    succ x = x + 1
    pred x = x - 1
    fromEnum _ = error "CodeGenFunction Value: fromEnum"
    toEnum = fromIntegral

instance (IsArithmetic a, CmpRet a, Num a, IsConst a) => Real (TValue r a) where
    toRational _ = error "CodeGenFunction Value: toRational"

instance (CmpRet a, Num a, IsConst a, IsInteger a) => Integral (TValue r a) where
    quot = binop idiv
    rem  = binop irem
    quotRem x y = (quot x y, rem x y)
    toInteger _ = error "CodeGenFunction Value: toInteger"

instance (CmpRet a, Fractional a, IsConst a, IsFloating a) => Fractional (TValue r a) where
    (/) = binop fdiv
    fromRational = return . valueOf . fromRational

instance (CmpRet a, Fractional a, IsConst a, IsFloating a) => RealFrac (TValue r a) where
    properFraction _ = error "CodeGenFunction Value: properFraction"

instance (CmpRet a, CallIntrinsic a, Floating a, IsConst a, IsFloating a) => Floating (TValue r a) where
    pi = return $ valueOf pi
    sqrt = callIntrinsic1 "sqrt"
    sin = callIntrinsic1 "sin"
    cos = callIntrinsic1 "cos"
    (**) = callIntrinsic2 "pow"
    exp = callIntrinsic1 "exp"
    log = callIntrinsic1 "log"

    asin _ = error "LLVM missing intrinsic: asin"
    acos _ = error "LLVM missing intrinsic: acos"
    atan _ = error "LLVM missing intrinsic: atan"

    sinh x           = (exp x - exp (-x)) / 2
    cosh x           = (exp x + exp (-x)) / 2
    asinh x          = log (x + sqrt (x*x + 1))
    acosh x          = log (x + sqrt (x*x - 1))
    atanh x          = (log (1 + x) - log (1 - x)) / 2

instance (CmpRet a, CallIntrinsic a, RealFloat a, IsConst a, IsFloating a) => RealFloat (TValue r a) where
    floatRadix _ = floatRadix (undefined :: a)
    floatDigits _ = floatDigits (undefined :: a)
    floatRange _ = floatRange (undefined :: a)
    decodeFloat _ = error "CodeGenFunction Value: decodeFloat"
    encodeFloat _ _ = error "CodeGenFunction Value: encodeFloat"
    exponent _ = 0
    scaleFloat 0 x = x
    scaleFloat _ _ = error "CodeGenFunction Value: scaleFloat"
    isNaN _ = error "CodeGenFunction Value: isNaN"
    isInfinite _ = error "CodeGenFunction Value: isInfinite"
    isDenormalized _ = error "CodeGenFunction Value: isDenormalized"
    isNegativeZero _ = error "CodeGenFunction Value: isNegativeZero"
    isIEEE _ = isIEEE (undefined :: a)

binop :: (Value a -> Value b -> TValue r c) ->
         TValue r a -> TValue r b -> TValue r c
binop op x y = do
    x' <- x
    y' <- y
    op x' y'

{-
If we add the ReadNone attribute, then LLVM-2.8 complains:

llvm/examples$ Arith_dyn.exe
Attribute readnone only applies to the function!
  %2 = call readnone double @llvm.sin.f64(double %0)
Attribute readnone only applies to the function!
  %3 = call readnone double @llvm.exp.f64(double %2)
Broken module found, compilation aborted!
Stack dump:
0.      Running pass 'Function Pass Manager' on module '_module'.
1.      Running pass 'Module Verifier' on function '@_fun1'
Aborted
-}
addReadNone :: Value a -> CodeGenFunction r (Value a)
addReadNone x = do
--   addAttributes x 0 [ReadNoneAttribute]
   return x

callIntrinsicP1 :: forall a b r . (IsFirstClass a, IsFirstClass b, IsPrimitive a) =>
                   String -> Value a -> TValue r b
callIntrinsicP1 fn x = do
    op <- externFunction ("llvm." ++ fn ++ "." ++ intrinsicTypeName (Proxy :: Proxy a))
{-
You can add these attributes,
but the verifier pass in the optimizer checks whether they match
the attributes that are declared for that intrinsic.
If we omit adding attributes then the right attributes are added automatically.
    addFunctionAttributes op [NoUnwindAttribute, ReadOnlyAttribute]
-}
    runCall (callFromFunction op `applyCall` x) >>= addReadNone

callIntrinsicP2 :: forall a b c r . (IsFirstClass a, IsFirstClass b, IsFirstClass c, IsPrimitive a) =>
                   String -> Value a -> Value b -> TValue r c
callIntrinsicP2 fn x y = do
    op <- externFunction ("llvm." ++ fn ++ "." ++ intrinsicTypeName (Proxy :: Proxy a))
    runCall (callFromFunction op `applyCall` x `applyCall` y) >>= addReadNone

-------------------------------------------

class ArithFunction r z a b | a -> b r z, b r z -> a where
    arithFunction' :: a -> b

instance
    (Ret a r) =>
        ArithFunction r a (CodeGenFunction r a) (CodeGenFunction r ()) where
    arithFunction' x = x >>= ret

instance
    (ArithFunction r z b0 b1) =>
        ArithFunction r z (CodeGenFunction r a -> b0) (a -> b1) where
    arithFunction' f = arithFunction' . f . return

-- |Unlift a function with @TValue@ to have @Value@ arguments.
arithFunction :: ArithFunction r z a b => a -> b
arithFunction = arithFunction'


class ToArithFunction r a b | a r -> b, b -> a r where
    toArithFunction' :: CodeGenFunction r (Call a) -> b

instance ToArithFunction r (IO b) (CodeGenFunction r (Value b)) where
    toArithFunction' cl = cl >>= runCall

instance
    ToArithFunction r b0 b1 =>
        ToArithFunction r (a -> b0) (CodeGenFunction r (Value a) -> b1) where
    toArithFunction' cl x =
        toArithFunction' (liftM2 applyCall cl x)


_toArithFunction2 ::
   Function (a -> b -> IO c) -> TValue r a -> TValue r b -> TValue r c
_toArithFunction2 f tx ty = do
    x <- tx
    y <- ty
    runCall $ callFromFunction f `applyCall` x `applyCall` y

-- |Lift a function from having @Value@ arguments to having @TValue@ arguments.
toArithFunction ::
    (ToArithFunction r f g) =>
    Function f -> g
toArithFunction f =
    toArithFunction' $ return $ callFromFunction f

-------------------------------------------

-- |Define a recursive 'arithFunction', gets passed itself as the first argument.
recursiveFunction ::
    (IsFunction f, FunctionArgs f, code ~ FunctionCodeGen f,
     ArithFunction r1 z arith code,
     ToArithFunction r0 f g) =>
    (g -> arith) -> CodeGenModule (Function f)
recursiveFunction af = do
    f <- newFunction ExternalLinkage
    defineFunction f $ arithFunction $ af $ toArithFunction f
    return f


-------------------------------------------

class CallIntrinsic a where
    callIntrinsic1' :: String -> Value a -> TValue r a
    callIntrinsic2' :: String -> Value a -> Value a -> TValue r a

instance CallIntrinsic Float where
    callIntrinsic1' = callIntrinsicP1
    callIntrinsic2' = callIntrinsicP2

instance CallIntrinsic Double where
    callIntrinsic1' = callIntrinsicP1
    callIntrinsic2' = callIntrinsicP2

{-
I think such a special case for certain systems
would be better handled as in LLVM.Extra.Extension.
(lemming)
-}
macOS :: Bool
#if defined(__MACOS__)
macOS = True
#else
macOS = False
#endif

instance (Dec.Positive n, IsPrimitive a, CallIntrinsic a) => CallIntrinsic (Vector n a) where
    callIntrinsic1' s x =
       if macOS && Dec.integerFromSingleton (Dec.singleton :: Dec.Singleton n) == 4 &&
          elem s ["sqrt", "log", "exp", "sin", "cos", "tan"]
         then do
            op <- externFunction ("v" ++ s ++ "f")
            call op x >>= addReadNone
         else mapVector (callIntrinsic1' s) x
    callIntrinsic2' s = mapVector2 (callIntrinsic2' s)

callIntrinsic1 :: (CallIntrinsic a) => String -> TValue r a -> TValue r a
callIntrinsic1 s x = do x' <- x; callIntrinsic1' s x'

callIntrinsic2 :: (CallIntrinsic a) => String -> TValue r a -> TValue r a -> TValue r a
callIntrinsic2 s = binop (callIntrinsic2' s)