{-# LANGUAGE BangPatterns          #-}
{-# LANGUAGE CPP                   #-}
{-# LANGUAGE DefaultSignatures     #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE MagicHash             #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE UnboxedTuples         #-}
{-# LANGUAGE UndecidableInstances  #-}
module Data.Massiv.Core.Common
  ( Array
  , Elt
  , EltRepr
  , Construct(..)
  , Source(..)
  , Load(..)
  , Size(..)
  , Slice(..)
  , OuterSlice(..)
  , InnerSlice(..)
  , Manifest(..)
  , Mutable(..)
  , State(..)
  , WorldState
  , Ragged(..)
  , Nested(..)
  , NestedStruct
  , makeArray
  , singleton
  
  , elemsCount
  , isEmpty
  
  , (!?)
  , index
  , indexWith
  , (!)
  , index'
  , (??)
  , defaultIndex
  , borderIndex
  , evaluateAt
  , module Data.Massiv.Core.Index
  
  , imapM_
  , module Data.Massiv.Core.Computation
  ) where
import           Control.Monad.Primitive
import           Data.Massiv.Core.Computation
import           Data.Massiv.Core.Index
import           Data.Massiv.Core.Scheduler
import           Data.Typeable
import           GHC.Prim
#include "massiv.h"
data family Array r ix e :: *
type family EltRepr r ix :: *
type family Elt r ix e :: * where
  Elt r Ix1 e = e
  Elt r ix  e = Array (EltRepr r ix) (Lower ix) e
type family NestedStruct r ix e :: *
class (Typeable r, Index ix) => Construct r ix e where
  
  getComp :: Array r ix e -> Comp
  
  setComp :: Comp -> Array r ix e -> Array r ix e
  
  unsafeMakeArray :: Comp -> ix -> (ix -> e) -> Array r ix e
class Construct r ix e => Size r ix e where
  
  size :: Array r ix e -> ix
  
  unsafeResize :: Index ix' => ix' -> Array r ix e -> Array r ix' e
  
  
  unsafeExtract :: ix -> ix -> Array r ix e -> Array (EltRepr r ix) ix e
class Size r ix e => Source r ix e where
  
  
  unsafeIndex :: Array r ix e -> ix -> e
  unsafeIndex =
    INDEX_CHECK("(Source r ix e).unsafeIndex",
                size, \ !arr -> unsafeLinearIndex arr . toLinearIndex (size arr))
  {-# INLINE unsafeIndex #-}
  
  
  unsafeLinearIndex :: Array r ix e -> Int -> e
  unsafeLinearIndex !arr = unsafeIndex arr . fromLinearIndex (size arr)
  {-# INLINE unsafeLinearIndex #-}
class Size r ix e => Load r ix e where
  
  loadS
    :: Monad m =>
       Array r ix e 
    -> (Int -> m e) 
    -> (Int -> e -> m ()) 
    -> m ()
  loadS = loadArray 1 id
  {-# INLINE loadS #-}
  
  loadP
    :: [Int] 
             
             
    -> Array r ix e 
    -> (Int -> IO e) 
    -> (Int -> e -> IO ()) 
    -> IO ()
  loadP wIds arr unsafeRead unsafeWrite =
    withScheduler_ wIds $ \scheduler ->
      loadArray (numWorkers scheduler) (scheduleWork scheduler) arr unsafeRead unsafeWrite
  {-# INLINE loadP #-}
  
  
  loadArrayWithStride
    :: Monad m =>
       Int 
    -> (m () -> m ()) 
                      
    -> Stride ix 
    -> ix 
    -> Array r ix e 
    -> (Int -> m e) 
    -> (Int -> e -> m ()) 
    -> m ()
  default loadArrayWithStride
    :: (Source r ix e, Monad m) =>
       Int
    -> (m () -> m ())
    -> Stride ix
    -> ix
    -> Array r ix e
    -> (Int -> m e)
    -> (Int -> e -> m ())
    -> m ()
  loadArrayWithStride numWorkers' scheduleWork' stride resultSize arr _ =
    splitLinearlyWith_ numWorkers' scheduleWork' (totalElem resultSize) unsafeLinearWriteWithStride
    where
      strideIx = unStride stride
      unsafeLinearWriteWithStride =
        unsafeIndex arr . liftIndex2 (*) strideIx . fromLinearIndex resultSize
      {-# INLINE unsafeLinearWriteWithStride #-}
  {-# INLINE loadArrayWithStride #-}
  
  
  
  
  loadArray
    :: Monad m =>
       Int 
    -> (m () -> m ()) 
                      
    -> Array r ix e 
    -> (Int -> m e) 
    -> (Int -> e -> m ()) 
    -> m ()
  default loadArray
    :: (Source r ix e, Monad m) =>
       Int
    -> (m () -> m ())
    -> Array r ix e
    -> (Int -> m e)
    -> (Int -> e -> m ())
    -> m ()
  loadArray numWorkers' scheduleWork' arr _ =
    splitLinearlyWith_ numWorkers' scheduleWork' (totalElem (size arr)) (unsafeLinearIndex arr)
  {-# INLINE loadArray #-}
class OuterSlice r ix e where
  
  unsafeOuterSlice :: Array r ix e -> Int -> Elt r ix e
  outerLength :: Array r ix e -> Int
  default outerLength :: Size r ix e => Array r ix e -> Int
  outerLength = headDim . size
class Size r ix e => InnerSlice r ix e where
  unsafeInnerSlice :: Array r ix e -> (Lower ix, Int) -> Int -> Elt r ix e
class Size r ix e => Slice r ix e where
  unsafeSlice :: Array r ix e -> ix -> ix -> Dim -> Maybe (Elt r ix e)
class Source r ix e => Manifest r ix e where
  unsafeLinearIndexM :: Array r ix e -> Int -> e
data State s = State (State# s)
type WorldState = State RealWorld
class Manifest r ix e => Mutable r ix e where
  data MArray s r ix e :: *
  
  msize :: MArray s r ix e -> ix
  unsafeThaw :: PrimMonad m =>
                Array r ix e -> m (MArray (PrimState m) r ix e)
  unsafeFreeze :: PrimMonad m =>
                  Comp -> MArray (PrimState m) r ix e -> m (Array r ix e)
  
  
  unsafeNew :: PrimMonad m =>
               ix -> m (MArray (PrimState m) r ix e)
  
  
  unsafeNewZero :: PrimMonad m =>
                   ix -> m (MArray (PrimState m) r ix e)
  unsafeLinearRead :: PrimMonad m =>
                      MArray (PrimState m) r ix e -> Int -> m e
  unsafeLinearWrite :: PrimMonad m =>
                       MArray (PrimState m) r ix e -> Int -> e -> m ()
  
  
  unsafeNewA :: Applicative f => ix -> WorldState -> f (WorldState, MArray RealWorld r ix e)
  unsafeNewA sz (State s#) =
    case internal (unsafeNew sz :: IO (MArray RealWorld r ix e)) s# of
      (# s'#, ma #) -> pure (State s'#, ma)
  {-# INLINE unsafeNewA #-}
  unsafeThawA :: Applicative m =>
                 Array r ix e -> WorldState -> m (WorldState, MArray RealWorld r ix e)
  unsafeThawA arr (State s#) =
    case internal (unsafeThaw arr :: IO (MArray RealWorld r ix e)) s# of
      (# s'#, ma #) -> pure (State s'#, ma)
  {-# INLINE unsafeThawA #-}
  unsafeFreezeA :: Applicative m =>
                   Comp -> MArray RealWorld r ix e -> WorldState -> m (WorldState, Array r ix e)
  unsafeFreezeA comp marr (State s#) =
    case internal (unsafeFreeze comp marr :: IO (Array r ix e)) s# of
      (# s'#, a #) -> pure (State s'#, a)
  {-# INLINE unsafeFreezeA #-}
  unsafeLinearWriteA :: Applicative m =>
                        MArray RealWorld r ix e -> Int -> e -> WorldState -> m WorldState
  unsafeLinearWriteA marr i val (State s#) =
    case internal (unsafeLinearWrite marr i val :: IO ()) s# of
      (# s'#, _ #) -> pure (State s'#)
  {-# INLINE unsafeLinearWriteA #-}
class Nested r ix e where
  fromNested :: NestedStruct r ix e -> Array r ix e
  toNested :: Array r ix e -> NestedStruct r ix e
class Construct r ix e => Ragged r ix e where
  empty :: Comp -> Array r ix e
  isNull :: Array r ix e -> Bool
  cons :: Elt r ix e -> Array r ix e -> Array r ix e
  uncons :: Array r ix e -> Maybe (Elt r ix e, Array r ix e)
  
  
  unsafeGenerateM :: Monad m => Comp -> ix -> (ix -> m e) -> m (Array r ix e)
  edgeSize :: Array r ix e -> ix
  flatten :: Array r ix e -> Array r Ix1 e
  loadRagged ::
    (IO () -> IO ()) -> (Int -> e -> IO a) -> Int -> Int -> ix -> Array r ix e -> IO ()
  
  
  raggedFormat :: (e -> String) -> String -> Array r ix e -> String
makeArray :: Construct r ix e =>
             Comp 
          -> ix 
          -> (ix -> e) 
          -> Array r ix e
makeArray !c = unsafeMakeArray c . liftIndex (max 0)
{-# INLINE makeArray #-}
singleton :: Construct r ix e =>
             Comp 
          -> e 
          -> Array r ix e
singleton !c = unsafeMakeArray c (pureIndex 1) . const
{-# INLINE singleton #-}
infixl 4 !, !?, ??
(!) :: Manifest r ix e => Array r ix e -> ix -> e
(!) = index'
{-# INLINE (!) #-}
(!?) :: Manifest r ix e => Array r ix e -> ix -> Maybe e
(!?) = index
{-# INLINE (!?) #-}
(??) :: Manifest r ix e => Maybe (Array r ix e) -> ix -> Maybe e
(??) Nothing    = const Nothing
(??) (Just arr) = (arr !?)
{-# INLINE (??) #-}
index :: Manifest r ix e => Array r ix e -> ix -> Maybe e
index arr = handleBorderIndex (Fill Nothing) (size arr) (Just . unsafeIndex arr)
{-# INLINE index #-}
defaultIndex :: Manifest r ix e => e -> Array r ix e -> ix -> e
defaultIndex defVal = borderIndex (Fill defVal)
{-# INLINE defaultIndex #-}
borderIndex :: Manifest r ix e => Border e -> Array r ix e -> ix -> e
borderIndex border arr = handleBorderIndex border (size arr) (unsafeIndex arr)
{-# INLINE borderIndex #-}
index' :: Manifest r ix e => Array r ix e -> ix -> e
index' arr ix =
  borderIndex (Fill (errorIx "Data.Massiv.Array.index" (size arr) ix)) arr ix
{-# INLINE index' #-}
evaluateAt :: Source r ix e => Array r ix e -> ix -> e
evaluateAt !arr !ix =
  handleBorderIndex
    (Fill (errorIx "Data.Massiv.Array.evaluateAt" (size arr) ix))
    (size arr)
    (unsafeIndex arr)
    ix
{-# INLINE evaluateAt #-}
indexWith ::
     Index ix
  => String 
  -> Int 
  -> String
  -> (arr -> ix) 
  -> (arr -> ix -> e) 
  -> arr 
  -> ix 
  -> e
indexWith fileName lineNo funName getSize f arr ix
  | isSafeIndex (getSize arr) ix = f arr ix
  | otherwise = errorIx ("<" ++ fileName ++ ":" ++ show lineNo ++ "> " ++ funName) (getSize arr) ix
{-# INLINE indexWith #-}
imapM_ :: (Source r ix a, Monad m) => (ix -> a -> m b) -> Array r ix a -> m ()
imapM_ f !arr =
  iterM_ zeroIndex (size arr) (pureIndex 1) (<) $ \ !ix -> f ix (unsafeIndex arr ix)
{-# INLINE imapM_ #-}
elemsCount :: Size r ix e => Array r ix e -> Int
elemsCount = totalElem . size
{-# INLINE elemsCount #-}
isEmpty :: Size r ix e => Array r ix e -> Bool
isEmpty !arr = 0 == elemsCount arr
{-# INLINE isEmpty #-}