{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE RankNTypes            #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE UndecidableInstances  #-}
{-# LANGUAGE ScopedTypeVariables   #-}

module Generics.MultiRec.Transformations.Explicit (
  diff, apply, Transformation, AnyInsert (..), WithRef (..), Path, 
  Transform, OrdI (..),
  HasRef (..), NiceTransformation, NiceInsert (..),
  toNiceTransformation, fromNiceTransformation
  ) where

import Generics.MultiRec.Any
import Generics.MultiRec.Eq
import Generics.MultiRec.Ord

import Generics.MultiRec hiding (show, foldM)
import Control.Applicative ( (<|>) )
import Control.Monad (foldM)
import Control.Monad.State hiding (foldM)
import Data.Monoid (mappend)
import qualified Data.Map as Map
import Data.Map (Map)

--------------------------------------------------------------------------------
-- Paths, annotations, edits and existentials
--------------------------------------------------------------------------------
data WithRef phi f a = InR (PF phi f a)
                     | Ref Path

type Path = [Int]

data AnyInsert phi where
  AnyInsert :: phi ix -> Path -> HFix (WithRef phi) ix -> AnyInsert phi

type Transformation phi = [ AnyInsert phi]

class (Fam phi, Children phi (PF phi), CountI phi (PF phi),
       HFunctor phi (PF phi), SEq phi (PF phi), ExtractN phi (PF phi), 
       MapN phi (PF phi), EqS phi, HEq phi (PF phi), HOrd phi (PF phi),
       OrdI phi) => Transform phi

--------------------------------------------------------------------------------
-- Applying
--------------------------------------------------------------------------------
-- | Apply the transformation to the given tree
apply :: forall phi ix. (Transform phi)
         => phi ix -> ix -> Transformation phi -> Maybe ix
apply p t = foldM (apply' p) t where
  apply' :: forall ix. phi ix -> ix -> AnyInsert phi -> Maybe ix
  apply' p' _ (AnyInsert p'' [] c) = case eqS p' p'' of 
    Just Refl -> lookupRefs p t p' c
    Nothing   -> Nothing
  apply' p' a (AnyInsert p'' (i:is) c) =
    liftM (to p') $ tmapN f p' $ from p' a where
      f :: forall ix. Int -> phi ix -> I0 ix -> Maybe (I0 ix)
      f j p''' x | i == j    = liftM I0 (apply' p''' (unI0 x) (AnyInsert p'' is c))
                 | otherwise = return x

-- | Look up the references using the original structure
lookupRefs :: forall phi ix ix'. (Fam phi, HFunctor phi (PF phi), ExtractN phi (PF phi), EqS phi) 
              => phi ix -> ix -> phi ix' -> HFix (WithRef phi) ix' -> Maybe ix'
lookupRefs p r p' = build .  hout where
  build :: WithRef phi (HFix (WithRef phi)) ix' -> Maybe ix'
  build (InR x) = liftM (to p') (hmapM (\p'' -> liftM I0 . lookupRefs p r p'') p' x)
  build (Ref l) = extract l p r >>= matchAny p'

-- | Extract the subtree at the given path
extract :: (Fam phi, ExtractN phi (PF phi)) => Path -> phi ix -> ix -> Maybe (Any phi)
extract []     p a = return $ Any p a
extract (i:is) p a = extractN i p a >>= \(Any p' x) -> extract is p' x

--------------------------------------------------------------------------------
-- Memoisation
--------------------------------------------------------------------------------
-- | Comparing index of different types
class OrdI phi where
  compareI :: phi ix -> phi ix' -> Ordering
  compareI p1 p2 = compare (indexI p1) (indexI p2)
  indexI :: phi ix -> Int
  indexI = error "At least compareI or indexI should be implemented."

-- | Key used in memoisation table
data MemoKey phi where
  MemoKey :: phi ix -> Bool -> ix -> ix -> MemoKey phi

instance (EqS phi, Fam phi, HEq phi (PF phi)) => Eq (MemoKey phi) where
  (MemoKey p1 a1 b1 c1) == (MemoKey p2 a2 b2 c2) = case eqS p1 p2 of
    Nothing   -> False
    Just Refl -> a1 == a2 && eq p1 b1 b2 && eq p1 c1 c2

instance (EqS phi, Fam phi, OrdI phi, HEq phi (PF phi), HOrd phi (PF phi))
         => Ord (MemoKey phi) where
  compare (MemoKey p1 a1 b1 c1) (MemoKey p2 a2 b2 c2) = case eqS p1 p2 of
    Nothing   -> compareI p1 p2
    Just Refl -> compare a1 a2 `mappend` gcompare p1 b1 b2 
                               `mappend` gcompare p1 c1 c2

-- | The type of the memo table
type MemoTable phi = Map (MemoKey phi) (Transformation phi)
type Memo phi a = State (MemoTable phi) a

runMemo :: Memo phi a -> a
runMemo = flip evalState Map.empty

recMemo :: (Fam phi, HEq phi (PF phi), HOrd phi (PF phi), EqS phi, OrdI phi) => 
           (forall ix. Bool -> phi ix -> ix -> ix -> Memo phi (Transformation phi))
           -> Bool -> phi ix -> ix -> ix -> Memo phi (Transformation phi)
recMemo f a p b c = do
  mp <- get
  let k = MemoKey p a b c
  case Map.lookup k mp of
    Just r -> return r
    Nothing -> do
      r <- f a p b c
      modify (Map.insert k r)
      return r

--------------------------------------------------------------------------------
-- Diffing
--------------------------------------------------------------------------------
-- | Find a set of insertions to transform the first into the second tree
diff :: forall phi ix. (Transform phi)
        => phi ix -> ix -> ix -> Transformation phi
diff p a b = runMemo (build False p a b)
  where
    childPaths :: [(Any phi, Path)]
    childPaths = childrenPaths p a
    build :: forall ix. Bool -> phi ix -> ix -> ix -> Memo phi (Transformation phi)
    build False p' a' b' | eq p' a' b' = return []
    build ins   p' a' b' = case anyLookup p' b' childPaths of
      Just l  -> return [ AnyInsert p' [] (HIn $ Ref l) ]
      Nothing -> uses >>= maybe insert return -- Only insert when we cannot reuse
        where
          -- Construct the edits for the children based on a root
          construct :: Bool -> ix -> Memo phi (Maybe (Transformation phi))
          construct ins' c = 
            if shallowEq p' (from p' c) (from p' b')
            then do r <- zipWithM (\(Any p1 c1) (Any p2 c2) -> case eqS p1 p2 of
                                      Just Refl -> recMemo build ins' p1 c1 c2)
                         (imChildren p' c) (imChildren p' b')
                    return $ Just $ concat $ updateChildPaths r
            else return Nothing
          -- Possible edits reusing the existing tree or using a part of
          -- the original tree. The existing tree is only used if we didn't
          -- just insert it, since we want to keep the inserts small
          uses :: Memo phi (Maybe (Transformation phi))
          uses = reuses >>= \re -> case re of
              Just r | ins -> return re
              _            -> construct ins a' >>= return . pickBest re
          -- Possible edits that include reusing a part of the original tree
          reuses :: Memo phi (Maybe (Transformation phi))
          reuses = foldM f Nothing childPaths where
            addRef :: Path -> Maybe (Transformation phi) 
                       -> Maybe (Transformation phi)
            addRef l = liftM ((AnyInsert p' [] (HIn $ Ref l)):)
            f c (Any p'' x, l) = case eqS p' p'' of
              Just Refl -> construct False x >>= return . pickBest c . addRef l
              Nothing   -> return c
          -- Best edit including insertion, only chosen if nothing can be reused
          insert :: Memo phi (Transformation phi)
          insert = do
            Just r <- construct True b'
            let (r',e')  = partialApply p' (annotate p' b') r
            return $ (AnyInsert p' [] r') : e'

-- | Pick the best edit
pickBest :: Maybe (Transformation phi) -> Maybe (Transformation phi) -> Maybe (Transformation phi)
pickBest e1 e2 = case (e1,e2) of
  (Just e1', Just e2') -> Just (pickShortest e1' e2')
  _                    -> e1 <|> e2

-- | Pick the shortest of two lists lazily
pickShortest :: [a] -> [a] -> [a]
pickShortest a b = if f a b then a else b
  where f []     _      = True
        f _      []     = False
        f (_:xs) (_:ys) = f xs ys

-- | Lookup with a specific type
anyLookup :: (Fam phi, EqS phi, HEq phi (PF phi))
             => phi ix -> ix -> [(Any phi, a)] -> Maybe a
anyLookup p _ [] = Nothing
anyLookup p x ((Any p' y,r) : ys) = case eqS p p' of
  Just Refl | eq p x y -> Just r
  _                    -> anyLookup p x ys

-- | Lift a tree to an edit structure
annotate :: (Fam phi, HFunctor phi (PF phi)) => phi ix -> ix -> HFix (WithRef phi) ix
annotate p = HIn . InR . hmap (\p' (I0 x) -> annotate p' x) p . from p

-- | Extend the paths of edits for the children with the child number
updateChildPaths :: [Transformation phi] -> [Transformation phi]
updateChildPaths = zipWith (\n -> map (\(AnyInsert p l c) -> (AnyInsert p (n:l) c))) [0..]

-- | Try to apply as much edits to the edit structure as possible
--   to make the final edit smaller
partialApply :: (Fam phi, CountI phi (PF phi), ExtractN phi (PF phi), MapN phi (PF phi), EqS phi)
                => phi ix -> HFix (WithRef phi) ix -> Transformation phi -> (HFix (WithRef phi) ix, Transformation phi)
partialApply _ a [] = (a, [])
partialApply p a (AnyInsert p' l x : xs) = case replace p' l x p a of
  Just a' -> partialApply p a' xs
  Nothing -> let (a',xs') = partialApply p a xs in (a', AnyInsert p' l x : xs')

-- | Replace a subtree in an edit structure
replace :: forall phi ix ix'. (Fam phi, EqS phi, MapN phi (PF phi))
           => phi ix -> Path -> HFix (WithRef phi) ix
           -> phi ix' -> HFix (WithRef phi) ix' -> Maybe (HFix (WithRef phi) ix')
replace p [] r p' _ = case eqS p p' of
  Just Refl -> Just r
  Nothing   -> Nothing
replace p (i:is) r p' a = case hout a of
  Ref _  -> Nothing
  InR a' -> liftM HIn . liftM InR . tmapN f p' $ a'
    where f :: forall ix. Int -> phi ix -> HFix (WithRef phi) ix -> Maybe (HFix (WithRef phi) ix)
          f j p'' = if i == j then replace p is r p'' else Just

--------------------------------------------------------------------------------
-- Shallow equality
--------------------------------------------------------------------------------

class SEq phi (f :: (* -> *) -> * -> *) where
  shallowEq :: phi ix -> f r ix  -> f r ix -> Bool

instance El phi xi => SEq phi (I xi) where
  shallowEq _ (I _) (I _) = True

instance SEq phi U where
  shallowEq _ U U = True

instance Eq a => SEq phi (K a) where
  shallowEq p (K a) (K b) = a == b

instance (SEq phi f, SEq phi g) => SEq phi (f :+: g) where
  shallowEq p (L a) (L b) = shallowEq p a b
  shallowEq p (R a) (R b) = shallowEq p a b
  shallowEq _ _     _     = False

instance (SEq phi f, SEq phi g) => SEq phi (f :*: g) where
  shallowEq p (a :*: b) (c :*: d) = shallowEq p a c && shallowEq p b d

instance SEq phi f => SEq phi (f :>: ix) where
  shallowEq p (Tag a) (Tag b) = shallowEq p a b

instance SEq phi f => SEq phi (C c f) where
  shallowEq p (C a) (C b) = shallowEq p a b

-- Todo: is this the best choice?
instance SEq phi ([] :.: ix) where
  shallowEq p (D a) (D b) = length a == length b

--------------------------------------------------------------------------------
-- ExtractN
--------------------------------------------------------------------------------

extractN :: (Fam phi, ExtractN phi (PF phi), Monad m) 
            => Int -> phi ix -> ix -> m (Any phi)
extractN i p v = extractN' (\p (I0 v) -> Any p v) i p (from p v)

class ExtractN phi (f :: (* -> *) -> * -> *) where
  extractN' :: Monad m => (forall ix. phi ix -> r ix -> r')
                          -> Int -> phi ix -> f r ix -> m r'

instance El phi xi => ExtractN phi (I xi) where
  extractN' mka 0 _ (I r) = return $ mka proof r
  extractN' _   _ _ (I _) = fail "extractN"

instance ExtractN phi (K a) where
  extractN' mka _ _ (K _) = fail "extractN"

instance ExtractN phi U where
  extractN' mka _ _ U = fail "extractN"

instance (ExtractN phi f, ExtractN phi g) => ExtractN phi (f :+: g) where
  extractN' mka i p (L x) = extractN' mka i p x
  extractN' mka i p (R x) = extractN' mka i p x

instance (CountI phi f, ExtractN phi f, ExtractN phi g) => ExtractN phi (f :*: g) where
  extractN' mka i p (x :*: y) = let n = countI p x
                                in if i < n then extractN' mka i     p x
                                            else extractN' mka (i-n) p y

instance ExtractN phi f => ExtractN phi (f :>: ix) where
  extractN' mka i p (Tag x) = extractN' mka i p x

instance ExtractN phi f => ExtractN phi (C c f) where
  extractN' mka i p (C x) = extractN' mka i p x

-- Todo: is this the best choice?
instance ExtractN phi f => ExtractN phi ([] :.: f) where
  extractN' mka i p (D x) = extractN' mka 0 p (x !! i)

--------------------------------------------------------------------------------
-- MapN
--------------------------------------------------------------------------------

-- | Map a function with child index at a top-level structure
tmapN :: (Fam phi, MapN phi f, Monad m)
         => (forall ix. Int -> phi ix -> r ix -> m (r' ix))
         -> phi ix -> f r ix -> m (f r' ix)
tmapN = mapN 0

class MapN phi (f :: (* -> *) -> * -> *) where
  mapN :: Monad m => Int -> (forall ix. Int -> phi ix -> r ix -> m (r' ix))
                           -> phi ix -> f r ix -> m (f r' ix)

instance El phi xi => MapN phi (I xi) where
  mapN i f p (I x) = liftM I (f i proof x)

instance MapN phi (K a) where
  mapN _ _ _ (K x)  = return $ K x

instance MapN phi U where
  mapN _ _ _ U = return U

instance (MapN phi f, MapN phi g) => MapN phi (f :+: g) where
  mapN i f p (L x) = liftM L (mapN i f p x)
  mapN i f p (R x) = liftM R (mapN i f p x)

-- Here we increment our parameter. Does not require right-nested products
instance (CountI phi f, MapN phi f, MapN phi g) => MapN phi (f :*: g) where
  mapN i f p (x :*: y) = liftM2 (:*:) (mapN i f p x) (mapN (i + countI p x) f p y)

instance MapN phi f => MapN phi (f :>: ix) where
  mapN i f p (Tag x) = liftM Tag (mapN i f p x)

instance MapN phi f => MapN phi (C c f) where
  mapN i f p (C x) = liftM C (mapN i f p x)

-- Todo: is this the best choice?
instance (CountI phi f, MapN phi f) => MapN phi ([] :.: f) where
  mapN i f p (D [])     = return $ D []
  mapN i f p (D (x:xs)) = do h <- mapN i f p x
                             t <- mapN (i + countI p x) f p (D xs)
                             return $ D (h : unD t)

--------------------------------------------------------------------------------
-- CountI
--------------------------------------------------------------------------------

class CountI phi (f :: (* -> *) -> * -> *) where
  -- | Count the number of recursive occurrences
  countI :: phi ix -> f r ix -> Int

instance El phi xi => CountI phi (I xi) where
  countI _ _ = 1

instance CountI phi (K a) where
  countI _ _ = 0

instance CountI phi U where
  countI _ _ = 0

instance (CountI phi f, CountI phi g) => CountI phi (f :+: g) where
  countI p (L x) = countI p x
  countI p (R x) = countI p x

instance (CountI phi f, CountI phi g) => CountI phi (f :*: g) where
  countI p (x :*: y) = countI p x + countI p y

instance CountI phi f => CountI phi (f :>: ix) where
  countI p (Tag x) = countI p x

instance CountI phi f => CountI phi (C c f) where
  countI p (C x) = countI p x

-- Todo: is this the best choice?
instance CountI phi f => CountI phi ([] :.: f) where
  countI p (D x) = sum (map (countI p) x)

--------------------------------------------------------------------------------
-- Children
--------------------------------------------------------------------------------

-- | Get the immediate children
imChildren :: (Fam phi, Children phi (PF phi)) => phi ix -> ix -> [Any phi]
imChildren p x = children (\p (I0 v) -> Any p v) p (from p x)

-- | Get all children with their paths
childrenPaths :: (Fam phi, Children phi (PF phi)) => phi ix -> ix -> [(Any phi, Path)]
childrenPaths p a = (Any p a, []) : 
                    [ (r, n : p)
                    | (n, Any p' c) <- zip [0..] (imChildren p a)
                    , (r, p) <- childrenPaths p' c ]

class Children phi (f :: (* -> *) -> * -> *) where
  children :: (forall ix. phi ix -> r ix -> Any phi) -> phi ix -> f r ix -> [Any phi]

instance (Fam phi, El phi xi) => Children phi (I xi) where
  children mka _ (I r) = [mka proof r]

instance Children phi (K a) where
  children _ _ (K _) = []

instance Children phi U where
  children _ _ U = []

instance (Children phi f, Children phi g) => Children phi (f :+: g) where
  children mka p (L x) = children mka p x
  children mka p (R x) = children mka p x

instance (Children phi f, Children phi g) => Children phi (f :*: g) where
  children mka p (x :*: y) = children mka p x ++ children mka p y

instance Children phi f => Children phi (C c f) where
  children mka p (C x) = children mka p x

instance Children phi f => Children phi (f :>: ix) where
  children mka p (Tag x) = children mka p x

-- Todo: is this the best choice?
instance Children phi f => Children phi ([] :.: f) where
  children mka p (D x) = concatMap (children mka p) x

--------------------------------------------------------------------------------
-- Nicer interface
--------------------------------------------------------------------------------

class HasRef phi where
  type RefRep phi ix
  -- I don't like HFix here but can't figure out how to do a
  -- single step unwrapping
  toRef   :: phi ix -> HFix (WithRef phi) ix -> RefRep phi ix
  fromRef :: phi ix -> RefRep phi ix -> HFix (WithRef phi) ix

data NiceInsert phi where
  NiceInsert :: phi ix -> Path -> RefRep phi ix -> NiceInsert phi

type NiceTransformation phi = [ NiceInsert phi]

toNiceTransformation :: HasRef phi => Transformation phi -> NiceTransformation phi
toNiceTransformation = map f
  where f (AnyInsert p l x) = NiceInsert p l (toRef p x)

fromNiceTransformation :: HasRef phi => NiceTransformation phi -> Transformation phi
fromNiceTransformation = map f
  where f (NiceInsert p l x) = AnyInsert p l (fromRef p x)