{-# OPTIONS -fno-warn-incomplete-patterns #-}
{-# LANGUAGE PackageImports #-}

-- | Algorithms operating on matrices.
-- 
--   These functions should give performance comparable with nested loop C
--   implementations. 
-- 
--   If you care deeply about runtime performance then you
--   may be better off using a binding to LAPACK, such as hvector.
--
module Data.Array.Repa.Algorithms.Matrix
        ( --  * Projections
          row
        , col

          -- * Matrix Multiplication.
        , mmultP,      mmultS

          -- * Transposition.
        , transpose2P, transpose2S

          -- * Trace.
        , trace2P, trace2S)

where
import Data.Array.Repa                  as R
import Data.Array.Repa.Eval             as R
import Data.Array.Repa.Unsafe           as R
import Control.Monad
import Control.Monad.ST.Strict


-- Projections ----------------------------------------------------------------
-- | Take the row number of a rank-2 index.
row :: DIM2 -> Int
row :: DIM2 -> Int
row (DIM0
Z :. Int
r :. Int
_) = Int
r
{-# INLINE row #-}


-- | Take the column number of a rank-2 index.
col :: DIM2 -> Int
col :: DIM2 -> Int
col (DIM0
Z :. Int
_ :. Int
c) = Int
c
{-# INLINE col #-}


-- MMult ----------------------------------------------------------------------
-- | Matrix matrix multiply, in parallel.
mmultP  :: Monad m
        => Array U DIM2 Double 
        -> Array U DIM2 Double 
        -> m (Array U DIM2 Double)

mmultP :: Array U DIM2 Double
-> Array U DIM2 Double -> m (Array U DIM2 Double)
mmultP Array U DIM2 Double
arr Array U DIM2 Double
brr 
 = [Array U DIM2 Double
arr, Array U DIM2 Double
brr] [Array U DIM2 Double]
-> m (Array U DIM2 Double) -> m (Array U DIM2 Double)
forall sh r e b. (Shape sh, Source r e) => [Array r sh e] -> b -> b
`deepSeqArrays` 
   do   Array U DIM2 Double
trr      <- Array U DIM2 Double -> m (Array U DIM2 Double)
forall (m :: * -> *).
Monad m =>
Array U DIM2 Double -> m (Array U DIM2 Double)
transpose2P Array U DIM2 Double
brr
        let (DIM0
Z :. Int
h1  :. Int
_)  = Array U DIM2 Double -> DIM2
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array U DIM2 Double
arr
        let (DIM0
Z :. Int
_   :. Int
w2) = Array U DIM2 Double -> DIM2
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array U DIM2 Double
brr
        Array U DIM2 Double
trr Array U DIM2 Double
-> (Array D DIM2 Double -> m (Array U DIM2 Double))
-> Array D DIM2 Double
-> m (Array U DIM2 Double)
forall r e sh b. (Source r e, Shape sh) => Array r sh e -> b -> b
`deepSeqArray` Array D DIM2 Double -> m (Array U DIM2 Double)
forall r1 sh e r2 (m :: * -> *).
(Load r1 sh e, Target r2 e, Source r2 e, Monad m) =>
Array r1 sh e -> m (Array r2 sh e)
computeP 
         (Array D DIM2 Double -> m (Array U DIM2 Double))
-> Array D DIM2 Double -> m (Array U DIM2 Double)
forall a b. (a -> b) -> a -> b
$ DIM2 -> (DIM2 -> Double) -> Array D DIM2 Double
forall sh a. sh -> (sh -> a) -> Array D sh a
fromFunction (DIM0
Z DIM0 -> Int -> DIM0 :. Int
forall tail head. tail -> head -> tail :. head
:. Int
h1 (DIM0 :. Int) -> Int -> DIM2
forall tail head. tail -> head -> tail :. head
:. Int
w2)
         ((DIM2 -> Double) -> Array D DIM2 Double)
-> (DIM2 -> Double) -> Array D DIM2 Double
forall a b. (a -> b) -> a -> b
$ \DIM2
ix   -> Array D (DIM0 :. Int) Double -> Double
forall sh r a. (Shape sh, Source r a, Num a) => Array r sh a -> a
R.sumAllS 
                  (Array D (DIM0 :. Int) Double -> Double)
-> Array D (DIM0 :. Int) Double -> Double
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Double)
-> Array D (DIM0 :. Int) Double
-> Array D (DIM0 :. Int) Double
-> Array D (DIM0 :. Int) Double
forall sh r1 a r2 b c.
(Shape sh, Source r1 a, Source r2 b) =>
(a -> b -> c) -> Array r1 sh a -> Array r2 sh b -> Array D sh c
R.zipWith Double -> Double -> Double
forall a. Num a => a -> a -> a
(*)
                        (Array U (FullShape ((Any DIM0 :. Int) :. All)) Double
-> ((Any DIM0 :. Int) :. All)
-> Array D (SliceShape ((Any DIM0 :. Int) :. All)) Double
forall sl r e.
(Slice sl, Shape (FullShape sl), Source r e) =>
Array r (FullShape sl) e -> sl -> Array D (SliceShape sl) e
unsafeSlice Array U (FullShape ((Any DIM0 :. Int) :. All)) Double
Array U DIM2 Double
arr (Any DIM0
forall sh. Any sh
Any Any DIM0 -> Int -> Any DIM0 :. Int
forall tail head. tail -> head -> tail :. head
:. (DIM2 -> Int
row DIM2
ix) (Any DIM0 :. Int) -> All -> (Any DIM0 :. Int) :. All
forall tail head. tail -> head -> tail :. head
:. All
All))
                        (Array U (FullShape ((Any DIM0 :. Int) :. All)) Double
-> ((Any DIM0 :. Int) :. All)
-> Array D (SliceShape ((Any DIM0 :. Int) :. All)) Double
forall sl r e.
(Slice sl, Shape (FullShape sl), Source r e) =>
Array r (FullShape sl) e -> sl -> Array D (SliceShape sl) e
unsafeSlice Array U (FullShape ((Any DIM0 :. Int) :. All)) Double
Array U DIM2 Double
trr (Any DIM0
forall sh. Any sh
Any Any DIM0 -> Int -> Any DIM0 :. Int
forall tail head. tail -> head -> tail :. head
:. (DIM2 -> Int
col DIM2
ix) (Any DIM0 :. Int) -> All -> (Any DIM0 :. Int) :. All
forall tail head. tail -> head -> tail :. head
:. All
All))
{-# NOINLINE mmultP #-}


-- | Matrix matrix multiply, sequentially.
mmultS  :: Array U DIM2 Double 
        -> Array U DIM2 Double 
        -> Array U DIM2 Double

mmultS :: Array U DIM2 Double -> Array U DIM2 Double -> Array U DIM2 Double
mmultS Array U DIM2 Double
arr Array U DIM2 Double
brr
 = [Array U DIM2 Double
arr, Array U DIM2 Double
brr]  [Array U DIM2 Double] -> Array U DIM2 Double -> Array U DIM2 Double
forall sh r e b. (Shape sh, Source r e) => [Array r sh e] -> b -> b
`deepSeqArrays` ((forall s. ST s (Array U DIM2 Double)) -> Array U DIM2 Double
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Array U DIM2 Double)) -> Array U DIM2 Double)
-> (forall s. ST s (Array U DIM2 Double)) -> Array U DIM2 Double
forall a b. (a -> b) -> a -> b
$
   do   Array U DIM2 Double
trr     <- Array U DIM2 Double -> ST s (Array U DIM2 Double)
forall sh r e (m :: * -> *).
(Shape sh, Source r e, Monad m) =>
Array r sh e -> m (Array r sh e)
R.now (Array U DIM2 Double -> ST s (Array U DIM2 Double))
-> Array U DIM2 Double -> ST s (Array U DIM2 Double)
forall a b. (a -> b) -> a -> b
$ Array U DIM2 Double -> Array U DIM2 Double
transpose2S Array U DIM2 Double
brr
        let (DIM0
Z :. Int
h1  :. Int
_)  = Array U DIM2 Double -> DIM2
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array U DIM2 Double
arr
        let (DIM0
Z :. Int
_   :. Int
w2) = Array U DIM2 Double -> DIM2
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array U DIM2 Double
brr
        Array U DIM2 Double -> ST s (Array U DIM2 Double)
forall (m :: * -> *) a. Monad m => a -> m a
return (Array U DIM2 Double -> ST s (Array U DIM2 Double))
-> Array U DIM2 Double -> ST s (Array U DIM2 Double)
forall a b. (a -> b) -> a -> b
$ Array D DIM2 Double -> Array U DIM2 Double
forall r1 sh e r2.
(Load r1 sh e, Target r2 e) =>
Array r1 sh e -> Array r2 sh e
computeS 
         (Array D DIM2 Double -> Array U DIM2 Double)
-> Array D DIM2 Double -> Array U DIM2 Double
forall a b. (a -> b) -> a -> b
$ DIM2 -> (DIM2 -> Double) -> Array D DIM2 Double
forall sh a. sh -> (sh -> a) -> Array D sh a
fromFunction (DIM0
Z DIM0 -> Int -> DIM0 :. Int
forall tail head. tail -> head -> tail :. head
:. Int
h1 (DIM0 :. Int) -> Int -> DIM2
forall tail head. tail -> head -> tail :. head
:. Int
w2)
         ((DIM2 -> Double) -> Array D DIM2 Double)
-> (DIM2 -> Double) -> Array D DIM2 Double
forall a b. (a -> b) -> a -> b
$ \DIM2
ix   -> Array D (DIM0 :. Int) Double -> Double
forall sh r a. (Shape sh, Source r a, Num a) => Array r sh a -> a
R.sumAllS 
                  (Array D (DIM0 :. Int) Double -> Double)
-> Array D (DIM0 :. Int) Double -> Double
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Double)
-> Array D (DIM0 :. Int) Double
-> Array D (DIM0 :. Int) Double
-> Array D (DIM0 :. Int) Double
forall sh r1 a r2 b c.
(Shape sh, Source r1 a, Source r2 b) =>
(a -> b -> c) -> Array r1 sh a -> Array r2 sh b -> Array D sh c
R.zipWith Double -> Double -> Double
forall a. Num a => a -> a -> a
(*)
                        (Array U (FullShape ((Any DIM0 :. Int) :. All)) Double
-> ((Any DIM0 :. Int) :. All)
-> Array D (SliceShape ((Any DIM0 :. Int) :. All)) Double
forall sl r e.
(Slice sl, Shape (FullShape sl), Source r e) =>
Array r (FullShape sl) e -> sl -> Array D (SliceShape sl) e
unsafeSlice Array U (FullShape ((Any DIM0 :. Int) :. All)) Double
Array U DIM2 Double
arr (Any DIM0
forall sh. Any sh
Any Any DIM0 -> Int -> Any DIM0 :. Int
forall tail head. tail -> head -> tail :. head
:. (DIM2 -> Int
row DIM2
ix) (Any DIM0 :. Int) -> All -> (Any DIM0 :. Int) :. All
forall tail head. tail -> head -> tail :. head
:. All
All))
                        (Array U (FullShape ((Any DIM0 :. Int) :. All)) Double
-> ((Any DIM0 :. Int) :. All)
-> Array D (SliceShape ((Any DIM0 :. Int) :. All)) Double
forall sl r e.
(Slice sl, Shape (FullShape sl), Source r e) =>
Array r (FullShape sl) e -> sl -> Array D (SliceShape sl) e
unsafeSlice Array U (FullShape ((Any DIM0 :. Int) :. All)) Double
Array U DIM2 Double
trr (Any DIM0
forall sh. Any sh
Any Any DIM0 -> Int -> Any DIM0 :. Int
forall tail head. tail -> head -> tail :. head
:. (DIM2 -> Int
col DIM2
ix) (Any DIM0 :. Int) -> All -> (Any DIM0 :. Int) :. All
forall tail head. tail -> head -> tail :. head
:. All
All)))
{-# NOINLINE mmultS #-}


-- Transpose ------------------------------------------------------------------
-- | Transpose a 2D matrix, in parallel.
transpose2P
        :: Monad m 
        => Array U DIM2 Double 
        -> m (Array U DIM2 Double)

transpose2P :: Array U DIM2 Double -> m (Array U DIM2 Double)
transpose2P Array U DIM2 Double
arr
 = Array U DIM2 Double
arr Array U DIM2 Double
-> m (Array U DIM2 Double) -> m (Array U DIM2 Double)
forall r e sh b. (Source r e, Shape sh) => Array r sh e -> b -> b
`deepSeqArray`
   do   Array D DIM2 Double -> m (Array U DIM2 Double)
forall r1 sh e (m :: * -> *).
(Load r1 sh e, Monad m, Unbox e) =>
Array r1 sh e -> m (Array U sh e)
computeUnboxedP 
         (Array D DIM2 Double -> m (Array U DIM2 Double))
-> Array D DIM2 Double -> m (Array U DIM2 Double)
forall a b. (a -> b) -> a -> b
$ DIM2
-> (DIM2 -> DIM2) -> Array U DIM2 Double -> Array D DIM2 Double
forall r sh1 sh2 e.
(Shape sh1, Source r e) =>
sh2 -> (sh2 -> sh1) -> Array r sh1 e -> Array D sh2 e
unsafeBackpermute DIM2
new_extent DIM2 -> DIM2
forall head head.
((DIM0 :. head) :. head) -> (DIM0 :. head) :. head
swap Array U DIM2 Double
arr
 where  swap :: ((DIM0 :. head) :. head) -> (DIM0 :. head) :. head
swap (DIM0
Z :. head
i :. head
j)      = DIM0
Z DIM0 -> head -> DIM0 :. head
forall tail head. tail -> head -> tail :. head
:. head
j (DIM0 :. head) -> head -> (DIM0 :. head) :. head
forall tail head. tail -> head -> tail :. head
:. head
i
        new_extent :: DIM2
new_extent              = DIM2 -> DIM2
forall head head.
((DIM0 :. head) :. head) -> (DIM0 :. head) :. head
swap (Array U DIM2 Double -> DIM2
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array U DIM2 Double
arr)
{-# NOINLINE transpose2P #-}


-- | Transpose a 2D matrix, sequentially.
transpose2S
        :: Array U DIM2 Double 
        -> Array U DIM2 Double

transpose2S :: Array U DIM2 Double -> Array U DIM2 Double
transpose2S Array U DIM2 Double
arr
 = Array U DIM2 Double
arr Array U DIM2 Double -> Array U DIM2 Double -> Array U DIM2 Double
forall r e sh b. (Source r e, Shape sh) => Array r sh e -> b -> b
`deepSeqArray`
   do   Array D DIM2 Double -> Array U DIM2 Double
forall r1 sh e.
(Load r1 sh e, Unbox e) =>
Array r1 sh e -> Array U sh e
computeUnboxedS
         (Array D DIM2 Double -> Array U DIM2 Double)
-> Array D DIM2 Double -> Array U DIM2 Double
forall a b. (a -> b) -> a -> b
$ DIM2
-> (DIM2 -> DIM2) -> Array U DIM2 Double -> Array D DIM2 Double
forall r sh1 sh2 e.
(Shape sh1, Source r e) =>
sh2 -> (sh2 -> sh1) -> Array r sh1 e -> Array D sh2 e
unsafeBackpermute DIM2
new_extent DIM2 -> DIM2
forall head head.
((DIM0 :. head) :. head) -> (DIM0 :. head) :. head
swap Array U DIM2 Double
arr
 where  swap :: ((DIM0 :. head) :. head) -> (DIM0 :. head) :. head
swap (DIM0
Z :. head
i :. head
j)      = DIM0
Z DIM0 -> head -> DIM0 :. head
forall tail head. tail -> head -> tail :. head
:. head
j (DIM0 :. head) -> head -> (DIM0 :. head) :. head
forall tail head. tail -> head -> tail :. head
:. head
i
        new_extent :: DIM2
new_extent              = DIM2 -> DIM2
forall head head.
((DIM0 :. head) :. head) -> (DIM0 :. head) :. head
swap (Array U DIM2 Double -> DIM2
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array U DIM2 Double
arr)
{-# NOINLINE transpose2S #-}


-- Trace ------------------------------------------------------------------------
-- | Get the trace of a (square) 2D matrix, in parallel.
trace2P :: Monad m => Array U DIM2 Double -> m Double
trace2P :: Array U DIM2 Double -> m Double
trace2P Array U DIM2 Double
x
 = Array D (DIM0 :. Int) Double -> m Double
forall sh r a (m :: * -> *).
(Shape sh, Source r a, Unbox a, Num a, Monad m) =>
Array r sh a -> m a
sumAllP (Array D (DIM0 :. Int) Double -> m Double)
-> Array D (DIM0 :. Int) Double -> m Double
forall a b. (a -> b) -> a -> b
$ (DIM0 :. Int)
-> ((DIM0 :. Int) -> DIM2)
-> Array U DIM2 Double
-> Array D (DIM0 :. Int) Double
forall r sh1 sh2 e.
(Shape sh1, Source r e) =>
sh2 -> (sh2 -> sh1) -> Array r sh1 e -> Array D sh2 e
unsafeBackpermute (DIM0
Z DIM0 -> Int -> DIM0 :. Int
forall tail head. tail -> head -> tail :. head
:. (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
nRows Int
nColumns)) (\(DIM0
Z :. Int
i) -> (DIM0
Z DIM0 -> Int -> DIM0 :. Int
forall tail head. tail -> head -> tail :. head
:. Int
i (DIM0 :. Int) -> Int -> DIM2
forall tail head. tail -> head -> tail :. head
:. Int
i)) Array U DIM2 Double
x
 where
    (DIM0
Z :. Int
nRows :. Int
nColumns) = Array U DIM2 Double -> DIM2
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array U DIM2 Double
x


-- | Get the trace of a (square) 2D matrix, sequentially.
trace2S :: Array U DIM2 Double -> Double
trace2S :: Array U DIM2 Double -> Double
trace2S Array U DIM2 Double
x
 = Array D (DIM0 :. Int) Double -> Double
forall sh r a. (Shape sh, Source r a, Num a) => Array r sh a -> a
sumAllS (Array D (DIM0 :. Int) Double -> Double)
-> Array D (DIM0 :. Int) Double -> Double
forall a b. (a -> b) -> a -> b
$ (DIM0 :. Int)
-> ((DIM0 :. Int) -> DIM2)
-> Array U DIM2 Double
-> Array D (DIM0 :. Int) Double
forall r sh1 sh2 e.
(Shape sh1, Source r e) =>
sh2 -> (sh2 -> sh1) -> Array r sh1 e -> Array D sh2 e
unsafeBackpermute (DIM0
Z DIM0 -> Int -> DIM0 :. Int
forall tail head. tail -> head -> tail :. head
:. (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
nRows Int
nColumns)) (\(DIM0
Z :. Int
i) -> (DIM0
Z DIM0 -> Int -> DIM0 :. Int
forall tail head. tail -> head -> tail :. head
:. Int
i (DIM0 :. Int) -> Int -> DIM2
forall tail head. tail -> head -> tail :. head
:. Int
i)) Array U DIM2 Double
x
 where
    (DIM0
Z :. Int
nRows :. Int
nColumns) = Array U DIM2 Double -> DIM2
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array U DIM2 Double
x