{-# language GADTs #-}
{-# language PackageImports #-}

module Data.Tensor.Compiler (
  --   contract
  --   -- * Tensor types
  -- , Tensor(..), Sh(..), Dd(..), Sd(..)
    -- * Syntax
    Phoas, eval, var, let_, let2_
    -- * Exceptions
  , CException (..)
  )where

import Data.Typeable
-- import "exceptions" Control.Monad.Catch (MonadThrow(..), throwM, MonadCatch(..), catch)
import Control.Exception (Exception(..))

-- import Control.Applicative (liftA2, (<|>))

import Data.Tensor -- (Tensor(..), Sh(..), Dd(..), Sd(..), tshape, tdata, nnz, rank, dim)
import Data.Tensor.Compiler.PHOAS -- (Phoas(..), let_, let2_, var, lift1, lift2, eval)



-- IN: Tensor reduction syntax (Einstein notation)

-- OUT: stride program (how to read/write memory)


-- taco compiles a tensor expression (e.g. C = A_{ijk}B_{k} ) into a series of nested loops.

-- dimensions : can be either dense or sparse

-- internally, tensor data is stored in /dense/ vectors

-- "contract A_{ijk}B_{k} over the third index"




-- | 

-- mkVar :: MonadThrow m => [Int] -> Tensor i a -> m (Phoas (Tensor i a))
-- mkVar ixs0 t = do
--   ixs <- mkIxs ixs0 (rank t)
--   return $ var t


-- mkIxs :: MonadThrow m => [Int] -> Int -> m [Int]
-- mkIxs ixs mm = go ixs []
--   where
--     go [] acc = pure acc
--     go (i:is) acc | i < 0 =
--                     throwM $ IncompatIx "Index must be non-negative"
--                   | i > mm - 1 =
--                     throwM $ IncompatIx $ unwords ["Index must be smaller than", show mm]
--                   | otherwise = go is (i : acc)
                  

-- | Tensor contraction
--
-- Inject two 'Tensor' constant into 'Var's, while ensuring that all the contraction indices are compatible with those of the tensors.
--
-- Throws a 'CException' if any index is nonnegative or too large for the shape of the given tensor.

-- -- contract :: MonadThrow m =>
-- --                   [Int]           -- ^ Tensor contraction indices
-- --                   -> Tensor i1 a
-- --                   -> Tensor i2 b
-- --                   -> ([Int] -> Tensor i1 a -> Tensor i2 b -> Phoas c) -- ^ Contraction function
-- --                   -> m (Phoas c)
-- contract ixs0 t1 t2 f = do
--   _ <- mkIxs ixs0 (rank t1)
--   ixs <- mkIxs ixs0 (rank t2)
--   pure $ let_ (var ixs) $ \ixs' ->
--     let2_ (var t1) (var t2) (f ixs')


-- | Exceptions
data CException = IncompatShape String | IncompatIx String deriving (Eq, Typeable)
instance Show CException where
  show c = case c of
    IncompatShape str -> unwords ["Incompatible shape:", str]
    IncompatIx str -> unwords ["Incompatible index:", str]
instance Exception CException where