{-# LANGUAGE FunctionalDependencies #-} -- | class declarations module Linear.Class ( AbelianGroup(..) , vecSum , MultSemiGroup(..) , Ring , semigroupProduct , LeftModule(..) , RightModule(..) , Vector(..) , DotProd(..) , Norm(..) , CrossProd(..) , normalize , distance , angle , angle' , UnitVector(..) , Pointwise(..) , Extend(..) , Dimension(..) , Transpose(..) , SquareMatrix(..) , Tensor(..) , Diagonal (..) , Determinant(..) , Orthogonal(..) , Projective(..) , MatrixNorms(..) , project , project' , projectUnsafe , flipNormal , householder, householderOrtho ) where class AbelianGroup g where (&+) :: g -> g -> g (&-) :: g -> g -> g neg :: g -> g zero :: g infixl 6 &+ infixl 6 &- vecSum :: AbelianGroup g => [g] -> g vecSum l = foldl (&+) zero l class MultSemiGroup r where (.*.) :: r -> r -> r one :: r class (AbelianGroup r, MultSemiGroup r) => Ring r infixl 7 .*. -- was: ringProduct :: Ring r => [r] -> r semigroupProduct :: MultSemiGroup r => [r] -> r semigroupProduct l = foldl (.*.) one l class LeftModule r m where lmul :: r -> m -> m (*.) :: r -> m -> m (*.) = lmul class RightModule m r | m -> r, r -> m where rmul :: m -> r -> m (.*) :: m -> r -> m (.*) = rmul -- I'm not really sure about this.. may actually degrade the performance in some cases? {- RULES "matrix multiplication left" forall m n x. (n .*. m) *. x = n *. (m *. x) "matrix multiplication right" forall m n x. x .* (m .*. n) = (x .* m) .* n -} infixr 7 *. infixl 7 .* class AbelianGroup (v a) => Vector a v where mapVec :: (a -> a) -> v a -> v a scalarMul :: a -> v a -> v a (*&) :: a -> v a -> v a (&*) :: v a -> a -> v a (*&) s v = scalarMul s v (&*) v s = scalarMul s v infixr 7 *& infixl 7 &* {-# RULES "scalar multiplication left" forall (s :: Num s => s) (t :: Num t => t) x. t *& (s *& x) = (t*s) *& x "scalar multiplication right" forall (s :: Num s => s) (t :: Num t => t) x. (x &* s) &* t = x &* (s*t) #-} class Num a => DotProd a v where (&.) :: v a -> v a -> a dotprod :: v a -> v a -> a dotprod = (&.) normsqr :: v a -> a normsqr v = (v &. v) lensqr :: v a -> a lensqr = normsqr class (Floating a, DotProd a v) => Norm a v where norm :: v a -> a norm = sqrt . lensqr vlen :: v a -> a vlen = norm infix 7 &. {-# RULES "vlen/square 1" forall x. (vlen x)*(vlen x) = lensqr x "vlen/square 2" forall x. (vlen x)^2 = lensqr x "norm/square 1" forall x. (norm x)*(norm x) = normsqr x "norm/square 2" forall x. (norm x)^2 = normsqr x #-} normalize :: (Vector a v, Norm a v) => v a -> v a normalize v = scalarMul (recip (vlen v)) v distance :: (Vector a v, Norm a v) => v a -> v a -> a distance x y = norm (x &- y) -- | the angle between two vectors angle :: (Vector a v, Norm a v) => v a -> v a -> a angle x y = acos $ (x &. y) / (norm x * norm y) -- | the angle between two unit vectors angle' {- ' CPP is sensitive to primes -} :: (Floating a, Vector a v, UnitVector a v u, DotProd a v) => u a -> u a -> a angle' x y = acos (fromNormal x &. fromNormal y) {-# RULES "normalize is idempotent" forall x. normalize (normalize x) = normalize x #-} class (Vector a v, Norm a v) => UnitVector a v u | u -> v, v -> u where mkNormal :: v a -> u a -- ^ normalizes the input toNormalUnsafe :: v a -> u a -- ^ does not normalize the input! fromNormal :: u a -> v a fromNormalRadius :: a -> u a -> v a fromNormalRadius t n = t *& fromNormal n -- | Projects the first vector down to the hyperplane orthogonal to the second (unit) vector project' :: (Vector a v, UnitVector a v u, Norm a v) => v a -> u a -> v a project' what dir = projectUnsafe what (fromNormal dir) -- | Direction (second argument) is assumed to be a /unit/ vector! projectUnsafe :: (Vector a v, DotProd a v) => v a -> v a -> v a projectUnsafe what dir = what &- dir &* (what &. dir) project :: (Fractional a, Vector a v, DotProd a v) => v a -> v a -> v a project what dir = what &- dir &* ((what &. dir) / (dir &. dir)) -- | Since unit vectors are not a group, we need a separate function. flipNormal :: UnitVector a v n => n a -> n a flipNormal = toNormalUnsafe . neg . fromNormal -- | Cross product class CrossProd v where crossprod :: v -> v -> v (&^) :: v -> v -> v (&^) = crossprod -- | Pointwise multiplication class Pointwise v where pointwise :: v -> v -> v (&!) :: v -> v -> v (&!) = pointwise infix 7 &^ infix 7 &! -- | conversion between vectors (and matrices) of different dimensions class Extend a u v where extendZero :: u a -> v a -- ^ example: @extendZero (V2 5 6) = V4 5 6 0 0@ extendWith :: a -> u a -> v a -- ^ example: @extendWith 1 (V2 5 6) = V4 5 6 1 1@ trim :: v a -> u a -- ^ example: @trim (V4 5 6 7 8) = V2 5 6@ -- | makes a diagonal matrix from a vector class Diagonal s t | t->s where diag :: s -> t class Transpose m n | m -> n, n -> m where transpose :: m -> n class SquareMatrix m where inverse :: m -> m idmtx :: m {-# RULES "transpose is an involution" forall m. transpose (transpose m) = m "inverse is an involution" forall m. inverse (inverse m) = m #-} class SquareMatrix (m a) => Orthogonal a m o | m -> o, o -> m where fromOrtho :: o a -> m a toOrthoUnsafe :: m a -> o a class (AbelianGroup m, SquareMatrix m) => MatrixNorms a m where frobeniusNorm :: m -> a -- ^ the frobenius norm (= euclidean norm in the space of matrices) matrixDistance :: m -> m -> a -- ^ euclidean distance in the space of matrices operatorNorm :: m -> a -- ^ (euclidean) operator norm (not implemented yet) matrixDistance m n = frobeniusNorm (n &- m) operatorNorm = error "operatorNorm: not implemented yet" -- | Outer product (could be unified with Diagonal?) class Tensor t v | t -> v where outer :: v -> v -> t class Determinant a m where det :: m -> a class Dimension a where dim :: a -> Int -- | Householder matrix, see . -- In plain words, it is the reflection to the hyperplane orthogonal to the input vector. householder :: (Vector a v, UnitVector a v u, SquareMatrix (m a), Vector a m, Tensor (m a) (v a)) => u a -> m a householder u = idmtx &- (2 *& outer v v) where v = fromNormal u householderOrtho :: (Vector a v, UnitVector a v u, SquareMatrix (m a), Vector a m, Tensor (m a) (v a), Orthogonal a m o) => u a -> o a householderOrtho = toOrthoUnsafe . householder -- | \"Projective\" matrices have the following form: the top left corner -- is an any matrix, the bottom right corner is 1, and the top-right -- column is zero. These describe the affine orthogonal transformation of -- the space one dimension less. class (Vector a v, Orthogonal a n o, Diagonal (v a) (n a)) => Projective a v n o m p | m -> p, p -> m, p -> o, o -> p, p -> n, n -> p, p -> v, v -> p, n -> o, n -> v, v -> n where fromProjective :: p a -> m a toProjectiveUnsafe :: m a -> p a orthogonal :: o a -> p a linear :: n a -> p a translation :: v a -> p a scaling :: v a -> p a