module Monocle.Core (
    -- * Morphism
    Morphism (..),
    FuncT (..),
    Mor (..),
    -- * Basic functions on morphisms
    nrm,
    atomary,
    arrow,
    element,
    coelement,
    object,
    objectId,
    tid,
    -- * Tensor product functoriality
    vert,
    horz,
    -- * Utilities
    collect,
    -- * Rules
    Rule (..),
    (\==),
    apply
    )where

import Monocle.Utils
import Control.Monad.State
import qualified Data.Map as Map

-- | Class of morphisms.
class (Eq a) => Morphism a where
    -- | Returns domain of the given morphism (actually its id).
    dom :: a -> a
    -- | Returns codomain of the given morphism  (actually its id).
    cod :: a -> a
    -- | Checks whether morphism is id.
    isId :: a -> Bool
    -- | Composition of two morphisms (should be associative).
    (\.) :: a -> a -> a
    -- | Tensor product of two morphisms.
    (\*) :: a -> a -> a

infixl 7  \*
infixl 6  \.

data ArrowData a = ArrowData { dom' :: Mor a, cod' :: Mor a, isId' :: Bool } deriving (Eq, Ord)

-- | Types of the functional modifier.
data FuncT
    -- | Function on objects
    = Function
    -- | Covariant functor
    | Functor
    -- | Contravariant functor
    | Cofunctor
    deriving (Eq, Ord)

-- | Morphism data type
data Mor a
    -- | Atomary morphism
    = Arrow a (ArrowData a)
    -- | Identity morphism
    | Id a
    -- | Tensor product of morphisms
    | Tensor [Mor a]
    -- | Composition of morphisms
    | Composition [Mor a]
    -- | Functionional modifier
    | Func String [Mor a] FuncT
    -- | Naturally transformational modifier
    | Transform String (Mor a) [Mor a]
    deriving (Eq, Ord)

instance (Eq a) => Morphism (Mor a) where
    dom x = case x of
        Arrow _ dt -> dom' dt
        Id _ -> x
        Tensor [] -> Tensor []
        Tensor y -> nrm (Tensor (map dom y))
        Composition y -> (dom . head) y
        Func y xs t -> case t of
            Function -> x
            Functor -> Func y (map dom xs) Functor
            Cofunctor -> Func y (map cod xs) Cofunctor
        Transform _ f _ -> dom f
    cod x = case x of
        Arrow _ dt -> cod' dt
        Id _ -> x
        Tensor [] -> Tensor []
        Tensor y -> nrm (Tensor (map cod y))
        Composition y -> (cod . last) y
        Func y xs t -> case t of
            Function -> x
            Functor -> Func y (map cod xs) Functor
            Cofunctor -> Func y (map dom xs) Cofunctor
        Transform _ f _ -> cod f
    isId x = case x of
        Arrow _ dt -> isId' dt
        Id _ -> True
        Tensor [] -> True
        Tensor y -> and (map isId y)
        Composition y -> and (map isId y)
        Func _ xs t -> case t of
            Function -> True
            Functor -> and (map isId xs)
            Cofunctor -> and (map isId xs)
        Transform _ f _ -> isId f
    f \. g
        | dom(f) == cod(g) = if isId f then g else if isId g then f else
            case f of
                Composition x -> case g of
                    Composition y -> Composition (y++x)
                    _ -> Composition (g:x)
                _ -> case g of
                    Composition y -> Composition (y++[f])
                    _ -> Composition [g, f]
        | otherwise = error "compose: domain and codomain of composing arrows do not coincide"

    f \* g = case f of
        Tensor [] -> g
        Tensor x -> case g of
            Tensor [] -> f
            Tensor y -> Tensor (x++y)
            _ -> Tensor (x++[g])
        _ -> case g of
            Tensor [] -> f
            Tensor y -> Tensor ([f]++y)
            _ -> Tensor [f, g]

-- Basic functions on morphisms

-- | Normalizes the term representing morphism, e.g. turns @((a \\* b) \\* c)@ to @(a \\* b \\* c)@
nrm :: (Eq t) => Mor t -> Mor t
nrm f = case f of
    Arrow ff (ArrowData d c ii) -> Arrow ff (ArrowData (nrm d) (nrm c) ii)
    Tensor [x] -> x
    Tensor x -> untens (map nrm x)
    Composition [x] -> x
    Composition x -> uncomp (map nrm x)
    Func nm xs t -> Func nm (map nrm xs) t
    Transform nm x xs -> Transform nm (nrm x) (map nrm xs)
    _ -> f
    where
        untens x = case x of
            y:ys -> y \* (untens ys)
            _ -> Tensor x
        uncomp x = case x of
            y:[] -> y
            y:ys -> (uncomp ys) \. y
            _ -> Composition x

-- | Checks whether morphism is an atomary formula.
atomary :: (Eq t) => Mor t -> Bool
atomary f = case f of
    Arrow _ _ -> True
    Id _ -> True
    Transform _ _ _ -> True
    Tensor [] -> True
    _ -> False

getInfo f = case f of
    Arrow x _ -> [x]
    Id x -> [x]
    Tensor [] -> []
    Func x _ _ -> [x]
    Transform x _ _ -> [x]
    _ -> error "getName: wrong argument"

getData (Arrow _ dt) = dt
getData s@(Id _) = ArrowData s s True
getData s@_ = ArrowData (dom s) (cod s) (isId s)

-- | Creates 'Arrow' by morphism information (e.g. name), domain and codomain.
arrow :: a -> Mor a -> Mor a -> Mor a
arrow nm adom acod = Arrow nm (ArrowData adom acod False)

-- | Creates generalized element, i.e. an arrow from the identity object to the given object.
element :: a -> Mor a -> Mor a
element nm acod = arrow nm (Tensor []) acod

-- | Creates generalized coelement, i.e. an arrow from the the given object to the identity object.
coelement :: a -> Mor a -> Mor a
coelement nm adom = arrow nm adom (Tensor [])

-- | Creates object (actually its id). Same as 'objectId'.
object :: a -> Mor a
object nm = Id nm

-- | Creates object id. Same as 'object'.
objectId :: a -> Mor a
objectId nm = Id nm

-- | Identity object, @tid \\* f == f@ in strict monoidal category.
tid :: Mor a
tid = Tensor []

width :: (Eq a) => Mor a -> Int
width f = case f of
    Id _ -> 1
    Tensor [] -> 0
    Tensor x -> sum (map width x)
    Composition x -> maximum  (map width x)
    _ -> max (width(dom f)) (width(cod f))

height :: (Eq a) => Mor a -> Int
height f = case f of
    Id _ -> 0
    Tensor [] -> 0
    Tensor x -> maximum (map height x)
    Composition x -> sum  (map height x)
    Transform _ x _ -> height x
    _ -> 1

-- Tensor product functoriality

-- | Turns recursively @(a \\* b) \\. (c \\* d)@ to @(a \\. c) \\* (b \\. d)@.
vert :: (Eq a) => Mor a -> Mor a
vert f = case f of
    Composition (y1:y2:[]) -> vertPair y1 y2
    Composition (y1:y2:ys) -> case (vertPair y1 y2) of
        s@(Tensor _) -> vert ((Composition ys) \. s)
        _ -> (vert ((Composition ys) \. y2)) \. y1
    _ -> vertInside f
    where
        vertPair f g = case f of
            Tensor (x1':x2:xs) -> case g of
                Tensor (y1':y2:ys) -> let x1 = vert x1'; y1 = vert y1' in
                    if x1 /= x1' || y1 /= y1' then
                        vertPair (x1 \* (Tensor (x2:xs))) (y1 \* (Tensor (y2:ys)))
                    else let wx = width (cod x1); wy = width (dom y1) in
                        if wx == wy
                            then (y1 \. x1) \*
                                nrm (vertPair (Tensor (x2:xs)) (Tensor (y2:ys)))
                            else if wx > wy
                                then nrm (vertPair f (Tensor ((y1 \* y2) : ys)))
                                else nrm (vertPair (Tensor ((x1 \* x2) : xs)) g)
                _ -> g \. f
            _ -> g \. f
        vertInside f = case f of
            Tensor y -> Tensor (map vert y)
            Transform s x xs -> Transform s (vert x) xs
            Func s xs t -> Func s (map vert xs) t
            _ -> f

-- | Turns recursively @(a \\. c) \\* (b \\. d)@ to @(a \\* b) \\. (c \\* d)@.
horz :: (Eq a) => Mor a -> Mor a
horz f = case f of
    Tensor (y1:y2:[]) -> horzPair y1 y2
    Tensor (y1:y2:ys) -> case (horzPair y1 y2) of
        s@(Composition _) -> horz (s \* (Tensor ys))
        _ -> y1 \* (horz (y2 \* (Tensor ys)))
    _ -> horzInside f
    where
        horzPair f g = case f of
            Composition (x1':x2:xs) -> case g of
                Composition (y1':y2:ys) -> let x1 = horz x1'; y1 = horz y1' in
                    if x1 /= x1' || y1 /= y1' then
                        horzPair ((Composition (x2:xs)) \. x1)
                            ((Composition (y2:ys)) \. y1)
                    else let hx = height x1; hy = height y1 in
                        if hx == hy
                            then nrm (horzPair (Composition (x2:xs)) (Composition (y2:ys))) \.
                                 (x1 \* y1)
                            else if hx > hy
                                then nrm (horzPair f (Composition ((y2 \. y1) : ys)))
                                else nrm (horzPair (Composition ((x2 \. x1) : xs)) g)
                _ -> f \* g
            _ -> f \* g
        horzInside f = case f of
            Composition y -> Composition (map horz y)
            Transform s x xs -> Transform s (horz x) xs
            Func s xs t -> Func s (map horz xs) t
            _ -> f

-- Monad.State support

mapMorM :: (Eq a, Monad m) => m () -> (Mor a -> m (Mor a)) -> Mor a -> m (Mor a)
mapMorM prep func f = case f of
    Tensor xs@(_:_) -> do
        prep; xs' <- mapM (mapMorM prep func) xs
        func $ nrm $ Tensor xs'
    Composition xs -> do
        prep; xs' <- mapM (mapMorM prep func) xs
        func $ nrm $ Composition xs'
    Func ff xs t -> do
        prep; xs' <- mapM (mapMorM prep func) xs
        func $ nrm $ Func ff xs' t
    Transform ff x xs -> do
        prep;
        x' <- mapMorM prep func x
        xs' <- mapM (mapMorM prep func) xs
        func $ nrm $ Transform ff x' xs'
    _ -> func f

mapMorM' :: (Eq a, Eq b, Monad m) => (Mor a -> m (Mor b)) -> Mor a -> m (Mor b)
mapMorM' func f = case f of
    Tensor xs@(_:_) -> do
        xs' <- mapM (mapMorM' func) xs
        return $ nrm $ Tensor xs'
    Composition xs -> do
        xs' <- mapM (mapMorM' func) xs
        return $ nrm $ Composition xs'
    Func ff xs t -> do
        xs' <- mapM (mapMorM' func) xs
        return $ nrm $ Func ff xs' t
    Transform ff x xs -> do
        x' <- mapMorM' func x
        xs' <- mapM (mapMorM' func) xs
        return $ nrm $ Transform ff x' xs'
    _ -> func f

transMor mor inits wlk       = evalState ((mapMorM (return ()) wlk) mor) inits
calcMor mor inits wlk        = execState ((mapMorM (return ()) wlk) mor) inits
transMorP mor inits prep wlk = evalState ((mapMorM prep wlk) mor) inits
calcMorP mor inits prep wlk  = execState ((mapMorM prep wlk) mor) inits
transMor' mor inits wlk      = evalState ((mapMorM' wlk) mor) inits
calcMor' mor inits wlk       = execState ((mapMorM' wlk) mor) inits

-- Show

instance (Printable a, Eq a) => Printable (Mor a) where
    str f = case f of
        Arrow fx _ -> str $ nm fx
        Id fx -> str $ nm fx
        Tensor [] -> "I"
        Tensor xs -> close $ op " * " (map show xs)
        Composition xs -> close $ op " . " (map show xs)
        Func nm xs t -> nm ++ (close $ op ", " (map show xs))
        Transform nm x xs -> nm ++ (close1 $ op ", " (map show xs)) ++ (close $ show x)
        where
            nm x = let s = str x in if (head s == '*') then tail s else s
            op s [] = ""
            op s (x:xs) = foldl (\x' y' -> x'++s++y') x xs
            close x = "(" ++ x ++ ")"
            close1 x = "[" ++ x ++ "]"

instance (Printable a, Eq a) => Show (Mor a) where
    show f = str f

-- Utilities

merge :: (Eq a, Eq b) => Mor a -> Mor b -> Maybe (Mor (Mor a, b))
merge m1 m2 = case (m1, m2) of
    (_, Id f2) -> Just (Id (m1, f2))
    (_, Arrow f2 (ArrowData d2 c2 ii2)) ->
        let d1 = dom m1; c1 = cod m1 in do
            d' <- merge d1 d2; c' <- merge c1 c2
            return $ Arrow (m1, f2) (ArrowData d' c' ii2)
    (Func f1 xs1 t1, Func f2 xs2 t2) ->
        if f1 /= f2 || t1 /= t2 then Nothing
        else do
            xs' <- mapM (uncurry $ merge) (zip xs1 xs2)
            return $ Func f1 xs' t1
    (Transform f1 x1 xs1, Transform f2 x2 xs2) ->
        if f1 /= f2 then Nothing
        else do
            x' <- merge x1 x2
            xs' <- mapM (uncurry $ merge) (zip xs1 xs2)
            return $ Transform f1 x' xs'
    (Tensor xs1, Tensor xs2) ->
        if length xs1 /= length xs2 then Nothing
        else do
            xs' <- mapM (uncurry $ merge) (zip xs1 xs2)
            return $ Tensor xs'
    (Composition xs1, Composition xs2) ->
        if length xs1 /= length xs2 then Nothing
        else do
            xs' <- mapM (uncurry $ merge) (zip xs1 xs2)
            return $ Composition xs'
    _ -> Nothing

match f g = case merge f g of
    Nothing -> (False, Map.empty)
    Just h -> calcMor h (True, Map.empty) $
        \x -> case x of
            Arrow (x1, x2) _ -> cmp x1 x2 x
            Id (x1, x2) -> cmp x1 x2 x
            _ -> return x
            where
                cmp x1 x2 x' = do
                    (tv, mp) <- get
                    put $ if not tv || (head x2)=='*' then (tv, mp) else
                        if Map.member x2 mp
                            then (tv && (mp Map.! x2) == x1, mp)
                        else
                            (tv, Map.insert x2 x1 mp)
                    return x'

subst :: (Ord a, Eq b) => Map.Map a (Mor b) -> Mor a -> Mor b
subst mp f = transMor' f () $ \x ->
    case x of
        Arrow f _ -> return $ mp Map.! f
        Id f ->  return $ mp Map.! f

subst' :: (Ord a, Eq b, Printable b) => Map.Map a (Mor b) -> Mor a -> Mor b
subst' mp f = transMor' f () $ \x ->
    case x of
        Arrow f _ -> let arr' = mp Map.! f in
            case arr' of
                Arrow f' _ -> return arr'
                _ -> error $ "subst': no match in" ++ (show arr')
        Id f ->  let arr' = mp Map.! f in
            case arr' of
                Id f' -> return arr'
                _ -> error $ "subst': no match in " ++ (show arr')

-- | Collects atomary subterms of the given arrow as keys of the map.
collect :: (Num b, Ord a) => Mor a -> Map.Map (Mor a) b
collect f = fst $ calcMor' f (Map.empty, 1) $ \x ->
    case x of
        Arrow _ _ -> do
            (mp, n) <- get
            put $ (Map.insert x n mp, n+1)
            return x
        Id _ ->  do
            (mp, n) <- get
            put $ (Map.insert x n mp, n+1)
            return x
        _ -> return x

-- Rules

-- | Rule type
data Rule a
    -- | Declares equality of two morphisms
    = DefEqual (Mor a) (Mor a)

-- | @x \\== y@ is the same as @'DefEqual' x y@
(\==) :: Mor a -> Mor a -> Rule a
x \== y = DefEqual x y

infix 4  \==

r'left (DefEqual x _) = x
r'right (DefEqual _ x) = x

-- | Applies the 'Rule' to the given morphism
apply :: (Eq a) => Rule String -> Mor a -> Mor a
apply (DefEqual l r) f =
    let (tv, mp) = match f l in
        if not tv then let (tv, mp) = match f r in
            if not tv then  error "apply: no match"
            else subst mp r
        else subst mp r

instance (Printable a, Eq a) => Show (Rule a) where
    show (DefEqual l r) = (show l)++" == "++(show r)