module Lang where
import DBI
import qualified Prelude as P
import Prelude (($), (.), (+), (), (++), show, (>>=), (*), (/), undefined)
import qualified Control.Monad.Writer as P
import qualified Data.Functor.Identity as P
import qualified GHC.Float as P
import qualified Data.Tuple as P
import Data.Void
import Data.Proxy
import Data.Proxy
import Data.Constraint
import Data.Constraint.Forall
type instance Diff v (P.Writer w a) = P.Writer (Diff v w) (Diff v a)
type instance Diff v Void = Void
type instance Diff v P.Double = (P.Double, v)
type instance Diff v P.Float = (P.Float, v)
type instance Diff v (P.Either a b) = P.Either (Diff v a) (Diff v b)
type instance Diff v (P.Maybe a) = P.Maybe (Diff v a)
type instance Diff v (P.IO a) = P.IO (Diff v a)
type instance Diff v [a] = [Diff v a]
class DBI repr => Lang repr where
mkProd :: repr h (a -> b -> (a, b))
zro :: repr h ((a, b) -> a)
fst :: repr h ((a, b) -> b)
double :: P.Double -> repr h P.Double
doubleZero :: repr h P.Double
doubleZero = double 0
doubleOne :: repr h P.Double
doubleOne = double 1
doublePlus :: repr h (P.Double -> P.Double -> P.Double)
doubleMinus :: repr h (P.Double -> P.Double -> P.Double)
doubleMult :: repr h (P.Double -> P.Double -> P.Double)
doubleDivide :: repr h (P.Double -> P.Double -> P.Double)
doubleExp :: repr h (P.Double -> P.Double)
float :: P.Float -> repr h P.Float
floatZero :: repr h P.Float
floatZero = float 0
floatOne :: repr h P.Float
floatOne = float 1
floatPlus :: repr h (P.Float -> P.Float -> P.Float)
floatMinus :: repr h (P.Float -> P.Float -> P.Float)
floatMult :: repr h (P.Float -> P.Float -> P.Float)
floatDivide :: repr h (P.Float -> P.Float -> P.Float)
floatExp :: repr h (P.Float -> P.Float)
fix :: repr h ((a -> a) -> a)
left :: repr h (a -> P.Either a b)
right :: repr h (b -> P.Either a b)
sumMatch :: repr h ((a -> c) -> (b -> c) -> P.Either a b -> c)
unit :: repr h ()
exfalso :: repr h (Void -> a)
nothing :: repr h (P.Maybe a)
just :: repr h (a -> P.Maybe a)
optionMatch :: repr h (b -> (a -> b) -> P.Maybe a -> b)
ioRet :: repr h (a -> P.IO a)
ioBind :: repr h (P.IO a -> (a -> P.IO b) -> P.IO b)
ioMap :: repr h ((a -> b) -> P.IO a -> P.IO b)
nil :: repr h [a]
cons :: repr h (a -> [a] -> [a])
listMatch :: repr h (b -> (a -> [a] -> b) -> [a] -> b)
listAppend :: repr h ([a] -> [a] -> [a])
listAppend = lam2 $ \l r -> fix2 (lam $ \self -> listMatch2 r (lam2 $ \a as -> cons2 a (app self as))) l
writer :: repr h ((a, w) -> P.Writer w a)
runWriter :: repr h (P.Writer w a -> (a, w))
swap :: repr h ((l, r) -> (r, l))
swap = lam $ \p -> mkProd2 (fst1 p) (zro1 p)
curry :: repr h (((a, b) -> c) -> (a -> b -> c))
uncurry :: repr h ((a -> b -> c) -> ((a, b) -> c))
curry = lam3 $ \f a b -> app f (mkProd2 a b)
uncurry = lam2 $ \f p -> app2 f (zro1 p) (fst1 p)
float2Double :: repr h (P.Float -> P.Double)
double2Float :: repr h (P.Double -> P.Float)
class Reify repr x where
reify :: x -> repr h x
instance Lang repr => Reify repr () where
reify _ = unit
instance Lang repr => Reify repr P.Double where
reify = double
instance (Lang repr, Reify repr l, Reify repr r) => Reify repr (l, r) where
reify (l, r) = mkProd2 (reify l) (reify r)
instance Lang Eval where
zro = comb P.fst
fst = comb P.snd
mkProd = comb (,)
double = comb
doublePlus = comb (+)
doubleMinus = comb ()
doubleMult = comb (*)
doubleDivide = comb (/)
fix = comb loop
where loop x = x $ loop x
left = comb P.Left
right = comb P.Right
sumMatch = comb $ \l r -> \case
P.Left x -> l x
P.Right x -> r x
unit = comb ()
exfalso = comb absurd
nothing = comb P.Nothing
just = comb P.Just
ioRet = comb P.return
ioBind = comb (>>=)
nil = comb []
cons = comb (:)
listMatch = comb $ \l r -> \case
[] -> l
x:xs -> r x xs
optionMatch = comb $ \l r -> \case
P.Nothing -> l
P.Just x -> r x
ioMap = comb P.fmap
writer = comb (P.WriterT . P.Identity)
runWriter = comb P.runWriter
doubleExp = comb P.exp
float = comb
floatPlus = comb (+)
floatMinus = comb ()
floatMult = comb (*)
floatDivide = comb (/)
floatExp = comb P.exp
float2Double = comb P.float2Double
double2Float = comb P.double2Float
newtype UnHOAS repr h x = UnHOAS {runUnHOAS :: repr h x}
instance DBI repr => DBI (UnHOAS repr) where
z = UnHOAS z
s (UnHOAS x) = UnHOAS $ s x
abs (UnHOAS x) = UnHOAS $ abs x
app (UnHOAS f) (UnHOAS x) = UnHOAS $ app f x
instance Lang repr => Lang (UnHOAS repr) where
mkProd = UnHOAS mkProd
zro = UnHOAS zro
fst = UnHOAS fst
double = UnHOAS . double
doublePlus = UnHOAS doublePlus
doubleMinus = UnHOAS doubleMinus
doubleMult = UnHOAS doubleMult
doubleDivide = UnHOAS doubleDivide
doubleExp = UnHOAS doubleExp
fix = UnHOAS fix
left = UnHOAS left
right = UnHOAS right
sumMatch = UnHOAS sumMatch
unit = UnHOAS unit
exfalso = UnHOAS exfalso
nothing = UnHOAS nothing
just = UnHOAS just
ioRet = UnHOAS ioRet
ioBind = UnHOAS ioBind
nil = UnHOAS nil
cons = UnHOAS cons
listMatch = UnHOAS listMatch
optionMatch = UnHOAS optionMatch
ioMap = UnHOAS ioMap
writer = UnHOAS writer
runWriter = UnHOAS runWriter
float = UnHOAS . float
floatPlus = UnHOAS floatPlus
floatMinus = UnHOAS floatMinus
floatMult = UnHOAS floatMult
floatDivide = UnHOAS floatDivide
floatExp = UnHOAS floatExp
float2Double = UnHOAS float2Double
double2Float = UnHOAS double2Float
instance Lang Show where
mkProd = name "mkProd"
zro = name "zro"
fst = name "fst"
double = name . show
doublePlus = name "plus"
doubleMinus = name "minus"
doubleMult = name "mult"
doubleDivide = name "divide"
doubleExp = name "exp"
fix = name "fix"
left = name "left"
right = name "right"
sumMatch = name "sumMatch"
unit = name "unit"
exfalso = name "exfalso"
nothing = name "nothing"
just = name "just"
ioRet = name "ioRet"
ioBind = name "ioBind"
nil = name "nil"
cons = name "cons"
listMatch = name "listMatch"
optionMatch = name "optionMatch"
ioMap = name "ioMap"
writer = name "writer"
runWriter = name "runWriter"
float = name . show
floatPlus = name "plus"
floatMinus = name "minus"
floatMult = name "mult"
floatDivide = name "divide"
floatExp = name "exp"
float2Double = name "float2Double"
double2Float = name "double2Float"
instance Lang repr => Lang (GWDiff repr) where
mkProd = GWDiff (P.const mkProd)
zro = GWDiff $ P.const $ zro
fst = GWDiff $ P.const $ fst
double x = GWDiff $ P.const $ mkProd2 (double x) zero
doublePlus = GWDiff $ P.const $ lam2 $ \l r ->
mkProd2 (plus2 (zro1 l) (zro1 r)) (plus2 (fst1 l) (fst1 r))
doubleMinus = GWDiff $ P.const $ lam2 $ \l r ->
mkProd2 (minus2 (zro1 l) (zro1 r)) (minus2 (fst1 l) (fst1 r))
doubleMult = GWDiff $ P.const $ lam2 $ \l r ->
mkProd2 (mult2 (zro1 l) (zro1 r))
(plus2 (mult2 (zro1 l) (fst1 r)) (mult2 (zro1 r) (fst1 l)))
doubleDivide = GWDiff $ P.const $ lam2 $ \l r ->
mkProd2 (divide2 (zro1 l) (zro1 r))
(divide2 (minus2 (mult2 (zro1 r) (fst1 l)) (mult2 (zro1 l) (fst1 r)))
(mult2 (zro1 r) (zro1 r)))
doubleExp = GWDiff $ P.const $ lam $ \x -> mkProd2 (doubleExp1 (zro1 x)) (mult2 (doubleExp1 (zro1 x)) (fst1 x))
fix = GWDiff $ P.const fix
left = GWDiff $ P.const left
right = GWDiff $ P.const right
sumMatch = GWDiff $ P.const sumMatch
unit = GWDiff $ P.const unit
exfalso = GWDiff $ P.const exfalso
nothing = GWDiff $ P.const nothing
just = GWDiff $ P.const just
ioRet = GWDiff $ P.const ioRet
ioBind = GWDiff $ P.const ioBind
nil = GWDiff $ P.const nil
cons = GWDiff $ P.const cons
listMatch = GWDiff $ P.const listMatch
optionMatch = GWDiff $ P.const optionMatch
ioMap = GWDiff $ P.const ioMap
writer = GWDiff $ P.const writer
runWriter = GWDiff $ P.const runWriter
float x = GWDiff $ P.const $ mkProd2 (float x) zero
floatPlus = GWDiff $ P.const $ lam2 $ \l r ->
mkProd2 (plus2 (zro1 l) (zro1 r)) (plus2 (fst1 l) (fst1 r))
floatMinus = GWDiff $ P.const $ lam2 $ \l r ->
mkProd2 (minus2 (zro1 l) (zro1 r)) (minus2 (fst1 l) (fst1 r))
floatMult = GWDiff $ P.const $ lam2 $ \l r ->
mkProd2 (mult2 (float2Double1 (zro1 l)) (zro1 r))
(plus2 (mult2 (float2Double1 (zro1 l)) (fst1 r)) (mult2 (float2Double1 (zro1 r)) (fst1 l)))
floatDivide = GWDiff $ P.const $ lam2 $ \l r ->
mkProd2 (divide2 (zro1 l) (float2Double1 (zro1 r)))
(divide2 (minus2 (mult2 (float2Double1 (zro1 r)) (fst1 l)) (mult2 (float2Double1 (zro1 l)) (fst1 r)))
(float2Double1 (mult2 (float2Double1 (zro1 r)) (zro1 r))))
floatExp = GWDiff $ P.const $ lam $ \x -> mkProd2 (floatExp1 (zro1 x)) (mult2 (float2Double1 (floatExp1 (zro1 x))) (fst1 x))
float2Double = GWDiff $ P.const $ bimap2 float2Double id
double2Float = GWDiff $ P.const $ bimap2 double2Float id
instance (Vector repr v, Lang repr) => Lang (WDiff repr v) where
mkProd = WDiff mkProd
zro = WDiff zro
fst = WDiff fst
double x = WDiff $ mkProd2 (double x) zero
doublePlus = WDiff $ lam2 $ \l r ->
mkProd2 (plus2 (zro1 l) (zro1 r)) (plus2 (fst1 l) (fst1 r))
doubleMinus = WDiff $ lam2 $ \l r ->
mkProd2 (minus2 (zro1 l) (zro1 r)) (minus2 (fst1 l) (fst1 r))
doubleMult = WDiff $ lam2 $ \l r ->
mkProd2 (mult2 (zro1 l) (zro1 r))
(plus2 (mult2 (zro1 l) (fst1 r)) (mult2 (zro1 r) (fst1 l)))
doubleDivide = WDiff $ lam2 $ \l r ->
mkProd2 (divide2 (zro1 l) (zro1 r))
(divide2 (minus2 (mult2 (zro1 r) (fst1 l)) (mult2 (zro1 l) (fst1 r)))
(mult2 (zro1 r) (zro1 r)))
doubleExp = WDiff $ lam $ \x -> mkProd2 (doubleExp1 (zro1 x)) (mult2 (doubleExp1 (zro1 x)) (fst1 x))
fix = WDiff fix
left = WDiff left
right = WDiff right
sumMatch = WDiff sumMatch
unit = WDiff unit
exfalso = WDiff exfalso
nothing = WDiff nothing
just = WDiff just
ioRet = WDiff ioRet
ioBind = WDiff ioBind
nil = WDiff nil
cons = WDiff cons
listMatch = WDiff listMatch
optionMatch = WDiff optionMatch
ioMap = WDiff ioMap
writer = WDiff writer
runWriter = WDiff runWriter
float x = WDiff $ mkProd2 (float x) zero
floatPlus = WDiff $ lam2 $ \l r ->
mkProd2 (plus2 (zro1 l) (zro1 r)) (plus2 (fst1 l) (fst1 r))
floatMinus = WDiff $ lam2 $ \l r ->
mkProd2 (minus2 (zro1 l) (zro1 r)) (minus2 (fst1 l) (fst1 r))
floatMult = WDiff $ lam2 $ \l r ->
mkProd2 (mult2 (float2Double1 (zro1 l)) (zro1 r))
(plus2 (mult2 (float2Double1 (zro1 l)) (fst1 r)) (mult2 (float2Double1 (zro1 r)) (fst1 l)))
floatDivide = WDiff $ lam2 $ \l r ->
mkProd2 (divide2 (zro1 l) (float2Double1 (zro1 r)))
(divide2 (minus2 (mult2 (float2Double1 (zro1 r)) (fst1 l)) (mult2 (float2Double1 (zro1 l)) (fst1 r)))
(float2Double1 (mult2 (float2Double1 (zro1 r)) (zro1 r))))
floatExp = WDiff $ lam $ \x -> mkProd2 (floatExp1 (zro1 x)) (mult2 (float2Double1 (floatExp1 (zro1 x))) (fst1 x))
float2Double = WDiff $ bimap2 float2Double id
double2Float = WDiff $ bimap2 double2Float id
instance Lang repr => ProdCon (Monoid repr) l r where prodCon = Sub Dict
instance Lang repr => ProdCon (WithDiff repr) l r where prodCon = Sub Dict
instance Lang repr => ProdCon (Reify repr) l r where prodCon = Sub Dict
instance Lang repr => ProdCon (Vector repr) l r where prodCon = Sub Dict
instance Lang repr => Lang (ImpW repr) where
nil = NoImpW nil
cons = NoImpW cons
listMatch = NoImpW listMatch
zro = NoImpW zro
fst = NoImpW fst
mkProd = NoImpW mkProd
ioRet = NoImpW ioRet
ioMap = NoImpW ioMap
ioBind = NoImpW ioBind
unit = NoImpW unit
nothing = NoImpW nothing
just = NoImpW just
optionMatch = NoImpW optionMatch
exfalso = NoImpW exfalso
fix = NoImpW fix
left = NoImpW left
right = NoImpW right
sumMatch = NoImpW sumMatch
writer = NoImpW writer
runWriter = NoImpW runWriter
double = NoImpW . double
doubleExp = NoImpW doubleExp
doublePlus = NoImpW doublePlus
doubleMinus = NoImpW doubleMinus
doubleMult = NoImpW doubleMult
doubleDivide = NoImpW doubleDivide
float = NoImpW . float
floatExp = NoImpW floatExp
floatPlus = NoImpW floatPlus
floatMinus = NoImpW floatMinus
floatMult = NoImpW floatMult
floatDivide = NoImpW floatDivide
float2Double = NoImpW float2Double
double2Float = NoImpW double2Float
instance (Lang l, Lang r) => Lang (Combine l r) where
mkProd = Combine mkProd mkProd
zro = Combine zro zro
fst = Combine fst fst
double x = Combine (double x) (double x)
doublePlus = Combine doublePlus doublePlus
doubleMinus = Combine doubleMinus doubleMinus
doubleMult = Combine doubleMult doubleMult
doubleDivide = Combine doubleDivide doubleDivide
doubleExp = Combine doubleExp doubleExp
float x = Combine (float x) (float x)
floatPlus = Combine floatPlus floatPlus
floatMinus = Combine floatMinus floatMinus
floatMult = Combine floatMult floatMult
floatDivide = Combine floatDivide floatDivide
floatExp = Combine floatExp floatExp
fix = Combine fix fix
left = Combine left left
right = Combine right right
sumMatch = Combine sumMatch sumMatch
unit = Combine unit unit
exfalso = Combine exfalso exfalso
nothing = Combine nothing nothing
just = Combine just just
optionMatch = Combine optionMatch optionMatch
ioRet = Combine ioRet ioRet
ioBind = Combine ioBind ioBind
ioMap = Combine ioMap ioMap
nil = Combine nil nil
cons = Combine cons cons
listMatch = Combine listMatch listMatch
runWriter = Combine runWriter runWriter
writer = Combine writer writer
double2Float = Combine double2Float double2Float
float2Double = Combine float2Double float2Double
instance Lang repr => WithDiff repr () where
withDiff = const1 id
instance Lang repr => WithDiff repr P.Double where
withDiff = lam2 $ \conv d -> mkProd2 d (app conv doubleOne)
instance (Lang repr, WithDiff repr l, WithDiff repr r) => WithDiff repr (l, r) where
withDiff = lam $ \conv -> bimap2 (withDiff1 (lam $ \l -> app conv (mkProd2 l zero))) (withDiff1 (lam $ \r -> app conv (mkProd2 zero r)))
class Monoid r g => Group r g where
invert :: r h (g -> g)
minus :: r h (g -> g -> g)
default invert :: Lang r => r h (g -> g)
invert = minus1 zero
default minus :: Lang r => r h (g -> g -> g)
minus = lam2 $ \x y -> plus2 x (invert1 y)
class Group r v => Vector r v where
mult :: r h (P.Double -> v -> v)
divide :: r h (v -> P.Double -> v)
default mult :: Lang r => r h (P.Double -> v -> v)
mult = lam2 $ \x y -> divide2 y (recip1 x)
default divide :: Lang r => r h (v -> P.Double -> v)
divide = lam2 $ \x y -> mult2 (recip1 y) x
instance Lang r => Monoid r () where
zero = unit
plus = const1 $ const1 unit
instance Lang r => Group r () where
invert = const1 unit
minus = const1 $ const1 unit
instance Lang r => Vector r () where
mult = const1 $ const1 unit
divide = const1 $ const1 unit
instance Lang r => Monoid r P.Double where
zero = doubleZero
plus = doublePlus
instance Lang r => Group r P.Double where
minus = doubleMinus
instance Lang r => Vector r P.Double where
mult = doubleMult
divide = doubleDivide
instance Lang r => Monoid r P.Float where
zero = floatZero
plus = floatPlus
instance Lang r => Group r P.Float where
minus = floatMinus
instance Lang r => Vector r P.Float where
mult = com2 floatMult double2Float
divide = com2 (flip2 com double2Float) floatDivide
instance (Lang repr, Monoid repr l, Monoid repr r) => Monoid repr (l, r) where
zero = mkProd2 zero zero
plus = lam2 $ \l r -> mkProd2 (plus2 (zro1 l) (zro1 r)) (plus2 (fst1 l) (fst1 r))
instance (Lang repr, Group repr l, Group repr r) => Group repr (l, r) where
invert = bimap2 invert invert
instance (Lang repr, Vector repr l, Vector repr r) => Vector repr (l, r) where
mult = lam $ \x -> bimap2 (mult1 x) (mult1 x)
instance (Lang repr, Monoid repr l, Monoid repr r) => Monoid repr (l -> r) where
zero = const1 zero
plus = lam3 $ \l r x -> plus2 (app l x) (app r x)
instance (Lang repr, Group repr l, Group repr r) => Group repr (l -> r) where
invert = lam2 $ \l x -> app l (invert1 x)
instance (Lang repr, Vector repr l, Vector repr r) => Vector repr (l -> r) where
mult = lam3 $ \l r x -> app r (mult2 l x)
instance Lang r => Monoid r [a] where
zero = nil
plus = listAppend
instance Lang r => Functor r [] where
map = lam $ \f -> fix1 $ lam $ \self -> listMatch2 nil (lam2 $ \x xs -> cons2 (app f x) $ app self xs)
instance Lang r => BiFunctor r (,) where
bimap = lam3 $ \l r p -> mkProd2 (app l (zro1 p)) (app r (fst1 p))
instance Lang r => Functor r (P.Writer w) where
map = lam $ \f -> com2 writer (com2 (bimap2 f id) runWriter)
instance (Lang r, Monoid r w) => Applicative r (P.Writer w) where
pure = com2 writer (flip2 mkProd zero)
ap = lam2 $ \f x -> writer1 (mkProd2 (app (zro1 (runWriter1 f)) (zro1 (runWriter1 x))) (plus2 (fst1 (runWriter1 f)) (fst1 (runWriter1 x))))
instance (Lang r, Monoid r w) => Monad r (P.Writer w) where
join = lam $ \x -> writer1 $ mkProd2 (zro1 $ runWriter1 $ zro1 $ runWriter1 x) (plus2 (fst1 $ runWriter1 $ zro1 $ runWriter1 x) (fst1 $ runWriter1 x))
instance Lang r => Functor r P.IO where
map = ioMap
instance Lang r => Applicative r P.IO where
pure = ioRet
ap = lam2 $ \f x -> ioBind2 f (flip2 ioMap x)
instance Lang r => Monad r P.IO where
bind = ioBind
instance Lang r => Functor r P.Maybe where
map = lam $ \func -> optionMatch2 nothing (com2 just func)
instance Lang r => Applicative r P.Maybe where
pure = just
ap = optionMatch2 (const1 nothing) map
instance Lang r => Monad r P.Maybe where
bind = lam2 $ \x func -> optionMatch3 nothing func x
runImpW :: forall repr h x. Lang repr => ImpW repr h x -> RunImpW repr h x
runImpW (ImpW x) = RunImpW x
runImpW (NoImpW x) = RunImpW (const1 x :: repr h (() -> x))
newtype GWDiff repr h x = GWDiff {runGWDiff :: forall v. Vector repr v => Proxy v -> repr (Diff v h) (Diff v x)}
instance DBI repr => DBI (GWDiff repr) where
z = GWDiff (P.const z)
s (GWDiff x) = GWDiff (\p -> s $ x p)
app (GWDiff f) (GWDiff x) = GWDiff (\p -> app (f p) (x p))
abs (GWDiff x) = GWDiff (\p -> abs $ x p)
cons2 = app2 cons
listMatch2 = app2 listMatch
fix1 = app fix
fix2 = app2 fix
uncurry1 = app uncurry
optionMatch2 = app2 optionMatch
optionMatch3 = app3 optionMatch
zro1 = app zro
fst1 = app fst
mult1 = app mult
mult2 = app2 mult
divide2 = app2 divide
invert1 = app invert
mkProd1 = app mkProd
mkProd2 = app2 mkProd
minus1 = app minus
divide1 = app divide
recip = divide1 doubleOne
recip1 = app recip
writer1 = app writer
runWriter1 = app runWriter
ioBind2 = app2 ioBind
minus2 = app2 minus
float2Double1 = app float2Double
doubleExp1 = app doubleExp
floatExp1 = app floatExp
instance Lang repr => DBI (ImpW repr) where
z = NoImpW z
s :: forall a h b. ImpW repr h b -> ImpW repr (a, h) b
s (ImpW x) = work x
where
work :: Weight w => repr h (w -> b) -> ImpW repr (a, h) b
work x = ImpW (s x)
s (NoImpW x) = NoImpW (s x)
app (ImpW f) (ImpW x) = ImpW (lam $ \p -> app (app (conv f) (zro1 p)) (app (conv x) (fst1 p)))
app (NoImpW f) (NoImpW x) = NoImpW (app f x)
app (ImpW f) (NoImpW x) = ImpW (lam $ \w -> app2 (conv f) w (conv x))
app (NoImpW f) (ImpW x) = ImpW (lam $ \w -> app (conv f) (app (conv x) w))
abs (ImpW f) = ImpW (flip1 $ abs f)
abs (NoImpW x) = NoImpW (abs x)