{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE Rank2Types #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilyDependencies #-} {-# LANGUAGE TypeInType #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UnboxedSums #-} {-# LANGUAGE UnboxedTuples #-} module Numeric.DataFrame.Internal.Array.Family.ArrayBase ( ArrayBase (..) ) where import Data.Int import Data.Word import GHC.Base hiding (foldr) import Numeric.DataFrame.Internal.Array.Class import Numeric.DataFrame.Internal.Array.PrimOps import Numeric.Dimensions import Numeric.PrimBytes -- | Generic Array implementation. -- This array can reside in plain `ByteArray#` and can share the @ByteArray#@ -- with other arrays. -- However, byte offset in the @ByteArray#@ must be multiple of the element size. data ArrayBase (t :: Type) (ds :: [Nat]) = ArrayBase (# t -- Same value for each element; -- this is the cheapest way to initialize an array. -- It is also used for Num instances to avoid dependency on Dimensions. | (# Int# -- Offset measured in elements. , Int# -- Number of elements. , ByteArray# -- Content. #) #) instance (PrimBytes t, Dimensions ds) => PrimBytes (ArrayBase t ds) where {-# SPECIALIZE instance Dimensions ds => PrimBytes (ArrayBase Float ds) #-} {-# SPECIALIZE instance Dimensions ds => PrimBytes (ArrayBase Double ds) #-} {-# SPECIALIZE instance Dimensions ds => PrimBytes (ArrayBase Int ds) #-} {-# SPECIALIZE instance Dimensions ds => PrimBytes (ArrayBase Word ds) #-} {-# SPECIALIZE instance Dimensions ds => PrimBytes (ArrayBase Int8 ds) #-} {-# SPECIALIZE instance Dimensions ds => PrimBytes (ArrayBase Int16 ds) #-} {-# SPECIALIZE instance Dimensions ds => PrimBytes (ArrayBase Int32 ds) #-} {-# SPECIALIZE instance Dimensions ds => PrimBytes (ArrayBase Int64 ds) #-} {-# SPECIALIZE instance Dimensions ds => PrimBytes (ArrayBase Word8 ds) #-} {-# SPECIALIZE instance Dimensions ds => PrimBytes (ArrayBase Word16 ds) #-} {-# SPECIALIZE instance Dimensions ds => PrimBytes (ArrayBase Word32 ds) #-} {-# SPECIALIZE instance Dimensions ds => PrimBytes (ArrayBase Word64 ds) #-} getBytes (ArrayBase a ) = case a of (# t | #) | W# nw <- totalDim' @ds , n <- word2Int# nw , tbs <- byteSize t -> go tbs (tbs *# n) t (# | (# _, _, arr #) #) -> -- very weird trick with touch# allows to workaround GHC bug -- "internal error: ARR_WORDS object entered!" -- TODO: report this case runRW# (\s -> (# touch# arr s, arr #)) of (# _, ba #) -> ba where go tbs bsize t = case runRW# ( \s0 -> case newByteArray# bsize s0 of (# s1, mba #) -> unsafeFreezeByteArray# mba ( loop# 0# tbs bsize (\i -> writeBytes mba i t) s1 ) ) of (# _, ba #) -> ba {-# NOINLINE go #-} {-# INLINE getBytes #-} fromBytes bOff ba | W# nw <- totalDim' @ds , n <- word2Int# nw , tbs <- byteSize (undefined :: t) , (# offN, offRem #) <- quotRemInt# bOff tbs = case offRem of 0# -> ArrayBase (# | (# offN, n , ba #) #) _ -> go n (tbs *# n) where go n bsize = case runRW# ( \s0 -> case ( if isTrue# (isByteArrayPinned# ba) then newAlignedPinnedByteArray# bsize (byteAlign @t undefined) else newByteArray# bsize ) s0 of (# s1, mba #) -> unsafeFreezeByteArray# mba (copyByteArray# ba bOff mba 0# bsize s1) ) of (# _, r #) -> ArrayBase (# | (# 0# , n , r #) #) {-# NOINLINE go #-} {-# INLINE fromBytes #-} readBytes mba bOff s0 | W# nw <- totalDim' @ds , n <- word2Int# nw , tbs <- byteSize (undefined :: t) , bsize <- tbs *# n = case newByteArray# bsize s0 of (# s1, mba1 #) -> case unsafeFreezeByteArray# mba1 (copyMutableByteArray# mba bOff mba1 0# bsize s1) of (# s2, ba #) -> (# s2, ArrayBase (# | (# 0# , n , ba #) #) #) {-# INLINE readBytes #-} writeBytes mba bOff (ArrayBase c) | tbs <- byteSize (undefined :: t) = case c of (# t | #) | W# n <- totalDim' @ds -> loop# bOff tbs (bOff +# word2Int# n *# tbs) (\i -> writeBytes mba i t) (# | (# offN, n, arr #) #) -> copyByteArray# arr (offN *# tbs) mba bOff (n *# tbs) {-# INLINE writeBytes #-} readAddr addr s0 | W# nw <- totalDim' @ds , n <- word2Int# nw , tbs <- byteSize (undefined :: t) , bsize <- tbs *# n = case newByteArray# bsize s0 of (# s1, mba1 #) -> case unsafeFreezeByteArray# mba1 (copyAddrToByteArray# addr mba1 0# bsize s1) of (# s2, ba #) -> (# s2, ArrayBase (# | (# 0# , n , ba #) #) #) {-# INLINE readAddr #-} writeAddr (ArrayBase c) addr | tbs <- byteSize (undefined :: t) = case c of (# t | #) | W# n <- totalDim' @ds -> loop# 0# tbs (word2Int# n *# tbs) (\i -> writeAddr t (plusAddr# addr i)) (# | (# offN, n, arr #) #) -> copyByteArrayToAddr# arr (offN *# tbs) addr (n *# tbs) {-# INLINE writeAddr #-} byteSize _ = case totalDim' @ds of -- WARNING: slow! W# n -> byteSize (undefined :: t) *# word2Int# n {-# INLINE byteSize #-} byteAlign _ = byteAlign (undefined :: t) {-# INLINE byteAlign #-} byteOffset (ArrayBase a) = case a of (# _ | #) -> 0# (# | (# off, _, _ #) #) -> off *# byteSize (undefined :: t) {-# INLINE byteOffset #-} indexArray ba off | W# nw <- totalDim' @ds , n <- word2Int# nw = ArrayBase (# | (# off *# n, n, ba #) #) {-# INLINE indexArray #-} -- | Accumulates only idempotent operations! -- Being applied to FromScalars, executes only once! -- Here, idempotance means: assuming @f a b = g @, @g (g x) = g x@ -- -- Also, I assume the size of arrays is the same accumV2Idempotent :: PrimBytes t => a -> (t -> t -> a -> a) -> ArrayBase t ds -> ArrayBase t ds -> a accumV2Idempotent x f (ArrayBase (# a | #)) (ArrayBase (# b | #)) = f a b x accumV2Idempotent x f a@(ArrayBase (# | (# _, nA, _ #) #)) b@(ArrayBase (# | (# _, nB, _ #) #)) = loop1a# (minInt# nA nB) (\i -> f (ix# i a) (ix# i b)) x accumV2Idempotent x f (ArrayBase (# a | #)) b@(ArrayBase (# | (# _, n, _ #) #)) = loop1a# n (\i -> f a (ix# i b)) x accumV2Idempotent x f a@(ArrayBase (# | (# _, n, _ #) #)) (ArrayBase (# b | #)) = loop1a# n (\i -> f (ix# i a) b) x {-# INLINE accumV2Idempotent #-} mapV :: PrimBytes t => (t -> t) -> ArrayBase t ds -> ArrayBase t ds mapV f (ArrayBase (# t | #)) = ArrayBase (# f t | #) mapV f x@(ArrayBase (# | (# offN, n, ba #) #)) | tbs <- byteSize (undefEl x) = go (tbs *# n) where go bsize = case runRW# ( \s0 -> case newByteArray# bsize s0 of (# s1, mba #) -> unsafeFreezeByteArray# mba ( loop1# n (\i -> writeArray mba i (f (indexArray ba (offN +# i)))) s1 ) ) of (# _, r #) -> ArrayBase (# | (# 0#, n, r #) #) {-# NOINLINE go #-} {-# INLINE mapV #-} zipV :: PrimBytes t => (t -> t -> t) -> ArrayBase t ds -> ArrayBase t ds -> ArrayBase t ds zipV f (ArrayBase (# x | #)) b = mapV (f x) b zipV f a (ArrayBase (# y | #)) = mapV (flip f y) a zipV f a@(ArrayBase (# | (# oa, na, ba #) #)) (ArrayBase (# | (# ob, nb, bb #) #)) | n <- (minInt# na nb) = go n (byteSize (undefEl a) *# n) where go n bsize = case runRW# ( \s0 -> case newByteArray# bsize s0 of (# s1, mba #) -> unsafeFreezeByteArray# mba ( loop1# n (\i -> writeArray mba i (f (indexArray ba (oa +# i)) (indexArray bb (ob +# i)) ) ) s1 ) ) of (# _, r #) -> ArrayBase (# | (# 0#, n, r #) #) {-# NOINLINE go #-} {-# INLINE zipV #-} -- TODO: to improve performance, I can either compare bytearrays using memcmp -- or implement early termination if the first elements do not match. -- On the other hand, hopefully @(&&)@ and @(||)@ ops take care of that. instance (Eq t, PrimBytes t) => Eq (ArrayBase t ds) where {-# SPECIALIZE instance Eq (ArrayBase Float ds) #-} {-# SPECIALIZE instance Eq (ArrayBase Double ds) #-} {-# SPECIALIZE instance Eq (ArrayBase Int ds) #-} {-# SPECIALIZE instance Eq (ArrayBase Word ds) #-} {-# SPECIALIZE instance Eq (ArrayBase Int8 ds) #-} {-# SPECIALIZE instance Eq (ArrayBase Int16 ds) #-} {-# SPECIALIZE instance Eq (ArrayBase Int32 ds) #-} {-# SPECIALIZE instance Eq (ArrayBase Int64 ds) #-} {-# SPECIALIZE instance Eq (ArrayBase Word8 ds) #-} {-# SPECIALIZE instance Eq (ArrayBase Word16 ds) #-} {-# SPECIALIZE instance Eq (ArrayBase Word32 ds) #-} {-# SPECIALIZE instance Eq (ArrayBase Word64 ds) #-} (==) = accumV2Idempotent True (\x y r -> r && x == y) (/=) = accumV2Idempotent False (\x y r -> r || x /= y) -- | Implement partial ordering for `>`, `<`, `>=`, `<=` -- and lexicographical ordering for `compare` instance (Ord t, PrimBytes t) => Ord (ArrayBase t ds) where {-# SPECIALIZE instance Ord (ArrayBase Float ds) #-} {-# SPECIALIZE instance Ord (ArrayBase Double ds) #-} {-# SPECIALIZE instance Ord (ArrayBase Int ds) #-} {-# SPECIALIZE instance Ord (ArrayBase Word ds) #-} {-# SPECIALIZE instance Ord (ArrayBase Int8 ds) #-} {-# SPECIALIZE instance Ord (ArrayBase Int16 ds) #-} {-# SPECIALIZE instance Ord (ArrayBase Int32 ds) #-} {-# SPECIALIZE instance Ord (ArrayBase Int64 ds) #-} {-# SPECIALIZE instance Ord (ArrayBase Word8 ds) #-} {-# SPECIALIZE instance Ord (ArrayBase Word16 ds) #-} {-# SPECIALIZE instance Ord (ArrayBase Word32 ds) #-} {-# SPECIALIZE instance Ord (ArrayBase Word64 ds) #-} -- | Partiall ordering: all elements GT (>) = accumV2Idempotent True (\x y r -> r && x > y) {-# INLINE (>) #-} -- | Partiall ordering: all elements LT (<) = accumV2Idempotent True (\x y r -> r && x < y) {-# INLINE (<) #-} -- | Partiall ordering: all elements GE (>=) = accumV2Idempotent True (\x y r -> r && x >= y) {-# INLINE (>=) #-} -- | Partiall ordering: all elements LE (<=) = accumV2Idempotent True (\x y r -> r && x <= y) {-# INLINE (<=) #-} -- | Compare lexicographically compare = accumV2Idempotent EQ (\x y -> flip mappend (compare x y)) {-# INLINE compare #-} -- | Element-wise minimum min = zipV min {-# INLINE min #-} -- | Element-wise maximum max = zipV max {-# INLINE max #-} instance (Dimensions ds, PrimBytes t, Show t) => Show (ArrayBase t ds) where show x = case dims @_ @ds of U -> "{ " ++ show (ix# 0# x) ++ " }" Dim :* U -> ('{' :) . drop 1 $ foldr (\i s -> ", " ++ show (ix i x) ++ s) " }" [minBound .. maxBound] (Dim :: Dim n) :* (Dim :: Dim m) :* (Dims :: Dims dss) -> let loopInner :: Idxs dss -> Idxs '[n,m] -> String loopInner ods (n:*m:*_) = ('{' :) . drop 2 $ foldr (\i ss -> '\n': foldr (\j s -> ", " ++ show (ix (i :* j :* ods) x) ++ s ) ss [1..m] ) " }" [1..n] loopOuter :: Idxs dss -> String -> String loopOuter U s = "\n" ++ loopInner U maxBound ++ s loopOuter ds s = "\n(i j" ++ drop 4 (show ds) ++ "):\n" ++ loopInner ds maxBound ++ s in drop 1 $ foldr loopOuter "" [minBound..maxBound] instance {-# OVERLAPPING #-} Bounded (ArrayBase Double ds) where maxBound = ArrayBase (# inftyD | #) minBound = ArrayBase (# negate inftyD | #) instance {-# OVERLAPPING #-} Bounded (ArrayBase Float ds) where maxBound = ArrayBase (# inftyF | #) minBound = ArrayBase (# negate inftyF | #) instance {-# OVERLAPPABLE #-} Bounded t => Bounded (ArrayBase t ds) where {-# SPECIALIZE instance Bounded (ArrayBase Int ds) #-} {-# SPECIALIZE instance Bounded (ArrayBase Word ds) #-} {-# SPECIALIZE instance Bounded (ArrayBase Int8 ds) #-} {-# SPECIALIZE instance Bounded (ArrayBase Int16 ds) #-} {-# SPECIALIZE instance Bounded (ArrayBase Int32 ds) #-} {-# SPECIALIZE instance Bounded (ArrayBase Int64 ds) #-} {-# SPECIALIZE instance Bounded (ArrayBase Word8 ds) #-} {-# SPECIALIZE instance Bounded (ArrayBase Word16 ds) #-} {-# SPECIALIZE instance Bounded (ArrayBase Word32 ds) #-} {-# SPECIALIZE instance Bounded (ArrayBase Word64 ds) #-} maxBound = ArrayBase (# maxBound | #) minBound = ArrayBase (# minBound | #) instance (Num t, PrimBytes t) => Num (ArrayBase t ds) where {-# SPECIALIZE instance Num (ArrayBase Float ds) #-} {-# SPECIALIZE instance Num (ArrayBase Double ds) #-} {-# SPECIALIZE instance Num (ArrayBase Int ds) #-} {-# SPECIALIZE instance Num (ArrayBase Word ds) #-} {-# SPECIALIZE instance Num (ArrayBase Int8 ds) #-} {-# SPECIALIZE instance Num (ArrayBase Int16 ds) #-} {-# SPECIALIZE instance Num (ArrayBase Int32 ds) #-} {-# SPECIALIZE instance Num (ArrayBase Int64 ds) #-} {-# SPECIALIZE instance Num (ArrayBase Word8 ds) #-} {-# SPECIALIZE instance Num (ArrayBase Word16 ds) #-} {-# SPECIALIZE instance Num (ArrayBase Word32 ds) #-} {-# SPECIALIZE instance Num (ArrayBase Word64 ds) #-} (+) = zipV (+) {-# INLINE (+) #-} (-) = zipV (-) {-# INLINE (-) #-} (*) = zipV (*) {-# INLINE (*) #-} negate = mapV negate {-# INLINE negate #-} abs = mapV abs {-# INLINE abs #-} signum = mapV signum {-# INLINE signum #-} fromInteger i = ArrayBase (# fromInteger i | #) {-# INLINE fromInteger #-} instance (Fractional t, PrimBytes t) => Fractional (ArrayBase t ds) where {-# SPECIALIZE instance Fractional (ArrayBase Float ds) #-} {-# SPECIALIZE instance Fractional (ArrayBase Double ds) #-} (/) = zipV (/) {-# INLINE (/) #-} recip = mapV recip {-# INLINE recip #-} fromRational r = ArrayBase (# fromRational r | #) {-# INLINE fromRational #-} instance (Floating t, PrimBytes t) => Floating (ArrayBase t ds) where {-# SPECIALIZE instance Floating (ArrayBase Float ds) #-} {-# SPECIALIZE instance Floating (ArrayBase Double ds) #-} pi = ArrayBase (# pi | #) {-# INLINE pi #-} exp = mapV exp {-# INLINE exp #-} log = mapV log {-# INLINE log #-} sqrt = mapV sqrt {-# INLINE sqrt #-} sin = mapV sin {-# INLINE sin #-} cos = mapV cos {-# INLINE cos #-} tan = mapV tan {-# INLINE tan #-} asin = mapV asin {-# INLINE asin #-} acos = mapV acos {-# INLINE acos #-} atan = mapV atan {-# INLINE atan #-} sinh = mapV sinh {-# INLINE sinh #-} cosh = mapV cosh {-# INLINE cosh #-} tanh = mapV tanh {-# INLINE tanh #-} (**) = zipV (**) {-# INLINE (**) #-} logBase = zipV logBase {-# INLINE logBase #-} asinh = mapV asinh {-# INLINE asinh #-} acosh = mapV acosh {-# INLINE acosh #-} atanh = mapV atanh {-# INLINE atanh #-} instance PrimBytes t => PrimArray t (ArrayBase t ds) where {-# SPECIALIZE instance PrimArray Float (ArrayBase Float ds) #-} {-# SPECIALIZE instance PrimArray Double (ArrayBase Double ds) #-} {-# SPECIALIZE instance PrimArray Int (ArrayBase Int ds) #-} {-# SPECIALIZE instance PrimArray Word (ArrayBase Word ds) #-} {-# SPECIALIZE instance PrimArray Int8 (ArrayBase Int8 ds) #-} {-# SPECIALIZE instance PrimArray Int16 (ArrayBase Int16 ds) #-} {-# SPECIALIZE instance PrimArray Int32 (ArrayBase Int32 ds) #-} {-# SPECIALIZE instance PrimArray Int64 (ArrayBase Int64 ds) #-} {-# SPECIALIZE instance PrimArray Word8 (ArrayBase Word8 ds) #-} {-# SPECIALIZE instance PrimArray Word16 (ArrayBase Word16 ds) #-} {-# SPECIALIZE instance PrimArray Word32 (ArrayBase Word32 ds) #-} {-# SPECIALIZE instance PrimArray Word64 (ArrayBase Word64 ds) #-} broadcast t = ArrayBase (# t | #) {-# INLINE broadcast #-} ix# i (ArrayBase a) = case a of (# t | #) -> t (# | (# off, _, arr #) #) -> indexArray arr (off +# i) {-# INLINE ix# #-} gen# n f z0 = go (byteSize @t undefined *# n) where go bsize = case runRW# ( \s0 -> case newByteArray# bsize s0 of (# s1, mba #) -> case loop0 mba 0# z0 s1 of (# s2, z1 #) -> case unsafeFreezeByteArray# mba s2 of (# s3, ba #) -> (# s3, (# z1, ba #) #) ) of (# _, (# z1, ba #) #) -> (# z1, ArrayBase (# | (# 0# , n , ba #) #) #) {-# NOINLINE go #-} loop0 mba i z s | isTrue# (i ==# n) = (# s, z #) | otherwise = case f z of (# z', x #) -> loop0 mba (i +# 1#) z' (writeArray mba i x s) {-# INLINE gen# #-} upd# n i x (ArrayBase (# a | #)) = go (byteSize x) where go tbs = case runRW# ( \s0 -> case newByteArray# (tbs *# n) s0 of (# s1, mba #) -> unsafeFreezeByteArray# mba (writeArray mba i x (loop1# n (\j -> writeArray mba j a) s1) ) ) of (# _, r #) -> ArrayBase (# | (# 0# , n , r #) #) {-# NOINLINE go #-} upd# _ i x (ArrayBase (# | (# offN , n , ba #) #)) = go (byteSize x) where go tbs = case runRW# ( \s0 -> case newByteArray# (tbs *# n) s0 of (# s1, mba #) -> unsafeFreezeByteArray# mba (writeArray mba i x (copyByteArray# ba (offN *# tbs) mba 0# (tbs *# n) s1) ) ) of (# _, r #) -> ArrayBase (# | (# 0# , n , r #) #) {-# NOINLINE go #-} {-# INLINE upd# #-} elemOffset (ArrayBase a) = case a of (# _ | #) -> 0# (# | (# off, _, _ #) #) -> off {-# INLINE elemOffset #-} elemSize0 (ArrayBase a) = case a of (# _ | #) -> 0# (# | (# _, n, _ #) #) -> n {-# INLINE elemSize0 #-} fromElems off n ba = ArrayBase (# | (# off , n , ba #) #) {-# INLINE fromElems #-} -------------------------------------------------------------------------------- -- * Utility functions -------------------------------------------------------------------------------- ix :: (PrimBytes t, Dimensions ds) => Idxs ds -> ArrayBase t ds -> t ix i (ArrayBase a) = case a of (# t | #) -> t (# | (# off, _, arr #) #) -> case fromEnum i of I# i# -> indexArray arr (off +# i#) {-# INLINE ix #-} undefEl :: ArrayBase t ds -> t undefEl = const undefined