diff --git a/clash-lib/src/Clash/Driver.hs b/clash-lib/src/Clash/Driver.hs index 6bf14d7848..023669b38e 100644 --- a/clash-lib/src/Clash/Driver.hs +++ b/clash-lib/src/Clash/Driver.hs @@ -120,8 +120,9 @@ import qualified Clash.Netlist.Id as Id import Clash.Netlist.Types (IdentifierText, BlackBox (..), Component (..), FilteredHWType, HWMap, SomeBackend (..), TopEntityT(..), TemplateFunction, ComponentMap, findClocks, ComponentMeta(..)) -import Clash.Normalize (checkNonRecursive, cleanupGraph, - normalize, runNormalization) +import Clash.Normalize (checkANF, checkNonRecursive, + cleanupGraph, normalize, + runNormalization) import Clash.Normalize.Util (callGraph, tvSubstWithTyEq) import qualified Clash.Primitives.Sized.Signed as P import qualified Clash.Primitives.Sized.ToInteger as P @@ -1015,9 +1016,11 @@ normalizeEntity normalizeEntity env bindingsMap typeTrans peEval eval topEntities supply tm = transformedBindings where doNorm = do norm <- normalize [tm] - let normChecked = checkNonRecursive norm - cleaned <- cleanupGraph tm normChecked - return cleaned + let normChecked = checkNonRecursive norm + anfChecked = checkANF normChecked + cleaned <- cleanupGraph tm anfChecked + let cleanedANF = checkANF cleaned + return cleanedANF transformedBindings = runNormalization env supply bindingsMap typeTrans peEval eval emptyVarEnv topEntities doNorm diff --git a/clash-lib/src/Clash/Normalize.hs b/clash-lib/src/Clash/Normalize.hs index a8ef45bb40..42815edabd 100644 --- a/clash-lib/src/Clash/Normalize.hs +++ b/clash-lib/src/Clash/Normalize.hs @@ -50,7 +50,8 @@ import Clash.Core.Pretty (PrettyOptions(..), showPpr, s import Clash.Core.Subst (extendGblSubstList, mkSubst, substTm) import Clash.Core.Term (Term (..), collectArgsTicks - ,mkApps, mkTicks) + ,collectBndrs, collectTicks + ,mkApps, mkTicks, stripTicks) import Clash.Core.Type (Type, splitCoreFunForallTy) import Clash.Core.TyCon (TyConMap) import Clash.Core.Type (isPolyTy) @@ -61,7 +62,7 @@ import Clash.Core.VarEnv mkVarEnv, mkVarSet, notElemVarEnv, notElemVarSet, nullVarEnv, unionVarEnv) import Clash.Debug (traceIf) import Clash.Driver.Types - (BindingMap, Binding(..), DebugOpts(..), ClashEnv(..)) + (BindingMap, Binding(..), DebugOpts(..), ClashEnv(..), IsPrim(..)) import Clash.Netlist.Types (HWMap, FilteredHWType(..)) import Clash.Netlist.Util @@ -248,6 +249,94 @@ checkNonRecursive norm = case mapMaybeVarEnv go norm of go (Binding nm _ _ _ tm r) = if r then Just (nm,tm) else Nothing +-- | Check whether the normalized bindings are in, what Clash calls, ANF. +-- Specifically, for each non-primitive binding, after stripping outer lambdas +-- and ticks, checks that: +-- +-- 1. The outermost expression is a 'Letrec' +-- 2. The body of the 'Letrec' is a plain 'Var' +-- 3. No RHS of a 'Letrec' binding is itself introduces variables through 'Let', +-- 'Lam', or 'TyLam'. This is must hold for any subterm. +-- +-- Typically, ANF would also make sure all arguments of application are variable +-- references. This isn't checked for two reasons: +-- +-- 1. Primitives like to inspect their arguments for certain values +-- 2. Field projections shouldn't create a bunch of indirections +-- +-- Note: we currently don't check any arguments to primitives. These arguments +-- can introduce binders through lambdas (e.g., in case of a HO-function) and +-- should themselves be in ANF. +checkANF + :: BindingMap + -- ^ Normalized binders to check + -> BindingMap +checkANF norm = foldr check norm (eltsVarEnv norm) + where + check (Binding _nm _ _ IsPrim _tm _) acc = acc + check (Binding nm _ _ IsFun tm _) acc = + case body1 of + Letrec xes result -> + case stripTicks result of + Var _ -> + let + badRhss = + [ showPpr (varName bid) ++ " = " ++ showPpr' opts be + | (bid, be) <- xes + , hasNestedBinder be + ] + in + case badRhss of + [] -> acc + bs -> error [i| + Binding '#{showPpr (varName nm)}' has non-ANF RHS(es) after normalization: + + #{unlines bs} + |] + other -> + error $ [i| + Binding '#{showPpr (varName nm)}': letrec body is not a simple Var after + normalization: + + #{showPpr other} + |] + other -> + error $ $(curLoc) ++ [i| + Binding '#{showPpr (varName nm)}': top-level expression is not a Letrec after + normalization: + + #{showPpr other} + |] + where + (_, body0) = collectBndrs tm + (body1, _) = collectTicks body0 + + -- | Recursively check whether a term contains a binder ('Let', 'Lam', or + -- 'TyLam') anywhere in its sub-terms. Arguments to primitives are not + -- checked, as they may legitimately contain lambdas (e.g. higher-order + -- primitives). + hasNestedBinder :: Term -> Bool + hasNestedBinder e = case stripTicks e of + Let {} -> True + Lam {} -> True + TyLam {} -> True + App {} -> + let (hd, args, _) = collectArgsTicks e + in case hd of + Prim {} -> False + _ -> hasNestedBinder hd || any (either hasNestedBinder (const False)) args + TyApp f _ -> hasNestedBinder f + Case subj _ alts -> hasNestedBinder subj || any (hasNestedBinder . snd) alts + Cast f _ _ -> hasNestedBinder f + Var {} -> False + Data {} -> False + Literal {} -> False + Prim {} -> False + Tick {} -> False + + opts = PrettyOptions { displayUniques = False, displayTypes = True + , displayQualifiers = False, displayTicks = False } + -- | Perform general \"clean up\" of the normalized (non-recursive) function -- hierarchy. This includes: --