module Numeric.LAPACK.Vector ( Vector, fromList, autoFromList, constant, dot, sum, absSum, norm1, norm2, argAbsMaximum, argAbs1Maximum, product, scale, add, sub , mac, mul, outer, conjugate, random, RandomDistribution(..), ) where import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape import qualified Numeric.LAPACK.Matrix.Private as Matrix import qualified Numeric.LAPACK.Private as Private import Numeric.LAPACK.Private (RealOf, zero, one, minusOne, fill) import qualified Numeric.LAPACK.FFI.Generic as LapackGen import qualified Numeric.LAPACK.FFI.Complex as LapackComplex import qualified Numeric.BLAS.FFI.Generic as BlasGen import qualified Numeric.BLAS.FFI.Complex as BlasComplex import qualified Numeric.BLAS.FFI.Real as BlasReal import qualified Numeric.Netlib.Utility as Call import qualified Numeric.Netlib.Class as Class import Foreign.ForeignPtr (withForeignPtr) import Foreign.Ptr (Ptr) import Foreign.Storable (Storable, peek, peekElemOff, pokeElemOff) import Foreign.C.Types (CInt) import System.IO.Unsafe (unsafePerformIO) import Control.Monad.Trans.Cont (ContT(ContT), evalContT) import Control.Monad.IO.Class (liftIO) import Control.Applicative (Const(Const,getConst), (<$>)) import qualified Data.Array.Comfort.Storable.Internal as Array import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Storable.Internal (Array(Array)) import Data.Complex (Complex) import Data.Word (Word64) import Data.Bits (shiftR, (.&.)) import Prelude hiding (sum, product) type Vector = Array fromList :: (Shape.C sh, Storable a) => sh -> [a] -> Vector sh a fromList = Array.fromList autoFromList :: (Storable a) => [a] -> Vector (Shape.ZeroBased Int) a autoFromList = Array.vectorFromList constant :: (Shape.C sh, Class.Floating a) => sh -> a -> Vector sh a constant sh a = Array.unsafeCreateWithSize sh $ fill a newtype Dot sh a = Dot {runDot :: Vector sh a -> Vector sh a -> a} dot :: (Shape.C sh, Eq sh, Class.Floating a) => Vector sh a -> Vector sh a -> a dot = runDot $ Class.switchFloating (Dot dotReal) (Dot dotReal) (Dot dotComplex) (Dot dotComplex) dotReal :: (Shape.C sh, Eq sh, Class.Real a) => Vector sh a -> Vector sh a -> a dotReal (Array shX x) (Array shY y) = unsafePerformIO $ do Call.assert "dot: shapes mismatch" (shX == shY) evalContT $ do nPtr <- Call.cint $ Shape.size shX sxPtr <- ContT $ withForeignPtr x incxPtr <- Call.cint 1 syPtr <- ContT $ withForeignPtr y incyPtr <- Call.cint 1 liftIO $ BlasReal.dot nPtr sxPtr incxPtr syPtr incyPtr {- We cannot use 'cdot' because Haskell's FFI does not support Complex numbers as return values. -} dotComplex :: (Shape.C sh, Eq sh, Class.Real a) => Vector sh (Complex a) -> Vector sh (Complex a) -> Complex a dotComplex (Array shX x) (Array shY y) = unsafePerformIO $ do Call.assert "dot: shapes mismatch" (shX == shY) evalContT $ do transPtr <- Call.char 'N' mPtr <- Call.cint 1 nPtr <- Call.cint $ Shape.size shX alphaPtr <- Call.number one xPtr <- ContT $ withForeignPtr x ldxPtr <- Call.cint 1 yPtr <- ContT $ withForeignPtr y incyPtr <- Call.cint 1 betaPtr <- Call.number zero zPtr <- Call.alloca inczPtr <- Call.cint 1 liftIO $ BlasGen.gemv transPtr mPtr nPtr alphaPtr xPtr ldxPtr yPtr incyPtr betaPtr zPtr inczPtr liftIO $ peek zPtr sum :: (Shape.C sh, Class.Floating a) => Vector sh a -> a sum (Array sh x) = unsafePerformIO $ withForeignPtr x $ \xPtr -> Private.sum (Shape.size sh) xPtr 1 norm1 :: (Shape.C sh, Class.Floating a) => Vector sh a -> RealOf a norm1 (Array sh x) = unsafePerformIO $ evalContT $ do nPtr <- Call.cint $ Shape.size sh sxPtr <- ContT $ withForeignPtr x incxPtr <- Call.cint 1 liftIO $ csum1 nPtr sxPtr incxPtr csum1 :: Class.Floating a => Ptr CInt -> Ptr a -> Ptr CInt -> IO (RealOf a) csum1 = getNorm $ Class.switchFloating (Norm BlasReal.asum) (Norm BlasReal.asum) (Norm LapackComplex.sum1) (Norm LapackComplex.sum1) {- | Sum of the absolute values of real numbers or components of complex numbers. For real numbers it is equivalent to 'norm1'. -} absSum :: (Shape.C sh, Class.Floating a) => Vector sh a -> RealOf a absSum (Array sh x) = unsafePerformIO $ evalContT $ do nPtr <- Call.cint $ Shape.size sh sxPtr <- ContT $ withForeignPtr x incxPtr <- Call.cint 1 liftIO $ asum nPtr sxPtr incxPtr asum :: Class.Floating a => Ptr CInt -> Ptr a -> Ptr CInt -> IO (RealOf a) asum = getNorm $ Class.switchFloating (Norm BlasReal.asum) (Norm BlasReal.asum) (Norm BlasComplex.casum) (Norm BlasComplex.casum) {- | Euclidean norm of a vector or Frobenius norm of a matrix. -} norm2 :: (Shape.C sh, Class.Floating a) => Vector sh a -> RealOf a norm2 (Array sh x) = unsafePerformIO $ evalContT $ do nPtr <- Call.cint $ Shape.size sh sxPtr <- ContT $ withForeignPtr x incxPtr <- Call.cint 1 liftIO $ nrm2 nPtr sxPtr incxPtr nrm2 :: Class.Floating a => Ptr CInt -> Ptr a -> Ptr CInt -> IO (RealOf a) nrm2 = getNorm $ Class.switchFloating (Norm BlasReal.nrm2) (Norm BlasReal.nrm2) (Norm BlasComplex.cnrm2) (Norm BlasComplex.cnrm2) newtype Norm a = Norm {getNorm :: Ptr CInt -> Ptr a -> Ptr CInt -> IO (RealOf a)} {- | Returns the index and value of the element with the maximal absolute value. Caution: It actually returns the value of the element, not its absolute value! -} argAbsMaximum :: (Shape.C sh, Class.Floating a) => Vector sh a -> (Shape.Index sh, a) argAbsMaximum (Array sh x) = unsafePerformIO $ evalContT $ do nPtr <- Call.cint $ Shape.size sh sxPtr <- ContT $ withForeignPtr x incxPtr <- Call.cint 1 liftIO $ do k <- fromIntegral . subtract 1 <$> absMax nPtr sxPtr incxPtr xmax <- peekElemOff sxPtr k return (Shape.indices sh !! k, xmax) newtype ArgMaximum a = ArgMaximum {runArgMaximum :: Ptr CInt -> Ptr a -> Ptr CInt -> IO CInt} absMax :: Class.Floating a => Ptr CInt -> Ptr a -> Ptr CInt -> IO CInt absMax = runArgMaximum $ Class.switchFloating (ArgMaximum BlasGen.iamax) (ArgMaximum BlasGen.iamax) (ArgMaximum LapackComplex.imax1) (ArgMaximum LapackComplex.imax1) {- | Returns the index and value of the element with the maximal absolute value. The function does not strictly compare the absolute value of a complex number but the sum of the absolute complex components. Caution: It actually returns the value of the element, not its absolute value! -} argAbs1Maximum :: (Shape.C sh, Class.Floating a) => Vector sh a -> (Shape.Index sh, a) argAbs1Maximum (Array sh x) = unsafePerformIO $ evalContT $ do nPtr <- Call.cint $ Shape.size sh sxPtr <- ContT $ withForeignPtr x incxPtr <- Call.cint 1 liftIO $ do k <- fromIntegral . subtract 1 <$> BlasGen.iamax nPtr sxPtr incxPtr xmax <- peekElemOff sxPtr k return (Shape.indices sh !! k, xmax) product :: (Shape.C sh, Class.Floating a) => Vector sh a -> a product (Array sh x) = unsafePerformIO $ withForeignPtr x $ \xPtr -> Private.product (Shape.size sh) xPtr 1 scale, _scale :: (Shape.C sh, Class.Floating a) => a -> Vector sh a -> Vector sh a scale alpha (Array sh x) = Array.unsafeCreateWithSize sh $ \n syPtr -> do evalContT $ do alphaPtr <- Call.number alpha nPtr <- Call.cint n sxPtr <- ContT $ withForeignPtr x incxPtr <- Call.cint 1 incyPtr <- Call.cint 1 liftIO $ BlasGen.copy nPtr sxPtr incxPtr syPtr incyPtr liftIO $ BlasGen.scal nPtr alphaPtr syPtr incyPtr _scale a (Array sh b) = Array.unsafeCreateWithSize sh $ \n cPtr -> do let m = 1 let k = 1 evalContT $ do transaPtr <- Call.char 'N' transbPtr <- Call.char 'N' mPtr <- Call.cint m kPtr <- Call.cint k nPtr <- Call.cint n alphaPtr <- Call.number one aPtr <- Call.number a ldaPtr <- Call.cint m bPtr <- ContT $ withForeignPtr b ldbPtr <- Call.cint k betaPtr <- Call.number zero ldcPtr <- Call.cint m liftIO $ BlasGen.gemm transaPtr transbPtr mPtr nPtr kPtr alphaPtr aPtr ldaPtr bPtr ldbPtr betaPtr cPtr ldcPtr add, sub :: (Shape.C sh, Eq sh, Class.Floating a) => Vector sh a -> Vector sh a -> Vector sh a add = mac one sub x y = mac minusOne y x mac :: (Shape.C sh, Eq sh, Class.Floating a) => a -> Vector sh a -> Vector sh a -> Vector sh a mac alpha (Array shX x) (Array shY y) = Array.unsafeCreateWithSize shX $ \n szPtr -> do Call.assert "mac: shapes mismatch" (shX == shY) evalContT $ do nPtr <- Call.cint n saPtr <- Call.number alpha sxPtr <- ContT $ withForeignPtr x incxPtr <- Call.cint 1 syPtr <- ContT $ withForeignPtr y incyPtr <- Call.cint 1 inczPtr <- Call.cint 1 liftIO $ BlasGen.copy nPtr syPtr incyPtr szPtr inczPtr liftIO $ BlasGen.axpy nPtr saPtr sxPtr incxPtr szPtr inczPtr mul :: (Shape.C sh, Eq sh, Class.Floating a) => Vector sh a -> Vector sh a -> Vector sh a mul (Array shA a) (Array shX x) = Array.unsafeCreateWithSize shX $ \n yPtr -> do Call.assert "mul: shapes mismatch" (shA == shX) evalContT $ do transPtr <- Call.char 'N' mPtr <- Call.cint n nPtr <- Call.cint n klPtr <- Call.cint 0 kuPtr <- Call.cint 0 alphaPtr <- Call.number one aPtr <- ContT $ withForeignPtr a ldaPtr <- Call.cint 1 xPtr <- ContT $ withForeignPtr x incxPtr <- Call.cint 1 betaPtr <- Call.number zero incyPtr <- Call.cint 1 liftIO $ BlasGen.gbmv transPtr mPtr nPtr klPtr kuPtr alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr outer :: (Shape.C shx, Eq shx, Shape.C shy, Eq shy, Class.Floating a) => Vector shx a -> Vector shy a -> Matrix.General shx shy a outer (Array shX x) (Array shY y) = Array.unsafeCreate (MatrixShape.General MatrixShape.ColumnMajor shX shY) $ \cPtr -> do let m = Shape.size shX let n = Shape.size shY evalContT $ do transaPtr <- Call.char 'N' transbPtr <- Call.char 'N' mPtr <- Call.cint m nPtr <- Call.cint n kPtr <- Call.cint 1 alphaPtr <- Call.number one aPtr <- ContT $ withForeignPtr x ldaPtr <- Call.cint m bPtr <- ContT $ withForeignPtr y ldbPtr <- Call.cint 1 betaPtr <- Call.number zero ldcPtr <- Call.cint m liftIO $ BlasGen.gemm transaPtr transbPtr mPtr nPtr kPtr alphaPtr aPtr ldaPtr bPtr ldbPtr betaPtr cPtr ldcPtr newtype Conjugate sh a = Conjugate {getConjugate :: Vector sh a -> Vector sh a} conjugate :: (Shape.C sh, Class.Floating a) => Vector sh a -> Vector sh a conjugate = getConjugate $ Class.switchFloating (Conjugate id) (Conjugate id) (Conjugate complexConjugate) (Conjugate complexConjugate) complexConjugate :: (Shape.C sh, Class.Real a) => Vector sh (Complex a) -> Vector sh (Complex a) complexConjugate (Array sh x) = Array.unsafeCreateWithSize sh $ \n syPtr -> evalContT $ do nPtr <- Call.cint n sxPtr <- ContT $ withForeignPtr x incxPtr <- Call.cint 1 incyPtr <- Call.cint 1 liftIO $ do BlasGen.copy nPtr sxPtr incxPtr syPtr incyPtr LapackComplex.lacgv nPtr syPtr incyPtr data RandomDistribution = UniformBox01 | UniformBoxPM1 | Normal | UniformDisc | UniformCircle deriving (Eq, Ord, Show, Enum) {- @random distribution shape seed@ Only the least significant 47 bits of @seed@ are used. -} random :: (Shape.C sh, Class.Floating a) => RandomDistribution -> sh -> Word64 -> Vector sh a random dist sh seed = Array.unsafeCreateWithSize sh $ \n xPtr -> evalContT $ do nPtr <- Call.cint n distPtr <- Call.cint $ case (getConst $ isComplexInFunctor xPtr, dist) of (_, UniformBox01) -> 1 (_, UniformBoxPM1) -> 2 (_, Normal) -> 3 (True, UniformDisc) -> 4 (True, UniformCircle) -> 5 (False, UniformDisc) -> 2 (False, UniformCircle) -> error "Vector.random: UniformCircle not supported for real numbers" iseedPtr <- Call.allocaArray 4 liftIO $ do pokeElemOff iseedPtr 0 $ fromIntegral ((seed `shiftR` 35) .&. 0xFFF) pokeElemOff iseedPtr 1 $ fromIntegral ((seed `shiftR` 23) .&. 0xFFF) pokeElemOff iseedPtr 2 $ fromIntegral ((seed `shiftR` 11) .&. 0xFFF) pokeElemOff iseedPtr 3 $ fromIntegral ((seed.&.0x7FF)*2+1) LapackGen.larnv distPtr iseedPtr nPtr xPtr isComplexInFunctor :: (Class.Floating a) => f a -> Const Bool a isComplexInFunctor _ = Class.switchFloating (Const False) (Const False) (Const True) (Const True)