{-# LANGUAGE Safe, PatternGuards, MultiWayIf #-}
module Cryptol.TypeCheck.Solver.Numeric
  ( cryIsEqual, cryIsNotEqual, cryIsGeq
  ) where
import           Control.Applicative(Alternative(..))
import           Control.Monad (guard,mzero)
import           Data.List (sortBy)
import Cryptol.Utils.Patterns
import Cryptol.TypeCheck.PP
import Cryptol.TypeCheck.Type hiding (tMul)
import Cryptol.TypeCheck.TypePat
import Cryptol.TypeCheck.Solver.Types
import Cryptol.TypeCheck.Solver.InfNat
import Cryptol.TypeCheck.Solver.Numeric.Interval
import Cryptol.TypeCheck.SimpType as Simp
cryIsEqual :: Ctxt -> Type -> Type -> Solved
cryIsEqual ctxt t1 t2 =
  matchDefault Unsolved $
        (pBin PEqual (==) t1 t2)
    <|> (aNat' t1 >>= tryEqK ctxt t2)
    <|> (aNat' t2 >>= tryEqK ctxt t1)
    <|> (aTVar t1 >>= tryEqVar t2)
    <|> (aTVar t2 >>= tryEqVar t1)
    <|> ( guard (t1 == t2) >> return (SolvedIf []))
    <|> tryEqMin t1 t2
    <|> tryEqMin t2 t1
    <|> tryEqMins t1 t2
    <|> tryEqMins t2 t1
    <|> tryEqMulConst t1 t2
    <|> tryEqAddInf ctxt t1 t2
    <|> tryAddConst (=#=) t1 t2
    <|> tryCancelVar ctxt (=#=) t1 t2
    <|> tryLinearSolution t1 t2
    <|> tryLinearSolution t2 t1
cryIsNotEqual :: Ctxt -> Type -> Type -> Solved
cryIsNotEqual _i t1 t2 = matchDefault Unsolved (pBin PNeq (/=) t1 t2)
cryIsGeq :: Ctxt -> Type -> Type -> Solved
cryIsGeq i t1 t2 =
  matchDefault Unsolved $
        (pBin PGeq (>=) t1 t2)
    <|> (aNat' t1 >>= tryGeqKThan i t2)
    <|> (aNat' t2 >>= tryGeqThanK i t1)
    <|> (aTVar t2 >>= tryGeqThanVar i t1)
    <|> tryGeqThanSub i t1 t2
    <|> (geqByInterval i t1 t2)
    <|> (guard (t1 == t2) >> return (SolvedIf []))
    <|> tryAddConst (>==) t1 t2
    <|> tryCancelVar i (>==) t1 t2
    <|> tryMinIsGeq t1 t2
    
    
  
  
pBin :: PC -> (Nat' -> Nat' -> Bool) -> Type -> Type -> Match Solved
pBin tf p t1 t2 =
      Unsolvable <$> anError KNum t1
  <|> Unsolvable <$> anError KNum t2
  <|> (do x <- aNat' t1
          y <- aNat' t2
          return $ if p x y
                      then SolvedIf []
                      else Unsolvable $ TCErrorMessage
                        $ "Unsolvable constraint: " ++
                              show (pp (TCon (PC tf) [ tNat' x, tNat' y ])))
tryGeqKThan :: Ctxt -> Type -> Nat' -> Match Solved
tryGeqKThan _ _ Inf = return (SolvedIf [])
tryGeqKThan _ ty (Nat n) =
  
  do (a,b) <- aMul ty
     m     <- aNat' a
     return $ SolvedIf
            $ case m of
                Inf   -> [ b =#= tZero ]
                Nat 0 -> []
                Nat k -> [ tNum (div n k) >== b ]
tryGeqThanK :: Ctxt -> Type -> Nat' -> Match Solved
tryGeqThanK _ t Inf = return (SolvedIf [ t =#= tInf ])
tryGeqThanK _ t (Nat k) =
  
  do (a,b) <- anAdd t
     n     <- aNat a
     return $ SolvedIf $ if n >= k
                            then []
                            else [ b >== tNum (k - n) ]
  
tryGeqThanSub :: Ctxt -> Type -> Type -> Match Solved
tryGeqThanSub _ x y =
  
  do (a,_) <- (|-|) y
     guard (x == a)
     return (SolvedIf [])
tryGeqThanVar :: Ctxt -> Type -> TVar -> Match Solved
tryGeqThanVar _ctxt ty x =
  
  do (a,b) <- anAdd ty
     let check y = do x' <- aTVar y
                      guard (x == x')
                      return (SolvedIf [])
     check a <|> check b
geqByInterval :: Ctxt -> Type -> Type -> Match Solved
geqByInterval ctxt x y =
  let ix = typeInterval ctxt x
      iy = typeInterval ctxt y
  in case (iLower ix, iUpper iy) of
       (l,Just n) | l >= n -> return (SolvedIf [])
       _                   -> mzero
tryMinIsGeq :: Type -> Type -> Match Solved
tryMinIsGeq t1 t2 =
  do (a,b) <- aMin t1
     k1    <- aNat a
     k2    <- aNat t2
     return $ if k1 >= k2
               then SolvedIf [ b >== t2 ]
               else Unsolvable $ TCErrorMessage $
                      show k1 ++ " can't be greater than " ++ show k2
tryCancelVar :: Ctxt -> (Type -> Type -> Prop) -> Type -> Type -> Match Solved
tryCancelVar ctxt p t1 t2 =
  let lhs = preproc t1
      rhs = preproc t2
  in case check [] [] lhs rhs of
       Nothing -> fail ""
       Just x  -> return x
  where
  check doneLHS doneRHS lhs@((a,mbA) : moreLHS) rhs@((b, mbB) : moreRHS) =
    do x <- mbA
       y <- mbB
       case compare x y of
         LT -> check (a : doneLHS) doneRHS moreLHS rhs
         EQ -> return $ SolvedIf [ p (term (doneLHS ++ map fst moreLHS))
                                     (term (doneRHS ++ map fst moreRHS)) ]
         GT -> check doneLHS (b : doneRHS) lhs moreRHS
  check _ _ _ _ = Nothing
  term xs = case xs of
              [] -> tNum (1::Int)
              _  -> foldr1 tMul xs
  preproc t = let fs = splitMul t []
              in sortBy cmpFact (zip fs (map cancelVar fs))
  splitMul t rest = case matchMaybe (aMul t) of
                      Just (a,b) -> splitMul a (splitMul b rest)
                      Nothing    -> t : rest
  cancelVar t = matchMaybe $ do x <- aTVar t
                                guard (iIsPosFin (tvarInterval ctxt x))
                                return x
  
  cmpFact (_,mbA) (_,mbB) =
    case (mbA,mbB) of
      (Just x, Just y)  -> compare x y
      (Just _, Nothing) -> LT
      (Nothing, Just _) -> GT
      _                 -> EQ
tryEqMin :: Type -> Type -> Match Solved
tryEqMin x y =
  do (a,b) <- aMin x
     let check m1 m2 = do guard (m1 == y)
                          return $ SolvedIf [ m2 >== m1 ]
     check a b <|> check b a
tryEqMins :: Type -> Type -> Match Solved
tryEqMins x y =
  do (a, b) <- aMin y
     let ys = splitMin a ++ splitMin b
     let ys' = filter (not . isGt) ys
     let y' = if null ys' then tInf else foldr1 Simp.tMin ys'
     return $ if length ys' < length ys
              then SolvedIf [x =#= y']
              else Unsolved
  where
    splitMin :: Type -> [Type]
    splitMin ty =
      case matchMaybe (aMin ty) of
        Just (t1, t2) -> splitMin t1 ++ splitMin t2
        Nothing       -> [ty]
    isGt :: Type -> Bool
    isGt t =
      case matchMaybe (asAddK t) of
        Just (k, t') -> k > 0 && t' == x
        Nothing      -> False
    asAddK :: Type -> Match (Integer, Type)
    asAddK t =
      do (t1, t2) <- anAdd t
         k <- aNat t1
         return (k, t2)
tryEqVar :: Type -> TVar -> Match Solved
tryEqVar ty x =
  
  (do (k,tv) <- matches ty (anAdd, aNat, aTVar)
      guard (tv == x && k >= 1)
      return $ SolvedIf [ TVar x =#= tInf ]
  )
  <|>
  
  (do (l,r) <- aMin ty
      let check this other =
            do (k,x') <- matches this (anAdd, aNat', aTVar)
               guard (x == x' && k >= Nat 1)
               return $ SolvedIf [ TVar x =#= other ]
      check l r <|> check r l
  )
  <|>
  
  (do (k,(l,r)) <- matches ty (anAdd, aNat, aMin)
      guard (k >= 1)
      let check a b = do x' <- aTVar a
                         guard (x' == x)
                         return (SolvedIf [ TVar x =#= tAdd (tNum k) b ])
      check l r <|> check r l
  )
tryEqK :: Ctxt -> Type -> Nat' -> Match Solved
tryEqK ctxt ty lk =
  
  do guard (lk == Inf)
     (a,b) <- anAdd ty
     let check x y = do guard (iIsFin (typeInterval ctxt x))
                        return $ SolvedIf [ y =#= tInf ]
     check a b <|> check b a
  <|>
  
  do (rk, b) <- matches ty (anAdd, aNat', __)
     return $
       case nSub lk rk of
         
         Nothing -> Unsolvable
                      $ TCErrorMessage
                      $ "Adding " ++ showNat' rk ++ " will always exceed "
                                  ++ showNat' lk
         Just r -> SolvedIf [ b =#= tNat' r ]
  <|>
  
  do (t,rk) <- matches ty ((|-|) , __, aNat')
     return (SolvedIf [ t =#= tNat' (nAdd lk rk) ])
  <|>
  do (rk, b) <- matches ty (aMul, aNat', __)
     return $
       case (lk,rk) of
         
         (Inf,Inf)    -> SolvedIf [ b >== tOne ]
         
         (Inf,Nat _)  -> SolvedIf [ b =#= tInf ]
         
         (Nat 0, Inf) -> SolvedIf [ b =#= tZero ]
         
         (Nat k, Inf) -> Unsolvable
                       $ TCErrorMessage
                       $ show k ++ " != inf * anything"
         (Nat lk', Nat rk')
           
           | rk' == 0 -> SolvedIf [ tNat' lk =#= tZero ]
              
           
           | (q,0) <- divMod lk' rk' -> SolvedIf [ b =#= tNum q ]
           | otherwise ->
               Unsolvable
             $ TCErrorMessage
             $ showNat' lk ++ " != " ++ showNat' rk ++ " * anything"
  <|>
  
  do (rk, b) <- matches ty ((|^|), aNat, __)
     return $ case lk of
                Inf | rk > 1 -> SolvedIf [ b =#= tInf ]
                Nat n | Just (a,True) <- genLog n rk -> SolvedIf [ b =#= tNum a]
                _ -> Unsolvable $ TCErrorMessage
                       $ show rk ++ " ^^ anything != " ++ showNat' lk
  
  
  
  
tryEqMulConst :: Type -> Type -> Match Solved
tryEqMulConst l r =
  do (lc,ls) <- matchLinear l
     (rc,rs) <- matchLinear r
     let d = foldr1 gcd (lc : rc : map fst (ls ++ rs))
     guard (d > 1)
     return (SolvedIf [build d lc ls =#= build d rc rs])
  where
  build d k ts   = foldr tAdd (cancel d k) (map (buildS d) ts)
  buildS d (k,t) = tMul (cancel d k) t
  cancel d x     = tNum (div x d)
tryEqAddInf :: Ctxt -> Type -> Type -> Match Solved
tryEqAddInf ctxt l r = check l r <|> check r l
  where
  
  check x y =
    do (x1,x2) <- anAdd x
       aInf y
       let x1Fin = iIsFin (typeInterval ctxt x1)
       let x2Fin = iIsFin (typeInterval ctxt x2)
       return $!
         if | x1Fin ->
              SolvedIf [ x2 =#= y ]
            | x2Fin ->
              SolvedIf [ x1 =#= y ]
            | otherwise ->
              Unsolved
tryAddConst :: (Type -> Type -> Prop) -> Type -> Type -> Match Solved
tryAddConst rel l r =
  do (x1,x2) <- anAdd l
     (y1,y2) <- anAdd r
     k1 <- aNat x1
     k2 <- aNat y1
     if k1 > k2
        then return (SolvedIf [ tAdd (tNum (k1 - k2)) x2 `rel` y2 ])
        else return (SolvedIf [ x2 `rel` tAdd (tNum (k2 - k1)) y2 ])
tryLinearSolution :: Type -> Type -> Match Solved
tryLinearSolution s1 t =
  do (a,xs) <- matchLinearUnifier t
     guard (noFreeVariables s1)
     
     let s2 = foldr1 Simp.tAdd xs
     return (SolvedIf [ TVar a =#= (Simp.tSub s1 s2), s1 >== s2 ])
matchLinearUnifier :: Pat Type (TVar,[Type])
matchLinearUnifier = go []
 where
  go xs t =
    
    
    do v <- aFreeTVar t
       guard (not . null $ xs)
       return (v, xs)
    <|>
    
    do (x, y) <- anAdd t
        
       (do v <- aFreeTVar x
           guard (noFreeVariables y)
           return (v, reverse (y:xs))
        <|>
         
        do guard (noFreeVariables x)
           go (x:xs) y)
matchLinear :: Pat Type (Integer, [(Integer,Type)])
matchLinear = go (0, [])
  where
  go (c,ts) t =
    do n <- aNat t
       return (n + c, ts)
    <|>
    do (x,y) <- aMul t
       n     <- aNat x
       return (c, (n,y) : ts)
    <|>
    do (l,r) <- anAdd t
       (c',ts') <- go (c,ts) l
       go (c',ts') r
showNat' :: Nat' -> String
showNat' Inf = "inf"
showNat' (Nat n) = show n