Source code for clusx.utils

"""
Utility functions for the clusx package.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np

if TYPE_CHECKING:
    from typing import Union

    import torch
    from numpy.typing import NDArray

    EmbeddingTensor = Union[torch.Tensor, NDArray[np.float32]]


[docs] def to_numpy(embedding: EmbeddingTensor) -> NDArray[np.float32]: """ Convert a tensor to a numpy array. If embedding is already a numpy array (or compatible), it is returned as is. Otherwise, it is converted to a numpy array. Parameters ---------- embedding : EmbeddingTensor The tensor to convert. Can be a PyTorch tensor or a numpy array. Returns ------- numpy.ndarray The input converted to a numpy array. If the input is already a numpy array (or compatible), it is returned as is. """ # Use duck typing to check if this is a PyTorch tensor # PyTorch tensors have detach() method, numpy arrays don't if hasattr(embedding, "detach"): return embedding.detach().cpu().numpy() # type: ignore[attr-defined] # Already numpy or other array-like return np.asarray(embedding)