Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions include/cute/arch/mma_sm100_umma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1207,8 +1207,10 @@ struct SM100_MMA_S8_2x1SM_SS_SPARSE
}
};

struct SM100_MMA_F8F6F4_SS
{
template <class a_type, class b_type, class c_type, int M, int N,
UMMA::Major a_major, UMMA::Major b_major, UMMA::ScaleIn a_neg,
UMMA::ScaleIn b_neg>
struct SM100_MMA_F8F6F4_SS {
using DRegisters = void;
using ARegisters = uint64_t[1];
using BRegisters = uint64_t[1];
Expand Down
20 changes: 8 additions & 12 deletions include/cute/atom/mma_traits_sm100.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3324,16 +3324,11 @@ struct MMA_Traits<SM100_MMA_S8_2x1SM_SS_SPARSE<a_type, b_type, c_type,
}
};

template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>
struct MMA_Traits<SM100_MMA_F8F6F4_SS, a_type, b_type, c_type,
cute::C<M>, cute::C<N>,
cute::integral_constant<UMMA::Major, a_major>,
cute::integral_constant<UMMA::Major, b_major>,
cute::integral_constant<UMMA::ScaleIn, a_neg>,
cute::integral_constant<UMMA::ScaleIn, b_neg>>
{
template <class a_type, class b_type, class c_type, int M, int N,
UMMA::Major a_major, UMMA::Major b_major, UMMA::ScaleIn a_neg,
UMMA::ScaleIn b_neg>
struct MMA_Traits<SM100_MMA_F8F6F4_SS<a_type, b_type, c_type, M, N, a_major,
b_major, a_neg, b_neg>> {
using ValTypeD = c_type;
using ValTypeA = a_type;
using ValTypeB = b_type;
Expand Down Expand Up @@ -3390,11 +3385,12 @@ struct MMA_Traits<SM100_MMA_F8F6F4_SS, a_type, b_type, c_type,
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);

SM100_MMA_F8F6F4_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
SM100_MMA_F8F6F4_SS<a_type, b_type, c_type, M, N, a_major, b_major, a_neg,
b_neg>::fma(desc_a, desc_b, tmem_c,
uint32_t(traits.accumulate_), idesc);
}
};


template <class a_type, class b_type, class c_type, class sf_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>
Expand Down