{-|
 -Module      : QProgram
 -Description : Definitions of qprogram datatypes and of simulation function.
 -Copyright   : (c) Mihai Sebastian Ardelean, 2024
 -License     : BSD3
 -Maintainer  : ardeleanasm@gmail.com
 -Portability : POSIX
 -}
module Quantum.QProgram
  (
    runQProg
  , makeQuantumState
  , Machine(..)
  , QInstruction(..)
  , QProgram(..)
  ) where

import Quantum.QDataTypes
import Quantum.Gates

import Data.List (nub, foldl')

import System.Random (randomRIO)

import qualified Numeric.LinearAlgebra as LA 

import Control.Monad (replicateM)

{-|
A `Machine` is defined by the quantum state and the measurement register.

It has two fields:
* `qstate` of type `State`
* `measurementRegister` of type `Int`
-}
data Machine = Machine {
    Machine -> State
qstate :: State            -- ^ Quantum state.
  , Machine -> Int
measurementRegister :: Int -- ^ Measurement register.
  } deriving (Machine -> Machine -> Bool
(Machine -> Machine -> Bool)
-> (Machine -> Machine -> Bool) -> Eq Machine
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Machine -> Machine -> Bool
== :: Machine -> Machine -> Bool
$c/= :: Machine -> Machine -> Bool
/= :: Machine -> Machine -> Bool
Eq, Int -> Machine -> ShowS
[Machine] -> ShowS
Machine -> String
(Int -> Machine -> ShowS)
-> (Machine -> String) -> ([Machine] -> ShowS) -> Show Machine
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Machine -> ShowS
showsPrec :: Int -> Machine -> ShowS
$cshow :: Machine -> String
show :: Machine -> String
$cshowList :: [Machine] -> ShowS
showList :: [Machine] -> ShowS
Show)


{-|
A `QInstruction` is defined by the unitary transformation and by the
qubits' index on which the transformation is applied.

It has two fields:
* `gateMatrix` of type `Gate` is the unitary matrix that defines the quantum gate.
* `affectedQubits` of type `[Int]`
-}
data QInstruction = QInstruction {
    QInstruction -> Gate
gateMatrix ::Gate       -- ^ Quantum gate matrix.
  , QInstruction -> [Int]
affectedQubits :: [Int] -- ^ List of qubits' index that are affected by the quantum gate.
    } deriving (QInstruction -> QInstruction -> Bool
(QInstruction -> QInstruction -> Bool)
-> (QInstruction -> QInstruction -> Bool) -> Eq QInstruction
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: QInstruction -> QInstruction -> Bool
== :: QInstruction -> QInstruction -> Bool
$c/= :: QInstruction -> QInstruction -> Bool
/= :: QInstruction -> QInstruction -> Bool
Eq,Int -> QInstruction -> ShowS
[QInstruction] -> ShowS
QInstruction -> String
(Int -> QInstruction -> ShowS)
-> (QInstruction -> String)
-> ([QInstruction] -> ShowS)
-> Show QInstruction
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> QInstruction -> ShowS
showsPrec :: Int -> QInstruction -> ShowS
$cshow :: QInstruction -> String
show :: QInstruction -> String
$cshowList :: [QInstruction] -> ShowS
showList :: [QInstruction] -> ShowS
Show)

{-|
A `QProgram` is defined by the list quantum instructions:
* `instructions` is of type `[QInstruction]`
-}
data QProgram = QProgram {
  QProgram -> [QInstruction]
instructions :: [QInstruction] -- ^ List of program instructions.
                   } deriving (QProgram -> QProgram -> Bool
(QProgram -> QProgram -> Bool)
-> (QProgram -> QProgram -> Bool) -> Eq QProgram
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: QProgram -> QProgram -> Bool
== :: QProgram -> QProgram -> Bool
$c/= :: QProgram -> QProgram -> Bool
/= :: QProgram -> QProgram -> Bool
Eq, Int -> QProgram -> ShowS
[QProgram] -> ShowS
QProgram -> String
(Int -> QProgram -> ShowS)
-> (QProgram -> String) -> ([QProgram] -> ShowS) -> Show QProgram
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> QProgram -> ShowS
showsPrec :: Int -> QProgram -> ShowS
$cshow :: QProgram -> String
show :: QProgram -> String
$cshowList :: [QProgram] -> ShowS
showList :: [QProgram] -> ShowS
Show)

{-|
 - makeQuantumState function initializes a quantum state of `n` qubits
-}
makeQuantumState :: Int -> State
makeQuantumState :: Int -> State
makeQuantumState Int
n = [Complex Double] -> State
forall a. Storable a => [a] -> Vector a
LA.fromList ([Complex Double] -> State) -> [Complex Double] -> State
forall a b. (a -> b) -> a -> b
$ Complex Double
1 Complex Double -> [Complex Double] -> [Complex Double]
forall a. a -> [a] -> [a]
: Int -> Complex Double -> [Complex Double]
forall a. Int -> a -> [a]
replicate (Int
2 Int -> Int -> Int
forall a b. (Num a, Integral b) => a -> b -> a
^ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Complex Double
0


dimensionQubits :: Int -> Int
dimensionQubits :: Int -> Int
dimensionQubits Int
size = Double -> Int
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
floor (Double -> Int) -> Double -> Int
forall a b. (a -> b) -> a -> b
$ Double -> Double -> Double
forall a. Floating a => a -> a -> a
logBase Double
2 (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
size)

apply :: Gate -> State -> State
apply :: Gate -> State -> State
apply = Gate -> State -> State
forall t. Numeric t => Matrix t -> Vector t -> Vector t
(LA.#>)

compose :: Gate -> Gate -> Gate
compose :: Gate -> Gate -> Gate
compose = Gate -> Gate -> Gate
forall t. Numeric t => Matrix t -> Matrix t -> Matrix t
(LA.<>)

kroneckerMul :: Gate -> Gate -> Gate
kroneckerMul :: Gate -> Gate -> Gate
kroneckerMul Gate
a Gate
b = Gate -> Gate -> Gate
forall t. Product t => Matrix t -> Matrix t -> Matrix t
LA.kronecker Gate
a Gate
b

kroneckerExp :: Gate -> Int -> Gate
kroneckerExp :: Gate -> Int -> Gate
kroneckerExp Gate
gate Int
n
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
1 = (Int
1Int -> Int -> [Complex Double] -> Gate
forall a. Storable a => Int -> Int -> [a] -> Matrix a
LA.><Int
1) [Complex Double
1]::Gate
  | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = Gate
gate
  | Bool
otherwise = Gate -> Gate -> Gate
kroneckerMul (Gate -> Int -> Gate
kroneckerExp Gate
gate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)) Gate
gate


lift :: Gate -> Int -> Int -> Gate
lift :: Gate -> Int -> Int -> Gate
lift Gate
gate Int
i Int
n = Gate
liftResult
  where
    left :: Gate
left = Gate -> Int -> Gate
kroneckerExp Gate
iGate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- (Int -> Int
dimensionQubits (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Gate -> Int
forall t. Matrix t -> Int
LA.rows Gate
gate))
    right :: Gate
right = Gate -> Int -> Gate
kroneckerExp Gate
iGate Int
i
    liftResult :: Gate
liftResult = Gate -> Gate -> Gate
kroneckerMul Gate
left (Gate -> Gate) -> Gate -> Gate
forall a b. (a -> b) -> a -> b
$ Gate -> Gate -> Gate
kroneckerMul Gate
gate Gate
right

perm2trans :: [Int] -> [(Int, Int)]
perm2trans :: [Int] -> [(Int, Int)]
perm2trans [Int]
permutation = [(Int, Int)] -> [(Int, Int)]
forall a. Eq a => [a] -> [a]
nub ((Int -> [(Int, Int)]) -> [Int] -> [(Int, Int)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Int -> [(Int, Int)]
processIndex [Int
0..[Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
permutation Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1])
  where
    -- Process each index to determine necessary swaps
    updateSrc :: Int -> Int -> [Int] -> Int
    updateSrc :: Int -> Int -> [Int] -> Int
updateSrc Int
src Int
dest [Int]
lst
      |Int
src Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
dest = Int
src
      |Bool
otherwise = Int -> Int -> [Int] -> Int
updateSrc ([Int]
lst [Int] -> Int -> Int
forall a. HasCallStack => [a] -> Int -> a
!! Int
src) Int
dest [Int]
lst
    processIndex :: Int -> [(Int, Int)]
processIndex Int
dest
      | Int
src Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
dest = [(Int
src, Int
dest)]
      | Int
src Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
dest = [(Int
dest,Int
src)]
      | Bool
otherwise = []
      where
        originalSrc :: Int
originalSrc = [Int]
permutation [Int] -> Int -> Int
forall a. HasCallStack => [a] -> Int -> a
!! Int
dest
        src :: Int
src = Int -> Int -> [Int] -> Int
updateSrc Int
originalSrc Int
dest [Int]
permutation

    
trans2adj :: [(Int, Int)] -> [Int]
trans2adj :: [(Int, Int)] -> [Int]
trans2adj [(Int, Int)]
transpositions = ((Int, Int) -> [Int]) -> [(Int, Int)] -> [Int]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Int, Int) -> [Int]
expandConsecutive [(Int, Int)]
transpositions
  where 
    expandConsecutive :: (Int, Int) -> [Int]
    expandConsecutive :: (Int, Int) -> [Int]
expandConsecutive (Int
x, Int
y)
      | Int
y Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = [Int
x]
      | Bool
otherwise  = [Int]
trans [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int] -> [Int]
forall a. [a] -> [a]
reverse ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
trans)
      where
        trans :: [Int]
trans = [Int
x..Int
yInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]

apply1Q :: State -> Gate -> Int -> State
apply1Q :: State -> Gate -> Int -> State
apply1Q State
s Gate
u Int
qubit = State
q1State
  where
    liftedU :: Gate
liftedU = Gate -> Int -> Int -> Gate
lift Gate
u Int
qubit (Int -> Int
dimensionQubits (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ State -> IndexOf Vector
forall (c :: * -> *) t. Container c t => c t -> IndexOf c
LA.size State
s)
    q1State :: State
q1State = Gate -> State -> State
apply Gate
liftedU State
s



applyNQ :: State -> Gate -> [Int] -> State
applyNQ :: State -> Gate -> [Int] -> State
applyNQ State
s Gate
u [Int]
qubits = State
qubitsNState
  where
    swap :: Int -> Int -> Gate
    swap :: Int -> Int -> Gate
swap Int
i Int
n = Gate -> Int -> Int -> Gate
lift Gate
swapGate Int
i Int
n

    trans2op :: [Int] -> Int -> Gate
    trans2op :: [Int] -> Int -> Gate
trans2op [] Int
n = Int -> Gate
forall a. (Num a, Element a) => Int -> Matrix a
LA.ident (Int
2Int -> Int -> Int
forall a b. (Num a, Integral b) => a -> b -> a
^Int
n)
    trans2op (Int
t:[Int]
ts) Int
n = (Gate -> Gate -> Gate) -> Gate -> [Gate] -> Gate
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Gate -> Gate -> Gate
compose (Int -> Int -> Gate
swap Int
t Int
n) ((Int -> Gate) -> [Int] -> [Gate]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Gate
`swap` Int
n) [Int]
ts)

    n :: Int
n = Int -> Int
dimensionQubits (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ State -> IndexOf Vector
forall (c :: * -> *) t. Container c t => c t -> IndexOf c
LA.size State
s
    u01 :: Gate
u01 = Gate -> Int -> Int -> Gate
lift Gate
u Int
0 Int
n
    fromSpace :: [Int]
fromSpace = [Int] -> [Int]
forall a. [a] -> [a]
reverse [Int]
qubits [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
i | Int
i <- [Int
0..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1], Int
i Int -> [Int] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Int]
qubits]
    trans :: [(Int, Int)]
trans = [Int] -> [(Int, Int)]
perm2trans [Int]
fromSpace
    adj :: [Int]
adj = [(Int, Int)] -> [Int]
trans2adj [(Int, Int)]
trans
    toFrom :: Gate
toFrom = [Int] -> Int -> Gate
trans2op [Int]
adj Int
n
    fromTo :: Gate
fromTo = [Int] -> Int -> Gate
trans2op ([Int] -> [Int]
forall a. [a] -> [a]
reverse [Int]
adj) Int
n
    upq :: Gate
upq = Gate -> Gate -> Gate
compose Gate
toFrom (Gate -> Gate -> Gate
compose Gate
u01 Gate
fromTo)
    qubitsNState :: State
qubitsNState = Gate -> State -> State
apply Gate
upq State
s

applyGate :: State -> Gate -> [Int] -> State
applyGate :: State -> Gate -> [Int] -> State
applyGate State
s Gate
u [Int]
qubits
  | Int
qubitsLength Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = State -> Gate -> Int -> State
apply1Q State
s Gate
u ([Int]
qubits [Int] -> Int -> Int
forall a. HasCallStack => [a] -> Int -> a
!! Int
0)
  | Bool
otherwise         = State -> Gate -> [Int] -> State
applyNQ State
s Gate
u [Int]
qubits
  where
    qubitsLength :: Int
qubitsLength = [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
qubits




sample :: State -> IO Int
sample :: State -> IO Int
sample State
s = do
  Double
r <- (Double, Double) -> IO Double
forall a (m :: * -> *). (Random a, MonadIO m) => (a, a) -> m a
randomRIO (Double
0.0, Double
1.0)
  Int -> IO Int
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> IO Int) -> Int -> IO Int
forall a b. (a -> b) -> a -> b
$ [Complex Double] -> Double -> Double -> Int
sampleIndex (State -> [Complex Double]
forall a. Storable a => Vector a -> [a]
LA.toList State
s) Double
r Double
0.0
  where
    sampleIndex :: [LA.Complex Double] -> Double -> Double -> Int
    sampleIndex :: [Complex Double] -> Double -> Double -> Int
sampleIndex [] Double
_ Double
_ = String -> Int
forall a. HasCallStack => String -> a
error String
"Invalid state vector"
    sampleIndex (Complex Double
c:[Complex Double]
cs) Double
r Double
accProb =
      let prob :: Double
prob = Double
accProb Double -> Double -> Double
forall a. Num a => a -> a -> a
+ (Complex Double -> Double
forall a. RealFloat a => Complex a -> a
LA.magnitude Complex Double
c) Double -> Double -> Double
forall a. Floating a => a -> a -> a
** Double
2
      in if Double
r Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
prob
         then Int
0
         else Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Complex Double] -> Double -> Double -> Int
sampleIndex [Complex Double]
cs Double
r Double
prob

collapse :: State -> Int -> State
collapse :: State -> Int -> State
collapse State
st Int
i = [Complex Double] -> State
forall a. Storable a => [a] -> Vector a
LA.fromList [Complex Double]
collapsedState
  where
    stateLength :: IndexOf Vector
stateLength = State -> IndexOf Vector
forall (c :: * -> *) t. Container c t => c t -> IndexOf c
LA.size State
st
    collapsedState :: [Complex Double]
collapsedState = Int -> Complex Double -> [Complex Double]
forall a. Int -> a -> [a]
replicate Int
i Complex Double
0 [Complex Double] -> [Complex Double] -> [Complex Double]
forall a. [a] -> [a] -> [a]
++ [Complex Double
1] [Complex Double] -> [Complex Double] -> [Complex Double]
forall a. [a] -> [a] -> [a]
++ Int -> Complex Double -> [Complex Double]
forall a. Int -> a -> [a]
replicate (Int
stateLength Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Complex Double
0    

observe :: Machine -> IO Machine
observe :: Machine -> IO Machine
observe Machine
machine = do
  let state :: State
state = Machine -> State
qstate Machine
machine
  Int
i <- State -> IO Int
sample State
state
  let newState :: State
newState = State -> Int -> State
collapse State
state Int
i
  Machine -> IO Machine
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Machine
machine { qstate :: State
qstate = State
newState, measurementRegister :: Int
measurementRegister = Int
i}


evolveState :: QInstruction -> Machine ->Machine
evolveState :: QInstruction -> Machine -> Machine
evolveState (QInstruction Gate
gateMatrix [Int]
affectedQubits) Machine
m = Machine
newMachine
  where
    newState :: State
newState = State -> Gate -> [Int] -> State
applyGate (Machine -> State
qstate Machine
m) Gate
gateMatrix [Int]
affectedQubits
    newMachine :: Machine
newMachine = State -> Int -> Machine
Machine State
newState (Machine -> Int
measurementRegister Machine
m)

{-|
 - runQProg function executes the program's instruction.
-}
runQProg :: QProgram -> Machine -> IO Machine
runQProg :: QProgram -> Machine -> IO Machine
runQProg QProgram
qprog Machine
machine = [QInstruction] -> Machine -> IO Machine
runInstruction (QProgram -> [QInstruction]
instructions QProgram
qprog) Machine
machine
  where
    runInstruction :: [QInstruction] -> Machine -> IO Machine
    runInstruction :: [QInstruction] -> Machine -> IO Machine
runInstruction [] Machine
m = Machine -> IO Machine
observe Machine
m
    runInstruction (QInstruction
x:[QInstruction]
xs) Machine
m = [QInstruction] -> Machine -> IO Machine
runInstruction [QInstruction]
xs (QInstruction -> Machine -> Machine
evolveState QInstruction
x Machine
m)