{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-}

module QIO.QArith where

import Data.Monoid as Monoid
import QIO.QioSyn
import QIO.Qdata
import QIO.QioClass
import QIO.Qio
import QIO.QExamples

swapQInt :: QInt -> QInt -> U
swapQInt (QInt xs) (QInt ys) = swapQInt' xs ys
			       where swapQInt' [] [] = mempty
				     swapQInt' (x:xs) (y:ys) = (swap x y) `mappend` swapQInt' xs ys

ifElseQ :: Qbit -> U -> U -> U
ifElseQ qa t f = cond qa (\ qa -> if qa then t else f)

ifQ :: Qbit -> U -> U
ifQ qa t = ifElseQ qa t mempty

cnot :: Qbit -> Qbit -> U
cnot qa qb = ifQ qa (unot qb)

addBit :: Qbit -> Qbit -> Qbit -> U
addBit qc qa qb = 
    cnot qa qb `mappend`
    cnot qc qb

carry :: Qbit -> Qbit -> Qbit -> Qbit -> U
carry qci qa qb qcsi = 
    cond qci (\ ci ->
              cond qa (\ a ->
                       cond qb (\ b ->
                                if ci && a || ci && b || a && b
                                then unot qcsi
                                else mempty)))

addBits :: [Qbit] -> [Qbit] -> Qbit -> U
addBits qas qbs qc' = 
    letU False (addBits' qas qbs)
    where addBits' [] [] qc = ifQ qc (unot qc')
          addBits' (qa:qas) (qb:qbs) qc =
              letU False (\ qc' -> carry qc qa qb qc' `mappend`
                                   addBits' qas qbs qc'`mappend`
                                   urev (carry qc qa qb qc')) `mappend`
              addBit qc qa qb

addBits' :: [Qbit] -> [Qbit] -> [Qbit] -> Qbit -> U
addBits' [] [] [] qc = mempty
addBits' (qa:qas) (qb:qbs) (qc':qcs') qc =
    (carry qc qa qb qc' `mappend`
     addBits' qas qbs qcs' qc'`mappend`
     urev (carry qc qa qb qc')) `mappend`
    addBit qc qa qb

adder :: QInt -> QInt -> Qbit -> U
adder (QInt qas) (QInt qbs) qc = addBits qas qbs qc 

tadder :: (Int,(Int,Bool)) -> QIO (Int,(Int,Bool))
tadder xyc = do q @ (qx,(qy,qc)) <- mkQ xyc
                applyU (adder qx qy qc)
                xyc <- measQ q
                return xyc

tRadder :: (Int,(Int,Bool)) -> QIO (Int,(Int,Bool))
tRadder xyc = do q @ (qx,(qy,qc)) <- mkQ xyc
                 applyU (urev (adder qx qy qc))
                 xyc <- measQ q
                 return xyc

tBiAdder :: (Int,(Int,Bool)) -> QIO (Int,(Int,Bool))
tBiAdder xyc = do q @ (qx,(qy,qc)) <- mkQ xyc
		  applyU (adder qx qy qc)
		  applyU (urev (adder qx qy qc))
                  xyc <- measQ q
	          return xyc

adderMod :: Int -> QInt -> QInt -> U
adderMod n qa qb =
    letU n (\ qn ->
       letU False (\ qz ->
          letU False (\ qc -> 
             adder qa qb qc
             `mappend` -- b = a+b, c=False
             urev (adder qn qb qc)
             `mappend` -- b = a+b-N, c = a+b < N
             cond qc (\ c -> if c then unot qz else mempty)
             `mappend` -- z = c = a+b < N
             cond qz (\ z -> if z then adder qn qb qc else mempty)
             `mappend` -- b = a+b mod N, c = False, z = a+b < N
             urev (adder qa qb qc)
             `mappend` -- if a+b < N then a=a,b=b,c=False 
                       -- else a=a,b=a+b mod N - b,c=True
                       -- z = not c
             cond qc (\ c -> if c then mempty else unot qz)
             `mappend` -- z = False
             adder qa qb qc))) -- b = a+b mod N, c=False, z=False

tadderMod :: Int -> (Int,Int) -> QIO (Int,Int)
tadderMod n ab = do q @ (qa,qb) <- mkQ ab
                    applyU (adderMod n qa qb)
                    ab <- measQ q
                    return ab

multMod :: Int -> Int -> QInt -> QInt -> U
multMod n a (QInt x) y = multMod' n a x y 1
                         where multMod' _ _ [] _ _ = mempty
			       multMod' n a (x:xs) y p = cond x (\x -> (if x then (letU ((p*a) `mod` n) (\ qa -> (adderMod n qa y)) `mappend` (multMod' n a xs y (p*2)))
                                                                             else multMod' n a xs y (p*2)))
		               
-- output is a*x mod n
tmultMod :: Int -> Int -> Int -> QIO (Int,Int)
tmultMod n a x = do y <- mkQ 0
                    x' <- mkQ x
                    applyU(multMod n a x' y)
                    qy <- measQ y
                    qx <- measQ x'
                    return (qx,qy)

condMultMod :: Qbit -> Int -> Int -> QInt -> QInt -> U
condMultMod q n a x y = ifQ q (multMod n a x y)

------------------------------------------------------------------------------

inverseMod :: Int -> Int -> Int
inverseMod n a = inverseMod'' n a (inverseMod' n a)

inverseMod' :: Int -> Int -> [Int]
inverseMod' n a = [x | x <- [1..n], ((x*a) `mod` n) == 1]


inverseMod'' :: Int -> Int -> [Int] -> Int
inverseMod'' n a [] = error ("inverseMod: no inverse of "++(show a)++" mod "++(show n)++ " found")
inverseMod'' _ _ xs = head xs

-------------------------------------------------------------------------------

modExpStep :: Qbit -> Int -> Int -> QInt -> Int -> U
modExpStep qc n a o p = letU 0 (\z ->                (condMultMod qc n p'                o z) 
			             `mappend` (ifQ qc (swapQInt o z))
                                     `mappend` (urev (condMultMod qc n (inverseMod n p') o z)))
				  where p' | (a^(2^p)) == 0 = error "modExpStep: arguments too large"
					   | otherwise = (a^(2^p)) `mod` n

modExpStept :: Int -> Int -> Int -> Int -> QIO Int
modExpStept i n a p = do q <- mkQ True
		         one <- mkQ i
		         applyU (modExpStep q n a one p)	      
		         r <- measQ one	     
		         return r

modExp :: Int -> Int -> QInt -> QInt -> U
modExp n a (QInt x) o = modExp' n a x o 0
                        where modExp' _ _ [] _ _ = mempty
			      modExp' n a (x:xs) o p =           modExpStep x n a o p 
						      `mappend` (modExp' n a xs o (p+1))

--a^x mod N

modExpt :: Int -> (Int,Int) -> QIO Int
modExpt n (a,x) = do qx <- mkQ x
		     one <- mkQ 1
                     applyU (modExp n a qx one)
                     r <- measQ one
		     return r