Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/finn/source_code/finn.custom_op.fpgadataflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,14 @@ finn.custom\_op.fpgadataflow.labelselect
:undoc-members:
:show-inheritance:

finn.custom\_op.fpgadataflow.selecttoken
-----------------------------------------

.. automodule:: finn.custom_op.fpgadataflow.selecttoken
:members:
:undoc-members:
:show-inheritance:

finn.custom\_op.fpgadataflow.lookup
-----------------------------------------------

Expand Down
8 changes: 8 additions & 0 deletions docs/finn/source_code/finn.custom_op.fpgadataflow.rtl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ finn.custom\_op.fpgadataflow.streamingdatawidthconverter\_rtl
:undoc-members:
:show-inheritance:

finn.custom\_op.fpgadataflow.selecttoken\_rtl
---------------------------------------------------------------

.. automodule:: finn.custom_op.fpgadataflow.rtl.selecttoken_rtl
:members:
:undoc-members:
:show-inheritance:

finn.custom\_op.fpgadataflow.streamingfifo\_rtl
-------------------------------------------------

Expand Down
138 changes: 138 additions & 0 deletions finn-rtllib/selecttoken/hdl/select_token.sv
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
/****************************************************************************
* Copyright Advanced Micro Devices, Inc.
* SPDX-License-Identifier: BSD-3-Clause
*
* @author Oliver Cassidy <oliver.cassidy@amd.com>
* @brief Select one token from a folded token stream.
*
* @description
* Consumes NUM_TOKENS token vectors, each consisting of TOKEN_BEATS stream
* beats. Beats belonging to TOKEN_INDEX are forwarded to the output; all
* other beats are consumed and discarded.
***************************************************************************/

`default_nettype none

module select_token #(
int unsigned NUM_TOKENS,
int unsigned TOKEN_BEATS,
int unsigned DATA_WIDTH,
int unsigned TOKEN_INDEX
)(
// Global Control
input wire logic clk,
input wire logic rst,

// Input Stream
output logic irdy,
input wire logic ivld,
input wire logic [DATA_WIDTH-1:0] idat,

// Output Stream - beats belonging to TOKEN_INDEX
input wire logic ordy,
output logic ovld,
output logic [DATA_WIDTH-1:0] odat
);

localparam int unsigned TOKEN_CNT_BITS = (NUM_TOKENS <= 1)? 1 : $clog2(NUM_TOKENS);
localparam int unsigned BEAT_CNT_BITS = (TOKEN_BEATS <= 1)? 1 : $clog2(TOKEN_BEATS);
typedef logic [TOKEN_CNT_BITS-1:0] token_cnt_t;
typedef logic [ BEAT_CNT_BITS-1:0] beat_cnt_t;
typedef logic [DATA_WIDTH-1:0] data_t;
localparam token_cnt_t TOKEN_INDEX_PRE = (TOKEN_INDEX == 0)? NUM_TOKENS-1 : TOKEN_INDEX-1;
localparam token_cnt_t TOKEN_PRE_LAST = (NUM_TOKENS == 1)? 0 : NUM_TOKENS-2;
localparam beat_cnt_t BEAT_PRE_LAST = (TOKEN_BEATS == 1)? 0 : TOKEN_BEATS-2;

initial begin
if(NUM_TOKENS < 1) begin
$error("%m: NUM_TOKENS must be positive.");
$finish;
end
if(TOKEN_BEATS < 1) begin
$error("%m: TOKEN_BEATS must be positive.");
$finish;
end
if(DATA_WIDTH < 1) begin
$error("%m: DATA_WIDTH must be positive.");
$finish;
end
if(TOKEN_INDEX >= NUM_TOKENS) begin
$error("%m: TOKEN_INDEX must be less than NUM_TOKENS.");
$finish;
end
end

// Beat and Token Position
token_cnt_t TokenCnt = '0; // 0, ..., NUM_TOKENS-1
beat_cnt_t BeatCnt = '0; // 0, ..., TOKEN_BEATS-1
logic Selected = TOKEN_INDEX == 0; // TokenCnt == TOKEN_INDEX
logic BeatLst = TOKEN_BEATS == 1; // BeatCnt == TOKEN_BEATS-1
logic TokenLst = NUM_TOKENS == 1; // TokenCnt == NUM_TOKENS-1

// Selected-Token Forwarding
data_t ADat = 'x;
logic AVld = 0;
data_t BDat = 'x;
logic BVld = 0;

assign irdy = !Selected || !AVld;
assign odat = BDat;
assign ovld = BVld;

uwire take = irdy && ivld;
uwire selected_take = Selected && take;
uwire bload = !BVld || ordy;

always_ff @(posedge clk) begin
if(rst) begin
TokenCnt <= '0;
BeatCnt <= '0;
Selected <= TOKEN_INDEX == 0;
BeatLst <= TOKEN_BEATS == 1;
TokenLst <= NUM_TOKENS == 1;
end
else if(take) begin
if(BeatLst) begin
BeatCnt <= '0;
BeatLst <= TOKEN_BEATS == 1;
Selected <= TokenCnt == TOKEN_INDEX_PRE;
if(TokenLst) begin
TokenCnt <= '0;
TokenLst <= NUM_TOKENS == 1;
end
else begin
TokenCnt <= TokenCnt + 1;
TokenLst <= TokenCnt == TOKEN_PRE_LAST;
end
end
else begin
BeatCnt <= BeatCnt + 1;
BeatLst <= BeatCnt == BEAT_PRE_LAST;
end
end
end

always_ff @(posedge clk) begin
if(rst) begin
ADat <= 'x;
AVld <= 0;
BDat <= 'x;
BVld <= 0;
end
else begin
if(bload) begin
BDat <= AVld? ADat : idat;
BVld <= AVld || selected_take;
end

if(bload) AVld <= 0;
else if(selected_take) begin
ADat <= idat;
AVld <= 1;
end
end
end

endmodule : select_token

`default_nettype wire
77 changes: 77 additions & 0 deletions finn-rtllib/selecttoken/hdl/select_token_template.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/******************************************************************************

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the new lic header

Copyright Advanced Micro Devices, Inc.
SPDX-License-Identifier: BSD-3-Clause

* Copyright (C) 2026, Advanced Micro Devices, Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
* THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
* WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
* OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
* ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/

module $TOP_MODULE_NAME$ #(
parameter FOLD_WIDTH = $FOLD_WIDTH$,
parameter AXI_WIDTH = ((FOLD_WIDTH + 7) / 8) * 8
)(
(* X_INTERFACE_INFO = "xilinx.com:signal:clock:1.0 ap_clk CLK" *)
(* X_INTERFACE_PARAMETER = "ASSOCIATED_BUSIF in0_V:out0_V, ASSOCIATED_RESET ap_rst_n" *)
input ap_clk,
(* X_INTERFACE_PARAMETER = "POLARITY ACTIVE_LOW" *)
input ap_rst_n,

output in0_V_TREADY,
input in0_V_TVALID,
input [AXI_WIDTH-1:0] in0_V_TDATA,

input out0_V_TREADY,
output out0_V_TVALID,
output [AXI_WIDTH-1:0] out0_V_TDATA
);

wire [FOLD_WIDTH-1:0] core_out;

assign out0_V_TDATA[FOLD_WIDTH-1:0] = core_out;

generate
if (AXI_WIDTH > FOLD_WIDTH) begin : gen_pad_tdata
assign out0_V_TDATA[AXI_WIDTH-1:FOLD_WIDTH] = {(AXI_WIDTH-FOLD_WIDTH){1'b0}};
end
endgenerate

select_token #(
.NUM_TOKENS($NUM_TOKENS$),
.TOKEN_BEATS($TOKEN_BEATS$),
.DATA_WIDTH(FOLD_WIDTH),
.TOKEN_INDEX($TOKEN_INDEX$)
) impl (
.clk(ap_clk),
.rst(!ap_rst_n),
.irdy(in0_V_TREADY),
.ivld(in0_V_TVALID),
.idat(in0_V_TDATA[FOLD_WIDTH-1:0]),
.ordy(out0_V_TREADY),
.ovld(out0_V_TVALID),
.odat(core_out)
);

endmodule
1 change: 1 addition & 0 deletions src/finn/builder/build_dataflow_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,7 @@ def apply_if_relevant(model, op_types, transform, desc=""):
)

# Lookup layers
model = apply_if_relevant(model, ["Gather"], to_hw.InferSelectTokenLayer(), "token selection")
model = apply_if_relevant(model, ["Gather"], to_hw.InferLookupLayer(), "lookup layers")

# Activation functions
Expand Down
2 changes: 2 additions & 0 deletions src/finn/custom_op/fpgadataflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def register_custom_op(cls):
from finn.custom_op.fpgadataflow.outer_shuffle import OuterShuffle
from finn.custom_op.fpgadataflow.pool import Pool
from finn.custom_op.fpgadataflow.requant import Requant
from finn.custom_op.fpgadataflow.selecttoken import SelectToken
from finn.custom_op.fpgadataflow.shuffle import Shuffle
from finn.custom_op.fpgadataflow.split import StreamingSplit
from finn.custom_op.fpgadataflow.streamingdataflowpartition import (
Expand Down Expand Up @@ -103,6 +104,7 @@ def register_custom_op(cls):
custom_op["Lookup"] = Lookup
custom_op["OuterShuffle"] = OuterShuffle
custom_op["Pool"] = Pool
custom_op["SelectToken"] = SelectToken
custom_op["Shuffle"] = Shuffle
custom_op["StreamingConcat"] = StreamingConcat
custom_op["StreamingSplit"] = StreamingSplit
Expand Down
2 changes: 2 additions & 0 deletions src/finn/custom_op/fpgadataflow/rtl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from finn.custom_op.fpgadataflow.rtl.layernorm_rtl import LayerNorm_rtl
from finn.custom_op.fpgadataflow.rtl.matrixvectoractivation_rtl import MVAU_rtl
from finn.custom_op.fpgadataflow.rtl.requant_rtl import Requant_rtl
from finn.custom_op.fpgadataflow.rtl.selecttoken_rtl import SelectToken_rtl
from finn.custom_op.fpgadataflow.rtl.streamingdatawidthconverter_rtl import (
StreamingDataWidthConverter_rtl,
)
Expand All @@ -60,6 +61,7 @@
custom_op["StreamingDataWidthConverter_rtl"] = StreamingDataWidthConverter_rtl
custom_op["StreamingFIFO_rtl"] = StreamingFIFO_rtl
custom_op["MVAU_rtl"] = MVAU_rtl
custom_op["SelectToken_rtl"] = SelectToken_rtl
custom_op["VVAU_rtl"] = VVAU_rtl
custom_op["Thresholding_rtl"] = Thresholding_rtl
custom_op["InnerShuffle_rtl"] = InnerShuffle_rtl
Expand Down
Loading
Loading