{-# LANGUAGE FlexibleInstances, MultiParamTypeClasses #-} ----------------------------------------------------------------------------- -- | -- Module : Data.Matrix.Dense.Internal -- Copyright : Copyright (c) , Patrick Perry -- License : BSD3 -- Maintainer : Patrick Perry -- Stability : experimental -- module Data.Matrix.Dense.Internal ( -- * Dense matrix data types DMatrix(..), IOMatrix, Matrix, module BLAS.Matrix.Base, module BLAS.Tensor, -- * Converting to and from foreign pointers toForeignPtr, fromForeignPtr, ldaOf, isHerm, -- * Creating new matrices newMatrix, newMatrix_, newListMatrix, newColsMatrix, newRowsMatrix, listMatrix, -- * Special matrices newIdentity, setIdentity, -- * Row and column views row, col, rows, cols, -- * Diagonal views diag, -- * Matrix views submatrix, -- * Converting to/from vectors maybeFromRow, maybeFromCol, maybeToVector, -- * Lifting scalar and vector operations liftV, liftV2, -- * Casting matrices coerceMatrix, -- * Unsafe operations unsafeThaw, unsafeFreeze, unsafeNewMatrix, unsafeWithElemPtr, unsafeRow, unsafeCol, unsafeDiag, unsafeSubmatrix, ) where import Control.Monad ( forM_, zipWithM_ ) import Data.Ix ( inRange, range ) import Foreign import System.IO.Unsafe import Unsafe.Coerce import Data.AEq import Data.Vector.Dense.Internal hiding ( toForeignPtr, fromForeignPtr, unsafeFreeze, unsafeThaw, fptr, offset, unsafeWithElemPtr ) import qualified Data.Vector.Dense.Internal as V import qualified Data.Vector.Dense.Operations as V import BLAS.Access import BLAS.Internal ( inlinePerformIO, checkedRow, checkedCol, checkedDiag, checkedSubmatrix, diagStart, diagLen ) import BLAS.Elem ( Elem, BLAS1 ) import qualified BLAS.Elem as E import BLAS.Matrix.Base hiding ( Matrix ) import qualified BLAS.Matrix.Base as C import BLAS.Tensor import BLAS.Types -- | The mutable dense matrix data type. It can either store elements in -- column-major order, or provide a view into another matrix. The view -- transposes and conjugates the underlying matrix. data DMatrix t mn e = DM { fptr :: !(ForeignPtr e) -- ^ a pointer to the storage region , offset :: !Int -- ^ an offset (in elements, not bytes) to the first element in the matrix. , size1 :: !Int -- ^ the number of rows in the matrix , size2 :: !Int -- ^ the number of columns in the matrix , lda :: !Int -- ^ the leading dimension size of the matrix } | H !(DMatrix t mn e) -- ^ a transposed and conjugated matrix type Matrix = DMatrix Imm type IOMatrix = DMatrix Mut unsafeFreeze :: DMatrix t mn e -> Matrix mn e unsafeFreeze = unsafeCoerce unsafeThaw :: DMatrix t mn e -> IOMatrix mn e unsafeThaw = unsafeCoerce -- | Coerce the phantom shape type from one type to another. coerceMatrix :: DMatrix t mn e -> DMatrix t kl e coerceMatrix = unsafeCoerce -- | @fromForeignPtr f o mn l@ creates a matrix view of the data pointed to -- by @f@ starting at offset @o@ and having shape @mn@ and lda @l@. fromForeignPtr :: ForeignPtr e -> Int -> (Int,Int) -> Int -> DMatrix t (m,n) e fromForeignPtr f o (m,n) l = DM f o m n l -- | Convert a dense matrix to a pointer, offset, size, and lda. Note that this -- does not give the conjugacy/transpose information. For that, use 'isHerm'. toForeignPtr :: DMatrix t (m,n) e -> (ForeignPtr e, Int, (Int,Int), Int) toForeignPtr (H a) = toForeignPtr a toForeignPtr a@(DM _ _ _ _ _) = (fptr a, offset a, (size1 a, size2 a), lda a) -- | Get the lda of a matrix, defined as the number of elements in the underlying -- array that separate two consecutive elements in the same row of the matrix. ldaOf :: DMatrix t (m,n) e -> Int ldaOf (H a) = ldaOf a ldaOf a@(DM _ _ _ _ _) = lda a indexOf :: DMatrix t (m,n) e -> (Int,Int) -> Int indexOf (H a) (i,j) = indexOf a (j,i) indexOf a@(DM _ _ _ _ _) (i,j) = let o = offset a l = lda a in o + i + j*l -- | Get the storage order of the matrix. If 'isTrans' is true, this -- will be 'RowMajor'. Otherwise, it will be 'ColMajor'. orderOf :: DMatrix t (m,n) e -> Order orderOf (H a) = flipOrder (orderOf a) orderOf (DM _ _ _ _ _) = ColMajor -- | Get whether or not the matrix is transposed and conjugated. isHerm :: DMatrix t (m,n) e -> Bool isHerm (H a) = not (isHerm a) isHerm (DM _ _ _ _ _) = False -- | Create a new matrix of the given size and initialize the given elements to -- the given values. All other elements get initialized to zero. newMatrix :: (BLAS1 e) => (Int,Int) -> [((Int,Int), e)] -> IO (DMatrix t (m,n) e) newMatrix = newMatrixHelp writeElem -- | Same as 'newMatrix' but do not do any bounds-checking. unsafeNewMatrix :: (BLAS1 e) => (Int,Int) -> [((Int,Int), e)] -> IO (DMatrix t (m,n) e) unsafeNewMatrix = newMatrixHelp unsafeWriteElem newMatrixHelp :: (BLAS1 e) => (IOMatrix (m,n) e -> (Int,Int) -> e -> IO ()) -> (Int,Int) -> [((Int,Int),e)] -> IO (DMatrix t (m,n) e) newMatrixHelp set mn ijes = do x <- newZero mn io <- unsafeInterleaveIO $ mapM_ (uncurry $ set $ unsafeThaw x) ijes return $ io `seq` x -- | Create a new matrix of given shape, but do not initialize the elements. newMatrix_ :: (Elem e) => (Int,Int) -> IO (DMatrix t (m,n) e) newMatrix_ (m,n) | m < 0 || n < 0 = ioError $ userError $ "Tried to create a matrix with shape `" ++ show (m,n) ++ "'" | otherwise = do f <- mallocForeignPtrArray (m*n) return $ fromForeignPtr f 0 (m,n) (max 1 m) -- | Create a new matrix with the given elements in column-major order. newListMatrix :: (Elem e) => (Int,Int) -> [e] -> IO (DMatrix t (m,n) e) newListMatrix (m,n) es = do a <- newMatrix_ (m,n) unsafeWithElemPtr a (0,0) $ flip pokeArray (take (m*n) es) return a -- | Create a new matrix with the given elements in row-major order. listMatrix :: (Elem e) => (Int,Int) -> [e] -> Matrix (m,n) e listMatrix mn es = unsafePerformIO $ newListMatrix mn es {-# NOINLINE listMatrix #-} -- | Create a new matrix of the given shape with ones along the diagonal, -- and zeros everywhere else. newIdentity :: (BLAS1 e) => (Int,Int) -> IO (DMatrix t (m,n) e) newIdentity mn = do a <- newMatrix_ mn setIdentity (unsafeThaw a) return a -- | Set the diagonal to ones, and set everywhere else to zero. setIdentity :: (BLAS1 e) => IOMatrix (m,n) e -> IO () setIdentity a = do s <- getSize a case s of 0 -> return () _ -> setZero a >> setConstant 1 (diag a 0) -- | Form a matrix from a list of column vectors. newColsMatrix :: (BLAS1 e) => (Int,Int) -> [DVector t m e] -> IO (DMatrix r (m,n) e) newColsMatrix (m,n) cs = do a <- newZero (m,n) forM_ (zip [0..(n-1)] cs) $ \(j,c) -> V.copyVector (unsafeCol (unsafeThaw a) j) c return a -- | Form a matrix from a list of row vectors. newRowsMatrix :: (BLAS1 e) => (Int,Int) -> [DVector t n e] -> IO (DMatrix r (m,n) e) newRowsMatrix (m,n) rs = do a <- newZero (m,n) forM_ (zip [0..(m-1)] rs) $ \(i,r) -> V.copyVector (unsafeRow (unsafeThaw a) i) r return a -- | Evaluate a function with a pointer to the raw storage for the element -- at the given index. It may be necessary to conjugate or scale values before -- reading or writing to or from the location. unsafeWithElemPtr :: (Elem e) => DMatrix t (m,n) e -> (Int,Int) -> (Ptr e -> IO a) -> IO a unsafeWithElemPtr (H a) (i,j) f = unsafeWithElemPtr a (j,i) f unsafeWithElemPtr a@(DM _ _ _ _ _) ij f = withForeignPtr (fptr a) $ \ptr -> let ptr' = ptr `advancePtr` (indexOf a ij) in f ptr' -- | Get a vector view of the given row in a matrix. row :: (Elem e) => DMatrix t (m,n) e -> Int -> DVector t n e row a = checkedRow (shape a) (unsafeRow a) -- | Get a list of vector views of the rows of the matrix. rows :: (Elem e) => DMatrix t (m,n) e -> [DVector t n e] rows a = [ unsafeRow a i | i <- [0..numRows a - 1] ] -- | Get a list of vector views of the columns of the matrix. cols :: (Elem e) => DMatrix t (m,n) e -> [DVector t m e] cols a = [ unsafeCol a j | j <- [0..numCols a - 1] ] -- | Get a vector view of the given column in a matrix. col :: (Elem e) => DMatrix t (m,n) e -> Int -> DVector t m e col a = checkedCol (shape a) (unsafeCol a) -- | Same as 'row', but does not do any bounds checking. unsafeRow :: (Elem e) => DMatrix t (m,n) e -> Int -> DVector t n e unsafeRow a@(H _) i = conj $ unsafeCol (herm a) i unsafeRow a@(DM _ _ _ _ _) i = let f = fptr a o = indexOf a (i,0) n = size2 a s = lda a in V.fromForeignPtr f o n s -- | Same as 'col', but does not do any bounds checking. unsafeCol :: (Elem e) => DMatrix t (m,n) e -> Int -> DVector t m e unsafeCol a@(H _) j = conj $ unsafeRow (herm a) j unsafeCol a@(DM _ _ _ _ _) j = let f = fptr a o = indexOf a (0,j) m = size1 a s = 1 in V.fromForeignPtr f o m s -- | @diag a 0@ gets a vector view of the main diagonal of @a@. @diag a k@ for -- @k@ positive gets a view of the @k@th superdiagonal. For @k@ negative, it -- gets a view of the @(-k)@th subdiagonal. diag :: (Elem e) => DMatrix t (m,n) e -> Int -> DVector t k e diag a = checkedDiag (shape a) (unsafeDiag a) -- | Same as 'diag', but does not do any bounds checking. unsafeDiag :: (Elem e) => DMatrix t (m,n) e -> Int -> DVector t k e unsafeDiag (H a) i = conj $ unsafeDiag a (negate i) unsafeDiag a@(DM _ _ _ _ _) i = let f = fptr a o = indexOf a (diagStart i) n = diagLen (shape a) i s = lda a + 1 in V.fromForeignPtr f o n s -- | @submatrix a ij mn@ returns a view of the submatrix of @a@ with element @(0,0)@ -- being element @ij@ in @a@, and having shape @mn@. submatrix :: (Elem e) => DMatrix t (m,n) e -> (Int,Int) -> (Int,Int) -> DMatrix t (k,l) e submatrix a = checkedSubmatrix (shape a) (unsafeSubmatrix a) -- | Same as 'submatrix' but does not do any bounds checking. unsafeSubmatrix :: (Elem e) => DMatrix t (m,n) e -> (Int,Int) -> (Int,Int) -> DMatrix t (k,l) e unsafeSubmatrix a@(H _) (i,j) (m',n') = herm $ unsafeSubmatrix (herm a) (j,i) (n',m') unsafeSubmatrix a@(DM _ _ _ _ _) (i,j) mn' = let f = fptr a o = indexOf a (i,j) l = lda a in fromForeignPtr f o mn' l -- | Create a matrix view of a row vector. This will fail if the -- stride is not @1@ and the vector is conjugated. maybeFromRow :: (Elem e) => DVector t m e -> Maybe (DMatrix t (one,m) e) maybeFromRow (V.C (V.C x)) = maybeFromRow x maybeFromRow (V.C x@(V.DV _ _ _ _)) | V.stride x == 1 = let f = V.fptr x o = V.offset x n = V.dim x l = max 1 n in Just $ herm $ fromForeignPtr f o (n,1) l | otherwise = Nothing maybeFromRow x@(V.DV _ _ _ _) = let f = V.fptr x o = V.offset x n = V.dim x s = V.stride x l = max 1 s in Just $ fromForeignPtr f o (1,n) l -- | Possibly create a matrix view of a column vector. This will fail -- if the stride of the vector is not @1@ and the vector is not conjugated. maybeFromCol :: (Elem e) => DVector t n e -> Maybe (DMatrix t (n,one) e) maybeFromCol (V.C x) = maybeFromRow x >>= return . herm maybeFromCol x@(V.DV _ _ _ _) | V.stride x == 1 = let f = V.fptr x o = V.offset x m = dim x l = max 1 m in Just $ fromForeignPtr f o (m,1) l | otherwise = Nothing maybeToVector :: (Elem e) => DMatrix t (m,n) e -> Maybe (Order, DVector t k e) maybeToVector (H a) = maybeToVector a >>= (\(o,x) -> return (flipOrder o, conj x)) maybeToVector (DM f o m n ld) | ld == m = Just $ (ColMajor, V.fromForeignPtr f o (m*n) 1) | m == 1 = Just $ (ColMajor, V.fromForeignPtr f o n ld) | otherwise = Nothing -- | Modify each element in-place by applying a function to it. -- modifyWith :: (Elem e) => (e -> e) -> IOMatrix (m,n) e -> IO () -- | Take a unary elementwise vector operation and apply it to the elements of a matrix. liftV :: (Elem e) => (DVector t k e -> IO ()) -> DMatrix t (m,n) e -> IO () liftV f a = case maybeToVector a of Just (_,x) -> f x _ -> let xs = case orderOf a of RowMajor -> rows (coerceMatrix a) ColMajor -> cols (coerceMatrix a) in mapM_ f xs -- | Take a binary elementwise vector operation and apply it to the elements of a pair -- of matrices. liftV2 :: (Elem e) => (DVector s k e -> DVector t k e -> IO ()) -> DMatrix s (m,n) e -> DMatrix t (m,n) e -> IO () liftV2 f a b = case (maybeToVector a, maybeToVector b) of (Just (RowMajor,x), Just (RowMajor,y)) -> f x y (Just (ColMajor,x), Just (ColMajor,y)) -> f x y _ -> let (xs,ys) = case orderOf a of RowMajor -> (rows (coerceMatrix a), rows (coerceMatrix b)) ColMajor -> (cols (coerceMatrix a), cols (coerceMatrix b)) in zipWithM_ f xs ys instance C.Matrix (DMatrix t) where numRows = fst . shape numCols = snd . shape herm a = case a of (H a') -> coerceMatrix a' _ -> H (coerceMatrix a) instance Tensor (DMatrix t (m,n)) (Int,Int) e where shape a = case a of (H a') -> case shape a' of (m,n) -> (n,m) _ -> (size1 a, size2 a) bounds a = let (m,n) = shape a in ((0,0), (m-1,n-1)) instance (BLAS1 e) => ITensor (DMatrix Imm (m,n)) (Int,Int) e where size a = (numRows a * numCols a) unsafeAt a = inlinePerformIO . unsafeReadElem a {-# INLINE unsafeAt #-} indices a = [ (i,j) | j <- range (0,n-1), i <- range (0,m-1) ] where (m,n) = shape a {-# INLINE indices #-} elems = inlinePerformIO . getElems {-# INLINE elems #-} assocs = inlinePerformIO . getAssocs {-# INLINE assocs #-} (//) = replaceHelp writeElem unsafeReplace = replaceHelp unsafeWriteElem amap f a = listMatrix (shape a) (map f $ elems a) azipWith f a b | shape b /= mn = error ("azipWith: matrix shapes differ; first has shape `" ++ show mn ++ "' and second has shape `" ++ show (shape b) ++ "'") | otherwise = listMatrix mn (zipWith f (elems a) (elems b)) where mn = shape a replaceHelp :: (BLAS1 e) => (IOMatrix (m,n) e -> (Int,Int) -> e -> IO ()) -> Matrix (m,n) e -> [((Int,Int), e)] -> Matrix (m,n) e replaceHelp set x ies = unsafeFreeze $ unsafePerformIO $ do y <- newCopy (unsafeThaw x) mapM_ (uncurry $ set y) ies return y {-# NOINLINE replaceHelp #-} instance (BLAS1 e) => IDTensor (DMatrix Imm (m,n)) (Int,Int) e where zero = unsafePerformIO . newZero {-# NOINLINE zero #-} constant mn = unsafePerformIO . newConstant mn {-# NOINLINE constant #-} instance (BLAS1 e) => RTensor (DMatrix t (m,n)) (Int,Int) e IO where getSize a = return (numRows a * numCols a) newCopy a = case a of (H a') -> newCopy a' >>= return . H _ -> do a' <- newMatrix_ (shape a) liftV2 V.copyVector (unsafeThaw a') a return a' unsafeReadElem a (i,j) = case a of (H a') -> unsafeReadElem a' (j,i) >>= return . E.conj _ -> withForeignPtr (fptr a) $ \ptr -> peekElemOff ptr (indexOf a (i,j)) {-# INLINE unsafeReadElem #-} getIndices = return . indices . unsafeFreeze {-# INLINE getIndices #-} getElems a = return $ go (cols a) where go cs | cs `seq` False = undefined go [] = [] go (c:cs) = let e = inlinePerformIO $ getElems c es = go cs in e ++ es {-# NOINLINE getElems #-} getAssocs a = return $ go (cols a) 0 where go cs j | cs `seq` j `seq` False = undefined go [] _ = [] go (c:cs) j = let ie = inlinePerformIO $ getAssocs c ije = map (\(i,e) -> ((i,j),e)) ie ijes = go cs (j+1) in ije ++ ijes {-# NOINLINE getAssocs #-} instance (BLAS1 e) => RDTensor (DMatrix t (m,n)) (Int,Int) e IO where newZero mn = do a <- newMatrix_ mn setZero (unsafeThaw a) return a newConstant mn e = do a <- newMatrix_ mn setConstant e (unsafeThaw a) return a instance (BLAS1 e) => MTensor (DMatrix Mut (m,n)) (Int,Int) e IO where setZero = liftV setZero setConstant e = liftV (setConstant e) canModifyElem a ij = return $ inRange (bounds a) ij {-# INLINE canModifyElem #-} unsafeWriteElem a (i,j) e = case a of (H a') -> unsafeWriteElem a' (j,i) $ E.conj e _ -> withForeignPtr (fptr a) $ \ptr -> pokeElemOff ptr (indexOf a (i,j)) e modifyWith f = liftV (modifyWith f) instance (BLAS1 e, Show e) => Show (DMatrix Imm (m,n) e) where show a = case a of (H a') -> "herm (" ++ show a' ++ ")" _ -> "listMatrix " ++ show (shape a) ++ " " ++ show (elems a) compareHelp :: (BLAS1 e) => (e -> e -> Bool) -> Matrix (m,n) e -> Matrix (m,n) e -> Bool compareHelp cmp x y | isHerm x && isHerm y = compareHelp cmp (herm x) (herm y) compareHelp cmp x y = (shape x == shape y) && (and $ zipWith cmp (elems x) (elems y)) instance (BLAS1 e, Eq e) => Eq (DMatrix Imm (m,n) e) where (==) = compareHelp (==) instance (BLAS1 e, AEq e) => AEq (DMatrix Imm (m,n) e) where (===) = compareHelp (===) (~==) = compareHelp (~==)