{-# LANGUAGE OverloadedStrings #-}
module Clash.Core.EqSolver where
import Data.Maybe (catMaybes, mapMaybe)
import Clash.Core.Name (Name(nameOcc))
import Clash.Core.Term
import Clash.Core.TyCon
import Clash.Core.Type
import Clash.Core.Var
data TypeEqSolution
  = Solution (TyVar, Type)
  
  | AbsurdSolution
  
  | NoSolution
  
    deriving (Show, Eq)
catSolutions :: [TypeEqSolution] -> [(TyVar, Type)]
catSolutions = mapMaybe getSol
 where
  getSol (Solution s) = Just s
  getSol _ = Nothing
solveNonAbsurds :: TyConMap -> [(Type, Type)] -> [(TyVar, Type)]
solveNonAbsurds _tcm [] = []
solveNonAbsurds tcm (eq:eqs) =
  solved ++ solveNonAbsurds tcm eqs
 where
  solvers = [pure . solveAdd, solveEq tcm]
  solved = catSolutions (concat [s eq | s <- solvers])
solveEq :: TyConMap -> (Type, Type) -> [TypeEqSolution]
solveEq tcm (coreView tcm -> left, coreView tcm -> right) =
  case (left, right) of
    (VarTy tyVar, ConstTy {}) ->
      
      [Solution (tyVar, right)]
    (ConstTy {}, VarTy tyVar) ->
      
      [Solution (tyVar, left)]
    (ConstTy {}, ConstTy {}) ->
      
      if left /= right then [AbsurdSolution] else []
    (LitTy {}, LitTy {}) ->
      
      if left /= right then [AbsurdSolution] else []
    _ ->
      
      
      
      if any (isTypeFamilyApplication tcm) [left, right] then
        []
      else
        case (tyView left, tyView right) of
          (TyConApp leftNm leftTys, TyConApp rightNm rightTys) ->
            
            if leftNm == rightNm then
              concat (map (solveEq tcm) (zip leftTys rightTys))
            else
              [AbsurdSolution]
          _ ->
            []
solveAdd
  :: (Type, Type)
  -> TypeEqSolution
solveAdd ab =
  case normalizeAdd ab of
    Just (n, m, VarTy tyVar) ->
      if n >= 0 && m >= 0 && n - m >= 0 then
        Solution (tyVar, (LitTy (NumTy (n - m))))
      else
        AbsurdSolution
    _ ->
      NoSolution
normalizeAdd
  :: (Type, Type)
  -> Maybe (Integer, Integer, Type)
normalizeAdd (a, b) = do
  (n, rhs) <- lhsLit a b
  case tyView rhs of
    TyConApp (nameOcc -> "GHC.TypeNats.+") [left, right] -> do
      (m, o) <- lhsLit left right
      return (n, m, o)
    _ ->
      Nothing
 where
  lhsLit x                 (LitTy (NumTy n)) = Just (n, x)
  lhsLit (LitTy (NumTy n)) y                 = Just (n, y)
  lhsLit _                 _                 = Nothing
isAbsurdAlt
  :: TyConMap
  -> Alt
  -> Bool
isAbsurdAlt tcm alt =
  any (isAbsurdEq tcm) (altEqs tcm alt)
isAbsurdEq
  :: TyConMap
  -> (Type, Type)
  -> Bool
isAbsurdEq tcm ((left0, right0)) =
  case (coreView tcm left0, coreView tcm right0) of
    (solveAdd -> AbsurdSolution) -> True
    lr -> any (==AbsurdSolution) (solveEq tcm lr)
altEqs
  :: TyConMap
  -> Alt
  -> [(Type, Type)]
altEqs tcm (pat, _term) =
 catMaybes (map (typeEq tcm . varType) (snd (patIds pat)))
typeEq
  :: TyConMap
  -> Type
  -> Maybe (Type, Type)
typeEq tcm ty =
 case tyView (coreView tcm ty) of
  TyConApp (nameOcc -> "GHC.Prim.~#") [_, _, left, right] ->
    Just (coreView tcm left, coreView tcm right)
  _ ->
    Nothing