csrtable module

Full Documentation for hippynn.layers.pairs.csr_pairs.csrtable module. Click here for a summary page.

This file was written with assistance from an LLM.

Light-weight CSRTable container built on PyTorch tensors and utilities.

This module defines: - starts_from_counts() — fast CSR row-pointer construction. - row_and_offset() — decode flat indices to (row, in-row offset). - CSRTable — a minimal, device/dtype-agnostic CSR container with

constructors, filtering, row reindexing, and row-wise Cartesian joins.

Conventions

  • All index tensors are torch.long.

  • Shapes are given in brackets, e.g. [R] for rows, [nnz] for entries.

  • device/dtype follow the tensors you pass in; no implicit moves/casts.

  • We avoid heavyweight validation in hot paths; use lightweight invariants below.

Invariants

Given a CSRTable with starts: [R+1], cols: [nnz], and data:

  • starts is non-decreasing with starts[0] == 0 and starts[-1] == nnz.

  • For each row r, row_len[r] = starts[r+1] - starts[r] >= 0.

  • Every v in data has length nnz and aligns with cols order.

class CSRTable(starts: Tensor, cols: Tensor, data: Dict[str, Tensor] | None = None, reorder: bool = True)[source]

Bases: object

A minimal CSR container allowing for multiple payload arrays..

Attributes

startsLongTensor, shape [R+1]

Row pointer. starts[r+1] - starts[r] is the length of row r.

colsLongTensor, shape [nnz]

Per-entry “column” payload. Its interpretation is user-defined: it can be a global column id, or a local index 0..row_len-1.

datadict[str, Tensor]

Additional per-entry fields, each of length nnz and aligned to cols.

Design

The class favors simple, tensorized transformations: * No heavyweight validation on construction. * filter_* and reindex preserve per-row stability as documented. * outer performs row-wise Cartesian products with user-supplied ops.

See Also

starts_from_counts, row_and_offset

classmethod from_coo(rows: Tensor, cols: Tensor, data: Dict[str, Tensor] | None = None, reorder: bool = True, nrows=None) CSRTable[source]

Build a CSRTable from COO-style row/col indices.

Parameters

rowsLongTensor, shape [nnz]

Row id for each entry.

colsLongTensor, shape [nnz]

Column payload for each entry (can be global column ids).

datadict[str, Tensor], optional

Additional per-entry fields; each must be length nnz.

reorderbool, default True

If True, rows are grouped and entries are made stable within-row (lexicographic by (row, original_position)). If False, current order is assumed already grouped by row.

nrowsint, optional

If provided, fixes the number of rows; otherwise inferred as rows.max()+1 (or 0 if empty).

Returns

CSRTable

A table whose starts encodes per-row cardinalities and whose cols contains the provided per-entry payloads.

Notes

  • All inputs must live on the same device.

  • No deduplication is performed; repeated (row, col) pairs are allowed.

Complexity

O(nnz log nnz) if reorder=True (stable sort by row), otherwise O(nnz).

classmethod from_counts(counts: Tensor, row_data: Dict[str, Tensor] | None = None, col_data: Dict[str, Tensor] | None = None) CSRTable[source]

Create a CSR with locally indexed columns from per-row counts.

Parameters

countsLongTensor, shape [R]

Number of entries per row.

row_datadict[str, Tensor], optional

Per-row fields to expand to length nnz using row indexing.

col_datadict[str, Tensor], optional

Per-column (local) fields to expand using the local column index 0..counts[r]-1 for each row.

Returns

CSRTable

With cols equal to the local in-row index and data containing expanded fields from row_data/col_data.

Examples

>>> counts = torch.tensor([2, 1], dtype=torch.long)
>>> row_feat = torch.tensor([[10.0],[20.0]])
>>> col_feat = torch.arange(2)  # local 0..C-1
>>> t = CSRTable.from_counts(counts, row_data={"rf": row_feat}, col_data={"cf": col_feat})
>>> t.cols
tensor([0, 1, 0])
classmethod from_mask(mask: Tensor, data: Dict[str, Tensor] | None = None) CSRTable[source]

Build a CSR from a boolean mask.

Parameters

maskBoolTensor, shape [R, C]

For each True at (r, c), one entry is emitted in row r with column payload c.

datadict[str, Tensor], optional

Per-entry fields aligned to True positions. If provided, each tensor must have shape [R, C, ...] and will be gathered at True indices.

Returns

CSRTable

Grouped by row with stable within-row ordering (scan order).

Notes

  • This constructor is convenient when you have a dense predicate.

  • If mask is all-false, an empty table with R rows is returned.

Complexity

O(R*C) for the boolean scan; storage scales with nnz.

classmethod make_empty(n_rows: int, *, dtype: dtype, device: device) CSRTable[source]

Create an empty CSRTable with n_rows rows and zero entries.

Parameters

n_rowsint

Number of rows.

dtypetorch.dtype

Target dtype for the empty cols tensor.

devicetorch.device

Target device for all tensors.

Returns

CSRTable

A well-formed empty table with starts == [0]* (n_rows+1).

Examples

>>> CSRTable.make_empty(3, dtype=torch.long, device=torch.device("cpu")).starts
tensor([0, 0, 0, 0])
static expand_pairings(left_csr: CSRTable, right_csr: CSRTable, left_rows: Tensor, right_rows: Tensor, operations: Dict[Tuple[str | None, str | None, str], callable]) CSRTable[source]

Expand specific row pairings of two CSR tables (generalized join).

Parameters

left_csrCSRTable

Left-hand table.

right_csrCSRTable

Right-hand table.

left_rowsLongTensor, shape [P]

Row ids in left_csr participating in pairings.

right_rowsLongTensor, shape [P]

Row ids in right_csr; left_rows[i] is paired with right_rows[i].

operationsdict

Same contract as in outer().

Returns

CSRTable

Output CSR with R = P rows (one row per pairing) and with entries equal to the Cartesian product of the paired source rows.

Notes

  • Use this when you have a precomputed mapping between rows of two tables

(e.g., voxel adjacency), not necessarily one-to-one by index. * Stability is per paired row based on the original per-row orders.

filter_indices(keep_indices: Tensor) CSRTable[source]

Filter entries by absolute positions in the current value domain.

Parameters

keep_indicesLongTensor, shape [K]

Absolute positions 0..nnz-1 to keep.

Returns

CSRTable

A new table with recomputed starts and per-row stable ordering (stability is defined by original positions within each row).

Notes

If K == 0, returns an empty table with the same row count.

filter_mask(keep_mask: Tensor) CSRTable[source]

Filter entries by a boolean mask while recomputing per-row pointers.

Parameters

keep_maskBoolTensor, shape [nnz]

Entries with True are kept.

Returns

CSRTable

A new table with the same number of rows and with row-wise stable ordering among kept entries.

Notes

  • Stability is per-row: relative order among kept entries from the same row matches their original order.

  • If all entries are dropped, returns an empty table with the same row count.

outer(other: CSRTable, operations: Dict[Tuple[str | None, str | None, str], callable]) CSRTable[source]

Row-wise Cartesian join between two CSR tables.

For each row r, forms the Cartesian product self[r] × other[r] and builds output fields by applying user-provided operations.

Parameters

otherCSRTable

Right-hand side table; must have the same number of rows.

operationsdict[tuple[key_left, key_right, out_key], callable]

Operation spec: * (None, None, out)fn() (nullary) * (key, None, out)fn(self[key]) (unary, left) * (None, key, out)fn(other[key]) (unary, right) * (ka, kb, out)fn(self[ka], other[kb]) (binary)

Each callable is applied to the expanded positions; it must be broadcasting-compatible or indexable per expanded pair.

Returns

CSRTable

Output with the same row count and nnz[r] = len(self[r]) * len(other[r]).

Notes

  • The output cols encodes the right-hand local index (or may be a local index depending on implementation); rely on output data for payloads you need downstream.

  • If a row is empty on either side, the corresponding output row is empty.

reindex(row_ids: Tensor, carry: Tuple[str, ...] | None = None, include_src_pos: bool = False, include_src_cols: bool = False) CSRTable[source]

Materialize a new CSR by selecting rows in the given order.

Parameters

row_idsLongTensor, shape [E]

Row ids to take from this table. Duplicates are allowed; unspecified rows are dropped. Output row count equals E.

carrytuple[str, …] or None, optional

Names of data fields to carry over. None carries all fields.

include_src_posbool, default False

If True, adds data["src_pos"] with absolute source positions.

include_src_colsbool, default False

If True, adds data["src_col"] with source cols values.

Returns

CSRTable

New table with rows arranged to match row_ids and with cols set to local in-row indices 0..len(row)-1.

Stability

Within each emitted row, the original per-row order is preserved.

Examples

>>> t2 = t.reindex(torch.tensor([3, 1, 3]), include_src_pos=True)
property nnz

Total number of entries (non-zeros) across all rows.

Returns

int

Equal to cols.numel() and starts[-1].

property nrows

Number of rows in the table (R).

Returns

int

Equal to starts.numel() - 1.

find_indices(query_vals: Tensor, keys: Tensor) Tuple[Tensor][source]

Map query values to their indices in keys, or -1 if not found.

This function allows you to look at the CSR rows for a table and see if they exist in another csr table.

Args:

keys: [V] long tensor, strictly increasing query_vals: […] long tensor

Returns:

Tuple: (locations, indices) locations: indices of query values which are in the key. indices: Index of the key corresponding to those query values.

row_and_offset(starts: Tensor) Tuple[Tensor, Tensor][source]

Decode the flattened CSR value domain into row ids and in-row offsets.

Parameters

startsLongTensor, shape [R+1]

CSR row pointer. Must satisfy the standard CSR invariants.

Returns

rows_of_entryLongTensor, shape [nnz]

For each entry position t [0, nnz), the owning row id.

offset_in_rowLongTensor, shape [nnz]

For each entry position t, the offset within its row (i.e., 0..row_len[row)-1).

Notes

If nnz == 0, both outputs are empty tensors on the same device.

Examples

>>> starts = torch.tensor([0, 2, 2, 5], dtype=torch.long)
>>> rows, offs = row_and_offset(starts)
>>> rows, offs
(tensor([0, 0, 2, 2, 2]), tensor([0, 1, 0, 1, 2]))

Complexity

O(nnz) time, O(1) extra memory.

starts_from_counts(counts: Tensor) Tensor[source]

Build a CSR row-pointer (starts) from per-row counts.

Parameters

countsLongTensor, shape [R]

Number of entries in each row.

Returns

startsLongTensor, shape [R+1]

CSR row pointer with starts[0] == 0 and torch.diff(starts) == counts. The final element equals nnz.

Notes

  • counts must be non-negative. No explicit check is performed.

  • device and dtype are inherited from counts (cast to long if needed).

Examples

>>> counts = torch.tensor([2, 0, 3], dtype=torch.long)
>>> starts = starts_from_counts(counts)
>>> starts
tensor([0, 2, 2, 5])

Complexity

O(R) time, O(1) extra memory.