{-# LANGUAGE DeriveDataTypeable   #-}
{-# LANGUAGE FlexibleContexts     #-}
{-# LANGUAGE GADTs                #-}
{-# LANGUAGE KindSignatures       #-}
{-# LANGUAGE TypeFamilies         #-}
{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE EmptyDataDecls       #-}
{-# LANGUAGE ScopedTypeVariables  #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleInstances    #-}
{-# LANGUAGE OverlappingInstances #-}

{-# OPTIONS_GHC -Wall             #-}

-- |
-- Module      :  Generics.Instant.Rewriting
-- Copyright   :  (c) 2010, Universiteit Utrecht
-- License     :  BSD3
-- Maintainer  :  generics@haskell.org
-- Stability   :  experimental
-- Portability :  non-portable
-- This is the top module for the rewriting library. All functionality is
-- implemented in this module. For examples of how to use the library, see
-- the included files in the directory examples, or the benchmark in the
-- directory performance.

module Generics.Instant.Rewriting (
    -- * The class to signal availability of rewriting for a type
    -- * Top-level functions
    rewrite, rewriteM, validate, synthesise,

    -- * Building rewrite rules
    Template(..), (+->), (//), Rule,
    -- * Internal classes: might be necessary to add new base types
    Sampleable(..), Empty(..), HasRec(..), Finite, True, False,

    -- * Re-exported for convenience
  ) where

import Control.Monad  (join, liftM, liftM2)
import Data.Maybe     (fromJust, fromMaybe)
import Data.Typeable  (Typeable, gcast)

import Generics.Instant.Base
import Generics.Instant.Instances ()

-- | Typed references

data Ref :: * -> * -> * where
  Rz :: Ref a (a :*: gam)
  Rs :: Ref a gam -> Ref a (b :*: gam)

instance Eq (Ref a gam) where
  Rz   == Rz    = True
  Rs r == Rs r' = r == r'
  _    == _     = False

-- | The 'Rewritable' class is used to signal types that can be rewritten and 
-- to ``tie the recursive knot'' of the generic functions.

class (Representable a, Typeable a, Eq a, Empty (Rep a),
       Extensible (Rep a), Matchable (Rep a), Substitutable (Rep a),
       Sampleable (Rep a), Diffable (Rep a), Validatable (Rep a)) =>
  Rewritable a

instance Rewritable Int
instance Rewritable Float
instance Rewritable Char

-- Rewrite rules

-- Metavariables
type Metavar a gam = Ref a gam

-- | Schemes
class Extensible a where
  data Ext a :: * -> *
  toExt :: a -> Ext a U

type Scheme a gam = Ext (Rep a) gam :+: Metavar a gam

toScheme :: Rewritable a => a -> Scheme a U
toScheme = L . toExt . from

instance Extensible Int where
  newtype Ext Int gam = ExtInt Int
  toExt = ExtInt

instance Extensible Float where
  newtype Ext Float gam = ExtFloat Float
  toExt = ExtFloat

instance Extensible Char where
  newtype Ext Char gam = ExtChar Char
  toExt = ExtChar

instance (Extensible a, Extensible b) => Extensible (a :+: b) where
  newtype Ext (a :+: b) gam = ExtSum (Ext a gam :+: Ext b gam)

  toExt (L x) = ExtSum (L (toExt x))
  toExt (R y) = ExtSum (R (toExt y))

instance (Extensible a, Extensible b) => Extensible (a :*: b) where
  newtype Ext (a :*: b) gam = ExtCons (Ext a gam :*: Ext b gam)
  toExt (x :*: y) = ExtCons (toExt x :*: toExt y)

instance Extensible U where
  newtype Ext U gam = ExtNil U
  toExt = ExtNil

instance (Rewritable a) => Extensible (Rec a) where
  newtype Ext (Rec a) gam = ExtRec (Rec (Scheme a gam))
  toExt (Rec a) = ExtRec (Rec (toScheme a))

instance (Rewritable a) => Extensible (Var a) where
  newtype Ext (Var a) gam = ExtVar (Var (Scheme a gam))
  toExt (Var a) = ExtVar (Var (toScheme a))

instance (Extensible a) => Extensible (C c a) where
  newtype Ext (C c a) gam = ExtC (C c (Ext a gam))
  toExt (C a) = ExtC (C (toExt a))

-- Guards
type family Guard gam :: *
type instance Guard U         = Bool
type instance Guard (a :*: gam) = a -> Guard gam

-- Rules
data Rule' a gam =
  Rule' { lhs :: Scheme a gam, rhs :: Scheme a gam, guard :: Guard gam }

-- Rewriting

-- Substitutions
data Subst :: * -> * where
  Sz :: Subst U
  Ss :: Rewritable a => Maybe a -> Subst gam -> Subst (a :*: gam)

(!!!) :: Subst gam -> Ref a gam -> Maybe a
Ss mb _ !!! Rz   = mb
Ss _  s !!! Rs r = s !!! r
_       !!! _    = error "(!!!) failure"

total :: Monad m => Subst gam -> Ref a gam -> m a
total s r = maybe (fail "metavariable unbound") return (s !!! r)

class Nillable gam where
  empty :: Subst gam

instance Nillable U where
  empty = Sz

instance (Rewritable a, Nillable gam) => Nillable (a :*: gam) where
  empty = Ss Nothing empty

singleton :: Nillable gam => Ref a gam -> a -> Subst gam
singleton r x = update r x empty

update :: Ref a gam -> a -> Subst gam -> Subst gam
update Rz     x (Ss _ s)  = Ss (Just x) s
update (Rs r) x (Ss mb s) = Ss mb (update r x s)
update _      _ _         = error "update failure"

(+++) :: Monad m => Subst gam -> Subst gam -> m (Subst gam)
Sz               +++ Sz                          = return Sz
Ss mb@(Just x) s +++ Ss (Just x') s' | x == x'   = liftM (Ss mb) (s +++ s')
                                     | otherwise = fail "merging failure"
Ss mb@(Just _) s +++ Ss Nothing s'               = liftM (Ss mb) (s +++ s')
Ss Nothing s     +++ Ss mb s'                    = liftM (Ss mb) (s +++ s')
_                +++ _                           = error "(+++) failure"

-- | Matching
class Matchable a where
  match' :: (Nillable gam, Monad m) => Ext a gam -> a -> m (Subst gam)

match :: (Rewritable a, Nillable gam, Monad m) =>
         Scheme a gam -> a -> m (Subst gam)
match (L ext) x = match' ext (from x)
match (R r) x   = return (singleton r x)

instance Matchable Int where
  match' (ExtInt n) n' | n == n'   = return empty
                       | otherwise = fail "structure mismatch"

instance Matchable Char where
  match' (ExtChar c) c' | c == c'   = return empty
                        | otherwise = fail "structure mismatch"

instance Matchable Float where
  match' (ExtFloat f) f' | f == f'   = return empty
                         | otherwise = fail "structure mismatch"

instance (Matchable a, Matchable b) => Matchable (a :+: b) where
  match' (ExtSum (L ext)) (L x) = match' ext x
  match' (ExtSum (R ext)) (R y) = match' ext y
  match' _                  _       = fail "structure mismatch"

instance (Matchable a, Matchable b) => Matchable (a :*: b) where
  match' (ExtCons (ext :*: ext')) (x :*: y) =
    join (liftM2 (+++) (match' ext x) (match' ext' y))

instance Matchable U where
  match' (ExtNil U) U = return empty

instance (Rewritable a) => Matchable (Var a) where
  match' (ExtVar (Var e)) (Var a) = match e a

instance (Rewritable a) => Matchable (Rec a) where
  match' (ExtRec (Rec e)) (Rec a) = match e a

instance (Matchable a) => Matchable (C c a) where
  match' (ExtC (C e)) (C a) = match' e a

-- | Substituting
class Substitutable a where
  subst' :: Monad m => Subst gam -> Ext a gam -> m a

subst :: (Rewritable a, Monad m) => Subst gam -> Scheme a gam -> m a
subst s (L ext) = liftM to (subst' s ext)
subst s (R r)   = total s r

instance Substitutable Int where
  subst' _ (ExtInt n) = return n

instance Substitutable Char where
  subst' _ (ExtChar c) = return c

instance Substitutable Float where
  subst' _ (ExtFloat f) = return f

instance (Substitutable a, Substitutable b) => Substitutable (a :+: b) where
  subst' s (ExtSum (L ext)) = liftM L (subst' s ext)
  subst' s (ExtSum (R ext)) = liftM R (subst' s ext)

instance (Substitutable a, Substitutable b) => Substitutable (a :*: b) where
  subst' s (ExtCons (ext :*: ext')) = liftM2 (:*:) (subst' s ext) (subst' s ext')

instance Substitutable U where
  subst' _ (ExtNil U) = return U

instance (Rewritable a) => Substitutable (Rec a) where
  subst' s (ExtRec (Rec scheme)) = liftM Rec (subst s scheme)

instance (Rewritable a) => Substitutable (Var a) where
  subst' s (ExtVar (Var scheme)) = liftM Var (subst s scheme)

instance (Substitutable a) => Substitutable (C c a) where
  subst' s (ExtC (C ext)) = liftM C (subst' s ext)

-- | Testing preconditions
class Testable gam where
  test :: Subst gam -> Guard gam -> Bool

instance Testable U where
  test Sz b = b

instance Testable gam => Testable (a :*: gam) where
  test (Ss (Just x) s) f = test s (f x)
  test (Ss Nothing _)  _ = error "invalid rule"

-- Rewriting
rewrite' :: (Rewritable a, Nillable gam, Testable gam) => Rule' a gam -> a -> a
rewrite' rule x = fromMaybe x (rewriteM' rule x)

rewriteM' :: (Rewritable a, Nillable gam, Testable gam, Monad m) 
         => Rule' a gam -> a -> m a
rewriteM' rule x =
  do s <- match (lhs rule) x
     if (test s (guard rule)) then subst s (rhs rule) else fail "guard failed"

-- Synthesising rules

-- | Sampling
class Sampleable a where
  left'  :: a
  right' :: a

left, right :: Rewritable a => a
left  = to left'
right = to right'

instance (Bounded a) => Sampleable a where
  left'  = minBound
  right' = maxBound

instance Sampleable Float where
  left'  = 0
  right' = 1

instance (Representable a, Empty (Rep a), Representable b, Empty (Rep b)) => Sampleable (a :+: b) where
  left'  = L gempty
  right' = R gempty

instance (Sampleable a, Sampleable b) => Sampleable (a :*: b) where
  left'  = left'  :*: left'
  right' = right' :*: right'

instance Sampleable U where
  left'  = U
  right' = U

instance (Rewritable a) => Sampleable (Rec a) where
  left'  = Rec left
  right' = Rec right

instance (Rewritable a) => Sampleable (Var a) where
  left'  = Var left
  right' = Var right

instance (Sampleable a) => Sampleable (C c a) where
  left'  = C left' 
  right' = C right'

-- | Diff

class Diffable a where
  diff' :: Typeable b => Ext a gam -> Ext a gam -> Maybe (Ext a (b :*: gam))

diff :: (Rewritable a, Typeable b) =>
        Scheme a gam -> Scheme a gam -> Maybe (Scheme a (b :*: gam))
diff (L ext) (L ext')         = 
  maybe (scast (R Rz)) (Just . L) (diff' ext ext')
diff (R r)   (R r') | r == r' = Just (R (Rs r))
diff _       _                = Nothing

newtype FlipScheme gam a = Flip {unFlip :: Scheme a gam}

scast :: (Typeable a, Typeable b) =>
  Scheme b (b :*: gam) -> Maybe (Scheme a (b :*: gam))
scast = fmap unFlip . gcast . Flip

(><) :: (Rewritable a, Typeable b) =>
        Scheme a gam -> Scheme a gam -> (Scheme a (b :*: gam))
scheme >< scheme' = fromJust (diff scheme scheme')

instance Diffable Int where
  diff' (ExtInt n) (ExtInt n') | n == n'   = Just (ExtInt n)
                               | otherwise = Nothing

instance Diffable Char where
  diff' (ExtChar c) (ExtChar c') | c == c'   = Just (ExtChar c)
                                 | otherwise = Nothing

instance Diffable Float where
  diff' (ExtFloat f) (ExtFloat f') | f == f'   = Just (ExtFloat f)
                                   | otherwise = Nothing

instance (Diffable a, Diffable b) => Diffable (a :+: b) where
  diff' (ExtSum (L ext)) (ExtSum (L ext')) =
    fmap (ExtSum . L) (diff' ext ext')
  diff' (ExtSum (R ext)) (ExtSum (R ext')) =
    fmap (ExtSum . R) (diff' ext ext')
  diff' _ _ = Nothing

instance (Diffable a, Diffable b) => Diffable (a :*: b) where
  diff' (ExtCons (a :*: ext)) (ExtCons (b :*: ext')) =
    fmap ExtCons (liftM2 (:*:) (diff' a b) (diff' ext ext'))

instance Diffable U where
  diff' (ExtNil U) (ExtNil U) = Just (ExtNil U)

instance (Rewritable a) => Diffable (Rec a) where
  diff' (ExtRec (Rec s)) (ExtRec (Rec s')) = fmap ExtRec (liftM Rec (diff s s'))

instance (Rewritable a) => Diffable (Var a) where
  diff' (ExtVar (Var s)) (ExtVar (Var s')) = fmap ExtVar (liftM Var (diff s s'))

instance (Diffable a) => Diffable (C c a) where
  diff' (ExtC (C a)) (ExtC (C b)) = fmap ExtC (liftM C (diff' a b))

-- | Templates
data Template a = Template a a Bool

infix 0 //
infix 1 +->

(+->) :: a -> a -> Template a
l +-> r = Template l r True

(//) :: Template a -> Bool -> Template a
Template l r _ // g = Template l r g

-- | Synthesising rules from metasyntax specifications
class (Rewritable (Obj a)) => IsRule a where
  type Obj a :: *
  type Env a :: *

  synthesise' :: a -> Rule' (Obj a) (Env a)

instance (Rewritable a) => IsRule (Template a) where
  type Obj (Template a) = a
  type Env (Template a) = U

  synthesise' (Template l r g) =
    Rule' {lhs = toScheme l, rhs = toScheme r, guard = g}

instance (Rewritable a, IsRule b) => IsRule (a -> b) where
  type Obj (a -> b) = Obj b
  type Env (a -> b) = a :*: Env b

  synthesise' f = Rule'
    { lhs = lhs l >< lhs r
    , rhs = rhs l >< rhs r
    , guard = guard . synthesise' . f
      l = synthesise' (f left)
      r = synthesise' (f right)

-- | Validating synthesised rules
class Validatable a where
  record :: Ext a gam -> Record gam -> Record gam
  record _ = id

data Record :: * -> * where
  RNil  :: Record U
  RCons :: Bool -> Record gam -> Record (a :*: gam)

class Recordable gam where
  blank :: Record gam

instance Recordable U where
  blank = RNil

instance Recordable gam => Recordable (a :*: gam) where
  blank = RCons False blank

record' :: Rewritable a => Scheme a gam -> Record gam -> Record gam
record' (L e)      rec           = record e rec
record' (R Rz)     (RCons _ rec) = RCons True rec
record' (R (Rs r)) (RCons b rec) = RCons b (record' (R r) rec)
record' _          _             = error "record' failure"

instance Validatable Int
instance Validatable Float
instance Validatable Char
instance Validatable U

instance (Validatable a, Validatable b) => Validatable (a :+: b) where
  record (ExtSum (L e)) = record e
  record (ExtSum (R e)) = record e

instance (Validatable a, Validatable b) => Validatable (a :*: b) where
  record (ExtCons (e :*: es)) = record e . record es

instance (Rewritable a) => Validatable (Rec a) where
  record (ExtRec (Rec e)) = record' e

instance (Rewritable a) => Validatable (Var a) where
  record (ExtVar (Var e)) = record' e

instance (Validatable a) => Validatable (C c a) where
  record (ExtC (C e)) = record e

validate' :: forall a gam. (Rewritable a, Recordable gam)
          => Rule' a gam -> Bool
validate' r = conj (record' (lhs r) blank)
    conj :: forall gam'. Record gam' -> Bool
    conj RNil          = True
    conj (RCons b rec) = b && conj rec

-- | Type level validation for the datatypes to be rewritten: there can be no
-- recursive calls on the leftmost constructor

class Empty a where
  empty' :: a

instance Empty U where empty' = U
instance (HasRec a, Empty a, Empty b) => Empty (a :+: b) where
  empty' = if hasRec' (empty' :: a) then R empty' else L empty'
instance (Empty a, Empty b) => Empty (a :*: b) where
  empty' = empty' :*: empty'
instance (Empty a) => Empty (C c a) where
  empty' = C empty'

instance (Rewritable a) => Empty (Var a) where
  empty' = Var gempty

instance (Rewritable a) => Empty (Rec a) where
  empty' = Rec gempty

instance Empty Int   where empty' = 0
instance Empty Float where empty' = 0
instance Empty Char  where empty' = '\NUL'

-- Dispatcher
gempty :: (Representable a, Empty (Rep a)) => a
gempty = to empty'

-- Used to avoid producing infinite values
class HasRec a where
  hasRec' :: a -> Bool
  hasRec' _ = False
instance HasRec U
instance HasRec (Var a)

instance (HasRec a, HasRec b) => HasRec (a :*: b) where
  hasRec' (a :*: b) = hasRec' a || hasRec' b
instance (HasRec a, HasRec b) => HasRec (a :+: b) where
  hasRec' (L x) = hasRec' x
  hasRec' (R x) = hasRec' x

instance (HasRec a) => HasRec (C c a) where
  hasRec' (C x) = hasRec' x
instance HasRec (Rec a) where
  hasRec' _ = True
instance HasRec Int
instance HasRec Float
instance HasRec Char

type family Finite a :: *
type instance Finite Int       = True
type instance Finite Float     = True
type instance Finite Char      = True
type instance Finite U         = True
type instance Finite (a :+: b) = Or (Finite a) (Finite b)
type instance Finite (a :*: b) = And (Finite a) (Finite b)
type instance Finite (Rec a)   = False
type instance Finite (Var a)   = True
type instance Finite (C c a)   = Finite a

data True
data False

type family And p q
type instance And True  True  = True
type instance And True  False = False
type instance And False True  = False
type instance And False False = False

type family Or p q
type instance Or True  True  = True
type instance Or True  False = True
type instance Or False True  = True
type instance Or False False = False

type family FiniteEnv gam
type instance FiniteEnv U = True
type instance FiniteEnv (a :*: gam) = And (Finite (Rep a)) (FiniteEnv gam)

-- Convenience: hiding the environment to the user Rule' type

-- | Rules
data Rule a where
  Rule :: (Nillable gam, Recordable gam, Testable gam)
       => Rule' a gam -> Rule a

-- | Validate a rewrite rule
validate :: Rewritable a => Rule a -> Bool
validate (Rule rule) = validate' rule

-- | Synthesise a function into a rewrite rule
synthesise :: (IsRule a, Nillable (Env a), Recordable (Env a), 
               Testable (Env a), FiniteEnv (Env a) ~ True)
           => a -> Rule (Obj a)
synthesise r = Rule (synthesise' r)

-- | Rewrite a term. The term is unchanged if the rule cannot be applied.
rewrite :: Rewritable a => Rule a -> a -> a
rewrite (Rule rule) x = rewrite' rule x

-- | Rewrite a term. Monad 'fail' is used if the rule cannot be applied.
rewriteM :: (Rewritable a, Monad m) => Rule a -> a -> m a
rewriteM (Rule rule) x = rewriteM' rule x