{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeSynonymInstances #-} -- | -- Module : Data.Array.Accelerate.IO.Data.Vector.Generic.Mutable -- Copyright : [2017..2020] The Accelerate Team -- License : BSD3 -- -- Maintainer : Trevor L. McDonell -- Stability : experimental -- Portability : non-portable (GHC extensions) -- -- This module provides instance for 'Data.Vector.Generic.Mutable.MVector', for -- generic mutable vectors backed by Accelerate. -- -- @since 0.1.0.0 -- module Data.Array.Accelerate.IO.Data.Vector.Generic.Mutable where import Data.Array.Accelerate.Array.Data import Data.Array.Accelerate.Array.Unique import Data.Array.Accelerate.Lifetime import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Sugar.Shape import Data.Array.Accelerate.Type import qualified Data.Vector.Generic.Mutable as V import Control.Monad.Primitive import Data.Typeable import Foreign.Marshal.Utils import Foreign.Storable import Prelude hiding ( length ) import GHC.Base import GHC.ForeignPtr -- | Dense, regular, mutable, multi-dimensional arrays -- data MArray sh s e where MArray :: EltR sh -- extent of dimensions == shape -> MutableArrayData (EltR e) -- mutable array payload -> MArray sh s e deriving instance Typeable MArray type MVector = MArray DIM1 instance Elt e => V.MVector MVector e where {-# INLINE basicLength #-} {-# INLINE basicUnsafeSlice #-} {-# INLINE basicOverlaps #-} {-# INLINE basicUnsafeNew #-} {-# INLINE basicInitialize #-} {-# INLINE basicUnsafeRead #-} {-# INLINE basicUnsafeWrite #-} {-# INLINE basicUnsafeCopy #-} basicLength (MArray ((), n) _) = n basicUnsafeSlice j m (MArray _ mad) = MArray ((),m) (go (eltR @e) mad) where go :: TypeR a -> MutableArrayData a -> MutableArrayData a go TupRunit () = () go (TupRpair aR1 aR2) (a1, a2) = (go aR1 a1, go aR2 a2) go (TupRsingle aR) a = scalar aR a scalar :: ScalarType a -> MutableArrayData a -> MutableArrayData a scalar (SingleScalarType t) a = single t a 1 scalar (VectorScalarType t) a = vector t a vector :: VectorType a -> MutableArrayData a -> MutableArrayData a vector (VectorType w t) a | SingleArrayDict <- singleArrayDict t = single t a w single :: SingleType a -> MutableArrayData a -> Int -> MutableArrayData a single (NumSingleType t) = num t num :: NumType a -> MutableArrayData a -> Int -> MutableArrayData a num (IntegralNumType t) = integral t num (FloatingNumType t) = floating t integral :: IntegralType a -> MutableArrayData a -> Int -> MutableArrayData a integral TypeInt = slice integral TypeInt8 = slice integral TypeInt16 = slice integral TypeInt32 = slice integral TypeInt64 = slice integral TypeWord = slice integral TypeWord8 = slice integral TypeWord16 = slice integral TypeWord32 = slice integral TypeWord64 = slice floating :: FloatingType a -> MutableArrayData a -> Int -> MutableArrayData a floating TypeHalf = slice floating TypeFloat = slice floating TypeDouble = slice slice :: forall a. Storable a => UniqueArray a -> Int -> UniqueArray a slice (UniqueArray uid (Lifetime lft w fp)) s = UniqueArray uid (Lifetime lft w (plusForeignPtr fp (j * s * sizeOf (undefined::a)))) basicOverlaps (MArray ((),m) mad1) (MArray ((),n) mad2) = go (eltR @e) mad1 mad2 where go :: TypeR a -> MutableArrayData a -> MutableArrayData a -> Bool go TupRunit () () = False go (TupRpair aR1 aR2) (a1, a2) (b1, b2) = go aR1 a1 b1 || go aR2 a2 b2 go (TupRsingle aR) a b = scalar aR a b scalar :: ScalarType a -> MutableArrayData a -> MutableArrayData a -> Bool scalar (SingleScalarType t) a b = single t a b 1 scalar (VectorScalarType t) a b = vector t a b vector :: VectorType a -> MutableArrayData a -> MutableArrayData a -> Bool vector (VectorType w t) a b | SingleArrayDict <- singleArrayDict t = single t a b w single :: SingleType a -> MutableArrayData a -> MutableArrayData a -> Int -> Bool single (NumSingleType t) = num t num :: NumType a -> MutableArrayData a -> MutableArrayData a -> Int -> Bool num (IntegralNumType t) = integral t num (FloatingNumType t) = floating t integral :: IntegralType a -> MutableArrayData a -> MutableArrayData a -> Int -> Bool integral TypeInt = overlaps integral TypeInt8 = overlaps integral TypeInt16 = overlaps integral TypeInt32 = overlaps integral TypeInt64 = overlaps integral TypeWord = overlaps integral TypeWord8 = overlaps integral TypeWord16 = overlaps integral TypeWord32 = overlaps integral TypeWord64 = overlaps floating :: FloatingType a -> MutableArrayData a -> MutableArrayData a -> Int -> Bool floating TypeHalf = overlaps floating TypeFloat = overlaps floating TypeDouble = overlaps overlaps :: forall a. Storable a => UniqueArray a -> UniqueArray a -> Int -> Bool overlaps (UniqueArray _ (Lifetime _ _ (ForeignPtr addr1# c1))) (UniqueArray _ (Lifetime _ _ (ForeignPtr addr2# c2))) s = let i = I# (addr2Int# addr1#) j = I# (addr2Int# addr2#) k = s * sizeOf (undefined::a) in same c1 c2 && (between i j (j + n*k) || between j i (i + m*k)) same :: ForeignPtrContents -> ForeignPtrContents -> Bool same (PlainPtr mba1#) (PlainPtr mba2#) = isTrue# (sameMutableByteArray# mba1# mba2#) same (MallocPtr mba1# _) (MallocPtr mba2# _) = isTrue# (sameMutableByteArray# mba1# mba2#) -- same PlainForeignPtr{} PlainForeignPtr{} = False -- probably? same _ _ = False -- probably? should we still check whether the address ranges overlap? between :: Int -> Int -> Int -> Bool between x y z = x >= y && x < z basicUnsafeNew n = unsafePrimToPrim $ MArray ((),n) <$> newArrayData (eltR @e) n basicInitialize (MArray ((),n) mad) = unsafePrimToPrim $ go (eltR @e) mad where go :: TypeR a -> MutableArrayData a -> IO () go TupRunit () = return () go (TupRpair aR1 aR2) (a1, a2) = go aR1 a1 >> go aR2 a2 go (TupRsingle aR) a = scalar aR a scalar :: ScalarType a -> MutableArrayData a -> IO () scalar (SingleScalarType t) a = single t a 1 scalar (VectorScalarType t) a = vector t a vector :: VectorType a -> MutableArrayData a -> IO () vector (VectorType w t) a | SingleArrayDict <- singleArrayDict t = single t a w single :: SingleType a -> MutableArrayData a -> Int -> IO () single (NumSingleType t) = num t num :: NumType a -> MutableArrayData a -> Int -> IO () num (IntegralNumType t) = integral t num (FloatingNumType t) = floating t integral :: IntegralType a -> MutableArrayData a -> Int -> IO () integral TypeInt = initialise integral TypeInt8 = initialise integral TypeInt16 = initialise integral TypeInt32 = initialise integral TypeInt64 = initialise integral TypeWord = initialise integral TypeWord8 = initialise integral TypeWord16 = initialise integral TypeWord32 = initialise integral TypeWord64 = initialise floating :: FloatingType a -> MutableArrayData a -> Int -> IO () floating TypeHalf = initialise floating TypeFloat = initialise floating TypeDouble = initialise initialise :: forall a. Storable a => UniqueArray a -> Int -> IO () initialise ua s = withUniqueArrayPtr ua $ \p -> fillBytes p 0 (n * s * sizeOf (undefined::a)) basicUnsafeRead (MArray _ mad) i = unsafePrimToPrim $ toElt <$> readArrayData (eltR @e) mad i basicUnsafeWrite (MArray _ mad) i v = unsafePrimToPrim $ writeArrayData (eltR @e) mad i (fromElt v) basicUnsafeCopy (MArray _ dst) (MArray ((),n) src) = unsafePrimToPrim $ go (eltR @e) dst src where go :: TypeR a -> MutableArrayData a -> MutableArrayData a -> IO () go TupRunit () () = return () go (TupRpair aR1 aR2) (a1, a2) (b1, b2) = go aR1 a1 b1 >> go aR2 a2 b2 go (TupRsingle aR) a b = scalar aR a b scalar :: ScalarType a -> MutableArrayData a -> MutableArrayData a -> IO () scalar (SingleScalarType t) a b = single t a b 1 scalar (VectorScalarType t) a b = vector t a b vector :: VectorType a -> MutableArrayData a -> MutableArrayData a -> IO () vector (VectorType w t) a b | SingleArrayDict <- singleArrayDict t = single t a b w single :: SingleType a -> MutableArrayData a -> MutableArrayData a -> Int -> IO () single (NumSingleType t) = num t num :: NumType a -> MutableArrayData a -> MutableArrayData a -> Int -> IO () num (IntegralNumType t) = integral t num (FloatingNumType t) = floating t integral :: IntegralType a -> MutableArrayData a -> MutableArrayData a -> Int -> IO () integral TypeInt = copy integral TypeInt8 = copy integral TypeInt16 = copy integral TypeInt32 = copy integral TypeInt64 = copy integral TypeWord = copy integral TypeWord8 = copy integral TypeWord16 = copy integral TypeWord32 = copy integral TypeWord64 = copy floating :: FloatingType a -> MutableArrayData a -> MutableArrayData a -> Int -> IO () floating TypeHalf = copy floating TypeFloat = copy floating TypeDouble = copy copy :: forall a. Storable a => UniqueArray a -> UniqueArray a -> Int -> IO () copy uu uv s = withUniqueArrayPtr uu $ \u -> withUniqueArrayPtr uv $ \v -> copyBytes u v (n * s * sizeOf (undefined::a)) #if !MIN_VERSION_base(4,10,0) plusForeignPtr :: ForeignPtr a -> Int -> ForeignPtr b plusForeignPtr (ForeignPtr addr# c) (I# i#) = ForeignPtr (plusAddr# addr# i#) c #endif