module Language.Mecha.Octree
  ( Octree (..)
  , Vertex
  , Color
  , sphere
  , mesh
  , union
  , intersection
  , difference
  ) where

import Control.Monad
import qualified Data.IntMap as IM
import qualified Data.Map as M
import qualified Graphics.Rendering.OpenGL as GL

import Language.Mecha.OpenGL

type Vertex = (Double, Double, Double)
type Color  = (Float, Float, Float)

data Octree
  = Octree  { center :: Vertex, radius :: Double, u1, u2, u3, u4, l1, l2, l3, l4 :: Octree }
  | Surface { point  :: Vertex, normal :: Vertex, color :: Color }
  | Inside
  | Outside deriving (Show, Eq)

sphere :: Color -> Double -> Octree
sphere color precision = Octree
  (0, 0, 0)
  1
  (sphere p (p, p, p))
  (sphere p (n, p, p))
  (sphere p (n, n, p))
  (sphere p (p, n, p))
  (sphere p (p, p, n))
  (sphere p (n, p, n))
  (sphere p (n, n, n))
  (sphere p (p, n, n))
  where
  p = 0.5
  n = -0.5
  sphere :: Double -> (Double, Double, Double) -> Octree
  sphere r c@(x, y, z) | rFar  < 1     = Inside
                       | rNear > 1     = Outside
                       | r < precision = Surface { point = (x', y', z'), normal = (x', y', z'), color = color }
                       | otherwise     = sub
    where
    m = sqrt $ x ** 2 + y ** 2 + z ** 2
    x' = x / m
    y' = y / m
    z' = z / m
    r' = r / 2
    rNear = sqrt ((abs x - r) ** 2 + (abs y - r) ** 2 + (abs z - r) ** 2)
    rFar  = sqrt ((abs x + r) ** 2 + (abs y + r) ** 2 + (abs z + r) ** 2)
    xu = x + r'
    xl = x - r'
    yu = y + r'
    yl = y - r'
    zu = z + r'
    zl = z - r'
    sub = Octree
      c
      r
      (sphere r' (xu, yu, zu))
      (sphere r' (xl, yu, zu))
      (sphere r' (xl, yl, zu))
      (sphere r' (xu, yl, zu))
      (sphere r' (xu, yu, zl))
      (sphere r' (xl, yu, zl))
      (sphere r' (xl, yl, zl))
      (sphere r' (xu, yl, zl))
              


union :: Octree -> Octree -> Octree
union (Octree c r a0 a1 a2 a3 a4 a5 a6 a7) (Octree _ _ b0 b1 b2 b3 b4 b5 b6 b7) = if allInside' then Inside else x
  where
  x0 = union a0 b0
  x1 = union a1 b1
  x2 = union a2 b2
  x3 = union a3 b3
  x4 = union a4 b4
  x5 = union a5 b5
  x6 = union a6 b6
  x7 = union a7 b7
  x  = Octree c r x0 x1 x2 x3 x4 x5 x6 x7
  allInside' = allInside [x0, x1, x2, x3, x4, x5, x6, x7]
union Inside _  = Inside
union _ Inside  = Inside
union Outside a = a
union a Outside = a
union a _       = a

intersection :: Octree -> Octree -> Octree
intersection (Octree c r a0 a1 a2 a3 a4 a5 a6 a7) (Octree _ _ b0 b1 b2 b3 b4 b5 b6 b7) = if allOutside' then Outside else x
  where
  x0 = intersection a0 b0
  x1 = intersection a1 b1
  x2 = intersection a2 b2
  x3 = intersection a3 b3
  x4 = intersection a4 b4
  x5 = intersection a5 b5
  x6 = intersection a6 b6
  x7 = intersection a7 b7
  x  = Octree c r x0 x1 x2 x3 x4 x5 x6 x7
  allOutside' = allOutside [x0, x1, x2, x3, x4, x5, x6, x7]
intersection Inside a  = a
intersection a Inside  = a
intersection Outside _ = Outside
intersection _ Outside = Outside
intersection a _       = a

difference :: Octree -> Octree -> Octree
difference (Octree c r a0 a1 a2 a3 a4 a5 a6 a7) (Octree _ _ b0 b1 b2 b3 b4 b5 b6 b7) = if allOutside' then Outside else x
  where
  x0 = difference a0 b0
  x1 = difference a1 b1
  x2 = difference a2 b2
  x3 = difference a3 b3
  x4 = difference a4 b4
  x5 = difference a5 b5
  x6 = difference a6 b6
  x7 = difference a7 b7
  x  = Octree c r x0 x1 x2 x3 x4 x5 x6 x7
  allOutside' = allOutside [x0, x1, x2, x3, x4, x5, x6, x7]
difference _ Inside  = Outside
difference a Outside = a
difference _ a       = a { normal = (-x, -y, -z) } where (x, y, z) = normal a

type Path = [Octant]
type Octant = (Bool, Bool, Bool)
data Axis = X | Y | Z deriving Eq
type Direction = (Axis, Bool)
type Context = [(Octree, Octant)]

neighbor :: Context -> Direction -> (Context, Octree)
neighbor context (axis, sign) = neighbor context []
  where
  neighbor :: Context -> Path -> (Context, Octree)
  neighbor [] _ = ([], Outside)
  neighbor ((octree, (x, y, z)) : context) path = case axis of
    X | xor sign x -> subOctree context xPath octree
      | otherwise  -> neighbor  context xPath
    Y | xor sign y -> subOctree context yPath octree
      | otherwise  -> neighbor  context yPath
    Z | xor sign z -> subOctree context zPath octree
      | otherwise  -> neighbor  context zPath
    where
    xPath = (not x, y, z) : path
    yPath = (x, not y, z) : path
    zPath = (x, y, not z) : path

octant :: Octant -> Octree -> Octree
octant (x, y, z) = if y then a else b
  where
  (uA, uB, lA, lB) = if x then (u1, u4, l1, l4) else (u2, u3, l2, l3)
  (a, b) = if z then (uA, uB) else (lA, lB)

subOctree :: Context -> Path -> Octree -> (Context, Octree)
subOctree context []    octree = (context, octree)
subOctree context (a:b) octree = case octree of
  Octree _ _ _ _ _ _ _ _ _ _ -> subOctree ((octree, a) : context) b (octant a octree)
  _ -> (context, octree)

xor :: Bool -> Bool -> Bool
xor True False = True
xor False True = True
xor _ _ = False

allInside :: [Octree] -> Bool
allInside a = all (== Inside) a

allOutside :: [Octree] -> Bool
allOutside a = all (== Outside) a

allSurface :: [Octree] -> Bool
allSurface = all isSurface

isSurface :: Octree -> Bool
isSurface (Surface _ _ _) = True
isSurface _               = False

{-
add :: Vertex -> Vertex -> Vertex
add (a,b,c) (d,e,f) = (a + d, b + e, c + f)

sub :: Vertex -> Vertex -> Vertex
sub (a,b,c) (d,e,f) = (a - d, b - e, c - f)

cross :: Vertex -> Vertex -> Vertex
cross (v1, v2, v3) (w1, w2, w3) = (v2 * w3 - v3 * w2, v3 * w1 - v1 * w3, v1 * w2 - v2 * w1)  -- Cross product.
-}

mesh :: Octree -> IO ()
mesh octree = GL.renderPrimitive GL.Quads render
  where
  mesh = meshVertices [] octree
  quads :: [Int]
  quads = concat [ [ids M.! a, ids M.! b, ids M.! c, ids M.! d] | (_, (_, _, a, _)) <- mesh, (a, b, c, d) <- a ]
  colors  = IM.fromList [ (ids M.! a, color)  | (a, (_, _, _, color))  <- mesh ]
  normals = IM.fromList [ (ids M.! a, normal) | (a, (normal, _, _, _)) <- mesh ]
  vertices1 = fst $ unzip mesh
  ids = M.fromList $ zip vertices1 [0..]
  vertices = IM.fromList $ zip [0..] vertices1
  render = sequence_ $ map glCmd $ glCmdOpt $ concat $ map glCmds quads

  glCmds :: Int -> [GlCmd]
  glCmds i = [C c1 c2 c3, N n1 n2 n3, V v1 v2 v3]
    where
    (c1, c2, c3) = colors   IM.! i
    (n1, n2, n3) = normals  IM.! i
    (v1, v2, v3) = vertices IM.! i

glCmd :: GlCmd -> IO ()
glCmd a = case a of
  C a b c -> color3  a b c
  N a b c -> normal3 a b c
  V a b c -> vertex3 a b c

glCmdOpt :: [GlCmd] -> [GlCmd]
glCmdOpt [] = []
glCmdOpt (a:b) = a : f a b
  where
  f _ [] = []
  f lastColor (a:b) = case a of
    C _ _ _ | a == lastColor -> f lastColor b
            | otherwise      -> a : f a b
    _                        -> a : f lastColor b

data GlCmd
  = C Float  Float  Float
  | N Double Double Double
  | V Double Double Double
  deriving Eq

meshVertices :: Context -> Octree -> [(Vertex, (Vertex, [Vertex], [(Vertex, Vertex, Vertex, Vertex)], Color))]
meshVertices context octree = case octree of
  Inside  -> []
  Outside -> []
  Surface point' normal color -> [(point', (normal, a ++ b ++ c, m ++ n ++ o, color))]
    where
    a = if isSurface xp then [point xp] else []
    b = if isSurface yp then [point yp] else []
    c = if isSurface zp then [point zp] else []
    m = if allSurface [xp, xpyp, yp] then [(point', point xp, point xpyp, point yp)] else []
    n = if allSurface [yp, ypzp, zp] then [(point', point yp, point ypzp, point zp)] else []
    o = if allSurface [xp, xpzp, zp] then [(point', point xp, point xpzp, point zp)] else []
    (xpC, xp) = neighbor context (X, True)
    (ypC, yp) = neighbor context (Y, True)
    (_,   zp) = neighbor context (Z, True)
    (_, xpyp) = neighbor xpC (Y, True)
    (_, ypzp) = neighbor ypC (Z, True)
    (_, xpzp) = neighbor xpC (Z, True)
  octree -> concat
    [ meshVertices ((octree, (True,  True,  True )) : context) $ u1 octree
    , meshVertices ((octree, (False, True,  True )) : context) $ u2 octree
    , meshVertices ((octree, (False, False, True )) : context) $ u3 octree
    , meshVertices ((octree, (True,  False, True )) : context) $ u4 octree
    , meshVertices ((octree, (True,  True,  False)) : context) $ l1 octree
    , meshVertices ((octree, (False, True,  False)) : context) $ l2 octree
    , meshVertices ((octree, (False, False, False)) : context) $ l3 octree
    , meshVertices ((octree, (True,  False, False)) : context) $ l4 octree
    ]