-- | Defines functions to ensure that only valid declarations and type 
--   signatures are fed to the FreeTheorems library. The given functions are
--   intended as second stage after parsing declarations.

module Language.Haskell.FreeTheorems.Frontend (
  , Parsed
  , runChecks
  , check
  , checkAgainst
) where

import Data.Generics (everything, extQ, mkQ)
import Data.List (partition, intersect)
import Data.Maybe (mapMaybe)

import Language.Haskell.FreeTheorems.Syntax
import Language.Haskell.FreeTheorems.ValidSyntax (ValidDeclaration (..))
import Language.Haskell.FreeTheorems.Frontend.Error (Checked, Parsed, runChecks)
import Language.Haskell.FreeTheorems.Frontend.TypeExpressions
    (replaceAllTypeSynonyms, closeTypeExpressions)
import Language.Haskell.FreeTheorems.Frontend.CheckLocal
    (checkLocal, checkDataAndNewtypeDeclarations)
import Language.Haskell.FreeTheorems.Frontend.CheckGlobal (checkGlobal)

-- | Checks a list of declarations.
--   It returns a list of all declarations which are valid and an error message
--   for all those declarations which are not valid.

check :: [Declaration] -> Checked [ValidDeclaration]
check = checkAgainst []

-- | Checks a list of declarations against a given list of valid
--   declarations.
--   It returns a list of all declarations from the second argument which are
--   valid. Moreover, the result contains an error message for all those
--   declarations which are not valid.
--   The declarations given in the second argument may be based on those of the
--   first argument. For example, if the first argument contains a valid
--   declaration of a type \"Foo\" and if the second argument contains the
--   following declaration
--   > type Bar = Foo
--   then also the declaration of \"Bar\" is valid.

checkAgainst :: 
    -> [Declaration] 
    -> Checked [ValidDeclaration]

checkAgainst vds ds = 
    -- start from 'ds'
  return ds
    -- perform local checks:
    --   * free variables of the right-hand side are declared on the left-hand
    --     of declarations
    --   * type variables of the left-hand side are pairwise distinct
    --   * primitive types are not declared
    --   * FixedTypeExpression does not occur anywhere
    --   * type synonyms are not recursive
    --   * data and newtype are not nested
    --   * classes methods are pairwise distinct, don't use the owning class
    --     and have the class variable as free variable
  >>= checkLocal
    -- perform global checks:
    --   * at most one declaration per name
    --   * arity checks of type constructors in all type expressions
    --   * type class hierarchy is acyclic
    --   * type synonym declarations are not mutually recursive
    --   * all used constructors and classes are declared
  >>= checkGlobal vds

    -- replace all type synonyms, use also the valid type synonyms
  >>= \ds' -> 
    let getTypeSyn d = case d of { TypeDecl t -> Just t ; otherwise -> Nothing }
        typeSyns = mapMaybe getTypeSyn (map rawDeclaration vds ++ ds')
     in return (replaceAllTypeSynonyms typeSyns ds')

    -- checks in data and newtype declarations: no abstractions, no functions
  >>= checkDataAndNewtypeDeclarations

    -- finally, close all type signatures and class methods and transform all
    -- declarations to valid ones
  >>= return . makeValid vds . closeTypeExpressions

-- | Turns a list of declarations into valid declarations.
--   Additionally, every declaration is checked whether it depends on any 
--   algebraic data type with strictness flags.

makeValid :: [ValidDeclaration] -> [Declaration] -> [ValidDeclaration]
makeValid vds ds = 
  let strict = map rawDeclaration (filter isStrictDeclaration vds)
      knownStrict = map getDeclarationName 
                        (strict ++ filter hasStrictnessFlags ds)
      rec ss ds = 
        let (ns, os) = partition (dependsOnStrictTypes ss) ds
         in if null ns
              then ss
              else rec (ss ++ map getDeclarationName ns) os

      allStrict = rec knownStrict ds
   in map (\d -> ValidDeclaration d (getDeclarationName d `elem` allStrict)) ds
    hasStrictnessFlags d = 
      let hasBang (Banged _)   = True
          hasBang (Unbanged _) = False
       in everything (||) (False `mkQ` hasBang) d
    dependsOnStrictTypes ss d = 
      let getCons c = case c of { Con n -> [n] ; otherwise -> [] }
          getClasses (TC n) = [n]
          ns = everything (++) ([] `mkQ` getCons `extQ` getClasses) d
       in not (null (ns `intersect` ss))