diff --git a/app/app.go b/app/app.go index a5f5ff78d0..aee0830a87 100644 --- a/app/app.go +++ b/app/app.go @@ -22,6 +22,7 @@ import ( icahosttypes "github.com/cosmos/ibc-go/v10/modules/apps/27-interchain-accounts/host/types" icatypes "github.com/cosmos/ibc-go/v10/modules/apps/27-interchain-accounts/types" ibccallbacks "github.com/cosmos/ibc-go/v10/modules/apps/callbacks" + ibccallbacksv2 "github.com/cosmos/ibc-go/v10/modules/apps/callbacks/v2" "github.com/cosmos/ibc-go/v10/modules/apps/transfer" ibctransferkeeper "github.com/cosmos/ibc-go/v10/modules/apps/transfer/keeper" ibctransfertypes "github.com/cosmos/ibc-go/v10/modules/apps/transfer/types" @@ -621,8 +622,9 @@ func NewWasmApp( wasmOpts..., ) + wasmContractKeeper := wasmkeeper.NewDefaultPermissionKeeper(&app.WasmKeeper) // Create fee enabled wasm ibc Stack - wasmStackIBCHandler := wasm.NewIBCHandler(app.WasmKeeper, app.IBCKeeper.ChannelKeeper, app.TransferKeeper, app.IBCKeeper.ChannelKeeper) + wasmStackIBCHandler := wasm.NewIBCHandler(app.WasmKeeper, app.IBCKeeper.ChannelKeeper, app.TransferKeeper, app.IBCKeeper.ChannelKeeper, wasmContractKeeper) // Create Interchain Accounts Stack // SendPacket, since it is originating from the application to core IBC: @@ -646,10 +648,14 @@ func NewWasmApp( // Create Transfer Stack var transferStack porttypes.IBCModule transferStack = transfer.NewIBCModule(app.TransferKeeper) + transferStack = wasm.NewIBCV1CallbacksPlusMiddleware(transferStack) transferStack = ibccallbacks.NewIBCMiddleware(transferStack, app.IBCKeeper.ChannelKeeper, wasmStackIBCHandler, wasm.DefaultMaxIBCCallbackGas) - transferICS4Wrapper := transferStack.(porttypes.ICS4Wrapper) + // Chains that also wire the IBC Hooks middleware should wrap the stack + // with IBCDedupMiddleware to reject Hooks/Callbacks same-side memo collisions. + // transferStack = wasm.NewIBCDedupMiddleware(transferStack, transferStack.(porttypes.ICS4Wrapper)) + // Since the callbacks middleware itself is an ics4wrapper, it needs to be passed to the ica controller keeper - app.TransferKeeper.WithICS4Wrapper(transferICS4Wrapper) + app.TransferKeeper.WithICS4Wrapper(transferStack.(porttypes.ICS4Wrapper)) // Create static IBC router, add app routes, then set and seal it ibcRouter := porttypes.NewRouter(). @@ -660,8 +666,16 @@ func NewWasmApp( app.IBCKeeper.SetRouter(ibcRouter) ibcRouterV2 := ibcapi.NewRouter() + transferV2Stack := ibcapi.IBCModule(wasm.NewIBCV2CallbacksPlusMiddleware(transferv2.NewIBCModule(app.TransferKeeper))) + transferV2Stack = ibccallbacksv2.NewIBCMiddleware( + transferV2Stack, + app.IBCKeeper.ChannelKeeperV2, + wasmStackIBCHandler, + app.IBCKeeper.ChannelKeeperV2, + wasm.DefaultMaxIBCCallbackGas, + ) ibcRouterV2 = ibcRouterV2. - AddRoute(ibctransfertypes.PortID, transferv2.NewIBCModule(app.TransferKeeper)). + AddRoute(ibctransfertypes.PortID, transferV2Stack). AddPrefixRoute(wasmkeeper.PortIDPrefixV2, wasmkeeper.NewIBC2Handler(app.WasmKeeper)) app.IBCKeeper.SetRouterV2(ibcRouterV2) diff --git a/go.mod b/go.mod index a554d28273..f44925940e 100644 --- a/go.mod +++ b/go.mod @@ -49,7 +49,7 @@ require ( cosmossdk.io/x/upgrade v0.2.0 github.com/cometbft/cometbft v0.38.21 github.com/cosmos/cosmos-db v1.1.3 - github.com/cosmos/ibc-go/v10 v10.5.0 + github.com/cosmos/ibc-go/v10 v10.6.0 github.com/distribution/reference v0.5.0 github.com/rs/zerolog v1.34.0 github.com/spf13/viper v1.21.0 diff --git a/go.sum b/go.sum index aa78eb9a78..62a08a24e0 100644 --- a/go.sum +++ b/go.sum @@ -833,8 +833,8 @@ github.com/cosmos/gogoproto v1.7.2 h1:5G25McIraOC0mRFv9TVO139Uh3OklV2hczr13KKVHC github.com/cosmos/gogoproto v1.7.2/go.mod h1:8S7w53P1Y1cHwND64o0BnArT6RmdgIvsBuco6uTllsk= github.com/cosmos/iavl v1.2.6 h1:Hs3LndJbkIB+rEvToKJFXZvKo6Vy0Ex1SJ54hhtioIs= github.com/cosmos/iavl v1.2.6/go.mod h1:GiM43q0pB+uG53mLxLDzimxM9l/5N9UuSY3/D0huuVw= -github.com/cosmos/ibc-go/v10 v10.5.0 h1:NI+cX04fXdu9JfP0V0GYeRi1ENa7PPdq0BYtVYo8Zrs= -github.com/cosmos/ibc-go/v10 v10.5.0/go.mod h1:a74pAPUSJ7NewvmvELU74hUClJhwnmm5MGbEaiTw/kE= +github.com/cosmos/ibc-go/v10 v10.6.0 h1:k7PZVSLXFtCdoWlU+ERGn2m1Np4Tw8BF8WyPGl0DOi4= +github.com/cosmos/ibc-go/v10 v10.6.0/go.mod h1:a74pAPUSJ7NewvmvELU74hUClJhwnmm5MGbEaiTw/kE= github.com/cosmos/ics23/go v0.11.0 h1:jk5skjT0TqX5e5QJbEnwXIS2yI2vnmLOgpQPeM5RtnU= github.com/cosmos/ics23/go v0.11.0/go.mod h1:A8OjxPE67hHST4Icw94hOxxFEJMBG031xIGF/JHNIY0= github.com/cosmos/keyring v1.2.0 h1:8C1lBP9xhImmIabyXW4c3vFjjLiBdGCmfLUfeZlV1Yo= diff --git a/x/wasm/ibc.go b/x/wasm/ibc.go index 2b59c8ebcd..64931bc670 100644 --- a/x/wasm/ibc.go +++ b/x/wasm/ibc.go @@ -1,9 +1,11 @@ package wasm import ( + "encoding/json" "math" wasmvmtypes "github.com/CosmWasm/wasmvm/v3/types" + callbackstypes "github.com/cosmos/ibc-go/v10/modules/apps/callbacks/types" transfertypes "github.com/cosmos/ibc-go/v10/modules/apps/transfer/types" clienttypes "github.com/cosmos/ibc-go/v10/modules/core/02-client/types" channeltypes "github.com/cosmos/ibc-go/v10/modules/core/04-channel/types" @@ -11,6 +13,7 @@ import ( ibcexported "github.com/cosmos/ibc-go/v10/modules/core/exported" errorsmod "cosmossdk.io/errors" + sdkmath "cosmossdk.io/math" sdk "github.com/cosmos/cosmos-sdk/types" @@ -37,10 +40,20 @@ type IBCHandler struct { channelKeeper types.ChannelKeeper transferKeeper types.ICS20TransferPortSource appVersionGetter appVersionGetter + contractKeeper types.ContractOpsKeeper } -func NewIBCHandler(k types.IBCContractKeeper, ck types.ChannelKeeper, tk types.ICS20TransferPortSource, vg appVersionGetter) IBCHandler { - return IBCHandler{keeper: k, channelKeeper: ck, transferKeeper: tk, appVersionGetter: vg} +func NewIBCHandler(k types.IBCContractKeeper, ck types.ChannelKeeper, tk types.ICS20TransferPortSource, vg appVersionGetter, contractKeeper types.ContractOpsKeeper) IBCHandler { + return IBCHandler{ + keeper: k, + channelKeeper: ck, + transferKeeper: tk, + appVersionGetter: vg, + contractKeeper: contractKeeper, + } +} + +func (i IBCHandler) SetICS4Wrapper(_ porttypes.ICS4Wrapper) { } // OnChanOpenInit implements the IBCModule interface @@ -331,15 +344,30 @@ func (i IBCHandler) IBCSendPacketCallback( packetSenderAddress string, version string, ) error { - _, err := validateSender(contractAddress, packetSenderAddress) - if err != nil { + if _, err := validateSender(contractAddress, packetSenderAddress); err != nil { return err } - - // no-op, since we are not interested in this callback + // reject src_callback.calldata + if srcCallbackHasCalldata(packetData) { + return errorsmod.Wrap(types.ErrInvalid, "src_callback must not contain a calldata field") + } return nil } +func srcCallbackHasCalldata(packetData []byte) bool { + var pd transfertypes.FungibleTokenPacketData + if err := json.Unmarshal(packetData, &pd); err != nil { + return false + } + _, obj := jsonStringHasKey(pd.Memo, "src_callback") + srcObj, ok := obj["src_callback"].(map[string]any) + if !ok { + return false + } + _, has := srcObj["calldata"] + return has +} + // IBCOnAcknowledgementPacketCallback implements the IBC Callbacks ContractKeeper interface // see https://github.com/cosmos/ibc-go/blob/main/docs/architecture/adr-008-app-caller-cbs.md#contractkeeper func (i IBCHandler) IBCOnAcknowledgementPacketCallback( @@ -447,13 +475,52 @@ func (i IBCHandler) IBCReceivePacketCallback( transferData.Token.Denom.Trace = append(trace, transferData.Token.Denom.Trace...) } + denom := transferData.Token.GetDenom().IBCDenom() + amount := transferData.Token.GetAmount() + + // dest_callback.calldata present: dispatch via Execute with the + // transferred funds. Otherwise fall through to ibc_destination_callback. + cbData, isCb, cbErr := callbackstypes.GetCallbackData( + transferData, version, packet.GetSourcePort(), 0, + DefaultMaxIBCCallbackGas, callbackstypes.DestinationCallbackKey, + ) + if isCb && cbErr != nil { + return errorsmod.Wrap(cbErr, "parse dest_callback") + } + if isCb && len(cbData.Calldata) != 0 { + amountInt, ok := sdkmath.NewIntFromString(amount) + if !ok { + return errorsmod.Wrapf(types.ErrInvalid, "invalid token amount: %s", amount) + } + funds := sdk.NewCoins(sdk.NewCoin(denom, amountInt)) + // Re-derive: ibccallbacks passes packet by value, so the + // rewriter's Receiver mutation doesn't reach this callback. + // https://github.com/cosmos/ibc-go/blob/v10.6.0/modules/apps/callbacks/ibc_middleware.go#L217 + intermediateBech32, err := DeriveIntermediateSender( + packet.GetDestChannel(), transferData.Sender, + sdk.GetConfig().GetBech32AccountAddrPrefix(), + ) + if err != nil { + return errorsmod.Wrap(err, "derive intermediate sender") + } + intermediate, err := sdk.AccAddressFromBech32(intermediateBech32) + if err != nil { + return errorsmod.Wrap(err, "parse intermediate sender") + } + _, err = i.contractKeeper.Execute(cachedCtx, contractAddr, intermediate, cbData.Calldata, funds) + if err != nil { + return errorsmod.Wrap(err, "execute contract via calldata") + } + return nil + } + transfer = &wasmvmtypes.IBCTransferCallback{ Receiver: receiverAddr.String(), Sender: transferData.Sender, Funds: wasmvmtypes.Array[wasmvmtypes.Coin]{ { - Denom: transferData.Token.GetDenom().IBCDenom(), - Amount: transferData.Token.GetAmount(), + Denom: denom, + Amount: amount, }, }, } @@ -533,3 +600,25 @@ func ValidateChannelParams(channelID string) error { func CreateErrorAcknowledgement(err error) ibcexported.Acknowledgement { return channeltypes.NewErrorAcknowledgementWithCodespace(err) } + +// jsonStringHasKey parses the memo as a json object and checks if it contains the key. +func jsonStringHasKey(memo, key string) (found bool, jsonObject map[string]interface{}) { + jsonObject = make(map[string]interface{}) + + // If there is no memo, the packet was either sent with an earlier version of IBC, or the memo was + // intentionally left blank. Nothing to do here. Ignore the packet and pass it down the stack. + if len(memo) == 0 { + return false, jsonObject + } + + // the jsonObject must be a valid JSON object + err := json.Unmarshal([]byte(memo), &jsonObject) + if err != nil { + return false, jsonObject + } + + // If the key doesn't exist, there's nothing to do on this hook. Continue by passing the packet + // down the stack + _, ok := jsonObject[key] + return ok, jsonObject +} diff --git a/x/wasm/ibc_callbacks_plus_middleware.go b/x/wasm/ibc_callbacks_plus_middleware.go new file mode 100644 index 0000000000..1cf179203a --- /dev/null +++ b/x/wasm/ibc_callbacks_plus_middleware.go @@ -0,0 +1,106 @@ +package wasm + +import ( + "encoding/json" + "fmt" + "math" + + callbackstypes "github.com/cosmos/ibc-go/v10/modules/apps/callbacks/types" + transfertypes "github.com/cosmos/ibc-go/v10/modules/apps/transfer/types" + channeltypes "github.com/cosmos/ibc-go/v10/modules/core/04-channel/types" + channeltypesv2 "github.com/cosmos/ibc-go/v10/modules/core/04-channel/v2/types" + porttypes "github.com/cosmos/ibc-go/v10/modules/core/05-port/types" + ibcapi "github.com/cosmos/ibc-go/v10/modules/core/api" + ibcexported "github.com/cosmos/ibc-go/v10/modules/core/exported" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/types/address" +) + +// Verbatim from https://github.com/cosmos/ibc-apps/blob/main/modules/ibc-hooks/types/keys.go +const SenderPrefix = "ibc-wasm-hook-intermediary" + +// Verbatim from https://github.com/cosmos/ibc-apps/blob/main/modules/ibc-hooks/keeper/keeper.go +func DeriveIntermediateSender(channel, originalSender, bech32Prefix string) (string, error) { + senderStr := fmt.Sprintf("%s/%s", channel, originalSender) + senderHash32 := address.Hash(SenderPrefix, []byte(senderStr)) + sender := sdk.AccAddress(senderHash32) + return sdk.Bech32ifyAddressBytes(bech32Prefix, sender) +} + +// rewriteReceiverForCalldata replaces Receiver with the intermediate sender when memo has dest_callback.calldata. +// Returns the data unchanged otherwise. +func rewriteReceiverForCalldata(data []byte, destChannel string) []byte { + var pd transfertypes.FungibleTokenPacketData + if err := json.Unmarshal(data, &pd); err != nil { + return data + } + if !hasDestCalldata(pd) { + return data + } + intermediate, err := DeriveIntermediateSender(destChannel, pd.Sender, sdk.GetConfig().GetBech32AccountAddrPrefix()) + if err != nil { + return data + } + pd.Receiver = intermediate + out, err := json.Marshal(pd) + if err != nil { + return data + } + return out +} + +// hasDestCalldata returns whether the packet carries a valid, non-empty dest_callback.calldata. +func hasDestCalldata(pd transfertypes.FungibleTokenPacketData) bool { + cbData, isCb, err := callbackstypes.GetCallbackData( + pd, "", "", 0, math.MaxUint64, callbackstypes.DestinationCallbackKey, + ) + return isCb && err == nil && len(cbData.Calldata) != 0 +} + +// IBCV1CallbacksPlusMiddleware rewrites the recv packet's Receiver to the +// intermediate sender when memo carries dest_callback.calldata. +type IBCV1CallbacksPlusMiddleware struct { + callbackstypes.CallbacksCompatibleModule +} + +func NewIBCV1CallbacksPlusMiddleware(app porttypes.IBCModule) *IBCV1CallbacksPlusMiddleware { + compat, ok := app.(callbackstypes.CallbacksCompatibleModule) + if !ok { + panic(fmt.Errorf("wrapped app must implement %T", (*callbackstypes.CallbacksCompatibleModule)(nil))) + } + return &IBCV1CallbacksPlusMiddleware{CallbacksCompatibleModule: compat} +} + +func (m *IBCV1CallbacksPlusMiddleware) OnRecvPacket(ctx sdk.Context, channelVersion string, packet channeltypes.Packet, relayer sdk.AccAddress) ibcexported.Acknowledgement { + packet.Data = rewriteReceiverForCalldata(packet.Data, packet.DestinationChannel) + return m.CallbacksCompatibleModule.OnRecvPacket(ctx, channelVersion, packet, relayer) +} + +// IBCV2CallbacksPlusMiddleware rewrites the recv packet's Receiver to the +// intermediate sender when memo carries dest_callback.calldata. +type IBCV2CallbacksPlusMiddleware struct { + callbackstypes.CallbacksCompatibleModuleV2 +} + +func NewIBCV2CallbacksPlusMiddleware(app ibcapi.IBCModule) *IBCV2CallbacksPlusMiddleware { + compat, ok := app.(callbackstypes.CallbacksCompatibleModuleV2) + if !ok { + panic(fmt.Errorf("wrapped app must implement %T", (*callbackstypes.CallbacksCompatibleModuleV2)(nil))) + } + return &IBCV2CallbacksPlusMiddleware{CallbacksCompatibleModuleV2: compat} +} + +func (m *IBCV2CallbacksPlusMiddleware) OnRecvPacket( + ctx sdk.Context, + sourceClient string, + destinationClient string, + sequence uint64, + payload channeltypesv2.Payload, + relayer sdk.AccAddress, +) channeltypesv2.RecvPacketResult { + if payload.SourcePort == transfertypes.PortID && payload.DestinationPort == transfertypes.PortID { + payload.Value = rewriteReceiverForCalldata(payload.Value, destinationClient) + } + return m.CallbacksCompatibleModuleV2.OnRecvPacket(ctx, sourceClient, destinationClient, sequence, payload, relayer) +} diff --git a/x/wasm/ibc_callbacks_plus_middleware_test.go b/x/wasm/ibc_callbacks_plus_middleware_test.go new file mode 100644 index 0000000000..b2c1720b3d --- /dev/null +++ b/x/wasm/ibc_callbacks_plus_middleware_test.go @@ -0,0 +1,342 @@ +package wasm + +import ( + "encoding/hex" + "encoding/json" + "testing" + + wasmvmtypes "github.com/CosmWasm/wasmvm/v3/types" + "github.com/cometbft/cometbft/libs/rand" + transfertypes "github.com/cosmos/ibc-go/v10/modules/apps/transfer/types" + clienttypes "github.com/cosmos/ibc-go/v10/modules/core/02-client/types" + channeltypes "github.com/cosmos/ibc-go/v10/modules/core/04-channel/types" + channeltypesv2 "github.com/cosmos/ibc-go/v10/modules/core/04-channel/v2/types" + porttypes "github.com/cosmos/ibc-go/v10/modules/core/05-port/types" + ibcexported "github.com/cosmos/ibc-go/v10/modules/core/exported" + mockv2 "github.com/cosmos/ibc-go/v10/testing/mock/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + sdkmath "cosmossdk.io/math" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/types/address" + + "github.com/CosmWasm/wasmd/x/wasm/keeper/wasmtesting" + "github.com/CosmWasm/wasmd/x/wasm/types" +) + +type mockContractOpsKeeper struct { + types.ContractOpsKeeper + executeFn func(ctx sdk.Context, contractAddress, caller sdk.AccAddress, msg []byte, coins sdk.Coins) ([]byte, error) +} + +func (m *mockContractOpsKeeper) Execute(ctx sdk.Context, contractAddress, caller sdk.AccAddress, msg []byte, coins sdk.Coins) ([]byte, error) { + if m.executeFn == nil { + panic("Execute not expected to be called") + } + return m.executeFn(ctx, contractAddress, caller, msg, coins) +} + +type mockDestCallbackKeeper struct { + types.IBCContractKeeper + fn func(ctx sdk.Context, contractAddr sdk.AccAddress, msg wasmvmtypes.IBCDestinationCallbackMsg) error +} + +func (m *mockDestCallbackKeeper) IBCDestinationCallback(ctx sdk.Context, contractAddr sdk.AccAddress, msg wasmvmtypes.IBCDestinationCallbackMsg) error { + return m.fn(ctx, contractAddr, msg) +} + +func TestIBCReceivePacketCallback(t *testing.T) { + myContractAddr := sdk.AccAddress(rand.Bytes(address.Len)) + contractMsg := []byte(`{"swap":{"output_denom":"uatom","min_output":"1000"}}`) + calldataMemo := mustMarshalJSON(t, map[string]any{ + "dest_callback": map[string]any{ + "address": myContractAddr.String(), + "calldata": hex.EncodeToString(contractMsg), + }, + }) + intermediateBech32, err := DeriveIntermediateSender("channel-1", "cosmos1sender", sdk.GetConfig().GetBech32AccountAddrPrefix()) + require.NoError(t, err) + intermediate, err := sdk.AccAddressFromBech32(intermediateBech32) + require.NoError(t, err) + ibcDenom := transfertypes.Denom{ + Base: "uosmo", + Trace: []transfertypes.Hop{transfertypes.NewHop("transfer", "channel-1")}, + }.IBCDenom() + + specs := map[string]struct { + memo string + receiver string + execErr error + expErr string + expExec bool + expDestCB bool + }{ + "rewritten receiver: execute with intermediate caller": { + memo: calldataMemo, + receiver: intermediate.String(), + expExec: true, + }, + "untouched receiver: execute still derives intermediate locally": { + memo: calldataMemo, + receiver: myContractAddr.String(), + expExec: true, + }, + "execute returns error: callback wraps it": { + memo: calldataMemo, + receiver: myContractAddr.String(), + execErr: types.ErrExecuteFailed.Wrap("contract reverted"), + expErr: "execute contract via calldata", + }, + "no calldata: falls through to ibc_destination_callback": { + memo: "", + receiver: myContractAddr.String(), + expDestCB: true, + }, + } + + const amount = "5000" + for name, spec := range specs { + t.Run(name, func(t *testing.T) { + pkt := channeltypes.Packet{ + Sequence: 1, + SourcePort: "transfer", + SourceChannel: "channel-0", + DestinationPort: "transfer", + DestinationChannel: "channel-1", + Data: transfertypes.NewFungibleTokenPacketData( + "uosmo", amount, "cosmos1sender", spec.receiver, spec.memo, + ).GetBytes(), + TimeoutHeight: clienttypes.Height{RevisionHeight: 100}, + } + + var gotExec, gotDestCB bool + executor := &mockContractOpsKeeper{ + executeFn: func(_ sdk.Context, gotContract, gotCaller sdk.AccAddress, gotMsg []byte, gotCoins sdk.Coins) ([]byte, error) { + gotExec = true + assert.Equal(t, myContractAddr, gotContract) + assert.Equal(t, intermediate, gotCaller) + assert.JSONEq(t, string(contractMsg), string(gotMsg)) + require.Len(t, gotCoins, 1) + assert.Equal(t, ibcDenom, gotCoins[0].Denom) + expAmount, ok := sdkmath.NewIntFromString(amount) + require.True(t, ok) + assert.Equal(t, expAmount, gotCoins[0].Amount) + if spec.execErr != nil { + return nil, spec.execErr + } + return []byte("ok"), nil + }, + } + + contractKeeper := &wasmtesting.IBCContractKeeperMock{} + if spec.expDestCB { + contractKeeper.IBCContractKeeper = &mockDestCallbackKeeper{ + fn: func(_ sdk.Context, gotAddr sdk.AccAddress, msg wasmvmtypes.IBCDestinationCallbackMsg) error { + gotDestCB = true + assert.Equal(t, myContractAddr, gotAddr) + require.NotNil(t, msg.Transfer) + assert.Equal(t, amount, msg.Transfer.Funds[0].Amount) + return nil + }, + } + } + + h := NewIBCHandler( + contractKeeper, nil, + &wasmtesting.MockIBCTransferKeeper{GetPortFn: func(ctx sdk.Context) string { return "transfer" }}, + nil, executor, + ) + ctx := sdk.Context{}.WithEventManager(&sdk.EventManager{}) + + gotErr := h.IBCReceivePacketCallback(ctx, pkt, channeltypes.NewResultAcknowledgement([]byte{1}), myContractAddr.String(), "ics20-1") + if spec.expErr != "" { + require.Error(t, gotErr) + assert.Contains(t, gotErr.Error(), spec.expErr) + return + } + require.NoError(t, gotErr) + assert.Equal(t, spec.expExec, gotExec) + assert.Equal(t, spec.expDestCB, gotDestCB) + }) + } +} + +func TestIBCSendPacketCallback(t *testing.T) { + myContractAddr := sdk.AccAddress(rand.Bytes(address.Len)).String() + + specs := map[string]struct { + memo string + expErr string + }{ + "src_callback.calldata rejected": { + memo: mustMarshalJSON(t, map[string]any{ + "src_callback": map[string]any{"address": myContractAddr, "calldata": "deadbeef"}, + }), + expErr: "src_callback must not contain a calldata field", + }, + "src_callback without calldata accepted": { + memo: mustMarshalJSON(t, map[string]any{ + "src_callback": map[string]any{"address": myContractAddr}, + }), + }, + } + + for name, spec := range specs { + t.Run(name, func(t *testing.T) { + transferData := transfertypes.NewFungibleTokenPacketData("uosmo", "100", myContractAddr, "cosmos1receiver", spec.memo) + + h := NewIBCHandler( + &wasmtesting.IBCContractKeeperMock{}, nil, + &wasmtesting.MockIBCTransferKeeper{GetPortFn: func(ctx sdk.Context) string { return "transfer" }}, + nil, nil, + ) + + gotErr := h.IBCSendPacketCallback( + sdk.Context{}, "transfer", "channel-0", + clienttypes.Height{RevisionHeight: 100}, 0, + transferData.GetBytes(), + myContractAddr, myContractAddr, "ics20-1", + ) + if spec.expErr != "" { + require.Error(t, gotErr) + assert.Contains(t, gotErr.Error(), spec.expErr) + return + } + require.NoError(t, gotErr) + }) + } +} + +func mustMarshalJSON(t *testing.T, m map[string]any) string { + t.Helper() + bz, err := json.Marshal(m) + require.NoError(t, err) + return string(bz) +} + +type recordingIBCModule struct { + porttypes.IBCModule + received []byte +} + +func (r *recordingIBCModule) OnRecvPacket(ctx sdk.Context, channelVersion string, packet channeltypes.Packet, relayer sdk.AccAddress) ibcexported.Acknowledgement { + r.received = packet.Data + return channeltypes.NewResultAcknowledgement([]byte{1}) +} + +func (r *recordingIBCModule) UnmarshalPacketData(_ sdk.Context, _, _ string, _ []byte) (any, string, error) { + return nil, "", nil +} + +func TestIBCV1CallbacksPlusMiddleware(t *testing.T) { + calldataHex := hex.EncodeToString([]byte(`{"swap":{}}`)) + + specs := map[string]struct { + memo string + expRewrite bool + }{ + "dest_callback with calldata rewrites receiver": { + memo: mustMarshalJSON(t, map[string]any{ + "dest_callback": map[string]any{"address": "cosmos1contract", "calldata": calldataHex}, + }), + expRewrite: true, + }, + "no memo": {memo: ""}, + "dest_callback without calldata": {memo: `{"dest_callback":{"address":"cosmos1ccc"}}`}, + "malformed memo (not json)": {memo: `{not-json`}, + } + + for name, spec := range specs { + t.Run(name, func(t *testing.T) { + transferData := transfertypes.NewFungibleTokenPacketData("uosmo", "100", "cosmos1sender", "cosmos1receiver", spec.memo) + pkt := channeltypes.Packet{ + Sequence: 1, + SourcePort: "transfer", + SourceChannel: "channel-0", + DestinationPort: "transfer", + DestinationChannel: "channel-1", + Data: transferData.GetBytes(), + TimeoutHeight: clienttypes.Height{RevisionHeight: 100}, + } + + inner := &recordingIBCModule{} + m := NewIBCV1CallbacksPlusMiddleware(inner) + m.OnRecvPacket(sdk.Context{}, "ics20-1", pkt, sdk.AccAddress("relayer")) + + require.NotNil(t, inner.received) + if !spec.expRewrite { + assert.Equal(t, transferData.GetBytes(), inner.received) + return + } + var gotData transfertypes.FungibleTokenPacketData + require.NoError(t, json.Unmarshal(inner.received, &gotData)) + wantReceiver, err := DeriveIntermediateSender("channel-1", "cosmos1sender", sdk.GetConfig().GetBech32AccountAddrPrefix()) + require.NoError(t, err) + assert.Equal(t, wantReceiver, gotData.Receiver) + assert.Equal(t, "cosmos1sender", gotData.Sender) + assert.Equal(t, spec.memo, gotData.Memo) + }) + } +} + +func TestIBCV2CallbacksPlusMiddleware(t *testing.T) { + calldataHex := hex.EncodeToString([]byte(`{"swap":{}}`)) + calldataMemo := mustMarshalJSON(t, map[string]any{ + "dest_callback": map[string]any{"address": "cosmos1contract", "calldata": calldataHex}, + }) + payloadValue := transfertypes.NewFungibleTokenPacketData("uosmo", "100", "cosmos1sender", "cosmos1receiver", calldataMemo).GetBytes() + + specs := map[string]struct { + payload channeltypesv2.Payload + expRewrite bool + }{ + "transfer port with dest_callback.calldata rewrites receiver": { + payload: channeltypesv2.Payload{ + SourcePort: transfertypes.PortID, + DestinationPort: transfertypes.PortID, + Version: "ics20-1", + Encoding: transfertypes.EncodingJSON, + Value: payloadValue, + }, + expRewrite: true, + }, + "non-transfer port passes through unchanged": { + payload: channeltypesv2.Payload{ + SourcePort: "different-port", + DestinationPort: "different-port", + Version: "v1", + Encoding: transfertypes.EncodingJSON, + Value: payloadValue, + }, + }, + } + + for name, spec := range specs { + t.Run(name, func(t *testing.T) { + origValue := append([]byte(nil), spec.payload.Value...) + var gotRecv bool + var gotPayload channeltypesv2.Payload + inner := mockv2.NewIBCModule() + inner.IBCApp.OnRecvPacket = func(_ sdk.Context, _, _ string, _ uint64, payload channeltypesv2.Payload, _ sdk.AccAddress) channeltypesv2.RecvPacketResult { + gotRecv = true + gotPayload = payload + return channeltypesv2.RecvPacketResult{Status: channeltypesv2.PacketStatus_Success, Acknowledgement: []byte{1}} + } + m := NewIBCV2CallbacksPlusMiddleware(inner) + _ = m.OnRecvPacket(sdk.Context{}, "client-0", "client-1", 1, spec.payload, sdk.AccAddress("relayer")) + + require.True(t, gotRecv) + if !spec.expRewrite { + assert.Equal(t, origValue, gotPayload.Value) + return + } + var gotData transfertypes.FungibleTokenPacketData + require.NoError(t, json.Unmarshal(gotPayload.Value, &gotData)) + wantReceiver, err := DeriveIntermediateSender("client-1", "cosmos1sender", sdk.GetConfig().GetBech32AccountAddrPrefix()) + require.NoError(t, err) + assert.Equal(t, wantReceiver, gotData.Receiver) + }) + } +} diff --git a/x/wasm/ibc_dedup_middleware.go b/x/wasm/ibc_dedup_middleware.go new file mode 100644 index 0000000000..4fa0499df4 --- /dev/null +++ b/x/wasm/ibc_dedup_middleware.go @@ -0,0 +1,81 @@ +package wasm + +import ( + "encoding/json" + + transfertypes "github.com/cosmos/ibc-go/v10/modules/apps/transfer/types" + clienttypes "github.com/cosmos/ibc-go/v10/modules/core/02-client/types" + channeltypes "github.com/cosmos/ibc-go/v10/modules/core/04-channel/types" + porttypes "github.com/cosmos/ibc-go/v10/modules/core/05-port/types" + ibcexported "github.com/cosmos/ibc-go/v10/modules/core/exported" + + errorsmod "cosmossdk.io/errors" + + sdk "github.com/cosmos/cosmos-sdk/types" + + "github.com/CosmWasm/wasmd/x/wasm/types" +) + +var ( + _ porttypes.IBCModule = (*IBCDedupMiddleware)(nil) + _ porttypes.ICS4Wrapper = (*IBCDedupMiddleware)(nil) + _ porttypes.Middleware = (*IBCDedupMiddleware)(nil) +) + +// IBCDedupMiddleware rejects same-side Hooks/Callbacks memo collisions. +type IBCDedupMiddleware struct { + porttypes.IBCModule + porttypes.ICS4Wrapper +} + +func NewIBCDedupMiddleware(app porttypes.IBCModule, ics4Wrapper porttypes.ICS4Wrapper) *IBCDedupMiddleware { + return &IBCDedupMiddleware{IBCModule: app, ICS4Wrapper: ics4Wrapper} +} + +func (m *IBCDedupMiddleware) OnRecvPacket(ctx sdk.Context, channelVersion string, packet channeltypes.Packet, relayer sdk.AccAddress) ibcexported.Acknowledgement { + if hasMemoCollision(packet.Data, "wasm", "dest_callback") { + return CreateErrorAcknowledgement(errorsmod.Wrap(types.ErrInvalid, "memo must not contain both wasm (Hooks) and dest_callback (Callbacks)")) + } + return m.IBCModule.OnRecvPacket(ctx, channelVersion, packet, relayer) +} + +func (m *IBCDedupMiddleware) SendPacket( + ctx sdk.Context, + sourcePort string, + sourceChannel string, + timeoutHeight clienttypes.Height, + timeoutTimestamp uint64, + data []byte, +) (uint64, error) { + if hasMemoCollision(data, "ibc_callback", "src_callback") { + return 0, errorsmod.Wrap(types.ErrInvalid, "memo must not contain both ibc_callback (Hooks) and src_callback (Callbacks)") + } + return m.ICS4Wrapper.SendPacket(ctx, sourcePort, sourceChannel, timeoutHeight, timeoutTimestamp, data) +} + +func (m *IBCDedupMiddleware) SetUnderlyingApplication(app porttypes.IBCModule) { + m.IBCModule = app +} + +func (m *IBCDedupMiddleware) SetICS4Wrapper(wrapper porttypes.ICS4Wrapper) { + m.ICS4Wrapper = wrapper +} + +func (m *IBCDedupMiddleware) UnmarshalPacketData(ctx sdk.Context, portID, channelID string, bz []byte) (any, string, error) { + if unmarshaler, ok := m.IBCModule.(porttypes.PacketDataUnmarshaler); ok { + return unmarshaler.UnmarshalPacketData(ctx, portID, channelID, bz) + } + return nil, "", errorsmod.Wrap(types.ErrInvalid, "underlying app does not implement PacketDataUnmarshaler") +} + +// hasMemoCollision returns true if the packet memo contains both +// hooksKey and callbacksKey. Returns false otherwise. +func hasMemoCollision(data []byte, hooksKey, callbacksKey string) bool { + var pd transfertypes.FungibleTokenPacketData + if err := json.Unmarshal(data, &pd); err != nil { + return false + } + hasHooks, memo := jsonStringHasKey(pd.Memo, hooksKey) + _, hasCallbacks := memo[callbacksKey] + return hasHooks && hasCallbacks +} diff --git a/x/wasm/ibc_dedup_middleware_test.go b/x/wasm/ibc_dedup_middleware_test.go new file mode 100644 index 0000000000..c335fc34a6 --- /dev/null +++ b/x/wasm/ibc_dedup_middleware_test.go @@ -0,0 +1,136 @@ +package wasm + +import ( + "testing" + + transfertypes "github.com/cosmos/ibc-go/v10/modules/apps/transfer/types" + clienttypes "github.com/cosmos/ibc-go/v10/modules/core/02-client/types" + channeltypes "github.com/cosmos/ibc-go/v10/modules/core/04-channel/types" + porttypes "github.com/cosmos/ibc-go/v10/modules/core/05-port/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + sdk "github.com/cosmos/cosmos-sdk/types" +) + +type mockICS4Wrapper struct { + porttypes.ICS4Wrapper + sentData []byte +} + +func (m *mockICS4Wrapper) SendPacket(_ sdk.Context, _, _ string, _ clienttypes.Height, _ uint64, data []byte) (uint64, error) { + m.sentData = data + return 1, nil +} + +func TestIBCDedupMiddlewareOnRecvPacket(t *testing.T) { + specs := map[string]struct { + memo map[string]any + rawData []byte + expFail bool + }{ + "wasm + dest_callback (same dest side) rejected": { + memo: map[string]any{ + "wasm": map[string]any{"contract": "cosmos1ccc", "msg": map[string]any{}}, + "dest_callback": map[string]any{"address": "cosmos1ccc"}, + }, + expFail: true, + }, + "wasm + src_callback (cross-side) passes through": { + memo: map[string]any{ + "wasm": map[string]any{"contract": "cosmos1ccc", "msg": map[string]any{}}, + "src_callback": map[string]any{"address": "cosmos1ccc"}, + }, + }, + "dest_callback alone passes through": {memo: map[string]any{"dest_callback": map[string]any{"address": "cosmos1ccc"}}}, + "non-transfer payload passes through": {rawData: []byte("not-a-transfer-packet")}, + } + + for name, spec := range specs { + t.Run(name, func(t *testing.T) { + inner := &recordingIBCModule{} + m := NewIBCDedupMiddleware(inner, &mockICS4Wrapper{}) + + data := spec.rawData + if data == nil { + data = transferPacketFixture(mustMarshalJSON(t, spec.memo)).Data + } + pkt := channeltypes.Packet{ + Sequence: 1, SourcePort: "transfer", SourceChannel: "channel-0", + DestinationPort: "transfer", DestinationChannel: "channel-1", + Data: data, + TimeoutHeight: clienttypes.Height{RevisionHeight: 100}, + } + + gotAck := m.OnRecvPacket(sdk.Context{}, "ics20-1", pkt, sdk.AccAddress("relayer")) + require.NotNil(t, gotAck) + + if spec.expFail { + assert.False(t, gotAck.Success()) + assert.Nil(t, inner.received) + return + } + assert.Equal(t, data, inner.received) + }) + } +} + +func TestIBCDedupMiddlewareSendPacket(t *testing.T) { + specs := map[string]struct { + memo map[string]any + rawData []byte + expFail bool + }{ + "ibc_callback + src_callback (same src side) rejected": { + memo: map[string]any{ + "ibc_callback": "cosmos1ccc", + "src_callback": map[string]any{"address": "cosmos1ccc"}, + }, + expFail: true, + }, + "wasm + src_callback (cross-side) passes through": { + memo: map[string]any{ + "wasm": map[string]any{"contract": "cosmos1ccc", "msg": map[string]any{}}, + "src_callback": map[string]any{"address": "cosmos1ccc"}, + }, + }, + "src_callback alone passes through": {memo: map[string]any{"src_callback": map[string]any{"address": "cosmos1ccc"}}}, + "non-transfer payload passes through": {rawData: []byte("not-a-transfer-packet")}, + } + + for name, spec := range specs { + t.Run(name, func(t *testing.T) { + ics4 := &mockICS4Wrapper{} + m := NewIBCDedupMiddleware(&recordingIBCModule{}, ics4) + + data := spec.rawData + if data == nil { + data = transferPacketFixture(mustMarshalJSON(t, spec.memo)).Data + } + + _, gotErr := m.SendPacket(sdk.Context{}, "transfer", "channel-0", + clienttypes.Height{RevisionHeight: 100}, 0, data) + + if spec.expFail { + require.Error(t, gotErr) + assert.Nil(t, ics4.sentData) + return + } + require.NoError(t, gotErr) + assert.Equal(t, data, ics4.sentData) + }) + } +} + +func transferPacketFixture(memo string) channeltypes.Packet { + td := transfertypes.NewFungibleTokenPacketData("uosmo", "1000", "cosmos1sender", "cosmos1receiver", memo) + return channeltypes.Packet{ + Sequence: 1, + SourcePort: "transfer", + SourceChannel: "channel-0", + DestinationPort: "transfer", + DestinationChannel: "channel-1", + Data: td.GetBytes(), + TimeoutHeight: clienttypes.Height{RevisionHeight: 100}, + } +} diff --git a/x/wasm/ibc_test.go b/x/wasm/ibc_test.go index 66a99a48fe..9f2e081592 100644 --- a/x/wasm/ibc_test.go +++ b/x/wasm/ibc_test.go @@ -110,7 +110,7 @@ func TestOnRecvPacket(t *testing.T) { }, } channelVersion := "" - h := NewIBCHandler(&mock, nil, nil, nil) + h := NewIBCHandler(&mock, nil, nil, nil, nil) em := &sdk.EventManager{} ctx := sdk.Context{}.WithEventManager(em) if spec.expPanic {