module Agda.Compiler.JS.Substitution where

import Prelude hiding ( map, lookup )
import Data.Map ( empty, unionWith, singleton, findWithDefault )
import qualified Data.Map as Map
import Data.List ( genericIndex )
import qualified Data.List as List

import Agda.Syntax.Common ( Nat )
import Agda.Compiler.JS.Syntax
  ( Exp(Self,Undefined,Local,Lambda,Object,Apply,Lookup,If,BinOp,PreOp),
    MemberId, LocalId(LocalId) )
import Agda.Utils.Function ( iterate' )

-- Map for expressions

map :: Nat -> (Nat -> LocalId -> Exp) -> Exp -> Exp
map :: Nat -> (Nat -> LocalId -> Exp) -> Exp -> Exp
map Nat
m Nat -> LocalId -> Exp
f (Local LocalId
i)       = Nat -> LocalId -> Exp
f Nat
m LocalId
i
map Nat
m Nat -> LocalId -> Exp
f (Lambda Nat
i Exp
e)    = Nat -> Exp -> Exp
Lambda Nat
i (Nat -> (Nat -> LocalId -> Exp) -> Exp -> Exp
map (Nat
m Nat -> Nat -> Nat
forall a. Num a => a -> a -> a
+ Nat
i) Nat -> LocalId -> Exp
f Exp
e)
map Nat
m Nat -> LocalId -> Exp
f (Object Map MemberId Exp
o)      = Map MemberId Exp -> Exp
Object ((Exp -> Exp) -> Map MemberId Exp -> Map MemberId Exp
forall a b k. (a -> b) -> Map k a -> Map k b
Map.map (Nat -> (Nat -> LocalId -> Exp) -> Exp -> Exp
map Nat
m Nat -> LocalId -> Exp
f) Map MemberId Exp
o)
map Nat
m Nat -> LocalId -> Exp
f (Apply Exp
e [Exp]
es)    = Exp -> [Exp] -> Exp
Apply (Nat -> (Nat -> LocalId -> Exp) -> Exp -> Exp
map Nat
m Nat -> LocalId -> Exp
f Exp
e) ((Exp -> Exp) -> [Exp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
List.map (Nat -> (Nat -> LocalId -> Exp) -> Exp -> Exp
map Nat
m Nat -> LocalId -> Exp
f) [Exp]
es)
map Nat
m Nat -> LocalId -> Exp
f (Lookup Exp
e MemberId
l)    = Exp -> MemberId -> Exp
Lookup (Nat -> (Nat -> LocalId -> Exp) -> Exp -> Exp
map Nat
m Nat -> LocalId -> Exp
f Exp
e) MemberId
l
map Nat
m Nat -> LocalId -> Exp
f (If Exp
e Exp
e' Exp
e'')   = Exp -> Exp -> Exp -> Exp
If (Nat -> (Nat -> LocalId -> Exp) -> Exp -> Exp
map Nat
m Nat -> LocalId -> Exp
f Exp
e) (Nat -> (Nat -> LocalId -> Exp) -> Exp -> Exp
map Nat
m Nat -> LocalId -> Exp
f Exp
e') (Nat -> (Nat -> LocalId -> Exp) -> Exp -> Exp
map Nat
m Nat -> LocalId -> Exp
f Exp
e'')
map Nat
m Nat -> LocalId -> Exp
f (PreOp String
op Exp
e)    = String -> Exp -> Exp
PreOp String
op (Nat -> (Nat -> LocalId -> Exp) -> Exp -> Exp
map Nat
m Nat -> LocalId -> Exp
f Exp
e)
map Nat
m Nat -> LocalId -> Exp
f (BinOp Exp
e String
op Exp
e') = Exp -> String -> Exp -> Exp
BinOp (Nat -> (Nat -> LocalId -> Exp) -> Exp -> Exp
map Nat
m Nat -> LocalId -> Exp
f Exp
e) String
op (Nat -> (Nat -> LocalId -> Exp) -> Exp -> Exp
map Nat
m Nat -> LocalId -> Exp
f Exp
e')
map Nat
m Nat -> LocalId -> Exp
f Exp
e               = Exp
e

-- Shifting

shift :: Nat -> Exp -> Exp
shift :: Nat -> Exp -> Exp
shift = Nat -> Nat -> Exp -> Exp
shiftFrom Nat
0

shiftFrom :: Nat -> Nat -> Exp -> Exp
shiftFrom :: Nat -> Nat -> Exp -> Exp
shiftFrom Nat
m Nat
0 Exp
e = Exp
e
shiftFrom Nat
m Nat
n Exp
e = Nat -> (Nat -> LocalId -> Exp) -> Exp -> Exp
map Nat
m (Nat -> Nat -> LocalId -> Exp
shifter Nat
n) Exp
e

shifter :: Nat -> Nat -> LocalId -> Exp
shifter :: Nat -> Nat -> LocalId -> Exp
shifter Nat
n Nat
m (LocalId Nat
i) | Nat
i Nat -> Nat -> Bool
forall a. Ord a => a -> a -> Bool
< Nat
m     = LocalId -> Exp
Local (Nat -> LocalId
LocalId Nat
i)
shifter Nat
n Nat
m (LocalId Nat
i) | Bool
otherwise = LocalId -> Exp
Local (Nat -> LocalId
LocalId (Nat
i Nat -> Nat -> Nat
forall a. Num a => a -> a -> a
+ Nat
n))

-- Substitution

subst :: Nat -> [Exp] -> Exp -> Exp
subst :: Nat -> [Exp] -> Exp -> Exp
subst Nat
0 [Exp]
es Exp
e = Exp
e
subst Nat
n [Exp]
es Exp
e = Nat -> (Nat -> LocalId -> Exp) -> Exp -> Exp
map Nat
0 (Nat -> [Exp] -> Nat -> LocalId -> Exp
substituter Nat
n [Exp]
es) Exp
e

substituter :: Nat -> [Exp] -> Nat -> LocalId -> Exp
substituter :: Nat -> [Exp] -> Nat -> LocalId -> Exp
substituter Nat
n [Exp]
es Nat
m (LocalId Nat
i) | Nat
i Nat -> Nat -> Bool
forall a. Ord a => a -> a -> Bool
< Nat
m       = LocalId -> Exp
Local (Nat -> LocalId
LocalId Nat
i)
substituter Nat
n [Exp]
es Nat
m (LocalId Nat
i) | (Nat
i Nat -> Nat -> Nat
forall a. Num a => a -> a -> a
- Nat
m) Nat -> Nat -> Bool
forall a. Ord a => a -> a -> Bool
< Nat
n = Nat -> Exp -> Exp
shift Nat
m ([Exp] -> Nat -> Exp
forall i a. Integral i => [a] -> i -> a
genericIndex ([Exp]
es [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ Exp -> [Exp]
forall a. a -> [a]
repeat Exp
Undefined) (Nat
n Nat -> Nat -> Nat
forall a. Num a => a -> a -> a
- (Nat
i Nat -> Nat -> Nat
forall a. Num a => a -> a -> a
+ Nat
1 Nat -> Nat -> Nat
forall a. Num a => a -> a -> a
- Nat
m)))
substituter Nat
n [Exp]
es Nat
m (LocalId Nat
i) | Bool
otherwise   = LocalId -> Exp
Local (Nat -> LocalId
LocalId (Nat
i Nat -> Nat -> Nat
forall a. Num a => a -> a -> a
- Nat
n))

-- A variant on substitution which performs beta-reduction

map' :: Nat -> (Nat -> LocalId -> Exp) -> Exp -> Exp
map' :: Nat -> (Nat -> LocalId -> Exp) -> Exp -> Exp
map' Nat
m Nat -> LocalId -> Exp
f (Local LocalId
i)       = Nat -> LocalId -> Exp
f Nat
m LocalId
i
map' Nat
m Nat -> LocalId -> Exp
f (Lambda Nat
i Exp
e)    = Nat -> Exp -> Exp
Lambda Nat
i (Nat -> (Nat -> LocalId -> Exp) -> Exp -> Exp
map' (Nat
m Nat -> Nat -> Nat
forall a. Num a => a -> a -> a
+ Nat
i) Nat -> LocalId -> Exp
f Exp
e)
map' Nat
m Nat -> LocalId -> Exp
f (Object Map MemberId Exp
o)      = Map MemberId Exp -> Exp
Object ((Exp -> Exp) -> Map MemberId Exp -> Map MemberId Exp
forall a b k. (a -> b) -> Map k a -> Map k b
Map.map (Nat -> (Nat -> LocalId -> Exp) -> Exp -> Exp
map' Nat
m Nat -> LocalId -> Exp
f) Map MemberId Exp
o)
map' Nat
m Nat -> LocalId -> Exp
f (Apply Exp
e [Exp]
es)    = Exp -> [Exp] -> Exp
apply (Nat -> (Nat -> LocalId -> Exp) -> Exp -> Exp
map' Nat
m Nat -> LocalId -> Exp
f Exp
e) ((Exp -> Exp) -> [Exp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
List.map (Nat -> (Nat -> LocalId -> Exp) -> Exp -> Exp
map' Nat
m Nat -> LocalId -> Exp
f) [Exp]
es)
map' Nat
m Nat -> LocalId -> Exp
f (Lookup Exp
e MemberId
l)    = Exp -> MemberId -> Exp
lookup (Nat -> (Nat -> LocalId -> Exp) -> Exp -> Exp
map' Nat
m Nat -> LocalId -> Exp
f Exp
e) MemberId
l
map' Nat
m Nat -> LocalId -> Exp
f (If Exp
e Exp
e' Exp
e'')   = Exp -> Exp -> Exp -> Exp
If (Nat -> (Nat -> LocalId -> Exp) -> Exp -> Exp
map' Nat
m Nat -> LocalId -> Exp
f Exp
e) (Nat -> (Nat -> LocalId -> Exp) -> Exp -> Exp
map' Nat
m Nat -> LocalId -> Exp
f Exp
e') (Nat -> (Nat -> LocalId -> Exp) -> Exp -> Exp
map' Nat
m Nat -> LocalId -> Exp
f Exp
e'')
map' Nat
m Nat -> LocalId -> Exp
f (PreOp String
op Exp
e)    = String -> Exp -> Exp
PreOp String
op (Nat -> (Nat -> LocalId -> Exp) -> Exp -> Exp
map' Nat
m Nat -> LocalId -> Exp
f Exp
e)
map' Nat
m Nat -> LocalId -> Exp
f (BinOp Exp
e String
op Exp
e') = Exp -> String -> Exp -> Exp
BinOp (Nat -> (Nat -> LocalId -> Exp) -> Exp -> Exp
map' Nat
m Nat -> LocalId -> Exp
f Exp
e) String
op (Nat -> (Nat -> LocalId -> Exp) -> Exp -> Exp
map' Nat
m Nat -> LocalId -> Exp
f Exp
e')
map' Nat
m Nat -> LocalId -> Exp
f Exp
e               = Exp
e

subst' :: Nat -> [Exp] -> Exp -> Exp
subst' :: Nat -> [Exp] -> Exp -> Exp
subst' Nat
0 [Exp]
es Exp
e = Exp
e
subst' Nat
n [Exp]
es Exp
e = Nat -> (Nat -> LocalId -> Exp) -> Exp -> Exp
map' Nat
0 (Nat -> [Exp] -> Nat -> LocalId -> Exp
substituter Nat
n [Exp]
es) Exp
e

-- Beta-reducing application and field access

apply :: Exp -> [Exp] -> Exp
apply :: Exp -> [Exp] -> Exp
apply (Lambda Nat
i Exp
e) [Exp]
es = Nat -> [Exp] -> Exp -> Exp
subst' Nat
i [Exp]
es Exp
e
apply Exp
e            [Exp]
es = Exp -> [Exp] -> Exp
Apply Exp
e [Exp]
es

lookup :: Exp -> MemberId -> Exp
lookup :: Exp -> MemberId -> Exp
lookup (Object Map MemberId Exp
o) MemberId
l = Exp -> MemberId -> Map MemberId Exp -> Exp
forall k a. Ord k => a -> k -> Map k a -> a
findWithDefault Exp
Undefined MemberId
l Map MemberId Exp
o
lookup Exp
e          MemberId
l = Exp -> MemberId -> Exp
Lookup Exp
e MemberId
l

-- Replace any top-level occurrences of self
-- (needed because JS is a cbv language, so any top-level
-- recursions would evaluate before the module has been defined,
-- e.g. exports = { x: 1, y: exports.x } results in an exception,
-- as exports is undefined at the point that exports.x is evaluated),

self :: Exp -> Exp -> Exp
self :: Exp -> Exp -> Exp
self Exp
e (Exp
Self)         = Exp
e
self Exp
e (Object Map MemberId Exp
o)     = Map MemberId Exp -> Exp
Object ((Exp -> Exp) -> Map MemberId Exp -> Map MemberId Exp
forall a b k. (a -> b) -> Map k a -> Map k b
Map.map (Exp -> Exp -> Exp
self Exp
e) Map MemberId Exp
o)
self Exp
e (Apply Exp
f [Exp]
es)   = case (Exp -> Exp -> Exp
self Exp
e Exp
f) of
  (Lambda Nat
n Exp
g) -> Exp -> Exp -> Exp
self Exp
e (Nat -> [Exp] -> Exp -> Exp
subst' Nat
n [Exp]
es Exp
g)
  Exp
g            -> Exp -> [Exp] -> Exp
Apply Exp
g ((Exp -> Exp) -> [Exp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
List.map (Exp -> Exp -> Exp
self Exp
e) [Exp]
es)
self Exp
e (Lookup Exp
f MemberId
l)   = Exp -> MemberId -> Exp
lookup (Exp -> Exp -> Exp
self Exp
e Exp
f) MemberId
l
self Exp
e (If Exp
f Exp
g Exp
h)     = Exp -> Exp -> Exp -> Exp
If (Exp -> Exp -> Exp
self Exp
e Exp
f) (Exp -> Exp -> Exp
self Exp
e Exp
g) (Exp -> Exp -> Exp
self Exp
e Exp
h)
self Exp
e (BinOp Exp
f String
op Exp
g) = Exp -> String -> Exp -> Exp
BinOp (Exp -> Exp -> Exp
self Exp
e Exp
f) String
op (Exp -> Exp -> Exp
self Exp
e Exp
g)
self Exp
e (PreOp String
op Exp
f)   = String -> Exp -> Exp
PreOp String
op (Exp -> Exp -> Exp
self Exp
e Exp
f)
self Exp
e Exp
f              = Exp
f

-- Find the fixed point of an expression, with no top-level occurrences
-- of self.

fix :: Exp -> Exp
fix :: Exp -> Exp
fix Exp
f = Exp
e where e :: Exp
e = Exp -> Exp -> Exp
self Exp
e Exp
f

-- Some helper functions

curriedApply :: Exp -> [Exp] -> Exp
curriedApply :: Exp -> [Exp] -> Exp
curriedApply = (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\ Exp
f Exp
e -> Exp -> [Exp] -> Exp
apply Exp
f [Exp
e])

curriedLambda :: Nat -> Exp -> Exp
curriedLambda :: Nat -> Exp -> Exp
curriedLambda Nat
n = Nat -> (Exp -> Exp) -> Exp -> Exp
forall i a. Integral i => i -> (a -> a) -> a -> a
iterate' Nat
n (Nat -> Exp -> Exp
Lambda Nat
1)

emp :: Exp
emp :: Exp
emp = Map MemberId Exp -> Exp
Object (Map MemberId Exp
forall k a. Map k a
empty)

union :: Exp -> Exp -> Exp
union :: Exp -> Exp -> Exp
union (Object Map MemberId Exp
o) (Object Map MemberId Exp
p) = Map MemberId Exp -> Exp
Object ((Exp -> Exp -> Exp)
-> Map MemberId Exp -> Map MemberId Exp -> Map MemberId Exp
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
unionWith Exp -> Exp -> Exp
union Map MemberId Exp
o Map MemberId Exp
p)
union Exp
e          Exp
f          = Exp
e

vine :: [MemberId] -> Exp -> Exp
vine :: [MemberId] -> Exp -> Exp
vine [MemberId]
ls Exp
e = (MemberId -> Exp -> Exp) -> Exp -> [MemberId] -> Exp
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\ MemberId
l Exp
e -> Map MemberId Exp -> Exp
Object (MemberId -> Exp -> Map MemberId Exp
forall k a. k -> a -> Map k a
singleton MemberId
l Exp
e)) Exp
e [MemberId]
ls

object :: [([MemberId],Exp)] -> Exp
object :: [([MemberId], Exp)] -> Exp
object = (([MemberId], Exp) -> Exp -> Exp)
-> Exp -> [([MemberId], Exp)] -> Exp
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\ ([MemberId]
ls,Exp
e) -> (Exp -> Exp -> Exp
union ([MemberId] -> Exp -> Exp
vine [MemberId]
ls Exp
e))) Exp
emp