{-# OPTIONS_GHC -Wall #-}
module Transform.Canonicalize.Type (tipe) where

import Control.Arrow (second)
import Control.Applicative ((<$>),(<*>))
import Control.Monad.Error
import qualified Data.Map as Map
import Data.Traversable (traverse)

import qualified AST.Type as T
import qualified AST.Variable as Var

import Transform.Canonicalize.Environment
import qualified Transform.Canonicalize.Variable as Canonicalize

tipe :: Environment -> T.RawType -> Canonicalizer String T.CanonicalType
tipe env typ =
    let go = tipe env in
    case typ of
      T.Var x    -> return (T.Var x)
      T.Type _   -> canonicalizeApp env typ []
      T.App t ts -> canonicalizeApp env t ts

      T.Lambda a b     -> T.Lambda <$> go a <*> go b
      T.Aliased name t -> T.Aliased name <$> go t

      T.Record fields ext ->
          let go' (f,t) = (,) f <$> go t
          in  T.Record <$> mapM go' fields <*> traverse go ext

canonicalizeApp :: Environment -> T.RawType -> [T.RawType]
                -> Canonicalizer String T.CanonicalType
canonicalizeApp env f args =
  case f of
    T.Type (Var.Raw rawName) ->
        do answer <- Canonicalize.tvar env rawName
           case answer of
             Right alias -> canonicalizeAlias env alias args
             Left name -> case args of
                            []  -> return (T.Type name)
                            _:_ -> T.App (T.Type name) <$> mapM (tipe env) args

    _ -> T.App <$> tipe env f <*> mapM (tipe env) args

canonicalizeAlias :: Environment -> (Var.Canonical, [String], T.CanonicalType)
                  -> [T.RawType]
                  -> Canonicalizer String T.CanonicalType
canonicalizeAlias env (name, tvars, dealiasedTipe) tipes =
  do when (tipesLen /= tvarsLen) (throwError msg)
     tipes' <- mapM (tipe env) tipes
     let tipe' = replace (Map.fromList (zip tvars tipes')) dealiasedTipe
     return $ T.Aliased name tipe'
  where
    tipesLen = length tipes
    tvarsLen = length tvars

    msg :: String
    msg = "Type alias '" ++ Var.toString name ++ "' expects " ++ show tvarsLen ++
          " type argument" ++ (if tvarsLen == 1 then "" else "s") ++
          " but was given " ++ show tipesLen

    replace :: Map.Map String T.CanonicalType -> T.CanonicalType -> T.CanonicalType
    replace typeTable t =
        let go = replace typeTable in
        case t of
          T.Lambda a b          -> T.Lambda (go a) (go b)
          T.Var x               -> Map.findWithDefault t x typeTable
          T.Record fields ext   -> T.Record (map (second go) fields) (fmap go ext)
          T.Aliased original t' -> T.Aliased original (go t')
          T.Type _              -> t
          T.App f args          -> T.App (go f) (map go args)