{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE IncoherentInstances #-}
{-# LANGUAGE EmptyDataDecls #-}

{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}

-- | This module provides fixed but arbitrary sized vectors and
-- matrices. The dimensions of the vectors and matrices are determined
-- by the type, for example,
-- 
-- > Matrix Two Three Complex
-- 
-- for complex 2×3-matrices. The type system ensures that there are no
-- run-time dimension errors.

module Quantum.Synthesis.Matrix where

import Quantum.Synthesis.Ring

-- ----------------------------------------------------------------------
-- * Type-level natural numbers
  
-- $ Note: with Haskell 7.4.2 data-kinds, this could be replaced by a
-- tighter definition; however, the following works just fine in
-- Haskell 7.2.

-- | Type-level representation of zero.
data Zero

-- | Type-level representation of successor.
data Succ a

-- | The natural number 1 as a type.
type One = Succ Zero

-- | The natural number 2 as a type.
type Two = Succ One

-- | The natural number 3 as a type.
type Three = Succ Two

-- | The natural number 4 as a type.
type Four = Succ Three

-- | The natural number 5 as a type.
type Five = Succ Four

-- | The natural number 6 as a type.
type Six = Succ Five

-- | The natural number 7 as a type.
type Seven = Succ Six

-- | The natural number 8 as a type.
type Eight = Succ Seven

-- | The natural number 9 as a type.
type Nine = Succ Eight

-- | The natural number 10 as a type.
type Ten = Succ Nine

-- | The 10th successor of a natural number type. For example, the
-- natural number 18 as a type is
-- 
-- > Ten_and Eight
type Ten_and a = Succ (Succ (Succ (Succ (Succ (Succ (Succ (Succ (Succ (Succ a)))))))))

-- | A data type for the natural numbers. Specifically, if /n/ is a
-- type-level natural number, then
-- 
-- > NNat n
-- 
-- is a singleton type containing only the natural number /n/.
data NNat :: * -> * where
  Zero :: NNat Zero
  Succ :: (Nat n) => NNat n -> NNat (Succ n)

-- | Convert an 'NNat' to an 'Integer'.
fromNNat :: NNat n -> Integer
fromNNat Zero = 0
fromNNat (Succ n) = 1 + fromNNat n

instance Show (NNat n) where
  show = show . fromNNat

-- | A type class for the natural numbers. The members are exactly the
-- type-level natural numbers.
class Nat n where
  -- | Return a term-level natural number corresponding to this
  -- type-level natural number.
  nnat :: NNat n
  
  -- | Return a term-level integer corresponding to this type-level
  -- natural number. The argument is just a dummy argument and is not
  -- evaluated.
  nat :: n -> Integer
  
instance Nat Zero where
  nnat = Zero
  nat n = 0
instance (Nat a) => Nat (Succ a) where
  nnat = Succ nnat
  nat n = 1 + nat (un n) where
    un :: Succ a -> a
    un = undefined

-- | Addition of type-level natural numbers.
type family Plus n m
type instance Zero `Plus` m = m
type instance (Succ n) `Plus` m = Succ (n `Plus` m)

-- | Multiplication of type-level natural numbers.
type family Times n m
type instance Zero `Times` m = Zero
type instance (Succ n) `Times` m = m `Plus` (n `Times` m)

-- ----------------------------------------------------------------------
-- * Fixed-length vectors

-- | @Vector /n/ /a/@ is the type of lists of length /n/ with elements
-- from /a/. We call this a \"vector\" rather than a tuple or list for
-- two reasons: the vectors are homogeneous (all elements have the
-- same type), and they are strict: if any one component is undefined,
-- the whole vector is undefined.
data Vector :: * -> * -> * where
  Nil :: Vector Zero a
  Cons :: !a -> !(Vector n a) -> Vector (Succ n) a

infixr 5 `Cons`

instance (Eq a) => Eq (Vector n a) where
  Nil == Nil = True
  Cons a as == Cons b bs = a == b && as == bs

instance (Show a) => Show (Vector n a) where
  showsPrec d x = showParen (d >= 11) $ showString ("vector " ++ show (list_of_vector x))

instance (ToDyadic a b) => ToDyadic (Vector n a) (Vector n b) where
  maybe_dyadic as = vector_sequence (vector_map maybe_dyadic as)

instance (WholePart a b) => WholePart (Vector n a) (Vector n b) where  
  from_whole = vector_map from_whole
  to_whole = vector_map to_whole
  
instance (DenomExp a) => DenomExp (Vector n a) where
  denomexp as = denomexp (list_of_vector as)
  denomexp_factor as k = vector_map (\a -> denomexp_factor a k) as
  
-- | Construct a vector of length 1.
vector_singleton :: a -> Vector One a
vector_singleton x = x `Cons` Nil

-- | Return the length of a vector. Since this information is
-- contained in the type, the vector argument is never evaluated and
-- can be a dummy (undefined) argument.
vector_length :: (Nat n) => Vector n a -> Integer
vector_length = nat . un where
  un :: Vector n a -> n
  un = undefined

-- | Convert a fixed-length list to an ordinary list.
list_of_vector :: Vector n a -> [a]
list_of_vector Nil = []
list_of_vector (Cons h t) = h : list_of_vector t

-- | Zip two equal length lists.
vector_zipwith :: (a -> b -> c) -> Vector n a -> Vector n b -> Vector n c
vector_zipwith f Nil Nil = Nil
vector_zipwith f (Cons a as) (Cons b bs) = Cons c cs where
  c = f a b
  cs = vector_zipwith f as bs

-- | Map a function over a fixed-length list.
vector_map :: (a -> b) -> Vector n a -> Vector n b
vector_map f Nil = Nil
vector_map f (Cons a as) = Cons (f a) (vector_map f as)

-- | Create the vector (0, 1, …, /n/-1).
vector_enum :: (Num a, Nat n) => Vector n a
vector_enum = aux nnat 0 where
  aux :: (Num a) => NNat n -> a -> Vector n a
  aux Zero a = Nil
  aux (Succ n) a = Cons a (aux n (a+1))

-- | Create the vector (/f/(0), /f/(1), …, /f/(/n/-1)).
vector_of_function :: (Num a, Nat n) => (a -> b) -> Vector n b
vector_of_function f = vector_map f vector_enum

-- | Construct a vector from a list. Note: since the length of the
-- vector is a type-level integer, it cannot be inferred from the
-- length of the input list; instead, it must be specified explicitly
-- in the type. It is an error to apply this function to a list of
-- the wrong length.
vector :: (Nat n) => [a] -> Vector n a
vector = aux nnat where
  aux :: NNat n -> [a] -> Vector n a
  aux Zero [] = Nil
  aux (Succ n) (h:t) = Cons h (aux n t)
  aux _ _ = error "vector: length mismatch"

-- | Return the /i/th element of the vector. Counting starts from 0.
-- Throws an error if the index is out of range.
vector_index :: (Integral i) => Vector n a -> i -> a
vector_index v i = list_of_vector v !! fromIntegral i

-- | Return a fixed-length list consisting of a repetition of the
-- given element. Unlike 'replicate', no count is needed, because this
-- information is already contained in the type. However, the type
-- must of course be inferable from the context.
vector_repeat :: (Nat n) => a -> Vector n a
vector_repeat x = vector_of_function (const x)

-- | Turn a list of columns into a list of rows.
vector_transpose :: (Nat m) => Vector n (Vector m a) -> Vector m (Vector n a)
vector_transpose Nil = vector_repeat Nil
vector_transpose (Cons a as) = vector_zipwith Cons a (vector_transpose as)

-- | Left strict fold over a fixed-length list.
vector_foldl :: (a -> b -> a) -> a -> Vector n b -> a
vector_foldl f x l = foldl f x (list_of_vector l)

-- | Right fold over a fixed-length list.
vector_foldr :: (a -> b -> b) -> b -> Vector n a -> b
vector_foldr f x l = foldr f x (list_of_vector l)

-- | Return the tail of a fixed-length list. Note that the type system
-- ensures that this never fails.
vector_tail :: Vector (Succ n) a -> Vector n a
vector_tail (Cons h t) = t

-- | Return the head of a fixed-length list. Note that the type system
-- ensures that this never fails.
vector_head :: Vector (Succ n) a -> a
vector_head (Cons h t) = h

-- | Append two fixed-length lists.
vector_append :: Vector n a -> Vector m a -> Vector (n `Plus` m) a
vector_append Nil v = v
vector_append (Cons h t) v = Cons h (vector_append t v)

-- | Version of 'sequence' for fixed-length lists.
vector_sequence :: (Monad m) => Vector n (m a) -> m (Vector n a)
vector_sequence Nil = return Nil
vector_sequence (Cons a as) = do
  a' <- a
  as' <- vector_sequence as
  return (Cons a' as')

-- ----------------------------------------------------------------------
-- * Matrices

-- | An /m/×/n/-matrix is a list of /n/ columns, each of which is a
-- list of /m/ scalars.  The type of square matrices of any fixed
-- dimension is an instance of the 'Ring' class, and therefore the
-- usual symbols, such as \"'+'\" and \"'*'\" can be used on
-- them. However, the non-square matrices, the symbols \"'.+.'\" and
-- \"'.*.'\" must be used.
data Matrix m n a = Matrix !(Vector n (Vector m a))
               deriving (Eq)

instance (Nat m, Show a) => Show (Matrix m n a) where
  showsPrec d m = showParen (d >= 11) $ showString ("matrix " ++ show (rows_of_matrix m))
  
-- This is an overlapping instance.
instance (Nat m) => Show (Matrix m n DRootTwo) where
  showsPrec = showsPrec_DenomExp
  
-- This is an overlapping instance.
instance (Nat m) => Show (Matrix m n DRComplex) where
  showsPrec = showsPrec_DenomExp

-- This is an overlapping instance.
instance (Nat m) => Show (Matrix m n DOmega) where
  showsPrec = showsPrec_DenomExp
  
instance (ToDyadic a b) => ToDyadic (Matrix m n a) (Matrix m n b) where
  maybe_dyadic (Matrix a) = do
    b <- maybe_dyadic a
    return (Matrix b)

instance (WholePart a b) => WholePart (Matrix m n a) (Matrix m n b) where
  from_whole (Matrix m) = Matrix (from_whole m)
  to_whole (Matrix m) = Matrix (to_whole m)

instance (DenomExp a) => DenomExp (Matrix m n a) where
  denomexp (Matrix m) = denomexp m
  denomexp_factor (Matrix m) k = Matrix (denomexp_factor m k)

-- | Decompose a matrix into a list of columns.
unMatrix :: Matrix m n a -> (Vector n (Vector m a))
unMatrix (Matrix m) = m

-- | Return the size (/m/, /n/) of a matrix, where /m/ is the number
-- of rows, and /n/ is the number of columns. Since this information
-- is contained in the type, the matrix argument is not evaluated and
-- can be a dummy (undefined) argument.
matrix_size :: (Nat m, Nat n) => Matrix m n a -> (Integer, Integer)
matrix_size op = (nat (m op), nat (n op)) where
  m :: Matrix m n a -> m
  m = undefined
  n :: Matrix m n a -> n
  n = undefined

-- ----------------------------------------------------------------------
-- ** Basic matrix operations

-- | Addition of /m/×/n/-matrices. We use a special symbol because
-- /m/×/n/-matrices do not form a ring; only /n/×/n/-matrices form a
-- ring (in which case the normal symbol \"'+'\" also works).
(.+.) :: (Num a) => Matrix m n a -> Matrix m n a -> Matrix m n a
Matrix a .+. Matrix b = Matrix c where
  c = vector_zipwith (vector_zipwith (+)) a b

infixl 6 .+.

-- | Subtraction of /m/×/n/-matrices. We use a special symbol because
-- /m/×/n/-matrices do not form a ring; only /n/×/n/-matrices form a
-- ring (in which case the normal symbol \"'-'\" also works).
(.-.) :: (Num a) => Matrix m n a -> Matrix m n a -> Matrix m n a
Matrix a .-. Matrix b = Matrix c where
  c = vector_zipwith (vector_zipwith (-)) a b

infixl 6 .-.

-- | Map some function over every element of a matrix.
matrix_map :: (a -> b) -> Matrix m n a -> Matrix m n b
matrix_map f (Matrix a) = Matrix b where
  b = vector_map (vector_map f) a

-- | Create the matrix whose /i/,/j/-entry is (/i/,/j/). Here /i/ and
-- /j/ are 0-based, i.e., the top left entry is (0,0).
matrix_enum :: (Num a, Nat n, Nat m) => Matrix m n (a,a)
matrix_enum = Matrix (vector_of_function f) where
  f i = vector_of_function (\j -> (j,i))

-- | Create the matrix whose /i/,/j/-entry is @f i j@. Here /i/ and
-- /j/ are 0-based, i.e., the top left entry is @f 0 0@.
matrix_of_function :: (Num a, Nat n, Nat m) => (a -> a -> b) -> Matrix m n b
matrix_of_function f = matrix_map (uncurry f) matrix_enum

-- | Multiplication of a scalar and an /m/×/n/-matrix.
scalarmult :: (Num a) => a -> Matrix m n a -> Matrix m n a
scalarmult x m = matrix_map (x *) m

infixl 7 `scalarmult`

-- | Division of an /m/×/n/-matrix by a scalar.
scalardiv :: (Fractional a) => Matrix m n a -> a -> Matrix m n a
scalardiv m x = matrix_map (/ x) m

infixl 7 `scalardiv`

-- | Multiplication of /m/×/n/-matrices. We use a special symbol
-- because /m/×/n/-matrices do not form a ring; only /n/×/n/-matrices
-- form a ring (in which case the normal symbol \"'*'\" also works).
(.*.) :: (Num a, Nat m) => Matrix m n a -> Matrix n p a -> Matrix m p a
Matrix a .*. Matrix b = Matrix c where
  c = vector_map (a `mmv`) b
  
  mmv :: (Num a, Nat m) => Vector n (Vector m a) -> Vector n a -> Vector m a
  Nil `mmv` Nil = vector_repeat 0
  (Cons h Nil) `mmv` (Cons k Nil) = k `msv` h
  (Cons h t) `mmv` (Cons k s) = (k `msv` h) `avv` (t `mmv` s)
  
  msv :: (Num b) => b -> Vector n b -> Vector n b
  k `msv` h = vector_map (k*) h
  
  avv :: (Num c) => Vector n c -> Vector n c -> Vector n c
  v `avv` w = vector_zipwith (+) v w

infixl 7 .*.

-- | Return the 0 matrix of the given dimension.
null_matrix :: (Num a, Nat n, Nat m) => Matrix m n a
null_matrix = Matrix (vector_repeat (vector_repeat 0))

-- | Take the transpose of an /m/×/n/-matrix.
matrix_transpose :: (Nat m) => Matrix m n a -> Matrix n m a
matrix_transpose (Matrix a) = Matrix b where
  b = vector_transpose a

-- | Take the adjoint of an /m/×/n/-matrix. Unlike 'adj', this can be
-- applied to non-square matrices.
adjoint :: (Nat m, Adjoint a) => Matrix m n a -> Matrix n m a
adjoint (Matrix a) = Matrix c where
  b = vector_map (vector_map adj) a
  c = vector_transpose b
  
-- | Return the element in the /i/th row and /j/th column of the
-- matrix. Counting of rows and columns starts from 0. Throws an error
-- if the index is out of range.
matrix_index :: (Integral i) => Matrix m n a -> i -> i -> a
matrix_index (Matrix a) i j = a `vector_index` j `vector_index` i

-- | Return a list of all the entries of a matrix, in some fixed but
-- unspecified order.
matrix_entries :: Matrix m n a -> [a]
matrix_entries (Matrix m) = 
  concat $ map list_of_vector $ list_of_vector m

-- | Version of 'sequence' for matrices.
matrix_sequence :: (Monad m) => Matrix n p (m a) -> m (Matrix n p a)
matrix_sequence (Matrix m) = do
  m' <- vector_sequence (vector_map vector_sequence m)
  return (Matrix m')

-- | Return the trace of a square matrix.
tr :: (Ring a) => Matrix n n a -> a
tr (Matrix a) = aux a where
  aux :: (Num a) => Vector n (Vector n a) -> a
  aux Nil = 0
  aux ((h `Cons` t) `Cons` s) = h + aux (vector_map vector_tail s)

-- | Return the square of the Hilbert-Schmidt norm of an
-- /m/×/n/-matrix, defined by ‖/M/‖² = tr /M/[sup †]/M/.
hs_sqnorm :: (Ring a, Adjoint a, Nat n) => Matrix n m a -> a
hs_sqnorm m = tr (m .*. adjoint m)

-- ----------------------------------------------------------------------
-- Class instances for the ring of square matrices

instance (Num a, Nat n) => Num (Matrix n n a) where
  (+) = (.+.)
  (*) = (.*.)
  negate = scalarmult (-1)
  (-) = (.-.)
  fromInteger x = matrix_of_function (\i j -> if i == j then fromInteger x else 0)
  abs a = a
  signum a = 1
        
instance (Nat n, Adjoint a) => Adjoint (Matrix n n a) where
  adj (Matrix a) = Matrix c where
    b = vector_map (vector_map adj) a
    c = vector_transpose b

instance (Nat n, Adjoint2 a) => Adjoint2 (Matrix n n a) where
  adj2 (Matrix a) = Matrix b where
    b = vector_map (vector_map adj2) a

instance (HalfRing a, Nat n) => HalfRing (Matrix n n a) where
  half = scalarmult half 1

instance (RootHalfRing a, Nat n) => RootHalfRing (Matrix n n a) where
  roothalf = scalarmult roothalf 1

instance (RootTwoRing a, Nat n) => RootTwoRing (Matrix n n a) where
  roottwo = scalarmult roottwo 1

instance (ComplexRing a, Nat n) => ComplexRing (Matrix n n a) where
  i = scalarmult i 1

-- ----------------------------------------------------------------------
-- ** Operations on block matrices

-- | Stack matrices vertically.
stack_vertical :: Matrix m n a -> Matrix p n a -> Matrix (m `Plus` p) n a
stack_vertical (Matrix a) (Matrix b) = (Matrix c) where
  c = vector_zipwith vector_append a b

-- | Stack matrices horizontally.
stack_horizontal :: Matrix m n a -> Matrix m p a -> Matrix m (n `Plus` p) a
stack_horizontal (Matrix a) (Matrix b) = (Matrix c) where
  c = vector_append a b
  
-- | Repeat a matrix vertically, according to some vector of scalars.
tensor_vertical :: (Num a, Nat n) => Vector p a -> Matrix m n a -> Matrix (p `Times` m) n a
tensor_vertical v m = concat_vertical (vector_map (`scalarmult` m) v)
                               
-- | Vertically concatenate a vector of matrices.
concat_vertical :: (Num a, Nat n) => Vector p (Matrix m n a) -> Matrix (p `Times` m) n a
concat_vertical Nil = null_matrix
concat_vertical (Cons h t) = stack_vertical h (concat_vertical t)

-- | Repeat a matrix horizontally, according to some vector of scalars.
tensor_horizontal :: (Num a, Nat m) => Vector p a -> Matrix m n a -> Matrix m (p `Times` n) a
tensor_horizontal v m = concat_horizontal (vector_map (`scalarmult` m) v)
  
-- | Horizontally concatenate a vector of matrices.
concat_horizontal :: (Num a, Nat m) => Vector p (Matrix m n a) -> Matrix m (p `Times` n) a
concat_horizontal Nil = null_matrix
concat_horizontal (Cons h t) = stack_horizontal h (concat_horizontal t)

-- | Kronecker tensor of two matrices.
tensor :: (Num a, Nat n, Nat (p `Times` m)) => Matrix p q a -> Matrix m n a -> Matrix (p `Times` m) (q `Times` n) a
tensor a b = ab3 where
  Matrix ab1 = matrix_map (`scalarmult` b) a
  ab2 = vector_map concat_vertical ab1
  ab3 = concat_horizontal ab2

-- | Form a diagonal block matrix.
oplus :: (Num a, Nat m, Nat q, Nat n, Nat p) => Matrix p q a -> Matrix m n a -> Matrix (p `Plus` m) (q `Plus` n) a
oplus (a :: Matrix p q a) (b :: Matrix m n a) = 
  (a `stack_vertical` (null_matrix :: Matrix m q a)) `stack_horizontal` ((null_matrix :: Matrix p n a) `stack_vertical` b)

-- | Form a controlled gate.
matrix_controlled :: (Eq a, Num a, Nat n) => Matrix n n a -> Matrix (n `Plus` n) (n `Plus` n) a
matrix_controlled (m :: Matrix n n a) = oplus (1 :: Matrix n n a) m

-- ----------------------------------------------------------------------
-- ** Constructors and destructors

-- | A convenient abbreviation for the type of 2×2-matrices.
type U2 a = Matrix Two Two a

-- | A convenient abbreviation for the type of 3×3-matrices.
type SO3 a = Matrix Three Three a

-- | A convenience constructor for matrices: turn a list of columns
-- into a matrix. 
-- 
-- Note: since the dimensions of the matrix are type-level integers,
-- they cannot be inferred from the dimensions of the input; instead,
-- they must be specified explicitly in the type. It is an error to
-- apply this function to a list of the wrong dimension.
matrix_of_columns :: (Nat n, Nat m) => [[a]] -> Matrix n m a
matrix_of_columns columns = Matrix m where
  m = vector $ map vector columns

-- | A convenience constructor for matrices: turn a list of rows into
-- a matrix.
-- 
-- Note: since the dimensions of the matrix are type-level integers,
-- they cannot be inferred from the dimensions of the input; instead,
-- they must be specified explicitly in the type. It is an error to
-- apply this function to a list of the wrong dimension.
matrix_of_rows :: (Nat n, Nat m) => [[a]] -> Matrix n m a
matrix_of_rows = matrix_transpose . matrix_of_columns

-- | A synonym for 'matrix_of_rows'.
matrix :: (Nat n, Nat m) => [[a]] -> Matrix n m a
matrix = matrix_of_rows

-- | Turn a matrix into a list of columns.
columns_of_matrix :: Matrix n m a -> [[a]]
columns_of_matrix (Matrix m) = 
  map list_of_vector (list_of_vector m)

-- | Turn a matrix into a list of rows.
rows_of_matrix :: (Nat n) => Matrix n m a -> [[a]]
rows_of_matrix = columns_of_matrix . matrix_transpose

-- | A convenience constructor for 2×2-matrices. The arguments are by
-- rows.
matrix2x2 :: (a, a) -> (a, a) -> Matrix Two Two a
matrix2x2 (a, b) (c, d) = matrix_of_columns [[a,c], [b,d]]

-- | A convenience destructor for 2×2-matrices. The result is by rows.
from_matrix2x2 :: Matrix Two Two a -> ((a, a), (a, a))
from_matrix2x2 (Matrix ((a `Cons` c `Cons` Nil) `Cons` (b `Cons` d `Cons` Nil) `Cons` Nil)) = ((a, b), (c, d))  

-- | A convenience constructor for 3×3-matrices. The arguments are by
-- rows.
matrix3x3 :: (a, a, a) -> (a, a, a) -> (a, a, a) -> Matrix Three Three a
matrix3x3 (a0, a1, a2) (b0, b1, b2) (c0, c1, c2) = 
  matrix_of_columns [[a0, b0, c0], [a1, b1, c1], [a2, b2, c2]]

-- | A convenience constructor for 4×4-matrices. The arguments are by
-- rows.
matrix4x4 :: (a, a, a, a) -> (a, a, a, a) -> (a, a, a, a) -> (a, a, a, a) -> Matrix Four Four a
matrix4x4 (a0, a1, a2, a3) (b0, b1, b2, b3) (c0, c1, c2, c3) (d0, d1, d2, d3) = 
  matrix_of_columns [[a0, b0, c0, d0], [a1, b1, c1, d1], [a2, b2, c2, d2], [a3, b3, c3, d3]]

-- | A convenience constructor for 3-dimensional column vectors.
column3 :: (a, a, a) -> Matrix Three One a
column3 (a, b, c) = matrix_of_columns [[a, b, c]]

-- | A convenience destructor for 3-dimensional column vectors. This
-- is the inverse of 'column3'.
from_column3 :: Matrix Three One a -> (a, a, a)
from_column3 (Matrix ((a `Cons` b `Cons` c `Cons` Nil) `Cons` Nil)) = (a, b, c)

-- | A convenience constructor for turning a vector into a column matrix.
column_matrix :: Vector n a -> Matrix n One a
column_matrix v = Matrix (vector_singleton v)

-- ----------------------------------------------------------------------
-- ** Particular matrices

-- | Controlled-not gate.
cnot :: (Num a) => Matrix Four Four a
cnot = matrix4x4 (1,0,0,0)
                 (0,1,0,0)
                 (0,0,0,1)
                 (0,0,1,0)

-- | Swap gate.
swap :: (Num a) => Matrix Four Four a
swap = matrix4x4 (1,0,0,0)
                 (0,0,1,0)
                 (0,1,0,0)
                 (0,0,0,1)

-- | A /z/-rotation gate, /R/[sub /z/](θ) = [exp −/i/θ/Z/\/2].
zrot :: (Eq r, Floating r, Adjoint r) => r -> Matrix Two Two (Cplx r)
zrot theta = matrix2x2 (u, 0)
                       (0, adj u)
  where
    u = Cplx (cos (theta/2)) (-sin (theta/2))