{-# language GADTs, RankNTypes #-}
module Data.Tensor.Compiler.PHOAS where


-- | Parametric higher-order abstract syntax (PHOAS), after B. Oliveira, A. Loeh, `Abstract Syntax Graphs for Domain Specific Languages`  
data Phoas a where
  Var :: a -> Phoas a
  Let :: Phoas a -> (a -> Phoas b) -> Phoas b
  -- Let :: Phoas a -> (a -> Phoas a) -> Phoas a

-- | Inject a constant into the abstract syntax
var :: a -> Phoas a
var = Var

-- | Bind a variable into a closure
let_ :: Phoas a -> (a -> Phoas b) -> Phoas b
let_ = Let

-- | Bind two variables into a closure
let2_ :: Phoas a -> Phoas b -> (a -> b -> Phoas c) -> Phoas c
let2_ a b f = let_ a $ \xa ->
  let_ b $ \xb -> f xa xb

-- letP_ :: Phoas a -> (Phoas a -> Phoas a) -> Phoas a
-- letP_ e f = let_ e (f . Var)

-- letP2_ :: Phoas a -> Phoas a -> (Phoas a -> Phoas a -> Phoas a) -> Phoas a
-- letP2_ a b f = let2_ a b (\x y -> f (Var x) (Var y))


-- instance Show a => Show (Phoas a) where
--   show e = case e of
--     Var x -> show x
--     -- Let e f -> unwords



-- | Helper functions  


lift1 :: (a -> b) -> a -> Phoas b
lift1 f = Var . f -- Lift1

lift2 :: (a -> b -> c) -> a -> b -> Phoas c
lift2 f a b = Var (f a b) -- Lift2


plus :: Num a => a -> a -> Phoas a
plus = lift2 (+)


-- | Benchmark: `tree 50` should compute the answer instantly.
--
-- This proves that the PHOAS formulation preserves variable sharing
treeE :: Integer -> Phoas Integer
treeE 0 = Var 1
treeE n = let_ (treeE (n - 1)) $ \a -> a `plus` a 




-- | Semantic function for evaluation
eval :: Phoas a -> a
eval expr = case expr of
  Var x -> x
  Let e f -> eval (f (eval e))


-- -- | Semantic function for pretty-printing
-- type ClosedExpr = forall a . Phoas a

-- pprint :: ClosedExpr -> String
-- pprint expr = go expr 0
--   where
--     go :: Phoas String -> Int -> String
--     go (Var x) _ = x
--     go (Let e f) c = unwords ["(let", v, "=", go e (c+1), "in", go (f v) (c+1),")"]
--       where
--         v = "v" ++ show c




-- * A possible abstract syntax










-- data Index i where
--   I1 :: i -> Index i
--   I2 :: i -> i -> Index (i, i)

-- -- | Expressions with tensor operands, e.g. "contract A_{ijk}B_{k} over the third index"

-- -- | User-facing grammar:
-- data Expr a where
--   -- | Introduce a constant in the AST
--   Konst :: a -> Expr a
--   -- | Tensor contraction
--   Contr :: i -> (Expr a -> Expr a -> Expr a) -> Expr a
--   -- | Binary componentwise operation
--   CW2 :: (a -> a -> a) -> Expr a -> Expr a -> Expr a

-- k :: a -> Expr a
-- k = Konst

-- (|*|), (|+|) :: Num a => Expr a -> Expr a -> Expr a
-- (|*|) = CW2 (*)  
-- (|+|) = CW2 (+)

-- dot = Contr 1 (|+|)





-- -- * PHOAS 2

-- data Phoas a where
--   Const :: a -> Phoas a
--   Lift1 :: (a -> b) -> Phoas (a -> b)
--   Let :: Phoas a -> (Phoas a -> Phoas b) -> Phoas b
--   Lambda :: (Phoas a -> Phoas b) -> Phoas (a -> b)
--   App1 :: Phoas (a -> b) -> (Phoas a -> Phoas b)

-- eval expr = case expr of
--   -- Const x -> x
--   Let e f -> f e

-- -- lift1 :: (a -> b) -> a -> Phoas b
-- -- lift1 f = Const . f

-- -- lift2 :: (t2 -> t1 -> t) -> Phoas (t2 -> t1 -> t)
-- -- lift2 f = Const $ \a b -> f a b 


-- -- 


-- data Phoas a =
--     Lit Int
--   -- | Lift1 (a -> a) (a -> Phoas a)
--   -- | Add (Phoas a) (Phoas a)
--   | Let (Phoas a) (a -> Phoas a)
--   | Let2 (Phoas a) (Phoas a) (a -> a -> Phoas a)
--   | Var a

-- evalPhoas expr = case expr of
--   Var x -> x
--   Lit i -> i
--   -- Add e0 e1 -> evalPhoas e0 + evalPhoas e1
--   Let e f -> evalPhoas $ f e' where e' = evalPhoas e
--   Let2 e0 e1 f -> evalPhoas $ f e0' e1' where
--     e0' = evalPhoas e0
--     e1' = evalPhoas e1


-- contract ixs (T sh)

-- eval (Const x) = x
-- eval (Contract ixs a b) = undefined



-- data Expr a =
--     Const a
--   | Contract Int (Expr a) (Expr a)
--   -- | Expr a :+: Expr a
--   -- | Expr a :*: Expr a
--   -- | Expr a :-: Expr a
--   -- | Expr a :/: Expr a
--   deriving (Eq, Show)

-- -- | trivial recursive evaluation function
-- eval :: Num t => Expr t -> t
-- eval (Const x) = x
-- eval (a :+: b) = eval a + eval b
-- eval (a :*: b) = eval a * eval b




-- | GADT syntax

-- data Expr a where
--   Const :: a -> Expr a 
--   -- ^ Sum (elementwise) two expressions
--   (:+:) :: Expr a -> Expr a -> Expr a
--   -- ^ Multiply (elementwise) two expressions
--   (:*:) :: Expr a -> Expr a -> Expr a
--   -- ^ Subtract (elementwise) two expressions
--   (:-:) :: Expr a -> Expr a -> Expr a