module Graphics.HaGL.Shader (
    Shader(..),
    ShaderFn(..),
    ShaderParam(..),
    ShaderDecl(..),
    ShaderStmt(..),
    ShaderExpr(..),
    VarName,
    addFn,
    addDecl,
    addStmt
) where

import Data.Char (isAlpha)
import Data.List (intercalate)

import Graphics.HaGL.GLType


data Shader = Shader [ShaderFn] [ShaderDecl] [ShaderStmt]

data ShaderFn =
    ShaderFn FnName ExprType [ShaderParam] 
        [ShaderStmt] ShaderExpr |
    ShaderLoopFn FnName ExprType [ShaderParam] 
        ShaderExpr ShaderExpr [ShaderStmt] [ShaderStmt] [ShaderStmt]

data ShaderParam =
    ShaderParam VarName ExprType 

data ShaderDecl = 
    UniformDecl VarName ExprType |
    InpDecl TypeQual VarName ExprType |
    OutDecl TypeQual VarName ExprType

data ShaderStmt = 
    VarAsmt VarName ShaderExpr |
    VarDecl VarName ExprType |
    VarDeclAsmt VarName ExprType ShaderExpr |
    DiscardStmt ShaderExpr

data ShaderExpr where
    ShaderConst :: GLType t => t -> ShaderExpr
    ShaderVarRef :: VarName -> ShaderExpr
    ShaderExpr :: String -> [ShaderExpr] -> ShaderExpr

type TypeQual = String
type ExprType = String
type FnName = String
type VarName = String

instance Show Shader where
    show :: Shader -> String
show (Shader [ShaderFn]
fns [ShaderDecl]
decls [ShaderStmt]
stmts) =
        String
"#version 430 core\n\n" forall a. [a] -> [a] -> [a]
++
        String -> ShowS
endWith String
"\n" (forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\ShaderDecl
s -> forall a. Show a => a -> String
show ShaderDecl
s forall a. [a] -> [a] -> [a]
++ String
"\n") [ShaderDecl]
decls) forall a. [a] -> [a] -> [a]
++
        forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\ShaderFn
s -> forall a. Show a => a -> String
show ShaderFn
s forall a. [a] -> [a] -> [a]
++ String
"\n\n") [ShaderFn]
fns forall a. [a] -> [a] -> [a]
++
        String
"void main() {\n" forall a. [a] -> [a] -> [a]
++
        forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\ShaderStmt
s -> String
"  " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show ShaderStmt
s forall a. [a] -> [a] -> [a]
++ String
"\n") [ShaderStmt]
stmts forall a. [a] -> [a] -> [a]
++
        String
"}\n"

instance Show ShaderFn where
    show :: ShaderFn -> String
show (ShaderFn String
name String
retType [ShaderParam]
params [ShaderStmt]
stmts ShaderExpr
ret) =
        String
retType forall a. [a] -> [a] -> [a]
++ String
" " forall a. [a] -> [a] -> [a]
++ String
name forall a. [a] -> [a] -> [a]
++ String
"(" forall a. [a] -> [a] -> [a]
++
        forall a. [a] -> [[a]] -> [a]
intercalate String
", " (forall a b. (a -> b) -> [a] -> [b]
map forall a. Show a => a -> String
show [ShaderParam]
params) forall a. [a] -> [a] -> [a]
++ String
") {\n" forall a. [a] -> [a] -> [a]
++
        forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\ShaderStmt
s -> String
"  " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show ShaderStmt
s forall a. [a] -> [a] -> [a]
++ String
"\n") [ShaderStmt]
stmts forall a. [a] -> [a] -> [a]
++
        String
"  return " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show ShaderExpr
ret forall a. [a] -> [a] -> [a]
++ String
";" forall a. [a] -> [a] -> [a]
++
        String
"\n}"
    show (ShaderLoopFn String
name String
retType [ShaderParam]
params ShaderExpr
cond ShaderExpr
ret [ShaderStmt]
condStmts [ShaderStmt]
retStmts [ShaderStmt]
updateStmts) =
        String
retType forall a. [a] -> [a] -> [a]
++ String
" " forall a. [a] -> [a] -> [a]
++ String
name forall a. [a] -> [a] -> [a]
++ String
"(" forall a. [a] -> [a] -> [a]
++
        forall a. [a] -> [[a]] -> [a]
intercalate String
", " (forall a b. (a -> b) -> [a] -> [b]
map forall a. Show a => a -> String
show [ShaderParam]
params) forall a. [a] -> [a] -> [a]
++ String
") {\n" forall a. [a] -> [a] -> [a]
++
        String
"  while (true) {\n" forall a. [a] -> [a] -> [a]
++
        forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\ShaderStmt
s -> String
"      " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show ShaderStmt
s forall a. [a] -> [a] -> [a]
++ String
"\n") [ShaderStmt]
condStmts forall a. [a] -> [a] -> [a]
++ 
        String
"      if (" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show ShaderExpr
cond forall a. [a] -> [a] -> [a]
++ String
") {\n" forall a. [a] -> [a] -> [a]
++
        forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\ShaderStmt
s -> String
"        " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show ShaderStmt
s forall a. [a] -> [a] -> [a]
++ String
"\n") [ShaderStmt]
retStmts forall a. [a] -> [a] -> [a]
++ 
        String
"        return " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show ShaderExpr
ret forall a. [a] -> [a] -> [a]
++ String
";\n" forall a. [a] -> [a] -> [a]
++ 
        String
"      }\n" forall a. [a] -> [a] -> [a]
++
        forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\ShaderStmt
s -> String
"      " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show ShaderStmt
s forall a. [a] -> [a] -> [a]
++ String
"\n") [ShaderStmt]
updateStmts forall a. [a] -> [a] -> [a]
++
        String
"  }\n}"

instance Show ShaderParam where
    show :: ShaderParam -> String
show (ShaderParam String
name String
exprType) =
        String
exprType forall a. [a] -> [a] -> [a]
++ String
" " forall a. [a] -> [a] -> [a]
++ String
name

instance Show ShaderDecl where
    show :: ShaderDecl -> String
show (UniformDecl String
varName String
exprType) = 
        String
"uniform " forall a. [a] -> [a] -> [a]
++ String
exprType forall a. [a] -> [a] -> [a]
++ String
" " forall a. [a] -> [a] -> [a]
++ String
varName forall a. [a] -> [a] -> [a]
++ String
";"
    show (InpDecl String
qual String
varName String
exprType) = 
        String -> ShowS
endWith String
" " String
qual forall a. [a] -> [a] -> [a]
++ String
"in " forall a. [a] -> [a] -> [a]
++ String
exprType forall a. [a] -> [a] -> [a]
++ String
" " forall a. [a] -> [a] -> [a]
++ String
varName forall a. [a] -> [a] -> [a]
++ String
";"
    show (OutDecl String
qual String
varName String
exprType) = 
        String -> ShowS
endWith String
" " String
qual forall a. [a] -> [a] -> [a]
++ String
"out " forall a. [a] -> [a] -> [a]
++ String
exprType forall a. [a] -> [a] -> [a]
++ String
" " forall a. [a] -> [a] -> [a]
++ String
varName forall a. [a] -> [a] -> [a]
++ String
";"

instance Show ShaderStmt where
    show :: ShaderStmt -> String
show (VarAsmt String
varName ShaderExpr
expr) = 
        String
varName forall a. [a] -> [a] -> [a]
++ String
" = " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show ShaderExpr
expr forall a. [a] -> [a] -> [a]
++ String
";"
    show (VarDecl String
varName String
exprType) =
        String
exprType forall a. [a] -> [a] -> [a]
++ String
" " forall a. [a] -> [a] -> [a]
++ String
varName forall a. [a] -> [a] -> [a]
++ String
";" 
    show (VarDeclAsmt String
varName String
exprType ShaderExpr
expr) = 
        String
exprType forall a. [a] -> [a] -> [a]
++ String
" " forall a. [a] -> [a] -> [a]
++ String
varName forall a. [a] -> [a] -> [a]
++ String
" = " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show ShaderExpr
expr forall a. [a] -> [a] -> [a]
++ String
";"
    show (DiscardStmt ShaderExpr
cond) =
        String
"if (" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show ShaderExpr
cond forall a. [a] -> [a] -> [a]
++ String
") discard;"

instance Show ShaderExpr where
    show :: ShaderExpr -> String
show (ShaderConst t
c) = forall t. GLType t => t -> String
showGlslVal t
c
    show (ShaderVarRef String
varName) = String
varName
    show (ShaderExpr String
funcName [ShaderExpr]
xs)
        | Char -> Bool
isAlpha (forall a. [a] -> a
head String
funcName) = 
            String
funcName forall a. [a] -> [a] -> [a]
++ String
"(" forall a. [a] -> [a] -> [a]
++ forall a. [a] -> [[a]] -> [a]
intercalate String
", " (forall a b. (a -> b) -> [a] -> [b]
map forall a. Show a => a -> String
show [ShaderExpr]
xs) forall a. [a] -> [a] -> [a]
++ String
")"
        | forall a. [a] -> a
head String
funcName forall a. Eq a => a -> a -> Bool
== Char
'.' = forall {a}. Show a => String -> [a] -> String
showCompSel String
funcName [ShaderExpr]
xs
        | String
funcName forall a. Eq a => a -> a -> Bool
== String
"[]" = forall {a}. Show a => [a] -> String
showSubscript [ShaderExpr]
xs
        | forall a. [a] -> a
head String
funcName forall a. Eq a => a -> a -> Bool
== Char
'[' = forall {a}. Show a => [a] -> ShowS
showMatCol [ShaderExpr]
xs String
funcName
        | String
funcName forall a. Eq a => a -> a -> Bool
== String
"?:" = forall {a}. Show a => [a] -> String
showTernCond [ShaderExpr]
xs
        | Bool
otherwise = forall {a}. Show a => String -> [a] -> String
showInfix String
funcName [ShaderExpr]
xs
        where 
            showCompSel :: String -> [a] -> String
showCompSel String
comp [a
x] = forall a. Show a => a -> String
show a
x forall a. [a] -> [a] -> [a]
++ String
comp
            showSubscript :: [a] -> String
showSubscript [a
arr, a
i] = forall a. Show a => a -> String
show a
arr forall a. [a] -> [a] -> [a]
++ String
"[" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show a
i forall a. [a] -> [a] -> [a]
++ String
"]"
            showMatCol :: [a] -> ShowS
showMatCol [a
x] String
col = forall a. Show a => a -> String
show a
x forall a. [a] -> [a] -> [a]
++ String
col
            showTernCond :: [a] -> String
showTernCond [a
x, a
y, a
z] = forall a. Show a => a -> String
show a
x forall a. [a] -> [a] -> [a]
++ String
" ? " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show a
y forall a. [a] -> [a] -> [a]
++ String
" : " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show a
z
            showInfix :: String -> [a] -> String
showInfix String
op [a
x] = String
op forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show a
x
            showInfix String
op [a]
xs = forall a. [a] -> [[a]] -> [a]
intercalate (String
" " forall a. [a] -> [a] -> [a]
++ String
op forall a. [a] -> [a] -> [a]
++ String
" ") (forall a b. (a -> b) -> [a] -> [b]
map forall a. Show a => a -> String
show [a]
xs)

endWith :: String -> String -> String
endWith :: String -> ShowS
endWith String
_ String
"" = String
""
endWith String
sep String
s = String
s forall a. [a] -> [a] -> [a]
++ String
sep

addFn :: ShaderFn -> Shader -> Shader
addFn :: ShaderFn -> Shader -> Shader
addFn ShaderFn
fn (Shader [ShaderFn]
fns [ShaderDecl]
decls [ShaderStmt]
stmts) =
    [ShaderFn] -> [ShaderDecl] -> [ShaderStmt] -> Shader
Shader ([ShaderFn]
fns forall a. [a] -> [a] -> [a]
++ [ShaderFn
fn]) [ShaderDecl]
decls [ShaderStmt]
stmts

addDecl :: ShaderDecl -> Shader -> Shader
addDecl :: ShaderDecl -> Shader -> Shader
addDecl ShaderDecl
decl (Shader [ShaderFn]
fns [ShaderDecl]
decls [ShaderStmt]
stmts) =
    [ShaderFn] -> [ShaderDecl] -> [ShaderStmt] -> Shader
Shader [ShaderFn]
fns ([ShaderDecl]
decls forall a. [a] -> [a] -> [a]
++ [ShaderDecl
decl]) [ShaderStmt]
stmts

addStmt :: ShaderStmt -> Shader -> Shader
addStmt :: ShaderStmt -> Shader -> Shader
addStmt ShaderStmt
stmt (Shader [ShaderFn]
fns [ShaderDecl]
decls [ShaderStmt]
stmts) =
    [ShaderFn] -> [ShaderDecl] -> [ShaderStmt] -> Shader
Shader [ShaderFn]
fns [ShaderDecl]
decls ([ShaderStmt]
stmts forall a. [a] -> [a] -> [a]
++ [ShaderStmt
stmt])