{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
module LLVM.Extra.Class where

import qualified LLVM.Core as LLVM
import LLVM.Core
   (Value, value, valueOf, undef,
    Vector,
    IsConst, IsType, IsFirstClass, IsPrimitive,
    CodeGenFunction, BasicBlock, )
import LLVM.Util.Loop (Phi, phis, addPhis, )
import qualified Data.TypeLevel.Num as TypeNum

import Control.Applicative (pure, liftA2, )
import qualified Control.Applicative as App
import qualified Data.Foldable as Fold
import qualified Data.Traversable as Trav

import Foreign.StablePtr (StablePtr, )
import Foreign.Ptr (Ptr, )

import Data.Word (Word8, Word16, Word32, Word64, )
import Data.Int  (Int8,  Int16,  Int32,  Int64, )

import Prelude hiding (and, iterate, map, zipWith, writeFile, )


-- * class for tuples of undefined values

class Undefined a where
   undefTuple :: a

instance Undefined () where
   undefTuple = ()

instance (IsFirstClass a) => Undefined (Value a) where
   undefTuple = value undef

instance (Undefined a, Undefined b) => Undefined (a, b) where
   undefTuple = (undefTuple, undefTuple)

instance (Undefined a, Undefined b, Undefined c) => Undefined (a, b, c) where
   undefTuple = (undefTuple, undefTuple, undefTuple)


-- * class for tuples of zero values

class Zero a where
   zeroTuple :: a

instance Zero () where
   zeroTuple = ()

instance (LLVM.IsFirstClass a) => Zero (Value a) where
   zeroTuple = LLVM.value LLVM.zero

instance (Zero a, Zero b) => Zero (a, b) where
   zeroTuple = (zeroTuple, zeroTuple)

instance (Zero a, Zero b, Zero c) => Zero (a, b, c) where
   zeroTuple = (zeroTuple, zeroTuple, zeroTuple)

zeroTuplePointed ::
   (Zero a, App.Applicative f) =>
   f a
zeroTuplePointed =
   pure zeroTuple


-- * class for creating tuples of constant values

{-
ToDo: flip type parameter order in order to match good style
-}
-- class (IsTuple haskellValue, ValueTuple llvmValue) =>
--      MakeValueTuple haskellValue llvmValue | haskellValue -> llvmValue where
class (Undefined llvmValue) =>
      MakeValueTuple haskellValue llvmValue | haskellValue -> llvmValue where
   valueTupleOf :: haskellValue -> llvmValue

instance (MakeValueTuple ah al, MakeValueTuple bh bl) =>
      MakeValueTuple (ah,bh) (al,bl) where
   valueTupleOf ~(a,b) = (valueTupleOf a, valueTupleOf b)

instance (MakeValueTuple ah al, MakeValueTuple bh bl, MakeValueTuple ch cl) =>
      MakeValueTuple (ah,bh,ch) (al,bl,cl) where
   valueTupleOf ~(a,b,c) = (valueTupleOf a, valueTupleOf b, valueTupleOf c)

instance MakeValueTuple Float        (Value Float)  where valueTupleOf = valueOf
instance MakeValueTuple Double       (Value Double) where valueTupleOf = valueOf
-- instance MakeValueTuple FP128        (Value FP128)  where valueTupleOf = valueOf
instance MakeValueTuple Bool         (Value Bool)   where valueTupleOf = valueOf
instance MakeValueTuple Int8         (Value Int8)   where valueTupleOf = valueOf
instance MakeValueTuple Int16        (Value Int16)  where valueTupleOf = valueOf
instance MakeValueTuple Int32        (Value Int32)  where valueTupleOf = valueOf
instance MakeValueTuple Int64        (Value Int64)  where valueTupleOf = valueOf
instance MakeValueTuple Word8        (Value Word8)  where valueTupleOf = valueOf
instance MakeValueTuple Word16       (Value Word16) where valueTupleOf = valueOf
instance MakeValueTuple Word32       (Value Word32) where valueTupleOf = valueOf
instance MakeValueTuple Word64       (Value Word64) where valueTupleOf = valueOf
instance MakeValueTuple ()           ()             where valueTupleOf = id

{-
I'm not sure about this instance.
Maybe it is better to convert the pointer target type
according to a class that maps Haskell tuples to LLVM structs.
-}
instance IsType a =>
         MakeValueTuple (Ptr a) (Value (Ptr a)) where valueTupleOf = valueOf
instance MakeValueTuple (StablePtr a) (Value (StablePtr a)) where valueTupleOf = valueOf

{-
instance (MakeValueTuple haskellValue llvmValue, Memory llvmValue llvmStruct) =>
         MakeValueTuple (Ptr haskellValue) (Value (Ptr llvmStruct)) where
   valueTupleOf = valueOf . castStorablePtr
instance (Pos n) =>
         MakeValueTuple (IntN n)     (Value (IntN n)) where
instance (Pos n) =>
         MakeValueTuple (WordN n)    (Value (WordN n)) where
-}
instance (TypeNum.Pos n, IsPrimitive a, IsConst a) =>
         MakeValueTuple (Vector n a) (Value (Vector n a)) where valueTupleOf = valueOf


-- * default methods for LLVM classes

{-
buildTupleTraversable ::
   (Undefined a, Trav.Traversable f, App.Applicative f) =>
   FunctionRef -> State Int (f a)
buildTupleTraversable f =
   Trav.sequence (pure (buildTuple f))
-}
{-
buildTupleTraversable ::
   (Trav.Traversable f, App.Applicative f) =>
   State Int a ->
   State Int (f a)
buildTupleTraversable build =
   Trav.sequence (pure build)
-}
{- this is the version I used
buildTupleTraversable ::
   (Monad m, Trav.Traversable f, App.Applicative f) =>
   m a ->
   m (f a)
buildTupleTraversable build =
   Trav.sequence (pure build)
-}

undefTuplePointed ::
   (Undefined a, App.Applicative f) =>
   f a
undefTuplePointed =
   pure undefTuple

valueTupleOfFunctor ::
   (MakeValueTuple h l, Functor f) =>
   f h -> f l
valueTupleOfFunctor =
   fmap valueTupleOf

{-
tupleDescFoldable ::
   (IsTuple a, Fold.Foldable f) =>
   f a -> [TypeDesc]
tupleDescFoldable =
   Fold.foldMap tupleDesc
-}

phisTraversable ::
   (Phi a, Trav.Traversable f) =>
   BasicBlock -> f a -> CodeGenFunction r (f a)
phisTraversable bb x =
   Trav.mapM (phis bb) x

addPhisFoldable ::
   (Phi a, Fold.Foldable f, App.Applicative f) =>
   BasicBlock -> f a -> f a -> CodeGenFunction r ()
addPhisFoldable bb x y =
   Fold.sequence_ (liftA2 (addPhis bb) x y)