{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NoImplicitPrelude #-}
module Data.Array.Accelerate.Numeric.LinearAlgebra (
Numeric, Scalar, Vector, Matrix,
(<.>),
(><),
(#>), (<#),
(<>),
identity, diagonal, trace,
) where
import Data.Array.Accelerate as A
import Data.Array.Accelerate.Numeric.LinearAlgebra.Type
import Data.Array.Accelerate.Numeric.LinearAlgebra.BLAS.Level1
import Data.Array.Accelerate.Numeric.LinearAlgebra.BLAS.Level2
import Data.Array.Accelerate.Numeric.LinearAlgebra.BLAS.Level3
infixr 8 <.>
(<.>) :: Numeric e => Acc (Vector e) -> Acc (Vector e) -> Acc (Scalar e)
<.> :: Acc (Vector e) -> Acc (Vector e) -> Acc (Scalar e)
(<.>) = Acc (Vector e) -> Acc (Vector e) -> Acc (Scalar e)
forall e.
Numeric e =>
Acc (Vector e) -> Acc (Vector e) -> Acc (Scalar e)
dotu
infixr 8 ><
(><) :: Numeric e => Acc (Vector e) -> Acc (Vector e) -> Acc (Matrix e)
>< :: Acc (Vector e) -> Acc (Vector e) -> Acc (Matrix e)
(><) Acc (Vector e)
x Acc (Vector e)
y = Acc (Matrix e)
xc Acc (Matrix e) -> Acc (Matrix e) -> Acc (Matrix e)
forall e.
Numeric e =>
Acc (Matrix e) -> Acc (Matrix e) -> Acc (Matrix e)
<> Acc (Matrix e)
yr
where
xc :: Acc (Matrix e)
xc = Exp ((Z :. Int) :. Int) -> Acc (Vector e) -> Acc (Matrix e)
forall sh sh' e.
(Shape sh, Shape sh', Elt e) =>
Exp sh -> Acc (Array sh' e) -> Acc (Array sh e)
reshape (Exp Int -> Exp Int -> Exp ((Z :. Int) :. Int)
forall i. Elt i => Exp i -> Exp i -> Exp ((Z :. i) :. i)
index2 (Acc (Vector e) -> Exp Int
forall e. Elt e => Acc (Vector e) -> Exp Int
length Acc (Vector e)
x) Exp Int
1) Acc (Vector e)
x
yr :: Acc (Matrix e)
yr = Exp ((Z :. Int) :. Int) -> Acc (Vector e) -> Acc (Matrix e)
forall sh sh' e.
(Shape sh, Shape sh', Elt e) =>
Exp sh -> Acc (Array sh' e) -> Acc (Array sh e)
reshape (Exp Int -> Exp Int -> Exp ((Z :. Int) :. Int)
forall i. Elt i => Exp i -> Exp i -> Exp ((Z :. i) :. i)
index2 Exp Int
1 (Acc (Vector e) -> Exp Int
forall e. Elt e => Acc (Vector e) -> Exp Int
length Acc (Vector e)
y)) Acc (Vector e)
y
infixr 8 #>
(#>) :: Numeric e => Acc (Matrix e) -> Acc (Vector e) -> Acc (Vector e)
#> :: Acc (Matrix e) -> Acc (Vector e) -> Acc (Vector e)
(#>) Acc (Matrix e)
m Acc (Vector e)
x = Exp e
-> Transpose -> Acc (Matrix e) -> Acc (Vector e) -> Acc (Vector e)
forall e.
Numeric e =>
Exp e
-> Transpose -> Acc (Matrix e) -> Acc (Vector e) -> Acc (Vector e)
gemv Exp e
1 Transpose
N Acc (Matrix e)
m Acc (Vector e)
x
infixr 8 <#
(<#) :: Numeric e => Acc (Vector e) -> Acc (Matrix e) -> Acc (Vector e)
<# :: Acc (Vector e) -> Acc (Matrix e) -> Acc (Vector e)
(<#) Acc (Vector e)
x Acc (Matrix e)
m = Exp e
-> Transpose -> Acc (Matrix e) -> Acc (Vector e) -> Acc (Vector e)
forall e.
Numeric e =>
Exp e
-> Transpose -> Acc (Matrix e) -> Acc (Vector e) -> Acc (Vector e)
gemv Exp e
1 Transpose
T Acc (Matrix e)
m Acc (Vector e)
x
infixr 8 <>
(<>) :: Numeric e => Acc (Matrix e) -> Acc (Matrix e) -> Acc (Matrix e)
<> :: Acc (Matrix e) -> Acc (Matrix e) -> Acc (Matrix e)
(<>) Acc (Matrix e)
matA Acc (Matrix e)
matB = Exp e
-> Transpose
-> Acc (Matrix e)
-> Transpose
-> Acc (Matrix e)
-> Acc (Matrix e)
forall e.
Numeric e =>
Exp e
-> Transpose
-> Acc (Matrix e)
-> Transpose
-> Acc (Matrix e)
-> Acc (Matrix e)
gemm Exp e
1 Transpose
N Acc (Matrix e)
matA Transpose
N Acc (Matrix e)
matB
identity :: Num e => Exp Int -> Acc (Matrix e)
identity :: Exp Int -> Acc (Matrix e)
identity Exp Int
n = Acc (Vector e) -> Acc (Matrix e)
forall e. Num e => Acc (Vector e) -> Acc (Matrix e)
diagonal (Exp (Z :. Int) -> Exp e -> Acc (Vector e)
forall sh e.
(Shape sh, Elt e) =>
Exp sh -> Exp e -> Acc (Array sh e)
fill (Exp Int -> Exp (Z :. Int)
forall i. Elt i => Exp i -> Exp (Z :. i)
index1 Exp Int
n) Exp e
1)
diagonal :: Num e => Acc (Vector e) -> Acc (Matrix e)
diagonal :: Acc (Vector e) -> Acc (Matrix e)
diagonal Acc (Vector e)
v =
let n :: Exp Int
n = Acc (Vector e) -> Exp Int
forall e. Elt e => Acc (Vector e) -> Exp Int
length Acc (Vector e)
v
zeros :: Acc (Matrix e)
zeros = Exp ((Z :. Int) :. Int) -> Exp e -> Acc (Matrix e)
forall sh e.
(Shape sh, Elt e) =>
Exp sh -> Exp e -> Acc (Array sh e)
fill (Exp Int -> Exp Int -> Exp ((Z :. Int) :. Int)
forall x0 x1.
(Elt x0, Elt x1) =>
Exp x0 -> Exp x1 -> Exp ((Z :. x0) :. x1)
I2 Exp Int
n Exp Int
n) Exp e
0
in
(Exp e -> Exp e -> Exp e)
-> Acc (Matrix e)
-> (Exp (Z :. Int) -> Exp (Maybe ((Z :. Int) :. Int)))
-> Acc (Vector e)
-> Acc (Matrix e)
forall sh sh' a.
(Shape sh, Shape sh', Elt a) =>
(Exp a -> Exp a -> Exp a)
-> Acc (Array sh' a)
-> (Exp sh -> Exp (Maybe sh'))
-> Acc (Array sh a)
-> Acc (Array sh' a)
permute Exp e -> Exp e -> Exp e
forall a b. a -> b -> a
const Acc (Matrix e)
zeros (\(I1 Exp Int
i) -> Exp ((Z :. Int) :. Int) -> Exp (Maybe ((Z :. Int) :. Int))
forall a. (HasCallStack, Elt a) => Exp a -> Exp (Maybe a)
Just_ (Exp Int -> Exp Int -> Exp ((Z :. Int) :. Int)
forall x0 x1.
(Elt x0, Elt x1) =>
Exp x0 -> Exp x1 -> Exp ((Z :. x0) :. x1)
I2 Exp Int
i Exp Int
i)) Acc (Vector e)
v
trace :: Num e => Acc (Matrix e) -> Acc (Scalar e)
trace :: Acc (Matrix e) -> Acc (Scalar e)
trace Acc (Matrix e)
m =
let Z
Z :. Exp Int
h :. Exp Int
w = Exp (Plain ((Z :. Exp Int) :. Exp Int))
-> (Z :. Exp Int) :. Exp Int
forall (c :: * -> *) e. Unlift c e => c (Plain e) -> e
unlift (Acc (Matrix e) -> Exp ((Z :. Int) :. Int)
forall sh e. (Shape sh, Elt e) => Acc (Array sh e) -> Exp sh
shape Acc (Matrix e)
m)
in Acc (Array (Z :. Int) e) -> Acc (Scalar e)
forall sh e.
(Shape sh, Num e) =>
Acc (Array (sh :. Int) e) -> Acc (Array sh e)
sum (Exp (Z :. Int)
-> (Exp (Z :. Int) -> Exp ((Z :. Int) :. Int))
-> Acc (Matrix e)
-> Acc (Array (Z :. Int) e)
forall sh sh' a.
(Shape sh, Shape sh', Elt a) =>
Exp sh'
-> (Exp sh' -> Exp sh) -> Acc (Array sh a) -> Acc (Array sh' a)
backpermute (Exp Int -> Exp (Z :. Int)
forall i. Elt i => Exp i -> Exp (Z :. i)
index1 (Exp Int -> Exp Int -> Exp Int
forall a. Ord a => Exp a -> Exp a -> Exp a
min Exp Int
h Exp Int
w)) (\(I1 Exp Int
i) -> Exp Int -> Exp Int -> Exp ((Z :. Int) :. Int)
forall x0 x1.
(Elt x0, Elt x1) =>
Exp x0 -> Exp x1 -> Exp ((Z :. x0) :. x1)
I2 Exp Int
i Exp Int
i) Acc (Matrix e)
m)