{-# LANGUAGE DataKinds #-} {-# LANGUAGE BangPatterns #-} {-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE TypeFamilies #-} -- | -- Module : Data.Massiv.Core.Index.Class -- Copyright : (c) Alexey Kuleshevich 2018 -- License : BSD3 -- Maintainer : Alexey Kuleshevich -- Stability : experimental -- Portability : non-portable -- module Data.Massiv.Core.Index.Class where import Control.DeepSeq (NFData (..)) import Data.Functor.Identity (runIdentity) import Data.Massiv.Core.Iterator import GHC.TypeLits newtype Dim = Dim Int deriving (Show, Eq, Ord, Num, Real, Integral, Enum) data Ix0 = Ix0 deriving (Eq, Ord, Show) type Ix1T = Int type Ix2T = (Int, Int) type Ix3T = (Int, Int, Int) type Ix4T = (Int, Int, Int, Int) type Ix5T = (Int, Int, Int, Int, Int) type family Lower ix :: * type instance Lower Ix1T = Ix0 type instance Lower Ix2T = Ix1T type instance Lower Ix3T = Ix2T type instance Lower Ix4T = Ix3T type instance Lower Ix5T = Ix4T class (Eq ix, Ord ix, Show ix, NFData ix) => Index ix where type Rank ix :: Nat rank :: ix -> Dim -- | Total number of elements in an array of this size. totalElem :: ix -> Int consDim :: Int -> Lower ix -> ix unconsDim :: ix -> (Int, Lower ix) snocDim :: Lower ix -> Int -> ix unsnocDim :: ix -> (Lower ix, Int) dropDim :: ix -> Dim -> Maybe (Lower ix) getIndex :: ix -> Dim -> Maybe Int setIndex :: ix -> Dim -> Int -> Maybe ix pureIndex :: Int -> ix -- | Zip together two indices with a function liftIndex2 :: (Int -> Int -> Int) -> ix -> ix -> ix zeroIndex :: ix zeroIndex = pureIndex 0 {-# INLINE [1] zeroIndex #-} -- | Map a function over an index liftIndex :: (Int -> Int) -> ix -> ix liftIndex f = liftIndex2 (\_ i -> f i) zeroIndex {-# INLINE [1] liftIndex #-} -- | Check whether index is within the size. isSafeIndex :: ix -- ^ Size -> ix -- ^ Index -> Bool default isSafeIndex :: Index (Lower ix) => ix -> ix -> Bool isSafeIndex !sz !ix = isSafeIndex n0 i0 && isSafeIndex szL ixL where !(n0, szL) = unconsDim sz !(i0, ixL) = unconsDim ix {-# INLINE [1] isSafeIndex #-} -- | Produce linear index from size and index toLinearIndex :: ix -- ^ Size -> ix -- ^ Index -> Int default toLinearIndex :: Index (Lower ix) => ix -> ix -> Int toLinearIndex !sz !ix = toLinearIndex szL ixL * n + i where !(szL, n) = unsnocDim sz !(ixL, i) = unsnocDim ix {-# INLINE [1] toLinearIndex #-} toLinearIndexAcc :: Int -> ix -> ix -> Int default toLinearIndexAcc :: Index (Lower ix) => Int -> ix -> ix -> Int toLinearIndexAcc !acc !sz !ix = toLinearIndexAcc (acc * n + i) szL ixL where !(n, szL) = unconsDim sz !(i, ixL) = unconsDim ix {-# INLINE [1] toLinearIndexAcc #-} -- | Produce N Dim index from size and linear index fromLinearIndex :: ix -> Int -> ix default fromLinearIndex :: Index (Lower ix) => ix -> Int -> ix fromLinearIndex sz k = consDim q ixL where !(q, ixL) = fromLinearIndexAcc (snd (unconsDim sz)) k {-# INLINE [1] fromLinearIndex #-} fromLinearIndexAcc :: ix -> Int -> (Int, ix) default fromLinearIndexAcc :: Index (Lower ix) => ix -> Int -> (Int, ix) fromLinearIndexAcc ix' !k = (q, consDim r ixL) where !(m, ix) = unconsDim ix' !(kL, ixL) = fromLinearIndexAcc ix k !(q, r) = quotRem kL m {-# INLINE [1] fromLinearIndexAcc #-} repairIndex :: ix -> ix -> (Int -> Int -> Int) -> (Int -> Int -> Int) -> ix default repairIndex :: Index (Lower ix) => ix -> ix -> (Int -> Int -> Int) -> (Int -> Int -> Int) -> ix repairIndex !sz !ix rBelow rOver = consDim (repairIndex n i rBelow rOver) (repairIndex szL ixL rBelow rOver) where !(n, szL) = unconsDim sz !(i, ixL) = unconsDim ix {-# INLINE [1] repairIndex #-} iter :: ix -> ix -> Int -> (Int -> Int -> Bool) -> a -> (ix -> a -> a) -> a iter sIx eIx inc cond acc f = runIdentity $ iterM sIx eIx inc cond acc (\ix -> return . f ix) {-# INLINE iter #-} iterM :: Monad m => ix -- ^ Start index -> ix -- ^ End index -> Int -- ^ Increment -> (Int -> Int -> Bool) -- ^ Continue iteration while predicate is True (eg. until end of row) -> a -- ^ Initial value for an accumulator -> (ix -> a -> m a) -- ^ Accumulator function -> m a default iterM :: (Index (Lower ix), Monad m) => ix -> ix -> Int -> (Int -> Int -> Bool) -> a -> (ix -> a -> m a) -> m a iterM !sIx !eIx !inc cond !acc f = loopM k0 (`cond` k1) (+ inc) acc $ \ !i !acc0 -> iterM sIxL eIxL inc cond acc0 $ \ !ix -> f (consDim i ix) where !(k0, sIxL) = unconsDim sIx !(k1, eIxL) = unconsDim eIx {-# INLINE iterM #-} iterM_ :: Monad m => ix -> ix -> Int -> (Int -> Int -> Bool) -> (ix -> m a) -> m () default iterM_ :: (Index (Lower ix), Monad m) => ix -> ix -> Int -> (Int -> Int -> Bool) -> (ix -> m a) -> m () iterM_ !sIx !eIx !inc cond f = loopM_ k0 (`cond` k1) (+ inc) $ \ !i -> iterM_ sIxL eIxL inc cond $ \ !ix -> f (consDim i ix) where !(k0, sIxL) = unconsDim sIx !(k1, eIxL) = unconsDim eIx {-# INLINE iterM_ #-} instance Index Ix1T where type Rank Ix1T = 1 rank _ = 1 {-# INLINE [1] rank #-} totalElem = id {-# INLINE [1] totalElem #-} isSafeIndex !k !i = 0 <= i && i < k {-# INLINE [1] isSafeIndex #-} toLinearIndex _ = id {-# INLINE [1] toLinearIndex #-} toLinearIndexAcc !acc m i = acc * m + i {-# INLINE [1] toLinearIndexAcc #-} fromLinearIndex _ = id {-# INLINE [1] fromLinearIndex #-} fromLinearIndexAcc n k = k `quotRem` n {-# INLINE [1] fromLinearIndexAcc #-} repairIndex !k !i rBelow rOver | i < 0 = rBelow k i | i >= k = rOver k i | otherwise = i {-# INLINE [1] repairIndex #-} consDim i _ = i {-# INLINE [1] consDim #-} unconsDim i = (i, Ix0) {-# INLINE [1] unconsDim #-} snocDim _ i = i {-# INLINE [1] snocDim #-} unsnocDim i = (Ix0, i) {-# INLINE [1] unsnocDim #-} getIndex i 1 = Just i getIndex _ _ = Nothing {-# INLINE [1] getIndex #-} setIndex _ 1 i = Just i setIndex _ _ _ = Nothing {-# INLINE [1] setIndex #-} dropDim _ 1 = Just Ix0 dropDim _ _ = Nothing {-# INLINE [1] dropDim #-} pureIndex i = i {-# INLINE [1] pureIndex #-} liftIndex f = f {-# INLINE [1] liftIndex #-} liftIndex2 f = f {-# INLINE [1] liftIndex2 #-} iter k0 k1 inc cond = loop k0 (`cond` k1) (+inc) {-# INLINE iter #-} iterM k0 k1 inc cond = loopM k0 (`cond` k1) (+inc) {-# INLINE iterM #-} iterM_ k0 k1 inc cond = loopM_ k0 (`cond` k1) (+inc) {-# INLINE iterM_ #-} instance Index Ix2T where type Rank Ix2T = 2 rank _ = 2 {-# INLINE [1] rank #-} totalElem !(m, n) = m * n {-# INLINE [1] totalElem #-} toLinearIndex !(_, n) !(i, j) = n * i + j {-# INLINE [1] toLinearIndex #-} fromLinearIndex (_, n) !k = k `quotRem` n {-# INLINE [1] fromLinearIndex #-} consDim = (,) {-# INLINE [1] consDim #-} unconsDim = id {-# INLINE [1] unconsDim #-} snocDim = (,) {-# INLINE [1] snocDim #-} unsnocDim = id {-# INLINE [1] unsnocDim #-} getIndex (i, _) 2 = Just i getIndex (_, j) 1 = Just j getIndex _ _ = Nothing {-# INLINE [1] getIndex #-} setIndex (_, j) 2 i = Just (i, j) setIndex (i, _) 1 j = Just (i, j) setIndex _ _ _ = Nothing {-# INLINE [1] setIndex #-} dropDim (_, j) 2 = Just j dropDim (i, _) 1 = Just i dropDim _ _ = Nothing {-# INLINE [1] dropDim #-} pureIndex i = (i, i) {-# INLINE [1] pureIndex #-} liftIndex2 f (i0, j0) (i1, j1) = (f i0 i1, f j0 j1) {-# INLINE [1] liftIndex2 #-} instance Index Ix3T where type Rank Ix3T = 3 rank _ = 3 {-# INLINE [1] rank #-} totalElem !(m, n, o) = m * n * o {-# INLINE [1] totalElem #-} consDim i (j, k) = (i, j, k) {-# INLINE [1] consDim #-} unconsDim (i, j, k) = (i, (j, k)) {-# INLINE [1] unconsDim #-} snocDim (i, j) k = (i, j, k) {-# INLINE [1] snocDim #-} unsnocDim (i, j, k) = ((i, j), k) {-# INLINE [1] unsnocDim #-} getIndex (i, _, _) 3 = Just i getIndex (_, j, _) 2 = Just j getIndex (_, _, k) 1 = Just k getIndex _ _ = Nothing {-# INLINE [1] getIndex #-} setIndex (_, j, k) 3 i = Just (i, j, k) setIndex (i, _, k) 2 j = Just (i, j, k) setIndex (i, j, _) 1 k = Just (i, j, k) setIndex _ _ _ = Nothing {-# INLINE [1] setIndex #-} dropDim (_, j, k) 3 = Just (j, k) dropDim (i, _, k) 2 = Just (i, k) dropDim (i, j, _) 1 = Just (i, j) dropDim _ _ = Nothing {-# INLINE [1] dropDim #-} pureIndex i = (i, i, i) {-# INLINE [1] pureIndex #-} liftIndex2 f (i0, j0, k0) (i1, j1, k1) = (f i0 i1, f j0 j1, f k0 k1) {-# INLINE [1] liftIndex2 #-} instance Index Ix4T where type Rank Ix4T = 4 rank _ = 4 {-# INLINE [1] rank #-} totalElem !(n1, n2, n3, n4) = n1 * n2 * n3 * n4 {-# INLINE [1] totalElem #-} consDim i1 (i2, i3, i4) = (i1, i2, i3, i4) {-# INLINE [1] consDim #-} unconsDim (i1, i2, i3, i4) = (i1, (i2, i3, i4)) {-# INLINE [1] unconsDim #-} snocDim (i1, i2, i3) i4 = (i1, i2, i3, i4) {-# INLINE [1] snocDim #-} unsnocDim (i1, i2, i3, i4) = ((i1, i2, i3), i4) {-# INLINE [1] unsnocDim #-} getIndex (i1, _, _, _) 4 = Just i1 getIndex ( _, i2, _, _) 3 = Just i2 getIndex ( _, _, i3, _) 2 = Just i3 getIndex ( _, _, _, i4) 1 = Just i4 getIndex _ _ = Nothing {-# INLINE [1] getIndex #-} setIndex ( _, i2, i3, i4) 4 i1 = Just (i1, i2, i3, i4) setIndex (i1, _, i3, i4) 3 i2 = Just (i1, i2, i3, i4) setIndex (i1, i2, _, i4) 2 i3 = Just (i1, i2, i3, i4) setIndex (i1, i2, i3, _) 1 i4 = Just (i1, i2, i3, i4) setIndex _ _ _ = Nothing {-# INLINE [1] setIndex #-} dropDim ( _, i2, i3, i4) 4 = Just (i2, i3, i4) dropDim (i1, _, i3, i4) 3 = Just (i1, i3, i4) dropDim (i1, i2, _, i4) 2 = Just (i1, i2, i4) dropDim (i1, i2, i3, _) 1 = Just (i1, i2, i3) dropDim _ _ = Nothing {-# INLINE [1] dropDim #-} pureIndex i = (i, i, i, i) {-# INLINE [1] pureIndex #-} liftIndex2 f (i0, i1, i2, i3) (j0, j1, j2, j3) = (f i0 j0, f i1 j1, f i2 j2, f i3 j3) {-# INLINE [1] liftIndex2 #-} instance Index Ix5T where type Rank Ix5T = 5 rank _ = 5 {-# INLINE [1] rank #-} totalElem !(n1, n2, n3, n4, n5) = n1 * n2 * n3 * n4 * n5 {-# INLINE [1] totalElem #-} consDim i1 (i2, i3, i4, i5) = (i1, i2, i3, i4, i5) {-# INLINE [1] consDim #-} unconsDim (i1, i2, i3, i4, i5) = (i1, (i2, i3, i4, i5)) {-# INLINE [1] unconsDim #-} snocDim (i1, i2, i3, i4) i5 = (i1, i2, i3, i4, i5) {-# INLINE [1] snocDim #-} unsnocDim (i1, i2, i3, i4, i5) = ((i1, i2, i3, i4), i5) {-# INLINE [1] unsnocDim #-} getIndex (i1, _, _, _, _) 5 = Just i1 getIndex ( _, i2, _, _, _) 4 = Just i2 getIndex ( _, _, i3, _, _) 3 = Just i3 getIndex ( _, _, _, i4, _) 2 = Just i4 getIndex ( _, _, _, _, i5) 1 = Just i5 getIndex _ _ = Nothing {-# INLINE [1] getIndex #-} setIndex ( _, i2, i3, i4, i5) 5 i1 = Just (i1, i2, i3, i4, i5) setIndex (i1, _, i3, i4, i5) 4 i2 = Just (i1, i2, i3, i4, i5) setIndex (i1, i2, _, i4, i5) 3 i3 = Just (i1, i2, i3, i4, i5) setIndex (i1, i2, i3, _, i5) 2 i4 = Just (i1, i2, i3, i4, i5) setIndex (i1, i2, i3, i4, _) 1 i5 = Just (i1, i2, i3, i4, i5) setIndex _ _ _ = Nothing {-# INLINE [1] setIndex #-} dropDim ( _, i2, i3, i4, i5) 5 = Just (i2, i3, i4, i5) dropDim (i1, _, i3, i4, i5) 4 = Just (i1, i3, i4, i5) dropDim (i1, i2, _, i4, i5) 3 = Just (i1, i2, i4, i5) dropDim (i1, i2, i3, _, i5) 2 = Just (i1, i2, i3, i5) dropDim (i1, i2, i3, i4, _) 1 = Just (i1, i2, i3, i4) dropDim _ _ = Nothing {-# INLINE [1] dropDim #-} pureIndex i = (i, i, i, i, i) {-# INLINE [1] pureIndex #-} liftIndex2 f (i0, i1, i2, i3, i4) (j0, j1, j2, j3, j4) = (f i0 j0, f i1 j1, f i2 j2, f i3 j3, f i4 j4) {-# INLINE [1] liftIndex2 #-} errorIx :: (Show ix, Show ix') => String -> ix -> ix' -> a errorIx fName sz ix = error $ fName ++ ": Index out of bounds: " ++ show ix ++ " for Array of size: " ++ show sz {-# NOINLINE errorIx #-}