diff --git a/Cargo.lock b/Cargo.lock index 388c39500..558431fd7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4883,11 +4883,13 @@ dependencies = [ "serde", "serde_json", "sqlparser", + "tempfile", "thiserror 1.0.69", "tonic", "tonic-build", "tonic-prost", "tonic-prost-build", + "url", "vegafusion-common", ] @@ -4963,6 +4965,7 @@ dependencies = [ "protobuf-src", "regex", "serde_json", + "tempfile", "tokio", "tonic", "tonic-build", diff --git a/docs/source/features/features.md b/docs/source/features/features.md index 7c7ed791e..1b37fffe0 100644 --- a/docs/source/features/features.md +++ b/docs/source/features/features.md @@ -13,6 +13,7 @@ transform_spec transform_extract chart_state inline_datasets +plan_resolver grpc embed jupyter_widget diff --git a/docs/source/features/grpc.md b/docs/source/features/grpc.md index cd7b67900..d59f9b23a 100644 --- a/docs/source/features/grpc.md +++ b/docs/source/features/grpc.md @@ -2,7 +2,9 @@ The VegaFusion Runtime can run as a [gRPC](https://grpc.io/) service, which makes it possible for multiple clients to connect to the same runtime, and share a cache (See [How it Works](../about/how_it_works) for more details). This also makes it possible for the Runtime to reside on a different host than the client. :::{warning} -VegaFusion's gRPC server does not currently support authentication, and chart specifications may reference the local file system of the machine running the server. It is not currently recommended to use VegaFusion server with untrusted Vega specifications unless other measures are taken to isolate the service. +VegaFusion's gRPC server does not currently support authentication. If you use it with untrusted Vega specifications, lock down the server process with `--no-allowed-urls`, `--allowed-base-url`, `--base-url`, or `--no-base-url`, and apply any additional isolation your deployment requires. + +URL policy is enforced against the initial resolved URL only. VegaFusion does not re-check redirect destinations after a fetch begins. ::: ## VegaFusion Server @@ -18,6 +20,15 @@ The server may then be launched using a particular port as follows: vegafusion-server --port 50051 ``` +The server process owns URL resolution and access policy for all gRPC clients. For example: + +``` +vegafusion-server \ + --port 50051 \ + --base-url https://cdn.jsdelivr.net/npm/vega-datasets@v2.9.0/ \ + --allowed-base-url https://cdn.jsdelivr.net/ +``` + ## Python The `vf.runtime.grpc_connect` method is used to connect the Python client to a VegaFusion Server instance. diff --git a/docs/source/features/inline_datasets.md b/docs/source/features/inline_datasets.md index 3a054ce24..1c7f0d2c5 100644 --- a/docs/source/features/inline_datasets.md +++ b/docs/source/features/inline_datasets.md @@ -37,3 +37,5 @@ See [inline_datasets.py](https://github.com/vega/vegafusion/tree/main/examples/p In Rust, `inline_datasets` should be a `HashMap` from dataset names (e.g. `movies` in the example above) to `VegaFusionDataset` instances. `VegaFusionDataset` is an enum that may be either a `VegaFusionTable` (which is a thin wrapper around Arrow RecordBatches), or a DataFusion [`LocalPlan`](https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.LogicalPlan.html) (which represents an arbitrary DataFusion query). See [inline_datasets.rs](https://github.com/vega/vegafusion/tree/main/examples/rust-examples/examples/inline_datasets.rs) for a complete example using a `VegaFusionTable`, and see [inline_datasets_plan.rs](https://github.com/vega/vegafusion/tree/main/examples/rust-examples/examples/inline_datasets_plan.rs) for a complete example using a DataFusion ``LogicalPlan``. + +For more advanced data source integration (custom URL schemes, SQL transpilation, remote execution), see [Plan Resolver](./plan_resolver.md). diff --git a/docs/source/features/plan_resolver.md b/docs/source/features/plan_resolver.md new file mode 100644 index 000000000..03da4135d --- /dev/null +++ b/docs/source/features/plan_resolver.md @@ -0,0 +1,152 @@ +# Plan Resolver + +PlanResolver lets you connect custom data sources to VegaFusion. Use it when data lives in an external system (Spark, Snowflake, DuckDB, a custom API) and you want to push computation there instead of pulling it all into memory. For data you already have in Python as DataFrames or Arrow tables, [inline datasets](./inline_datasets.md) are simpler. + +:::{note} +`resolve_table`, `resolve_plan_proto` (bytes variant), and `unparse_to_sql` with bytes require no additional dependencies beyond `vegafusion`. + +`external_table_scan_node`, `inline_table_scan_node`, and `resolve_plan` (deserialized `LogicalPlanNode` variant) require the protobuf package: + +``` +pip install vegafusion[plan-resolver] +``` +::: + +## Python + +Override one of these methods on `PlanResolver` (simplest first): + +- `resolve_table`: return an Arrow table for a single external data source. VegaFusion handles the rest — it applies Vega transforms (filter, aggregate, etc.) via DataFusion after your resolver provides the data. +- `resolve_plan` / `resolve_plan_proto`: evaluate an entire logical plan, or the parts your backend supports. Use this to transpile the plan to SQL and execute it remotely, or to push supported operations to your query engine while letting DataFusion handle the rest. + +### scan_url + resolve_table + +For custom URL schemes in Vega specs (e.g. `"url": "mydb://warehouse/sales"`), override `scan_url()` and `resolve_table()`: + +```python +import vegafusion as vf +from vegafusion import PlanResolver +from vegafusion.plan_resolver import external_table_scan_node + +class MyResolver(PlanResolver): + def scan_url(self, parsed_url): + if parsed_url["scheme"] != "mydb": + return None # pass to next resolver + + # Look up the table schema from your data source. + # This is called at planning time, so avoid loading data here. + schema = get_table_schema(parsed_url["path"]) + + return external_table_scan_node( + table_name=parsed_url["url"], + schema=schema, + scheme="mydb", + metadata={"path": parsed_url["path"]}, + ) + + def resolve_table(self, name, scheme, schema, metadata=None, + projected_columns=None, filters=None): + # Called at execution time — load the actual data. + # projected_columns lists only the columns DataFusion needs, + # so you can avoid reading unnecessary columns. + return load_table(metadata["path"], columns=projected_columns) +``` + +`scan_url()` is called at planning time — it inspects the URL and returns an `ExternalTableProvider` plan node with the table's schema. `resolve_table()` is called at execution time to provide the actual data. + +Use `base_url` on the runtime to set a base path for relative URLs in Vega specs: + +```python +resolver = MyResolver() +rt = vf.VegaFusionRuntime( + plan_resolver=resolver, + base_url="mydb://warehouse/", +) + +# Vega spec with "url": "sales" resolves to "mydb://warehouse/sales" +``` + +See [plan_resolver_url_scanning.py](https://github.com/vega/vegafusion/tree/main/examples/python-examples/plan_resolver_url_scanning.py) for a complete example. + +### resolve_table only + +If data comes from `ExternalDataset` inline datasets (not URLs), you only need `resolve_table`: + +```python +import vegafusion as vf +from vegafusion import ExternalDataset, PlanResolver + +class MyResolver(PlanResolver): + def resolve_table(self, name, scheme, schema, metadata=None, + projected_columns=None, filters=None): + # Look up data by name from your data source + df = my_database.query(name, columns=projected_columns) + return df.to_arrow() + +ext = ExternalDataset(scheme="mydb", schema=table.schema, data=table) +rt = vf.VegaFusionRuntime(plan_resolver=MyResolver()) +datasets, _ = rt.pre_transform_datasets( + spec, datasets=["result"], + inline_datasets={"source": ext}, dataset_format="pyarrow", +) +``` + +No protobuf dependency is needed for this pattern. + +### resolve_plan + unparse_to_sql + +Override `resolve_plan_proto` to receive the full logical plan and transpile it to SQL for remote execution: + +```python +from vegafusion import PlanResolver +from vegafusion.plan_resolver import unparse_to_sql + +class SqlResolver(PlanResolver): + def __init__(self, connection): + self._conn = connection + + def resolve_plan_proto(self, plan_bytes, datasets): + # Convert the DataFusion logical plan to a SQL string + sql = unparse_to_sql(plan_bytes, dialect="default") + + # Execute the SQL against your database + cursor = self._conn.cursor() + cursor.execute(sql) + return cursor.fetch_arrow_all() +``` + +`resolve_plan_proto` receives protobuf bytes that can be passed directly to `unparse_to_sql()` without deserialization. To inspect or modify the plan tree, use `resolve_plan()` instead (it receives a deserialized `LogicalPlanNode`). + +Supported SQL dialects: `"default"`, `"postgres"`, `"mysql"`, `"sqlite"`, `"duckdb"`, `"bigquery"`. + +See [plan_resolver_sql.py](https://github.com/vega/vegafusion/tree/main/examples/python-examples/plan_resolver_sql.py) for a complete example. + +### Configuration + +`PlanResolver` cannot be used with `grpc_connect()` (resolvers run in-process). Class-level attributes control resolver behavior: + +- `thread_safe` (default `True`) — set to `False` for backends with thread-affine connections (e.g. DuckDB) +- `skip_when_no_external_tables` (default `True`) — set to `False` to receive all plans, not just those with external tables (e.g. for logging) +- `supports_arrow_tables` (default `False`) — set to `True` to let the runtime eagerly materialize plans into Arrow tables + +### API Reference + +```{eval-rst} +.. autoclass:: vegafusion.PlanResolver + :members: + +.. autoclass:: vegafusion.ExternalDataset + :members: + +.. autofunction:: vegafusion.plan_resolver.external_table_scan_node + +.. autofunction:: vegafusion.plan_resolver.unparse_to_sql + +.. autofunction:: vegafusion.plan_resolver.unparse_expr_to_sql + +.. autofunction:: vegafusion.plan_resolver.inline_table_scan_node +``` + +## Rust + +The `PlanResolver` trait in `vegafusion-runtime` provides the same two-phase architecture (scan_url at planning time, resolve_table/resolve_plan at execution time). See the [vegafusion-runtime docs on docs.rs](https://docs.rs/vegafusion-runtime/) for the full API. diff --git a/examples/editor-demo/README.md b/examples/editor-demo/README.md index 489b96b9c..02c0396e7 100644 --- a/examples/editor-demo/README.md +++ b/examples/editor-demo/README.md @@ -8,6 +8,11 @@ Launch gRPC-Web server with: ./vegafusion-server --port 50051 --web ``` +Add `--base-url`, `--no-base-url`, `--allowed-base-url`, or `--no-allowed-urls` +to control how the server resolves and accesses external data URLs. +Policy checks apply to the initial resolved URL only; redirect destinations are +not re-checked after a fetch begins. + Build and launch editor with ``` npm install diff --git a/examples/python-examples/plan_resolver_sql.py b/examples/python-examples/plan_resolver_sql.py new file mode 100644 index 000000000..375119f1b --- /dev/null +++ b/examples/python-examples/plan_resolver_sql.py @@ -0,0 +1,87 @@ +# Demonstrates SQL transpilation using resolve_plan_proto() + unparse_to_sql(). +# The resolver receives a serialized logical plan, converts it to SQL, and prints it. +# In a real application you would execute the SQL against a database. + +import json +from typing import Any + +import pyarrow as pa + +import vegafusion as vf +from vegafusion import ExternalDataset, PlanResolver +from vegafusion.plan_resolver import unparse_to_sql + + +def main() -> None: + source_table = pa.table({"x": [1, 5, 10], "y": ["a", "b", "c"]}) + ext = ExternalDataset(scheme="table", schema=source_table.schema, data=source_table) + + resolver = SqlTranspileResolver() + rt = vf.VegaFusionRuntime(plan_resolver=resolver) + + spec = get_spec() + datasets, warnings = rt.pre_transform_datasets( + spec, + datasets=["filtered"], + inline_datasets={"source": ext}, + dataset_format="pyarrow", + ) + + assert warnings == [] + result = datasets[0] + assert result.column("x").to_pylist() == [5, 10] + assert result.column("y").to_pylist() == ["b", "c"] + assert resolver.captured_sql is not None + assert "SELECT" in resolver.captured_sql + + print("Captured SQL (postgres dialect):") + print(resolver.captured_sql) + print() + print("Result table:") + print(result) + + +class SqlTranspileResolver(PlanResolver): + """Converts the logical plan to Postgres-dialect SQL.""" + + def __init__(self) -> None: + self.captured_sql: str | None = None + + def resolve_plan_proto( + self, plan_bytes: bytes, datasets: dict[str, Any] + ) -> pa.Table: + sql = unparse_to_sql(plan_bytes, dialect="postgres") + self.captured_sql = sql + + # In a real resolver, you would execute `sql` against your database + # and return the result as an Arrow table. Here we return hardcoded + # data matching the expected query result for demonstration. + return pa.table({"x": [5, 10], "y": ["b", "c"]}) + + +def get_spec() -> dict[str, Any]: + return json.loads(""" +{ + "$schema": "https://vega.github.io/schema/vega/v5.json", + "data": [ + { + "name": "source", + "url": "table://source" + }, + { + "name": "filtered", + "source": "source", + "transform": [ + { + "type": "filter", + "expr": "datum.x > 3" + } + ] + } + ] +} + """) + + +if __name__ == "__main__": + main() diff --git a/examples/python-examples/plan_resolver_url_scanning.py b/examples/python-examples/plan_resolver_url_scanning.py new file mode 100644 index 000000000..89915f127 --- /dev/null +++ b/examples/python-examples/plan_resolver_url_scanning.py @@ -0,0 +1,92 @@ +# Requires: pip install vegafusion[plan-resolver] +""" +Demonstrates the URL scanning pattern for custom URL schemes: +scan_url() + resolve_table() + +VegaFusion's PlanResolver lets you register custom URL schemes so that +data references like "mydata://database/sales" in a Vega spec are resolved +by your own Python code rather than fetched over HTTP. +""" + +from __future__ import annotations + +import json +from typing import Any + +import pyarrow as pa +import vegafusion as vf +from vegafusion import PlanResolver +from vegafusion.plan_resolver import external_table_scan_node + + +def main(): + resolver = SalesDataResolver() + rt = vf.VegaFusionRuntime(plan_resolver=resolver) + + spec = make_spec() + datasets, warnings = rt.pre_transform_datasets( + spec, datasets=["sales"], dataset_format="pyarrow" + ) + + assert warnings == [] + assert len(datasets) == 1 + + table = datasets[0] + assert table.column("product").to_pylist() == ["Widget", "Gadget", "Gizmo"] + assert table.column("revenue").to_pylist() == [1200, 3400, 560] + print("Result table:") + print(table) + print("\nAll assertions passed.") + + +class SalesDataResolver(PlanResolver): + """Resolves URLs with the 'mydata' scheme using in-memory data.""" + + def scan_url(self, parsed_url: dict[str, Any]) -> Any: + if parsed_url["scheme"] == "mydata": + schema = pa.schema([("product", pa.utf8()), ("revenue", pa.int64())]) + return external_table_scan_node( + # Use the full URL as the table name so multiple URLs + # produce distinct plan nodes + table_name=parsed_url["url"], + schema=schema, + scheme="mydata", + ) + return None + + def resolve_table( + self, + name: str, + scheme: str, + schema: Any, + metadata: dict[str, Any] | None = None, + projected_columns: list[str] | None = None, + filters: list[Any] | None = None, + ) -> pa.Table: + # In a real resolver, use `name` or `metadata` to look up data + # from your data source. Here we return a fixed table. + return pa.table( + { + "product": ["Widget", "Gadget", "Gizmo"], + "revenue": [1200, 3400, 560], + } + ) + + +def make_spec() -> dict[str, Any]: + spec_str = """ +{ + "$schema": "https://vega.github.io/schema/vega/v5.json", + "data": [ + { + "name": "sales", + "url": "mydata://database/sales" + } + ] +} + """ + return json.loads(spec_str) + + +if __name__ == "__main__": + main() diff --git a/examples/rust-examples/examples/custom_resolver.rs b/examples/rust-examples/examples/custom_resolver.rs deleted file mode 100644 index 71a515e92..000000000 --- a/examples/rust-examples/examples/custom_resolver.rs +++ /dev/null @@ -1,122 +0,0 @@ -use std::sync::Arc; -use vegafusion_common::datafusion_expr::LogicalPlan; -use vegafusion_common::error::Result; -use vegafusion_core::runtime::{PlanResolver, ResolutionResult, VegaFusionRuntimeTrait}; -use vegafusion_core::spec::chart::ChartSpec; -use vegafusion_runtime::task_graph::runtime::VegaFusionRuntime; - -/// A custom resolver that logs plan resolution and passes through to DataFusion -#[derive(Clone)] -struct LoggingResolver; - -#[async_trait::async_trait] -impl PlanResolver for LoggingResolver { - fn name(&self) -> &str { - "LoggingResolver" - } - - async fn resolve_plan(&self, plan: LogicalPlan) -> Result { - println!("Custom resolver received logical plan"); - println!("Plan details:\n{}\n", plan.display_indent()); - - // Return the plan unchanged — DataFusion will execute it - Ok(ResolutionResult::Plan(plan)) - } -} - -/// This example demonstrates how to use a custom plan resolver with VegaFusion. -/// The custom resolver logs each plan before letting DataFusion execute it. -#[tokio::main] -async fn main() { - let spec = get_spec(); - - // Create a custom resolver - let custom_resolver = Arc::new(LoggingResolver) as Arc; - - // Create runtime with custom resolver - let runtime = VegaFusionRuntime::new(None, vec![custom_resolver]); - - println!("Starting pre-transform with custom resolver\n"); - - let (_transformed_spec, warnings) = runtime - .pre_transform_spec( - &spec, - &Default::default(), // Inline datasets - &Default::default(), // Options - ) - .await - .unwrap(); - println!("Spec transformed"); - assert_eq!(warnings.len(), 0); -} - -fn get_spec() -> ChartSpec { - let spec_str = r##" - { - "$schema": "https://vega.github.io/schema/vega/v5.json", - "description": "A histogram demonstrating custom resolver usage", - "width": 400, - "height": 200, - "padding": 5, - "data": [ - { - "name": "table", - "url": "data/movies.json", - "transform": [ - { - "type": "extent", - "field": "IMDB Rating", - "signal": "extent" - }, - { - "type": "bin", - "signal": "bins", - "field": "IMDB Rating", - "extent": {"signal": "extent"}, - "maxbins": 10 - }, - { - "type": "aggregate", - "groupby": ["bin0", "bin1"], - "ops": ["count"], - "fields": [null], - "as": ["count"] - } - ] - } - ], - "scales": [ - { - "name": "xscale", - "type": "linear", - "range": "width", - "domain": {"signal": "extent"} - }, - { - "name": "yscale", - "type": "linear", - "range": "height", - "round": true, - "domain": {"data": "table", "field": "count"}, - "zero": true, - "nice": true - } - ], - "marks": [ - { - "type": "rect", - "from": {"data": "table"}, - "encode": { - "update": { - "x": {"scale": "xscale", "field": "bin0"}, - "x2": {"scale": "xscale", "field": "bin1"}, - "y": {"scale": "yscale", "field": "count"}, - "y2": {"scale": "yscale", "value": 0} - } - } - } - ] - } - "##; - serde_json::from_str(spec_str).unwrap() -} diff --git a/vegafusion-core/Cargo.toml b/vegafusion-core/Cargo.toml index 0bd0efc52..5b319b8f6 100644 --- a/vegafusion-core/Cargo.toml +++ b/vegafusion-core/Cargo.toml @@ -73,6 +73,9 @@ workspace = true [dependencies.log] workspace = true +[dependencies.url] +version = "2" + [dependencies.serde] workspace = true @@ -84,6 +87,9 @@ optional = true workspace = true optional = true +[dev-dependencies.tempfile] +workspace = true + [lints.clippy] module_inception = "allow" diff --git a/vegafusion-core/src/chart_state.rs b/vegafusion-core/src/chart_state.rs index 20398559e..d7fad7db8 100644 --- a/vegafusion-core/src/chart_state.rs +++ b/vegafusion-core/src/chart_state.rs @@ -2,7 +2,7 @@ use crate::{ data::dataset::VegaFusionDataset, planning::{ apply_pre_transform::apply_pre_transform_datasets, - plan::SpecPlan, + plan::{PlannerConfig, SpecPlan}, stitch::CommPlan, watch::{ExportUpdate, ExportUpdateJSON, ExportUpdateNamespace}, }, @@ -66,7 +66,7 @@ impl ChartState { .map(|(k, ds)| (k.clone(), ds.fingerprint())) .collect::>(); - let plan = SpecPlan::try_new(&spec, &Default::default())?; + let plan = SpecPlan::try_new(&spec, &PlannerConfig::default())?; let task_scope = plan .server_spec diff --git a/vegafusion-core/src/data/mod.rs b/vegafusion-core/src/data/mod.rs index f311152c4..d90e52fd2 100644 --- a/vegafusion-core/src/data/mod.rs +++ b/vegafusion-core/src/data/mod.rs @@ -1,2 +1,3 @@ pub mod dataset; pub mod tasks; +pub mod url; diff --git a/vegafusion-core/src/data/url.rs b/vegafusion-core/src/data/url.rs new file mode 100644 index 000000000..e08e52976 --- /dev/null +++ b/vegafusion-core/src/data/url.rs @@ -0,0 +1,728 @@ +use regex::Regex; +#[cfg(not(target_arch = "wasm32"))] +use std::fs; +use std::path::{Path, PathBuf}; +use std::sync::LazyLock; +use vegafusion_common::error::{Result, VegaFusionError}; + +/// Parsed URL representation passed to resolvers during the scan phase. +/// All fields are populated from the fully-resolved URL (after base URL +/// resolution and hash-stripping). Resolvers pattern-match on these fields +/// rather than doing their own URL string parsing. +#[derive(Clone, Debug, PartialEq)] +pub struct ParsedUrl { + /// Original URL string (after base URL resolution and hash-stripping) + pub url: String, + /// URL scheme (http, https, file, s3, spark, etc.) — always present + pub scheme: String, + /// Host/authority component (e.g. "example.com", S3 bucket name) + pub host: Option, + /// Path component + pub path: String, + /// Query parameters in URL order, preserving duplicates + pub query_params: Vec<(String, String)>, + /// File extension extracted from path (e.g. "csv", "parquet") + pub extension: Option, + /// Explicit format type from Vega spec (overrides extension) + pub format_type: Option, + /// Parse spec from Vega format (e.g., {"date": "date"} for CSV column typing) + pub parse: Option, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum AllowedBaseUrlPattern { + Any, + Scheme(String), + Prefix(String), + WildcardHost { + scheme: String, + host_suffix: String, + path_prefix: String, + }, + FilePathPrefix(PathBuf), +} + +static URL_SCHEME_RE: LazyLock = + LazyLock::new(|| Regex::new(r"^(//|[a-zA-Z][a-zA-Z0-9+.\-]*://)").unwrap()); +static SCHEME_PATTERN_RE: LazyLock = + LazyLock::new(|| Regex::new(r"^[a-zA-Z][a-zA-Z0-9+.\-]*:$").unwrap()); +static WILDCARD_HOST_RE: LazyLock = LazyLock::new(|| { + Regex::new(r"^([a-zA-Z][a-zA-Z0-9+.\-]*)://\*\.([^/?#]+)(/[^?#]*)?$").unwrap() +}); + +#[cfg(not(target_arch = "wasm32"))] +fn normalize_file_base_url(base_url: String) -> Result { + let parsed = match url::Url::parse(&base_url) { + Ok(parsed) => parsed, + Err(_) => return Ok(base_url), + }; + + if parsed.scheme() != "file" { + return Ok(base_url); + } + + let Ok(path) = parsed.to_file_path() else { + return Ok(base_url); + }; + + if path.is_dir() && !base_url.ends_with('/') { + Ok(format!("{base_url}/")) + } else { + Ok(base_url) + } +} + +#[cfg(target_arch = "wasm32")] +fn normalize_file_base_url(base_url: String) -> Result { + Ok(base_url) +} + +/// Returns true if the string is already a URL (has a scheme per RFC 3986) +/// or is scheme-relative (starts with //). +pub fn has_url_scheme(s: &str) -> bool { + URL_SCHEME_RE.is_match(s) +} + +/// Returns true if `path` is an absolute filesystem path. +/// Unix: starts with `/`. Windows: starts with a drive letter `[A-Za-z]:\` or `[A-Za-z]:/`. +pub fn is_absolute_path(path: &str) -> bool { + let bytes = path.as_bytes(); + if bytes.first() == Some(&b'/') { + return true; + } + bytes.len() >= 3 + && bytes[0].is_ascii_alphabetic() + && bytes[1] == b':' + && (bytes[2] == b'\\' || bytes[2] == b'/') +} + +/// Normalize a base URL so it always has a scheme. +/// Bare absolute paths become file:// URLs; scheme-relative URLs get +/// https: prepended; scheme URLs are preserved as-is; everything else is rejected. +pub fn normalize_base_url(base: String) -> Result { + if base.starts_with("//") { + // Protocol-relative URL — prepend https: so url::Url::parse works + Ok(format!("https:{base}")) + } else if has_url_scheme(&base) { + normalize_file_base_url(base) + } else if is_absolute_path(&base) { + normalize_file_base_url(path_to_file_url(&base)?) + } else { + Err(VegaFusionError::specification(format!( + "base_url must be absolute (scheme URL or absolute path), got: {base}" + ))) + } +} + +/// Convert an absolute local path to a file:// URL. +/// Uses url::Url::from_file_path() for correct percent-encoding. +#[cfg(not(target_arch = "wasm32"))] +pub fn path_to_file_url(path: &str) -> Result { + let normalized = path.replace('\\', "/"); + let p = std::path::Path::new(&normalized); + url::Url::from_file_path(p) + .map(|u| u.to_string()) + .map_err(|_| { + VegaFusionError::specification(format!( + "Cannot convert path to file URL: {}", + p.display() + )) + }) +} + +/// Browser-wasm fallback: `url::Url::from_file_path` is unavailable on +/// `wasm32-unknown-unknown` (not compiled for that target in the `url` crate), +/// and `std::path` absolute-path semantics on that target do not recognize +/// POSIX-like virtual paths such as `/foo`. +/// +/// We therefore synthesize a `file:` URL for the restricted path forms we +/// expect here. Unlike `Url::from_file_path`, this does **not** percent-encode +/// reserved characters, so inputs must not contain `#`, `?`, etc. +#[cfg(target_arch = "wasm32")] +pub fn path_to_file_url(path: &str) -> Result { + let normalized = path.replace('\\', "/"); + Ok(format!("file://{normalized}")) +} + +#[cfg(not(target_arch = "wasm32"))] +pub fn file_url_to_path(url: &str) -> Result { + let parsed = url::Url::parse(url) + .map_err(|e| VegaFusionError::specification(format!("Invalid file URL '{url}': {e}")))?; + parsed.to_file_path().map_err(|_| { + VegaFusionError::specification(format!("Cannot convert file URL to path: {url}")) + }) +} + +#[cfg(target_arch = "wasm32")] +pub fn file_url_to_path(url: &str) -> Result { + Err(VegaFusionError::specification(format!( + "Cannot convert file URL to path on wasm target: {url}" + ))) +} + +#[cfg(not(target_arch = "wasm32"))] +fn portable_canonicalize(path: &Path) -> Result { + let canonical = fs::canonicalize(path).map_err(|e| { + VegaFusionError::specification(format!("Failed to resolve path {}: {e}", path.display())) + })?; + // On Windows, fs::canonicalize returns extended-length paths (\\?\C:\...) + // which break prefix matching. Strip the prefix for consistent comparisons. + #[cfg(target_os = "windows")] + { + let s = canonical.to_string_lossy(); + if let Some(stripped) = s.strip_prefix(r"\\?\") { + return Ok(PathBuf::from(stripped)); + } + } + Ok(canonical) +} + +#[cfg(target_arch = "wasm32")] +fn portable_canonicalize(path: &Path) -> Result { + Err(VegaFusionError::specification(format!( + "Cannot canonicalize path on wasm target: {}", + path.display() + ))) +} + +pub fn canonicalize_path_for_policy_check(path: &Path) -> Result { + if path.exists() { + return portable_canonicalize(path); + } + + let parent = path.parent().unwrap_or_else(|| Path::new(".")); + let canonical_parent = portable_canonicalize(parent)?; + let Some(file_name) = path.file_name() else { + return Err(VegaFusionError::specification(format!( + "Failed to resolve local path {}: missing file name", + path.display() + ))); + }; + Ok(canonical_parent.join(file_name)) +} + +fn normalize_url_prefix(mut normalized: String) -> String { + if !normalized.ends_with('/') { + normalized.push('/'); + } + normalized +} + +pub fn normalize_allowed_base_urls( + allowed_base_urls: Option>, +) -> Result>> { + allowed_base_urls + .map(|urls| { + urls.into_iter() + .map(|url| normalize_allowed_base_url(&url)) + .collect::>>() + }) + .transpose() +} + +pub fn normalize_allowed_base_url(allowed_base_url: &str) -> Result { + if allowed_base_url == "*" { + return Ok(AllowedBaseUrlPattern::Any); + } + + if SCHEME_PATTERN_RE.is_match(allowed_base_url) { + return Ok(AllowedBaseUrlPattern::Scheme( + allowed_base_url[..allowed_base_url.len() - 1].to_ascii_lowercase(), + )); + } + + if is_absolute_path(allowed_base_url) || allowed_base_url.starts_with("file:///") { + let path = if allowed_base_url.starts_with("file:///") { + file_url_to_path(allowed_base_url)? + } else { + PathBuf::from(allowed_base_url) + }; + let canonical = portable_canonicalize(&path)?; + if !canonical.is_dir() { + return Err(VegaFusionError::specification(format!( + "Filesystem path in allowed_base_urls must be a directory: {}", + canonical.display() + ))); + } + return Ok(AllowedBaseUrlPattern::FilePathPrefix(canonical)); + } + + if let Some(captures) = WILDCARD_HOST_RE.captures(allowed_base_url) { + let scheme = captures.get(1).unwrap().as_str().to_ascii_lowercase(); + let host_suffix = captures.get(2).unwrap().as_str().to_ascii_lowercase(); + if host_suffix.is_empty() || host_suffix.contains('@') || host_suffix.contains(':') { + return Err(VegaFusionError::specification(format!( + "Invalid wildcard host pattern in allowed_base_urls: {allowed_base_url}" + ))); + } + let path_prefix = normalize_url_prefix( + captures + .get(3) + .map(|m| m.as_str().to_string()) + .unwrap_or_else(|| "/".to_string()), + ); + return Ok(AllowedBaseUrlPattern::WildcardHost { + scheme, + host_suffix, + path_prefix, + }); + } + + let parsed_url = url::Url::parse(allowed_base_url).map_err(|e| { + VegaFusionError::specification(format!( + "Invalid allowed_base_url '{allowed_base_url}': {e}" + )) + })?; + + if !parsed_url.username().is_empty() || parsed_url.password().is_some() { + return Err(VegaFusionError::specification(format!( + "allowed_base_url cannot include userinfo credentials: {allowed_base_url}" + ))); + } + + if parsed_url.query().is_some() { + return Err(VegaFusionError::specification(format!( + "allowed_base_url cannot include a query component: {allowed_base_url}" + ))); + } + + if parsed_url.fragment().is_some() { + return Err(VegaFusionError::specification(format!( + "allowed_base_url cannot include a fragment component: {allowed_base_url}" + ))); + } + + Ok(AllowedBaseUrlPattern::Prefix(normalize_url_prefix( + parsed_url.to_string(), + ))) +} + +fn url_to_local_path(url: &str) -> Result { + if url.starts_with("file://") { + file_url_to_path(url) + } else if is_absolute_path(url) { + Ok(PathBuf::from(url)) + } else { + Err(VegaFusionError::specification(format!( + "Expected local file path or file URL, got: {url}" + ))) + } +} + +pub fn is_url_allowed(url: &str, allowed_base_urls: &[AllowedBaseUrlPattern]) -> bool { + let parsed_url = url::Url::parse(url).ok(); + + allowed_base_urls.iter().any(|pattern| match pattern { + AllowedBaseUrlPattern::Any => true, + AllowedBaseUrlPattern::Scheme(scheme) => parsed_url + .as_ref() + .map(|parsed| parsed.scheme().eq_ignore_ascii_case(scheme)) + .unwrap_or(false), + AllowedBaseUrlPattern::Prefix(prefix) => parsed_url + .as_ref() + .map(|parsed| parsed.as_str().starts_with(prefix)) + .unwrap_or(false), + AllowedBaseUrlPattern::WildcardHost { + scheme, + host_suffix, + path_prefix, + } => parsed_url + .as_ref() + .and_then(|parsed| { + parsed.host_str().map(|host| { + parsed.scheme().eq_ignore_ascii_case(scheme) + && (host.eq_ignore_ascii_case(host_suffix) + || host + .to_ascii_lowercase() + .ends_with(&format!(".{host_suffix}"))) + && parsed.path().starts_with(path_prefix) + }) + }) + .unwrap_or(false), + AllowedBaseUrlPattern::FilePathPrefix(prefix) => url_to_local_path(url) + .and_then(|path| canonicalize_path_for_policy_check(&path)) + .map(|path| path.starts_with(prefix)) + .unwrap_or(false), + }) +} + +pub fn check_url_allowed( + url: &str, + allowed_base_urls: &Option>, +) -> Result<()> { + if allowed_base_urls + .as_ref() + .map(|patterns| is_url_allowed(url, patterns)) + .unwrap_or(true) + { + Ok(()) + } else { + Err(VegaFusionError::specification(format!( + "URL or path '{url}' blocked by allowed_base_urls. Add the URL prefix to allowed_base_urls or change base_url." + ))) + } +} + +/// Resolve a spec URL against a base URL. This is the shared function used by +/// both plan-time resolution (MakeTasksVisitor for Url::String) and eval-time +/// resolution (DataUrlTask::eval for Url::Expr). +pub fn resolve_url(url: &str, base_url: &Option) -> Result { + if url.starts_with("//") { + // Protocol-relative URL — prepend https: so downstream parsers work + Ok(format!("https:{url}")) + } else if has_url_scheme(url) { + Ok(url.to_string()) + } else if is_absolute_path(url) { + path_to_file_url(url) + } else { + // Relative path — resolve against base URL using RFC 3986 joining + match base_url { + Some(base) => { + let base_url = url::Url::parse(base).map_err(|e| { + VegaFusionError::specification(format!("Invalid base URL '{base}': {e}")) + })?; + let resolved = base_url.join(url).map_err(|e| { + VegaFusionError::specification(format!( + "Cannot resolve '{url}' against base '{base}': {e}" + )) + })?; + Ok(resolved.to_string()) + } + None => Err(VegaFusionError::specification(format!( + "Relative URL with no base_url configured: {url}" + ))), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_has_url_scheme_https() { + assert!(has_url_scheme("https://example.com/data.csv")); + } + + #[test] + fn test_has_url_scheme_custom() { + assert!(has_url_scheme("spark://org.users")); + } + + #[test] + fn test_has_url_scheme_scheme_relative() { + assert!(has_url_scheme("//example.com/data.csv")); + } + + #[test] + fn test_has_url_scheme_absolute_path() { + assert!(!has_url_scheme("/tmp/data.csv")); + } + + #[test] + fn test_has_url_scheme_relative() { + assert!(!has_url_scheme("data/cars.json")); + } + + #[test] + fn test_has_url_scheme_embedded_scheme_in_query() { + // Relative reference with "://" in a query parameter — must not be + // misclassified as an absolute URL. + assert!(!has_url_scheme("fetch?target=http://evil.com/data")); + } + + #[test] + fn test_has_url_scheme_embedded_scheme_in_path() { + assert!(!has_url_scheme("foo/http://bar")); + } + + #[test] + fn test_is_absolute_path_unix() { + assert!(is_absolute_path("/tmp/data.csv")); + } + + #[test] + fn test_is_absolute_path_windows_backslash() { + assert!(is_absolute_path("C:\\tmp\\foo.csv")); + } + + #[test] + fn test_is_absolute_path_windows_forward() { + assert!(is_absolute_path("C:/tmp/foo.csv")); + } + + #[test] + fn test_is_absolute_path_rejects_ambiguous_colon() { + assert!(!is_absolute_path("a:b")); + } + + #[test] + fn test_is_absolute_path_rejects_digit_colon() { + assert!(!is_absolute_path("1:/foo")); + } + + #[test] + fn test_is_absolute_path_rejects_relative() { + assert!(!is_absolute_path("relative/path")); + } + + #[test] + #[cfg(not(target_os = "windows"))] + fn test_path_to_file_url_unix() { + let result = path_to_file_url("/tmp/data.csv").unwrap(); + assert_eq!(result, "file:///tmp/data.csv"); + } + + #[test] + #[cfg(not(target_os = "windows"))] + fn test_path_to_file_url_spaces() { + let result = path_to_file_url("/tmp/my data/file.csv").unwrap(); + assert_eq!(result, "file:///tmp/my%20data/file.csv"); + } + + #[test] + #[cfg(not(target_os = "windows"))] + fn test_path_to_file_url_hash() { + let result = path_to_file_url("/tmp/file#1.csv").unwrap(); + assert!( + result.contains("%23"), + "Hash should be percent-encoded: {result}" + ); + } + + #[test] + fn test_normalize_base_url_scheme() { + let result = normalize_base_url("https://example.com/data/".to_string()).unwrap(); + assert_eq!(result, "https://example.com/data/"); + } + + #[test] + fn test_normalize_base_url_scheme_relative() { + let result = normalize_base_url("//example.com/data/".to_string()).unwrap(); + assert_eq!(result, "https://example.com/data/"); + } + + #[test] + #[cfg(not(target_os = "windows"))] + fn test_normalize_base_url_absolute_path() { + let result = normalize_base_url("/home/user/data".to_string()).unwrap(); + assert_eq!(result, "file:///home/user/data"); + } + + #[test] + fn test_normalize_base_url_rejects_relative() { + let result = normalize_base_url("relative/path".to_string()); + assert!(result.is_err()); + } + + #[test] + fn test_normalize_base_url_rejects_ambiguous_colon() { + let result = normalize_base_url("a:b".to_string()); + assert!(result.is_err()); + } + + #[test] + #[cfg(not(target_os = "windows"))] + fn test_normalize_base_url_existing_directory_adds_trailing_slash() { + let tempdir = tempfile::tempdir().unwrap(); + let result = normalize_base_url(tempdir.path().to_str().unwrap().to_string()).unwrap(); + assert!( + result.ends_with('/'), + "expected trailing slash, got {result}" + ); + } + + #[test] + fn test_resolve_url_scheme_passthrough() { + let base = Some("https://cdn.example.com/".to_string()); + let result = resolve_url("https://other.com/data.csv", &base).unwrap(); + assert_eq!(result, "https://other.com/data.csv"); + } + + #[test] + #[cfg(not(target_os = "windows"))] + fn test_resolve_url_absolute_path_to_file() { + let base = Some("https://cdn.example.com/".to_string()); + let result = resolve_url("/tmp/data.csv", &base).unwrap(); + assert_eq!(result, "file:///tmp/data.csv"); + } + + #[test] + fn test_resolve_url_relative_with_base() { + let base = Some("https://raw.githubusercontent.com/vega/vega-datasets/v2.3.0/".to_string()); + let result = resolve_url("data/cars.json", &base).unwrap(); + assert_eq!( + result, + "https://raw.githubusercontent.com/vega/vega-datasets/v2.3.0/data/cars.json" + ); + } + + #[test] + fn test_resolve_url_relative_without_trailing_slash() { + // Per RFC 3986, joining against a base without trailing slash replaces + // the last path segment: "data" is replaced by "cars.json" + let base = Some("https://example.com/data".to_string()); + let result = resolve_url("cars.json", &base).unwrap(); + assert_eq!(result, "https://example.com/cars.json"); + } + + #[test] + fn test_resolve_url_relative_parent_traversal() { + let base = Some("https://example.com/data/v2/".to_string()); + let result = resolve_url("../v1/cars.json", &base).unwrap(); + assert_eq!(result, "https://example.com/data/v1/cars.json"); + } + + #[test] + fn test_resolve_url_relative_no_base_errors() { + let result = resolve_url("data/cars.json", &None); + assert!(result.is_err()); + } + + #[test] + fn test_resolve_url_relative_with_embedded_scheme() { + // A relative reference that contains "://" in a query parameter should + // be joined against the base URL, not treated as absolute. + let base = Some("https://proxy.com/".to_string()); + let result = resolve_url("fetch?target=http://evil.com/data", &base).unwrap(); + assert_eq!( + result, + "https://proxy.com/fetch?target=http://evil.com/data" + ); + } + + #[test] + fn test_normalize_allowed_base_url_star() { + assert_eq!( + normalize_allowed_base_url("*").unwrap(), + AllowedBaseUrlPattern::Any + ); + } + + #[test] + fn test_normalize_allowed_base_url_generic_scheme() { + assert_eq!( + normalize_allowed_base_url("s3:").unwrap(), + AllowedBaseUrlPattern::Scheme("s3".to_string()) + ); + } + + #[test] + fn test_normalize_allowed_base_url_prefix() { + assert_eq!( + normalize_allowed_base_url("https://example.com/data").unwrap(), + AllowedBaseUrlPattern::Prefix("https://example.com/data/".to_string()) + ); + } + + #[test] + fn test_normalize_allowed_base_url_wildcard_host() { + assert_eq!( + normalize_allowed_base_url("https://*.example.com/data").unwrap(), + AllowedBaseUrlPattern::WildcardHost { + scheme: "https".to_string(), + host_suffix: "example.com".to_string(), + path_prefix: "/data/".to_string(), + } + ); + } + + #[test] + #[cfg(not(target_os = "windows"))] + fn test_normalize_allowed_base_url_filesystem_root() { + let tempdir = tempfile::tempdir().unwrap(); + let normalized = normalize_allowed_base_url(tempdir.path().to_str().unwrap()).unwrap(); + assert_eq!( + normalized, + AllowedBaseUrlPattern::FilePathPrefix(fs::canonicalize(tempdir.path()).unwrap()) + ); + } + + #[test] + fn test_normalize_allowed_base_url_rejects_query() { + assert!(normalize_allowed_base_url("https://example.com/data?q=1").is_err()); + } + + #[test] + fn test_is_url_allowed_generic_scheme() { + let patterns = vec![normalize_allowed_base_url("myproto:").unwrap()]; + assert!(is_url_allowed("myproto://warehouse/sales", &patterns)); + assert!(!is_url_allowed("otherproto://warehouse/sales", &patterns)); + } + + #[test] + fn test_is_url_allowed_prefix() { + let patterns = vec![normalize_allowed_base_url("https://example.com/data/").unwrap()]; + assert!(is_url_allowed( + "https://example.com/data/cars.json", + &patterns + )); + assert!(!is_url_allowed( + "https://example.com/other/cars.json", + &patterns + )); + } + + #[test] + fn test_is_url_allowed_wildcard_host() { + let patterns = vec![normalize_allowed_base_url("https://*.example.com/data/").unwrap()]; + assert!(is_url_allowed( + "https://example.com/data/cars.json", + &patterns + )); + assert!(is_url_allowed( + "https://cdn.example.com/data/cars.json", + &patterns + )); + assert!(!is_url_allowed( + "https://example.com.evil.com/data/cars.json", + &patterns + )); + assert!(!is_url_allowed( + "https://cdn.example.com/other/cars.json", + &patterns + )); + } + + #[test] + #[cfg(not(target_os = "windows"))] + fn test_is_url_allowed_filesystem_canonicalization() { + let root = tempfile::tempdir().unwrap(); + let nested = root.path().join("nested"); + std::fs::create_dir_all(&nested).unwrap(); + let file_path = nested.join("data.json"); + std::fs::write(&file_path, "{}").unwrap(); + + let patterns = vec![normalize_allowed_base_url(root.path().to_str().unwrap()).unwrap()]; + assert!(is_url_allowed( + &format!("file://{}", file_path.display()), + &patterns + )); + } + + #[test] + #[cfg(not(target_os = "windows"))] + fn test_is_url_allowed_rejects_parent_traversal() { + let root = tempfile::tempdir().unwrap(); + let allowed = root.path().join("allowed"); + std::fs::create_dir_all(&allowed).unwrap(); + let outside = root.path().join("outside"); + std::fs::create_dir_all(&outside).unwrap(); + let file_path = allowed.join("../outside/data.json"); + + let patterns = vec![normalize_allowed_base_url(allowed.to_str().unwrap()).unwrap()]; + assert!(!is_url_allowed( + &format!("file://{}", file_path.display()), + &patterns + )); + } + + #[test] + #[cfg(not(target_os = "windows"))] + fn test_file_url_to_path_roundtrip() { + let path = "/tmp/my data/file.csv"; + let url = path_to_file_url(path).unwrap(); + let roundtrip = file_url_to_path(&url).unwrap(); + assert_eq!(roundtrip, PathBuf::from(path)); + } +} diff --git a/vegafusion-core/src/proto/services.proto b/vegafusion-core/src/proto/services.proto index 20599ba5e..c4f06981e 100644 --- a/vegafusion-core/src/proto/services.proto +++ b/vegafusion-core/src/proto/services.proto @@ -12,6 +12,9 @@ service VegaFusionRuntime { rpc PreTransformExtract(pretransform.PreTransformExtractRequest) returns (PreTransformExtractResult) {} } +// Multiplexed envelope used by the WASM runtime's query_fn callback, +// including gRPC-Web mode. Bundles task graph queries into one message +// so a single JS function can handle them. message QueryRequest { oneof request { tasks.TaskGraphValueRequest task_graph_values = 1; @@ -44,4 +47,4 @@ message PreTransformExtractResult { errors.Error error = 1; pretransform.PreTransformExtractResponse response = 2; } -} \ No newline at end of file +} diff --git a/vegafusion-core/src/proto/tasks.proto b/vegafusion-core/src/proto/tasks.proto index f44ec540e..045cf564a 100644 --- a/vegafusion-core/src/proto/tasks.proto +++ b/vegafusion-core/src/proto/tasks.proto @@ -4,8 +4,6 @@ package tasks; import "expression.proto"; import "transforms.proto"; -// ## Materialized Task Value -// Represents a fully materialized (computed) task value, either a scalar or table message MaterializedTaskValue { oneof data { /* @@ -183,4 +181,5 @@ message InlineDataset { InlineDatasetTable table = 1; InlineDatasetPlan plan = 2; } -} \ No newline at end of file +} + diff --git a/vegafusion-core/src/runtime/mod.rs b/vegafusion-core/src/runtime/mod.rs index 5e2831d2f..15578fb7f 100644 --- a/vegafusion-core/src/runtime/mod.rs +++ b/vegafusion-core/src/runtime/mod.rs @@ -1,5 +1,8 @@ -mod plan_resolver; mod runtime; -pub use plan_resolver::{PlanResolver, ResolutionResult}; +pub use crate::data::url::{ + canonicalize_path_for_policy_check, check_url_allowed, file_url_to_path, has_url_scheme, + is_absolute_path, is_url_allowed, normalize_allowed_base_url, normalize_allowed_base_urls, + normalize_base_url, path_to_file_url, resolve_url, AllowedBaseUrlPattern, ParsedUrl, +}; pub use runtime::{PreTransformExtractTable, VegaFusionRuntimeTrait}; diff --git a/vegafusion-core/src/runtime/plan_resolver.rs b/vegafusion-core/src/runtime/plan_resolver.rs deleted file mode 100644 index 9798b7155..000000000 --- a/vegafusion-core/src/runtime/plan_resolver.rs +++ /dev/null @@ -1,18 +0,0 @@ -use async_trait::async_trait; -use vegafusion_common::data::table::VegaFusionTable; -use vegafusion_common::datafusion_expr::LogicalPlan; -use vegafusion_common::error::Result; - -pub enum ResolutionResult { - /// Resolver fully materialized the plan - Table(VegaFusionTable), - /// Resolver produced a rewritten plan for the next resolver to handle, - /// or for DataFusion to execute if this is the last resolver - Plan(LogicalPlan), -} - -#[async_trait] -pub trait PlanResolver: Send + Sync + 'static { - fn name(&self) -> &str; - async fn resolve_plan(&self, plan: LogicalPlan) -> Result; -} diff --git a/vegafusion-core/src/spec/data.rs b/vegafusion-core/src/spec/data.rs index c7fd3e092..e6efd4824 100644 --- a/vegafusion-core/src/spec/data.rs +++ b/vegafusion-core/src/spec/data.rs @@ -49,15 +49,19 @@ impl DataSpec { signals.into_iter().sorted().collect() } + /// Formats that VegaFusion can read server-side. Anything else (e.g. topojson) + /// stays client-side for Vega JS to handle. + const SUPPORTED_FORMATS: &'static [&'static str] = &["csv", "tsv", "json", "arrow", "parquet"]; + pub fn supported( &self, planner_config: &PlannerConfig, task_scope: &TaskScope, scope: &[u32], ) -> DependencyNodeSupported { + // Check if the URL format is one VegaFusion can read if let Some(Some(format_type)) = self.format.as_ref().map(|fmt| fmt.type_.clone()) { - if !matches!(format_type.as_str(), "csv" | "tsv" | "arrow" | "json") { - // We don't know how to read the data, so full node is unsupported + if !Self::SUPPORTED_FORMATS.contains(&format_type.as_str()) { return DependencyNodeSupported::Unsupported; } } diff --git a/vegafusion-core/src/spec/visitors.rs b/vegafusion-core/src/spec/visitors.rs index 9f4fc1eba..d2cc9372b 100644 --- a/vegafusion-core/src/spec/visitors.rs +++ b/vegafusion-core/src/spec/visitors.rs @@ -165,28 +165,31 @@ impl ChartVisitor for MakeTasksVisitor<'_> { }; let task = if let Some(url) = &data.url { - let mut proto_url = match url { - StringOrSignalSpec::String(url) => Url::String(url.clone()), + let proto_url = match url { + StringOrSignalSpec::String(url) => { + // Store raw URL string — resolution happens at eval time + let mut proto_url = Url::String(url.clone()); + + // Append fingerprint to URL that references an inline dataset + if let Url::String(url_str) = &proto_url { + if let Some(inline_name) = extract_inline_dataset(url_str) { + let inline_name = inline_name.trim().to_string(); + if let Some(fingerprint) = self.dataset_fingerprints.get(&inline_name) { + proto_url = Url::String(format!("{url_str}#{fingerprint}")); + } else { + let fingerprint = random::(); + proto_url = Url::String(format!("{url_str}#{fingerprint}")); + } + } + } + proto_url + } StringOrSignalSpec::Signal(expr) => { let url_expr = parse(&expr.signal)?; Url::Expr(url_expr) } }; - // Append fingerprint to URL that references an inline dataset - if let Url::String(url) = &proto_url { - if let Some(inline_name) = extract_inline_dataset(url) { - let inline_name = inline_name.trim().to_string(); - if let Some(fingerprint) = self.dataset_fingerprints.get(&inline_name) { - proto_url = Url::String(format!("{url}#{fingerprint}")); - } else { - // Unknown fingerprint, use random id to break cache - let fingerprint = random::(); - proto_url = Url::String(format!("{url}#{fingerprint}")); - } - } - } - Task::new_data_url( data_var, scope, diff --git a/vegafusion-python/src/lib.rs b/vegafusion-python/src/lib.rs index 8ba93e50c..843702056 100644 --- a/vegafusion-python/src/lib.rs +++ b/vegafusion-python/src/lib.rs @@ -20,7 +20,7 @@ use vegafusion_core::proto::gen::pretransform::{ use vegafusion_core::proto::gen::tasks::{TzConfig, Variable}; use vegafusion_runtime::task_graph::GrpcVegaFusionRuntime; -use vegafusion_runtime::task_graph::runtime::VegaFusionRuntime; +use vegafusion_runtime::task_graph::runtime::{VegaFusionRuntime, VegaFusionRuntimeOpts}; use env_logger::{Builder, Target}; use serde_json::json; @@ -32,7 +32,8 @@ use vegafusion_core::task_graph::graph::ScopedVariable; use vegafusion_core::task_graph::task_value::MaterializedTaskValue; use vegafusion_runtime::tokio_runtime::TOKIO_THREAD_STACK_SIZE; -use vegafusion_core::runtime::{PlanResolver, VegaFusionRuntimeTrait}; +use vegafusion_core::runtime::VegaFusionRuntimeTrait; +use vegafusion_runtime::data::plan_resolver::PlanResolver; use vegafusion_runtime::task_graph::cache::VegaFusionCache; use crate::chart_state::PyChartState; @@ -66,9 +67,13 @@ impl PyVegaFusionRuntime { worker_threads: Option, resolvers: Vec>, use_current_thread: bool, + base_url: Option<&Bound>, + allowed_base_urls: Option>, ) -> PyResult { initialize_logging(); + let base_url_setting = parse_base_url(base_url)?; + let tokio_runtime_connection = if use_current_thread { tokio::runtime::Builder::new_current_thread() .enable_all() @@ -88,23 +93,56 @@ impl PyVegaFusionRuntime { }; Ok(Self { - runtime: Arc::new(VegaFusionRuntime::new( - Some(VegaFusionCache::new(max_capacity, memory_limit)), - resolvers, - )), + runtime: Arc::new(VegaFusionRuntime::new(VegaFusionRuntimeOpts { + cache: Some(VegaFusionCache::new(max_capacity, memory_limit)), + plan_resolvers: resolvers, + base_url: base_url_setting, + allowed_base_urls, + })?), tokio_runtime: Arc::new(tokio_runtime_connection), }) } } +/// Parse Python `base_url` argument into `BaseUrlSetting`. +/// +/// - `None` or `True` -> `BaseUrlSetting::Default` (CDN) +/// - `str` -> `BaseUrlSetting::Custom(s)` +/// - `False` -> `BaseUrlSetting::Disabled` +fn parse_base_url( + value: Option<&Bound>, +) -> PyResult { + use vegafusion_runtime::data::pipeline::BaseUrlSetting; + match value { + None => Ok(BaseUrlSetting::Default), + Some(obj) => { + if let Ok(b) = obj.extract::() { + if b { + Ok(BaseUrlSetting::Default) + } else { + Ok(BaseUrlSetting::Disabled) + } + } else if let Ok(s) = obj.extract::() { + Ok(BaseUrlSetting::Custom(s)) + } else { + Err(PyValueError::new_err( + "base_url must be a str, bool, or None", + )) + } + } + } +} + #[pymethods] impl PyVegaFusionRuntime { #[staticmethod] - #[pyo3(signature = (max_capacity=None, memory_limit=None, worker_threads=None))] + #[pyo3(signature = (max_capacity=None, memory_limit=None, worker_threads=None, base_url=None, allowed_base_urls=None))] pub fn new_embedded( max_capacity: Option, memory_limit: Option, worker_threads: Option, + base_url: Option<&Bound>, + allowed_base_urls: Option>, ) -> PyResult { Self::build_with_resolvers( max_capacity, @@ -112,16 +150,20 @@ impl PyVegaFusionRuntime { worker_threads, Vec::new(), false, + base_url, + allowed_base_urls, ) } #[staticmethod] - #[pyo3(signature = (py_resolvers, max_capacity=None, memory_limit=None, worker_threads=None))] + #[pyo3(signature = (py_resolvers, max_capacity=None, memory_limit=None, worker_threads=None, base_url=None, allowed_base_urls=None))] pub fn new_with_resolvers( py_resolvers: Vec>, max_capacity: Option, memory_limit: Option, worker_threads: Option, + base_url: Option<&Bound>, + allowed_base_urls: Option>, ) -> PyResult { let py_resolvers: Vec = py_resolvers .into_iter() @@ -141,6 +183,8 @@ impl PyVegaFusionRuntime { worker_threads, resolvers, use_current_thread, + base_url, + allowed_base_urls, ) } @@ -585,6 +629,68 @@ pub fn inline_table_scan_node(name: String, schema: pyo3_arrow::PySchema) -> PyR Ok(bytes.to_vec()) } +/// Build a LogicalPlanNode protobuf (as bytes) for an external table scan. +/// +/// Use this in `scan_url` implementations to create ExternalTableProvider plan +/// nodes that will later be resolved by `resolve_plan`. +/// +/// Args: +/// table_name: Name for the table in the plan. +/// scheme: Scheme identifier (e.g. "spark"). +/// schema: Arrow schema (arro3.core.Schema) — required for logical planning. +/// metadata: Optional JSON-serializable dict of metadata. +/// +/// Returns: +/// bytes: Serialized LogicalPlanNode protobuf. +#[pyfunction] +#[pyo3(signature = (table_name, scheme, schema, metadata=None))] +pub fn external_table_scan_node( + table_name: String, + scheme: String, + schema: pyo3_arrow::PySchema, + metadata: Option<&Bound<'_, pyo3::types::PyAny>>, +) -> PyResult> { + use datafusion::datasource::provider_as_source; + use datafusion_proto::bytes::logical_plan_to_bytes_with_extension_codec; + use vegafusion_common::datafusion_expr::LogicalPlanBuilder; + use vegafusion_runtime::data::codec::VegaFusionCodec; + use vegafusion_runtime::data::external_table::ExternalTableProvider; + + let arrow_schema = schema.into_inner(); + + let metadata_value: serde_json::Value = match metadata { + Some(obj) => pythonize::depythonize(obj).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!("Failed to convert metadata dict: {e}")) + })?, + None => serde_json::Value::Object(serde_json::Map::new()), + }; + + let provider = Arc::new(ExternalTableProvider::new( + scheme, + arrow_schema, + metadata_value, + )); + let table_source = provider_as_source(provider); + + let plan = LogicalPlanBuilder::scan(&table_name, table_source, None) + .map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!("Failed to build scan plan: {e}")) + })? + .build() + .map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!("Failed to build plan: {e}")) + })?; + + let codec = VegaFusionCodec::new(); + let bytes = logical_plan_to_bytes_with_extension_codec(&plan, &codec).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "Failed to serialize external table plan: {e}" + )) + })?; + + Ok(bytes.to_vec()) +} + /// A Python module implemented in Rust. The name of this function must match /// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to /// import the module. @@ -597,7 +703,9 @@ fn _vegafusion(_py: Python, m: &Bound) -> PyResult<()> { m.add_function(wrap_pyfunction!(get_virtual_memory, m)?)?; m.add_function(wrap_pyfunction!(get_cpu_count, m)?)?; m.add_function(wrap_pyfunction!(inline_table_scan_node, m)?)?; + m.add_function(wrap_pyfunction!(external_table_scan_node, m)?)?; m.add_function(wrap_pyfunction!(unparse::unparse_plan_to_sql, m)?)?; + m.add_function(wrap_pyfunction!(unparse::unparse_expr_to_sql, m)?)?; m.add("__version__", env!("CARGO_PKG_VERSION"))?; Ok(()) } diff --git a/vegafusion-python/src/plan_resolver.rs b/vegafusion-python/src/plan_resolver.rs index cb8fd9a0f..fab976ef5 100644 --- a/vegafusion-python/src/plan_resolver.rs +++ b/vegafusion-python/src/plan_resolver.rs @@ -15,9 +15,11 @@ use vegafusion_common::arrow::record_batch::RecordBatch; use vegafusion_common::data::table::VegaFusionTable; use vegafusion_common::datafusion_expr::LogicalPlan; use vegafusion_common::error::{Result, VegaFusionError}; -use vegafusion_core::runtime::{PlanResolver, ResolutionResult}; +use vegafusion_core::runtime::ParsedUrl; use vegafusion_runtime::data::codec::VegaFusionCodec; use vegafusion_runtime::data::external_table::ExternalTableProvider; +use vegafusion_runtime::data::plan_resolver::PlanResolver; +use vegafusion_runtime::data::plan_resolver::ResolutionResult; /// A `PlanResolver` that delegates to a Python object. /// @@ -28,33 +30,39 @@ pub struct PyPlanResolver { name: String, skip_when_no_external_tables: bool, thread_safe: bool, + has_scan_url_override: bool, } impl PyPlanResolver { pub fn new(py_resolver: Py) -> Self { - let (name, skip_when_no_external_tables, thread_safe) = Python::attach(|py| { - let name = py_resolver - .bind(py) - .get_type() - .qualname() - .map(|q| q.to_string()) - .unwrap_or_else(|_| "PyPlanResolver".to_string()); - let skip = py_resolver - .getattr(py, "skip_when_no_external_tables") - .and_then(|v| v.extract::(py)) - .unwrap_or(true); - let safe = py_resolver - .getattr(py, "thread_safe") - .and_then(|v| v.extract::(py)) - .unwrap_or(true); - (name, skip, safe) - }); + let (name, skip_when_no_external_tables, thread_safe, has_scan_url_override) = + Python::attach(|py| { + let name = py_resolver + .bind(py) + .get_type() + .qualname() + .map(|q| q.to_string()) + .unwrap_or_else(|_| "PyPlanResolver".to_string()); + let skip = py_resolver + .getattr(py, "skip_when_no_external_tables") + .and_then(|v| v.extract::(py)) + .unwrap_or(true); + let safe = py_resolver + .getattr(py, "thread_safe") + .and_then(|v| v.extract::(py)) + .unwrap_or(true); + // Check if the Python class overrides scan_url or scan_url_proto + let has_scan_url = Self::check_method_override(py, &py_resolver, "scan_url") + || Self::check_method_override(py, &py_resolver, "scan_url_proto"); + (name, skip, safe, has_scan_url) + }); Self { py_resolver, name, skip_when_no_external_tables, thread_safe, + has_scan_url_override, } } @@ -62,13 +70,28 @@ impl PyPlanResolver { pub fn thread_safe(&self) -> bool { self.thread_safe } + + /// Check if a Python method is overridden from the base class. + fn check_method_override(py: Python, obj: &Py, method_name: &str) -> bool { + // Get the method from the instance's class and compare to the base PlanResolver class + let result: PyResult = (|| { + let bound = obj.bind(py); + let cls = bound.get_type(); + let base_cls = py + .import("vegafusion.plan_resolver")? + .getattr("PlanResolver")?; + let cls_method = cls.getattr(method_name)?; + let base_method = base_cls.getattr(method_name)?; + Ok(!cls_method.is(&base_method)) + })(); + result.unwrap_or(false) + } } /// Info extracted from an ExternalTableProvider node in the plan. struct ExternalTableInfo { + scheme: String, schema: SchemaRef, - protocol: Option, - source: Option, metadata: Value, ref_id: Option, } @@ -88,9 +111,8 @@ fn extract_external_tables(plan: &LogicalPlan) -> HashMap( // Convert metadata to Python dict let py_metadata = pythonize::pythonize(py, &info.metadata)?; - // Reconstruct ExternalDataset(protocol, schema, metadata, data, source) + // Reconstruct ExternalDataset(scheme, schema, metadata, data) let kwargs = PyDict::new(py); - kwargs.set_item("protocol", info.protocol.as_deref())?; + kwargs.set_item("scheme", &info.scheme)?; kwargs.set_item("schema", py_schema)?; kwargs.set_item("metadata", py_metadata)?; kwargs.set_item("data", &data)?; - kwargs.set_item("source", info.source.as_deref())?; let dataset = dataset_cls.call((), Some(&kwargs))?; dict.set_item(table_name, dataset)?; } @@ -163,6 +184,82 @@ impl PlanResolver for PyPlanResolver { &self.name } + fn supports_arrow_tables(&self) -> bool { + Python::attach(|py| { + self.py_resolver + .getattr(py, "supports_arrow_tables") + .and_then(|v| v.extract::(py)) + .unwrap_or(false) + }) + } + + async fn scan_url(&self, parsed_url: &ParsedUrl) -> Result> { + if !self.has_scan_url_override { + return Ok(None); + } + + Python::attach(|py| { + // Serialize ParsedUrl to a Python dict + let dict = PyDict::new(py); + dict.set_item("url", &parsed_url.url) + .map_err(|e| VegaFusionError::internal(format!("Failed to set url: {e}")))?; + dict.set_item("scheme", &parsed_url.scheme) + .map_err(|e| VegaFusionError::internal(format!("Failed to set scheme: {e}")))?; + dict.set_item("host", parsed_url.host.as_deref()) + .map_err(|e| VegaFusionError::internal(format!("Failed to set host: {e}")))?; + dict.set_item("path", &parsed_url.path) + .map_err(|e| VegaFusionError::internal(format!("Failed to set path: {e}")))?; + // query_params as list of [key, value] pairs + let params: Vec<(&str, &str)> = parsed_url + .query_params + .iter() + .map(|(k, v)| (k.as_str(), v.as_str())) + .collect(); + dict.set_item("query_params", params).map_err(|e| { + VegaFusionError::internal(format!("Failed to set query_params: {e}")) + })?; + dict.set_item("extension", parsed_url.extension.as_deref()) + .map_err(|e| VegaFusionError::internal(format!("Failed to set extension: {e}")))?; + dict.set_item("format_type", parsed_url.format_type.as_deref()) + .map_err(|e| { + VegaFusionError::internal(format!("Failed to set format_type: {e}")) + })?; + + let result = self + .py_resolver + .call_method1(py, "scan_url_proto", (&dict,)) + .map_err(|e| { + VegaFusionError::internal(format!("Python scan_url_proto failed: {e}")) + })?; + + let result_ref = result.bind(py); + + if result_ref.is_none() { + return Ok(None); + } + + // Result is bytes — deserialize into LogicalPlan + let plan_bytes: Vec = result_ref.extract().map_err(|e| { + VegaFusionError::internal(format!( + "scan_url_proto must return bytes or None, got: {e}" + )) + })?; + + let ctx = vegafusion_runtime::datafusion::context::make_datafusion_context(); + let codec = VegaFusionCodec::new(); + let plan = datafusion_proto::bytes::logical_plan_from_bytes_with_extension_codec( + &plan_bytes, + &ctx.task_ctx(), + &codec, + ) + .map_err(|e| { + VegaFusionError::internal(format!("Failed to deserialize scan_url plan: {e}")) + })?; + + Ok(Some(plan)) + }) + } + async fn resolve_plan(&self, plan: LogicalPlan) -> Result { let tables = extract_external_tables(&plan); diff --git a/vegafusion-python/src/unparse.rs b/vegafusion-python/src/unparse.rs index 149b291d7..bb7bf255e 100644 --- a/vegafusion-python/src/unparse.rs +++ b/vegafusion-python/src/unparse.rs @@ -2,12 +2,31 @@ use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use datafusion::prelude::SessionContext; +use datafusion_proto::generated::datafusion::LogicalExprNode; +use datafusion_proto::logical_plan::from_proto::parse_expr; use datafusion_sql::unparser::dialect::{ - BigQueryDialect, DefaultDialect, DuckDBDialect, MySqlDialect, PostgreSqlDialect, SqliteDialect, + BigQueryDialect, DefaultDialect, Dialect, DuckDBDialect, MySqlDialect, PostgreSqlDialect, + SqliteDialect, }; use datafusion_sql::unparser::Unparser; +use prost::Message; +use vegafusion_common::datafusion_expr::Expr; use vegafusion_runtime::data::codec::VegaFusionCodec; +fn make_dialect(dialect: &str) -> PyResult> { + match dialect { + "default" => Ok(Box::new(DefaultDialect {})), + "postgres" | "postgresql" => Ok(Box::new(PostgreSqlDialect {})), + "mysql" => Ok(Box::new(MySqlDialect {})), + "sqlite" => Ok(Box::new(SqliteDialect {})), + "duckdb" => Ok(Box::new(DuckDBDialect::new())), + "bigquery" => Ok(Box::new(BigQueryDialect {})), + _ => Err(PyValueError::new_err(format!( + "Unknown dialect '{dialect}'. Supported: default, postgres, mysql, sqlite, duckdb, bigquery" + ))), + } +} + /// Convert a protobuf-serialized LogicalPlan to a SQL string. /// /// Args: @@ -29,39 +48,58 @@ pub fn unparse_plan_to_sql(plan_bytes: Vec, dialect: &str) -> PyResult { - let d = DefaultDialect {}; - Unparser::new(&d).plan_to_sql(&plan) - } - "postgres" | "postgresql" => { - let d = PostgreSqlDialect {}; - Unparser::new(&d).plan_to_sql(&plan) - } - "mysql" => { - let d = MySqlDialect {}; - Unparser::new(&d).plan_to_sql(&plan) - } - "sqlite" => { - let d = SqliteDialect {}; - Unparser::new(&d).plan_to_sql(&plan) - } - "duckdb" => { - let d = DuckDBDialect::new(); - Unparser::new(&d).plan_to_sql(&plan) - } - "bigquery" => { - let d = BigQueryDialect {}; - Unparser::new(&d).plan_to_sql(&plan) - } - _ => { - return Err(PyValueError::new_err(format!( - "Unknown dialect '{}'. Supported: default, postgres, mysql, sqlite, duckdb, bigquery", - dialect - ))); - } - } - .map_err(|e| PyValueError::new_err(format!("Failed to unparse plan to SQL: {e}")))?; + let d = make_dialect(dialect)?; + let sql = Unparser::new(d.as_ref()) + .plan_to_sql(&plan) + .map_err(|e| PyValueError::new_err(format!("Failed to unparse plan to SQL: {e}")))?; Ok(sql.to_string()) } + +/// Convert protobuf-serialized filter expressions to a SQL WHERE clause string. +/// +/// Accepts a single expression or a list of expressions (joined with AND). +/// +/// Args: +/// expr_bytes: A single serialized LogicalExprNode (bytes) or a list of them. +/// dialect: SQL dialect name. One of "default", "postgres", "mysql", +/// "sqlite", "duckdb", "bigquery". +/// +/// Returns: +/// The SQL string representation of the expression(s). +#[pyfunction] +#[pyo3(signature = (expr_bytes, dialect="default"))] +pub fn unparse_expr_to_sql(expr_bytes: Vec>, dialect: &str) -> PyResult { + if expr_bytes.is_empty() { + return Err(PyValueError::new_err( + "expr_bytes must contain at least one expression", + )); + } + + let ctx = SessionContext::new(); + let codec = VegaFusionCodec::new(); + + let exprs: Vec = expr_bytes + .iter() + .map(|bytes| { + let proto = LogicalExprNode::decode(bytes.as_slice()).map_err(|e| { + PyValueError::new_err(format!("Failed to decode LogicalExprNode: {e}")) + })?; + parse_expr(&proto, &ctx, &codec) + .map_err(|e| PyValueError::new_err(format!("Failed to parse expression: {e}"))) + }) + .collect::>>()?; + + // Join multiple expressions with AND + let combined = exprs + .into_iter() + .reduce(|a, b| a.and(b)) + .expect("non-empty after validation"); + + let d = make_dialect(dialect)?; + let sql_expr = Unparser::new(d.as_ref()) + .expr_to_sql(&combined) + .map_err(|e| PyValueError::new_err(format!("Failed to unparse expression to SQL: {e}")))?; + + Ok(sql_expr.to_string()) +} diff --git a/vegafusion-python/src/utils.rs b/vegafusion-python/src/utils.rs index 9de34cc1e..1da72cbfd 100644 --- a/vegafusion-python/src/utils.rs +++ b/vegafusion-python/src/utils.rs @@ -24,13 +24,12 @@ pub fn process_inline_datasets( .iter() .map(|(name, inline_dataset)| { let inline_dataset = inline_dataset; - let dataset = if inline_dataset.hasattr("protocol")? + let dataset = if inline_dataset.hasattr("scheme")? && inline_dataset.hasattr("schema")? && inline_dataset.hasattr("metadata")? { - // Handle ExternalDataset with .protocol, .schema, .metadata - let protocol: Option = - inline_dataset.getattr("protocol")?.extract()?; + // Handle ExternalDataset with .scheme, .schema, .metadata + let scheme: String = inline_dataset.getattr("scheme")?.extract()?; let pyschema = inline_dataset.getattr("schema")?.extract::()?; let schema = pyschema.into_inner(); let metadata_obj = inline_dataset.getattr("metadata")?; @@ -42,7 +41,7 @@ pub fn process_inline_datasets( })?; let provider = - Arc::new(ExternalTableProvider::new(schema, protocol, metadata)); + Arc::new(ExternalTableProvider::new(scheme, schema, metadata)); let table_source = provider_as_source(provider); let logical_plan = LogicalPlanBuilder::scan(name.to_string(), table_source, None) diff --git a/vegafusion-python/tests/test_plan_resolver.py b/vegafusion-python/tests/test_plan_resolver.py index 1f0ed761e..0ac1208aa 100644 --- a/vegafusion-python/tests/test_plan_resolver.py +++ b/vegafusion-python/tests/test_plan_resolver.py @@ -83,9 +83,7 @@ def test_passthrough_resolver() -> None: source_table = pa.table({"x": [1, 5, 10], "y": ["a", "b", "c"]}) expected_result = pa.table({"x": [5, 10], "y": ["b", "c"]}) - ext = ExternalDataset( - protocol="test", schema=source_table.schema, data=source_table - ) + ext = ExternalDataset(scheme="test", schema=source_table.schema, data=source_table) resolver = PassthroughResolver(result_table=expected_result) rt = vf.VegaFusionRuntime(plan_resolver=resolver) @@ -115,9 +113,7 @@ def test_deserializing_resolver() -> None: source_table = pa.table({"x": [1, 5, 10], "y": ["a", "b", "c"]}) expected_result = pa.table({"x": [5, 10], "y": ["b", "c"]}) - ext = ExternalDataset( - protocol="test", schema=source_table.schema, data=source_table - ) + ext = ExternalDataset(scheme="test", schema=source_table.schema, data=source_table) resolver = DeserializingResolver(result_table=expected_result) rt = vf.VegaFusionRuntime(plan_resolver=resolver) @@ -146,11 +142,10 @@ def test_external_dataset_registry() -> None: """ExternalDataset with data registers data in weakref registry.""" table = pa.table({"a": [1, 2, 3]}) ext = ExternalDataset( - protocol="test", schema=table.schema, data=table, metadata={"engine": "test"} + scheme="test", schema=table.schema, data=table, metadata={"engine": "test"} ) - assert ext.protocol == "test" - assert "_vf_protocol" not in ext.metadata # protocol is separate from metadata + assert ext.scheme == "test" assert "_vf_ref_id" in ext.metadata ref_id = ext.metadata["_vf_ref_id"] assert ExternalDataset.resolve_data(ref_id) is table @@ -161,7 +156,7 @@ def test_external_dataset_registry() -> None: def test_external_dataset_schema_only() -> None: """ExternalDataset without data does not register.""" schema = pa.schema([("x", pa.int64())]) - ext = ExternalDataset(protocol="test", schema=schema) + ext = ExternalDataset(scheme="test", schema=schema) assert "_vf_ref_id" not in ext.metadata assert ext.data is None @@ -317,9 +312,11 @@ def __init__(self) -> None: def resolve_table( self, name: str, + scheme: str, schema: Any, - metadata: dict[str, Any], + metadata: dict[str, Any] | None = None, projected_columns: list[str] | None = None, + filters: list[Any] | None = None, ) -> pa.Table: self.resolve_calls.append( { @@ -331,9 +328,7 @@ def resolve_table( return source_table resolver = TableResolver() - ext = ExternalDataset( - protocol="test", schema=source_table.schema, data=source_table - ) + ext = ExternalDataset(scheme="test", schema=source_table.schema, data=source_table) rt = vf.VegaFusionRuntime(plan_resolver=resolver) spec = simple_spec() @@ -406,9 +401,7 @@ def _replace_custom_scan( self._replace_custom_scan(child, target_name, replacement) resolver = ManualResolver() - ext = ExternalDataset( - protocol="test", schema=source_table.schema, data=source_table - ) + ext = ExternalDataset(scheme="test", schema=source_table.schema, data=source_table) rt = vf.VegaFusionRuntime(plan_resolver=resolver) spec = simple_spec() @@ -440,9 +433,11 @@ def __init__(self) -> None: def resolve_table( self, name: str, + scheme: str, schema: Any, - metadata: dict[str, Any], + metadata: dict[str, Any] | None = None, projected_columns: list[str] | None = None, + filters: list[Any] | None = None, ) -> pa.Table: self.resolved_names.append(name) if name == "source_a": @@ -469,8 +464,8 @@ def resolve_table( ], } - ext_a = ExternalDataset(protocol="test", schema=table_a.schema, data=table_a) - ext_b = ExternalDataset(protocol="test", schema=table_b.schema, data=table_b) + ext_a = ExternalDataset(scheme="test", schema=table_a.schema, data=table_a) + ext_b = ExternalDataset(scheme="test", schema=table_b.schema, data=table_b) resolver = MultiTableResolver() rt = vf.VegaFusionRuntime(plan_resolver=resolver) @@ -502,15 +497,15 @@ class FailingResolver(PlanResolver): def resolve_table( self, name: str, + scheme: str, schema: Any, - metadata: dict[str, Any], + metadata: dict[str, Any] | None = None, projected_columns: list[str] | None = None, + filters: list[Any] | None = None, ) -> pa.Table: raise ValueError("Simulated resolver failure") - ext = ExternalDataset( - protocol="test", schema=source_table.schema, data=source_table - ) + ext = ExternalDataset(scheme="test", schema=source_table.schema, data=source_table) resolver = FailingResolver() rt = vf.VegaFusionRuntime(plan_resolver=resolver) spec = simple_spec() @@ -581,9 +576,7 @@ def resolve_plan_proto( return source_table resolver = SqlCapturingResolver() - ext = ExternalDataset( - protocol="test", schema=source_table.schema, data=source_table - ) + ext = ExternalDataset(scheme="test", schema=source_table.schema, data=source_table) rt = vf.VegaFusionRuntime(plan_resolver=resolver) spec = simple_spec() @@ -632,9 +625,7 @@ def resolve_plan(self, logical_plan: Any, datasets: dict[str, Any]) -> pa.Table: return source_table resolver = ProtoCapturingResolver() - ext = ExternalDataset( - protocol="test", schema=source_table.schema, data=source_table - ) + ext = ExternalDataset(scheme="test", schema=source_table.schema, data=source_table) rt = vf.VegaFusionRuntime(plan_resolver=resolver) spec = simple_spec() @@ -648,19 +639,14 @@ def resolve_plan(self, logical_plan: Any, datasets: dict[str, Any]) -> pa.Table: assert resolver.sql_from_proto is not None assert resolver.sql_from_bytes is not None + # Both paths (bytes and proto message) produce identical SQL assert resolver.sql_from_proto == resolver.sql_from_bytes - # Verify the SQL references the external table name - assert resolver.sql_from_proto == snapshot( - 'SELECT "x", "y" FROM (SELECT "_vf_order" AS "_vf_order", "source"."x" AS "x", "source"."y" AS "y" FROM (SELECT row_number() OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "_vf_order", "source"."x", "source"."y" FROM "source") AS "derived_projection") AS "derived_projection" WHERE CASE WHEN ("x" > 3.0) IS NULL THEN false ELSE ("x" > 3.0) END ORDER BY "_vf_order" ASC NULLS LAST' - ) def test_external_dataset_without_resolver_raises() -> None: """ExternalDataset without a plan resolver raises ValueError with helpful message.""" source_table = pa.table({"x": [1, 2, 3]}) - ext = ExternalDataset( - protocol="spark", schema=source_table.schema, data=source_table - ) + ext = ExternalDataset(scheme="spark", schema=source_table.schema, data=source_table) rt = vf.VegaFusionRuntime() # No resolver spec = simple_spec() @@ -694,9 +680,7 @@ def resolve_plan_proto( return source_table resolver = DialectTestResolver() - ext = ExternalDataset( - protocol="test", schema=source_table.schema, data=source_table - ) + ext = ExternalDataset(scheme="test", schema=source_table.schema, data=source_table) rt = vf.VegaFusionRuntime(plan_resolver=resolver) rt.pre_transform_datasets( @@ -708,3 +692,292 @@ def resolve_plan_proto( assert resolver.error is not None assert "Unknown dialect" in str(resolver.error) + + +def test_scan_url_called_with_structured_dict() -> None: + """scan_url receives a structured dict with parsed URL fields.""" + from vegafusion.plan_resolver import external_table_scan_node + + received_urls: list[dict[str, Any]] = [] + + class UrlCapturingResolver(PlanResolver): + def scan_url(self, parsed_url: dict[str, Any]) -> Any: + received_urls.append(parsed_url) + # Create an ExternalTableProvider plan node + schema = pa.schema([("x", pa.int64()), ("y", pa.utf8())]) + return external_table_scan_node( + table_name="captured", + schema=schema, + scheme="test", + metadata={"source_url": parsed_url["url"]}, + ) + + def resolve_table( + self, + name: str, + scheme: str, + schema: Any, + metadata: dict[str, Any] | None = None, + projected_columns: list[str] | None = None, + filters: list[Any] | None = None, + ) -> pa.Table: + return pa.table({"x": [1, 2], "y": ["a", "b"]}) + + resolver = UrlCapturingResolver() + rt = vf.VegaFusionRuntime(plan_resolver=resolver) + + spec = { + "$schema": "https://vega.github.io/schema/vega/v5.json", + "data": [ + { + "name": "source", + "url": "https://example.com/data.csv?limit=10&format=raw", + "format": {"type": "csv"}, + } + ], + } + + rt.pre_transform_datasets(spec, datasets=["source"], dataset_format="pyarrow") + + assert len(received_urls) == 1 + url_dict = received_urls[0] + assert url_dict["scheme"] == "https" + assert url_dict["host"] == "example.com" + assert url_dict["url"].startswith("https://example.com/data.csv") + assert url_dict["extension"] == "csv" + assert url_dict["format_type"] == "csv" + # Query params preserved + assert isinstance(url_dict["query_params"], list) + + +def test_scan_url_none_falls_back_to_datafusion() -> None: + """scan_url returning None causes DataFusion to handle the URL.""" + + class NoOpScanner(PlanResolver): + def __init__(self) -> None: + self.scan_url_called = False + + def scan_url(self, parsed_url: dict[str, Any]) -> Any: + self.scan_url_called = True + return None # Pass to next resolver (DataFusion) + + csv_path = os.path.join(tempfile.gettempdir(), "vf_scan_fallback.csv") + table = pa.table({"x": [1, 5, 10]}) + pcsv.write_csv(table, csv_path) + + resolver = NoOpScanner() + rt = vf.VegaFusionRuntime(plan_resolver=resolver) + + spec = { + "$schema": "https://vega.github.io/schema/vega/v5.json", + "data": [ + { + "name": "source", + "url": csv_path, + "format": {"type": "csv"}, + } + ], + } + + datasets, _warnings = rt.pre_transform_datasets( + spec, datasets=["source"], dataset_format="pyarrow" + ) + + assert resolver.scan_url_called + assert len(datasets) == 1 + assert datasets[0].num_rows == 3 + + +def test_custom_scheme_via_scan_url() -> None: + """Custom scheme URLs are handled via scan_url at runtime.""" + from vegafusion.plan_resolver import external_table_scan_node + + class CustomSchemeResolver(PlanResolver): + def scan_url(self, parsed_url: dict[str, Any]) -> Any: + if parsed_url["scheme"] == "myproto": + schema = pa.schema([("val", pa.int64())]) + return external_table_scan_node( + table_name="custom_data", + schema=schema, + scheme="myproto", + ) + return None + + def resolve_table( + self, + name: str, + scheme: str, + schema: Any, + metadata: dict[str, Any] | None = None, + projected_columns: list[str] | None = None, + filters: list[Any] | None = None, + ) -> pa.Table: + return pa.table({"val": [42, 99]}) + + resolver = CustomSchemeResolver() + rt = vf.VegaFusionRuntime(plan_resolver=resolver) + + spec = { + "$schema": "https://vega.github.io/schema/vega/v5.json", + "data": [ + { + "name": "source", + "url": "myproto://database/table1", + } + ], + } + + datasets, _warnings = rt.pre_transform_datasets( + spec, datasets=["source"], dataset_format="pyarrow" + ) + + assert len(datasets) == 1 + assert datasets[0].column("val").to_pylist() == [42, 99] + + +def test_scan_url_not_called_without_override() -> None: + """Resolver without scan_url override does not trigger Python roundtrip.""" + + class SimpleResolver(PlanResolver): + """Only overrides resolve_table — scan_url is not overridden.""" + + def resolve_table( + self, + name: str, + scheme: str, + schema: Any, + metadata: dict[str, Any] | None = None, + projected_columns: list[str] | None = None, + filters: list[Any] | None = None, + ) -> pa.Table: + return pa.table({"x": [1, 2, 3]}) + + source_table = pa.table({"x": [1, 2, 3]}) + ext = ExternalDataset(scheme="test", schema=source_table.schema, data=source_table) + resolver = SimpleResolver() + rt = vf.VegaFusionRuntime(plan_resolver=resolver) + + spec = simple_spec() + # This exercises the code path where check_method_override detects no + # scan_url override, so the Rust side skips the Python call entirely. + # If the detection were wrong, the base class scan_url (returning None) + # would still work, but we'd pay an unnecessary Python roundtrip. + datasets, _warnings = rt.pre_transform_datasets( + spec, + datasets=["filtered"], + inline_datasets={"source": ext}, + dataset_format="pyarrow", + ) + + assert len(datasets) == 1 + + +def test_resolve_table_with_filter_transform() -> None: + """resolve_table works with a Vega filter transform; filter is applied after resolution.""" + from vegafusion.plan_resolver import external_table_scan_node + + class FilterCapturingResolver(PlanResolver): + def __init__(self) -> None: + self.captured_filters: list[Any] = [] + + def scan_url(self, parsed_url: dict[str, Any]) -> Any: + if parsed_url["scheme"] == "myproto": + schema = pa.schema([("x", pa.int64()), ("y", pa.utf8())]) + return external_table_scan_node( + table_name="data", + schema=schema, + scheme="myproto", + ) + return None + + def resolve_table( + self, + name: str, + scheme: str, + schema: Any, + metadata: dict[str, Any] | None = None, + projected_columns: list[str] | None = None, + filters: list[Any] | None = None, + ) -> pa.Table: + self.captured_filters.extend(filters or []) + return pa.table({"x": [1, 5, 10], "y": ["a", "b", "c"]}) + + resolver = FilterCapturingResolver() + rt = vf.VegaFusionRuntime(plan_resolver=resolver) + + spec = { + "$schema": "https://vega.github.io/schema/vega/v5.json", + "data": [ + { + "name": "source", + "url": "myproto://db/table", + "transform": [{"type": "filter", "expr": "datum.x > 3"}], + } + ], + } + + datasets, _warnings = rt.pre_transform_datasets( + spec, + datasets=["source"], + dataset_format="pyarrow", + ) + + assert len(datasets) == 1 + # Filter is applied by DataFusion after resolve_table returns + result = datasets[0] + assert result.column("x").to_pylist() == [5, 10] + + # TODO: filters should be pushed down to resolve_table so resolvers can + # optimize data loading. Currently blocked because VegaFusion's _vf_order + # window sits between the scan and user filters, preventing DataFusion's + # PushDownFilter from reaching the ExternalTableProvider. + assert resolver.captured_filters == [] + + +def test_unparse_expr_to_sql() -> None: + """unparse_expr_to_sql converts proto expressions to SQL strings.""" + from vegafusion.plan_resolver import unparse_expr_to_sql + from vegafusion.proto.datafusion_pb2 import ( + BinaryExprNode, + LogicalExprNode, + ) + from vegafusion.proto.datafusion.proto_common.proto.datafusion_common_pb2 import ( + Column as ColumnProto, + ScalarValue, + ) + + # Build proto for: x > 3 + col_x = LogicalExprNode(column=ColumnProto(name="x")) + lit_3 = LogicalExprNode( + literal=ScalarValue(int64_value=3), + ) + gt_expr = LogicalExprNode( + binary_expr=BinaryExprNode( + operands=[col_x, lit_3], + op="Gt", + ) + ) + + # Build proto for: y = 'hello' + col_y = LogicalExprNode(column=ColumnProto(name="y")) + lit_hello = LogicalExprNode( + literal=ScalarValue(utf8_value="hello"), + ) + eq_expr = LogicalExprNode( + binary_expr=BinaryExprNode( + operands=[col_y, lit_hello], + op="Eq", + ) + ) + + # Single expression + sql_single = unparse_expr_to_sql(gt_expr) + assert sql_single == snapshot("(x > 3)") + + # Multiple expressions joined with AND + sql_multi = unparse_expr_to_sql([gt_expr, eq_expr]) + assert sql_multi == snapshot("((x > 3) AND (y = 'hello'))") + + # With postgres dialect + sql_pg = unparse_expr_to_sql([gt_expr, eq_expr], dialect="postgres") + assert sql_pg == snapshot('(("x" > 3) AND ("y" = \'hello\'))') diff --git a/vegafusion-python/tests/test_runtime_config.py b/vegafusion-python/tests/test_runtime_config.py new file mode 100644 index 000000000..2721679e7 --- /dev/null +++ b/vegafusion-python/tests/test_runtime_config.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +import pytest + +import vegafusion as vf +import vegafusion._vegafusion as _core + + +def test_runtime_exposes_url_policy_properties() -> None: + rt = vf.VegaFusionRuntime( + memory_limit=1, + worker_threads=1, + base_url="https://example.com/data/", + allowed_base_urls=["https://example.com/data/"], + ) + + assert rt.base_url == "https://example.com/data/" + assert rt.allowed_base_urls == ["https://example.com/data/"] + + +def test_runtime_passes_url_policy_to_embedded_runtime( + monkeypatch: pytest.MonkeyPatch, +) -> None: + calls: list[dict[str, object]] = [] + + class FakeRuntime: + def clear_cache(self) -> None: + return None + + class FakePyVegaFusionRuntime: + @staticmethod + def new_embedded( + cache_capacity: int, + memory_limit: int, + worker_threads: int, + base_url: str | bool | None = None, + allowed_base_urls: list[str] | None = None, + ) -> FakeRuntime: + calls.append( + { + "cache_capacity": cache_capacity, + "memory_limit": memory_limit, + "worker_threads": worker_threads, + "base_url": base_url, + "allowed_base_urls": allowed_base_urls, + } + ) + return FakeRuntime() + + monkeypatch.setattr(_core, "PyVegaFusionRuntime", FakePyVegaFusionRuntime) + + rt = vf.VegaFusionRuntime( + cache_capacity=8, + memory_limit=256, + worker_threads=2, + base_url=False, + allowed_base_urls=["file:///tmp/allowed/"], + ) + + _ = rt.runtime + + assert calls == [ + { + "cache_capacity": 8, + "memory_limit": 256, + "worker_threads": 2, + "base_url": False, + "allowed_base_urls": ["file:///tmp/allowed/"], + } + ] + + +def test_grpc_connect_rejects_local_url_policy() -> None: + rt = vf.VegaFusionRuntime(base_url=False) + + with pytest.raises(ValueError, match="base_url or allowed_base_urls"): + rt.grpc_connect("http://127.0.0.1:50051") + + rt = vf.VegaFusionRuntime(allowed_base_urls=[]) + + with pytest.raises(ValueError, match="base_url or allowed_base_urls"): + rt.grpc_connect("http://127.0.0.1:50051") + + +def test_url_policy_setters_reject_changes_while_using_grpc( + monkeypatch: pytest.MonkeyPatch, +) -> None: + calls: list[str] = [] + + class FakeRuntime: + def clear_cache(self) -> None: + calls.append("clear_cache") + + class FakePyVegaFusionRuntime: + @staticmethod + def new_grpc(url: str) -> FakeRuntime: + calls.append(url) + return FakeRuntime() + + monkeypatch.setattr(_core, "PyVegaFusionRuntime", FakePyVegaFusionRuntime) + + rt = vf.VegaFusionRuntime() + rt.grpc_connect("http://127.0.0.1:50051") + + with pytest.raises(ValueError, match="vegafusion-server"): + rt.base_url = False + + with pytest.raises(ValueError, match="vegafusion-server"): + rt.allowed_base_urls = [] + + assert calls == ["http://127.0.0.1:50051"] diff --git a/vegafusion-python/vegafusion/dataset.py b/vegafusion-python/vegafusion/dataset.py index 3ad9e10c2..13fed4a89 100644 --- a/vegafusion-python/vegafusion/dataset.py +++ b/vegafusion-python/vegafusion/dataset.py @@ -17,12 +17,12 @@ def __init__(self, data: Any) -> None: # noqa: ANN401 class ExternalDataset: - """External dataset with protocol, schema, metadata, and optional data ref. + """External dataset with scheme, schema, metadata, and optional data ref. - The ``protocol`` parameter is an optional short identifier for the data - source type (e.g. ``"spark"``, ``"snowflake"``, ``"duckdb"``). It is - propagated through protobuf separately from metadata so that error - messages can name the source when no resolver is registered. + The ``scheme`` parameter identifies the data source type + (e.g. ``"spark"``, ``"snowflake"``, ``"duckdb"``). It is propagated + through protobuf separately from metadata so that error messages can + name the source when no resolver is registered. When ``data`` is provided, it is registered in a class-level :class:`weakref.WeakValueDictionary` keyed by a UUID. The UUID is @@ -42,17 +42,15 @@ class ExternalDataset: def __init__( self, - protocol: str | None = None, - schema: Any = None, # noqa: ANN401 + scheme: str, + schema: Any, # noqa: ANN401 metadata: dict[str, Any] | None = None, data: Any = None, # noqa: ANN401 - source: str | None = None, ) -> None: self._schema: Schema = ( Schema.from_arrow(schema) if not isinstance(schema, Schema) else schema ) - self._protocol = protocol - self._source = source + self._scheme = scheme self._metadata: dict[str, Any] = dict(metadata) if metadata else {} self._data: Any = data self._data_ref: _DataRef | None = None @@ -74,18 +72,21 @@ def resolve_data(cls, ref_id: str) -> Any | None: # noqa: ANN401 return data_ref.data if data_ref is not None else None @property - def protocol(self) -> str | None: - """Optional short identifier for the data source type (e.g. ``"spark"``).""" - return self._protocol + def scheme(self) -> str: + """Short identifier for the data source type (e.g. ``"spark"``).""" + return self._scheme @property def schema(self) -> Schema: + """Arrow schema of the external table (``arro3.core.Schema``).""" return self._schema @property def metadata(self) -> dict[str, Any]: + """JSON-serializable metadata dict propagated through the plan.""" return self._metadata @property def data(self) -> Any: # noqa: ANN401 + """The opaque data object, or ``None`` if not provided.""" return self._data diff --git a/vegafusion-python/vegafusion/plan_resolver.py b/vegafusion-python/vegafusion/plan_resolver.py index c29f893d5..34cab1009 100644 --- a/vegafusion-python/vegafusion/plan_resolver.py +++ b/vegafusion-python/vegafusion/plan_resolver.py @@ -12,13 +12,14 @@ _PROTOBUF_INSTALL_HINT = ( "The 'protobuf' package is required for plan-level resolvers " "(resolve_plan / resolve_plan_proto) and related utilities " - "(inline_table_scan_node, unparse_to_sql). " + "(inline_table_scan_node, unparse_to_sql, unparse_expr_to_sql). " "Install it with: pip install vegafusion[plan-resolver]" ) if TYPE_CHECKING: from vegafusion.dataset import ExternalDataset from vegafusion.proto.datafusion_pb2 import ( + LogicalExprNode, # type: ignore[attr-defined] LogicalPlanNode, # type: ignore[attr-defined] ) @@ -46,10 +47,15 @@ class ResolvedPlan: class PlanResolver: """Base class for plan resolvers. - Override one of these (checked in priority order): + Override one of these (simplest first): - 1. ``resolve_table`` — provide data for each external table independently - 2. ``resolve_plan_proto`` / ``resolve_plan`` — full control over resolution + - ``resolve_table``: return data for each external table independently. + The default ``resolve_plan`` walks the plan and calls this for every + ``ExternalTableProvider`` node. + - ``resolve_plan_proto`` / ``resolve_plan``: receive the entire logical + plan. Overriding this supersedes ``resolve_table`` since the runtime + calls ``resolve_plan`` directly; ``resolve_table`` is only reached + via the default implementation. For ``resolve_plan``, override either the ``_proto`` variant (raw bytes) or the non-``_proto`` variant (deserialized ``LogicalPlanNode``). The ``_proto`` @@ -75,12 +81,59 @@ class PlanResolver: callbacks run on the main thread. Set to False for backends with thread-affine connections (e.g. DuckDB in-memory databases).""" + supports_arrow_tables: bool = False + """Whether this resolver can efficiently consume in-memory Arrow tables. + When all resolvers in the pipeline return True, the runtime may eagerly + materialize LogicalPlans into tables. When False, data is kept as lazy + plans so resolvers that need plan-level access can intercept them.""" + + def scan_url_proto(self, parsed_url: dict[str, Any]) -> bytes | None: + """Handle a URL during the scan phase (raw bytes variant). + + The default implementation delegates to :meth:`scan_url` which works + with deserialized ``LogicalPlanNode`` messages. + + Args: + parsed_url: Dict with keys ``url``, ``scheme``, ``host``, ``path``, + ``query_params``, ``extension``, ``format_type``. + + Returns: + Serialized ``LogicalPlanNode`` bytes, or None to pass to the next + resolver. + """ + result = self.scan_url(parsed_url) + if result is None: + return None + if isinstance(result, bytes): + return result + # It's a LogicalPlanNode proto message + return result.SerializeToString() + + def scan_url(self, parsed_url: dict[str, Any]) -> LogicalPlanNode | bytes | None: + """Handle a URL during the scan phase. + + Override to claim URLs by returning a ``LogicalPlanNode`` or raw bytes. + Use :func:`external_table_scan_node` to build ``ExternalTableProvider`` + plan nodes that will later be resolved by :meth:`resolve_plan`. + + Args: + parsed_url: Dict with keys ``url``, ``scheme``, ``host``, ``path``, + ``query_params``, ``extension``, ``format_type``. + + Returns: + A ``LogicalPlanNode``, raw bytes, or None to pass to the next + resolver. + """ + return None + def resolve_table( self, name: str, + scheme: str, schema: Schema, - metadata: dict[str, Any], + metadata: dict[str, Any] | None = None, projected_columns: list[str] | None = None, + filters: list[Any] | None = None, ) -> Table: """Provide data for an external table reference. @@ -88,10 +141,18 @@ def resolve_table( Args: name: Table name from the plan. + scheme: URL scheme identifier (e.g. ``"spark"``, + ``"snowflake"``). schema: Full schema of the external table. metadata: JSON metadata dict from ExternalTableProvider. projected_columns: Column names DataFusion actually needs. None if no projection (all columns needed). + filters: Pushed-down filter predicates from DataFusion as + ``LogicalExprNode`` protobuf messages, already split into + a conjunction (individual expressions from AND). These are + hints — resolvers may apply some, all, or none. DataFusion + re-applies all filters on the output regardless. Use + :func:`unparse_expr_to_sql` to convert to SQL strings. Returns: An Arrow-compatible table (arro3, PyArrow, etc.). @@ -106,7 +167,20 @@ def resolve_plan_proto( """Resolve a plan given raw protobuf bytes. The default implementation deserializes into a - LogicalPlanNode and calls resolve_plan(). + ``LogicalPlanNode`` and delegates to :meth:`resolve_plan`. + + Override this (instead of ``resolve_plan``) when you only need + the serialized bytes, e.g. to pass them directly to + :func:`unparse_to_sql` without a deserialization round-trip. + + Args: + plan_bytes: Serialized ``LogicalPlanNode`` protobuf bytes. + datasets: Dict mapping table names to :class:`ExternalDataset` + instances for every ``ExternalTableProvider`` in the plan. + + Returns: + An Arrow-compatible table (full execution) or a + :class:`ResolvedPlan` (plan rewriting with sidecar data). """ try: from vegafusion.proto.datafusion_pb2 import ( @@ -133,12 +207,24 @@ def resolve_plan( logical_plan: LogicalPlanNode, datasets: dict[str, ExternalDataset], ) -> ResolutionResult: - """Resolve a plan given a deserialized LogicalPlanNode. + """Resolve a plan given a deserialized ``LogicalPlanNode``. + + The default implementation walks the plan tree, finds + ``ExternalTableProvider`` nodes, calls :meth:`resolve_table` for + each, and replaces them with :func:`inline_table_scan_node` markers. - The default implementation walks the plan tree looking for - CustomTableScanNode nodes that correspond to ExternalTableProvider - entries. For each, it calls resolve_table() and replaces the node - with an inline_table_scan_node. + Override this for full control over plan rewriting, e.g. + to transpile the plan to SQL and execute it remotely. + + Args: + logical_plan: Deserialized ``LogicalPlanNode`` protobuf message. + datasets: Dict mapping table names to :class:`ExternalDataset` + instances for every ``ExternalTableProvider`` in the plan. + + Returns: + An Arrow-compatible table (for full execution by the resolver) + or a :class:`ResolvedPlan` (rewritten plan with sidecar Arrow + data for DataFusion to execute). """ sidecar: dict[str, Table] = {} self._resolve_external_tables(logical_plan, datasets, sidecar) @@ -178,11 +264,15 @@ def _resolve_external_tables( table_name, ) + filters = list(inner.filters) if inner.filters else None + table_data = self.resolve_table( name=table_name, + scheme=dataset.scheme, schema=dataset.schema, metadata=metadata, projected_columns=projected_columns, + filters=filters, ) replacement = inline_table_scan_node( @@ -295,6 +385,48 @@ def inline_table_scan_node( return node +def external_table_scan_node( + table_name: str, + scheme: str, + schema: Schema, + metadata: dict[str, Any] | None = None, +) -> LogicalPlanNode: + """Build a LogicalPlanNode for an external table scan. + + Use this in :meth:`PlanResolver.scan_url` implementations to create + ``ExternalTableProvider`` plan nodes that will later be resolved by + :meth:`PlanResolver.resolve_plan`. + + Args: + table_name: Name for the table in the plan. + scheme: Scheme identifier (e.g. ``"spark"``). + schema: Arrow schema (arro3.core.Schema) — required for logical planning. + metadata: Optional JSON-serializable dict of metadata. + + Returns: + A deserialized LogicalPlanNode protobuf message. + """ + from vegafusion._vegafusion import external_table_scan_node as _native + + try: + from vegafusion.proto.datafusion_pb2 import ( + LogicalPlanNode, # type: ignore[attr-defined] + ) + except ImportError as e: + raise ImportError(_PROTOBUF_INSTALL_HINT) from e + + node = LogicalPlanNode() + node.ParseFromString( + _native( + table_name=table_name, + scheme=scheme, + schema=schema, + metadata=metadata, + ) + ) + return node + + def unparse_to_sql( plan: bytes | LogicalPlanNode, dialect: str = "default", @@ -317,3 +449,39 @@ def unparse_to_sql( if not isinstance(plan, bytes): plan = plan.SerializeToString() return str(_native(plan, dialect)) + + +def unparse_expr_to_sql( + exprs: LogicalExprNode | bytes | list[LogicalExprNode | bytes], + dialect: str = "default", +) -> str: + """Convert filter expression(s) to a SQL string. + + Accepts a single ``LogicalExprNode`` protobuf message or a list of them. + Multiple expressions are joined with ``AND``. + + This is useful for converting the ``filters`` parameter of + :meth:`PlanResolver.resolve_table` into a SQL WHERE clause that can + be passed to external data sources. + + Args: + exprs: A single ``LogicalExprNode`` or a list of them. + dialect: SQL dialect. One of ``"default"``, ``"postgres"``, + ``"mysql"``, ``"sqlite"``, ``"duckdb"``, ``"bigquery"``. + + Returns: + The SQL string representation of the expression(s). + """ + from vegafusion._vegafusion import unparse_expr_to_sql as _native + + if not isinstance(exprs, list): + exprs = [exprs] + + expr_bytes = [] + for expr in exprs: + if isinstance(expr, bytes): + expr_bytes.append(expr) + else: + expr_bytes.append(expr.SerializeToString()) + + return str(_native(expr_bytes, dialect)) diff --git a/vegafusion-python/vegafusion/runtime.py b/vegafusion-python/vegafusion/runtime.py index 19d8bf464..9209a46e2 100644 --- a/vegafusion-python/vegafusion/runtime.py +++ b/vegafusion-python/vegafusion/runtime.py @@ -207,6 +207,8 @@ def __init__( | list[PlanResolver] | tuple[PlanResolver, ...] | None = None, + base_url: str | bool | None = None, + allowed_base_urls: list[str] | None = None, ) -> None: """ Initialize a VegaFusionRuntime. @@ -220,6 +222,17 @@ def __init__( Can be a single resolver or a list of resolvers that form a pipeline (executed in order; short-circuits on first Table result). + base_url: Base URL for resolving relative data URLs. + - None or True: use the default CDN + (https://raw.githubusercontent.com/vega/vega-datasets/v2.3.0/) + - str: custom base URL (scheme URL or absolute path) + - False: disabled; relative paths produce an error + allowed_base_urls: Optional allowlist for external data access. + - None: unrestricted for embedded VegaFusion runtimes + - []: deny all external data access + - list[str]: allow matching URL/path patterns only + Policy checks apply to the initial resolved URL only; redirect + destinations are not re-checked after a fetch begins. """ self._runtime = None self._grpc_url: str | None = None @@ -227,6 +240,20 @@ def __init__( self._memory_limit = memory_limit self._worker_threads = worker_threads self._plan_resolvers = _normalize_resolvers(plan_resolver) + self._base_url = base_url + self._allowed_base_urls = allowed_base_urls + + def _has_non_default_url_policy(self) -> bool: + return self._base_url not in (None, True) or self._allowed_base_urls is not None + + def _ensure_not_using_grpc_for_url_policy_change(self) -> None: + if self._grpc_url is not None: + raise ValueError( + "Cannot change base_url or allowed_base_urls " + "while using a gRPC runtime. " + "Configure these on the vegafusion-server " + "process instead." + ) @property def runtime(self) -> PyVegaFusionRuntime: @@ -240,6 +267,10 @@ def runtime(self) -> PyVegaFusionRuntime: # Try to initialize a VegaFusion runtime from vegafusion._vegafusion import PyVegaFusionRuntime + if self._grpc_url is not None: + self._runtime = PyVegaFusionRuntime.new_grpc(self._grpc_url) + return self._runtime + if self.memory_limit is None: self.memory_limit = get_virtual_memory() // 2 if self.worker_threads is None: @@ -251,12 +282,16 @@ def runtime(self) -> PyVegaFusionRuntime: self.cache_capacity, self.memory_limit, self.worker_threads, + base_url=self._base_url, + allowed_base_urls=self._allowed_base_urls, ) else: self._runtime = PyVegaFusionRuntime.new_embedded( self.cache_capacity, self.memory_limit, self.worker_threads, + base_url=self._base_url, + allowed_base_urls=self._allowed_base_urls, ) return self._runtime @@ -273,6 +308,13 @@ def grpc_connect(self, url: str) -> None: "Plan resolvers run locally and are not supported " "with remote gRPC runtimes." ) + if self._has_non_default_url_policy(): + raise ValueError( + "Cannot use grpc_connect with local " + "base_url or allowed_base_urls settings. " + "Configure URL policy on the " + "vegafusion-server process instead." + ) from vegafusion._vegafusion import PyVegaFusionRuntime @@ -396,7 +438,7 @@ def _import_inline_datasets( # Validate: ExternalDatasets require a plan resolver if external_dataset_refs and not self._plan_resolvers: details = [ - f" - {name!r} (protocol={value.protocol!r})" + f" - {name!r} (scheme={value.scheme!r})" for name, value in inline_datasets.items() if isinstance(value, ExternalDataset) ] @@ -883,6 +925,56 @@ def cache_capacity(self, value: int) -> None: self._cache_capacity = value self.reset() + @property + def base_url(self) -> str | bool | None: + """ + Get the base URL setting. + + Returns: + The current base_url setting. + """ + return self._base_url + + @base_url.setter + def base_url(self, value: str | bool | None) -> None: + """ + Set the base URL and restart the runtime. + + Args: + value: Base URL for resolving relative data URLs. + - None or True: use the default CDN + - str: custom base URL + - False: disabled + """ + if value != self._base_url: + self._ensure_not_using_grpc_for_url_policy_change() + self._base_url = value + self.reset() + + @property + def allowed_base_urls(self) -> list[str] | None: + """ + Get the allowed_base_urls setting. + + Returns: + The current allowed_base_urls setting. + """ + return self._allowed_base_urls + + @allowed_base_urls.setter + def allowed_base_urls(self, value: list[str] | None) -> None: + """ + Set the external data allowlist and restart the runtime. + + Args: + value: None for unrestricted embedded access, [] to deny all external + access, or a list of URL/path patterns to allow. + """ + if value != self._allowed_base_urls: + self._ensure_not_using_grpc_for_url_policy_change() + self._allowed_base_urls = value + self.reset() + @property def plan_resolver(self) -> PlanResolver | tuple[PlanResolver, ...] | None: if not self._plan_resolvers: diff --git a/vegafusion-runtime/src/data/codec.rs b/vegafusion-runtime/src/data/codec.rs index fbb187e7a..121d77976 100644 --- a/vegafusion-runtime/src/data/codec.rs +++ b/vegafusion-runtime/src/data/codec.rs @@ -62,12 +62,9 @@ impl LogicalExtensionCodec for VegaFusionCodec { _ctx: &datafusion::execution::TaskContext, ) -> Result> { if buf.is_empty() { - // Backward compatibility: empty buf treated as ExternalTableProvider - return Ok(Arc::new(ExternalTableProvider::new( - schema, - None, - Value::Null, - ))); + return Err(DataFusionError::Plan( + "Empty custom_table_data buffer — expected JSON envelope".to_string(), + )); } let envelope: Value = serde_json::from_slice(buf).map_err(|e| { @@ -76,18 +73,20 @@ impl LogicalExtensionCodec for VegaFusionCodec { match envelope.get("type").and_then(|t| t.as_str()) { Some("external") => { - let protocol = envelope - .get("protocol") + let scheme = envelope + .get("scheme") .and_then(|v| v.as_str()) - .map(|s| s.to_string()); - let source = envelope - .get("source") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()); + .ok_or_else(|| { + DataFusionError::Plan( + "ExternalTableProvider envelope missing required 'scheme' field" + .to_string(), + ) + })? + .to_string(); let metadata = envelope.get("metadata").cloned().unwrap_or(Value::Null); - Ok(Arc::new( - ExternalTableProvider::new(schema, protocol, metadata).with_source(source), - )) + Ok(Arc::new(ExternalTableProvider::new( + scheme, schema, metadata, + ))) } Some("inline") => { let name = envelope @@ -117,11 +116,9 @@ impl LogicalExtensionCodec for VegaFusionCodec { Some(other) => Err(DataFusionError::Plan(format!( "Unknown table provider type in envelope: '{other}'" ))), - None => { - // No "type" field — treat as legacy ExternalTableProvider where - // the entire JSON value is the metadata - Ok(Arc::new(ExternalTableProvider::new(schema, None, envelope))) - } + None => Err(DataFusionError::Plan( + "Table provider envelope missing required 'type' field".to_string(), + )), } } @@ -132,14 +129,11 @@ impl LogicalExtensionCodec for VegaFusionCodec { buf: &mut Vec, ) -> Result<()> { if let Some(ext) = node.as_any().downcast_ref::() { - let mut envelope = serde_json::json!({ + let envelope = serde_json::json!({ "type": "external", - "protocol": ext.protocol(), + "scheme": ext.scheme(), "metadata": ext.metadata(), }); - if let Some(source) = ext.source() { - envelope["source"] = serde_json::Value::String(source.to_string()); - } let json_bytes = serde_json::to_vec(&envelope).map_err(|e| { DataFusionError::Plan(format!( "Failed to encode ExternalTableProvider envelope: {e}" diff --git a/vegafusion-runtime/src/data/datafusion_resolver.rs b/vegafusion-runtime/src/data/datafusion_resolver.rs index 8a46c71df..c3b1c1428 100644 --- a/vegafusion-runtime/src/data/datafusion_resolver.rs +++ b/vegafusion-runtime/src/data/datafusion_resolver.rs @@ -1,18 +1,28 @@ use std::sync::Arc; use async_trait::async_trait; +use cfg_if::cfg_if; use datafusion::prelude::{DataFrame, SessionContext}; use vegafusion_common::datafusion_expr::LogicalPlan; use vegafusion_common::error::Result; -use vegafusion_core::runtime::{PlanResolver, ResolutionResult}; +#[cfg(not(feature = "parquet"))] +use vegafusion_common::error::VegaFusionError; +use vegafusion_core::runtime::ParsedUrl; +use super::plan_resolver::ResolutionResult; + +use super::plan_resolver::PlanResolver; + +use super::tasks::{read_arrow, read_csv, read_json}; use super::util::DataFrameUtils; -/// Terminal `PlanResolver` that executes plans via DataFusion. +#[cfg(feature = "parquet")] +use super::tasks::read_parquet; + +/// Terminal `PlanResolver` that handles built-in URL formats and executes plans via DataFusion. /// -/// This is the final resolver in a `ResolverPipeline`. It always returns -/// `ResolutionResult::Table` by executing the plan against the shared -/// `SessionContext`. +/// Handles http, https, file, and s3 schemes with csv, tsv, json, arrow, and parquet formats. +/// This is always the last resolver in a `ResolverPipeline`. pub struct DataFusionResolver { pub(crate) ctx: Arc, } @@ -29,6 +39,59 @@ impl PlanResolver for DataFusionResolver { "DataFusionResolver" } + fn supports_arrow_tables(&self) -> bool { + true + } + + async fn scan_url(&self, parsed_url: &ParsedUrl) -> Result> { + // Only handle schemes that DataFusion supports natively + const SUPPORTED_SCHEMES: &[&str] = &["http", "https", "file", "s3"]; + if !SUPPORTED_SCHEMES.contains(&parsed_url.scheme.as_str()) { + return Ok(None); + } + + // Determine file type: format_type takes precedence over extension. + // "json" format_type is treated as None (json is Vega-Lite's default, + // shouldn't override extension detection). + let file_type = match &parsed_url.format_type { + Some(ft) if ft != "json" => Some(ft.as_str()), + _ => None, + }; + let ext = parsed_url.extension.as_deref(); + + let url = &parsed_url.url; + let ctx = self.ctx.clone(); + + let df = if file_type == Some("csv") || (file_type.is_none() && ext == Some("csv")) { + read_csv(url, &parsed_url.parse, ctx, false).await? + } else if file_type == Some("tsv") || (file_type.is_none() && ext == Some("tsv")) { + read_csv(url, &parsed_url.parse, ctx, true).await? + } else if file_type == Some("json") + || (file_type.is_none() && matches!(ext, Some("json") | None)) + { + read_json(url, ctx).await? + } else if file_type == Some("arrow") + || (file_type.is_none() && matches!(ext, Some("arrow") | Some("feather"))) + { + read_arrow(url, ctx).await? + } else if file_type == Some("parquet") || (file_type.is_none() && ext == Some("parquet")) { + cfg_if! { + if #[cfg(feature = "parquet")] { + read_parquet(url, ctx).await? + } else { + return Err(VegaFusionError::internal( + "Enable parquet support by enabling the `parquet` feature flag" + )) + } + } + } else { + // Unrecognized format — pass to next resolver + return Ok(None); + }; + + Ok(Some(df.logical_plan().clone())) + } + async fn resolve_plan(&self, plan: LogicalPlan) -> Result { let table = DataFrame::new(self.ctx.state(), plan) .collect_to_table() diff --git a/vegafusion-runtime/src/data/external_table.rs b/vegafusion-runtime/src/data/external_table.rs index 71adaa026..da2474509 100644 --- a/vegafusion-runtime/src/data/external_table.rs +++ b/vegafusion-runtime/src/data/external_table.rs @@ -8,7 +8,7 @@ use datafusion::catalog::TableProvider; use datafusion::datasource::TableType; use datafusion::physical_plan::ExecutionPlan; use datafusion_common::{plan_err, Result}; -use datafusion_expr::Expr; +use datafusion_expr::{Expr, TableProviderFilterPushDown}; use serde_json::Value; use vegafusion_common::arrow::datatypes::SchemaRef; @@ -22,33 +22,22 @@ use vegafusion_common::arrow::datatypes::SchemaRef; /// Optionally carries arbitrary JSON metadata in [`Self::metadata`], /// which is serialized into `custom_table_data` by [`super::codec::VegaFusionCodec`]. pub struct ExternalTableProvider { + scheme: String, schema: SchemaRef, - protocol: Option, - source: Option, metadata: Value, } impl ExternalTableProvider { - pub fn new(schema: SchemaRef, protocol: Option, metadata: Value) -> Self { + pub fn new(scheme: String, schema: SchemaRef, metadata: Value) -> Self { Self { + scheme, schema, - protocol, - source: None, metadata, } } - pub fn with_source(mut self, source: Option) -> Self { - self.source = source; - self - } - - pub fn protocol(&self) -> Option<&str> { - self.protocol.as_deref() - } - - pub fn source(&self) -> Option<&str> { - self.source.as_deref() + pub fn scheme(&self) -> &str { + &self.scheme } pub fn metadata(&self) -> &Value { @@ -59,8 +48,7 @@ impl ExternalTableProvider { impl Debug for ExternalTableProvider { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ExternalTableProvider") - .field("protocol", &self.protocol) - .field("source", &self.source) + .field("scheme", &self.scheme) .field("schema", &self.schema) .field("metadata", &self.metadata) .finish() @@ -81,6 +69,16 @@ impl TableProvider for ExternalTableProvider { TableType::Base } + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + // Report Inexact so DataFusion pushes filters into the TableScan + // (where resolve_table can access them) while still re-applying + // them on the output for correctness. + Ok(vec![TableProviderFilterPushDown::Inexact; filters.len()]) + } + async fn scan( &self, _state: &dyn Session, @@ -88,9 +86,9 @@ impl TableProvider for ExternalTableProvider { _filters: &[Expr], _limit: Option, ) -> Result> { - let protocol = self.protocol().unwrap_or("unknown"); + let scheme = self.scheme(); plan_err!( - "ExternalTableProvider (protocol: {protocol}) cannot be executed directly. \ + "ExternalTableProvider (scheme: {scheme}) cannot be executed directly. \ This table represents an external data source that must be resolved \ before execution. Set a PlanResolver on the VegaFusionRuntime to \ handle external table references." diff --git a/vegafusion-runtime/src/data/mod.rs b/vegafusion-runtime/src/data/mod.rs index 8ad3a1379..6bea15bf9 100644 --- a/vegafusion-runtime/src/data/mod.rs +++ b/vegafusion-runtime/src/data/mod.rs @@ -4,5 +4,6 @@ pub mod datafusion_resolver; pub mod external_table; pub mod inline_table; pub mod pipeline; +pub mod plan_resolver; pub mod tasks; pub mod util; diff --git a/vegafusion-runtime/src/data/pipeline.rs b/vegafusion-runtime/src/data/pipeline.rs index 02ed64570..389c20a03 100644 --- a/vegafusion-runtime/src/data/pipeline.rs +++ b/vegafusion-runtime/src/data/pipeline.rs @@ -1,59 +1,139 @@ use std::sync::Arc; +use datafusion::datasource::source_as_provider; use datafusion::prelude::SessionContext; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; +use datafusion_expr::LogicalPlan as DFLogicalPlan; use vegafusion_common::data::table::VegaFusionTable; use vegafusion_common::datafusion_expr::LogicalPlan; -use vegafusion_common::error::Result; -use vegafusion_core::runtime::{PlanResolver, ResolutionResult}; +use vegafusion_common::error::{Result, VegaFusionError}; +use vegafusion_core::data::url::normalize_base_url; +use vegafusion_core::runtime::ParsedUrl; + +use super::plan_resolver::ResolutionResult; use super::datafusion_resolver::DataFusionResolver; +use super::external_table::ExternalTableProvider; +use super::plan_resolver::PlanResolver; + +/// CDN base URL for vega-datasets, used as the default base_url. +pub const VEGA_DATASETS_CDN_BASE: &str = + "https://raw.githubusercontent.com/vega/vega-datasets/v2.3.0/"; + +/// Three-state base URL setting for public API boundaries. +#[derive(Clone, Debug, Default)] +pub enum BaseUrlSetting { + /// Use the default CDN base URL (vega-datasets) + #[default] + Default, + /// Disable base URL; relative paths produce an error + Disabled, + /// Use a custom base URL (scheme URL or absolute path) + Custom(String), +} + +/// Map a `BaseUrlSetting` to the two-state `Option` used internally. +/// Custom base URLs are normalized (bare absolute paths become file:// URLs). +pub fn resolve_base_url(setting: &BaseUrlSetting) -> Result> { + match setting { + BaseUrlSetting::Default => Ok(Some(VEGA_DATASETS_CDN_BASE.to_string())), + BaseUrlSetting::Disabled => Ok(None), + BaseUrlSetting::Custom(s) => Ok(Some(normalize_base_url(s.clone())?)), + } +} -/// Chains user-supplied resolvers with a terminal `DataFusionResolver`. +/// Chains resolvers with a terminal `DataFusionResolver`. +/// +/// All resolvers (user-supplied + DataFusionResolver) live in a single vec. +/// DataFusionResolver is always the last resolver in the chain. /// -/// Each user resolver either returns a fully materialized `Table` (short-circuiting -/// the pipeline) or a rewritten `Plan` that is passed to the next resolver. -/// The `DataFusionResolver` at the end always executes the plan and returns a table. +/// For `scan_url`, resolvers are tried in order; the first `Some(plan)` wins. +/// For `resolve`, each resolver either returns a `Table` (short-circuiting) +/// or a rewritten `Plan` passed to the next resolver. #[derive(Clone)] pub struct ResolverPipeline { - user_resolvers: Arc>>, - datafusion_resolver: Arc, + resolvers: Arc>>, + ctx: Arc, } impl ResolverPipeline { pub fn new(user_resolvers: Vec>, ctx: Arc) -> Self { + let mut resolvers: Vec> = user_resolvers; + resolvers.push(Arc::new(DataFusionResolver::new(ctx.clone()))); Self { - user_resolvers: Arc::new(user_resolvers), - datafusion_resolver: Arc::new(DataFusionResolver::new(ctx)), + resolvers: Arc::new(resolvers), + ctx, } } - /// Whether any user-supplied resolvers are registered. - pub fn has_user_resolvers(&self) -> bool { - !self.user_resolvers.is_empty() + /// Whether the runtime should eagerly materialize a `LogicalPlan` into + /// an in-memory Arrow table. + /// + /// Materializes when: + /// 1. All resolvers support in-memory Arrow tables, OR + /// 2. The plan contains no `ExternalTableProvider` nodes (no resolver + /// will need to intercept it) + /// + /// Keeps the plan lazy otherwise, so resolvers that need plan-level + /// access (e.g. a Spark connector) can intercept external tables. + pub fn should_materialize(&self, plan: &LogicalPlan) -> bool { + if self.resolvers.iter().all(|r| r.supports_arrow_tables()) { + return true; + } + !has_external_table_scans(plan) } /// Access the shared `SessionContext`. pub fn ctx(&self) -> &SessionContext { - &self.datafusion_resolver.ctx + &self.ctx + } + + /// Try each resolver's `scan_url` in order. Returns the first `Some(plan)`. + pub async fn scan_url(&self, parsed_url: &ParsedUrl) -> Result> { + for resolver in self.resolvers.iter() { + if let Some(plan) = resolver.scan_url(parsed_url).await? { + return Ok(Some(plan)); + } + } + Ok(None) } /// Resolve a `LogicalPlan` to a `VegaFusionTable`. /// - /// Iterates through user resolvers first; if any returns `Table`, that result - /// is returned immediately. Otherwise the (possibly rewritten) plan is executed - /// by the terminal `DataFusionResolver`. + /// Iterates through all resolvers; if any returns `Table`, that result + /// is returned immediately. Otherwise the (possibly rewritten) plan is + /// passed to the next resolver. pub async fn resolve(&self, plan: LogicalPlan) -> Result { let mut current = plan; - for resolver in self.user_resolvers.iter() { + for resolver in self.resolvers.iter() { match resolver.resolve_plan(current).await? { ResolutionResult::Table(table) => return Ok(table), ResolutionResult::Plan(p) => current = p, } } - // Terminal: DataFusionResolver always returns Table - match self.datafusion_resolver.resolve_plan(current).await? { - ResolutionResult::Table(table) => Ok(table), - ResolutionResult::Plan(_) => unreachable!("DataFusionResolver always returns Table"), - } + Err(VegaFusionError::internal( + "No resolver produced a final table", + )) } } + +/// Returns true if the plan contains any `ExternalTableProvider` table scans. +fn has_external_table_scans(plan: &LogicalPlan) -> bool { + let mut found = false; + let _ = plan.apply(|node| { + if let DFLogicalPlan::TableScan(scan) = node { + if let Ok(provider) = source_as_provider(&scan.source) { + if provider + .as_any() + .downcast_ref::() + .is_some() + { + found = true; + return Ok(TreeNodeRecursion::Stop); + } + } + } + Ok(TreeNodeRecursion::Continue) + }); + found +} diff --git a/vegafusion-runtime/src/data/plan_resolver.rs b/vegafusion-runtime/src/data/plan_resolver.rs new file mode 100644 index 000000000..24772a259 --- /dev/null +++ b/vegafusion-runtime/src/data/plan_resolver.rs @@ -0,0 +1,211 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use async_trait::async_trait; +use datafusion::catalog::TableProvider; +use datafusion::datasource::{provider_as_source, source_as_provider, MemTable}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; +use datafusion_expr::{Expr, LogicalPlan as DFLogicalPlan}; +use vegafusion_common::arrow::datatypes::SchemaRef; +use vegafusion_common::data::table::VegaFusionTable; +use vegafusion_common::datafusion_expr::LogicalPlan; +use vegafusion_common::error::{Result, VegaFusionError}; +use vegafusion_core::runtime::ParsedUrl; + +use super::external_table::ExternalTableProvider; + +pub enum ResolutionResult { + /// Resolver fully materialized the plan + Table(VegaFusionTable), + /// Resolver produced a rewritten plan for the next resolver to handle, + /// or for DataFusion to execute if this is the last resolver + Plan(LogicalPlan), +} + +/// Trait for custom data source integration with VegaFusion. +/// +/// Resolvers participate in a two-phase pipeline: +/// +/// 1. **URL scanning**: [`scan_url`](Self::scan_url) converts URLs into +/// `LogicalPlan` nodes (typically `ExternalTableProvider` markers). +/// +/// 2. **Execution**: [`resolve_table`](Self::resolve_table) or +/// [`resolve_plan`](Self::resolve_plan) provides data for external table +/// references or rewrites the plan for remote execution. +/// +/// Override one of (simplest first): +/// - [`resolve_table`](Self::resolve_table): per-table data provider, called +/// by the default `resolve_plan` for each `ExternalTableProvider` node. +/// - [`resolve_plan`](Self::resolve_plan): receives the entire plan. Overriding +/// this supersedes `resolve_table` since the runtime calls `resolve_plan` directly. +#[async_trait] +pub trait PlanResolver: Send + Sync + 'static { + /// Human-readable name for logging and error messages. + fn name(&self) -> &str; + + /// Whether this resolver can efficiently consume in-memory Arrow tables. + /// When all resolvers in the pipeline return true, the runtime may eagerly + /// materialize LogicalPlans into tables. When false, data is kept as lazy + /// plans so resolvers that need plan-level access can intercept them. + fn supports_arrow_tables(&self) -> bool { + false + } + + /// Given a parsed URL, optionally return a LogicalPlan to handle it. + /// Return Ok(None) to pass the URL to the next resolver in the chain. + async fn scan_url(&self, _parsed_url: &ParsedUrl) -> Result> { + Ok(None) + } + + /// Provide data for a single external table reference. + /// + /// Called once per `ExternalTableProvider` node in the plan. + /// Override this instead of [`resolve_plan`](Self::resolve_plan) when + /// each external table can be resolved independently. + /// + /// The default `resolve_plan` walks the plan tree and calls this method + /// for every `ExternalTableProvider` it finds, replacing each with an + /// in-memory table. + /// + /// # Arguments + /// * `name` - table name from the plan + /// * `scheme` - URL scheme identifier (e.g. `"spark"`, `"snowflake"`) + /// * `schema` - full Arrow schema of the external table + /// * `metadata` - JSON metadata from ExternalTableProvider + /// * `projected_columns` - column names DataFusion actually needs, + /// or `None` if all columns are needed + /// * `filters` - pushed-down filter predicates from DataFusion, already + /// split into a conjunction. These are hints — resolvers may apply + /// some, all, or none. DataFusion re-applies all filters regardless. + async fn resolve_table( + &self, + _name: &str, + _scheme: &str, + _schema: SchemaRef, + _metadata: &serde_json::Value, + _projected_columns: Option>, + _filters: &[Expr], + ) -> Result { + Err(VegaFusionError::internal( + "resolve_table not implemented — override resolve_table or resolve_plan", + )) + } + + /// Resolve a LogicalPlan containing external table references. + /// + /// The default implementation walks the plan tree, finds + /// `ExternalTableProvider` nodes, calls [`resolve_table`](Self::resolve_table) + /// for each, and replaces them with in-memory table scans. Plans with + /// no external tables are passed through unchanged. + /// + /// Override this for full control over plan rewriting (e.g. SQL + /// transpilation or remote execution). + async fn resolve_plan(&self, plan: LogicalPlan) -> Result { + let external_tables = extract_external_tables(&plan); + + if external_tables.is_empty() { + return Ok(ResolutionResult::Plan(plan)); + } + + // Resolve each external table, then wrap as MemTable + let mut mem_tables: HashMap> = HashMap::new(); + for (table_name, info) in &external_tables { + let table = self + .resolve_table( + table_name, + &info.scheme, + info.schema.clone(), + &info.metadata, + info.projected_columns.clone(), + &info.filters, + ) + .await?; + let mem_table = + MemTable::try_new(table.schema.clone(), vec![table.batches]).map_err(|e| { + VegaFusionError::internal(format!("Failed to create MemTable: {e}")) + })?; + mem_tables.insert(table_name.clone(), Arc::new(mem_table)); + } + + // Rewrite the plan, replacing ExternalTableProvider with MemTable + let mut rewriter = ResolvedTableRewriter { tables: mem_tables }; + let rewritten = plan + .rewrite(&mut rewriter) + .map_err(|e| VegaFusionError::internal(format!("Failed to rewrite plan: {e}")))? + .data; + + Ok(ResolutionResult::Plan(rewritten)) + } +} + +/// Info extracted from an ExternalTableProvider node in a LogicalPlan. +struct ExternalTableInfo { + scheme: String, + schema: SchemaRef, + metadata: serde_json::Value, + projected_columns: Option>, + filters: Vec, +} + +/// Walk a LogicalPlan and collect ExternalTableProvider info for each table scan. +/// +/// Filters come from `scan.filters` on the `TableScan`, which are populated +/// when DataFusion's optimizer pushes filter predicates down to the scan. +/// `ExternalTableProvider` reports `Inexact` for all filters to enable this. +fn extract_external_tables(plan: &LogicalPlan) -> HashMap { + let mut tables = HashMap::new(); + let _ = plan.apply(|node| { + if let DFLogicalPlan::TableScan(scan) = node { + if let Ok(provider) = source_as_provider(&scan.source) { + if let Some(ext) = provider.as_any().downcast_ref::() { + let projected_columns = scan.projection.as_ref().map(|indices| { + let schema = ext.schema(); + indices + .iter() + .map(|&i| schema.field(i).name().clone()) + .collect() + }); + tables.insert( + scan.table_name.table().to_string(), + ExternalTableInfo { + scheme: ext.scheme().to_string(), + schema: ext.schema(), + metadata: ext.metadata().clone(), + projected_columns, + filters: scan.filters.clone(), + }, + ); + } + } + } + Ok(datafusion_common::tree_node::TreeNodeRecursion::Continue) + }); + tables +} + +/// Rewriter that replaces ExternalTableProvider scans with MemTable scans. +struct ResolvedTableRewriter { + tables: HashMap>, +} + +impl TreeNodeRewriter for ResolvedTableRewriter { + type Node = DFLogicalPlan; + + fn f_up(&mut self, node: Self::Node) -> datafusion_common::Result> { + if let DFLogicalPlan::TableScan(scan) = &node { + let table_name = scan.table_name.table(); + if let Some(mem_table) = self.tables.get(table_name) { + let new_scan = DFLogicalPlan::TableScan(datafusion_expr::TableScan { + table_name: scan.table_name.clone(), + source: provider_as_source(mem_table.clone()), + projection: scan.projection.clone(), + projected_schema: scan.projected_schema.clone(), + filters: scan.filters.clone(), + fetch: scan.fetch, + }); + return Ok(Transformed::yes(new_scan)); + } + } + Ok(Transformed::no(node)) + } +} diff --git a/vegafusion-runtime/src/data/tasks.rs b/vegafusion-runtime/src/data/tasks.rs index a42c3dbd2..4bb204e38 100644 --- a/vegafusion-runtime/src/data/tasks.rs +++ b/vegafusion-runtime/src/data/tasks.rs @@ -2,13 +2,13 @@ use crate::data::pipeline::ResolverPipeline; use crate::expression::compiler::compile; use crate::expression::compiler::config::CompilationConfig; use crate::expression::compiler::utils::ExprHelpers; -use crate::task_graph::task::TaskCall; +use crate::task_graph::task::{TaskCall, TaskContext}; use std::borrow::Cow; use async_trait::async_trait; use datafusion_expr::{lit, Expr}; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::path::Path; use std::sync::Arc; use vegafusion_core::data::dataset::VegaFusionDataset; @@ -24,12 +24,13 @@ use datafusion_common::config::TableOptions; use datafusion_functions::expr_fn::make_date; use vegafusion_common::data::scalar::{ScalarValue, ScalarValueHelpers}; -use vegafusion_common::error::{Result, ResultWithContext, VegaFusionError}; +use vegafusion_common::error::{Result, ResultWithContext, ToExternalError, VegaFusionError}; use vegafusion_core::proto::gen::tasks::data_url_task::Url; use vegafusion_core::proto::gen::tasks::scan_url_format; use vegafusion_core::proto::gen::tasks::scan_url_format::Parse; use vegafusion_core::proto::gen::tasks::{DataSourceTask, DataUrlTask, DataValuesTask}; +use vegafusion_core::runtime::{check_url_allowed, file_url_to_path, path_to_file_url}; use vegafusion_core::task_graph::task::{InputVariable, TaskDependencies}; use vegafusion_core::task_graph::task_value::TaskValue; @@ -54,24 +55,26 @@ use object_store::{http::HttpBuilder, ClientOptions}; use tokio::io::AsyncReadExt; #[cfg(feature = "parquet")] -use {datafusion::prelude::ParquetReadOptions, vegafusion_common::error::ToExternalError}; +use datafusion::prelude::ParquetReadOptions; #[cfg(target_arch = "wasm32")] use object_store_wasm::HttpStore; -/// If no user resolvers are configured, eagerly materialize a `TaskValue::Plan` -/// into a `TaskValue::Table` via DataFusion execution. Passthrough otherwise. +/// Eagerly materialize a `TaskValue::Plan` into a `TaskValue::Table` when safe: +/// either all resolvers support Arrow tables, or the plan has no external table +/// nodes that a resolver would need to intercept. Otherwise keep it lazy. async fn maybe_materialize_plan( task_value: TaskValue, pipeline: &ResolverPipeline, ) -> Result { - if !pipeline.has_user_resolvers() { - if let TaskValue::Plan(plan) = task_value { + if let TaskValue::Plan(plan) = task_value { + if pipeline.should_materialize(&plan) { let table = DataFrame::new(pipeline.ctx().state(), plan) .collect_to_table() .await?; return Ok(TaskValue::Table(table)); } + return Ok(TaskValue::Plan(plan)); } Ok(task_value) } @@ -122,22 +125,25 @@ impl TaskCall for DataUrlTask { async fn eval( &self, values: &[TaskValue], - tz_config: &Option, - inline_datasets: HashMap, - pipeline: ResolverPipeline, + ctx: &TaskContext, ) -> Result<(TaskValue, Vec)> { - let ctx = Arc::new(pipeline.ctx().clone()); + let session_ctx = Arc::new(ctx.pipeline.ctx().clone()); // Build compilation config for url signal (if any) and transforms (if any) - let config = - build_compilation_config(&self.input_vars(), values, tz_config, pipeline.clone()); - - // Build url string + let config = build_compilation_config( + &self.input_vars(), + values, + &ctx.tz_config, + ctx.pipeline.clone(), + ); + + // Build url string — resolve at eval time for both static and signal URLs let url = match self.url.as_ref().unwrap() { - Url::String(url) => url.clone(), + Url::String(url) => vegafusion_core::runtime::resolve_url(url, &ctx.base_url)?, Url::Expr(expr) => { let compiled = compile(expr, &config, None).await?; let url_scalar = compiled.eval_to_scalar()?; - url_scalar.to_scalar_string()? + let raw_url = url_scalar.to_scalar_string()?; + vegafusion_core::runtime::resolve_url(&raw_url, &ctx.base_url)? } }; @@ -145,70 +151,56 @@ impl TaskCall for DataUrlTask { let url_parts: Vec<&str> = url.splitn(2, '#').collect(); let url = url_parts.first().cloned().unwrap_or(&url).to_string(); - // Handle references to vega default datasets (e.g. "data/us-10m.json") - let url = check_builtin_dataset(url); - // Load data from URL let parse = self.format_type.as_ref().and_then(|fmt| fmt.parse.clone()); let file_type = self.format_type.as_ref().and_then(|fmt| fmt.r#type.clone()); // Vega-Lite sets unspecified file types to "json", so we don't want this to take // precedence over file extension - let file_type = if file_type == Some("json".to_string()) { + let format_type = if file_type == Some("json".to_string()) { None } else { - file_type.as_deref() + file_type }; let inline_name = extract_inline_dataset(&url).map(|name| name.trim().to_string()); let inline_dataset_info: Option<&VegaFusionDataset> = inline_name .as_ref() - .and_then(|name| inline_datasets.get(name)); + .and_then(|name| ctx.inline_datasets.get(name)); + + if inline_name.is_none() { + check_url_allowed(&url, &ctx.allowed_base_urls)?; + } let df = if let Some(inline_name) = &inline_name { if let Some(inline_dataset) = inline_dataset_info { match inline_dataset { VegaFusionDataset::Table { table, .. } => { let table = table.clone().with_ordering()?; - ctx.vegafusion_table(table).await? + session_ctx.vegafusion_table(table).await? } VegaFusionDataset::Plan { plan } => { - DataFrame::new(ctx.state(), plan.clone()).with_index()? + DataFrame::new(session_ctx.state(), plan.clone()).with_index()? } } - } else if let Ok(df) = ctx.table(inline_name).await { + } else if let Ok(df) = session_ctx.table(inline_name).await { df } else { return Err(VegaFusionError::internal(format!( "No inline dataset named {inline_name}" ))); } - } else if file_type == Some("csv") || (file_type.is_none() && url.ends_with(".csv")) { - read_csv(&url, &parse, ctx.clone(), false).await? - } else if file_type == Some("tsv") || (file_type.is_none() && url.ends_with(".tsv")) { - read_csv(&url, &parse, ctx.clone(), true).await? - } else if file_type == Some("json") || (file_type.is_none() && url.ends_with(".json")) { - read_json(&url, ctx.clone()).await? - } else if file_type == Some("arrow") - || (file_type.is_none() && (url.ends_with(".arrow") || url.ends_with(".feather"))) - { - read_arrow(&url, ctx.clone()).await? - } else if file_type == Some("parquet") - || (file_type.is_none() && (url.ends_with(".parquet"))) - { - cfg_if! { - if #[cfg(any(feature = "parquet"))] { - read_parquet(&url, ctx.clone()).await? - } else { + } else { + // Construct ParsedUrl and dispatch to pipeline.scan_url() + let parsed_url = build_parsed_url(&url, format_type.as_deref(), parse.clone())?; + match ctx.pipeline.scan_url(&parsed_url).await? { + Some(plan) => DataFrame::new(session_ctx.state(), plan), + None => { return Err(VegaFusionError::internal(format!( - "Enable parquet support by enabling the `parquet` feature flag" - ))) + "No resolver handled URL: {url}" + ))); } } - } else { - return Err(VegaFusionError::internal(format!( - "Invalid url file extension {url}" - ))); }; // Ensure there is an ordering column present @@ -235,108 +227,66 @@ impl TaskCall for DataUrlTask { // Return value based on whether inline dataset was used let task_value = if let Some(inline_dataset) = inline_dataset_info { let task_value = result_df.to_task_value(inline_dataset).await?; - maybe_materialize_plan(task_value, &pipeline).await? + maybe_materialize_plan(task_value, &ctx.pipeline).await? } else { - TaskValue::Table(result_df.collect_to_table().await?) + // URL-sourced data: use Plan when user resolvers exist for lazy evaluation + let task_value = TaskValue::Plan(result_df.logical_plan().clone()); + maybe_materialize_plan(task_value, &ctx.pipeline).await? }; Ok((task_value, output_values)) } } -lazy_static! { - static ref BUILT_IN_DATASETS: HashSet<&'static str> = vec![ - "7zip.png", - "airports.csv", - "annual-precip.json", - "anscombe.json", - "barley.json", - "birdstrikes.csv", - "budget.json", - "budgets.json", - "burtin.json", - "cars.json", - "co2-concentration.csv", - "countries.json", - "crimea.json", - "disasters.csv", - "driving.json", - "earthquakes.json", - "ffox.png", - "flare-dependencies.json", - "flare.json", - "flights-10k.json", - "flights-200k.arrow", - "flights-200k.json", - "flights-20k.json", - "flights-2k.json", - "flights-3m.csv", - "flights-5k.json", - "flights-airport.csv", - "football.json", - "gapminder-health-income.csv", - "gapminder.json", - "gimp.png", - "github.csv", - "income.json", - "iowa-electricity.csv", - "jobs.json", - "la-riots.csv", - "londonBoroughs.json", - "londonCentroids.json", - "londonTubeLines.json", - "lookup_groups.csv", - "lookup_people.csv", - "miserables.json", - "monarchs.json", - "movies.json", - "normal-2d.json", - "obesity.json", - "ohlc.json", - "penguins.json", - "platformer-terrain.json", - "points.json", - "political-contributions.json", - "population_engineers_hurricanes.csv", - "population.json", - "seattle-weather.csv", - "seattle-weather-hourly-normals.csv", - "sp500-2000.csv", - "sp500.csv", - "stocks.csv", - "udistrict.json", - "unemployment-across-industries.json", - "unemployment.tsv", - "uniform-2d.json", - "us-10m.json", - "us-employment.csv", - "us-state-capitals.json", - "volcano.json", - "weather.csv", - "weather.json", - "wheat.json", - "windvectors.csv", - "world-110m.json", - "zipcodes.csv", - ] - .into_iter() - .collect(); +/// Construct a `ParsedUrl` from a fully-resolved URL string and optional format type. +fn build_parsed_url( + url: &str, + format_type: Option<&str>, + parse: Option, +) -> Result { + let parsed = url::Url::parse(url) + .map_err(|e| VegaFusionError::internal(format!("Failed to parse URL '{url}': {e}")))?; + + let extension = std::path::Path::new(parsed.path()) + .extension() + .and_then(|ext| ext.to_str()) + .map(|s| s.to_string()); + + let query_params: Vec<(String, String)> = parsed + .query_pairs() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(); + + Ok(vegafusion_core::runtime::ParsedUrl { + url: url.to_string(), + scheme: parsed.scheme().to_string(), + host: parsed.host_str().map(|s| s.to_string()), + path: parsed.path().to_string(), + query_params, + extension, + format_type: format_type.map(|s| s.to_string()), + parse, + }) } -const DATASET_BASE: &str = "https://raw.githubusercontent.com/vega/vega-datasets"; -const DATASET_TAG: &str = "v2.3.0"; +#[cfg(feature = "http")] +async fn fetch_http_bytes(url: &str) -> Result> { + let client = reqwest::Client::new(); + let response = client + .get(url) + .send() + .await + .external(format!("Failed to fetch URL: {url}"))?; + + let response = response + .error_for_status() + .external(format!("Failed to fetch URL: {url}"))?; -fn check_builtin_dataset(url: String) -> String { - if let Some(dataset) = url.strip_prefix("data/") { - let path = std::path::Path::new(&url); - if !path.exists() && BUILT_IN_DATASETS.contains(dataset) { - format!("{DATASET_BASE}/{DATASET_TAG}/data/{dataset}") - } else { - url - } - } else { - url - } + Ok(response + .bytes() + .await + .external("Failed to read response bytes")? + .to_vec()) } /// After processing, all datetime columns are converted to Timestamptz and Date32 @@ -485,11 +435,9 @@ impl TaskCall for DataValuesTask { async fn eval( &self, values: &[TaskValue], - tz_config: &Option, - _inline_datasets: HashMap, - pipeline: ResolverPipeline, + ctx: &TaskContext, ) -> Result<(TaskValue, Vec)> { - let ctx = Arc::new(pipeline.ctx().clone()); + let session_ctx = Arc::new(ctx.pipeline.ctx().clone()); // Deserialize data into table let values_table = VegaFusionTable::from_ipc_bytes(&self.values)?; if values_table.schema.fields.is_empty() { @@ -523,11 +471,15 @@ impl TaskCall for DataValuesTask { { let transform_pipeline = self.pipeline.as_ref().unwrap(); - let config = - build_compilation_config(&self.input_vars(), values, tz_config, pipeline.clone()); + let config = build_compilation_config( + &self.input_vars(), + values, + &ctx.tz_config, + ctx.pipeline.clone(), + ); // Process datetime columns - let df = ctx.vegafusion_table(values_table).await?; + let df = session_ctx.vegafusion_table(values_table).await?; let sql_df = process_datetimes(&parse, df, &config.tz_config).await?; let (df, output_values) = transform_pipeline.eval_sql(sql_df, &config).await?; @@ -535,8 +487,8 @@ impl TaskCall for DataValuesTask { (table, output_values) } else { // No transforms - let values_df = ctx.vegafusion_table(values_table).await?; - let values_df: DataFrame = process_datetimes(&parse, values_df, tz_config).await?; + let values_df = session_ctx.vegafusion_table(values_table).await?; + let values_df: DataFrame = process_datetimes(&parse, values_df, &ctx.tz_config).await?; ( values_df.drop_index()?.collect_to_table().await?, Vec::new(), @@ -554,13 +506,12 @@ impl TaskCall for DataSourceTask { async fn eval( &self, values: &[TaskValue], - tz_config: &Option, - _inline_datasets: HashMap, - pipeline: ResolverPipeline, + ctx: &TaskContext, ) -> Result<(TaskValue, Vec)> { - let ctx = Arc::new(pipeline.ctx().clone()); + let session_ctx = Arc::new(ctx.pipeline.ctx().clone()); let input_vars = self.input_vars(); - let mut config = build_compilation_config(&input_vars, values, tz_config, pipeline.clone()); + let mut config = + build_compilation_config(&input_vars, values, &ctx.tz_config, ctx.pipeline.clone()); // Remove source dataset from config let source_dataset = config.data_scope.remove(&self.source).with_context(|| { @@ -580,7 +531,7 @@ impl TaskCall for DataSourceTask { match source_dataset { VegaFusionDataset::Plan { plan } => { let task_value = - maybe_materialize_plan(TaskValue::Plan(plan), &pipeline).await?; + maybe_materialize_plan(TaskValue::Plan(plan), &ctx.pipeline).await?; return Ok((task_value, Vec::new())); } VegaFusionDataset::Table { table, .. } => { @@ -591,8 +542,10 @@ impl TaskCall for DataSourceTask { } let source_df = match &source_dataset { - VegaFusionDataset::Table { table, .. } => ctx.vegafusion_table(table.clone()).await?, - VegaFusionDataset::Plan { plan } => DataFrame::new(ctx.state(), plan.clone()), + VegaFusionDataset::Table { table, .. } => { + session_ctx.vegafusion_table(table.clone()).await? + } + VegaFusionDataset::Plan { plan } => DataFrame::new(session_ctx.state(), plan.clone()), }; let source_df = source_df.with_index()?; @@ -601,7 +554,7 @@ impl TaskCall for DataSourceTask { let (df, output_values) = transform_pipeline.eval_sql(source_df, &config).await?; let df = df.drop_index()?; let task_value = df.to_task_value(&source_dataset).await?; - let task_value = maybe_materialize_plan(task_value, &pipeline).await?; + let task_value = maybe_materialize_plan(task_value, &ctx.pipeline).await?; Ok((task_value, output_values)) } } @@ -642,18 +595,8 @@ async fn read_csv_with_reqwest( is_tsv: bool, ext: &str, ) -> Result { - // Fetch CSV content using reqwest - let client = reqwest::Client::new(); - let response = client - .get(url) - .send() - .await - .external(format!("Failed to fetch URL: {url}"))?; - - let text = response - .text() - .await - .external("Failed to read response as text")?; + let bytes = fetch_http_bytes(url).await?; + let text: Cow = String::from_utf8_lossy(&bytes); // Create a temporary file to store the CSV content use std::io::Write; @@ -664,7 +607,7 @@ async fn read_csv_with_reqwest( temp_file.sync_all()?; // Read the CSV from the temporary file - let temp_url = format!("file://{}", temp_path.display()); + let temp_url = path_to_file_url(temp_path.to_str().unwrap())?; // Build CSV options let mut csv_opts = if is_tsv { @@ -696,7 +639,7 @@ async fn read_csv_with_reqwest( ctx.vegafusion_table(table).await } -async fn read_csv( +pub(crate) async fn read_csv( url: &str, parse: &Option, ctx: Arc, @@ -736,7 +679,7 @@ async fn read_csv( } /// Build final schema by combining the input and inferred schemas -async fn build_csv_schema( +pub(crate) async fn build_csv_schema( csv_opts: &CsvReadOptions<'_>, parse: &Option, uri: impl Into, @@ -798,84 +741,82 @@ async fn build_csv_schema( Ok(Schema::new(new_fields)) } -async fn read_json(url: &str, ctx: Arc) -> Result { - let value: serde_json::Value = - if let Some(base_url) = maybe_register_object_stores_for_url(&ctx, url)? { - // Create single use object store that points directly to file - let store = ctx.runtime_env().object_store(&base_url)?; - let child_url = url.strip_prefix(&base_url.to_string()).unwrap(); - match store.get(&child_url.into()).await { - Ok(get_res) => { - let bytes = get_res.bytes().await?.to_vec(); - let text: Cow = String::from_utf8_lossy(&bytes); - serde_json::from_str(text.as_ref())? - } - Err(e) => { - cfg_if::cfg_if! { - if #[cfg(feature="http")] { - if url.starts_with("http://") || url.starts_with("https://") { - // Fallback to direct reqwest implementation. This is needed in some cases because - // the object-store http implementation has stricter requirements on what the - // server provides. For example the content-length header is required. - let client = reqwest::Client::new(); - let response = client - .get(url) - .send() - .await - .external(format!("Failed to fetch URL: {url}"))?; - - let text = response - .text() - .await - .external("Failed to read response as text")?; - serde_json::from_str(&text)? - } else { - return Err(VegaFusionError::from(e)); - } +async fn read_json_via_store_or_file( + url: &str, + ctx: Arc, +) -> Result { + if let Some(base_url) = maybe_register_object_stores_for_url(&ctx, url)? { + // Create single use object store that points directly to file + let store = ctx.runtime_env().object_store(&base_url)?; + let child_url = url.strip_prefix(&base_url.to_string()).unwrap(); + match store.get(&child_url.into()).await { + Ok(get_res) => { + let bytes = get_res.bytes().await?.to_vec(); + let text: Cow = String::from_utf8_lossy(&bytes); + Ok(serde_json::from_str(text.as_ref())?) + } + Err(e) => { + cfg_if::cfg_if! { + if #[cfg(feature="http")] { + if url.starts_with("http://") || url.starts_with("https://") { + let bytes = fetch_http_bytes(url).await?; + Ok(serde_json::from_slice(&bytes)?) } else { - return Err(VegaFusionError::from(e)); + Err(VegaFusionError::from(e)) } + } else { + Err(VegaFusionError::from(e)) } } } - } else { - cfg_if::cfg_if! { - if #[cfg(feature="fs")] { - // Assume local file - let mut file = tokio::fs::File::open(url) - .await - .external(format!("Failed to open as local file: {url}"))?; - - let mut json_str = String::new(); - file.read_to_string(&mut json_str) - .await - .external("Failed to read file contents to string")?; - - serde_json::from_str(&json_str)? + } + } else { + cfg_if::cfg_if! { + if #[cfg(feature="fs")] { + let local_path = if url.starts_with("file://") { + file_url_to_path(url)? } else { - return Err(VegaFusionError::internal( - "The `fs` feature flag must be enabled for file system support" - )); - } + std::path::PathBuf::from(url) + }; + + let mut file = tokio::fs::File::open(&local_path) + .await + .external(format!("Failed to open as local file: {}", local_path.display()))?; + + let mut json_str = String::new(); + file.read_to_string(&mut json_str) + .await + .external("Failed to read file contents to string")?; + + Ok(serde_json::from_str(&json_str)?) + } else { + Err(VegaFusionError::internal( + "The `fs` feature flag must be enabled for file system support" + )) } - }; + } + } +} + +pub(crate) async fn read_json(url: &str, ctx: Arc) -> Result { + let value: serde_json::Value = read_json_via_store_or_file(url, ctx.clone()).await?; let table = VegaFusionTable::from_json(&value)?.with_ordering()?; ctx.vegafusion_table(table).await } -async fn read_arrow(url: &str, ctx: Arc) -> Result { +pub(crate) async fn read_arrow(url: &str, ctx: Arc) -> Result { maybe_register_object_stores_for_url(&ctx, url)?; Ok(ctx.read_arrow(url, ArrowReadOptions::default()).await?) } #[cfg(feature = "parquet")] -async fn read_parquet(url: &str, ctx: Arc) -> Result { +pub(crate) async fn read_parquet(url: &str, ctx: Arc) -> Result { maybe_register_object_stores_for_url(&ctx, url)?; Ok(ctx.read_parquet(url, ParquetReadOptions::default()).await?) } -fn maybe_register_object_stores_for_url( +pub(crate) fn maybe_register_object_stores_for_url( ctx: &SessionContext, url: &str, ) -> Result> { @@ -886,10 +827,11 @@ fn maybe_register_object_stores_for_url( if let Some(path) = url.strip_prefix(prefix) { let Some((root, _)) = path.split_once('/') else { return Err(VegaFusionError::specification(format!( - "Invalid https URL: {url}" + "Invalid {prefix} URL: {url}" ))); }; - let base_url_str = format!("https://{root}"); + let scheme = prefix.trim_end_matches("://"); + let base_url_str = format!("{scheme}://{root}"); let base_url = url::Url::parse(&base_url_str)?; // Register store for url if not already registered diff --git a/vegafusion-runtime/src/signal/mod.rs b/vegafusion-runtime/src/signal/mod.rs index f22feb510..8a3a37f27 100644 --- a/vegafusion-runtime/src/signal/mod.rs +++ b/vegafusion-runtime/src/signal/mod.rs @@ -1,13 +1,9 @@ -use crate::data::pipeline::ResolverPipeline; use crate::data::tasks::build_compilation_config; use crate::expression::compiler::compile; use crate::expression::compiler::utils::ExprHelpers; -use crate::task_graph::task::TaskCall; +use crate::task_graph::task::{TaskCall, TaskContext}; use async_trait::async_trait; -use std::collections::HashMap; -use vegafusion_core::data::dataset::VegaFusionDataset; -use crate::task_graph::timezone::RuntimeTzConfig; use vegafusion_core::error::Result; use vegafusion_core::proto::gen::tasks::SignalTask; use vegafusion_core::task_graph::task::TaskDependencies; @@ -18,11 +14,14 @@ impl TaskCall for SignalTask { async fn eval( &self, values: &[TaskValue], - tz_config: &Option, - _inline_datasets: HashMap, - pipeline: ResolverPipeline, + ctx: &TaskContext, ) -> Result<(TaskValue, Vec)> { - let config = build_compilation_config(&self.input_vars(), values, tz_config, pipeline); + let config = build_compilation_config( + &self.input_vars(), + values, + &ctx.tz_config, + ctx.pipeline.clone(), + ); let expression = self.expr.as_ref().unwrap(); let expr = compile(expression, &config, None).await?; let value = expr.eval_to_scalar()?; diff --git a/vegafusion-runtime/src/task_graph/runtime.rs b/vegafusion-runtime/src/task_graph/runtime.rs index 8c0df325d..2fa174c06 100644 --- a/vegafusion-runtime/src/task_graph/runtime.rs +++ b/vegafusion-runtime/src/task_graph/runtime.rs @@ -1,7 +1,8 @@ use crate::data::pipeline::ResolverPipeline; +use crate::data::plan_resolver::PlanResolver; use crate::datafusion::context::make_datafusion_context; use crate::task_graph::cache::VegaFusionCache; -use crate::task_graph::task::TaskCall; +use crate::task_graph::task::{TaskCall, TaskContext}; use crate::task_graph::timezone::RuntimeTzConfig; use async_recursion::async_recursion; use cfg_if::cfg_if; @@ -18,8 +19,8 @@ use vegafusion_core::proto::gen::tasks::inline_dataset::Dataset; use vegafusion_core::proto::gen::tasks::{ task::TaskKind, InlineDataset, InlineDatasetTable, NodeValueIndex, TaskGraph, }; -use vegafusion_core::runtime::PlanResolver; use vegafusion_core::runtime::VegaFusionRuntimeTrait; +use vegafusion_core::runtime::{normalize_allowed_base_urls, AllowedBaseUrlPattern}; use vegafusion_core::task_graph::task_value::{MaterializedTaskValue, NamedTaskValue, TaskValue}; #[cfg(feature = "proto")] @@ -33,19 +34,47 @@ use { type CacheValue = (TaskValue, Vec); +use crate::data::pipeline::{resolve_base_url, BaseUrlSetting}; + +pub struct VegaFusionRuntimeOpts { + pub plan_resolvers: Vec>, + pub base_url: BaseUrlSetting, + pub allowed_base_urls: Option>, + pub cache: Option, +} + +impl Default for VegaFusionRuntimeOpts { + fn default() -> Self { + Self { + plan_resolvers: Vec::new(), + base_url: BaseUrlSetting::Default, + allowed_base_urls: None, + cache: None, + } + } +} + #[derive(Clone)] pub struct VegaFusionRuntime { pub cache: VegaFusionCache, pub pipeline: ResolverPipeline, + pub base_url: Option, + pub allowed_base_urls: Option>, } impl VegaFusionRuntime { - pub fn new(cache: Option, plan_resolvers: Vec>) -> Self { + pub fn new(opts: VegaFusionRuntimeOpts) -> vegafusion_core::error::Result { let ctx = Arc::new(make_datafusion_context()); - Self { - cache: cache.unwrap_or_else(|| VegaFusionCache::new(Some(32), None)), - pipeline: ResolverPipeline::new(plan_resolvers, ctx), - } + let base_url = resolve_base_url(&opts.base_url)?; + let allowed_base_urls = normalize_allowed_base_urls(opts.allowed_base_urls)?; + Ok(Self { + cache: opts + .cache + .unwrap_or_else(|| VegaFusionCache::new(Some(32), None)), + pipeline: ResolverPipeline::new(opts.plan_resolvers, ctx), + base_url, + allowed_base_urls, + }) } pub async fn get_node_value( @@ -56,13 +85,18 @@ impl VegaFusionRuntime { ) -> Result { // We shouldn't panic inside get_or_compute_node_value, but since this may be used // in a server context, wrap in catch_unwind just in case. - let pipeline = self.pipeline.clone(); + let task_ctx = TaskContext { + tz_config: None, // overridden per-task from task.tz_config + inline_datasets, + pipeline: self.pipeline.clone(), + base_url: self.base_url.clone(), + allowed_base_urls: self.allowed_base_urls.clone(), + }; let node_value = AssertUnwindSafe(get_or_compute_node_value( task_graph, node_value_index.node_index as usize, self.cache.clone(), - inline_datasets, - pipeline, + task_ctx, )) .catch_unwind() .await; @@ -84,7 +118,7 @@ impl VegaFusionRuntime { impl Default for VegaFusionRuntime { fn default() -> Self { - Self::new(None, Vec::new()) + Self::new(VegaFusionRuntimeOpts::default()).expect("default opts should not fail") } } @@ -163,8 +197,7 @@ async fn get_or_compute_node_value( task_graph: Arc, node_index: usize, cache: VegaFusionCache, - inline_datasets: HashMap, - pipeline: ResolverPipeline, + task_ctx: TaskContext, ) -> Result { // Get the cache key for requested node let node = task_graph.node(node_index).unwrap(); @@ -195,8 +228,7 @@ async fn get_or_compute_node_value( task_graph.clone(), input_node_index, cloned_cache.clone(), - inline_datasets.clone(), - pipeline.clone(), + task_ctx.clone(), ); cfg_if! { @@ -244,8 +276,12 @@ async fn get_or_compute_node_value( }) .collect::>>()?; - task.eval(&input_values, &tz_config, inline_datasets, pipeline) - .await + // Override tz_config from task + let task_ctx = TaskContext { + tz_config, + ..task_ctx + }; + task.eval(&input_values, &task_ctx).await }; // get or construct from cache diff --git a/vegafusion-runtime/src/task_graph/task.rs b/vegafusion-runtime/src/task_graph/task.rs index 3d7bc9ed3..2946c4c45 100644 --- a/vegafusion-runtime/src/task_graph/task.rs +++ b/vegafusion-runtime/src/task_graph/task.rs @@ -7,16 +7,25 @@ use vegafusion_core::data::dataset::VegaFusionDataset; use vegafusion_core::error::Result; use vegafusion_core::proto::gen::tasks::task::TaskKind; use vegafusion_core::proto::gen::tasks::Task; +use vegafusion_core::runtime::AllowedBaseUrlPattern; use vegafusion_core::task_graph::task_value::TaskValue; +/// Ambient context available to all tasks during evaluation. +#[derive(Clone)] +pub struct TaskContext { + pub tz_config: Option, + pub inline_datasets: HashMap, + pub pipeline: ResolverPipeline, + pub base_url: Option, + pub allowed_base_urls: Option>, +} + #[async_trait] pub trait TaskCall { async fn eval( &self, values: &[TaskValue], - tz_config: &Option, - inline_datasets: HashMap, - pipeline: ResolverPipeline, + ctx: &TaskContext, ) -> Result<(TaskValue, Vec)>; } @@ -25,28 +34,14 @@ impl TaskCall for Task { async fn eval( &self, values: &[TaskValue], - tz_config: &Option, - inline_datasets: HashMap, - pipeline: ResolverPipeline, + ctx: &TaskContext, ) -> Result<(TaskValue, Vec)> { match self.task_kind() { TaskKind::Value(value) => Ok((value.try_into()?, Default::default())), - TaskKind::DataUrl(task) => { - task.eval(values, tz_config, inline_datasets, pipeline) - .await - } - TaskKind::DataValues(task) => { - task.eval(values, tz_config, inline_datasets, pipeline) - .await - } - TaskKind::DataSource(task) => { - task.eval(values, tz_config, inline_datasets, pipeline) - .await - } - TaskKind::Signal(task) => { - task.eval(values, tz_config, inline_datasets, pipeline) - .await - } + TaskKind::DataUrl(task) => task.eval(values, ctx).await, + TaskKind::DataValues(task) => task.eval(values, ctx).await, + TaskKind::DataSource(task) => task.eval(values, ctx).await, + TaskKind::Signal(task) => task.eval(values, ctx).await, } } } diff --git a/vegafusion-runtime/tests/test_plan_resolver.rs b/vegafusion-runtime/tests/test_plan_resolver.rs index 02ec22f16..f0eddb274 100644 --- a/vegafusion-runtime/tests/test_plan_resolver.rs +++ b/vegafusion-runtime/tests/test_plan_resolver.rs @@ -14,11 +14,14 @@ use vegafusion_common::datafusion_expr::LogicalPlan; use vegafusion_common::error::{Result, VegaFusionError}; use vegafusion_core::data::dataset::VegaFusionDataset; use vegafusion_core::proto::gen::pretransform::PreTransformSpecOpts; -use vegafusion_core::runtime::{PlanResolver, ResolutionResult, VegaFusionRuntimeTrait}; +use vegafusion_core::runtime::{ParsedUrl, VegaFusionRuntimeTrait}; use vegafusion_core::spec::chart::ChartSpec; use vegafusion_runtime::data::external_table::ExternalTableProvider; use vegafusion_runtime::data::pipeline::ResolverPipeline; +use vegafusion_runtime::data::plan_resolver::PlanResolver; +use vegafusion_runtime::data::plan_resolver::ResolutionResult; use vegafusion_runtime::task_graph::runtime::VegaFusionRuntime; +use vegafusion_runtime::task_graph::runtime::VegaFusionRuntimeOpts; #[derive(Clone, Debug)] struct ResolverEvent { @@ -176,8 +179,8 @@ fn table_row_count(table: &VegaFusionTable) -> usize { fn build_external_scan_plan(table_name: &str) -> LogicalPlan { let schema = get_movies_schema(); let provider = Arc::new(ExternalTableProvider::new( + "test".to_string(), schema, - Some("test".to_string()), serde_json::Value::Null, )); let table_source = provider_as_source(provider); @@ -197,7 +200,11 @@ async fn test_custom_executor_called_in_pre_transform_spec() { ); let resolver_clone = resolver.clone(); - let runtime = VegaFusionRuntime::new(None, vec![Arc::new(resolver)]); + let runtime = VegaFusionRuntime::new(VegaFusionRuntimeOpts { + plan_resolvers: vec![Arc::new(resolver)], + ..Default::default() + }) + .unwrap(); let spec = get_simple_spec(); let inline_datasets = get_inline_datasets(); @@ -236,7 +243,11 @@ async fn test_custom_executor_called_in_pre_transform_extract() { ); let resolver_clone = resolver.clone(); - let runtime = VegaFusionRuntime::new(None, vec![Arc::new(resolver)]); + let runtime = VegaFusionRuntime::new(VegaFusionRuntimeOpts { + plan_resolvers: vec![Arc::new(resolver)], + ..Default::default() + }) + .unwrap(); let spec = get_simple_spec(); let inline_datasets = get_inline_datasets(); @@ -275,7 +286,11 @@ async fn test_custom_executor_called_in_pre_transform_values() { ); let resolver_clone = resolver.clone(); - let runtime = VegaFusionRuntime::new(None, vec![Arc::new(resolver)]); + let runtime = VegaFusionRuntime::new(VegaFusionRuntimeOpts { + plan_resolvers: vec![Arc::new(resolver)], + ..Default::default() + }) + .unwrap(); let spec = get_simple_spec(); let inline_datasets = get_inline_datasets(); @@ -323,7 +338,11 @@ async fn test_bin_transform_uses_custom_executor() { ); let resolver_clone = resolver.clone(); - let runtime = VegaFusionRuntime::new(None, vec![Arc::new(resolver)]); + let runtime = VegaFusionRuntime::new(VegaFusionRuntimeOpts { + plan_resolvers: vec![Arc::new(resolver)], + ..Default::default() + }) + .unwrap(); let spec_str = r#"{ "$schema": "https://vega.github.io/schema/vega/v5.json", @@ -433,7 +452,11 @@ async fn test_mixed_data_only_executes_plans() { ); let resolver_clone = resolver.clone(); - let runtime = VegaFusionRuntime::new(None, vec![Arc::new(resolver)]); + let runtime = VegaFusionRuntime::new(VegaFusionRuntimeOpts { + plan_resolvers: vec![Arc::new(resolver)], + ..Default::default() + }) + .unwrap(); let spec_str = r#"{ "$schema": "https://vega.github.io/schema/vega/v5.json", @@ -796,6 +819,137 @@ fn get_inline_datasets() -> std::collections::HashMap datasets } +/// A resolver that claims custom:// URLs by returning an ExternalTableProvider plan +struct CustomSchemeScanner { + schema: Arc, +} + +#[async_trait] +impl PlanResolver for CustomSchemeScanner { + fn name(&self) -> &str { + "custom_scheme_scanner" + } + + async fn scan_url(&self, parsed_url: &ParsedUrl) -> Result> { + if parsed_url.scheme == "custom" { + let provider = Arc::new(ExternalTableProvider::new( + "custom".to_string(), + self.schema.clone(), + serde_json::json!({"url": parsed_url.url}), + )); + let plan = LogicalPlanBuilder::scan("custom_table", provider_as_source(provider), None) + .unwrap() + .build() + .unwrap(); + Ok(Some(plan)) + } else { + Ok(None) + } + } + + async fn resolve_plan(&self, plan: LogicalPlan) -> Result { + // Rewrite ExternalTableProvider to MemTable for execution + let movies = create_movies_table(); + let mem_table = Arc::new( + MemTable::try_new(movies.schema.clone(), vec![movies.batches.clone()]).unwrap(), + ) as Arc; + let mut rewriter = TableRewriter { + movies_table: mem_table, + }; + let rewritten = plan.rewrite(&mut rewriter).unwrap().data; + Ok(ResolutionResult::Plan(rewritten)) + } +} + +#[tokio::test] +async fn test_scan_url_custom_scheme_first_wins() { + let schema = get_movies_schema(); + let scanner = CustomSchemeScanner { + schema: schema.clone(), + }; + + let ctx = Arc::new(datafusion::prelude::SessionContext::new()); + let pipeline = ResolverPipeline::new(vec![Arc::new(scanner)], ctx); + + let parsed = ParsedUrl { + url: "custom://mydb/table1".to_string(), + scheme: "custom".to_string(), + host: Some("mydb".to_string()), + path: "/table1".to_string(), + query_params: vec![], + extension: None, + format_type: None, + parse: None, + }; + + let result = pipeline.scan_url(&parsed).await.unwrap(); + assert!( + result.is_some(), + "Custom scanner should handle custom:// URLs" + ); +} + +#[tokio::test] +async fn test_scan_url_unknown_scheme_falls_through() { + let ctx = Arc::new(datafusion::prelude::SessionContext::new()); + // Pipeline with only DataFusionResolver (no user resolvers) + let pipeline = ResolverPipeline::new(vec![], ctx); + + let parsed = ParsedUrl { + url: "spark://cluster/table1".to_string(), + scheme: "spark".to_string(), + host: Some("cluster".to_string()), + path: "/table1".to_string(), + query_params: vec![], + extension: None, + format_type: None, + parse: None, + }; + + let result = pipeline.scan_url(&parsed).await.unwrap(); + assert!( + result.is_none(), + "DataFusionResolver should return None for unknown schemes" + ); +} + +#[tokio::test] +async fn test_should_materialize() { + let ctx = Arc::new(datafusion::prelude::SessionContext::new()); + let schema = get_movies_schema(); + + // A plan with no external tables (just an empty MemTable) + let empty_batch = RecordBatch::new_empty(schema.clone()); + let mem_table = MemTable::try_new(schema.clone(), vec![vec![empty_batch]]).unwrap(); + let plain_plan = + LogicalPlanBuilder::scan("plain", provider_as_source(Arc::new(mem_table)), None) + .unwrap() + .build() + .unwrap(); + + // A plan with an ExternalTableProvider + let ext_provider = + ExternalTableProvider::new("custom".to_string(), schema.clone(), serde_json::json!({})); + let external_plan = + LogicalPlanBuilder::scan("ext", provider_as_source(Arc::new(ext_provider)), None) + .unwrap() + .build() + .unwrap(); + + // DataFusion-only: all support arrow → always materialize + let pipeline = ResolverPipeline::new(vec![], ctx.clone()); + assert!(pipeline.should_materialize(&plain_plan)); + assert!(pipeline.should_materialize(&external_plan)); + + // With a non-arrow resolver: materialize plain plans, not external ones + let scanner = CustomSchemeScanner { + schema: schema.clone(), + }; + let pipeline = ResolverPipeline::new(vec![Arc::new(scanner)], ctx); + assert!(pipeline.should_materialize(&plain_plan)); + assert!(!pipeline.should_materialize(&external_plan)); +} + /// Test a resolver that returns ResolutionResult::Table directly (bypassing DataFusion execution). #[tokio::test] async fn test_table_returning_resolver() { @@ -836,7 +990,11 @@ async fn test_table_returning_resolver() { let resolver = TableResolver { movies_table: mem_table, }; - let runtime = VegaFusionRuntime::new(None, vec![Arc::new(resolver)]); + let runtime = VegaFusionRuntime::new(VegaFusionRuntimeOpts { + plan_resolvers: vec![Arc::new(resolver)], + ..Default::default() + }) + .unwrap(); let spec = get_simple_spec(); let inline_datasets = get_inline_datasets(); @@ -862,7 +1020,7 @@ async fn test_table_returning_resolver() { /// Test that VegaFusionRuntime works with no resolver (None) when inline datasets are tables. #[tokio::test] async fn test_no_resolver() { - let runtime = VegaFusionRuntime::new(None, Vec::new()); + let runtime = VegaFusionRuntime::new(VegaFusionRuntimeOpts::default()).unwrap(); let spec_str = r#"{ "$schema": "https://vega.github.io/schema/vega/v5.json", @@ -944,8 +1102,8 @@ mod serialization_tests { async fn test_external_table_proto_round_trip() { let schema = get_movies_schema(); let provider = Arc::new(ExternalTableProvider::new( + "test".to_string(), schema, - Some("test".to_string()), serde_json::Value::Null, )); let table_source = provider_as_source(provider); @@ -978,8 +1136,8 @@ mod serialization_tests { async fn test_external_table_raw_proto_inspection() { let schema = get_movies_schema(); let provider = Arc::new(ExternalTableProvider::new( + "test".to_string(), schema.clone(), - Some("test".to_string()), serde_json::Value::Null, )); let table_source = provider_as_source(provider); @@ -1058,8 +1216,8 @@ mod serialization_tests { "filters": [{"col": "year", "op": ">", "val": 2000}], }); let provider = Arc::new(ExternalTableProvider::new( + "test".to_string(), schema.clone(), - Some("test".to_string()), metadata.clone(), )); let table_source = provider_as_source(provider); @@ -1082,7 +1240,7 @@ mod serialization_tests { .as_any() .downcast_ref::() .expect("Expected ExternalTableProvider"); - assert_eq!(ext.protocol(), Some("test")); + assert_eq!(ext.scheme(), "test"); assert_eq!(ext.metadata(), &metadata); } else { panic!("Expected TableScan, got {:?}", round_tripped); @@ -1174,7 +1332,11 @@ async fn test_resolver_error_propagation() { } } - let runtime = VegaFusionRuntime::new(None, vec![Arc::new(FailingResolver)]); + let runtime = VegaFusionRuntime::new(VegaFusionRuntimeOpts { + plan_resolvers: vec![Arc::new(FailingResolver)], + ..Default::default() + }) + .unwrap(); let spec = get_simple_spec(); let inline_datasets = get_inline_datasets(); @@ -1248,21 +1410,48 @@ async fn test_datafusion_resolver_executes_simple_plan() { } #[tokio::test] -async fn test_resolver_pipeline_has_user_resolvers() { +async fn test_resolver_pipeline_should_materialize() { let ctx = Arc::new(datafusion::prelude::SessionContext::new()); + let schema = get_movies_schema(); + + // Plan with no external tables + let empty_batch = RecordBatch::new_empty(schema.clone()); + let mem_table = MemTable::try_new(schema.clone(), vec![vec![empty_batch]]).unwrap(); + let plain_plan = + LogicalPlanBuilder::scan("plain", provider_as_source(Arc::new(mem_table)), None) + .unwrap() + .build() + .unwrap(); + + // Plan with an external table + let ext = ExternalTableProvider::new("test".to_string(), schema, serde_json::json!({})); + let external_plan = LogicalPlanBuilder::scan("ext", provider_as_source(Arc::new(ext)), None) + .unwrap() + .build() + .unwrap(); + // DataFusion-only: always materialize let empty_pipeline = ResolverPipeline::new(vec![], ctx.clone()); assert!( - !empty_pipeline.has_user_resolvers(), - "Empty pipeline should report no user resolvers" + empty_pipeline.should_materialize(&plain_plan), + "DataFusion-only pipeline should materialize plain plans" + ); + assert!( + empty_pipeline.should_materialize(&external_plan), + "DataFusion-only pipeline should materialize even external plans" ); + // With non-arrow resolver: materialize plain, not external let events = Arc::new(Mutex::new(Vec::new())); let resolver = ScriptedResolver::new("test", ResolverBehavior::PassThroughPlan, events); let resolvers: Vec> = vec![Arc::new(resolver)]; let pipeline_with_resolvers = ResolverPipeline::new(resolvers, ctx); assert!( - pipeline_with_resolvers.has_user_resolvers(), - "Pipeline with resolvers should report has user resolvers" + pipeline_with_resolvers.should_materialize(&plain_plan), + "Non-arrow pipeline should still materialize plans with no external tables" + ); + assert!( + !pipeline_with_resolvers.should_materialize(&external_plan), + "Non-arrow pipeline should not materialize plans with external tables" ); } diff --git a/vegafusion-runtime/tests/test_url_policy.rs b/vegafusion-runtime/tests/test_url_policy.rs new file mode 100644 index 000000000..25b83e73c --- /dev/null +++ b/vegafusion-runtime/tests/test_url_policy.rs @@ -0,0 +1,263 @@ +use async_trait::async_trait; +use datafusion::datasource::{provider_as_source, MemTable}; +use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder}; +use serde_json::json; +use std::fs; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use tempfile::TempDir; +use vegafusion_common::arrow::array::{ArrayRef, Float64Array}; +use vegafusion_common::arrow::record_batch::RecordBatch; +use vegafusion_common::data::scalar::ScalarValueHelpers; +use vegafusion_common::error::Result; +use vegafusion_core::proto::gen::tasks::{TaskGraph, TzConfig, Variable}; +use vegafusion_core::spec::chart::ChartSpec; +use vegafusion_runtime::data::pipeline::BaseUrlSetting; +use vegafusion_runtime::data::plan_resolver::{PlanResolver, ResolutionResult}; +use vegafusion_runtime::task_graph::runtime::{VegaFusionRuntime, VegaFusionRuntimeOpts}; + +fn write_json_rows(dir: &Path, name: &str, values: &[f64]) -> PathBuf { + let path = dir.join(name); + let rows: Vec<_> = values.iter().map(|value| json!({ "x": value })).collect(); + fs::write(&path, serde_json::to_string(&rows).unwrap()).unwrap(); + path +} + +fn extent_spec(url: serde_json::Value) -> ChartSpec { + serde_json::from_value(json!({ + "$schema": "https://vega.github.io/schema/vega/v5.json", + "data": [ + { + "name": "source", + "url": url, + "format": {"type": "json"}, + }, + { + "name": "derived", + "source": "source", + "transform": [ + { + "type": "extent", + "signal": "my_extent", + "field": "x", + } + ], + } + ] + })) + .unwrap() +} + +fn extent_spec_with_url_signal(signal_url: &str) -> ChartSpec { + serde_json::from_value(json!({ + "$schema": "https://vega.github.io/schema/vega/v5.json", + "signals": [ + { + "name": "url", + "value": signal_url, + } + ], + "data": [ + { + "name": "source", + "url": {"signal": "url"}, + "format": {"type": "json"}, + }, + { + "name": "derived", + "source": "source", + "transform": [ + { + "type": "extent", + "signal": "my_extent", + "field": "x", + } + ], + } + ] + })) + .unwrap() +} + +async fn query_extent(runtime: &VegaFusionRuntime, spec: &ChartSpec) -> Result<[f64; 2]> { + let tz_config = TzConfig { + local_tz: "UTC".to_string(), + default_input_tz: None, + }; + let task_scope = spec.to_task_scope().unwrap(); + let tasks = spec.to_tasks(&tz_config, &Default::default()).unwrap(); + let graph = Arc::new(TaskGraph::new(tasks, &task_scope).unwrap()); + let mapping = graph.build_mapping(); + let node = mapping + .get(&(Variable::new_signal("my_extent"), Vec::new())) + .cloned() + .unwrap(); + let value = runtime + .get_node_value(graph, &node, Default::default()) + .await?; + value.as_scalar()?.to_f64x2() +} + +struct CustomSchemeResolver; + +#[async_trait] +impl PlanResolver for CustomSchemeResolver { + fn name(&self) -> &str { + "custom_scheme_resolver" + } + + async fn scan_url( + &self, + parsed_url: &vegafusion_core::runtime::ParsedUrl, + ) -> Result> { + if parsed_url.scheme != "custom" { + return Ok(None); + } + + let batch = RecordBatch::try_from_iter(vec![( + "x", + Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0])) as ArrayRef, + )]) + .unwrap(); + let mem_table = MemTable::try_new(batch.schema(), vec![vec![batch]]).unwrap(); + let plan = LogicalPlanBuilder::scan( + "custom_table", + provider_as_source(Arc::new(mem_table)), + None, + ) + .unwrap() + .build() + .unwrap(); + Ok(Some(plan)) + } + + async fn resolve_plan(&self, plan: LogicalPlan) -> Result { + Ok(ResolutionResult::Plan(plan)) + } +} + +fn tempdir_str(tempdir: &TempDir) -> String { + tempdir.path().to_str().unwrap().to_string() +} + +#[tokio::test] +async fn test_relative_url_resolves_against_base_url_and_allowlist() { + let tempdir = tempfile::tempdir().unwrap(); + write_json_rows(tempdir.path(), "data.json", &[1.0, 2.0, 3.0]); + + let runtime = VegaFusionRuntime::new(VegaFusionRuntimeOpts { + base_url: BaseUrlSetting::Custom(tempdir_str(&tempdir)), + allowed_base_urls: Some(vec![tempdir_str(&tempdir)]), + ..Default::default() + }) + .unwrap(); + + let extent = query_extent(&runtime, &extent_spec(json!("data.json"))) + .await + .unwrap(); + assert_eq!(extent, [1.0, 3.0]); +} + +#[tokio::test] +async fn test_relative_url_fails_when_base_url_disabled() { + let tempdir = tempfile::tempdir().unwrap(); + write_json_rows(tempdir.path(), "data.json", &[1.0, 2.0, 3.0]); + + let runtime = VegaFusionRuntime::new(VegaFusionRuntimeOpts { + base_url: BaseUrlSetting::Disabled, + ..Default::default() + }) + .unwrap(); + + let err = query_extent(&runtime, &extent_spec(json!("data.json"))) + .await + .unwrap_err(); + let message = err.to_string(); + assert!( + message.contains("Relative URL with no base_url configured"), + "unexpected error: {message}" + ); +} + +#[tokio::test] +async fn test_allowed_base_urls_block_local_file_access() { + let allowed_dir = tempfile::tempdir().unwrap(); + let blocked_dir = tempfile::tempdir().unwrap(); + write_json_rows(blocked_dir.path(), "data.json", &[1.0, 2.0, 3.0]); + + let runtime = VegaFusionRuntime::new(VegaFusionRuntimeOpts { + base_url: BaseUrlSetting::Custom(tempdir_str(&blocked_dir)), + allowed_base_urls: Some(vec![tempdir_str(&allowed_dir)]), + ..Default::default() + }) + .unwrap(); + + let err = query_extent(&runtime, &extent_spec(json!("data.json"))) + .await + .unwrap_err(); + let message = err.to_string(); + assert!( + message.contains("blocked by allowed_base_urls"), + "unexpected error: {message}" + ); +} + +#[tokio::test] +async fn test_allowed_base_urls_gate_custom_scheme_resolvers() { + let runtime = VegaFusionRuntime::new(VegaFusionRuntimeOpts { + plan_resolvers: vec![Arc::new(CustomSchemeResolver)], + allowed_base_urls: Some(vec!["custom://allowed-host/".to_string()]), + ..Default::default() + }) + .unwrap(); + + let allowed_extent = query_extent( + &runtime, + &extent_spec(json!("custom://allowed-host/warehouse/table")), + ) + .await + .unwrap(); + assert_eq!(allowed_extent, [10.0, 30.0]); + + let err = query_extent( + &runtime, + &extent_spec(json!("custom://blocked-host/warehouse/table")), + ) + .await + .unwrap_err(); + let message = err.to_string(); + assert!( + message.contains("blocked by allowed_base_urls"), + "unexpected error: {message}" + ); +} + +#[tokio::test] +async fn test_signal_updated_urls_are_revalidated_against_policy() { + let runtime = VegaFusionRuntime::new(VegaFusionRuntimeOpts { + plan_resolvers: vec![Arc::new(CustomSchemeResolver)], + allowed_base_urls: Some(vec!["custom://allowed-host/".to_string()]), + ..Default::default() + }) + .unwrap(); + + let allowed_extent = query_extent( + &runtime, + &extent_spec_with_url_signal("custom://allowed-host/warehouse/table"), + ) + .await + .unwrap(); + assert_eq!(allowed_extent, [10.0, 30.0]); + + let err = query_extent( + &runtime, + &extent_spec_with_url_signal("custom://blocked-host/warehouse/table"), + ) + .await + .unwrap_err(); + let message = err.to_string(); + assert!( + message.contains("blocked by allowed_base_urls"), + "unexpected error: {message}" + ); +} diff --git a/vegafusion-server/Cargo.toml b/vegafusion-server/Cargo.toml index b364264d2..8f0d4124a 100644 --- a/vegafusion-server/Cargo.toml +++ b/vegafusion-server/Cargo.toml @@ -21,6 +21,9 @@ h2 = "0.4" assert_cmd = "2.0.17" predicates = "3.1.3" +[dev-dependencies.tempfile] +workspace = true + [dependencies.regex] workspace = true diff --git a/vegafusion-server/src/main.rs b/vegafusion-server/src/main.rs index ea46f714f..d0dd5f8e9 100644 --- a/vegafusion-server/src/main.rs +++ b/vegafusion-server/src/main.rs @@ -19,9 +19,12 @@ use vegafusion_core::proto::gen::tasks::{ use vegafusion_core::runtime::VegaFusionRuntimeTrait; use vegafusion_core::spec::chart::ChartSpec; use vegafusion_core::task_graph::graph::ScopedVariable; -use vegafusion_runtime::task_graph::runtime::{decode_inline_datasets, VegaFusionRuntime}; +use vegafusion_runtime::data::pipeline::BaseUrlSetting; +use vegafusion_runtime::task_graph::runtime::{ + decode_inline_datasets, VegaFusionRuntime, VegaFusionRuntimeOpts, +}; -use clap::Parser; +use clap::{ArgAction, Parser}; use regex::Regex; use vegafusion_core::proto::gen::pretransform::{ PreTransformExtractDataset, PreTransformExtractRequest, PreTransformExtractResponse, @@ -347,6 +350,22 @@ struct Args { /// Include compatibility with gRPC-Web #[clap(long, num_args = 0)] pub web: bool, + + /// Base URL for resolving relative data URLs + #[clap(long, conflicts_with = "no_base_url")] + pub base_url: Option, + + /// Disable base URL resolution for relative data URLs + #[clap(long, action = ArgAction::SetTrue, conflicts_with = "base_url")] + pub no_base_url: bool, + + /// Allowlist entry for external data access. Repeat for multiple entries. + #[clap(long = "allowed-base-url", action = ArgAction::Append, conflicts_with = "no_allowed_urls")] + pub allowed_base_url: Vec, + + /// Disable all external data access + #[clap(long, action = ArgAction::SetTrue, conflicts_with = "allowed_base_url")] + pub no_allowed_urls: bool, } fn main() -> Result<(), VegaFusionError> { @@ -368,16 +387,35 @@ fn main() -> Result<(), VegaFusionError> { None }; + let base_url = if args.no_base_url { + BaseUrlSetting::Disabled + } else if let Some(base_url) = args.base_url.clone() { + BaseUrlSetting::Custom(base_url) + } else { + BaseUrlSetting::Default + }; + + let allowed_base_urls = if args.no_allowed_urls { + Some(vec![]) + } else if args.allowed_base_url.is_empty() { + None + } else { + Some(args.allowed_base_url.clone()) + }; + let tokio_runtime = tokio::runtime::Builder::new_multi_thread() .enable_all() .thread_stack_size(TOKIO_THREAD_STACK_SIZE) .build() .expect("Failed to create tokio runtime"); - let tg_runtime = VegaFusionRuntime::new( - Some(VegaFusionCache::new(Some(args.capacity), memory_limit)), - Vec::new(), - ); + let tg_runtime = VegaFusionRuntime::new(VegaFusionRuntimeOpts { + cache: Some(VegaFusionCache::new(Some(args.capacity), memory_limit)), + base_url, + allowed_base_urls, + ..Default::default() + }) + .expect("Failed to create VegaFusionRuntime"); tokio_runtime.block_on(async move { grpc_server(grpc_address, tg_runtime.clone(), args.web) diff --git a/vegafusion-server/tests/test_task_graph_runtime.rs b/vegafusion-server/tests/test_task_graph_runtime.rs index b77faffa2..9332a0005 100644 --- a/vegafusion-server/tests/test_task_graph_runtime.rs +++ b/vegafusion-server/tests/test_task_graph_runtime.rs @@ -1,12 +1,146 @@ +use serde_json::json; +use std::fs; +use std::net::TcpListener; +use std::path::{Path, PathBuf}; +use std::process::{Child, Command, Stdio}; use std::time::Duration; +use tokio::time::sleep; use vegafusion_common::data::scalar::ScalarValueHelpers; use vegafusion_core::proto::gen::services::query_result::Response; use vegafusion_core::proto::gen::services::vega_fusion_runtime_client::VegaFusionRuntimeClient; use vegafusion_core::proto::gen::services::{query_request, QueryRequest}; use vegafusion_core::proto::gen::tasks::{ - NodeValueIndex, TaskGraph, TaskGraphValueRequest, TzConfig, VariableNamespace, + TaskGraph, TaskGraphValueRequest, TzConfig, Variable, VariableNamespace, }; -use vegafusion_core::spec::chart::ChartSpec; // Add methods on commands +use vegafusion_core::spec::chart::ChartSpec; + +struct ServerProcess { + child: Child, +} + +impl Drop for ServerProcess { + fn drop(&mut self) { + let _ = self.child.kill(); + let _ = self.child.wait(); + } +} + +fn pick_unused_port() -> u16 { + TcpListener::bind("127.0.0.1:0") + .unwrap() + .local_addr() + .unwrap() + .port() +} + +fn write_json_rows(dir: &Path, name: &str, values: &[f64]) -> PathBuf { + let path = dir.join(name); + let rows: Vec<_> = values.iter().map(|value| json!({ "x": value })).collect(); + fs::write(&path, serde_json::to_string(&rows).unwrap()).unwrap(); + path +} + +fn extent_spec(url: serde_json::Value) -> ChartSpec { + serde_json::from_value(json!({ + "$schema": "https://vega.github.io/schema/vega/v5.json", + "data": [ + { + "name": "source", + "url": url, + "format": {"type": "json"}, + }, + { + "name": "derived", + "source": "source", + "transform": [ + { + "type": "extent", + "signal": "my_extent", + "field": "x", + } + ], + } + ] + })) + .unwrap() +} + +fn build_request(chart: &ChartSpec) -> QueryRequest { + let tz_config = TzConfig { + local_tz: "UTC".to_string(), + default_input_tz: None, + }; + let task_scope = chart.to_task_scope().unwrap(); + let tasks = chart.to_tasks(&tz_config, &Default::default()).unwrap(); + let graph = TaskGraph::new(tasks, &task_scope).unwrap(); + let mapping = graph.build_mapping(); + let extent_node = mapping + .get(&(Variable::new_signal("my_extent"), Vec::new())) + .cloned() + .unwrap(); + + QueryRequest { + request: Some(query_request::Request::TaskGraphValues( + TaskGraphValueRequest { + task_graph: Some(graph), + indices: vec![extent_node], + inline_datasets: vec![], + }, + )), + } +} + +async fn spawn_server(extra_args: &[String]) -> (ServerProcess, String) { + let port = pick_unused_port(); + let mut cmd = Command::new(assert_cmd::cargo::cargo_bin!("vegafusion-server")); + cmd.arg("--host") + .arg("127.0.0.1") + .arg("--port") + .arg(port.to_string()) + .args(extra_args) + .stdout(Stdio::null()) + .stderr(Stdio::null()); + + let child = cmd.spawn().expect("Failed to spawn vegafusion-server"); + let address = format!("http://127.0.0.1:{port}"); + + for _ in 0..60 { + if VegaFusionRuntimeClient::connect(address.clone()) + .await + .is_ok() + { + return (ServerProcess { child }, address); + } + sleep(Duration::from_millis(100)).await; + } + + panic!("Timed out waiting for vegafusion-server to start on port {port}"); +} + +async fn query_extent(address: String, chart: &ChartSpec) -> std::result::Result<[f64; 2], String> { + let mut client = VegaFusionRuntimeClient::connect(address) + .await + .map_err(|err| err.to_string())?; + let response = client + .task_graph_query(build_request(chart)) + .await + .map_err(|err| err.to_string())?; + + let query_result = response.into_inner(); + match query_result.response.unwrap() { + Response::Error(error) => Err(format!("{error:?}")), + Response::TaskGraphValues(values_response) => { + let response_values = values_response.deserialize().unwrap(); + let (_var, scope, value) = &response_values[0]; + assert_eq!(scope, &Vec::::new()); + value + .as_scalar() + .map_err(|err| err.to_string())? + .to_f64x2() + .map_err(|err| err.to_string()) + } + } +} #[tokio::test(flavor = "multi_thread")] async fn try_it_from_spec() { @@ -53,23 +187,22 @@ async fn try_it_from_spec() { let tasks = chart.to_tasks(&tz_config, &Default::default()).unwrap(); let graph = TaskGraph::new(tasks, &task_scope).unwrap(); + let mapping = graph.build_mapping(); let request = QueryRequest { request: Some(query_request::Request::TaskGraphValues( TaskGraphValueRequest { task_graph: Some(graph), - indices: vec![NodeValueIndex::new(2, Some(0))], + indices: vec![mapping + .get(&(Variable::new_signal("my_extent"), Vec::new())) + .cloned() + .unwrap()], inline_datasets: vec![], }, )), }; - let mut bin = std::process::Command::new(assert_cmd::cargo::cargo_bin!("vegafusion-server")); - let cmd = bin.args(["--port", "50059"]); - - let mut proc = cmd.spawn().expect("Failed to spawn vegafusion-server"); - std::thread::sleep(Duration::from_millis(2000)); - - let mut client = VegaFusionRuntimeClient::connect("http://127.0.0.1:50059") + let (_server, address) = spawn_server(&[]).await; + let mut client = VegaFusionRuntimeClient::connect(address) .await .expect("Failed to connect to gRPC server"); let response = client.task_graph_query(request).await.unwrap(); @@ -94,5 +227,97 @@ async fn try_it_from_spec() { ) } } - proc.kill().ok(); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_server_base_url_flag_resolves_relative_urls() { + let tempdir = tempfile::tempdir().unwrap(); + write_json_rows(tempdir.path(), "data.json", &[1.0, 2.0, 3.0]); + + let args = vec![ + "--base-url".to_string(), + tempdir.path().to_str().unwrap().to_string(), + "--allowed-base-url".to_string(), + tempdir.path().to_str().unwrap().to_string(), + ]; + let (_server, address) = spawn_server(&args).await; + + let extent = query_extent(address, &extent_spec(json!("data.json"))) + .await + .unwrap(); + assert_eq!(extent, [1.0, 3.0]); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_server_no_base_url_rejects_relative_urls() { + let tempdir = tempfile::tempdir().unwrap(); + write_json_rows(tempdir.path(), "data.json", &[1.0, 2.0, 3.0]); + + let args = vec!["--no-base-url".to_string()]; + let (_server, address) = spawn_server(&args).await; + + let err = query_extent(address, &extent_spec(json!("data.json"))) + .await + .unwrap_err(); + assert!( + err.contains("Relative URL with no base_url configured"), + "unexpected error: {err}" + ); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_server_no_allowed_urls_blocks_external_access() { + let tempdir = tempfile::tempdir().unwrap(); + let data_path = write_json_rows(tempdir.path(), "data.json", &[1.0, 2.0, 3.0]); + + let args = vec!["--no-allowed-urls".to_string()]; + let (_server, address) = spawn_server(&args).await; + + let err = query_extent( + address, + &extent_spec(json!(data_path.to_str().unwrap().to_string())), + ) + .await + .unwrap_err(); + assert!( + err.contains("blocked by allowed_base_urls"), + "unexpected error: {err}" + ); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_server_repeatable_allowed_base_url_flags_allow_multiple_roots() { + let first_dir = tempfile::tempdir().unwrap(); + let second_dir = tempfile::tempdir().unwrap(); + let blocked_dir = tempfile::tempdir().unwrap(); + write_json_rows(first_dir.path(), "first.json", &[1.0, 2.0, 3.0]); + let second_path = write_json_rows(second_dir.path(), "second.json", &[4.0, 5.0, 6.0]); + let blocked_path = write_json_rows(blocked_dir.path(), "blocked.json", &[7.0, 8.0, 9.0]); + + let args = vec![ + "--allowed-base-url".to_string(), + first_dir.path().to_str().unwrap().to_string(), + "--allowed-base-url".to_string(), + second_dir.path().to_str().unwrap().to_string(), + ]; + let (_server, address) = spawn_server(&args).await; + + let extent = query_extent( + address.clone(), + &extent_spec(json!(second_path.to_str().unwrap().to_string())), + ) + .await + .unwrap(); + assert_eq!(extent, [4.0, 6.0]); + + let err = query_extent( + address, + &extent_spec(json!(blocked_path.to_str().unwrap().to_string())), + ) + .await + .unwrap_err(); + assert!( + err.contains("blocked by allowed_base_urls"), + "unexpected error: {err}" + ); } diff --git a/vegafusion-wasm/src/lib.rs b/vegafusion-wasm/src/lib.rs index 4105f48ad..758122799 100644 --- a/vegafusion-wasm/src/lib.rs +++ b/vegafusion-wasm/src/lib.rs @@ -446,6 +446,7 @@ pub async fn vegafusion_embed( js_sys::JSON::stringify(&e).unwrap() )) })?; + Box::new(QueryFnVegaFusionRuntime::new(query_fn)) };