| 1 | {-# LANGUAGE MultiParamTypeClasses, FlexibleInstances, TypeFamilies, ScopedTypeVariables #-} |
|---|
| 2 | |
|---|
| 3 | import TypeNats2 |
|---|
| 4 | import qualified Data.List as List (transpose) |
|---|
| 5 | |
|---|
| 6 | class Multiply a b where |
|---|
| 7 | type Product a b |
|---|
| 8 | times :: a -> b -> Product a b |
|---|
| 9 | |
|---|
| 10 | data Matrix r c t = Matrix [[t]] deriving (Eq) |
|---|
| 11 | |
|---|
| 12 | type Vector n t = Matrix n One t |
|---|
| 13 | |
|---|
| 14 | instance (Show t) => Show (Matrix r c t) where |
|---|
| 15 | show (Matrix rows) = unlines (map show rows) |
|---|
| 16 | |
|---|
| 17 | -- creation functions |
|---|
| 18 | |
|---|
| 19 | matrix :: (Natural r, Natural c) => r -> c -> [[t]] -> (Matrix r c t) |
|---|
| 20 | matrix r c rows | (validNumRows r rows) && (all (validNumCols c) rows) = Matrix rows |
|---|
| 21 | matrix _ _ _ = error "dimension mismatch" |
|---|
| 22 | |
|---|
| 23 | vector :: (Natural n) => n -> [t] -> Vector n t |
|---|
| 24 | vector n ts = matrix n one (map (\t -> [t]) ts) |
|---|
| 25 | |
|---|
| 26 | validNumRows r rows = (length rows) == (naturalToIntegral r) |
|---|
| 27 | |
|---|
| 28 | validNumCols c row = (length row) == (naturalToIntegral c) |
|---|
| 29 | |
|---|
| 30 | -- element accessors |
|---|
| 31 | |
|---|
| 32 | unsafeAt :: (Natural r, Natural c) => Int -> Int -> Matrix r c t -> t |
|---|
| 33 | unsafeAt r c (Matrix rows) = rows !! r !! c |
|---|
| 34 | |
|---|
| 35 | at :: (Natural r, Natural c, Natural rAt, Natural cAt, LEqNat rAt r, LEqNat cAt c) => rAt -> cAt -> Matrix r c t -> t |
|---|
| 36 | at r c m = unsafeAt (naturalToIntegral r) (naturalToIntegral c) m |
|---|
| 37 | |
|---|
| 38 | rows :: forall r c t. (Natural r, Natural c) => (Matrix r c t) -> Integer |
|---|
| 39 | rows _ = naturalToIntegral (undefined :: r) |
|---|
| 40 | |
|---|
| 41 | cols :: forall r c t. (Natural r, Natural c) => (Matrix r c t) -> Integer |
|---|
| 42 | cols _ = naturalToIntegral (undefined :: c) |
|---|
| 43 | |
|---|
| 44 | size :: forall r c t. (Natural r, Natural c) => (Matrix r c t) -> (Integer,Integer) |
|---|
| 45 | size m = (rows m, cols m) |
|---|
| 46 | |
|---|
| 47 | -- matrix is a functor |
|---|
| 48 | |
|---|
| 49 | instance (Natural r, Natural c) => Functor (Matrix r c) where |
|---|
| 50 | fmap f (Matrix rows) = Matrix (map (map f) rows) |
|---|
| 51 | |
|---|
| 52 | -- matrix products |
|---|
| 53 | |
|---|
| 54 | matrixPartialProductElem n m1 m2 i j = sum [(unsafeAt i r m1) * (unsafeAt r j m2) | r <- [0..n-1]] |
|---|
| 55 | |
|---|
| 56 | matrixPartialProductRow n m1 m2 i cols = [matrixPartialProductElem n m1 m2 i j | j <- [0..cols-1]] |
|---|
| 57 | |
|---|
| 58 | instance (Natural a, Natural b, Natural c, Num t) => Multiply (Matrix a b t) (Matrix b c t) where |
|---|
| 59 | type Product (Matrix a b t) (Matrix b c t) = Matrix a c t |
|---|
| 60 | times m1 m2 = let rows = naturalToIntegral (undefined :: a) |
|---|
| 61 | cols = naturalToIntegral (undefined :: c) |
|---|
| 62 | n = naturalToIntegral (undefined :: b) |
|---|
| 63 | in |
|---|
| 64 | Matrix [matrixPartialProductRow n m1 m2 i cols | i <- [0..rows-1]] |
|---|
| 65 | |
|---|
| 66 | instance (Natural r, Natural c, Num t) => Multiply t (Matrix r c t) where |
|---|
| 67 | type Product t (Matrix r c t) = Matrix r c t |
|---|
| 68 | times x m = fmap (* x) m |
|---|
| 69 | |
|---|
| 70 | -- useful math |
|---|
| 71 | |
|---|
| 72 | zipMatrixWith :: (Natural rows, Natural cols) => (a -> b -> c) -> (Matrix rows cols a) -> (Matrix rows cols b) -> (Matrix rows cols c) |
|---|
| 73 | zipMatrixWith f (Matrix m1) (Matrix m2) = Matrix (zipWith (\row1 row2 -> zipWith f row1 row2) m1 m2) |
|---|
| 74 | |
|---|
| 75 | dot :: (Natural n, Num t) => Vector n t -> Vector n t -> Vector n t |
|---|
| 76 | dot = zipMatrixWith (*) |
|---|
| 77 | |
|---|
| 78 | cross :: (Num t) => Vector Three t -> Vector Three t -> Vector Three t |
|---|
| 79 | cross v1 v2 = vector three [a2 * b3 - a3 * b2, a3 * b1 - a1 * b3, a1 * b2 - a2 * b1] |
|---|
| 80 | where a1 = at zero zero v1 |
|---|
| 81 | a2 = at one zero v1 |
|---|
| 82 | a3 = at two zero v1 |
|---|
| 83 | b1 = at zero zero v2 |
|---|
| 84 | b2 = at one zero v2 |
|---|
| 85 | b3 = at two zero v2 |
|---|
| 86 | |
|---|
| 87 | transpose :: (Natural a, Natural b) => (Matrix a b t) -> (Matrix b a t) |
|---|
| 88 | transpose (Matrix rows) = Matrix (List.transpose rows) |
|---|
| 89 | |
|---|
| 90 | plus :: (Natural r, Natural c, Num t) => (Matrix r c t) -> (Matrix r c t) -> (Matrix r c t) |
|---|
| 91 | plus = zipMatrixWith (+) |
|---|
| 92 | |
|---|
| 93 | minus :: (Natural r, Natural c, Num t) => (Matrix r c t) -> (Matrix r c t) -> (Matrix r c t) |
|---|
| 94 | minus = zipMatrixWith (-) |
|---|
| 95 | |
|---|
| 96 | identity :: (Natural n, Num t) => n -> Matrix n n t |
|---|
| 97 | identity n = matrix n n [[identity' (i == j) | j <- [0..(naturalToIntegral n)-1]] | i <- [0..(naturalToIntegral n)-1]] |
|---|
| 98 | |
|---|
| 99 | identity' :: (Num t) => Bool -> t |
|---|
| 100 | identity' True = 1 |
|---|
| 101 | identity' False = 0 |
|---|
| 102 | |
|---|
| 103 | flatten :: [[a]] -> [a] |
|---|
| 104 | flatten = foldl (++) [] |
|---|
| 105 | |
|---|
| 106 | norm :: (Natural n, Floating t) => (Vector n t) -> t |
|---|
| 107 | norm (Matrix rows) = sqrt $ sum (map (\x -> x * x) (flatten rows)) |
|---|