{-# LANGUAGE LambdaCase #-}
module Data.Syntax.Node (
  module Data.Syntax,

  Node(..),ThunkN,Env,nil,shape,dict,

  funcall,builtin,builtin2,builtin3,lambda,lambdaSum,

  reduce
  ) where

import Definitive
import Data.Syntax
import Language.Syntax.Regex

data Node k b a = ValList [a]
                | Dictionary (Map k a)
                | Quote (Node k b a)
                | Text k
                | Function b
type Env k m = Map k (ThunkT (Node k) m ())

instance Functor (Node k b) where
  map f (ValList l) = ValList (map f l)
  map f (Dictionary d) = Dictionary (map f d)
  map f (Quote s) = Quote (map f s)
  map _ (Text k) = Text k
  map _ (Function f) = Function f
instance Foldable (Node k b) where
  fold (ValList l) = fold l
  fold (Dictionary d) = fold d
  fold (Quote a) = fold a
  fold _ = zero
instance Eq k => Traversable (Node k b) where
  sequence (ValList l) = ValList<$>sequence l
  sequence (Dictionary d) = Dictionary<$>sequence d
  sequence (Quote a) = Quote<$>sequence a
  sequence (Text k) = pure (Text k)
  sequence (Function f) = pure (Function f)
instance Eq k => NodeFunctor (Node k) m
instance (Show k,Show b,Show a) => Show (Node k b a) where
  show (ValList l) = show l
  show (Dictionary d) = "{"+show (toList (map show d^.keyed))+"}"
  show (Text t) = show t
  show (Quote s) = "'"+show s
  show (Function f) = show f

type ThunkN k m = ThunkT (Node k) m ()

nil :: SyntaxT (Node k) m a
nil = SyntaxT (Join (ValList zero))
funcall :: (Eq k,Unit m) => ThunkT (Node k) m a -> ThunkT (Node k) m a -> ThunkT (Node k) m a
funcall f x = liftNS (ValList [f,x])
builtin :: (Eq k,Unit m) => (ThunkN k m -> ThunkN k m) -> ThunkN k m
builtin f = liftNS (Function (Lambda (Nothing,f)))
builtin2 :: (Eq k,Unit m) => (ThunkN k m -> ThunkN k m -> ThunkN k m) -> ThunkN k m
builtin2 f = builtin (\a -> builtin (f a))
builtin3 :: (Eq k,Unit m) => (ThunkN k m -> ThunkN k m -> ThunkN k m -> ThunkN k m) -> ThunkN k m
builtin3 f = builtin (\a -> builtin2 (f a))

dict :: Traversal' (Node k b a) (Map k a)
dict = prism f g
  where f (Dictionary d) = Right d
        f c = Left c
        g (Dictionary _) d = Dictionary d
        g x _ = x
shape :: Node k b a -> String
shape (ValList []) = "Nil"
shape (ValList _) = "ValList"
shape (Text _) = "Text"
shape (Dictionary _) = "Dictionary"
shape (Quote _) = "Quote"
shape (Function _) = "Function"

reduce :: (Ord k,MonadReader (Env k m) m) => ThunkN k m -> ThunkN k m
reduce th = force th >>= \v -> case v of
  ValList (fun:args) -> foldl' (\f a -> force f >>= call a) fun args
  Dictionary d -> liftNS $ Dictionary $ fix (\d' -> map (local (d'+) . reduce) d)
  Quote n -> liftNS n
  a -> liftNS a
  where call x f = case f of
          Function (Lambda (_,f')) -> f' x
          _  -> error "Invalid function call"

class (Ord k,Monoid k) => Matching k where
  matchRe :: k -> k -> Maybe [(k,k)]
instance Matching String where
  matchRe re = \s -> case match s of
    ((x,_):_) -> Just x
    _ -> Nothing
    where match = runRegex re

lambda :: (Matching k,MonadReader (Env k m) m) => ThunkN k m -> ThunkN k m -> ThunkN k m
lambda pat e = liftF (emerge (perform pat)) >>= tryAlt
  where tryAlt p = builtin b
          where b x = match x >>= maybe (liftS nil) bind
                bind vars = local (compose (_insert<$>c'list vars)) (reduce e)
                  where _insert (s,v) = insert s v
                match = matchPat p

matchPat :: (Monad m,Matching k) => SyntaxT (Node k) m () -> ThunkN k m -> ThunkT (Node k) m (Maybe [(k,ThunkN k m)])
matchPat (SyntaxT (Join j)) = case j of
  Dictionary d -> \x -> force x >>= \n -> case n of
    Dictionary d' | keysSet d == keysSet d' -> do
      let f (k,s') v = map2 ((k,v):) (matchPat (SyntaxT s') v)
      map join . sequence <$> sequence (toList (zipWith f (d^.keyed) d'))
    _ -> pure Nothing
  ValList l -> \x -> force x >>= \n -> case n of
    ValList l' | length l' == length l -> 
      map join . sequence <$> sequence (zipWith (matchPat . SyntaxT) l l')
    _ -> pure Nothing
  Text k -> \x -> force x >>= \n -> case n of
    Text k' | Just ks <- matchRe k k' ->
      pure (Just . map (second (liftNS . Text)) $ ks)
    _ -> pure Nothing
  _ -> \_ -> pure Nothing
matchPat _ = \_ -> pure Nothing
      
lambdaSum :: (Eq k,Monad m) => [ThunkN k m -> ThunkN k m] -> ThunkN k m -> ThunkN k m
lambdaSum = foldr combine (const (liftS nil))
  where combine f g = \v -> force (f v) >>= \case
          ValList [] -> g v
          x -> liftNS x