{-# LANGUAGE TypeFamilies #-}
module LLVM.Extra.Scalar where

import qualified LLVM.Extra.Tuple as Tuple
import qualified LLVM.Extra.Arithmetic as A

import qualified Control.Monad as Monad


{- |
The entire purpose of this datatype is to mark a type as scalar,
although it might also be interpreted as vector.
This way you can write generic operations for vectors
using the 'A.PseudoModule' class,
and specialise them to scalar types with respect to the 'A.PseudoRing' class.
From another perspective
you can consider the 'Scalar.T' type constructor a marker
where the 'A.Scalar' type function
stops reducing nested vector types to scalar types.
-}
newtype T a = Cons {decons :: a}

liftM :: (Monad m) => (a -> m b) -> T a -> m (T b)
liftM f (Cons a) = Monad.liftM Cons $ f a

liftM2 :: (Monad m) => (a -> b -> m c) -> T a -> T b -> m (T c)
liftM2 f (Cons a) (Cons b) = Monad.liftM Cons $ f a b


unliftM ::
   (Monad m) =>
   (T a -> m (T r)) ->
   a -> m r
unliftM f a =
   Monad.liftM decons $ f (Cons a)

unliftM2 ::
   (Monad m) =>
   (T a -> T b -> m (T r)) ->
   a -> b -> m r
unliftM2 f a b =
   Monad.liftM decons $ f (Cons a) (Cons b)

unliftM3 ::
   (Monad m) =>
   (T a -> T b -> T c -> m (T r)) ->
   a -> b -> c -> m r
unliftM3 f a b c =
   Monad.liftM decons $ f (Cons a) (Cons b) (Cons c)

unliftM4 ::
   (Monad m) =>
   (T a -> T b -> T c -> T d -> m (T r)) ->
   a -> b -> c -> d -> m r
unliftM4 f a b c d =
   Monad.liftM decons $ f (Cons a) (Cons b) (Cons c) (Cons d)

unliftM5 ::
   (Monad m) =>
   (T a -> T b -> T c -> T d -> T e -> m (T r)) ->
   a -> b -> c -> d -> e -> m r
unliftM5 f a b c d e =
   Monad.liftM decons $ f (Cons a) (Cons b) (Cons c) (Cons d) (Cons e)


instance (Tuple.Zero a) => Tuple.Zero (T a) where
   zero = Cons Tuple.zero

instance (Tuple.Undefined a) => Tuple.Undefined (T a) where
   undef = Cons Tuple.undef

instance (Tuple.Phi a) => Tuple.Phi (T a) where
   phi bb = fmap Cons . Tuple.phi bb . decons
   addPhi bb (Cons a) (Cons b) = Tuple.addPhi bb a b

instance (A.IntegerConstant a) => A.IntegerConstant (T a) where
   fromInteger' = Cons . A.fromInteger'

instance (A.RationalConstant a) => A.RationalConstant (T a) where
   fromRational' = Cons . A.fromRational'

instance (A.Additive a) => A.Additive (T a) where
   zero = Cons A.zero
   add = liftM2 A.add
   sub = liftM2 A.sub
   neg = liftM A.neg

instance (A.PseudoRing a) => A.PseudoRing (T a) where
   mul = liftM2 A.mul

instance (A.Field a) => A.Field (T a) where
   fdiv = liftM2 A.fdiv

type instance A.Scalar (T a) = T a

instance (A.PseudoRing a) => A.PseudoModule (T a) where
   scale = liftM2 A.mul


instance (A.Real a) => A.Real (T a) where
   min = liftM2 A.min
   max = liftM2 A.max
   abs = liftM A.abs
   signum = liftM A.signum

instance (A.Fraction a) => A.Fraction (T a) where
   truncate = liftM A.truncate
   fraction = liftM A.fraction

instance (A.Algebraic a) => A.Algebraic (T a) where
   sqrt = liftM A.sqrt

instance (A.Transcendental a) => A.Transcendental (T a) where
   pi = fmap Cons A.pi
   sin = liftM A.sin
   cos = liftM A.cos
   exp = liftM A.exp
   log = liftM A.log
   pow = liftM2 A.pow