{-# LANGUAGE BangPatterns               #-}
{-# LANGUAGE ConstraintKinds            #-}
{-# LANGUAGE DataKinds                  #-}
{-# LANGUAGE DefaultSignatures          #-}
{-# LANGUAGE FlexibleContexts           #-}
{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE GADTs                      #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
{-# LANGUAGE TypeFamilies               #-}
{-# LANGUAGE TypeOperators              #-}
module Data.Massiv.Core.Index.Class where
import           Control.DeepSeq           (NFData (..))
import           Data.Functor.Identity     (runIdentity)
import           Data.Massiv.Core.Iterator
import           GHC.TypeLits
newtype Dim = Dim Int deriving (Show, Eq, Ord, Num, Real, Integral, Enum)
data Dimension (n :: Nat) where
  Dim1 :: Dimension 1
  Dim2 :: Dimension 2
  Dim3 :: Dimension 3
  Dim4 :: Dimension 4
  Dim5 :: Dimension 5
  DimN :: (6 <= n, KnownNat n) => Dimension n
type IsIndexDimension ix n = (1 <= n, n <= Dimensions ix, Index ix, KnownNat n)
data Ix0 = Ix0 deriving (Eq, Ord, Show)
type Ix1T = Int
type Ix2T = (Int, Int)
type Ix3T = (Int, Int, Int)
type Ix4T = (Int, Int, Int, Int)
type Ix5T = (Int, Int, Int, Int, Int)
type family Lower ix :: *
type instance Lower Ix1T = Ix0
type instance Lower Ix2T = Ix1T
type instance Lower Ix3T = Ix2T
type instance Lower Ix4T = Ix3T
type instance Lower Ix5T = Ix4T
class (Eq ix, Ord ix, Show ix, NFData ix) => Index ix where
  type Dimensions ix :: Nat
  
  dimensions :: ix -> Dim
  
  totalElem :: ix -> Int
  
  consDim :: Int -> Lower ix -> ix
  
  unconsDim :: ix -> (Int, Lower ix)
  
  snocDim :: Lower ix -> Int -> ix
  
  unsnocDim :: ix -> (Lower ix, Int)
  
  
  dropDim :: ix -> Dim -> Maybe (Lower ix)
  dropDim ix = fmap snd . pullOutDim ix
  {-# INLINE [1] dropDim #-}
  
  pullOutDim :: ix -> Dim -> Maybe (Int, Lower ix)
  
  insertDim :: Lower ix -> Dim -> Int -> Maybe ix
  
  getDim :: ix -> Dim -> Maybe Int
  getDim = getIndex
  {-# INLINE [1] getDim #-}
  
  setDim :: ix -> Dim -> Int -> Maybe ix
  setDim = setIndex
  {-# INLINE [1] setDim #-}
  
  getIndex :: ix -> Dim -> Maybe Int
  getIndex = getDim
  {-# INLINE [1] getIndex #-}
  
  setIndex :: ix -> Dim -> Int -> Maybe ix
  setIndex = setDim
  {-# INLINE [1] setIndex #-}
  
  pureIndex :: Int -> ix
  
  liftIndex2 :: (Int -> Int -> Int) -> ix -> ix -> ix
  
  liftIndex :: (Int -> Int) -> ix -> ix
  liftIndex f = liftIndex2 (\_ i -> f i) (pureIndex 0)
  {-# INLINE [1] liftIndex #-}
  foldlIndex :: (a -> Int -> a) -> a -> ix -> a
  default foldlIndex :: Index (Lower ix) => (a -> Int -> a) -> a -> ix -> a
  foldlIndex f !acc !ix = foldlIndex f (f acc i0) ixL
    where
      !(i0, ixL) = unconsDim ix
  {-# INLINE [1] foldlIndex #-}
  
  isSafeIndex :: ix 
              -> ix 
              -> Bool
  default isSafeIndex :: Index (Lower ix) => ix -> ix -> Bool
  isSafeIndex !sz !ix = isSafeIndex n0 i0 && isSafeIndex szL ixL
    where
      !(n0, szL) = unconsDim sz
      !(i0, ixL) = unconsDim ix
  {-# INLINE [1] isSafeIndex #-}
  
  toLinearIndex :: ix 
                -> ix 
                -> Int
  default toLinearIndex :: Index (Lower ix) => ix -> ix -> Int
  toLinearIndex !sz !ix = toLinearIndex szL ixL * n + i
    where !(szL, n) = unsnocDim sz
          !(ixL, i) = unsnocDim ix
  {-# INLINE [1] toLinearIndex #-}
  
  
  toLinearIndexAcc :: Int -> ix -> ix -> Int
  default toLinearIndexAcc :: Index (Lower ix) => Int -> ix -> ix -> Int
  toLinearIndexAcc !acc !sz !ix = toLinearIndexAcc (acc * n + i) szL ixL
    where !(n, szL) = unconsDim sz
          !(i, ixL) = unconsDim ix
  {-# INLINE [1] toLinearIndexAcc #-}
  
  fromLinearIndex :: ix -> Int -> ix
  default fromLinearIndex :: Index (Lower ix) => ix -> Int -> ix
  fromLinearIndex sz k = consDim q ixL
    where !(q, ixL) = fromLinearIndexAcc (snd (unconsDim sz)) k
  {-# INLINE [1] fromLinearIndex #-}
  
  
  fromLinearIndexAcc :: ix -> Int -> (Int, ix)
  default fromLinearIndexAcc :: Index (Lower ix) => ix -> Int -> (Int, ix)
  fromLinearIndexAcc ix' !k = (q, consDim r ixL)
    where !(m, ix) = unconsDim ix'
          !(kL, ixL) = fromLinearIndexAcc ix k
          !(q, r) = quotRem kL m
  {-# INLINE [1] fromLinearIndexAcc #-}
  
  
  repairIndex :: ix 
              -> ix 
              -> (Int -> Int -> Int) 
              -> (Int -> Int -> Int) 
              -> ix
  default repairIndex :: Index (Lower ix)
    => ix -> ix -> (Int -> Int -> Int) -> (Int -> Int -> Int) -> ix
  repairIndex !sz !ix rBelow rOver =
    consDim (repairIndex n i rBelow rOver) (repairIndex szL ixL rBelow rOver)
    where !(n, szL) = unconsDim sz
          !(i, ixL) = unconsDim ix
  {-# INLINE [1] repairIndex #-}
  
  iter :: ix -> ix -> ix -> (Int -> Int -> Bool) -> a -> (ix -> a -> a) -> a
  iter sIx eIx incIx cond acc f =
    runIdentity $ iterM sIx eIx incIx cond acc (\ix -> return . f ix)
  {-# INLINE iter #-}
  
  iterM :: Monad m =>
           ix 
        -> ix 
        -> ix 
        -> (Int -> Int -> Bool) 
        -> a 
        -> (ix -> a -> m a) 
        -> m a
  default iterM :: (Index (Lower ix), Monad m)
    => ix -> ix -> ix -> (Int -> Int -> Bool) -> a -> (ix -> a -> m a) -> m a
  iterM !sIx !eIx !incIx cond !acc f =
    loopM s (`cond` e) (+ inc) acc $ \ !i !acc0 ->
      iterM sIxL eIxL incIxL cond acc0 $ \ !ix ->
        f (consDim i ix)
    where
      !(s, sIxL) = unconsDim sIx
      !(e, eIxL) = unconsDim eIx
      !(inc, incIxL) = unconsDim incIx
  {-# INLINE iterM #-}
  
  
  iterM_ :: Monad m => ix -> ix -> ix -> (Int -> Int -> Bool) -> (ix -> m a) -> m ()
  default iterM_ :: (Index (Lower ix), Monad m)
    => ix -> ix -> ix -> (Int -> Int -> Bool) -> (ix -> m a) -> m ()
  iterM_ !sIx !eIx !incIx cond f =
    loopM_ s (`cond` e) (+ inc) $ \ !i ->
      iterM_ sIxL eIxL incIxL cond $ \ !ix ->
        f (consDim i ix)
    where
      !(s, sIxL) = unconsDim sIx
      !(e, eIxL) = unconsDim eIx
      !(inc, incIxL) = unconsDim incIx
  {-# INLINE iterM_ #-}
{-# DEPRECATED getIndex "In favor of 'getDim'" #-}
{-# DEPRECATED setIndex "In favor of 'setDim'" #-}
instance Index Ix1T where
  type Dimensions Ix1T = 1
  dimensions _ = 1
  {-# INLINE [1] dimensions #-}
  totalElem = id
  {-# INLINE [1] totalElem #-}
  isSafeIndex !k !i = 0 <= i && i < k
  {-# INLINE [1] isSafeIndex #-}
  toLinearIndex _ = id
  {-# INLINE [1] toLinearIndex #-}
  toLinearIndexAcc !acc m i  = acc * m + i
  {-# INLINE [1] toLinearIndexAcc #-}
  fromLinearIndex _ = id
  {-# INLINE [1] fromLinearIndex #-}
  fromLinearIndexAcc n k = k `quotRem` n
  {-# INLINE [1] fromLinearIndexAcc #-}
  repairIndex !k !i rBelow rOver
    | i < 0 = rBelow k i
    | i >= k = rOver k i
    | otherwise = i
  {-# INLINE [1] repairIndex #-}
  consDim i _ = i
  {-# INLINE [1] consDim #-}
  unconsDim i = (i, Ix0)
  {-# INLINE [1] unconsDim #-}
  snocDim _ i = i
  {-# INLINE [1] snocDim #-}
  unsnocDim i = (Ix0, i)
  {-# INLINE [1] unsnocDim #-}
  getIndex i 1 = Just i
  getIndex _ _ = Nothing
  {-# INLINE [1] getIndex #-}
  setIndex _ 1 i = Just i
  setIndex _ _ _ = Nothing
  {-# INLINE [1] setIndex #-}
  dropDim _ 1 = Just Ix0
  dropDim _ _ = Nothing
  {-# INLINE [1] dropDim #-}
  pullOutDim i 1 = Just (i, Ix0)
  pullOutDim _ _ = Nothing
  {-# INLINE [1] pullOutDim #-}
  insertDim Ix0 1 i = Just i
  insertDim _   _ _ = Nothing
  {-# INLINE [1] insertDim #-}
  pureIndex i = i
  {-# INLINE [1] pureIndex #-}
  liftIndex f = f
  {-# INLINE [1] liftIndex #-}
  liftIndex2 f = f
  {-# INLINE [1] liftIndex2 #-}
  foldlIndex f = f
  {-# INLINE [1] foldlIndex #-}
  iter k0 k1 inc cond = loop k0 (`cond` k1) (+inc)
  {-# INLINE iter #-}
  iterM k0 k1 inc cond = loopM k0 (`cond` k1) (+inc)
  {-# INLINE iterM #-}
  iterM_ k0 k1 inc cond = loopM_ k0 (`cond` k1) (+inc)
  {-# INLINE iterM_ #-}
instance Index Ix2T where
  type Dimensions Ix2T = 2
  dimensions _ = 2
  {-# INLINE [1] dimensions #-}
  totalElem (k2, k1) = k2 * k1
  {-# INLINE [1] totalElem #-}
  toLinearIndex (_, k1) (i2, i1) = k1 * i2 + i1
  {-# INLINE [1] toLinearIndex #-}
  fromLinearIndex (_, k1) !i = i `quotRem` k1
  {-# INLINE [1] fromLinearIndex #-}
  consDim = (,)
  {-# INLINE [1] consDim #-}
  unconsDim = id
  {-# INLINE [1] unconsDim #-}
  snocDim = (,)
  {-# INLINE [1] snocDim #-}
  unsnocDim = id
  {-# INLINE [1] unsnocDim #-}
  getIndex (i2,  _) 2 = Just i2
  getIndex ( _, i1) 1 = Just i1
  getIndex _      _ = Nothing
  {-# INLINE [1] getIndex #-}
  setIndex (_, i1) 2 i2 = Just (i2, i1)
  setIndex (i2, _) 1 i1 = Just (i2, i1)
  setIndex _      _ _ = Nothing
  {-# INLINE [1] setIndex #-}
  dropDim (_, i1) 2 = Just i1
  dropDim (i2, _) 1 = Just i2
  dropDim _      _ = Nothing
  {-# INLINE [1] dropDim #-}
  pullOutDim (i2, i1) 2 = Just (i2, i1)
  pullOutDim (i2, i1) 1 = Just (i1, i2)
  pullOutDim _        _ = Nothing
  {-# INLINE [1] pullOutDim #-}
  insertDim i1 2 i2 = Just (i2, i1)
  insertDim i2 1 i1 = Just (i2, i1)
  insertDim _  _  _ = Nothing
  {-# INLINE [1] insertDim #-}
  pureIndex i = (i, i)
  {-# INLINE [1] pureIndex #-}
  liftIndex2 f (i2, i1) (i2', i1') = (f i2 i2', f i1 i1')
  {-# INLINE [1] liftIndex2 #-}
instance Index Ix3T where
  type Dimensions Ix3T = 3
  dimensions _ = 3
  {-# INLINE [1] dimensions #-}
  totalElem  (k3, k2, k1) = k3 * k2 * k1
  {-# INLINE [1] totalElem #-}
  consDim i3 (i2, i1) = (i3, i2, i1)
  {-# INLINE [1] consDim #-}
  unconsDim (i3, i2, i1) = (i3, (i2, i1))
  {-# INLINE [1] unconsDim #-}
  snocDim (i3, i2) i1 = (i3, i2, i1)
  {-# INLINE [1] snocDim #-}
  unsnocDim (i3, i2, i1) = ((i3, i2), i1)
  {-# INLINE [1] unsnocDim #-}
  getIndex (i3,  _,  _) 3 = Just i3
  getIndex ( _, i2,  _) 2 = Just i2
  getIndex ( _,  _, i1) 1 = Just i1
  getIndex _            _ = Nothing
  {-# INLINE [1] getIndex #-}
  setIndex ( _, i2, i1) 3 i3 = Just (i3, i2, i1)
  setIndex (i3,  _, i1) 2 i2 = Just (i3, i2, i1)
  setIndex (i3, i2,  _) 1 i1 = Just (i3, i2, i1)
  setIndex _      _ _    = Nothing
  {-# INLINE [1] setIndex #-}
  dropDim ( _, i2, i1) 3 = Just (i2, i1)
  dropDim (i3,  _, i1) 2 = Just (i3, i1)
  dropDim (i3, i2,  _) 1 = Just (i3, i2)
  dropDim _      _    = Nothing
  {-# INLINE [1] dropDim #-}
  pullOutDim (i3, i2, i1) 3 = Just (i3, (i2, i1))
  pullOutDim (i3, i2, i1) 2 = Just (i2, (i3, i1))
  pullOutDim (i3, i2, i1) 1 = Just (i1, (i3, i2))
  pullOutDim _      _    = Nothing
  {-# INLINE [1] pullOutDim #-}
  insertDim (i2, i1) 3 i3 = Just (i3, i2, i1)
  insertDim (i3, i1) 2 i2 = Just (i3, i2, i1)
  insertDim (i3, i2) 1 i1 = Just (i3, i2, i1)
  insertDim _      _ _ = Nothing
  pureIndex i = (i, i, i)
  {-# INLINE [1] pureIndex #-}
  liftIndex2 f (i3, i2, i1) (i3', i2', i1') = (f i3 i3', f i2 i2', f i1 i1')
  {-# INLINE [1] liftIndex2 #-}
instance Index Ix4T where
  type Dimensions Ix4T = 4
  dimensions _ = 4
  {-# INLINE [1] dimensions #-}
  totalElem !(k4, k3, k2, k1) = k4 * k3 * k2 * k1
  {-# INLINE [1] totalElem #-}
  consDim i4 (i3, i2, i1) = (i4, i3, i2, i1)
  {-# INLINE [1] consDim #-}
  unconsDim (i4, i3, i2, i1) = (i4, (i3, i2, i1))
  {-# INLINE [1] unconsDim #-}
  snocDim (i4, i3, i2) i1 = (i4, i3, i2, i1)
  {-# INLINE [1] snocDim #-}
  unsnocDim (i4, i3, i2, i1) = ((i4, i3, i2), i1)
  {-# INLINE [1] unsnocDim #-}
  getIndex (i4,  _,  _,  _) 4 = Just i4
  getIndex ( _, i3,  _,  _) 3 = Just i3
  getIndex ( _,  _, i2,  _) 2 = Just i2
  getIndex ( _,  _,  _, i1) 1 = Just i1
  getIndex _                _ = Nothing
  {-# INLINE [1] getIndex #-}
  setIndex ( _, i3, i2, i1) 4 i4 = Just (i4, i3, i2, i1)
  setIndex (i4,  _, i2, i1) 3 i3 = Just (i4, i3, i2, i1)
  setIndex (i4, i3,  _, i1) 2 i2 = Just (i4, i3, i2, i1)
  setIndex (i4, i3, i2,  _) 1 i1 = Just (i4, i3, i2, i1)
  setIndex _                _  _ = Nothing
  {-# INLINE [1] setIndex #-}
  dropDim ( _, i3, i2, i1) 4 = Just (i3, i2, i1)
  dropDim (i4,  _, i2, i1) 3 = Just (i4, i2, i1)
  dropDim (i4, i3,  _, i1) 2 = Just (i4, i3, i1)
  dropDim (i4, i3, i2,  _) 1 = Just (i4, i3, i2)
  dropDim _                _ = Nothing
  {-# INLINE [1] dropDim #-}
  pullOutDim (i4, i3, i2, i1) 4 = Just (i4, (i3, i2, i1))
  pullOutDim (i4, i3, i2, i1) 3 = Just (i3, (i4, i2, i1))
  pullOutDim (i4, i3, i2, i1) 2 = Just (i2, (i4, i3, i1))
  pullOutDim (i4, i3, i2, i1) 1 = Just (i1, (i4, i3, i2))
  pullOutDim _                _ = Nothing
  {-# INLINE [1] pullOutDim #-}
  insertDim (i3, i2, i1) 4 i4 = Just (i4, i3, i2, i1)
  insertDim (i4, i2, i1) 3 i3 = Just (i4, i3, i2, i1)
  insertDim (i4, i3, i1) 2 i2 = Just (i4, i3, i2, i1)
  insertDim (i4, i3, i2) 1 i1 = Just (i4, i3, i2, i1)
  insertDim _            _  _ = Nothing
  {-# INLINE [1] insertDim #-}
  pureIndex i = (i, i, i, i)
  {-# INLINE [1] pureIndex #-}
  liftIndex2 f (i4, i3, i2, i1) (i4', i3', i2', i1') = (f i4 i4', f i3 i3', f i2 i2', f i1 i1')
  {-# INLINE [1] liftIndex2 #-}
instance Index Ix5T where
  type Dimensions Ix5T = 5
  dimensions _ = 5
  {-# INLINE [1] dimensions #-}
  totalElem !(n5, n4, n3, n2, n1) = n5 * n4 * n3 * n2 * n1
  {-# INLINE [1] totalElem #-}
  consDim i5 (i4, i3, i2, i1) = (i5, i4, i3, i2, i1)
  {-# INLINE [1] consDim #-}
  unconsDim (i5, i4, i3, i2, i1) = (i5, (i4, i3, i2, i1))
  {-# INLINE [1] unconsDim #-}
  snocDim (i5, i4, i3, i2) i1 = (i5, i4, i3, i2, i1)
  {-# INLINE [1] snocDim #-}
  unsnocDim (i5, i4, i3, i2, i1) = ((i5, i4, i3, i2), i1)
  {-# INLINE [1] unsnocDim #-}
  getIndex (i5,  _,  _,  _,  _) 5 = Just i5
  getIndex ( _, i4,  _,  _,  _) 4 = Just i4
  getIndex ( _,  _, i3,  _,  _) 3 = Just i3
  getIndex ( _,  _,  _, i2,  _) 2 = Just i2
  getIndex ( _,  _,  _,  _, i1) 1 = Just i1
  getIndex _                _     = Nothing
  {-# INLINE [1] getIndex #-}
  setIndex ( _, i4, i3, i2, i1) 5 i5 = Just (i5, i4, i3, i2, i1)
  setIndex (i5,  _, i3, i2, i1) 4 i4 = Just (i5, i4, i3, i2, i1)
  setIndex (i5, i4,  _, i2, i1) 3 i3 = Just (i5, i4, i3, i2, i1)
  setIndex (i5, i4, i3,  _, i1) 2 i2 = Just (i5, i4, i3, i2, i1)
  setIndex (i5, i4, i3, i2,  _) 1 i1 = Just (i5, i4, i3, i2, i1)
  setIndex _                    _  _ = Nothing
  {-# INLINE [1] setIndex #-}
  dropDim ( _, i4, i3, i2, i1) 5 = Just (i4, i3, i2, i1)
  dropDim (i5,  _, i3, i2, i1) 4 = Just (i5, i3, i2, i1)
  dropDim (i5, i4,  _, i2, i1) 3 = Just (i5, i4, i2, i1)
  dropDim (i5, i4, i3,  _, i1) 2 = Just (i5, i4, i3, i1)
  dropDim (i5, i4, i3, i2,  _) 1 = Just (i5, i4, i3, i2)
  dropDim _                    _ = Nothing
  {-# INLINE [1] dropDim #-}
  pullOutDim (i5, i4, i3, i2, i1) 5 = Just (i5, (i4, i3, i2, i1))
  pullOutDim (i5, i4, i3, i2, i1) 4 = Just (i4, (i5, i3, i2, i1))
  pullOutDim (i5, i4, i3, i2, i1) 3 = Just (i3, (i5, i4, i2, i1))
  pullOutDim (i5, i4, i3, i2, i1) 2 = Just (i2, (i5, i4, i3, i1))
  pullOutDim (i5, i4, i3, i2, i1) 1 = Just (i1, (i5, i4, i3, i2))
  pullOutDim _                    _ = Nothing
  {-# INLINE [1] pullOutDim #-}
  insertDim (i4, i3, i2, i1) 5 i5 = Just (i5, i4, i3, i2, i1)
  insertDim (i5, i3, i2, i1) 4 i4 = Just (i5, i4, i3, i2, i1)
  insertDim (i5, i4, i2, i1) 3 i3 = Just (i5, i4, i3, i2, i1)
  insertDim (i5, i4, i3, i1) 2 i2 = Just (i5, i4, i3, i2, i1)
  insertDim (i5, i4, i3, i2) 1 i1 = Just (i5, i4, i3, i2, i1)
  insertDim _            _  _     = Nothing
  {-# INLINE [1] insertDim #-}
  pureIndex i = (i, i, i, i, i)
  {-# INLINE [1] pureIndex #-}
  liftIndex2 f (i5, i4, i3, i2, i1) (i5', i4', i3', i2', i1') =
    (f i5 i5', f i4 i4', f i3 i3', f i2 i2', f i1 i1')
  {-# INLINE [1] liftIndex2 #-}
errorIx :: (Show ix, Show ix') => String -> ix -> ix' -> a
errorIx fName sz ix =
  error $
  fName ++
  ": Index out of bounds: (" ++ show ix ++ ") for Array of size: (" ++ show sz ++ ")"
{-# NOINLINE errorIx #-}
errorSizeMismatch :: (Show ix, Show ix') => String -> ix -> ix' -> a
errorSizeMismatch fName sz sz' =
  error $ fName ++ ": Mismatch in size of arrays " ++ show sz ++ " vs " ++ show sz'
{-# NOINLINE errorSizeMismatch #-}