module Graphics.PS.Matrix ( Matrix(Matrix)
                          , identity
                          , translation, scaling, rotation ) where

type R = Double

data Matrix = Matrix R R R R R R
              deriving (Eq, Show)

type M = Matrix

data MIx = I0 
            | I1 
            | I2 
              deriving (Eq, Show, Enum)

row :: M -> MIx -> (R, R, R)
row (Matrix a b _ _ _ _) I0 = (a, b, 0)
row (Matrix _ _ c d _ _) I1 = (c, d, 0)
row (Matrix _ _ _ _ e f) I2 = (e, f, 1)

col :: M -> MIx -> (R, R, R)
col (Matrix a _ c _ e _) I0 = (a, c, e)
col (Matrix _ b _ d _ f) I1 = (b, d, f)
col (Matrix _ _ _ _ _ _) I2 = (0, 0, 1)

multiply :: M -> M -> M
multiply a b =
    let f i j = let (r1, r2, r3) = row a i
                    (c1, c2, c3) = col b j
                in r1 * c1 + r2 * c2 + r3 * c3
        m = Matrix
    in m (f I0 I0) (f I0 I1) (f I1 I0) (f I1 I1) (f I2 I0) (f I2 I1)

pointwise :: (R -> R) -> M -> M 
pointwise g (Matrix a b c d e f) =
    Matrix (g a) (g b) (g c) (g d) (g e) (g f)

pointwise2 :: (R -> R -> R) -> M -> M -> M
pointwise2 g (Matrix a b c d e f) (Matrix a' b' c' d' e' f') =
    Matrix (g a a') (g b b') (g c c') (g d d') (g e e') (g f f')

instance Num Matrix where
    (*) = multiply
    (+) = pointwise2 (+)
    (-) = pointwise2 (-)
    abs = pointwise abs
    signum = pointwise signum
    fromInteger n = let n' = fromInteger n
                    in Matrix n' 0 0 n' 0 0 

translation :: R -> R -> M
translation = Matrix 1 0 0 1

scaling :: R -> R -> M
scaling x y = Matrix x 0 0 y 0 0

rotation :: R -> M
rotation a =
    let c = cos a
        s = sin a
        t = negate s
    in Matrix c s t c 0 0

identity :: M
identity = Matrix 1 0 0 1 0 0

{--
translate :: R -> R -> M -> M
translate x y m = m * (translation x y)

scale :: R -> R -> M -> M
scale x y m = m * (scaling x y)

rotate :: R -> M -> M
rotate r m = m * (rotation r)

scalarMultiply :: R -> M -> M
scalarMultiply scalar = pointwise (* scalar)

adjoint :: M -> M
adjoint (Matrix a b c d x y) = M d (-b) (-c) a (c * y - d * x) (b * x - a * y)

invert :: M -> M
invert m = scalarMultiply (recip d) (adjoint m)
  where (Matrix xx yx xy yy _ _) = m
        d = xx*yy - yx*xy
--}