{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Numeric.BLAS.Matrix.RowMajor (
   Matrix,
   Vector,
   takeRow,
   takeColumn,
   fromRows,
   tensorProduct,
   decomplex,
   recomplex,
   scaleRows,
   scaleColumns,
   multiplyVectorLeft,
   multiplyVectorRight,
   ) where

import qualified Numeric.BLAS.Matrix.Modifier as Modifier
import qualified Numeric.BLAS.Private as Private
import Numeric.BLAS.Matrix.Modifier (Conjugation(NonConjugated,Conjugated))
import Numeric.BLAS.Scalar (zero, one)
import Numeric.BLAS.Private (ShapeInt, shapeInt, ComplexShape, pointerSeq, fill)

import qualified Numeric.BLAS.FFI.Generic as Blas
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import Foreign.Marshal.Array (copyArray, advancePtr)
import Foreign.ForeignPtr (withForeignPtr, castForeignPtr)
import Foreign.Storable (Storable)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))

import Data.Foldable (forM_)
import Data.Complex (Complex)
import Data.Tuple.HT (swap)


type Matrix height width = Array (height,width)
type Vector = Array

takeRow ::
   (Shape.Indexed height, Shape.C width, Shape.Index height ~ ix,
    Storable a) =>
   ix -> Matrix height width a -> Vector width a
takeRow :: forall height width ix a.
(Indexed height, C width, Index height ~ ix, Storable a) =>
ix -> Matrix height width a -> Vector width a
takeRow ix
ix (Array (height
height,width
width) ForeignPtr a
x) =
   width -> (Int -> Ptr a -> IO ()) -> Array width a
forall sh a.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> Array sh a
Array.unsafeCreateWithSize width
width ((Int -> Ptr a -> IO ()) -> Array width a)
-> (Int -> Ptr a -> IO ()) -> Array width a
forall a b. (a -> b) -> a -> b
$ \Int
n Ptr a
yPtr ->
   ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
xPtr ->
      Ptr a -> Ptr a -> Int -> IO ()
forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray Ptr a
yPtr (Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
xPtr (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* height -> Index height -> Int
forall sh. Indexed sh => sh -> Index sh -> Int
Shape.offset height
height ix
Index height
ix)) Int
n

takeColumn ::
   (Shape.C height, Shape.Indexed width, Shape.Index width ~ ix,
    Class.Floating a) =>
   ix -> Matrix height width a -> Vector height a
takeColumn :: forall height width ix a.
(C height, Indexed width, Index width ~ ix, Floating a) =>
ix -> Matrix height width a -> Vector height a
takeColumn ix
ix (Array (height
height,width
width) ForeignPtr a
x) =
   height -> (Int -> Ptr a -> IO ()) -> Array height a
forall sh a.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> Array sh a
Array.unsafeCreateWithSize height
height ((Int -> Ptr a -> IO ()) -> Array height a)
-> (Int -> Ptr a -> IO ()) -> Array height a
forall a b. (a -> b) -> a -> b
$ \Int
n Ptr a
yPtr -> ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      let offset :: Int
offset = width -> Index width -> Int
forall sh. Indexed sh => sh -> Index sh -> Int
Shape.offset width
width ix
Index width
ix
      Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr a
xPtr <- ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x
      Ptr CInt
incxPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint (Int -> FortranIO () (Ptr CInt)) -> Int -> FortranIO () (Ptr CInt)
forall a b. (a -> b) -> a -> b
$ width -> Int
forall sh. C sh => sh -> Int
Shape.size width
width
      Ptr CInt
incyPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      IO () -> ContT () IO ()
forall a. IO a -> ContT () IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
Blas.copy Ptr CInt
nPtr (Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
xPtr Int
offset) Ptr CInt
incxPtr Ptr a
yPtr Ptr CInt
incyPtr


fromRows ::
   (Shape.C width, Eq width, Storable a) =>
   width -> [Vector width a] -> Matrix ShapeInt width a
fromRows :: forall width a.
(C width, Eq width, Storable a) =>
width -> [Vector width a] -> Matrix ShapeInt width a
fromRows width
width [Vector width a]
rows =
   (ShapeInt, width) -> (Ptr a -> IO ()) -> Array (ShapeInt, width) a
forall sh a.
(C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> Array sh a
Array.unsafeCreate (Int -> ShapeInt
shapeInt (Int -> ShapeInt) -> Int -> ShapeInt
forall a b. (a -> b) -> a -> b
$ [Vector width a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Vector width a]
rows, width
width) ((Ptr a -> IO ()) -> Array (ShapeInt, width) a)
-> (Ptr a -> IO ()) -> Array (ShapeInt, width) a
forall a b. (a -> b) -> a -> b
$ \Ptr a
dstPtr ->
   let widthSize :: Int
widthSize = width -> Int
forall sh. C sh => sh -> Int
Shape.size width
width
   in [(Ptr a, Vector width a)]
-> ((Ptr a, Vector width a) -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Ptr a] -> [Vector width a] -> [(Ptr a, Vector width a)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Int -> Ptr a -> [Ptr a]
forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
widthSize Ptr a
dstPtr) [Vector width a]
rows) (((Ptr a, Vector width a) -> IO ()) -> IO ())
-> ((Ptr a, Vector width a) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$
         \(Ptr a
dstRowPtr, Array.Array width
rowWidth ForeignPtr a
srcFPtr) ->
         ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
srcFPtr ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
srcPtr -> do
            String -> Bool -> IO ()
Call.assert
               String
"Matrix.fromRows: non-matching vector size"
               (width
width width -> width -> Bool
forall a. Eq a => a -> a -> Bool
== width
rowWidth)
            Ptr a -> Ptr a -> Int -> IO ()
forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray Ptr a
dstRowPtr Ptr a
srcPtr Int
widthSize


-- ToDo: use lapack:Private.multiplyMatrix
tensorProduct ::
   (Shape.C height, Shape.C width, Class.Floating a) =>
   Either Conjugation Conjugation ->
   Vector height a -> Vector width a -> Matrix height width a
tensorProduct :: forall height width a.
(C height, C width, Floating a) =>
Either Conjugation Conjugation
-> Vector height a -> Vector width a -> Matrix height width a
tensorProduct Either Conjugation Conjugation
side (Array height
height ForeignPtr a
x) (Array width
width ForeignPtr a
y) =
   (height, width) -> (Ptr a -> IO ()) -> Array (height, width) a
forall sh a.
(C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> Array sh a
Array.unsafeCreate (height
height,width
width) ((Ptr a -> IO ()) -> Array (height, width) a)
-> (Ptr a -> IO ()) -> Array (height, width) a
forall a b. (a -> b) -> a -> b
$ \Ptr a
cPtr -> do
   let m :: Int
m = width -> Int
forall sh. C sh => sh -> Int
Shape.size width
width
   let n :: Int
n = height -> Int
forall sh. C sh => sh -> Int
Shape.size height
height
   let trans :: Conjugation -> Char
trans Conjugation
conjugated =
         case Conjugation
conjugated of Conjugation
NonConjugated -> Char
'T'; Conjugation
Conjugated -> Char
'C'
   let ((Char
transa,Char
transb),(Int
lda,Int
ldb)) =
         case Either Conjugation Conjugation
side of
            Left Conjugation
c -> ((Conjugation -> Char
trans Conjugation
c, Char
'N'),(Int
1,Int
1))
            Right Conjugation
c -> ((Char
'N', Conjugation -> Char
trans Conjugation
c),(Int
m,Int
n))
   ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      Ptr CChar
transaPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
transa
      Ptr CChar
transbPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
transb
      Ptr CInt
mPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
m
      Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr CInt
kPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      Ptr a
alphaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
one
      Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
y
      Ptr CInt
ldaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
lda
      Ptr a
bPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x
      Ptr CInt
ldbPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
ldb
      Ptr a
betaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
zero
      Ptr CInt
ldcPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
m
      IO () -> ContT () IO ()
forall a. IO a -> ContT () IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$
         Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
Blas.gemm
            Ptr CChar
transaPtr Ptr CChar
transbPtr Ptr CInt
mPtr Ptr CInt
nPtr Ptr CInt
kPtr Ptr a
alphaPtr
            Ptr a
aPtr Ptr CInt
ldaPtr Ptr a
bPtr Ptr CInt
ldbPtr Ptr a
betaPtr Ptr a
cPtr Ptr CInt
ldcPtr


decomplex ::
   (Class.Real a) =>
   Matrix height width (Complex a) ->
   Matrix height (width, ComplexShape) a
decomplex :: forall a height width.
Real a =>
Matrix height width (Complex a)
-> Matrix height (width, ComplexShape) a
decomplex (Array (height
height,width
width) ForeignPtr (Complex a)
a) =
   (height, (width, ComplexShape))
-> ForeignPtr a -> Array (height, (width, ComplexShape)) a
forall sh a. sh -> ForeignPtr a -> Array sh a
Array (height
height, (width
width, ComplexShape
forall sh. Static sh => sh
Shape.static)) (ForeignPtr (Complex a) -> ForeignPtr a
forall a b. ForeignPtr a -> ForeignPtr b
castForeignPtr ForeignPtr (Complex a)
a)

recomplex ::
   (Class.Real a) =>
   Matrix height (width, ComplexShape) a ->
   Matrix height width (Complex a)
recomplex :: forall a height width.
Real a =>
Matrix height (width, ComplexShape) a
-> Matrix height width (Complex a)
recomplex (Array (height
height, (width
width, Shape.NestedTuple Complex Element
_)) ForeignPtr a
a) =
   (height, width)
-> ForeignPtr (Complex a) -> Array (height, width) (Complex a)
forall sh a. sh -> ForeignPtr a -> Array sh a
Array (height
height,width
width) (ForeignPtr a -> ForeignPtr (Complex a)
forall a b. ForeignPtr a -> ForeignPtr b
castForeignPtr ForeignPtr a
a)


scaleRows ::
   (Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   Vector height a -> Matrix height width a -> Matrix height width a
scaleRows :: forall height width a.
(C height, Eq height, C width, Floating a) =>
Vector height a -> Matrix height width a -> Matrix height width a
scaleRows (Array height
heightX ForeignPtr a
x) (Array shape :: (height, width)
shape@(height
height,width
width) ForeignPtr a
a) =
      (height, width) -> (Ptr a -> IO ()) -> Array (height, width) a
forall sh a.
(C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> Array sh a
Array.unsafeCreate (height, width)
shape ((Ptr a -> IO ()) -> Array (height, width) a)
-> (Ptr a -> IO ()) -> Array (height, width) a
forall a b. (a -> b) -> a -> b
$ \Ptr a
bPtr -> do
   String -> Bool -> IO ()
Call.assert String
"scaleRows: sizes mismatch" (height
heightX height -> height -> Bool
forall a. Eq a => a -> a -> Bool
== height
height)
   ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      let m :: Int
m = height -> Int
forall sh. C sh => sh -> Int
Shape.size height
height
      let n :: Int
n = width -> Int
forall sh. C sh => sh -> Int
Shape.size width
width
      Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr a
xPtr <- ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x
      Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
      Ptr CInt
incaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      Ptr CInt
incbPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      IO () -> ContT () IO ()
forall a. IO a -> ContT () IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ [IO ()] -> IO ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ ([IO ()] -> IO ()) -> [IO ()] -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> [IO ()] -> [IO ()]
forall a. Int -> [a] -> [a]
take Int
m ([IO ()] -> [IO ()]) -> [IO ()] -> [IO ()]
forall a b. (a -> b) -> a -> b
$
         (Ptr a -> Ptr a -> Ptr a -> IO ())
-> [Ptr a] -> [Ptr a] -> [Ptr a] -> [IO ()]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3
            (\Ptr a
xkPtr Ptr a
akPtr Ptr a
bkPtr -> do
               Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
Blas.copy Ptr CInt
nPtr Ptr a
akPtr Ptr CInt
incaPtr Ptr a
bkPtr Ptr CInt
incbPtr
               Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
Blas.scal Ptr CInt
nPtr Ptr a
xkPtr Ptr a
bkPtr Ptr CInt
incbPtr)
            (Int -> Ptr a -> [Ptr a]
forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
1 Ptr a
xPtr)
            (Int -> Ptr a -> [Ptr a]
forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
n Ptr a
aPtr)
            (Int -> Ptr a -> [Ptr a]
forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
n Ptr a
bPtr)

scaleColumns ::
   (Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   Vector width a -> Matrix height width a -> Matrix height width a
scaleColumns :: forall height width a.
(C height, C width, Eq width, Floating a) =>
Vector width a -> Matrix height width a -> Matrix height width a
scaleColumns (Array width
widthX ForeignPtr a
x) (Array shape :: (height, width)
shape@(height
height,width
width) ForeignPtr a
a) =
      (height, width) -> (Ptr a -> IO ()) -> Array (height, width) a
forall sh a.
(C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> Array sh a
Array.unsafeCreate (height, width)
shape ((Ptr a -> IO ()) -> Array (height, width) a)
-> (Ptr a -> IO ()) -> Array (height, width) a
forall a b. (a -> b) -> a -> b
$ \Ptr a
bPtr -> do
   String -> Bool -> IO ()
Call.assert String
"scaleColumns: sizes mismatch" (width
widthX width -> width -> Bool
forall a. Eq a => a -> a -> Bool
== width
width)
   ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      let m :: Int
m = height -> Int
forall sh. C sh => sh -> Int
Shape.size height
height
      let n :: Int
n = width -> Int
forall sh. C sh => sh -> Int
Shape.size width
width
      Ptr CChar
transPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
'N'
      Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr CInt
klPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
0
      Ptr CInt
kuPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
0
      Ptr a
alphaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
one
      Ptr a
xPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x
      Ptr CInt
ldxPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
1
      Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
      Ptr CInt
incaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      Ptr a
betaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
zero
      Ptr CInt
incbPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      IO () -> ContT () IO ()
forall a. IO a -> ContT () IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ [IO ()] -> IO ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ ([IO ()] -> IO ()) -> [IO ()] -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> [IO ()] -> [IO ()]
forall a. Int -> [a] -> [a]
take Int
m ([IO ()] -> [IO ()]) -> [IO ()] -> [IO ()]
forall a b. (a -> b) -> a -> b
$
         (Ptr a -> Ptr a -> IO ()) -> [Ptr a] -> [Ptr a] -> [IO ()]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
            (\Ptr a
akPtr Ptr a
bkPtr ->
               Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
Private.gbmv Ptr CChar
transPtr
                  Ptr CInt
nPtr Ptr CInt
nPtr Ptr CInt
klPtr Ptr CInt
kuPtr Ptr a
alphaPtr Ptr a
xPtr Ptr CInt
ldxPtr
                  Ptr a
akPtr Ptr CInt
incaPtr Ptr a
betaPtr Ptr a
bkPtr Ptr CInt
incbPtr)
            (Int -> Ptr a -> [Ptr a]
forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
n Ptr a
aPtr)
            (Int -> Ptr a -> [Ptr a]
forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
n Ptr a
bPtr)


multiplyVectorLeft ::
   (Eq height, Shape.C height, Shape.C width, Class.Floating a) =>
   Vector height a -> Matrix height width a -> Vector width a
multiplyVectorLeft :: forall height width a.
(Eq height, C height, C width, Floating a) =>
Vector height a -> Matrix height width a -> Vector width a
multiplyVectorLeft = Transposition height width height width
-> Vector height a -> Matrix height width a -> Vector width a
forall heightB widthB a heightA widthA.
(C heightB, C widthB, Eq heightB, Floating a) =>
Transposition heightA widthA heightB widthB
-> Vector heightB a -> Matrix heightA widthA a -> Vector widthB a
multiplyVector Transposition height width height width
forall height width. Transposition height width height width
nonTransposed

multiplyVectorRight ::
   (Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   Matrix height width a -> Vector width a -> Vector height a
multiplyVectorRight :: forall height width a.
(C height, C width, Eq width, Floating a) =>
Matrix height width a -> Vector width a -> Vector height a
multiplyVectorRight = (Vector width a -> Matrix height width a -> Vector height a)
-> Matrix height width a -> Vector width a -> Vector height a
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Vector width a -> Matrix height width a -> Vector height a)
 -> Matrix height width a -> Vector width a -> Vector height a)
-> (Vector width a -> Matrix height width a -> Vector height a)
-> Matrix height width a
-> Vector width a
-> Vector height a
forall a b. (a -> b) -> a -> b
$ Transposition height width width height
-> Vector width a -> Matrix height width a -> Vector height a
forall heightB widthB a heightA widthA.
(C heightB, C widthB, Eq heightB, Floating a) =>
Transposition heightA widthA heightB widthB
-> Vector heightB a -> Matrix heightA widthA a -> Vector widthB a
multiplyVector Transposition height width width height
forall height width. Transposition height width width height
transposed


data Transposition heightA widthA heightB widthB =
   Transposition
      Modifier.Transposition Char
      ((heightA,widthA) -> (heightB,widthB))

transposed :: Transposition height width width height
transposed :: forall height width. Transposition height width width height
transposed = Transposition
-> Char
-> ((height, width) -> (width, height))
-> Transposition height width width height
forall heightA widthA heightB widthB.
Transposition
-> Char
-> ((heightA, widthA) -> (heightB, widthB))
-> Transposition heightA widthA heightB widthB
Transposition Transposition
Modifier.Transposed Char
'T' (height, width) -> (width, height)
forall a b. (a, b) -> (b, a)
swap

nonTransposed :: Transposition height width height width
nonTransposed :: forall height width. Transposition height width height width
nonTransposed = Transposition
-> Char
-> ((height, width) -> (height, width))
-> Transposition height width height width
forall heightA widthA heightB widthB.
Transposition
-> Char
-> ((heightA, widthA) -> (heightB, widthB))
-> Transposition heightA widthA heightB widthB
Transposition Transposition
Modifier.NonTransposed Char
'N' (height, width) -> (height, width)
forall a. a -> a
id

multiplyVector ::
   (Shape.C heightB, Shape.C widthB, Eq heightB, Class.Floating a) =>
   Transposition heightA widthA heightB widthB ->
   Vector heightB a -> Matrix heightA widthA a -> Vector widthB a
multiplyVector :: forall heightB widthB a heightA widthA.
(C heightB, C widthB, Eq heightB, Floating a) =>
Transposition heightA widthA heightB widthB
-> Vector heightB a -> Matrix heightA widthA a -> Vector widthB a
multiplyVector
      (Transposition Transposition
trans Char
transChar (heightA, widthA) -> (heightB, widthB)
assignDims)
      (Array heightB
sh ForeignPtr a
x) (Array (heightA, widthA)
shA ForeignPtr a
a) =
   let (heightB
height,widthB
width) = (heightA, widthA) -> (heightB, widthB)
assignDims (heightA, widthA)
shA in
   widthB -> (Int -> Ptr a -> IO ()) -> Array widthB a
forall sh a.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> Array sh a
Array.unsafeCreateWithSize widthB
width ((Int -> Ptr a -> IO ()) -> Array widthB a)
-> (Int -> Ptr a -> IO ()) -> Array widthB a
forall a b. (a -> b) -> a -> b
$ \Int
m0 Ptr a
yPtr -> do
   String -> Bool -> IO ()
Call.assert
      String
"Matrix.RowMajor.multiplyVector: shapes mismatch"
      (heightB
height heightB -> heightB -> Bool
forall a. Eq a => a -> a -> Bool
== heightB
sh)
   let n0 :: Int
n0 = heightB -> Int
forall sh. C sh => sh -> Int
Shape.size heightB
height
   let (Int
m,Int
n) =
         case Transposition
trans of
            Transposition
Modifier.NonTransposed -> (Int
m0,Int
n0)
            Transposition
Modifier.Transposed -> (Int
n0,Int
m0)
   if Int
nInt -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==Int
0
      then a -> Int -> Ptr a -> IO ()
forall a. Floating a => a -> Int -> Ptr a -> IO ()
fill a
forall a. Floating a => a
zero Int
m Ptr a
yPtr
      else ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do

      let lda :: Int
lda = Int
m
      Ptr CChar
transPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
transChar
      Ptr CInt
mPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
m
      Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr a
alphaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
one
      Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
      Ptr CInt
ldaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
lda
      Ptr a
xPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x
      Ptr CInt
incxPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      Ptr a
betaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
zero
      Ptr CInt
incyPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      IO () -> ContT () IO ()
forall a. IO a -> ContT () IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$
         Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
Blas.gemv
            Ptr CChar
transPtr Ptr CInt
mPtr Ptr CInt
nPtr Ptr a
alphaPtr Ptr a
aPtr Ptr CInt
ldaPtr
            Ptr a
xPtr Ptr CInt
incxPtr Ptr a
betaPtr Ptr a
yPtr Ptr CInt
incyPtr