{-# LANGUAGE ForeignFunctionInterface, GeneralizedNewtypeDeriving #-} module Numerical.HBLAS.BLAS.FFI.Level3 where import Foreign.Ptr import Foreign() import Foreign.C.Types import Data.Complex import Numerical.HBLAS.BLAS.FFI -------------------------------------------------------------------------------- ------------------------------ | BLAS LEVEL 3 ROUTINES -------------------------------------------------------------------------------- ----------------------- | Level 3 ops are faster than Levels 1 or 2 -------------------------------------------------------------------------------- --void cblas_sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const blasint M, const blasint N, const blasint K, -- const float alpha, const float *A, const blasint lda, const float *B, const blasint ldb, const float beta, float *C, const blasint ldc); --void cblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const blasint M, const blasint N, const blasint K, -- const double alpha, const double *A, const blasint lda, const double *B, const blasint ldb, const double beta, double *C, const blasint ldc); --void cblas_cgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const blasint M, const blasint N, const blasint K, -- const float *alpha, const float *A, const blasint lda, const float *B, const blasint ldb, const float *beta, float *C, const blasint ldc); --void cblas_zgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const blasint M, const blasint N, const blasint K, -- const double *alpha, const double *A, const blasint lda, const double *B, const blasint ldb, const double *beta, double *C, const blasint ldc); -- | Matrix mult for general dense matrices type GemmFunFFI scale el = CBLAS_ORDERT -> CBLAS_TRANSPOSET -> CBLAS_TRANSPOSET-> CInt -> CInt -> CInt -> {- scal A * B -} scale -> {- Matrix A-} Ptr el -> CInt -> {- B -} Ptr el -> CInt-> scale -> {- C -} Ptr el -> CInt -> IO () {- C := alpha*op( A )*op( B ) + beta*C , -} -- matrix mult! foreign import ccall unsafe "cblas_sgemm" cblas_sgemm_unsafe :: GemmFunFFI Float Float foreign import ccall unsafe "cblas_dgemm" cblas_dgemm_unsafe :: GemmFunFFI Double Double foreign import ccall unsafe "cblas_cgemm" cblas_cgemm_unsafe :: GemmFunFFI (Ptr(Complex Float)) (Complex Float) foreign import ccall unsafe "cblas_zgemm" cblas_zgemm_unsafe :: GemmFunFFI (Ptr (Complex Double)) (Complex Double) -- safe ffi variant for large inputs foreign import ccall "cblas_sgemm" cblas_sgemm_safe :: GemmFunFFI Float Float foreign import ccall "cblas_dgemm" cblas_dgemm_safe :: GemmFunFFI Double Double foreign import ccall "cblas_cgemm" cblas_cgemm_safe :: GemmFunFFI (Ptr(Complex Float)) (Complex Float) foreign import ccall "cblas_zgemm" cblas_zgemm_safe :: GemmFunFFI (Ptr (Complex Double)) (Complex Double) ----------------------------------------- ----- | Matrix mult for Symmetric Matrices ----------------------------------------- --void cblas_ssymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, const enum CBLAS_UPLO Uplo, const blasint M, const blasint N, -- const float alpha, const float *A, const blasint lda, const float *B, const blasint ldb, const float beta, float *C, const blasint ldc); --void cblas_dsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, const enum CBLAS_UPLO Uplo, const blasint M, const blasint N, -- const double alpha, const double *A, const blasint lda, const double *B, const blasint ldb, const double beta, double *C, const blasint ldc); --void cblas_csymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, const enum CBLAS_UPLO Uplo, const blasint M, const blasint N, -- const float *alpha, const float *A, const blasint lda, const float *B, const blasint ldb, const float *beta, float *C, const blasint ldc); --void cblas_zsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, const enum CBLAS_UPLO Uplo, const blasint M, const blasint N, -- const double *alpha, const double *A, const blasint lda, const double *B, const blasint ldb, const double *beta, double *C, const blasint ldc); type SymmFunFFI scale el = CBLAS_ORDERT -> CBLAS_SIDET -> CBLAS_UPLOT -> CInt->CInt -> scale -> Ptr el -> CInt -> Ptr el -> CInt -> scale ->Ptr el -> CInt -> IO () foreign import ccall unsafe "cblas_ssymm" cblas_ssymm_unsafe :: SymmFunFFI Float Float foreign import ccall unsafe "cblas_dsymm" cblas_dsymm_unsafe :: SymmFunFFI Double Double foreign import ccall unsafe "cblas_csymm" cblas_csymm_unsafe :: SymmFunFFI (Ptr (Complex Float )) (Complex Float) foreign import ccall unsafe "cblas_zsymm" cblas_zsymm_unsafe :: SymmFunFFI (Ptr (Complex Double)) (Complex Double) -- safe ffi variant, foreign import ccall "cblas_ssymm" cblas_ssymm_safe :: SymmFunFFI Float Float foreign import ccall "cblas_dsymm" cblas_dsymm_safe :: SymmFunFFI Double Double foreign import ccall "cblas_csymm" cblas_csymm_safe :: SymmFunFFI (Ptr (Complex Float )) (Complex Float) foreign import ccall "cblas_zsymm" cblas_zsymm_safe :: SymmFunFFI (Ptr (Complex Double)) (Complex Double) ----------------------------------- --- | symmetric rank k matrix update, C := alpha*A*A' + beta*C --- or C = alpha*A'*A + beta*C ------------------------------------ --void cblas_ssyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, -- const blasint N, const blasint K, const float alpha, const float *A, const blasint lda, const float beta, float *C, const blasint ldc); --void cblas_dsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, -- const blasint N, const blasint K, const double alpha, const double *A, const blasint lda, const double beta, double *C, const blasint ldc); --void cblas_csyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, -- const blasint N, const blasint K, const float *alpha, const float *A, const blasint lda, const float *beta, float *C, const blasint ldc); --void cblas_zsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, -- const blasint N, const blasint K, const double *alpha, const double *A, const blasint lda, const double *beta, double *C, const blasint ldc); type SyrkFunFFI scale el = CBLAS_ORDERT -> CBLAS_UPLOT -> CBLAS_TRANSPOSET -> CInt -> CInt -> scale -> Ptr el -> CInt -> scale -> Ptr el -> CInt -> IO () foreign import ccall unsafe "cblas_ssyrk" cblas_ssyrk_unsafe :: SyrkFunFFI Float Float foreign import ccall unsafe "cblas_dsyrk" cblas_dsyrk_unsafe :: SyrkFunFFI Double Double foreign import ccall unsafe "cblas_csyrk" cblas_csyrk_unsafe :: SyrkFunFFI (Ptr(Complex Float)) (Complex Float) foreign import ccall unsafe "cblas_zsyrk" cblas_zsyrk_unsafe :: SyrkFunFFI (Ptr(Complex Double)) (Complex Double) foreign import ccall safe "cblas_ssyrk" cblas_ssyrk_safe :: SyrkFunFFI Float Float foreign import ccall safe "cblas_dsyrk" cblas_dsyrk_safe :: SyrkFunFFI Double Double foreign import ccall safe "cblas_csyrk" cblas_csyrk_safe :: SyrkFunFFI (Ptr(Complex Float)) (Complex Float) foreign import ccall safe "cblas_zsyrk" cblas_zsyrk_safe :: SyrkFunFFI (Ptr(Complex Double)) (Complex Double) ---------------------- ----- | Symmetric Rank 2k matrix update, C= alpha* A*B' + alpha* B*A' + beta * C ----- or C= alpha* A'*B + alpha* B'*A + beta * C ------------------- --void cblas_ssyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, -- const blasint N, const blasint K, const float alpha, const float *A, const blasint lda, const float *B, const blasint ldb, const float beta, float *C, const blasint ldc); --void cblas_dsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, -- const blasint N, const blasint K, const double alpha, const double *A, const blasint lda, const double *B, const blasint ldb, const double beta, double *C, const blasint ldc); --void cblas_csyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, -- const blasint N, const blasint K, const float *alpha, const float *A, const blasint lda, const float *B, const blasint ldb, const float *beta, float *C, const blasint ldc); --void cblas_zsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, --const blasint N, const blasint K, const double *alpha, const double *A, const blasint lda, const double *B, const blasint ldb, const double *beta, double *C, const blasint ldc); type Syr2kFunFFI scale el = CBLAS_ORDERT -> CBLAS_UPLOT -> CBLAS_TRANSPOSET -> CInt->CInt -> scale -> Ptr el -> CInt -> Ptr el -> CInt -> scale ->Ptr el -> CInt -> IO () foreign import ccall unsafe "cblas_ssyr2k" cblas_ssyr2k_unsafe :: Syr2kFunFFI Float Float foreign import ccall unsafe "cblas_dsyr2k" cblas_dsyr2k_unsafe :: Syr2kFunFFI Double Double foreign import ccall unsafe "cblas_csyr2k" cblas_csyr2k_unsafe :: Syr2kFunFFI (Ptr (Complex Float)) (Complex Float) foreign import ccall unsafe "cblas_zsyr2k" cblas_zsyr2k_unsafe :: Syr2kFunFFI (Ptr (Complex Double)) (Complex Double) foreign import ccall safe "cblas_ssyr2k" cblas_ssyr2k_safe :: Syr2kFunFFI Float Float foreign import ccall safe "cblas_dsyr2k" cblas_dsyr2k_safe :: Syr2kFunFFI Double Double foreign import ccall safe "cblas_csyr2k" cblas_csyr2k_safe :: Syr2kFunFFI (Ptr (Complex Float)) (Complex Float) foreign import ccall safe "cblas_zsyr2k" cblas_zsyr2k_safe :: Syr2kFunFFI (Ptr (Complex Double)) (Complex Double) ------------------------------- -------- | matrix matrix product for triangular matrices ------------------------------ type TrmmFunFFI scale el = CBLAS_ORDERT -> CBLAS_SIDET -> CBLAS_UPLOT -> CBLAS_TRANSPOSET -> CBLAS_DIAGT -> CInt -> CInt -> scale -> Ptr el -> CInt -> Ptr el -> CInt -> IO () foreign import ccall unsafe "cblas_strmm" cblas_strmm_unsafe :: TrmmFunFFI Float Float foreign import ccall unsafe "cblas_dtrmm" cblas_dtrmm_unsafe :: TrmmFunFFI Double Double foreign import ccall unsafe "cblas_ctrmm" cblas_ctrmm_unsafe :: TrmmFunFFI (Ptr (Complex Float )) (Complex Float) foreign import ccall unsafe "cblas_ztrmm" cblas_ztrmm_unsafe :: TrmmFunFFI (Ptr (Complex Double )) (Complex Double) foreign import ccall safe "cblas_strmm" cblas_strmm_safe :: TrmmFunFFI Float Float foreign import ccall safe "cblas_dtrmm" cblas_dtrmm_safe :: TrmmFunFFI Double Double foreign import ccall safe "cblas_ctrmm" cblas_ctrmm_safe :: TrmmFunFFI (Ptr (Complex Float )) (Complex Float) foreign import ccall safe "cblas_ztrmm" cblas_ztrmm_safe :: TrmmFunFFI (Ptr (Complex Double )) (Complex Double) --void cblas_strmm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, -- enum CBLAS_DIAG Diag, CInt M, CInt N, Float alpha, Float *A, CInt lda, Float *B, CInt ldb); --void cblas_dtrmm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, -- enum CBLAS_DIAG Diag, CInt M, CInt N, Double alpha, Double *A, CInt lda, Double *B, CInt ldb); --void cblas_ctrmm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, -- enum CBLAS_DIAG Diag, CInt M, CInt N, Float *alpha, Float *A, CInt lda, Float *B, CInt ldb); --void cblas_ztrmm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, -- enum CBLAS_DIAG Diag, CInt M, CInt N, Double *alpha, Double *A, CInt lda, Double *B, CInt ldb); ------------------------ -- | triangular solvers ----------------------- -- --TRSM solves op(A)*X = alpha*B or X*op(A) = alpha*B --op(A) is one of op(A) = A, or op(A) = A', or op(A) = conjg(A'). -- A is a unit, or non-unit, upper or lower triangular matrix ---- type TrsmFunFFI scale el = CBLAS_ORDERT -> CBLAS_SIDET -> CBLAS_UPLOT -> CBLAS_TRANSPOSET -> CBLAS_DIAGT -> CInt -> CInt -> scale -> Ptr el -> CInt -> Ptr el -> CInt -> IO () foreign import ccall unsafe "cblas_strsm" cblas_strsm_unsafe :: TrsmFunFFI Float Float foreign import ccall unsafe "cblas_dtrsm" cblas_dtrsm_unsafe :: TrsmFunFFI Double Double foreign import ccall unsafe "cblas_ctrsm" cblas_ctrsm_unsafe :: TrsmFunFFI (Ptr (Complex Float )) (Complex Float) foreign import ccall unsafe "cblas_ztrsm" cblas_ztrsm_unsafe :: TrsmFunFFI (Ptr (Complex Double )) (Complex Double) foreign import ccall safe "cblas_strsm" cblas_strsm_safe :: TrsmFunFFI Float Float foreign import ccall safe "cblas_dtrsm" cblas_dtrsm_safe :: TrsmFunFFI Double Double foreign import ccall safe "cblas_ctrsm" cblas_ctrsm_safe :: TrsmFunFFI (Ptr (Complex Float )) (Complex Float) foreign import ccall safe "cblas_ztrsm" cblas_ztrsm_safe :: TrsmFunFFI (Ptr (Complex Double )) (Complex Double) --void cblas_strsm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, -- enum CBLAS_DIAG Diag, CInt M, CInt N, Float alpha, Float *A, CInt lda, Float *B, CInt ldb); --void cblas_dtrsm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, -- enum CBLAS_DIAG Diag, CInt M, CInt N, Double alpha, Double *A, CInt lda, Double *B, CInt ldb); --void cblas_ctrsm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, -- enum CBLAS_DIAG Diag, CInt M, CInt N, Float *alpha, Float *A, CInt lda, Float *B, CInt ldb); --void cblas_ztrsm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, -- enum CBLAS_DIAG Diag, CInt M, CInt N, Double *alpha, Double *A, CInt lda, Double *B, CInt ldb); ------------------------- -- | hermitian matrix mult ------------------------ type HemmFunFFI el = CBLAS_ORDERT -> CBLAS_SIDET -> CBLAS_UPLOT -> CInt -> CInt -> Ptr el -> Ptr el -> CInt -> Ptr el -> CInt -> Ptr el -> Ptr el -> CInt -> IO () foreign import ccall unsafe "cblas_chemm" cblas_chemm_unsafe :: HemmFunFFI (Complex Float) foreign import ccall unsafe "cblas_zhemm" cblas_zhemm_unsafe :: HemmFunFFI (Complex Double) foreign import ccall safe "cblas_chemm" cblas_chemm_safe :: HemmFunFFI (Complex Float) foreign import ccall safe "cblas_zhemm" cblas_zhemm_safe :: HemmFunFFI (Complex Double) --void cblas_chemm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, enum CBLAS_UPLO Uplo, CInt M, CInt N, -- Float *alpha, Float *A, CInt lda, Float *B, CInt ldb, Float *beta, Float *C, CInt ldc); --void cblas_zhemm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, enum CBLAS_UPLO Uplo, CInt M, CInt N, -- Double *alpha, Double *A, CInt lda, Double *B, CInt ldb, Double *beta, Double *C, CInt ldc); type HerkFunFFI scale el = CBLAS_ORDERT -> CBLAS_UPLOT -> CBLAS_TRANSPOSET -> CInt -> CInt -> scale -> Ptr el -> CInt -> scale -> Ptr el -> CInt -> IO () foreign import ccall unsafe "cblas_cherk" cblas_cherk_unsafe :: HerkFunFFI Float (Complex Float) foreign import ccall unsafe "cblas_zherk" cblas_zherk_unsafe :: HerkFunFFI Double (Complex Double) foreign import ccall safe "cblas_cherk" cblas_cherk_safe :: HerkFunFFI Float (Complex Float) foreign import ccall safe "cblas_zherk" cblas_zherk_safe :: HerkFunFFI Double (Complex Double) --void cblas_cherk( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE Trans, CInt N, CInt K, -- Float alpha, Float *A, CInt lda, Float beta, Float *C, CInt ldc); --void cblas_zherk( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE Trans, CInt N, CInt K, -- Double alpha, Double *A, CInt lda, Double beta, Double *C, CInt ldc); type Her2kFunFFI scale el = CBLAS_ORDERT -> CBLAS_UPLOT -> CBLAS_TRANSPOSET -> CInt -> CInt -> Ptr el -> Ptr el -> CInt -> Ptr el -> CInt -> scale ->Ptr el -> CInt -> IO () foreign import ccall unsafe "cblas_cher2k" cblas_cher2k_unsafe :: Her2kFunFFI Float (Complex Float) foreign import ccall unsafe "cblas_zher2k" cblas_zher2k_unsafe :: Her2kFunFFI Double (Complex Double) foreign import ccall safe "cblas_cher2k" cblas_cher2k_safe :: Her2kFunFFI Float (Complex Float) foreign import ccall safe "cblas_zher2k" cblas_zher2k_safe :: Her2kFunFFI Double (Complex Double) --void cblas_cher2k( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE Trans, CInt N, CInt K, -- Float *alpha, Float *A, CInt lda, Float *B, CInt ldb, Float beta, Float *C, CInt ldc); --void cblas_zher2k( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE Trans, CInt N, CInt K, -- Double *alpha, Double *A, CInt lda, Double *B, CInt ldb, Double beta, Double *C, CInt ldc); ----void cblas_xerbla(CInt p, char *rout, char *form, ...);