fit_pytorch#

obsidian.surrogates.utils.fit_pytorch(model: Module, X: Tensor, y: Tensor, loss_fcn: Module | None = None, verbose: bool = False, max_iter: int = 5000) None[source]#

Fits a PyTorch model to the given input data. :param model: The PyTorch model to be trained. :type model: Module :param X: The input data. :type X: Tensor :param y: The target data. :type y: Tensor :param loss_fcn: The loss function to be used.

If not provided, the Mean Squared Error (MSE) loss function will be used. Defaults to None.

Parameters:

verbose (bool, optional) – Whether to print the loss at each epoch. Defaults to False.