{-# LANGUAGE BangPatterns , RankNTypes, GADTs, DataKinds #-} module Numerical.HBLAS.BLAS.Internal.Level3( GemmFun ,HemmFun ,HerkFun ,Her2kFun ,SymmFun ,SyrkFun ,Syr2kFun ,TrmmFun ,TrsmFun ,gemmAbstraction ,hemmAbstraction ,herkAbstraction ,her2kAbstraction ,symmAbstraction ,syrkAbstraction ,syr2kAbstraction ,trmmAbstraction ,trsmAbstraction ) where import Numerical.HBLAS.Constants import Numerical.HBLAS.UtilsFFI import Numerical.HBLAS.BLAS.FFI.Level3 import Numerical.HBLAS.BLAS.Internal.Utility import Numerical.HBLAS.MatrixTypes import Control.Monad.Primitive import qualified Data.Vector.Storable.Mutable as SM import Data.Int import Foreign.Ptr type GemmFun el orient s m = Transpose -> Transpose -> el -> el -> MDenseMatrix s orient el -> MDenseMatrix s orient el -> MDenseMatrix s orient el -> m () type SymmFun el orient s m = EquationSide -> MatUpLo -> el -> el -> MDenseMatrix s orient el -> MDenseMatrix s orient el -> MDenseMatrix s orient el -> m () type HemmFun el orient s m = EquationSide -> MatUpLo -> el -> el -> MDenseMatrix s orient el -> MDenseMatrix s orient el -> MDenseMatrix s orient el -> m () type HerkFun scale el orient s m = MatUpLo -> Transpose -> scale -> scale -> MDenseMatrix s orient el -> MDenseMatrix s orient el -> m () type Her2kFun scale el orient s m = MatUpLo -> Transpose -> el -> scale -> MDenseMatrix s orient el -> MDenseMatrix s orient el -> MDenseMatrix s orient el -> m () type SyrkFun el orient s m = MatUpLo -> Transpose -> el -> el -> MDenseMatrix s orient el -> MDenseMatrix s orient el -> m () type Syr2kFun el orient s m = MatUpLo -> Transpose -> el -> el -> MDenseMatrix s orient el -> MDenseMatrix s orient el -> MDenseMatrix s orient el -> m () type TrmmFun el orient s m = EquationSide -> MatUpLo -> Transpose -> MatDiag -> el -> MDenseMatrix s orient el -> MDenseMatrix s orient el -> m () type TrsmFun el orient s m = EquationSide -> MatUpLo -> Transpose -> MatDiag -> el -> MDenseMatrix s orient el -> MDenseMatrix s orient el -> m () gemmComplexity :: Integral a => a -> a -> a -> Int64 gemmComplexity a b c = fromIntegral a * fromIntegral b *fromIntegral c -- this will be wrong by some constant factor, albeit a small one -- this covers the ~6 cases for checking the dimensions for GEMM quite nicely isBadGemm :: (Ord a, Num a) => Transpose -> Transpose -> a -> a -> a -> a -> a -> a -> Bool isBadGemm tra trb ax ay bx by cx cy = isBadGemmHelper (cds tra (ax,ay)) (cds trb (bx,by) ) (cx,cy) where cds = coordSwapper isBadGemmHelper !(ax,ay) !(bx,by) !(cx,cy) = (minimum [ax, ay, bx, by, cx ,cy] <= 0) || not ( cy == ay && cx == bx && ax == by) isBadSymm :: (Ord a, Num a) => EquationSide -> a -> a -> a -> a -> a -> a -> Bool isBadSymm LeftSide ax ay bx by cx cy = isBadSymmBothSide ax ay bx by cx cy || (ax /= by) isBadSymm RightSide ax ay bx by cx cy = isBadSymmBothSide ax ay bx by cx cy || (bx /= ay) isBadSymmBothSide :: (Ord a, Num a) => a -> a -> a -> a -> a -> a -> Bool isBadSymmBothSide ax ay bx by cx cy = (minimum [ax, ay, bx, by, cx, cy] <= 0) || not (ax == ay && bx == cx && by == cy) {- A key design goal of this ffi is to provide *safe* throughput guarantees for a concurrent application built on top of these apis, while evading any overheads for providing such safety. Accordingly, on inputs sizes where the estimated flops count will be more then 1-10 microseconds, safe ffi calls are used. For inputs whose runtime is under that unsafe ffi calls are used. -} ---- | Matrix mult for general dense matrices --type GemmFunFFI scale el = CBLAS_ORDERT -> CBLAS_TRANSPOSET -> CBLAS_TRANSPOSET-> --CInt -> CInt -> CInt -> {- scal A * B -} scale -> {- Matrix A-} Ptr el -> CInt -> {- B -} Ptr el -> CInt-> --scale -> {- C -} Ptr el -> CInt -> IO () --type GemmFun = MutDenseMatrix or el -> MutDenseMatrix or el -> MutDenseMatrix or el -> m () {-# NOINLINE gemmAbstraction #-} gemmAbstraction:: (SM.Storable el, PrimMonad m) => String -> GemmFunFFI scale el -> GemmFunFFI scale el -> (el -> (scale -> m ())->m ()) -> forall orient . GemmFun el orient (PrimState m) m gemmAbstraction gemmName gemmSafeFFI gemmUnsafeFFI constHandler = go where shouldCallFast :: Int -> Int -> Int -> Bool shouldCallFast cy cx ax = flopsThreshold >= gemmComplexity cy cx ax go tra trb alpha beta (MutableDenseMatrix ornta ax ay astride abuff) (MutableDenseMatrix _ bx by bstride bbuff) (MutableDenseMatrix _ cx cy cstride cbuff) | isBadGemm tra trb ax ay bx by cx cy = error $! "bad dimension args to GEMM: ax ay bx by cx cy: " ++ show [ax, ay, bx, by, cx ,cy] | SM.overlaps abuff cbuff || SM.overlaps bbuff cbuff = error $ "the read and write inputs for: " ++ gemmName ++ " overlap. This is a programmer error. Please fix." | otherwise = {- FIXME : Add Sharing check that also errors out for now-} unsafeWithPrim abuff $ \ap -> unsafeWithPrim bbuff $ \bp -> unsafeWithPrim cbuff $ \cp -> constHandler alpha $ \alphaPtr -> constHandler beta $ \betaPtr -> do (axNew,_) <- return $ coordSwapper tra (ax,ay) --- dont need to swap b, info is in a and c --- c doesn't get implicitly transposed let blasOrder = encodeNiceOrder ornta -- all three are the same orientation let rawTra = encodeFFITranspose tra let rawTrb = encodeFFITranspose trb -- example of why i want to switch to singletones unsafePrimToPrim $! (if shouldCallFast cy cx axNew then gemmUnsafeFFI else gemmSafeFFI ) blasOrder rawTra rawTrb (fromIntegral cy) (fromIntegral cx) (fromIntegral axNew) alphaPtr ap (fromIntegral astride) bp (fromIntegral bstride) betaPtr cp (fromIntegral cstride) {-# NOINLINE symmAbstraction #-} symmAbstraction :: (SM.Storable el, PrimMonad m) => String -> SymmFunFFI scale el -> SymmFunFFI scale el -> (el -> (scale -> m ()) -> m ()) -> forall orient . SymmFun el orient (PrimState m) m symmAbstraction symmName symmSafeFFI symmUnsafeFFI constHandler = symm where shouldCallFast :: Int -> Int -> Int -> Bool shouldCallFast cy cx ax = flopsThreshold >= (fromIntegral cx :: Int64) * (fromIntegral cy :: Int64) * (fromIntegral ax :: Int64) symm side uplo alpha beta (MutableDenseMatrix ornta ax ay astride abuff) (MutableDenseMatrix _ bx by bstride bbuff) (MutableDenseMatrix _ cx cy cstride cbuff) | isBadSymm side ax ay bx by cx cy = error $! "bad dimension args to SYMM: ax ay bx by cx cy side: " ++ show [ax, ay, bx, by, cx ,cy] ++ " " ++ show side | SM.overlaps abuff cbuff || SM.overlaps bbuff cbuff = error $ "the read and write inputs for: " ++ symmName ++ " overlap. This is a programmer error. Please fix." | otherwise = unsafeWithPrim abuff $ \ap -> unsafeWithPrim bbuff $ \bp -> unsafeWithPrim cbuff $ \cp -> constHandler alpha $ \alphaPtr -> constHandler beta $ \betaPtr -> do let rawOrder = encodeNiceOrder ornta let rawUplo = encodeFFIMatrixHalf uplo let rawSide = encodeFFISide side unsafePrimToPrim $! (if shouldCallFast cy cx ax then symmUnsafeFFI else symmSafeFFI) rawOrder rawSide rawUplo (fromIntegral cy) (fromIntegral cx) alphaPtr ap (fromIntegral astride) bp (fromIntegral bstride) betaPtr cp (fromIntegral cstride) {-# NOINLINE hemmAbstraction #-} hemmAbstraction :: (SM.Storable el, PrimMonad m) => String -> HemmFunFFI el -> HemmFunFFI el -> (el -> (Ptr el -> m ()) -> m ()) -> forall orient . HemmFun el orient (PrimState m) m hemmAbstraction hemmName hemmSafeFFI hemmUnsafeFFI constHandler = hemm where isBadHemmBothSide :: (Ord a, Num a) => a -> a -> a -> a -> a -> a -> Bool isBadHemmBothSide ax ay bx by cx cy = (minimum [ax, ay, bx, by, cx, cy] <= 0) || not (ax == ay && bx == cx && by == cy) isBadHemm :: (Ord a, Num a) => EquationSide -> a -> a -> a -> a -> a -> a -> Bool isBadHemm LeftSide ax ay bx by cx cy = isBadHemmBothSide ax ay bx by cx cy || (ax /= by) isBadHemm RightSide ax ay bx by cx cy = isBadHemmBothSide ax ay bx by cx cy || (bx /= ay) shouldCallFast :: Int -> Int -> Int -> Bool shouldCallFast cy cx ax = flopsThreshold >= (fromIntegral cx :: Int64) * (fromIntegral cy :: Int64) * (fromIntegral ax :: Int64) hemm side uplo alpha beta (MutableDenseMatrix ornta ax ay astride abuff) (MutableDenseMatrix _ bx by bstride bbuff) (MutableDenseMatrix _ cx cy cstride cbuff) | isBadHemm side ax ay bx by cx cy = error $! "bad dimension args to hemm: ax ay bx by cx cy trans: " ++ show [ax, ay, bx, by, cx ,cy] ++ " " ++ show side | SM.overlaps abuff cbuff || SM.overlaps bbuff cbuff = error $ "the read and write inputs for: " ++ hemmName ++ " overlap. This is a programmer error. Please fix." | otherwise = unsafeWithPrim abuff $ \ap -> unsafeWithPrim bbuff $ \bp -> unsafeWithPrim cbuff $ \cp -> constHandler alpha $ \alphaPtr -> constHandler beta $ \betaPtr -> do let rawOrder = encodeNiceOrder ornta let rawUplo = encodeFFIMatrixHalf uplo let rawSide = encodeFFISide side unsafePrimToPrim $! (if shouldCallFast cy cx ax then hemmUnsafeFFI else hemmSafeFFI) rawOrder rawSide rawUplo (fromIntegral cy) (fromIntegral cx) alphaPtr ap (fromIntegral astride) bp (fromIntegral bstride) betaPtr cp (fromIntegral cstride) {-# NOINLINE herkAbstraction #-} herkAbstraction :: (SM.Storable el, PrimMonad m) => String -> HerkFunFFI scalePtr el -> HerkFunFFI scalePtr el -> (scale -> (scalePtr -> m ()) -> m ()) -> forall orient . HerkFun scale el orient (PrimState m) m herkAbstraction herkName herkSafeFFI herkUnsafeFFI constHandler = herk where isBadHerkBothSide :: (Ord a, Num a) => a -> a -> a -> a -> Bool isBadHerkBothSide ax ay cx cy = (minimum [ax, ay, cx, cy] <= 0) || (cx /= cy) isBadHerk :: (Ord a, Num a) => Transpose -> a -> a -> a -> a -> Bool isBadHerk NoTranspose ax ay cx cy = isBadHerkBothSide ax ay cx cy || (ay /= cx) isBadHerk ConjTranspose ax ay cx cy = isBadHerkBothSide ax ay cx cy || (ax /= cx) isBadHerk trans _ _ _ _ = error $ herkName ++ ": trans " ++ show trans ++ " is invalid." -- n * k * n shouldCallFast :: Int -> Int -> Int -> Bool shouldCallFast ax ay cx = flopsThreshold >= (fromIntegral ax :: Int64) * (fromIntegral ay :: Int64) * (fromIntegral cx :: Int64) herk uplo trans alpha beta (MutableDenseMatrix ornta ax ay astride abuff) (MutableDenseMatrix _ cx cy cstride cbuff) | isBadHerk trans ax ay cx cy = error $! "bad dimension args to " ++ herkName ++ ": ax ay cx cy trans: " ++ show [ax, ay, cx ,cy] ++ " " ++ show trans | SM.overlaps abuff cbuff = error $ "the read and write inputs for: " ++ herkName ++ " overlap. This is a programmer error. Please fix." | otherwise = call where k = if (trans == NoTranspose) then ax else ay call = unsafeWithPrim abuff $ \ap -> unsafeWithPrim cbuff $ \cp -> constHandler alpha $ \alphaPtr -> constHandler beta $ \betaPtr -> do let rawOrder = encodeNiceOrder ornta let rawUplo = encodeFFIMatrixHalf uplo let rawTrans = encodeFFITranspose trans unsafePrimToPrim $! (if shouldCallFast ax ay cx then herkUnsafeFFI else herkSafeFFI) rawOrder rawUplo rawTrans (fromIntegral cy) (fromIntegral k) alphaPtr ap (fromIntegral astride) betaPtr cp (fromIntegral cstride) {-# NOINLINE her2kAbstraction #-} her2kAbstraction :: (SM.Storable el, PrimMonad m) => String -> Her2kFunFFI scale el -> Her2kFunFFI scale el -> (el -> (Ptr el -> m ()) -> m ()) -> forall orient . Her2kFun scale el orient (PrimState m) m her2kAbstraction her2kName her2kSafeFFI her2kUnsafeFFI constHandler = her2k where isBadHer2kBothSide :: (Ord a, Num a) => a -> a -> a -> a -> a -> a -> Bool isBadHer2kBothSide ax ay bx by cx cy = (minimum [ax, ay, bx, by, cx, cy] <= 0) || not (cx == cy && ax == bx && ay == by) isBadHer2k :: (Ord a, Num a) => Transpose -> a -> a -> a -> a -> a -> a -> Bool isBadHer2k NoTranspose ax ay bx by cx cy = isBadHer2kBothSide ax ay bx by cx cy || (ay /= cx) isBadHer2k ConjTranspose ax ay bx by cx cy = isBadHer2kBothSide ax ay bx by cx cy || (ax /= cx) isBadHer2k trans _ _ _ _ _ _ = error $ her2kName ++ ": trans " ++ show trans ++ " is invalid." -- n * k * n shouldCallFast :: Int -> Int -> Int -> Bool shouldCallFast ax ay cx = flopsThreshold >= (fromIntegral ax :: Int64) * (fromIntegral ay :: Int64) * (fromIntegral cx :: Int64) * 2 her2k uplo trans alpha beta (MutableDenseMatrix ornta ax ay astride abuff) (MutableDenseMatrix _ bx by bstride bbuff) (MutableDenseMatrix _ cx cy cstride cbuff) | isBadHer2k trans ax ay bx by cx cy = error $! "bad dimension args to " ++ her2kName ++ ": ax ay cx cy trans: " ++ show [ax, ay, bx, by, cx ,cy] ++ " " ++ show trans | SM.overlaps abuff cbuff = error $ "the read and write inputs for: " ++ her2kName ++ " overlap. This is a programmer error. Please fix." | otherwise = call where k = if (trans == NoTranspose) then ax else ay call = unsafeWithPrim abuff $ \ap -> unsafeWithPrim bbuff $ \bp -> unsafeWithPrim cbuff $ \cp -> constHandler alpha $ \alphaPtr -> do let rawOrder = encodeNiceOrder ornta let rawUplo = encodeFFIMatrixHalf uplo let rawTrans = encodeFFITranspose trans unsafePrimToPrim $! (if shouldCallFast ax ay cx then her2kUnsafeFFI else her2kSafeFFI) rawOrder rawUplo rawTrans (fromIntegral cy) (fromIntegral k) alphaPtr ap (fromIntegral astride) bp (fromIntegral bstride) beta cp (fromIntegral cstride) {-# NOINLINE syrkAbstraction #-} syrkAbstraction :: (SM.Storable el, PrimMonad m) => String -> SyrkFunFFI scale el -> SyrkFunFFI scale el -> (el -> (scale -> m ()) -> m ()) -> forall orient . SyrkFun el orient (PrimState m) m syrkAbstraction syrkName syrkSafeFFI syrkUnsafeFFI constHandler = syrk where isBadSyrkBothSide :: (Ord a, Num a) => a -> a -> a -> a -> Bool isBadSyrkBothSide ax ay cx cy = (minimum [ax, ay, cx, cy] <= 0) || (cx /= cy) isBadSyrk :: (Ord a, Num a) => Transpose -> a -> a -> a -> a -> Bool isBadSyrk NoTranspose ax ay cx cy = isBadSyrkBothSide ax ay cx cy || (ay /= cx) isBadSyrk Transpose ax ay cx cy = isBadSyrkBothSide ax ay cx cy || (ax /= cx) isBadSyrk ConjTranspose ax ay cx cy = isBadSyrkBothSide ax ay cx cy || (ax /= cx) isBadSyrk trans _ _ _ _ = error $ syrkName ++ ": trans " ++ show trans ++ " is invalid." -- n * k * n shouldCallFast :: Int -> Int -> Int -> Bool shouldCallFast ax ay cx = flopsThreshold >= (fromIntegral ax :: Int64) * (fromIntegral ay :: Int64) * (fromIntegral cx :: Int64) syrk uplo trans alpha beta (MutableDenseMatrix ornta ax ay astride abuff) (MutableDenseMatrix _ cx cy cstride cbuff) | isBadSyrk trans ax ay cx cy = error $! "bad dimension args to " ++ syrkName ++ ": ax ay cx cy trans: " ++ show [ax, ay, cx ,cy] ++ " " ++ show trans | SM.overlaps abuff cbuff = error $ "the read and write inputs for: " ++ syrkName ++ " overlap. This is a programmer error. Please fix." | otherwise = call where k = if (trans == NoTranspose) then ax else ay call = unsafeWithPrim abuff $ \ap -> unsafeWithPrim cbuff $ \cp -> constHandler alpha $ \alphaPtr -> constHandler beta $ \betaPtr -> do let rawOrder = encodeNiceOrder ornta let rawUplo = encodeFFIMatrixHalf uplo let rawTrans = encodeFFITranspose trans unsafePrimToPrim $! (if shouldCallFast ax ay cx then syrkUnsafeFFI else syrkSafeFFI) rawOrder rawUplo rawTrans (fromIntegral cy) (fromIntegral k) alphaPtr ap (fromIntegral astride) betaPtr cp (fromIntegral cstride) {-# NOINLINE syr2kAbstraction #-} syr2kAbstraction :: (SM.Storable el, PrimMonad m) => String -> Syr2kFunFFI scale el -> Syr2kFunFFI scale el -> (el -> (scale -> m ()) -> m ()) -> forall orient . Syr2kFun el orient (PrimState m) m syr2kAbstraction syr2kName syr2kSafeFFI syr2kUnsafeFFI constHandler = syr2k where isBadSyr2kBothSide :: (Ord a, Num a) => a -> a -> a -> a -> a -> a -> Bool isBadSyr2kBothSide ax ay bx by cx cy = (minimum [ax, ay, bx, by, cx, cy] <= 0) || not (cx == cy && ax == bx && ay == by) isBadSyr2k :: (Ord a, Num a) => Transpose -> a -> a -> a -> a -> a -> a -> Bool isBadSyr2k NoTranspose ax ay bx by cx cy = isBadSyr2kBothSide ax ay bx by cx cy || (ay /= cx) isBadSyr2k Transpose ax ay bx by cx cy = isBadSyr2kBothSide ax ay bx by cx cy || (ax /= cx) isBadSyr2k ConjTranspose ax ay bx by cx cy = isBadSyr2kBothSide ax ay bx by cx cy || (ax /= cx) isBadSyr2k trans _ _ _ _ _ _ = error $ syr2kName ++ ": trans " ++ show trans ++ " is invalid." -- n * k * n shouldCallFast :: Int -> Int -> Int -> Bool shouldCallFast ax ay cx = flopsThreshold >= (fromIntegral ax :: Int64) * (fromIntegral ay :: Int64) * (fromIntegral cx :: Int64) * 2 syr2k uplo trans alpha beta (MutableDenseMatrix ornta ax ay astride abuff) (MutableDenseMatrix _ bx by bstride bbuff) (MutableDenseMatrix _ cx cy cstride cbuff) | isBadSyr2k trans ax ay bx by cx cy = error $! "bad dimension args to " ++ syr2kName ++ ": ax ay cx cy trans: " ++ show [ax, ay, bx, by, cx ,cy] ++ " " ++ show trans | SM.overlaps abuff cbuff = error $ "the read and write inputs for: " ++ syr2kName ++ " overlap. This is a programmer error. Please fix." | otherwise = call where k = if (trans == NoTranspose) then ax else ay call = unsafeWithPrim abuff $ \ap -> unsafeWithPrim bbuff $ \bp -> unsafeWithPrim cbuff $ \cp -> constHandler alpha $ \alphaPtr -> constHandler beta $ \betaPtr -> do let rawOrder = encodeNiceOrder ornta let rawUplo = encodeFFIMatrixHalf uplo let rawTrans = encodeFFITranspose trans unsafePrimToPrim $! (if shouldCallFast ax ay cx then syr2kUnsafeFFI else syr2kSafeFFI) rawOrder rawUplo rawTrans (fromIntegral cy) (fromIntegral k) alphaPtr ap (fromIntegral astride) bp (fromIntegral bstride) betaPtr cp (fromIntegral cstride) {-# NOINLINE trmmAbstraction #-} trmmAbstraction :: (SM.Storable el, PrimMonad m) => String -> TrmmFunFFI scale el -> TrmmFunFFI scale el -> (el -> (scale -> m ()) -> m ()) -> forall orient . TrmmFun el orient (PrimState m) m trmmAbstraction trmmName trmmSafeFFI trmmUnsafeFFI constHandler = trmm where isBadTrmmBothSide :: (Ord a, Num a) => a -> a -> a -> a -> Bool isBadTrmmBothSide ax ay cx cy = (minimum [ax, ay, cx, cy] <= 0) || not (ax == ay) isBadTrmm :: (Ord a, Num a) => EquationSide -> a -> a -> a -> a -> Bool isBadTrmm LeftSide ax ay cx cy = isBadTrmmBothSide ax ay cx cy || (ax /= cy) isBadTrmm RightSide ax ay cx cy = isBadTrmmBothSide ax ay cx cy || (ax /= cx) -- n * k * n shouldCallFast :: Int -> Int -> Int -> Bool shouldCallFast ax cx cy = flopsThreshold >= (fromIntegral ax :: Int64) * (fromIntegral cx :: Int64) * (fromIntegral cy :: Int64) trmm side uplo trans diag alpha (MutableDenseMatrix ornta ax ay astride abuff) (MutableDenseMatrix _ cx cy cstride cbuff) | isBadTrmm side ax ay cx cy = error $! "bad dimension args to " ++ trmmName ++ ": ax ay cx cy side: " ++ show [ax, ay, cx ,cy] ++ " " ++ show side | SM.overlaps abuff cbuff = error $ "the read and write inputs for: " ++ trmmName ++ " overlap. This is a programmer error. Please fix." | otherwise = call where call = unsafeWithPrim abuff $ \ap -> unsafeWithPrim cbuff $ \cp -> constHandler alpha $ \alphaPtr -> do let rawOrder = encodeNiceOrder ornta let rawSide = encodeFFISide side let rawUplo = encodeFFIMatrixHalf uplo let rawTrans = encodeFFITranspose trans let rawDiag = encodeFFITriangleSort diag unsafePrimToPrim $! (if shouldCallFast ax ay cx then trmmUnsafeFFI else trmmSafeFFI) rawOrder rawSide rawUplo rawTrans rawDiag (fromIntegral cy) (fromIntegral cx) alphaPtr ap (fromIntegral astride) cp (fromIntegral cstride) {-# NOINLINE trsmAbstraction #-} trsmAbstraction :: (SM.Storable el, PrimMonad m) => String -> TrsmFunFFI scale el -> TrsmFunFFI scale el -> (el -> (scale -> m ()) -> m ()) -> forall orient . TrsmFun el orient (PrimState m) m trsmAbstraction trsmName trsmSafeFFI trsmUnsafeFFI constHandler = trsm where isBadTrsmBothSide :: (Ord a, Num a) => a -> a -> a -> a -> Bool isBadTrsmBothSide ax ay cx cy = (minimum [ax, ay, cx, cy] <= 0) || not (ax == ay) isBadTrsm :: (Ord a, Num a) => EquationSide -> a -> a -> a -> a -> Bool isBadTrsm LeftSide ax ay cx cy = isBadTrsmBothSide ax ay cx cy || (ax /= cy) isBadTrsm RightSide ax ay cx cy = isBadTrsmBothSide ax ay cx cy || (ax /= cx) -- n * k * n shouldCallFast :: Int -> Int -> Int -> Bool shouldCallFast ax cx cy = flopsThreshold >= (fromIntegral ax :: Int64) * (fromIntegral cx :: Int64) * (fromIntegral cy :: Int64) trsm side uplo trans diag alpha (MutableDenseMatrix ornta ax ay astride abuff) (MutableDenseMatrix _ cx cy cstride cbuff) | isBadTrsm side ax ay cx cy = error $! "bad dimension args to " ++ trsmName ++ ": ax ay cx cy side: " ++ show [ax, ay, cx ,cy] ++ " " ++ show side | SM.overlaps abuff cbuff = error $ "the read and write inputs for: " ++ trsmName ++ " overlap. This is a programmer error. Please fix." | otherwise = call where call = unsafeWithPrim abuff $ \ap -> unsafeWithPrim cbuff $ \cp -> constHandler alpha $ \alphaPtr -> do let rawOrder = encodeNiceOrder ornta let rawSide = encodeFFISide side let rawUplo = encodeFFIMatrixHalf uplo let rawTrans = encodeFFITranspose trans let rawDiag = encodeFFITriangleSort diag unsafePrimToPrim $! (if shouldCallFast ax ay cx then trsmUnsafeFFI else trsmSafeFFI) rawOrder rawSide rawUplo rawTrans rawDiag (fromIntegral cy) (fromIntegral cx) alphaPtr ap (fromIntegral astride) cp (fromIntegral cstride)