module BLAS.Tensor.Write (
WriteTensor(..),
writeElem,
modifyElem,
swapElems,
) where
import Data.Ix( inRange )
import BLAS.Tensor.Base
import BLAS.Tensor.Read
import BLAS.Elem( Elem, BLAS1, conj )
class (ReadTensor x i m) => WriteTensor x i m | x -> m where
getMaxSize :: x n e -> m Int
getMaxSize = getSize
setZero :: (Elem e) => x n e -> m ()
setZero = setConstant 0
setConstant :: (Elem e) => e -> x n e -> m ()
canModifyElem :: x n e -> i -> m Bool
unsafeWriteElem :: (Elem e) => x n e -> i -> e -> m ()
unsafeModifyElem :: (Elem e) => x n e -> i -> (e -> e) -> m ()
unsafeModifyElem x i f = do
e <- unsafeReadElem x i
unsafeWriteElem x i (f e)
modifyWith :: (Elem e) => (e -> e) -> x n e -> m ()
unsafeSwapElems :: (Elem e) => x n e -> i -> i -> m ()
unsafeSwapElems x i j = do
e <- unsafeReadElem x i
f <- unsafeReadElem x j
unsafeWriteElem x j e
unsafeWriteElem x i f
doConj :: (BLAS1 e) => x n e -> m ()
doConj = modifyWith conj
scaleBy :: (BLAS1 e) => e -> x n e -> m ()
scaleBy 1 = const $ return ()
scaleBy k = modifyWith (k*)
shiftBy :: (BLAS1 e) => e -> x n e -> m ()
shiftBy 0 = const $ return ()
shiftBy k = modifyWith (k+)
writeElem :: (WriteTensor x i m, Elem e) => x n e -> i -> e -> m ()
writeElem x i e = do
ok <- canModifyElem x i
case ok && inRange (bounds x) i of
False ->
fail $ "tried to set element at index `" ++ show i ++ "'"
++ " in an object with shape `" ++ show s ++ "'"
++ " but that element cannot be modified"
True ->
unsafeWriteElem x i e
where
s = shape x
modifyElem :: (WriteTensor x i m, Elem e) => x n e -> i -> (e -> e) -> m ()
modifyElem x i f = do
ok <- canModifyElem x i
case ok of
False ->
fail $ "tried to modify element at index `" ++ show i ++ "'"
++ " in an object with shape `" ++ show s ++ "'"
++ " but that element cannot be modified"
True ->
unsafeModifyElem x i f
where
s = shape x
swapElems :: (WriteTensor x i m, Elem e) => x n e -> i -> i -> m ()
swapElems x i j
| not ((inRange (bounds x) i) && (inRange (bounds x) j)) =
fail $ "Tried to swap elements `" ++ show i ++ "' and `"
++ show j ++ "' in a tensor of shape `" ++ show (shape x)
++ "'."
| otherwise =
unsafeSwapElems x i j