{-# LANGUAGE CPP #-}
module Data.Deriving.Via.Internal where
#if MIN_VERSION_template_haskell(2,12,0)
import           Control.Monad ((<=<), unless)
import           Data.Deriving.Internal
import qualified Data.Map as M
import           Data.Map (Map)
import           Data.Maybe (catMaybes)
import           Language.Haskell.TH
import           Language.Haskell.TH.Datatype
deriveGND :: Q Type -> Q [Dec]
deriveGND qty = do
  ty <- qty
  let (instanceTvbs, instanceCxt, instanceTy) = decomposeType ty
  instanceTy' <- (resolveTypeSynonyms <=< resolveInfixT) instanceTy
  decs <- deriveViaDecs instanceTy' Nothing
  let instanceHeader = ForallT instanceTvbs instanceCxt instanceTy
  (:[]) `fmap` instanceD (return [])
                         (return instanceHeader)
                         (map return decs)
deriveVia :: Q Type -> Q [Dec]
deriveVia qty = do
  ty <- qty
  let (instanceTvbs, instanceCxt, viaApp) = decomposeType ty
  viaApp' <- (resolveTypeSynonyms <=< resolveInfixT) viaApp
  (instanceTy, viaTy)
    <- case unapplyTy viaApp' of
         [via,instanceTy,viaTy]
           | via == ConT viaTypeName
          -> return (instanceTy, viaTy)
         _ -> fail $ unlines
                [ "Failure to meet ‘deriveVia‘ specification"
                , "\tThe ‘Via‘ type must be used, e.g."
                , "\t[t| forall a. C (T a) `Via` V a |]"
                ]
  decs <- deriveViaDecs instanceTy (Just viaTy)
  let instanceHeader = ForallT instanceTvbs instanceCxt instanceTy
  (:[]) `fmap` instanceD (return [])
                         (return instanceHeader)
                         (map return decs)
deriveViaDecs :: Type       
              -> Maybe Type 
                            
              -> Q [Dec]
deriveViaDecs instanceTy mbViaTy = do
  let (clsTy:clsArgs) = unapplyTy instanceTy
  case clsTy of
    ConT clsName -> do
      clsInfo <- reify clsName
      case clsInfo of
        ClassI (ClassD _ _ clsTvbs _ clsDecs) _ ->
          case (unsnoc clsArgs, unsnoc clsTvbs) of
            (Just (_, dataApp), Just (_, clsLastTvb)) -> do
              let (dataTy:dataArgs)  = unapplyTy dataApp
                  clsLastTvbKind     = tvbKind clsLastTvb
                  (_, kindList)      = uncurryTy clsLastTvbKind
                  numArgsToEtaReduce = length kindList - 1
              repTy <-
                case mbViaTy of
                  Just viaTy -> return viaTy
                  Nothing ->
                    case dataTy of
                      ConT dataName -> do
                        DatatypeInfo {
                                       datatypeInstTypes = dataInstTypes
                                     , datatypeVariant   = dv
                                     , datatypeCons      = cons
                                     } <- reifyDatatype dataName
                        case newtypeRepType dv cons of
                          Just newtypeRepTy ->
                            case etaReduce numArgsToEtaReduce newtypeRepTy of
                              Just etaRepTy ->
                                let repTySubst =
                                      M.fromList $
                                      zipWith (\var arg -> (varTToName var, arg))
                                              dataInstTypes dataArgs
                                in return $ applySubstitution repTySubst etaRepTy
                              Nothing -> etaReductionError instanceTy
                          Nothing -> fail $ "Not a newtype: " ++ nameBase dataName
                      _ -> fail $ "Not a data type: " ++ pprint dataTy
              concat . catMaybes <$> traverse (deriveViaDecs' clsName clsTvbs clsArgs repTy) clsDecs
            (_, _) -> fail $ "Cannot derive instance for nullary class " ++ pprint clsTy
        _ -> fail $ "Not a type class: " ++ pprint clsTy
    _ -> fail $ "Malformed instance: " ++ pprint instanceTy
deriveViaDecs' :: Name -> [TyVarBndr] -> [Type] -> Type -> Dec -> Q (Maybe [Dec])
deriveViaDecs' clsName clsTvbs clsArgs repTy dec = do
    let numExpectedArgs = length clsTvbs
        numActualArgs   = length clsArgs
    unless (numExpectedArgs == numActualArgs) $
      fail $ "Mismatched number of class arguments"
          ++ "\n\tThe class " ++ nameBase clsName ++ " expects " ++ show numExpectedArgs ++ " argument(s),"
          ++ "\n\tbut was provided " ++ show numActualArgs ++ " argument(s)."
    go dec
  where
    go :: Dec -> Q (Maybe [Dec])
    go (OpenTypeFamilyD (TypeFamilyHead tfName tfTvbs _ _)) = do
      let lhsSubst = zipTvbSubst clsTvbs clsArgs
          rhsSubst = zipTvbSubst clsTvbs $ changeLast clsArgs repTy
          tfTvbTys = map tvbToType tfTvbs
          tfLHSTys = map (applySubstitution lhsSubst) tfTvbTys
          tfRHSTys = map (applySubstitution rhsSubst) tfTvbTys
          tfRHSTy  = applyTy (ConT tfName) tfRHSTys
      tfInst <- tySynInstDCompat tfName Nothing
                                 (map pure tfLHSTys) (pure tfRHSTy)
      pure (Just [tfInst])
    go (SigD methName methTy) =
      let (fromTy, toTy) = mkCoerceClassMethEqn clsTvbs clsArgs repTy $
                           stripOuterForallT methTy
          fromTau = stripOuterForallT fromTy
          toTau   = stripOuterForallT toTy
          rhsExpr = VarE coerceValName `AppTypeE` fromTau
                                       `AppTypeE` toTau
                                       `AppE`     VarE methName
          sig  = SigD methName toTy
          meth = ValD (VarP methName)
                      (NormalB rhsExpr)
                      []
      in return (Just [sig, meth])
    go _ = return Nothing
mkCoerceClassMethEqn :: [TyVarBndr] -> [Type] -> Type -> Type -> (Type, Type)
mkCoerceClassMethEqn clsTvbs clsArgs repTy methTy
  = ( applySubstitution rhsSubst methTy
    , applySubstitution lhsSubst methTy
    )
  where
    lhsSubst = zipTvbSubst clsTvbs clsArgs
    rhsSubst = zipTvbSubst clsTvbs $ changeLast clsArgs repTy
zipTvbSubst :: [TyVarBndr] -> [Type] -> Map Name Type
zipTvbSubst tvbs = M.fromList . zipWith (\tvb ty -> (tvName tvb, ty)) tvbs
changeLast :: [a] -> a -> [a]
changeLast []     _  = error "changeLast"
changeLast [_]    x  = [x]
changeLast (x:xs) x' = x : changeLast xs x'
stripOuterForallT :: Type -> Type
#if __GLASGOW_HASKELL__ < 807
stripOuterForallT (ForallT _ _ ty) = ty
#endif
stripOuterForallT ty               = ty
decomposeType :: Type -> ([TyVarBndr], Cxt, Type)
decomposeType (ForallT tvbs ctxt ty) = (tvbs, ctxt, ty)
decomposeType ty                     = ([],   [],   ty)
newtypeRepType :: DatatypeVariant -> [ConstructorInfo] -> Maybe Type
newtypeRepType dv cons = do
    checkIfNewtype
    case cons of
      [ConstructorInfo { constructorVars    = []
                       , constructorContext = []
                       , constructorFields  = [repTy]
                       }] -> Just repTy
      _ -> Nothing
  where
    checkIfNewtype :: Maybe ()
    checkIfNewtype
      | Newtype         <- dv = Just ()
      | NewtypeInstance <- dv = Just ()
      | otherwise             = Nothing
etaReduce :: Int -> Type -> Maybe Type
etaReduce num ty =
  let (tyHead:tyArgs) = unapplyTy ty
      (tyArgsRemaining, tyArgsDropped) = splitAt (length tyArgs - num) tyArgs
  in if canEtaReduce tyArgsRemaining tyArgsDropped
        then Just $ applyTy tyHead tyArgsRemaining
        else Nothing
#endif