Source code for deeprank2.dataset

from __future__ import annotations

import inspect
import logging
import os
import pickle
import re
import sys
import warnings
from typing import Literal

import h5py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from torch_geometric.data.data import Data
from torch_geometric.data.dataset import Dataset
from tqdm import tqdm

from deeprank2.domain import edgestorage as Efeat
from deeprank2.domain import gridstorage
from deeprank2.domain import nodestorage as Nfeat
from deeprank2.domain import targetstorage as targets

_log = logging.getLogger(__name__)


[docs]class DeeprankDataset(Dataset): """Parent class of :class:`GridDataset` and :class:`GraphDataset`. This class inherits from :class:`torch_geometric.data.dataset.Dataset`. More detailed information about the parameters can be found in :class:`GridDataset` and :class:`GraphDataset`. """ def __init__( self, hdf5_path: str | list[str], subset: list[str] | None, train_source: str | GridDataset | GraphDataset | None, target: str | None, target_transform: bool | None, target_filter: dict[str, str] | None, task: str | None, classes: list[str] | list[int] | list[float] | None, use_tqdm: bool, root: str, check_integrity: bool, ): super().__init__(root) if isinstance(hdf5_path, str): self.hdf5_paths = [hdf5_path] elif isinstance(hdf5_path, list): self.hdf5_paths = hdf5_path else: msg = f"hdf5_path: unexpected type: {type(hdf5_path)}" raise TypeError(msg) self.subset = subset self.train_source = train_source self.target = target self.target_transform = target_transform self.target_filter = target_filter if check_integrity: self._check_hdf5_files() self._check_task_and_classes(task, classes) self.use_tqdm = use_tqdm # create the indexing system # alows to associate each mol to an index # and get fname and mol name from the index self._create_index_entries() self.df = None self.means = None self.devs = None self.train_means = None self.train_devs = None # get the device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def _check_and_inherit_train( # noqa: C901 self, data_type: GridDataset | GraphDataset, inherited_params: list[str], ) -> None: """Check if the pre-trained model or training set provided are valid for validation and/or testing, and inherit the parameters.""" if isinstance(self.train_source, str): try: if torch.cuda.is_available(): data = torch.load(self.train_source) else: data = torch.load(self.train_source, map_location=torch.device("cpu")) if data["data_type"] is not data_type: msg = ( f"The pre-trained model has been trained with data of type {data['data_type']}, but you are trying \n\t" f"to define a {data_type}-class validation/testing dataset. Please provide a valid DeepRank2 \n\t" f"model trained with {data_type}-class type data, or define the dataset using the appropriate class." ) raise TypeError(msg) if data_type is GraphDataset: self.train_means = data["means"] self.train_devs = data["devs"] # convert strings in 'transform' key to lambda functions if data["features_transform"]: for key in data["features_transform"].values(): if key["transform"] is None: continue key["transform"] = eval(key["transform"]) # noqa: S307 except pickle.UnpicklingError as e: msg = "The path provided to `train_source` is not a valid DeepRank2 pre-trained model." raise ValueError(msg) from e elif isinstance(self.train_source, data_type): data = self.train_source if data_type is GraphDataset: self.train_means = self.train_source.means self.train_devs = self.train_source.devs else: msg = ( f"The train data provided is invalid: {type(self.train_source)}.\n\t" f"Please provide a valid training {data_type} or the path to a valid DeepRank2 pre-trained model." ) raise TypeError(msg) # match parameters with the ones in the training set self._check_inherited_params(inherited_params, data) def _check_hdf5_files(self) -> None: """Checks if the data contained in the .HDF5 file is valid.""" _log.info("\nChecking dataset Integrity...") to_be_removed = [] for hdf5_path in self.hdf5_paths: try: with h5py.File(hdf5_path, "r") as f5: entry_names = list(f5.keys()) if len(entry_names) == 0: _log.info(f" -> {hdf5_path} is empty ") to_be_removed.append(hdf5_path) except Exception as e: # noqa: BLE001, PERF203 _log.error(e) _log.info(f" -> {hdf5_path} is corrupted ") to_be_removed.append(hdf5_path) for hdf5_path in to_be_removed: self.hdf5_paths.remove(hdf5_path) def _check_task_and_classes(self, task: str, classes: str | None = None) -> None: # Determine the task based on the target or use the provided task if task is None: target_to_task_map = { targets.IRMSD: targets.REGRESS, targets.LRMSD: targets.REGRESS, targets.FNAT: targets.REGRESS, targets.DOCKQ: targets.REGRESS, targets.BINARY: targets.CLASSIF, targets.CAPRI: targets.CLASSIF, } self.task = target_to_task_map.get(self.target) else: self.task = task # Validate the task if self.task not in [targets.CLASSIF, targets.REGRESS] and self.target is not None: msg = f"User target detected: {self.target} -> The task argument must be 'classif' or 'regress', currently set as {self.task}" raise ValueError(msg) # Warn if the user-set task does not match the determined task if task and task != self.task: warnings.warn( f"Target {self.target} expects {self.task}, but was set to task {task} by user. User set task is ignored and {self.task} will be used.", ) # Handle classification task if self.task == targets.CLASSIF: if classes is None: self.classes = [0, 1, 2, 3, 4, 5] if self.target == targets.CAPRI else [0, 1] self.classes_to_index = {class_: index for index, class_ in enumerate(self.classes)} _log.info(f"Target classes set to: {self.classes}") else: self.classes = None self.classes_to_index = None def _check_inherited_params( self, inherited_params: list[str], data: dict | GraphDataset | GridDataset, ) -> None: """Check if the parameters for validation and/or testing are the same as in the pre-trained model or training set provided. Args: inherited_params: List of parameters that need to be checked for inheritance. data: The parameters in `inherited_param` will be inherited from the information contained in `data`. """ self_vars = vars(self) if not isinstance(data, dict): data = vars(data) for param in inherited_params: if self_vars[param] != data[param]: if self_vars[param] != self.default_vars[param]: _log.warning( f"The {param} parameter set here is: {self_vars[param]}, " f"which is not equivalent to the one in the training phase: {data[param]}./n" f"Overwriting {param} parameter with the one used in the training phase.", ) setattr(self, param, data[param]) def _create_index_entries(self) -> None: """Creates the indexing of each molecule in the dataset. Creates the indexing: [ ('1ak4.hdf5,1AK4_100w),...,('1fqj.hdf5,1FGJ_400w)]. This allows to refer to one entry with its index in the list. """ _log.debug(f"Processing data set with .HDF5 files: {self.hdf5_paths}") self.index_entries = [] desc = f" {self.hdf5_paths}{' dataset':25s}" if self.use_tqdm: hdf5_path_iterator = tqdm(self.hdf5_paths, desc=desc, file=sys.stdout) else: _log.info(f" {self.hdf5_paths} dataset\n") hdf5_path_iterator = self.hdf5_paths sys.stdout.flush() for hdf5_path in hdf5_path_iterator: if self.use_tqdm: hdf5_path_iterator.set_postfix(entry_name=os.path.basename(hdf5_path)) try: with h5py.File(hdf5_path, "r") as hdf5_file: if self.subset is None: entry_names = list(hdf5_file.keys()) else: entry_names = [entry_name for entry_name in self.subset if entry_name in list(hdf5_file.keys())] # skip self._filter_targets when target_filter is None, improve performance using list comprehension. if self.target_filter is None: self.index_entries += [(hdf5_path, entry_name) for entry_name in entry_names] else: self.index_entries += [(hdf5_path, entry_name) for entry_name in entry_names if self._filter_targets(hdf5_file[entry_name])] except Exception: # noqa: BLE001 _log.exception(f"on {hdf5_path}") def _filter_targets(self, grp: h5py.Group) -> bool: """Filters the entry according to a dictionary. The filter is based on the attribute self.target_filter that must be either of the form: { target_name : target_condition } or None. Args: grp: The entry group in the .HDF5 file. Returns: True if we keep the entry; False otherwise. Raises: ValueError: If an unsuported condition is provided. """ if self.target_filter is None: return True for target_name, target_condition in self.target_filter.items(): present_target_names = list(grp[targets.VALUES].keys()) if target_name in present_target_names: # If we have a given target_condition, see if it's met. if isinstance(target_condition, str): operation = target_condition target_value = grp[targets.VALUES][target_name][()] for operator_string in [">", "<", "==", "<=", ">=", "!="]: operation = operation.replace(operator_string, f"{target_value}" + operator_string) if not eval(operation): # noqa: S307 return False elif target_condition is not None: msg = "Conditions not supported" raise ValueError(msg, target_condition) else: _log.warning(f" :Filter {target_name} not found for entry {grp}\n :Filter options are: {present_target_names}") return True
[docs] def len(self) -> int: """Gets the length of the dataset, either :class:`GridDataset` or :class:`GraphDataset` object. Returns: int: Number of complexes in the dataset. """ return len(self.index_entries)
[docs] def hdf5_to_pandas( # noqa: C901 self, ) -> pd.DataFrame: """Loads features data from the HDF5 files into a Pandas DataFrame in the attribute `df` of the class. Returns: :class:`pd.DataFrame`: Pandas DataFrame containing the selected features as columns per all data points in hdf5_path files. """ df_final = pd.DataFrame() for fname in self.hdf5_paths: with h5py.File(fname, "r") as f: entry_name = next(iter(f.keys())) if self.subset is not None: entry_names = [entry for entry, _ in f.items() if entry in self.subset] else: entry_names = [entry for entry, _ in f.items()] df_dict = {} df_dict["id"] = entry_names for feat_type in self.features_dict: for feat in self.features_dict[feat_type]: # reset transform for each feature transform = None if self.features_transform: transform = self.features_transform.get("all", {}).get("transform") if (transform is None) and (feat in self.features_transform): transform = self.features_transform.get(feat, {}).get("transform") # Check the number of channels the features have if f[entry_name][feat_type][feat][()].ndim == 2: # noqa:PLR2004 for i in range(f[entry_name][feat_type][feat][:].shape[1]): df_dict[feat + "_" + str(i)] = [f[entry_name][feat_type][feat][:][:, i] for entry_name in entry_names] # apply transformation for each channel in this feature if transform: df_dict[feat + "_" + str(i)] = [transform(row) for row in df_dict[feat + "_" + str(i)]] else: df_dict[feat] = [ f[entry_name][feat_type][feat][:] if f[entry_name][feat_type][feat][()].ndim == 1 else f[entry_name][feat_type][feat][()] for entry_name in entry_names ] # apply transformation if transform: df_dict[feat] = [transform(row) for row in df_dict[feat]] df_temp = pd.DataFrame(data=df_dict) df_concat = pd.concat([df_final, df_temp]) self.df = df_concat.reset_index(drop=True) return self.df
[docs] def save_hist( # noqa: C901 self, features: str | list[str], fname: str = "features_hist.png", bins: int | list[float] | str = 10, figsize: tuple = (15, 15), log: bool = False, ) -> None: """After having generated a pd.DataFrame using hdf5_to_pandas method, histograms of the features can be saved in an image. Args: features: Features to be plotted. fname: str or path-like or binary file-like object. Defaults to 'features_hist.png'. bins: If bins is an integer, it defines the number of equal-width bins in the range. If bins is a sequence, it defines the bin edges, including the left edge of the first bin and the right edge of the last bin; in this case, bins may be unequally spaced. All but the last (righthand-most) bin is half-open. If bins is a string, it is one of the binning strategies supported by numpy.histogram_bin_edges: 'auto', 'fd', 'doane', 'scott', 'stone', 'rice', 'sturges', or 'sqrt'. Defaults to 10. figsize: Saved figure sizes. Defaults to (15, 15). log: Whether to apply log transformation to the data indicated by the `features` parameter. Defaults to False. """ if self.df is None: self.hdf5_to_pandas() if not isinstance(features, list): features = [features] features_df = [col for feat in features for col in self.df.columns.to_numpy().tolist() if feat in col] means = [ round(np.concatenate(self.df[feat].to_numpy()).mean(), 1) if isinstance(self.df[feat].to_numpy()[0], np.ndarray) else round(self.df[feat].to_numpy().mean(), 1) for feat in features_df ] devs = [ round(np.concatenate(self.df[feat].to_numpy()).std(), 1) if isinstance(self.df[feat].to_numpy()[0], np.ndarray) else round(self.df[feat].to_numpy().std(), 1) for feat in features_df ] if len(features_df) > 1: fig, axs = plt.subplots(len(features_df), figsize=figsize) for row, feat in enumerate(features_df): if isinstance(self.df[feat].to_numpy()[0], np.ndarray): if log: log_data = np.log(np.concatenate(self.df[feat].to_numpy())) log_data[log_data == -np.inf] = 0 axs[row].hist(log_data, bins=bins) else: axs[row].hist(np.concatenate(self.df[feat].to_numpy()), bins=bins) elif log: log_data = np.log(self.df[feat].to_numpy()) log_data[log_data == -np.inf] = 0 axs[row].hist(log_data, bins=bins) else: axs[row].hist(self.df[feat].to_numpy(), bins=bins) axs[row].set( xlabel=f"{feat} (mean {means[row]}, std {devs[row]})", ylabel="Count", ) fig.tight_layout() elif len(features_df) == 1: fig = plt.figure(figsize=figsize) ax = fig.add_subplot(111) if isinstance(self.df[features_df[0]].to_numpy()[0], np.ndarray): if log: log_data = np.log(np.concatenate(self.df[features_df[0]].to_numpy())) log_data[log_data == -np.inf] = 0 ax.hist(log_data, bins=bins) else: ax.hist(np.concatenate(self.df[features_df[0]].to_numpy()), bins=bins) elif log: log_data = np.log(self.df[features_df[0]].to_numpy()) log_data[log_data == -np.inf] = 0 ax.hist(log_data, bins=bins) else: ax.hist(self.df[features_df[0]].values, bins=bins) ax.set( xlabel=f"{features_df[0]} (mean {means[0]}, std {devs[0]})", ylabel="Count", ) else: msg = "Please provide valid features names. They must be present in the current :class:`DeeprankDataset` children instance." raise ValueError(msg) with warnings.catch_warnings(): warnings.simplefilter("ignore") fig.tight_layout() fig.savefig(fname) plt.close(fig)
def _compute_mean_std(self) -> None: means = { col: round(np.nanmean(np.concatenate(self.df[col].values)), 1) if isinstance(self.df[col].to_numpy()[0], np.ndarray) else round(np.nanmean(self.df[col].to_numpy()), 1) for col in self.df.columns[1:] } devs = { col: round(np.nanstd(np.concatenate(self.df[col].to_numpy())), 1) if isinstance(self.df[col].to_numpy()[0], np.ndarray) else round(np.nanstd(self.df[col].to_numpy()), 1) for col in self.df.columns[1:] } self.means = means self.devs = devs
# Grid features are stored per dimension and named accordingly. # Example: position_001, position_002, position_003 (for x,y,z) # Use this regular expression to take the feature name apart GRID_PARTIAL_FEATURE_NAME_PATTERN = re.compile(r"^([a-zA-Z_]+)_([0-9]{3})$")
[docs]class GridDataset(DeeprankDataset): """Class to load the .HDF5 files data into grids. Args: hdf5_path: Path to .HDF5 file(s). For multiple .HDF5 files, insert the paths in a list. Defaults to None. subset: list of keys from .HDF5 file to include. Defaults to None (meaning include all). train_source: data to inherit information from the training dataset or the pre-trained model. If None, the current dataset is considered as the training set. Otherwise, `train_source` needs to be a dataset of the same class or the path of a DeepRank2 pre-trained model. If set, the parameters `features`, `target`, `traget_transform`, `task`, and `classes` will be inherited from `train_source`. Defaults to None. features: Consider all pre-computed features ("all") or some defined node features (provide a list, example: ["res_type", "polarity", "bsa"]). The complete list can be found in `deeprank2.domain.gridstorage`. Value will be ignored and inherited from `train_source` if `train_source` is assigned. Defaults to "all". target: Default options are irmsd, lrmsd, fnat, binary, capri_class, and dockq. It can also be a custom-defined target given to the Query class as input (see: `deeprank2.query`); in this case, the task parameter needs to be explicitly specified as well. Only numerical target variables are supported, not categorical. If the latter is your case, please convert the categorical classes into numerical class indices before defining the :class:`GraphDataset` instance. Value will be ignored and inherited from `train_source` if `train_source` is assigned. Defaults to None. target_transform: Apply a log and then a sigmoid transformation to the target (for regression only). This puts the target value between 0 and 1, and can result in a more uniform target distribution and speed up the optimization. Value will be ignored and inherited from `train_source` if `train_source` is assigned. Defaults to False. target_filter: Dictionary of type [target: cond] to filter the molecules. Note that the you can filter on a different target than the one selected as the dataset target. Defaults to None. task: 'regress' for regression or 'classif' for classification. Required if target not in ['irmsd', 'lrmsd', 'fnat', 'binary', 'capri_class', or 'dockq'], otherwise this setting is ignored. Automatically set to 'classif' if the target is 'binary' or 'capri_classes'. Automatically set to 'regress' if the target is 'irmsd', 'lrmsd', 'fnat', or 'dockq'. Value will be ignored and inherited from `train_source` if `train_source` is assigned. Defaults to None. classes: Define the dataset target classes in classification mode. Value will be ignored and inherited from `train_source` if `train_source` is assigned. Defaults to None. use_tqdm: Show progress bar. Defaults to True. root: Root directory where the dataset should be saved. Defaults to "./". check_integrity: Whether to check the integrity of the hdf5 files. Defaults to True. """ def __init__( self, hdf5_path: str | list, subset: list[str] | None = None, train_source: str | GridDataset | None = None, features: list[str] | str | None = "all", target: str | None = None, target_transform: bool = False, target_filter: dict[str, str] | None = None, task: Literal["regress", "classif"] | None = None, classes: list[str] | list[int] | list[float] | None = None, use_tqdm: bool = True, root: str = "./", check_integrity: bool = True, ): super().__init__( hdf5_path, subset, train_source, target, target_transform, target_filter, task, classes, use_tqdm, root, check_integrity, ) self.default_vars = {k: v.default for k, v in inspect.signature(self.__init__).parameters.items() if v.default is not inspect.Parameter.empty} self.default_vars["classes_to_index"] = None self.features = features self.target_transform = target_transform if train_source is not None: self.inherited_params = [ "features", "target", "target_transform", "task", "classes", "classes_to_index", ] self._check_and_inherit_train(GridDataset, self.inherited_params) self._check_features() else: self._check_features() self.inherited_params = None try: fname, mol = self.index_entries[0] except IndexError as e: msg = "No entries found in the dataset. Please check the dataset parameters." raise IndexError(msg) from e with h5py.File(fname, "r") as f5: grp = f5[mol] possible_targets = grp[targets.VALUES].keys() if self.target is None: msg = f"Please set the target during training dataset definition; targets present in the file/s are {possible_targets}." raise ValueError(msg) if self.target not in possible_targets: msg = f"Target {self.target} not present in the file/s; targets present in the file/s are {possible_targets}." raise ValueError(msg) self.features_dict = {} self.features_dict[gridstorage.MAPPED_FEATURES] = self.features if self.target is not None: if isinstance(self.target, str): self.features_dict[targets.VALUES] = [self.target] else: self.features_dict[targets.VALUES] = self.target def _check_features(self) -> None: # noqa: C901 """Checks if the required features exist.""" hdf5_path = self.hdf5_paths[0] # read available features with h5py.File(hdf5_path, "r") as f: mol_key = next(iter(f.keys())) if isinstance(self.features, list): self.features = [ GRID_PARTIAL_FEATURE_NAME_PATTERN.match(feature_name).group(1) if GRID_PARTIAL_FEATURE_NAME_PATTERN.match(feature_name) is not None else feature_name for feature_name in self.features ] # remove the dimension number suffix self.features = list(set(self.features)) # remove duplicates available_features = list(f[f"{mol_key}/{gridstorage.MAPPED_FEATURES}"].keys()) available_features = [key for key in available_features if key[0] != "_"] # ignore metafeatures hdf5_matching_feature_names = [] # feature names that match with the requested list of names unpartial_feature_names = [] # feature names without their dimension number suffix for feature_name in available_features: partial_feature_match = GRID_PARTIAL_FEATURE_NAME_PATTERN.match(feature_name) if partial_feature_match is not None: # there's a dimension number in the feature name unpartial_feature_name = partial_feature_match.group(1) if self.features == "all" or (isinstance(self.features, list) and unpartial_feature_name in self.features): hdf5_matching_feature_names.append(feature_name) unpartial_feature_names.append(unpartial_feature_name) else: # no numbers, it's a one-dimensional feature name if self.features == "all" or (isinstance(self.features, list) and feature_name in self.features): hdf5_matching_feature_names.append(feature_name) unpartial_feature_names.append(feature_name) # check for the requested features missing_features = [] if self.features == "all": self.features = sorted(available_features) self.default_vars["features"] = self.features else: if not isinstance(self.features, list): if self.features is None: self.features = [] else: self.features = [self.features] for feature_name in self.features: if feature_name not in unpartial_feature_names: _log.info(f"The feature {feature_name} was not found in the file {hdf5_path}.") missing_features.append(feature_name) self.features = sorted(hdf5_matching_feature_names) # raise error if any features are missing if len(missing_features) > 0: msg = ( f"Not all features could be found in the file {hdf5_path} under entry {mol_key}.\n\t" f"Missing features are: {missing_features}.\n\t" "Check feature_modules passed to the preprocess function.\n\t" "Probably, the feature wasn't generated during the preprocessing step.\n\t" f"Available features: {available_features}" ) raise ValueError(msg)
[docs] def get(self, idx: int) -> Data: """Gets one grid item from its unique index. Args: idx: Index of the item, ranging from 0 to len(dataset). Returns: :class:`torch_geometric.data.data.Data`: item with tensors x, y if present, entry_names. """ file_path, entry_name = self.index_entries[idx] return self.load_one_grid(file_path, entry_name)
[docs] def load_one_grid(self, hdf5_path: str, entry_name: str) -> Data: """Loads one grid. Args: hdf5_path: .HDF5 file name. entry_name: Name of the entry. Returns: :class:`torch_geometric.data.data.Data`: item with tensors x, y if present, entry_names. """ with h5py.File(hdf5_path, "r") as hdf5_file: grp = hdf5_file[entry_name] mapped_features_group = grp[gridstorage.MAPPED_FEATURES] feature_data = [mapped_features_group[feature_name][:] for feature_name in self.features if feature_name[0] != "_"] x = torch.tensor(np.expand_dims(np.array(feature_data), axis=0), dtype=torch.float) # target if self.target is None: y = None elif targets.VALUES in grp and self.target in grp[targets.VALUES]: y = torch.tensor([grp[targets.VALUES][self.target][()]], dtype=torch.float) if self.task == targets.REGRESS and self.target_transform is True: y = torch.sigmoid(torch.log(y)) elif self.task is not targets.REGRESS and self.target_transform is True: msg = f'Sigmoid transformation not possible for {self.task} tasks. Please change `task` to "regress" or set `target_transform` to `False`.' raise ValueError(msg) else: y = None possible_targets = grp[targets.VALUES].keys() if self.train_source is None: msg = ( f"Target {self.target} missing in entry {entry_name} in file {hdf5_path}, possible targets are {possible_targets}.\n\t" "Use the query class to add more target values to input data." ) raise ValueError(msg) # Wrap up the data in this object, for the collate_fn to handle it properly: data = Data(x=x, y=y) data.entry_names = entry_name return data
[docs]class GraphDataset(DeeprankDataset): """Class to load the .HDF5 files data into graphs. Args: hdf5_path: Path to .HDF5 file(s). For multiple .HDF5 files, insert the paths in a list. Defaults to None. subset: list of keys from .HDF5 file to include. Defaults to None (meaning include all). train_source: data to inherit information from the training dataset or the pre-trained model. If None, the current dataset is considered as the training set. Otherwise, `train_source` needs to be a dataset of the same class or the path of a DeepRank2 pre-trained model. If set, the parameters `node_features`, `edge_features`, `features_transform`, `target`, `target_transform`, `task`, and `classes` will be inherited from `train_source`. If standardization was performed in the training dataset/step, also the attributes `means` and `devs` will be inherited from `train_source`, and they will be used to scale the validation/testing set. Defaults to None. node_features: Consider all pre-computed node features ("all") or some defined node features (provide a list, e.g.: ["res_type", "polarity", "bsa"]). The complete list can be found in `deeprank2.domain.nodestorage`. Value will be ignored and inherited from `train_source` if `train_source` is assigned. Defaults to "all". edge_features: Consider all pre-computed edge features ("all") or some defined edge features (provide a list, e.g.: ["dist", "coulomb"]). The complete list can be found in `deeprank2.domain.edgestorage`. Value will be ignored and inherited from `train_source` if `train_source` is assigned. Defaults to "all". features_transform: Dictionary to indicate the transformations to apply to each feature in the dictionary, being the transformations lambda functions and/or standardization. Example: `features_transform = {'bsa': {'transform': lambda t:np.log(t+1),' standardize': True}}` for the feature `bsa`. An `all` key can be set in the dictionary for indicating to apply the same `standardize` and `transform` to all the features. Example: `features_transform = {'all': {'transform': lambda t:np.log(t+1), 'standardize': True}}`. If both `all` and feature name/s are present, the latter have the priority over what indicated in `all`. Value will be ignored and inherited from `train_source` if `train_source` is assigned. Defaults to None. clustering_method: "mcl" for Markov cluster algorithm (see https://micans.org/mcl/), or "louvain" for Louvain method (see https://en.wikipedia.org/wiki/Louvain_method). In both options, for each graph, the chosen method first finds communities (clusters) of nodes and generates a torch tensor whose elements represent the cluster to which the node belongs to. Each tensor is then saved in the .HDF5 file as a :class:`Dataset` called "depth_0". Then, all cluster members beloging to the same community are pooled into a single node, and the resulting tensor is used to find communities among the pooled clusters. The latter tensor is saved into the .HDF5 file as a :class:`Dataset` called "depth_1". Both "depth_0" and "depth_1" :class:`Datasets` belong to the "cluster" Group. They are saved in the .HDF5 file to make them available to networks that make use of clustering methods. Defaults to None. target: Default options are irmsd, lrmsd, fnat, binary, capri_class, and dockq. It can also be a custom-defined target given to the Query class as input (see: `deeprank2.query`); in this case, the task parameter needs to be explicitly specified as well. Only numerical target variables are supported, not categorical. If the latter is your case, please convert the categorical classes into numerical class indices before defining the :class:`GraphDataset` instance. Value will be ignored and inherited from `train_source` if `train_source` is assigned. Defaults to None. target_transform: Apply a log and then a sigmoid transformation to the target (for regression only). This puts the target value between 0 and 1, and can result in a more uniform target distribution and speed up the optimization. Value will be ignored and inherited from `train_source` if `train_source` is assigned. Defaults to False. target_filter: Dictionary of type [target: cond] to filter the molecules. Note that the you can filter on a different target than the one selected as the dataset target. Defaults to None. task: 'regress' for regression or 'classif' for classification. Required if target not in ['irmsd', 'lrmsd', 'fnat', 'binary', 'capri_class', or 'dockq'], otherwise this setting is ignored. Automatically set to 'classif' if the target is 'binary' or 'capri_classes'. Automatically set to 'regress' if the target is 'irmsd', 'lrmsd', 'fnat', or 'dockq'. Value will be ignored and inherited from `train_source` if `train_source` is assigned. Defaults to None. classes: Define the dataset target classes in classification mode. Value will be ignored and inherited from `train_source` if `train_source` is assigned. Defaults to None. use_tqdm: Show progress bar. Defaults to True. root: Root directory where the dataset should be saved. Defaults to "./". check_integrity: Whether to check the integrity of the hdf5 files. Defaults to True. """ def __init__( # noqa: C901 self, hdf5_path: str | list, subset: list[str] | None = None, train_source: str | GridDataset | None = None, node_features: list[str] | str | None = "all", edge_features: list[str] | str | None = "all", features_transform: dict | None = None, clustering_method: str | None = None, target: str | None = None, target_transform: bool = False, target_filter: dict[str, str] | None = None, task: Literal["regress", "classif"] | None = None, classes: list[str] | list[int] | list[float] | None = None, use_tqdm: bool = True, root: str = "./", check_integrity: bool = True, ): super().__init__( hdf5_path, subset, train_source, target, target_transform, target_filter, task, classes, use_tqdm, root, check_integrity, ) self.default_vars = {k: v.default for k, v in inspect.signature(self.__init__).parameters.items() if v.default is not inspect.Parameter.empty} self.default_vars["classes_to_index"] = None self.node_features = node_features self.edge_features = edge_features self.clustering_method = clustering_method self.target_transform = target_transform self.features_transform = features_transform if train_source is not None: self.inherited_params = [ "node_features", "edge_features", "features_transform", "target", "target_transform", "task", "classes", "classes_to_index", ] self._check_and_inherit_train(GraphDataset, self.inherited_params) self._check_features() else: self._check_features() self.inherited_params = None try: fname, mol = self.index_entries[0] except IndexError as e: msg = "No entries found in the dataset. Please check the dataset parameters." raise IndexError(msg) from e with h5py.File(fname, "r") as f5: grp = f5[mol] possible_targets = grp[targets.VALUES].keys() if self.target is None: msg = f"Please set the target during training dataset definition; targets present in the file/s are {possible_targets}." raise ValueError(msg) if self.target not in possible_targets: msg = f"Target {self.target} not present in the file/s; targets present in the file/s are {possible_targets}." raise ValueError(msg) self.features_dict = {} self.features_dict[Nfeat.NODE] = self.node_features self.features_dict[Efeat.EDGE] = self.edge_features if self.target is not None: if isinstance(self.target, str): self.features_dict[targets.VALUES] = [self.target] else: self.features_dict[targets.VALUES] = self.target standardize = False if self.features_transform: standardize = any(self.features_transform[key].get("standardize") for key, _ in self.features_transform.items()) if standardize and (train_source is None): if self.means or self.devs is None: if self.df is None: self.hdf5_to_pandas() self._compute_mean_std() elif standardize and (train_source is not None): self.means = self.train_means self.devs = self.train_devs
[docs] def get(self, idx: int) -> Data: """Gets one graph item from its unique index. Args: idx: Index of the item, ranging from 0 to len(dataset). Returns: :class:`torch_geometric.data.data.Data`: item with tensors x, y if present, edge_index, edge_attr, pos, entry_names. """ fname, mol = self.index_entries[idx] return self.load_one_graph(fname, mol)
[docs] def load_one_graph(self, fname: str, entry_name: str) -> Data: # noqa: PLR0915, C901 """Loads one graph. Args: fname: .HDF5 file name. entry_name: Name of the entry. Returns: :class:`torch_geometric.data.data.Data`: item with tensors x, y if present, edge_index, edge_attr, pos, entry_names. """ with h5py.File(fname, "r") as f5: grp = f5[entry_name] # node features if len(self.node_features) > 0: node_data = () for feat in self.node_features: # resetting transformation and standardization for each feature transform = None standard = None if feat[0] != "_": # ignore metafeatures vals = grp[f"{Nfeat.NODE}/{feat}"][()] # get feat transformation and standardization if self.features_transform is not None: transform = self.features_transform.get("all", {}).get("transform") standard = self.features_transform.get("all", {}).get("standardize") # if no transformation is set for all features, check if one is set for the current feature if (transform is None) and (feat in self.features_transform): transform = self.features_transform.get(feat, {}).get("transform") # if no standardization is set for all features, check if one is set for the current feature if (standard is None) and (feat in self.features_transform): standard = self.features_transform.get(feat, {}).get("standardize") # apply transformation if transform: with warnings.catch_warnings(record=True) as w: vals = transform(vals) if len(w) > 0: msg = ( f"Invalid value occurs in {entry_name}, file {fname},when applying {transform} for feature {feat}.\n\t" f"Please change the transformation function for {feat}." ) raise ValueError(msg) if vals.ndim == 1: # features with only one channel vals = vals.reshape(-1, 1) if standard: vals = (vals - self.means[feat]) / self.devs[feat] elif standard: reshaped_mean = [mean_value for mean_key, mean_value in self.means.items() if feat in mean_key] reshaped_dev = [dev_value for dev_key, dev_value in self.devs.items() if feat in dev_key] vals = (vals - reshaped_mean) / reshaped_dev node_data += (vals,) x = torch.tensor(np.hstack(node_data), dtype=torch.float) else: x = None _log.warning("No node features set.") # edge index, # we have to have all the edges i.e : (i,j) and (j,i) if Efeat.INDEX in grp[Efeat.EDGE]: ind = grp[f"{Efeat.EDGE}/{Efeat.INDEX}"][()] if ind.ndim == 2: # noqa: PLR2004 ind = np.vstack((ind, np.flip(ind, 1))).T edge_index = torch.tensor(ind, dtype=torch.long).contiguous() else: edge_index = torch.empty((2, 0), dtype=torch.long) # edge feature # we have to have all the edges i.e : (i,j) and (j,i) if len(self.edge_features) > 0: edge_data = () for feat in self.edge_features: # resetting transformation and standardization for each feature transform = None standard = None if feat[0] != "_": # ignore metafeatures vals = grp[f"{Efeat.EDGE}/{feat}"][()] # get feat transformation and standardization if self.features_transform is not None: transform = self.features_transform.get("all", {}).get("transform") standard = self.features_transform.get("all", {}).get("standardize") # if no transformation is set for all features, check if one is set for the current feature if (transform is None) and (feat in self.features_transform): transform = self.features_transform.get(feat, {}).get("transform") # if no standardization is set for all features, check if one is set for the current feature if (standard is None) and (feat in self.features_transform): standard = self.features_transform.get(feat, {}).get("standardize") # apply transformation if transform: with warnings.catch_warnings(record=True) as w: vals = transform(vals) if len(w) > 0: msg = ( f"Invalid value occurs in {entry_name}, file {fname}, when applying {transform} for feature {feat}.\n\t" f"Please change the transformation function for {feat}." ) raise ValueError(msg) if vals.ndim == 1: vals = vals.reshape(-1, 1) if standard: vals = (vals - self.means[feat]) / self.devs[feat] elif standard: reshaped_mean = [mean_value for mean_key, mean_value in self.means.items() if feat in mean_key] reshaped_dev = [dev_value for dev_key, dev_value in self.devs.items() if feat in dev_key] vals = (vals - reshaped_mean) / reshaped_dev edge_data += (vals,) edge_data = np.hstack(edge_data) edge_data = np.vstack((edge_data, edge_data)) edge_attr = torch.tensor(edge_data, dtype=torch.float).contiguous() else: edge_attr = torch.empty((edge_index.shape[1], 0), dtype=torch.float).contiguous() # target if self.target is None: y = None elif targets.VALUES in grp and self.target in grp[targets.VALUES]: y = torch.tensor([grp[f"{targets.VALUES}/{self.target}"][()]], dtype=torch.float).contiguous() if self.task == targets.REGRESS and self.target_transform is True: y = torch.sigmoid(torch.log(y)) elif self.task is not targets.REGRESS and self.target_transform is True: msg = f'Sigmoid transformation not possible for {self.task} tasks. Please change `task` to "regress" or set `target_transform` to `False`.' raise ValueError(msg) else: y = None possible_targets = grp[targets.VALUES].keys() if self.train_source is None: msg = ( f"Target {self.target} missing in entry {entry_name} in file {fname}, possible targets are {possible_targets}.\n\t" "Use the query class to add more target values to input data." ) raise ValueError(msg) # positions pos = torch.tensor(grp[f"{Nfeat.NODE}/{Nfeat.POSITION}/"][()], dtype=torch.float).contiguous() # cluster cluster0 = None cluster1 = None if self.clustering_method is not None and "clustering" in grp: if self.clustering_method in grp["clustering"]: if "depth_0" in grp[f"clustering/{self.clustering_method}"] and "depth_1" in grp[f"clustering/{self.clustering_method}"]: cluster0 = torch.tensor( grp["clustering/" + self.clustering_method + "/depth_0"][()], dtype=torch.long, ) cluster1 = torch.tensor( grp["clustering/" + self.clustering_method + "/depth_1"][()], dtype=torch.long, ) else: _log.warning("no clusters detected") else: _log.warning(f"no clustering/{self.clustering_method} detected") # load data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, pos=pos) data.cluster0 = cluster0 data.cluster1 = cluster1 data.entry_names = entry_name return data
def _check_features(self) -> None: # noqa: C901 """Checks if the required features exist.""" f = h5py.File(self.hdf5_paths[0], "r") mol_key = next(iter(f.keys())) # read available node features self.available_node_features = list(f[f"{mol_key}/{Nfeat.NODE}/"].keys()) self.available_node_features = [key for key in self.available_node_features if key[0] != "_"] # ignore metafeatures # read available edge features self.available_edge_features = list(f[f"{mol_key}/{Efeat.EDGE}/"].keys()) self.available_edge_features = [key for key in self.available_edge_features if key[0] != "_"] # ignore metafeatures f.close() # check node features missing_node_features = [] if self.node_features == "all": self.node_features = self.available_node_features self.default_vars["node_features"] = self.node_features else: if not isinstance(self.node_features, list): if self.node_features is None: self.node_features = [] else: self.node_features = [self.node_features] for feat in self.node_features: if feat not in self.available_node_features: _log.info(f"The node feature _{feat}_ was not found in the file {self.hdf5_paths[0]}.") missing_node_features.append(feat) # check edge features missing_edge_features = [] if self.edge_features == "all": self.edge_features = self.available_edge_features self.default_vars["edge_features"] = self.edge_features else: if not isinstance(self.edge_features, list): if self.edge_features is None: self.edge_features = [] else: self.edge_features = [self.edge_features] for feat in self.edge_features: if feat not in self.available_edge_features: _log.info(f"The edge feature _{feat}_ was not found in the file {self.hdf5_paths[0]}.") missing_edge_features.append(feat) # raise error if any features are missing if missing_node_features + missing_edge_features: miss_node_error, miss_edge_error = "", "" _log.info( "\nCheck feature_modules passed to the preprocess function.\ Probably, the feature wasn't generated during the preprocessing step.", ) if missing_node_features: _log.info(f"\nAvailable node features: {self.available_node_features}\n") miss_node_error = f"\nMissing node features: {missing_node_features} \ \nAvailable node features: {self.available_node_features}" if missing_edge_features: _log.info(f"\nAvailable edge features: {self.available_edge_features}\n") miss_edge_error = f"\nMissing edge features: {missing_edge_features} \ \nAvailable edge features: {self.available_edge_features}" msg = ( f"Not all features could be found in the file {self.hdf5_paths[0]}.\n\t" "Check feature_modules passed to the preprocess function.\n\t" "Probably, the feature wasn't generated during the preprocessing step.\n\t" f"{miss_node_error}{miss_edge_error}" ) raise ValueError(msg)
[docs]def save_hdf5_keys( f_src_path: str, src_ids: list[str], f_dest_path: str, hardcopy: bool = False, ) -> None: """Save references to keys in src_ids in a new .HDF5 file. Args: f_src_path: The path to the .HDF5 file containing the keys. src_ids: Keys to be saved in the new .HDF5 file. It should be a list containing at least one key. f_dest_path: The path to the new .HDF5 file. hardcopy: If False, the new file contains only references (external links, see :class:`ExternalLink` class from `h5py`) to the original .HDF5 file. If True, the new file contains a copy of the objects specified in src_ids (see h5py :class:`HardLink` from `h5py`). Defaults to False. """ if not all(isinstance(d, str) for d in src_ids): msg = "data_ids should be a list containing strings." raise TypeError(msg) with h5py.File(f_dest_path, "w") as f_dest, h5py.File(f_src_path, "r") as f_src: for key in src_ids: if hardcopy: f_src.copy(f_src[key], f_dest) else: f_dest[key] = h5py.ExternalLink(f_src_path, "/" + key)