module Data.Tensor.Class.MTensor (
ReadTensor(..),
readElem,
WriteTensor(..),
writeElem,
modifyElem,
swapElems,
) where
import Data.Elem.BLAS ( Elem, conjugate )
import Data.Ix
import Data.Tensor.Class
class (Shaped x i, Monad m) => ReadTensor x i e m | x -> i where
getSize :: x n e -> m Int
unsafeReadElem :: x n e -> i -> m e
getIndices :: x n e -> m [i]
getIndices' :: x n e -> m [i]
getElems :: x n e -> m [e]
getElems x = getAssocs x >>= return . snd . unzip
getElems' :: x n e -> m [e]
getElems' x = getAssocs' x >>= return . snd . unzip
getAssocs :: x n e -> m [(i,e)]
getAssocs x = do
is <- getIndices x
es <- getElems x
return $ zip is es
getAssocs' :: x n e -> m [(i,e)]
getAssocs' x = do
is <- getIndices' x
es <- getElems' x
return $ zip is es
readElem :: (ReadTensor x i e m) => x n e -> i -> m e
readElem x i =
case (inRange b i) of
False ->
fail $ "tried to get element at a index `" ++ show i ++ "'"
++ " in an object with shape `" ++ show s ++ "'"
True ->
unsafeReadElem x i
where
b = bounds x
s = shape x
class (ReadTensor x i e m) => WriteTensor x i e m | x -> m where
getMaxSize :: x n e -> m Int
getMaxSize = getSize
setZero :: (Num e) => x n e -> m ()
setZero = setConstant 0
setConstant :: e -> x n e -> m ()
canModifyElem :: x n e -> i -> m Bool
unsafeWriteElem :: x n e -> i -> e -> m ()
unsafeModifyElem :: x n e -> i -> (e -> e) -> m ()
unsafeModifyElem x i f = do
e <- unsafeReadElem x i
unsafeWriteElem x i (f e)
modifyWith :: (e -> e) -> x n e -> m ()
unsafeSwapElems :: x n e -> i -> i -> m ()
unsafeSwapElems x i j = do
e <- unsafeReadElem x i
f <- unsafeReadElem x j
unsafeWriteElem x j e
unsafeWriteElem x i f
doConj :: (Elem e) => x n e -> m ()
doConj = modifyWith conjugate
scaleBy :: (Num e) => e -> x n e -> m ()
scaleBy 1 = const $ return ()
scaleBy k = modifyWith (k*)
shiftBy :: (Num e) => e -> x n e -> m ()
shiftBy 0 = const $ return ()
shiftBy k = modifyWith (k+)
writeElem :: (WriteTensor x i e m) => x n e -> i -> e -> m ()
writeElem x i e = do
ok <- canModifyElem x i
case ok && inRange (bounds x) i of
False ->
fail $ "tried to set element at index `" ++ show i ++ "'"
++ " in an object with shape `" ++ show s ++ "'"
++ " but that element cannot be modified"
True ->
unsafeWriteElem x i e
where
s = shape x
modifyElem :: (WriteTensor x i e m) => x n e -> i -> (e -> e) -> m ()
modifyElem x i f = do
ok <- canModifyElem x i
case ok of
False ->
fail $ "tried to modify element at index `" ++ show i ++ "'"
++ " in an object with shape `" ++ show s ++ "'"
++ " but that element cannot be modified"
True ->
unsafeModifyElem x i f
where
s = shape x
swapElems :: (WriteTensor x i e m) => x n e -> i -> i -> m ()
swapElems x i j
| not ((inRange (bounds x) i) && (inRange (bounds x) j)) =
fail $ "Tried to swap elements `" ++ show i ++ "' and `"
++ show j ++ "' in a tensor of shape `" ++ show (shape x)
++ "'."
| otherwise =
unsafeSwapElems x i j