{-# language TypeFamilies, MultiParamTypeClasses, KindSignatures, FlexibleContexts, FlexibleInstances, ConstraintKinds #-} {-# language AllowAmbiguousTypes #-} {-# language CPP #-} ----------------------------------------------------------------------------- -- | -- Module : Numeric.LinearAlgebra.Class -- Copyright : (c) Marco Zocca 2017 -- License : GPL-style (see the file LICENSE) -- -- Maintainer : zocca marco gmail -- Stability : experimental -- Portability : portable -- -- Typeclasses for linear algebra and related concepts -- ----------------------------------------------------------------------------- module Numeric.LinearAlgebra.Class where -- import Control.Applicative import Data.Complex -- import Data.Ratio -- import Foreign.C.Types (CSChar, CInt, CShort, CLong, CLLong, CIntMax, CFloat, CDouble) -- import Control.Exception -- import Control.Exception.Common import Control.Monad.Catch import Control.Monad.IO.Class -- import Data.Typeable (Typeable) import qualified Data.Vector as V (Vector) import Data.VectorSpace hiding (magnitude) import Data.Sparse.Types import Numeric.Eps -- * Matrix and vector elements (optionally Complex) class (Eq e , Fractional e, Floating e, Num (EltMag e), Ord (EltMag e)) => Elt e where type EltMag e :: * -- | Complex conjugate, or identity function if its input is real-valued conj :: e -> e conj = id -- | Magnitude mag :: e -> EltMag e instance Elt Double where {type EltMag Double = Double ; mag = id} instance Elt Float where {type EltMag Float = Float; mag = id} instance RealFloat e => Elt (Complex e) where type EltMag (Complex e) = e conj = conjugate mag = magnitude -- * Vector space -- | Scale a vector (.*) :: VectorSpace v => Scalar v -> v -> v (.*) = (*^) -- | Scale a vector by the reciprocal of a number (e.g. for normalization) (./) :: (VectorSpace v, Fractional (Scalar v)) => v -> Scalar v -> v v ./ n = recip n .* v -- | Convex combination of two vectors (NB: 0 <= `a` <= 1). cvx :: (VectorSpace e, Num (Scalar e)) => Scalar e -> e -> e -> e cvx a u v = a .* u ^+^ ((1-a) .* v) -- linearCombination :: (VectorSpace v , Foldable t) => t (Scalar v, v) -> v -- linearCombination = foldr (\(a, x) (b, y) -> (a .* x) ^+^ (b .* y)) -- linComb a v = a .* v -- * Hilbert space (inner product) -- | Inner product dot :: InnerSpace v => v -> v -> Scalar v dot = (<.>) -- ** Hilbert-space distance function -- |`hilbertDistSq x y = || x - y ||^2` computes the squared L2 distance between two vectors hilbertDistSq :: InnerSpace v => v -> v -> Scalar v hilbertDistSq x y = t <.> t where t = x ^-^ y -- * Normed vector spaces class (InnerSpace v, Num (RealScalar v), Eq (RealScalar v), Epsilon (Magnitude v), Show (Magnitude v), Ord (Magnitude v)) => Normed v where type Magnitude v :: * type RealScalar v :: * -- | L1 norm norm1 :: v -> Magnitude v -- | Euclidean (L2) norm squared norm2Sq :: v -> Magnitude v -- | Lp norm (p > 0) normP :: RealScalar v -> v -> Magnitude v -- | Normalize w.r.t. Lp norm normalize :: RealScalar v -> v -> v -- | Normalize w.r.t. L2 norm normalize2 :: v -> v -- | Normalize w.r.t. norm2' instead of norm2 normalize2' :: Floating (Scalar v) => v -> v normalize2' x = x ./ norm2' x -- | Euclidean (L2) norm norm2 :: Floating (Magnitude v) => v -> Magnitude v norm2 x = sqrt (norm2Sq x) -- | Euclidean (L2) norm; returns a Complex (norm :+ 0) for Complex-valued vectors norm2' :: Floating (Scalar v) => v -> Scalar v norm2' x = sqrt $ x <.> x -- | Lp norm (p > 0) norm :: Floating (Magnitude v) => RealScalar v -> v -> Magnitude v norm p v | p == 1 = norm1 v | p == 2 = norm2 v | otherwise = normP p v -- | Infinity-norm (Real) normInftyR :: (Foldable t, Ord a) => t a -> a normInftyR x = maximum x -- | Infinity-norm (Complex) normInftyC :: (Foldable t, RealFloat a, Functor t) => t (Complex a) -> a normInftyC x = maximum (magnitude <$> x) instance Normed Double where type Magnitude Double = Double type RealScalar Double = Double norm1 = abs norm2Sq = abs normP _ = abs normalize _ _ = 1 normalize2 _ = 1 norm2 = abs norm2' = abs instance Normed (Complex Double) where type Magnitude (Complex Double) = Double type RealScalar (Complex Double) = Double norm1 (r :+ i) = abs r + abs i norm2Sq = (**2) . magnitude normP p (r :+ i) = (r**p + i**p)**(1/p) normalize p c = c ./ normP p c normalize2 c = c ./ magnitude c norm2 = magnitude norm2' = magnitude -- | Lp inner product (p > 0) dotLp :: (Set t, Foldable t, Floating a) => a -> t a -> t a -> a dotLp p v1 v2 = sum u**(1/p) where f a b = (a*b)**p u = liftI2 f v1 v2 -- | Reciprocal reciprocal :: (Functor f, Fractional b) => f b -> f b reciprocal = fmap recip -- |Scale scale :: (Num b, Functor f) => b -> f b -> f b scale n = fmap (* n) -- * Matrix ring -- | A matrix ring is any collection of matrices over some ring R that form a ring under matrix addition and matrix multiplication class (AdditiveGroup m, Epsilon (MatrixNorm m)) => MatrixRing m where type MatrixNorm m :: * -- | Matrix-matrix product (##) :: m -> m -> m -- | Matrix times matrix transpose (A B^T) (##^) :: m -> m -> m -- | Matrix transpose times matrix (A^T B) (#^#) :: m -> m -> m a #^# b = transpose a ## b -- | Matrix transpose (Hermitian conjugate in the Complex case) transpose :: m -> m -- | Frobenius norm normFrobenius :: m -> MatrixNorm m -- a "sparse matrix ring" ? -- class MatrixRing m a => SparseMatrixRing m a where -- (#~#) :: Epsilon a => Matrix m a -> Matrix m a -> Matrix m a -- * Linear vector space class (VectorSpace v, MatrixRing (MatrixType v)) => LinearVectorSpace v where type MatrixType v :: * -- | Matrix-vector action (#>) :: MatrixType v -> v -> v -- | Dual matrix-vector action (<#) :: v -> MatrixType v -> v -- ** LinearVectorSpace + Normed type V v = (LinearVectorSpace v, Normed v) -- ** Linear systems class LinearVectorSpace v => LinearSystem v where -- | Solve a linear system; uses GMRES internally as default method (<\>) :: (MonadIO m, MonadThrow m) => MatrixType v -- ^ System matrix -> v -- ^ Right-hand side -> m v -- ^ Result -- * FiniteDim : finite-dimensional objects class Functor f => FiniteDim f where type FDSize f :: * -- | Dimension (i.e. Int for SpVector, (Int, Int) for SpMatrix) dim :: f a -> FDSize f class FiniteDim' f where type FDSize' f :: * dim' :: f -> FDSize' f -- -- | unary dimension-checking bracket -- withDim :: (FiniteDim f, Show s) => -- f e -- -> (FDSize f -> f e -> Bool) -- -> (f e -> c) -- -> String -- -> (f e -> s) -- -> c -- withDim x p f e ef | p (dim x) x = f x -- | otherwise = error e' where e' = e ++ show (ef x) -- -- | binary dimension-checking bracket -- withDim2 :: (FiniteDim f, FiniteDim g, Show s) => -- f e -- -> g e -- -> (FDSize f -> FDSize g -> f e -> g e -> Bool) -- -> (f e -> g e -> c) -- -> String -- -> (f e -> g e -> s) -- -> c -- withDim2 x y p f e ef | p (dim x) (dim y) x y = f x y -- | otherwise = error e' where e' = e ++ show (ef x y) -- * HasData : accessing inner data (do not export) class HasData f a where type HDData f a :: * -- | Number of nonzeros nnz :: f a -> Int dat :: f a -> HDData f a class HasData' f where type HDD f :: * nnz' :: f -> Int dat' :: f -> HDD f -- * Sparse : sparse datastructures class (FiniteDim f, HasData f a) => Sparse f a where -- | Sparsity (fraction of nonzero elements) spy :: Fractional b => f a -> b class (FiniteDim' f, HasData' f) => Sparse' f where spy' :: Fractional b => f -> b -- * Set : types that behave as sets class Functor f => Set f where -- | Union binary lift : apply function on _union_ of two "sets" liftU2 :: (a -> a -> a) -> f a -> f a -> f a -- | Intersection binary lift : apply function on _intersection_ of two "sets" liftI2 :: (a -> a -> b) -> f a -> f a -> f b -- * SpContainer : sparse container datastructures. Insertion, lookup, toList, lookup with 0 default class Sparse c a => SpContainer c a where type ScIx c :: * scInsert :: ScIx c -> a -> c a -> c a scLookup :: c a -> ScIx c -> Maybe a scToList :: c a -> [(ScIx c, a)] -- -- | Lookup with default, infix form ("safe" : should throw an exception if lookup is outside matrix bounds) (@@) :: c a -> ScIx c -> a class SpContainer' c where type ScIx' c :: * scInsert' :: ScIx' c -> a -> c -> c scLookup' :: c -> ScIx' c -> Maybe a scToList' :: c -> [a] -- (@@') -- * SparseVector class SpContainer v e => SparseVector v e where type SpvIx v :: * svFromList :: Int -> [(SpvIx v, e)] -> v e svFromListDense :: Int -> [e] -> v e svConcat :: Foldable t => t (v e) -> v e -- svZipWith :: (e -> e -> e) -> v e -> v e -> v e -- * SparseMatrix class SpContainer m e => SparseMatrix m e where smFromVector :: LexOrd -> (Int, Int) -> V.Vector (IxRow, IxCol, e) -> m e -- smFromFoldableDense :: Foldable t => t e -> m e smTranspose :: m e -> m e -- smExtractSubmatrix :: -- m e -> (IxRow, IxRow) -> (IxCol, IxCol) -> m e encodeIx :: m e -> LexOrd -> (IxRow, IxCol) -> LexIx decodeIx :: m e -> LexOrd -> LexIx -> (IxRow, IxCol) -- data RowsFirst = RowsFirst -- data ColsFirst = ColsFirst -- class SpContainer m e => SparseMatrix m o e where -- smFromVector :: o -> (Int, Int) -> V.Vector (IxRow, IxCol, e) -> m e -- -- smFromFoldableDense :: Foldable t => t e -> m e -- smTranspose :: o -> m e -> m e -- -- smExtractSubmatrix :: -- -- m e -> (IxRow, IxRow) -> (IxCol, IxCol) -> m e -- encodeIx :: m e -> o -> (IxRow, IxCol) -> LexIx -- decodeIx :: m e -> o -> LexIx -> (IxRow, IxCol) -- * SparseMatVec -- | Combining functions for relating (structurally) matrices and vectors, e.g. extracting/inserting rows/columns/submatrices -- class (SparseMatrix m o e, SparseVector v e) => SparseMatVec m o v e where -- smvInsertRow :: m e -> v e -> IxRow -> m e -- smvInsertCol :: m e -> v e -> IxCol -> m e -- smvExtractRow :: m e -> IxRow -> v e -- smvExtractCol :: m e -> IxCol -> v e -- * Utilities -- | Lift a real number onto the complex plane toC :: Num a => a -> Complex a toC r = r :+ 0 -- -- | Instances for AdditiveGroup -- instance Integral a => AdditiveGroup (Ratio a) where -- {zero=0; (^+^) = (+); negated = negate} -- instance (RealFloat v, AdditiveGroup v) => AdditiveGroup (Complex v) where -- zero = zero :+ zero -- (^+^) = (+) -- negated = negate -- -- | Standard instance for an applicative functor applied to a vector space. -- instance AdditiveGroup v => AdditiveGroup (a -> v) where -- zero = pure zero -- (^+^) = liftA2 (^+^) -- negated = fmap negated -- -- | Instances for VectorSpace -- instance (RealFloat v, VectorSpace v) => VectorSpace (Complex v) where -- type Scalar (Complex v) = Scalar v -- s .* (u :+ v) = s .* u :+ s .* v -- #define ScalarType(t) \ -- instance AdditiveGroup (t) where {zero = 0; (^+^) = (+); negated = negate};\ -- instance VectorSpace (t) where {type Scalar (t) = (t); (.*) = (*) };\ -- instance Hilbert (t) where dot = (*) -- ScalarType(Int) -- ScalarType(Integer) -- ScalarType(Float) -- ScalarType(Double) -- ScalarType(CSChar) -- ScalarType(CInt) -- ScalarType(CShort) -- ScalarType(CLong) -- ScalarType(CLLong) -- ScalarType(CIntMax) -- ScalarType(CFloat) -- ScalarType(CDouble)