-
Notifications
You must be signed in to change notification settings - Fork 0
feat: evaluate translations #234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,7 +18,8 @@ | |
| from ml_filter.llm_client import LLMClient | ||
| from ml_filter.sample_from_hf_dataset import sample_from_hf_dataset, upload_file_to_hf | ||
| from ml_filter.training.annotator_model_pipeline import run_annotator_training_pipeline | ||
| from ml_filter.translate import TranslationServiceType, TranslatorFactory | ||
| from ml_filter.translation.translate import TranslationServiceType, TranslatorFactory | ||
| from ml_filter.translation.translation_evaluation import evaluate_translations | ||
| from ml_filter.utils.chunk_data import chunk_jsonl | ||
| from ml_filter.utils.manipulate_datasets import apply_score_transforms, convert_hf_dataset_to_jsonl, split_dataset | ||
| from ml_filter.utils.manipulate_documents import merge_and_sort_jsonl_files | ||
|
|
@@ -757,5 +758,28 @@ def _get_target_language_codes_list_helper(target_language_codes: str) -> list[s | |
| return [lang_code.strip().lower() for lang_code in target_language_codes.split(",")] | ||
|
|
||
|
|
||
| @main.command(name="evaluate_translations") | ||
| @click.option("--data-dir", required=True, help="Directory containing translation JSONL files") | ||
| @click.option("--gold-path", required=True, help="Path to gold reference JSONL file") | ||
| @click.option("--model-name", default="Unbabel/wmt22-cometkiwi-da", help="COMET model to use") | ||
| @click.option("--languages", type=str, required=True, help="Comma-separated list of supported language codes") | ||
| @click.option("--batch-size", help="Batch size for processing translations") | ||
| def evaluate_translations_cli( | ||
| data_dir: str, | ||
| gold_path: str, | ||
| model_name: str, | ||
| languages: str, | ||
| batch_size: int, | ||
| ): | ||
| """CLI entry point for evaluating translation quality.""" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should document that the files in |
||
| evaluate_translations( | ||
| data_dir=data_dir, | ||
| gold_path=gold_path, | ||
| languages=languages.split(","), | ||
| model_name=model_name, | ||
| batch_size=batch_size, | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,102 @@ | ||||||
| import json | ||||||
| import logging | ||||||
| import os | ||||||
|
|
||||||
| import numpy as np | ||||||
| from comet import download_model, load_from_checkpoint | ||||||
|
|
||||||
| logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") | ||||||
|
|
||||||
|
|
||||||
| def _load_gold_dict(gold_path: str) -> dict[str, str]: | ||||||
| """Load reference translations from a JSONL file. | ||||||
|
|
||||||
| Args: | ||||||
| gold_path: Path to the gold reference JSONL file. | ||||||
|
|
||||||
| Returns: | ||||||
| A dictionary mapping document IDs to reference texts. | ||||||
| """ | ||||||
| gold_dict = {} | ||||||
| with open(gold_path, "r") as f: | ||||||
| for line in f: | ||||||
| item = json.loads(line) | ||||||
| gold_dict[item["document_id"]] = item["text"] | ||||||
| return gold_dict | ||||||
|
|
||||||
|
|
||||||
| def _prepare_translation_input(file_path: str, gold_dict: dict[str, str]) -> list[dict[str, str]]: | ||||||
| """Extract source and machine-translated texts from a JSONL file. | ||||||
|
|
||||||
| Args: | ||||||
| file_path: Path to the target JSONL file. | ||||||
| lang: Language code. | ||||||
| gold_dict: Dictionary of gold references. | ||||||
|
|
||||||
| Returns: | ||||||
| A list of dictionaries containing 'src' and 'mt' keys. | ||||||
| """ | ||||||
| target_texts = [] | ||||||
| with open(file_path, "r") as f: | ||||||
| for line_num, line in enumerate(f, 1): | ||||||
| if not line: | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. line is never a boolean, but it treated here like one. It also works with "None" but in my opinion it is bad style.
Suggested change
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we maybe even raise an exception here?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. line will never be None since it is a string. |
||||||
| continue | ||||||
| try: | ||||||
| document = json.loads(line) | ||||||
| doc_id = document["document_id"] | ||||||
| text = document["text"] | ||||||
|
|
||||||
| if doc_id not in gold_dict: | ||||||
| logging.warning(f"doc_id {doc_id} not found in gold references.") | ||||||
| continue | ||||||
|
|
||||||
| target_texts.append({"src": gold_dict[doc_id], "mt": text}) | ||||||
| except json.JSONDecodeError as e: | ||||||
| logging.warning(f"Skipping invalid line {line_num} in {file_path}: {e}") | ||||||
| continue | ||||||
| return target_texts | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should add a warning if len(target_texts) != len(gold_dict) |
||||||
|
|
||||||
|
|
||||||
| def evaluate_translations( | ||||||
| data_dir: str, | ||||||
| gold_path: str, | ||||||
| languages: list[str], | ||||||
| batch_size: int, | ||||||
| model_name: str = "Unbabel/wmt22-cometkiwi-da", | ||||||
| ) -> None: | ||||||
| """Evaluate translation quality for a set of files using a COMET model. | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe explain how we actually want to evaluate it? Like 2-3 sentences about pitching the idea. |
||||||
|
|
||||||
| Args: | ||||||
| data_dir: Directory containing translation JSONL files. | ||||||
| gold_path: Path to gold reference JSONL file. | ||||||
| languages: List of supported language codes. | ||||||
| model_name: COMET model to use. | ||||||
| """ | ||||||
| model_path = download_model(model_name) | ||||||
| model = load_from_checkpoint(model_path) | ||||||
|
|
||||||
| gold_dict = _load_gold_dict(gold_path) | ||||||
| quality_dict = {} | ||||||
|
|
||||||
| for filename in os.listdir(data_dir): | ||||||
| if filename.endswith(".jsonl"): | ||||||
| file_path = os.path.join(data_dir, filename) | ||||||
| lang = filename.split("_")[5] | ||||||
|
||||||
|
|
||||||
| if lang not in languages: | ||||||
| logging.info(f"Skipping file with unsupported language: {file_path}") | ||||||
| continue | ||||||
|
|
||||||
| target_texts = _prepare_translation_input(file_path, gold_dict) | ||||||
|
|
||||||
| if target_texts: | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| # TODO: ;ultiple GPUs handling | ||||||
|
mali-git marked this conversation as resolved.
Outdated
|
||||||
| model_output = model.predict(target_texts, batch_size=batch_size, gpus=1, accelerator="gpu") | ||||||
| quality_dict[lang] = model_output.scores | ||||||
| logging.info(f"Processed {len(target_texts)} documents for language '{lang}' in file {file_path}") | ||||||
| else: | ||||||
| logging.info(f"No valid documents for language '{lang}' in file {file_path}") | ||||||
|
|
||||||
| logging.info("Translation quality scores:") | ||||||
| for lang, scores in quality_dict.items(): | ||||||
| logging.info(f"Mean score for {lang}: {np.mean(scores):.4f}") | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CLI option
--batch-sizehas no type specified; it will be parsed as a string. Consider addingtype=intto ensurebatch_sizeis passed as an integer.