module Numeric.LAPACK.Matrix.Square ( Square, size, toGeneral, fromGeneral, fromScalar, toScalar, fromList, autoFromList, transpose, adjoint, identity, identityFrom, diagonal, getDiagonal, trace, multiply, square, power, ) where import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape import qualified Numeric.LAPACK.Vector as Vector import qualified Numeric.LAPACK.Private as Private import Numeric.LAPACK.Matrix.Shape.Private (Order(RowMajor, ColumnMajor)) import Numeric.LAPACK.Matrix.Private (General, ZeroInt, zeroInt) import Numeric.LAPACK.Vector (Vector) import Numeric.LAPACK.Private (zero, one) import qualified Numeric.LAPACK.FFI.Generic as LapackGen import qualified Numeric.BLAS.FFI.Generic as BlasGen import qualified Numeric.Netlib.Utility as Call import qualified Numeric.Netlib.Class as Class import qualified Data.Array.Comfort.Storable.Internal as Array import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Storable.Internal (Array(Array)) import Foreign.ForeignPtr (withForeignPtr) import Foreign.Storable (Storable, peek, poke) import System.IO.Unsafe (unsafePerformIO) import Control.Monad.Trans.Cont (ContT(ContT), evalContT) import Control.Monad.IO.Class (liftIO) import Data.Function.HT (powerAssociative) type Square sh = Array (MatrixShape.Square sh) size :: Square sh a -> sh size = MatrixShape.squareSize . Array.shape toGeneral :: Square sh a -> General sh sh a toGeneral (Array sh a) = Array (MatrixShape.generalFromSquare sh) a fromGeneral :: (Eq sh) => General sh sh a -> Square sh a fromGeneral (Array (MatrixShape.General order height width) a) = if height==width then Array (MatrixShape.Square order height) a else error "Square.fromGeneral: no square shape" fromScalar :: (Storable a) => a -> Square () a fromScalar a = Array.unsafeCreate (MatrixShape.Square RowMajor ()) $ flip poke a toScalar :: (Storable a) => Square () a -> a toScalar (Array (MatrixShape.Square _ ()) a) = unsafePerformIO $ withForeignPtr a peek fromList :: (Shape.C sh, Storable a) => sh -> [a] -> Square sh a fromList sh = Array.fromList (MatrixShape.Square RowMajor sh) autoFromList :: (Storable a) => [a] -> Square ZeroInt a autoFromList xs = let n = length xs m = round $ sqrt (fromIntegral n :: Double) in if n == m*m then fromList (zeroInt m) xs else error "Square.autoFromList: no quadratic number of elements" transpose :: Square sh a -> Square sh a transpose = Array.mapShape MatrixShape.transposeSquare {- | conjugate transpose -} adjoint :: (Shape.C sh, Class.Floating a) => Square sh a -> Square sh a adjoint = transpose . Vector.conjugate identity :: (Shape.C sh, Class.Floating a) => sh -> Square sh a identity = identityOrder ColumnMajor identityFrom :: (Shape.C sh, Class.Floating a) => Square sh a -> Square sh a identityFrom (Array (MatrixShape.Square order sh) _) = identityOrder order sh identityOrder, _identityOrder :: (Shape.C sh, Class.Floating a) => Order -> sh -> Square sh a identityOrder order sh = Array.unsafeCreate (MatrixShape.Square order sh) $ \aPtr -> evalContT $ do uploPtr <- Call.char 'A' nPtr <- Call.cint $ Shape.size sh alphaPtr <- Call.number zero betaPtr <- Call.number one liftIO $ LapackGen.laset uploPtr nPtr nPtr alphaPtr betaPtr aPtr nPtr _identityOrder order sh = Array.unsafeCreateWithSize (MatrixShape.Square order sh) $ \blockSize yPtr -> evalContT $ do nPtr <- Call.alloca xPtr <- Call.number zero incxPtr <- Call.cint 0 incyPtr <- Call.cint 1 liftIO $ do poke nPtr $ fromIntegral blockSize BlasGen.copy nPtr xPtr incxPtr yPtr incyPtr let n = fromIntegral $ Shape.size sh poke nPtr n poke xPtr one poke incyPtr (n+1) BlasGen.copy nPtr xPtr incxPtr yPtr incyPtr diagonal :: (Shape.C sh, Class.Floating a) => Vector sh a -> Square sh a diagonal (Array sh x) = Array.unsafeCreateWithSize (MatrixShape.Square ColumnMajor sh) $ \blockSize yPtr -> evalContT $ do nPtr <- Call.alloca xPtr <- ContT $ withForeignPtr x zPtr <- Call.number zero incxPtr <- Call.cint 1 incyPtr <- Call.cint 1 inczPtr <- Call.cint 0 liftIO $ do poke nPtr $ fromIntegral blockSize BlasGen.copy nPtr zPtr inczPtr yPtr incyPtr let n = fromIntegral $ Shape.size sh poke nPtr n poke incyPtr (n+1) BlasGen.copy nPtr xPtr incxPtr yPtr incyPtr getDiagonal :: (Shape.C sh, Class.Floating a) => Square sh a -> Vector sh a getDiagonal (Array (MatrixShape.Square _ sh) x) = Array.unsafeCreateWithSize sh $ \n yPtr -> evalContT $ do nPtr <- Call.cint n xPtr <- ContT $ withForeignPtr x incxPtr <- Call.cint (n+1) incyPtr <- Call.cint 1 liftIO $ BlasGen.copy nPtr xPtr incxPtr yPtr incyPtr trace :: (Shape.C sh, Class.Floating a) => Square sh a -> a trace (Array (MatrixShape.Square _ sh) x) = unsafePerformIO $ do let n = Shape.size sh withForeignPtr x $ \xPtr -> Private.sum n xPtr (n+1) multiply :: (Shape.C sh, Eq sh, Class.Floating a) => Square sh a -> Square sh a -> Square sh a multiply (Array (MatrixShape.Square orderA shA) a) (Array (MatrixShape.Square orderB shB) b) = Array.unsafeCreate (MatrixShape.Square ColumnMajor shA) $ \cPtr -> do Call.assert "Square.multiply: shapes mismatch" (shA == shB) let n = Shape.size shA Private.multiplyMatrix orderA orderB n n n a b cPtr square :: (Shape.C sh, Class.Floating a) => Square sh a -> Square sh a square a = multiplyCommutativeUnchecked a a power :: (Shape.C sh, Class.Floating a) => Integer -> Square sh a -> Square sh a power n a = powerAssociative multiplyCommutativeUnchecked (identityFrom a) a n {- orderA and orderB must be equal but this is not checked. -} multiplyCommutativeUnchecked :: (Shape.C sh, Class.Floating a) => Square sh a -> Square sh a -> Square sh a multiplyCommutativeUnchecked (Array shape@(MatrixShape.Square order sh) a) (Array (MatrixShape.Square _order _sh) b) = Array.unsafeCreate shape $ \cPtr -> let n = Shape.size sh (at,bt) = case order of ColumnMajor -> (a,b) RowMajor -> (b,a) in Private.multiplyMatrix ColumnMajor ColumnMajor n n n at bt cPtr