diff --git a/python/infinity_sdk/infinity/rag_tokenizer.py b/python/infinity_sdk/infinity/rag_tokenizer.py index 916f853c8d..2c51dac06f 100644 --- a/python/infinity_sdk/infinity/rag_tokenizer.py +++ b/python/infinity_sdk/infinity/rag_tokenizer.py @@ -570,11 +570,16 @@ def naive_qie(txt): parser.add_argument('--fine-grained', action='store_true', help='Use fine-grained tokenization') parser.add_argument('--user-dict', help='User dictionary file') + parser.add_argument('-l', '--language', help='Language for stemming (e.g., english, dutch)') args = parser.parse_args() tokenizer = RagTokenizer(debug=True, user_dict=args.user_dict) + # Set language if specified + if args.language: + tokenizer.set_language(args.language) + # Process input if args.file: # File mode diff --git a/src/common/analyzer/rag_analyzer.cppm b/src/common/analyzer/rag_analyzer.cppm index 0b778c8c41..9ec9d22002 100644 --- a/src/common/analyzer/rag_analyzer.cppm +++ b/src/common/analyzer/rag_analyzer.cppm @@ -24,6 +24,7 @@ import :darts_trie; import :stemmer; import :analyzer; import :wordnet_lemmatizer; +import :logger; import third_party; @@ -41,7 +42,9 @@ public: ~RAGAnalyzer(); - void InitStemmer(Language language) { stemmer_->Init(language); } + void InitStemmer(Language language); + + void SetLanguage(const std::string &language); Status Load(); @@ -132,6 +135,8 @@ public: WordNetLemmatizer *wordnet_lemma_{nullptr}; + bool use_lemmatizer_{true}; // WordNet only supports English + std::unique_ptr stemmer_; OpenCC *opencc_{nullptr}; diff --git a/src/common/analyzer/rag_analyzer_impl.cpp b/src/common/analyzer/rag_analyzer_impl.cpp index e85bf36575..24b4f471da 100644 --- a/src/common/analyzer/rag_analyzer_impl.cpp +++ b/src/common/analyzer/rag_analyzer_impl.cpp @@ -46,6 +46,28 @@ static const std::string WORDNET_PATH = "wordnet"; static const std::string OPENCC_PATH = "opencc"; +// Map language names (lowercase) to Stemmer Language enum. +// Used by SetLanguage() to configure language-specific stemming. +static const std::pair SNOWBALL_LANGUAGE_MAP[] = { + {"english", STEM_LANG_ENGLISH}, + {"dutch", STEM_LANG_DUTCH}, + {"german", STEM_LANG_GERMAN}, + {"french", STEM_LANG_FRENCH}, + {"spanish", STEM_LANG_SPANISH}, + {"italian", STEM_LANG_ITALIAN}, + {"portuguese", STEM_LANG_PORTUGUESE}, + {"portuguese br", STEM_LANG_PORTUGUESE}, + {"russian", STEM_LANG_RUSSIAN}, + {"arabic", STEM_LANG_UNKNOWN}, // No Arabic entry in Language enum (stemmer.cppm), so use UNKNOWN to keep default stemming behavior + {"danish", STEM_LANG_DANISH}, + {"finnish", STEM_LANG_FINNISH}, + {"hungarian", STEM_LANG_HUNGARIAN}, + {"norwegian", STEM_LANG_NORWEGIAN}, + {"romanian", STEM_LANG_ROMANIAN}, + {"swedish", STEM_LANG_SWEDISH}, + {"turkish", STEM_LANG_TURKISH}, +}; + static const std::string REGEX_SPLIT_CHAR = R"#(([ ,\.<>/?;'\[\]\`!@#$%^&*$$\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-zA-Z\.-]+|[0-9,\.-]+))#"; @@ -662,6 +684,40 @@ RAGAnalyzer::~RAGAnalyzer() { } } +void RAGAnalyzer::InitStemmer(Language language) { + stemmer_->Init(language); + use_lemmatizer_ = (language == STEM_LANG_ENGLISH); +} + +void RAGAnalyzer::SetLanguage(const std::string &language) { + std::string lang_key = language; + // Convert to lowercase + std::transform(lang_key.begin(), lang_key.end(), lang_key.begin(), [](unsigned char c) { return std::tolower(c); }); + // Trim whitespace + lang_key.erase(lang_key.find_last_not_of(" \t") + 1); + lang_key.erase(0, lang_key.find_first_not_of(" \t")); + + Language stem_lang = STEM_LANG_UNKNOWN; + std::string snowball_lang; + for (const auto &pair : SNOWBALL_LANGUAGE_MAP) { + if (pair.first == lang_key) { + stem_lang = pair.second; + snowball_lang = pair.first; + break; + } + } + + if (stem_lang != STEM_LANG_UNKNOWN) { + stemmer_->Init(stem_lang); + use_lemmatizer_ = (stem_lang == STEM_LANG_ENGLISH); + LOG_DEBUG(fmt::format("Tokenizer language set to '{}' (Snowball: {}, lemmatizer: {})", language, snowball_lang, use_lemmatizer_)); + } else { + // Unsupported language (Chinese, Japanese, Korean, etc.) – + // keep defaults. CJK text uses dictionary segmentation, not stemming. + LOG_DEBUG(fmt::format("Language '{}' has no Snowball stemmer; keeping defaults", language)); + } +} + Status RAGAnalyzer::Load() { fs::path root(dict_path_); fs::path dict_path(root / DICT_PATH); @@ -1332,9 +1388,14 @@ void RAGAnalyzer::EnglishNormalize(const std::vector &tokens, std:: // Apply lowercase before lemmatization to match Python NLTK behavior char *lowercase_term = lowercase_string_buffer_.data(); ToLower(t.c_str(), t.size(), lowercase_term, term_string_buffer_limit_); - std::string lemma_term = wordnet_lemma_->Lemmatize(lowercase_term); + std::string term_to_stem; + if (use_lemmatizer_) { + term_to_stem = wordnet_lemma_->Lemmatize(lowercase_term); + } else { + term_to_stem = lowercase_term; + } std::string stem_term; - stemmer_->Stem(lemma_term, stem_term); + stemmer_->Stem(term_to_stem, stem_term); res.push_back(stem_term); } else { res.push_back(t); @@ -1694,9 +1755,14 @@ std::string RAGAnalyzer::Tokenize(const std::string &line) { // Apply lowercase before lemmatization to match Python NLTK behavior char *lowercase_term = lowercase_string_buffer_.data(); ToLower(term_list[i].c_str(), term_list[i].size(), lowercase_term, term_string_buffer_limit_); - std::string lemma_term = wordnet_lemma_->Lemmatize(lowercase_term); + std::string term_to_stem; + if (use_lemmatizer_) { + term_to_stem = wordnet_lemma_->Lemmatize(lowercase_term); + } else { + term_to_stem = lowercase_term; + } std::string stem_term; - stemmer_->Stem(lemma_term, stem_term); + stemmer_->Stem(term_to_stem, stem_term); res.push_back(stem_term); } continue; @@ -1811,9 +1877,14 @@ std::pair, std::vector>> // Apply lowercase before lemmatization to match Python NLTK behavior char *lowercase_term = lowercase_string_buffer_.data(); ToLower(term.c_str(), term.size(), lowercase_term, term_string_buffer_limit_); - std::string lemma_term = wordnet_lemma_->Lemmatize(lowercase_term); + std::string term_to_stem; + if (use_lemmatizer_) { + term_to_stem = wordnet_lemma_->Lemmatize(lowercase_term); + } else { + term_to_stem = lowercase_term; + } std::string stem_term; - stemmer_->Stem(lemma_term, stem_term); + stemmer_->Stem(term_to_stem, stem_term); tokens.push_back(stem_term); @@ -2136,9 +2207,14 @@ void RAGAnalyzer::EnglishNormalizeWithPosition(const std::vector &t // Apply lowercase before lemmatization to match Python NLTK behavior char *lowercase_term = lowercase_string_buffer_.data(); ToLower(token.c_str(), token.size(), lowercase_term, term_string_buffer_limit_); - std::string lemma_term = wordnet_lemma_->Lemmatize(lowercase_term); + std::string term_to_stem; + if (use_lemmatizer_) { + term_to_stem = wordnet_lemma_->Lemmatize(lowercase_term); + } else { + term_to_stem = lowercase_term; + } std::string stem_term; - stemmer_->Stem(lemma_term, stem_term); + stemmer_->Stem(term_to_stem, stem_term); normalize_tokens.push_back(stem_term); normalize_positions.emplace_back(start_pos, end_pos); diff --git a/src/unit_test/common/analyzer/rag_analyzer_ut.cpp b/src/unit_test/common/analyzer/rag_analyzer_ut.cpp index 6a0041daec..3df679344c 100644 --- a/src/unit_test/common/analyzer/rag_analyzer_ut.cpp +++ b/src/unit_test/common/analyzer/rag_analyzer_ut.cpp @@ -308,3 +308,33 @@ TEST_F(RAGAnalyzerTest, test_fine_grained_tokenize_consistency_with_python) { } infile.close(); } + +TEST_F(RAGAnalyzerTest, test_set_language_dutch) { + if (!analyzer_) { + FAIL() << "RAGAnalyzer not loaded, skipping test"; + } + // Dutch word "huizen" (houses) should stem to "huiz" with Dutch stemmer + // Compare C++ result with Python result + std::string python_cmd = "uv run " + rag_tokenizer_path_ + "/rag_tokenizer.py " + "-l dutch \"huizen\""; + std::cout << "Call Python tokenizer: " << python_cmd << std::endl; + + FILE *pipe = popen(python_cmd.c_str(), "r"); + std::string python_result; + char buffer[128]; + if (pipe) { + while (fgets(buffer, sizeof(buffer), pipe) != nullptr) { + python_result += buffer; + } + pclose(pipe); + } + // Remove trailing newline + python_result.erase(python_result.find_last_not_of(" \n\r\t") + 1); + std::cout << "Python 'huizen' tokenized (Dutch): " << python_result << std::endl; + + analyzer_->SetLanguage("dutch"); + std::string cxx_result = analyzer_->Tokenize("huizen"); + std::cout << "C++ 'huizen' tokenized (Dutch): " << cxx_result << std::endl; + + EXPECT_TRUE(cxx_result.find("huiz") != std::string::npos); + EXPECT_EQ(cxx_result, python_result); +}