Skip to content

Prediction metrics

prediction_metrics

Functions

compute_and_show_cf(pred_labels, gt_labels, labels=None, use_labels_from='both', vis=True, cf_plot_savefile=None, cf_np_savefile=None)

summary

Parameters:

Name Type Description Default
pred_labels list

description

required
gt_labels list

description

required
labels Union[None, List[str]]

description. Defaults to None.

None
use_labels_from str

description. Defaults to "both".

'both'
vis bool

description. Defaults to True.

True
cf_plot_savefile Union[None, PATH_TYPE]

description. Defaults to None.

None
cf_np_savefile Union[None, PATH_TYPE]

description. Defaults to None.

None

Raises:

Type Description
ValueError

description

Returns:

Name Type Description
_type_

description

Source code in geograypher/utils/prediction_metrics.py
def compute_and_show_cf(
    pred_labels: list,
    gt_labels: list,
    labels: typing.Union[None, typing.List[str]] = None,
    use_labels_from: str = "both",
    vis: bool = True,
    cf_plot_savefile: typing.Union[None, PATH_TYPE] = None,
    cf_np_savefile: typing.Union[None, PATH_TYPE] = None,
):
    """_summary_

    Args:
        pred_labels (list): _description_
        gt_labels (list): _description_
        labels (typing.Union[None, typing.List[str]], optional): _description_. Defaults to None.
        use_labels_from (str, optional): _description_. Defaults to "both".
        vis (bool, optional): _description_. Defaults to True.
        cf_plot_savefile (typing.Union[None, PATH_TYPE], optional): _description_. Defaults to None.
        cf_np_savefile (typing.Union[None, PATH_TYPE], optional): _description_. Defaults to None.

    Raises:
        ValueError: _description_

    Returns:
        _type_: _description_
    """
    if labels is None:
        if use_labels_from == "gt":
            labels = np.unique(list(gt_labels))
        elif use_labels_from == "pred":
            labels = np.unique(list(pred_labels))
        elif use_labels_from == "both":
            labels = np.unique(list(pred_labels) + list(gt_labels))
        else:
            raise ValueError(
                f"Must use labels from gt, pred, or both but instead was {use_labels_from}"
            )

    cf_matrix = confusion_matrix(y_true=gt_labels, y_pred=pred_labels, labels=labels)

    if vis:
        cf_disp = ConfusionMatrixDisplay(
            confusion_matrix=cf_matrix, display_labels=labels
        )
        cf_disp.plot()
        if cf_plot_savefile is None:
            plt.show()
        else:
            ensure_containing_folder(cf_plot_savefile)
            plt.savefig(cf_plot_savefile)

    if cf_np_savefile:
        ensure_containing_folder(cf_np_savefile)
        np.save(cf_np_savefile, cf_matrix)

    # TODO compute more comprehensive metrics here
    accuracy = np.sum(cf_matrix * np.eye(cf_matrix.shape[0])) / np.sum(cf_matrix)

    return cf_matrix, labels, accuracy