{-# LANGUAGE ViewPatterns        #-}
--------------------------------------------------------------------------------
-- |
-- Module      : ArrayFire.BLAS
-- Copyright   : David Johnson (c) 2019-2020
-- License     : BSD3
-- Maintainer  : David Johnson <djohnson.m@gmail.com>
-- Stability   : Experimental
-- Portability : GHC
--
-- Basic Linear Algebra Subprograms (BLAS) API
--
-- @
-- main :: IO ()
-- main = print (matmul x y xProp yProp)
--  where
--     x,y :: Array Double
--     x = matrix (2,3) [[1,2],[3,4],[5,6]]
--     y = matrix (3,2) [[1,2,3],[4,5,6]]
--
--     xProp, yProp :: MatProp
--     xProp = None
--     yProp = None
-- @
-- @
--  ArrayFire Array
--  [2 2 1 1]
--     22.0000    49.0000
--     28.0000    64.0000
-- @
--------------------------------------------------------------------------------
module ArrayFire.BLAS where

import Data.Complex

import ArrayFire.FFI
import ArrayFire.Internal.BLAS
import ArrayFire.Internal.Types

-- | The following applies for Sparse-Dense matrix multiplication.
--
-- This function can be used with one sparse input. The sparse input must always be the lhs and the dense matrix must be rhs.
--
-- The sparse array can only be of 'CSR' format.
--
-- The returned array is always dense.
--
-- optLhs an only be one of AF_MAT_NONE, AF_MAT_TRANS, AF_MAT_CTRANS.
--
-- optRhs can only be AF_MAT_NONE.
--
-- >>> matmul (matrix @Double (2,2) [[1,2],[3,4]]) (matrix @Double (2,2) [[1,2],[3,4]]) None None
-- ArrayFire Array
-- [2 2 1 1]
--     7.0000    15.0000
--    10.0000    22.0000
matmul
  :: AFType a
  => Array a
  -- ^ 2D matrix of Array a, left-hand side
  -> Array a
  -- ^ 2D matrix of Array a, right-hand side
  -> MatProp
  -- ^ Left hand side matrix options
  -> MatProp
  -- ^ Right hand side matrix options
  -> Array a
  -- ^ Output of 'matmul'
matmul :: forall a.
AFType a =>
Array a -> Array a -> MatProp -> MatProp -> Array a
matmul Array a
arr1 Array a
arr2 MatProp
prop1 MatProp
prop2 = do
  Array a
-> Array a
-> (Ptr AFArray -> AFArray -> AFArray -> IO AFErr)
-> Array a
forall b a.
Array b
-> Array a
-> (Ptr AFArray -> AFArray -> AFArray -> IO AFErr)
-> Array a
op2 Array a
arr1 Array a
arr2 (\Ptr AFArray
p AFArray
a AFArray
b -> Ptr AFArray
-> AFArray -> AFArray -> AFMatProp -> AFMatProp -> IO AFErr
af_matmul Ptr AFArray
p AFArray
a AFArray
b (MatProp -> AFMatProp
toMatProp MatProp
prop1) (MatProp -> AFMatProp
toMatProp MatProp
prop2))

-- | Scalar dot product between two vectors. Also referred to as the inner product.
--
-- >>> dot (vector @Double 10 [1..]) (vector @Double 10 [1..]) None None
-- ArrayFire Array
-- [1 1 1 1]
--   385.0000
dot
  :: AFType a
  => Array a
  -- ^ Left-hand side input
  -> Array a
  -- ^ Right-hand side input
  -> MatProp
  -- ^ Options for left-hand side. Currently only AF_MAT_NONE and AF_MAT_CONJ are supported.
  -> MatProp
  -- ^ Options for right-hand side. Currently only AF_MAT_NONE and AF_MAT_CONJ are supported.
  -> Array a
  -- ^ Output of 'dot'
dot :: forall a.
AFType a =>
Array a -> Array a -> MatProp -> MatProp -> Array a
dot Array a
arr1 Array a
arr2 MatProp
prop1 MatProp
prop2 =
  Array a
-> Array a
-> (Ptr AFArray -> AFArray -> AFArray -> IO AFErr)
-> Array a
forall b a.
Array b
-> Array a
-> (Ptr AFArray -> AFArray -> AFArray -> IO AFErr)
-> Array a
op2 Array a
arr1 Array a
arr2 (\Ptr AFArray
p AFArray
a AFArray
b -> Ptr AFArray
-> AFArray -> AFArray -> AFMatProp -> AFMatProp -> IO AFErr
af_dot Ptr AFArray
p AFArray
a AFArray
b (MatProp -> AFMatProp
toMatProp MatProp
prop1) (MatProp -> AFMatProp
toMatProp MatProp
prop2))

-- | Scalar dot product between two vectors. Also referred to as the inner product. Returns the result as a host scalar.
--
-- >>> dotAll (vector @Double 10 [1..]) (vector @Double 10 [1..]) None None
-- 385.0 :+ 0.0
dotAll
  :: AFType a
  => Array a
  -- ^ Left-hand side array
  -> Array a
  -- ^ Right-hand side array
  -> MatProp
  -- ^ Options for left-hand side. Currently only AF_MAT_NONE and AF_MAT_CONJ are supported.
  -> MatProp
  -- ^ Options for right-hand side. Currently only AF_MAT_NONE and AF_MAT_CONJ are supported.
  -> Complex Double
  -- ^ Real and imaginary component result
dotAll :: forall a.
AFType a =>
Array a -> Array a -> MatProp -> MatProp -> Complex Double
dotAll Array a
arr1 Array a
arr2 MatProp
prop1 MatProp
prop2 = do
  let (Double
real,Double
imag) =
        Array a
-> Array a
-> (Ptr Double -> Ptr Double -> AFArray -> AFArray -> IO AFErr)
-> (Double, Double)
forall a b arr.
(Storable a, Storable b) =>
Array arr
-> Array arr
-> (Ptr a -> Ptr b -> AFArray -> AFArray -> IO AFErr)
-> (a, b)
infoFromArray22 Array a
arr1 Array a
arr2 ((Ptr Double -> Ptr Double -> AFArray -> AFArray -> IO AFErr)
 -> (Double, Double))
-> (Ptr Double -> Ptr Double -> AFArray -> AFArray -> IO AFErr)
-> (Double, Double)
forall a b. (a -> b) -> a -> b
$ \Ptr Double
a Ptr Double
b AFArray
c AFArray
d ->
          Ptr Double
-> Ptr Double
-> AFArray
-> AFArray
-> AFMatProp
-> AFMatProp
-> IO AFErr
af_dot_all Ptr Double
a Ptr Double
b AFArray
c AFArray
d (MatProp -> AFMatProp
toMatProp MatProp
prop1) (MatProp -> AFMatProp
toMatProp MatProp
prop2)
  Double
real Double -> Double -> Complex Double
forall a. a -> a -> Complex a
:+ Double
imag

-- | Transposes a matrix.
--
-- >>> array = matrix @Double (2,3) [[2,3],[3,4],[5,6]]
-- >>> array
-- ArrayFire Array
-- [2 3 1 1]
--     2.0000     3.0000     5.0000
--     3.0000     4.0000     6.0000
--
-- >>> transpose array True
-- ArrayFire Array
-- [3 2 1 1]
--     2.0000     3.0000
--     3.0000     4.0000
--     5.0000     6.0000
--
transpose
  :: AFType a
  => Array a
  -- ^ Input matrix to be transposed
  -> Bool
  -- ^ Should perform conjugate transposition
  -> Array a
  -- ^ The transposed matrix
transpose :: forall a. AFType a => Array a -> Bool -> Array a
transpose Array a
arr1 (Int -> CBool
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CBool) -> (Bool -> Int) -> Bool -> CBool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> Int
forall a. Enum a => a -> Int
fromEnum -> CBool
b) =
  Array a
arr1 Array a -> (Ptr AFArray -> AFArray -> IO AFErr) -> Array a
forall a.
Array a -> (Ptr AFArray -> AFArray -> IO AFErr) -> Array a
`op1` (\Ptr AFArray
x AFArray
y -> Ptr AFArray -> AFArray -> CBool -> IO AFErr
af_transpose Ptr AFArray
x AFArray
y CBool
b)

-- | Transposes a matrix.
--
-- * Warning: This function mutates an array in-place, all subsequent references will be changed. Use carefully.
--
-- >>> array = matrix @Double (2,2) [[1,2],[3,4]]
-- >>> array
-- ArrayFire Array
-- [3 2 1 1]
--    1.0000     4.0000
--    2.0000     5.0000
--    3.0000     6.0000
--
-- >>> transposeInPlace array False
-- >>> array
-- ArrayFire Array
-- [2 2 1 1]
--    1.0000     2.0000
--    3.0000     4.0000
--
transposeInPlace
  :: AFType a
  => Array a
  -- ^ Input matrix to be transposed
  -> Bool
  -- ^ Should perform conjugate transposition
  -> IO ()
transposeInPlace :: forall a. AFType a => Array a -> Bool -> IO ()
transposeInPlace Array a
arr (Int -> CBool
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CBool) -> (Bool -> Int) -> Bool -> CBool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> Int
forall a. Enum a => a -> Int
fromEnum -> CBool
b) =
  Array a
arr Array a -> (AFArray -> IO AFErr) -> IO ()
forall a. Array a -> (AFArray -> IO AFErr) -> IO ()
`inPlace` (AFArray -> CBool -> IO AFErr
`af_transpose_inplace` CBool
b)