{-# language GADTs #-}
{-# language DeriveFunctor #-}
{-# language TypeOperators #-}
module Data.Tensor (
  -- * Tensor type
  Tensor(..), shape, nnz,
  -- * Shape type
  Sh(..),
  -- * Dimension types
  Dim.Dd(..), Dim.Sd(..)) where

import qualified Data.Vector as V
-- import qualified Data.Vector.Unboxed as VU

-- import Data.Word (Word32, Word64)
-- import Data.Int (Int32)

import Data.Shape (Sh(..), dim, rank,
                   Z,
                   D1, D2, CSR, COO, mkD2, mkCSR, mkCOO)
import qualified Data.Dim as Dim



{- |
IN: Tensor reduction syntax (Einstein notation)

OUT: stride program (how to read/write memory)


taco compiles a tensor expression (e.g. C = A_{ijk}B_{k} ) into a series of nested loops.

dimensions : can be either dense or sparse

internally, tensor data is stored in /dense/ vectors

"contract A_{ijk}B_{k} over the third index"

-}



-- | A generic tensor type, polymorphic in the container type as well
data GTensor c i a where
  GTensor :: Sh i -> c a -> GTensor c (Sh i) a
  
mkGT :: Sh i -> c a -> GTensor c (Sh i) a
mkGT = GTensor


-- | The 'Tensor' type. Tensor data entries are stored as one single array
data Tensor i a where
  T :: Sh i -> V.Vector a -> Tensor (Sh i) a 

-- | Construct a tensor given a shape and a vector of entries
mkT :: Sh i -> V.Vector a -> Tensor (Sh i) a
mkT = T

instance Functor (Tensor i) where
  fmap f (T sh v) = T sh (f <$> v)

-- liftA2' :: (a -> a -> b) -> Tensor i a -> Tensor i a -> Tensor i a 
-- liftA2' f (T sh1 v1) (T sh2 v2) = mkT sh1 (V.zipWith f v1 v2)

pure' :: a -> Tensor (Sh Z) a
pure' = mkT Z . V.singleton

instance (Eq a) => Eq (Tensor i a) where
  (T sh1 d1) == (T sh2 d2) = sh1 == sh2 && d1 == d2

instance (Show a) => Show (Tensor i a) where
  show (T sh d) = unwords [show sh, show $ V.take 5 d, "..."]


 

-- | Access the shape of a tensor
shape :: Tensor sh a -> sh
shape (T sh _) = sh

-- | Number of nonzero tensor elements
nnz :: Tensor i a -> Int
nnz (T _ td) = V.length td






-- * A possible abstract syntax

-- data Index i where
--   I1 :: i -> Index i
--   I2 :: i -> i -> Index (i, i)

-- -- | Expressions with tensor operands, e.g. "contract A_{ijk}B_{k} over the third index"

-- -- mkConstE = Const <$> mkT

-- data Expr a where
--   Const :: Tensor (Sh i) a -> Expr (Tensor (Sh i) a)

-- data Expr i a where
--   -- Const :: a -> Expr a
--   Contract :: Index i -> Expr (Sh i) a -> Expr (Sh i) a -> Expr (Sh i) a
--   -- (:*:) :: Expr a -> Expr a -> Expr a
--   -- (:+:) :: Expr a -> Expr a -> Expr a

-- eval (Const x) = x
-- eval (Contract ixs a b) = undefined



-- data Expr a =
--     Const a
--   | Contract Int (Expr a) (Expr a)
--   -- | Expr a :+: Expr a
--   -- | Expr a :*: Expr a
--   -- | Expr a :-: Expr a
--   -- | Expr a :/: Expr a
--   deriving (Eq, Show)

-- -- | trivial recursive evaluation function
-- eval :: Num t => Expr t -> t
-- eval (Const x) = x
-- eval (a :+: b) = eval a + eval b
-- eval (a :*: b) = eval a * eval b




-- | GADT syntax

-- data Expr a where
--   Const :: a -> Expr a 
--   -- ^ Sum (elementwise) two expressions
--   (:+:) :: Expr a -> Expr a -> Expr a
--   -- ^ Multiply (elementwise) two expressions
--   (:*:) :: Expr a -> Expr a -> Expr a
--   -- ^ Subtract (elementwise) two expressions
--   (:-:) :: Expr a -> Expr a -> Expr a