{-# language CPP
           , TemplateHaskell #-}

-- | This module re-exports a subset of `Yaya.Fold`, intended for when you want
--   to define recursion scheme instances for your existing recursive types.
--
--   This is /not/ the recommended way to use Yaya, but it solves some real
--   problems:
-- 1. you have existing directly-recursive types and you want to start taking
--    advantage of recursion schemes without having to rewrite your existing
--    code, or
-- 2. a directly-recursive type has been imposed on you by some other library
--    and you want to take advantage of recursion schemes.
--
--   The distinction between these two cases is whether you have control of the
--   @data@ declaration. In the first case, you probably do. In that case, you
--   should only generate the /safe/ instances, and ensure that all the
--   recursive type references are /strict/ (if you want a `Recursive`
--   instance). If you don't have control, then you /may/ need to generate all
--   instances.
--
--   Another difference when you have control is that it means you may migrate
--   away from direct recursion entirely, at which point this import should
--   disappear.
module Yaya.Retrofit
  ( module Yaya.Fold
  , PatternFunctorRules (..)
  , defaultRules
  , extractPatternFunctor
  ) where

import Control.Exception (Exception (..), throw)
import Control.Monad ((<=<))
import Data.Bifunctor (bimap)
import Data.Either.Validation (Validation (..), validationToEither)
import Data.Functor.Identity (Identity (..))
import Data.List.NonEmpty (NonEmpty)
import Language.Haskell.TH as TH
import Language.Haskell.TH.Datatype as TH.Abs
import Language.Haskell.TH.Syntax (mkNameG_tc)
import Text.Read.Lex (isSymbolChar)

import Yaya.Fold
       ( Corecursive (..)
       , Projectable (..)
       , Recursive (..)
       , Steppable (..)
       , recursiveEq
       , recursiveShowsPrec
       )

-- | Extract a pattern functor and relevant instances from a simply recursive type.
--
-- /e.g./
--
-- @
-- data Expr a
--     = Lit a
--     | Add (Expr a) (Expr a)
--     | Expr a :* [Expr a]
--   deriving (Show)
--
-- `extractPatternFunctor` `defaultRules` ''Expr
-- @
--
-- will create
--
-- @
-- data ExprF a x
--     = LitF a
--     | AddF x x
--     | x :*$ [x]
--   deriving ('Functor', 'Foldable', 'Traversable')
--
-- instance `Projectable` (->) (Expr a) (ExprF a) where
--   `project` (Lit x)   = LitF x
--   `project` (Add x y) = AddF x y
--   `project` (x :* y)  = x :*$ y
--
-- instance `Steppable` (->) (Expr a) (ExprF a) where
--   `embed` (LitF x)   = Lit x
--   `embed` (AddF x y) = Add x y
--   `embed` (x :*$ y)  = x :* y
--
-- instance `Recursive` (->) (Expr a) (ExprF a) where
--   `cata` φ = φ . `fmap` (`cata` φ) . `project`
--
-- instance `Corecursive` (->) (Expr a) (ExprF a) where
--   `ana` ψ = `embed` . `fmap` (`ana` ψ) . ψ
-- @
--
-- /Notes:/
--
-- - `extractPatternFunctor` works properly only with ADTs.
--   Existentials and GADTs aren't supported,
--   as we don't try to do better than
--   <https://downloads.haskell.org/~ghc/latest/docs/html/users_guide/glasgow_exts.html#deriving-functor-instances GHC's DeriveFunctor>.
-- - we always generate both `Recursive` and `Corecursive` instances, but one of these is always unsafe.
--   In future, we should check the strictness of the recursive parameter and generate only the appropriate one (unless overridden by a rule).
extractPatternFunctor :: PatternFunctorRules -> Name -> Q [Dec]
extractPatternFunctor rules =
  either throw id . makePrimForDI rules <=< reifyDatatype

-- | Rules of renaming data names
data PatternFunctorRules = PatternFunctorRules
    { patternType  :: Name -> Name
    , patternCon   :: Name -> Name
    , patternField :: Name -> Name
    }

-- | Default 'PatternFunctorRules': append @F@ or @$@ to data type, constructors and field names.
defaultRules :: PatternFunctorRules
defaultRules = PatternFunctorRules
    { patternType  = toFName
    , patternCon   = toFName
    , patternField = toFName
    }

toFName :: Name -> Name
toFName = mkName . f . nameBase
  where
    f name | isInfixName name = name ++ "$"
           | otherwise        = name ++ "F"

    isInfixName :: String -> Bool
    isInfixName = all isSymbolChar

data UnsupportedDatatype
  = UnsupportedInstTypes (NonEmpty Type)
  | UnsupportedVariant DatatypeVariant

instance Show UnsupportedDatatype where
  show = \case
    UnsupportedInstTypes tys ->
      "extractPatternFunctor: Couldn't process the following types " <> show tys
    UnsupportedVariant _variant ->
      "extractPatternFunctor: Data families are currently not supported."

instance Exception UnsupportedDatatype

makePrimForDI
  :: PatternFunctorRules -> DatatypeInfo -> Either UnsupportedDatatype (Q [Dec])
makePrimForDI
  rules
  (DatatypeInfo { datatypeName      = tyName
                , datatypeInstTypes = instTys
                , datatypeCons      = cons
                , datatypeVariant   = variant }) =
  if isDataFamInstance
  then Left $ UnsupportedVariant variant
  else
    bimap
    UnsupportedInstTypes
    (flip (makePrimForDI' rules (variant == Newtype) tyName) cons)
    . validationToEither
    $ traverse (\ty -> maybe (Failure $ pure ty) Success $ toTyVarBndr ty) instTys
  where
    isDataFamInstance = case variant of
                          DataInstance    -> True
                          NewtypeInstance -> True
                          Datatype        -> False
                          Newtype         -> False

    toTyVarBndr :: Type -> Maybe TyVarBndr
    toTyVarBndr (VarT n)          = pure $ PlainTV n
    toTyVarBndr (SigT (VarT n) k) = pure $ KindedTV n k
    toTyVarBndr _                 = Nothing

makePrimForDI'
  :: PatternFunctorRules -> Bool -> Name -> [TyVarBndr] -> [ConstructorInfo] -> Q [Dec]
makePrimForDI' rules isNewtype tyName vars cons = do
    -- variable parameters
    let vars' = map VarT (typeVars vars)
    -- Name of base functor
    let tyNameF = patternType rules tyName
    -- Recursive type
    let s = conAppsT tyName vars'
    -- Additional argument
    rName <- newName "r"
    let r = VarT rName

    -- Vars
    let varsF = vars ++ [PlainTV rName]

    -- #33
    cons' <- traverse (conTypeTraversal resolveTypeSynonyms) cons
    let consF
          = toCon
          . conNameMap (patternCon rules)
          . conFieldNameMap (patternField rules)
          . conTypeMap (substType s r)
          <$> cons'

    -- Data definition
    let dataDec = case consF of
            [conF] | isNewtype ->
                NewtypeD [] tyNameF varsF Nothing conF deriveds
            _ -> DataD [] tyNameF varsF Nothing consF deriveds
          where
            deriveds =
-- TH 2.12.O means GHC 8.2.1, otherwise, we work back to GHC 8.0.1
#if MIN_VERSION_template_haskell(2,12,0)
              pure $ DerivClause Nothing
#endif
              [ ConT functorTypeName
              , ConT foldableTypeName
              , ConT traversableTypeName ]

    recursiveDec <-
      [d|
        instance Projectable (->) $(pure s) $(pure $ conAppsT tyNameF vars') where
          project = $(LamCaseE <$> mkMorphism id (patternCon rules) cons')

        instance Steppable (->) $(pure s) $(pure $ conAppsT tyNameF vars') where
          embed = $(LamCaseE <$> mkMorphism (patternCon rules) id cons')

        instance Recursive (->) $(pure s) $(pure $ conAppsT tyNameF vars') where
          cata φ = φ . fmap (cata φ) . project

        instance Corecursive (->) $(pure s) $(pure $ conAppsT tyNameF vars') where
          ana ψ = embed . fmap (ana ψ) . ψ
        |]
    -- Combine
    pure ([dataDec] <> recursiveDec)

-- | makes clauses to rename constructors
mkMorphism
    :: (Name -> Name)
    -> (Name -> Name)
    -> [ConstructorInfo]
    -> Q [Match]
mkMorphism nFrom nTo =
  traverse
  (\ci -> do
      let n = constructorName ci
      fs <- traverse (const $ newName "x") $ constructorFields ci
      pure
        $ Match
          (ConP (nFrom n) (map VarP fs))                      -- pattern
          (NormalB $ foldl AppE (ConE $ nTo n) (map VarE fs)) -- body
          [] -- where dec
  )
-------------------------------------------------------------------------------
-- Traversals
-------------------------------------------------------------------------------

conNameTraversal :: Traversal' ConstructorInfo Name
conNameTraversal = lens constructorName (\s v -> s { constructorName = v })

conFieldNameTraversal :: Traversal' ConstructorInfo Name
conFieldNameTraversal = lens constructorVariant (\s v -> s { constructorVariant = v })
                      . conVariantTraversal
  where
    conVariantTraversal :: Traversal' ConstructorVariant Name
    conVariantTraversal _ NormalConstructor      = pure NormalConstructor
    conVariantTraversal _ InfixConstructor       = pure InfixConstructor
    conVariantTraversal f (RecordConstructor fs) = RecordConstructor <$> traverse f fs

conTypeTraversal :: Traversal' ConstructorInfo Type
conTypeTraversal = lens constructorFields (\s v -> s { constructorFields = v })
                 . traverse

conNameMap :: (Name -> Name) -> ConstructorInfo -> ConstructorInfo
conNameMap = over conNameTraversal

conFieldNameMap :: (Name -> Name) -> ConstructorInfo -> ConstructorInfo
conFieldNameMap = over conFieldNameTraversal

conTypeMap :: (Type -> Type) -> ConstructorInfo -> ConstructorInfo
conTypeMap = over conTypeTraversal

-------------------------------------------------------------------------------
-- Lenses
-------------------------------------------------------------------------------

type Lens'      s a = forall f. Functor     f => (a -> f a) -> s -> f s
type Traversal' s a = forall f. Applicative f => (a -> f a) -> s -> f s

lens :: (s -> a) -> (s -> a -> s) -> Lens' s a
lens sa sas afa s = sas s <$> afa (sa s)
{-# INLINE lens #-}

over :: Traversal' s a -> (a -> a) -> s -> s
over l f = runIdentity . l (Identity . f)
{-# INLINE over #-}

-------------------------------------------------------------------------------
-- Type mangling
-------------------------------------------------------------------------------

-- | Extract type variables
typeVars :: [TyVarBndr] -> [Name]
typeVars = map tvName

-- | Apply arguments to a type constructor.
conAppsT :: Name -> [Type] -> Type
conAppsT conName = foldl AppT (ConT conName)

-- | Provides substitution for types
substType
    :: Type
    -> Type
    -> Type
    -> Type
substType a b = go
  where
    go x | x == a         = b
    go (VarT n)           = VarT n
    go (AppT l r)         = AppT (go l) (go r)
    go (ForallT xs ctx t) = ForallT xs ctx (go t)
    -- This may fail with kind error
    go (SigT t k)         = SigT (go t) k
    go (InfixT l n r)     = InfixT (go l) n (go r)
    go (UInfixT l n r)    = UInfixT (go l) n (go r)
    go (ParensT t)        = ParensT (go t)
    -- Rest are unchanged
    go x = x

toCon :: ConstructorInfo -> Con
toCon (ConstructorInfo { constructorName       = name
                       , constructorVars       = vars
                       , constructorContext    = ctxt
                       , constructorFields     = ftys
                       , constructorStrictness = fstricts
                       , constructorVariant    = variant })
  | not (null vars && null ctxt)
  = error "makeBaseFunctor: GADTs are not currently supported."
  | otherwise
  = let bangs = map toBang fstricts
     in case variant of
          NormalConstructor        -> NormalC name $ zip bangs ftys
          RecordConstructor fnames -> RecC name $ zip3 fnames bangs ftys
          InfixConstructor         -> let [bang1, bang2] = bangs
                                          [fty1,  fty2]  = ftys
                                       in InfixC (bang1, fty1) name (bang2, fty2)
  where
    toBang (FieldStrictness upkd strct) = Bang (toSourceUnpackedness upkd)
                                               (toSourceStrictness strct)
      where
        toSourceUnpackedness :: Unpackedness -> SourceUnpackedness
        toSourceUnpackedness UnspecifiedUnpackedness = NoSourceUnpackedness
        toSourceUnpackedness NoUnpack                = SourceNoUnpack
        toSourceUnpackedness Unpack                  = SourceUnpack

        toSourceStrictness :: Strictness -> SourceStrictness
        toSourceStrictness UnspecifiedStrictness = NoSourceStrictness
        toSourceStrictness Lazy                  = SourceLazy
        toSourceStrictness TH.Abs.Strict         = SourceStrict

-------------------------------------------------------------------------------
-- Manually quoted names
-------------------------------------------------------------------------------
-- By manually generating these names we avoid needing to use the
-- TemplateHaskell language extension when compiling this library.
-- This allows the library to be used in stage1 cross-compilers.

functorTypeName :: Name
functorTypeName = mkNameG_tc "base" "GHC.Base" "Functor"

foldableTypeName :: Name
foldableTypeName = mkNameG_tc "base" "Data.Foldable" "Foldable"

traversableTypeName :: Name
traversableTypeName = mkNameG_tc "base" "Data.Traversable" "Traversable"