module Monocle.Core (
-- * Morphism
Morphism (..),
FuncT (..),
Mor (..),
-- * Basic functions on morphisms
nrm,
atomary,
arrow,
element,
coelement,
object,
objectId,
tid,
-- * Tensor product functoriality
vert,
horz,
-- * Utilities
collect,
-- * Rules
Rule (..),
(\==),
apply
)where

import Monocle.Utils
import qualified Data.Map as Map

-- | Class of morphisms.
class (Eq a) => Morphism a where
-- | Returns domain of the given morphism.
dom :: a -> a
-- | Returns codomain of the given morphism.
cod :: a -> a
-- | Checks whether morphism is id.
isId :: a -> Bool
-- | Composition of two morphisms (should be associative).
(\.) :: a -> a -> a
-- | Tensor product of two morphisms.
(\*) :: a -> a -> a

infixl 7  \*
infixl 6  \.

data ArrowData a = ArrowData { dom' :: Mor a, cod' :: Mor a, isId' :: Bool } deriving (Eq, Ord)

-- | Types of the functional modifier.
data FuncT
-- | Function on objects
= Function
-- | Covariant functor
| Functor
-- | Contravariant functor
| Cofunctor
deriving (Eq, Ord)

-- | Morphism data type
data Mor a
-- | Atomary morphism
= Arrow a (ArrowData a)
-- | Identity morphism
| Id a
-- | Tensor product of morphisms
| Tensor [Mor a]
-- | Composition of morphisms
| Composition [Mor a]
-- | Functionional modifier
| Func String [Mor a] FuncT
-- | Naturally transformational modifier
| Transform String (Mor a) [Mor a]
deriving (Eq, Ord)

instance (Eq a) => Morphism (Mor a) where
dom x = case x of
Arrow _ dt -> dom' dt
Id _ -> x
Tensor [] -> Tensor []
Tensor y -> nrm (Tensor (map dom y))
Composition y -> (dom . head) y
Func y xs t -> case t of
Function -> x
Functor -> Func y (map dom xs) Functor
Cofunctor -> Func y (map cod xs) Cofunctor
Transform _ f _ -> dom f
cod x = case x of
Arrow _ dt -> cod' dt
Id _ -> x
Tensor [] -> Tensor []
Tensor y -> nrm (Tensor (map cod y))
Composition y -> (cod . last) y
Func y xs t -> case t of
Function -> x
Functor -> Func y (map cod xs) Functor
Cofunctor -> Func y (map dom xs) Cofunctor
Transform _ f _ -> cod f
isId x = case x of
Arrow _ dt -> isId' dt
Id _ -> True
Tensor [] -> True
Tensor y -> and (map isId y)
Composition y -> and (map isId y)
Func _ xs t -> case t of
Function -> True
Functor -> and (map isId xs)
Cofunctor -> and (map isId xs)
Transform _ f _ -> isId f
f \. g
| dom(f) == cod(g) = if isId f then g else if isId g then f else
case f of
Composition x -> case g of
Composition y -> Composition (y++x)
_ -> Composition (g:x)
_ -> case g of
Composition y -> Composition (y++[f])
_ -> Composition [g, f]
| otherwise = error "compose: domain and codomain of composing arrows do not coincide"

f \* g = case f of
Tensor [] -> g
Tensor x -> case g of
Tensor [] -> f
Tensor y -> Tensor (x++y)
_ -> Tensor (x++[g])
_ -> case g of
Tensor [] -> f
Tensor y -> Tensor ([f]++y)
_ -> Tensor [f, g]

-- Basic functions on morphisms

-- | Normalizes the term representing morphism, e.g. turns @((a \* b) \* c)@ to @(a \* b \* c)@
nrm :: (Eq t) => Mor t -> Mor t
nrm f = case f of
Arrow ff (ArrowData d c ii) -> Arrow ff (ArrowData (nrm d) (nrm c) ii)
Tensor [x] -> x
Tensor x -> untens (map nrm x)
Composition [x] -> x
Composition x -> uncomp (map nrm x)
Func nm xs t -> Func nm (map nrm xs) t
Transform nm x xs -> Transform nm (nrm x) (map nrm xs)
_ -> f
where
untens x = case x of
y:ys -> y \* (untens ys)
_ -> Tensor x
uncomp x = case x of
y:[] -> y
y:ys -> (uncomp ys) \. y
_ -> Composition x

-- | Checks whether morphism is an atomary formula.
atomary :: (Eq t) => Mor t -> Bool
atomary f = case f of
Arrow _ _ -> True
Id _ -> True
Transform _ _ _ -> True
Tensor [] -> True
_ -> False

getInfo f = case f of
Arrow x _ -> [x]
Id x -> [x]
Tensor [] -> []
Func x _ _ -> [x]
Transform x _ _ -> [x]
_ -> error "getName: wrong argument"

getData (Arrow _ dt) = dt
getData s@(Id _) = ArrowData s s True
getData s@_ = ArrowData (dom s) (cod s) (isId s)

-- | Creates 'Arrow' by morphism information (e.g. name), domain and codomain.
arrow :: a -> Mor a -> Mor a -> Mor a

-- | Creates generalized element, i.e. an arrow from the tensorial Id to the given object.
element :: a -> Mor a -> Mor a
element nm acod = arrow nm (Tensor []) acod

-- | Creates generalized coelement, i.e. an arrow from the the given object to the tensorial Id.
coelement :: a -> Mor a -> Mor a

-- | Creates object (actually it's id). Same as 'objectId'.
object :: a -> Mor a
object nm = Id nm

-- | Creates object id. Same as 'object'.
objectId :: a -> Mor a
objectId nm = Id nm

-- | Tensorial Id, @tid \* f == f@ in strict monoidal category.
tid :: Mor a
tid = Tensor []

width :: (Eq a) => Mor a -> Int
width f = case f of
Id _ -> 1
Tensor [] -> 0
Tensor x -> sum (map width x)
Composition x -> maximum  (map width x)
_ -> max (width(dom f)) (width(cod f))

height :: (Eq a) => Mor a -> Int
height f = case f of
Id _ -> 0
Tensor [] -> 0
Tensor x -> maximum (map height x)
Composition x -> sum  (map height x)
Transform _ x _ -> height x
_ -> 1

-- Tensor product functoriality

-- | Turns recursively @(a \\* b) \\. (c \\* d)@ to @(a \\. c) \\* (b \\. d)@.
vert :: (Eq a) => Mor a -> Mor a
vert f = case f of
Composition (y1:y2:[]) -> vertPair y1 y2
Composition (y1:y2:ys) -> case (vertPair y1 y2) of
s@(Tensor _) -> vert ((Composition ys) \. s)
_ -> (vert ((Composition ys) \. y2)) \. y1
_ -> vertInside f
where
vertPair f g = case f of
Tensor (x1':x2:xs) -> case g of
Tensor (y1':y2:ys) -> let x1 = vert x1'; y1 = vert y1' in
if x1 /= x1' || y1 /= y1' then
vertPair (x1 \* (Tensor (x2:xs))) (y1 \* (Tensor (y2:ys)))
else let wx = width (cod x1); wy = width (dom y1) in
if wx == wy
then (y1 \. x1) \*
nrm (vertPair (Tensor (x2:xs)) (Tensor (y2:ys)))
else if wx > wy
then nrm (vertPair f (Tensor ((y1 \* y2) : ys)))
else nrm (vertPair (Tensor ((x1 \* x2) : xs)) g)
_ -> g \. f
_ -> g \. f
vertInside f = case f of
Tensor y -> Tensor (map vert y)
Transform s x xs -> Transform s (vert x) xs
Func s xs t -> Func s (map vert xs) t
_ -> f

-- | Turns recursively @(a \\. c) \\* (b \\. d)@ to @(a \\* b) \\. (c \\* d)@.
horz :: (Eq a) => Mor a -> Mor a
horz f = case f of
Tensor (y1:y2:[]) -> horzPair y1 y2
Tensor (y1:y2:ys) -> case (horzPair y1 y2) of
s@(Composition _) -> horz (s \* (Tensor ys))
_ -> y1 \* (horz (y2 \* (Tensor ys)))
_ -> horzInside f
where
horzPair f g = case f of
Composition (x1':x2:xs) -> case g of
Composition (y1':y2:ys) -> let x1 = horz x1'; y1 = horz y1' in
if x1 /= x1' || y1 /= y1' then
horzPair ((Composition (x2:xs)) \. x1)
((Composition (y2:ys)) \. y1)
else let hx = height x1; hy = height y1 in
if hx == hy
then nrm (horzPair (Composition (x2:xs)) (Composition (y2:ys))) \.
(x1 \* y1)
else if hx > hy
then nrm (horzPair f (Composition ((y2 \. y1) : ys)))
else nrm (horzPair (Composition ((x2 \. x1) : xs)) g)
_ -> f \* g
_ -> f \* g
horzInside f = case f of
Composition y -> Composition (map horz y)
Transform s x xs -> Transform s (horz x) xs
Func s xs t -> Func s (map horz xs) t
_ -> f

mapMorM :: (Eq a, Monad m) => m () -> (Mor a -> m (Mor a)) -> Mor a -> m (Mor a)
mapMorM prep func f = case f of
Tensor xs@(_:_) -> do
prep; xs' <- mapM (mapMorM prep func) xs
func \$ nrm \$ Tensor xs'
Composition xs -> do
prep; xs' <- mapM (mapMorM prep func) xs
func \$ nrm \$ Composition xs'
Func ff xs t -> do
prep; xs' <- mapM (mapMorM prep func) xs
func \$ nrm \$ Func ff xs' t
Transform ff x xs -> do
prep;
x' <- mapMorM prep func x
xs' <- mapM (mapMorM prep func) xs
func \$ nrm \$ Transform ff x' xs'
_ -> func f

mapMorM' :: (Eq a, Eq b, Monad m) => (Mor a -> m (Mor b)) -> Mor a -> m (Mor b)
mapMorM' func f = case f of
Tensor xs@(_:_) -> do
xs' <- mapM (mapMorM' func) xs
return \$ nrm \$ Tensor xs'
Composition xs -> do
xs' <- mapM (mapMorM' func) xs
return \$ nrm \$ Composition xs'
Func ff xs t -> do
xs' <- mapM (mapMorM' func) xs
return \$ nrm \$ Func ff xs' t
Transform ff x xs -> do
x' <- mapMorM' func x
xs' <- mapM (mapMorM' func) xs
return \$ nrm \$ Transform ff x' xs'
_ -> func f

transMor mor inits wlk       = evalState ((mapMorM (return ()) wlk) mor) inits
calcMor mor inits wlk        = execState ((mapMorM (return ()) wlk) mor) inits
transMorP mor inits prep wlk = evalState ((mapMorM prep wlk) mor) inits
calcMorP mor inits prep wlk  = execState ((mapMorM prep wlk) mor) inits
transMor' mor inits wlk      = evalState ((mapMorM' wlk) mor) inits
calcMor' mor inits wlk       = execState ((mapMorM' wlk) mor) inits

-- Show

instance (Printable a, Eq a) => Printable (Mor a) where
str f = case f of
Arrow fx _ -> str \$ nm fx
Id fx -> str \$ nm fx
Tensor [] -> "I"
Tensor xs -> close \$ op " * " (map show xs)
Composition xs -> close \$ op " . " (map show xs)
Func nm xs t -> nm ++ (close \$ op ", " (map show xs))
Transform nm x xs -> nm ++ (close1 \$ op ", " (map show xs)) ++ (close \$ show x)
where
nm x = let s = str x in if (head s == '*') then tail s else s
op s [] = ""
op s (x:xs) = foldl (\x' y' -> x'++s++y') x xs
close x = "(" ++ x ++ ")"
close1 x = "[" ++ x ++ "]"

instance (Printable a, Eq a) => Show (Mor a) where
show f = str f

-- Utilities

merge :: (Eq a, Eq b) => Mor a -> Mor b -> Maybe (Mor (Mor a, b))
merge m1 m2 = case (m1, m2) of
(_, Id f2) -> Just (Id (m1, f2))
(_, Arrow f2 (ArrowData d2 c2 ii2)) ->
let d1 = dom m1; c1 = cod m1 in do
d' <- merge d1 d2; c' <- merge c1 c2
return \$ Arrow (m1, f2) (ArrowData d' c' ii2)
(Func f1 xs1 t1, Func f2 xs2 t2) ->
if f1 /= f2 || t1 /= t2 then Nothing
else do
xs' <- mapM (\(x', y') -> merge x' y') (zip xs1 xs2)
return \$ Func f1 xs' t1
(Transform f1 x1 xs1, Transform f2 x2 xs2) ->
if f1 /= f2 then Nothing
else do
x' <- merge x1 x2
xs' <- mapM (\(xx, yy) -> merge xx yy) (zip xs1 xs2)
return \$ Transform f1 x' xs'
(Tensor xs1, Tensor xs2) ->
if length xs1 /= length xs2 then Nothing
else do
xs' <- mapM (\(x', y') -> merge x' y') (zip xs1 xs2)
return \$ Tensor xs'
(Composition xs1, Composition xs2) ->
if length xs1 /= length xs2 then Nothing
else do
xs' <- mapM (\(x', y') -> merge x' y') (zip xs1 xs2)
return \$ Composition xs'
_ -> Nothing

match f g = case merge f g of
Nothing -> (False, Map.empty)
Just h -> calcMor h (True, Map.empty) \$
\x -> case x of
Arrow (x1, x2) _ -> cmp x1 x2 x
Id (x1, x2) -> cmp x1 x2 x
_ -> return x
where
cmp x1 x2 x' = do
(tv, mp) <- get
put \$ if not tv || (head x2)=='*' then (tv, mp) else
if Map.member x2 mp
then (tv && (mp Map.! x2) == x1, mp)
else
(tv, Map.insert x2 x1 mp)
return x'

subst :: (Ord a, Eq b) => Map.Map a (Mor b) -> Mor a -> Mor b
subst mp f = transMor' f () \$ \x ->
case x of
Arrow f _ -> return \$ mp Map.! f
Id f ->  return \$ mp Map.! f

subst' :: (Ord a, Eq b, Printable b) => Map.Map a (Mor b) -> Mor a -> Mor b
subst' mp f = transMor' f () \$ \x ->
case x of
Arrow f _ -> let arr' = mp Map.! f in
case arr' of
Arrow f' _ -> return arr'
_ -> error \$ "subst': no match in" ++ (show arr')
Id f ->  let arr' = mp Map.! f in
case arr' of
Id f' -> return arr'
_ -> error \$ "subst': no match in " ++ (show arr')

-- | Collects atomary subterms of the given arrow as keys of the map.
collect :: (Num b, Ord a) => Mor a -> Map.Map (Mor a) b
collect f = fst \$ calcMor' f (Map.empty, 1) \$ \x ->
case x of
Arrow _ _ -> do
(mp, n) <- get
put \$ (Map.insert x n mp, n+1)
return x
Id _ ->  do
(mp, n) <- get
put \$ (Map.insert x n mp, n+1)
return x
_ -> return x

-- Rules

-- | Rule type
data Rule a
-- | Declares equality of two morphisms
= DefEqual (Mor a) (Mor a)

-- | @x \\== y@ is the same as @'DefEqual' x y@
(\==) :: Mor a -> Mor a -> Rule a
x \== y = DefEqual x y

infix 4  \==

r'left (DefEqual x _) = x
r'right (DefEqual _ x) = x

-- | Applies the 'Rule' to the given morphism
apply :: (Eq a) => Rule String -> Mor a -> Mor a
apply (DefEqual l r) f =
let (tv, mp) = match f l in
if not tv then let (tv, mp) = match f r in
if not tv then  error "apply: no match"
else subst mp r
else subst mp r

instance (Printable a, Eq a) => Show (Rule a) where
show (DefEqual l r) = (show l)++" == "++(show r)