Skip to content

Add support for MPS backend #90

@schroedk

Description

@schroedk

The current class TorchModel has the following init:

class TorchModel(ABC, ToStringMixin):
    """
    sensAI abstraction for torch models, which supports one-line training, allows for convenient model application,
    has basic mechanisms for data scaling, and soundly handles persistence (via pickle).
    An instance wraps a torch.nn.Module, which is constructed on demand during training via the factory method
    createTorchModule.
    """
    log: logging.Logger = log.getChild(__qualname__)

    def __init__(self, cuda=True) -> None:
        self.cuda: bool = cuda
        self.module: Optional[torch.nn.Module] = None
        self.outputScaler: Optional[TensorScaler] = None
        self.inputScaler: Optional[TensorScaler] = None
        self.trainingInfo: Optional[TrainingInfo] = None
        self._gpu: Optional[int] = None
        self._normalisationCheckThreshold: Optional[int] = 5

and is responsible for putting the inputs of the torch model the corresponding device (here):

if self._is_cuda_enabled():
            torch.cuda.set_device(self._gpu)
            inputs = [t.cuda() for t in inputs]

I would like to suggest to include support for different torch backends, in particular the MPS-backend for Apple machines.

My first impression is, that this could be a breaking change, so let's discuss here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions