module Data.Logical.Knot
  ( Knot
  , KnotT
  , MonadKnot
  , askKnot
  , askKnotDef
  , (*=)
  , accKnot
  , tieKnot
  , accKnotT
  , tieKnotT
  ) where



import Data.Map (Map,fromList,fromListWith,(!),findWithDefault)
import Control.Monad.Reader
import Control.Monad.Writer



type Constraint i x = (i,x)
type Solution   i x = Map i x

newtype Knot i x a =
          Knot (ReaderT (Solution i x) (Writer [Constraint i x]) a)
        deriving (Monad, MonadFix)

newtype KnotT i x m a =
          KnotT (ReaderT (Solution i x) (WriterT [Constraint i x] m) a)
        deriving (Monad, MonadFix)



instance MonadTrans (KnotT i x)
  where
    lift = KnotT . lift . lift
  -- Couldn't be derived for some reason



class (Monad m, Ord i) => MonadKnot i x m | m -> i x
 where
  askKnot    ::      i -> m x
  askKnotDef :: x -> i -> m x

  (*=) :: i -> x -> m ()



instance Ord i => MonadKnot i x (Knot i x)
  where
    askKnot i        = Knot $ asks (! i)
    askKnotDef def i = Knot $ asks $ findWithDefault def i

    i *= x = Knot $ tell [(i,x)]

instance (Monad m, Ord i) => MonadKnot i x (KnotT i x m)
  where
    askKnot i        = KnotT $ asks (! i)
    askKnotDef def i = KnotT $ asks $ findWithDefault def i

    i *= x = KnotT $ tell [(i,x)]



accKnot :: Ord i => (x -> x -> x) -> Knot i x a -> (a, Map i x)
accKnot acc (Knot knot) = (a,solution)
  where
    (a,ass)  = runWriter $ runReaderT knot solution
    solution = fromListWith acc ass
  -- acc should be commutative and associative.

tieKnot :: Ord i => Knot i x a -> (a, Map i x)
tieKnot = accKnot (error "tieKnot: Over-constrained")

accKnotT
    :: (Ord i, MonadFix m)
    => (x -> x -> x) -> KnotT i x m a -> m (a, Map i x)
accKnotT acc (KnotT knot) = mdo
    (a,ass) <- runWriterT $ runReaderT knot solution
    let solution = fromListWith acc ass
    return (a,solution)
  -- acc should be commutative and associative.

tieKnotT :: (Ord i, MonadFix m) => KnotT i x m a -> m (a, Map i x)
tieKnotT = accKnotT (error "tieKnot: Over-constrained")