diff --git a/docs/finn/source_code/finn.custom_op.fpgadataflow.rst b/docs/finn/source_code/finn.custom_op.fpgadataflow.rst index 84e9633304..24d02a5d93 100644 --- a/docs/finn/source_code/finn.custom_op.fpgadataflow.rst +++ b/docs/finn/source_code/finn.custom_op.fpgadataflow.rst @@ -111,6 +111,14 @@ finn.custom\_op.fpgadataflow.labelselect :undoc-members: :show-inheritance: +finn.custom\_op.fpgadataflow.selecttoken +----------------------------------------- + +.. automodule:: finn.custom_op.fpgadataflow.selecttoken + :members: + :undoc-members: + :show-inheritance: + finn.custom\_op.fpgadataflow.lookup ----------------------------------------------- diff --git a/docs/finn/source_code/finn.custom_op.fpgadataflow.rtl.rst b/docs/finn/source_code/finn.custom_op.fpgadataflow.rtl.rst index 346eddb073..a5c47470ee 100644 --- a/docs/finn/source_code/finn.custom_op.fpgadataflow.rtl.rst +++ b/docs/finn/source_code/finn.custom_op.fpgadataflow.rtl.rst @@ -37,6 +37,14 @@ finn.custom\_op.fpgadataflow.streamingdatawidthconverter\_rtl :undoc-members: :show-inheritance: +finn.custom\_op.fpgadataflow.selecttoken\_rtl +--------------------------------------------------------------- + +.. automodule:: finn.custom_op.fpgadataflow.rtl.selecttoken_rtl + :members: + :undoc-members: + :show-inheritance: + finn.custom\_op.fpgadataflow.streamingfifo\_rtl ------------------------------------------------- diff --git a/finn-rtllib/selecttoken/hdl/select_token.sv b/finn-rtllib/selecttoken/hdl/select_token.sv new file mode 100644 index 0000000000..6a0c4dae1e --- /dev/null +++ b/finn-rtllib/selecttoken/hdl/select_token.sv @@ -0,0 +1,138 @@ +/**************************************************************************** + * Copyright Advanced Micro Devices, Inc. + * SPDX-License-Identifier: BSD-3-Clause + * + * @author Oliver Cassidy + * @brief Select one token from a folded token stream. + * + * @description + * Consumes NUM_TOKENS token vectors, each consisting of TOKEN_BEATS stream + * beats. Beats belonging to TOKEN_INDEX are forwarded to the output; all + * other beats are consumed and discarded. + ***************************************************************************/ + +`default_nettype none + +module select_token #( + int unsigned NUM_TOKENS, + int unsigned TOKEN_BEATS, + int unsigned DATA_WIDTH, + int unsigned TOKEN_INDEX +)( + // Global Control + input wire logic clk, + input wire logic rst, + + // Input Stream + output logic irdy, + input wire logic ivld, + input wire logic [DATA_WIDTH-1:0] idat, + + // Output Stream - beats belonging to TOKEN_INDEX + input wire logic ordy, + output logic ovld, + output logic [DATA_WIDTH-1:0] odat +); + + localparam int unsigned TOKEN_CNT_BITS = (NUM_TOKENS <= 1)? 1 : $clog2(NUM_TOKENS); + localparam int unsigned BEAT_CNT_BITS = (TOKEN_BEATS <= 1)? 1 : $clog2(TOKEN_BEATS); + typedef logic [TOKEN_CNT_BITS-1:0] token_cnt_t; + typedef logic [ BEAT_CNT_BITS-1:0] beat_cnt_t; + typedef logic [DATA_WIDTH-1:0] data_t; + localparam token_cnt_t TOKEN_INDEX_PRE = (TOKEN_INDEX == 0)? NUM_TOKENS-1 : TOKEN_INDEX-1; + localparam token_cnt_t TOKEN_PRE_LAST = (NUM_TOKENS == 1)? 0 : NUM_TOKENS-2; + localparam beat_cnt_t BEAT_PRE_LAST = (TOKEN_BEATS == 1)? 0 : TOKEN_BEATS-2; + + initial begin + if(NUM_TOKENS < 1) begin + $error("%m: NUM_TOKENS must be positive."); + $finish; + end + if(TOKEN_BEATS < 1) begin + $error("%m: TOKEN_BEATS must be positive."); + $finish; + end + if(DATA_WIDTH < 1) begin + $error("%m: DATA_WIDTH must be positive."); + $finish; + end + if(TOKEN_INDEX >= NUM_TOKENS) begin + $error("%m: TOKEN_INDEX must be less than NUM_TOKENS."); + $finish; + end + end + + // Beat and Token Position + token_cnt_t TokenCnt = '0; // 0, ..., NUM_TOKENS-1 + beat_cnt_t BeatCnt = '0; // 0, ..., TOKEN_BEATS-1 + logic Selected = TOKEN_INDEX == 0; // TokenCnt == TOKEN_INDEX + logic BeatLst = TOKEN_BEATS == 1; // BeatCnt == TOKEN_BEATS-1 + logic TokenLst = NUM_TOKENS == 1; // TokenCnt == NUM_TOKENS-1 + + // Selected-Token Forwarding + data_t ADat = 'x; + logic AVld = 0; + data_t BDat = 'x; + logic BVld = 0; + + assign irdy = !Selected || !AVld; + assign odat = BDat; + assign ovld = BVld; + + uwire take = irdy && ivld; + uwire selected_take = Selected && take; + uwire bload = !BVld || ordy; + + always_ff @(posedge clk) begin + if(rst) begin + TokenCnt <= '0; + BeatCnt <= '0; + Selected <= TOKEN_INDEX == 0; + BeatLst <= TOKEN_BEATS == 1; + TokenLst <= NUM_TOKENS == 1; + end + else if(take) begin + if(BeatLst) begin + BeatCnt <= '0; + BeatLst <= TOKEN_BEATS == 1; + Selected <= TokenCnt == TOKEN_INDEX_PRE; + if(TokenLst) begin + TokenCnt <= '0; + TokenLst <= NUM_TOKENS == 1; + end + else begin + TokenCnt <= TokenCnt + 1; + TokenLst <= TokenCnt == TOKEN_PRE_LAST; + end + end + else begin + BeatCnt <= BeatCnt + 1; + BeatLst <= BeatCnt == BEAT_PRE_LAST; + end + end + end + + always_ff @(posedge clk) begin + if(rst) begin + ADat <= 'x; + AVld <= 0; + BDat <= 'x; + BVld <= 0; + end + else begin + if(bload) begin + BDat <= AVld? ADat : idat; + BVld <= AVld || selected_take; + end + + if(bload) AVld <= 0; + else if(selected_take) begin + ADat <= idat; + AVld <= 1; + end + end + end + +endmodule : select_token + +`default_nettype wire diff --git a/finn-rtllib/selecttoken/hdl/select_token_template.v b/finn-rtllib/selecttoken/hdl/select_token_template.v new file mode 100644 index 0000000000..234b4bfaaa --- /dev/null +++ b/finn-rtllib/selecttoken/hdl/select_token_template.v @@ -0,0 +1,77 @@ +/****************************************************************************** + * Copyright (C) 2026, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + *****************************************************************************/ + +module $TOP_MODULE_NAME$ #( + parameter FOLD_WIDTH = $FOLD_WIDTH$, + parameter AXI_WIDTH = ((FOLD_WIDTH + 7) / 8) * 8 +)( + (* X_INTERFACE_INFO = "xilinx.com:signal:clock:1.0 ap_clk CLK" *) + (* X_INTERFACE_PARAMETER = "ASSOCIATED_BUSIF in0_V:out0_V, ASSOCIATED_RESET ap_rst_n" *) + input ap_clk, + (* X_INTERFACE_PARAMETER = "POLARITY ACTIVE_LOW" *) + input ap_rst_n, + + output in0_V_TREADY, + input in0_V_TVALID, + input [AXI_WIDTH-1:0] in0_V_TDATA, + + input out0_V_TREADY, + output out0_V_TVALID, + output [AXI_WIDTH-1:0] out0_V_TDATA +); + + wire [FOLD_WIDTH-1:0] core_out; + + assign out0_V_TDATA[FOLD_WIDTH-1:0] = core_out; + + generate + if (AXI_WIDTH > FOLD_WIDTH) begin : gen_pad_tdata + assign out0_V_TDATA[AXI_WIDTH-1:FOLD_WIDTH] = {(AXI_WIDTH-FOLD_WIDTH){1'b0}}; + end + endgenerate + + select_token #( + .NUM_TOKENS($NUM_TOKENS$), + .TOKEN_BEATS($TOKEN_BEATS$), + .DATA_WIDTH(FOLD_WIDTH), + .TOKEN_INDEX($TOKEN_INDEX$) + ) impl ( + .clk(ap_clk), + .rst(!ap_rst_n), + .irdy(in0_V_TREADY), + .ivld(in0_V_TVALID), + .idat(in0_V_TDATA[FOLD_WIDTH-1:0]), + .ordy(out0_V_TREADY), + .ovld(out0_V_TVALID), + .odat(core_out) + ); + +endmodule diff --git a/src/finn/builder/build_dataflow_steps.py b/src/finn/builder/build_dataflow_steps.py index f2164ca2c1..840526f031 100644 --- a/src/finn/builder/build_dataflow_steps.py +++ b/src/finn/builder/build_dataflow_steps.py @@ -538,6 +538,7 @@ def apply_if_relevant(model, op_types, transform, desc=""): ) # Lookup layers + model = apply_if_relevant(model, ["Gather"], to_hw.InferSelectTokenLayer(), "token selection") model = apply_if_relevant(model, ["Gather"], to_hw.InferLookupLayer(), "lookup layers") # Activation functions diff --git a/src/finn/custom_op/fpgadataflow/__init__.py b/src/finn/custom_op/fpgadataflow/__init__.py index f05198837b..ec0b5c27b4 100644 --- a/src/finn/custom_op/fpgadataflow/__init__.py +++ b/src/finn/custom_op/fpgadataflow/__init__.py @@ -70,6 +70,7 @@ def register_custom_op(cls): from finn.custom_op.fpgadataflow.outer_shuffle import OuterShuffle from finn.custom_op.fpgadataflow.pool import Pool from finn.custom_op.fpgadataflow.requant import Requant +from finn.custom_op.fpgadataflow.selecttoken import SelectToken from finn.custom_op.fpgadataflow.shuffle import Shuffle from finn.custom_op.fpgadataflow.split import StreamingSplit from finn.custom_op.fpgadataflow.streamingdataflowpartition import ( @@ -103,6 +104,7 @@ def register_custom_op(cls): custom_op["Lookup"] = Lookup custom_op["OuterShuffle"] = OuterShuffle custom_op["Pool"] = Pool +custom_op["SelectToken"] = SelectToken custom_op["Shuffle"] = Shuffle custom_op["StreamingConcat"] = StreamingConcat custom_op["StreamingSplit"] = StreamingSplit diff --git a/src/finn/custom_op/fpgadataflow/rtl/__init__.py b/src/finn/custom_op/fpgadataflow/rtl/__init__.py index 520fcdcd12..15baa0e191 100644 --- a/src/finn/custom_op/fpgadataflow/rtl/__init__.py +++ b/src/finn/custom_op/fpgadataflow/rtl/__init__.py @@ -40,6 +40,7 @@ from finn.custom_op.fpgadataflow.rtl.layernorm_rtl import LayerNorm_rtl from finn.custom_op.fpgadataflow.rtl.matrixvectoractivation_rtl import MVAU_rtl from finn.custom_op.fpgadataflow.rtl.requant_rtl import Requant_rtl +from finn.custom_op.fpgadataflow.rtl.selecttoken_rtl import SelectToken_rtl from finn.custom_op.fpgadataflow.rtl.streamingdatawidthconverter_rtl import ( StreamingDataWidthConverter_rtl, ) @@ -60,6 +61,7 @@ custom_op["StreamingDataWidthConverter_rtl"] = StreamingDataWidthConverter_rtl custom_op["StreamingFIFO_rtl"] = StreamingFIFO_rtl custom_op["MVAU_rtl"] = MVAU_rtl +custom_op["SelectToken_rtl"] = SelectToken_rtl custom_op["VVAU_rtl"] = VVAU_rtl custom_op["Thresholding_rtl"] = Thresholding_rtl custom_op["InnerShuffle_rtl"] = InnerShuffle_rtl diff --git a/src/finn/custom_op/fpgadataflow/rtl/selecttoken_rtl.py b/src/finn/custom_op/fpgadataflow/rtl/selecttoken_rtl.py new file mode 100644 index 0000000000..1afd6b56c7 --- /dev/null +++ b/src/finn/custom_op/fpgadataflow/rtl/selecttoken_rtl.py @@ -0,0 +1,132 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import shutil + +from finn.custom_op.fpgadataflow.rtlbackend import RTLBackend +from finn.custom_op.fpgadataflow.selecttoken import SelectToken + + +def _rtlsrc_dir(): + return os.environ["FINN_ROOT"] + "/finn-rtllib/selecttoken/hdl" + + +class SelectToken_rtl(SelectToken, RTLBackend): + """RTL implementation of SelectToken.""" + + def __init__(self, onnx_node, **kwargs): + super().__init__(onnx_node, **kwargs) + + def get_nodeattr_types(self): + my_attrs = {} + my_attrs.update(SelectToken.get_nodeattr_types(self)) + my_attrs.update(RTLBackend.get_nodeattr_types(self)) + return my_attrs + + def generate_hdl(self, model, fpgapart, clk): + simd = self.get_nodeattr("SIMD") + num_channels = self.get_nodeattr("NumChannels") + token_index = self.get_nodeattr("TokenIndex") + num_tokens = self.get_nodeattr("NumTokens") + if token_index < 0: + token_index += num_tokens + assert num_channels % simd == 0, "SIMD must divide NumChannels" + assert 0 <= token_index < num_tokens, "TokenIndex must select an existing token" + token_beats = num_channels // simd + + rtlsrc = _rtlsrc_dir() + template_path = rtlsrc + "/select_token_template.v" + with open(template_path, "r") as f: + template = f.read() + + topname = self.get_verilog_top_module_name() + self.set_nodeattr("gen_top_module", topname) + + elem_width = self.get_input_datatype().bitwidth() + fold_width = elem_width * simd + code_gen_dict = { + "TOP_MODULE_NAME": topname, + "NUM_TOKENS": num_tokens, + "TOKEN_BEATS": token_beats, + "TOKEN_INDEX": token_index, + "FOLD_WIDTH": fold_width, + } + + for key, value in code_gen_dict.items(): + template = template.replace("$%s$" % key, str(value)) + + code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + with open(os.path.join(code_gen_dir, topname + ".v"), "w") as f: + f.write(template) + shutil.copy(rtlsrc + "/select_token.sv", code_gen_dir) + + self.set_nodeattr("ipgen_path", code_gen_dir) + self.set_nodeattr("ip_path", code_gen_dir) + + def get_rtl_file_list(self, abspath=False): + if abspath: + code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + "/" + rtllib_dir = _rtlsrc_dir() + "/" + else: + code_gen_dir = "" + rtllib_dir = "" + + verilog_files = [ + rtllib_dir + "select_token.sv", + code_gen_dir + self.get_nodeattr("gen_top_module") + ".v", + ] + return verilog_files + + def code_generation_ipi(self): + code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + sourcefiles = self.get_rtl_file_list() + sourcefiles = [os.path.join(code_gen_dir, f) for f in sourcefiles] + + cmd = [] + for f in sourcefiles: + cmd += ["add_files -norecurse %s" % f] + cmd += [ + "create_bd_cell -type module -reference %s %s" + % (self.get_nodeattr("gen_top_module"), self.onnx_node.name) + ] + return cmd + + def execute_node(self, context, graph): + mode = self.get_nodeattr("exec_mode") + if mode == "cppsim": + SelectToken.execute_node(self, context, graph) + elif mode == "rtlsim": + RTLBackend.execute_node(self, context, graph) + else: + raise Exception( + """Invalid value for attribute exec_mode! Is currently set to: {} + has to be set to one of the following values ("cppsim", "rtlsim")""".format( + mode + ) + ) diff --git a/src/finn/custom_op/fpgadataflow/selecttoken.py b/src/finn/custom_op/fpgadataflow/selecttoken.py new file mode 100644 index 0000000000..8139fbfbc8 --- /dev/null +++ b/src/finn/custom_op/fpgadataflow/selecttoken.py @@ -0,0 +1,155 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import numpy as np +import warnings +from qonnx.core.datatype import DataType + +from finn.custom_op.fpgadataflow.hwcustomop import HWCustomOp + + +class SelectToken(HWCustomOp): + """Select one token vector from a sequence of token vectors.""" + + def __init__(self, onnx_node, **kwargs): + super().__init__(onnx_node, **kwargs) + + def get_nodeattr_types(self): + my_attrs = super().get_nodeattr_types() + my_attrs.update( + { + "NumTokens": ("i", True, 0), + "NumChannels": ("i", True, 0), + "TokenIndex": ("i", True, 0), + "SIMD": ("i", False, 1), + "inputDataType": ("s", True, ""), + "outputDataType": ("s", False, ""), + } + ) + return my_attrs + + def get_normal_input_shape(self, ind=0): + if ind != 0: + raise Exception("SelectToken only has one input") + return (1, self.get_nodeattr("NumTokens"), self.get_nodeattr("NumChannels")) + + def get_folded_input_shape(self, ind=0): + normal_shape = self.get_normal_input_shape(ind) + simd = self.get_nodeattr("SIMD") + num_channels = normal_shape[-1] + assert num_channels % simd == 0, "SIMD must divide NumChannels" + return normal_shape[:-1] + (num_channels // simd, simd) + + def get_normal_output_shape(self, ind=0): + return (1, self.get_nodeattr("NumChannels")) + + def get_folded_output_shape(self, ind=0): + normal_shape = self.get_normal_output_shape(ind) + simd = self.get_nodeattr("SIMD") + num_channels = normal_shape[-1] + assert num_channels % simd == 0, "SIMD must divide NumChannels" + return normal_shape[:-1] + (num_channels // simd, simd) + + def make_shape_compatible_op(self, model): + exp_ishape = self.get_normal_input_shape() + ishape = tuple(model.get_tensor_shape(self.onnx_node.input[0])) + assert ishape == exp_ishape, "Unexpected input shape for token sequence." + return super().make_const_shape_op(self.get_normal_output_shape()) + + def infer_node_datatype(self, model): + node = self.onnx_node + attr_idt = None + if self.get_nodeattr("inputDataType") != "": + attr_idt = self.get_input_datatype() + + idt = model.get_tensor_datatype(node.input[0]) + if idt is None: + idt = attr_idt + if idt is None: + raise Exception("SelectToken input datatype is not set") + + if attr_idt is not None and attr_idt != idt: + warnings.warn( + "inputDataType changing for %s: %s -> %s" % (node.name, str(attr_idt), str(idt)) + ) + self.set_nodeattr("inputDataType", idt.name) + + attr_odt = self.get_nodeattr("outputDataType") + if attr_odt != "" and DataType[attr_odt] != idt: + warnings.warn( + "outputDataType changing for %s: %s -> %s" + % (node.name, str(DataType[attr_odt]), str(idt)) + ) + self.set_nodeattr("outputDataType", idt.name) + model.set_tensor_datatype(node.output[0], idt) + + def verify_node(self): + pass + + def get_input_datatype(self, ind=0): + return DataType[self.get_nodeattr("inputDataType")] + + def get_output_datatype(self, ind=0): + odt = self.get_nodeattr("outputDataType") + if odt == "": + return self.get_input_datatype(ind) + return DataType[odt] + + def get_instream_width(self, ind=0): + if ind != 0: + return 0 + return self.get_input_datatype().bitwidth() * self.get_nodeattr("SIMD") + + def get_outstream_width(self, ind=0): + return self.get_output_datatype().bitwidth() * self.get_nodeattr("SIMD") + + def get_number_output_values(self): + return int(np.prod(self.get_folded_output_shape()[:-1])) + + def get_exp_cycles(self): + return int(np.prod(self.get_folded_input_shape()[:-1])) + + def execute_node(self, context, graph): + node = self.onnx_node + inp = context[node.input[0]] + token_index = self.get_nodeattr("TokenIndex") + num_tokens = self.get_nodeattr("NumTokens") + if token_index < 0: + token_index += num_tokens + assert 0 <= token_index < num_tokens, "TokenIndex must select an existing token." + + result = inp[:, token_index, :] + context[node.output[0]] = np.asarray(result, dtype=np.float32).reshape( + self.get_normal_output_shape() + ) + + def bram_estimation(self): + return 0 + + def lut_estimation(self): + return 200 diff --git a/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py index f7b7beee14..3482961500 100644 --- a/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py +++ b/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py @@ -1265,6 +1265,86 @@ def apply(self, model): return (model, graph_modified) +class InferSelectTokenLayer(Transformation): + """Convert scalar Gather(input, token_index, axis=1) into SelectToken.""" + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for node in graph.node: + node_ind += 1 + if node.op_type != "Gather": + continue + + axis = get_by_name(node.attribute, "axis") + if axis is None or len(node.input) != 2: + continue + + seq_name = node.input[0] + idx_name = node.input[1] + idx_init = model.get_initializer(idx_name) + if idx_init is None or idx_init.size != 1: + continue + if model.get_initializer(seq_name) is not None: + continue + + seq_shape = model.get_tensor_shape(seq_name) + if seq_shape is None or any(x is None for x in seq_shape): + continue + + rank = len(seq_shape) + gather_axis = axis.i if axis.i >= 0 else axis.i + rank + if rank != 3 or gather_axis != 1: + continue + + token_index = int(idx_init.flatten()[0]) + num_tokens = int(seq_shape[1]) + if token_index < 0: + token_index += num_tokens + if token_index < 0 or token_index >= num_tokens: + continue + + out_shape = model.get_tensor_shape(node.output[0]) + exp_oshape = [int(seq_shape[0]), int(seq_shape[2])] + if out_shape is not None and list(out_shape) != exp_oshape: + continue + if seq_shape[0] != 1: + continue + + idt = model.get_tensor_datatype(seq_name) + if idt is None or not idt.is_integer(): + continue + odt = model.get_tensor_datatype(node.output[0]) + if odt is None: + odt = idt + elif odt != idt: + continue + + new_node = helper.make_node( + "SelectToken", + [seq_name], + node.output, + domain="finn.custom_op.fpgadataflow", + backend="fpgadataflow", + name="SelectToken_" + node.name, + NumTokens=num_tokens, + NumChannels=int(seq_shape[2]), + TokenIndex=token_index, + SIMD=1, + inputDataType=idt.name, + outputDataType=odt.name, + ) + graph.node.insert(node_ind, new_node) + graph.node.remove(node) + graph_modified = True + + if graph_modified: + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + return (model, graph_modified) + + class InferSplitLayer(Transformation): """Convert suitable Split nodes (operating on last/-1 axis) into StreamingConcat HW layers.""" diff --git a/tests/fpgadataflow/test_fpgadataflow_selecttoken.py b/tests/fpgadataflow/test_fpgadataflow_selecttoken.py new file mode 100644 index 0000000000..78fdd5ed9a --- /dev/null +++ b/tests/fpgadataflow/test_fpgadataflow_selecttoken.py @@ -0,0 +1,272 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +import numpy as np +from functools import partial +from onnx import TensorProto, helper, numpy_helper +from qonnx.core.datatype import DataType +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.custom_op.registry import getCustomOp +from qonnx.transformation.general import GiveUniqueNodeNames + +from finn.analysis.fpgadataflow.exp_cycles_per_layer import exp_cycles_per_layer +from finn.analysis.fpgadataflow.res_estimation import ( + res_estimation, + res_estimation_complete, +) +from finn.core.onnx_exec import execute_onnx +from finn.transformation.fpgadataflow.convert_to_hw_layers import InferSelectTokenLayer +from finn.transformation.fpgadataflow.create_stitched_ip import CreateStitchedIP +from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP +from finn.transformation.fpgadataflow.insert_fifo import InsertFIFO +from finn.transformation.fpgadataflow.prepare_ip import PrepareIP +from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim +from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode +from finn.transformation.fpgadataflow.specialize_layers import SpecializeLayers +from finn.util.vivado import parse_ooc_synth_results + +FPGA_PART = "xc7z020clg400-1" +CLK_NS = 10 + + +def _make_graph(nodes, output_shape, idx_values=None, finn_dtype=DataType["INT8"]): + tokens_shape = [1, 4, 4] + tokens = helper.make_tensor_value_info("tokens", TensorProto.FLOAT, tokens_shape) + output = helper.make_tensor_value_info("out", TensorProto.FLOAT, output_shape) + initializers = [] + if idx_values is not None: + initializers.append(numpy_helper.from_array(idx_values, name="idx")) + graph = helper.make_graph(nodes, "selecttoken_test", [tokens], [output], initializers) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 11)]) + model = ModelWrapper(model) + for tensor_name in ["tokens", "out"]: + model.set_tensor_datatype(tensor_name, finn_dtype) + return model + + +def _make_gather_model(token_index=0): + idx_values = np.asarray(token_index, dtype=np.int64) + gather = helper.make_node( + "Gather", + ["tokens", "idx"], + ["out"], + axis=1, + name="gather_token", + ) + return _make_graph([gather], [1, 4], idx_values) + + +def _make_selecttoken_model(token_index=0, simd=1, finn_dtype=DataType["INT8"]): + select = helper.make_node( + "SelectToken", + ["tokens"], + ["out"], + domain="finn.custom_op.fpgadataflow", + backend="fpgadataflow", + name="SelectToken_0", + NumTokens=4, + NumChannels=4, + TokenIndex=token_index, + SIMD=simd, + inputDataType=finn_dtype.name, + outputDataType=finn_dtype.name, + ) + return _make_graph([select], [1, 4], None, finn_dtype) + + +def _prepare_selecttoken_stitched_ip_model(simd=1, token_index=0, run_pnr=False): + model = _make_selecttoken_model(token_index=token_index, simd=simd) + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(InsertFIFO(create_shallow_fifos=True)) + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(PrepareIP(FPGA_PART, CLK_NS)) + model = model.transform(HLSSynthIP()) + model = model.transform(CreateStitchedIP(FPGA_PART, CLK_NS, run_pnr=run_pnr)) + return model + + +def _make_input_dict(model, tokens): + return {model.graph.input[0].name: tokens} + + +@pytest.mark.fpgadataflow +def test_convert_gather_to_selecttoken(): + model = _make_gather_model(token_index=2) + tokens = np.arange(16, dtype=np.float32).reshape(1, 4, 4) + expected = tokens[:, 2, :] + + ret = execute_onnx(model, _make_input_dict(model, tokens)) + assert (ret["out"] == expected).all() + + model = model.transform(InferSelectTokenLayer()) + node = model.graph.node[0] + assert node.op_type == "SelectToken" + assert node.domain == "finn.custom_op.fpgadataflow" + assert list(node.input) == ["tokens"] + + inst = getCustomOp(node) + assert inst.get_normal_output_shape() == (1, 4) + assert inst.get_exp_cycles() == 16 + assert inst.get_nodeattr("TokenIndex") == 2 + + ret = execute_onnx(model, _make_input_dict(model, tokens)) + assert (ret["out"] == expected).all() + + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + assert model.graph.node[0].op_type == "SelectToken_rtl" + assert model.graph.node[0].domain == "finn.custom_op.fpgadataflow.rtl" + + +@pytest.mark.fpgadataflow +@pytest.mark.parametrize("token_index", [0, 1, 3]) +def test_selecttoken_python_execution(token_index): + model = _make_selecttoken_model(token_index=token_index) + tokens = np.arange(16, dtype=np.float32).reshape(1, 4, 4) + expected = tokens[:, token_index, :] + + ret = execute_onnx(model, _make_input_dict(model, tokens)) + assert (ret["out"] == expected).all() + + +@pytest.mark.fpgadataflow +@pytest.mark.parametrize( + "finn_dtype,fold_width", + [(DataType["INT8"], 16), (DataType["UINT4"], 8), (DataType["BIPOLAR"], 2)], +) +def test_selecttoken_rtl_codegen(tmp_path, finn_dtype, fold_width): + model = _make_selecttoken_model(token_index=3, simd=2, finn_dtype=finn_dtype) + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + + node = model.graph.node[0] + inst = getCustomOp(node) + inst.set_nodeattr("code_gen_dir_ipgen", str(tmp_path)) + inst.code_generation_ipgen(model, FPGA_PART, CLK_NS) + + topname = inst.get_nodeattr("gen_top_module") + assert topname == node.name + wrapper = tmp_path / (topname + ".v") + core = tmp_path / "select_token.sv" + assert wrapper.is_file() + assert core.is_file() + wrapper_text = wrapper.read_text() + assert "parameter FOLD_WIDTH = %d" % fold_width in wrapper_text + assert ".TOKEN_BEATS(2)" in wrapper_text + assert ".DATA_WIDTH(FOLD_WIDTH)" in wrapper_text + assert ".TOKEN_INDEX(3)" in wrapper_text + assert "select_token #(" in wrapper_text + assert "out0_V_TVALID" in wrapper_text + + ipi_cmds = inst.code_generation_ipi() + assert any("select_token.sv" in cmd for cmd in ipi_cmds) + assert any("create_bd_cell" in cmd and topname in cmd for cmd in ipi_cmds) + + +@pytest.mark.fpgadataflow +def test_selecttoken_resource_estimation(): + model = _make_selecttoken_model(token_index=1, simd=2) + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + + expected = { + "BRAM_18K": 0, + "BRAM_efficiency": 1, + "LUT": 200, + "URAM": 0, + "URAM_efficiency": 1, + "DSP": 0, + } + resources = model.analysis(partial(res_estimation, fpgapart=FPGA_PART)) + assert len(resources) == 1 + assert list(resources.values())[0] == expected + + complete_resources = model.analysis(partial(res_estimation_complete, fpgapart=FPGA_PART)) + assert len(complete_resources) == 1 + assert list(complete_resources.values())[0] == [expected] + + +@pytest.mark.fpgadataflow +@pytest.mark.vivado +@pytest.mark.slow +@pytest.mark.parametrize("simd,token_index", [(1, 0), (2, 3)]) +def test_selecttoken_rtlsim(simd, token_index): + model = _make_selecttoken_model(token_index=token_index, simd=simd) + tokens = np.arange(16, dtype=np.float32).reshape(1, 4, 4) + expected = tokens[:, token_index, :] + + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(PrepareIP(FPGA_PART, CLK_NS)) + model = model.transform(SetExecMode("rtlsim")) + model = model.transform(PrepareRTLSim()) + + ret = execute_onnx(model, _make_input_dict(model, tokens)) + assert (ret["out"] == expected).all() + + node = model.get_nodes_by_op_type("SelectToken_rtl")[0] + inst = getCustomOp(node) + cycles_rtlsim = inst.get_nodeattr("cycles_rtlsim") + exp_cycles_dict = model.analysis(exp_cycles_per_layer) + exp_cycles = exp_cycles_dict[node.name] + assert np.isclose(exp_cycles, cycles_rtlsim, atol=10) + assert exp_cycles != 0 + + +@pytest.mark.fpgadataflow +@pytest.mark.vivado +@pytest.mark.slow +@pytest.mark.parametrize("simd,token_index", [(1, 0), (2, 3)]) +def test_selecttoken_stitched_ip_rtlsim(simd, token_index): + model = _prepare_selecttoken_stitched_ip_model(simd=simd, token_index=token_index) + tokens = np.arange(16, dtype=np.float32).reshape(1, 4, 4) + expected = tokens[:, token_index, :] + + model.set_metadata_prop("exec_mode", "rtlsim") + + ret = execute_onnx(model, _make_input_dict(model, tokens)) + assert (ret["out"] == expected).all() + + +@pytest.mark.fpgadataflow +@pytest.mark.vivado +@pytest.mark.slow +def test_selecttoken_stitched_ip_synth_ooc(): + model = _prepare_selecttoken_stitched_ip_model(simd=2, token_index=1, run_pnr=True) + ret = parse_ooc_synth_results(model.get_metadata_prop("vivado_stitch_proj")) + assert ret is not None + + assert ret["LUT"] > 0 + assert ret["FF"] > 0 + assert ret.get("DSP", 0) == 0 + assert ret.get("BRAM_18K", 0) == 0 + assert ret.get("BRAM_36K", 0) == 0 + assert ret["WNS"] >= 0