{-# LANGUAGE CPP                        #-}
{-# LANGUAGE ConstraintKinds            #-}
{-# LANGUAGE DataKinds                  #-}
{-# LANGUAGE ExistentialQuantification  #-}
{-# LANGUAGE FlexibleContexts           #-}
{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MagicHash                  #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
{-# LANGUAGE ScopedTypeVariables        #-}
{-# LANGUAGE TypeApplications           #-}
{-# LANGUAGE TypeFamilies               #-}
{-# LANGUAGE TypeFamilyDependencies     #-}
{-# LANGUAGE TypeOperators              #-}
{-# LANGUAGE UnboxedTuples              #-}
{-# LANGUAGE StandaloneDeriving         #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Numeric.Array.Family
-- Copyright   :  (c) Artem Chirkin
-- License     :  BSD3
--
-- Maintainer  :  chirkin@arch.ethz.ch
--
--
-----------------------------------------------------------------------------

module Numeric.Array.Family
  ( Array
  , ArrayF (..), ArrayD (..)
  , ArrayI (..), ArrayI8 (..), ArrayI16 (..), ArrayI32 (..), ArrayI64 (..)
  , ArrayW (..), ArrayW8 (..), ArrayW16 (..), ArrayW32 (..), ArrayW64 (..)
  , Scalar (..)
  , FloatX2 (..), FloatX3 (..), FloatX4 (..)
  , DoubleX2 (..), DoubleX3 (..), DoubleX4 (..)
  , ArrayInstanceInference, ElemType (..), ArraySize (..)
  , ElemTypeInference (..), ArraySizeInference (..), ArrayInstanceEvidence
  , getArrayInstance, ArrayInstance (..), inferArrayInstance
  ) where

#include "MachDeps.h"

import           Data.Int                  (Int16, Int32, Int64, Int8)
import           Data.Type.Equality        ((:~:) (..))
import           Data.Word                 (Word16, Word32, Word64, Word8)
import           GHC.Prim                  ( ByteArray#, Double#, Float#
#if WORD_SIZE_IN_BITS < 64
                                           , Int64#, Word64#
#endif
                                           , Int#, Word#, unsafeCoerce#)
import           GHC.Exts                  (RuntimeRep(..))

import           Numeric.Array.ElementWise
import           Numeric.Commons
import           Numeric.TypeLits
import           Numeric.Dimensions

-- | Full collection of n-order arrays
type family Array t (ds :: [Nat]) = v | v -> t ds where
  Array t      '[]          = Scalar t
  Array Float  '[2]         = FloatX2
  Array Float  '[3]         = FloatX3
  Array Float  '[4]         = FloatX4
  Array Double '[2]         = DoubleX2
  Array Double '[3]         = DoubleX3
  Array Double '[4]         = DoubleX4
  Array Float  (d ': ds)    = ArrayF   (d ': ds)
  Array Double (d ': ds)    = ArrayD   (d ': ds)
  Array Int    (d ': ds)    = ArrayI   (d ': ds)
  Array Int8   (d ': ds)    = ArrayI8  (d ': ds)
  Array Int16  (d ': ds)    = ArrayI16 (d ': ds)
  Array Int32  (d ': ds)    = ArrayI32 (d ': ds)
  Array Int64  (d ': ds)    = ArrayI64 (d ': ds)
  Array Word   (d ': ds)    = ArrayW   (d ': ds)
  Array Word8  (d ': ds)    = ArrayW8  (d ': ds)
  Array Word16 (d ': ds)    = ArrayW16 (d ': ds)
  Array Word32 (d ': ds)    = ArrayW32 (d ': ds)
  Array Word64 (d ': ds)    = ArrayW64 (d ': ds)


-- | Specialize scalar type without any arrays
newtype Scalar t = Scalar { _unScalar :: t }
  deriving ( Enum, Eq, Integral
           , Num, Fractional, Floating, Ord, Read, Real, RealFrac, RealFloat)
instance Show t => Show (Scalar t) where
  show (Scalar t) = "{ " ++ show t ++ " }"

deriving instance {-# OVERLAPPABLE #-} Bounded t => Bounded (Scalar t)
instance {-# OVERLAPPING #-} Bounded (Scalar Double) where
  maxBound = Scalar inftyD
  minBound = Scalar $ negate inftyD
instance {-# OVERLAPPING #-} Bounded (Scalar Float) where
  maxBound = Scalar inftyF
  minBound = Scalar $ negate inftyF
inftyD :: Double
inftyD = read "Infinity"
inftyF :: Float
inftyF = read "Infinity"

type instance ElemRep (Scalar Float ) = 'FloatRep
type instance ElemRep (Scalar Double) = 'DoubleRep
type instance ElemRep (Scalar Int   ) = 'IntRep
type instance ElemRep (Scalar Int8  ) = 'IntRep
type instance ElemRep (Scalar Int16 ) = 'IntRep
type instance ElemRep (Scalar Int32 ) = 'IntRep
#if WORD_SIZE_IN_BITS < 64
type instance ElemRep (Scalar Int64 ) = 'Int64Rep
#else
type instance ElemRep (Scalar Int64 ) = 'IntRep
#endif
type instance ElemRep (Scalar Word  ) = 'WordRep
type instance ElemRep (Scalar Word8 ) = 'WordRep
type instance ElemRep (Scalar Word16) = 'WordRep
type instance ElemRep (Scalar Word32) = 'WordRep
#if WORD_SIZE_IN_BITS < 64
type instance ElemRep (Scalar Word64) = 'Word64Rep
#else
type instance ElemRep (Scalar Word64) = 'WordRep
#endif

type instance ElemPrim (Scalar Float ) = Float#
type instance ElemPrim (Scalar Double) = Double#
type instance ElemPrim (Scalar Int   ) = Int#
type instance ElemPrim (Scalar Int8  ) = Int#
type instance ElemPrim (Scalar Int16 ) = Int#
type instance ElemPrim (Scalar Int32 ) = Int#
#if WORD_SIZE_IN_BITS < 64
type instance ElemPrim (Scalar Int64 ) = Int64#
#else
type instance ElemPrim (Scalar Int64 ) = Int#
#endif
type instance ElemPrim (Scalar Word  ) = Word#
type instance ElemPrim (Scalar Word8 ) = Word#
type instance ElemPrim (Scalar Word16) = Word#
type instance ElemPrim (Scalar Word32) = Word#
#if WORD_SIZE_IN_BITS < 64
type instance ElemPrim (Scalar Word64) = Word64#
#else
type instance ElemPrim (Scalar Word64) = Word#
#endif

deriving instance PrimBytes (Scalar Float)
deriving instance PrimBytes (Scalar Double)
deriving instance PrimBytes (Scalar Int)
deriving instance PrimBytes (Scalar Int8)
deriving instance PrimBytes (Scalar Int16)
deriving instance PrimBytes (Scalar Int32)
deriving instance PrimBytes (Scalar Int64)
deriving instance PrimBytes (Scalar Word)
deriving instance PrimBytes (Scalar Word8)
deriving instance PrimBytes (Scalar Word16)
deriving instance PrimBytes (Scalar Word32)
deriving instance PrimBytes (Scalar Word64)

-- | Indexing over scalars is trivial...
instance ElementWise (Idx ('[] :: [Nat])) t (Scalar t) where
  indexOffset# x _ = _unScalar x
  (!) x _ = _unScalar x
  {-# INLINE (!) #-}
  ewmap f = Scalar . f Z . _unScalar
  {-# INLINE ewmap #-}
  ewgen f = Scalar $ f Z
  {-# INLINE ewgen #-}
  ewgenA f = Scalar <$> f Z
  {-# INLINE ewgenA #-}
  ewfoldl f x0 = f Z x0 . _unScalar
  {-# INLINE ewfoldl #-}
  ewfoldr f x0 x = f Z (_unScalar x) x0
  {-# INLINE ewfoldr #-}
  elementWise f = fmap Scalar . f . _unScalar
  {-# INLINE elementWise #-}
  indexWise f = fmap Scalar . f Z . _unScalar
  {-# INLINE indexWise #-}
  broadcast = Scalar
  {-# INLINE broadcast #-}
  update _ x _ = Scalar x
  {-# INLINE update #-}


-- * Array implementations.
--   All array implementations have the same structure:
--   Array[Type] (element offset :: Int#) (element length :: Int#)
--                 (content :: ByteArray#)
--   All types can also be instantiated with a single scalar value.


data ArrayF   (ds :: [Nat]) = ArrayF# Int# Int# ByteArray#
                            | FromScalarF# Float#
data ArrayD   (ds :: [Nat]) = ArrayD# Int# Int# ByteArray#
                            | FromScalarD# Double#
data ArrayI   (ds :: [Nat]) = ArrayI# Int# Int# ByteArray#
                            | FromScalarI# Int#
data ArrayI8  (ds :: [Nat]) = ArrayI8# Int# Int# ByteArray#
                            | FromScalarI8# Int#
data ArrayI16 (ds :: [Nat]) = ArrayI16# Int# Int# ByteArray#
                            | FromScalarI16# Int#
data ArrayI32 (ds :: [Nat]) = ArrayI32# Int# Int# ByteArray#
                            | FromScalarI32# Int#
#if WORD_SIZE_IN_BITS < 64
data ArrayI64 (ds :: [Nat]) = ArrayI64# Int# Int# ByteArray#
                            | FromScalarI64# Int64#
#else
data ArrayI64 (ds :: [Nat]) = ArrayI64# Int# Int# ByteArray#
                            | FromScalarI64# Int#
#endif
data ArrayW   (ds :: [Nat]) = ArrayW# Int# Int# ByteArray#
                            | FromScalarW# Word#
data ArrayW8  (ds :: [Nat]) = ArrayW8# Int# Int# ByteArray#
                            | FromScalarW8# Word#
data ArrayW16 (ds :: [Nat]) = ArrayW16# Int# Int# ByteArray#
                            | FromScalarW16# Word#
data ArrayW32 (ds :: [Nat]) = ArrayW32# Int# Int# ByteArray#
                            | FromScalarW32# Word#
#if WORD_SIZE_IN_BITS < 64
data ArrayW64 (ds :: [Nat]) = ArrayW64# Int# Int# ByteArray#
                            | FromScalarW64# Word64#
#else
data ArrayW64 (ds :: [Nat]) = ArrayW64# Int# Int# ByteArray#
                            | FromScalarW64# Word#
#endif

-- * Specialized types
--   More efficient data types for small fixed-size tensors
data FloatX2 = FloatX2# Float# Float#
data FloatX3 = FloatX3# Float# Float# Float#
data FloatX4 = FloatX4# Float# Float# Float# Float#

data DoubleX2 = DoubleX2# Double# Double#
data DoubleX3 = DoubleX3# Double# Double# Double#
data DoubleX4 = DoubleX4# Double# Double# Double# Double#

-- * Recovering type instances at runtime
--   A combination of `ElemType t` and `ArraySize ds` should
--   define an instance of `Array t ds` unambiguously.


-- | Keep information about the element type instance.
--
--   Warning! This part of the code is platform and flag dependent.
data ElemType t
  = t ~ Float  => ETFloat
  | t ~ Double => ETDouble
  | t ~ Int    => ETInt
  | t ~ Int8   => ETInt8
  | t ~ Int16  => ETInt16
  | t ~ Int32  => ETInt32
  | t ~ Int64  => ETInt64
  | t ~ Word   => ETWord
  | t ~ Word8  => ETWord8
  | t ~ Word16 => ETWord16
  | t ~ Word32 => ETWord32
  | t ~ Word64 => ETWord64

-- | Keep information about the array dimensionality
--
--   Warning! This part of the code is platform and flag dependent.
data ArraySize (ds :: [Nat])
  = ds ~ '[]   => ASScalar
  | ds ~ '[2]  => ASX2
  | ds ~ '[3]  => ASX3
  | ds ~ '[4]  => ASX4
  | forall n . (ds ~ '[n], 5 <= n) => ASXN
  | forall n1 n2 ns . ds ~ (n1 ': n2 ': ns) => ASArray

-- | Keep information about the instance behind Array family
--
--   Warning! This part of the code is platform and flag dependent.
data ArrayInstance t (ds :: [Nat])
  = ( Array t ds ~ Scalar t, ds ~ '[]) => AIScalar
  | forall n ns . ( Array t ds ~ ArrayF   ds, ds ~ (n ': ns), t ~ Float ) => AIArrayF
  | forall n ns . ( Array t ds ~ ArrayD   ds, ds ~ (n ': ns), t ~ Double) => AIArrayD
  | forall n ns . ( Array t ds ~ ArrayI   ds, ds ~ (n ': ns), t ~ Int   ) => AIArrayI
  | forall n ns . ( Array t ds ~ ArrayI8  ds, ds ~ (n ': ns), t ~ Int8  ) => AIArrayI8
  | forall n ns . ( Array t ds ~ ArrayI16 ds, ds ~ (n ': ns), t ~ Int16 ) => AIArrayI16
  | forall n ns . ( Array t ds ~ ArrayI32 ds, ds ~ (n ': ns), t ~ Int32 ) => AIArrayI32
  | forall n ns . ( Array t ds ~ ArrayI64 ds, ds ~ (n ': ns), t ~ Int64 ) => AIArrayI64
  | forall n ns . ( Array t ds ~ ArrayW   ds, ds ~ (n ': ns), t ~ Word  ) => AIArrayW
  | forall n ns . ( Array t ds ~ ArrayW8  ds, ds ~ (n ': ns), t ~ Word8 ) => AIArrayW8
  | forall n ns . ( Array t ds ~ ArrayW16 ds, ds ~ (n ': ns), t ~ Word16) => AIArrayW16
  | forall n ns . ( Array t ds ~ ArrayW32 ds, ds ~ (n ': ns), t ~ Word32) => AIArrayW32
  | forall n ns . ( Array t ds ~ ArrayW64 ds, ds ~ (n ': ns), t ~ Word64) => AIArrayW64
  | ( Array t ds ~ FloatX2, ds ~ '[2], t ~ Float) => AIFloatX2
  | ( Array t ds ~ FloatX3, ds ~ '[3], t ~ Float) => AIFloatX3
  | ( Array t ds ~ FloatX4, ds ~ '[4], t ~ Float) => AIFloatX4
  | ( Array t ds ~ DoubleX2, ds ~ '[2], t ~ Double) => AIDoubleX2
  | ( Array t ds ~ DoubleX3, ds ~ '[3], t ~ Double) => AIDoubleX3
  | ( Array t ds ~ DoubleX4, ds ~ '[4], t ~ Double) => AIDoubleX4

-- | A singleton type used to prove that the given Array family instance
--   has a known instance
type ArrayInstanceEvidence t (ds :: [Nat])
  = Evidence (ArrayInstanceInference t ds)


class ElemTypeInference t where
    -- | Pattern match against result to get specific element type
    elemTypeInstance  :: ElemType t

class ArraySizeInference ds where
    -- | Pattern match agains result to get actual array dimensionality
    arraySizeInstance :: ArraySize ds
    inferSnocArrayInstance :: (ElemTypeInference t, KnownDim z)
                           => p t ds -> q z -> ArrayInstanceEvidence t (ds +: z)
    inferConsArrayInstance :: (ElemTypeInference t, KnownDim z)
                           => q z -> p t ds -> ArrayInstanceEvidence t (z :+ ds)
    inferInitArrayInstance :: ElemTypeInference t
                           => p t ds -> ArrayInstanceEvidence t (Init ds)


-- | Use this typeclass constraint in libraries functions if there is a need
--   to select an instance of Array famility at runtime.
--   Combination of `elemTypeInstance` and `arraySizeInstance` allows
--   to bring into typechecker's scope any specific typeclass instance
type ArrayInstanceInference t ds = (ElemTypeInference t, ArraySizeInference ds)



instance ElemTypeInference Float where
    elemTypeInstance = ETFloat
instance ElemTypeInference Double where
    elemTypeInstance = ETDouble
instance ElemTypeInference Int where
    elemTypeInstance = ETInt
instance ElemTypeInference Int8 where
    elemTypeInstance = ETInt8
instance ElemTypeInference Int16 where
    elemTypeInstance = ETInt16
instance ElemTypeInference Int32 where
    elemTypeInstance = ETInt32
instance ElemTypeInference Int64 where
    elemTypeInstance = ETInt64
instance ElemTypeInference Word where
    elemTypeInstance = ETWord
instance ElemTypeInference Word8 where
    elemTypeInstance = ETWord8
instance ElemTypeInference Word16 where
    elemTypeInstance = ETWord16
instance ElemTypeInference Word32 where
    elemTypeInstance = ETWord32
instance ElemTypeInference Word64 where
    elemTypeInstance = ETWord64

instance ArraySizeInference '[] where
    arraySizeInstance = ASScalar
    {-# INLINE arraySizeInstance #-}
    inferSnocArrayInstance _ _ = Evidence
    {-# INLINE inferSnocArrayInstance #-}
    inferConsArrayInstance _ _ = Evidence
    {-# INLINE inferConsArrayInstance #-}
    inferInitArrayInstance _ = error "Init -- empty type-level list"
    {-# INLINE inferInitArrayInstance #-}

instance KnownDim d => ArraySizeInference '[d] where
    arraySizeInstance = case dimVal' @d of
        0 -> unsafeCoerce# ASScalar
        1 -> unsafeCoerce# ASScalar
        2 -> unsafeCoerce# ASX2
        3 -> unsafeCoerce# ASX3
        4 -> unsafeCoerce# ASX4
        _ -> case (unsafeCoerce# Refl :: (5 <=? d) :~: 'True) of Refl -> ASXN
    {-# INLINE arraySizeInstance #-}
    inferSnocArrayInstance _ _ = Evidence
    {-# INLINE inferSnocArrayInstance #-}
    inferConsArrayInstance _ _ = Evidence
    {-# INLINE inferConsArrayInstance #-}
    inferInitArrayInstance _ = Evidence
    {-# INLINE inferInitArrayInstance #-}

instance KnownDim d1 => ArraySizeInference '[d1, d2] where
    arraySizeInstance = ASArray
    {-# INLINE arraySizeInstance #-}
    inferSnocArrayInstance _ _ = Evidence
    {-# INLINE inferSnocArrayInstance #-}
    inferConsArrayInstance _ _ = Evidence
    {-# INLINE inferConsArrayInstance #-}
    inferInitArrayInstance _ = Evidence
    {-# INLINE inferInitArrayInstance #-}


instance ArraySizeInference (d1 ': d2 ': d3 ': ds) where
    arraySizeInstance = ASArray
    {-# INLINE arraySizeInstance #-}
    -- I know that for dimensionality > 2 all instances are the same.
    -- Hence this dirty hack should work.
    -- I have to change this when I have customized N*M instances
    inferSnocArrayInstance p q = unsafeCoerce# (inferConsArrayInstance q p)
    {-# INLINE inferSnocArrayInstance #-}
    inferConsArrayInstance _ _ = Evidence
    {-# INLINE inferConsArrayInstance #-}
    -- I know that for dimensionality > 2 all instances are the same.
    -- Hence this dirty hack should work.
    -- I have to change this when I have customized N*M instances
    inferInitArrayInstance p = unsafeCoerce# (inferConsArrayInstance (Proxy @3) p)
    {-# INLINE inferInitArrayInstance #-}



getArrayInstance :: forall t (ds :: [Nat])
                  . ArrayInstanceInference t ds
                 => ArrayInstance t ds
getArrayInstance = case (elemTypeInstance @t, arraySizeInstance @ds) of
    (ETFloat  , ASScalar) -> AIScalar
    (ETDouble , ASScalar) -> AIScalar
    (ETInt    , ASScalar) -> AIScalar
    (ETInt8   , ASScalar) -> AIScalar
    (ETInt16  , ASScalar) -> AIScalar
    (ETInt32  , ASScalar) -> AIScalar
    (ETInt64  , ASScalar) -> AIScalar
    (ETWord   , ASScalar) -> AIScalar
    (ETWord8  , ASScalar) -> AIScalar
    (ETWord16 , ASScalar) -> AIScalar
    (ETWord32 , ASScalar) -> AIScalar
    (ETWord64 , ASScalar) -> AIScalar

    (ETFloat  , ASX2) -> AIFloatX2
    (ETDouble , ASX2) -> AIDoubleX2
    (ETInt    , ASX2) -> AIArrayI
    (ETInt8   , ASX2) -> AIArrayI8
    (ETInt16  , ASX2) -> AIArrayI16
    (ETInt32  , ASX2) -> AIArrayI32
    (ETInt64  , ASX2) -> AIArrayI64
    (ETWord   , ASX2) -> AIArrayW
    (ETWord8  , ASX2) -> AIArrayW8
    (ETWord16 , ASX2) -> AIArrayW16
    (ETWord32 , ASX2) -> AIArrayW32
    (ETWord64 , ASX2) -> AIArrayW64

    (ETFloat  , ASX3) -> AIFloatX3
    (ETDouble , ASX3) -> AIDoubleX3
    (ETInt    , ASX3) -> AIArrayI
    (ETInt8   , ASX3) -> AIArrayI8
    (ETInt16  , ASX3) -> AIArrayI16
    (ETInt32  , ASX3) -> AIArrayI32
    (ETInt64  , ASX3) -> AIArrayI64
    (ETWord   , ASX3) -> AIArrayW
    (ETWord8  , ASX3) -> AIArrayW8
    (ETWord16 , ASX3) -> AIArrayW16
    (ETWord32 , ASX3) -> AIArrayW32
    (ETWord64 , ASX3) -> AIArrayW64

    (ETFloat  , ASX4) -> AIFloatX4
    (ETDouble , ASX4) -> AIDoubleX4
    (ETInt    , ASX4) -> AIArrayI
    (ETInt8   , ASX4) -> AIArrayI8
    (ETInt16  , ASX4) -> AIArrayI16
    (ETInt32  , ASX4) -> AIArrayI32
    (ETInt64  , ASX4) -> AIArrayI64
    (ETWord   , ASX4) -> AIArrayW
    (ETWord8  , ASX4) -> AIArrayW8
    (ETWord16 , ASX4) -> AIArrayW16
    (ETWord32 , ASX4) -> AIArrayW32
    (ETWord64 , ASX4) -> AIArrayW64

    (ETFloat  , ASXN) -> unsafeCoerce# (AIArrayF :: ArrayInstance Float '[5])
    (ETDouble , ASXN) -> unsafeCoerce# (AIArrayD :: ArrayInstance Double '[5])
    (ETInt    , ASXN) -> AIArrayI
    (ETInt8   , ASXN) -> AIArrayI8
    (ETInt16  , ASXN) -> AIArrayI16
    (ETInt32  , ASXN) -> AIArrayI32
    (ETInt64  , ASXN) -> AIArrayI64
    (ETWord   , ASXN) -> AIArrayW
    (ETWord8  , ASXN) -> AIArrayW8
    (ETWord16 , ASXN) -> AIArrayW16
    (ETWord32 , ASXN) -> AIArrayW32
    (ETWord64 , ASXN) -> AIArrayW64

    (ETFloat  , ASArray) -> AIArrayF
    (ETDouble , ASArray) -> AIArrayD
    (ETInt    , ASArray) -> AIArrayI
    (ETInt8   , ASArray) -> AIArrayI8
    (ETInt16  , ASArray) -> AIArrayI16
    (ETInt32  , ASArray) -> AIArrayI32
    (ETInt64  , ASArray) -> AIArrayI64
    (ETWord   , ASArray) -> AIArrayW
    (ETWord8  , ASArray) -> AIArrayW8
    (ETWord16 , ASArray) -> AIArrayW16
    (ETWord32 , ASArray) -> AIArrayW32
    (ETWord64 , ASArray) -> AIArrayW64

-- | Given element type instance and proper dimension list,
--   infer a corresponding array instance
inferArrayInstance :: forall t ds
                    . ( FiniteList ds
                      , KnownDims ds
                      , ElemTypeInference t
                      )
                  => ArrayInstanceEvidence t ds
inferArrayInstance = case tList @_ @ds of
    TLEmpty                          -> Evidence
    TLCons _ TLEmpty                 -> Evidence
    TLCons _ (TLCons _ TLEmpty)      -> Evidence
    TLCons _ (TLCons _ (TLCons _ _)) -> Evidence


_suppressHlintUnboxedTuplesWarning :: () -> (# (), () #)
_suppressHlintUnboxedTuplesWarning = undefined