{-# LANGUAGE DeriveDataTypeable   #-}
{-# LANGUAGE FlexibleContexts     #-}
{-# LANGUAGE GADTs                #-}
{-# LANGUAGE KindSignatures       #-}
{-# LANGUAGE TypeFamilies         #-}
{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE EmptyDataDecls       #-}
{-# LANGUAGE ScopedTypeVariables  #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleInstances    #-}
{-# LANGUAGE OverlappingInstances #-}

{-# OPTIONS_GHC -Wall             #-}

-----------------------------------------------------------------------------
-- |
-- Module      :  Generics.Instant.Rewriting
-- Copyright   :  (c) 2010, Universiteit Utrecht
-- License     :  BSD3
--
-- Maintainer  :  generics@haskell.org
-- Stability   :  experimental
-- Portability :  non-portable
--
-- This is the top module for the rewriting library. All functionality is
-- implemented in this module. For examples of how to use the library, see
-- the included files in the directory examples, or the benchmark in the
-- directory performance.
--
-----------------------------------------------------------------------------

module Generics.Instant.Rewriting (
    
    -- * The class to signal availability of rewriting for a type
    Rewritable,
  
    -- * Top-level functions
    rewrite, rewriteM, validate, synthesise,

    -- * Building rewrite rules
    Template(..), (+->), (//), Rule,
    
    -- * Internal classes: might be necessary to add new base types
    Extensible(..),
    Matchable(..),
    Substitutable(..),
    Sampleable(..), Empty(..), HasRec(..), Finite, True, False,
    Diffable(..),
    Validatable(..),
    Nillable(..),

    -- * Re-exported for convenience
    Typeable
    
  ) where

import Control.Monad  (join, liftM, liftM2)
import Data.Maybe     (fromJust, fromMaybe)
import Data.Typeable  (Typeable, gcast)

import Generics.Instant.Base
import Generics.Instant.Instances ()

-------------------------------------------------------------------------------
-- | Typed references
-------------------------------------------------------------------------------

data Ref :: * -> * -> * where
  Rz :: Ref a (a :*: gam)
  Rs :: Ref a gam -> Ref a (b :*: gam)

instance Eq (Ref a gam) where
  Rz   == Rz    = True
  Rs r == Rs r' = r == r'
  _    == _     = False

-------------------------------------------------------------------------------
-- | The 'Rewritable' class is used to signal types that can be rewritten and 
-- to ``tie the recursive knot'' of the generic functions.
-------------------------------------------------------------------------------

class (Representable a, Typeable a, Eq a, Empty (Rep a),
       Extensible (Rep a), Matchable (Rep a), Substitutable (Rep a),
       Sampleable (Rep a), Diffable (Rep a), Validatable (Rep a)) =>
  Rewritable a

instance Rewritable Int
instance Rewritable Float
instance Rewritable Char

-------------------------------------------------------------------------------
-- Rewrite rules
-------------------------------------------------------------------------------

-- Metavariables
type Metavar a gam = Ref a gam

-- | Schemes
class Extensible a where
  data Ext a :: * -> *
  toExt :: a -> Ext a U

type Scheme a gam = Ext (Rep a) gam :+: Metavar a gam

toScheme :: Rewritable a => a -> Scheme a U
toScheme = L . toExt . from

instance Extensible Int where
  newtype Ext Int gam = ExtInt Int
  toExt = ExtInt

instance Extensible Float where
  newtype Ext Float gam = ExtFloat Float
  toExt = ExtFloat

instance Extensible Char where
  newtype Ext Char gam = ExtChar Char
  toExt = ExtChar

instance (Extensible a, Extensible b) => Extensible (a :+: b) where
  newtype Ext (a :+: b) gam = ExtSum (Ext a gam :+: Ext b gam)

  toExt (L x) = ExtSum (L (toExt x))
  toExt (R y) = ExtSum (R (toExt y))

instance (Extensible a, Extensible b) => Extensible (a :*: b) where
  newtype Ext (a :*: b) gam = ExtCons (Ext a gam :*: Ext b gam)
  toExt (x :*: y) = ExtCons (toExt x :*: toExt y)

instance Extensible U where
  newtype Ext U gam = ExtNil U
  toExt = ExtNil

instance (Rewritable a) => Extensible (Rec a) where
  newtype Ext (Rec a) gam = ExtRec (Rec (Scheme a gam))
  toExt (Rec a) = ExtRec (Rec (toScheme a))

instance (Rewritable a) => Extensible (Var a) where
  newtype Ext (Var a) gam = ExtVar (Var (Scheme a gam))
  toExt (Var a) = ExtVar (Var (toScheme a))

instance (Extensible a) => Extensible (C c a) where
  newtype Ext (C c a) gam = ExtC (C c (Ext a gam))
  toExt (C a) = ExtC (C (toExt a))

-- Guards
type family Guard gam :: *
type instance Guard U         = Bool
type instance Guard (a :*: gam) = a -> Guard gam

-- Rules
data Rule' a gam =
  Rule' { lhs :: Scheme a gam, rhs :: Scheme a gam, guard :: Guard gam }

-------------------------------------------------------------------------------
-- Rewriting
-------------------------------------------------------------------------------

-- Substitutions
data Subst :: * -> * where
  Sz :: Subst U
  Ss :: Rewritable a => Maybe a -> Subst gam -> Subst (a :*: gam)

(!!!) :: Subst gam -> Ref a gam -> Maybe a
Ss mb _ !!! Rz   = mb
Ss _  s !!! Rs r = s !!! r
_       !!! _    = error "(!!!) failure"

total :: Monad m => Subst gam -> Ref a gam -> m a
total s r = maybe (fail "metavariable unbound") return (s !!! r)

class Nillable gam where
  empty :: Subst gam

instance Nillable U where
  empty = Sz

instance (Rewritable a, Nillable gam) => Nillable (a :*: gam) where
  empty = Ss Nothing empty

singleton :: Nillable gam => Ref a gam -> a -> Subst gam
singleton r x = update r x empty

update :: Ref a gam -> a -> Subst gam -> Subst gam
update Rz     x (Ss _ s)  = Ss (Just x) s
update (Rs r) x (Ss mb s) = Ss mb (update r x s)
update _      _ _         = error "update failure"

(+++) :: Monad m => Subst gam -> Subst gam -> m (Subst gam)
Sz               +++ Sz                          = return Sz
Ss mb@(Just x) s +++ Ss (Just x') s' | x == x'   = liftM (Ss mb) (s +++ s')
                                     | otherwise = fail "merging failure"
Ss mb@(Just _) s +++ Ss Nothing s'               = liftM (Ss mb) (s +++ s')
Ss Nothing s     +++ Ss mb s'                    = liftM (Ss mb) (s +++ s')
_                +++ _                           = error "(+++) failure"

-- | Matching
class Matchable a where
  match' :: (Nillable gam, Monad m) => Ext a gam -> a -> m (Subst gam)

match :: (Rewritable a, Nillable gam, Monad m) =>
         Scheme a gam -> a -> m (Subst gam)
match (L ext) x = match' ext (from x)
match (R r) x   = return (singleton r x)

instance Matchable Int where
  match' (ExtInt n) n' | n == n'   = return empty
                       | otherwise = fail "structure mismatch"

instance Matchable Char where
  match' (ExtChar c) c' | c == c'   = return empty
                        | otherwise = fail "structure mismatch"

instance Matchable Float where
  match' (ExtFloat f) f' | f == f'   = return empty
                         | otherwise = fail "structure mismatch"

instance (Matchable a, Matchable b) => Matchable (a :+: b) where
  match' (ExtSum (L ext)) (L x) = match' ext x
  match' (ExtSum (R ext)) (R y) = match' ext y
  match' _                  _       = fail "structure mismatch"

instance (Matchable a, Matchable b) => Matchable (a :*: b) where
  match' (ExtCons (ext :*: ext')) (x :*: y) =
    join (liftM2 (+++) (match' ext x) (match' ext' y))

instance Matchable U where
  match' (ExtNil U) U = return empty

instance (Rewritable a) => Matchable (Var a) where
  match' (ExtVar (Var e)) (Var a) = match e a

instance (Rewritable a) => Matchable (Rec a) where
  match' (ExtRec (Rec e)) (Rec a) = match e a

instance (Matchable a) => Matchable (C c a) where
  match' (ExtC (C e)) (C a) = match' e a

-- | Substituting
class Substitutable a where
  subst' :: Monad m => Subst gam -> Ext a gam -> m a

subst :: (Rewritable a, Monad m) => Subst gam -> Scheme a gam -> m a
subst s (L ext) = liftM to (subst' s ext)
subst s (R r)   = total s r

instance Substitutable Int where
  subst' _ (ExtInt n) = return n

instance Substitutable Char where
  subst' _ (ExtChar c) = return c

instance Substitutable Float where
  subst' _ (ExtFloat f) = return f

instance (Substitutable a, Substitutable b) => Substitutable (a :+: b) where
  subst' s (ExtSum (L ext)) = liftM L (subst' s ext)
  subst' s (ExtSum (R ext)) = liftM R (subst' s ext)

instance (Substitutable a, Substitutable b) => Substitutable (a :*: b) where
  subst' s (ExtCons (ext :*: ext')) = liftM2 (:*:) (subst' s ext) (subst' s ext')

instance Substitutable U where
  subst' _ (ExtNil U) = return U

instance (Rewritable a) => Substitutable (Rec a) where
  subst' s (ExtRec (Rec scheme)) = liftM Rec (subst s scheme)

instance (Rewritable a) => Substitutable (Var a) where
  subst' s (ExtVar (Var scheme)) = liftM Var (subst s scheme)

instance (Substitutable a) => Substitutable (C c a) where
  subst' s (ExtC (C ext)) = liftM C (subst' s ext)

-- | Testing preconditions
class Testable gam where
  test :: Subst gam -> Guard gam -> Bool

instance Testable U where
  test Sz b = b

instance Testable gam => Testable (a :*: gam) where
  test (Ss (Just x) s) f = test s (f x)
  test (Ss Nothing _)  _ = error "invalid rule"

-- Rewriting
rewrite' :: (Rewritable a, Nillable gam, Testable gam) => Rule' a gam -> a -> a
rewrite' rule x = fromMaybe x (rewriteM' rule x)

rewriteM' :: (Rewritable a, Nillable gam, Testable gam, Monad m) 
         => Rule' a gam -> a -> m a
rewriteM' rule x =
  do s <- match (lhs rule) x
     if (test s (guard rule)) then subst s (rhs rule) else fail "guard failed"

-------------------------------------------------------------------------------
-- Synthesising rules
-------------------------------------------------------------------------------

-- | Sampling
class Sampleable a where
  left'  :: a
  right' :: a

left, right :: Rewritable a => a
left  = to left'
right = to right'

instance (Bounded a) => Sampleable a where
  left'  = minBound
  right' = maxBound

instance Sampleable Float where
  left'  = 0
  right' = 1

instance (Representable a, Empty (Rep a), Representable b, Empty (Rep b)) => Sampleable (a :+: b) where
  left'  = L gempty
  right' = R gempty

instance (Sampleable a, Sampleable b) => Sampleable (a :*: b) where
  left'  = left'  :*: left'
  right' = right' :*: right'

instance Sampleable U where
  left'  = U
  right' = U

instance (Rewritable a) => Sampleable (Rec a) where
  left'  = Rec left
  right' = Rec right

instance (Rewritable a) => Sampleable (Var a) where
  left'  = Var left
  right' = Var right

instance (Sampleable a) => Sampleable (C c a) where
  left'  = C left' 
  right' = C right'

-- | Diff

class Diffable a where
  diff' :: Typeable b => Ext a gam -> Ext a gam -> Maybe (Ext a (b :*: gam))

diff :: (Rewritable a, Typeable b) =>
        Scheme a gam -> Scheme a gam -> Maybe (Scheme a (b :*: gam))
diff (L ext) (L ext')         = 
  maybe (scast (R Rz)) (Just . L) (diff' ext ext')
diff (R r)   (R r') | r == r' = Just (R (Rs r))
diff _       _                = Nothing

newtype FlipScheme gam a = Flip {unFlip :: Scheme a gam}

scast :: (Typeable a, Typeable b) =>
  Scheme b (b :*: gam) -> Maybe (Scheme a (b :*: gam))
scast = fmap unFlip . gcast . Flip

(><) :: (Rewritable a, Typeable b) =>
        Scheme a gam -> Scheme a gam -> (Scheme a (b :*: gam))
scheme >< scheme' = fromJust (diff scheme scheme')

instance Diffable Int where
  diff' (ExtInt n) (ExtInt n') | n == n'   = Just (ExtInt n)
                               | otherwise = Nothing

instance Diffable Char where
  diff' (ExtChar c) (ExtChar c') | c == c'   = Just (ExtChar c)
                                 | otherwise = Nothing

instance Diffable Float where
  diff' (ExtFloat f) (ExtFloat f') | f == f'   = Just (ExtFloat f)
                                   | otherwise = Nothing

instance (Diffable a, Diffable b) => Diffable (a :+: b) where
  diff' (ExtSum (L ext)) (ExtSum (L ext')) =
    fmap (ExtSum . L) (diff' ext ext')
  diff' (ExtSum (R ext)) (ExtSum (R ext')) =
    fmap (ExtSum . R) (diff' ext ext')
  diff' _ _ = Nothing

instance (Diffable a, Diffable b) => Diffable (a :*: b) where
  diff' (ExtCons (a :*: ext)) (ExtCons (b :*: ext')) =
    fmap ExtCons (liftM2 (:*:) (diff' a b) (diff' ext ext'))

instance Diffable U where
  diff' (ExtNil U) (ExtNil U) = Just (ExtNil U)

instance (Rewritable a) => Diffable (Rec a) where
  diff' (ExtRec (Rec s)) (ExtRec (Rec s')) = fmap ExtRec (liftM Rec (diff s s'))

instance (Rewritable a) => Diffable (Var a) where
  diff' (ExtVar (Var s)) (ExtVar (Var s')) = fmap ExtVar (liftM Var (diff s s'))

instance (Diffable a) => Diffable (C c a) where
  diff' (ExtC (C a)) (ExtC (C b)) = fmap ExtC (liftM C (diff' a b))

-- | Templates
data Template a = Template a a Bool

infix 0 //
infix 1 +->

(+->) :: a -> a -> Template a
l +-> r = Template l r True

(//) :: Template a -> Bool -> Template a
Template l r _ // g = Template l r g

-- | Synthesising rules from metasyntax specifications
class (Rewritable (Obj a)) => IsRule a where
  type Obj a :: *
  type Env a :: *

  synthesise' :: a -> Rule' (Obj a) (Env a)

instance (Rewritable a) => IsRule (Template a) where
  type Obj (Template a) = a
  type Env (Template a) = U

  synthesise' (Template l r g) =
    Rule' {lhs = toScheme l, rhs = toScheme r, guard = g}

instance (Rewritable a, IsRule b) => IsRule (a -> b) where
  type Obj (a -> b) = Obj b
  type Env (a -> b) = a :*: Env b

  synthesise' f = Rule'
    { lhs = lhs l >< lhs r
    , rhs = rhs l >< rhs r
    , guard = guard . synthesise' . f
    }
    where
      l = synthesise' (f left)
      r = synthesise' (f right)

-- | Validating synthesised rules
class Validatable a where
  record :: Ext a gam -> Record gam -> Record gam
  record _ = id

data Record :: * -> * where
  RNil  :: Record U
  RCons :: Bool -> Record gam -> Record (a :*: gam)

class Recordable gam where
  blank :: Record gam

instance Recordable U where
  blank = RNil

instance Recordable gam => Recordable (a :*: gam) where
  blank = RCons False blank

record' :: Rewritable a => Scheme a gam -> Record gam -> Record gam
record' (L e)      rec           = record e rec
record' (R Rz)     (RCons _ rec) = RCons True rec
record' (R (Rs r)) (RCons b rec) = RCons b (record' (R r) rec)
record' _          _             = error "record' failure"

instance Validatable Int
instance Validatable Float
instance Validatable Char
instance Validatable U

instance (Validatable a, Validatable b) => Validatable (a :+: b) where
  record (ExtSum (L e)) = record e
  record (ExtSum (R e)) = record e

instance (Validatable a, Validatable b) => Validatable (a :*: b) where
  record (ExtCons (e :*: es)) = record e . record es

instance (Rewritable a) => Validatable (Rec a) where
  record (ExtRec (Rec e)) = record' e

instance (Rewritable a) => Validatable (Var a) where
  record (ExtVar (Var e)) = record' e

instance (Validatable a) => Validatable (C c a) where
  record (ExtC (C e)) = record e

validate' :: forall a gam. (Rewritable a, Recordable gam)
          => Rule' a gam -> Bool
validate' r = conj (record' (lhs r) blank)
  where
    conj :: forall gam'. Record gam' -> Bool
    conj RNil          = True
    conj (RCons b rec) = b && conj rec

-------------------------------------------------------------------------------
-- | Type level validation for the datatypes to be rewritten: there can be no
-- recursive calls on the leftmost constructor
-------------------------------------------------------------------------------

class Empty a where
  empty' :: a

instance Empty U where empty' = U
  
instance (HasRec a, Empty a, Empty b) => Empty (a :+: b) where
  empty' = if hasRec' (empty' :: a) then R empty' else L empty'
  
instance (Empty a, Empty b) => Empty (a :*: b) where
  empty' = empty' :*: empty'
  
instance (Empty a) => Empty (C c a) where
  empty' = C empty'

instance (Rewritable a) => Empty (Var a) where
  empty' = Var gempty

instance (Rewritable a) => Empty (Rec a) where
  empty' = Rec gempty

instance Empty Int   where empty' = 0
instance Empty Float where empty' = 0
instance Empty Char  where empty' = '\NUL'


-- Dispatcher
gempty :: (Representable a, Empty (Rep a)) => a
gempty = to empty'


-- Used to avoid producing infinite values
class HasRec a where
  hasRec' :: a -> Bool
  hasRec' _ = False
  
instance HasRec U
instance HasRec (Var a)

instance (HasRec a, HasRec b) => HasRec (a :*: b) where
  hasRec' (a :*: b) = hasRec' a || hasRec' b
  
instance (HasRec a, HasRec b) => HasRec (a :+: b) where
  hasRec' (L x) = hasRec' x
  hasRec' (R x) = hasRec' x

instance (HasRec a) => HasRec (C c a) where
  hasRec' (C x) = hasRec' x
  
instance HasRec (Rec a) where
  hasRec' _ = True
  
instance HasRec Int
instance HasRec Float
instance HasRec Char

type family Finite a :: *
type instance Finite Int       = True
type instance Finite Float     = True
type instance Finite Char      = True
type instance Finite U         = True
type instance Finite (a :+: b) = Or (Finite a) (Finite b)
type instance Finite (a :*: b) = And (Finite a) (Finite b)
type instance Finite (Rec a)   = False
type instance Finite (Var a)   = True
type instance Finite (C c a)   = Finite a

data True
data False

type family And p q
type instance And True  True  = True
type instance And True  False = False
type instance And False True  = False
type instance And False False = False

type family Or p q
type instance Or True  True  = True
type instance Or True  False = True
type instance Or False True  = True
type instance Or False False = False

type family FiniteEnv gam
type instance FiniteEnv U = True
type instance FiniteEnv (a :*: gam) = And (Finite (Rep a)) (FiniteEnv gam)

-------------------------------------------------------------------------------
-- Convenience: hiding the environment to the user Rule' type
-------------------------------------------------------------------------------

-- | Rules
data Rule a where
  Rule :: (Nillable gam, Recordable gam, Testable gam)
       => Rule' a gam -> Rule a

-- | Validate a rewrite rule
validate :: Rewritable a => Rule a -> Bool
validate (Rule rule) = validate' rule

-- | Synthesise a function into a rewrite rule
synthesise :: (IsRule a, Nillable (Env a), Recordable (Env a), 
               Testable (Env a), FiniteEnv (Env a) ~ True)
           => a -> Rule (Obj a)
synthesise r = Rule (synthesise' r)

-- | Rewrite a term. The term is unchanged if the rule cannot be applied.
rewrite :: Rewritable a => Rule a -> a -> a
rewrite (Rule rule) x = rewrite' rule x

-- | Rewrite a term. Monad 'fail' is used if the rule cannot be applied.
rewriteM :: (Rewritable a, Monad m) => Rule a -> a -> m a
rewriteM (Rule rule) x = rewriteM' rule x