{-# LANGUAGE FlexibleInstances, MultiParamTypeClasses #-}
-----------------------------------------------------------------------------
-- |
-- Module     : Data.Vector.Dense.Internal
-- Copyright  : Copyright (c) 2008, Patrick Perry <patperry@stanford.edu>
-- License    : BSD3
-- Maintainer : Patrick Perry <patperry@stanford.edu>
-- Stability  : experimental
--
module Data.Vector.Dense.Internal (
    -- * The Vector type
    Vector(..),

    -- * Vector shape
    dim,
    coerceVector,
    module BLAS.Tensor.Base,

    -- * Conjugating vectors
    module BLAS.Conj,

    -- * Creating new vectors
    vector, 
    listVector,
    unsafeVector,

    -- * Reading vector elements
    module BLAS.Tensor.Immutable,

    -- * Special vectors
    zeroVector,
    constantVector,
    basisVector,

    -- * Vector views
    subvector,
    subvectorWithStride,
    unsafeSubvector,
    unsafeSubvectorWithStride,

    -- * Vector properties
    sumAbs,
    norm2,
    whichMaxAbs,
    (<.>),

    -- * Low-level vector properties
    stride,
    isConj,
    ) where

import Data.AEq
import System.IO.Unsafe

import BLAS.Conj
import BLAS.Tensor.Base
import BLAS.Tensor.Immutable

import BLAS.Elem ( BLAS1 )
import BLAS.Internal ( inlinePerformIO )
import BLAS.UnsafeIOToM

import Data.Vector.Dense.IO
import Data.Vector.Dense.Class


infixl 7 <.>


newtype Vector n e = V (IOVector n e)


unsafeFreezeIOVector :: IOVector n e -> Vector n e
unsafeFreezeIOVector = V
unsafeThawIOVector :: Vector n e -> IOVector n e
unsafeThawIOVector (V x) = x


liftVector :: (IOVector n e -> a) -> Vector n e -> a
liftVector f (V x) = f x
{-# INLINE liftVector #-}

liftVector2 :: 
    (IOVector n e -> IOVector n e -> a) -> 
        Vector n e -> Vector n e -> a
liftVector2 f x = liftVector (liftVector f x)
{-# INLINE liftVector2 #-}

unsafeLiftVector :: (IOVector n e -> IO a) -> Vector n e -> a
unsafeLiftVector f = unsafePerformIO . liftVector f
{-# NOINLINE unsafeLiftVector #-}

unsafeLiftVector2 :: 
    (IOVector n e -> IOVector n e -> IO a) -> 
        Vector n e -> Vector n e -> a
unsafeLiftVector2 f x y = unsafePerformIO $ liftVector2 f x y
{-# NOINLINE unsafeLiftVector2 #-}

inlineLiftVector :: (IOVector n e -> IO a) -> Vector n e -> a
inlineLiftVector f = inlinePerformIO . liftVector f
{-# INLINE inlineLiftVector #-}

-- | Create a vector with the given dimension and elements.  The elements
-- given in the association list must all have unique indices, otherwise
-- the result is undefined.
vector :: (BLAS1 e) => Int -> [(Int, e)] -> Vector n e
vector n ies = unsafeFreezeIOVector $ unsafePerformIO $ newVector n ies
{-# NOINLINE vector #-}

-- | Same as 'vector', but does not range-check the indices.
unsafeVector :: (BLAS1 e) => Int -> [(Int, e)] -> Vector n e
unsafeVector n ies = unsafeFreezeIOVector $ unsafePerformIO $ unsafeNewVector n ies
{-# NOINLINE unsafeVector #-}

-- | Create a vector of the given dimension with elements initialized
-- to the values from the list.  @listVector n es@ is equivalent to 
-- @vector n (zip [0..(n-1)] es)@, except that the result is undefined 
-- if @length es@ is less than @n@.
listVector :: (BLAS1 e) => Int -> [e] -> Vector n e
listVector n es = unsafeFreezeIOVector $ unsafePerformIO $ newListVector n es
{-# NOINLINE listVector #-}

-- | @zeroVector n@ creates a vector of dimension @n@ with all values
-- set to zero.
zeroVector :: (BLAS1 e) => Int -> Vector n e
zeroVector n = unsafeFreezeIOVector $ unsafePerformIO $ newZeroVector n
{-# NOINLINE zeroVector #-}

-- | @constantVector n e@ creates a vector of dimension @n@ with all values
-- set to @e@.
constantVector :: (BLAS1 e) => Int -> e -> Vector n e
constantVector n e = unsafeFreezeIOVector $ unsafePerformIO $ newConstantVector n e
{-# NOINLINE constantVector #-}

-- | @basisVector n i@ creates a vector of dimension @n@ with zeros 
-- everywhere but position @i@, where there is a one.
basisVector :: (BLAS1 e) => Int -> Int -> Vector n e
basisVector n i = unsafeFreezeIOVector $ unsafePerformIO $ newBasisVector n i
{-# NOINLINE basisVector #-}

-- | Compute the sum of absolute values of entries in the vector.
sumAbs :: (BLAS1 e) => Vector n e -> Double
sumAbs = unsafeLiftVector getSumAbs
{-# NOINLINE sumAbs #-}

-- | Compute the 2-norm of a vector.
norm2 :: (BLAS1 e) => Vector n e -> Double
norm2 = unsafeLiftVector getNorm2
{-# NOINLINE norm2 #-}

-- | Get the index and norm of the element with absulte value.  Not valid 
-- if any of the vector entries are @NaN@.  Raises an exception if the 
-- vector has length @0@.
whichMaxAbs :: (BLAS1 e) => Vector n e -> (Int, e)
whichMaxAbs = unsafeLiftVector getWhichMaxAbs
{-# NOINLINE whichMaxAbs #-}

-- | Compute the dot product of two vectors.
(<.>) :: (BLAS1 e) => Vector n e -> Vector n e -> e
(<.>) = unsafeLiftVector2 getDot
{-# NOINLINE (<.>) #-}


instance BaseTensor Vector Int where
    shape  = liftVector shape
    bounds = liftVector bounds

instance ITensor Vector Int where
    (//)          = replaceHelp writeElem
    unsafeReplace = replaceHelp unsafeWriteElem
    
    unsafeAt x i  = inlineLiftVector (flip unsafeReadElem i) x
    {-# INLINE unsafeAt #-}
    
    size          = inlineLiftVector getSize
    elems         = inlineLiftVector getElems
    indices       = inlineLiftVector getIndices
    assocs        = inlineLiftVector getAssocs

    tmap f x      = listVector (dim x) (map f $ elems x)

    (*>) k x = unsafeFreezeIOVector $ unsafeLiftVector (getScaledVector k) x
    {-# NOINLINE (*>) #-}

    shift k x = unsafeFreezeIOVector $ unsafeLiftVector (getShiftedVector k) x
    {-# NOINLINE shift #-}


replaceHelp :: (BLAS1 e) => 
    (IOVector n e -> Int -> e -> IO ()) ->
        Vector n e -> [(Int, e)] -> Vector n e
replaceHelp set x ies =
    unsafePerformIO $ do
        y  <- newCopyVector (unsafeThawIOVector x)
        mapM_ (uncurry $ set y) ies
        return (unsafeFreezeIOVector y)
{-# NOINLINE replaceHelp #-}

instance (Monad m) => ReadTensor Vector Int m where
    getSize        = return . size
    getAssocs      = return . assocs
    getIndices     = return . indices
    getElems       = return . elems
    getAssocs'     = return . assocs
    getIndices'    = return . indices
    getElems'      = return . elems
    unsafeReadElem x i = return $ unsafeAt x i
    


instance BaseVector Vector where
    vectorViewArray f o n s c = V $ vectorViewArray f o n s c
    arrayFromVector           = liftVector arrayFromVector

instance (UnsafeIOToM m) => ReadVector Vector m where
    
instance (BLAS1 e) => Num (Vector n e) where
    (+) x y     = unsafeFreezeIOVector $ unsafeLiftVector2 getAddVector x y
    (-) x y     = unsafeFreezeIOVector $ unsafeLiftVector2 getSubVector x y
    (*) x y     = unsafeFreezeIOVector $ unsafeLiftVector2 getMulVector x y
    negate      = ((-1) *>)
    abs         = tmap abs
    signum      = tmap signum
    fromInteger = (constantVector 1) . fromInteger
    
instance (BLAS1 e) => Fractional (Vector n e) where
    (/) x y      = unsafeFreezeIOVector $ unsafeLiftVector2 getDivVector x y
    recip        = tmap recip
    fromRational = (constantVector 1) . fromRational 
    
instance (BLAS1 e, Floating e) => Floating (Vector n e) where
    pi    = constantVector 1 pi
    exp   = tmap exp
    sqrt  = tmap sqrt
    log   = tmap log
    (**)  = tzipWith (**)
    sin   = tmap sin
    cos   = tmap cos
    tan   = tmap tan
    asin  = tmap asin
    acos  = tmap acos
    atan  = tmap atan
    sinh  = tmap sinh
    cosh  = tmap cosh
    tanh  = tmap tanh
    asinh = tmap asinh
    acosh = tmap acosh
    atanh = tmap atanh

tzipWith :: (BLAS1 e) =>
    (e -> e -> e) -> Vector n e -> Vector n e -> Vector n e
tzipWith f x y
    | dim y /= n =
        error ("tzipWith: vector lengths differ; first has length `" ++
                show n ++ "' and second has length `" ++
                show (dim y) ++ "'")
    | otherwise =
        listVector n (zipWith f (elems x) (elems y))
  where
    n = dim x


instance (BLAS1 e, Show e) => Show (Vector n e) where
    show x
        | isConj x  = "conj (" ++ show (conj x) ++ ")"
        | otherwise = "listVector " ++ show (dim x) ++ " " ++ show (elems x)

instance (BLAS1 e, Eq e) => Eq (Vector n e) where
    (==) = compareHelp (==)

instance (BLAS1 e, AEq e) => AEq (Vector n e) where
    (===) = compareHelp (===)
    (~==) = compareHelp (~==)

compareHelp :: (BLAS1 e) => 
    (e -> e -> Bool) -> 
        Vector n e -> Vector n e -> Bool
compareHelp cmp x y
    | isConj x && isConj y =
        compareHelp cmp (conj x) (conj y)
    | otherwise =
        (dim x == dim y) && (and $ zipWith cmp (elems x) (elems y))