utils#

Utility functions for surrogate model handling

Functions

check_parameter_change(params, prev_params)

Check the average change in model parameters.

fit_pytorch(model, X, y[, loss_fcn, ...])

Fits a PyTorch model to the given input data. :param model: The PyTorch model to be trained. :type model: Module :param X: The input data. :type X: Tensor :param y: The target data. :type y: Tensor :param loss_fcn: The loss function to be used. If not provided, the Mean Squared Error (MSE) loss function will be used. Defaults to None. :type loss_fcn: Module, optional :param verbose: Whether to print the loss at each epoch. Defaults to False. :type verbose: bool, optional.

obsidian.surrogates.utils.check_parameter_change(params: list[Parameter], prev_params: list[Parameter]) float[source]#

Check the average change in model parameters.

obsidian.surrogates.utils.fit_pytorch(model: Module, X: Tensor, y: Tensor, loss_fcn: Module | None = None, verbose: bool = False, max_iter: int = 5000) None[source]#

Fits a PyTorch model to the given input data. :param model: The PyTorch model to be trained. :type model: Module :param X: The input data. :type X: Tensor :param y: The target data. :type y: Tensor :param loss_fcn: The loss function to be used.

If not provided, the Mean Squared Error (MSE) loss function will be used. Defaults to None.

Parameters:

verbose (bool, optional) – Whether to print the loss at each epoch. Defaults to False.