{-# LANGUAGE BangPatterns #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} -- | -- Module : Data.Massiv.Array.Numeric -- Copyright : (c) Alexey Kuleshevich 2018-2019 -- License : BSD3 -- Maintainer : Alexey Kuleshevich -- Stability : experimental -- Portability : non-portable -- module Data.Massiv.Array.Numeric ( -- * Num (.+.) , (.+) , (+.) , (.-.) , (.-) , (-.) , (.*.) , (.*) , (*.) , (.^) , (#>) , (|*|) , multiplyTransposed , identityMatrix , lowerTriangular , upperTriangular , negateA , absA , signumA , fromIntegerA -- * Integral , quotA , remA , divA , modA , quotRemA , divModA -- * Fractional , (./.) , (./) , (.^^) , recipA , fromRationalA -- * Floating , piA , expA , logA , sqrtA , (.**) , logBaseA , sinA , cosA , tanA , asinA , acosA , atanA , sinhA , coshA , tanhA , asinhA , acoshA , atanhA -- * RealFrac , truncateA , roundA , ceilingA , floorA -- * RealFloat , atan2A ) where import Data.Massiv.Array.Delayed.Pull import Data.Massiv.Array.Delayed.Push import Data.Massiv.Array.Manifest.Internal import Data.Massiv.Array.Ops.Fold as A import Data.Massiv.Array.Ops.Map as A import Data.Massiv.Array.Ops.Transform as A import Data.Massiv.Array.Ops.Construct import Data.Massiv.Core import Data.Massiv.Core.Common import Data.Massiv.Core.Operations import Prelude as P infixr 8 .^, .^^ infixl 7 .*., .*, *., ./., ./, `quotA`, `remA`, `divA`, `modA` infixl 6 .+., .+, +., .-., .-, -. liftArray2Matching :: (Source r1 ix a, Source r2 ix b) => (a -> b -> e) -> Array r1 ix a -> Array r2 ix b -> Array D ix e liftArray2Matching f !arr1 !arr2 | sz1 == sz2 = makeArray (getComp arr1 <> getComp arr2) sz1 (\ !ix -> f (unsafeIndex arr1 ix) (unsafeIndex arr2 ix)) | otherwise = throw $ SizeMismatchException (size arr1) (size arr2) where sz1 = size arr1 sz2 = size arr2 {-# INLINE liftArray2Matching #-} liftArray2M :: (Load r ix e, Numeric r e, MonadThrow m) => (e -> e -> e) -> Array r ix e -> Array r ix e -> m (Array r ix e) liftArray2M f a1 a2 | size a1 == size a2 = pure $ unsafeLiftArray2 f a1 a2 | otherwise = throwM $ SizeMismatchException (size a1) (size a2) {-# INLINE liftArray2M #-} liftNumericArray2M :: (Load r ix e, MonadThrow m) => (Array r ix e -> Array r ix e -> Array r ix e) -> Array r ix e -> Array r ix e -> m (Array r ix e) liftNumericArray2M f a1 a2 | size a1 == size a2 = pure $ f a1 a2 | otherwise = throwM $ SizeMismatchException (size a1) (size a2) {-# INLINE liftNumericArray2M #-} -- | Add two arrays together pointwise. Throws `SizeMismatchException` if arrays sizes do -- not match. -- -- @since 0.4.0 (.+.) :: (Load r ix e, Numeric r e, MonadThrow m) => Array r ix e -> Array r ix e -> m (Array r ix e) (.+.) = liftNumericArray2M additionPointwise {-# INLINE (.+.) #-} -- | Add a scalar to each element of the array. Array is on the left. -- -- @since 0.1.0 (.+) :: (Index ix, Numeric r e) => Array r ix e -> e -> Array r ix e (.+) = plusScalar {-# INLINE (.+) #-} -- | Add a scalar to each element of the array. Array is on the right. -- -- @since 0.4.0 (+.) :: (Index ix, Numeric r e) => e -> Array r ix e -> Array r ix e (+.) = flip plusScalar {-# INLINE (+.) #-} -- | Subtract two arrays pointwise. Throws `SizeMismatchException` if arrays sizes do not -- match. -- -- @since 0.4.0 (.-.) :: (Load r ix e, Numeric r e, MonadThrow m) => Array r ix e -> Array r ix e -> m (Array r ix e) (.-.) = liftNumericArray2M subtractionPointwise {-# INLINE (.-.) #-} -- | Subtract a scalar from each element of the array. Array is on the left. -- -- @since 0.1.0 (.-) :: (Index ix, Numeric r e) => Array r ix e -> e -> Array r ix e (.-) = minusScalar {-# INLINE (.-) #-} -- | Subtract a scalar from each element of the array. Array is on the right. -- -- @since 0.4.0 (-.) :: (Index ix, Numeric r e) => e -> Array r ix e -> Array r ix e (-.) = flip minusScalar {-# INLINE (-.) #-} -- | Multiply two arrays together pointwise. -- -- @since 0.4.0 (.*.) :: (Load r ix e, Numeric r e, MonadThrow m) => Array r ix e -> Array r ix e -> m (Array r ix e) (.*.) = liftNumericArray2M multiplicationPointwise {-# INLINE (.*.) #-} (.*) :: (Index ix, Numeric r e) => Array r ix e -> e -> Array r ix e (.*) = multiplyScalar {-# INLINE (.*) #-} (*.) :: (Index ix, Numeric r e) => e -> Array r ix e -> Array r ix e (*.) = flip multiplyScalar {-# INLINE (*.) #-} (.^) :: (Index ix, Numeric r e) => Array r ix e -> Int -> Array r ix e (.^) = powerPointwise {-# INLINE (.^) #-} -- | Matrix multiplication -- -- Inner dimensions must agree, otherwise `SizeMismatchException`. (|*|) :: (Mutable r Ix2 e, Source r' Ix2 e, OuterSlice r Ix2 e, Source (R r) Ix1 e, Num e, MonadThrow m) => Array r Ix2 e -> Array r' Ix2 e -> m (Array r Ix2 e) (|*|) a1 a2 = compute <$> multArrs a1 a2 {-# INLINE [1] (|*|) #-} {-# RULES "multDoubleTranspose" [~1] forall arr1 arr2 . arr1 |*| transpose arr2 = multiplyTransposedFused arr1 (convert arr2) #-} -- | Matrix-vector product -- -- Inner dimensions must agree, otherwise `SizeMismatchException` -- -- @since 0.5.2 (#>) :: (MonadThrow m, Num e, Source (R r) Ix1 e, Manifest r' Ix1 e, OuterSlice r Ix2 e) => Array r Ix2 e -- ^ Matrix -> Array r' Ix1 e -- ^ Vector -> m (Array D Ix1 e) mm #> v | mCols /= n = throwM $ SizeMismatchException (size mm) (Sz2 n 1) | otherwise = pure $ makeArray (getComp mm <> getComp v) (Sz1 mRows) $ \i -> A.foldlS (+) 0 (A.zipWith (*) (unsafeOuterSlice mm i) v) where Sz2 mRows mCols = size mm Sz1 n = size v {-# INLINE (#>) #-} multiplyTransposedFused :: ( Mutable r Ix2 e , OuterSlice r Ix2 e , Source (R r) Ix1 e , Num e , MonadThrow m ) => Array r Ix2 e -> Array r Ix2 e -> m (Array r Ix2 e) multiplyTransposedFused arr1 arr2 = compute <$> multiplyTransposed arr1 arr2 {-# INLINE multiplyTransposedFused #-} multArrs :: forall r r' e m. ( Mutable r Ix2 e , Source r' Ix2 e , OuterSlice r Ix2 e , Source (R r) Ix1 e , Num e , MonadThrow m ) => Array r Ix2 e -> Array r' Ix2 e -> m (Array D Ix2 e) multArrs arr1 arr2 = multiplyTransposed arr1 arr2' where arr2' :: Array r Ix2 e arr2' = compute $ transpose arr2 {-# INLINE multArrs #-} -- | Computes the matrix-matrix transposed product (i.e. A * A') multiplyTransposed :: ( Manifest r Ix2 e , OuterSlice r Ix2 e , Source (R r) Ix1 e , Num e , MonadThrow m ) => Array r Ix2 e -> Array r Ix2 e -> m (Array D Ix2 e) multiplyTransposed arr1 arr2 | n1 /= m2 = throwM $ SizeMismatchException (size arr1) (size arr2) | otherwise = pure $ DArray (getComp arr1 <> getComp arr2) (SafeSz (m1 :. n2)) $ \(i :. j) -> A.foldlS (+) 0 (A.zipWith (*) (unsafeOuterSlice arr1 i) (unsafeOuterSlice arr2 j)) where SafeSz (m1 :. n1) = size arr1 SafeSz (n2 :. m2) = size arr2 {-# INLINE multiplyTransposed #-} -- | Create an indentity matrix. -- -- ==== __Example__ -- -- >>> import Data.Massiv.Array -- >>> identityMatrix 5 -- Array DL Seq (Sz (5 :. 5)) -- [ [ 1, 0, 0, 0, 0 ] -- , [ 0, 1, 0, 0, 0 ] -- , [ 0, 0, 1, 0, 0 ] -- , [ 0, 0, 0, 1, 0 ] -- , [ 0, 0, 0, 0, 1 ] -- ] -- -- @since 0.3.6 identityMatrix :: Num e => Sz1 -> Matrix DL e identityMatrix (Sz n) = makeLoadArrayS (Sz2 n n) 0 $ \ w -> loopM_ 0 (< n) (+1) $ \ i -> w (i :. i) 1 {-# INLINE identityMatrix #-} -- | Create a lower triangular (L in LU decomposition) matrix of size @NxN@ -- -- ==== __Example__ -- -- >>> import Data.Massiv.Array -- >>> lowerTriangular Seq 5 (\(i :. j) -> i + j) -- Array DL Seq (Sz (5 :. 5)) -- [ [ 0, 0, 0, 0, 0 ] -- , [ 1, 2, 0, 0, 0 ] -- , [ 2, 3, 4, 0, 0 ] -- , [ 3, 4, 5, 6, 0 ] -- , [ 4, 5, 6, 7, 8 ] -- ] -- -- @since 0.5.2 lowerTriangular :: Num e => Comp -> Sz1 -> (Ix2 -> e) -> Matrix DL e lowerTriangular comp (Sz1 n) f = let sz = Sz2 n n in unsafeMakeLoadArrayAdjusted comp sz (Just 0) $ \scheduler wr -> forM_ (0 ..: n) $ \i -> scheduleWork scheduler $ forM_ (0 ... i) $ \j -> let ix = i :. j in wr (toLinearIndex sz ix) (f ix) {-# INLINE lowerTriangular #-} -- | Create an upper triangular (U in LU decomposition) matrix of size @NxN@ -- -- ==== __Example__ -- -- >>> import Data.Massiv.Array -- >>> upperTriangular Par 5 (\(i :. j) -> i + j) -- Array DL Par (Sz (5 :. 5)) -- [ [ 0, 1, 2, 3, 4 ] -- , [ 0, 2, 3, 4, 5 ] -- , [ 0, 0, 4, 5, 6 ] -- , [ 0, 0, 0, 6, 7 ] -- , [ 0, 0, 0, 0, 8 ] -- ] -- -- @since 0.5.2 upperTriangular :: Num e => Comp -> Sz1 -> (Ix2 -> e) -> Matrix DL e upperTriangular comp (Sz1 n) f = let sz = Sz2 n n in unsafeMakeLoadArrayAdjusted comp sz (Just 0) $ \scheduler wr -> forM_ (0 ..: n) $ \i -> scheduleWork scheduler $ forM_ (i ..: n) $ \j -> let ix = i :. j in wr (toLinearIndex sz ix) (f ix) {-# INLINE upperTriangular #-} negateA :: (Index ix, Numeric r e) => Array r ix e -> Array r ix e negateA = unsafeLiftArray negate {-# INLINE negateA #-} absA :: (Index ix, Numeric r e) => Array r ix e -> Array r ix e absA = absPointwise {-# INLINE absA #-} signumA :: (Index ix, Numeric r e) => Array r ix e -> Array r ix e signumA = unsafeLiftArray signum {-# INLINE signumA #-} fromIntegerA :: (Index ix, Num e) => Integer -> Array D ix e fromIntegerA = singleton . fromInteger {-# INLINE fromIntegerA #-} (./.) :: (Load r ix e, NumericFloat r e, MonadThrow m) => Array r ix e -> Array r ix e -> m (Array r ix e) (./.) = liftNumericArray2M divisionPointwise {-# INLINE (./.) #-} (./) ::(Index ix, NumericFloat r e) => Array r ix e -> e -> Array r ix e (./) = divideScalar {-# INLINE (./) #-} (.^^) :: (Index ix, Numeric r e, Fractional e, Integral b) => Array r ix e -> b -> Array r ix e (.^^) arr n = unsafeLiftArray (^^ n) arr {-# INLINE (.^^) #-} recipA :: (Index ix, NumericFloat r e) => Array r ix e -> Array r ix e recipA = recipPointwise {-# INLINE recipA #-} fromRationalA :: (Index ix, Fractional e) => Rational -> Array D ix e fromRationalA = singleton . fromRational {-# INLINE fromRationalA #-} piA :: (Index ix, Floating e) => Array D ix e piA = singleton pi {-# INLINE piA #-} expA :: (Index ix, NumericFloat r e) => Array r ix e -> Array r ix e expA = unsafeLiftArray exp {-# INLINE expA #-} sqrtA :: (Index ix, NumericFloat r e) => Array r ix e -> Array r ix e sqrtA = unsafeLiftArray sqrt {-# INLINE sqrtA #-} logA :: (Index ix, NumericFloat r e) => Array r ix e -> Array r ix e logA = unsafeLiftArray log {-# INLINE logA #-} logBaseA :: (Source r1 ix e, Source r2 ix e, Floating e) => Array r1 ix e -> Array r2 ix e -> Array D ix e logBaseA = liftArray2Matching logBase {-# INLINE logBaseA #-} (.**) :: (Source r1 ix e, Source r2 ix e, Floating e) => Array r1 ix e -> Array r2 ix e -> Array D ix e (.**) = liftArray2Matching (**) {-# INLINE (.**) #-} sinA :: (Index ix, NumericFloat r e) => Array r ix e -> Array r ix e sinA = unsafeLiftArray sin {-# INLINE sinA #-} cosA :: (Index ix, NumericFloat r e) => Array r ix e -> Array r ix e cosA = unsafeLiftArray cos {-# INLINE cosA #-} tanA :: (Index ix, NumericFloat r e) => Array r ix e -> Array r ix e tanA = unsafeLiftArray tan {-# INLINE tanA #-} asinA :: (Index ix, NumericFloat r e) => Array r ix e -> Array r ix e asinA = unsafeLiftArray asin {-# INLINE asinA #-} atanA :: (Index ix, NumericFloat r e) => Array r ix e -> Array r ix e atanA = unsafeLiftArray atan {-# INLINE atanA #-} acosA :: (Index ix, NumericFloat r e) => Array r ix e -> Array r ix e acosA = unsafeLiftArray acos {-# INLINE acosA #-} sinhA :: (Index ix, NumericFloat r e) => Array r ix e -> Array r ix e sinhA = unsafeLiftArray sinh {-# INLINE sinhA #-} tanhA :: (Index ix, NumericFloat r e) => Array r ix e -> Array r ix e tanhA = unsafeLiftArray tanh {-# INLINE tanhA #-} coshA :: (Index ix, NumericFloat r e) => Array r ix e -> Array r ix e coshA = unsafeLiftArray cosh {-# INLINE coshA #-} asinhA :: (Index ix, NumericFloat r e) => Array r ix e -> Array r ix e asinhA = unsafeLiftArray asinh {-# INLINE asinhA #-} acoshA :: (Index ix, NumericFloat r e) => Array r ix e -> Array r ix e acoshA = unsafeLiftArray acosh {-# INLINE acoshA #-} atanhA :: (Index ix, NumericFloat r e) => Array r ix e -> Array r ix e atanhA = unsafeLiftArray atanh {-# INLINE atanhA #-} quotA :: (Source r1 ix e, Source r2 ix e, Integral e) => Array r1 ix e -> Array r2 ix e -> Array D ix e quotA = liftArray2Matching quot {-# INLINE quotA #-} remA :: (Source r1 ix e, Source r2 ix e, Integral e) => Array r1 ix e -> Array r2 ix e -> Array D ix e remA = liftArray2Matching rem {-# INLINE remA #-} divA :: (Source r1 ix e, Source r2 ix e, Integral e) => Array r1 ix e -> Array r2 ix e -> Array D ix e divA = liftArray2Matching div {-# INLINE divA #-} modA :: (Source r1 ix e, Source r2 ix e, Integral e) => Array r1 ix e -> Array r2 ix e -> Array D ix e modA = liftArray2Matching mod {-# INLINE modA #-} quotRemA :: (Source r1 ix e, Source r2 ix e, Integral e) => Array r1 ix e -> Array r2 ix e -> (Array D ix e, Array D ix e) quotRemA arr1 = A.unzip . liftArray2Matching (quotRem) arr1 {-# INLINE quotRemA #-} divModA :: (Source r1 ix e, Source r2 ix e, Integral e) => Array r1 ix e -> Array r2 ix e -> (Array D ix e, Array D ix e) divModA arr1 = A.unzip . liftArray2Matching (divMod) arr1 {-# INLINE divModA #-} truncateA :: (Index ix, Numeric r e, RealFrac a, Integral e) => Array r ix a -> Array r ix e truncateA = unsafeLiftArray truncate {-# INLINE truncateA #-} roundA :: (Index ix, Numeric r e, RealFrac a, Integral e) => Array r ix a -> Array r ix e roundA = unsafeLiftArray round {-# INLINE roundA #-} ceilingA :: (Index ix, Numeric r e, RealFrac a, Integral e) => Array r ix a -> Array r ix e ceilingA = unsafeLiftArray ceiling {-# INLINE ceilingA #-} floorA :: (Index ix, Numeric r e, RealFrac a, Integral e) => Array r ix a -> Array r ix e floorA = unsafeLiftArray floor {-# INLINE floorA #-} atan2A :: (Load r ix e, Numeric r e, RealFloat e, MonadThrow m) => Array r ix e -> Array r ix e -> m (Array r ix e) atan2A = liftArray2M atan2 {-# INLINE atan2A #-}