diff --git a/changelog/2025-12-14T16_40_22+01_00_improve_deriveAutoReg b/changelog/2025-12-14T16_40_22+01_00_improve_deriveAutoReg new file mode 100644 index 0000000000..a336e4fa39 --- /dev/null +++ b/changelog/2025-12-14T16_40_22+01_00_improve_deriveAutoReg @@ -0,0 +1,4 @@ +CHANGED: `deriveAutoReg` to improve the calculated constraints + +In addition to looking for constraints on AutoReg instances fields, +it now also looks for constraints on the NFDataX instance of the whole type. diff --git a/clash-prelude/src/Clash/Class/AutoReg/Internal.hs b/clash-prelude/src/Clash/Class/AutoReg/Internal.hs index 747f7afd4d..8afe7221fe 100644 --- a/clash-prelude/src/Clash/Class/AutoReg/Internal.hs +++ b/clash-prelude/src/Clash/Class/AutoReg/Internal.hs @@ -1,6 +1,6 @@ {-| Copyright : (C) 2019 , Google Inc., - 2021-2022, QBayLogic B.V., + 2021-2025, QBayLogic B.V., 2021-2022, Myrtle.ai License : BSD2 (see the file LICENSE) Maintainer : QBayLogic B.V. @@ -20,7 +20,7 @@ module Clash.Class.AutoReg.Internal where import Data.List (nub,zipWith4) -import Data.Maybe (fromMaybe,isJust) +import Data.Maybe (catMaybes,fromMaybe,isJust) import GHC.Stack (HasCallStack) import GHC.TypeNats (KnownNat,Nat,type (+)) @@ -326,11 +326,36 @@ deriveAutoRegProduct tyInfo conInfo = go (constructorName conInfo) fieldInfos [] -> [| $tyConE |] autoRegDec <- funD 'autoReg [clause argsP (normalB body) decls] - ctx <- calculateRequiredContext conInfo + ctxAutoReg <- fmap nub $ calculateRequiredContext conInfo + -- look up if the NFDataX superclass has any (extra) constraints + ctxNFDataX <- fmap nub $ constraintsWantedFor ''NFDataX [ty] + let ctxNFDataXfiltered = removedImpliedNFXby ctxAutoReg ctxNFDataX + let ctx = nub (ctxAutoReg ++ ctxNFDataXfiltered) return [InstanceD Nothing ctx (AppT (ConT ''AutoReg) ty) [ autoRegDec , PragmaD (InlineP 'autoReg Inline FunLike AllPhases) ]] +-- | Looks through the constraits in the 2nd argument and +-- drops any `NFDataX a`, when there is a corresponding `AutoReg a` in the first argument +removedImpliedNFXby :: Cxt -> Cxt -> Cxt +removedImpliedNFXby autoreg nfdatx = filter (not . isImplied) nfdatx + where + autoregs = catMaybes $ map (isTyClass ''AutoReg) autoreg + isImplied x = case isTyClass ''NFDataX x of + Nothing -> False + Just tys -> elem (map viewType tys) autoregs + + isTyClass :: Name -> Pred -> Maybe [Type] + isTyClass nm x = case unfoldType x of + (ConT nm',tys) | nm == nm' -> Just tys + _ -> Nothing + + -- | look through kind signatures + viewType :: Type -> Type + viewType x = case x of + SigT ty _kind -> ty + ty -> ty + -- Calculate the required constraint to call autoReg on all the fields of a -- given constructor calculateRequiredContext :: ConstructorInfo -> Q Cxt @@ -341,7 +366,7 @@ calculateRequiredContext conInfo = do constraintsWantedFor :: Name -> [Type] -> Q Cxt constraintsWantedFor clsNm tys - | show clsNm == "GHC.TypeNats.KnownNat" = do + | show clsNm == show ''KnownNat = do -- KnownNat is special, you can't just lookup instances with reifyInstances. -- So we just pass KnownNat constraints. -- This will most likely require UndecidableInstances. @@ -367,9 +392,10 @@ constraintsWantedFor clsNm [ty] = case ty of _ -> fail $ "Got unexpected instance: " ++ pprint insts where isOk :: Type -> Bool - isOk (unfoldType -> (_cls,tys)) = + isOk (unfoldType -> (cls,tys)) = case tys of [VarT _] -> True + [SigT t _] -> isOk (AppT cls t) -- look through a kind signature [_] -> False _ -> True -- see [NOTE: MultiParamTypeClasses] needRecurse :: Type -> Bool @@ -380,6 +406,7 @@ constraintsWantedFor clsNm [ty] = case ty of [ConT _] -> False -- we can just drop constraints like: "AutoReg Bool => ..." [LitT _] -> False -- or "KnownNat 4 =>" [TupleT 0] -> False -- handle Unit () + [SigT t _] -> needRecurse (AppT cls t) -- look through a kind signature [_] -> error ( "Error while deriveAutoReg: don't know how to handle: " ++ pprint cls ++ " (" ++ pprint tys ++ ")" ) _ -> False -- see [NOTE: MultiParamTypeClasses] @@ -421,6 +448,7 @@ findTyVarSubsts = go (ImplicitParamT _ x1, ImplicitParamT _ x2) -> go x1 x2 (PromotedT _ , PromotedT _ ) -> [] (TupleT _ , TupleT _ ) -> [] + (TupleT _ , ConT _ ) -> [] -- see NOTE [TupleT ConT equality] (UnboxedTupleT _ , UnboxedTupleT _ ) -> [] (UnboxedSumT _ , UnboxedSumT _ ) -> [] (ArrowT , ArrowT ) -> [] @@ -435,6 +463,21 @@ findTyVarSubsts = go (WildCardT , WildCardT ) -> [] _ -> error $ unlines [ "findTyVarSubsts: Unexpected types" , "ty1:", pprint ty1,"ty2:", pprint ty2] +{- +NOTE [TupleT ConT equality] + +When looking up typeclass instances for tuples we get a slight mismatch +in the representation of the tuple types. + +Internally in GHC a 3-tuple is represented as `TupleT 3`. +But our API is Name based, so we look them up as `ConT "(,,)"`, see 'deriveAutoRegTuple'. +This confused 'findTyVarSubsts', as it gets both forms and tries to unify them. +To allow it to continue I just allowed it to see them as equal. +This may seem dangerous, but isn't any more dangerous then seeing arbitratry ConT's +or LitT's as equal. +And we should only get things from GHC that unify. +-} + applyTyVarSubsts :: [(Name,Type)] -> Type -> Type applyTyVarSubsts substs ty = go ty diff --git a/tests/Main.hs b/tests/Main.hs index 647211f79d..443c81a01e 100755 --- a/tests/Main.hs +++ b/tests/Main.hs @@ -355,6 +355,7 @@ runClashTest = defaultMain , runTest "T1507" def{hdlSim=[]} , let _opts = def{hdlSim=[], hdlTargets=[VHDL]} in runTest "T1632" _opts + , runTest "T3111_Constrained_NFDataX" def{hdlSim=[], hdlLoad=[], hdlTargets=[VHDL]} ] , clashTestGroup "Basic" [ runTest "AES" def{hdlSim=[]} diff --git a/tests/shouldwork/AutoReg/T3111_Constrained_NFDataX.hs b/tests/shouldwork/AutoReg/T3111_Constrained_NFDataX.hs new file mode 100644 index 0000000000..ad00b6b104 --- /dev/null +++ b/tests/shouldwork/AutoReg/T3111_Constrained_NFDataX.hs @@ -0,0 +1,37 @@ +{-# LANGUAGE UndecidableInstances #-} +{- # OPTIONS_GHC -ddump-splices #-} +module T3111_Constrained_NFDataX where + +{- +Tests that deriveAutoReg can work with types with constrained NFDataX instances +-} + +import Clash.Prelude + +data MsgConfig = MsgConfig + { _msgChannelWidth :: Nat + , _msgDataLen :: Nat + } + +type family MsgChannelWidth (cfg :: MsgConfig) :: Nat where + MsgChannelWidth ('MsgConfig x _) = x + +type family MsgDataLen (cfg :: MsgConfig) :: Nat where + MsgDataLen ('MsgConfig _ x) = x + +type KnownMsgConfig cfg = (KnownNat (MsgDataLen cfg), KnownNat (MsgChannelWidth cfg)) + +data Message (cfg :: MsgConfig) = Status + { _msg_channel :: Unsigned (MsgChannelWidth cfg) + , _msg_data :: Vec (MsgDataLen cfg) (Unsigned 8) + , _msg_chksum :: Unsigned 4 + } + deriving (Bundle, Generic) +-- deriving instance (KnownMsgConfig cfg) => BitPack (Message cfg) +deriving instance (KnownMsgConfig cfg) => NFDataX (Message cfg) +deriveAutoReg ''Message + +type MyConfig = Message ('MsgConfig 3 4) + +topEntity :: SystemClockResetEnable => Signal System MyConfig -> Signal System MyConfig +topEntity = autoReg (Status 1 (repeat 2) 3)