module BLAS.Internal (
clearArray,
bzero,
inlinePerformIO,
checkedSubvector,
checkedSubvectorWithStride,
checkVecVecOp,
checkedRow,
checkedCol,
checkedDiag,
checkedSubmatrix,
checkMatMatOp,
checkMatVecMult,
checkMatMatMult,
checkMatVecMultAdd,
checkMatMatMultAdd,
checkMatVecSolv,
checkMatMatSolv,
checkMatVecSolvTo,
checkMatMatSolvTo,
checkSquare,
checkFat,
checkTall,
checkBinaryOp,
checkTernaryOp,
diagStart,
diagLen,
) where
import Data.Ix ( inRange )
import Foreign ( Ptr, Storable, castPtr, sizeOf )
import Foreign.C.Types ( CSize )
import Text.Printf ( printf )
#if defined(__GLASGOW_HASKELL__)
import GHC.Base ( realWorld# )
import GHC.IOBase ( IO(IO) )
#else
import System.IO.Unsafe ( unsafePerformIO )
#endif
clearArray :: Storable e => Ptr e -> Int -> IO ()
clearArray = clearArray' undefined
where
clearArray' :: Storable e => e -> Ptr e -> Int -> IO ()
clearArray' e ptr n =
let nbytes = fromIntegral (n * sizeOf e)
in do
bzero ptr nbytes
bzero :: Ptr a -> Int -> IO ()
bzero ptr n =
let ptr' = castPtr ptr
n' = fromIntegral n
in bzero_ ptr' n'
foreign import ccall "strings.h bzero"
bzero_ :: Ptr () -> CSize -> IO ()
inlinePerformIO :: IO a -> a
#if defined(__GLASGOW_HASKELL__)
inlinePerformIO (IO m) = case m realWorld# of (# _, r #) -> r
#else
inlinePerformIO = unsafePerformIO
#endif
checkedSubvector :: Int -> (Int -> Int -> v) -> Int -> Int -> v
checkedSubvector n sub o n'
| (o < 0) && (n' /= 0) =
error $ printf
"tried to create a subvector starting at a negative offset: `%d'" o
| n' < 0 =
error $ printf
"tried to create a subvector with a negative length `%d'" n'
| n' + o > n =
error $ printf
("tried to create a subvector of length `%d' and offset `%d' "
++ " from a vector of length `%d'") n' o n
| otherwise =
sub o n'
checkedSubvectorWithStride :: Int -> Int -> (Int -> Int -> v)
-> Int -> Int -> v
checkedSubvectorWithStride s n sub o n'
| (o < 0) && (n' /= 0) =
error $ printf
"Tried to create a subvector starting at a negative offset: `%d'" o
| n' < 0 =
error $ printf
"Tried to create a subvector with a negative length `%d'" n'
| s <= 0 =
error $ printf
"Tried to create a subvector with non-positive stride `%d'" s
| not $ inRange (1,n) (o + s * n') =
error $ printf
("tried to create a subvector of length `%d', offset `%d',"
++ " and stride '%d' from a vector of length `%d'") n' o s n
| otherwise =
sub o n'
checkVecVecOp :: String -> Int -> Int -> a -> a
checkVecVecOp name n1 n2
| n1 /= n2 =
error $ printf
("%s: x and y have different dimensions. x has dimension `%d',"
++ " and y has dimension `%d'") name n1 n2
| otherwise = id
checkedRow :: (Int,Int) -> (Int -> v) -> Int -> v
checkedRow (m,n) row i
| i < 0 || i >= m =
error $ printf
"Error in row index. Tried to get row `%d' in a matrix with shape `(%d,%d)'" i m n
| otherwise =
row i
checkedCol :: (Int,Int) -> (Int -> v) -> Int -> v
checkedCol (m,n) col j
| j < 0 || j >= n =
error $ printf
"Error in column index. Tried to get column `%d' in a matrix with shape `(%d,%d)'" j m n
| otherwise =
col j
checkedDiag :: (Int,Int) -> (Int -> v) -> Int -> v
checkedDiag (m,n) diag i
| i < 0 && negate i >= m =
error $ printf
"Tried to get sub-diagonal `%d' of a matrix with shape `(%d,%d)'" (negate i) m n
| i > 0 && i >= n =
error $ printf
"Tried to get super-diagonal `%d' of a matrix with shape `(%d,%d)'" i m n
| otherwise =
diag i
diagStart :: Int -> (Int,Int)
diagStart i
| i <= 0 =
(negate i, 0)
| otherwise =
(0, i)
diagLen :: (Int,Int) -> Int -> Int
diagLen (m,n) i
| m <= n =
if i <= 0
then max (m + i) 0
else min (n i) m
| otherwise =
if i > 0
then max (n i) 0
else min (m + i) n
checkedSubmatrix :: (Int,Int) -> ((Int,Int) -> (Int,Int) -> a) -> (Int,Int) -> (Int,Int) -> a
checkedSubmatrix (m,n) sub (i,j) (m',n')
| or [ i < 0, m' < 0, i + m' > m,
j < 0, n' < 0, j + n' > n ] =
error $ printf ("tried to create submatrix of a `(%d,%d)' matrix " ++
" using offset `(%d,%d)' and shape (%d,%d)") m n i j m' n'
| otherwise =
sub (i,j) (m',n')
checkMatMatOp :: String -> (Int,Int) -> (Int,Int) -> a -> a
checkMatMatOp name mn1 mn2
| mn1 /= mn2 =
error $ printf
("%s: x and y have different shapes. x has shape `%s',"
++ " and y has shape `%s'") name (show mn1) (show mn2)
| otherwise = id
checkMatVecMult :: (Int,Int) -> Int -> a -> a
checkMatVecMult mn n
| snd mn /= n =
error $ printf
("Tried to multiply a matrix with shape `%s' by a vector of dimension `%d'")
(show mn) n
| otherwise = id
checkMatMatMult :: (Int,Int) -> (Int,Int) -> a -> a
checkMatMatMult mk kn
| snd mk /= fst kn =
error $ printf
("Tried to multiply a matrix with shape `%s' by a matrix with shape `%s'")
(show mk) (show kn)
| otherwise = id
checkMatVecMultAdd :: (Int,Int) -> Int -> Int -> a -> a
checkMatVecMultAdd mn n m
| snd mn /= n =
error $ printf
("Tried to multiply a matrix with shape `%s' by a vector of dimension `%d'")
(show mn) n
| fst mn /= m =
error $ printf
("Tried to add a vector of dimension `%d' to a vector of dimension `%d'")
(fst mn) m
| otherwise = id
checkMatMatMultAdd :: (Int,Int) -> (Int,Int) -> (Int,Int) -> a -> a
checkMatMatMultAdd mk kn mn
| snd mk /= fst kn =
error $ printf
("Tried to multiply a matrix with shape `%s' by a matrix with shape `%s'")
(show mk) (show kn)
| (fst mk, snd kn) /= mn =
error $ printf
("Tried to add a matrix with shape `%s' to a matrix with shape `%s'")
(show (fst mk, snd kn)) (show mn)
| otherwise = id
checkMatVecSolv :: (Int,Int) -> Int -> a -> a
checkMatVecSolv mn m
| fst mn /= m =
error $ printf
("Tried to solve a matrix with shape `%s' for a vector of dimension `%d'")
(show mn) m
| otherwise = id
checkMatVecSolvTo :: (Int,Int) -> Int -> Int -> a -> a
checkMatVecSolvTo mn m n
| fst mn /= m =
error $ printf
("Tried to solve a matrix with shape `%s' for a vector of dimension `%d'")
(show mn) m
| snd mn /= n =
error $ printf
("Tried to store a vector of dimension `%d' in a vector of dimension `%d'")
(show $ snd mn) n
| otherwise = id
checkMatMatSolv :: (Int,Int) -> (Int,Int) -> a -> a
checkMatMatSolv mn mk
| fst mn /= fst mk =
error $ printf
("Tried to solve a matrix with shape `%s' for a matrix with shape `%s'")
(show mn) (show mk)
| otherwise = id
checkMatMatSolvTo :: (Int,Int) -> (Int,Int) -> (Int,Int) -> a -> a
checkMatMatSolvTo mk mn kn
| fst mn /= fst mk =
error $ printf
("Tried to solve a matrix with shape `%s' for a matrix with shape `%s'")
(show mk) (show mn)
| kn /= (snd mk, snd mn) =
error $ printf
("Tried to store a matrix with shape `%s' in a matrix with shape `%s'")
(show (snd mk, snd mn)) (show kn)
| otherwise = id
checkSquare :: String -> (Int,Int) -> a -> a
checkSquare str (m,n)
| m /= n =
error $ printf
"%s <matrix of shape (%d,%d)>: matrix shape must be square."
str m n
| otherwise = id
checkFat :: String -> (Int,Int) -> a -> a
checkFat str (m,n)
| m > n =
error $ printf
"%s <matrix of shape (%d,%d)>: matrix must have at least as many columns as rows."
str m n
| otherwise = id
checkTall :: String -> (Int,Int) -> a -> a
checkTall str (m,n)
| m < n =
error $ printf
"%s <matrix of shape (%d,%d)>: matrix must have at least as many rows as columns."
str m n
| otherwise = id
checkBinaryOp :: (Eq i, Show i) => i -> i -> a -> a
checkBinaryOp m n
| m /= n =
error $ printf
("Shapes in binary operation do not match. "
++ " First operand has shape `%s' and second has shapw `%s'.")
(show m)
(show n)
| otherwise = id
checkTernaryOp :: (Eq i, Show i) => i -> i -> i -> a -> a
checkTernaryOp l m n
| l == m && l == n = id
| otherwise =
error $ printf
("Shapes in ternary operation do not match. "
++ " First operand has shape `%s', second has shapw `%s',"
++ " and third has shape `%s'.")
(show l)
(show m)
(show n)