{-# LANGUAGE MultiParamTypeClasses, FunctionalDependencies, FlexibleInstances, GeneralizedNewtypeDeriving, FlexibleContexts, TypeSynonymInstances #-} {-| Functions to work with C-like arrays. This is provided to have arrays which work with CBLAS and LAPACKE libraries. -} module Numeric.Jalla.CVector ( -- * Classes -- ** Vectors GVector(..), CVector(..), -- ** Vector/vector operations VectorVector(..), -- ** Vector/scalar operations VectorScalar(..), -- ** Indexable module Numeric.Jalla.Indexable, -- * Construction, conversion, modification -- ** Monadic, efficient vector modification VMM, -- getVector, createVector, modifyVector, module Numeric.Jalla.IMM, -- ** Vector maps vectorAdd, vectorMap, vectorBinMap, -- * Conversion From And To Lists listVector, vectorList, -- * Getting Single Values vectorGetElem, -- * Inner product innerProduct, -- * IO Functions copyVector, -- * Unsafe Functions unsafeCopyVector, unsafeVectorAdd, unsafeVectorMap, unsafeVectorBinMap, -- * Re-exported CFloat, CDouble, Complex ) where import Numeric.Jalla.Foreign.BLAS import Numeric.Jalla.Foreign.BlasOps import Numeric.Jalla.Foreign.LAPACKE import Numeric.Jalla.Foreign.LapackeOps import Numeric.Jalla.Internal import Numeric.Jalla.IMM import Numeric.Jalla.Indexable import Numeric.Jalla.Types import Foreign.C.Types import Foreign.Marshal.Array import Foreign hiding (unsafePerformIO) import System.IO.Unsafe (unsafePerformIO) import Data.Ix import Data.Complex import Control.Monad.State import Data.Convertible class (Indexable (vec e) Index e, Field1 e) => GVector vec e where -- | Creates a vector. Not really useful may be dropped in the future. -- vector :: Index -> vec e -- | Returns the length of a vector. vectorLength :: vec e -> Index --(-|) :: vec e -> vec e -> e -- (!) :: vec e -> Index -> e class (BlasOps e, GVector vec e, Show (vec e)) => CVector vec e where -- | Allocate a vector of a given length. vectorAlloc :: Index -> IO (vec e) -- | Operate on a vector with the given IO action. The action gets as parameter a pointer to the array. withCVector :: vec e -> (Ptr e -> IO a) -> IO a -- | Returns the increment per element for this vector (like the /inc/ arguments for BLAS). For contiguous storage, this would simply be 1. inc :: vec e -> Index infixl 7 ||* infixl 6 ||+,||- {-| Vector/vector operations. -} class (CVector vec e) => VectorVector vec e where -- | Vector addition (||+) :: vec e -> vec e -> vec e v1 ||+ v2 = modifyVector v1 $ vectorAdd 1 v2 -- | Vector subtraction (||-) :: vec e -> vec e -> vec e v1 ||- v2 = modifyVector v1 $ vectorAdd (-1) v2 -- | Dot product (||*) :: vec e -> vec e -> e v1 ||* v2 = innerProduct v1 v2 innerProduct :: (BlasOps e, CVector vec e) => vec e -> vec e -> e innerProduct v1 v2 | n == n2 = unsafePerformIO $ withCVector v1 $ \p1 -> withCVector v2 $ \p2 -> dot n p1 (inc v1) p2 (inc v2) where n = vectorLength v1 n2 = vectorLength v2 innerProduct _ _ | otherwise = error "innerProduct: vectors must have same length." innerProductReal :: (BlasOpsReal e, CVector vec e) => vec e -> vec e -> CDouble innerProductReal v1 v2 | n == n2 = realToFrac $ unsafePerformIO $ withCVector v1 $ \p1 -> withCVector v2 $ \p2 -> realdot n p1 (inc v1) p2 (inc v2) where n = vectorLength v1 n2 = vectorLength v2 innerProductReal _ _ | otherwise = error "innerProduct: vectors must have same length." innerProductC :: (RealFloat e, BlasOpsComplex e, CVector vec (Complex e)) => vec (Complex e) -> vec (Complex e) -> Complex e innerProductC v1 v2 | n == n2 = unsafePerformIO $ withCVector v1 $ \p1 -> withCVector v2 $ \p2 -> with (0 :+ 0) $ \pret -> dotu_sub n p1 (inc v1) p2 (inc v2) pret >> peek pret where n = vectorLength v1 n2 = vectorLength v2 innerProductC _ _ | otherwise = error "innerProduct: vectors must have same length." infixl 7 |.*,|./ infixl 6 |.+,|.- {-| Vector manipulations by a scalar. -} class (CVector vec e) => VectorScalar vec e where (|.*) :: vec e -> e -> vec e a |.* b = vectorMap (*b) a (|./) :: vec e -> e -> vec e a |./ b = vectorMap (/b) a (|.+) :: vec e -> e -> vec e a |.+ b = vectorMap (+b) a (|.-) :: vec e -> e -> vec e a |.- b = vectorMap ((-)b) a vectorGetElem :: CVector vec e => vec e -> Index -> e vectorGetElem v i = unsafePerformIO $ unsafeGetElem v i vectorList :: GVector vec e => vec e -> [e] vectorList v = map (v !) [0..n-1] where n = vectorLength v listVector :: (CVector vec e) => [e] -> vec e listVector es = createVector n $ setElems ies where n = length es ies = zip [0..n-1] es {-| Maps a unary function over the elements of a vector and returns the resulting vector. -} vectorMap :: (CVector vec1 e1, CVector vec2 e2) => (e1 -> e2) -> vec1 e1 -> vec2 e2 vectorMap f v1 = unsafePerformIO $ vectorAlloc n >>= \v2 -> unsafeVectorMap f v1 v2 >> return v2 where n = vectorLength v1 {-# NOINLINE vectorMap #-} {-| Maps a binary function to the elements of two vectors and returns the resulting vector. -} vectorBinMap :: (CVector vec1 e1, CVector vec2 e2, CVector vec3 e3) => (e1 -> e2 -> e3) -- ^ The function /f/ to map. -> vec1 e1 -- ^ The first input vector /v1/ for /f/. -> vec2 e2 -- ^ The second input vector /v2/ for /f/. -> vec3 e3 -- ^ The result vector. It will have length min(l1,l2), where l1,l2 are the lengths of /v1/ and /v2/. vectorBinMap f v1 v2 = unsafePerformIO $ vectorAlloc n >>= \v3 -> unsafeVectorBinMap f v1 v2 v3 >> return v3 where n = min (vectorLength v1) (vectorLength v2) {-| Make a copy of the input vector. Using the cblas_*copy functions. -} copyVector :: (BlasOps e, CVector vec e, CVector vec2 e) => vec e -> IO (vec2 e) copyVector v = vectorAlloc n >>= \ret -> withCVector v $ \p -> withCVector ret $ \pret -> copy n p (inc v) pret (inc ret) >> return ret where n = vectorLength v ------------------------------- -- Monadic vector manipulations ------------------------------- type VMMMonad vec e a = StateT (vec e) IO a newtype VMM s vec e a = VMM { unVMM :: VMMMonad vec e a } deriving (Applicative, Functor, Monad) runVMM :: CVector vec e => vec e -> VMM s vec e a -> IO a runVMM v action = evalStateT action' v where action' = unVMM action instance (BlasOps e, CVector vec e) => IMM (VMM s vec e) Index (vec e) e where -- create = createVector -- modify = modifyVector -- getO = getVector setElem = setElem' setElems = setElems' fill = fill' getElem = getElem' createVector :: CVector vec e => Index -> VMM s vec e a -> vec e createVector n action = unsafePerformIO $ vectorAlloc n >>= \mv -> runVMM mv (action >> (VMM get)) getVector :: CVector vec e => VMM s vec e (vec e) getVector = VMM get modifyVector :: CVector vec e => vec e -> VMM s vec e a -> vec e modifyVector v action = unsafePerformIO $ copyVector v >>= \nv -> runVMM nv (action >> (VMM get)) where n = vectorLength v {-| Adds alpha * v to the current vector. -} vectorAdd :: CVector vec e => e -> vec e -> VMM s vec e () vectorAdd alpha x = VMM $ (get >>= \v -> liftIO $ unsafeVectorAdd alpha x v) {-| unsafeSetElem may fail gracefully, therefore this method may or may not set the element, depending on a successful range check. -} setElem' :: CVector vec e => Index -> e -> VMM s vec e () setElem' i e = VMM $ (get >>= \v -> liftIO $ unsafeSetElem v i e >> return ()) setElems' :: CVector vec e => [(Index,e)] -> VMM s vec e () setElems' ies = VMM $ (get >>= \v -> liftIO $ mapM_ (\(i,e) -> unsafeSetElem v i e) ies) {-| Note: getElem' returns a Maybe. -} getElem' :: CVector vec e => Index -> VMM s vec e e getElem' i = VMM $ get >>= \v -> liftIO (unsafeGetElem v i) fill' :: CVector vec e => e -> VMM s vec e () fill' e = VMM $ get >>= \v -> liftIO (unsafeFillVector v e) --------------------------------------------------------------------------------------- -- Unsafe functions. --------------------------------------------------------------------------------------- unsafeVectorMap :: (CVector vec1 e1, CVector vec2 e2) => (e1 -> e2) -> vec1 e1 -> vec2 e2 -> IO () unsafeVectorMap f v1 v2 = withCVector v1 $ \v1p -> withCVector v2 $ \v2p -> unsafePtrMapInc i1 i2 f v1p v2p n where i1 = inc v1 i2 = inc v2 n = min (vectorLength v1) (vectorLength v2) {-# INLINABLE unsafeVectorMap #-} {-| Copies from one vector to the other, in-place and therefore unsafely. Uses the BLAS 'copy' function. /min (vectorLength src) (vectorlength dest)/ elements are copied from the first to the second vector. -} unsafeCopyVector :: (CVector vec e, CVector vec2 e) => vec e -- ^ The source vector. -> vec2 e -- ^ The destination vector. -> IO () unsafeCopyVector src dest = withCVector src $ \srcp -> withCVector dest $ \destp -> copy n srcp (inc src) destp (inc dest) where n = min (vectorLength src) (vectorLength dest) unsafeVectorBinMap :: (CVector vec1 e1, CVector vec2 e2, CVector vec3 e3) => (e1 -> e2 -> e3) -> vec1 e1 -> vec2 e2 -> vec3 e3 -> IO () unsafeVectorBinMap f v1 v2 v3 = withCVector v1 $ \v1p -> withCVector v2 $ \v2p -> withCVector v3 $ \v3p -> unsafe2PtrMapInc i1 i2 i3 f v1p v2p v3p n where i1 = inc v1 i2 = inc v2 i3 = inc v3 n = minimum [(vectorLength v1), (vectorLength v2), vectorLength v3] {-| Computes v2 <- alpha * v1 + v2. The result is stored in the memory of v2, therefore this is unsafe and low level, only for internal use. -} unsafeVectorAdd :: (BlasOps e, CVector vec e) => e -- ^ alpha -> vec e -- ^ Vector 1 -> vec e -- ^ Vector 2, will be changed in place! -> IO () unsafeVectorAdd alpha v1 v2 | n == n2 = withCVector v1 $ \p1 -> withCVector v2 $ \p2 -> axpy n alpha p1 (inc v1) p2 (inc v2) where n = vectorLength v1 n2 = vectorLength v2 unsafeVectorAdd _ v1 v2 | otherwise = error $ "unsafeVectorAdd: Vector lengths must match, when adding " ++ show v1 ++ "\nand\n" ++ show v2 {-# NOINLINE unsafeVectorAdd #-} unsafeSetElem :: (BlasOps e, CVector vec e) => vec e -> Index -> e -> IO () unsafeSetElem v i e | i >= 0 && i < vectorLength v = withCVector v $ \p -> let p' = p `plusPtr` (i * sizeOf e * (inc v)) in poke p' e unsafeSetElem _ _ _ | otherwise = error "unsafeSetElem: out of bounds." unsafeGetElem :: (BlasOps e, CVector vec e) => vec e -> Index -> IO e unsafeGetElem v i | i >= 0 && i < vectorLength v = withCVector v $ \p -> do e1 <- peek p let p' = p `plusPtr` (i * sizeOf e1 * (inc v)) peek p' unsafeGetElem _ _ | otherwise = error "unsafeGetElem: out of bounds." unsafeFillVector :: (BlasOps e, CVector vec e) => vec e -> e -> IO () unsafeFillVector v e = withCVector v $ \p -> unsafePtrMap1Inc i (const e) p n where -- Using /f/ instead of the unsafePtrMap1Inc should be /slightly/ more efficient, -- but it is a good idea to have such dirty functions in a central place. -- f _ _ 0 = return () -- f i p n = poke p e >> f i (advancePtr p i) (n - 1) where {p' = advancePtr p i; n' = n - 1 } i = inc v n = vectorLength v