module Data.Massiv.Array.Delayed.Windowed
( DW
, Array(..)
, makeWindowedArray
) where
import Control.Monad (when)
import Data.Massiv.Array.Delayed.Internal
import Data.Massiv.Core
import Data.Massiv.Core.Common
import Data.Massiv.Core.Scheduler
data DW
type instance EltRepr DW ix = D
data instance Array DW ix e = DWArray { wdArray :: !(Array D ix e)
, wdStencilSize :: Maybe ix
, wdWindowStartIndex :: !ix
, wdWindowSize :: !ix
, wdWindowUnsafeIndex :: ix -> e }
instance Index ix => Construct DW ix e where
getComp = dComp . wdArray
setComp c arr = arr { wdArray = (wdArray arr) { dComp = c } }
unsafeMakeArray c sz f = DWArray (unsafeMakeArray c sz f) Nothing zeroIndex zeroIndex f
instance Index ix => Size DW ix e where
size = size . wdArray
unsafeResize sz DWArray {..} =
let dArr = unsafeResize sz wdArray
in DWArray
{ wdArray = dArr
, wdStencilSize = Nothing
, wdWindowStartIndex = zeroIndex
, wdWindowSize = zeroIndex
, wdWindowUnsafeIndex = evaluateAt dArr
}
unsafeExtract sIx newSz = unsafeExtract sIx newSz . wdArray
instance Functor (Array DW ix) where
fmap f !arr =
arr
{ wdArray = fmap f (wdArray arr)
, wdWindowUnsafeIndex = f . wdWindowUnsafeIndex arr
}
makeWindowedArray
:: Source r ix e
=> Array r ix e
-> ix
-> ix
-> (ix -> e)
-> Array DW ix e
makeWindowedArray !arr !wIx !wSz wUnsafeIndex
| not (isSafeIndex sz wIx) =
error $
"Incorrect window starting index: " ++ show wIx ++ " for: " ++ show (size arr)
| liftIndex2 (+) wIx wSz > sz =
error $
"Incorrect window size: " ++
show wSz ++ " and/or placement: " ++ show wIx ++ " for: " ++ show (size arr)
| otherwise =
DWArray
{ wdArray = delay arr
, wdStencilSize = Nothing
, wdWindowStartIndex = wIx
, wdWindowSize = wSz
, wdWindowUnsafeIndex = wUnsafeIndex
}
where sz = size arr
instance Load DW Ix1 e where
loadS (DWArray (DArray _ sz indexB) _ it wk indexW) _ unsafeWrite = do
iterM_ 0 it 1 (<) $ \ !i -> unsafeWrite i (indexB i)
iterM_ it wk 1 (<) $ \ !i -> unsafeWrite i (indexW i)
iterM_ wk sz 1 (<) $ \ !i -> unsafeWrite i (indexB i)
loadP wIds (DWArray (DArray _ sz indexB) _ it wk indexW) _ unsafeWrite = do
divideWork_ wIds wk $ \ !scheduler !chunkLength !totalLength !slackStart -> do
scheduleWork scheduler $
iterM_ 0 it 1 (<) $ \ !ix ->
unsafeWrite (toLinearIndex sz ix) (indexB ix)
scheduleWork scheduler $
iterM_ wk sz 1 (<) $ \ !ix ->
unsafeWrite (toLinearIndex sz ix) (indexB ix)
loopM_ it (< (slackStart + it)) (+ chunkLength) $ \ !start ->
scheduleWork scheduler $
iterM_ start (start + chunkLength) 1 (<) $ \ !k ->
unsafeWrite k $ indexW k
scheduleWork scheduler $
iterM_ (slackStart + it) (totalLength + it) 1 (<) $ \ !k ->
unsafeWrite k (indexW k)
instance Load DW Ix2 e where
loadS arr _ unsafeWrite = do
let (DWArray (DArray _ sz@(m :. n) indexB) mStencilSz (it :. jt) (wm :. wn) indexW) =
arr
let (ib :. jb) = (wm + it) :. (wn + jt)
blockHeight = case mStencilSz of
Just (i :. _) -> i
_ -> 1
iterM_ (0 :. 0) (it :. n) 1 (<) $ \ !ix ->
unsafeWrite (toLinearIndex sz ix) (indexB ix)
iterM_ (ib :. 0) (m :. n) 1 (<) $ \ !ix ->
unsafeWrite (toLinearIndex sz ix) (indexB ix)
iterM_ (it :. 0) (ib :. jt) 1 (<) $ \ !ix ->
unsafeWrite (toLinearIndex sz ix) (indexB ix)
iterM_ (it :. jb) (ib :. n) 1 (<) $ \ !ix ->
unsafeWrite (toLinearIndex sz ix) (indexB ix)
unrollAndJam blockHeight (it :. ib) (jt :. jb) $ \ !ix ->
unsafeWrite (toLinearIndex sz ix) (indexW ix)
loadP wIds arr _ unsafeWrite = do
let (DWArray (DArray _ sz@(m :. n) indexB) mStencilSz (it :. jt) (wm :. wn) indexW) = arr
withScheduler_ wIds $ \scheduler -> do
let (ib :. jb) = (wm + it) :. (wn + jt)
!blockHeight = case mStencilSz of
Just (i :. _) -> i
_ -> 1
!(chunkHeight, slackHeight) = wm `quotRem` numWorkers scheduler
let loadBlock !it' !ib' =
unrollAndJam blockHeight (it' :. ib') (jt :. jb) $ \ !ix ->
unsafeWrite (toLinearIndex sz ix) (indexW ix)
scheduleWork scheduler $
iterM_ (0 :. 0) (it :. n) 1 (<) $ \ !ix ->
unsafeWrite (toLinearIndex sz ix) (indexB ix)
scheduleWork scheduler $
iterM_ (ib :. 0) (m :. n) 1 (<) $ \ !ix ->
unsafeWrite (toLinearIndex sz ix) (indexB ix)
scheduleWork scheduler $
iterM_ (it :. 0) (ib :. jt) 1 (<) $ \ !ix ->
unsafeWrite (toLinearIndex sz ix) (indexB ix)
scheduleWork scheduler $
iterM_ (it :. jb) (ib :. n) 1 (<) $ \ !ix ->
unsafeWrite (toLinearIndex sz ix) (indexB ix)
loopM_ 0 (< numWorkers scheduler) (+ 1) $ \ !wid -> do
let !it' = wid * chunkHeight + it
scheduleWork scheduler $ loadBlock it' (it' + chunkHeight)
when (slackHeight > 0) $ do
let !itSlack = (numWorkers scheduler) * chunkHeight + it
scheduleWork scheduler $
loadBlock itSlack (itSlack + slackHeight)
instance (Index ix, Load DW (Lower ix) e) => Load DW ix e where
loadS = loadWindowedSRec
loadP = loadWindowedPRec
loadWindowedSRec :: (Index ix, Load DW (Lower ix) e, Monad m) =>
Array DW ix e -> (Int -> m e) -> (Int -> e -> m ()) -> m ()
loadWindowedSRec (DWArray darr mStencilSz tix wSz indexW) _unsafeRead unsafeWrite = do
let DArray _ sz indexB = darr
!szL = tailDim sz
!bix = liftIndex2 (+) tix wSz
!(t, tixL) = unconsDim tix
!pageElements = totalElem szL
unsafeWriteLower i k val = unsafeWrite (k + pageElements * i) val
iterM_ zeroIndex tix 1 (<) $ \ !ix ->
unsafeWrite (toLinearIndex sz ix) (indexB ix)
iterM_ bix sz 1 (<) $ \ !ix ->
unsafeWrite (toLinearIndex sz ix) (indexB ix)
loopM_ t (< headDim bix) (+ 1) $ \ !i ->
let !lowerArr =
(DWArray
(DArray Seq szL (indexB . consDim i))
(tailDim <$> mStencilSz)
tixL
(tailDim wSz)
(indexW . consDim i))
in loadS lowerArr _unsafeRead (unsafeWriteLower i)
loadWindowedPRec :: (Index ix, Load DW (Lower ix) e) =>
[Int] -> Array DW ix e -> (Int -> IO e) -> (Int -> e -> IO ()) -> IO ()
loadWindowedPRec wIds (DWArray darr mStencilSz tix wSz indexW) _unsafeRead unsafeWrite = do
withScheduler_ wIds $ \ scheduler -> do
let DArray _ sz indexB = darr
!szL = tailDim sz
!bix = liftIndex2 (+) tix wSz
!(t, tixL) = unconsDim tix
!pageElements = totalElem szL
unsafeWriteLower i k = unsafeWrite (k + pageElements * i)
scheduleWork scheduler $
iterM_ zeroIndex tix 1 (<) $ \ !ix ->
unsafeWrite (toLinearIndex sz ix) (indexB ix)
scheduleWork scheduler $
iterM_ bix sz 1 (<) $ \ !ix ->
unsafeWrite (toLinearIndex sz ix) (indexB ix)
loopM_ t (< headDim bix) (+ 1) $ \ !i ->
let !lowerArr =
(DWArray
(DArray Seq szL (indexB . consDim i))
(tailDim <$> mStencilSz)
tixL
(tailDim wSz)
(indexW . consDim i))
in scheduleWork scheduler $
loadS
lowerArr
(_unsafeRead)
(unsafeWriteLower i)
unrollAndJam :: Monad m =>
Int -> Ix2 -> Ix2 -> (Ix2 -> m a) -> m ()
unrollAndJam !bH (it :. ib) (jt :. jb) f = do
let !bH' = min (max 1 bH) 7
let f2 (i :. j) = f (i :. j) >> f ((i + 1) :. j)
let f3 (i :. j) = f (i :. j) >> f2 ((i + 1) :. j)
let f4 (i :. j) = f (i :. j) >> f3 ((i + 1) :. j)
let f5 (i :. j) = f (i :. j) >> f4 ((i + 1) :. j)
let f6 (i :. j) = f (i :. j) >> f5 ((i + 1) :. j)
let f7 (i :. j) = f (i :. j) >> f6 ((i + 1) :. j)
let f' = case bH' of
1 -> f
2 -> f2
3 -> f3
4 -> f4
5 -> f5
6 -> f6
_ -> f7
let !ibS = ib ((ib it) `mod` bH')
loopM_ it (< ibS) (+ bH') $ \ !i ->
loopM_ jt (< jb) (+ 1) $ \ !j ->
f' (i :. j)
loopM_ ibS (< ib) (+ 1) $ \ !i ->
loopM_ jt (< jb) (+ 1) $ \ !j ->
f (i :. j)
instance Load DW Ix2T e where
loadS arr _ unsafeWrite = do
let (DWArray (DArray _ sz@(m, n) indexB) mStencilSz (it, jt) (wm, wn) indexW) =
arr
let (ib, jb) = (wm + it, wn + jt)
blockHeight = case mStencilSz of
Just (i, _) -> i
_ -> 1
iterM_ (0, 0) (it, n) 1 (<) $ \ !ix ->
unsafeWrite (toLinearIndex sz ix) (indexB ix)
iterM_ (ib, 0) (m, n) 1 (<) $ \ !ix ->
unsafeWrite (toLinearIndex sz ix) (indexB ix)
iterM_ (it, 0) (ib, jt) 1 (<) $ \ !ix ->
unsafeWrite (toLinearIndex sz ix) (indexB ix)
iterM_ (it, jb) (ib, n) 1 (<) $ \ !ix ->
unsafeWrite (toLinearIndex sz ix) (indexB ix)
unrollAndJamT blockHeight (it, ib) (jt, jb) $ \ !ix ->
unsafeWrite (toLinearIndex sz ix) (indexW ix)
loadP wIds arr _ unsafeWrite = do
let (DWArray (DArray _ sz@(m, n) indexB) mStencilSz (it, jt) (wm, wn) indexW) = arr
withScheduler_ wIds $ \ scheduler -> do
let (ib, jb) = (wm + it, wn + jt)
blockHeight = case mStencilSz of
Just (i, _) -> i
_ -> 1
!(chunkHeight, slackHeight) = wm `quotRem` numWorkers scheduler
let loadBlock !it' !ib' =
unrollAndJamT blockHeight (it', ib') (jt, jb) $ \ !ix ->
unsafeWrite (toLinearIndex sz ix) (indexW ix)
scheduleWork scheduler $
iterM_ (0, 0) (it, n) 1 (<) $ \ !ix ->
unsafeWrite (toLinearIndex sz ix) (indexB ix)
scheduleWork scheduler $
iterM_ (ib, 0) (m, n) 1 (<) $ \ !ix ->
unsafeWrite (toLinearIndex sz ix) (indexB ix)
scheduleWork scheduler $
iterM_ (it, 0) (ib, jt) 1 (<) $ \ !ix ->
unsafeWrite (toLinearIndex sz ix) (indexB ix)
scheduleWork scheduler $
iterM_ (it, jb) (ib, n) 1 (<) $ \ !ix ->
unsafeWrite (toLinearIndex sz ix) (indexB ix)
loopM_ 0 (< numWorkers scheduler) (+ 1) $ \ !wid -> do
let !it' = wid * chunkHeight + it
scheduleWork scheduler $ loadBlock it' (it' + chunkHeight)
when (slackHeight > 0) $ do
let !itSlack = (numWorkers scheduler) * chunkHeight + it
scheduleWork scheduler $ loadBlock itSlack (itSlack + slackHeight)
unrollAndJamT :: Monad m =>
Int -> Ix2T -> Ix2T -> (Ix2T -> m a) -> m ()
unrollAndJamT !bH (it, ib) (jt, jb) f = do
let !bH' = min (max 1 bH) 7
let f2 !(i, j) = f (i, j) >> f (i+1, j)
let f3 !(i, j) = f (i, j) >> f2 (i+1, j)
let f4 !(i, j) = f (i, j) >> f3 (i+1, j)
let f5 !(i, j) = f (i, j) >> f4 (i+1, j)
let f6 !(i, j) = f (i, j) >> f5 (i+1, j)
let f7 !(i, j) = f (i, j) >> f6 (i+1, j)
let f' = case bH' of
1 -> f
2 -> f2
3 -> f3
4 -> f4
5 -> f5
6 -> f6
_ -> f7
let !ibS = ib ((ib it) `mod` bH')
loopM_ it (< ibS) (+ bH') $ \ !i ->
loopM_ jt (< jb) (+ 1) $ \ !j ->
f' (i, j)
loopM_ ibS (< ib) (+ 1) $ \ !i ->
loopM_ jt (< jb) (+ 1) $ \ !j ->
f (i, j)