module Numerical.HBLAS.MatrixTypes where
import qualified Data.Vector.Storable as S
import qualified Data.Vector.Storable.Mutable as SM
import Control.Monad.Primitive
import Data.Typeable
data Orientation = Row | Column
deriving (Eq,Show,Typeable)
type Row = 'Row
type Column = 'Column
data SOrientation :: Orientation -> * where
SRow :: SOrientation Row
SColumn :: SOrientation Column
deriving (Typeable)
instance Eq (SOrientation a) where
_ == _ = True
instance Ord (SOrientation a) where
compare _ _ = EQ
instance Show (SOrientation a) where
show SRow = "SRow"
show SColumn = "SColumn"
sTranpose :: (x~ TransposeF y, y~TransposeF x ) =>SOrientation x -> SOrientation y
sTranpose SColumn = SRow
sTranpose SRow = SColumn
data Transpose = NoTranspose | Transpose | ConjTranspose | ConjNoTranspose
deriving(Typeable,Eq,Show)
data MatUpLo = MatUpper | MatLower
deriving(Typeable,Eq,Show)
data MatDiag= MatUnit | MatNonUnit
deriving(Typeable,Eq,Show)
data EquationSide = LeftSide | RightSide
deriving(Typeable,Eq,Show)
type family TransposeF (x :: Orientation) :: Orientation
type instance TransposeF Row = Column
type instance TransposeF Column = Row
data Variant = Direct | Implicit
deriving(Typeable,Eq,Show)
data SVariant :: Variant -> * where
SImplicit :: {_frontPadding :: !Int, _endPadding:: !Int } -> SVariant 'Implicit
SDirect :: SVariant 'Direct
data DenseVector :: Variant -> * -> * where
DenseVector :: { _VariantDenseVect :: !(SVariant varnt)
,_LogicalDimDenseVector :: !Int
,_StrideDenseVector :: ! Int
,_bufferDenseVector :: !(S.Vector elem)
} -> DenseVector varnt elem
#if defined(__GLASGOW_HASKELL_) && (__GLASGOW_HASKELL__ >= 707)
deriving (Typeable)
#endif
data MDenseVector :: * -> Variant -> * -> * where
MutableDenseVector :: { _VariantMutDenseVect :: !(SVariant varnt)
,_LogicalDimMutDenseVector :: !Int
,_StrideMutDenseVector :: ! Int
,_bufferMutDenseVector :: !(S.MVector s elem)
} -> MDenseVector s varnt elem
#if defined(__GLASGOW_HASKELL_) && (__GLASGOW_HASKELL__ >= 707)
deriving (Typeable)
#endif
data DenseMatrix :: Orientation -> * -> * where
DenseMatrix ::{ _OrientationMat :: SOrientation ornt ,
_XdimDenMat :: !Int,
_YdimDenMat :: !Int ,
_StrideDenMat :: !Int ,
_bufferDenMat :: !(S.Vector elem) }-> DenseMatrix ornt elem
deriving (Typeable,Eq,Ord,Show)
mutableVectorToList :: (PrimMonad m, S.Storable a) => S.MVector (PrimState m) a -> m [a]
mutableVectorToList mv = do
v <- S.unsafeFreeze mv
return (S.toList v )
data MDenseMatrix :: * ->Orientation -> * -> * where
MutableDenseMatrix :: { _OrientationMutMat :: SOrientation ornt ,
_XdimDenMutMat :: !Int ,
_YdimDenMutMat :: !Int,
_StrideDenMutMat :: !Int,
_bufferDenMutMat :: !(SM.MVector s elem) } -> MDenseMatrix s ornt elem
type IODenseMatrix = MDenseMatrix RealWorld
unsafeFreezeDenseMatrix :: (SM.Storable elem, PrimMonad m)=> MDenseMatrix (PrimState m) or elem -> m (DenseMatrix or elem)
unsafeFreezeDenseMatrix (MutableDenseMatrix ornt a b c mv) = do
v <- S.unsafeFreeze mv
return $! DenseMatrix ornt a b c v
unsafeThawDenseMatrix :: (SM.Storable elem, PrimMonad m)=> DenseMatrix or elem-> m (MDenseMatrix (PrimState m) or elem)
unsafeThawDenseMatrix (DenseMatrix ornt a b c v) = do
mv <- S.unsafeThaw v
return $! MutableDenseMatrix ornt a b c mv
--freezeDenseMatrix
getDenseMatrixRow :: DenseMatrix or elem -> Int
getDenseMatrixRow (DenseMatrix _ _ ydim _ _)= ydim
getDenseMatrixColumn :: DenseMatrix or elem -> Int
getDenseMatrixColumn (DenseMatrix _ xdim _ _ _)= xdim
getDenseMatrixLeadingDimStride :: DenseMatrix or elem -> Int
getDenseMatrixLeadingDimStride (DenseMatrix _ _ _ stride _ ) = stride
getDenseMatrixArray :: DenseMatrix or elem -> S.Vector elem
getDenseMatrixArray (DenseMatrix _ _ _ _ arr) = arr
getDenseMatrixOrientation :: DenseMatrix or elem -> SOrientation or
getDenseMatrixOrientation m = _OrientationMat m
uncheckedDenseMatrixIndex :: (S.Storable elem )=> DenseMatrix or elem -> (Int,Int) -> elem
uncheckedDenseMatrixIndex (DenseMatrix SRow _ _ ystride arr) = \ (x,y)-> arr `S.unsafeIndex` (x + y * ystride)
uncheckedDenseMatrixIndex (DenseMatrix SColumn _ _ xstride arr) = \ (x,y)-> arr `S.unsafeIndex` (y + x* xstride)
uncheckedDenseMatrixIndexM :: (Monad m ,S.Storable elem )=> DenseMatrix or elem -> (Int,Int) -> m elem
uncheckedDenseMatrixIndexM (DenseMatrix SRow _ _ ystride arr) = \ (x,y)-> return $! arr `S.unsafeIndex` (x + y * ystride)
uncheckedDenseMatrixIndexM (DenseMatrix SColumn _ _ xstride arr) = \ (x,y)-> return $! arr `S.unsafeIndex` (y + x* xstride)
uncheckedMutableDenseMatrixIndexM :: (PrimMonad m ,S.Storable elem )=> MDenseMatrix (PrimState m) or elem -> (Int,Int) -> m elem
uncheckedMutableDenseMatrixIndexM (MutableDenseMatrix SRow _ _ ystride arr) = \ (x,y)-> arr `SM.unsafeRead` (x + y * ystride)
uncheckedMutableDenseMatrixIndexM (MutableDenseMatrix SColumn _ _ xstride arr) = \ (x,y)-> arr `SM.unsafeRead` (y + x* xstride)
swap :: (a,b)->(b,a)
swap = \ (!x,!y)-> (y,x)
mapDenseMatrix :: (S.Storable a, S.Storable b) => (a->b) -> DenseMatrix or a -> DenseMatrix or b
mapDenseMatrix f rm@(DenseMatrix SRow xdim ydim _ _) =
DenseMatrix SRow xdim ydim xdim $!
S.generate (xdim * ydim) (\ix -> f $! uncheckedDenseMatrixIndex rm (swap $ quotRem ix xdim ) )
mapDenseMatrix f rm@(DenseMatrix SColumn xdim ydim _ _) =
DenseMatrix SColumn xdim ydim ydim $!
S.generate (xdim * ydim ) (\ix -> f $! uncheckedDenseMatrixIndex rm ( quotRem ix ydim ) )
imapDenseMatrix :: (S.Storable a, S.Storable b) => ((Int,Int)->a->b) -> DenseMatrix or a -> DenseMatrix or b
imapDenseMatrix f rm@(DenseMatrix sornt xdim ydim _ _) =
generateDenseMatrix sornt (xdim,ydim) (\ix -> f ix $! uncheckedDenseMatrixIndex rm ix )
uncheckedDenseMatrixNextTuple :: DenseMatrix or elem -> (Int,Int) -> Maybe (Int,Int)
uncheckedDenseMatrixNextTuple (DenseMatrix SRow xdim ydim _ _) =
\(!x,!y)-> if (x >= xdim && y >= ydim) then Nothing else Just $! swap $! quotRem (x+ xdim * y + 1) xdim
uncheckedDenseMatrixNextTuple (DenseMatrix SColumn xdim ydim _ _ ) =
\(!x,!y) -> if (x >= xdim && y >= ydim) then Nothing else Just $! quotRem (y + ydim * x + 1) ydim
generateDenseMatrix :: (S.Storable a)=> SOrientation x -> (Int,Int)->((Int,Int)-> a) -> DenseMatrix x a
generateDenseMatrix SRow (xdim,ydim) f = DenseMatrix SRow xdim ydim xdim $!
S.generate (xdim * ydim) (\ix ->
let !ixtup@(!_,!_) = swap $ quotRem ix xdim
in f ixtup )
generateDenseMatrix SColumn (xdim,ydim) f = DenseMatrix SColumn xdim ydim ydim $!
S.generate (xdim * ydim ) (\ix -> let ixtup@(!_,!_) = ( quotRem ix ydim ) in
f ixtup )
generateMutableDenseMatrix :: (S.Storable a,PrimMonad m)=>
SOrientation x -> (Int,Int)->((Int,Int)-> a) -> m (MDenseMatrix (PrimState m) x a)
generateMutableDenseMatrix sor dims fun = do
x <- unsafeThawDenseMatrix $! generateDenseMatrix sor dims fun
return x
generateMutableUpperTriangular :: forall a x m . (Num a, S.Storable a, PrimMonad m)=>
SOrientation x -> (Int,Int)->((Int,Int)-> a) -> m (MDenseMatrix (PrimState m) x a)
generateMutableUpperTriangular sor dims fun = do
x <- unsafeThawDenseMatrix $! generateDenseMatrix sor dims trimFun
return x
where trimFun (x, y) = (if x>=y then fun (x, y) else (0 :: a))
generateMutableLowerTriangular :: forall a x m . (Num a, S.Storable a,PrimMonad m)=>
SOrientation x -> (Int,Int)->((Int,Int)-> a) -> m (MDenseMatrix (PrimState m) x a)
generateMutableLowerTriangular sor dims fun = do
x <- unsafeThawDenseMatrix $! generateDenseMatrix sor dims trimFun
return x
where trimFun (x, y) = (if x<=y then fun (x, y) else (0 :: a))
generateMutableDenseVector :: (S.Storable a,PrimMonad m) => Int -> (Int -> a) ->
m (MDenseVector (PrimState m ) 'Direct a)
generateMutableDenseVector size init = do
mv <- S.unsafeThaw $ S.generate size init
return $! MutableDenseVector SDirect size 1 mv
generateMutableDenseVectorWithStride :: (Num a ,S.Storable a,PrimMonad m)
=> Int -> Int -> (Int -> a) -> m (MDenseVector (PrimState m ) 'Direct a)
generateMutableDenseVectorWithStride size stride init = do
mv <- S.unsafeThaw $ S.generate (size * stride) zeroOffStride
return $! MutableDenseVector SDirect size stride mv
where
zeroOffStride i | i `mod` stride == 0 = init (i `div` stride)
| otherwise = 0
uncheckedDenseMatrixSlice :: (S.Storable elem)=> DenseMatrix or elem -> (Int,Int)-> (Int,Int)-> DenseMatrix or elem
uncheckedDenseMatrixSlice (DenseMatrix SRow xdim _ ystride arr) (xstart,ystart) (xend,yend) = res
where !res = DenseMatrix SRow (xend xstart + 1)
(yend ystart+1)
(ystride + xstart + (xdim xend))
(S.slice ixStart (ixEnd ixStart) arr )
!ixStart = (xstart+ystart*ystride)
!ixEnd = (xend+yend*ystride)
uncheckedDenseMatrixSlice (DenseMatrix SColumn _ ydim xstride arr) (xstart,ystart) (xend,yend) = res
where !res = DenseMatrix SColumn (xend xstart + 1)
(yend ystart+1)
(xstride + ystart + (ydim yend))
(S.slice ixStart (ixEnd ixStart) arr )
!ixStart = (ystart+xstart*xstride)
!ixEnd = (yend+xend*xstride)
transposeDenseMatrix :: (inor ~ (TransposeF outor) , outor ~ (TransposeF inor) ) => DenseMatrix inor elem -> DenseMatrix outor elem
transposeDenseMatrix (DenseMatrix orient x y stride arr)= (DenseMatrix (sTranpose orient) y x stride arr)