module Language.Mecha.Solid
  ( Solid
  , PrimQuery (..)
  , primitive
  , transform
  , translate
  , scale
  , rotate
  , union
  , intersection
  , difference
  -- , mesh
  ) where

import Language.Mecha.Octree (Vertex)
import qualified Language.Mecha.Octree as OT

type Matrix = (Vector, Vector, Vector, Vector)
type Vector = (Double, Double, Double, Double)

identity :: Matrix
identity =
  ( (1, 0, 0, 0)
  , (0, 1, 0, 0)
  , (0, 0, 1, 0)
  , (0, 0, 0, 1)
  )

mm :: Matrix -> Matrix -> Matrix
mm a b = x
  where
  ((a11, a12, a13, a14), (a21, a22, a23, a24), (a31, a32, a33, a34), (a41, a42, a43, a44)) = a
  ((b11, b12, b13, b14), (b21, b22, b23, b24), (b31, b32, b33, b34), (b41, b42, b43, b44)) = b
  x = ((x11, x12, x13, x14), (x21, x22, x23, x24), (x31, x32, x33, x34), (x41, x42, x43, x44))
  x11 = a11 * b11 + a12 * b21 + a13 * b31 + a14 * b41
  x12 = a11 * b12 + a12 * b22 + a13 * b32 + a14 * b42
  x13 = a11 * b13 + a12 * b23 + a13 * b33 + a14 * b43
  x14 = a11 * b14 + a12 * b24 + a13 * b34 + a14 * b44
  x21 = a21 * b11 + a22 * b21 + a23 * b31 + a24 * b41
  x22 = a21 * b12 + a22 * b22 + a23 * b32 + a24 * b42
  x23 = a21 * b13 + a22 * b23 + a23 * b33 + a24 * b43
  x24 = a21 * b14 + a22 * b24 + a23 * b34 + a24 * b44
  x31 = a31 * b11 + a32 * b21 + a33 * b31 + a34 * b41
  x32 = a31 * b12 + a32 * b22 + a33 * b32 + a34 * b42
  x33 = a31 * b13 + a32 * b23 + a33 * b33 + a34 * b43
  x34 = a31 * b14 + a32 * b24 + a33 * b34 + a34 * b44
  x41 = a41 * b11 + a42 * b21 + a43 * b31 + a44 * b41
  x42 = a41 * b12 + a42 * b22 + a43 * b32 + a44 * b42
  x43 = a41 * b13 + a42 * b23 + a43 * b33 + a44 * b43
  x44 = a41 * b14 + a42 * b24 + a43 * b34 + a44 * b44

mv :: Matrix -> Vector -> Vector
mv a b = x
  where
  ((a11, a12, a13, a14), (a21, a22, a23, a24), (a31, a32, a33, a34), (a41, a42, a43, a44)) = a
  (b1, b2, b3, b4) = b
  x1 = a11 * b1 * a12 * b2 + a13 * b3 + a14 * b4
  x2 = a21 * b1 * a22 * b2 + a23 * b3 + a24 * b4
  x3 = a31 * b1 * a32 * b2 + a33 * b3 + a34 * b4
  x4 = a41 * b1 * a42 * b2 + a43 * b3 + a44 * b4
  x = (x1, x2, x3, x4)

type Curve = [(Double, Double)]

data PrimQuery
  = PrimInside
  | PrimOutside
  | PrimSurface Double Double Double

data Solid
  = Primitive Matrix Matrix (Vertex -> Double -> PrimQuery)  -- Forward and reverse matrix.  Reverse to transform octree center and radius.  Forward to transform surface normal.
  | Union        Solid Solid
  | Intersection Solid Solid
  | Difference   Solid Solid

primitive :: (Vertex -> Double -> PrimQuery) -> Solid
primitive = Primitive identity identity

transform :: Matrix -> Matrix -> Solid -> Solid
transform m m' s = case s of
  Primitive n n' f -> Primitive (mm m n) (mm n' m') f
  Union        a b -> Union        (transform m m' a) (transform m m' b)
  Intersection a b -> Intersection (transform m m' a) (transform m m' b)
  Difference   a b -> Difference   (transform m m' a) (transform m m' b)

translate :: Double -> Double -> Double -> Solid -> Solid
translate x y z = transform
  ( (1, 0, 0, x)
  , (0, 1, 0, y)
  , (0, 0, 1, z)
  , (0, 0, 0, 1)
  )
  ( (1, 0, 0, -x)
  , (0, 1, 0, -y)
  , (0, 0, 1, -z)
  , (0, 0, 0,  1)
  )

scale :: Double -> Double -> Double -> Solid -> Solid
scale x y z = transform
  ( (x, 0, 0, 0)
  , (0, y, 0, 0)
  , (0, 0, z, 0)
  , (0, 0, 0, 1)
  )
  ( (1/x,   0,   0,   0)
  , (  0, 1/y,   0,   0)
  , (  0,   0, 1/z,   0)
  , (  0,   0,   0,   1)
  )

rotate :: Double -> Double -> Double -> Double -> Solid -> Solid
rotate x y z a = error "rotations not supported yet"

union :: Solid -> Solid -> Solid
union = Union

intersection :: Solid -> Solid -> Solid
intersection = Intersection

difference :: Solid -> Solid -> Solid
difference = Difference

{-
mesh :: Double -> Double -> Solid -> IO ()
mesh radius precision solid = OT.mesh 4 $ octree solid
  where
  octree :: Solid -> Octree
  octree solid = case solid of
    Union        a b -> OT.union        (octree a) (octree b)
    Intersection a b -> OT.intersection (octree a) (octree b)
    Difference   a b -> OT.difference   (octree a) (octree b)
    Primitive n n' f -> 
      where
      prim :: (Double, Double, Double) -> Double -> Octree
      prim (x, y, z) r = case f center radius of
        where
        (x', y', z', _) = mv n' (x, y, z, 1)
        --XXX How to transform the radius?  Make it (1,1,1)?
-}