{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE UndecidableInstances #-}
module LLVM.Extra.Tuple (
   Phi(..), phiTraversable, addPhiFoldable,
   Undefined(..), undefPointed,
   Zero(..), zeroPointed,
   Value(..), valueOfFunctor,
   VectorValue(..),
   ) where

import LLVM.Extra.TuplePrivate (
   Phi(..), phiTraversable, addPhiFoldable,
   Undefined(..), undefPointed,
   Zero(..), zeroPointed,
   )
import qualified LLVM.Extra.EitherPrivate as Either
import qualified LLVM.Extra.MaybePrivate as Maybe
import qualified LLVM.Core as LLVM
import LLVM.Core (IsType, Vector)

import qualified Type.Data.Num.Decimal as TypeNum
import Type.Data.Num.Decimal ((:*:))

import qualified Control.Monad.Trans.State as MS
import qualified Control.Applicative as App
import qualified Control.Functor.HT as FuncHT

import qualified Data.Foldable as Fold
import qualified Data.Traversable as Trav

import qualified Foreign.Storable.Record.Tuple as StoreTuple
import Foreign.StablePtr (StablePtr, )
import Foreign.Ptr (FunPtr, Ptr, )

import qualified Data.EnumBitSet as EnumBitSet
import qualified Data.Enum.Storable as Enum
import qualified Data.Bool8 as Bool8
import Data.Complex (Complex((:+)))
import Data.Tagged (Tagged(unTagged))
import Data.Word (Word8, Word16, Word32, Word64, Word)
import Data.Int  (Int8,  Int16,  Int32,  Int64, )
import Data.Bool8 (Bool8)

import Prelude2010
import Prelude ()


-- * class for creating tuples of constant values

class (Undefined (ValueOf a)) => Value a where
   type ValueOf a
   valueOf :: a -> ValueOf a

instance (Value a, Value b) => Value (a,b) where
   type ValueOf (a,b) = (ValueOf a, ValueOf b)
   valueOf ~(a,b) = (valueOf a, valueOf b)

instance (Value a, Value b, Value c) => Value (a,b,c) where
   type ValueOf (a,b,c) = (ValueOf a, ValueOf b, ValueOf c)
   valueOf ~(a,b,c) = (valueOf a, valueOf b, valueOf c)

instance (Value a, Value b, Value c, Value d) => Value (a,b,c,d) where
   type ValueOf (a,b,c,d) = (ValueOf a, ValueOf b, ValueOf c, ValueOf d)
   valueOf ~(a,b,c,d) = (valueOf a, valueOf b, valueOf c, valueOf d)

instance (Value tuple) => Value (StoreTuple.Tuple tuple) where
   type ValueOf (StoreTuple.Tuple tuple) = ValueOf tuple
   valueOf (StoreTuple.Tuple a) = valueOf a

instance (Value a) => Value (Maybe a) where
   type ValueOf (Maybe a) = Maybe.T (ValueOf a)
   valueOf = maybe (Maybe.nothing undef) (Maybe.just . valueOf)

instance (Value a, Value b) => Value (Either a b) where
   type ValueOf (Either a b) = Either.T (ValueOf a) (ValueOf b)
   valueOf =
      either
         (Either.left undef . valueOf)
         (Either.right undef . valueOf)

instance Value Float  where type ValueOf Float  = LLVM.Value Float  ; valueOf = LLVM.valueOf
instance Value Double where type ValueOf Double = LLVM.Value Double ; valueOf = LLVM.valueOf
-- instance Value FP128  where type ValueOf FP128  = LLVM.Value FP128  ; valueOf = LLVM.valueOf
instance Value Bool   where type ValueOf Bool   = LLVM.Value Bool   ; valueOf = LLVM.valueOf
instance Value Bool8  where type ValueOf Bool8  = LLVM.Value Bool   ; valueOf = LLVM.valueOf . Bool8.toBool
instance Value Int    where type ValueOf Int    = LLVM.Value Int    ; valueOf = LLVM.valueOf
instance Value Int8   where type ValueOf Int8   = LLVM.Value Int8   ; valueOf = LLVM.valueOf
instance Value Int16  where type ValueOf Int16  = LLVM.Value Int16  ; valueOf = LLVM.valueOf
instance Value Int32  where type ValueOf Int32  = LLVM.Value Int32  ; valueOf = LLVM.valueOf
instance Value Int64  where type ValueOf Int64  = LLVM.Value Int64  ; valueOf = LLVM.valueOf
instance Value Word   where type ValueOf Word   = LLVM.Value Word   ; valueOf = LLVM.valueOf
instance Value Word8  where type ValueOf Word8  = LLVM.Value Word8  ; valueOf = LLVM.valueOf
instance Value Word16 where type ValueOf Word16 = LLVM.Value Word16 ; valueOf = LLVM.valueOf
instance Value Word32 where type ValueOf Word32 = LLVM.Value Word32 ; valueOf = LLVM.valueOf
instance Value Word64 where type ValueOf Word64 = LLVM.Value Word64 ; valueOf = LLVM.valueOf
instance Value ()     where type ValueOf ()     = ()           ; valueOf = id


instance (TypeNum.Positive n) => Value (LLVM.IntN n) where
   type ValueOf (LLVM.IntN n) = LLVM.Value (LLVM.IntN n)
   valueOf = LLVM.valueOf

instance (TypeNum.Positive n) => Value (LLVM.WordN n) where
   type ValueOf (LLVM.WordN n) = LLVM.Value (LLVM.WordN n)
   valueOf = LLVM.valueOf


instance Value (Ptr a) where
   type ValueOf (Ptr a) = LLVM.Value (Ptr a)
   valueOf = LLVM.valueOf

instance IsType a => Value (LLVM.Ptr a) where
   type ValueOf (LLVM.Ptr a) = LLVM.Value (LLVM.Ptr a)
   valueOf = LLVM.valueOf

instance LLVM.IsFunction a => Value (FunPtr a) where
   type ValueOf (FunPtr a) = LLVM.Value (FunPtr a)
   valueOf = LLVM.valueOf

instance Value (StablePtr a) where
   type ValueOf (StablePtr a) = LLVM.Value (StablePtr a)
   valueOf = LLVM.valueOf

instance
   (TypeNum.Positive n, VectorValue n a, Undefined (VectorValueOf n a)) =>
      Value (Vector n a) where
   type ValueOf (Vector n a) = VectorValueOf n a
   valueOf = vectorValueOf


instance Value a => Value (Tagged tag a) where
   type ValueOf (Tagged tag a) = ValueOf a
   valueOf = valueOf . unTagged

instance
   (LLVM.IsInteger w, LLVM.IsConst w, Num w, Enum e) =>
      Value (Enum.T w e) where
   type ValueOf (Enum.T w e) = LLVM.Value w
   valueOf = LLVM.valueOf . fromIntegral . fromEnum . Enum.toPlain

instance (LLVM.IsInteger w, LLVM.IsConst w) => Value (EnumBitSet.T w i) where
   type ValueOf (EnumBitSet.T w i) = LLVM.Value w
   valueOf = LLVM.valueOf . EnumBitSet.decons

instance (Value a) => Value (Complex a) where
   type ValueOf (Complex a) = Complex (ValueOf a)
   valueOf (a:+b) = valueOf a :+ valueOf b


-- * class for vectors of tuples and other complex types

class
   (TypeNum.Positive n, Undefined (VectorValueOf n a)) =>
      VectorValue n a where
   type VectorValueOf n a
   vectorValueOf :: Vector n a -> VectorValueOf n a

-- may be simplified using a fake proof of TypeNum.Positive (n :*: m)
instance
   (TypeNum.Positive n, TypeNum.Positive m, TypeNum.Positive (n :*: m),
    Undefined (Vector (n :*: m) a)) =>
      VectorValue n (Vector m a) where
   type VectorValueOf n (Vector m a) = Vector (n :*: m) a
   vectorValueOf = vectorFromList . Fold.foldMap Fold.toList

vectorFromList :: (TypeNum.Positive n) => [a] -> Vector n a
vectorFromList =
   MS.evalState $ Trav.sequence $ App.pure $ MS.state $ \(y:ys) -> (y,ys)

instance (VectorValue n a, VectorValue n b) => VectorValue n (a,b) where
   type VectorValueOf n (a,b) = (VectorValueOf n a, VectorValueOf n b)
   vectorValueOf v =
      case FuncHT.unzip v of
         (a,b) -> (vectorValueOf a, vectorValueOf b)

instance
   (VectorValue n a, VectorValue n b, VectorValue n c) =>
      VectorValue n (a,b,c) where
   type VectorValueOf n (a,b,c) =
         (VectorValueOf n a, VectorValueOf n b, VectorValueOf n c)
   vectorValueOf v =
      case FuncHT.unzip3 v of
         (a,b,c) -> (vectorValueOf a, vectorValueOf b, vectorValueOf c)

instance (VectorValue n tuple) => VectorValue n (StoreTuple.Tuple tuple) where
   type VectorValueOf n (StoreTuple.Tuple tuple) = VectorValueOf n tuple
   vectorValueOf = vectorValueOf . fmap StoreTuple.getTuple

instance (TypeNum.Positive n) => VectorValue n Float where
   type VectorValueOf n Float  = LLVM.Value (Vector n Float)
   vectorValueOf = LLVM.valueOf

instance (TypeNum.Positive n) => VectorValue n Double where
   type VectorValueOf n Double = LLVM.Value (Vector n Double)
   vectorValueOf = LLVM.valueOf
{-
instance (TypeNum.Positive n) => VectorValue n FP128  where
   type VectorValueOf n FP128  = LLVM.Value (Vector n FP128)
   vectorValueOf = LLVM.valueOf
-}
instance (TypeNum.Positive n) => VectorValue n Bool   where
   type VectorValueOf n Bool   = LLVM.Value (Vector n Bool)
   vectorValueOf = LLVM.valueOf
instance (TypeNum.Positive n) => VectorValue n Bool8  where
   type VectorValueOf n Bool8  = LLVM.Value (Vector n Bool)
   vectorValueOf = LLVM.valueOf . fmap Bool8.toBool
instance (TypeNum.Positive n) => VectorValue n Int  where
   type VectorValueOf n Int    = LLVM.Value (Vector n Int)
   vectorValueOf = LLVM.valueOf
instance (TypeNum.Positive n) => VectorValue n Int8   where
   type VectorValueOf n Int8   = LLVM.Value (Vector n Int8)
   vectorValueOf = LLVM.valueOf
instance (TypeNum.Positive n) => VectorValue n Int16  where
   type VectorValueOf n Int16  = LLVM.Value (Vector n Int16)
   vectorValueOf = LLVM.valueOf
instance (TypeNum.Positive n) => VectorValue n Int32  where
   type VectorValueOf n Int32  = LLVM.Value (Vector n Int32)
   vectorValueOf = LLVM.valueOf
instance (TypeNum.Positive n) => VectorValue n Int64  where
   type VectorValueOf n Int64  = LLVM.Value (Vector n Int64)
   vectorValueOf = LLVM.valueOf
instance (TypeNum.Positive n) => VectorValue n Word   where
   type VectorValueOf n Word   = LLVM.Value (Vector n Word)
   vectorValueOf = LLVM.valueOf
instance (TypeNum.Positive n) => VectorValue n Word8  where
   type VectorValueOf n Word8  = LLVM.Value (Vector n Word8)
   vectorValueOf = LLVM.valueOf
instance (TypeNum.Positive n) => VectorValue n Word16 where
   type VectorValueOf n Word16 = LLVM.Value (Vector n Word16)
   vectorValueOf = LLVM.valueOf
instance (TypeNum.Positive n) => VectorValue n Word32 where
   type VectorValueOf n Word32 = LLVM.Value (Vector n Word32)
   vectorValueOf = LLVM.valueOf
instance (TypeNum.Positive n) => VectorValue n Word64 where
   type VectorValueOf n Word64 = LLVM.Value (Vector n Word64)
   vectorValueOf = LLVM.valueOf


-- * default methods for LLVM classes

valueOfFunctor :: (Value h, Functor f) => f h -> f (ValueOf h)
valueOfFunctor = fmap valueOf