-- | This module defines the functions that can be used to simulate the running of
-- QIO computations.
module QIO.Qio where

import Data.List
import qualified System.Random as Random
import Data.Monoid as Monoid
import Data.Maybe as Maybe
import Control.Monad.State
import QIO.QioSyn
import QIO.Vec
import QIO.VecEq
import QIO.Heap

-- | A "Pure" state can be thought of as a vector of classical basis states, stored
-- as Heaps, along with complex amplitudes.
type Pure = VecEqL CC HeapMap

-- | The state of a qubit can be updated in a Pure state, by mapping the update
-- operation over each Heap.
updateP :: Pure -> Qbit -> Bool -> Pure
updateP p x b = VecEqL (map (\ (h,pa) -> (update h x b,pa)) (unVecEqL p))

-- | A "Unitary" can be thought of as an operation on a HeapMap that may produce
-- a Pure state.
newtype Unitary =  U {unU :: Int -> HeapMap -> Pure }

-- | The Unitary type forms a Monoid
instance Monoid Unitary where
    mempty = U (\ fv h -> unEmbed $ return h)
    mappend (U f) (U g) = U (\ fv h -> unEmbed $ do h' <- Embed $ f fv h
                                                    h'' <- Embed $ g fv  h'
                                                    return h''
                                                 )

-- | A function that checks if a given "Rotation" is in face unitary. Note that
-- this is currently a dummy stub function, and states that any rotation is
-- unitary. (This is only o.k. at the moment as all the rotations defined in the
-- QIO library are unitary, but won't catch un-unitary user-defined Rotations)
unitaryRot :: Rotation -> Bool
unitaryRot r = True
-- TODO: update to check that the rotation is unitary...

-- | Given the four complex numbers that make up a 2-by-2 matrix, we can create
-- a Unitary that applies the rotation to the given qubit.
uMatrix :: Qbit -> (CC,CC,CC,CC) -> Unitary
uMatrix q (m00,m01,m10,m11) = U (\ fv h -> (if (fromJust(h ? q)) 
                                           then   (m01 <*> (unEmbed $ return (update h q False))) 
                                                  <+> (m11 <*> (unEmbed $ return h)) 
                                           else   (m00 <*> (unEmbed $ return h)) 
                                                  <+> (m10 <*> (unEmbed $ return (update h q True)))))

-- | A rotation can be converted into a "Unitary", using the 'uMatrix' function
uRot :: Qbit -> Rotation -> Unitary
uRot q r = if (unitaryRot r) then (uMatrix q (r (False,False),
                                              r (False,True),
                                              r (True,False),
                                              r (True,True)))
                             else error "Non unitary Rotation!"

-- | A swap operation can be defined as a Unitary
uSwap :: Qbit -> Qbit -> Unitary
uSwap x y = U (\ fv h -> unEmbed $ return (hswap h x y ))

-- | A conditional operation can be defined as a Unitary
uCond :: Qbit -> (Bool -> Unitary) -> Unitary
--uCond x us = U (\ fv h -> updateP (unU (us (h ? x)) fv (forget h x)) x (h ? x))
uCond x us = U (\ fv h -> unU (us (fromJust(h ? x))) fv h )
--whether or not to forget? (if not then no runtime error for conditionals)

-- | A let operation can be defined as a Unitary
uLet :: Bool -> (Qbit -> Unitary) -> Unitary
uLet b ux = U (\fv h -> unU (ux (Qbit fv)) (fv + 1) (update h (Qbit fv) b))
--doesn't enforce unitary
-- need Unitary -> [Qbit] ???

-- | Any member of the "U" type can be \"run\" by converting it into a Unitary.
runU :: U -> Unitary
runU UReturn = mempty
runU (Rot x a u) = uRot x a `mappend` runU u
runU (Swap x y u) = uSwap x y `mappend` runU u
runU (Cond x us u) = uCond x (runU.us) `mappend` runU u
runU (Ulet b xu u) = uLet b (runU.xu) `mappend` runU u

-- | A quantum state is a defined as the next free qubit reference, along with the
-- Pure state that represents the overall quantum state
data StateQ = StateQ { free :: Int, pure :: Pure }

-- | The initial 'StateQ'
initialStateQ :: StateQ
initialStateQ = StateQ 0 (unEmbed $ return initial)

-- | Given a Pure state, return a sum of all the amplitudes.
pa :: Pure -> RR
pa (VecEqL as) = foldr (\ (_,k) p -> p + amp k) 0 as

-- | A Split, is defined as a probability, along with the two Pure states. 
data Split = Split { p :: RR, ifTrue,ifFalse :: Pure }

-- | Given a Pure state, we can create a Split, by essentially splitting the
-- state into the part where the qubit is True, and the part where the qubit is
-- False. This is how measurements are implemented in QIO.
split :: Pure -> Qbit -> Split
split (VecEqL as) x =
    let pas = pa (VecEqL as)
        (ift',iff') = partition (\ (h,_) -> (fromJust(h ? x))) as
        ift = VecEqL ift'
        iff = VecEqL iff'
        p_ift = if pas==0 then 0 else (pa ift)/pas
    in Split p_ift ift iff

-- | We can extend a Monad into a PMonad if it defines a way of probabilistically
-- merging two computations, based on a given probability.
class Monad m => PMonad m where
    merge :: RR -> m a -> m a -> m a

-- | IO forms a PMonad, using the random number generator to pick one of the
-- \"merged\" computations probabilistically.
instance PMonad IO where
    merge pr ift iff = do pp <- Random.randomRIO (0,1.0)
                          if pr > pp then ift else iff

-- | The type Prob is defined as a wrapper around Vectors with Real probabilities.
data Prob a = Prob {unProb :: Vec RR a}

-- | We can show a probability distribution by filtering out elements with
-- a zero probability.
instance Show a => Show (Prob a) where
    show (Prob (Vec ps)) = show (filter (\ (a,p) -> p>0) ps)

-- | Prob forms a Monad
instance Monad Prob where
    return = Prob . return
    (Prob ps) >>= f = Prob (ps >>= unProb . f)

-- | Prob is also a PMonad, where the result of both computations are combined into
-- a probability distribution.
instance PMonad Prob where
    merge pr (Prob ift) (Prob iff) = Prob ((pr <**> ift) <++> ((1-pr) <**> iff))

-- | Given a PMonad, a QIO Computation can be converted into a Stateful computation
-- over a quantum state (of type 'StateQ').
evalWith :: PMonad m => QIO a -> State StateQ (m a)
evalWith (QReturn a) = return (return a)
evalWith (MkQbit b g) = do (StateQ f p) <- get 
                           put (StateQ (f+1) (updateP p (Qbit f) b))
                           evalWith (g (Qbit f))
evalWith (ApplyU u q) = do (StateQ f p) <- get
                           put (StateQ f (unEmbed $ do x <- Embed $ p
                                                       x' <-Embed $ uu f x
                                                       return x'
                                          )
                                )
                           evalWith q  
                               where U uu = runU u
evalWith (Meas x g) = do (StateQ f p) <- get
                         (let Split pr ift iff = split p x
                          in if pr < 0 || pr > 1 then error "pr < 0 or >1" 
                             else do put (StateQ f ift)
                                     pift <- evalWith (g True)
                                     put (StateQ f iff)
                                     piff <- evalWith (g False)
                                     return (merge pr pift piff))

-- | A QIO computation is evaluated by converting it into a stateful computation
-- and then evaluating that over the initial state.
eval :: PMonad m => QIO a -> m a
eval p = evalState (evalWith p) initialStateQ

-- | Running a QIO computation evaluates it in the IO PMonad
run :: QIO a -> IO a
run = eval

-- | Simulating a QIO computation evaluates it in the Prob PMonad
sim :: QIO a -> Prob a
sim = eval