{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Numeric.LAPACK.Matrix.RowMajor (
   Matrix,
   Matrix.takeRow,
   Matrix.takeColumn,
   Matrix.fromRows,
   Matrix.tensorProduct,
   Matrix.decomplex,
   Matrix.recomplex,
   Matrix.scaleRows,
   Matrix.scaleColumns,
   kronecker,
   ) where

import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import Numeric.LAPACK.Matrix.Private (Full)

import qualified Numeric.BLAS.Matrix.RowMajor as Matrix
import Numeric.BLAS.Matrix.RowMajor (Matrix)
import Numeric.BLAS.Matrix.Layout (Order(RowMajor, ColumnMajor))
import Numeric.BLAS.Scalar (zero, one)

import qualified Numeric.BLAS.FFI.Generic as BlasGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import Foreign.Marshal.Array (advancePtr)
import Foreign.ForeignPtr (withForeignPtr)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import Control.Applicative (liftA2)

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))

import Data.Foldable (forM_)


kronecker ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C heightA, Shape.C widthA, Shape.C heightB, Shape.C widthB,
    Class.Floating a) =>
   Full meas vert horiz heightA widthA a ->
   Matrix heightB widthB a ->
   Matrix (heightA,heightB) (widthA,widthB) a
kronecker :: forall meas vert horiz heightA widthA heightB widthB a.
(Measure meas, C vert, C horiz, C heightA, C widthA, C heightB,
 C widthB, Floating a) =>
Full meas vert horiz heightA widthA a
-> Matrix heightB widthB a
-> Matrix (heightA, heightB) (widthA, widthB) a
kronecker
      (Array (Layout.Full Order
orderA Extent meas vert horiz heightA widthA
extentA) ForeignPtr a
a) (Array (heightB
heightB,widthB
widthB) ForeignPtr a
b) =
   let (heightA
heightA,widthA
widthA) = Extent meas vert horiz heightA widthA -> (heightA, widthA)
forall meas vert horiz height width.
(Measure meas, C vert, C horiz) =>
Extent meas vert horiz height width -> (height, width)
Extent.dimensions Extent meas vert horiz heightA widthA
extentA
   in ((heightA, heightB), (widthA, widthB))
-> (Ptr a -> IO ())
-> Array ((heightA, heightB), (widthA, widthB)) a
forall sh a.
(C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> Array sh a
Array.unsafeCreate ((heightA
heightA,heightB
heightB), (widthA
widthA,widthB
widthB)) ((Ptr a -> IO ())
 -> Array ((heightA, heightB), (widthA, widthB)) a)
-> (Ptr a -> IO ())
-> Array ((heightA, heightB), (widthA, widthB)) a
forall a b. (a -> b) -> a -> b
$ \Ptr a
cPtr ->
      ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
   let (Int
ma,Int
na) = (heightA -> Int
forall sh. C sh => sh -> Int
Shape.size heightA
heightA, widthA -> Int
forall sh. C sh => sh -> Int
Shape.size widthA
widthA)
   let (Int
mb,Int
nb) = (heightB -> Int
forall sh. C sh => sh -> Int
Shape.size heightB
heightB, widthB -> Int
forall sh. C sh => sh -> Int
Shape.size widthB
widthB)
   let (Int
lda,Int
istep) =
         case Order
orderA of
            Order
RowMajor -> (Int
1,Int
na)
            Order
ColumnMajor -> (Int
ma,Int
1)
   Ptr CChar
transaPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
'N'
   Ptr CChar
transbPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
'T'
   Ptr CInt
mPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
na
   Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
nb
   Ptr CInt
kPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
   Ptr a
alphaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
one
   Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
   Ptr CInt
ldaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
lda
   Ptr a
bPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
b
   Ptr CInt
ldbPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
1
   Ptr a
betaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
zero
   Ptr CInt
ldcPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
nb
   IO () -> ContT () IO ()
forall a. IO a -> ContT () IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$
      [(Int, Int)] -> ((Int, Int) -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ((Int -> Int -> (Int, Int)) -> [Int] -> [Int] -> [(Int, Int)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (,) (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
ma [Int
0..]) (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
mb [Int
0..])) (((Int, Int) -> IO ()) -> IO ()) -> ((Int, Int) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(Int
i,Int
j) -> do
         let aiPtr :: Ptr a
aiPtr = Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
aPtr (Int
istepInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
i)
         let bjPtr :: Ptr a
bjPtr = Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
bPtr (Int
nbInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
j)
         let cijPtr :: Ptr a
cijPtr = Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
cPtr (Int
naInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
nbInt -> Int -> Int
forall a. Num a => a -> a -> a
*(Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
mbInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
i))
         Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
BlasGen.gemm
            Ptr CChar
transbPtr Ptr CChar
transaPtr Ptr CInt
nPtr Ptr CInt
mPtr Ptr CInt
kPtr Ptr a
alphaPtr
            Ptr a
bjPtr Ptr CInt
ldbPtr Ptr a
aiPtr Ptr CInt
ldaPtr Ptr a
betaPtr Ptr a
cijPtr Ptr CInt
ldcPtr