{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module Numeric.LAPACK.Matrix.RowMajor ( Matrix, Matrix.takeRow, Matrix.takeColumn, Matrix.fromRows, Matrix.tensorProduct, Matrix.decomplex, Matrix.recomplex, Matrix.scaleRows, Matrix.scaleColumns, kronecker, ) where import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent import Numeric.LAPACK.Matrix.Private (Full) import qualified Numeric.BLAS.Matrix.RowMajor as Matrix import Numeric.BLAS.Matrix.RowMajor (Matrix) import Numeric.BLAS.Matrix.Layout (Order(RowMajor, ColumnMajor)) import Numeric.BLAS.Scalar (zero, one) import qualified Numeric.BLAS.FFI.Generic as BlasGen import qualified Numeric.Netlib.Utility as Call import qualified Numeric.Netlib.Class as Class import Foreign.Marshal.Array (advancePtr) import Foreign.ForeignPtr (withForeignPtr) import Control.Monad.Trans.Cont (ContT(ContT), evalContT) import Control.Monad.IO.Class (liftIO) import Control.Applicative (liftA2) import qualified Data.Array.Comfort.Storable.Unchecked as Array import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Storable.Unchecked (Array(Array)) import Data.Foldable (forM_) kronecker :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C heightA, Shape.C widthA, Shape.C heightB, Shape.C widthB, Class.Floating a) => Full meas vert horiz heightA widthA a -> Matrix heightB widthB a -> Matrix (heightA,heightB) (widthA,widthB) a kronecker :: forall meas vert horiz heightA widthA heightB widthB a. (Measure meas, C vert, C horiz, C heightA, C widthA, C heightB, C widthB, Floating a) => Full meas vert horiz heightA widthA a -> Matrix heightB widthB a -> Matrix (heightA, heightB) (widthA, widthB) a kronecker (Array (Layout.Full Order orderA Extent meas vert horiz heightA widthA extentA) ForeignPtr a a) (Array (heightB heightB,widthB widthB) ForeignPtr a b) = let (heightA heightA,widthA widthA) = Extent meas vert horiz heightA widthA -> (heightA, widthA) forall meas vert horiz height width. (Measure meas, C vert, C horiz) => Extent meas vert horiz height width -> (height, width) Extent.dimensions Extent meas vert horiz heightA widthA extentA in ((heightA, heightB), (widthA, widthB)) -> (Ptr a -> IO ()) -> Array ((heightA, heightB), (widthA, widthB)) a forall sh a. (C sh, Storable a) => sh -> (Ptr a -> IO ()) -> Array sh a Array.unsafeCreate ((heightA heightA,heightB heightB), (widthA widthA,widthB widthB)) ((Ptr a -> IO ()) -> Array ((heightA, heightB), (widthA, widthB)) a) -> (Ptr a -> IO ()) -> Array ((heightA, heightB), (widthA, widthB)) a forall a b. (a -> b) -> a -> b $ \Ptr a cPtr -> ContT () IO () -> IO () forall (m :: * -> *) r. Monad m => ContT r m r -> m r evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO () forall a b. (a -> b) -> a -> b $ do let (Int ma,Int na) = (heightA -> Int forall sh. C sh => sh -> Int Shape.size heightA heightA, widthA -> Int forall sh. C sh => sh -> Int Shape.size widthA widthA) let (Int mb,Int nb) = (heightB -> Int forall sh. C sh => sh -> Int Shape.size heightB heightB, widthB -> Int forall sh. C sh => sh -> Int Shape.size widthB widthB) let (Int lda,Int istep) = case Order orderA of Order RowMajor -> (Int 1,Int na) Order ColumnMajor -> (Int ma,Int 1) Ptr CChar transaPtr <- Char -> FortranIO () (Ptr CChar) forall r. Char -> FortranIO r (Ptr CChar) Call.char Char 'N' Ptr CChar transbPtr <- Char -> FortranIO () (Ptr CChar) forall r. Char -> FortranIO r (Ptr CChar) Call.char Char 'T' Ptr CInt mPtr <- Int -> FortranIO () (Ptr CInt) forall r. Int -> FortranIO r (Ptr CInt) Call.cint Int na Ptr CInt nPtr <- Int -> FortranIO () (Ptr CInt) forall r. Int -> FortranIO r (Ptr CInt) Call.cint Int nb Ptr CInt kPtr <- Int -> FortranIO () (Ptr CInt) forall r. Int -> FortranIO r (Ptr CInt) Call.cint Int 1 Ptr a alphaPtr <- a -> FortranIO () (Ptr a) forall a r. Floating a => a -> FortranIO r (Ptr a) Call.number a forall a. Floating a => a one Ptr a aPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a) forall {k} (r :: k) (m :: k -> *) a. ((a -> m r) -> m r) -> ContT r m a ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)) -> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a) forall a b. (a -> b) -> a -> b $ ForeignPtr a -> (Ptr a -> IO ()) -> IO () forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b withForeignPtr ForeignPtr a a Ptr CInt ldaPtr <- Int -> FortranIO () (Ptr CInt) forall r. Int -> FortranIO r (Ptr CInt) Call.leadingDim Int lda Ptr a bPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a) forall {k} (r :: k) (m :: k -> *) a. ((a -> m r) -> m r) -> ContT r m a ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)) -> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a) forall a b. (a -> b) -> a -> b $ ForeignPtr a -> (Ptr a -> IO ()) -> IO () forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b withForeignPtr ForeignPtr a b Ptr CInt ldbPtr <- Int -> FortranIO () (Ptr CInt) forall r. Int -> FortranIO r (Ptr CInt) Call.leadingDim Int 1 Ptr a betaPtr <- a -> FortranIO () (Ptr a) forall a r. Floating a => a -> FortranIO r (Ptr a) Call.number a forall a. Floating a => a zero Ptr CInt ldcPtr <- Int -> FortranIO () (Ptr CInt) forall r. Int -> FortranIO r (Ptr CInt) Call.leadingDim Int nb IO () -> ContT () IO () forall a. IO a -> ContT () IO a forall (m :: * -> *) a. MonadIO m => IO a -> m a liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO () forall a b. (a -> b) -> a -> b $ [(Int, Int)] -> ((Int, Int) -> IO ()) -> IO () forall (t :: * -> *) (m :: * -> *) a b. (Foldable t, Monad m) => t a -> (a -> m b) -> m () forM_ ((Int -> Int -> (Int, Int)) -> [Int] -> [Int] -> [(Int, Int)] forall a b c. (a -> b -> c) -> [a] -> [b] -> [c] forall (f :: * -> *) a b c. Applicative f => (a -> b -> c) -> f a -> f b -> f c liftA2 (,) (Int -> [Int] -> [Int] forall a. Int -> [a] -> [a] take Int ma [Int 0..]) (Int -> [Int] -> [Int] forall a. Int -> [a] -> [a] take Int mb [Int 0..])) (((Int, Int) -> IO ()) -> IO ()) -> ((Int, Int) -> IO ()) -> IO () forall a b. (a -> b) -> a -> b $ \(Int i,Int j) -> do let aiPtr :: Ptr a aiPtr = Ptr a -> Int -> Ptr a forall a. Storable a => Ptr a -> Int -> Ptr a advancePtr Ptr a aPtr (Int istepInt -> Int -> Int forall a. Num a => a -> a -> a *Int i) let bjPtr :: Ptr a bjPtr = Ptr a -> Int -> Ptr a forall a. Storable a => Ptr a -> Int -> Ptr a advancePtr Ptr a bPtr (Int nbInt -> Int -> Int forall a. Num a => a -> a -> a *Int j) let cijPtr :: Ptr a cijPtr = Ptr a -> Int -> Ptr a forall a. Storable a => Ptr a -> Int -> Ptr a advancePtr Ptr a cPtr (Int naInt -> Int -> Int forall a. Num a => a -> a -> a *Int nbInt -> Int -> Int forall a. Num a => a -> a -> a *(Int jInt -> Int -> Int forall a. Num a => a -> a -> a +Int mbInt -> Int -> Int forall a. Num a => a -> a -> a *Int i)) Ptr CChar -> Ptr CChar -> Ptr CInt -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO () forall a. Floating a => Ptr CChar -> Ptr CChar -> Ptr CInt -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO () BlasGen.gemm Ptr CChar transbPtr Ptr CChar transaPtr Ptr CInt nPtr Ptr CInt mPtr Ptr CInt kPtr Ptr a alphaPtr Ptr a bjPtr Ptr CInt ldbPtr Ptr a aiPtr Ptr CInt ldaPtr Ptr a betaPtr Ptr a cijPtr Ptr CInt ldcPtr