{-# LANGUAGE FlexibleContexts           #-}
{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE GADTs                      #-}
{-# LANGUAGE TypeOperators              #-}
{-# LANGUAGE ScopedTypeVariables        #-}
{-# LANGUAGE UndecidableInstances       #-}
{-# LANGUAGE TypeFamilies               #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeSynonymInstances       #-}
{-# LANGUAGE OverlappingInstances       #-}

module Generics.Regular.Transformations.Main
  ( diff, apply
  , Transformation, WithRef (..), Path (..), Transform, HasRef (..)
  , NiceTransformation, toNiceTransformation, fromNiceTransformation
  ) where

import Prelude as P
import Generics.Regular
import Generics.Regular.Functions.Show hiding ( show, shows, Show )
import qualified Generics.Regular.Functions.Show as R
import Generics.Regular.Zipper
import Generics.Regular.Functions.GOrd
import Control.Applicative ( (<|>) )
import Control.Monad (foldM, liftM, liftM2)
import Control.Monad.State
import Data.Monoid (mappend)
import qualified Data.Map as Map
import Data.Map (Map)
import qualified Generics.Regular.Functions.Eq as GEq

-- Paths, annotations and edits
type Path  a = [Dir (PF a)]
type Dir f = Ctx f ()

data WithRef a b = InR (PF a b)
                 | Ref (Path a)

instance Functor (PF a) => Functor (WithRef a) where
  fmap f (InR x) = InR (fmap f x)
  fmap _ (Ref p) = Ref p -- ?

type Transformation a = [ (Path a, Fix (WithRef a)) ]

class (Regular a, Children (PF a), Functor (PF a), ZipChildren (PF a),
       SEq (PF a), ExtractP (PF a), MapP (PF a), GMap (PF a), GOrd (PF a),
       GEq.Eq (PF a)) => Transform a

-- Showing paths
newtype ConIndex = CI Int deriving (Eq, Num)

instance Show ConIndex where
  show (CI (-1)) = ""
  show (CI n   ) = "_" ++ show n ++ " "

class ShowPath f where
  showsPrecPath :: ShowS -> ConIndex -> Int -> Dir f -> ShowS

instance (ShowPath f, ShowPath g) => ShowPath (f :+: g) where
  showsPrecPath r d n (CL p) = showsPrecPath r d n p
  showsPrecPath r d n (CR p) = showsPrecPath r d n p

instance (ShowPath f, ShowPath g, CountIs g) => ShowPath (f :*: g) where
  -- Going left on a product is unproblematic
  showsPrecPath r d n (C1 p _) = showsPrecPath r d n p
  -- Going right, however, we have to increase |d| by the number of children to
  -- our left
  showsPrecPath r d n (C2 _ (p :: Ctx g ())) =
    let newd = d + CI (countIs (undefined :: g r))
    in showsPrecPath r newd n p

instance (ShowPath f, Constructor c) => ShowPath (C c f) where
  showsPrecPath r d n (CC p) = let name = conName (undefined :: C c f r)
                               in showParen (n > 10) $ showString name
                                                     . showsPrecPath r 0 11 p

instance ShowPath (K a) where showsPrecPath _ _ _ _ = id
instance ShowPath U     where showsPrecPath _ _ _ _ = id

instance ShowPath I where
  showsPrecPath r d n CId = shows d . r

showsPrecPathC :: (ShowPath f) => ConIndex -> Int -> [Dir f] -> ShowS
showsPrecPathC d n []     = showString "End"
showsPrecPathC d n (p:ps) = showsPrecPath (showsPrecPathC d n ps) d n p

instance (ShowPath f) => Show [Dir f] where
  showsPrec = showsPrecPathC 0

instance (ShowPath (PF a), Functor (PF a), R.Show (PF a))
    => Show (Fix (WithRef a)) where
  showsPrec n (In (Ref p)) = showParen (n > 10)
                           $ showString "Ref " . showsPrec 11 p
  showsPrec n (In (InR x)) = showParen (n > 10)
                           $ showString "InR " . R.hshowsPrec showsPrec False 11 x

spaces :: [ShowS] -> ShowS
spaces = intersperse " "

intersperse :: String -> [ShowS] -> ShowS
intersperse s []     = id
intersperse s [x]    = x
intersperse s (x:xs) = x . (s ++) . spaces xs

class CountIs f where
  countIs :: f r -> Int

instance CountIs I     where countIs _ = 1
instance CountIs U     where countIs _ = 0
instance CountIs (K a) where countIs _ = 0

instance (CountIs f) => CountIs (C c f) where
  countIs (C x) = countIs x

instance (CountIs f, CountIs g) => CountIs (f :+: g) where
  countIs (L x) = countIs x
  countIs (R x) = countIs x

instance (CountIs f, CountIs g) => CountIs (f :*: g) where
  countIs (x :*: y) = countIs x + countIs y

-- Patching

-- | Apply the edits to the given tree
apply :: Transform a => Transformation a -> a -> Maybe a
apply e t = foldM apply' t e where
  apply' a (p, c) = mapP (flip lookupRefs c) p a

-- | Look up the references using the original structure
lookupRefs :: Transform a => a -> Fix (WithRef a) -> Maybe a
lookupRefs r (In (InR a)) = fmap to (fmapM (lookupRefs r) a)
lookupRefs r (In (Ref p)) = extract p r

-- Diffing
data MemoKey a where
  MemoKey :: Bool -> a -> a -> MemoKey a

instance (Regular a, GEq.Eq (PF a)) => Eq (MemoKey a) where
  (MemoKey a1 b1 c1) == (MemoKey a2 b2 c2) =
    a1 == a2 && GEq.eq b1 b2 && GEq.eq c1 c2

instance (Regular a, GEq.Eq (PF a), GOrd (PF a)) => Ord (MemoKey a) where
  compare (MemoKey a1 b1 c1) (MemoKey a2 b2 c2) =
    compare a1 a2 `mappend` gcompare b1 b2 `mappend` gcompare c1 c2

type Memo a = Map (MemoKey a) (Transformation a)

-- | Find a set of edits to transform the first into the second tree
diff :: forall a. (Transform a) => a -> a -> Transformation a
diff a b = evalState (build False a b) Map.empty
    childPaths :: [(a,Path a)]
    childPaths = childrenPaths a
    buildmem :: Bool -> a -> a -> State (Memo a) (Transformation a)
    buildmem a b c = do
      mp <- get
      let k = MemoKey a b c
      case Map.lookup k mp of
        Just r  -> return r
        Nothing -> do
          r <- build a b c
          modify (Map.insert k r)
          return r
    build :: Bool -> a -> a -> State (Memo a) (Transformation a)
    build False a' b' | GEq.eq a' b' = return []
    build ins a' b' = case lookupWith GEq.eq b' childPaths of
      Just p  -> return [([], In (Ref p))]
      Nothing -> uses >>= maybe insert return
          -- Construct the edits for the children based on a root
          construct :: Bool -> a -> State (Memo a) (Maybe (Transformation a))
          construct ins' c =
            if shallowEq (from c) (from b')
            then do r <- zipChildrenM (\p c1 c2 -> buildmem ins' c1 c2 >>=
                                                   return . updateChildPaths p) c b'
                    return $ Just $ concat r
            else return Nothing
          -- Possible edits reusing the existing tree or using a part of
          -- the original tree. The existing tree is only used if we didn't
          -- just insert it, since we want to keep the inserts small
          uses :: State (Memo a) (Maybe (Transformation a))
          uses = reuses >>= \re -> case re of
              Just r | ins -> return re
              _            -> construct ins a' >>= return . best re
          -- Possible edits that include reusing a part of the original tree
          reuses :: State (Memo a) (Maybe (Transformation a))
          reuses = foldM f Nothing childPaths where
            addRef p = fmap (([], In (Ref p)):)
            f c (x,p) = construct False x >>= return . best c . addRef p
          -- Best edit including insertion, only chosen if nothing can be reused
          insert :: State (Memo a) (Transformation a)
          insert = do
            Just r <- construct True b'
            let (r', e') = partialApply (withRefs b') r
            return $ ([], r') : e'

-- | Helper function for lookup with provided compare function
lookupWith :: (a -> a -> Bool) -> a -> [(a,b)] -> Maybe b
lookupWith _ _ [] = Nothing
lookupWith f a ((b,r):bs)
  | f a b     = Just r
  | otherwise = lookupWith f a bs

-- | Pick the best edit
best :: Maybe (Transformation a) -> Maybe (Transformation a) -> Maybe (Transformation a)
best e1 e2 = case (e1,e2) of
  (Just e1', Just e2') -> Just (pickShortest e1' e2')
  _                    -> e1 <|> e2

-- | Pick the shortest of two lists lazily
pickShortest :: [a] -> [a] -> [a]
pickShortest a b = if f a b then a else b
  where f []     _      = True
        f _      []     = False
        f (_:xs) (_:ys) = f xs ys

-- | Lift a tree to a tree with references
withRefs :: Transform a => a -> Fix (WithRef a)
withRefs = In . InR . fmap withRefs . from

-- | Try to apply as much edits to the edit structure as possible
--   to make the final edit smaller
partialApply :: Transform a =>
                Fix (WithRef a) -> Transformation a -> (Fix (WithRef a), Transformation a)
partialApply a [] = (a, [])
partialApply a ((p,r):xs) = case replace p r a of
  Just a' -> partialApply a' xs
  Nothing -> let (a',xs') = partialApply a xs in (a', (p,r) : xs')

-- | Replace a subtree in an edit structure
replace :: (Transform a, Monad m)
           => Path a -> Fix (WithRef a) -> Fix (WithRef a) -> m (Fix (WithRef a))
replace p r a = mapPR (const (return r)) p a

-- | Extend the paths of edits for the children with the child number
updateChildPaths :: Path a -> Transformation a -> Transformation a
updateChildPaths p = map (\(p2,c) -> (p ++ p2,c))
-- Shallow equality

class SEq f where
  shallowEq :: f a -> f a -> Bool

instance SEq I where
  shallowEq (I _) (I _) = True

instance SEq U where
  shallowEq U U = True

instance Eq a => SEq (K a) where
  shallowEq (K a) (K b) = a == b

instance (SEq f, SEq g) => SEq (f :+: g) where
  shallowEq (L a) (L b) = shallowEq a b
  shallowEq (R a) (R b) = shallowEq a b
  shallowEq _     _     = False

instance (SEq f, SEq g) => SEq (f :*: g) where
  shallowEq (a :*: b) (c :*: d) = shallowEq a c && shallowEq b d

instance SEq f => SEq (C c f) where
  shallowEq (C a) (C b) = shallowEq a b

instance SEq f => SEq (S s f) where
  shallowEq (S a) (S b) = shallowEq a b

-- Extract
-- | Extract the subtree at the given path
extract :: (Transform a, Monad m) => Path a -> a -> m a
extract []     = return
extract (p:ps) = extractP (extract ps) p . from

class ExtractP f where
  extractP :: Monad m => (a -> m a) -> Dir f -> f a -> m a

instance ExtractP I where
  extractP f CId (I r) = f r

instance ExtractP (K a) where
  extractP _ _ (K _) = fail "extractP"

instance ExtractP U where
  extractP _ _ U = fail "extractP"

instance (ExtractP f, ExtractP g) => ExtractP (f :+: g) where
  extractP f (CL p) (L x) = extractP f p x
  extractP f (CR p) (R x) = extractP f p x
  extractP _ _      _     = fail "extractP"

instance (ExtractP f, ExtractP g) => ExtractP (f :*: g) where
  extractP f (C1 p _) (x :*: _) = extractP f p x
  extractP f (C2 _ p) (_ :*: y) = extractP f p y

instance ExtractP f => ExtractP (C c f) where
  extractP f (CC p) (C x) = extractP f p x

instance ExtractP f => ExtractP (S s f) where
  extractP f (CS p) (S x) = extractP f p x

-- MapP
-- | Map a function over the child in a specific path

mapP :: (MapP (PF a), Monad m, Regular a) => (a -> m a) -> Path a -> a -> m a
mapP f []     = f
mapP f (p:ps) = liftM to . mapP' (mapP f ps) p . from

-- | Version of |mapP| for trees with references
mapPR :: (Transform a, Monad m) =>
         (Fix (WithRef a) -> m (Fix (WithRef a)))
         -> Path a -> Fix (WithRef a) -> m (Fix (WithRef a))
mapPR f p (In (Ref _)) = fail "mapPR"
mapPR f []     x            = f x
mapPR f (p:ps) (In (InR r)) = mapP' (mapPR f ps) p r >>= return . In . InR

class MapP f where
  mapP' :: Monad m => (b -> m b) -> Dir f -> f b -> m (f b)

instance MapP I where
  mapP' f CId (I r) = liftM I (f r)

instance MapP (K a) where
  mapP' _ _ (K x)  = liftM K (return x)

instance MapP U where
  mapP' _ _ U = return U

instance (MapP f, MapP g) => MapP (f :+: g) where
  mapP' f (CL p) (L x) = liftM L (mapP' f p x)
  mapP' f (CR p) (R x) = liftM R (mapP' f p x)

instance (MapP f, MapP g) => MapP (f :*: g) where
  mapP' f (C1 p _) (x :*: y) = liftM2 (:*:) (mapP' f p x) (return y)
  mapP' f (C2 _ p) (x :*: y) = liftM2 (:*:) (return x) (mapP' f p y)

instance MapP f => MapP (C c f) where
  mapP' f (CC p) (C x) = liftM C (mapP' f p x)

instance MapP f => MapP (S s f) where
  mapP' f (CS p) (S x) = liftM S (mapP' f p x)

-- Children
-- | Get the immediate children
imChildren :: (Regular a, Children (PF a)) => a -> [a]
imChildren = map fst . children . from

-- | Get all children with their paths
childrenPaths :: (Regular a, Children (PF a)) => a -> [(a,Path a)]
childrenPaths a = (a,[]) : [ (r, n : p)
                           | (c, n) <- children (from a)
                           , (r, p) <- childrenPaths c ]

class Children f where
  children :: f a -> [(a, Dir f)]

instance Children I where
  children (I r) = [(r, CId)]

instance Children (K a) where
  children (K _) = []

instance Children U where
  children U = []

instance (Children f, Children g) => Children (f :+: g) where
  children (L x) = [ (a, CL p) | (a,p) <- children x ]
  children (R x) = [ (a, CR p) | (a,p) <- children x ]

instance (Children f, Children g) => Children (f :*: g) where
  children (x :*: y) = [ (a, C1 p nullY) | (a,p) <- children x ]
                       ++ [ (a, C2 nullX p) | (a,p) <- children y ]
    where nullX = error "nullX" -- fmap (const ()) x
          nullY = error "nullY" -- fmap (const ()) y
-- The errors above should be safe, because we're never inspecting those anyway

instance Children f => Children (C c f) where
  children (C x) = [ (a, CC p) | (a,p) <- children x ]

instance Children f => Children (S s f) where
  children (S x) = [ (a, CS p) | (a,p) <- children x ]

-- ZipChildren
zipChildrenM :: (Transform a, Monad m) => (Path a -> a -> a -> m b) -> a -> a -> m [b]
zipChildrenM f a b = zipChildren f (:[]) (from a) (from b)

class ZipChildren f where
  zipChildren :: Monad m => (Path a -> a -> a -> m b) -> (Dir f -> Path a) -> f a -> f a -> m [b]

instance ZipChildren I where
  zipChildren f p (I a) (I b) = f (p CId) a b >>= \x -> return [x]

instance ZipChildren (K a) where
  zipChildren _ _ _ _ = return []

instance ZipChildren U where
  zipChildren _ _ _ _ = return []

instance (ZipChildren f, ZipChildren g) => ZipChildren (f :+: g) where
  zipChildren f p (L x) (L y) = zipChildren f (p . CL) x y
  zipChildren f p (R x) (R y) = zipChildren f (p . CR) x y

instance (ZipChildren f, ZipChildren g) => ZipChildren (f :*: g) where
  zipChildren f p (x1 :*: y1) (x2 :*: y2) =
    liftM2 (++) (zipChildren f (\x -> p $ C1 x nullY) x1 x2)
                (zipChildren f (\x -> p $ C2 nullX x) y1 y2)
      where nullX = error "nullX" -- fmap (const ()) x
            nullY = error "nullY" -- fmap (const ()) y
-- The errors above should be safe, because we're never inspecting those anyway

instance ZipChildren f => ZipChildren (C c f) where
  zipChildren f p (C x) (C y) = zipChildren f (p . CC) x y

instance ZipChildren f => ZipChildren (S s f) where
  zipChildren f p (S x) (S y) = zipChildren f (p . CS) x y

-- Nicer interface
class HasRef a where
  type RefRep a

  toRef   :: WithRef a (RefRep a) -> RefRep a
  fromRef :: RefRep a -> WithRef a (RefRep a)

type NiceTransformation a = [ (Path a, RefRep a) ]

toNiceTransformation :: (Functor (PF a), HasRef a)
                        => Transformation a -> NiceTransformation a
toNiceTransformation = map (\(p,e) -> (p, tr e)) where
  tr = toRef . fmap tr . out

fromNiceTransformation :: (Functor (PF a), HasRef a)
                          => NiceTransformation a -> Transformation a
fromNiceTransformation = map (\(p,e) -> (p, fr e)) where
  fr = In . fmap fr . fromRef