{-# LANGUAGE CPP, ForeignFunctionInterface #-} {-# OPTIONS_GHC -fglasgow-exts #-} {-# OPTIONS_HADDOCK hide #-} ----------------------------------------------------------------------------- -- | -- Module : BLAS.Internal -- Copyright : Copyright (c) 2008, Patrick Perry -- License : BSD3 -- Maintainer : Patrick Perry -- Stability : experimental -- module BLAS.Internal ( clearArray, bzero, inlinePerformIO, checkedSubvector, checkedSubvectorWithStride, checkVecVecOp, checkedRow, checkedCol, checkedDiag, checkedSubmatrix, checkMatMatOp, checkMatVecMult, checkMatMatMult, checkMatVecMultAdd, checkMatMatMultAdd, checkMatVecSolv, checkMatMatSolv, checkMatVecSolvTo, checkMatMatSolvTo, checkSquare, checkFat, checkTall, checkBinaryOp, checkTernaryOp, diagStart, diagLen, ) where import Data.Ix ( inRange ) import Foreign ( Ptr, Storable, castPtr, sizeOf ) import Foreign.C.Types ( CSize ) import Text.Printf ( printf ) #if defined(__GLASGOW_HASKELL__) import GHC.Base ( realWorld# ) import GHC.IOBase ( IO(IO) ) #else import System.IO.Unsafe ( unsafePerformIO ) #endif clearArray :: Storable e => Ptr e -> Int -> IO () clearArray = clearArray' undefined where clearArray' :: Storable e => e -> Ptr e -> Int -> IO () clearArray' e ptr n = let nbytes = fromIntegral (n * sizeOf e) in do bzero ptr nbytes {-# INLINE clearArray #-} bzero :: Ptr a -> Int -> IO () bzero ptr n = let ptr' = castPtr ptr n' = fromIntegral n in bzero_ ptr' n' foreign import ccall "strings.h bzero" bzero_ :: Ptr () -> CSize -> IO () inlinePerformIO :: IO a -> a #if defined(__GLASGOW_HASKELL__) inlinePerformIO (IO m) = case m realWorld# of (# _, r #) -> r #else inlinePerformIO = unsafePerformIO #endif {-# INLINE inlinePerformIO #-} checkedSubvector :: Int -> (Int -> Int -> v) -> Int -> Int -> v checkedSubvector n sub o n' | (o < 0) && (n' /= 0) = error $ printf "tried to create a subvector starting at a negative offset: `%d'" o | n' < 0 = error $ printf "tried to create a subvector with a negative length `%d'" n' | n' + o > n = error $ printf ("tried to create a subvector of length `%d' and offset `%d' " ++ " from a vector of length `%d'") n' o n | otherwise = sub o n' checkedSubvectorWithStride :: Int -> Int -> (Int -> Int -> v) -> Int -> Int -> v checkedSubvectorWithStride s n sub o n' | (o < 0) && (n' /= 0) = error $ printf "Tried to create a subvector starting at a negative offset: `%d'" o | n' < 0 = error $ printf "Tried to create a subvector with a negative length `%d'" n' | s <= 0 = error $ printf "Tried to create a subvector with non-positive stride `%d'" s | not $ inRange (-1,n) (o + s * n') = error $ printf ("tried to create a subvector of length `%d', offset `%d'," ++ " and stride '%d' from a vector of length `%d'") n' o s n | otherwise = sub o n' checkVecVecOp :: String -> Int -> Int -> a -> a checkVecVecOp name n1 n2 | n1 /= n2 = error $ printf ("%s: x and y have different dimensions. x has dimension `%d'," ++ " and y has dimension `%d'") name n1 n2 | otherwise = id {-# INLINE checkVecVecOp #-} checkedRow :: (Int,Int) -> (Int -> v) -> Int -> v checkedRow (m,n) row i | i < 0 || i >= m = error $ printf "Error in row index. Tried to get row `%d' in a matrix with shape `(%d,%d)'" i m n | otherwise = row i checkedCol :: (Int,Int) -> (Int -> v) -> Int -> v checkedCol (m,n) col j | j < 0 || j >= n = error $ printf "Error in column index. Tried to get column `%d' in a matrix with shape `(%d,%d)'" j m n | otherwise = col j checkedDiag :: (Int,Int) -> (Int -> v) -> Int -> v checkedDiag (m,n) diag i | i < 0 && negate i >= m = error $ printf "Tried to get sub-diagonal `%d' of a matrix with shape `(%d,%d)'" (negate i) m n | i > 0 && i >= n = error $ printf "Tried to get super-diagonal `%d' of a matrix with shape `(%d,%d)'" i m n | otherwise = diag i diagStart :: Int -> (Int,Int) diagStart i | i <= 0 = (negate i, 0) | otherwise = (0, i) diagLen :: (Int,Int) -> Int -> Int diagLen (m,n) i | m <= n = if i <= 0 then max (m + i) 0 else min (n - i) m | otherwise = if i > 0 then max (n - i) 0 else min (m + i) n checkedSubmatrix :: (Int,Int) -> ((Int,Int) -> (Int,Int) -> a) -> (Int,Int) -> (Int,Int) -> a checkedSubmatrix (m,n) sub (i,j) (m',n') | or [ i < 0, m' < 0, i + m' > m, j < 0, n' < 0, j + n' > n ] = error $ printf ("tried to create submatrix of a `(%d,%d)' matrix " ++ " using offset `(%d,%d)' and shape (%d,%d)") m n i j m' n' | otherwise = sub (i,j) (m',n') checkMatMatOp :: String -> (Int,Int) -> (Int,Int) -> a -> a checkMatMatOp name mn1 mn2 | mn1 /= mn2 = error $ printf ("%s: x and y have different shapes. x has shape `%s'," ++ " and y has shape `%s'") name (show mn1) (show mn2) | otherwise = id checkMatVecMult :: (Int,Int) -> Int -> a -> a checkMatVecMult mn n | snd mn /= n = error $ printf ("Tried to multiply a matrix with shape `%s' by a vector of dimension `%d'") (show mn) n | otherwise = id checkMatMatMult :: (Int,Int) -> (Int,Int) -> a -> a checkMatMatMult mk kn | snd mk /= fst kn = error $ printf ("Tried to multiply a matrix with shape `%s' by a matrix with shape `%s'") (show mk) (show kn) | otherwise = id checkMatVecMultAdd :: (Int,Int) -> Int -> Int -> a -> a checkMatVecMultAdd mn n m | snd mn /= n = error $ printf ("Tried to multiply a matrix with shape `%s' by a vector of dimension `%d'") (show mn) n | fst mn /= m = error $ printf ("Tried to add a vector of dimension `%d' to a vector of dimension `%d'") (fst mn) m | otherwise = id checkMatMatMultAdd :: (Int,Int) -> (Int,Int) -> (Int,Int) -> a -> a checkMatMatMultAdd mk kn mn | snd mk /= fst kn = error $ printf ("Tried to multiply a matrix with shape `%s' by a matrix with shape `%s'") (show mk) (show kn) | (fst mk, snd kn) /= mn = error $ printf ("Tried to add a matrix with shape `%s' to a matrix with shape `%s'") (show (fst mk, snd kn)) (show mn) | otherwise = id checkMatVecSolv :: (Int,Int) -> Int -> a -> a checkMatVecSolv mn m | fst mn /= m = error $ printf ("Tried to solve a matrix with shape `%s' for a vector of dimension `%d'") (show mn) m | otherwise = id checkMatVecSolvTo :: (Int,Int) -> Int -> Int -> a -> a checkMatVecSolvTo mn m n | fst mn /= m = error $ printf ("Tried to solve a matrix with shape `%s' for a vector of dimension `%d'") (show mn) m | snd mn /= n = error $ printf ("Tried to store a vector of dimension `%d' in a vector of dimension `%d'") (show $ snd mn) n | otherwise = id checkMatMatSolv :: (Int,Int) -> (Int,Int) -> a -> a checkMatMatSolv mn mk | fst mn /= fst mk = error $ printf ("Tried to solve a matrix with shape `%s' for a matrix with shape `%s'") (show mn) (show mk) | otherwise = id checkMatMatSolvTo :: (Int,Int) -> (Int,Int) -> (Int,Int) -> a -> a checkMatMatSolvTo mk mn kn | fst mn /= fst mk = error $ printf ("Tried to solve a matrix with shape `%s' for a matrix with shape `%s'") (show mk) (show mn) | kn /= (snd mk, snd mn) = error $ printf ("Tried to store a matrix with shape `%s' in a matrix with shape `%s'") (show (snd mk, snd mn)) (show kn) | otherwise = id checkSquare :: String -> (Int,Int) -> a -> a checkSquare str (m,n) | m /= n = error $ printf "%s : matrix shape must be square." str m n | otherwise = id checkFat :: String -> (Int,Int) -> a -> a checkFat str (m,n) | m > n = error $ printf "%s : matrix must have at least as many columns as rows." str m n | otherwise = id checkTall :: String -> (Int,Int) -> a -> a checkTall str (m,n) | m < n = error $ printf "%s : matrix must have at least as many rows as columns." str m n | otherwise = id checkBinaryOp :: (Eq i, Show i) => i -> i -> a -> a checkBinaryOp m n | m /= n = error $ printf ("Shapes in binary operation do not match. " ++ " First operand has shape `%s' and second has shapw `%s'.") (show m) (show n) | otherwise = id {-# INLINE checkBinaryOp #-} checkTernaryOp :: (Eq i, Show i) => i -> i -> i -> a -> a checkTernaryOp l m n | l == m && l == n = id | otherwise = error $ printf ("Shapes in ternary operation do not match. " ++ " First operand has shape `%s', second has shapw `%s'," ++ " and third has shape `%s'.") (show l) (show m) (show n) {-# INLINE checkTernaryOp #-}