Ticket #3005: Vectors2.hs

File Vectors2.hs, 4.0 KB (added by shelarcy, 4 months ago)
Line 
1{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances, TypeFamilies, ScopedTypeVariables #-}
2
3import TypeNats2
4import qualified Data.List as List (transpose)
5
6class Multiply a b where
7  type Product a b
8  times :: a -> b -> Product a b
9
10data Matrix r c t = Matrix [[t]] deriving (Eq)
11
12type Vector n t = Matrix n One t
13
14instance (Show t) => Show (Matrix r c t) where
15  show (Matrix rows) = unlines (map show rows)
16
17-- creation functions
18
19matrix :: (Natural r, Natural c) => r -> c -> [[t]] -> (Matrix r c t)
20matrix r c rows | (validNumRows r rows) && (all (validNumCols c) rows) = Matrix rows
21matrix _ _ _ = error "dimension mismatch"
22
23vector :: (Natural n) => n -> [t] -> Vector n t
24vector n ts = matrix n one (map (\t -> [t]) ts)
25
26validNumRows r rows = (length rows) == (naturalToIntegral r)
27
28validNumCols c row = (length row) == (naturalToIntegral c)
29
30-- element accessors
31
32unsafeAt :: (Natural r, Natural c) => Int -> Int -> Matrix r c t -> t
33unsafeAt r c (Matrix rows) = rows !! r !! c
34
35at :: (Natural r, Natural c, Natural rAt, Natural cAt, LEqNat rAt r, LEqNat cAt c) => rAt -> cAt -> Matrix r c t -> t
36at r c m = unsafeAt (naturalToIntegral r) (naturalToIntegral c) m
37
38rows :: forall r c t. (Natural r, Natural c) => (Matrix r c t) -> Integer
39rows _ = naturalToIntegral (undefined :: r)
40
41cols :: forall r c t. (Natural r, Natural c) => (Matrix r c t) -> Integer
42cols _ = naturalToIntegral (undefined :: c)
43
44size :: forall r c t. (Natural r, Natural c) => (Matrix r c t) -> (Integer,Integer)
45size m = (rows m, cols m)
46
47-- matrix is a functor
48
49instance (Natural r, Natural c) => Functor (Matrix r c) where
50  fmap f (Matrix rows) = Matrix (map (map f) rows)
51
52-- matrix products
53
54matrixPartialProductElem n m1 m2 i j = sum [(unsafeAt i r m1) * (unsafeAt r j m2) | r <- [0..n-1]]
55
56matrixPartialProductRow n m1 m2 i cols = [matrixPartialProductElem n m1 m2 i j | j <- [0..cols-1]]
57
58instance (Natural a, Natural b, Natural c, Num t) => Multiply (Matrix a b t) (Matrix b c t) where
59  type Product (Matrix a b t) (Matrix b c t) = Matrix a c t
60  times m1 m2 = let rows = naturalToIntegral (undefined :: a)
61                    cols = naturalToIntegral (undefined :: c)
62                    n = naturalToIntegral (undefined :: b)
63                in
64                    Matrix [matrixPartialProductRow n m1 m2 i cols | i <- [0..rows-1]]
65
66instance (Natural r, Natural c, Num t) => Multiply t (Matrix r c t) where
67  type Product t (Matrix r c t) = Matrix r c t
68  times x m = fmap (* x) m
69
70-- useful math
71
72zipMatrixWith :: (Natural rows, Natural cols) => (a -> b -> c) -> (Matrix rows cols a) -> (Matrix rows cols b) -> (Matrix rows cols c)
73zipMatrixWith f (Matrix m1) (Matrix m2) = Matrix (zipWith (\row1 row2 -> zipWith f row1 row2) m1 m2)
74
75dot :: (Natural n, Num t) => Vector n t -> Vector n t -> Vector n t
76dot = zipMatrixWith (*)
77
78cross :: (Num t) => Vector Three t -> Vector Three t -> Vector Three t
79cross v1 v2 = vector three [a2 * b3 - a3 * b2, a3 * b1 - a1 * b3, a1 * b2 - a2 * b1]
80              where a1 = at zero zero v1
81                    a2 = at one zero v1
82                    a3 = at two zero v1
83                    b1 = at zero zero v2
84                    b2 = at one zero v2
85                    b3 = at two zero v2
86
87transpose :: (Natural a, Natural b) => (Matrix a b t) -> (Matrix b a t)
88transpose (Matrix rows) = Matrix (List.transpose rows)
89
90plus :: (Natural r, Natural c, Num t) => (Matrix r c t) -> (Matrix r c t) -> (Matrix r c t)
91plus = zipMatrixWith (+)
92
93minus :: (Natural r, Natural c, Num t) => (Matrix r c t) -> (Matrix r c t) -> (Matrix r c t)
94minus = zipMatrixWith (-)
95
96identity :: (Natural n, Num t) => n -> Matrix n n t
97identity n = matrix n n [[identity' (i == j) | j <- [0..(naturalToIntegral n)-1]] | i <- [0..(naturalToIntegral n)-1]]
98
99identity' :: (Num t) => Bool -> t
100identity' True = 1
101identity' False = 0
102
103flatten :: [[a]] -> [a]
104flatten = foldl (++) []
105
106norm :: (Natural n, Floating t) => (Vector n t) -> t
107norm (Matrix rows) = sqrt $ sum (map (\x -> x * x) (flatten rows))