{-# language TypeFamilies, MultiParamTypeClasses, KindSignatures, FlexibleContexts, FlexibleInstances, ConstraintKinds #-} {-# language AllowAmbiguousTypes #-} {-# language CPP #-} ----------------------------------------------------------------------------- -- | -- Module : Numeric.LinearAlgebra.Class -- Copyright : (c) Marco Zocca 2017 -- License : GPL-3 (see the file LICENSE) -- -- Maintainer : zocca marco gmail -- Stability : experimental -- Portability : portable -- -- Typeclasses for linear algebra and related concepts -- ----------------------------------------------------------------------------- module Numeric.LinearAlgebra.Class where -- import Control.Applicative import Data.Complex -- import Control.Exception -- import Control.Exception.Common import Control.Monad.Catch import Control.Monad.IO.Class -- import Data.Typeable (Typeable) import qualified Data.Vector as V (Vector) import Data.Sparse.Types import Numeric.Eps -- * Matrix and vector elements (optionally Complex) class (Eq e , Fractional e, Floating e, Num (EltMag e), Ord (EltMag e)) => Elt e where type EltMag e :: * -- | Complex conjugate, or identity function if its input is real-valued conj :: e -> e conj = id -- | Magnitude mag :: e -> EltMag e instance Elt Double where {type EltMag Double = Double ; mag = id} instance Elt Float where {type EltMag Float = Float; mag = id} instance (RealFloat e) => Elt (Complex e) where type EltMag (Complex e) = e conj = conjugate mag = magnitude infixl 6 ^+^, ^-^ -- * Additive group class AdditiveGroup v where -- | The zero element: identity for '(^+^)' zeroV :: v -- | Add vectors (^+^) :: v -> v -> v -- | Additive inverse negateV :: v -> v -- | Group subtraction (^-^) :: v -> v -> v (^-^) x y = x ^+^ negateV y infixr 7 .* -- * Vector space @v@. class (AdditiveGroup v, Num (Scalar v)) => VectorSpace v where type Scalar v :: * -- | Scale a vector (.*) :: Scalar v -> v -> v -- | Adds inner (dot) products. class VectorSpace v => InnerSpace v where -- | Inner/dot product (<.>) :: v -> v -> Scalar v -- | Inner product dot :: InnerSpace v => v -> v -> Scalar v dot = (<.>) infixr 7 ./ infixl 7 *. -- | Scale a vector by the reciprocal of a number (e.g. for normalization) (./) :: (VectorSpace v, s ~ Scalar v, Fractional s) => v -> s -> v v ./ s = (recip s) .* v -- | Vector multiplied by scalar (*.) :: (VectorSpace v, s ~ Scalar v) => v -> s -> v (*.) = flip (.*) -- | Convex combination of two vectors (NB: 0 <= `a` <= 1). cvx :: VectorSpace v => Scalar v -> v -> v -> v cvx a u v = a .* u ^+^ ((1-a) .* v) -- ** Hilbert-space distance function -- |`hilbertDistSq x y = || x - y ||^2` computes the squared L2 distance between two vectors hilbertDistSq :: InnerSpace v => v -> v -> Scalar v hilbertDistSq x y = t <.> t where t = x ^-^ y -- * Normed vector spaces class (InnerSpace v, Num (RealScalar v), Eq (RealScalar v), Epsilon (Magnitude v), Show (Magnitude v), Ord (Magnitude v)) => Normed v where type Magnitude v :: * type RealScalar v :: * -- | L1 norm norm1 :: v -> Magnitude v -- | Euclidean (L2) norm squared norm2Sq :: v -> Magnitude v -- | Lp norm (p > 0) normP :: RealScalar v -> v -> Magnitude v -- | Normalize w.r.t. Lp norm normalize :: RealScalar v -> v -> v -- | Normalize w.r.t. L2 norm normalize2 :: v -> v -- | Normalize w.r.t. norm2' instead of norm2 normalize2' :: Floating (Scalar v) => v -> v normalize2' x = x ./ norm2' x -- | Euclidean (L2) norm norm2 :: Floating (Magnitude v) => v -> Magnitude v norm2 x = sqrt (norm2Sq x) -- | Euclidean (L2) norm; returns a Complex (norm :+ 0) for Complex-valued vectors norm2' :: Floating (Scalar v) => v -> Scalar v norm2' x = sqrt $ x <.> x -- | Lp norm (p > 0) norm :: Floating (Magnitude v) => RealScalar v -> v -> Magnitude v norm p v | p == 1 = norm1 v | p == 2 = norm2 v | otherwise = normP p v -- | Infinity-norm (Real) normInftyR :: (Foldable t, Ord a) => t a -> a normInftyR x = maximum x -- | Infinity-norm (Complex) normInftyC :: (Foldable t, RealFloat a, Functor t) => t (Complex a) -> a normInftyC x = maximum (magnitude <$> x) -- | Lp inner product (p > 0) dotLp :: (Set t, Foldable t, Floating a) => a -> t a -> t a -> a dotLp p v1 v2 = sum u**(1/p) where f a b = (a*b)**p u = liftI2 f v1 v2 -- | Reciprocal reciprocal :: (Functor f, Fractional b) => f b -> f b reciprocal = fmap recip -- |Scale scale :: (Num b, Functor f) => b -> f b -> f b scale n = fmap (* n) -- * Matrix ring -- | A matrix ring is any collection of matrices over some ring R that form a ring under matrix addition and matrix multiplication class (AdditiveGroup m, Epsilon (MatrixNorm m)) => MatrixRing m where type MatrixNorm m :: * -- | Matrix-matrix product (##) :: m -> m -> m -- | Matrix times matrix transpose (A B^T) (##^) :: m -> m -> m -- | Matrix transpose times matrix (A^T B) (#^#) :: m -> m -> m a #^# b = transpose a ## b -- | Matrix transpose (Hermitian conjugate in the Complex case) transpose :: m -> m -- | Frobenius norm normFrobenius :: m -> MatrixNorm m -- a "sparse matrix ring" ? -- class MatrixRing m a => SparseMatrixRing m a where -- (#~#) :: Epsilon a => Matrix m a -> Matrix m a -> Matrix m a -- * Linear vector space class (VectorSpace v {-, MatrixRing (MatrixType v)-}) => LinearVectorSpace v where type MatrixType v :: * -- | Matrix-vector action (#>) :: MatrixType v -> v -> v -- | Dual matrix-vector action (<#) :: v -> MatrixType v -> v -- ** LinearVectorSpace + Normed type V v = (LinearVectorSpace v, Normed v) -- ** Linear systems class LinearVectorSpace v => LinearSystem v where -- | Solve a linear system; uses GMRES internally as default method (<\>) :: (MonadIO m, MonadThrow m) => MatrixType v -- ^ System matrix -> v -- ^ Right-hand side -> m v -- ^ Result -- * FiniteDim : finite-dimensional objects class FiniteDim f where type FDSize f -- | Dimension (i.e. Int for SpVector, (Int, Int) for SpMatrix) dim :: f -> FDSize f -- * HasData : accessing inner data (do not export) class HasData f where type HDData f -- | Number of nonzeros nnz :: f -> Int dat :: f -> HDData f -- * Sparse : sparse datastructures class (FiniteDim f, HasData f) => Sparse f where -- | Sparsity (fraction of nonzero elements) spy :: Fractional b => f -> b -- * Set : types that behave as sets class Functor f => Set f where -- | Union binary lift : apply function on _union_ of two "sets" liftU2 :: (a -> a -> a) -> f a -> f a -> f a -- | Intersection binary lift : apply function on _intersection_ of two "sets" liftI2 :: (a -> a -> b) -> f a -> f a -> f b -- * SpContainer : sparse container datastructures. Insertion, lookup, toList, lookup with 0 default class Sparse c => SpContainer c where type ScIx c :: * type ScElem c scInsert :: ScIx c -> ScElem c -> c -> c scLookup :: c -> ScIx c -> Maybe (ScElem c) scToList :: c -> [(ScIx c, ScElem c)] -- -- | Lookup with default, infix form ("safe" : should throw an exception if lookup is outside matrix bounds) (@@) :: c -> ScIx c -> ScElem c -- * SparseVector class SpContainer v => SparseVector v where type SpvIx v :: * svFromList :: Int -> [(SpvIx v, ScElem v)] -> v svFromListDense :: Int -> [ScElem v] -> v svConcat :: Foldable t => t v -> v -- svZipWith :: (e -> e -> e) -> v e -> v e -> v e -- * SparseMatrix class SpContainer m => SparseMatrix m where smFromVector :: LexOrd -> (Int, Int) -> V.Vector (IxRow, IxCol, ScElem m) -> m -- smFromFoldableDense :: Foldable t => t e -> m e smTranspose :: m -> m -- smExtractSubmatrix :: -- m e -> (IxRow, IxRow) -> (IxCol, IxCol) -> m e encodeIx :: m -> LexOrd -> (IxRow, IxCol) -> LexIx decodeIx :: m -> LexOrd -> LexIx -> (IxRow, IxCol) -- data RowsFirst = RowsFirst -- data ColsFirst = ColsFirst -- * SparseMatVec -- | Combining functions for relating (structurally) matrices and vectors, e.g. extracting/inserting rows/columns/submatrices -- class (SparseMatrix m o e, SparseVector v e) => SparseMatVec m o v e where -- smvInsertRow :: m e -> v e -> IxRow -> m e -- smvInsertCol :: m e -> v e -> IxCol -> m e -- smvExtractRow :: m e -> IxRow -> v e -- smvExtractCol :: m e -> IxCol -> v e -- * Utilities -- | Lift a real number onto the complex plane toC :: Num a => a -> Complex a toC r = r :+ 0 -- | Instances for builtin types #define ScalarType(t) \ instance AdditiveGroup (t) where {zeroV = 0; (^+^) = (+); negateV = negate};\ instance VectorSpace (t) where {type Scalar (t) = t; (.*) = (*) }; -- ScalarType(Int) -- ScalarType(Integer) ScalarType(Float) ScalarType(Double) ScalarType(Complex Float) ScalarType(Complex Double) -- ScalarType(CSChar) -- ScalarType(CInt) -- ScalarType(CShort) -- ScalarType(CLong) -- ScalarType(CLLong) -- ScalarType(CIntMax) -- ScalarType(CFloat) -- ScalarType(CDouble) #undef ScalarType instance InnerSpace Float where {(<.>) = (*)} instance InnerSpace Double where {(<.>) = (*)} instance InnerSpace (Complex Float) where {x <.> y = x * conjugate y} instance InnerSpace (Complex Double) where {x <.> y = x * conjugate y} #define SimpleNormedInstance(t) \ instance Normed (t) where {type Magnitude (t) = t; type RealScalar (t) = t;\ norm1 = abs; norm2Sq = (**2); normP _ = abs; normalize _ = signum;\ normalize2 = signum; normalize2' = signum; norm2 = abs; norm2' = abs; norm _ = abs}; SimpleNormedInstance(Float) SimpleNormedInstance(Double) #undef SimpleNormedInstance #define ComplexNormedInstance(t) \ instance Normed (Complex (t)) where {\ type Magnitude (Complex (t)) = t;\ type RealScalar (Complex (t)) = t;\ norm1 (r :+ i) = abs r + abs i;\ norm2Sq (r :+ i) = r*r + i*i;\ normP p (r :+ i) = (r**p + i**p)**(1/p);\ normalize p c = toC (1 / normP p c) * c;\ normalize2 c = (1 / norm2' c) * c;\ norm2 = magnitude;\ norm2' = toC . magnitude;}; ComplexNormedInstance(Float) ComplexNormedInstance(Double) #undef ComplexNormedInstance