utils#
Utility functions for surrogate model handling
Functions
|
Check the average change in model parameters. |
|
Fits a PyTorch model to the given input data. :param model: The PyTorch model to be trained. :type model: Module :param X: The input data. :type X: Tensor :param y: The target data. :type y: Tensor :param loss_fcn: The loss function to be used. If not provided, the Mean Squared Error (MSE) loss function will be used. Defaults to |
- obsidian.surrogates.utils.check_parameter_change(params: list[Parameter], prev_params: list[Parameter]) float [source]#
Check the average change in model parameters.
- obsidian.surrogates.utils.fit_pytorch(model: Module, X: Tensor, y: Tensor, loss_fcn: Module | None = None, verbose: bool = False, max_iter: int = 5000) None [source]#
Fits a PyTorch model to the given input data. :param model: The PyTorch model to be trained. :type model: Module :param X: The input data. :type X: Tensor :param y: The target data. :type y: Tensor :param loss_fcn: The loss function to be used.
If not provided, the Mean Squared Error (MSE) loss function will be used. Defaults to
None
.- Parameters:
verbose (bool, optional) – Whether to print the loss at each epoch. Defaults to
False
.