diff --git a/.gitignore b/.gitignore index e7a026874b..7a77323dac 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,11 @@ __pycache__/ cutlass_library.egg-info/ /build* + +# CuTe DSL editable-install payload copied from nvidia-cutlass-dsl-libs-*. +python/CuTeDSL/VERSION.EDITABLE +python/CuTeDSL/cutlass/_mlir/ +python/CuTeDSL/cutlass/base_dsl/py.typed +python/CuTeDSL/cutlass/cute/py.typed +python/CuTeDSL/lib/ +python/CuTeDSL/nvidia_cutlass_dsl.egg-info/ diff --git a/python/CuTeDSL/cutlass/base_dsl/compiler.py b/python/CuTeDSL/cutlass/base_dsl/compiler.py index 2edacfb341..8879b5873a 100644 --- a/python/CuTeDSL/cutlass/base_dsl/compiler.py +++ b/python/CuTeDSL/cutlass/base_dsl/compiler.py @@ -18,16 +18,13 @@ from typing import Any import collections.abc import os -import sys import inspect +import importlib import types from .common import DSLRuntimeError from .utils.logger import log from .env_manager import EnvironmentVarManager -_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(_SCRIPT_PATH) - from .._mlir import ir @@ -477,11 +474,11 @@ def enable_tvm_ffi(self) -> bool: ret = self.options[EnableTVMFFI].value if ret: try: - import tvm_ffi - except ModuleNotFoundError: + _ = importlib.import_module("tvm_ffi") + except Exception as e: raise DSLRuntimeError( "TVM FFI is not installed, please install it via `pip install apache-tvm-ffi`" - ) + ) from e return ret def to_str(self) -> str: @@ -521,7 +518,6 @@ def _get_compile_option_from_str(option_str: str) -> type[CompileOption]: return mapping[option_str] import argparse - import shlex parser = argparse.ArgumentParser() parser.add_argument("--opt-level", nargs="?", type=int, default=3) diff --git a/python/CuTeDSL/cutlass/cute/algorithm.py b/python/CuTeDSL/cutlass/cute/algorithm.py index a6a99ea0c8..f5bd71f06e 100644 --- a/python/CuTeDSL/cutlass/cute/algorithm.py +++ b/python/CuTeDSL/cutlass/cute/algorithm.py @@ -42,7 +42,6 @@ CopyAtom, make_atom, _normalize_variadic_tensor_operand, - copy_atom_call, ) from .nvgpu.common import ( CacheEvictionPriority, diff --git a/python/CuTeDSL/prep_editable_install.py b/python/CuTeDSL/prep_editable_install.py index d3064f7bbc..1d7bb0a20a 100644 --- a/python/CuTeDSL/prep_editable_install.py +++ b/python/CuTeDSL/prep_editable_install.py @@ -23,7 +23,7 @@ import zipfile import re from pathlib import Path -from typing import Optional, Tuple, List +from typing import Optional, Tuple import logging # Configure logging @@ -32,6 +32,20 @@ # Constants PACKAGE_NAME = "nvidia-cutlass-dsl" +RUNTIME_PACKAGE_PREFIX = "nvidia-cutlass-dsl-libs-" +RUNTIME_PROVIDER_ENV = "CUTLASS_DSL_RUNTIME_PROVIDER" +DEFAULT_RUNTIME_PROVIDER = "base" +GENERATED_RUNTIME_DIRS = ( + Path("cutlass") / "_mlir", + Path("lib"), +) +GENERATED_RUNTIME_FILE_GLOBS = ( + str(Path("cutlass") / "**" / "py.typed"), +) +REQUIRED_RUNTIME_ARTIFACTS = ( + Path("cutlass") / "_mlir", + Path("lib"), +) class CutlassDSLSetupError(Exception): @@ -61,11 +75,12 @@ def get_package_spec(requirements_path: Optional[Path] = None) -> str: return PACKAGE_NAME -def download_wheel(temp_dir: Path) -> Path: +def download_requirement(requirement_spec: str, temp_dir: Path) -> Path: """ - Download the nvidia-cutlass-dsl wheel to a temporary directory. + Download one wheel to a temporary directory without dependencies. Args: + requirement_spec: pip requirement spec to download temp_dir: Temporary directory path for downloading Returns: @@ -74,10 +89,8 @@ def download_wheel(temp_dir: Path) -> Path: Raises: CutlassDSLSetupError: If download fails or wheel not found """ - # Resolve package spec from requirements, or fall back to PACKAGE_NAME - package_spec = get_package_spec() - - logger.info(f"Downloading {package_spec} wheel to {temp_dir}") + logger.info(f"Downloading {requirement_spec} wheel to {temp_dir}") + before = set(temp_dir.glob("*.whl")) try: subprocess.check_call( @@ -87,7 +100,7 @@ def download_wheel(temp_dir: Path) -> Path: "pip", "download", "--no-deps", - package_spec, + requirement_spec, "--dest", str(temp_dir), ], @@ -95,19 +108,25 @@ def download_wheel(temp_dir: Path) -> Path: stderr=subprocess.PIPE, ) except subprocess.CalledProcessError as e: - error_msg = f"Failed to download {PACKAGE_NAME}: {e}" + error_msg = f"Failed to download {requirement_spec}: {e}" if e.stdout: error_msg += f"\nstdout: {e.stdout.decode()}" if e.stderr: error_msg += f"\nstderr: {e.stderr.decode()}" raise CutlassDSLSetupError(error_msg) - # Find the downloaded wheel file - wheel_pattern = f"*.whl" - wheel_files = list(temp_dir.glob(wheel_pattern)) + # Find the newly downloaded wheel file. + wheel_pattern = "*.whl" + wheel_files = [path for path in temp_dir.glob(wheel_pattern) if path not in before] if not wheel_files: raise CutlassDSLSetupError( - f"No wheel file matching {wheel_pattern} found after download" + f"No wheel file matching {wheel_pattern} found after downloading " + f"{requirement_spec}" + ) + if len(wheel_files) != 1: + raise CutlassDSLSetupError( + f"Expected one wheel for {requirement_spec}, found " + f"{[path.name for path in wheel_files]}" ) wheel_path = wheel_files[0] @@ -115,6 +134,24 @@ def download_wheel(temp_dir: Path) -> Path: return wheel_path +def download_wheel(temp_dir: Path) -> Path: + """ + Download the nvidia-cutlass-dsl wheel to a temporary directory. + + Args: + temp_dir: Temporary directory path for downloading + + Returns: + Path to the downloaded wheel file + + Raises: + CutlassDSLSetupError: If download fails or wheel not found + """ + # Resolve package spec from requirements, or fall back to PACKAGE_NAME + package_spec = get_package_spec() + return download_requirement(package_spec, temp_dir) + + def extract_version_from_wheel(wheel_path: Path) -> str: """ Extract version from wheel filename and convert to dev version. @@ -159,6 +196,116 @@ def extract_version_from_wheel(wheel_path: Path) -> str: return "9.9.9.dev0" +def read_wheel_metadata(wheel_path: Path) -> str: + """Read the METADATA payload from a wheel.""" + try: + with zipfile.ZipFile(wheel_path, "r") as wheel_zip: + metadata_files = [ + name + for name in wheel_zip.namelist() + if name.endswith(".dist-info/METADATA") + ] + if len(metadata_files) != 1: + raise CutlassDSLSetupError( + f"Expected one METADATA file in {wheel_path.name}, found " + f"{metadata_files}" + ) + return wheel_zip.read(metadata_files[0]).decode("utf-8") + except zipfile.BadZipFile as e: + raise CutlassDSLSetupError(f"Invalid wheel file {wheel_path}: {e}") + + +def _canonical_package_name(name: str) -> str: + return name.lower().replace("_", "-") + + +def _extract_runtime_requires(metadata: str) -> dict[str, str]: + """Return exact runtime companion requirements keyed by provider name.""" + runtime_requires: dict[str, str] = {} + for raw_line in metadata.splitlines(): + if not raw_line.startswith("Requires-Dist:"): + continue + requirement = raw_line.split(":", 1)[1].strip() + requirement_name = re.split(r"[ ;(<>=!~]", requirement, maxsplit=1)[0] + requirement_name = _canonical_package_name(requirement_name) + if not requirement_name.startswith(RUNTIME_PACKAGE_PREFIX): + continue + provider = requirement_name.removeprefix(RUNTIME_PACKAGE_PREFIX) + version_match = re.search(r"==\s*([A-Za-z0-9_.!+\-]+)", requirement) + if version_match is None: + raise CutlassDSLSetupError( + "Runtime companion dependency must be exact-pinned, got " + f"{requirement!r}" + ) + runtime_requires[provider] = ( + f"{RUNTIME_PACKAGE_PREFIX}{provider}=={version_match.group(1)}" + ) + return runtime_requires + + +def _runtime_provider_from_package_spec(package_spec: str) -> str | None: + extras_match = re.search(r"\[([^\]]+)\]", package_spec) + if extras_match is None: + return None + extras = { + item.strip().lower().replace("_", "-") + for item in extras_match.group(1).split(",") + if item.strip() + } + providers = [extra for extra in extras if extra] + if len(providers) > 1: + raise CutlassDSLSetupError( + "Editable install runtime provider must be unique when selected " + f"through extras, got {providers}" + ) + return providers[0] if providers else None + + +def select_runtime_provider(runtime_requires: dict[str, str]) -> str: + """Select a runtime companion provider without consulting site-packages.""" + requested_provider = None + try: + import os + + requested_provider = os.environ.get(RUNTIME_PROVIDER_ENV) + except Exception: + requested_provider = None + + if requested_provider: + provider = requested_provider.strip().lower().replace("_", "-") + else: + provider = _runtime_provider_from_package_spec(get_package_spec()) + if provider is None: + provider = DEFAULT_RUNTIME_PROVIDER + + if provider not in runtime_requires: + raise CutlassDSLSetupError( + f"Requested runtime provider {provider!r} is not available. " + f"Available providers: {sorted(runtime_requires)}. Set " + f"{RUNTIME_PROVIDER_ENV}= to select one explicitly." + ) + logger.info(f"Selected runtime provider: {provider}") + return provider + + +def download_runtime_payload_wheel(dsl_wheel_path: Path, temp_dir: Path) -> Path | None: + """ + Download the exact runtime companion wheel required by nvidia-cutlass-dsl. + + The generated Python payload and shared libraries must come from a companion + wheel whose version and provider match the downloaded DSL wheel metadata. + Copying from ambient site-packages is intentionally avoided because + nvidia-cutlass-dsl-libs-* wheels can install overlapping payload paths. + """ + metadata = read_wheel_metadata(dsl_wheel_path) + runtime_requires = _extract_runtime_requires(metadata) + if not runtime_requires: + logger.info("No nvidia-cutlass-dsl runtime companion dependency found") + return None + provider = select_runtime_provider(runtime_requires) + return download_requirement(runtime_requires[provider], temp_dir) + + def extract_wheel_contents(wheel_path: Path, extract_dir: Path) -> None: """ Extract wheel contents to specified directory. @@ -182,6 +329,47 @@ def extract_wheel_contents(wheel_path: Path, extract_dir: Path) -> None: raise CutlassDSLSetupError(f"Failed to extract wheel: {e}") +def clean_generated_runtime_payload(package_root: Path) -> None: + """ + Remove generated runtime payload copied by previous editable setup runs. + + The _mlir Python package, runtime shared library directory, and py.typed + markers come from the downloaded runtime wheel. They must be replaced as a + unit to avoid stale Python/.so skew across runtime package versions. + """ + for rel_path in GENERATED_RUNTIME_DIRS: + path = package_root / rel_path + if path.exists(): + logger.info(f"Removing generated runtime directory {path}") + shutil.rmtree(path) + + for pattern in GENERATED_RUNTIME_FILE_GLOBS: + for path in package_root.glob(pattern): + if path.is_file(): + logger.info(f"Removing generated runtime file {path}") + path.unlink() + + +def validate_runtime_payload(package_root: Path) -> None: + """Validate that editable setup left the required runtime payload in place.""" + missing = [ + str(package_root / rel_path) + for rel_path in REQUIRED_RUNTIME_ARTIFACTS + if not (package_root / rel_path).exists() + ] + if missing: + raise CutlassDSLSetupError( + "Editable CuTe DSL runtime payload is incomplete; missing " + f"{missing}" + ) + + if not any((package_root / "lib").glob("*.so")): + raise CutlassDSLSetupError( + "Editable CuTe DSL runtime payload is incomplete; " + f"no shared libraries found in {package_root / 'lib'}" + ) + + def copy_library_files(extract_dir: Path, package_root: Path) -> int: """ Copy .so library files from extracted wheel to package lib directory. @@ -193,7 +381,6 @@ def copy_library_files(extract_dir: Path, package_root: Path) -> int: Returns: Number of files copied """ - lib_pattern = extract_dir / "**" / "lib" / "*.so" so_files = [f for f in extract_dir.rglob("lib/*.so")] if not so_files: @@ -240,7 +427,7 @@ def copy_python_packages(extract_dir: Path, package_root: Path) -> Tuple[int, in cutlass_source_dir = cutlass_source_dirs[0] cutlass_dest_dir = package_root / "cutlass" - logger.info(f"Found python_packages/cutlass/ directory") + logger.info("Found python_packages/cutlass/ directory") logger.info(f"Copying from {cutlass_source_dir} to {cutlass_dest_dir}") copied_count = 0 @@ -256,7 +443,8 @@ def copy_python_packages(extract_dir: Path, package_root: Path) -> Tuple[int, in # Create parent directories dest_file.parent.mkdir(parents=True, exist_ok=True) - # Copy file if it doesn't exist + # Copy runtime files only when they do not overlap with source files. + # Generated payload state is cleaned before this pass. if dest_file.exists(): skipped_count += 1 logger.debug(f" Skipping {rel_path} (already exists)") @@ -310,11 +498,23 @@ def prep_editable_install() -> None: version = extract_version_from_wheel(wheel_path) extract_wheel_contents(wheel_path, extract_dir) - # Copy files + # Download and extract the exact companion runtime payload when the + # DSL wheel declares one. The payload provider can be selected with + # CUTLASS_DSL_RUNTIME_PROVIDER=base/cu13. + runtime_wheel_path = download_runtime_payload_wheel(wheel_path, temp_dir) + if runtime_wheel_path is not None: + extract_wheel_contents(runtime_wheel_path, extract_dir) + + # Replace generated runtime files as a unit so stale _mlir/Python + # payloads cannot survive across runtime package version changes. + clean_generated_runtime_payload(package_root) + + # Copy files from the downloaded wheels. lib_files_copied = copy_library_files(extract_dir, package_root) py_files_copied, py_files_skipped = copy_python_packages( extract_dir, package_root ) + validate_runtime_payload(package_root) # Write version file write_version_file(version, package_root) diff --git a/test/utils/test_sharding.py b/test/utils/test_sharding.py index 4ca29817d9..06ade85fd8 100644 --- a/test/utils/test_sharding.py +++ b/test/utils/test_sharding.py @@ -34,7 +34,6 @@ from pathlib import Path import inspect import re -import sys import pytest from _pytest.assertion.util import running_on_ci @@ -256,17 +255,14 @@ def pytest_collection_modifyitems(config, items): else: compatible_SMs = {f"{target_cc}a", f"{target_cc}f", f"{target_cc}"} else: - script_dir = os.path.dirname(os.path.abspath(__file__)) - if script_dir not in sys.path: - sys.path.append(script_dir) + from cutlass.base_dsl.runtime.cuda import get_compute_capability_major_minor - from device_info import compute_capability - - if compute_capability: - target_cc = int(compute_capability) - else: + major, minor = get_compute_capability_major_minor() + if major is None or minor is None: raise SystemError("Failed to get CUDA compute capability!") + target_cc = major * 10 + minor + if target_cc < 90: compatible_SMs = {f"{target_cc}"} else: