{-# LANGUAGE MultiParamTypeClasses, FunctionalDependencies, FlexibleInstances, TypeFamilies #-} ----------------------------------------------------------------------------- -- -- Module : Math.Types -- Copyright : 2011 by Christian Gosch -- License : BSD3 -- -- Maintainer : Christian Gosch -- Stability : Experimental -- Portability : GHC only -- -- | Contains some types used by Jalla, including some BLAS/LAPACK related ones. ----------------------------------------------------------------------------- module Numeric.Jalla.Types ( -- * Classes Field1(..), -- ** BLAS And LAPACK BLASEnum(..), LAPACKEEnum(..), -- * Indexing Index, Shape, IndexPair, rowCountTrans, colCountTrans, shapeTrans, diagIndices, -- * Information About Matrices And Storage Order(..), Transpose(..), UpLo(..), module Data.Complex, ) where import Data.Complex import Data.Orphans () import Foreign.C.Types import Foreign.Marshal.Array import Foreign import qualified Data.Tuple as T (swap) type Index = Int type Shape = (Index,Index) type IndexPair = (Index,Index) {-| Row count of a matrix with given transposedness and shape. -} rowCountTrans :: Transpose -> Shape -> Index rowCountTrans t (r,c) | t == Trans = c | otherwise = r {-| Column count of a matrix with given transposedness and shape. -} colCountTrans :: Transpose -> Shape -> Index colCountTrans t (r,c) | t == Trans = r | otherwise = c {-| Shape of a matrix with given transposedness and shape. -} shapeTrans :: Transpose -> Shape -> Shape shapeTrans t s | t == Trans = T.swap s | otherwise = s {-| Generate indices of a diagonal in a matrix of given shape. -} diagIndices :: Shape -- ^ The shape of the matrix (rows,columns) -> Index -- ^ The index of the diagonal -- 0: main diagonal; < 0: lower diagonals; >0: upper diagonals -> [IndexPair] -- ^ Index list. Empty if there is no such diagonal. diagIndices (r,c) d | d >= 0 && d < c = diagIndices' (0, d, min (c - d) r) | d < 0 && d > (-r) = diagIndices' (-d, 0, min (r + d) c) | otherwise = [] where diagIndices' :: (Index,Index,Index) -> [(Index,Index)] diagIndices' (rstart,cstart,n) = [(rstart + i, cstart + i) | i <- [0..(max 0 (n-1))]] data Order = RowMajor | ColumnMajor deriving (Eq, Show) -- type Order = CblasOrder data Transpose = Trans | NoTrans deriving (Eq, Show) data UpLo = Up | Lo deriving (Eq, Show) class BLASEnum e be where toBlas :: e -> be fromBlas :: be -> e class LAPACKEEnum e le where toLapacke :: e -> le fromLapacke :: le -> e f :: Complex a -> a f _ = undefined {-| Defines a scalar type for each field type. Those are 'Complex' 'CFloat' and 'CFloat', as well as 'Complex' 'CDouble' and 'CDouble'. -} class (Num e, Floating e, Show e) => Field1 e where type FieldScalar e :: * instance Field1 CFloat where type FieldScalar CFloat = CFloat instance Field1 CDouble where type FieldScalar CDouble = CDouble instance Field1 (Complex CFloat) where type FieldScalar (Complex CFloat) = CFloat instance Field1 (Complex CDouble) where type FieldScalar (Complex CDouble) = CDouble