{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Numeric.BLAS.Matrix.RowMajor (
   Matrix,
   Vector,
   takeRow,
   takeColumn,
   fromRows,
   tensorProduct,
   decomplex,
   recomplex,
   scaleRows,
   scaleColumns,
   ) where

import qualified Numeric.BLAS.Private as Private
import Numeric.BLAS.Matrix.Modifier (Conjugation(NonConjugated,Conjugated))
import Numeric.BLAS.Scalar (zero, one)
import Numeric.BLAS.Private (ShapeInt, shapeInt, ComplexShape, pointerSeq)

import qualified Numeric.BLAS.FFI.Generic as Blas
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import Foreign.Marshal.Array (copyArray, advancePtr)
import Foreign.ForeignPtr (withForeignPtr, castForeignPtr)
import Foreign.Storable (Storable)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))

import Data.Complex (Complex)
import Data.Foldable (forM_)


type Matrix height width = Array (height,width)
type Vector = Array

takeRow ::
   (Shape.Indexed height, Shape.C width, Shape.Index height ~ ix,
    Storable a) =>
   ix -> Matrix height width a -> Vector width a
takeRow :: forall height width ix a.
(Indexed height, C width, Index height ~ ix, Storable a) =>
ix -> Matrix height width a -> Vector width a
takeRow ix
ix (Array (height
height,width
width) ForeignPtr a
x) =
   forall sh a.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> Array sh a
Array.unsafeCreateWithSize width
width forall a b. (a -> b) -> a -> b
$ \Int
n Ptr a
yPtr ->
   forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x forall a b. (a -> b) -> a -> b
$ \Ptr a
xPtr ->
      forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray Ptr a
yPtr (forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
xPtr (Int
n forall a. Num a => a -> a -> a
* forall sh. Indexed sh => sh -> Index sh -> Int
Shape.offset height
height ix
ix)) Int
n

takeColumn ::
   (Shape.C height, Shape.Indexed width, Shape.Index width ~ ix,
    Class.Floating a) =>
   ix -> Matrix height width a -> Vector height a
takeColumn :: forall height width ix a.
(C height, Indexed width, Index width ~ ix, Floating a) =>
ix -> Matrix height width a -> Vector height a
takeColumn ix
ix (Array (height
height,width
width) ForeignPtr a
x) =
   forall sh a.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> Array sh a
Array.unsafeCreateWithSize height
height forall a b. (a -> b) -> a -> b
$ \Int
n Ptr a
yPtr -> forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT forall a b. (a -> b) -> a -> b
$ do
      let offset :: Int
offset = forall sh. Indexed sh => sh -> Index sh -> Int
Shape.offset width
width ix
ix
      Ptr CInt
nPtr <- forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr a
xPtr <- forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT forall a b. (a -> b) -> a -> b
$ forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x
      Ptr CInt
incxPtr <- forall r. Int -> FortranIO r (Ptr CInt)
Call.cint forall a b. (a -> b) -> a -> b
$ forall sh. C sh => sh -> Int
Shape.size width
width
      Ptr CInt
incyPtr <- forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a.
Floating a =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
Blas.copy Ptr CInt
nPtr (forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
xPtr Int
offset) Ptr CInt
incxPtr Ptr a
yPtr Ptr CInt
incyPtr


fromRows ::
   (Shape.C width, Eq width, Storable a) =>
   width -> [Vector width a] -> Matrix ShapeInt width a
fromRows :: forall width a.
(C width, Eq width, Storable a) =>
width -> [Vector width a] -> Matrix ShapeInt width a
fromRows width
width [Vector width a]
rows =
   forall sh a.
(C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> Array sh a
Array.unsafeCreate (Int -> ShapeInt
shapeInt forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Int
length [Vector width a]
rows, width
width) forall a b. (a -> b) -> a -> b
$ \Ptr a
dstPtr ->
   let widthSize :: Int
widthSize = forall sh. C sh => sh -> Int
Shape.size width
width
   in forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
widthSize Ptr a
dstPtr) [Vector width a]
rows) forall a b. (a -> b) -> a -> b
$
         \(Ptr a
dstRowPtr, Array.Array width
rowWidth ForeignPtr a
srcFPtr) ->
         forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
srcFPtr forall a b. (a -> b) -> a -> b
$ \Ptr a
srcPtr -> do
            String -> Bool -> IO ()
Call.assert
               String
"Matrix.fromRows: non-matching vector size"
               (width
width forall a. Eq a => a -> a -> Bool
== width
rowWidth)
            forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray Ptr a
dstRowPtr Ptr a
srcPtr Int
widthSize


-- ToDo: use lapack:Private.multiplyMatrix
tensorProduct ::
   (Shape.C height, Shape.C width, Class.Floating a) =>
   Either Conjugation Conjugation ->
   Vector height a -> Vector width a -> Matrix height width a
tensorProduct :: forall height width a.
(C height, C width, Floating a) =>
Either Conjugation Conjugation
-> Vector height a -> Vector width a -> Matrix height width a
tensorProduct Either Conjugation Conjugation
side (Array height
height ForeignPtr a
x) (Array width
width ForeignPtr a
y) =
   forall sh a.
(C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> Array sh a
Array.unsafeCreate (height
height,width
width) forall a b. (a -> b) -> a -> b
$ \Ptr a
cPtr -> do
   let m :: Int
m = forall sh. C sh => sh -> Int
Shape.size width
width
   let n :: Int
n = forall sh. C sh => sh -> Int
Shape.size height
height
   let trans :: Conjugation -> Char
trans Conjugation
conjugated =
         case Conjugation
conjugated of Conjugation
NonConjugated -> Char
'T'; Conjugation
Conjugated -> Char
'C'
   let ((Char
transa,Char
transb),(Int
lda,Int
ldb)) =
         case Either Conjugation Conjugation
side of
            Left Conjugation
c -> ((Conjugation -> Char
trans Conjugation
c, Char
'N'),(Int
1,Int
1))
            Right Conjugation
c -> ((Char
'N', Conjugation -> Char
trans Conjugation
c),(Int
m,Int
n))
   forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT forall a b. (a -> b) -> a -> b
$ do
      Ptr CChar
transaPtr <- forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
transa
      Ptr CChar
transbPtr <- forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
transb
      Ptr CInt
mPtr <- forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
m
      Ptr CInt
nPtr <- forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr CInt
kPtr <- forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      Ptr a
alphaPtr <- forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number forall a. Floating a => a
one
      Ptr a
aPtr <- forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT forall a b. (a -> b) -> a -> b
$ forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
y
      Ptr CInt
ldaPtr <- forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
lda
      Ptr a
bPtr <- forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT forall a b. (a -> b) -> a -> b
$ forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x
      Ptr CInt
ldbPtr <- forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
ldb
      Ptr a
betaPtr <- forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number forall a. Floating a => a
zero
      Ptr CInt
ldcPtr <- forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
m
      forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$
         forall a.
Floating a =>
Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
Blas.gemm
            Ptr CChar
transaPtr Ptr CChar
transbPtr Ptr CInt
mPtr Ptr CInt
nPtr Ptr CInt
kPtr Ptr a
alphaPtr
            Ptr a
aPtr Ptr CInt
ldaPtr Ptr a
bPtr Ptr CInt
ldbPtr Ptr a
betaPtr Ptr a
cPtr Ptr CInt
ldcPtr


decomplex ::
   (Class.Real a) =>
   Matrix height width (Complex a) ->
   Matrix height (width, ComplexShape) a
decomplex :: forall a height width.
Real a =>
Matrix height width (Complex a)
-> Matrix height (width, ComplexShape) a
decomplex (Array (height
height,width
width) ForeignPtr (Complex a)
a) =
   forall sh a. sh -> ForeignPtr a -> Array sh a
Array (height
height, (width
width, forall sh. Static sh => sh
Shape.static)) (forall a b. ForeignPtr a -> ForeignPtr b
castForeignPtr ForeignPtr (Complex a)
a)

recomplex ::
   (Class.Real a) =>
   Matrix height (width, ComplexShape) a ->
   Matrix height width (Complex a)
recomplex :: forall a height width.
Real a =>
Matrix height (width, ComplexShape) a
-> Matrix height width (Complex a)
recomplex (Array (height
height, (width
width, Shape.NestedTuple Complex Element
_)) ForeignPtr a
a) =
   forall sh a. sh -> ForeignPtr a -> Array sh a
Array (height
height,width
width) (forall a b. ForeignPtr a -> ForeignPtr b
castForeignPtr ForeignPtr a
a)


scaleRows ::
   (Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   Vector height a -> Matrix height width a -> Matrix height width a
scaleRows :: forall height width a.
(C height, Eq height, C width, Floating a) =>
Vector height a -> Matrix height width a -> Matrix height width a
scaleRows (Array height
heightX ForeignPtr a
x) (Array shape :: (height, width)
shape@(height
height,width
width) ForeignPtr a
a) =
      forall sh a.
(C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> Array sh a
Array.unsafeCreate (height, width)
shape forall a b. (a -> b) -> a -> b
$ \Ptr a
bPtr -> do
   String -> Bool -> IO ()
Call.assert String
"scaleRows: sizes mismatch" (height
heightX forall a. Eq a => a -> a -> Bool
== height
height)
   forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT forall a b. (a -> b) -> a -> b
$ do
      let m :: Int
m = forall sh. C sh => sh -> Int
Shape.size height
height
      let n :: Int
n = forall sh. C sh => sh -> Int
Shape.size width
width
      Ptr CInt
nPtr <- forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr a
xPtr <- forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT forall a b. (a -> b) -> a -> b
$ forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x
      Ptr a
aPtr <- forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT forall a b. (a -> b) -> a -> b
$ forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
      Ptr CInt
incaPtr <- forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      Ptr CInt
incbPtr <- forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take Int
m forall a b. (a -> b) -> a -> b
$
         forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3
            (\Ptr a
xkPtr Ptr a
akPtr Ptr a
bkPtr -> do
               forall a.
Floating a =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
Blas.copy Ptr CInt
nPtr Ptr a
akPtr Ptr CInt
incaPtr Ptr a
bkPtr Ptr CInt
incbPtr
               forall a.
Floating a =>
Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
Blas.scal Ptr CInt
nPtr Ptr a
xkPtr Ptr a
bkPtr Ptr CInt
incbPtr)
            (forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
1 Ptr a
xPtr)
            (forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
n Ptr a
aPtr)
            (forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
n Ptr a
bPtr)

scaleColumns ::
   (Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   Vector width a -> Matrix height width a -> Matrix height width a
scaleColumns :: forall height width a.
(C height, C width, Eq width, Floating a) =>
Vector width a -> Matrix height width a -> Matrix height width a
scaleColumns (Array width
widthX ForeignPtr a
x) (Array shape :: (height, width)
shape@(height
height,width
width) ForeignPtr a
a) =
      forall sh a.
(C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> Array sh a
Array.unsafeCreate (height, width)
shape forall a b. (a -> b) -> a -> b
$ \Ptr a
bPtr -> do
   String -> Bool -> IO ()
Call.assert String
"scaleColumns: sizes mismatch" (width
widthX forall a. Eq a => a -> a -> Bool
== width
width)
   forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT forall a b. (a -> b) -> a -> b
$ do
      let m :: Int
m = forall sh. C sh => sh -> Int
Shape.size height
height
      let n :: Int
n = forall sh. C sh => sh -> Int
Shape.size width
width
      Ptr CChar
transPtr <- forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
'N'
      Ptr CInt
nPtr <- forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr CInt
klPtr <- forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
0
      Ptr CInt
kuPtr <- forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
0
      Ptr a
alphaPtr <- forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number forall a. Floating a => a
one
      Ptr a
xPtr <- forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT forall a b. (a -> b) -> a -> b
$ forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x
      Ptr CInt
ldxPtr <- forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
1
      Ptr a
aPtr <- forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT forall a b. (a -> b) -> a -> b
$ forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
      Ptr CInt
incaPtr <- forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      Ptr a
betaPtr <- forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number forall a. Floating a => a
zero
      Ptr CInt
incbPtr <- forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take Int
m forall a b. (a -> b) -> a -> b
$
         forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
            (\Ptr a
akPtr Ptr a
bkPtr ->
               forall a.
Floating a =>
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
Private.gbmv Ptr CChar
transPtr
                  Ptr CInt
nPtr Ptr CInt
nPtr Ptr CInt
klPtr Ptr CInt
kuPtr Ptr a
alphaPtr Ptr a
xPtr Ptr CInt
ldxPtr
                  Ptr a
akPtr Ptr CInt
incaPtr Ptr a
betaPtr Ptr a
bkPtr Ptr CInt
incbPtr)
            (forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
n Ptr a
aPtr)
            (forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
n Ptr a
bPtr)