{-# Language DataKinds, DefaultSignatures, FlexibleContexts, FlexibleInstances, GeneralizedNewtypeDeriving,
             InstanceSigs, MultiParamTypeClasses, PolyKinds, QuantifiedConstraints,
             RankNTypes, ScopedTypeVariables, StandaloneDeriving,
             TypeApplications, TypeFamilies, TypeOperators, UndecidableInstances #-}

-- | This module can be used to scrap the boilerplate attribute declarations. In particular:
--
-- * If an 'attribution' rule always merely copies the inherited attributes to the children's inherited attributes of
--   the same name, the rule can be left out by wrapping the transformation into an 'Auto' constructor and deriving
--   the 'Generic' instance of the inherited attributes.
-- * A synthesized attribute whose value is a fold of all same-named attributes of the children can be wrapped in the
--   'Folded' constructor and calculated automatically.
-- * A synthesized attribute that is a copy of the current node but with every child taken from the same-named
--   synthesized child attribute can be wrapped in the 'Mapped' constructor and calculated automatically.
-- * If the attribute additionally carries an applicative effect, the 'Mapped' wrapper can be replaced by 'Traversed'.

module Transformation.AG.Generics (-- * Type wrappers for automatic attribute inference
                                   Auto(..), Folded(..), Mapped(..), Traversed(..),
                                   -- * Type classes replacing 'Attribution'
                                   Bequether(..), Synthesizer(..), SynthesizedField(..),
                                   -- * The default behaviour on generic datatypes
                                   foldedField, mappedField, passDown, bequestDefault)
where

import Data.Functor.Compose (Compose(..))
import Data.Functor.Const (Const(..))
import Data.Kind (Type)
import Data.Generics.Product.Subtype (Subtype(upcast))
import Data.Proxy (Proxy(..))
import GHC.Generics
import GHC.Records
import GHC.TypeLits (Symbol, ErrorMessage (Text), TypeError)
import Unsafe.Coerce (unsafeCoerce)
import qualified Rank2
import Transformation (Transformation, Codomain)
import Transformation.AG
import qualified Transformation
import qualified Transformation.Shallow as Shallow

-- | Transformation wrapper that allows automatic inference of attribute rules.
newtype Auto t = Auto t

type instance Atts (Inherited (Auto t)) x = Atts (Inherited t) x
type instance Atts (Synthesized (Auto t)) x = Atts (Synthesized t) x

instance Attribution t => Attribution (Auto t) where
   type Origin (Auto t) = Origin t
   unwrap (Auto t) = unwrap t

instance {-# overlappable #-} (Attribution t, Bequether (Auto t) g, Synthesizer (Auto t) g) =>
                              Auto t `At` g where
   attribution t l (Inherited i, s) = (Synthesized $ synthesis t l i s, bequest t l i s)

-- | A half of the 'Attribution' class used to specify all inherited attributes.
class Bequether t g where
   bequest     :: forall sem.
                  t                                -- ^ transformation
               -> Origin t (g sem sem)             -- ^ tree node
               -> Atts (Inherited t) g             -- ^ inherited attributes
               -> g sem (Synthesized t)            -- ^ synthesized attributes
               -> g sem (Inherited t)

-- | A half of the 'Attribution' class used to specify all synthesized attributes.
class Attribution t => Synthesizer t g where
   synthesis   :: forall sem.
                  t                                -- ^ transformation
               -> Origin t (g sem sem)             -- ^ tree node
               -> Atts (Inherited t) g             -- ^ inherited attributes
               -> g sem (Synthesized t)            -- ^ synthesized attributes
               -> Atts (Synthesized t) g

-- | Class for specifying a single named attribute
class Attribution t => SynthesizedField (name :: Symbol) result t g where
   synthesizedField  :: forall sem.
                        Proxy name                      -- ^ attribute name
                     -> t                               -- ^ transformation
                     -> Origin t (g sem sem)            -- ^ tree node
                     -> Atts (Inherited t) g            -- ^ inherited attributes
                     -> g sem (Synthesized t)           -- ^ synthesized attributes
                     -> result

instance {-# overlappable #-} (Attribution t, a ~ Atts (Inherited (Auto t)) g,
                               forall deep. Shallow.Functor (PassDown (Auto t) deep a) (g deep)) =>
                              Bequether (Auto t) g where
   bequest = bequestDefault

instance {-# overlappable #-} (Attribution t, Atts (Synthesized (Auto t)) g ~ result, Generic result,
                               GenericSynthesizer (Auto t) g (Rep result)) => Synthesizer (Auto t) g where
   synthesis t node i s = to (genericSynthesis t node i s)

-- | Wrapper for a field that should be automatically synthesized by folding together all child nodes' synthesized
-- attributes of the same name.
newtype Folded a = Folded{getFolded :: a} deriving (Eq, Ord, Show, Semigroup, Monoid)
-- | Wrapper for a field that should be automatically synthesized by replacing every child node by its synthesized
-- attribute of the same name.
newtype Mapped f a = Mapped{getMapped :: f a}
                   deriving (Eq, Ord, Show, Semigroup, Monoid, Functor, Applicative, Monad, Foldable)

-- | Wrapper for a field that should be automatically synthesized by traversing over all child nodes and applying each
-- node's synthesized attribute of the same name.
newtype Traversed m f g = Traversed{getTraversed :: m (f (g f f))} --deriving (Eq, Ord, Show, Semigroup, Monoid)

-- * Generic transformations

-- | Internal transformation for passing down the inherited attributes.
newtype PassDown (t :: Type) (f :: Type -> Type) a = PassDown a
-- | Internal transformation for accumulating the 'Folded' attributes.
data Accumulator (t :: Type) (name :: Symbol) (a :: Type) = Accumulator
-- | Internal transformation for replicating the 'Mapped' attributes.
data Replicator (t :: Type) (f :: Type -> Type) (name :: Symbol) = Replicator
-- | Internal transformation for traversing the 'Traversed' attributes.
data Traverser (t :: Type) (m :: Type -> Type) (f :: Type -> Type) (name :: Symbol) = Traverser

instance Transformation (PassDown t f a) where
  type Domain (PassDown t f a) = f
  type Codomain (PassDown t f a) = Inherited t

instance Transformation (Accumulator t name a) where
  type Domain (Accumulator t name a) = Synthesized t
  type Codomain (Accumulator t name a) = Const (Folded a)

instance Transformation (Replicator t f name) where
  type Domain (Replicator t f name) = Synthesized t
  type Codomain (Replicator t f name) = f

instance Transformation (Traverser t m f name) where
  type Domain (Traverser t m f name) = Synthesized t
  type Codomain (Traverser t m f name) = Compose m f

instance Subtype (Atts (Inherited t) (NodeConstructor a)) b => Transformation.At (PassDown t f b) a where
   ($) (PassDown i) _ = Inherited (upcast i)

instance (Monoid a, r ~ Atts (Synthesized t) (NodeConstructor x), Generic r,
          MayHaveMonoidalField name (Folded a) (Rep r)) =>
         Transformation.At (Accumulator t name a) x where
   _ $ Synthesized r = Const (getMonoidalField (Proxy :: Proxy name) $ from r)

instance (HasField name (Atts (Synthesized t) (NodeConstructor a)) (Mapped f a)) => Transformation.At (Replicator t f name) a where
   _ $ Synthesized r = getMapped (getField @name r)

instance (HasField name (Atts (Synthesized t) g) (Traversed m f g)) =>
         Transformation.At (Traverser t m f name) (g f f) where
   _ $ Synthesized r = Compose (getTraversed $ getField @name r)

-- * Generic classes

-- | The 'Generic' mirror of 'Synthesizer'
class GenericSynthesizer t g result where
   genericSynthesis  :: forall a sem.
                        t
                     -> Origin t (g sem sem)
                     -> Atts (Inherited t) g
                     -> g sem (Synthesized t)
                     -> result a

-- | The 'Generic' mirror of 'SynthesizedField'
class Attribution t => GenericSynthesizedField (name :: Symbol) result t g where
   genericSynthesizedField  :: forall a sem.
                               Proxy name
                            -> t
                            -> Origin t (g sem sem)
                            -> Atts (Inherited t) g
                            -> g sem (Synthesized t)
                            -> result a

-- | Used for accumulating the 'Folded' fields 
class MayHaveMonoidalField (name :: Symbol) a f where
   getMonoidalField :: Proxy name -> f x -> a
class FoundField a f where
   getFoundField :: f x -> a

instance {-# overlaps #-} (MayHaveMonoidalField name a x, MayHaveMonoidalField name a y, Semigroup a) =>
         MayHaveMonoidalField name a (x :*: y) where
   getMonoidalField name (x :*: y) = getMonoidalField name x <> getMonoidalField name y

instance {-# overlaps #-} TypeError ('Text "Cannot get a single field value from a sum type") =>
         MayHaveMonoidalField name a (x :+: y) where
   getMonoidalField _ _ = error "getMonoidalField on sum type"

instance {-# overlaps #-} FoundField a f => MayHaveMonoidalField name a (M1 i ('MetaSel ('Just name) su ss ds) f) where
   getMonoidalField _ (M1 x) = getFoundField x

instance {-# overlaps #-} Monoid a => MayHaveMonoidalField name a (M1 i ('MetaSel 'Nothing su ss ds) f) where
   getMonoidalField _ _ = mempty

instance {-# overlaps #-} MayHaveMonoidalField name a f => MayHaveMonoidalField name a (M1 i ('MetaData n m p nt) f) where
   getMonoidalField name (M1 x) = getMonoidalField name x

instance {-# overlaps #-} MayHaveMonoidalField name a f => MayHaveMonoidalField name a (M1 i ('MetaCons n fi s) f) where
   getMonoidalField name (M1 x) = getMonoidalField name x

instance {-# overlappable #-} Monoid a => MayHaveMonoidalField name a f where
   getMonoidalField _ _ = mempty

instance FoundField a f => FoundField a (M1 i j f) where
     getFoundField (M1 f) = getFoundField f

instance FoundField a (K1 i a) where
     getFoundField (K1 a) = a

instance (GenericSynthesizer t g x, GenericSynthesizer t g y) => GenericSynthesizer t g (x :*: y) where
   genericSynthesis t node i s = genericSynthesis t node i s :*: genericSynthesis t node i s

instance {-# overlappable #-} GenericSynthesizer t g f =>
                              GenericSynthesizer t g (M1 i meta f) where
   genericSynthesis t node i s = M1 (genericSynthesis t node i s)

instance {-# overlaps #-} GenericSynthesizedField name f t g =>
                          GenericSynthesizer t g (M1 i ('MetaSel ('Just name) su ss ds) f) where
   genericSynthesis t node i s = M1 (genericSynthesizedField (Proxy :: Proxy name) t node i s)

instance SynthesizedField name a t g => GenericSynthesizedField name (K1 i a) t g where
   genericSynthesizedField name t node i s = K1 (synthesizedField name t node i s)

instance  {-# overlappable #-} (Attribution t, Monoid a,
                                forall sem. Shallow.Foldable (Accumulator t name a) (g sem)) =>
                               SynthesizedField name (Folded a) t g where
   synthesizedField name t _ _ s = foldedField name t s

instance  {-# overlappable #-} (Attribution t, Origin t ~ f, Functor f,
                                Shallow.Functor (Replicator t f name) (g f)) =>
                               SynthesizedField name (Mapped f (g f f)) t g where
   synthesizedField name t local _ s = Mapped (mappedField name t s <$ local)

instance  {-# overlappable #-} (Attribution t, Origin t ~ f, Traversable f, Applicative m,
                                Shallow.Traversable (Traverser t m f name) (g f)) =>
                               SynthesizedField name (Traversed m f g) t g where
   synthesizedField name t local _ s = Traversed (traverse (const $ traversedField name t s) local)

-- | The default 'bequest' method definition relies on generics to automatically pass down all same-named inherited
-- attributes.
bequestDefault :: forall t g sem.
                  (Attribution t, Shallow.Functor (PassDown t sem (Atts (Inherited t) g)) (g sem))
               => t -> Origin t (g sem sem) -> Atts (Inherited t) g -> g sem (Synthesized t)
               -> g sem (Inherited t)
bequestDefault t local inheritance _synthesized = passDown @t inheritance (unwrap t local :: g sem sem)

-- | Pass down the given record of inherited fields to child nodes.
passDown :: forall t g shallow deep atts. (Shallow.Functor (PassDown t shallow atts) (g deep)) =>
            atts -> g deep shallow -> g deep (Inherited t)
-- unsafeCoerce is safe here because Inherited doesn't refer to deep functor so the latter is a phantom
passDown inheritance local = Rank2.coerce (PassDown @t inheritance Shallow.<$> local)

-- | The default 'synthesizedField' method definition for 'Folded' fields.
foldedField :: forall name t g a sem. (Monoid a, Shallow.Foldable (Accumulator t name a) (g sem)) =>
               Proxy name -> t -> g sem (Synthesized t) -> Folded a
foldedField _name _t s = Shallow.foldMap (Accumulator :: Accumulator t name a) s

-- | The default 'synthesizedField' method definition for 'Mapped' fields.
mappedField :: forall name t g f sem.
                  (Shallow.Functor (Replicator t f name) (g f)) =>
                  Proxy name -> t -> g sem (Synthesized t) -> g f f
-- unsafeCoerce is safe here because Synthesized doesn't refer to deep functor so the latter is a phantom
mappedField _name _t s = (Replicator :: Replicator t f name) Shallow.<$> (unsafeCoerce s :: g f (Synthesized t))

-- | The default 'synthesizedField' method definition for 'Traversed' fields.
traversedField :: forall name t g m f sem.
                     (Shallow.Traversable (Traverser t m f name) (g f)) =>
                     Proxy name -> t -> g sem (Synthesized t) -> m (g f f)
-- unsafeCoerce is safe here because Synthesized doesn't refer to deep functor so the latter is a phantom
traversedField _name _t s = Shallow.traverse (Traverser :: Traverser t m f name) (unsafeCoerce s :: g f (Synthesized t))
