deeprank2 main modules

deeprank2.dataset

class deeprank2.dataset.DeeprankDataset(*args: Any, **kwargs: Any)[source]

Bases: Dataset

Parent class of GridDataset and GraphDataset.

This class inherits from torch_geometric.data.dataset.Dataset. More detailed information about the parameters can be found in GridDataset and GraphDataset.

len() int[source]

Gets the length of the dataset, either GridDataset or GraphDataset object.

Returns

Number of complexes in the dataset.

Return type

int

hdf5_to_pandas() pandas.DataFrame[source]

Loads features data from the HDF5 files into a Pandas DataFrame in the attribute df of the class.

Returns

Pandas DataFrame containing the selected features as columns per all data points in

hdf5_path files.

Return type

pd.DataFrame

save_hist(features: str | list[str], fname: str = 'features_hist.png', bins: int | list[float] | str = 10, figsize: tuple = (15, 15), log: bool = False) None[source]

After having generated a pd.DataFrame using hdf5_to_pandas method, histograms of the features can be saved in an image.

Parameters
  • features – Features to be plotted.

  • fname – str or path-like or binary file-like object. Defaults to ‘features_hist.png’.

  • bins – If bins is an integer, it defines the number of equal-width bins in the range. If bins is a sequence, it defines the bin edges, including the left edge of the first bin and the right edge of the last bin; in this case, bins may be unequally spaced. All but the last (righthand-most) bin is half-open. If bins is a string, it is one of the binning strategies supported by numpy.histogram_bin_edges: ‘auto’, ‘fd’, ‘doane’, ‘scott’, ‘stone’, ‘rice’, ‘sturges’, or ‘sqrt’. Defaults to 10.

  • figsize – Saved figure sizes. Defaults to (15, 15).

  • log – Whether to apply log transformation to the data indicated by the features parameter. Defaults to False.

class deeprank2.dataset.GridDataset(*args: Any, **kwargs: Any)[source]

Bases: DeeprankDataset

Class to load the .HDF5 files data into grids.

Parameters
  • hdf5_path – Path to .HDF5 file(s). For multiple .HDF5 files, insert the paths in a list. Defaults to None.

  • subset – list of keys from .HDF5 file to include. Defaults to None (meaning include all).

  • train_source – data to inherit information from the training dataset or the pre-trained model. If None, the current dataset is considered as the training set. Otherwise, train_source needs to be a dataset of the same class or the path of a DeepRank2 pre-trained model. If set, the parameters features, target, traget_transform, task, and classes will be inherited from train_source. Defaults to None.

  • features – Consider all pre-computed features (“all”) or some defined node features (provide a list, example: [“res_type”, “polarity”, “bsa”]). The complete list can be found in deeprank2.domain.gridstorage. Value will be ignored and inherited from train_source if train_source is assigned. Defaults to “all”.

  • target – Default options are irmsd, lrmsd, fnat, binary, capri_class, and dockq. It can also be a custom-defined target given to the Query class as input (see: deeprank2.query); in this case, the task parameter needs to be explicitly specified as well. Only numerical target variables are supported, not categorical. If the latter is your case, please convert the categorical classes into numerical class indices before defining the GraphDataset instance. Value will be ignored and inherited from train_source if train_source is assigned. Defaults to None.

  • target_transform – Apply a log and then a sigmoid transformation to the target (for regression only). This puts the target value between 0 and 1, and can result in a more uniform target distribution and speed up the optimization. Value will be ignored and inherited from train_source if train_source is assigned. Defaults to False.

  • target_filter – Dictionary of type [target: cond] to filter the molecules. Note that the you can filter on a different target than the one selected as the dataset target. Defaults to None.

  • task – ‘regress’ for regression or ‘classif’ for classification. Required if target not in [‘irmsd’, ‘lrmsd’, ‘fnat’, ‘binary’, ‘capri_class’, or ‘dockq’], otherwise this setting is ignored. Automatically set to ‘classif’ if the target is ‘binary’ or ‘capri_classes’. Automatically set to ‘regress’ if the target is ‘irmsd’, ‘lrmsd’, ‘fnat’, or ‘dockq’. Value will be ignored and inherited from train_source if train_source is assigned. Defaults to None.

  • classes – Define the dataset target classes in classification mode. Value will be ignored and inherited from train_source if train_source is assigned. Defaults to None.

  • use_tqdm – Show progress bar. Defaults to True.

  • root – Root directory where the dataset should be saved. Defaults to “./”.

  • check_integrity – Whether to check the integrity of the hdf5 files. Defaults to True.

get(idx: int) torch_geometric.data.data.Data[source]

Gets one grid item from its unique index.

Parameters

idx – Index of the item, ranging from 0 to len(dataset).

Returns

item with tensors x, y if present, entry_names.

Return type

torch_geometric.data.data.Data

load_one_grid(hdf5_path: str, entry_name: str) torch_geometric.data.data.Data[source]

Loads one grid.

Parameters
  • hdf5_path – .HDF5 file name.

  • entry_name – Name of the entry.

Returns

item with tensors x, y if present, entry_names.

Return type

torch_geometric.data.data.Data

hdf5_to_pandas() pandas.DataFrame

Loads features data from the HDF5 files into a Pandas DataFrame in the attribute df of the class.

Returns

Pandas DataFrame containing the selected features as columns per all data points in

hdf5_path files.

Return type

pd.DataFrame

len() int

Gets the length of the dataset, either GridDataset or GraphDataset object.

Returns

Number of complexes in the dataset.

Return type

int

save_hist(features: str | list[str], fname: str = 'features_hist.png', bins: int | list[float] | str = 10, figsize: tuple = (15, 15), log: bool = False) None

After having generated a pd.DataFrame using hdf5_to_pandas method, histograms of the features can be saved in an image.

Parameters
  • features – Features to be plotted.

  • fname – str or path-like or binary file-like object. Defaults to ‘features_hist.png’.

  • bins – If bins is an integer, it defines the number of equal-width bins in the range. If bins is a sequence, it defines the bin edges, including the left edge of the first bin and the right edge of the last bin; in this case, bins may be unequally spaced. All but the last (righthand-most) bin is half-open. If bins is a string, it is one of the binning strategies supported by numpy.histogram_bin_edges: ‘auto’, ‘fd’, ‘doane’, ‘scott’, ‘stone’, ‘rice’, ‘sturges’, or ‘sqrt’. Defaults to 10.

  • figsize – Saved figure sizes. Defaults to (15, 15).

  • log – Whether to apply log transformation to the data indicated by the features parameter. Defaults to False.

class deeprank2.dataset.GraphDataset(*args: Any, **kwargs: Any)[source]

Bases: DeeprankDataset

Class to load the .HDF5 files data into graphs.

Parameters
  • hdf5_path – Path to .HDF5 file(s). For multiple .HDF5 files, insert the paths in a list. Defaults to None.

  • subset – list of keys from .HDF5 file to include. Defaults to None (meaning include all).

  • train_source – data to inherit information from the training dataset or the pre-trained model. If None, the current dataset is considered as the training set. Otherwise, train_source needs to be a dataset of the same class or the path of a DeepRank2 pre-trained model. If set, the parameters node_features, edge_features, features_transform, target, target_transform, task, and classes will be inherited from train_source. If standardization was performed in the training dataset/step, also the attributes means and devs will be inherited from train_source, and they will be used to scale the validation/testing set. Defaults to None.

  • node_features – Consider all pre-computed node features (“all”) or some defined node features (provide a list, e.g.: [“res_type”, “polarity”, “bsa”]). The complete list can be found in deeprank2.domain.nodestorage. Value will be ignored and inherited from train_source if train_source is assigned. Defaults to “all”.

  • edge_features – Consider all pre-computed edge features (“all”) or some defined edge features (provide a list, e.g.: [“dist”, “coulomb”]). The complete list can be found in deeprank2.domain.edgestorage. Value will be ignored and inherited from train_source if train_source is assigned. Defaults to “all”.

  • features_transform – Dictionary to indicate the transformations to apply to each feature in the dictionary, being the transformations lambda functions and/or standardization. Example: features_transform = {‘bsa’: {‘transform’: lambda t:np.log(t+1),’ standardize’: True}} for the feature bsa. An all key can be set in the dictionary for indicating to apply the same standardize and transform to all the features. Example: features_transform = {‘all’: {‘transform’: lambda t:np.log(t+1), ‘standardize’: True}}. If both all and feature name/s are present, the latter have the priority over what indicated in all. Value will be ignored and inherited from train_source if train_source is assigned. Defaults to None.

  • clustering_method – “mcl” for Markov cluster algorithm (see https://micans.org/mcl/), or “louvain” for Louvain method (see https://en.wikipedia.org/wiki/Louvain_method). In both options, for each graph, the chosen method first finds communities (clusters) of nodes and generates a torch tensor whose elements represent the cluster to which the node belongs to. Each tensor is then saved in the .HDF5 file as a Dataset called “depth_0”. Then, all cluster members beloging to the same community are pooled into a single node, and the resulting tensor is used to find communities among the pooled clusters. The latter tensor is saved into the .HDF5 file as a Dataset called “depth_1”. Both “depth_0” and “depth_1” Datasets belong to the “cluster” Group. They are saved in the .HDF5 file to make them available to networks that make use of clustering methods. Defaults to None.

  • target – Default options are irmsd, lrmsd, fnat, binary, capri_class, and dockq. It can also be a custom-defined target given to the Query class as input (see: deeprank2.query); in this case, the task parameter needs to be explicitly specified as well. Only numerical target variables are supported, not categorical. If the latter is your case, please convert the categorical classes into numerical class indices before defining the GraphDataset instance. Value will be ignored and inherited from train_source if train_source is assigned. Defaults to None.

  • target_transform – Apply a log and then a sigmoid transformation to the target (for regression only). This puts the target value between 0 and 1, and can result in a more uniform target distribution and speed up the optimization. Value will be ignored and inherited from train_source if train_source is assigned. Defaults to False.

  • target_filter – Dictionary of type [target: cond] to filter the molecules. Note that the you can filter on a different target than the one selected as the dataset target. Defaults to None.

  • task – ‘regress’ for regression or ‘classif’ for classification. Required if target not in [‘irmsd’, ‘lrmsd’, ‘fnat’, ‘binary’, ‘capri_class’, or ‘dockq’], otherwise this setting is ignored. Automatically set to ‘classif’ if the target is ‘binary’ or ‘capri_classes’. Automatically set to ‘regress’ if the target is ‘irmsd’, ‘lrmsd’, ‘fnat’, or ‘dockq’. Value will be ignored and inherited from train_source if train_source is assigned. Defaults to None.

  • classes – Define the dataset target classes in classification mode. Value will be ignored and inherited from train_source if train_source is assigned. Defaults to None.

  • use_tqdm – Show progress bar. Defaults to True.

  • root – Root directory where the dataset should be saved. Defaults to “./”.

  • check_integrity – Whether to check the integrity of the hdf5 files. Defaults to True.

get(idx: int) torch_geometric.data.data.Data[source]

Gets one graph item from its unique index.

Parameters

idx – Index of the item, ranging from 0 to len(dataset).

Returns

item with tensors x, y if present, edge_index, edge_attr, pos, entry_names.

Return type

torch_geometric.data.data.Data

hdf5_to_pandas() pandas.DataFrame

Loads features data from the HDF5 files into a Pandas DataFrame in the attribute df of the class.

Returns

Pandas DataFrame containing the selected features as columns per all data points in

hdf5_path files.

Return type

pd.DataFrame

len() int

Gets the length of the dataset, either GridDataset or GraphDataset object.

Returns

Number of complexes in the dataset.

Return type

int

load_one_graph(fname: str, entry_name: str) torch_geometric.data.data.Data[source]

Loads one graph.

Parameters
  • fname – .HDF5 file name.

  • entry_name – Name of the entry.

Returns

item with tensors x, y if present, edge_index, edge_attr, pos, entry_names.

Return type

torch_geometric.data.data.Data

save_hist(features: str | list[str], fname: str = 'features_hist.png', bins: int | list[float] | str = 10, figsize: tuple = (15, 15), log: bool = False) None

After having generated a pd.DataFrame using hdf5_to_pandas method, histograms of the features can be saved in an image.

Parameters
  • features – Features to be plotted.

  • fname – str or path-like or binary file-like object. Defaults to ‘features_hist.png’.

  • bins – If bins is an integer, it defines the number of equal-width bins in the range. If bins is a sequence, it defines the bin edges, including the left edge of the first bin and the right edge of the last bin; in this case, bins may be unequally spaced. All but the last (righthand-most) bin is half-open. If bins is a string, it is one of the binning strategies supported by numpy.histogram_bin_edges: ‘auto’, ‘fd’, ‘doane’, ‘scott’, ‘stone’, ‘rice’, ‘sturges’, or ‘sqrt’. Defaults to 10.

  • figsize – Saved figure sizes. Defaults to (15, 15).

  • log – Whether to apply log transformation to the data indicated by the features parameter. Defaults to False.

deeprank2.dataset.save_hdf5_keys(f_src_path: str, src_ids: list[str], f_dest_path: str, hardcopy: bool = False) None[source]

Save references to keys in src_ids in a new .HDF5 file.

Parameters
  • f_src_path – The path to the .HDF5 file containing the keys.

  • src_ids – Keys to be saved in the new .HDF5 file. It should be a list containing at least one key.

  • f_dest_path – The path to the new .HDF5 file.

  • hardcopy – If False, the new file contains only references (external links, see ExternalLink class from h5py) to the original .HDF5 file. If True, the new file contains a copy of the objects specified in src_ids (see h5py HardLink from h5py). Defaults to False.

deeprank2.query

class deeprank2.query.Query(*, pdb_path: str, resolution: ~typing.Literal['residue', 'atom'], chain_ids: list[str] | str, pssm_paths: dict[str, str] = <factory>, targets: dict[str, float] = <factory>, influence_radius: float | None = None, max_edge_length: float | None = None, suppress_pssm_errors: bool = False)[source]

Bases: object

Parent class of SingleResidueVariantQuery and ProteinProteinInterfaceQuery.

More detailed information about the parameters can be found in SingleResidueVariantQuery and ProteinProteinInterfaceQuery.

pdb_path: str
resolution: Literal['residue', 'atom']
chain_ids: list[str] | str
pssm_paths: dict[str, str]
targets: dict[str, float]
influence_radius: float | None = None
max_edge_length: float | None = None
suppress_pssm_errors: bool = False
property model_id: str

The ID of the model, usually a .PDB accession code.

build(feature_modules: list[str | module]) Graph[source]

Builds the graph from the .PDB structure.

Parameters

feature_modules – the feature modules used to build the graph. These must be filenames existing inside deeprank2.features subpackage.

Returns

The resulting Graph object with all the features and targets.

Return type

Graph

get_query_id() str[source]
class deeprank2.query.SingleResidueVariantQuery(*, pdb_path: str, resolution: Literal['residue', 'atom'], chain_ids: list[str] | str, pssm_paths: dict[str, str] = <factory>, targets: dict[str, float] = <factory>, influence_radius: float | None = None, max_edge_length: float | None = None, suppress_pssm_errors: bool = False, variant_residue_number: int, insertion_code: str | None, wildtype_amino_acid: AminoAcid, variant_amino_acid: AminoAcid)[source]

Bases: Query

A query that builds a single residue variant graph.

Parameters
  • pdb_path – the path to the PDB file to query.

  • resolution – sets whether each node is a residue or atom.

  • chain_ids – the chain identifier of the variant residue (generally a single capital letter). Note that this does not limit the structure to residues from this chain.

  • pssm_paths – the name of the chain(s) (key) and path to the pssm file(s) (value).

  • targets – Name(s) (key) and target value(s) (value) associated with this query.

  • influence_radius – all residues within this radius from the variant residue will be included in the graph, irrespective of the chain they are on.

  • max_edge_length – the maximum distance between two nodes to generate an edge connecting them.

  • suppress_pssm_errors – Whether to suppress the error raised if the .pssm files do not match the .pdb files. If True, a warning is returned instead.

  • variant_residue_number – the residue number of the variant residue.

  • insertion_code – the insertion code of the variant residue.

  • wildtype_amino_acid – the amino acid at above position in the wildtype protein.

  • variant_amino_acid – the amino acid at above position in the variant protein.

  • radius – all Residues within this radius (in Å) from the variant residue will be included in the graph.

variant_residue_number: int
insertion_code: str | None
wildtype_amino_acid: AminoAcid
variant_amino_acid: AminoAcid
property residue_id: str

String representation of the residue number and insertion code.

get_query_id() str[source]

Returns the string representing the complete query ID.

build(feature_modules: list[str | module]) Graph

Builds the graph from the .PDB structure.

Parameters

feature_modules – the feature modules used to build the graph. These must be filenames existing inside deeprank2.features subpackage.

Returns

The resulting Graph object with all the features and targets.

Return type

Graph

influence_radius: float | None = None
max_edge_length: float | None = None
property model_id: str

The ID of the model, usually a .PDB accession code.

suppress_pssm_errors: bool = False
pdb_path: str
resolution: Literal['residue', 'atom']
chain_ids: list[str] | str
pssm_paths: dict[str, str]
targets: dict[str, float]
class deeprank2.query.ProteinProteinInterfaceQuery(*, pdb_path: str, resolution: ~typing.Literal['residue', 'atom'], chain_ids: list[str] | str, pssm_paths: dict[str, str] = <factory>, targets: dict[str, float] = <factory>, influence_radius: float | None = None, max_edge_length: float | None = None, suppress_pssm_errors: bool = False)[source]

Bases: Query

A query that builds a protein-protein interface graph.

Parameters
  • pdb_path – the path to the PDB file to query.

  • resolution – sets whether each node is a residue or atom.

  • chain_ids – the chain identifiers of the interacting interfaces (generally a single capital letter each). Note that this does not limit the structure to residues from these chains.

  • pssm_paths – the name of the chain(s) (key) and path to the pssm file(s) (value).

  • targets – Name(s) (key) and target value(s) (value) associated with this query.

  • influence_radius – all residues within this radius from the interacting interface will be included in the graph, irrespective of the chain they are on.

  • max_edge_length – the maximum distance between two nodes to generate an edge connecting them.

  • suppress_pssm_errors – Whether to suppress the error raised if the .pssm files do not match the .pdb files. If True, a warning is returned instead.

get_query_id() str[source]

Returns the string representing the complete query ID.

build(feature_modules: list[str | module]) Graph

Builds the graph from the .PDB structure.

Parameters

feature_modules – the feature modules used to build the graph. These must be filenames existing inside deeprank2.features subpackage.

Returns

The resulting Graph object with all the features and targets.

Return type

Graph

influence_radius: float | None = None
max_edge_length: float | None = None
property model_id: str

The ID of the model, usually a .PDB accession code.

suppress_pssm_errors: bool = False
pdb_path: str
resolution: Literal['residue', 'atom']
chain_ids: list[str] | str
pssm_paths: dict[str, str]
targets: dict[str, float]
class deeprank2.query.QueryCollection[source]

Bases: object

Represents the collection of data queries that will be processed.

The class attributes are set either while adding queries to the collection (_queries and _ids_count), or when processing the collection (other attributes).

_queries

The Query objects in the collection.

Type

list[Query]

_ids_count

The original query_id and the repeat number for this id. This is used to rename the query_id to ensure that there are no duplicate ids.

Type

dict[str, int]

_prefix, _cpu_count, _grid_settings, etc.

See docstring for QueryCollection.process.

Notes

Queries can be saved as a dictionary to easily navigate through their data, using QueryCollection.export_dict().

add(query: Query, verbose: bool = False, warn_duplicate: bool = True) None[source]

Add a new query to the collection.

Parameters
  • query – The Query to add to the collection.

  • verbose – For logging query IDs added. Defaults to False.

  • warn_duplicate – Log a warning before renaming if a duplicate query is identified. Defaults to True.

export_dict(dataset_path: str) None[source]

Exports the colection of all queries to a dictionary file.

Parameters

dataset_path – The path where to save the list of queries.

property queries: list[deeprank2.query.Query]

The list of queries added to the collection.

process(prefix: str = 'processed-queries', feature_modules: list[module, str] | module | str | None = None, cpu_count: int | None = None, combine_output: bool = True, grid_settings: deeprank2.utils.grid.GridSettings | None = None, grid_map_method: deeprank2.utils.grid.MapMethod | None = None, grid_augmentation_count: int = 0, log_error_traceback: bool = False) list[str][source]

Render queries into graphs (and optionally grids).

Parameters
  • prefix – Prefix for naming the output files. Defaults to “processed-queries”.

  • feature_modules – Feature module or list of feature modules used to generate features (given as string or as an imported module). Each module must implement the add_features() function, and all feature modules must exist inside deeprank2.features folder. If set to ‘all’, all available modules in deeprank2.features are used to generate the features. Defaults to the two primary feature modules deeprank2.features.components and deeprank2.features.contact.

  • cpu_count – The number of processes to be run in parallel (i.e. number of CPUs used), capped by the number of CPUs available to the system. Defaults to None, which takes all available cpu cores.

  • combine_output – If True (default): all processes are combined into a single HDF5 file. If False: separate HDF5 files are created for each process (i.e. for each CPU used).

  • grid_settings – If valid together with grid_map_method, the grid data will be stored as well. Defaults to None.

  • grid_map_method – If valid together with grid_settings, the grid data will be stored as well. Defaults to None.

  • grid_augmentation_count – Number of grid data augmentations (must be >= 0). Defaults to 0.

  • log_error_traceback – if True, logs full error message in case query fails. Otherwise only the error message is logged. Defaults to false.

Returns

The list of paths of the generated HDF5 files.

deeprank2.trainer

class deeprank2.trainer.Trainer(neuralnet: nn.Module | None = None, dataset_train: GraphDataset | GridDataset | None = None, dataset_val: GraphDataset | GridDataset | None = None, dataset_test: GraphDataset | GridDataset | None = None, val_size: float | int | None = None, test_size: float | int | None = None, class_weights: bool = False, pretrained_model: str | None = None, cuda: bool = False, ngpu: int = 0, output_exporters: list[OutputExporter] | None = None)[source]

Bases: object

Class from which the network is trained, evaluated and tested.

Parameters
  • neuralnet – Neural network class (ex. GINet, Foutnet etc.). It should subclass torch.nn.Module, and it shouldn’t be specific to regression or classification in terms of output shape (Trainer class takes care of formatting the output shape according to the task). More specifically, in classification task cases, softmax shouldn’t be used as the last activation function. Defaults to None.

  • dataset_train – Training set used during training. Can’t be None if pretrained_model is also None. Defaults to None.

  • dataset_val – Evaluation set used during training. If None, training set will be split randomly into training set and validation set during training, using val_size parameter. Defaults to None.

  • dataset_test – Independent evaluation set. Defaults to None.

  • val_size – Fraction of dataset (if float) or number of datapoints (if int) to use for validation. Only used if dataset_val is not specified. Can be set to 0 if no validation set is needed. Defaults to None (in _divide_dataset function).

  • test_size – Fraction of dataset (if float) or number of datapoints (if int) to use for test dataset. Only used if dataset_test is not specified. Can be set to 0 if no test set is needed. Defaults to None.

  • class_weights – Assign class weights based on the dataset content. Defaults to False.

  • pretrained_model – Path to pre-trained model. Defaults to None.

  • cuda – Whether to use CUDA. Defaults to False.

  • ngpu – Number of GPU to be used. Defaults to 0.

  • output_exporters – The output exporters to use for saving/exploring/plotting predictions/targets/losses over the epochs. If None, defaults to HDF5OutputExporter, which saves all the results in an .HDF5 file stored in ./output directory. Defaults to None.

configure_optimizers(optimizer: torch.optim = None, lr: float = 0.001, weight_decay: float = 1e-05) None[source]

Configure optimizer and its main parameters.

Parameters
  • optimizer – PyTorch optimizer object. If none, defaults to torch.optim.Adam. Defaults to None.

  • lr – Learning rate. Defaults to 0.001.

  • weight_decay – Weight decay (L2 penalty). This is fundamental for GNNs, otherwise, parameters can become too big and the gradient may explode. Defaults to 1e-05.

set_lossfunction(lossfunction: nn.modules.loss._Loss | None = None, override_invalid: bool = False) None[source]

Set the loss function.

Parameters
  • lossfunction – Make sure to use a loss function that is appropriate for your task (classification or regression). All loss functions from torch.nn.modules.loss are listed as belonging to either category (or to neither) and an exception is raised if an invalid loss function is chosen for the set task. Default for regression: MSELoss. Default for classification: CrossEntropyLoss.

  • override_invalid – If True, loss functions that are considered invalid for the task do no longer automaticallt raise an exception. Defaults to False.

train(nepoch: int = 1, batch_size: int = 32, shuffle: bool = True, earlystop_patience: int | None = None, earlystop_maxgap: float | None = None, min_epoch: int = 10, validate: bool = False, num_workers: int = 0, best_model: bool = True, filename: str | None = 'model.pth.tar') None[source]

Performs the training of the model.

Parameters
  • nepoch – Maximum number of epochs to run. Defaults to 1.

  • batch_size – Sets the size of the batch. Defaults to 32.

  • shuffle – Whether to shuffle the training dataloaders data (train set and validation set). Default: True.

  • earlystop_patience – Training ends if the model has run for this number of epochs without improving the validation loss. Defaults to None.

  • earlystop_maxgap – Training ends if the difference between validation and training loss exceeds this value. Defaults to None.

  • min_epoch – Minimum epoch to be reached before looking at maxgap. Defaults to 10.

  • validate – Perform validation on independent data set (requires a validation data set). Defaults to False.

  • num_workers – How many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. Defaults to 0.

  • best_model – If True (default), the best model (in terms of validation loss) is selected for later testing or saving. If False, the last model tried is selected.

  • filename – Name of the file where to save the selected model. If not None, the model is saved to filename. If None, the model is not saved. Defaults to ‘model.pth.tar’.

test(batch_size: int = 32, num_workers: int = 0) None[source]

Performs the testing of the model.

Parameters
  • batch_size – Sets the size of the batch. Defaults to 32.

  • num_workers – How many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. Defaults to 0.