Skip to content

Ortho segmentor

ortho_segmentor

Functions

assemble_tiled_predictions(raster_input_file, pred_files, class_savefile, num_classes, counts_savefile=None, downweight_edge_frac=0.25, nodataval=NULL_TEXTURE_INT_VALUE, count_dtype=np.uint8, max_overlapping_tiles=4)

Take tiled predictions on disk and aggregate them into a raster

Parameters:

Name Type Description Default
pred_files list[PATH_TYPE]

List of filenames where predictions are written

required
class_savefile PATH_TYPE

Where to save the merged raster.

required
counts_savefile Union[PATH_TYPE, NoneType]

Where to save the counts for the merged predictions raster. A tempfile will be created and then deleted if not specified. Defaults to None.

None
downweight_edge_frac float

Downweight this fraction of predictions at the edge of each tile using a linear ramp. Defaults to 0.25.

0.25
nodataval Union[int, None]

(typing.Union[int, None]): Value for unassigned pixels. If None, will be set to len(class_names), the first unused class. Defaults to 255

NULL_TEXTURE_INT_VALUE
count_dtype type

What type to use for aggregation. Float uses more space but is more accurate. Defaults to np.uint8

uint8
max_overlapping_tiles int

The max number of prediction tiles that may overlap at a given point. This is used to upper bound the valud in the count matrix, because we use scaled np.uint8 values rather than floats for efficiency. Setting a lower value enables slightly more accuracy in the aggregation process, but too low can lead to overflow. Defaults to 4

4
Source code in geograypher/predictors/ortho_segmentor.py
def assemble_tiled_predictions(
    raster_input_file: PATH_TYPE,
    pred_files: list[PATH_TYPE],
    class_savefile: PATH_TYPE,
    num_classes: int,
    counts_savefile: typing.Union[PATH_TYPE, None] = None,
    downweight_edge_frac: float = 0.25,
    nodataval: typing.Union[int, None] = NULL_TEXTURE_INT_VALUE,
    count_dtype: type = np.uint8,
    max_overlapping_tiles: int = 4,
):
    """Take tiled predictions on disk and aggregate them into a raster

    Args:
        pred_files (list[PATH_TYPE]): List of filenames where predictions are written
        class_savefile (PATH_TYPE): Where to save the merged raster.
        counts_savefile (typing.Union[PATH_TYPE, NoneType], optional):
            Where to save the counts for the merged predictions raster.
            A tempfile will be created and then deleted if not specified. Defaults to None.
        downweight_edge_frac (float, optional): Downweight this fraction of predictions at the edge of each tile using a linear ramp. Defaults to 0.25.
        nodataval: (typing.Union[int, None]): Value for unassigned pixels. If None, will be set to len(class_names), the first unused class. Defaults to 255
        count_dtype (type, optional): What type to use for aggregation. Float uses more space but is more accurate. Defaults to np.uint8
        max_overlapping_tiles (int):
            The max number of prediction tiles that may overlap at a given point. This is used to upper bound the valud in the count matrix,
            because we use scaled np.uint8 values rather than floats for efficiency. Setting a lower value enables slightly more accuracy in the
            aggregation process, but too low can lead to overflow. Defaults to 4
    """
    # Set nodataval to the first unused class ID
    if nodataval is None:
        nodataval = num_classes

    # If the user didn't specify where to write the counts, create a tempfile that will be deleted
    if counts_savefile is None:
        # Create the containing folder if required
        ensure_containing_folder(class_savefile)
        counts_savefile_manager = tempfile.NamedTemporaryFile(
            mode="w+", suffix=".tif", dir=class_savefile.parent
        )
        counts_savefile = counts_savefile_manager.name

    # Parse the filenames to get the windows
    # TODO consider using the extent to only write a file for the minimum encolsing rectangle
    windows, extent = parse_windows_from_files(pred_files, return_in_extent_coords=True)

    # Aggregate predictions
    with rio.open(raster_input_file) as src:
        # Create file to store counts that is the same as the input raster except it has num_classes number of bands
        # TODO make this only the size of the extent computed by parse_windows_from_files
        extent_transform = src.window_transform(extent)

        with rio.open(
            counts_savefile,
            "w+",
            driver="GTiff",
            height=extent.height,
            width=extent.width,
            count=num_classes,
            dtype=count_dtype,
            crs=src.crs,
            transform=extent_transform,
        ) as dst:
            # Create
            pred_weighting_dict = {}
            for pred_file, window in tqdm(
                zip(pred_files, windows),
                desc="Aggregating raster predictions",
                total=len(pred_files),
            ):
                # Read the prediction from disk
                pred = read_image_or_numpy(pred_file)

                if pred.shape != (window.height, window.width):
                    raise ValueError("Size of pred does not match window")

                # We want to downweight portions at the edge so we create a ramped weighting mask
                # but we don't want to duplicate this computation because it's the same for each same sized chip
                if pred.shape not in pred_weighting_dict:
                    # We want to keep this as a uint8
                    pred_weighting = create_ramped_weighting(
                        pred.shape, downweight_edge_frac
                    )

                    # Allow us to get as much granularity as possible given the datatype
                    if count_dtype is not float:
                        pred_weighting = pred_weighting * (
                            np.iinfo(count_dtype).max / max_overlapping_tiles
                        )
                    # Convert weighting to desired type
                    pred_weighting_dict[pred.shape] = pred_weighting.astype(count_dtype)

                # Get weighting
                pred_weighting = pred_weighting_dict[pred.shape]

                # Update each band in the counts file within the window
                for i in range(num_classes):
                    # Bands in rasterio are 1-indexed
                    band_ind = i + 1
                    class_i_window_counts = dst.read(band_ind, window=window)
                    class_i_preds = pred == i
                    # If nothing matches this class, don't waste computation
                    if not np.any(class_i_preds):
                        continue
                    # Weight the predictions to downweight the ones at the edge
                    weighted_preds = (class_i_preds * pred_weighting).astype(
                        count_dtype
                    )
                    # Add the new predictions to the previous counts
                    class_i_window_counts += weighted_preds
                    # Write out the updated results for this window
                    dst.write(class_i_window_counts, band_ind, window=window)

    ## Convert counts file to max-class file

    with rio.open(counts_savefile, "r") as src:
        # Create a one-band file to store the index of the most predicted class
        with rio.open(
            class_savefile,
            "w",
            driver="GTiff",
            height=src.shape[0],
            width=src.shape[1],
            count=1,
            dtype=np.uint8,
            crs=src.crs,
            transform=src.transform,
            nodata=nodataval,
        ) as dst:
            # Iterate over the blocks corresponding to the tiff driver in the dataset
            # to compute the max class and write it out
            for _, window in tqdm(
                list(src.block_windows()), desc="Writing out max class"
            ):
                # Read in the counts
                counts_array = src.read(window=window)
                # Compute which pixels have no recorded predictions and mask them out
                nodata_mask = np.sum(counts_array, axis=0) == 0

                # If it's all nodata, don't write it out
                # TODO make sure this works as expected
                if np.all(nodata_mask):
                    continue

                # Compute which class had the highest counts
                max_class = np.argmax(counts_array, axis=0)
                max_class[nodata_mask] = nodataval
                # TODO, it would be good to check if it's all nodata and skip the write because that's unneeded
                dst.write(max_class, 1, window=window)

parse_windows_from_files(files, sep=':', return_in_extent_coords=True)

Return the boxes and extent from a list of filenames

Parameters:

Name Type Description Default
files list[Path]

List of filenames

required
sep str

Seperator between elements

':'
return_in_extent_coords bool

Return in the coordinate frame of the extent

True

Returns:

Type Description
tuple[list[Window], Window]

tuple[list[Window], Window]: List of windows for each file and extent

Source code in geograypher/predictors/ortho_segmentor.py
def parse_windows_from_files(
    files: list[Path], sep: str = ":", return_in_extent_coords: bool = True
) -> tuple[list[Window], Window]:
    """Return the boxes and extent from a list of filenames

    Args:
        files (list[Path]): List of filenames
        sep (str): Seperator between elements
        return_in_extent_coords (bool): Return in the coordinate frame of the extent

    Returns:
        tuple[list[Window], Window]: List of windows for each file and extent
    """
    # Split the coords out, currently ignorign the filename as the first element
    coords = [file.stem.split(sep)[1:] for file in files]

    # Compute the extents as the min/max of the boxes
    coords_array = np.array(coords).astype(int)

    xmin = np.min(coords_array[:, 0])
    ymin = np.min(coords_array[:, 1])
    xmax = np.max(coords_array[:, 2] + coords_array[:, 0])
    ymax = np.max(coords_array[:, 3] + coords_array[:, 1])
    extent = Window(row_off=ymin, col_off=xmin, width=xmax - xmin, height=ymax - ymin)

    if return_in_extent_coords:
        # Subtract out x and y min so it's w.r.t. the extent coordinates
        coords_array[:, 0] = coords_array[:, 0] - xmin
        coords_array[:, 1] = coords_array[:, 1] - ymin

    # Create windows from coords
    windows = [
        Window(
            col_off=coord[0],
            row_off=coord[1],
            width=coord[2],
            height=coord[3],
        )
        for coord in coords_array.astype(int)
    ]

    return windows, extent