{-# language GADTs #-}
{-# language DeriveFunctor #-}
{-# language TypeOperators #-}
{-# language PackageImports #-}
{-# language TypeFamilies #-}
{-# language FlexibleInstances #-}
module Data.Tensor (
  -- * Tensor type
  Tensor(..),
  tshape, tdata, nnz, rank, dim, 
  -- * Shape type
  Sh(..),
  -- * Dimension types
  Dim.Dd(..), Dim.Sd(..)) where

import qualified Data.Vector as V
-- import qualified Data.Vector.Unboxed as VU

-- import Data.Int (Int32)

import Control.Applicative


import qualified Data.Shape as Shape (Shape(..), dim, rank)
import Data.Shape (Sh(..), 
                   Z,
                   D1, D2, CSR, COO, mkD2, mkCSR, mkCOO)
import qualified Data.Shape.Dynamic as ShDyn       
import qualified Data.Dim as Dim



-- | A tensor type with dimensions only known at runtime
data Tenzor i a = Tenzor {tzShape :: ShDyn.ShD i, tzData :: V.Vector a }

instance Integral i => Shape.Shape (Tenzor i a) where
  type ShapeT (Tenzor i a) = ShDyn.ShD i
  shape = tzShape 
  shRank = ShDyn.rank . Shape.shape
  shDim = ShDyn.dim . Shape.shape

instance Shape.Shape (Tensor (Sh i) a) where
  type ShapeT (Tensor (Sh i) a) = Sh i
  shape = tshape 
  shRank = rank
  shDim = dim


-- | The 'Tensor' type with statically known shape. Tensor data entries are stored as one single array
data Tensor i a where
  Tensor :: Sh i -> V.Vector a -> Tensor (Sh i) a 

-- | Construct a tensor given a shape and a vector of entries
mkT :: Sh i -> V.Vector a -> Tensor (Sh i) a
mkT = Tensor

instance Functor (Tensor i) where
  fmap f (Tensor sh v) = Tensor sh (f <$> v)

-- liftA2' :: (a -> a -> b) -> Tensor i a -> Tensor i a -> Tensor i a 
-- liftA2' f (T sh1 v1) (T sh2 v2) = mkT sh1 (V.zipWith f v1 v2)

pure' :: a -> Tensor (Sh Z) a
pure' = mkT Z . V.singleton

instance (Eq a) => Eq (Tensor i a) where
  (Tensor sh1 d1) == (Tensor sh2 d2) = sh1 == sh2 && d1 == d2

instance (Show a) => Show (Tensor i a) where
  show (Tensor sh d) = unwords [show sh, show $ V.take 5 d, "..."]


 

-- | Access the shape of a 'Tensor'
tshape :: Tensor sh a -> sh
tshape (Tensor sh _) = sh

-- | Access the raw data of a 'Tensor'
tdata :: Tensor sh a -> V.Vector a
tdata (Tensor _ td) = td

-- | Number of nonzero tensor elements
nnz :: Tensor i a -> Int
nnz (Tensor _ td) = V.length td

-- | Tensor rank
rank :: Tensor i a -> Int
rank (Tensor sh _) = Shape.rank sh

-- | Tensor dimensions
dim :: Tensor i a -> [Int]
dim (Tensor sh _) = Shape.dim sh






-- | playground, for future use

-- -- | A generic tensor type, polymorphic in the container type as well
-- data GTensor c i a where
--   GTensor :: Sh i -> c a -> GTensor c (Sh i) a
  
-- mkGT :: Sh i -> c a -> GTensor c (Sh i) a
-- mkGT = GTensor