Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 51 additions & 20 deletions python/grass/jupyter/baseseriesmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,54 @@
import shutil
import multiprocessing

from functools import partial

import grass.script as gs

from .map import Map
from .utils import get_number_of_cores, save_gif


def _render_worker_base(
i: int,
tmpdir: str | None = None,
base_file: str | None = None,
width: int | None = None,
height: int | None = None,
calls: list | None = None,
indices: list | None = None,
env: dict | None = None,
):
"""Render a single layer.

Being at top-level, this function isolates rendering
from the BaseSeriesMap object or object derived from it,
because any of those objects contains any attribute that
creates a lock, parallel processing with spawn would fail,
if this function was part of the object itself.
"""
filename = os.path.join(tmpdir, f"{i}.png")
shutil.copyfile(base_file, filename)
img = Map(
width=width,
height=height,
filename=filename,
use_region=True,
env=env,
read_file=True,
)
for grass_module, kwargs in calls[i]:
if grass_module is not None:
img.run(grass_module, **kwargs)
return indices[i], filename


class BaseSeriesMap:
"""
Base class for SeriesMap and TimeSeriesMap
"""

def __init__(self, width=None, height=None, env=None):
def __init__(self, width: int = 600, height: int = 400, env=None):
"""Creates an instance of the visualizations class.

:param int width: width of map in pixels
Expand Down Expand Up @@ -101,25 +137,8 @@ def _render_baselayers(self, img):
for grass_module, kwargs in self._base_layer_calls:
img.run(grass_module, **kwargs)

def _render_worker(self, i):
"""Function to render a single layer."""
filename = os.path.join(self._tmpdir.name, f"{i}.png")
shutil.copyfile(self.base_file, filename)
img = Map(
width=self._width,
height=self._height,
filename=filename,
use_region=True,
env=self._env,
read_file=True,
)
for grass_module, kwargs in self._calls[i]:
if grass_module is not None:
img.run(grass_module, **kwargs)
return self._indices[i], filename

def render(self):
"""Renders image for each raster in series.
"""Render an image for each map in series.

Save PNGs to temporary directory. Must be run before creating a visualization
(i.e. show or save).
Expand All @@ -136,6 +155,7 @@ def render(self):
# Random name needed to avoid potential conflict with layer names
random_name_base = gs.append_random("base", 8) + ".png"
self.base_file = os.path.join(self._tmpdir.name, random_name_base)
base_file_path = os.path.join(self._tmpdir.name, random_name_base)
img = Map(
width=self._width,
height=self._height,
Expand All @@ -151,9 +171,20 @@ def render(self):
self._render_baselayers(img)

# Render layers in respective classes

cores = get_number_of_cores(len(tasks), env=self._env)
render_worker = partial(
_render_worker_base,
tmpdir=str(self._tmpdir.name),
base_file=base_file_path,
width=int(self._width),
height=int(self._height),
calls=self._calls,
indices=self._indices,
env=dict(self._env) if self._env else None,
)
with multiprocessing.Pool(processes=cores) as pool:
results = pool.starmap(self._render_worker, tasks)
results = pool.starmap(render_worker, tasks)

for i, filename in results:
self._base_filename_dict[i] = filename
Expand Down
47 changes: 27 additions & 20 deletions python/grass/jupyter/region.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def __init__(self, use_region, saved_region, src_env, tgt_env):
self._set_bbox(self._src_env)
if self._saved_region:
self._src_env["GRASS_REGION"] = gs.region_env(
region=self._saved_region, env=self._src_env
region=self._saved_region,
env=self._src_env,
)
set_target_region(src_env=self._src_env, tgt_env=self._tgt_env)
self._resolution = self._get_psmerc_region_resolution()
Expand Down Expand Up @@ -145,7 +146,8 @@ def __init__(self, use_region, saved_region, width, height, env):

def set_region_from_env(self, env):
"""Copies GRASS_REGION from provided environment
to local environment to set the computational region"""
to local environment to set the computational region
"""
if "GRASS_REGION" in env:
self._env["GRASS_REGION"] = env["GRASS_REGION"]

Expand Down Expand Up @@ -177,7 +179,8 @@ def set_region_from_command(self, module, **kwargs):
"""
if self._saved_region:
self._env["GRASS_REGION"] = gs.region_env(
region=self._saved_region, env=self._env
region=self._saved_region,
env=self._env,
)
return
if self._use_region:
Expand All @@ -194,7 +197,8 @@ def set_region_from_command(self, module, **kwargs):
if module.startswith("d.vect"):
if not self._resolution_set and not self._extent_set:
self._env["GRASS_REGION"] = gs.region_env(
vector=name, env=self._env
vector=name,
env=self._env,
)
self._extent_set = True
elif not self._resolution_set and not self._extent_set:
Expand Down Expand Up @@ -243,7 +247,8 @@ def set_region_from_rasters(self, rasters):
"""
if self._saved_region:
self._env["GRASS_REGION"] = gs.region_env(
region=self._saved_region, env=self._env
region=self._saved_region,
env=self._env,
)
return
if self._use_region:
Expand All @@ -267,7 +272,8 @@ def set_region_from_vectors(self, vectors):
"""
if self._saved_region:
self._env["GRASS_REGION"] = gs.region_env(
region=self._saved_region, env=self._env
region=self._saved_region,
env=self._env,
)
return
if self._use_region:
Expand Down Expand Up @@ -334,7 +340,7 @@ def __init__(self, use_region, saved_region, env):
self._use_region = use_region
self._saved_region = saved_region

def set_region_from_timeseries(self, timeseries, element_type="strds"):
def set_region_from_timeseries(self, stds: object | None = None) -> None:
"""Sets computational region for rendering.

This function sets the computation region from the extent of
Expand All @@ -346,24 +352,25 @@ def set_region_from_timeseries(self, timeseries, element_type="strds"):
"""
if self._saved_region:
self._env["GRASS_REGION"] = gs.region_env(
region=self._saved_region, env=self._env
region=self._saved_region,
env=self._env,
)
return
if self._use_region:
# use current
return
# Get extent, resolution from space time dataset
info = gs.parse_command(
"t.info", input=timeseries, type=element_type, flags="g", env=self._env
)
# Set grass region from extent
if not stds:
raise RuntimeError(_("No SpaceTimeDataset provided to set region from."))
# Set grass region from STDS extent
params = {
"n": info["north"],
"s": info["south"],
"e": info["east"],
"w": info["west"],
"n": stds.spatial_extent.north,
"s": stds.spatial_extent.south,
"e": stds.spatial_extent.east,
"w": stds.spatial_extent.west,
}
if "nsres_min" in info:
params["nsres"] = info["nsres_min"]
params["ewres"] = info["ewres_min"]
region_dict = stds.metadata.__dict__["D"]
for resolution in ("nsres", "ewres"):
resolution_value = region_dict.get(f"{resolution}_min")
if resolution_value:
params[resolution] = resolution_value
self._env["GRASS_REGION"] = gs.region_env(**params, env=self._env)
10 changes: 5 additions & 5 deletions python/grass/jupyter/seriesmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ class SeriesMap(BaseSeriesMap):

def __init__(
self,
width=None,
height=None,
env=None,
use_region=False,
saved_region=None,
width: int = 600,
height: int = 400,
env: dict | None = None,
use_region: bool = False,
saved_region: str | None = None,
):
"""Creates an instance of the SeriesMap visualizations class.

Expand Down
Loading
Loading