{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} -- | -- Module : Data.Massiv.Array.Delayed.Internal -- Copyright : (c) Alexey Kuleshevich 2018 -- License : BSD3 -- Maintainer : Alexey Kuleshevich -- Stability : experimental -- Portability : non-portable -- module Data.Massiv.Array.Delayed.Internal ( D(..) , Array(..) , delay , eq , ord , liftArray , liftArray2 ) where import Data.Foldable (Foldable (..)) import Data.Massiv.Array.Ops.Fold.Internal as A import Data.Massiv.Core.Common import Data.Massiv.Core.Scheduler import Data.Monoid ((<>)) import GHC.Base (build) import Prelude hiding (zipWith) #include "massiv.h" -- | Delayed representation. data D = D deriving Show data instance Array D ix e = DArray { dComp :: !Comp , dSize :: !ix , dIndex :: ix -> e } type instance EltRepr D ix = D instance Index ix => Construct D ix e where getComp = dComp {-# INLINE getComp #-} setComp c arr = arr { dComp = c } {-# INLINE setComp #-} unsafeMakeArray = DArray {-# INLINE unsafeMakeArray #-} instance Index ix => Source D ix e where unsafeIndex = INDEX_CHECK("(Source D ix e).unsafeIndex", size, dIndex) {-# INLINE unsafeIndex #-} instance Index ix => Size D ix e where size = dSize {-# INLINE size #-} unsafeResize !sz !arr = DArray (getComp arr) sz $ \ !ix -> unsafeIndex arr (fromLinearIndex (size arr) (toLinearIndex sz ix)) {-# INLINE unsafeResize #-} unsafeExtract !sIx !newSz !arr = DArray (getComp arr) newSz $ \ !ix -> unsafeIndex arr (liftIndex2 (+) ix sIx) {-# INLINE unsafeExtract #-} instance ( Index ix , Index (Lower ix) , Elt D ix e ~ Array D (Lower ix) e ) => Slice D ix e where unsafeSlice arr start cutSz dim = do newSz <- dropDim cutSz dim return $ unsafeResize newSz (unsafeExtract start cutSz arr) {-# INLINE unsafeSlice #-} instance (Elt D ix e ~ Array D (Lower ix) e, Index ix) => OuterSlice D ix e where unsafeOuterSlice !arr !i = DArray (getComp arr) (tailDim (size arr)) (\ !ix -> unsafeIndex arr (consDim i ix)) {-# INLINE unsafeOuterSlice #-} instance (Elt D ix e ~ Array D (Lower ix) e, Index ix) => InnerSlice D ix e where unsafeInnerSlice !arr !(szL, _) !i = DArray (getComp arr) szL (\ !ix -> unsafeIndex arr (snocDim ix i)) {-# INLINE unsafeInnerSlice #-} instance (Eq e, Index ix) => Eq (Array D ix e) where (==) = eq (==) {-# INLINE (==) #-} instance (Ord e, Index ix) => Ord (Array D ix e) where compare = ord compare {-# INLINE compare #-} instance Functor (Array D ix) where fmap f (DArray c sz g) = DArray c sz (f . g) {-# INLINE fmap #-} instance Index ix => Applicative (Array D ix) where pure a = DArray Seq (liftIndex (+ 1) zeroIndex) (const a) {-# INLINE pure #-} (<*>) (DArray c1 sz1 uIndex1) (DArray c2 sz2 uIndex2) = DArray (c1 <> c2) (liftIndex2 min sz1 sz2) $ \ !ix -> (uIndex1 ix) (uIndex2 ix) {-# INLINE (<*>) #-} -- | Row-major sequential folding over a Delayed array. instance Index ix => Foldable (Array D ix) where foldl = lazyFoldlS {-# INLINE foldl #-} foldl' = foldlS {-# INLINE foldl' #-} foldr = foldrFB {-# INLINE foldr #-} foldr' = foldrS {-# INLINE foldr' #-} null (DArray _ sz _) = totalElem sz == 0 {-# INLINE null #-} sum = foldl' (+) 0 {-# INLINE sum #-} product = foldl' (*) 1 {-# INLINE product #-} length = totalElem . size {-# INLINE length #-} toList arr = build (\ c n -> foldrFB c n arr) {-# INLINE toList #-} instance Index ix => Load D ix e where loadS (DArray _ sz f) _ unsafeWrite = iterM_ zeroIndex sz (pureIndex 1) (<) $ \ !ix -> unsafeWrite (toLinearIndex sz ix) (f ix) {-# INLINE loadS #-} loadP wIds (DArray _ sz f) _ unsafeWrite = divideWork_ wIds sz $ \ !scheduler !chunkLength !totalLength !slackStart -> do loopM_ 0 (< slackStart) (+ chunkLength) $ \ !start -> scheduleWork scheduler $ iterLinearM_ sz start (start + chunkLength) 1 (<) $ \ !k !ix -> unsafeWrite k (f ix) scheduleWork scheduler $ iterLinearM_ sz slackStart totalLength 1 (<) $ \ !k !ix -> unsafeWrite k (f ix) {-# INLINE loadP #-} instance (Index ix, Num e) => Num (Array D ix e) where (+) = liftArray2 (+) {-# INLINE (+) #-} (-) = liftArray2 (-) {-# INLINE (-) #-} (*) = liftArray2 (*) {-# INLINE (*) #-} abs = liftArray abs {-# INLINE abs #-} signum = liftArray signum {-# INLINE signum #-} fromInteger = singleton Seq . fromInteger {-# INLINE fromInteger #-} instance (Index ix, Fractional e) => Fractional (Array D ix e) where (/) = liftArray2 (/) {-# INLINE (/) #-} fromRational = singleton Seq . fromRational {-# INLINE fromRational #-} instance (Index ix, Floating e) => Floating (Array D ix e) where pi = singleton Seq pi {-# INLINE pi #-} exp = liftArray exp {-# INLINE exp #-} log = liftArray log {-# INLINE log #-} sin = liftArray sin {-# INLINE sin #-} cos = liftArray cos {-# INLINE cos #-} asin = liftArray asin {-# INLINE asin #-} atan = liftArray atan {-# INLINE atan #-} acos = liftArray acos {-# INLINE acos #-} sinh = liftArray sinh {-# INLINE sinh #-} cosh = liftArray cosh {-# INLINE cosh #-} asinh = liftArray asinh {-# INLINE asinh #-} atanh = liftArray atanh {-# INLINE atanh #-} acosh = liftArray acosh {-# INLINE acosh #-} -- | /O(1)/ Conversion from a source array to `D` representation. delay :: Source r ix e => Array r ix e -> Array D ix e delay arr = DArray (getComp arr) (size arr) (unsafeIndex arr) {-# INLINE delay #-} -- | /O(n1 + n2)/ - Compute array equality by applying a comparing function to each element. eq :: (Source r1 ix e1, Source r2 ix e2) => (e1 -> e2 -> Bool) -> Array r1 ix e1 -> Array r2 ix e2 -> Bool eq f arr1 arr2 = (size arr1 == size arr2) && A.fold (&&) True (DArray (getComp arr1 <> getComp arr2) (size arr1) $ \ix -> f (unsafeIndex arr1 ix) (unsafeIndex arr2 ix)) {-# INLINE eq #-} -- | /O(n1 + n2)/ - Compute array ordering by applying a comparing function to each element. -- The exact ordering is unspecified so this is only intended for use in maps and the like where -- you need an ordering but do not care about which one is used. ord :: (Source r1 ix e1, Source r2 ix e2) => (e1 -> e2 -> Ordering) -> Array r1 ix e1 -> Array r2 ix e2 -> Ordering ord f arr1 arr2 = (compare (size arr1) (size arr2)) <> A.fold (<>) mempty (DArray (getComp arr1 <> getComp arr2) (size arr1) $ \ix -> f (unsafeIndex arr1 ix) (unsafeIndex arr2 ix)) {-# INLINE ord #-} -- | The usual map. liftArray :: Source r ix b => (b -> e) -> Array r ix b -> Array D ix e liftArray f !arr = DArray (getComp arr) (size arr) (f . unsafeIndex arr) {-# INLINE liftArray #-} -- | Similar to `Data.Massiv.Array.zipWith`, except dimensions of both arrays either have to be the -- same, or at least one of the two array must be a singleton array, in which case it will behave as -- a `Data.Massiv.Array.map`. -- -- @since 0.1.4 liftArray2 :: (Source r1 ix a, Source r2 ix b) => (a -> b -> e) -> Array r1 ix a -> Array r2 ix b -> Array D ix e liftArray2 f !arr1 !arr2 | sz1 == oneIndex = liftArray (f (unsafeIndex arr1 zeroIndex)) arr2 | sz2 == oneIndex = liftArray (`f` (unsafeIndex arr2 zeroIndex)) arr1 | sz1 == sz2 = DArray (getComp arr1) sz1 (\ !ix -> f (unsafeIndex arr1 ix) (unsafeIndex arr2 ix)) | otherwise = errorSizeMismatch "liftArray2" (size arr1) (size arr2) where oneIndex = pureIndex 1 sz1 = size arr1 sz2 = size arr2 {-# INLINE liftArray2 #-}