{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE ExplicitNamespaces #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeSynonymInstances #-}
#if __GLASGOW_HASKELL__ < 820
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
#endif
module Data.Massiv.Core.Index.Internal
  ( Sz(SafeSz)
  , pattern Sz
  , pattern Sz1
  , type Sz1
  , unSz
  , zeroSz
  , oneSz
  , consSz
  , unconsSz
  , snocSz
  , unsnocSz
  , setSzM
  , insertSzM
  , pullOutSzM
  , Dim(..)
  , Dimension(DimN)
  , pattern Dim1
  , pattern Dim2
  , pattern Dim3
  , pattern Dim4
  , pattern Dim5
  , IsIndexDimension
  , Lower
  , Index(..)
  , Ix0(..)
  , type Ix1
  , pattern Ix1
  , IndexException(..)
  , SizeException(..)
  , ShapeException(..)
  ) where
import Control.DeepSeq
import Control.Exception (Exception(..))
import Control.Monad.Catch (MonadThrow(..))
import Data.Coerce
import Data.Massiv.Core.Iterator
import Data.Typeable
import GHC.TypeLits
newtype Sz ix =
  SafeSz ix
  
  
  
  
  deriving (Eq, Ord, NFData)
pattern Sz :: Index ix => ix -> Sz ix
pattern Sz ix <- SafeSz ix where
        Sz ix = SafeSz (liftIndex (max 0) ix)
{-# COMPLETE Sz #-}
type Sz1 = Sz Ix1
pattern Sz1 :: Ix1 -> Sz1
pattern Sz1 ix  <- SafeSz ix where
        Sz1 ix = SafeSz (max 0 ix)
{-# COMPLETE Sz1 #-}
instance Index ix => Show (Sz ix) where
  showsPrec n sz@(SafeSz usz) s =
    if n == 0
      then str ++ s
      else '(' : str ++ ')' : s
    where
      str =
        "Sz" ++
        case unDim (dimensions sz) of
          1 -> "1 " ++ show usz
          _ -> " (" ++ show usz ++ ")"
instance (Num ix, Index ix) => Num (Sz ix) where
  (+) x y = Sz (coerce x + coerce y)
  {-# INLINE (+) #-}
  (-) x y = Sz (coerce x - coerce y)
  {-# INLINE (-) #-}
  (*) x y = SafeSz (coerce x * coerce y)
  {-# INLINE (*) #-}
  abs !x = x
  {-# INLINE abs #-}
  negate !_x = 0
  {-# INLINE negate #-}
  signum x = SafeSz (signum (coerce x))
  {-# INLINE signum #-}
  fromInteger = Sz . fromInteger
  {-# INLINE fromInteger #-}
unSz :: Sz ix -> ix
unSz (SafeSz ix) = ix
{-# INLINE unSz #-}
zeroSz :: Index ix => Sz ix
zeroSz = SafeSz (pureIndex 0)
{-# INLINE zeroSz #-}
oneSz :: Index ix => Sz ix
oneSz = SafeSz (pureIndex 1)
{-# INLINE oneSz #-}
consSz :: Index ix => Sz1 -> Sz (Lower ix) -> Sz ix
consSz (SafeSz i) (SafeSz ix) = SafeSz (consDim i ix)
{-# INLINE consSz #-}
snocSz :: Index ix => Sz (Lower ix) -> Sz1 -> Sz ix
snocSz (SafeSz i) (SafeSz ix) = SafeSz (snocDim i ix)
{-# INLINE snocSz #-}
setSzM :: (MonadThrow m, Index ix) => Sz ix -> Dim -> Sz Int -> m (Sz ix)
setSzM (SafeSz sz) dim (SafeSz sz1) = SafeSz <$> setDimM sz dim sz1
{-# INLINE setSzM #-}
insertSzM :: (MonadThrow m, Index ix) => Sz (Lower ix) -> Dim -> Sz Int -> m (Sz ix)
insertSzM (SafeSz sz) dim (SafeSz sz1) = SafeSz <$> insertDimM sz dim sz1
{-# INLINE insertSzM #-}
unconsSz :: Index ix => Sz ix -> (Sz1, Sz (Lower ix))
unconsSz (SafeSz sz) = coerce (unconsDim sz)
{-# INLINE unconsSz #-}
unsnocSz :: Index ix => Sz ix -> (Sz (Lower ix), Sz1)
unsnocSz (SafeSz sz) = coerce (unsnocDim sz)
{-# INLINE unsnocSz #-}
pullOutSzM :: (MonadThrow m, Index ix) => Sz ix -> Dim -> m (Sz Ix1, Sz (Lower ix))
pullOutSzM (SafeSz sz) = fmap coerce . pullOutDimM sz
{-# INLINE pullOutSzM #-}
newtype Dim = Dim { unDim :: Int } deriving (Eq, Ord, Num, Real, Integral, Enum)
instance Show Dim where
  show (Dim d) = "(Dim " ++ show d ++ ")"
data Dimension (n :: Nat) where
  DimN :: (1 <= n, KnownNat n) => Dimension n
pattern Dim1 :: Dimension 1
pattern Dim1 = DimN
pattern Dim2 :: Dimension 2
pattern Dim2 = DimN
pattern Dim3 :: Dimension 3
pattern Dim3 = DimN
pattern Dim4 :: Dimension 4
pattern Dim4 = DimN
pattern Dim5 :: Dimension 5
pattern Dim5 = DimN
type IsIndexDimension ix n = (1 <= n, n <= Dimensions ix, Index ix, KnownNat n)
type family Lower ix :: *
class ( Eq ix
      , Ord ix
      , Show ix
      , NFData ix
      , Eq (Lower ix)
      , Ord (Lower ix)
      , Show (Lower ix)
      , NFData (Lower ix)
      ) =>
      Index ix
  where
  
  
  
  type Dimensions ix :: Nat
  
  
  
  dimensions :: proxy ix -> Dim
  
  
  
  totalElem :: Sz ix -> Int
  
  
  
  consDim :: Int -> Lower ix -> ix
  
  
  
  unconsDim :: ix -> (Int, Lower ix)
  
  
  
  snocDim :: Lower ix -> Int -> ix
  
  
  
  unsnocDim :: ix -> (Lower ix, Int)
  
  
  
  pullOutDimM :: MonadThrow m => ix -> Dim -> m (Int, Lower ix)
  
  insertDimM :: MonadThrow m => Lower ix -> Dim -> Int -> m ix
  
  getDimM :: MonadThrow m => ix -> Dim -> m Int
  
  setDimM :: MonadThrow m => ix -> Dim -> Int -> m ix
  
  
  
  pureIndex :: Int -> ix
  
  
  
  liftIndex2 :: (Int -> Int -> Int) -> ix -> ix -> ix
  
  
  
  liftIndex :: (Int -> Int) -> ix -> ix
  liftIndex f = liftIndex2 (\_ i -> f i) (pureIndex 0)
  {-# INLINE [1] liftIndex #-}
  
  foldlIndex :: (a -> Int -> a) -> a -> ix -> a
  default foldlIndex :: Index (Lower ix) =>
    (a -> Int -> a) -> a -> ix -> a
  foldlIndex f !acc !ix = foldlIndex f (f acc i0) ixL
    where
      !(i0, ixL) = unconsDim ix
  {-# INLINE [1] foldlIndex #-}
  
  
  
  
  isSafeIndex ::
       Sz ix 
    -> ix 
    -> Bool
  default isSafeIndex :: Index (Lower ix) =>
    Sz ix -> ix -> Bool
  isSafeIndex sz !ix = isSafeIndex n0 i0 && isSafeIndex szL ixL
    where
      !(n0, szL) = unconsSz sz
      !(i0, ixL) = unconsDim ix
  {-# INLINE [1] isSafeIndex #-}
  
  
  
  toLinearIndex ::
       Sz ix 
    -> ix 
    -> Int
  default toLinearIndex :: Index (Lower ix) =>
    Sz ix -> ix -> Int
  toLinearIndex (SafeSz sz) !ix = toLinearIndex (SafeSz szL) ixL * n + i
    where
      !(szL, n) = unsnocDim sz
      !(ixL, i) = unsnocDim ix
  {-# INLINE [1] toLinearIndex #-}
  
  
  
  
  toLinearIndexAcc :: Int -> ix -> ix -> Int
  default toLinearIndexAcc :: Index (Lower ix) =>
    Int -> ix -> ix -> Int
  toLinearIndexAcc !acc !sz !ix = toLinearIndexAcc (acc * n + i) szL ixL
    where
      !(n, szL) = unconsDim sz
      !(i, ixL) = unconsDim ix
  {-# INLINE [1] toLinearIndexAcc #-}
  
  
  
  fromLinearIndex :: Sz ix -> Int -> ix
  default fromLinearIndex :: Index (Lower ix) =>
    Sz ix -> Int -> ix
  fromLinearIndex (SafeSz sz) k = consDim q ixL
    where
      !(q, ixL) = fromLinearIndexAcc (snd (unconsDim sz)) k
  {-# INLINE [1] fromLinearIndex #-}
  
  
  
  
  fromLinearIndexAcc :: ix -> Int -> (Int, ix)
  default fromLinearIndexAcc :: Index (Lower ix) =>
    ix -> Int -> (Int, ix)
  fromLinearIndexAcc ix' !k = (q, consDim r ixL)
    where
      !(m, ix) = unconsDim ix'
      !(kL, ixL) = fromLinearIndexAcc ix k
      !(q, r) = quotRem kL m
  {-# INLINE [1] fromLinearIndexAcc #-}
  
  
  
  
  repairIndex ::
       Sz ix 
    -> ix 
    -> (Sz Int -> Int -> Int) 
    -> (Sz Int -> Int -> Int) 
    -> ix
  default repairIndex :: Index (Lower ix) =>
    Sz ix -> ix -> (Sz Int -> Int -> Int) -> (Sz Int -> Int -> Int) -> ix
  repairIndex sz !ix rBelow rOver =
    consDim (repairIndex n i rBelow rOver) (repairIndex szL ixL rBelow rOver)
    where
      !(n, szL) = unconsSz sz
      !(i, ixL) = unconsDim ix
  {-# INLINE [1] repairIndex #-}
  
  
  
  iterM ::
       Monad m
    => ix 
    -> ix 
    -> ix 
    -> (Int -> Int -> Bool) 
    -> a 
    -> (ix -> a -> m a) 
    -> m a
  default iterM :: (Index (Lower ix), Monad m) =>
    ix -> ix -> ix -> (Int -> Int -> Bool) -> a -> (ix -> a -> m a) -> m a
  iterM !sIx eIx !incIx cond !acc f =
    loopM s (`cond` e) (+ inc) acc $ \ !i !acc0 ->
      iterM sIxL eIxL incIxL cond acc0 $ \ !ix -> f (consDim i ix)
    where
      !(s, sIxL) = unconsDim sIx
      !(e, eIxL) = unconsDim eIx
      !(inc, incIxL) = unconsDim incIx
  {-# INLINE iterM #-}
  
  
  
  
  iterM_ :: Monad m => ix -> ix -> ix -> (Int -> Int -> Bool) -> (ix -> m a) -> m ()
  default iterM_ :: (Index (Lower ix), Monad m) =>
    ix -> ix -> ix -> (Int -> Int -> Bool) -> (ix -> m a) -> m ()
  iterM_ !sIx eIx !incIx cond f =
    loopM_ s (`cond` e) (+ inc) $ \ !i -> iterM_ sIxL eIxL incIxL cond $ \ !ix -> f (consDim i ix)
    where
      !(s, sIxL) = unconsDim sIx
      !(e, eIxL) = unconsDim eIx
      !(inc, incIxL) = unconsDim incIx
  {-# INLINE iterM_ #-}
data Ix0 = Ix0 deriving (Eq, Ord, Show)
instance NFData Ix0 where
  rnf Ix0 = ()
type Ix1 = Int
pattern Ix1 :: Int -> Ix1
pattern Ix1 i = i
type instance Lower Int = Ix0
instance Index Ix1 where
  type Dimensions Ix1 = 1
  dimensions _ = 1
  {-# INLINE [1] dimensions #-}
  totalElem = unSz
  {-# INLINE [1] totalElem #-}
  isSafeIndex (SafeSz k) !i = 0 <= i && i < k
  {-# INLINE [1] isSafeIndex #-}
  toLinearIndex _ = id
  {-# INLINE [1] toLinearIndex #-}
  toLinearIndexAcc !acc m i  = acc * m + i
  {-# INLINE [1] toLinearIndexAcc #-}
  fromLinearIndex _ = id
  {-# INLINE [1] fromLinearIndex #-}
  fromLinearIndexAcc n k = k `quotRem` n
  {-# INLINE [1] fromLinearIndexAcc #-}
  repairIndex k@(SafeSz ksz) !i rBelow rOver
    | i < 0 = rBelow k i
    | i >= ksz = rOver k i
    | otherwise = i
  {-# INLINE [1] repairIndex #-}
  consDim i _ = i
  {-# INLINE [1] consDim #-}
  unconsDim i = (i, Ix0)
  {-# INLINE [1] unconsDim #-}
  snocDim _ i = i
  {-# INLINE [1] snocDim #-}
  unsnocDim i = (Ix0, i)
  {-# INLINE [1] unsnocDim #-}
  getDimM i  1 = pure i
  getDimM ix d = throwM $ IndexDimensionException ix d
  {-# INLINE [1] getDimM #-}
  setDimM _  1 i = pure i
  setDimM ix d _ = throwM $ IndexDimensionException ix d
  {-# INLINE [1] setDimM #-}
  pullOutDimM i  1 = pure (i, Ix0)
  pullOutDimM ix d = throwM $ IndexDimensionException ix d
  {-# INLINE [1] pullOutDimM #-}
  insertDimM Ix0 1 i = pure i
  insertDimM ix  d _ = throwM $ IndexDimensionException ix d
  {-# INLINE [1] insertDimM #-}
  pureIndex i = i
  {-# INLINE [1] pureIndex #-}
  liftIndex f = f
  {-# INLINE [1] liftIndex #-}
  liftIndex2 f = f
  {-# INLINE [1] liftIndex2 #-}
  foldlIndex f = f
  {-# INLINE [1] foldlIndex #-}
  iterM k0 k1 inc cond = loopM k0 (`cond` k1) (+inc)
  {-# INLINE iterM #-}
  iterM_ k0 k1 inc cond = loopM_ k0 (`cond` k1) (+inc)
  {-# INLINE iterM_ #-}
data IndexException where
  
  IndexZeroException :: Index ix => !ix -> IndexException
  
  IndexDimensionException :: (Show ix, Typeable ix) => !ix -> Dim -> IndexException
  
  IndexOutOfBoundsException :: Index ix => !(Sz ix) -> !ix -> IndexException
instance Show IndexException where
  show (IndexZeroException ix) = "IndexZeroException: " ++ show ix
  show (IndexDimensionException ix dim) =
    "IndexDimensionException: " ++ show dim ++ " for " ++ show ix
  show (IndexOutOfBoundsException sz ix) =
    "IndexOutOfBoundsException: " ++ showsPrec 1 ix " not safe for (" ++ show sz ++ ")"
  showsPrec 0 arr s = show arr ++ s
  showsPrec _ arr s = '(' : show arr ++ ")" ++ s
instance Exception IndexException
data SizeException where
  
  SizeMismatchException :: Index ix => !(Sz ix) -> !(Sz ix) -> SizeException
  
  SizeElementsMismatchException :: (Index ix, Index ix') => !(Sz ix) -> !(Sz ix') -> SizeException
  
  SizeSubregionException :: Index ix => !(Sz ix) -> !ix -> !(Sz ix) -> SizeException
  
  SizeEmptyException :: Index ix => !(Sz ix) -> SizeException
instance Exception SizeException
instance Show SizeException where
  show (SizeMismatchException sz sz') =
    "SizeMismatchException: (" ++ show sz ++ ") vs (" ++ show sz' ++ ")"
  show (SizeElementsMismatchException sz sz') =
    "SizeElementsMismatchException: (" ++ show sz ++ ") vs (" ++ show sz' ++ ")"
  show (SizeSubregionException sz' ix sz) =
    "SizeSubregionException: (" ++
    show sz' ++ ") is to small for " ++ show ix ++ " (" ++ show sz ++ ")"
  show (SizeEmptyException sz) =
    "SizeEmptyException: (" ++ show sz ++ ") corresponds to an empty array"
  showsPrec 0 arr s = show arr ++ s
  showsPrec _ arr s = '(' : show arr ++ ")" ++ s
data ShapeException
  = DimTooShortException !Sz1 !Sz1
  | DimTooLongException
  deriving Eq
instance Show ShapeException where
  show (DimTooShortException sz sz') =
    "DimTooShortException: expected (" ++ show sz ++ "), got (" ++ show sz' ++ ")"
  show DimTooLongException =
    "DimTooLongException"
  showsPrec 0 arr s = show arr ++ s
  showsPrec _ arr s = '(' : show arr ++ ")" ++ s
instance Exception ShapeException