fit_pytorch#
- obsidian.surrogates.utils.fit_pytorch(model: Module, X: Tensor, y: Tensor, loss_fcn: Module | None = None, verbose: bool = False, max_iter: int = 5000) None [source]#
Fits a PyTorch model to the given input data. :param model: The PyTorch model to be trained. :type model: Module :param X: The input data. :type X: Tensor :param y: The target data. :type y: Tensor :param loss_fcn: The loss function to be used.
If not provided, the Mean Squared Error (MSE) loss function will be used. Defaults to
None
.- Parameters:
verbose (bool, optional) – Whether to print the loss at each epoch. Defaults to
False
.