Source code for mldft.utils.sparse

import torch


[docs] def construct_block_diag_coo_indices_and_shape( *block_shapes: tuple[int, int], device: torch.device = None, ) -> (torch.Tensor, tuple[int, int]): """Construct the indices for a block diagonal sparse matrix in COOrdinate format. >>> blocks = [torch.ones((2, 2)), torch.ones((3, 3))] >>> indices, shape = construct_block_diag_coo_indices_and_shape(*[b.shape for b in blocks]) >>> torch.sparse_coo_tensor(indices, torch.cat([b.flatten() for b in blocks]), size=shape).to_dense() Args: block_shapes: The shapes of the blocks. device: Which device to use. Returns: torch.Tensor: The indices of the blocks. tuple: The shape of the block diagonal matrix. """ block_offsets = torch.cumsum(torch.tensor([(0, 0)] + list(block_shapes), device=device), dim=0) shape = tuple(block_offsets[-1]) indices = [] for block_size, block_offset in zip(block_shapes, block_offsets): indices.append( torch.stack( torch.meshgrid( torch.arange(block_size[0], device=device), torch.arange(block_size[1], device=device), indexing="ij", ) ).reshape(2, -1) + block_offset[:, None] ) indices = torch.cat(indices, dim=-1) return indices, shape
[docs] def construct_block_diag_coo_tensor_indices_and_shape_from_sparse( *blocks: torch.Tensor, device: torch.device = None, ) -> (torch.Tensor, tuple[int, int]): """Construct the indices for a block diagonal sparse matrix in COOrdinate format where the blocks are sparse tensors. Args: blocks: The blocks of the block diagonal tensor. device: Which device to use. Returns: torch.Tensor: The indices of the blocks. tuple: The shape of the block diagonal matrix. """ block_shapes = torch.tensor([b.shape for b in blocks], dtype=torch.int, device=device) block_shapes = torch.cat((torch.zeros((1, 2), dtype=torch.int), block_shapes), dim=0) block_offsets = torch.cumsum(block_shapes, dim=0) block_indices = [b.coalesce().indices() for b in blocks] shape = tuple(block_offsets[-1]) indices = torch.tensor([], dtype=torch.int, device=device) for idx, block_offset in zip(block_indices, block_offsets): indices = torch.cat((indices, idx + block_offset.unsqueeze(1)), dim=1) return indices, shape
[docs] def construct_block_diag_coo_tensor(*blocks: torch.Tensor) -> torch.Tensor: """Construct a block diagonal tensor from the given blocks. >>> blocks = [torch.ones((2, 2)), torch.ones((3, 3))] >>> construct_block_diag_tensor(*blocks) Args: blocks: The blocks of the block diagonal tensor. Returns: torch.Tensor: The block diagonal tensor. """ assert not any(b.is_sparse for b in blocks) != all( b.is_sparse for b in blocks ), "Either all or none of the blocks must be sparse tensors." device = blocks[0].device if any(b.is_sparse for b in blocks): indices, shape = construct_block_diag_coo_tensor_indices_and_shape_from_sparse( *blocks, device=device ) values = torch.cat([b.coalesce().values() for b in blocks]) else: indices, shape = construct_block_diag_coo_indices_and_shape( *[b.shape for b in blocks], device=device ) values = torch.cat([b.flatten() for b in blocks]) return torch.sparse_coo_tensor(indices, values, size=shape)