```{-# LANGUAGE OverloadedStrings #-}

module Clash.Core.EqSolver where

import Data.Maybe (catMaybes, mapMaybe)

import Clash.Core.Name (Name(nameOcc))
import Clash.Core.Term
import Clash.Core.TyCon
import Clash.Core.Type
import Clash.Core.Var

-- | Data type that indicates what kind of solution (if any) was found
data TypeEqSolution
= Solution (TyVar, Type)
-- ^ Solution was found. Variable equals some integer.
| AbsurdSolution
-- ^ A solution was found, but it involved negative naturals.
| NoSolution
-- ^ Given type wasn't an equation, or it was unsolvable.
deriving (Show, Eq)

catSolutions :: [TypeEqSolution] -> [(TyVar, Type)]
catSolutions = mapMaybe getSol
where
getSol (Solution s) = Just s
getSol _ = Nothing

-- | Solve given equations and return all non-absurd solutions
solveNonAbsurds :: TyConMap -> [(Type, Type)] -> [(TyVar, Type)]
solveNonAbsurds _tcm [] = []
solveNonAbsurds tcm (eq:eqs) =
solved ++ solveNonAbsurds tcm eqs
where
solvers = [pure . solveAdd, solveEq tcm]
solved = catSolutions (concat [s eq | s <- solvers])

-- | Solve simple equalities such as:
--
--   * a ~ 3
--   * 3 ~ a
--   * SomeType a b ~ SomeType 3 5
--   * SomeType 3 5 ~ SomeType a b
--   * SomeType a 5 ~ SomeType 3 b
--
solveEq :: TyConMap -> (Type, Type) -> [TypeEqSolution]
solveEq tcm (coreView tcm -> left, coreView tcm -> right) =
case (left, right) of
(VarTy tyVar, ConstTy {}) ->
-- a ~ 3
[Solution (tyVar, right)]
(ConstTy {}, VarTy tyVar) ->
-- 3 ~ a
[Solution (tyVar, left)]
(ConstTy {}, ConstTy {}) ->
-- Int /= Char
if left /= right then [AbsurdSolution] else []
(LitTy {}, LitTy {}) ->
-- 3 /= 5
if left /= right then [AbsurdSolution] else []
_ ->
-- The call to 'coreView' at the start of 'solveEq' should have reduced
-- all solvable type families. If we encounter one here that means the
-- type family is stuck (and that we shouldn't compare it to anything!).
if any (isTypeFamilyApplication tcm) [left, right] then
[]
else
case (tyView left, tyView right) of
(TyConApp leftNm leftTys, TyConApp rightNm rightTys) ->
-- SomeType a b ~ SomeType 3 5 (or other way around)
if leftNm == rightNm then
concat (map (solveEq tcm) (zip leftTys rightTys))
else
[AbsurdSolution]
_ ->
[]

-- | Solve equations supported by @normalizeAdd@. See documentation of
-- @TypeEqSolution@ to understand the return value.
:: (Type, Type)
-> TypeEqSolution
Just (n, m, VarTy tyVar) ->
if n >= 0 && m >= 0 && n - m >= 0 then
Solution (tyVar, (LitTy (NumTy (n - m))))
else
AbsurdSolution
_ ->
NoSolution

-- | Given the left and right side of an equation, normalize it such that
-- equations of the following forms:
--
--     * 5     ~ n + 2
--     * 5     ~ 2 + n
--     * n + 2 ~ 5
--     * 2 + n ~ 5
--
-- are returned as (5, 2, n)
:: (Type, Type)
-> Maybe (Integer, Integer, Type)
(n, rhs) <- lhsLit a b
case tyView rhs of
TyConApp (nameOcc -> "GHC.TypeNats.+") [left, right] -> do
(m, o) <- lhsLit left right
return (n, m, o)
_ ->
Nothing
where
lhsLit x                 (LitTy (NumTy n)) = Just (n, x)
lhsLit (LitTy (NumTy n)) y                 = Just (n, y)
lhsLit _                 _                 = Nothing

-- | Tests for unreachable alternative due to types being "absurd". See
isAbsurdAlt
:: TyConMap
-> Alt
-> Bool
isAbsurdAlt tcm alt =
any (isAbsurdEq tcm) (altEqs tcm alt)

-- | Determines if an "equation" obtained through @altEqs@ or @typeEq@ is
-- absurd. That is, it tests if two types that are definitely not equal are
-- asserted to be equal OR if the computation of the types yield some absurd
-- (intermediate) result such as -1.
isAbsurdEq
:: TyConMap
-> (Type, Type)
-> Bool
isAbsurdEq tcm ((left0, right0)) =
case (coreView tcm left0, coreView tcm right0) of
lr -> any (==AbsurdSolution) (solveEq tcm lr)

-- | Get constraint equations
altEqs
:: TyConMap
-> Alt
-> [(Type, Type)]
altEqs tcm (pat, _term) =
catMaybes (map (typeEq tcm . varType) (snd (patIds pat)))

-- | If type is an equation, return LHS and RHS.
typeEq
:: TyConMap
-> Type
-> Maybe (Type, Type)
typeEq tcm ty =
case tyView (coreView tcm ty) of
TyConApp (nameOcc -> "GHC.Prim.~#") [_, _, left, right] ->
Just (coreView tcm left, coreView tcm right)
_ ->
Nothing

```