{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE BangPatterns #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} -- | -- Module : Data.Massiv.Array.Delayed.Windowed -- Copyright : (c) Alexey Kuleshevich 2018 -- License : BSD3 -- Maintainer : Alexey Kuleshevich -- Stability : experimental -- Portability : non-portable -- module Data.Massiv.Array.Delayed.Windowed ( DW , Array(..) , makeWindowedArray ) where import Control.Monad (when) import Data.Massiv.Array.Delayed.Internal import Data.Massiv.Core import Data.Massiv.Core.Common import Data.Massiv.Core.Scheduler -- | Delayed Windowed Array representation. data DW type instance EltRepr DW ix = D data instance Array DW ix e = DWArray { wdArray :: !(Array D ix e) , wdStencilSize :: Maybe ix -- ^ Setting this value during stencil -- application improves cache utilization -- while computing an array , wdWindowStartIndex :: !ix , wdWindowSize :: !ix , wdWindowUnsafeIndex :: ix -> e } instance Index ix => Construct DW ix e where getComp = dComp . wdArray {-# INLINE getComp #-} setComp c arr = arr { wdArray = (wdArray arr) { dComp = c } } {-# INLINE setComp #-} unsafeMakeArray c sz f = DWArray (unsafeMakeArray c sz f) Nothing zeroIndex zeroIndex f {-# INLINE unsafeMakeArray #-} -- | Any resize or extract on Windowed Array will hurt the performance. instance Index ix => Size DW ix e where size = size . wdArray {-# INLINE size #-} unsafeResize sz DWArray {..} = let dArr = unsafeResize sz wdArray in DWArray { wdArray = dArr , wdStencilSize = Nothing , wdWindowStartIndex = zeroIndex , wdWindowSize = zeroIndex , wdWindowUnsafeIndex = evaluateAt dArr } unsafeExtract sIx newSz = unsafeExtract sIx newSz . wdArray instance Functor (Array DW ix) where fmap f !arr = arr { wdArray = fmap f (wdArray arr) , wdWindowUnsafeIndex = f . wdWindowUnsafeIndex arr } {-# INLINE fmap #-} -- | Supply a separate generating function for interior of an array. This is -- very usful for stencil mapping, where interior function does not perform -- boundary checks, thus significantly speeding up computation process. makeWindowedArray :: Source r ix e => Array r ix e -- ^ Source array that will have a window inserted into it -> ix -- ^ Start index for the window -> ix -- ^ Size of the window -> (ix -> e) -- ^ Inside window indexing function -> Array DW ix e makeWindowedArray !arr !wIx !wSz wUnsafeIndex | not (isSafeIndex sz wIx) = error $ "Incorrect window starting index: " ++ show wIx ++ " for: " ++ show (size arr) | liftIndex2 (+) wIx wSz > sz = error $ "Incorrect window size: " ++ show wSz ++ " and/or placement: " ++ show wIx ++ " for: " ++ show (size arr) | otherwise = DWArray { wdArray = delay arr , wdStencilSize = Nothing , wdWindowStartIndex = wIx , wdWindowSize = wSz , wdWindowUnsafeIndex = wUnsafeIndex } where sz = size arr {-# INLINE makeWindowedArray #-} instance {-# OVERLAPPING #-} Load DW Ix1 e where loadS (DWArray (DArray _ sz indexB) _ it wk indexW) _ unsafeWrite = do iterM_ 0 it 1 (<) $ \ !i -> unsafeWrite i (indexB i) iterM_ it wk 1 (<) $ \ !i -> unsafeWrite i (indexW i) iterM_ wk sz 1 (<) $ \ !i -> unsafeWrite i (indexB i) {-# INLINE loadS #-} loadP wIds (DWArray (DArray _ sz indexB) _ it wk indexW) _ unsafeWrite = do divideWork_ wIds wk $ \ !scheduler !chunkLength !totalLength !slackStart -> do scheduleWork scheduler $ iterM_ 0 it 1 (<) $ \ !ix -> unsafeWrite (toLinearIndex sz ix) (indexB ix) scheduleWork scheduler $ iterM_ wk sz 1 (<) $ \ !ix -> unsafeWrite (toLinearIndex sz ix) (indexB ix) loopM_ it (< (slackStart + it)) (+ chunkLength) $ \ !start -> scheduleWork scheduler $ iterM_ start (start + chunkLength) 1 (<) $ \ !k -> unsafeWrite k $ indexW k scheduleWork scheduler $ iterM_ (slackStart + it) (totalLength + it) 1 (<) $ \ !k -> unsafeWrite k (indexW k) {-# INLINE loadP #-} instance {-# OVERLAPPING #-} Load DW Ix2 e where loadS arr _ unsafeWrite = do let (DWArray (DArray _ sz@(m :. n) indexB) mStencilSz (it :. jt) (wm :. wn) indexW) = arr let (ib :. jb) = (wm + it) :. (wn + jt) blockHeight = case mStencilSz of Just (i :. _) -> i _ -> 1 iterM_ (0 :. 0) (it :. n) 1 (<) $ \ !ix -> unsafeWrite (toLinearIndex sz ix) (indexB ix) iterM_ (ib :. 0) (m :. n) 1 (<) $ \ !ix -> unsafeWrite (toLinearIndex sz ix) (indexB ix) iterM_ (it :. 0) (ib :. jt) 1 (<) $ \ !ix -> unsafeWrite (toLinearIndex sz ix) (indexB ix) iterM_ (it :. jb) (ib :. n) 1 (<) $ \ !ix -> unsafeWrite (toLinearIndex sz ix) (indexB ix) unrollAndJam blockHeight (it :. ib) (jt :. jb) $ \ !ix -> unsafeWrite (toLinearIndex sz ix) (indexW ix) {-# INLINE loadS #-} loadP wIds arr _ unsafeWrite = do let (DWArray (DArray _ sz@(m :. n) indexB) mStencilSz (it :. jt) (wm :. wn) indexW) = arr withScheduler_ wIds $ \scheduler -> do let (ib :. jb) = (wm + it) :. (wn + jt) !blockHeight = case mStencilSz of Just (i :. _) -> i _ -> 1 !(chunkHeight, slackHeight) = wm `quotRem` numWorkers scheduler let loadBlock !it' !ib' = unrollAndJam blockHeight (it' :. ib') (jt :. jb) $ \ !ix -> unsafeWrite (toLinearIndex sz ix) (indexW ix) {-# INLINE loadBlock #-} scheduleWork scheduler $ iterM_ (0 :. 0) (it :. n) 1 (<) $ \ !ix -> unsafeWrite (toLinearIndex sz ix) (indexB ix) scheduleWork scheduler $ iterM_ (ib :. 0) (m :. n) 1 (<) $ \ !ix -> unsafeWrite (toLinearIndex sz ix) (indexB ix) scheduleWork scheduler $ iterM_ (it :. 0) (ib :. jt) 1 (<) $ \ !ix -> unsafeWrite (toLinearIndex sz ix) (indexB ix) scheduleWork scheduler $ iterM_ (it :. jb) (ib :. n) 1 (<) $ \ !ix -> unsafeWrite (toLinearIndex sz ix) (indexB ix) loopM_ 0 (< numWorkers scheduler) (+ 1) $ \ !wid -> do let !it' = wid * chunkHeight + it scheduleWork scheduler $ loadBlock it' (it' + chunkHeight) when (slackHeight > 0) $ do let !itSlack = (numWorkers scheduler) * chunkHeight + it scheduleWork scheduler $ loadBlock itSlack (itSlack + slackHeight) {-# INLINE loadP #-} -- instance {-# OVERLAPPING #-} Load DW Ix3 e where -- loadS = loadWindowedSRec -- {-# INLINE loadS #-} -- loadP = loadWindowedPRec -- {-# INLINE loadP #-} instance {-# OVERLAPPABLE #-} (Index ix, Load DW (Lower ix) e) => Load DW ix e where loadS = loadWindowedSRec {-# INLINE loadS #-} loadP = loadWindowedPRec {-# INLINE loadP #-} loadWindowedSRec :: (Index ix, Load DW (Lower ix) e, Monad m) => Array DW ix e -> (Int -> m e) -> (Int -> e -> m ()) -> m () loadWindowedSRec (DWArray darr mStencilSz tix wSz indexW) _unsafeRead unsafeWrite = do let DArray _ sz indexB = darr !szL = tailDim sz !bix = liftIndex2 (+) tix wSz !(t, tixL) = unconsDim tix !pageElements = totalElem szL unsafeWriteLower i k val = unsafeWrite (k + pageElements * i) val {-# INLINE unsafeWriteLower #-} iterM_ zeroIndex tix 1 (<) $ \ !ix -> unsafeWrite (toLinearIndex sz ix) (indexB ix) iterM_ bix sz 1 (<) $ \ !ix -> unsafeWrite (toLinearIndex sz ix) (indexB ix) loopM_ t (< headDim bix) (+ 1) $ \ !i -> let !lowerArr = (DWArray (DArray Seq szL (indexB . consDim i)) (tailDim <$> mStencilSz) -- can safely drop the dim, only -- last 2 matter anyways tixL (tailDim wSz) (indexW . consDim i)) in loadS lowerArr _unsafeRead (unsafeWriteLower i) {-# INLINE loadWindowedSRec #-} loadWindowedPRec :: (Index ix, Load DW (Lower ix) e) => [Int] -> Array DW ix e -> (Int -> IO e) -> (Int -> e -> IO ()) -> IO () loadWindowedPRec wIds (DWArray darr mStencilSz tix wSz indexW) _unsafeRead unsafeWrite = do withScheduler_ wIds $ \ scheduler -> do let DArray _ sz indexB = darr !szL = tailDim sz !bix = liftIndex2 (+) tix wSz !(t, tixL) = unconsDim tix !pageElements = totalElem szL unsafeWriteLower i k = unsafeWrite (k + pageElements * i) {-# INLINE unsafeWriteLower #-} scheduleWork scheduler $ iterM_ zeroIndex tix 1 (<) $ \ !ix -> unsafeWrite (toLinearIndex sz ix) (indexB ix) scheduleWork scheduler $ iterM_ bix sz 1 (<) $ \ !ix -> unsafeWrite (toLinearIndex sz ix) (indexB ix) loopM_ t (< headDim bix) (+ 1) $ \ !i -> let !lowerArr = (DWArray (DArray Seq szL (indexB . consDim i)) (tailDim <$> mStencilSz) -- can safely drop the dim, only -- last 2 matter anyways tixL (tailDim wSz) (indexW . consDim i)) in scheduleWork scheduler $ loadS lowerArr (_unsafeRead) (unsafeWriteLower i) {-# INLINE loadWindowedPRec #-} unrollAndJam :: Monad m => Int -> Ix2 -> Ix2 -> (Ix2 -> m a) -> m () unrollAndJam !bH (it :. ib) (jt :. jb) f = do let !bH' = min (max 1 bH) 7 let f2 (i :. j) = f (i :. j) >> f ((i + 1) :. j) let f3 (i :. j) = f (i :. j) >> f2 ((i + 1) :. j) let f4 (i :. j) = f (i :. j) >> f3 ((i + 1) :. j) let f5 (i :. j) = f (i :. j) >> f4 ((i + 1) :. j) let f6 (i :. j) = f (i :. j) >> f5 ((i + 1) :. j) let f7 (i :. j) = f (i :. j) >> f6 ((i + 1) :. j) let f' = case bH' of 1 -> f 2 -> f2 3 -> f3 4 -> f4 5 -> f5 6 -> f6 _ -> f7 let !ibS = ib - ((ib - it) `mod` bH') loopM_ it (< ibS) (+ bH') $ \ !i -> loopM_ jt (< jb) (+ 1) $ \ !j -> f' (i :. j) loopM_ ibS (< ib) (+ 1) $ \ !i -> loopM_ jt (< jb) (+ 1) $ \ !j -> f (i :. j) {-# INLINE unrollAndJam #-} -- TODO: Implement Hilbert curve instance {-# OVERLAPPING #-} Load DW Ix2T e where loadS arr _ unsafeWrite = do let (DWArray (DArray _ sz@(m, n) indexB) mStencilSz (it, jt) (wm, wn) indexW) = arr let (ib, jb) = (wm + it, wn + jt) blockHeight = case mStencilSz of Just (i, _) -> i _ -> 1 iterM_ (0, 0) (it, n) 1 (<) $ \ !ix -> unsafeWrite (toLinearIndex sz ix) (indexB ix) iterM_ (ib, 0) (m, n) 1 (<) $ \ !ix -> unsafeWrite (toLinearIndex sz ix) (indexB ix) iterM_ (it, 0) (ib, jt) 1 (<) $ \ !ix -> unsafeWrite (toLinearIndex sz ix) (indexB ix) iterM_ (it, jb) (ib, n) 1 (<) $ \ !ix -> unsafeWrite (toLinearIndex sz ix) (indexB ix) unrollAndJamT blockHeight (it, ib) (jt, jb) $ \ !ix -> unsafeWrite (toLinearIndex sz ix) (indexW ix) {-# INLINE loadS #-} loadP wIds arr _ unsafeWrite = do let (DWArray (DArray _ sz@(m, n) indexB) mStencilSz (it, jt) (wm, wn) indexW) = arr withScheduler_ wIds $ \ scheduler -> do let (ib, jb) = (wm + it, wn + jt) blockHeight = case mStencilSz of Just (i, _) -> i _ -> 1 !(chunkHeight, slackHeight) = wm `quotRem` numWorkers scheduler let loadBlock !it' !ib' = unrollAndJamT blockHeight (it', ib') (jt, jb) $ \ !ix -> unsafeWrite (toLinearIndex sz ix) (indexW ix) {-# INLINE loadBlock #-} scheduleWork scheduler $ iterM_ (0, 0) (it, n) 1 (<) $ \ !ix -> unsafeWrite (toLinearIndex sz ix) (indexB ix) scheduleWork scheduler $ iterM_ (ib, 0) (m, n) 1 (<) $ \ !ix -> unsafeWrite (toLinearIndex sz ix) (indexB ix) scheduleWork scheduler $ iterM_ (it, 0) (ib, jt) 1 (<) $ \ !ix -> unsafeWrite (toLinearIndex sz ix) (indexB ix) scheduleWork scheduler $ iterM_ (it, jb) (ib, n) 1 (<) $ \ !ix -> unsafeWrite (toLinearIndex sz ix) (indexB ix) loopM_ 0 (< numWorkers scheduler) (+ 1) $ \ !wid -> do let !it' = wid * chunkHeight + it scheduleWork scheduler $ loadBlock it' (it' + chunkHeight) when (slackHeight > 0) $ do let !itSlack = (numWorkers scheduler) * chunkHeight + it scheduleWork scheduler $ loadBlock itSlack (itSlack + slackHeight) {-# INLINE loadP #-} unrollAndJamT :: Monad m => Int -> Ix2T -> Ix2T -> (Ix2T -> m a) -> m () unrollAndJamT !bH (it, ib) (jt, jb) f = do let !bH' = min (max 1 bH) 7 let f2 !(i, j) = f (i, j) >> f (i+1, j) let f3 !(i, j) = f (i, j) >> f2 (i+1, j) let f4 !(i, j) = f (i, j) >> f3 (i+1, j) let f5 !(i, j) = f (i, j) >> f4 (i+1, j) let f6 !(i, j) = f (i, j) >> f5 (i+1, j) let f7 !(i, j) = f (i, j) >> f6 (i+1, j) let f' = case bH' of 1 -> f 2 -> f2 3 -> f3 4 -> f4 5 -> f5 6 -> f6 _ -> f7 let !ibS = ib - ((ib - it) `mod` bH') loopM_ it (< ibS) (+ bH') $ \ !i -> loopM_ jt (< jb) (+ 1) $ \ !j -> f' (i, j) loopM_ ibS (< ib) (+ 1) $ \ !i -> loopM_ jt (< jb) (+ 1) $ \ !j -> f (i, j) {-# INLINE unrollAndJamT #-}