csrtable module
Full Documentation for hippynn.layers.pairs.csr_pairs.csrtable module.
Click here for a summary page.
This file was written with assistance from an LLM.
Light-weight CSRTable container built on PyTorch tensors and utilities.
This module defines:
- starts_from_counts() — fast CSR row-pointer construction.
- row_and_offset() — decode flat indices to (row, in-row offset).
- CSRTable — a minimal, device/dtype-agnostic CSR container with
constructors, filtering, row reindexing, and row-wise Cartesian joins.
Conventions
All index tensors are
torch.long.Shapes are given in brackets, e.g.
[R]for rows,[nnz]for entries.device/dtypefollow the tensors you pass in; no implicit moves/casts.We avoid heavyweight validation in hot paths; use lightweight invariants below.
Invariants
Given a CSRTable with starts: [R+1], cols: [nnz], and data:
startsis non-decreasing withstarts[0] == 0andstarts[-1] == nnz.For each row
r,row_len[r] = starts[r+1] - starts[r] >= 0.Every
vindatahas lengthnnzand aligns withcolsorder.
- class CSRTable(starts: Tensor, cols: Tensor, data: Dict[str, Tensor] | None = None, reorder: bool = True)[source]
Bases:
objectA minimal CSR container allowing for multiple payload arrays..
Attributes
- startsLongTensor, shape
[R+1] Row pointer.
starts[r+1] - starts[r]is the length of rowr.- colsLongTensor, shape
[nnz] Per-entry “column” payload. Its interpretation is user-defined: it can be a global column id, or a local index
0..row_len-1.- datadict[str, Tensor]
Additional per-entry fields, each of length
nnzand aligned tocols.
Design
The class favors simple, tensorized transformations: * No heavyweight validation on construction. *
filter_*andreindexpreserve per-row stability as documented. *outerperforms row-wise Cartesian products with user-supplied ops.See Also
starts_from_counts, row_and_offset
- classmethod from_coo(rows: Tensor, cols: Tensor, data: Dict[str, Tensor] | None = None, reorder: bool = True, nrows=None) CSRTable[source]
Build a CSRTable from COO-style row/col indices.
Parameters
- rowsLongTensor, shape
[nnz] Row id for each entry.
- colsLongTensor, shape
[nnz] Column payload for each entry (can be global column ids).
- datadict[str, Tensor], optional
Additional per-entry fields; each must be length
nnz.- reorderbool, default
True If
True, rows are grouped and entries are made stable within-row (lexicographic by(row, original_position)). IfFalse, current order is assumed already grouped by row.- nrowsint, optional
If provided, fixes the number of rows; otherwise inferred as
rows.max()+1(or0if empty).
Returns
- CSRTable
A table whose
startsencodes per-row cardinalities and whosecolscontains the provided per-entry payloads.
Notes
All inputs must live on the same device.
No deduplication is performed; repeated (row, col) pairs are allowed.
Complexity
O(nnz log nnz) if
reorder=True(stable sort by row), otherwise O(nnz).- rowsLongTensor, shape
- classmethod from_counts(counts: Tensor, row_data: Dict[str, Tensor] | None = None, col_data: Dict[str, Tensor] | None = None) CSRTable[source]
Create a CSR with locally indexed columns from per-row counts.
Parameters
- countsLongTensor, shape
[R] Number of entries per row.
- row_datadict[str, Tensor], optional
Per-row fields to expand to length
nnzusing row indexing.- col_datadict[str, Tensor], optional
Per-column (local) fields to expand using the local column index
0..counts[r]-1for each row.
Returns
- CSRTable
With
colsequal to the local in-row index anddatacontaining expanded fields fromrow_data/col_data.
Examples
>>> counts = torch.tensor([2, 1], dtype=torch.long) >>> row_feat = torch.tensor([[10.0],[20.0]]) >>> col_feat = torch.arange(2) # local 0..C-1 >>> t = CSRTable.from_counts(counts, row_data={"rf": row_feat}, col_data={"cf": col_feat}) >>> t.cols tensor([0, 1, 0])
- countsLongTensor, shape
- classmethod from_mask(mask: Tensor, data: Dict[str, Tensor] | None = None) CSRTable[source]
Build a CSR from a boolean mask.
Parameters
- maskBoolTensor, shape
[R, C] For each
Trueat(r, c), one entry is emitted in rowrwith column payloadc.- datadict[str, Tensor], optional
Per-entry fields aligned to
Truepositions. If provided, each tensor must have shape[R, C, ...]and will be gathered atTrueindices.
Returns
- CSRTable
Grouped by row with stable within-row ordering (scan order).
Notes
This constructor is convenient when you have a dense predicate.
If
maskis all-false, an empty table withRrows is returned.
Complexity
O(R*C) for the boolean scan; storage scales with
nnz.- maskBoolTensor, shape
- classmethod make_empty(n_rows: int, *, dtype: dtype, device: device) CSRTable[source]
Create an empty CSRTable with
n_rowsrows and zero entries.Parameters
- n_rowsint
Number of rows.
- dtypetorch.dtype
Target dtype for the empty
colstensor.- devicetorch.device
Target device for all tensors.
Returns
- CSRTable
A well-formed empty table with
starts == [0]* (n_rows+1).
Examples
>>> CSRTable.make_empty(3, dtype=torch.long, device=torch.device("cpu")).starts tensor([0, 0, 0, 0])
- static expand_pairings(left_csr: CSRTable, right_csr: CSRTable, left_rows: Tensor, right_rows: Tensor, operations: Dict[Tuple[str | None, str | None, str], callable]) CSRTable[source]
Expand specific row pairings of two CSR tables (generalized join).
Parameters
- left_csrCSRTable
Left-hand table.
- right_csrCSRTable
Right-hand table.
- left_rowsLongTensor, shape
[P] Row ids in
left_csrparticipating in pairings.- right_rowsLongTensor, shape
[P] Row ids in
right_csr;left_rows[i]is paired withright_rows[i].- operationsdict
Same contract as in
outer().
Returns
- CSRTable
Output CSR with
R = Prows (one row per pairing) and with entries equal to the Cartesian product of the paired source rows.
Notes
Use this when you have a precomputed mapping between rows of two tables
(e.g., voxel adjacency), not necessarily one-to-one by index. * Stability is per paired row based on the original per-row orders.
- filter_indices(keep_indices: Tensor) CSRTable[source]
Filter entries by absolute positions in the current value domain.
Parameters
- keep_indicesLongTensor, shape
[K] Absolute positions
0..nnz-1to keep.
Returns
- CSRTable
A new table with recomputed
startsand per-row stable ordering (stability is defined by original positions within each row).
Notes
If
K == 0, returns an empty table with the same row count.- keep_indicesLongTensor, shape
- filter_mask(keep_mask: Tensor) CSRTable[source]
Filter entries by a boolean mask while recomputing per-row pointers.
Parameters
- keep_maskBoolTensor, shape
[nnz] Entries with
Trueare kept.
Returns
- CSRTable
A new table with the same number of rows and with row-wise stable ordering among kept entries.
Notes
Stability is per-row: relative order among kept entries from the same row matches their original order.
If all entries are dropped, returns an empty table with the same row count.
- keep_maskBoolTensor, shape
- outer(other: CSRTable, operations: Dict[Tuple[str | None, str | None, str], callable]) CSRTable[source]
Row-wise Cartesian join between two CSR tables.
For each row
r, forms the Cartesian productself[r] × other[r]and builds output fields by applying user-provided operations.Parameters
- otherCSRTable
Right-hand side table; must have the same number of rows.
- operationsdict[tuple[key_left, key_right, out_key], callable]
Operation spec: *
(None, None, out)→fn()(nullary) *(key, None, out)→fn(self[key])(unary, left) *(None, key, out)→fn(other[key])(unary, right) *(ka, kb, out)→fn(self[ka], other[kb])(binary)Each callable is applied to the expanded positions; it must be broadcasting-compatible or indexable per expanded pair.
Returns
- CSRTable
Output with the same row count and
nnz[r] = len(self[r]) * len(other[r]).
Notes
The output
colsencodes the right-hand local index (or may be a local index depending on implementation); rely on outputdatafor payloads you need downstream.If a row is empty on either side, the corresponding output row is empty.
- reindex(row_ids: Tensor, carry: Tuple[str, ...] | None = None, include_src_pos: bool = False, include_src_cols: bool = False) CSRTable[source]
Materialize a new CSR by selecting rows in the given order.
Parameters
- row_idsLongTensor, shape
[E] Row ids to take from this table. Duplicates are allowed; unspecified rows are dropped. Output row count equals
E.- carrytuple[str, …] or None, optional
Names of
datafields to carry over.Nonecarries all fields.- include_src_posbool, default
False If
True, addsdata["src_pos"]with absolute source positions.- include_src_colsbool, default
False If
True, addsdata["src_col"]with sourcecolsvalues.
Returns
- CSRTable
New table with rows arranged to match
row_idsand withcolsset to local in-row indices0..len(row)-1.
Stability
Within each emitted row, the original per-row order is preserved.
Examples
>>> t2 = t.reindex(torch.tensor([3, 1, 3]), include_src_pos=True)
- row_idsLongTensor, shape
- startsLongTensor, shape
- find_indices(query_vals: Tensor, keys: Tensor) Tuple[Tensor][source]
Map query values to their indices in keys, or -1 if not found.
This function allows you to look at the CSR rows for a table and see if they exist in another csr table.
- Args:
keys: [V] long tensor, strictly increasing query_vals: […] long tensor
- Returns:
Tuple: (locations, indices) locations: indices of query values which are in the key. indices: Index of the key corresponding to those query values.
- row_and_offset(starts: Tensor) Tuple[Tensor, Tensor][source]
Decode the flattened CSR value domain into row ids and in-row offsets.
Parameters
- startsLongTensor, shape
[R+1] CSR row pointer. Must satisfy the standard CSR invariants.
Returns
- rows_of_entryLongTensor, shape
[nnz] For each entry position
t ∈ [0, nnz), the owning row id.- offset_in_rowLongTensor, shape
[nnz] For each entry position
t, the offset within its row (i.e.,0..row_len[row)-1).
Notes
If
nnz == 0, both outputs are empty tensors on the same device.Examples
>>> starts = torch.tensor([0, 2, 2, 5], dtype=torch.long) >>> rows, offs = row_and_offset(starts) >>> rows, offs (tensor([0, 0, 2, 2, 2]), tensor([0, 1, 0, 1, 2]))
Complexity
O(nnz) time, O(1) extra memory.
- startsLongTensor, shape
- starts_from_counts(counts: Tensor) Tensor[source]
Build a CSR row-pointer (starts) from per-row counts.
Parameters
- countsLongTensor, shape
[R] Number of entries in each row.
Returns
- startsLongTensor, shape
[R+1] CSR row pointer with
starts[0] == 0andtorch.diff(starts) == counts. The final element equalsnnz.
Notes
countsmust be non-negative. No explicit check is performed.deviceanddtypeare inherited fromcounts(cast tolongif needed).
Examples
>>> counts = torch.tensor([2, 0, 3], dtype=torch.long) >>> starts = starts_from_counts(counts) >>> starts tensor([0, 2, 2, 5])
Complexity
O(R) time, O(1) extra memory.
- countsLongTensor, shape