{-# LANGUAGE GADTs #-}

import Hardware.Chalk
import Control.Applicative

-- The expression language corresponding to the (symbolic)
-- values that we pass through our circuits
data Expr where
  Val :: Int -> Expr
  Var :: String -> Expr 
  Add :: Expr -> Expr -> Expr
  Mul :: Expr -> Expr -> Expr
  Lt :: Expr -> Expr -> Expr 
  And :: Expr -> Expr -> Expr 
  If ::  Expr -> Expr -> Expr -> Expr
  Not :: Expr -> Expr
    deriving (Show, Eq)

instance Num Expr where
  (+) x y = Add x y
  (*) = Mul
  abs = undefined
  signum = undefined
  fromInteger x = Val (fromInteger x)

(.&&.) :: Expr -> Expr -> Expr 
x .&&. y = And x y

(.<.) :: Expr -> Expr -> Expr
x .<. y = Lt x y

cond ::  Expr -> Expr -> Expr -> Expr
cond = If

-- We use "Ticked" signals, where we record some Expr representing the cost
newtype TSignal a = TSignal {unT :: Signal (Ticked a)}

instance Functor TSignal where
  fmap f (TSignal x) = TSignal (fmap (fmap f) x)

instance Applicative TSignal where
  pure x = TSignal (pure (pure x))
  (TSignal s) <*> (TSignal x) = TSignal (pure (<*>) <*> s <*> x)

data Ticked a = Ticked {tval :: a, cost :: Expr} deriving (Show)
instance Functor Ticked where
  fmap f (Ticked x cost) = Ticked (f x) cost

instance Applicative Ticked where
  pure x = Ticked x 0
  (<*>) (Ticked f c1) (Ticked x c2) = Ticked (f x) (Add c1 c2)

pay :: Expr -> TSignal a -> TSignal a
pay i (TSignal t) = TSignal (fmap (\x -> x {cost = Add (cost x) i }) t)

costed :: Int -> a -> TSignal a
costed i x = pay (Val i) (pure x)

payIf :: TSignal Expr -> Expr -> TSignal a -> TSignal a
payIf (TSignal b) c (TSignal s) = TSignal (pure foo <*> b <*> s)
  where
  foo :: Ticked Expr -> Ticked a -> Ticked a
  foo (Ticked b cost1) (Ticked x cost2) = 
    Ticked x (Add cost1 (If b c cost2))

-- This is the clever multiplier.  It takes an argument of type Expr
--  -> Expr that estimates the cost of the small multiplier, in
--  terms of the size of the inputs

mult :: (Expr -> Expr) -> TSignal (Expr, Expr) -> TSignal Expr
mult cheapCost xys = 
  mux (pay cmpCost sizeTest) (cheapMul sizeTest) (dearMul (inv sizeTest))
  where
  sizeTest :: TSignal Expr
  sizeTest =  
    pure (\(x,y) -> (x .<. threshold) .&&. (y .<. threshold)) <*> xys
  cheapMul, dearMul :: TSignal Expr -> TSignal Expr
  cheapMul bs = payIf bs (cheapCost threshold) (uncurry (*) <$> xys)
  dearMul bs = payIf bs dearCost (uncurry (*) <$> xys)
  inv c = pure Not <*> c

threshold = Var "threshold"
cmpCost = Var "cmpCost"
dearCost = Var "dearCost"

-- This is the 'dumb'
mux :: TSignal Expr -> TSignal Expr -> TSignal Expr -> TSignal Expr
mux bs ts es = pure If <*> bs <*> ts <*> es

test cheap = 
  let (Ticked v c) = first 
         (unT (mult cheap (TSignal $ delay (pure (Var "x1", Var "x2")) undefined)))
  in Ticked v (fix simpl c)

simpl :: Expr -> Expr
simpl (Add (Val 0) x) = simpl x
simpl (Add x (Val 0)) = simpl x
simpl (Add y z) = Add (simpl y) (simpl z)
simpl (If c t e) = If (simpl c) (simpl t) (simpl e)
simpl (And b1 b2) = And (simpl b1) (simpl b2)
simpl (Lt x1 x2) = Lt (simpl x1) (simpl x2)
simpl (Not x) = Not (simpl x)
simpl y = y

fix f x 
  | f x == x = x
  | otherwise = fix f (f x)
