{-# LANGUAGE TypeFamilies #-}
module Numeric.LAPACK.Matrix.RowMajor where

import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Private as Private
import Numeric.LAPACK.Matrix.Shape.Private (Order(RowMajor, ColumnMajor))
import Numeric.LAPACK.Matrix.Private (Full, ShapeInt, shapeInt)
import Numeric.LAPACK.Matrix.Modifier (Conjugation(NonConjugated,Conjugated))
import Numeric.LAPACK.Scalar (zero, one)
import Numeric.LAPACK.Private (ComplexPart, pointerSeq)

import qualified Numeric.BLAS.FFI.Generic as BlasGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import Foreign.Marshal.Array (copyArray, advancePtr)
import Foreign.ForeignPtr (withForeignPtr, castForeignPtr)
import Foreign.Storable (Storable)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import Control.Applicative (liftA2)

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))

import Data.Complex (Complex)
import Data.Foldable (forM_)


type Matrix height width = Array (height,width)
type Vector = Array

takeRow ::
   (Shape.Indexed height, Shape.C width, Shape.Index height ~ ix,
    Storable a) =>
   ix -> Matrix height width a -> Vector width a
takeRow :: ix -> Matrix height width a -> Vector width a
takeRow ix
ix (Array (height
height,width
width) ForeignPtr a
x) =
   width -> (Int -> Ptr a -> IO ()) -> Vector width a
forall sh a.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> Array sh a
Array.unsafeCreateWithSize width
width ((Int -> Ptr a -> IO ()) -> Vector width a)
-> (Int -> Ptr a -> IO ()) -> Vector width a
forall a b. (a -> b) -> a -> b
$ \Int
n Ptr a
yPtr ->
   ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
xPtr ->
      Ptr a -> Ptr a -> Int -> IO ()
forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray Ptr a
yPtr (Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
xPtr (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* height -> Index height -> Int
forall sh. Indexed sh => sh -> Index sh -> Int
Shape.offset height
height ix
Index height
ix)) Int
n

takeColumn ::
   (Shape.C height, Shape.Indexed width, Shape.Index width ~ ix,
    Class.Floating a) =>
   ix -> Matrix height width a -> Vector height a
takeColumn :: ix -> Matrix height width a -> Vector height a
takeColumn ix
ix (Array (height
height,width
width) ForeignPtr a
x) =
   height -> (Int -> Ptr a -> IO ()) -> Vector height a
forall sh a.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> Array sh a
Array.unsafeCreateWithSize height
height ((Int -> Ptr a -> IO ()) -> Vector height a)
-> (Int -> Ptr a -> IO ()) -> Vector height a
forall a b. (a -> b) -> a -> b
$ \Int
n Ptr a
yPtr -> ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      let offset :: Int
offset = width -> Index width -> Int
forall sh. Indexed sh => sh -> Index sh -> Int
Shape.offset width
width ix
Index width
ix
      Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr a
xPtr <- ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x
      Ptr CInt
incxPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint (Int -> FortranIO () (Ptr CInt)) -> Int -> FortranIO () (Ptr CInt)
forall a b. (a -> b) -> a -> b
$ width -> Int
forall sh. C sh => sh -> Int
Shape.size width
width
      Ptr CInt
incyPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
BlasGen.copy Ptr CInt
nPtr (Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
xPtr Int
offset) Ptr CInt
incxPtr Ptr a
yPtr Ptr CInt
incyPtr


fromRows ::
   (Shape.C width, Eq width, Storable a) =>
   width -> [Vector width a] -> Matrix ShapeInt width a
fromRows :: width -> [Vector width a] -> Matrix ShapeInt width a
fromRows width
width [Vector width a]
rows =
   (ShapeInt, width) -> (Ptr a -> IO ()) -> Matrix ShapeInt width a
forall sh a.
(C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> Array sh a
Array.unsafeCreate (Int -> ShapeInt
shapeInt (Int -> ShapeInt) -> Int -> ShapeInt
forall a b. (a -> b) -> a -> b
$ [Vector width a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Vector width a]
rows, width
width) ((Ptr a -> IO ()) -> Matrix ShapeInt width a)
-> (Ptr a -> IO ()) -> Matrix ShapeInt width a
forall a b. (a -> b) -> a -> b
$ \Ptr a
dstPtr ->
   let widthSize :: Int
widthSize = width -> Int
forall sh. C sh => sh -> Int
Shape.size width
width
   in [(Ptr a, Vector width a)]
-> ((Ptr a, Vector width a) -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Ptr a] -> [Vector width a] -> [(Ptr a, Vector width a)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Int -> Ptr a -> [Ptr a]
forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
widthSize Ptr a
dstPtr) [Vector width a]
rows) (((Ptr a, Vector width a) -> IO ()) -> IO ())
-> ((Ptr a, Vector width a) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$
         \(Ptr a
dstRowPtr, Array.Array width
rowWidth ForeignPtr a
srcFPtr) ->
         ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
srcFPtr ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
srcPtr -> do
            String -> Bool -> IO ()
Call.assert
               String
"Matrix.fromRows: non-matching vector size"
               (width
width width -> width -> Bool
forall a. Eq a => a -> a -> Bool
== width
rowWidth)
            Ptr a -> Ptr a -> Int -> IO ()
forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray Ptr a
dstRowPtr Ptr a
srcPtr Int
widthSize


tensorProduct ::
   (Shape.C height, Shape.C width, Class.Floating a) =>
   Either Conjugation Conjugation ->
   Vector height a -> Vector width a -> Matrix height width a
tensorProduct :: Either Conjugation Conjugation
-> Vector height a -> Vector width a -> Matrix height width a
tensorProduct Either Conjugation Conjugation
side (Array height
height ForeignPtr a
x) (Array width
width ForeignPtr a
y) =
   (height, width) -> (Ptr a -> IO ()) -> Matrix height width a
forall sh a.
(C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> Array sh a
Array.unsafeCreate (height
height,width
width) ((Ptr a -> IO ()) -> Matrix height width a)
-> (Ptr a -> IO ()) -> Matrix height width a
forall a b. (a -> b) -> a -> b
$ \Ptr a
cPtr -> do
   let m :: Int
m = width -> Int
forall sh. C sh => sh -> Int
Shape.size width
width
   let n :: Int
n = height -> Int
forall sh. C sh => sh -> Int
Shape.size height
height
   let trans :: Conjugation -> Char
trans Conjugation
conjugated =
         case Conjugation
conjugated of Conjugation
NonConjugated -> Char
'T'; Conjugation
Conjugated -> Char
'C'
   let ((Char
transa,Char
transb),(Int
lda,Int
ldb)) =
         case Either Conjugation Conjugation
side of
            Left Conjugation
c -> ((Conjugation -> Char
trans Conjugation
c, Char
'N'),(Int
1,Int
1))
            Right Conjugation
c -> ((Char
'N', Conjugation -> Char
trans Conjugation
c),(Int
m,Int
n))
   ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      Ptr CChar
transaPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
transa
      Ptr CChar
transbPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
transb
      Ptr CInt
mPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
m
      Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr CInt
kPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      Ptr a
alphaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
one
      Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
y
      Ptr CInt
ldaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
lda
      Ptr a
bPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x
      Ptr CInt
ldbPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
ldb
      Ptr a
betaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
zero
      Ptr CInt
ldcPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
m
      IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$
         Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
BlasGen.gemm
            Ptr CChar
transaPtr Ptr CChar
transbPtr Ptr CInt
mPtr Ptr CInt
nPtr Ptr CInt
kPtr Ptr a
alphaPtr
            Ptr a
aPtr Ptr CInt
ldaPtr Ptr a
bPtr Ptr CInt
ldbPtr Ptr a
betaPtr Ptr a
cPtr Ptr CInt
ldcPtr


decomplex ::
   (Class.Real a) =>
   Matrix height width (Complex a) ->
   Matrix height (width, Shape.Enumeration ComplexPart) a
decomplex :: Matrix height width (Complex a)
-> Matrix height (width, Enumeration ComplexPart) a
decomplex (Array (height
height,width
width) ForeignPtr (Complex a)
a) =
   (height, (width, Enumeration ComplexPart))
-> ForeignPtr a -> Matrix height (width, Enumeration ComplexPart) a
forall sh a. sh -> ForeignPtr a -> Array sh a
Array (height
height, (width
width, Enumeration ComplexPart
forall n. Enumeration n
Shape.Enumeration)) (ForeignPtr (Complex a) -> ForeignPtr a
forall a b. ForeignPtr a -> ForeignPtr b
castForeignPtr ForeignPtr (Complex a)
a)

recomplex ::
   (Class.Real a) =>
   Matrix height (width, Shape.Enumeration ComplexPart) a ->
   Matrix height width (Complex a)
recomplex :: Matrix height (width, Enumeration ComplexPart) a
-> Matrix height width (Complex a)
recomplex (Array (height
height, (width
width, Enumeration ComplexPart
Shape.Enumeration)) ForeignPtr a
a) =
   (height, width)
-> ForeignPtr (Complex a) -> Matrix height width (Complex a)
forall sh a. sh -> ForeignPtr a -> Array sh a
Array (height
height,width
width) (ForeignPtr a -> ForeignPtr (Complex a)
forall a b. ForeignPtr a -> ForeignPtr b
castForeignPtr ForeignPtr a
a)


scaleRows ::
   (Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   Vector height a -> Matrix height width a -> Matrix height width a
scaleRows :: Vector height a -> Matrix height width a -> Matrix height width a
scaleRows (Array height
heightX ForeignPtr a
x) (Array shape :: (height, width)
shape@(height
height,width
width) ForeignPtr a
a) =
      (height, width) -> (Ptr a -> IO ()) -> Matrix height width a
forall sh a.
(C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> Array sh a
Array.unsafeCreate (height, width)
shape ((Ptr a -> IO ()) -> Matrix height width a)
-> (Ptr a -> IO ()) -> Matrix height width a
forall a b. (a -> b) -> a -> b
$ \Ptr a
bPtr -> do
   String -> Bool -> IO ()
Call.assert String
"scaleRows: sizes mismatch" (height
heightX height -> height -> Bool
forall a. Eq a => a -> a -> Bool
== height
height)
   ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      let m :: Int
m = height -> Int
forall sh. C sh => sh -> Int
Shape.size height
height
      let n :: Int
n = width -> Int
forall sh. C sh => sh -> Int
Shape.size width
width
      Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr a
xPtr <- ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x
      Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
      Ptr CInt
incaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      Ptr CInt
incbPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ [IO ()] -> IO ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ ([IO ()] -> IO ()) -> [IO ()] -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> [IO ()] -> [IO ()]
forall a. Int -> [a] -> [a]
take Int
m ([IO ()] -> [IO ()]) -> [IO ()] -> [IO ()]
forall a b. (a -> b) -> a -> b
$
         (Ptr a -> Ptr a -> Ptr a -> IO ())
-> [Ptr a] -> [Ptr a] -> [Ptr a] -> [IO ()]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3
            (\Ptr a
xkPtr Ptr a
akPtr Ptr a
bkPtr -> do
               Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
BlasGen.copy Ptr CInt
nPtr Ptr a
akPtr Ptr CInt
incaPtr Ptr a
bkPtr Ptr CInt
incbPtr
               Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
BlasGen.scal Ptr CInt
nPtr Ptr a
xkPtr Ptr a
bkPtr Ptr CInt
incbPtr)
            (Int -> Ptr a -> [Ptr a]
forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
1 Ptr a
xPtr)
            (Int -> Ptr a -> [Ptr a]
forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
n Ptr a
aPtr)
            (Int -> Ptr a -> [Ptr a]
forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
n Ptr a
bPtr)

scaleColumns ::
   (Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   Vector width a -> Matrix height width a -> Matrix height width a
scaleColumns :: Vector width a -> Matrix height width a -> Matrix height width a
scaleColumns (Array width
widthX ForeignPtr a
x) (Array shape :: (height, width)
shape@(height
height,width
width) ForeignPtr a
a) =
      (height, width) -> (Ptr a -> IO ()) -> Matrix height width a
forall sh a.
(C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> Array sh a
Array.unsafeCreate (height, width)
shape ((Ptr a -> IO ()) -> Matrix height width a)
-> (Ptr a -> IO ()) -> Matrix height width a
forall a b. (a -> b) -> a -> b
$ \Ptr a
bPtr -> do
   String -> Bool -> IO ()
Call.assert String
"scaleColumns: sizes mismatch" (width
widthX width -> width -> Bool
forall a. Eq a => a -> a -> Bool
== width
width)
   ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      let m :: Int
m = height -> Int
forall sh. C sh => sh -> Int
Shape.size height
height
      let n :: Int
n = width -> Int
forall sh. C sh => sh -> Int
Shape.size width
width
      Ptr CChar
transPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
'N'
      Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr CInt
klPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
0
      Ptr CInt
kuPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
0
      Ptr a
alphaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
one
      Ptr a
xPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x
      Ptr CInt
ldxPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
1
      Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
      Ptr CInt
incaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      Ptr a
betaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
zero
      Ptr CInt
incbPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ [IO ()] -> IO ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ ([IO ()] -> IO ()) -> [IO ()] -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> [IO ()] -> [IO ()]
forall a. Int -> [a] -> [a]
take Int
m ([IO ()] -> [IO ()]) -> [IO ()] -> [IO ()]
forall a b. (a -> b) -> a -> b
$
         (Ptr a -> Ptr a -> IO ()) -> [Ptr a] -> [Ptr a] -> [IO ()]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
            (\Ptr a
akPtr Ptr a
bkPtr ->
               Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
Private.gbmv Ptr CChar
transPtr
                  Ptr CInt
nPtr Ptr CInt
nPtr Ptr CInt
klPtr Ptr CInt
kuPtr Ptr a
alphaPtr Ptr a
xPtr Ptr CInt
ldxPtr
                  Ptr a
akPtr Ptr CInt
incaPtr Ptr a
betaPtr Ptr a
bkPtr Ptr CInt
incbPtr)
            (Int -> Ptr a -> [Ptr a]
forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
n Ptr a
aPtr)
            (Int -> Ptr a -> [Ptr a]
forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
n Ptr a
bPtr)


kronecker ::
   (Extent.C vert, Extent.C horiz,
    Shape.C heightA, Shape.C widthA, Shape.C heightB, Shape.C widthB,
    Class.Floating a) =>
   Full vert horiz heightA widthA a ->
   Matrix heightB widthB a ->
   Matrix (heightA,heightB) (widthA,widthB) a
kronecker :: Full vert horiz heightA widthA a
-> Matrix heightB widthB a
-> Matrix (heightA, heightB) (widthA, widthB) a
kronecker
      (Array (MatrixShape.Full Order
orderA Extent vert horiz heightA widthA
extentA) ForeignPtr a
a) (Array (heightB
heightB,widthB
widthB) ForeignPtr a
b) =
   let (heightA
heightA,widthA
widthA) = Extent vert horiz heightA widthA -> (heightA, widthA)
forall vert horiz height width.
(C vert, C horiz) =>
Extent vert horiz height width -> (height, width)
Extent.dimensions Extent vert horiz heightA widthA
extentA
   in ((heightA, heightB), (widthA, widthB))
-> (Ptr a -> IO ()) -> Matrix (heightA, heightB) (widthA, widthB) a
forall sh a.
(C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> Array sh a
Array.unsafeCreate ((heightA
heightA,heightB
heightB), (widthA
widthA,widthB
widthB)) ((Ptr a -> IO ()) -> Matrix (heightA, heightB) (widthA, widthB) a)
-> (Ptr a -> IO ()) -> Matrix (heightA, heightB) (widthA, widthB) a
forall a b. (a -> b) -> a -> b
$ \Ptr a
cPtr ->
      ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
   let (Int
ma,Int
na) = (heightA -> Int
forall sh. C sh => sh -> Int
Shape.size heightA
heightA, widthA -> Int
forall sh. C sh => sh -> Int
Shape.size widthA
widthA)
   let (Int
mb,Int
nb) = (heightB -> Int
forall sh. C sh => sh -> Int
Shape.size heightB
heightB, widthB -> Int
forall sh. C sh => sh -> Int
Shape.size widthB
widthB)
   let (Int
lda,Int
istep) =
         case Order
orderA of
            Order
RowMajor -> (Int
1,Int
na)
            Order
ColumnMajor -> (Int
ma,Int
1)
   Ptr CChar
transaPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
'N'
   Ptr CChar
transbPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
'T'
   Ptr CInt
mPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
na
   Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
nb
   Ptr CInt
kPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
   Ptr a
alphaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
one
   Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
   Ptr CInt
ldaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
lda
   Ptr a
bPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
b
   Ptr CInt
ldbPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
1
   Ptr a
betaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
zero
   Ptr CInt
ldcPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
nb
   IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$
      [(Int, Int)] -> ((Int, Int) -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ((Int -> Int -> (Int, Int)) -> [Int] -> [Int] -> [(Int, Int)]
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (,) (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
ma [Int
0..]) (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
mb [Int
0..])) (((Int, Int) -> IO ()) -> IO ()) -> ((Int, Int) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(Int
i,Int
j) -> do
         let aiPtr :: Ptr a
aiPtr = Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
aPtr (Int
istepInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
i)
         let bjPtr :: Ptr a
bjPtr = Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
bPtr (Int
nbInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
j)
         let cijPtr :: Ptr a
cijPtr = Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
cPtr (Int
naInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
nbInt -> Int -> Int
forall a. Num a => a -> a -> a
*(Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
mbInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
i))
         Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
BlasGen.gemm
            Ptr CChar
transbPtr Ptr CChar
transaPtr Ptr CInt
nPtr Ptr CInt
mPtr Ptr CInt
kPtr Ptr a
alphaPtr
            Ptr a
bjPtr Ptr CInt
ldbPtr Ptr a
aiPtr Ptr CInt
ldaPtr Ptr a
betaPtr Ptr a
cijPtr Ptr CInt
ldcPtr