seconohe.torch

  1# -*- coding: utf-8 -*-
  2# Copyright (c) 2025 Salvador E. Tropea
  3# Copyright (c) 2025 Instituto Nacional de TecnologĂ­a Industrial
  4# License: GPL-3.0
  5# Project: SeCoNoHe
  6# PyTorch helpers
  7import contextlib  # For context manager
  8import logging
  9import torch
 10from typing import Optional, Iterator, cast
 11try:
 12    import comfy.model_management as mm
 13    with_comfy = True
 14except Exception:
 15    with_comfy = False
 16
 17
 18def get_torch_device_options() -> tuple[list[str], str]:
 19    """
 20    Detects available PyTorch devices and returns a list and a suitable default.
 21
 22    Scans for CPU, CUDA devices, and MPS (for Apple Silicon), providing a list
 23    of device strings (e.g., 'cpu', 'cuda', 'cuda:0') and a recommended default.
 24
 25    :return: A tuple containing the list of available device strings and the
 26             recommended default device string.
 27    :rtype: tuple[list[str], str]
 28    """
 29    # We always have CPU
 30    default = "cpu"
 31    options = [default]
 32    # Do we have CUDA?
 33    if torch.cuda.is_available():
 34        default = "cuda"
 35        options.append(default)
 36        for i in range(torch.cuda.device_count()):
 37            options.append(f"cuda:{i}")  # Specific CUDA devices
 38    # Is this a Mac?
 39    if torch.backends.mps.is_available() and torch.backends.mps.is_built():
 40        options.append("mps")
 41        if default == "cpu":
 42            default = "mps"
 43    return options, default
 44
 45
 46def get_offload_device() -> torch.device:
 47    """
 48    Gets the appropriate device for offloading models.
 49
 50    Uses `comfy.model_management.unet_offload_device()` if available in a ComfyUI
 51    environment, otherwise defaults to the CPU.
 52
 53    :return: The torch.device object to use for offloading.
 54    :rtype: torch.device
 55    """
 56    if with_comfy:
 57        return cast(torch.device, mm.unet_offload_device())
 58    return torch.device("cpu")
 59
 60
 61def get_canonical_device(device: str | torch.device) -> torch.device:
 62    """
 63    Converts a device string or object into a canonical torch.device object.
 64
 65    Ensures that a device string like 'cuda' is converted to its fully
 66    indexed form, e.g., 'cuda:0', by checking the current default device.
 67
 68    :param device: The device identifier to canonicalize.
 69    :type device: str | torch.device
 70    :return: A torch.device object with an explicit index if applicable.
 71    :rtype: torch.device
 72    """
 73    if not isinstance(device, torch.device):
 74        device = torch.device(device)
 75
 76    # If it's a CUDA device and no index is specified, get the default one.
 77    if device.type == 'cuda' and device.index is None:
 78        # NOTE: This adds a dependency on torch.cuda.current_device()
 79        # The first solution is often better as it doesn't need this.
 80        return torch.device(f'cuda:{torch.cuda.current_device()}')
 81    return device
 82
 83
 84# ##################################################################################
 85# # Helper for inference (Target device, offload, eval, no_grad and cuDNN Benchmark)
 86# ##################################################################################
 87
 88@contextlib.contextmanager
 89def model_to_target(logger: logging.Logger, model: torch.nn.Module) -> Iterator[None]:
 90    """
 91    A context manager for safe model inference.
 92
 93    Handles device placement, inference state (`eval()`, `no_grad()`),
 94    optional cuDNN benchmark settings, and safe offloading in a `finally` block
 95    to ensure resources are managed even if errors occur.
 96
 97    The model object is expected to have two optional custom attributes:
 98    - ``target_device`` (torch.device): The device to run inference on.
 99    - ``cudnn_benchmark_setting`` (bool): The desired cuDNN benchmark state.
100
101    Example Usage::
102
103        with model_to_target(logger, my_model):
104            # Your inference code here
105            output = my_model(input_tensor)
106
107    :param logger: A logger instance for logging state changes.
108    :type logger: logging.Logger
109    :param model: The torch.nn.Module to manage.
110    :type model: torch.nn.Module
111    :yield: None
112    """
113    if not isinstance(model, torch.nn.Module):
114        with torch.no_grad():
115            yield  # The code inside the 'with' statement runs here
116        return
117
118    # --- EXPLICIT TYPE ANNOTATIONS ---
119    # Tell mypy the expected types for these variables.
120    target_device: torch.device
121    original_device: torch.device
122    original_cudnn_benchmark_state: Optional[bool] = None  # Default is to keep the current setting
123
124    # 1. Determine target device from the model object
125    # Use a hasattr check for robustness, and type hinting
126    if hasattr(model, 'target_device') and isinstance(model.target_device, torch.device):
127        # Mypy now understands that model.target_device exists and is a torch.device
128        target_device = model.target_device
129    else:
130        logger.warning("model_to_target: 'target_device' attribute not found or is not a torch.device on the model. "
131                       "Defaulting to model's current device.")
132        # Ensure model has parameters before calling next()
133        try:
134            target_device = next(model.parameters()).device
135        except StopIteration:
136            logger.warning("model_to_target: Model has no parameters. Cannot determine device. Assuming CPU.")
137            target_device = torch.device("cpu")
138
139    # 2. Get CUDNN benchmark setting from the model object (optional)
140    # Use hasattr as this is an optional setting that not all models might have.
141    cudnn_benchmark_enabled: Optional[bool] = None
142    if hasattr(model, 'cudnn_benchmark_setting'):
143        setting = model.cudnn_benchmark_setting
144        if isinstance(setting, bool):
145            cudnn_benchmark_enabled = setting
146        else:
147            logger.warning(f"model.cudnn_benchmark_setting was not a bool, but {type(setting)}. Ignoring.")
148
149    try:
150        # Get original_device from model after ensuring it has parameters
151        original_device = next(model.parameters()).device
152    except StopIteration:
153        original_device = torch.device("cpu")  # Match fallback from above
154
155    is_cuda_target = target_device.type == 'cuda'
156
157    try:
158        # 3. Manage cuDNN benchmark state
159        if (cudnn_benchmark_enabled is not None and is_cuda_target and hasattr(torch.backends, 'cudnn') and
160           torch.backends.cudnn.is_available()):
161            if torch.backends.cudnn.benchmark != cudnn_benchmark_enabled:
162                original_cudnn_benchmark_state = torch.backends.cudnn.benchmark
163                torch.backends.cudnn.benchmark = cudnn_benchmark_enabled
164                logger.debug(f"Temporarily set cuDNN benchmark to {torch.backends.cudnn.benchmark}")
165
166        # 4. Move model to target device if not already there
167        if original_device != target_device:
168            logger.debug(f"Moving model from `{original_device}` to target device `{target_device}` for inference.")
169            model.to(target_device)
170
171        # 5. Set to eval mode and disable gradients for the operation
172        model.eval()
173        with torch.no_grad():
174            yield  # The code inside the 'with' statement runs here
175
176    finally:
177        # 6. Restore original cuDNN benchmark state
178        if original_cudnn_benchmark_state is not None:
179            # This check is sufficient because it will only be not None if we set it inside the try block
180            torch.backends.cudnn.benchmark = original_cudnn_benchmark_state
181            logger.debug(f"Restored cuDNN benchmark to {original_cudnn_benchmark_state}")
182
183        # 7. Offload model back to CPU
184        if with_comfy:
185            offload_device = get_offload_device()
186            try:
187                current_device_after_yield = next(model.parameters()).device
188                if current_device_after_yield != offload_device:
189                    logger.debug(f"Offloading model from `{current_device_after_yield}` to offload device `{offload_device}`.")
190                    model.to(offload_device)
191                    # Clear cache if we were on a CUDA device
192                    if 'cuda' in str(current_device_after_yield):
193                        torch.cuda.empty_cache()
194            except StopIteration:
195                pass  # Model with no parameters doesn't need offloading.
def get_torch_device_options() -> tuple[list[str], str]:
19def get_torch_device_options() -> tuple[list[str], str]:
20    """
21    Detects available PyTorch devices and returns a list and a suitable default.
22
23    Scans for CPU, CUDA devices, and MPS (for Apple Silicon), providing a list
24    of device strings (e.g., 'cpu', 'cuda', 'cuda:0') and a recommended default.
25
26    :return: A tuple containing the list of available device strings and the
27             recommended default device string.
28    :rtype: tuple[list[str], str]
29    """
30    # We always have CPU
31    default = "cpu"
32    options = [default]
33    # Do we have CUDA?
34    if torch.cuda.is_available():
35        default = "cuda"
36        options.append(default)
37        for i in range(torch.cuda.device_count()):
38            options.append(f"cuda:{i}")  # Specific CUDA devices
39    # Is this a Mac?
40    if torch.backends.mps.is_available() and torch.backends.mps.is_built():
41        options.append("mps")
42        if default == "cpu":
43            default = "mps"
44    return options, default

Detects available PyTorch devices and returns a list and a suitable default.

Scans for CPU, CUDA devices, and MPS (for Apple Silicon), providing a list of device strings (e.g., 'cpu', 'cuda', 'cuda:0') and a recommended default.

Returns

A tuple containing the list of available device strings and the recommended default device string.

def get_offload_device() -> torch.device:
47def get_offload_device() -> torch.device:
48    """
49    Gets the appropriate device for offloading models.
50
51    Uses `comfy.model_management.unet_offload_device()` if available in a ComfyUI
52    environment, otherwise defaults to the CPU.
53
54    :return: The torch.device object to use for offloading.
55    :rtype: torch.device
56    """
57    if with_comfy:
58        return cast(torch.device, mm.unet_offload_device())
59    return torch.device("cpu")

Gets the appropriate device for offloading models.

Uses comfy.model_management.unet_offload_device() if available in a ComfyUI environment, otherwise defaults to the CPU.

Returns

The torch.device object to use for offloading.

def get_canonical_device(device: str | torch.device) -> torch.device:
62def get_canonical_device(device: str | torch.device) -> torch.device:
63    """
64    Converts a device string or object into a canonical torch.device object.
65
66    Ensures that a device string like 'cuda' is converted to its fully
67    indexed form, e.g., 'cuda:0', by checking the current default device.
68
69    :param device: The device identifier to canonicalize.
70    :type device: str | torch.device
71    :return: A torch.device object with an explicit index if applicable.
72    :rtype: torch.device
73    """
74    if not isinstance(device, torch.device):
75        device = torch.device(device)
76
77    # If it's a CUDA device and no index is specified, get the default one.
78    if device.type == 'cuda' and device.index is None:
79        # NOTE: This adds a dependency on torch.cuda.current_device()
80        # The first solution is often better as it doesn't need this.
81        return torch.device(f'cuda:{torch.cuda.current_device()}')
82    return device

Converts a device string or object into a canonical torch.device object.

Ensures that a device string like 'cuda' is converted to its fully indexed form, e.g., 'cuda:0', by checking the current default device.

Parameters
  • device: The device identifier to canonicalize.
Returns

A torch.device object with an explicit index if applicable.

@contextlib.contextmanager
def model_to_target( logger: logging.Logger, model: torch.nn.modules.module.Module) -> Iterator[NoneType]:
 89@contextlib.contextmanager
 90def model_to_target(logger: logging.Logger, model: torch.nn.Module) -> Iterator[None]:
 91    """
 92    A context manager for safe model inference.
 93
 94    Handles device placement, inference state (`eval()`, `no_grad()`),
 95    optional cuDNN benchmark settings, and safe offloading in a `finally` block
 96    to ensure resources are managed even if errors occur.
 97
 98    The model object is expected to have two optional custom attributes:
 99    - ``target_device`` (torch.device): The device to run inference on.
100    - ``cudnn_benchmark_setting`` (bool): The desired cuDNN benchmark state.
101
102    Example Usage::
103
104        with model_to_target(logger, my_model):
105            # Your inference code here
106            output = my_model(input_tensor)
107
108    :param logger: A logger instance for logging state changes.
109    :type logger: logging.Logger
110    :param model: The torch.nn.Module to manage.
111    :type model: torch.nn.Module
112    :yield: None
113    """
114    if not isinstance(model, torch.nn.Module):
115        with torch.no_grad():
116            yield  # The code inside the 'with' statement runs here
117        return
118
119    # --- EXPLICIT TYPE ANNOTATIONS ---
120    # Tell mypy the expected types for these variables.
121    target_device: torch.device
122    original_device: torch.device
123    original_cudnn_benchmark_state: Optional[bool] = None  # Default is to keep the current setting
124
125    # 1. Determine target device from the model object
126    # Use a hasattr check for robustness, and type hinting
127    if hasattr(model, 'target_device') and isinstance(model.target_device, torch.device):
128        # Mypy now understands that model.target_device exists and is a torch.device
129        target_device = model.target_device
130    else:
131        logger.warning("model_to_target: 'target_device' attribute not found or is not a torch.device on the model. "
132                       "Defaulting to model's current device.")
133        # Ensure model has parameters before calling next()
134        try:
135            target_device = next(model.parameters()).device
136        except StopIteration:
137            logger.warning("model_to_target: Model has no parameters. Cannot determine device. Assuming CPU.")
138            target_device = torch.device("cpu")
139
140    # 2. Get CUDNN benchmark setting from the model object (optional)
141    # Use hasattr as this is an optional setting that not all models might have.
142    cudnn_benchmark_enabled: Optional[bool] = None
143    if hasattr(model, 'cudnn_benchmark_setting'):
144        setting = model.cudnn_benchmark_setting
145        if isinstance(setting, bool):
146            cudnn_benchmark_enabled = setting
147        else:
148            logger.warning(f"model.cudnn_benchmark_setting was not a bool, but {type(setting)}. Ignoring.")
149
150    try:
151        # Get original_device from model after ensuring it has parameters
152        original_device = next(model.parameters()).device
153    except StopIteration:
154        original_device = torch.device("cpu")  # Match fallback from above
155
156    is_cuda_target = target_device.type == 'cuda'
157
158    try:
159        # 3. Manage cuDNN benchmark state
160        if (cudnn_benchmark_enabled is not None and is_cuda_target and hasattr(torch.backends, 'cudnn') and
161           torch.backends.cudnn.is_available()):
162            if torch.backends.cudnn.benchmark != cudnn_benchmark_enabled:
163                original_cudnn_benchmark_state = torch.backends.cudnn.benchmark
164                torch.backends.cudnn.benchmark = cudnn_benchmark_enabled
165                logger.debug(f"Temporarily set cuDNN benchmark to {torch.backends.cudnn.benchmark}")
166
167        # 4. Move model to target device if not already there
168        if original_device != target_device:
169            logger.debug(f"Moving model from `{original_device}` to target device `{target_device}` for inference.")
170            model.to(target_device)
171
172        # 5. Set to eval mode and disable gradients for the operation
173        model.eval()
174        with torch.no_grad():
175            yield  # The code inside the 'with' statement runs here
176
177    finally:
178        # 6. Restore original cuDNN benchmark state
179        if original_cudnn_benchmark_state is not None:
180            # This check is sufficient because it will only be not None if we set it inside the try block
181            torch.backends.cudnn.benchmark = original_cudnn_benchmark_state
182            logger.debug(f"Restored cuDNN benchmark to {original_cudnn_benchmark_state}")
183
184        # 7. Offload model back to CPU
185        if with_comfy:
186            offload_device = get_offload_device()
187            try:
188                current_device_after_yield = next(model.parameters()).device
189                if current_device_after_yield != offload_device:
190                    logger.debug(f"Offloading model from `{current_device_after_yield}` to offload device `{offload_device}`.")
191                    model.to(offload_device)
192                    # Clear cache if we were on a CUDA device
193                    if 'cuda' in str(current_device_after_yield):
194                        torch.cuda.empty_cache()
195            except StopIteration:
196                pass  # Model with no parameters doesn't need offloading.

A context manager for safe model inference.

Handles device placement, inference state (eval(), no_grad()), optional cuDNN benchmark settings, and safe offloading in a finally block to ensure resources are managed even if errors occur.

The model object is expected to have two optional custom attributes:

  • target_device (torch.device): The device to run inference on.
  • cudnn_benchmark_setting (bool): The desired cuDNN benchmark state.

Example Usage::

with model_to_target(logger, my_model):
    # Your inference code here
    output = my_model(input_tensor)
Parameters
  • logger: A logger instance for logging state changes.
  • model: The torch.nn.Module to manage. :yield: None