{-# LANGUAGE TypeFamilies, RankNTypes, FlexibleInstances, ScopedTypeVariables, UndecidableInstances, MultiParamTypeClasses #-}
module Data.Sized.Matrix
( module Data.Sized.Matrix
, module Data.Sized.Ix
) where
import Data.Array as A hiding (indices,(!), ixmap, assocs)
import qualified Data.Array as A
import Prelude as P hiding (all)
import Control.Applicative
import qualified Data.Traversable as T
import qualified Data.Foldable as F
import qualified Data.List as L
import Numeric
import Data.Sized.Ix
data Matrix ix a = Matrix (Array ix a)
| NullMatrix
deriving (Eq,Ord)
(!) :: (Size n) => Matrix n a -> n -> a
(!) (Matrix xs) n = xs A.! n
(!) NullMatrix _ = error "Attending to index into a Null Matrix, should *never* happen"
instance (Size i) => Functor (Matrix i) where
fmap f (Matrix xs) = Matrix (fmap f xs)
fmap f NullMatrix = NullMatrix
toList :: (Size i) => Matrix i a -> [a]
toList (Matrix a) = elems a
toList NullMatrix = []
fromList :: forall i a . (Size i) => [a] -> Matrix i a
fromList xs | size witness == 0 = NullMatrix
| size witness == L.length xs = Matrix $ listArray (low,high) xs
| otherwise = error $ "bad length of fromList for Matrix, "
++ "expecting " ++ show (size witness) ++ " elements"
++ ", found " ++ show (L.length xs) ++ " elements."
where
witness :: i
witness = undefined
low :: i
low = minBound
high :: i
high = maxBound
matrix :: (Size i) => [a] -> Matrix i a
matrix = fromList
indices :: (Size i) => Matrix i a -> [i]
indices _ = all
length :: (Size i) => Matrix i a -> Int
length = size . zeroOf
assocs :: (Size i) => Matrix i a -> [(i,a)]
assocs (Matrix a) = A.assocs a
assocs NullMatrix = []
(//) :: (Size i) => Matrix i e -> [(i, e)] -> Matrix i e
(//) (Matrix arr) ixs = Matrix (arr A.// ixs)
(//) (NullMatrix) _ = NullMatrix
accum :: (Size i) => (e -> a -> e) -> Matrix i e -> [(i, a)] -> Matrix i e
accum f (Matrix arr) ixs = Matrix (A.accum f arr ixs)
zeroOf :: (Size i) => Matrix i a -> i
zeroOf _ = minBound
coord :: (Size i) => Matrix i i
coord = fromList all
zipWith :: (Size i) => (a -> b -> c) -> Matrix i a -> Matrix i b -> Matrix i c
zipWith f a b = forAll $ \ i -> f (a ! i) (b ! i)
forEach :: (Size i) => Matrix i a -> (i -> a -> b) -> Matrix i b
forEach a f = Data.Sized.Matrix.zipWith f coord a
forAll :: (Size i) => (i -> a) -> Matrix i a
forAll f = fmap f coord
instance (Size i) => Applicative (Matrix i) where
pure a = fmap (const a) coord
a <*> b = forAll $ \ i -> (a ! i) (b ! i)
mm :: (Size m, Size n, Size m', Size n', n ~ m', Num a) => Matrix (m,n) a -> Matrix (m',n') a -> Matrix (m,n') a
mm a b = forAll $ \ (i,j) -> sum [ a ! (i,r) * b ! (r,j) | r <- all ]
transpose :: (Size x, Size y) => Matrix (x,y) a -> Matrix (y,x) a
transpose = ixmap $ \ (x,y) -> (y,x)
identity :: (Size x, Num a) => Matrix (x,x) a
identity = (\ (x,y) -> if x == y then 1 else 0) <$> coord
above :: (Size m, Size top, Size bottom, Size both
, ADD top bottom ~ both
, SUB both top ~ bottom
, SUB both bottom ~ top
)
=> Matrix (top,m) a -> Matrix (bottom,m) a -> Matrix (both,m) a
above m1 m2 = fromList (toList m1 ++ toList m2)
beside
:: (Size m,
Size left,
Size right,
Size both
, ADD left right ~ both
, SUB both left ~ right
, SUB both right ~ left
) =>
Matrix (m, left) a -> Matrix (m, right) a -> Matrix (m, both) a
beside m1 m2 = transpose (transpose m1 `above` transpose m2)
append ::
(Size left,
Size right,
Size both
, ADD left right ~ both
, SUB both left ~ right
, SUB both right ~ left
) => Matrix left a -> Matrix right a -> Matrix both a
append m1 m2 = fromList (toList m1 ++ toList m2)
ixmap :: (Size i, Size j) => (i -> j) -> Matrix j a -> Matrix i a
ixmap f m = (\ i -> m ! f i) <$> coord
ixfmap :: (Size i, Size j, Functor f) => (i -> f j) -> Matrix j a -> Matrix i (f a)
ixfmap f m = (fmap (\ j -> m ! j) . f) <$> coord
cropAt :: (Index i ~ Index ix, Size i, Size ix) => Matrix ix a -> ix -> Matrix i a
cropAt m corner = ixmap (\ i -> (addIndex corner (toIndex i))) m
rows :: (Bounded n, Size n, Bounded m, Size m) => Matrix (m,n) a -> Matrix m (Matrix n a)
rows a = (\ m -> matrix [ a ! (m,n) | n <- all ]) <$> coord
columns :: (Bounded n, Size n, Bounded m, Size m) => Matrix (m,n) a -> Matrix n (Matrix m a)
columns = rows . transpose
joinRows :: (Bounded n, Size n, Bounded m, Size m) => Matrix m (Matrix n a) -> Matrix (m,n) a
joinRows a = (\ (m,n) -> (a ! m) ! n) <$> coord
joinColumns :: (Bounded n, Size n, Bounded m, Size m) => Matrix n (Matrix m a) -> Matrix (m,n) a
joinColumns a = (\ (m,n) -> (a ! n) ! m) <$> coord
unitRow :: (Size m, Bounded m) => Matrix m a -> Matrix (X1, m) a
unitRow = ixmap snd
unRow :: (Size m, Bounded m) => Matrix (X1, m) a -> Matrix m a
unRow = ixmap (\ n -> (0,n))
unitColumn :: (Size m, Bounded m) => Matrix m a -> Matrix (m, X1) a
unitColumn = ixmap fst
unColumn :: (Size m, Bounded m) => Matrix (m, X1) a -> Matrix m a
unColumn = ixmap (\ n -> (n,0))
squash :: (Size n, Size m) => Matrix m a -> Matrix n a
squash = fromList . toList
instance (Size ix) => T.Traversable (Matrix ix) where
traverse f a = matrix <$> (T.traverse f $ toList a)
instance (Size ix) => F.Foldable (Matrix ix) where
foldMap f m = F.foldMap f (toList m)
showMatrix :: (Size n, Size m) => Matrix (m, n) String -> String
showMatrix m = (joinLines $ map showRow m_rows)
where
m' = forEach m $ \ (x,y) a -> (x == maxBound && y == maxBound,a)
joinLines = unlines . addTail . L.zipWith (++) ("[":repeat " ")
addTail xs = init xs ++ [last xs ++ " ]"]
showRow r = concat (toList $ Data.Sized.Matrix.zipWith showEle r m_cols_size)
showEle (f,str) s = take (s - L.length str) (cycle " ") ++ " " ++ str ++ (if f then "" else ",")
m_cols = columns m
m_rows = toList $ rows m'
m_cols_size = fmap (maximum . map L.length . toList) m_cols
instance (Show a, Size ix) => Show (Matrix ix a) where
show = showMatrix . fmap show . unitRow
newtype S = S String
instance Show S where
show (S s) = s
showAsE :: (RealFloat a) => Int -> a -> S
showAsE i a = S $ showEFloat (Just i) a ""
showAsF :: (RealFloat a) => Int -> a -> S
showAsF i a = S $ showFFloat (Just i) a ""
scanM :: (Size ix, Bounded ix, Enum ix)
=> ((left,a,right) -> (right,b,left))
-> (left, Matrix ix a,right)
-> (right,Matrix ix b,left)
scanM f (l,m,r) = ( fst3 (tmp ! minBound), snd3 `fmap` tmp, trd3 (tmp ! maxBound) )
where tmp = forEach m $ \ i a -> f (prev i, a, next i)
prev i = if i == minBound then l else (trd3 (tmp ! (pred i)))
next i = if i == maxBound then r else (fst3 (tmp ! (succ i)))
fst3 (a,_,_) = a
snd3 (_,b,_) = b
trd3 (_,_,c) = c
scanL :: (Size ix, Bounded ix, Enum ix)
=> ((a,right) -> (right,b))
-> (Matrix ix a,right)
-> (right,Matrix ix b)
scanL = error "to be written"
scanR :: (Size ix, Bounded ix, Enum ix)
=> ((left,a) -> (b,left))
-> (left, Matrix ix a)
-> (Matrix ix b,left)
scanR f (l,m) = ( fst `fmap` tmp, snd (tmp ! maxBound) )
where tmp = forEach m $ \ i a -> f (prev i,a)
prev i = if i == minBound then l else (snd (tmp ! (pred i)))