module Quantum.Synthesis.RotationDecomposition where
import Quantum.Synthesis.Matrix
import Quantum.Synthesis.MultiQubitSynthesis
import Quantum.Synthesis.Ring
import Quantum.Synthesis.EulerAngles
import Quantum.Synthesis.ArcTan2
import Data.List
import System.Random
data ElementaryRot a = 
  ERot_zx a a Index Index
  | ERot_phase a Index
    deriving (Show)
matrix_of_elementary :: (Ring a, Floating a, Nat n) => ElementaryRot a -> Matrix n n (Cplx a)
matrix_of_elementary (ERot_zx delta gamma j k) = 
  twolevel_matrix (a, b) (c, d) j k where
  a = ed' * cg
  b = i * ed' * sg
  c = i * ed * sg
  d = ed * cg
  cg = Cplx (cos (gamma/2)) 0
  sg = Cplx (sin (gamma/2)) 0
  ed = Cplx cd sd
  ed' = Cplx cd (sd)
  cd = cos (delta/2) 
  sd = sin (delta/2)
matrix_of_elementary (ERot_phase theta j) = 
  onelevel_matrix (Cplx c s) j where
    c = cos theta
    s = sin theta
matrix_of_elementaries :: (Ring a, Floating a, Nat n) => [ElementaryRot a] -> Matrix n n (Cplx a)
matrix_of_elementaries ops =
  foldl' (*) 1 [ matrix_of_elementary op | op <- ops ]
rotation_decomposition :: (Eq a, Fractional a, Floating a, Adjoint a, ArcTan2 a, Nat n) => Matrix n n (Cplx a) -> [ElementaryRot a]
rotation_decomposition op = concat gates ++ reverse gates' where
  (op', gates) = mapAccumL rowop op [ (i,j) | j <- [0..n2], i <- [j+1..n1] ]
  gates' = [ get_phase op' i | i <- [0..n1] ]
  (n', _) = matrix_size op
  n = fromInteger n'
twolevel_matrix_of_matrix :: (Ring a, Nat n) => Matrix Two Two a -> Index -> Index -> Matrix n n a
twolevel_matrix_of_matrix u j k = op where
  op = twolevel_matrix (a,b) (c,d) j k
  ((a,b), (c,d)) = from_matrix2x2 u
  
get_phase :: (ArcTan2 a) => Matrix n n (Cplx a) -> Index -> ElementaryRot a
get_phase op j = ERot_phase theta j where
  a = matrix_index op j j
  theta = arctan2 y x
  Cplx x y = a
             
rowop :: (Eq a, Fractional a, Floating a, Adjoint a, ArcTan2 a, Nat n) => Matrix n n (Cplx a) -> (Index, Index) -> (Matrix n n (Cplx a), [ElementaryRot a])
rowop op (j,k) 
  | b == 0 = (op, [])
  | otherwise = (op', gates) 
  where
    a = matrix_index op k k
    b = matrix_index op j k
    matrix = 1/Cplx (sqrt(real (a * adj a + b * adj b))) 0 `scalarmult` matrix2x2 (adj a, adj b) (b, a)
    (alpha, beta, gamma, delta) = euler_angles matrix
    matrix2 = matrix_of_euler_angles (0, 0, gamma, delta)
    op' = twolevel_matrix_of_matrix matrix2 k j .*. op
    gates = [ ERot_zx (delta) (gamma) k j ]
random_unitary :: (RandomGen g, Nat n, Floating a, Random a) => g -> Matrix n n (Cplx a)
random_unitary g = op where
  op = matrix_of_elementaries gates
  gates = random_gates g (20*n^2)
  random_gates g 0 = []
  random_gates g m = h:t where
    (gamma, g1) = randomR (0, 2*pi) g
    (delta, g1') = randomR (0, 2*pi) g1
    (c, g2) = randomR (0, 1) g1'
    (j, g3) = randomR (0, n2) g2
    (k, g4) = randomR (j+1, n1) g3
    h = case c :: Int of
      0 -> ERot_zx delta gamma j k
      _ -> ERot_phase delta j
    t = random_gates g4 (m1)
  (n', _) = matrix_size op
  n = fromInteger n'
test :: IO ()
test = do
  g <- newStdGen
  let m = random_unitary g :: Matrix Four Four CDouble
  let gates = rotation_decomposition m
  let m' = matrix_of_elementaries gates :: Matrix Four Four CDouble
  putStrLn $ "m = " ++ show m
  putStrLn $ "gates = " ++ show gates
  putStrLn $ "m' = " ++ show m'