-- | Function to compute free type variable set for a Type.  (I took
-- this from somewhere, I really need to credit it.  Now when I search
-- all I can find is myself.)
{-# LANGUAGE CPP, FlexibleContexts, FlexibleInstances, ScopedTypeVariables, TemplateHaskell #-}
module Language.Haskell.TH.TypeGraph.Free
    ( freeTypeVars
    ) where

import Control.Lens hiding (Strict, cons)
import Control.Monad.State (MonadState, execStateT)
import Data.Set as Set (Set, delete, difference, empty, fromList, insert, member)
import Language.Haskell.TH
import Language.Haskell.TH.Desugar ({- instances -})
import Language.Haskell.TH.Syntax (Quasi(qReify))
import Language.Haskell.TH.TypeGraph.Prelude (pprint1)

data St
    = St { _result :: Set Name
         , _visited :: Set Name
         } deriving Show

st0 :: St
st0 = St {_result = empty, _visited = empty}

$(makeLenses ''St)

-- | Return the names of the type variables that are free in x.  I.e.,
-- type variables that appear in the type expression but are not bound
-- by an enclosing forall or by the type parameters of a Dec.
freeTypeVars :: (FreeTypeVars t, Quasi m) => t -> m (Set Name)
freeTypeVars x = view result <$> execStateT (ftv x) st0

-- | This is based on the freeNamesOfTypes function from the
-- th-desugar package.
class FreeTypeVars t where
    ftv :: (Quasi m, MonadState St m) => t -> m ()

instance FreeTypeVars a => FreeTypeVars [a] where
    ftv ts = mapM_ ftv ts

instance FreeTypeVars Type where
    ftv (ForallT tvbs cx ty) = do
      ftv ty
      mapM_ go_pred cx
      result %= (`Set.difference` (Set.fromList (map tvbName tvbs)))
        where
#if __GLASGOW_HASKELL__ >= 709
          go_pred typ =
              -- This looks wrong as the one below looks wrong.  Wronger maybe.
              ftv typ
#else
          go_pred (ClassP _ tys) = ftv tys
          go_pred (EqualP t1 t2) = do
            -- This looks wrong - we need to unify t1 and t2 and look
            -- at the free type variables in the resulting bindings
            ftv t1
            ftv t2
#endif
    ftv (SigT ty _) = ftv ty
    ftv (VarT n) = result %= Set.insert n
    ftv (AppT t1 t2) = {-trace ("go_app " ++ show typ) (return ()) >>-} go_app [t2] t1
    ftv typ@(ConT _) = {-trace ("go_app " ++ show typ) (return ()) >>-} go_app [] typ
    ftv _ = return ()


go_app :: (Quasi m, MonadState St m) => [Type] -> Type -> m ()
go_app params (AppT t1 t2) = go_app (t2 : params) t1
go_app params (ConT n) = do
    stk <- use visited
    case Set.member n stk of
      True -> return ()
      False -> do
        visited %= Set.insert n
        qReify n >>= go_info (reverse params)
go_app params typ = mapM_ ftv (typ : params)
go_info :: (Quasi m, MonadState St m) => [Type] -> Info -> m ()
go_info params (TyConI dec) = go_dec params ({-trace ("go_dec " ++ show dec)-} dec)
go_info params (FamilyI dec _insts) = go_dec params dec
go_info _params (PrimTyConI _name _arity _unlifed) = return ()
go_info _params info = error $ "go_info - unexpected: " ++ pprint1 info
go_dec :: (Quasi m, MonadState St m) => [Type] -> Dec -> m ()
#if MIN_VERSION_template_haskell(2,11,0)
go_dec params (NewtypeD cx tname tvs m con supers) = go_dec params (DataD cx tname tvs m [con] supers)
go_dec params (DataD _ tname tvs _ _ _) | length params > length tvs = error $ "Too many arguments to " ++ show tname
go_dec params (DataD _cx tname tvs _ cons _supers) = do
  -- For each type variable bound to a type parameter,
  -- replace the type variable with the free variables
  -- in the parameter
  ftv cons
  go_params tname tvs params
#else
go_dec params (NewtypeD cx tname tvs con supers) = go_dec params (DataD cx tname tvs [con] supers)
go_dec params (DataD _ tname tvs _ _) | length params > length tvs = error $ "Too many arguments to " ++ show tname
go_dec params (DataD _cx tname tvs cons _supers) = do
  -- For each type variable bound to a type parameter,
  -- replace the type variable with the free variables
  -- in the parameter
  ftv cons
  go_params tname tvs params
#endif
go_dec params (TySynD tname tvs typ) = do
  -- Add the free variables in the type, then subtract the ones that
  -- are bound here.
  ftv typ
  go_params tname tvs params

-- I have a feeling this is utterly wrong.  Example, with this class:
--
-- class OrderKey k => OrderMap k where
--    data Order k :: * -> *
--    ...
--
-- the resulting declaration of Order is
--
--    FamilyD DataFam Language.Haskell.TH.Path.Order.Order [PlainTV k,PlainTV $a] (Just StarT)
--    params=[ConT AbbrevPairID]
--
-- so the parameter is bound to k, and $a should be free.
#if MIN_VERSION_template_haskell(2,11,0)
go_dec params (DataFamilyD tname tvs _mkind) = go_params tname tvs params
#else
go_dec params (FamilyD _flavour tname tvs _mkind) = go_params tname tvs params
#endif
go_dec params dec = error $ "go_dec - unexpected: " ++ pprint1 dec ++ ", params=" ++ show params

go_params :: (Quasi m, MonadState St m) => Name -> [TyVarBndr] -> [Type] -> m ()
go_params tname tvs params | length params  > length tvs = error $ "Too many arguments to " ++ show tname
go_params _ tvs params = mapM_ (uncurry go_param) (zip tvs (map Just params ++ repeat Nothing))

-- | Update the free variable set for a type parameter
go_param :: (Quasi m, MonadState St m) => TyVarBndr -> Maybe Type -> m ()
go_param tvb (Just param) = do
  -- If there is a binding, add the free variables found in the type
  -- and remove the variable bound here
  -- trace ("go_param " ++ "(" ++ pprint tvb ++ ", " ++ pprint1 param ++ ")") (return ())
  ftv param
  result %= Set.delete (tvbName tvb)
  -- let tv = tvbName tvb
  -- r <- use result
  -- when (Set.member tv r) (ftv param >> result %= Set.delete tv)
go_param tvb Nothing = do
  -- If there is a variable not bound to a type parameter it is fee
  result %= Set.insert (tvbName tvb)

{-
instance FreeTypeVars Info where
    ftv (TyConI dec) = ftv dec

instance FreeTypeVars Dec where
    ftv dec@(DataD _ _ _ _ _ _) = ftv dec
#if __GLASGOW_HASKELL__ >= 709
    go_pred = go
#else
    go_pred (ClassP _ tys) = freeNamesOfTypes tys
    go_pred (EqualP t1 t2) = go t1 <> go t2
#endif
-}

instance FreeTypeVars Con where
    ftv (NormalC _name sts) = ftv sts
    ftv (RecC _name vsts) = ftv vsts
    ftv (InfixC st1 _ st2) = ftv [st1, st2]
    -- I'm not sure what effect this forall has.
    ftv (ForallC _tvbs _cx con) = ftv con

instance FreeTypeVars (Strict, Type) where
    ftv (_, typ) = ftv typ

instance FreeTypeVars (Name, Strict, Type) where
    ftv (_, _, typ) = ftv typ

-- | Extract a 'Name' from a 'TyVarBndr'
tvbName :: TyVarBndr -> Name
tvbName (PlainTV n)    = n
tvbName (KindedTV n _) = n