module QIO.Qio where

import List
import qualified System.Random as Random
import Data.Monoid as Monoid
import Data.Maybe as Maybe
import Control.Monad.State
import QIO.QioSyn
import QIO.Vec
import QIO.VecEq
import QIO.Heap

type Pure = VecEqL CC HeapMap

updateP :: Pure -> Qbit -> Bool -> Pure
updateP p x b = VecEqL (map (\ (h,pa) -> (update h x b,pa)) (unVecEqL p))

newtype Unitary =  U {unU :: Int -> HeapMap -> Pure }

instance Monoid Unitary where
    mempty = U (\ fv h -> unEmbed $ return h)
    mappend (U f) (U g) = U (\ fv h -> unEmbed $ do h' <- Embed $ f fv h
                                                    h'' <- Embed $ g fv  h'
                                                    return h''
                                                 )

uRot :: Qbit -> Rotation -> Unitary
uRot q r = if (unitaryRot r) then (uMatrix q (r (False,False),
                                              r (False,True),
                                              r (True,False),
                                              r (True,True)))
                             else error "Non unitary Rotation!"

unitaryRot :: Rotation -> Bool
unitaryRot r = True
-- update to check that the rotation is unitary...

uMatrix :: Qbit -> (CC,CC,CC,CC) -> Unitary
uMatrix q (m00,m01,m10,m11) = U (\ fv h -> (if (fromJust(h ? q)) 
                                           then   (m01 <*> (unEmbed $ return (update h q False))) 
                                                  <+> (m11 <*> (unEmbed $ return h)) 
                                           else   (m00 <*> (unEmbed $ return h)) 
                                                  <+> (m10 <*> (unEmbed $ return (update h q True)))))

uSwap :: Qbit -> Qbit -> Unitary
uSwap x y = U (\ fv h -> unEmbed $ return (hswap h x y ))

uCond :: Qbit -> (Bool -> Unitary) -> Unitary
--uCond x us = U (\ fv h -> updateP (unU (us (h ? x)) fv (forget h x)) x (h ? x))
uCond x us = U (\ fv h -> unU (us (fromJust(h ? x))) fv h )
--whether or not to forget? (if not then no runtime error for conditionals)

uLet :: Bool -> (Qbit -> Unitary) -> Unitary
uLet b ux = U (\fv h -> unU (ux (Qbit fv)) (fv + 1) (update h (Qbit fv) b))
--doesn't enforce unitary
-- need Unitary -> [Qbit] ???

runU :: U -> Unitary
runU UReturn = mempty
runU (Rot x a u) = uRot x a `mappend` runU u
runU (Swap x y u) = uSwap x y `mappend` runU u
runU (Cond x us u) = uCond x (runU.us) `mappend` runU u
runU (Ulet b xu u) = uLet b (runU.xu) `mappend` runU u

data StateQ = StateQ { free :: Int, pure :: Pure }

initialStateQ :: StateQ
initialStateQ = StateQ 0 (unEmbed $ return initial)

pa :: Pure -> RR
pa (VecEqL as) = foldr (\ (_,k) p -> p + amp k) 0 as

data Split = Split { p :: RR, ifTrue,ifFalse :: Pure }

split :: Pure -> Qbit -> Split
split (VecEqL as) x =
    let pas = pa (VecEqL as)
        (ift',iff') = partition (\ (h,_) -> (fromJust(h ? x))) as
        ift = VecEqL ift'
        iff = VecEqL iff'
        p_ift = if pas==0 then 0 else (pa ift)/pas
    in Split p_ift ift iff

class Monad m => PMonad m where
    merge :: RR -> m a -> m a -> m a

instance PMonad IO where
    merge pr ift iff = do pp <- Random.randomRIO (0,1.0)
                          if pr > pp then ift else iff

data Prob a = Prob {unProb :: Vec RR a}

instance Show a => Show (Prob a) where
    show (Prob (Vec ps)) = show (filter (\ (a,p) -> p>0) ps)

instance Monad Prob where
    return = Prob . return
    (Prob ps) >>= f = Prob (ps >>= unProb . f)

instance PMonad Prob where
    merge pr (Prob ift) (Prob iff) = Prob ((pr <**> ift) <++> ((1-pr) <**> iff))


evalWith :: PMonad m => QIO a -> State StateQ (m a)
evalWith (QReturn a) = return (return a)
evalWith (MkQbit b g) = do (StateQ f p) <- get 
                           put (StateQ (f+1) (updateP p (Qbit f) b))
                           evalWith (g (Qbit f))
evalWith (ApplyU u q) = do (StateQ f p) <- get
                           put (StateQ f (unEmbed $ do x <- Embed $ p
                                                       x' <-Embed $ uu f x
                                                       return x'
                                          )
                                )
                           evalWith q  
                               where U uu = runU u
evalWith (Meas x g) = do (StateQ f p) <- get
                         (let Split pr ift iff = split p x
                          in if pr < 0 || pr > 1 then error "pr < 0 or >1" 
                             else do put (StateQ f ift)
                                     pift <- evalWith (g True)
                                     put (StateQ f iff)
                                     piff <- evalWith (g False)
                                     return (merge pr pift piff))

eval :: PMonad m => QIO a -> m a
eval p = evalState (evalWith p) initialStateQ

run :: QIO a -> IO a
run = eval

sim :: QIO a -> Prob a
sim = eval