{-# LANGUAGE ConstraintKinds   #-}
{-# LANGUAGE FlexibleContexts  #-}
{-# LANGUAGE NoImplicitPrelude #-}
-- |
-- Module      : Data.Array.Accelerate.Numeric.LinearAlgebra
-- Copyright   : [2017..2020] Trevor L. McDonell
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.Numeric.LinearAlgebra (

  -- * Types
  Numeric, Scalar, Vector, Matrix,

  -- * Products
  -- ** Vector-vector
  (<.>),
  (><),

  -- ** Matrix-vector
  (#>), (<#),

  -- ** Matrix-matrix
  (<>),

  -- * Diagonal
  identity, diagonal, trace,

) where

import Data.Array.Accelerate                                        as A

import Data.Array.Accelerate.Numeric.LinearAlgebra.Type
import Data.Array.Accelerate.Numeric.LinearAlgebra.BLAS.Level1
import Data.Array.Accelerate.Numeric.LinearAlgebra.BLAS.Level2
import Data.Array.Accelerate.Numeric.LinearAlgebra.BLAS.Level3


-- Level 1
-- -------

-- | An infix synonym for 'dotu'.
--
-- >>> let a = fromList (Z:.4) [1..]
-- >>> let b = fromList (Z:.4) [-2,0,1,1]
-- >>> a <.> b
-- Scalar Z [5.0]
--
-- >>> let c = fromList (Z:.2) [1:+1, 1:+0]
-- >>> let d = fromList (Z:.2) [1:+0, 1:+(-1)]
-- >>> c <.> d
-- Scalar Z [2.0 :+ 0.0]
--
infixr 8 <.>
(<.>) :: Numeric e => Acc (Vector e) -> Acc (Vector e) -> Acc (Scalar e)
<.> :: Acc (Vector e) -> Acc (Vector e) -> Acc (Scalar e)
(<.>) = Acc (Vector e) -> Acc (Vector e) -> Acc (Scalar e)
forall e.
Numeric e =>
Acc (Vector e) -> Acc (Vector e) -> Acc (Scalar e)
dotu


-- | Outer product of two vectors
--
-- >>> let a = fromList (Z :. 3) [1,2,3]
-- >>> let b = fromList (Z :. 3) [5,2,3]
-- >>> a >< b
--  Matrix (Z :. 3 :. 3)
--    [  5.0, 2.0, 3.0
--    , 10.0, 4.0, 6.0
--    , 15.0, 6.0, 9.0 ]
--
infixr 8 ><
(><) :: Numeric e => Acc (Vector e) -> Acc (Vector e) -> Acc (Matrix e)
>< :: Acc (Vector e) -> Acc (Vector e) -> Acc (Matrix e)
(><) Acc (Vector e)
x Acc (Vector e)
y = Acc (Matrix e)
xc Acc (Matrix e) -> Acc (Matrix e) -> Acc (Matrix e)
forall e.
Numeric e =>
Acc (Matrix e) -> Acc (Matrix e) -> Acc (Matrix e)
<> Acc (Matrix e)
yr
  where
    xc :: Acc (Matrix e)
xc = Exp ((Z :. Int) :. Int) -> Acc (Vector e) -> Acc (Matrix e)
forall sh sh' e.
(Shape sh, Shape sh', Elt e) =>
Exp sh -> Acc (Array sh' e) -> Acc (Array sh e)
reshape (Exp Int -> Exp Int -> Exp ((Z :. Int) :. Int)
forall i. Elt i => Exp i -> Exp i -> Exp ((Z :. i) :. i)
index2 (Acc (Vector e) -> Exp Int
forall e. Elt e => Acc (Vector e) -> Exp Int
length Acc (Vector e)
x) Exp Int
1) Acc (Vector e)
x
    yr :: Acc (Matrix e)
yr = Exp ((Z :. Int) :. Int) -> Acc (Vector e) -> Acc (Matrix e)
forall sh sh' e.
(Shape sh, Shape sh', Elt e) =>
Exp sh -> Acc (Array sh' e) -> Acc (Array sh e)
reshape (Exp Int -> Exp Int -> Exp ((Z :. Int) :. Int)
forall i. Elt i => Exp i -> Exp i -> Exp ((Z :. i) :. i)
index2 Exp Int
1 (Acc (Vector e) -> Exp Int
forall e. Elt e => Acc (Vector e) -> Exp Int
length Acc (Vector e)
y)) Acc (Vector e)
y


-- Level 2
-- -------

-- | Dense matrix-vector product
--
-- >>> let m = fromList (Z :. 2 :. 3) [1..]
-- >>> m
-- Matrix (Z :. 2 :. 3)
--  [ 1.0, 2.0, 3.0
--  , 4.0, 5.0, 6.0 ]
--
-- >>> let x = fromList (Z :. 3) [10,20,30]
--
-- >>> m #> x
-- Vector (Z :. 2) [140.0,320.0]
--
-- See 'gemv' for a more general version of this operation.
--
infixr 8 #>
(#>) :: Numeric e => Acc (Matrix e) -> Acc (Vector e) -> Acc (Vector e)
#> :: Acc (Matrix e) -> Acc (Vector e) -> Acc (Vector e)
(#>) Acc (Matrix e)
m Acc (Vector e)
x = Exp e
-> Transpose -> Acc (Matrix e) -> Acc (Vector e) -> Acc (Vector e)
forall e.
Numeric e =>
Exp e
-> Transpose -> Acc (Matrix e) -> Acc (Vector e) -> Acc (Vector e)
gemv Exp e
1 Transpose
N Acc (Matrix e)
m Acc (Vector e)
x


-- | Dense vector-matrix product
--
-- >>> let m = fromList (Z :. 2 :. 3) [1..]
-- >>> m
-- Matrix (Z :. 2 :. 3)
--  [1.0,2.0,3.0,
--   4.0,5.0,6.0]
--
-- >>> let v = fromList (Z :. 2) [5,10]
--
-- >>> v <# m
-- Vector (Z :. 3) [45.0,60.0,75.0]
--
-- See 'gemv' for a more general version of this operation.
--
infixr 8 <#
(<#) :: Numeric e => Acc (Vector e) -> Acc (Matrix e) -> Acc (Vector e)
<# :: Acc (Vector e) -> Acc (Matrix e) -> Acc (Vector e)
(<#) Acc (Vector e)
x Acc (Matrix e)
m = Exp e
-> Transpose -> Acc (Matrix e) -> Acc (Vector e) -> Acc (Vector e)
forall e.
Numeric e =>
Exp e
-> Transpose -> Acc (Matrix e) -> Acc (Vector e) -> Acc (Vector e)
gemv Exp e
1 Transpose
T Acc (Matrix e)
m Acc (Vector e)
x


-- Level 3
-- -------

-- | Dense matrix-matrix product
--
-- >>> let a = fromList (Z :. 3 :. 5) [1..]
-- >>> a
-- Matrix (Z:.3:.5)
--  [  1.0,  2.0,  3.0,  4.0,  5.0
--  ,  6.0,  7.0,  8.0,  9.0, 10.0
--  , 11.0, 12.0, 13.0, 14.0, 15.0 ]
--
-- >>> let b = fromList (Z :. 5 :. 2) [1,3, 0,2, -1,5, 7,7, 6,0]
-- >>> b
-- Matrix (Z :. 5 :. 2)
--  [  1.0, 3.0
--  ,  0.0, 2.0
--  , -1.0, 5.0
--  ,  7.0, 7.0
--  ,  6.0, 0.0 ]
--
-- >>> a <> b
-- Matrix (Z :. 3 :. 2)
--  [  56.0,  50.0
--  , 121.0, 135.0
--  , 186.0, 220.0 ]
--
-- See 'gemm' for a more general version of this operation.
--
infixr 8 <>
(<>) :: Numeric e => Acc (Matrix e) -> Acc (Matrix e) -> Acc (Matrix e)
<> :: Acc (Matrix e) -> Acc (Matrix e) -> Acc (Matrix e)
(<>) Acc (Matrix e)
matA Acc (Matrix e)
matB = Exp e
-> Transpose
-> Acc (Matrix e)
-> Transpose
-> Acc (Matrix e)
-> Acc (Matrix e)
forall e.
Numeric e =>
Exp e
-> Transpose
-> Acc (Matrix e)
-> Transpose
-> Acc (Matrix e)
-> Acc (Matrix e)
gemm Exp e
1 Transpose
N Acc (Matrix e)
matA Transpose
N Acc (Matrix e)
matB


-- | Create a square identity matrix of the given dimension
--
identity :: Num e => Exp Int -> Acc (Matrix e)
identity :: Exp Int -> Acc (Matrix e)
identity Exp Int
n = Acc (Vector e) -> Acc (Matrix e)
forall e. Num e => Acc (Vector e) -> Acc (Matrix e)
diagonal (Exp (Z :. Int) -> Exp e -> Acc (Vector e)
forall sh e.
(Shape sh, Elt e) =>
Exp sh -> Exp e -> Acc (Array sh e)
fill (Exp Int -> Exp (Z :. Int)
forall i. Elt i => Exp i -> Exp (Z :. i)
index1 Exp Int
n) Exp e
1)

-- | Create a square matrix with the given diagonal
--
diagonal :: Num e => Acc (Vector e) -> Acc (Matrix e)
diagonal :: Acc (Vector e) -> Acc (Matrix e)
diagonal Acc (Vector e)
v =
  let n :: Exp Int
n     = Acc (Vector e) -> Exp Int
forall e. Elt e => Acc (Vector e) -> Exp Int
length Acc (Vector e)
v
      zeros :: Acc (Matrix e)
zeros = Exp ((Z :. Int) :. Int) -> Exp e -> Acc (Matrix e)
forall sh e.
(Shape sh, Elt e) =>
Exp sh -> Exp e -> Acc (Array sh e)
fill (Exp Int -> Exp Int -> Exp ((Z :. Int) :. Int)
forall x0 x1.
(Elt x0, Elt x1) =>
Exp x0 -> Exp x1 -> Exp ((Z :. x0) :. x1)
I2 Exp Int
n Exp Int
n) Exp e
0
  in
  (Exp e -> Exp e -> Exp e)
-> Acc (Matrix e)
-> (Exp (Z :. Int) -> Exp (Maybe ((Z :. Int) :. Int)))
-> Acc (Vector e)
-> Acc (Matrix e)
forall sh sh' a.
(Shape sh, Shape sh', Elt a) =>
(Exp a -> Exp a -> Exp a)
-> Acc (Array sh' a)
-> (Exp sh -> Exp (Maybe sh'))
-> Acc (Array sh a)
-> Acc (Array sh' a)
permute Exp e -> Exp e -> Exp e
forall a b. a -> b -> a
const Acc (Matrix e)
zeros (\(I1 Exp Int
i) -> Exp ((Z :. Int) :. Int) -> Exp (Maybe ((Z :. Int) :. Int))
forall a. (HasCallStack, Elt a) => Exp a -> Exp (Maybe a)
Just_ (Exp Int -> Exp Int -> Exp ((Z :. Int) :. Int)
forall x0 x1.
(Elt x0, Elt x1) =>
Exp x0 -> Exp x1 -> Exp ((Z :. x0) :. x1)
I2 Exp Int
i Exp Int
i)) Acc (Vector e)
v

-- | The sum of the diagonal elements of a (square) matrix
--
trace :: Num e => Acc (Matrix e) -> Acc (Scalar e)
trace :: Acc (Matrix e) -> Acc (Scalar e)
trace Acc (Matrix e)
m =
  let Z
Z :. Exp Int
h :. Exp Int
w = Exp (Plain ((Z :. Exp Int) :. Exp Int))
-> (Z :. Exp Int) :. Exp Int
forall (c :: * -> *) e. Unlift c e => c (Plain e) -> e
unlift (Acc (Matrix e) -> Exp ((Z :. Int) :. Int)
forall sh e. (Shape sh, Elt e) => Acc (Array sh e) -> Exp sh
shape Acc (Matrix e)
m)
  in  Acc (Array (Z :. Int) e) -> Acc (Scalar e)
forall sh e.
(Shape sh, Num e) =>
Acc (Array (sh :. Int) e) -> Acc (Array sh e)
sum (Exp (Z :. Int)
-> (Exp (Z :. Int) -> Exp ((Z :. Int) :. Int))
-> Acc (Matrix e)
-> Acc (Array (Z :. Int) e)
forall sh sh' a.
(Shape sh, Shape sh', Elt a) =>
Exp sh'
-> (Exp sh' -> Exp sh) -> Acc (Array sh a) -> Acc (Array sh' a)
backpermute (Exp Int -> Exp (Z :. Int)
forall i. Elt i => Exp i -> Exp (Z :. i)
index1 (Exp Int -> Exp Int -> Exp Int
forall a. Ord a => Exp a -> Exp a -> Exp a
min Exp Int
h Exp Int
w)) (\(I1 Exp Int
i) -> Exp Int -> Exp Int -> Exp ((Z :. Int) :. Int)
forall x0 x1.
(Elt x0, Elt x1) =>
Exp x0 -> Exp x1 -> Exp ((Z :. x0) :. x1)
I2 Exp Int
i Exp Int
i) Acc (Matrix e)
m)