Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog/2025-12-14T16_40_22+01_00_improve_deriveAutoReg
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
CHANGED: `deriveAutoReg` to improve the calculated constraints

In addition to looking for constraints on AutoReg instances fields,
it now also looks for constraints on the NFDataX instance of the whole type.
53 changes: 48 additions & 5 deletions clash-prelude/src/Clash/Class/AutoReg/Internal.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{-|
Copyright : (C) 2019 , Google Inc.,
2021-2022, QBayLogic B.V.,
2021-2025, QBayLogic B.V.,
2021-2022, Myrtle.ai
License : BSD2 (see the file LICENSE)
Maintainer : QBayLogic B.V. <devops@qbaylogic.com>
Expand All @@ -20,7 +20,7 @@ module Clash.Class.AutoReg.Internal
where

import Data.List (nub,zipWith4)
import Data.Maybe (fromMaybe,isJust)
import Data.Maybe (catMaybes,fromMaybe,isJust)

import GHC.Stack (HasCallStack)
import GHC.TypeNats (KnownNat,Nat,type (+))
Expand Down Expand Up @@ -326,11 +326,36 @@ deriveAutoRegProduct tyInfo conInfo = go (constructorName conInfo) fieldInfos
[] -> [| $tyConE |]

autoRegDec <- funD 'autoReg [clause argsP (normalB body) decls]
ctx <- calculateRequiredContext conInfo
ctxAutoReg <- fmap nub $ calculateRequiredContext conInfo
-- look up if the NFDataX superclass has any (extra) constraints
ctxNFDataX <- fmap nub $ constraintsWantedFor ''NFDataX [ty]
let ctxNFDataXfiltered = removedImpliedNFXby ctxAutoReg ctxNFDataX
let ctx = nub (ctxAutoReg ++ ctxNFDataXfiltered)
return [InstanceD Nothing ctx (AppT (ConT ''AutoReg) ty)
[ autoRegDec
, PragmaD (InlineP 'autoReg Inline FunLike AllPhases) ]]

-- | Looks through the constraits in the 2nd argument and
-- drops any `NFDataX a`, when there is a corresponding `AutoReg a` in the first argument
removedImpliedNFXby :: Cxt -> Cxt -> Cxt
removedImpliedNFXby autoreg nfdatx = filter (not . isImplied) nfdatx
where
autoregs = catMaybes $ map (isTyClass ''AutoReg) autoreg
isImplied x = case isTyClass ''NFDataX x of
Nothing -> False
Just tys -> elem (map viewType tys) autoregs

isTyClass :: Name -> Pred -> Maybe [Type]
isTyClass nm x = case unfoldType x of
(ConT nm',tys) | nm == nm' -> Just tys
_ -> Nothing

-- | look through kind signatures
viewType :: Type -> Type
viewType x = case x of
SigT ty _kind -> ty
ty -> ty

-- Calculate the required constraint to call autoReg on all the fields of a
-- given constructor
calculateRequiredContext :: ConstructorInfo -> Q Cxt
Expand All @@ -341,7 +366,7 @@ calculateRequiredContext conInfo = do

constraintsWantedFor :: Name -> [Type] -> Q Cxt
constraintsWantedFor clsNm tys
| show clsNm == "GHC.TypeNats.KnownNat" = do
| show clsNm == show ''KnownNat = do
-- KnownNat is special, you can't just lookup instances with reifyInstances.
-- So we just pass KnownNat constraints.
-- This will most likely require UndecidableInstances.
Expand All @@ -367,9 +392,10 @@ constraintsWantedFor clsNm [ty] = case ty of
_ -> fail $ "Got unexpected instance: " ++ pprint insts
where
isOk :: Type -> Bool
isOk (unfoldType -> (_cls,tys)) =
isOk (unfoldType -> (cls,tys)) =
case tys of
[VarT _] -> True
[SigT t _] -> isOk (AppT cls t) -- look through a kind signature
[_] -> False
_ -> True -- see [NOTE: MultiParamTypeClasses]
needRecurse :: Type -> Bool
Expand All @@ -380,6 +406,7 @@ constraintsWantedFor clsNm [ty] = case ty of
[ConT _] -> False -- we can just drop constraints like: "AutoReg Bool => ..."
[LitT _] -> False -- or "KnownNat 4 =>"
[TupleT 0] -> False -- handle Unit ()
[SigT t _] -> needRecurse (AppT cls t) -- look through a kind signature
[_] -> error ( "Error while deriveAutoReg: don't know how to handle: "
++ pprint cls ++ " (" ++ pprint tys ++ ")" )
_ -> False -- see [NOTE: MultiParamTypeClasses]
Expand Down Expand Up @@ -421,6 +448,7 @@ findTyVarSubsts = go
(ImplicitParamT _ x1, ImplicitParamT _ x2) -> go x1 x2
(PromotedT _ , PromotedT _ ) -> []
(TupleT _ , TupleT _ ) -> []
(TupleT _ , ConT _ ) -> [] -- see NOTE [TupleT ConT equality]
(UnboxedTupleT _ , UnboxedTupleT _ ) -> []
(UnboxedSumT _ , UnboxedSumT _ ) -> []
(ArrowT , ArrowT ) -> []
Expand All @@ -435,6 +463,21 @@ findTyVarSubsts = go
(WildCardT , WildCardT ) -> []
_ -> error $ unlines [ "findTyVarSubsts: Unexpected types"
, "ty1:", pprint ty1,"ty2:", pprint ty2]
{-
NOTE [TupleT ConT equality]

When looking up typeclass instances for tuples we get a slight mismatch
in the representation of the tuple types.

Internally in GHC a 3-tuple is represented as `TupleT 3`.
But our API is Name based, so we look them up as `ConT "(,,)"`, see 'deriveAutoRegTuple'.
This confused 'findTyVarSubsts', as it gets both forms and tries to unify them.
To allow it to continue I just allowed it to see them as equal.
This may seem dangerous, but isn't any more dangerous then seeing arbitratry ConT's
or LitT's as equal.
And we should only get things from GHC that unify.
-}


applyTyVarSubsts :: [(Name,Type)] -> Type -> Type
applyTyVarSubsts substs ty = go ty
Expand Down
1 change: 1 addition & 0 deletions tests/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ runClashTest = defaultMain
, runTest "T1507" def{hdlSim=[]}
, let _opts = def{hdlSim=[], hdlTargets=[VHDL]}
in runTest "T1632" _opts
, runTest "T3111_Constrained_NFDataX" def{hdlSim=[], hdlLoad=[], hdlTargets=[VHDL]}
]
, clashTestGroup "Basic"
[ runTest "AES" def{hdlSim=[]}
Expand Down
37 changes: 37 additions & 0 deletions tests/shouldwork/AutoReg/T3111_Constrained_NFDataX.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{-# LANGUAGE UndecidableInstances #-}
{- # OPTIONS_GHC -ddump-splices #-}
module T3111_Constrained_NFDataX where

{-
Tests that deriveAutoReg can work with types with constrained NFDataX instances
-}

import Clash.Prelude

data MsgConfig = MsgConfig
{ _msgChannelWidth :: Nat
, _msgDataLen :: Nat
}

type family MsgChannelWidth (cfg :: MsgConfig) :: Nat where
MsgChannelWidth ('MsgConfig x _) = x

type family MsgDataLen (cfg :: MsgConfig) :: Nat where
MsgDataLen ('MsgConfig _ x) = x

type KnownMsgConfig cfg = (KnownNat (MsgDataLen cfg), KnownNat (MsgChannelWidth cfg))

data Message (cfg :: MsgConfig) = Status
{ _msg_channel :: Unsigned (MsgChannelWidth cfg)
, _msg_data :: Vec (MsgDataLen cfg) (Unsigned 8)
, _msg_chksum :: Unsigned 4
}
deriving (Bundle, Generic)
-- deriving instance (KnownMsgConfig cfg) => BitPack (Message cfg)
deriving instance (KnownMsgConfig cfg) => NFDataX (Message cfg)
deriveAutoReg ''Message

type MyConfig = Message ('MsgConfig 3 4)

topEntity :: SystemClockResetEnable => Signal System MyConfig -> Signal System MyConfig
topEntity = autoReg (Status 1 (repeat 2) 3)