{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE FlexibleInstances #-}

module Language.Prolog.NanoProlog.Lib (
     LowerCase
  ,  Result(..)
  ,  Rule((:<-:))
  ,  Subst(..)
  ,  Taggable(..)
  ,  Term(..)
  ,  emptyEnv
  ,  enumerateDepthFirst
  ,  pFun
  ,  pRule
  ,  pTerm
  ,  pTerms
  ,  show'
  ,  solve
  ,  startParse
  ,  unify
  ) where

import            Data.ListLike.Base (ListLike)
import            Data.List (intercalate)
import            Data.Map (Map)
import qualified  Data.Map as M
import            Text.ParserCombinators.UU
import            Text.ParserCombinators.UU.BasicInstances
import            Text.ParserCombinators.UU.Utils

-- * Types
type UpperCase  = String
type LowerCase  = String

data Term  =  Var UpperCase
           |  Fun LowerCase [Term]
           deriving (Eq, Ord)

data Rule  =  Term :<-: [Term]
           deriving Eq

class Taggable a where
  tag :: Int -> a -> a

instance Taggable Term where
  tag n (Var  x)     = Var  (x ++ show n)
  tag n (Fun  x xs)  = Fun  x (tag n xs)

instance Taggable Rule where
  tag n (c :<-: cs) = tag n c :<-: tag n cs

instance Taggable a => Taggable [a] where
  tag n = map (tag n)

type Env = Map UpperCase Term

emptyEnv :: Maybe (Map UpperCase t)
emptyEnv = Just M.empty

-- * The Prolog machinery
data Result  =  Done Env
             |  ApplyRules [(Rule, Result)]

type Proofs = [(String, Rule)]

class Subst t where
  subst :: Env -> t -> t

instance Subst a => Subst [a] where
  subst e = map (subst e)

instance Subst Term where
  subst env (Var x)     = maybe (Var x) (subst env) (M.lookup x env)
  subst env (Fun x cs)  = Fun x (subst env cs)

instance Subst Rule where
  subst env (c :<-: cs) = subst env c :<-: subst env cs

unify :: (Term, Term) -> Maybe Env-> Maybe Env
unify _       Nothing       = Nothing
unify (t, u)  env@(Just m)  = uni (subst m t) (subst m u)
  where  uni  (Var x)  y        = Just (M.insert x  y  m)
         uni  x        (Var y)  = Just (M.insert y  x  m)
         uni  (Fun x xs) (Fun y ys)
           |  x == y && length xs == length ys  = foldr unify env (zip xs ys)
           |  otherwise                         = Nothing

solve :: [Rule] -> Maybe Env -> Int -> [Term] -> Result
solve _      Nothing   _  _       = ApplyRules []
solve _      (Just e)  _  []      = Done e
solve rules  e         n  (t:ts)  = ApplyRules
  [  (rule, solve rules nextenv (n+1) (cs ++ ts))
  |  rule@(c :<-: cs)  <- tag n rules
  ,  nextenv@(Just _)  <- [unify (t, c) e]
  ]

-- ** Printing the solutions | `enumerateBreadthFirst` performs a
-- depth-first walk over the `Result` tree, while accumulating the
-- rules that were applied on the path which was traversed from the
-- root to the current node. At a successful leaf this contains the
-- full proof.
enumerateDepthFirst :: Proofs -> [String] -> Result -> [(Proofs, Env)]
enumerateDepthFirst proofs _ (Done env) = [(proofs, env)]
enumerateDepthFirst proofs (pr:prefixes) (ApplyRules bs) =
  [ s  |  (rule@(c :<-: cs), subTree) <- bs
       ,  let extraPrefixes = take (length cs) (map (\i -> pr ++ "." ++ show i) [1 ..])
       ,  s <- enumerateDepthFirst ((pr, rule):proofs) (extraPrefixes ++ prefixes) subTree
  ]

{-
-- | `enumerateBreadthFirst` is still undefined, and is left as an
-- exercise to the JCU students
enumerateBreadthFirst :: Proofs -> [String] -> Result -> [(Proofs, Env)]
-}

-- | `printEnv` prints a single solution, showing only the variables
-- that were introduced in the original goal
show' :: Env -> String
show' env = intercalate ", " . filter (not.null) . map showBdg $ M.assocs env
  where  showBdg (x, t)  | isGlobVar x =  x ++ " <- " ++ showTerm t
                         | otherwise = ""
         showTerm t@(Var _)  = showTerm (subst env t)
         showTerm (Fun f []) = f
         showTerm (Fun f ts) = f ++ "(" ++ intercalate ", " (map showTerm ts) ++ ")"
         isGlobVar x = head x `elem` ['A'..'Z'] && last x `notElem` ['0'..'9']

instance Show Term where
  show (Var  i)       = i
  show (Fun  i []  )  = i
  show (Fun  i ts  )  = i ++ "(" ++ showCommas ts ++ ")"

instance Show Rule where
  show (t :<-: []  ) = show t ++ "."
  show (t :<-: ts  ) = show t ++ ":-" ++ showCommas ts ++ "."

showCommas :: Show a => [a] -> String
showCommas l = intercalate ", " (map show l)

-- ** Parsing Rules and Terms
startParse :: (ListLike s b, Show b)  => P (Str b s LineColPos) a -> s
                                      -> (a, [Error LineColPos])
startParse p inp  =  parse ((,) <$> p <*> pEnd)
                  $  createStr (LineColPos 0 0 0) inp

pTerm, pVar, pFun :: Parser Term
pTerm  = pVar  <|>  pFun
pVar   = Var   <$>  lexeme (pList1 pUpper)
pFun   = Fun   <$>  pLowerCase <*> (pParens pTerms `opt` [])
  where  pLowerCase :: Parser String
         pLowerCase = (:) <$> pLower <*> lexeme (pList (pLetter <|> pDigit))

pRule :: Parser Rule
pRule = (:<-:) <$> pFun <*> (pSymbol ":-" *> pTerms `opt` []) <* pDot

pTerms :: Parser [Term]
pTerms = pListSep pComma pTerm