{-# LANGUAGE FlexibleContexts           #-}
{-# LANGUAGE GADTs                      #-}
{-# LANGUAGE TypeOperators              #-}
{-# LANGUAGE ScopedTypeVariables        #-}
{-# LANGUAGE UndecidableInstances       #-}
{-# LANGUAGE TypeFamilies               #-}

module Generics.Regular.Transformations.Explicit (
  diff, apply, Transformation, WithRef (..), Path, Transform, 
  HasRef (..), NiceTransformation, toNiceTransformation, fromNiceTransformation
  ) where

import Generics.Regular
import Generics.Regular.Functions.GOrd
import Control.Applicative ( (<|>) )
import Control.Monad (foldM, liftM, liftM2)
import Control.Monad.State
import Data.Monoid (mappend)
import qualified Data.Map as Map
import Data.Map (Map)
import qualified Generics.Regular.Functions.Eq as GEq

--------------------------------------------------------------------------------
-- Paths, annotations and edits
--------------------------------------------------------------------------------
type Path             = [Int]
data WithRef a b      = InR (PF a b)
                      | Ref Path

instance Functor (PF a) => Functor (WithRef a) where
  fmap f (InR x) = InR (fmap f x)
  fmap _ (Ref p) = Ref p

type Transformation a = [ (Path, Fix (WithRef a)) ]

class (Regular a, Children (PF a), CountI (PF a), Functor (PF a),
       SEq (PF a), ExtractN (PF a), MapN (PF a), GMap (PF a), GOrd (PF a),
       GEq.Eq (PF a)) => Transform a

--------------------------------------------------------------------------------
-- Patching
--------------------------------------------------------------------------------

-- | Apply the edits to the given tree
apply :: Transform a => Transformation a -> a -> Maybe a
apply e t = foldM apply' t e where
  apply' _ ([],   c) = lookupRefs t c
  apply' a (i:is, c) = fmap to . tmapN f . from $ a where
    f j x | i == j     = apply' x (is,c)
          | otherwise  = Just x

-- | Look up the references using the original structure
lookupRefs :: Transform a => a -> Fix (WithRef a) -> Maybe a
lookupRefs r (In (InR a)) = fmap to (fmapM (lookupRefs r) a)
lookupRefs r (In (Ref p)) = extract p r

-- | Extract the subtree at the given path
extract :: Transform a => Path -> a -> Maybe a
extract p a = foldM (\x i -> extractN i $ from x) a p

--------------------------------------------------------------------------------
-- Diffing
--------------------------------------------------------------------------------
data MemoKey a where
  MemoKey :: Bool -> a -> a -> MemoKey a

instance (Regular a, GEq.Eq (PF a)) => Eq (MemoKey a) where
  (MemoKey a1 b1 c1) == (MemoKey a2 b2 c2) =
    a1 == a2 && GEq.eq b1 b2 && GEq.eq c1 c2

instance (Regular a, GEq.Eq (PF a), GOrd (PF a)) => Ord (MemoKey a) where
  compare (MemoKey a1 b1 c1) (MemoKey a2 b2 c2) =
    compare a1 a2 `mappend` gcompare b1 b2 `mappend` gcompare c1 c2

type Memo a = Map (MemoKey a) (Transformation a)

-- | Find a set of edits to transform the first into the second tree
diff :: forall a. (Transform a) => a -> a -> Transformation a
diff a b = evalState (build False a b) Map.empty
  where
    childPaths :: [(a,Path)]
    childPaths = childrenPaths a
    buildmem :: Bool -> a -> a -> State (Memo a) (Transformation a)
    buildmem a b c = do
      mp <- get
      let k = MemoKey a b c
      case Map.lookup k mp of
        Just r  -> return r
        Nothing -> do
          r <- build a b c
          modify (Map.insert k r)
          return r
    build :: Bool -> a -> a -> State (Memo a) (Transformation a)
    build False a' b' | GEq.eq a' b' = return []
    build ins a' b' = case lookupWith GEq.eq b' childPaths of
      Just p  -> return [([], In (Ref p))]
      Nothing -> uses >>= maybe insert return
        where
          -- Construct the edits for the children based on a root
          construct :: Bool -> a -> State (Memo a) (Maybe (Transformation a))
          construct ins' c =
            if shallowEq (from c) (from b')
            then do r <- zipWithM (buildmem ins') (imChildren c) (imChildren b')
                    return $ Just $ concat $ updateChildPaths r
            else return Nothing
          -- Possible edits reusing the existing tree or using a part of
          -- the original tree. The existing tree is only used if we didn't
          -- just insert it, since we want to keep the inserts small
          uses :: State (Memo a) (Maybe (Transformation a))
          uses = reuses >>= \re -> case re of
              Just r | ins -> return re
              _            -> construct ins a' >>= return . best re
          -- Possible edits that include reusing a part of the original tree
          reuses :: State (Memo a) (Maybe (Transformation a))
          reuses = foldM f Nothing childPaths where
            addRef p = fmap (([], In (Ref p)):)
            f c (x,p) = construct False x >>= return . best c . addRef p
          -- Best edit including insertion, only chosen if nothing can be reused
          insert :: State (Memo a) (Transformation a)
          insert = do
            Just r <- construct True b'
            let (r', e') = partialApply (withRefs b') r
            return $ ([], r') : e'

-- | Helper function for lookup with provided compare function
lookupWith :: (a -> a -> Bool) -> a -> [(a,b)] -> Maybe b
lookupWith _ _ [] = Nothing
lookupWith f a ((b,r):bs)
  | f a b     = Just r
  | otherwise = lookupWith f a bs

-- | Pick the best edit
best :: Maybe (Transformation a) -> Maybe (Transformation a) -> Maybe (Transformation a)
best e1 e2 = case (e1,e2) of
  (Just e1', Just e2') -> Just (pickShortest e1' e2')
  _                    -> e1 <|> e2

-- | Pick the shortest of two lists lazily
pickShortest :: [a] -> [a] -> [a]
pickShortest a b = if f a b then a else b
  where f []     _      = True
        f _      []     = False
        f (_:xs) (_:ys) = f xs ys

-- | Lift a tree to a tree with references
withRefs :: Transform a => a -> Fix (WithRef a)
withRefs = In . InR . fmap withRefs . from

-- | Try to apply as much edits to the edit structure as possible
--   to make the final edit smaller
partialApply :: Transform a =>
                Fix (WithRef a) -> Transformation a -> (Fix (WithRef a), Transformation a)
partialApply a [] = (a, [])
partialApply a ((p,r):xs) = case replace p r a of
  Just a' -> partialApply a' xs
  Nothing -> let (a',xs') = partialApply a xs in (a', (p,r) : xs')

-- | Replace a subtree in an edit structure
replace :: (Transform a, Monad m)
           => Path -> Fix (WithRef a) -> Fix (WithRef a) -> m (Fix (WithRef a))
replace []     r _ = return r
replace (i:is) r a = case a of
  In (Ref _) -> fail "Replace"
  In (InR a') -> tmapN f a' >>= return . In . InR
    where f j = if i == j then replace is r else return

-- | Extend the paths of edits for the children with the child number
updateChildPaths :: [Transformation a] -> [Transformation a]
updateChildPaths = zipWith (\n -> map (\(p,c) -> (n:p,c))) [0..]

--------------------------------------------------------------------------------
-- Shallow equality
--------------------------------------------------------------------------------

class SEq f where
  shallowEq :: f a -> f a -> Bool

instance SEq I where
  shallowEq (I _) (I _) = True

instance SEq U where
  shallowEq U U = True

instance Eq a => SEq (K a) where
  shallowEq (K a) (K b) = a == b

instance (SEq f, SEq g) => SEq (f :+: g) where
  shallowEq (L a) (L b) = shallowEq a b
  shallowEq (R a) (R b) = shallowEq a b
  shallowEq _     _     = False

instance (SEq f, SEq g) => SEq (f :*: g) where
  shallowEq (a :*: b) (c :*: d) = shallowEq a c && shallowEq b d

instance SEq f => SEq (C c f) where
  shallowEq (C a) (C b) = shallowEq a b

instance SEq f => SEq (S s f) where
  shallowEq (S a) (S b) = shallowEq a b

--------------------------------------------------------------------------------
-- ExtractN
--------------------------------------------------------------------------------

class ExtractN f where
  extractN :: Monad m => Int -> f a -> m a

instance ExtractN I where
  extractN 0 (I r) = return r
  extractN _ (I _) = fail "extractN"

instance ExtractN (K a) where
  extractN _ (K _) = fail "extractN"

instance ExtractN U where
  extractN _ U = fail "extractN"

instance (ExtractN f, ExtractN g) => ExtractN (f :+: g) where
  extractN i (L x) = extractN i x
  extractN i (R x) = extractN i x

-- Here we decrement our parameter. Does not require right-nested products
instance (CountI f, ExtractN f, ExtractN g) => ExtractN (f :*: g) where
  extractN i (x :*: y) = let n = countI x
                          in if i < n then extractN i     x
                                      else extractN (i-n) y

instance ExtractN f => ExtractN (C c f) where
  extractN i (C x) = extractN i x

instance ExtractN f => ExtractN (S s f) where
  extractN i (S x) = extractN i x

--------------------------------------------------------------------------------
-- MapN
--------------------------------------------------------------------------------

-- | Map a function with child index at a top-level structure
tmapN :: (Monad m, MapN f) => (Int -> a -> m b) -> f a -> m (f b)
tmapN = mapN 0

class MapN f where
  mapN :: Monad m => Int -> (Int -> a -> m b) -> f a -> m (f b)

instance MapN I where
  mapN i f (I r) = liftM I (f i r)

instance MapN (K a) where
  mapN _ _ (K x)  = liftM K (return x)

instance MapN U where
  mapN _ _ U = return U

instance (MapN f, MapN g) => MapN (f :+: g) where
  mapN i f (L x) = liftM L (mapN i f x)
  mapN i f (R x) = liftM R (mapN i f x)

-- Here we increment our parameter. Does not require right-nested products
instance (CountI f, MapN f, MapN g) => MapN (f :*: g) where
  mapN i f (x :*: y) = liftM2 (:*:) (mapN i f x) (mapN (i + countI x) f y)

instance MapN f => MapN (C c f) where
  mapN i f (C x) = liftM C (mapN i f x)

instance MapN f => MapN (S s f) where
  mapN i f (S x) = liftM S (mapN i f x)


--------------------------------------------------------------------------------
-- CountI
--------------------------------------------------------------------------------

class CountI f where
  -- | Count the number of recursive occurrences
  countI :: f a -> Int

instance CountI I where
  countI _ = 1

instance CountI (K a) where
  countI _ = 0

instance CountI U where
  countI _ = 0

instance (CountI f, CountI g) => CountI (f :+: g) where
  countI (L x) = countI x
  countI (R x) = countI x

instance (CountI f, CountI g) => CountI (f :*: g) where
  countI (x :*: y) = countI x + countI y

instance CountI f => CountI (C c f) where
  countI (C x) = countI x

instance CountI f => CountI (S s f) where
  countI (S x) = countI x

--------------------------------------------------------------------------------
-- Children
--------------------------------------------------------------------------------

-- | Get the immediate children
imChildren :: (Regular a, Children (PF a)) => a -> [a]
imChildren = children . from

-- | Get all children with their paths
childrenPaths :: (Regular a, Children (PF a)) => a -> [(a,Path)]
childrenPaths a = (a, []) : [ (r, n : p)
                            | (n, c) <- zip [0..] (imChildren a)
                            , (r, p) <- childrenPaths c ]

class Children f where
  children :: f a -> [a]

instance Children I where
  children (I r) = [r]

instance Children (K a) where
  children (K _) = []

instance Children U where
  children U = []

instance (Children f, Children g) => Children (f :+: g) where
  children (L x) = children x
  children (R x) = children x

instance (Children f, Children g) => Children (f :*: g) where
  children (x :*: y) = children x ++ children y

instance Children f => Children (C c f) where
  children (C x) = children x

instance Children f => Children (S s f) where
  children (S x) = children x

--------------------------------------------------------------------------------
-- Nicer interface
--------------------------------------------------------------------------------
class HasRef a where
  type RefRep a
  
  toRef   :: WithRef a (RefRep a) -> RefRep a
  fromRef :: RefRep a -> WithRef a (RefRep a)

type NiceTransformation a = [ (Path, RefRep a) ] 

toNiceTransformation :: (Functor (PF a), HasRef a) 
                        => Transformation a -> NiceTransformation a
toNiceTransformation = map (\(p,e) -> (p, tr e)) where
  tr = toRef . fmap tr . out

fromNiceTransformation :: (Functor (PF a), HasRef a) 
                          => NiceTransformation a -> Transformation a
fromNiceTransformation = map (\(p,e) -> (p, fr e)) where
  fr = In . fmap fr . fromRef