import torch
[docs]
def construct_block_diag_coo_indices_and_shape(
    *block_shapes: tuple[int, int],
    device: torch.device = None,
) -> (torch.Tensor, tuple[int, int]):
    """Construct the indices for a block diagonal sparse matrix in COOrdinate format.
    >>> blocks = [torch.ones((2, 2)), torch.ones((3, 3))]
    >>> indices, shape = construct_block_diag_coo_indices_and_shape(*[b.shape for b in blocks])
    >>> torch.sparse_coo_tensor(indices, torch.cat([b.flatten() for b in blocks]), size=shape).to_dense()
    Args:
        block_shapes: The shapes of the blocks.
        device: Which device to use.
    Returns:
        torch.Tensor: The indices of the blocks.
        tuple: The shape of the block diagonal matrix.
    """
    block_offsets = torch.cumsum(torch.tensor([(0, 0)] + list(block_shapes), device=device), dim=0)
    shape = tuple(block_offsets[-1])
    indices = []
    for block_size, block_offset in zip(block_shapes, block_offsets):
        indices.append(
            torch.stack(
                torch.meshgrid(
                    torch.arange(block_size[0], device=device),
                    torch.arange(block_size[1], device=device),
                    indexing="ij",
                )
            ).reshape(2, -1)
            + block_offset[:, None]
        )
    indices = torch.cat(indices, dim=-1)
    return indices, shape 
[docs]
def construct_block_diag_coo_tensor_indices_and_shape_from_sparse(
    *blocks: torch.Tensor,
    device: torch.device = None,
) -> (torch.Tensor, tuple[int, int]):
    """Construct the indices for a block diagonal sparse matrix in COOrdinate format where the
    blocks are sparse tensors.
    Args:
        blocks: The blocks of the block diagonal tensor.
        device: Which device to use.
    Returns:
        torch.Tensor: The indices of the blocks.
        tuple: The shape of the block diagonal matrix.
    """
    block_shapes = torch.tensor([b.shape for b in blocks], dtype=torch.int, device=device)
    block_shapes = torch.cat((torch.zeros((1, 2), dtype=torch.int), block_shapes), dim=0)
    block_offsets = torch.cumsum(block_shapes, dim=0)
    block_indices = [b.coalesce().indices() for b in blocks]
    shape = tuple(block_offsets[-1])
    indices = torch.tensor([], dtype=torch.int, device=device)
    for idx, block_offset in zip(block_indices, block_offsets):
        indices = torch.cat((indices, idx + block_offset.unsqueeze(1)), dim=1)
    return indices, shape 
[docs]
def construct_block_diag_coo_tensor(*blocks: torch.Tensor) -> torch.Tensor:
    """Construct a block diagonal tensor from the given blocks.
    >>> blocks = [torch.ones((2, 2)), torch.ones((3, 3))]
    >>> construct_block_diag_tensor(*blocks)
    Args:
        blocks:         The blocks of the block diagonal tensor.
    Returns:
        torch.Tensor:   The block diagonal tensor.
    """
    assert not any(b.is_sparse for b in blocks) != all(
        b.is_sparse for b in blocks
    ), "Either all or none of the blocks must be sparse tensors."
    device = blocks[0].device
    if any(b.is_sparse for b in blocks):
        indices, shape = construct_block_diag_coo_tensor_indices_and_shape_from_sparse(
            *blocks, device=device
        )
        values = torch.cat([b.coalesce().values() for b in blocks])
    else:
        indices, shape = construct_block_diag_coo_indices_and_shape(
            *[b.shape for b in blocks], device=device
        )
        values = torch.cat([b.flatten() for b in blocks])
    return torch.sparse_coo_tensor(indices, values, size=shape)