diff --git a/tools/library/src/blockwise_gemm_operation_3x.hpp b/tools/library/src/blockwise_gemm_operation_3x.hpp index cdc9ca03be..db6c61ec2d 100644 --- a/tools/library/src/blockwise_gemm_operation_3x.hpp +++ b/tools/library/src/blockwise_gemm_operation_3x.hpp @@ -243,26 +243,39 @@ class BlockwiseGemmUniversal3xOperation : public GemmOperation3xBase operator_args.mainloop.ptr_A = static_cast(arguments->A); operator_args.mainloop.ptr_B = static_cast(arguments->B); - std::unordered_map mapping = { - {RuntimeDatatype::kE4M3, cute::UMMA::MXF8F6F4Format::E4M3}, - {RuntimeDatatype::kE5M2, cute::UMMA::MXF8F6F4Format::E5M2}, - {RuntimeDatatype::kE3M2, cute::UMMA::MXF8F6F4Format::E3M2}, - {RuntimeDatatype::kE2M1, cute::UMMA::MXF8F6F4Format::E2M1} - }; - - auto iter_runtime_a = mapping.find(arguments->runtime_input_datatype_a); - auto iter_runtime_b = mapping.find(arguments->runtime_input_datatype_b); - - if (iter_runtime_a != mapping.end()) { - operator_args.mainloop.runtime_data_type_a = iter_runtime_a->second; - } else { - assert("invalid runtime argument for datatype A!"); + auto runtime_datatype_to_mxf8f6f4 = + [](RuntimeDatatype type, cute::UMMA::MXF8F6F4Format& format) -> Status { + switch (type) { + case RuntimeDatatype::kE4M3: + format = cute::UMMA::MXF8F6F4Format::E4M3; + return Status::kSuccess; + case RuntimeDatatype::kE5M2: + format = cute::UMMA::MXF8F6F4Format::E5M2; + return Status::kSuccess; + case RuntimeDatatype::kE3M2: + format = cute::UMMA::MXF8F6F4Format::E3M2; + return Status::kSuccess; + case RuntimeDatatype::kE2M1: + format = cute::UMMA::MXF8F6F4Format::E2M1; + return Status::kSuccess; + default: + assert(false && "invalid runtime argument for datatype!"); + return Status::kErrorInvalidProblem; + } + }; + + status = runtime_datatype_to_mxf8f6f4( + arguments->runtime_input_datatype_a, + operator_args.mainloop.runtime_data_type_a); + if (status != Status::kSuccess) { + return status; } - if (iter_runtime_b != mapping.end()) { - operator_args.mainloop.runtime_data_type_b = iter_runtime_b->second; - } else { - assert("invalid runtime argument for datatype B!"); + status = runtime_datatype_to_mxf8f6f4( + arguments->runtime_input_datatype_b, + operator_args.mainloop.runtime_data_type_b); + if (status != Status::kSuccess) { + return status; } } diff --git a/tools/library/src/gemm_operation_3x.hpp b/tools/library/src/gemm_operation_3x.hpp index ff03975675..cae9d25fc4 100644 --- a/tools/library/src/gemm_operation_3x.hpp +++ b/tools/library/src/gemm_operation_3x.hpp @@ -47,7 +47,6 @@ #include "cutlass/util/reference/device/tensor_fill.h" #include "cutlass/util/reference/device/tensor_compare.h" #include "cute/tensor.hpp" -#include /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -374,26 +373,39 @@ class GemmUniversal3xOperation : public GemmOperation3xBase { operator_args.mainloop.ptr_A = static_cast(arguments->A); operator_args.mainloop.ptr_B = static_cast(arguments->B); - std::unordered_map mapping = { - {RuntimeDatatype::kE4M3, cute::UMMA::MXF8F6F4Format::E4M3}, - {RuntimeDatatype::kE5M2, cute::UMMA::MXF8F6F4Format::E5M2}, - {RuntimeDatatype::kE3M2, cute::UMMA::MXF8F6F4Format::E3M2}, - {RuntimeDatatype::kE2M1, cute::UMMA::MXF8F6F4Format::E2M1} - }; - - auto iter_runtime_a = mapping.find(arguments->runtime_input_datatype_a); - auto iter_runtime_b = mapping.find(arguments->runtime_input_datatype_b); + auto runtime_datatype_to_mxf8f6f4 = + [](RuntimeDatatype type, cute::UMMA::MXF8F6F4Format& format) -> Status { + switch (type) { + case RuntimeDatatype::kE4M3: + format = cute::UMMA::MXF8F6F4Format::E4M3; + return Status::kSuccess; + case RuntimeDatatype::kE5M2: + format = cute::UMMA::MXF8F6F4Format::E5M2; + return Status::kSuccess; + case RuntimeDatatype::kE3M2: + format = cute::UMMA::MXF8F6F4Format::E3M2; + return Status::kSuccess; + case RuntimeDatatype::kE2M1: + format = cute::UMMA::MXF8F6F4Format::E2M1; + return Status::kSuccess; + default: + assert(false && "invalid runtime argument for datatype!"); + return Status::kErrorInvalidProblem; + } + }; - if (iter_runtime_a != mapping.end()) { - operator_args.mainloop.runtime_data_type_a = iter_runtime_a->second; - } else { - assert("invalid runtime argument for datatype A!"); + status = runtime_datatype_to_mxf8f6f4( + arguments->runtime_input_datatype_a, + operator_args.mainloop.runtime_data_type_a); + if (status != Status::kSuccess) { + return status; } - if (iter_runtime_b != mapping.end()) { - operator_args.mainloop.runtime_data_type_b = iter_runtime_b->second; - } else { - assert("invalid runtime argument for datatype B!"); + status = runtime_datatype_to_mxf8f6f4( + arguments->runtime_input_datatype_b, + operator_args.mainloop.runtime_data_type_b); + if (status != Status::kSuccess) { + return status; } } diff --git a/tools/library/src/sparse_gemm_operation_3x.hpp b/tools/library/src/sparse_gemm_operation_3x.hpp index 04781092f0..dbaa593ae8 100644 --- a/tools/library/src/sparse_gemm_operation_3x.hpp +++ b/tools/library/src/sparse_gemm_operation_3x.hpp @@ -51,7 +51,6 @@ #include "cutlass/util/reference/device/tensor_fill.h" #include "cutlass/util/reference/device/tensor_compare.h" #include "cute/tensor.hpp" -#include /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -196,26 +195,39 @@ class SparseGemmUniversal3xOperation : public GemmOperation3xBase { operator_args.mainloop.ptr_A = static_cast(device_a_compressed_ptr); operator_args.mainloop.ptr_B = static_cast(arguments->B); - std::unordered_map mapping = { - {RuntimeDatatype::kE4M3, cute::UMMA::MXF8F6F4Format::E4M3}, - {RuntimeDatatype::kE5M2, cute::UMMA::MXF8F6F4Format::E5M2}, - {RuntimeDatatype::kE3M2, cute::UMMA::MXF8F6F4Format::E3M2}, - {RuntimeDatatype::kE2M1, cute::UMMA::MXF8F6F4Format::E2M1} - }; - - auto iter_runtime_a = mapping.find(arguments->runtime_input_datatype_a); - auto iter_runtime_b = mapping.find(arguments->runtime_input_datatype_b); - - if (iter_runtime_a != mapping.end()) { - operator_args.mainloop.runtime_data_type_a = iter_runtime_a->second; - } else { - assert("invalid runtime argument for datatype A!"); + auto runtime_datatype_to_mxf8f6f4 = + [](RuntimeDatatype type, cute::UMMA::MXF8F6F4Format& format) -> Status { + switch (type) { + case RuntimeDatatype::kE4M3: + format = cute::UMMA::MXF8F6F4Format::E4M3; + return Status::kSuccess; + case RuntimeDatatype::kE5M2: + format = cute::UMMA::MXF8F6F4Format::E5M2; + return Status::kSuccess; + case RuntimeDatatype::kE3M2: + format = cute::UMMA::MXF8F6F4Format::E3M2; + return Status::kSuccess; + case RuntimeDatatype::kE2M1: + format = cute::UMMA::MXF8F6F4Format::E2M1; + return Status::kSuccess; + default: + assert(false && "invalid runtime argument for datatype!"); + return Status::kErrorInvalidProblem; + } + }; + + status = runtime_datatype_to_mxf8f6f4( + arguments->runtime_input_datatype_a, + operator_args.mainloop.runtime_data_type_a); + if (status != Status::kSuccess) { + return status; } - if (iter_runtime_b != mapping.end()) { - operator_args.mainloop.runtime_data_type_b = iter_runtime_b->second; - } else { - assert("invalid runtime argument for datatype B!"); + status = runtime_datatype_to_mxf8f6f4( + arguments->runtime_input_datatype_b, + operator_args.mainloop.runtime_data_type_b); + if (status != Status::kSuccess) { + return status; } }