Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions include/cute/stride.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,20 @@ namespace detail {
template <class Major>
struct CompactLambda;

// Workaround for MSVC's inability to deduce a non-type parameter pack from a dependent template alias, causing error C3545.
template <class Major, class Shape>
struct CompactSeq;

template <class Shape>
struct CompactSeq<LayoutLeft, Shape> {
using type = tuple_seq<Shape>;
};

template <class Shape>
struct CompactSeq<LayoutRight, Shape> {
using type = tuple_rseq<Shape>;
};

// @pre is_integral<Current>
// Return (result, current * product(shape)) to enable recurrence
template <class Major, class Shape, class Current>
Expand All @@ -296,7 +310,7 @@ compact(Shape const& shape,
{
if constexpr (is_tuple<Shape>::value) { // Shape::tuple Current::int
using Lambda = CompactLambda<Major>; // Append or Prepend
using Seq = typename Lambda::template seq<Shape>; // Seq or RSeq
using Seq = typename CompactSeq<Major, Shape>::type; // Seq or RSeq
return cute::detail::fold(shape, cute::make_tuple(cute::make_tuple(), current), Lambda{}, Seq{});
} else { // Shape::int Current::int
if constexpr (is_constant<1, Shape>::value) {
Expand All @@ -319,9 +333,6 @@ struct CompactLambda<LayoutLeft>
auto result = detail::compact<LayoutLeft>(si, get<1>(init));
return cute::make_tuple(append(get<0>(init), get<0>(result)), get<1>(result)); // Append
}

template <class Shape>
using seq = tuple_seq<Shape>; // Seq
};

// For GCC8.5 -- Specialization LayoutRight
Expand All @@ -334,9 +345,6 @@ struct CompactLambda<LayoutRight>
auto result = detail::compact<LayoutRight>(si, get<1>(init));
return cute::make_tuple(prepend(get<0>(init), get<0>(result)), get<1>(result)); // Prepend
}

template <class Shape>
using seq = tuple_rseq<Shape>; // RSeq
};

} // end namespace detail
Expand Down