Source code for pyCP_APR.datasets
"""
datasets.py is used to load the example tensors.
@author Maksim Ekin Eren
"""
import numpy as np
import pandas as pd
import os
import sys
import requests
[docs]def list_datasets():
"""
This function returns the list of tensor names that are available to load.\n
If the listing is requested for the first time, a new directory for the datasets is created.
Returns
-------
datasets : list
List of tensor names that are available to load.
.. note::
**Example**
.. code-block:: python
from pyCP_APR.datasets import list_datasets
list_datasets()
.. code-block:: console
['TOY']
"""
files = _get_file_paths()["files"]
datasets = list()
for file in files:
if ".npz" in file:
datasets.append(file.split(".")[0])
return datasets
[docs]def load_dataset(name="TOY"):
"""
Loads the tensor specified by its name.
.. warning::
If a dataset is requested for the first time, it gets downloaded from GitHub.
Parameters
----------
name : string, optional
The name of the tensor to load. The default is ``name="TOY"``.
Returns
-------
data : Numpy NPZ
Tensor contents compressed in Numpy NPZ format.
.. note::
**Example**
.. code-block:: python
from pyCP_APR.datasets import load_dataset
# Load a sample authentication training and test tensors along with the labels
data = load_dataset(name="TOY")
coords_train, nnz_train = data['train_coords'], data['train_count']
coords_test, nnz_test = data['test_coords'], data['test_count']
Available tensor data can be listed as follows:
.. code-block:: python
data = load_dataset(name = "TOY")
list(data)
.. code-block:: console
['train_coords',
'train_count',
'test_coords',
'test_count']
"""
datasets = _get_file_paths(name+".npz")
if name + str(".npz") not in datasets["files"]:
sys.exit("Dataset is not found! Available datasets are: " + ", ".join(list_datasets()))
return np.load(datasets["path"] + name + str(".npz"), allow_pickle=True)
def _get_file_paths(name=""):
"""
Helper function to extract the absolute path to the dataset, and the list of files in the data folder.\n
.. warning::
If the directory for the data are not available, i.e. when the data is requested for the same time,
a new directory for the data is created.\n
If a dataset is requested for the first time, this helper function downloads the data from the GitHub
repository of pyCP_APR.
Parameters
----------
name : string
The name of the dataset to download if it does not exist.\n
The default is "".
Returns
-------
dataset information : dict
``{"path":string, "files":list}``.
"""
download = False
dirname = os.path.dirname(__file__)
files_dirname = os.path.join(dirname, 'data/')
try:
# attempt to read the data directory
files = os.listdir(files_dirname)
except FileNotFoundError:
# first time calling, create the datasets directory
os.mkdir(files_dirname)
print("Created:", files_dirname)
files = os.listdir(files_dirname)
# if not only listing the datasets, and if dataset is not downloaded yet
if name != "" and name not in files:
download = True
# if only listing the datasets and non is downloaded
elif name == "" and len(files) == 0:
print("No datasets are downloaded.")
print("See https://github.com/lanl/pyCP_APR/tree/main/data/tensors for available datasets.")
print("Example dataset name: \"TOY\".")
print("Datasets will be downloaded when pyCP_APR.datasets.load_dataset(name=\"TOY\") is called first time.")
# if we need to download the dataset
if download:
print("Downloading the dataset:", name)
url = "https://raw.githubusercontent.com//lanl/pyCP_APR/main/data/tensors/" + name
r = requests.get(url, allow_redirects=True)
open(files_dirname + name, 'wb').write(r.content)
files = os.listdir(files_dirname)
return {"path":files_dirname, "files":files}