module Data.API.NormalForm
    ( 
      NormAPI
    , NormTypeDecl(..)
    , NormRecordType
    , NormUnionType
    , NormEnumType
      
    , apiNormalForm
    , declNF
      
    , typeDeclsFreeVars
    , typeDeclFreeVars
    , typeFreeVars
    , typeDeclaredInApi
    , typeUsedInApi
    , typeUsedInTransitiveDep
    , transitiveDeps
    , transitiveReverseDeps
      
    , apiInvariant
    , declIsValid
    , typeIsValid
      
    , substTypeDecl
    , substType
    , renameTypeUses
    ) where
import           Data.API.PP
import           Data.API.Types
import           Control.DeepSeq
import           Data.Map (Map)
import qualified Data.Map as Map
import           Data.Set (Set)
import qualified Data.Set as Set
type NormAPI = Map TypeName NormTypeDecl
data NormTypeDecl
    = NRecordType  NormRecordType
    | NUnionType   NormUnionType
    | NEnumType    NormEnumType
    | NTypeSynonym APIType
    | NNewtype     BasicType
  deriving (Eq, Show)
instance NFData NormTypeDecl where
  rnf (NRecordType  x) = rnf x
  rnf (NUnionType   x) = rnf x
  rnf (NEnumType    x) = rnf x
  rnf (NTypeSynonym x) = rnf x
  rnf (NNewtype     x) = rnf x
type NormRecordType = Map FieldName APIType
type NormUnionType  = Map FieldName APIType
type NormEnumType   = Set FieldName
apiNormalForm :: API -> NormAPI
apiNormalForm api =
    Map.fromList
      [ (name, declNF spec)
      | ThNode (APINode {anName = name, anSpec = spec}) <- api ]
declNF :: Spec -> NormTypeDecl
declNF (SpRecord (SpecRecord fields)) = NRecordType $ Map.fromList
                                          [ (fname, ftType ftype)
                                          | (fname, ftype) <- fields ]
declNF (SpUnion (SpecUnion alts))     = NUnionType $ Map.fromList
                                          [ (fname, ftype)
                                          | (fname, (ftype, _)) <- alts ]
declNF (SpEnum (SpecEnum elems))      = NEnumType $ Set.fromList
                                          [ fname | (fname, _) <- elems ]
declNF (SpSynonym t)                  = NTypeSynonym t
declNF (SpNewtype (SpecNewtype bt _)) = NNewtype bt
typeDeclsFreeVars :: NormAPI -> Set TypeName
typeDeclsFreeVars = Set.unions . map typeDeclFreeVars . Map.elems
typeDeclFreeVars :: NormTypeDecl -> Set TypeName
typeDeclFreeVars (NRecordType fields) = Set.unions . map typeFreeVars
                                                   . Map.elems $ fields
typeDeclFreeVars (NUnionType  alts)   = Set.unions . map typeFreeVars
                                                   . Map.elems $ alts
typeDeclFreeVars (NEnumType _)        = Set.empty
typeDeclFreeVars (NTypeSynonym t)     = typeFreeVars t
typeDeclFreeVars (NNewtype _)         = Set.empty
typeFreeVars :: APIType -> Set TypeName
typeFreeVars (TyList  t) = typeFreeVars t
typeFreeVars (TyMaybe t) = typeFreeVars t
typeFreeVars (TyName  n) = Set.singleton n
typeFreeVars (TyBasic _) = Set.empty
typeFreeVars  TyJSON     = Set.empty
typeDeclaredInApi :: TypeName -> NormAPI -> Bool
typeDeclaredInApi tname api = Map.member tname api
typeUsedInApi :: TypeName -> NormAPI -> Bool
typeUsedInApi tname api = tname `Set.member` typeDeclsFreeVars api
typeUsedInTransitiveDep :: TypeName -> TypeName -> NormAPI -> Bool
typeUsedInTransitiveDep root tname api =
    tname == root || tname `Set.member` transitiveDeps api (Set.singleton root)
transitiveDeps :: NormAPI -> Set TypeName -> Set TypeName
transitiveDeps api = transitiveClosure $ \ s ->
                         typeDeclsFreeVars $
                         Map.filterWithKey (\ x _ -> x `Set.member` s) api
transitiveReverseDeps :: NormAPI -> Set TypeName -> Set TypeName
transitiveReverseDeps api = transitiveClosure $ \ s ->
                         Map.keysSet $
                         Map.filter (intersects s . typeDeclFreeVars) api
  where
    intersects s1 s2 = not $ Set.null $ s1 `Set.intersection` s2
transitiveClosure :: Ord a => (Set a -> Set a) -> Set a -> Set a
transitiveClosure rel x = findUsed x0 x0
  where
    x0 = rel x
    findUsed seen old
      | Set.null new = seen
      | otherwise    = findUsed (seen `Set.union` new) new
      where
        new = rel old `Set.difference` seen
typeIsValid :: APIType -> NormAPI -> Either (Set TypeName) ()
typeIsValid t api
    | typeVars `Set.isSubsetOf` declaredTypes = return ()
    | otherwise = Left (typeVars Set.\\ declaredTypes)
  where
    typeVars      = typeFreeVars t
    declaredTypes = Map.keysSet api
declIsValid :: NormTypeDecl -> NormAPI -> Either (Set TypeName) ()
declIsValid decl api
    | declVars `Set.isSubsetOf` declaredTypes = return ()
    | otherwise = Left (declVars Set.\\ declaredTypes)
  where
    declVars      = typeDeclFreeVars decl
    declaredTypes = Map.keysSet api
apiInvariant :: NormAPI -> Either (Set TypeName) ()
apiInvariant api
  | usedTypes `Set.isSubsetOf` declaredTypes = return ()
  | otherwise = Left (usedTypes Set.\\ declaredTypes)
  where
    usedTypes     = typeDeclsFreeVars api
    declaredTypes = Map.keysSet api
substTypeDecl :: (TypeName -> APIType) -> NormTypeDecl -> NormTypeDecl
substTypeDecl f   (NRecordType fields) = NRecordType (Map.map (substType f) fields)
substTypeDecl f   (NUnionType  alts)   = NUnionType (Map.map (substType f) alts)
substTypeDecl _ d@(NEnumType _)        = d
substTypeDecl f   (NTypeSynonym t)     = NTypeSynonym (substType f t)
substTypeDecl _ d@(NNewtype _)         = d
substType :: (TypeName -> APIType) -> APIType -> APIType
substType f (TyList  t)   = TyList (substType f t)
substType f (TyMaybe t)   = TyMaybe (substType f t)
substType f (TyName  n)   = f n
substType _ t@(TyBasic _) = t
substType _ t@TyJSON      = t
renameTypeUses :: TypeName -> TypeName -> NormAPI -> NormAPI
renameTypeUses tname tname' = Map.map (substTypeDecl rename)
  where
    rename tn | tn == tname = TyName tname'
              | otherwise   = TyName tn
instance PPLines NormTypeDecl where
  ppLines (NRecordType flds) = "record" : map (\ (f, ty) -> "  " ++ pp f
                                                            ++ " :: " ++ pp ty)
                                              (Map.toList flds)
  ppLines (NUnionType alts)  = "union"  : map (\ (f, ty) -> "  | " ++ pp f
                                                            ++ " :: " ++ pp ty)
                                              (Map.toList alts)
  ppLines (NEnumType vals)   = "enum"   : map (\ v -> "  | " ++ pp v)
                                              (Set.toList vals)
  ppLines (NTypeSynonym t)   = [pp t]
  ppLines (NNewtype b)       = ["basic " ++ pp b]