module Numeric.Jalla.CVector
(
GVector(..),
CVector(..),
VectorVector(..),
VectorScalar(..),
module Numeric.Jalla.Indexable,
VMM,
createVector,
modifyVector,
module Numeric.Jalla.IMM,
vectorAdd,
vectorMap,
vectorBinMap,
listVector,
vectorList,
vectorGetElem,
innerProduct,
copyVector,
unsafeCopyVector,
unsafeVectorAdd,
unsafeVectorMap,
unsafeVectorBinMap,
CFloat,
CDouble,
Complex
) where
import Numeric.Jalla.Foreign.BLAS
import Numeric.Jalla.Foreign.BlasOps
import Numeric.Jalla.Foreign.LAPACKE
import Numeric.Jalla.Foreign.LapackeOps
import Numeric.Jalla.Internal
import Numeric.Jalla.IMM
import Numeric.Jalla.Indexable
import Numeric.Jalla.Types
import Foreign.C.Types
import Foreign.Marshal.Array
import Foreign hiding (unsafePerformIO)
import System.IO.Unsafe (unsafePerformIO)
import Data.Ix
import Data.Complex
import Control.Monad.State
import Data.Convertible
class (Indexable (vec e) Index e, Field1 e) => GVector vec e where
vectorLength :: vec e -> Index
class (BlasOps e, GVector vec e, Show (vec e)) => CVector vec e where
vectorAlloc :: Index -> IO (vec e)
withCVector :: vec e -> (Ptr e -> IO a) -> IO a
inc :: vec e -> Index
infixl 7 ||*
infixl 6 ||+,||-
class (CVector vec e) => VectorVector vec e where
(||+) :: vec e -> vec e -> vec e
v1 ||+ v2 = modifyVector v1 $ vectorAdd 1 v2
(||-) :: vec e -> vec e -> vec e
v1 ||- v2 = modifyVector v1 $ vectorAdd (1) v2
(||*) :: vec e -> vec e -> e
v1 ||* v2 = innerProduct v1 v2
innerProduct :: (BlasOps e, CVector vec e) => vec e -> vec e -> e
innerProduct v1 v2 | n == n2 = unsafePerformIO $
withCVector v1 $ \p1 ->
withCVector v2 $ \p2 ->
dot n p1 (inc v1) p2 (inc v2)
where
n = vectorLength v1
n2 = vectorLength v2
innerProduct _ _ | otherwise = error "innerProduct: vectors must have same length."
innerProductReal :: (BlasOpsReal e, CVector vec e) => vec e -> vec e -> CDouble
innerProductReal v1 v2 | n == n2 = realToFrac $ unsafePerformIO $
withCVector v1 $ \p1 ->
withCVector v2 $ \p2 ->
realdot n p1 (inc v1) p2 (inc v2)
where
n = vectorLength v1
n2 = vectorLength v2
innerProductReal _ _ | otherwise = error "innerProduct: vectors must have same length."
innerProductC :: (RealFloat e, BlasOpsComplex e, CVector vec (Complex e)) =>
vec (Complex e)
-> vec (Complex e)
-> Complex e
innerProductC v1 v2 | n == n2 = unsafePerformIO $
withCVector v1 $ \p1 ->
withCVector v2 $ \p2 ->
with (0 :+ 0) $ \pret ->
dotu_sub n p1 (inc v1) p2 (inc v2) pret >> peek pret
where
n = vectorLength v1
n2 = vectorLength v2
innerProductC _ _ | otherwise = error "innerProduct: vectors must have same length."
infixl 7 |.*,|./
infixl 6 |.+,|.-
class (CVector vec e) => VectorScalar vec e where
(|.*) :: vec e -> e -> vec e
a |.* b = vectorMap (*b) a
(|./) :: vec e -> e -> vec e
a |./ b = vectorMap (/b) a
(|.+) :: vec e -> e -> vec e
a |.+ b = vectorMap (+b) a
(|.-) :: vec e -> e -> vec e
a |.- b = vectorMap (()b) a
vectorGetElem :: CVector vec e => vec e -> Index -> e
vectorGetElem v i = unsafePerformIO $ unsafeGetElem v i
vectorList :: GVector vec e => vec e -> [e]
vectorList v = map (v !) [0..n1] where n = vectorLength v
listVector :: (CVector vec e) => [e] -> vec e
listVector es = createVector n $ setElems ies
where n = length es
ies = zip [0..n1] es
vectorMap :: (CVector vec1 e1, CVector vec2 e2) => (e1 -> e2) -> vec1 e1 -> vec2 e2
vectorMap f v1 = unsafePerformIO $
vectorAlloc n >>= \v2 -> unsafeVectorMap f v1 v2 >> return v2
where n = vectorLength v1
vectorBinMap :: (CVector vec1 e1, CVector vec2 e2, CVector vec3 e3) =>
(e1 -> e2 -> e3)
-> vec1 e1
-> vec2 e2
-> vec3 e3
vectorBinMap f v1 v2 = unsafePerformIO $
vectorAlloc n >>= \v3 -> unsafeVectorBinMap f v1 v2 v3 >> return v3
where n = min (vectorLength v1) (vectorLength v2)
copyVector :: (BlasOps e, CVector vec e, CVector vec2 e) => vec e -> IO (vec2 e)
copyVector v = vectorAlloc n >>= \ret ->
withCVector v $ \p ->
withCVector ret $ \pret ->
copy n p (inc v) pret (inc ret) >> return ret
where n = vectorLength v
type VMMMonad vec e a = StateT (vec e) IO a
newtype VMM s vec e a = VMM { unVMM :: VMMMonad vec e a } deriving (Applicative, Functor, Monad)
runVMM :: CVector vec e => vec e -> VMM s vec e a -> IO a
runVMM v action = evalStateT action' v
where
action' = unVMM action
instance (BlasOps e, CVector vec e) => IMM (VMM s vec e) Index (vec e) e where
setElem = setElem'
setElems = setElems'
fill = fill'
getElem = getElem'
createVector :: CVector vec e => Index -> VMM s vec e a -> vec e
createVector n action = unsafePerformIO $
vectorAlloc n >>= \mv -> runVMM mv (action >> (VMM get))
getVector :: CVector vec e => VMM s vec e (vec e)
getVector = VMM get
modifyVector :: CVector vec e => vec e -> VMM s vec e a -> vec e
modifyVector v action = unsafePerformIO $
copyVector v >>= \nv -> runVMM nv (action >> (VMM get))
where
n = vectorLength v
vectorAdd :: CVector vec e => e -> vec e -> VMM s vec e ()
vectorAdd alpha x = VMM $ (get >>= \v -> liftIO $ unsafeVectorAdd alpha x v)
setElem' :: CVector vec e => Index -> e -> VMM s vec e ()
setElem' i e = VMM $ (get >>= \v -> liftIO $ unsafeSetElem v i e >> return ())
setElems' :: CVector vec e => [(Index,e)] -> VMM s vec e ()
setElems' ies = VMM $ (get >>= \v -> liftIO $ mapM_ (\(i,e) -> unsafeSetElem v i e) ies)
getElem' :: CVector vec e => Index -> VMM s vec e e
getElem' i = VMM $ get >>= \v -> liftIO (unsafeGetElem v i)
fill' :: CVector vec e => e -> VMM s vec e ()
fill' e = VMM $ get >>= \v -> liftIO (unsafeFillVector v e)
unsafeVectorMap :: (CVector vec1 e1, CVector vec2 e2) => (e1 -> e2) -> vec1 e1 -> vec2 e2 -> IO ()
unsafeVectorMap f v1 v2 =
withCVector v1 $ \v1p ->
withCVector v2 $ \v2p ->
unsafePtrMapInc i1 i2 f v1p v2p n
where
i1 = inc v1
i2 = inc v2
n = min (vectorLength v1) (vectorLength v2)
unsafeCopyVector :: (CVector vec e, CVector vec2 e) =>
vec e
-> vec2 e
-> IO ()
unsafeCopyVector src dest =
withCVector src $ \srcp ->
withCVector dest $ \destp ->
copy n srcp (inc src) destp (inc dest)
where n = min (vectorLength src) (vectorLength dest)
unsafeVectorBinMap :: (CVector vec1 e1, CVector vec2 e2, CVector vec3 e3) =>
(e1 -> e2 -> e3)
-> vec1 e1
-> vec2 e2
-> vec3 e3
-> IO ()
unsafeVectorBinMap f v1 v2 v3 =
withCVector v1 $ \v1p ->
withCVector v2 $ \v2p ->
withCVector v3 $ \v3p ->
unsafe2PtrMapInc i1 i2 i3 f v1p v2p v3p n
where
i1 = inc v1
i2 = inc v2
i3 = inc v3
n = minimum [(vectorLength v1), (vectorLength v2), vectorLength v3]
unsafeVectorAdd :: (BlasOps e, CVector vec e) =>
e
-> vec e
-> vec e
-> IO ()
unsafeVectorAdd alpha v1 v2 | n == n2 =
withCVector v1 $ \p1 ->
withCVector v2 $ \p2 ->
axpy n alpha p1 (inc v1) p2 (inc v2)
where
n = vectorLength v1
n2 = vectorLength v2
unsafeVectorAdd _ v1 v2 | otherwise = error $ "unsafeVectorAdd: Vector lengths must match, when adding " ++ show v1 ++ "\nand\n" ++ show v2
unsafeSetElem :: (BlasOps e, CVector vec e) => vec e -> Index -> e -> IO ()
unsafeSetElem v i e | i >= 0 && i < vectorLength v = withCVector v $
\p -> let p' = p `plusPtr` (i * sizeOf e * (inc v)) in poke p' e
unsafeSetElem _ _ _ | otherwise = error "unsafeSetElem: out of bounds."
unsafeGetElem :: (BlasOps e, CVector vec e) => vec e -> Index -> IO e
unsafeGetElem v i | i >= 0 && i < vectorLength v = withCVector v $ \p -> do
e1 <- peek p
let p' = p `plusPtr` (i * sizeOf e1 * (inc v))
peek p'
unsafeGetElem _ _ | otherwise = error "unsafeGetElem: out of bounds."
unsafeFillVector :: (BlasOps e, CVector vec e) => vec e -> e -> IO ()
unsafeFillVector v e =
withCVector v $ \p ->
unsafePtrMap1Inc i (const e) p n
where
i = inc v
n = vectorLength v