-- | -- Module : Statistics.Matrix.Mutable -- Copyright : (c) 2014 Bryan O'Sullivan -- License : BSD3 -- -- Basic mutable matrix operations. module Statistics.Matrix.Mutable ( MMatrix(..) , MVector , replicate , thaw , bounds , unsafeNew , unsafeFreeze , unsafeRead , unsafeWrite , unsafeModify , immutably , unsafeBounds ) where import Control.Applicative ((<$>)) import Control.DeepSeq (NFData(..)) import Control.Monad.ST (ST) import Statistics.Matrix.Types (Matrix(..), MMatrix(..), MVector) import qualified Data.Vector.Unboxed as U import qualified Data.Vector.Unboxed.Mutable as M import Prelude hiding (replicate) replicate :: Int -> Int -> Double -> ST s (MMatrix s) replicate r c k = MMatrix r c 0 <$> M.replicate (r*c) k thaw :: Matrix -> ST s (MMatrix s) thaw (Matrix r c e v) = MMatrix r c e <$> U.thaw v unsafeFreeze :: MMatrix s -> ST s Matrix unsafeFreeze (MMatrix r c e mv) = Matrix r c e <$> U.unsafeFreeze mv -- | Allocate new matrix. Matrix content is not initialized hence unsafe. unsafeNew :: Int -- ^ Number of row -> Int -- ^ Number of columns -> ST s (MMatrix s) unsafeNew r c | r < 0 = error "Statistics.Matrix.Mutable.unsafeNew: negative number of rows" | c < 0 = error "Statistics.Matrix.Mutable.unsafeNew: negative number of columns" | otherwise = do vec <- M.new (r*c) return $ MMatrix r c 0 vec unsafeRead :: MMatrix s -> Int -> Int -> ST s Double unsafeRead mat r c = unsafeBounds mat r c M.unsafeRead {-# INLINE unsafeRead #-} unsafeWrite :: MMatrix s -> Int -> Int -> Double -> ST s () unsafeWrite mat row col k = unsafeBounds mat row col $ \v i -> M.unsafeWrite v i k {-# INLINE unsafeWrite #-} unsafeModify :: MMatrix s -> Int -> Int -> (Double -> Double) -> ST s () unsafeModify mat row col f = unsafeBounds mat row col $ \v i -> do k <- M.unsafeRead v i M.unsafeWrite v i (f k) {-# INLINE unsafeModify #-} -- | Given row and column numbers, calculate the offset into the flat -- row-major vector. bounds :: MMatrix s -> Int -> Int -> (MVector s -> Int -> r) -> r bounds (MMatrix rs cs _ mv) r c k | r < 0 || r >= rs = error "row out of bounds" | c < 0 || c >= cs = error "column out of bounds" | otherwise = k mv $! r * cs + c {-# INLINE bounds #-} -- | Given row and column numbers, calculate the offset into the flat -- row-major vector, without checking. unsafeBounds :: MMatrix s -> Int -> Int -> (MVector s -> Int -> r) -> r unsafeBounds (MMatrix _ cs _ mv) r c k = k mv $! r * cs + c {-# INLINE unsafeBounds #-} immutably :: NFData a => MMatrix s -> (Matrix -> a) -> ST s a immutably mmat f = do k <- f <$> unsafeFreeze mmat rnf k `seq` return k {-# INLINE immutably #-}