{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-}

module QIO.VecEq where

import QIO.QioSyn
import QIO.Heap

class VecEq v where
    vzero :: v x a
    (<+>) :: (Eq a, Num x) => v x a -> v x a -> v x a
    (<*>) :: (Num x) => x -> v x a -> v x a
    (<@>) :: (Eq a, Num x) => a -> v x a -> x
    fromList  :: [(a,x)] -> v x a
    toList    :: v x a -> [(a,x)] 

newtype VecEqL x a = VecEqL {unVecEqL :: [(a,x)]} deriving Show

vEqZero :: VecEqL x a
vEqZero = VecEqL []
          

vEqPlus :: (Eq a, Num x) => VecEqL x a -> VecEqL x a -> VecEqL x a
(VecEqL as) `vEqPlus` vbs = foldr add vbs as


vEqTimes :: (Num x) => x -> VecEqL x a -> VecEqL x a
l `vEqTimes` (VecEqL bs) | l==0 = VecEqL []
                         | otherwise = VecEqL (map (\ (b,k) -> (b,l*k)) bs)
          

vEqAt :: (Eq a, Num x) => a -> VecEqL x a -> x
a `vEqAt` (VecEqL []) = 0
a `vEqAt` (VecEqL ((a',b):abs)) | a == a' = b
                                | otherwise = a `vEqAt` (VecEqL abs)
          

add :: (Eq a,Num x) => (a,x) -> VecEqL x a -> VecEqL x a
add (a,x) (VecEqL axs) = VecEqL (addV' axs)
    where addV' [] = [(a,x)]
          addV' ((by @ (b,y)):bys) | a == b = (b,x+y):bys
                                   | otherwise = by:(addV' bys)

instance VecEq VecEqL where
      vzero = vEqZero
      (<+>) = vEqPlus
      (<*>) = vEqTimes
      (<@>) = vEqAt
      fromList as = VecEqL as
      toList (VecEqL as) = as

class EqMonad m where
    eqReturn :: Eq a => a -> m a
    eqBind   :: (Eq a, Eq b) => m a -> (a -> m b) -> m b 

instance (VecEq v, Num x) => EqMonad (v x) where
    eqReturn a = fromList [(a,1)]
    eqBind va f = case toList va of
                   ([]) -> vzero
                   ((a,x):[]) -> x <*> f a
                   ((a,x):vas) -> (x <*> f a) <+> ((fromList vas) `eqBind` f)


data AsMonad m a where
   Embed  :: (EqMonad m, Eq a) => m a -> AsMonad m a
   Return :: EqMonad m => a -> AsMonad m a
   Bind   :: EqMonad m => AsMonad m a -> (a -> AsMonad m b) -> AsMonad m b
 
instance EqMonad m => Monad (AsMonad m) where
   return = Return
   (>>=) = Bind

unEmbed :: Eq a => AsMonad m a -> m a
unEmbed (Embed m) = m
unEmbed (Return a) = eqReturn a
unEmbed (Bind (Embed m) f) = m `eqBind` (unEmbed.f)
unEmbed (Bind (Return a) f) = unEmbed (f a)
unEmbed (Bind (Bind m f) g) = unEmbed (Bind m (\x -> Bind (f x) g))