{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TemplateHaskellQuotes #-}
module Optics.TH.Internal.Product
  ( LensRules(..)
  , FieldNamer
  , DefName(..)
  , ClassyNamer
  , makeFieldOptics
  , makeFieldOpticsForDec
  , makeFieldOpticsForDec'
  , makeFieldLabelsWith
  , makeFieldLabelsForDec
  , HasFieldClasses
  ) where

import Control.Monad
import Control.Monad.State
import Data.Either
import Data.List
import Data.Maybe
import Language.Haskell.TH
import qualified Data.Map as M
import qualified Data.Set as S
import qualified Data.Traversable as T
import qualified Language.Haskell.TH.Datatype as D

import Data.Either.Optics
import Data.Tuple.Optics
import Data.Set.Optics
import Language.Haskell.TH.Optics.Internal
import Optics.Core hiding (cons)
import Optics.TH.Internal.Utils

------------------------------------------------------------------------
-- Utilities
------------------------------------------------------------------------

typeSelf :: Traversal' Type Type
typeSelf = traversalVL $ \f -> \case
  ForallT tyVarBndrs ctx ty ->
    let go (KindedTV nam kind) = KindedTV <$> pure nam <*> f kind
        go (PlainTV nam)       = pure (PlainTV nam)
    in ForallT <$> traverse go tyVarBndrs <*> traverse f ctx <*> f ty
  AppT ty1 ty2              -> AppT <$> f ty1 <*> f ty2
  SigT ty kind              -> SigT <$> f ty <*> f kind
  InfixT ty1 nam ty2        -> InfixT <$> f ty1 <*> pure nam <*> f ty2
  UInfixT ty1 nam ty2       -> UInfixT <$> f ty1 <*> pure nam <*> f ty2
  ParensT ty                -> ParensT <$> f ty
  ty                        -> pure ty

------------------------------------------------------------------------
-- Field generation entry point
------------------------------------------------------------------------

-- | Compute the field optics for the type identified by the given type name.
-- Lenses will be computed when possible, Traversals otherwise.
makeFieldOptics :: LensRules -> Name -> DecsQ
makeFieldOptics rules = (`evalStateT` S.empty) . makeFieldOpticsForDatatype rules <=< D.reifyDatatype

makeFieldOpticsForDec :: LensRules -> Dec -> DecsQ
makeFieldOpticsForDec rules = (`evalStateT` S.empty) . makeFieldOpticsForDec' rules

makeFieldOpticsForDec' :: LensRules -> Dec -> HasFieldClasses [Dec]
makeFieldOpticsForDec' rules = makeFieldOpticsForDatatype rules <=< lift . D.normalizeDec

-- | Compute the field optics for a deconstructed datatype Dec
-- When possible build an Iso otherwise build one optic per field.
makeFieldOpticsForDatatype :: LensRules -> D.DatatypeInfo -> HasFieldClasses [Dec]
makeFieldOpticsForDatatype rules info =
  do perDef <- lift $ do
       fieldCons <- traverse normalizeConstructor cons
       let allFields  = toListOf (folded % _2 % folded % _1 % folded) fieldCons
       let defCons    = over normFieldLabels (expandName allFields) fieldCons
           allDefs    = setOf (normFieldLabels % folded) defCons
       T.sequenceA (M.fromSet (buildScaffold True rules s defCons) allDefs)

     let defs = M.toList perDef
     case _classyLenses rules tyName of
       Just (className, methodName) ->
         makeClassyDriver rules className methodName s defs
       Nothing -> do decss <- traverse (makeFieldOptic rules) defs
                     return (concat decss)

  where
  tyName = D.datatypeName info
  s      = D.datatypeType info
  cons   = D.datatypeCons info

  -- Traverse the field labels of a normalized constructor
  normFieldLabels :: Traversal [(Name,[(a,Type)])] [(Name,[(b,Type)])] a b
  normFieldLabels = traversed % _2 % traversed % _1

  -- Map a (possibly missing) field's name to zero-to-many optic definitions
  expandName :: [Name] -> Maybe Name -> [DefName]
  expandName allFields = concatMap (_fieldToDef rules tyName allFields) . maybeToList

makeFieldLabelsForDec :: LensRules -> Dec -> DecsQ
makeFieldLabelsForDec rules = makeFieldLabelsForDatatype rules <=< D.normalizeDec

-- | Build field optics as labels with a custom configuration.
makeFieldLabelsWith :: LensRules -> Name -> DecsQ
makeFieldLabelsWith rules = D.reifyDatatype >=> makeFieldLabelsForDatatype rules

-- | Compute the field optics for a deconstructed datatype Dec
-- When possible build an Iso otherwise build one optic per field.
makeFieldLabelsForDatatype :: LensRules -> D.DatatypeInfo -> Q [Dec]
makeFieldLabelsForDatatype rules info =
  do perDef <- do
       fieldCons <- traverse normalizeConstructor cons
       let allFields  = toListOf (folded % _2 % folded % _1 % folded) fieldCons
       let defCons    = over normFieldLabels (expandName allFields) fieldCons
           allDefs    = setOf (normFieldLabels % folded) defCons
       T.sequenceA (M.fromSet (buildScaffold False rules s defCons) allDefs)

     let defs = filter isRank1 $ M.toList perDef
     traverse (makeFieldLabel rules) defs

  where
    -- LabelOptic doesn't support higher rank fields because of functional
    -- dependencies (s -> a, t -> b), so just skip them.
    isRank1 = \case
      (_, (OpticSa rank1 _ _ _ _, _)) -> rank1
      _                               -> True

    tyName = D.datatypeName info
    s      = D.datatypeType info
    cons   = D.datatypeCons info

    -- Traverse the field labels of a normalized constructor
    normFieldLabels :: Traversal [(Name,[(a,Type)])] [(Name,[(b,Type)])] a b
    normFieldLabels = traversed % _2 % traversed % _1

    -- Map a (possibly missing) field's name to zero-to-many optic definitions
    expandName :: [Name] -> Maybe Name -> [DefName]
    expandName allFields = concatMap (_fieldToDef rules tyName allFields) . maybeToList

makeFieldLabel
  :: LensRules
  -> (DefName, (OpticStab, [(Name, Int, [Int])]))
  -> Q Dec
makeFieldLabel rules (defName, (defType, cons)) = do
  (context, instHead) <- case defType of
    OpticSa _ _ otype s a -> do
      (a', cxtA) <- eqSubst a "a"
      (b', cxtB) <- eqSubst a "b"
      pure (pure [cxtA, cxtB], pure $ conAppsT ''LabelOptic
        [LitT (StrTyLit fieldName), ConT $ opticTypeToTag otype, s, s, a', b'])
    OpticStab otype s t a b -> do
      ambiguousTypeFamilies <- containsAmbiguousTypeFamilyApplications s a
      -- If 'a' contains ambiguous type family applications, generate type
      -- preserving version as functional dependencies on LabelOptic demand it.
      let t' = if ambiguousTypeFamilies then s else t
      (a', cxtA) <- eqSubst a "a"
      (b', cxtB) <- if ambiguousTypeFamilies
                    then eqSubst a "b"
                    else eqSubst b "b"
      pure (pure [cxtA, cxtB], pure $ conAppsT ''LabelOptic
        [LitT (StrTyLit fieldName), ConT $ opticTypeToTag otype, s, t', a', b'])
  instanceD context instHead (fun 'labelOptic)
  where
    opticTypeToTag AffineFoldType      = ''An_AffineFold
    opticTypeToTag AffineTraversalType = ''An_AffineTraversal
    opticTypeToTag FoldType            = ''A_Fold
    opticTypeToTag GetterType          = ''A_Getter
    opticTypeToTag IsoType             = ''An_Iso
    opticTypeToTag LensType            = ''A_Lens
    opticTypeToTag TraversalType       = ''A_Traversal

    -- TODO: check for injectivity of encountered type families.
    containsAmbiguousTypeFamilyApplications s a = do
      -- We consider type family application ambiguous only if it's applied to a
      -- type variable not referenced anywhere else.
      (hasTypeFamilies, bareVars) <- (`runStateT` setOf typeVars s) $
        go =<< lift (D.resolveTypeSynonyms a)
      pure $ hasTypeFamilies && not (S.null bareVars)
      where
        go (ConT nm) = has (_FamilyI % _1 % _TypeFamilyD) <$> lift (reify nm)
        go (VarT n)  = modify' (S.delete n) *> pure False
        go ty        = or <$> traverse go (toListOf typeSelf ty)

    fieldName = case defName of
      TopName fname      -> nameBase fname
      MethodName _ fname -> nameBase fname

    fun :: Name -> [DecQ]
    fun n = funD n [funDef] : inlinePragma n

    funDef :: ClauseQ
    funDef = makeFieldClause rules (stabToOpticType defType) cons

-- | Normalized the Con type into a uniform positional representation,
-- eliminating the variance between records, infix constructors, and normal
-- constructors.
normalizeConstructor ::
  D.ConstructorInfo ->
  Q (Name, [(Maybe Name, Type)]) -- ^ constructor name, field name, field type

normalizeConstructor con =
  return (D.constructorName con,
          zipWith checkForExistentials fieldNames (D.constructorFields con))
  where
    fieldNames =
      case D.constructorVariant con of
        D.RecordConstructor xs -> fmap Just xs
        D.NormalConstructor    -> repeat Nothing
        D.InfixConstructor     -> repeat Nothing

    -- Fields mentioning existentially quantified types are not
    -- elligible for TH generated optics.
    checkForExistentials _ fieldtype
      | any (\tv -> D.tvName tv `S.member` used) unallowable
      = (Nothing, fieldtype)
      where
        used        = setOf typeVars fieldtype
        unallowable = D.constructorVars con
    checkForExistentials fieldname fieldtype = (fieldname, fieldtype)

-- | Compute the positional location of the fields involved in
-- each constructor for a given optic definition as well as the
-- type of clauses to generate and the type to annotate the declaration
-- with.
buildScaffold ::
  Bool                       {- ^ allow change of phantom type parameters -} ->
  LensRules                                                                  ->
  Type                              {- ^ outer type                       -} ->
  [(Name, [([DefName], Type)])]     {- ^ normalized constructors          -} ->
  DefName                           {- ^ target definition                -} ->
  Q (OpticStab, [(Name, Int, [Int])])
              {- ^ optic type, definition type, field count, target fields -}
buildScaffold allowPhantomsChange rules s cons defName =

  do (s',t,a,b) <- buildStab allowPhantomsChange s (concatMap snd consForDef)

     let defType
           | Just (tyvars,cx,a') <- preview _ForallT a =
               let optic | lensCase   = GetterType
                         | affineCase = AffineFoldType
                         | otherwise  = FoldType
               in OpticSa (null tyvars) cx optic s' a'

           -- Getter and Fold are always simple
           | not (_allowUpdates rules) =
               let optic | lensCase   = GetterType
                         | affineCase = AffineFoldType
                         | otherwise  = FoldType
               in OpticSa True [] optic s' a

           -- Generate simple Lens and Traversal where possible
           | _simpleLenses rules || s' == t && a == b =
               let optic | isoCase && _allowIsos rules = IsoType
                         | lensCase                    = LensType
                         | affineCase                  = AffineTraversalType
                         | otherwise                   = TraversalType
               in OpticSa True [] optic s' a

           -- Generate type-changing Lens and Traversal otherwise
           | otherwise =
               let optic | isoCase && _allowIsos rules = IsoType
                         | lensCase                    = LensType
                         | affineCase                  = AffineTraversalType
                         | otherwise                   = TraversalType
               in OpticStab optic s' t a b

     return (defType, scaffolds)
  where
    consForDef :: [(Name, [Either Type Type])]
    consForDef = over (mapped % _2 % mapped) categorize cons

    scaffolds :: [(Name, Int, [Int])]
    scaffolds = [ (n, length ts, rightIndices ts) | (n,ts) <- consForDef ]

    rightIndices :: [Either Type Type] -> [Int]
    rightIndices = findIndices (has _Right)

    -- Right: types for this definition
    -- Left : other types
    categorize :: ([DefName], Type) -> Either Type Type
    categorize (defNames, t)
      | defName `elem` defNames = Right t
      | otherwise               = Left  t

    affectedFields :: [Int]
    affectedFields = toListOf (folded % _3 % to length) scaffolds

    lensCase :: Bool
    lensCase = all (== 1) affectedFields

    affineCase :: Bool
    affineCase = all (<= 1) affectedFields

    isoCase :: Bool
    isoCase = case scaffolds of
                [(_,1,[0])] -> True
                _           -> False

data OpticType
  = AffineFoldType
  | AffineTraversalType
  | FoldType
  | GetterType
  | IsoType
  | LensType
  | TraversalType
  deriving Show

opticTypeName :: Bool -> OpticType -> Name
opticTypeName typeChanging  AffineTraversalType = if typeChanging
                                                  then ''AffineTraversal
                                                  else ''AffineTraversal'
opticTypeName _typeChanging AffineFoldType      = ''AffineFold
opticTypeName _typeChanging FoldType            = ''Fold
opticTypeName _typeChanging GetterType          = ''Getter
opticTypeName typeChanging  IsoType             = if typeChanging
                                                  then ''Iso
                                                  else ''Iso'
opticTypeName typeChanging  LensType            = if typeChanging
                                                  then ''Lens
                                                  else ''Lens'
opticTypeName typeChanging  TraversalType       = if typeChanging
                                                  then ''Traversal
                                                  else ''Traversal'

data OpticStab = OpticStab        OpticType Type Type Type Type
               | OpticSa Bool Cxt OpticType Type Type

stabToType :: OpticStab -> Type
stabToType (OpticStab  c s t a b) =
  quantifyType [] (opticTypeName True c `conAppsT` [s,t,a,b])
stabToType (OpticSa _ cx c s   a  ) =
  quantifyType cx (opticTypeName False c `conAppsT` [s,a])

stabToContext :: OpticStab -> Cxt
stabToContext OpticStab{}          = []
stabToContext (OpticSa _ cx _ _ _) = cx

stabToOpticType :: OpticStab -> OpticType
stabToOpticType (OpticStab c _ _ _ _) = c
stabToOpticType (OpticSa _ _ c _ _) = c

stabToOptic :: OpticStab -> Name
stabToOptic (OpticStab c _ _ _ _) = opticTypeName True c
stabToOptic (OpticSa _ _ c _ _) = opticTypeName False c

stabToS :: OpticStab -> Type
stabToS (OpticStab _ s _ _ _) = s
stabToS (OpticSa _ _ _ s _) = s

stabToA :: OpticStab -> Type
stabToA (OpticStab _ _ _ a _) = a
stabToA (OpticSa _ _ _ _ a) = a

-- | Compute the s t a b types given the outer type 's' and the
-- categorized field types. Left for fixed and Right for visited.
-- These types are "raw" and will be packaged into an 'OpticStab'
-- shortly after creation.
buildStab :: Bool -> Type -> [Either Type Type] -> Q (Type,Type,Type,Type)
buildStab allowPhantomsChange s categorizedFields = do
  -- compute possible type changes
  sub <- T.sequenceA (M.fromSet (newName . nameBase) unfixedTypeVars)
  let (t, b) = over each (substTypeVars sub) (s, a)
  pure (s, t, a, b)
  where
    -- Just take the type of the first field and let GHC do the unification.
    a = fromMaybe
      (error "buildStab: unexpected empty list of fields")
      (preview _head targetFields)

    phantomTypeVars =
      let allTypeVars = folded % chosen % typeVars
      in setOf typeVars s S.\\ setOf allTypeVars categorizedFields

    (fixedFields, targetFields) = partitionEithers categorizedFields
    unfixedTypeVars =
      let fixedTypeVars = setOf typeVars fixedFields
      in if allowPhantomsChange
         then setOf typeVars s S.\\ fixedTypeVars
         else setOf typeVars s S.\\ fixedTypeVars S.\\ phantomTypeVars

-- | Build the signature and definition for a single field optic.
-- In the case of a singleton constructor irrefutable matches are
-- used to enable the resulting lenses to be used on a bottom value.
makeFieldOptic ::
  LensRules ->
  (DefName, (OpticStab, [(Name, Int, [Int])])) ->
  HasFieldClasses [Dec]
makeFieldOptic rules (defName, (defType, cons)) = do
  locals <- get
  addName
  lift $ do cls <- mkCls locals
            T.sequenceA (cls ++ sig ++ def)
  where
  mkCls locals = case defName of
                 MethodName c n | _generateClasses rules ->
                  do classExists <- isJust <$> lookupTypeName (show c)
                     return (if classExists || S.member c locals then [] else [makeFieldClass defType c n])
                 _ -> return []

  addName = case defName of
            MethodName c _ -> addFieldClassName c
            _              -> return ()

  sig = case defName of
          _ | not (_generateSigs rules) -> []
          TopName n -> [sigD n (return (stabToType defType))]
          MethodName{} -> []

  fun n = funD n [funDef] : inlinePragma n

  def = case defName of
          TopName n      -> fun n
          MethodName c n -> [makeFieldInstance defType c (fun n)]

  funDef = makeFieldClause rules (stabToOpticType defType) cons

------------------------------------------------------------------------
-- Classy class generator
------------------------------------------------------------------------


makeClassyDriver ::
  LensRules ->
  Name ->
  Name ->
  Type {- ^ Outer 's' type -} ->
  [(DefName, (OpticStab, [(Name, Int, [Int])]))] ->
  HasFieldClasses [Dec]
makeClassyDriver rules className methodName s defs = T.sequenceA (cls ++ inst)

  where
  cls | _generateClasses rules = [lift $ makeClassyClass className methodName s defs]
      | otherwise = []

  inst = [makeClassyInstance rules className methodName s defs]


makeClassyClass ::
  Name ->
  Name ->
  Type {- ^ Outer 's' type -} ->
  [(DefName, (OpticStab, [(Name, Int, [Int])]))] ->
  DecQ
makeClassyClass className methodName s defs = do
  c <- newName "c"
  let vars = toListOf typeVars s
      fd   | null vars = []
           | otherwise = [FunDep [c] vars]


  classD (cxt[]) className (map PlainTV (c:vars)) fd
    $ sigD methodName (return (''Lens' `conAppsT` [VarT c, s]))
    : concat
      [ [sigD defName (return ty)
        ,valD (varP defName) (normalB body) []
        ] ++
        inlinePragma defName
      | (TopName defName, (stab, _)) <- defs
      , let body = infixApp (varE methodName) (varE '(%)) (varE defName)
      , let ty   = quantifyType' (S.fromList (c:vars))
                                 (stabToContext stab)
                 $ stabToOptic stab `conAppsT`
                       [VarT c, stabToA stab]
      ]


makeClassyInstance ::
  LensRules ->
  Name ->
  Name ->
  Type {- ^ Outer 's' type -} ->
  [(DefName, (OpticStab, [(Name, Int, [Int])]))] ->
  HasFieldClasses Dec
makeClassyInstance rules className methodName s defs = do
  methodss <- traverse (makeFieldOptic rules') defs

  lift $ instanceD (cxt[]) (return instanceHead)
           $ valD (varP methodName) (normalB (varE 'lensVL `appE` varE 'id)) []
           : map return (concat methodss)

  where
  instanceHead = className `conAppsT` (s : map VarT vars)
  vars         = toListOf typeVars s
  rules'       = rules { _generateSigs    = False
                       , _generateClasses = False
                       }

------------------------------------------------------------------------
-- Field class generation
------------------------------------------------------------------------

makeFieldClass :: OpticStab -> Name -> Name -> DecQ
makeFieldClass defType className methodName =
  classD (cxt []) className [PlainTV s, PlainTV a] [FunDep [s] [a]]
         [sigD methodName (return methodType)]
  where
  methodType = quantifyType' (S.fromList [s,a])
                             (stabToContext defType)
             $ stabToOptic defType `conAppsT` [VarT s,VarT a]
  s = mkName "s"
  a = mkName "a"

-- | Build an instance for a field. If the field’s type contains any type
-- families, will produce an equality constraint to avoid a type family
-- application in the instance head.
makeFieldInstance :: OpticStab -> Name -> [DecQ] -> DecQ
makeFieldInstance defType className decs =
  containsTypeFamilies a >>= pickInstanceDec
  where
  s = stabToS defType
  a = stabToA defType

  containsTypeFamilies = go <=< D.resolveTypeSynonyms
    where
    go (ConT nm) = has (_FamilyI % _1 % _TypeFamilyD) <$> reify nm
    go ty = or <$> traverse go (toListOf typeSelf ty)

  pickInstanceDec hasFamilies
    | hasFamilies = do
        placeholder <- VarT <$> newName "a"
        mkInstanceDec
          [return (D.equalPred placeholder a)]
          [s, placeholder]
    | otherwise = mkInstanceDec [] [s, a]

  mkInstanceDec context headTys =
    instanceD (cxt context) (return (className `conAppsT` headTys)) decs

------------------------------------------------------------------------
-- Optic clause generators
------------------------------------------------------------------------

makeFieldClause :: LensRules -> OpticType -> [(Name, Int, [Int])] -> ClauseQ
makeFieldClause rules opticType cons =
  case opticType of
    AffineFoldType      -> makeAffineFoldClause cons
    AffineTraversalType -> makeAffineTraversalClause cons irref
    FoldType            -> makeFoldClause cons
    IsoType             -> makeIsoClause cons irref
    GetterType          -> makeGetterClause cons
    LensType            -> makeLensClause cons irref
    TraversalType       -> makeTraversalClause cons irref
  where
    irref = _lazyPatterns rules && length cons == 1

makeAffineFoldClause :: [(Name, Int, [Int])] -> ClauseQ
makeAffineFoldClause cons = do
  s <- newName "s"
  clause
    []
    (normalB $ appsE
      [ varE 'afolding
      , lamE [varP s] $ caseE (varE s)
        [ makeAffineFoldMatch conName fieldCount fields
        | (conName, fieldCount, fields) <- cons
        ]
      ])
    []
  where
    makeAffineFoldMatch conName fieldCount fields = do
      xs <- newNames "x" $ length fields

      let args = foldr (\(i, x) -> set (ix i) (varP x))
                       (replicate fieldCount wildP)
                       (zip fields xs)

          body = case xs of
            -- Con _ .. _ -> Nothing
            []  -> conE 'Nothing
            -- Con _ .. x_i .. _ -> Just x_i
            [x] -> conE 'Just `appE` varE x
            _   -> error "AffineFold focuses on at most one field"

      match (conP conName args)
            (normalB body)
            []

makeFoldClause :: [(Name, Int, [Int])] -> ClauseQ
makeFoldClause cons = do
  f <- newName "f"
  s <- newName "s"
  clause
    []
    (normalB $ appsE
      [ varE 'foldVL
      , lamE [varP f, varP s] $ caseE (varE s)
        [ makeFoldMatch f conName fieldCount fields
        | (conName, fieldCount, fields) <- cons
        ]
      ])
    []
  where
    makeFoldMatch f conName fieldCount fields = do
      xs <- newNames "x" $ length fields

      let args = foldr (\(i, x) -> set (ix i) (varP x))
                       (replicate fieldCount wildP)
                       (zip fields xs)

          fxs = case xs of
            [] -> [varE 'pure `appE` conE '()]
            _  -> map (\x -> varE f `appE` varE x) xs

          -- Con _ .. x_1 .. _ .. x_k .. _ -> f x_1 *> .. f x_k
          body = appsE
            [ foldr1 (\fx -> infixApp fx (varE '(*>))) fxs
            ]

      match (conP conName args)
            (normalB body)
            []

-- | Build a getter clause that retrieves the field at the given index.
makeGetterClause :: [(Name, Int, [Int])] -> ClauseQ
makeGetterClause cons = do
  s <- newName "s"
  clause
    []
    (normalB $ appsE
      [ varE 'to
      , lamE [varP s] $ caseE (varE s)
        [ makeGetterMatch conName fieldCount fields
        | (conName, fieldCount, fields) <- cons
        ]
      ])
    []
  where
    makeGetterMatch conName fieldCount = \case
      [field] -> do
        x <- newName "x"
        -- Con _ .. x_i .. _ -> x_i
        match (conP conName . set (ix field) (varP x) $ replicate fieldCount wildP)
              (normalB $ varE x)
              []
      _       -> error "Getter focuses on exactly one field"

-- | Build a clause that constructs an Iso.
makeIsoClause :: [(Name, Int, [Int])] -> Bool -> ClauseQ
makeIsoClause fields irref = case fields of
  [(conName, 1, [0])] -> do
    x <- newName "x"
    clause []
           (normalB $ appsE
             [ varE 'iso
             , lamE [irrefP $ conP conName [varP x]] (varE x)
             , conE conName
             ])
           []
  _ -> error "Iso works only for types with one constructor and one field"
  where
    irrefP = if irref then tildeP else id

-- | Build a lens clause that updates the field at the given index. When irref
-- is 'True' the value with be matched with an irrefutable pattern.
makeLensClause :: [(Name, Int, [Int])] -> Bool -> ClauseQ
makeLensClause cons irref = do
  f <- newName "f"
  s <- newName "s"
  clause
    []
    (normalB $ appsE
      [ varE 'lensVL
      , lamE [varP f, varP s] $ caseE (varE s)
        [ makeLensMatch irrefP f conName fieldCount fields
        | (conName, fieldCount, fields) <- cons
        ]
      ])
    []
  where
    irrefP = if irref then tildeP else id

-- | Make a lens match. Used for both lens and affine traversal generation.
makeLensMatch :: (PatQ -> PatQ) -> Name -> Name -> Int -> [Int] -> Q Match
makeLensMatch irrefP f conName fieldCount = \case
  [field] -> do
    xs <- newNames "x" fieldCount
    y  <- newName "y"

    let body = appsE
          [ varE 'fmap
          , lamE [varP y] . appsE $
            conE conName : map varE (set (ix field) y xs)
          , appE (varE f) . varE $ xs !! field
          ]

    -- Con x_1 .. x_n -> fmap (\y_i -> Con x_1 .. y_i .. x_n) (f x_i)
    match (irrefP . conP conName $ map varP xs)
          (normalB body)
          []
  _       -> error "Lens focuses on exactly one field"

makeAffineTraversalClause :: [(Name, Int, [Int])] -> Bool -> ClauseQ
makeAffineTraversalClause cons irref = do
  point <- newName "point"
  f     <- newName "f"
  s     <- newName "s"
  clause
    []
    (normalB $ appsE
      [ varE 'atraversalVL
      , lamE [varP point, varP f, varP s] $ caseE (varE s)
        [ makeAffineTraversalMatch point f conName fieldCount fields
        | (conName, fieldCount, fields) <- cons
        ]
      ])
    []
  where
    irrefP = if irref then tildeP else id

    makeAffineTraversalMatch point f conName fieldCount = \case
      [] -> do
        xs <- newNames "x" fieldCount
        -- Con x_1 ... x_n -> point (Con x_1 .. x_n)
        match (irrefP . conP conName $ map varP xs)
              (normalB $ varE point `appE` appsE (conE conName : map varE xs))
              []
      [field] -> makeLensMatch irrefP f conName fieldCount [field]
      _ -> error "Affine traversal focuses on at most one field"

makeTraversalClause :: [(Name, Int, [Int])] -> Bool -> ClauseQ
makeTraversalClause cons irref = do
  f <- newName "f"
  s <- newName "s"
  clause
    []
    (normalB $ appsE
      [ varE 'traversalVL
      , lamE [varP f, varP s] $ caseE (varE s)
        [ makeTraversalMatch f conName fieldCount fields
        | (conName, fieldCount, fields) <- cons
        ]
      ])
    []
  where
    irrefP = if irref then tildeP else id

    makeTraversalMatch f conName fieldCount fields = do
      xs <- newNames "x" fieldCount
      case fields of
        [] -> -- Con x_1 .. x_n -> pure (Con x_1 .. x_n)
          match (irrefP . conP conName $ map varP xs)
                (normalB $ varE 'pure `appE` appsE (conE conName : map varE xs))
                []
        _ -> do
          ys <- newNames "y" $ length fields

          let xs' = foldr (\(i, x) -> set (ix i) x) xs (zip fields ys)

              mkFx i = varE f `appE` varE (xs !! i)

              body0 = appsE
                [ varE 'pure
                , lamE (map varP ys) $ appsE $ conE conName : map varE xs'
                ]

              body = foldl (\acc i -> infixApp acc (varE '(<*>)) $ mkFx i)
                           body0
                           fields

          -- Con x_1 .. x_n ->
          --  pure (\y_1 .. y_k -> Con x_1 .. y_1 .. x_l .. y_k .. x_n)
          --    <*> f x_i_y_1 <*> .. <*> f x_i_y_k
          match (irrefP . conP conName $ map varP xs)
                (normalB body)
                []

------------------------------------------------------------------------
-- Field generation parameters
------------------------------------------------------------------------

-- | Rules to construct lenses for data fields.
data LensRules = LensRules
  { _simpleLenses    :: Bool
  , _generateSigs    :: Bool
  , _generateClasses :: Bool
  , _allowIsos       :: Bool
  , _allowUpdates    :: Bool -- ^ Allow Lens/Traversal (otherwise Getter/Fold)
  , _lazyPatterns    :: Bool
  , _fieldToDef      :: FieldNamer
       -- ^ Type Name -> Field Names -> Target Field Name -> Definition Names
  , _classyLenses    :: ClassyNamer
       -- type name to class name and top method
  }

-- | The rule to create function names of lenses for data fields.
--
-- Although it's sometimes useful, you won't need the first two
-- arguments most of the time.
type FieldNamer = Name -- ^ Name of the data type that lenses are being generated for.
                  -> [Name] -- ^ Names of all fields (including the field being named) in the data type.
                  -> Name -- ^ Name of the field being named.
                  -> [DefName] -- ^ Name(s) of the lens functions. If empty, no lens is created for that field.

-- | Name to give to generated field optics.
data DefName
  = TopName Name -- ^ Simple top-level definiton name
  | MethodName Name Name -- ^ makeFields-style class name and method name
  deriving (Show, Eq, Ord)

-- | The optional rule to create a class and method around a
-- monomorphic data type. If this naming convention is provided, it
-- generates a "classy" lens.
type ClassyNamer = Name -- ^ Name of the data type that lenses are being generated for.
                   -> Maybe (Name, Name) -- ^ Names of the class and the main method it generates, respectively.

-- | Tracks the field class 'Name's that have been created so far. We consult
-- these so that we may avoid creating duplicate classes.

-- See #643 for more information.
type HasFieldClasses = StateT (S.Set Name) Q

addFieldClassName :: Name -> HasFieldClasses ()
addFieldClassName n = modify $ S.insert n

------------------------------------------------------------------------
-- Miscellaneous utility functions
------------------------------------------------------------------------

 -- We want to catch type families, but not *data* families. See #799.
_TypeFamilyD :: AffineFold Dec ()
_TypeFamilyD = _OpenTypeFamilyD % united `afailing` _ClosedTypeFamilyD % united

-- | Template Haskell wants type variables declared in a forall, so
-- we find all free type variables in a given type and declare them.
quantifyType :: Cxt -> Type -> Type
quantifyType = quantifyType' S.empty

-- | This function works like 'quantifyType' except that it takes
-- a list of variables to exclude from quantification.
quantifyType' :: S.Set Name -> Cxt -> Type -> Type
quantifyType' exclude c t = ForallT vs c t
  where
  vs = map PlainTV
     $ filter (`S.notMember` exclude)
     $ nub -- stable order
     $ toListOf typeVars t