{-# LANGUAGE UndecidableInstances #-}
module Algebra.Monad.Logic where

import Algebra.Monad.Base

newtype LogicT m a = LogicT { runLogicT :: forall r. (a -> m r -> m r) -> m r -> m r }

instance Functor (LogicT m) where
  map f (LogicT l) = LogicT (\k -> l (\a -> k (f a)))
instance Unit (LogicT m) where
  pure a = LogicT ($a)
instance Applicative (LogicT m)
instance Monad (LogicT m) where
  join (LogicT l) = LogicT (\k -> l (\(LogicT l') -> l' k))
instance MonadFix (LogicT m) where
  mfix f = fix (\(LogicT l) -> LogicT (\k -> l (\a m -> runLogicT (f a) k m)))
instance MonadTrans LogicT where
  lift ma = LogicT (\k mr -> ma >>= \a -> k a mr)

instance Semigroup (LogicT m a) where
  LogicT l + LogicT l' = LogicT (\k -> l k . l' k)
instance Monoid (LogicT m a) where
  zero = LogicT (pure id)
instance Semigroup a => Semiring (LogicT m a) where
  (*) = plusA
instance Monoid a => Ring (LogicT m a) where
  one = zeroA

instance MonadState s m => MonadState s (LogicT m) where
  get = lift get
  modify f = lift (modify f)

class Monad m => MonadLogic l m | l -> m where
  deduce :: l a -> m (Maybe (a,l a))
  induce :: m (Maybe (a,l a)) -> l a
instance Monad m => MonadLogic (LogicT m) m where
  deduce l = runLogicT l (\a m -> pure (pure (a,induce m))) (pure zero)
  induce mm = LogicT (\k m -> mm >>= maybe m (\(a,l) -> k a (runLogicT l k m)))

listLogic :: (MonadLogic l m,MonadLogic l' n) => Iso (l a) (l' b) (m [a]) (n [b])
listLogic = iso alts deduceAll
  where alts m = induce (m <&> \l -> case l of
          [] -> Nothing
          (a:t) -> Just (a,alts (pure t)))

deduceMany :: MonadLogic l m => Int -> l a -> m [a]
deduceMany 0 _ = pure []
deduceMany n l = deduce l >>= maybe (pure []) (\(a,t) -> (a:)<$>deduceMany (n-1) t)
deduceAll :: MonadLogic l m => l a -> m [a]
deduceAll l = deduce l >>= maybe (pure []) (\(a,t) -> (a:)<$>deduceAll t)

choose :: MonadLogic l m => [a] -> l a
choose l = pure l^.listLogic